backend/src/api/v1/channels/uuid/socket.rs
Radical 1c07957c4e
All checks were successful
ci/woodpecker/push/build-and-publish Pipeline was successful
ci/woodpecker/pr/build-and-publish Pipeline was successful
ci/woodpecker/pull_request_closed/build-and-publish Pipeline was successful
refactor: small dependency optimizations
2025-07-20 18:45:50 +02:00

138 lines
3.8 KiB
Rust

use std::sync::Arc;
use axum::{
extract::{Path, State, WebSocketUpgrade, ws::Message},
http::HeaderMap,
response::IntoResponse,
};
use futures_util::{SinkExt, StreamExt};
use serde::Deserialize;
use uuid::Uuid;
use crate::{
AppState,
api::v1::auth::check_access_token,
error::Error,
objects::{Channel, Member},
utils::global_checks,
};
#[derive(Deserialize)]
struct MessageBody {
message: String,
reply_to: Option<Uuid>,
}
pub async fn ws(
ws: WebSocketUpgrade,
State(app_state): State<Arc<AppState>>,
Path(channel_uuid): Path<Uuid>,
headers: HeaderMap,
) -> Result<impl IntoResponse, Error> {
// Retrieve auth header
let auth_token = headers.get(axum::http::header::SEC_WEBSOCKET_PROTOCOL);
if auth_token.is_none() {
return Err(Error::Unauthorized(
"No authorization header provided".to_string(),
));
}
let auth_raw = auth_token.unwrap().to_str()?;
let mut auth = auth_raw.split_whitespace();
let response_proto = auth.next();
let auth_value = auth.next();
if response_proto.is_none() {
return Err(Error::BadRequest(
"Sec-WebSocket-Protocol header is empty".to_string(),
));
} else if response_proto.is_some_and(|rp| rp != "Authorization,") {
return Err(Error::BadRequest(
"First protocol should be Authorization".to_string(),
));
}
if auth_value.is_none() {
return Err(Error::BadRequest("No token provided".to_string()));
}
let auth_header = auth_value.unwrap();
let mut conn = app_state
.pool
.get()
.await
.map_err(crate::error::Error::from)?;
// Authorize client using auth header
let uuid = check_access_token(auth_header, &mut conn).await?;
global_checks(&app_state, uuid).await?;
let channel = Channel::fetch_one(&app_state, channel_uuid).await?;
Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?;
let mut pubsub = app_state
.cache_pool
.get_async_pubsub()
.await
.map_err(crate::error::Error::from)?;
let mut res = ws.on_upgrade(async move |socket| {
let (mut sender, mut receiver) = socket.split();
tokio::spawn(async move {
pubsub.subscribe(channel_uuid.to_string()).await?;
while let Some(msg) = pubsub.on_message().next().await {
let payload: String = msg.get_payload()?;
sender.send(payload.into()).await?;
}
Ok::<(), crate::error::Error>(())
});
tokio::spawn(async move {
while let Some(msg) = receiver.next().await {
if let Ok(Message::Text(text)) = msg {
let message_body: MessageBody = serde_json::from_str(&text)?;
let message = channel
.new_message(
&app_state,
uuid,
message_body.message,
message_body.reply_to,
)
.await?;
redis::cmd("PUBLISH")
.arg(&[channel_uuid.to_string(), serde_json::to_string(&message)?])
.exec_async(
&mut app_state
.cache_pool
.get_multiplexed_tokio_connection()
.await?,
)
.await?;
}
}
Ok::<(), crate::error::Error>(())
});
});
let headers = res.headers_mut();
headers.append(
axum::http::header::SEC_WEBSOCKET_PROTOCOL,
"Authorization".parse()?,
);
// respond immediately with response connected to WS session
Ok(res)
}