1
0
Fork 0
forked from gorb/backend

fix: hack around websocket spec to make tokens work

This commit is contained in:
Radical 2025-05-26 19:41:32 +02:00
parent 5d26f94cdd
commit efa0cd555f
2 changed files with 50 additions and 8 deletions

View file

@ -1,4 +1,8 @@
use actix_web::{Error, HttpRequest, HttpResponse, get, rt, web}; use actix_web::{
Error, HttpRequest, HttpResponse, get,
http::header::{HeaderValue, SEC_WEBSOCKET_PROTOCOL},
rt, web,
};
use actix_ws::AggregatedMessage; use actix_ws::AggregatedMessage;
use futures_util::StreamExt as _; use futures_util::StreamExt as _;
use uuid::Uuid; use uuid::Uuid;
@ -7,7 +11,7 @@ use crate::{
Data, Data,
api::v1::auth::check_access_token, api::v1::auth::check_access_token,
structs::{Channel, Member}, structs::{Channel, Member},
utils::get_auth_header, utils::get_ws_protocol_header,
}; };
#[get("{uuid}/channels/{channel_uuid}/socket")] #[get("{uuid}/channels/{channel_uuid}/socket")]
@ -21,7 +25,7 @@ pub async fn echo(
let headers = req.headers(); let headers = req.headers();
// Retrieve auth header // Retrieve auth header
let auth_header = get_auth_header(headers)?; let auth_header = get_ws_protocol_header(headers)?;
// Get uuids from path // Get uuids from path
let (guild_uuid, channel_uuid) = path.into_inner(); let (guild_uuid, channel_uuid) = path.into_inner();
@ -46,7 +50,7 @@ pub async fn echo(
.await?; .await?;
} }
let (res, mut session_1, stream) = actix_ws::handle(&req, stream)?; let (mut res, mut session_1, stream) = actix_ws::handle(&req, stream)?;
let mut stream = stream let mut stream = stream
.aggregate_continuations() .aggregate_continuations()
@ -77,10 +81,7 @@ pub async fn echo(
while let Some(msg) = stream.next().await { while let Some(msg) = stream.next().await {
match msg { match msg {
Ok(AggregatedMessage::Text(text)) => { Ok(AggregatedMessage::Text(text)) => {
let mut conn = data let mut conn = data.cache_pool.get_multiplexed_tokio_connection().await?;
.cache_pool
.get_multiplexed_tokio_connection()
.await?;
redis::cmd("PUBLISH") redis::cmd("PUBLISH")
.arg(&[channel_uuid.to_string(), text.to_string()]) .arg(&[channel_uuid.to_string(), text.to_string()])
@ -109,6 +110,13 @@ pub async fn echo(
Ok::<(), crate::error::Error>(()) Ok::<(), crate::error::Error>(())
}); });
let headers = res.headers_mut();
headers.append(
SEC_WEBSOCKET_PROTOCOL,
HeaderValue::from_str("Authorization")?,
);
// respond immediately with response connected to WS session // respond immediately with response connected to WS session
Ok(res) Ok(res)
} }

View file

@ -45,6 +45,40 @@ pub fn get_auth_header(headers: &HeaderMap) -> Result<&str, Error> {
Ok(auth_value.unwrap()) Ok(auth_value.unwrap())
} }
pub fn get_ws_protocol_header(headers: &HeaderMap) -> Result<&str, Error> {
let auth_token = headers.get(actix_web::http::header::SEC_WEBSOCKET_PROTOCOL);
if auth_token.is_none() {
return Err(Error::Unauthorized(
"No authorization header provided".to_string(),
));
}
let auth_raw = auth_token.unwrap().to_str()?;
let mut auth = auth_raw.split_whitespace();
let response_proto = auth.next();
let auth_value = auth.next();
if response_proto.is_none() {
return Err(Error::BadRequest(
"Sec-WebSocket-Protocol header is empty".to_string(),
));
} else if response_proto.is_some_and(|rp| rp != "Authorization,") {
return Err(Error::BadRequest(
"First protocol should be Authorization".to_string(),
));
}
if auth_value.is_none() {
return Err(Error::BadRequest("No token provided".to_string()));
}
Ok(auth_value.unwrap())
}
pub fn refresh_token_cookie(refresh_token: String) -> Cookie<'static> { pub fn refresh_token_cookie(refresh_token: String) -> Cookie<'static> {
Cookie::build("refresh_token", refresh_token) Cookie::build("refresh_token", refresh_token)
.http_only(true) .http_only(true)