From efa0cd555f41858dbd35470860754cedad517625 Mon Sep 17 00:00:00 2001 From: Radical Date: Mon, 26 May 2025 19:41:32 +0200 Subject: [PATCH] fix: hack around websocket spec to make tokens work --- .../v1/servers/uuid/channels/uuid/socket.rs | 24 ++++++++----- src/utils.rs | 34 +++++++++++++++++++ 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/src/api/v1/servers/uuid/channels/uuid/socket.rs b/src/api/v1/servers/uuid/channels/uuid/socket.rs index 3300e6c..c16efa7 100644 --- a/src/api/v1/servers/uuid/channels/uuid/socket.rs +++ b/src/api/v1/servers/uuid/channels/uuid/socket.rs @@ -1,4 +1,8 @@ -use actix_web::{Error, HttpRequest, HttpResponse, get, rt, web}; +use actix_web::{ + Error, HttpRequest, HttpResponse, get, + http::header::{HeaderValue, SEC_WEBSOCKET_PROTOCOL}, + rt, web, +}; use actix_ws::AggregatedMessage; use futures_util::StreamExt as _; use uuid::Uuid; @@ -7,7 +11,7 @@ use crate::{ Data, api::v1::auth::check_access_token, structs::{Channel, Member}, - utils::get_auth_header, + utils::get_ws_protocol_header, }; #[get("{uuid}/channels/{channel_uuid}/socket")] @@ -21,7 +25,7 @@ pub async fn echo( let headers = req.headers(); // Retrieve auth header - let auth_header = get_auth_header(headers)?; + let auth_header = get_ws_protocol_header(headers)?; // Get uuids from path let (guild_uuid, channel_uuid) = path.into_inner(); @@ -46,7 +50,7 @@ pub async fn echo( .await?; } - let (res, mut session_1, stream) = actix_ws::handle(&req, stream)?; + let (mut res, mut session_1, stream) = actix_ws::handle(&req, stream)?; let mut stream = stream .aggregate_continuations() @@ -77,10 +81,7 @@ pub async fn echo( while let Some(msg) = stream.next().await { match msg { Ok(AggregatedMessage::Text(text)) => { - let mut conn = data - .cache_pool - .get_multiplexed_tokio_connection() - .await?; + let mut conn = data.cache_pool.get_multiplexed_tokio_connection().await?; redis::cmd("PUBLISH") .arg(&[channel_uuid.to_string(), text.to_string()]) @@ -109,6 +110,13 @@ pub async fn echo( Ok::<(), crate::error::Error>(()) }); + let headers = res.headers_mut(); + + headers.append( + SEC_WEBSOCKET_PROTOCOL, + HeaderValue::from_str("Authorization")?, + ); + // respond immediately with response connected to WS session Ok(res) } diff --git a/src/utils.rs b/src/utils.rs index f9a5705..631b003 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -45,6 +45,40 @@ pub fn get_auth_header(headers: &HeaderMap) -> Result<&str, Error> { Ok(auth_value.unwrap()) } +pub fn get_ws_protocol_header(headers: &HeaderMap) -> Result<&str, Error> { + let auth_token = headers.get(actix_web::http::header::SEC_WEBSOCKET_PROTOCOL); + + if auth_token.is_none() { + return Err(Error::Unauthorized( + "No authorization header provided".to_string(), + )); + } + + let auth_raw = auth_token.unwrap().to_str()?; + + let mut auth = auth_raw.split_whitespace(); + + let response_proto = auth.next(); + + let auth_value = auth.next(); + + if response_proto.is_none() { + return Err(Error::BadRequest( + "Sec-WebSocket-Protocol header is empty".to_string(), + )); + } else if response_proto.is_some_and(|rp| rp != "Authorization,") { + return Err(Error::BadRequest( + "First protocol should be Authorization".to_string(), + )); + } + + if auth_value.is_none() { + return Err(Error::BadRequest("No token provided".to_string())); + } + + Ok(auth_value.unwrap()) +} + pub fn refresh_token_cookie(refresh_token: String) -> Cookie<'static> { Cookie::build("refresh_token", refresh_token) .http_only(true)