feat: reimplement old websocket
This commit is contained in:
parent
a602c2624f
commit
2fb7e7781f
6 changed files with 91 additions and 70 deletions
|
@ -38,7 +38,7 @@ bindet = "0.3.2"
|
||||||
bunny-api-tokio = { version = "0.4", features = ["edge_storage"], default-features = false }
|
bunny-api-tokio = { version = "0.4", features = ["edge_storage"], default-features = false }
|
||||||
|
|
||||||
# Web Server
|
# Web Server
|
||||||
axum = { version = "0.8.4", features = ["macros", "multipart"] }
|
axum = { version = "0.8.4", features = ["macros", "multipart", "ws"] }
|
||||||
tower-http = { version = "0.6.6", features = ["cors"] }
|
tower-http = { version = "0.6.6", features = ["cors"] }
|
||||||
axum-extra = { version = "0.10.1", features = ["cookie", "typed-header"] }
|
axum-extra = { version = "0.10.1", features = ["cookie", "typed-header"] }
|
||||||
socketioxide = { version = "0.17.2", features = ["state"] }
|
socketioxide = { version = "0.17.2", features = ["state"] }
|
||||||
|
|
|
@ -2,7 +2,7 @@ use std::sync::Arc;
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
Router,
|
Router,
|
||||||
routing::{delete, get, patch},
|
routing::{any, delete, get, patch},
|
||||||
};
|
};
|
||||||
//use socketioxide::SocketIo;
|
//use socketioxide::SocketIo;
|
||||||
|
|
||||||
|
@ -15,5 +15,6 @@ pub fn router() -> Router<Arc<AppState>> {
|
||||||
.route("/{uuid}", get(uuid::get))
|
.route("/{uuid}", get(uuid::get))
|
||||||
.route("/{uuid}", delete(uuid::delete))
|
.route("/{uuid}", delete(uuid::delete))
|
||||||
.route("/{uuid}", patch(uuid::patch))
|
.route("/{uuid}", patch(uuid::patch))
|
||||||
|
.route("/{uuid}/socket", any(uuid::socket::ws))
|
||||||
.route("/{uuid}/messages", get(uuid::messages::get))
|
.route("/{uuid}/messages", get(uuid::messages::get))
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
//! `/api/v1/channels/{uuid}` Channel specific endpoints
|
//! `/api/v1/channels/{uuid}` Channel specific endpoints
|
||||||
|
|
||||||
pub mod messages;
|
pub mod messages;
|
||||||
//pub mod socket;
|
pub mod socket;
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
|
|
@ -1,18 +1,21 @@
|
||||||
use actix_web::{
|
use std::sync::Arc;
|
||||||
Error, HttpRequest, HttpResponse, get,
|
|
||||||
http::header::{HeaderValue, SEC_WEBSOCKET_PROTOCOL},
|
use axum::{
|
||||||
rt, web,
|
extract::{Path, State, WebSocketUpgrade, ws::Message},
|
||||||
|
http::HeaderMap,
|
||||||
|
response::IntoResponse,
|
||||||
};
|
};
|
||||||
use actix_ws::AggregatedMessage;
|
use futures::SinkExt;
|
||||||
use futures_util::StreamExt as _;
|
use futures_util::StreamExt as _;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
Data,
|
AppState,
|
||||||
api::v1::auth::check_access_token,
|
api::v1::auth::check_access_token,
|
||||||
|
error::Error,
|
||||||
objects::{Channel, Member},
|
objects::{Channel, Member},
|
||||||
utils::{get_ws_protocol_header, global_checks},
|
utils::global_checks,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
|
@ -21,100 +24,114 @@ struct MessageBody {
|
||||||
reply_to: Option<Uuid>,
|
reply_to: Option<Uuid>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/{uuid}/socket")]
|
|
||||||
pub async fn ws(
|
pub async fn ws(
|
||||||
req: HttpRequest,
|
ws: WebSocketUpgrade,
|
||||||
path: web::Path<(Uuid,)>,
|
State(app_state): State<Arc<AppState>>,
|
||||||
stream: web::Payload,
|
Path(channel_uuid): Path<Uuid>,
|
||||||
data: web::Data<Data>,
|
headers: HeaderMap,
|
||||||
) -> Result<HttpResponse, Error> {
|
) -> Result<impl IntoResponse, Error> {
|
||||||
// Get all headers
|
|
||||||
let headers = req.headers();
|
|
||||||
|
|
||||||
// Retrieve auth header
|
// Retrieve auth header
|
||||||
let auth_header = get_ws_protocol_header(headers)?;
|
let auth_token = headers.get(axum::http::header::SEC_WEBSOCKET_PROTOCOL);
|
||||||
|
|
||||||
// Get uuid from path
|
if auth_token.is_none() {
|
||||||
let channel_uuid = path.into_inner().0;
|
return Err(Error::Unauthorized(
|
||||||
|
"No authorization header provided".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
let mut conn = data.pool.get().await.map_err(crate::error::Error::from)?;
|
let auth_raw = auth_token.unwrap().to_str()?;
|
||||||
|
|
||||||
|
let mut auth = auth_raw.split_whitespace();
|
||||||
|
|
||||||
|
let response_proto = auth.next();
|
||||||
|
|
||||||
|
let auth_value = auth.next();
|
||||||
|
|
||||||
|
if response_proto.is_none() {
|
||||||
|
return Err(Error::BadRequest(
|
||||||
|
"Sec-WebSocket-Protocol header is empty".to_string(),
|
||||||
|
));
|
||||||
|
} else if response_proto.is_some_and(|rp| rp != "Authorization,") {
|
||||||
|
return Err(Error::BadRequest(
|
||||||
|
"First protocol should be Authorization".to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if auth_value.is_none() {
|
||||||
|
return Err(Error::BadRequest("No token provided".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let auth_header = auth_value.unwrap();
|
||||||
|
|
||||||
|
let mut conn = app_state
|
||||||
|
.pool
|
||||||
|
.get()
|
||||||
|
.await
|
||||||
|
.map_err(crate::error::Error::from)?;
|
||||||
|
|
||||||
// Authorize client using auth header
|
// Authorize client using auth header
|
||||||
let uuid = check_access_token(auth_header, &mut conn).await?;
|
let uuid = check_access_token(auth_header, &mut conn).await?;
|
||||||
|
|
||||||
global_checks(&data, uuid).await?;
|
global_checks(&app_state, uuid).await?;
|
||||||
|
|
||||||
let channel = Channel::fetch_one(&data, channel_uuid).await?;
|
let channel = Channel::fetch_one(&app_state, channel_uuid).await?;
|
||||||
|
|
||||||
Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?;
|
Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?;
|
||||||
|
|
||||||
let (mut res, mut session_1, stream) = actix_ws::handle(&req, stream)?;
|
let mut pubsub = app_state
|
||||||
|
|
||||||
let mut stream = stream
|
|
||||||
.aggregate_continuations()
|
|
||||||
// aggregate continuation frames up to 1MiB
|
|
||||||
.max_continuation_size(2_usize.pow(20));
|
|
||||||
|
|
||||||
let mut pubsub = data
|
|
||||||
.cache_pool
|
.cache_pool
|
||||||
.get_async_pubsub()
|
.get_async_pubsub()
|
||||||
.await
|
.await
|
||||||
.map_err(crate::error::Error::from)?;
|
.map_err(crate::error::Error::from)?;
|
||||||
|
|
||||||
let mut session_2 = session_1.clone();
|
let mut res = ws.on_upgrade(async move |socket| {
|
||||||
|
let (mut sender, mut receiver) = socket.split();
|
||||||
|
|
||||||
rt::spawn(async move {
|
tokio::spawn(async move {
|
||||||
pubsub.subscribe(channel_uuid.to_string()).await?;
|
pubsub.subscribe(channel_uuid.to_string()).await?;
|
||||||
while let Some(msg) = pubsub.on_message().next().await {
|
while let Some(msg) = pubsub.on_message().next().await {
|
||||||
let payload: String = msg.get_payload()?;
|
let payload: String = msg.get_payload()?;
|
||||||
session_1.text(payload).await?;
|
sender.send(payload.into()).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok::<(), crate::error::Error>(())
|
Ok::<(), crate::error::Error>(())
|
||||||
});
|
});
|
||||||
|
|
||||||
// start task but don't wait for it
|
tokio::spawn(async move {
|
||||||
rt::spawn(async move {
|
while let Some(msg) = receiver.next().await {
|
||||||
// receive messages from websocket
|
if let Ok(Message::Text(text)) = msg {
|
||||||
while let Some(msg) = stream.next().await {
|
|
||||||
match msg {
|
|
||||||
Ok(AggregatedMessage::Text(text)) => {
|
|
||||||
let mut conn = data.cache_pool.get_multiplexed_tokio_connection().await?;
|
|
||||||
|
|
||||||
let message_body: MessageBody = serde_json::from_str(&text)?;
|
let message_body: MessageBody = serde_json::from_str(&text)?;
|
||||||
|
|
||||||
let message = channel
|
let message = channel
|
||||||
.new_message(&data, uuid, message_body.message, message_body.reply_to)
|
.new_message(
|
||||||
|
&app_state,
|
||||||
|
uuid,
|
||||||
|
message_body.message,
|
||||||
|
message_body.reply_to,
|
||||||
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
redis::cmd("PUBLISH")
|
redis::cmd("PUBLISH")
|
||||||
.arg(&[channel_uuid.to_string(), serde_json::to_string(&message)?])
|
.arg(&[channel_uuid.to_string(), serde_json::to_string(&message)?])
|
||||||
.exec_async(&mut conn)
|
.exec_async(
|
||||||
|
&mut app_state
|
||||||
|
.cache_pool
|
||||||
|
.get_multiplexed_tokio_connection()
|
||||||
|
.await?,
|
||||||
|
)
|
||||||
.await?;
|
.await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(AggregatedMessage::Binary(bin)) => {
|
|
||||||
// echo binary message
|
|
||||||
session_2.binary(bin).await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(AggregatedMessage::Ping(msg)) => {
|
|
||||||
// respond to PING frame with PONG frame
|
|
||||||
session_2.pong(&msg).await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok::<(), crate::error::Error>(())
|
Ok::<(), crate::error::Error>(())
|
||||||
});
|
});
|
||||||
|
});
|
||||||
|
|
||||||
let headers = res.headers_mut();
|
let headers = res.headers_mut();
|
||||||
|
|
||||||
headers.append(
|
headers.append(
|
||||||
SEC_WEBSOCKET_PROTOCOL,
|
axum::http::header::SEC_WEBSOCKET_PROTOCOL,
|
||||||
HeaderValue::from_str("Authorization")?,
|
"Authorization".parse()?,
|
||||||
);
|
);
|
||||||
|
|
||||||
// respond immediately with response connected to WS session
|
// respond immediately with response connected to WS session
|
||||||
|
|
|
@ -17,7 +17,6 @@ mod users;
|
||||||
pub fn router(app_state: Arc<AppState>) -> Router<Arc<AppState>> {
|
pub fn router(app_state: Arc<AppState>) -> Router<Arc<AppState>> {
|
||||||
let router_with_auth = Router::new()
|
let router_with_auth = Router::new()
|
||||||
.nest("/users", users::router())
|
.nest("/users", users::router())
|
||||||
.nest("/channels", channels::router())
|
|
||||||
.nest("/guilds", guilds::router())
|
.nest("/guilds", guilds::router())
|
||||||
.nest("/invites", invites::router())
|
.nest("/invites", invites::router())
|
||||||
.nest("/me", me::router())
|
.nest("/me", me::router())
|
||||||
|
@ -28,6 +27,7 @@ pub fn router(app_state: Arc<AppState>) -> Router<Arc<AppState>> {
|
||||||
|
|
||||||
Router::new()
|
Router::new()
|
||||||
.route("/stats", get(stats::res))
|
.route("/stats", get(stats::res))
|
||||||
.nest("/auth", auth::router(app_state))
|
.nest("/auth", auth::router(app_state.clone()))
|
||||||
|
.nest("/channels", channels::router(app_state))
|
||||||
.merge(router_with_auth)
|
.merge(router_with_auth)
|
||||||
}
|
}
|
||||||
|
|
|
@ -83,6 +83,9 @@ pub enum Error {
|
||||||
TooManyRequests(String),
|
TooManyRequests(String),
|
||||||
#[error("{0}")]
|
#[error("{0}")]
|
||||||
InternalServerError(String),
|
InternalServerError(String),
|
||||||
|
// TODO: remove when doing socket.io
|
||||||
|
#[error(transparent)]
|
||||||
|
AxumError(#[from] axum::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IntoResponse for Error {
|
impl IntoResponse for Error {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue