From 2fb7e7781f42c419c0d9f31de04b5a784e4c11f7 Mon Sep 17 00:00:00 2001 From: Radical Date: Sun, 20 Jul 2025 18:11:08 +0200 Subject: [PATCH] feat: reimplement old websocket --- Cargo.toml | 2 +- src/api/v1/channels/mod.rs | 3 +- src/api/v1/channels/uuid/mod.rs | 2 +- src/api/v1/channels/uuid/socket.rs | 147 ++++++++++++++++------------- src/api/v1/mod.rs | 4 +- src/error.rs | 3 + 6 files changed, 91 insertions(+), 70 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2b9962c..e0c83bb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,7 @@ bindet = "0.3.2" bunny-api-tokio = { version = "0.4", features = ["edge_storage"], default-features = false } # Web Server -axum = { version = "0.8.4", features = ["macros", "multipart"] } +axum = { version = "0.8.4", features = ["macros", "multipart", "ws"] } tower-http = { version = "0.6.6", features = ["cors"] } axum-extra = { version = "0.10.1", features = ["cookie", "typed-header"] } socketioxide = { version = "0.17.2", features = ["state"] } diff --git a/src/api/v1/channels/mod.rs b/src/api/v1/channels/mod.rs index 24b62f7..cc033af 100644 --- a/src/api/v1/channels/mod.rs +++ b/src/api/v1/channels/mod.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use axum::{ Router, - routing::{delete, get, patch}, + routing::{any, delete, get, patch}, }; //use socketioxide::SocketIo; @@ -15,5 +15,6 @@ pub fn router() -> Router> { .route("/{uuid}", get(uuid::get)) .route("/{uuid}", delete(uuid::delete)) .route("/{uuid}", patch(uuid::patch)) + .route("/{uuid}/socket", any(uuid::socket::ws)) .route("/{uuid}/messages", get(uuid::messages::get)) } diff --git a/src/api/v1/channels/uuid/mod.rs b/src/api/v1/channels/uuid/mod.rs index 5c88a29..373742e 100644 --- a/src/api/v1/channels/uuid/mod.rs +++ b/src/api/v1/channels/uuid/mod.rs @@ -1,7 +1,7 @@ //! `/api/v1/channels/{uuid}` Channel specific endpoints pub mod messages; -//pub mod socket; +pub mod socket; use std::sync::Arc; diff --git a/src/api/v1/channels/uuid/socket.rs b/src/api/v1/channels/uuid/socket.rs index 7233f39..46a7334 100644 --- a/src/api/v1/channels/uuid/socket.rs +++ b/src/api/v1/channels/uuid/socket.rs @@ -1,18 +1,21 @@ -use actix_web::{ - Error, HttpRequest, HttpResponse, get, - http::header::{HeaderValue, SEC_WEBSOCKET_PROTOCOL}, - rt, web, +use std::sync::Arc; + +use axum::{ + extract::{Path, State, WebSocketUpgrade, ws::Message}, + http::HeaderMap, + response::IntoResponse, }; -use actix_ws::AggregatedMessage; +use futures::SinkExt; use futures_util::StreamExt as _; use serde::Deserialize; use uuid::Uuid; use crate::{ - Data, + AppState, api::v1::auth::check_access_token, + error::Error, objects::{Channel, Member}, - utils::{get_ws_protocol_header, global_checks}, + utils::global_checks, }; #[derive(Deserialize)] @@ -21,100 +24,114 @@ struct MessageBody { reply_to: Option, } -#[get("/{uuid}/socket")] pub async fn ws( - req: HttpRequest, - path: web::Path<(Uuid,)>, - stream: web::Payload, - data: web::Data, -) -> Result { - // Get all headers - let headers = req.headers(); - + ws: WebSocketUpgrade, + State(app_state): State>, + Path(channel_uuid): Path, + headers: HeaderMap, +) -> Result { // Retrieve auth header - let auth_header = get_ws_protocol_header(headers)?; + let auth_token = headers.get(axum::http::header::SEC_WEBSOCKET_PROTOCOL); - // Get uuid from path - let channel_uuid = path.into_inner().0; + if auth_token.is_none() { + return Err(Error::Unauthorized( + "No authorization header provided".to_string(), + )); + } - let mut conn = data.pool.get().await.map_err(crate::error::Error::from)?; + let auth_raw = auth_token.unwrap().to_str()?; + + let mut auth = auth_raw.split_whitespace(); + + let response_proto = auth.next(); + + let auth_value = auth.next(); + + if response_proto.is_none() { + return Err(Error::BadRequest( + "Sec-WebSocket-Protocol header is empty".to_string(), + )); + } else if response_proto.is_some_and(|rp| rp != "Authorization,") { + return Err(Error::BadRequest( + "First protocol should be Authorization".to_string(), + )); + } + + if auth_value.is_none() { + return Err(Error::BadRequest("No token provided".to_string())); + } + + let auth_header = auth_value.unwrap(); + + let mut conn = app_state + .pool + .get() + .await + .map_err(crate::error::Error::from)?; // Authorize client using auth header let uuid = check_access_token(auth_header, &mut conn).await?; - global_checks(&data, uuid).await?; + global_checks(&app_state, uuid).await?; - let channel = Channel::fetch_one(&data, channel_uuid).await?; + let channel = Channel::fetch_one(&app_state, channel_uuid).await?; Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; - let (mut res, mut session_1, stream) = actix_ws::handle(&req, stream)?; - - let mut stream = stream - .aggregate_continuations() - // aggregate continuation frames up to 1MiB - .max_continuation_size(2_usize.pow(20)); - - let mut pubsub = data + let mut pubsub = app_state .cache_pool .get_async_pubsub() .await .map_err(crate::error::Error::from)?; - let mut session_2 = session_1.clone(); + let mut res = ws.on_upgrade(async move |socket| { + let (mut sender, mut receiver) = socket.split(); - rt::spawn(async move { - pubsub.subscribe(channel_uuid.to_string()).await?; - while let Some(msg) = pubsub.on_message().next().await { - let payload: String = msg.get_payload()?; - session_1.text(payload).await?; - } + tokio::spawn(async move { + pubsub.subscribe(channel_uuid.to_string()).await?; + while let Some(msg) = pubsub.on_message().next().await { + let payload: String = msg.get_payload()?; + sender.send(payload.into()).await?; + } - Ok::<(), crate::error::Error>(()) - }); - - // start task but don't wait for it - rt::spawn(async move { - // receive messages from websocket - while let Some(msg) = stream.next().await { - match msg { - Ok(AggregatedMessage::Text(text)) => { - let mut conn = data.cache_pool.get_multiplexed_tokio_connection().await?; + Ok::<(), crate::error::Error>(()) + }); + tokio::spawn(async move { + while let Some(msg) = receiver.next().await { + if let Ok(Message::Text(text)) = msg { let message_body: MessageBody = serde_json::from_str(&text)?; let message = channel - .new_message(&data, uuid, message_body.message, message_body.reply_to) + .new_message( + &app_state, + uuid, + message_body.message, + message_body.reply_to, + ) .await?; redis::cmd("PUBLISH") .arg(&[channel_uuid.to_string(), serde_json::to_string(&message)?]) - .exec_async(&mut conn) + .exec_async( + &mut app_state + .cache_pool + .get_multiplexed_tokio_connection() + .await?, + ) .await?; } - - Ok(AggregatedMessage::Binary(bin)) => { - // echo binary message - session_2.binary(bin).await?; - } - - Ok(AggregatedMessage::Ping(msg)) => { - // respond to PING frame with PONG frame - session_2.pong(&msg).await?; - } - - _ => {} } - } - Ok::<(), crate::error::Error>(()) + Ok::<(), crate::error::Error>(()) + }); }); let headers = res.headers_mut(); headers.append( - SEC_WEBSOCKET_PROTOCOL, - HeaderValue::from_str("Authorization")?, + axum::http::header::SEC_WEBSOCKET_PROTOCOL, + "Authorization".parse()?, ); // respond immediately with response connected to WS session diff --git a/src/api/v1/mod.rs b/src/api/v1/mod.rs index 5ca9558..860944c 100644 --- a/src/api/v1/mod.rs +++ b/src/api/v1/mod.rs @@ -17,7 +17,6 @@ mod users; pub fn router(app_state: Arc) -> Router> { let router_with_auth = Router::new() .nest("/users", users::router()) - .nest("/channels", channels::router()) .nest("/guilds", guilds::router()) .nest("/invites", invites::router()) .nest("/me", me::router()) @@ -28,6 +27,7 @@ pub fn router(app_state: Arc) -> Router> { Router::new() .route("/stats", get(stats::res)) - .nest("/auth", auth::router(app_state)) + .nest("/auth", auth::router(app_state.clone())) + .nest("/channels", channels::router(app_state)) .merge(router_with_auth) } diff --git a/src/error.rs b/src/error.rs index 1b8f27c..d6f7a12 100644 --- a/src/error.rs +++ b/src/error.rs @@ -83,6 +83,9 @@ pub enum Error { TooManyRequests(String), #[error("{0}")] InternalServerError(String), + // TODO: remove when doing socket.io + #[error(transparent)] + AxumError(#[from] axum::Error), } impl IntoResponse for Error {