feat: reimplement old websocket

This commit is contained in:
Radical 2025-07-20 18:11:08 +02:00
parent a602c2624f
commit 2fb7e7781f
6 changed files with 91 additions and 70 deletions

View file

@ -38,7 +38,7 @@ bindet = "0.3.2"
bunny-api-tokio = { version = "0.4", features = ["edge_storage"], default-features = false }
# Web Server
axum = { version = "0.8.4", features = ["macros", "multipart"] }
axum = { version = "0.8.4", features = ["macros", "multipart", "ws"] }
tower-http = { version = "0.6.6", features = ["cors"] }
axum-extra = { version = "0.10.1", features = ["cookie", "typed-header"] }
socketioxide = { version = "0.17.2", features = ["state"] }

View file

@ -2,7 +2,7 @@ use std::sync::Arc;
use axum::{
Router,
routing::{delete, get, patch},
routing::{any, delete, get, patch},
};
//use socketioxide::SocketIo;
@ -15,5 +15,6 @@ pub fn router() -> Router<Arc<AppState>> {
.route("/{uuid}", get(uuid::get))
.route("/{uuid}", delete(uuid::delete))
.route("/{uuid}", patch(uuid::patch))
.route("/{uuid}/socket", any(uuid::socket::ws))
.route("/{uuid}/messages", get(uuid::messages::get))
}

View file

@ -1,7 +1,7 @@
//! `/api/v1/channels/{uuid}` Channel specific endpoints
pub mod messages;
//pub mod socket;
pub mod socket;
use std::sync::Arc;

View file

@ -1,18 +1,21 @@
use actix_web::{
Error, HttpRequest, HttpResponse, get,
http::header::{HeaderValue, SEC_WEBSOCKET_PROTOCOL},
rt, web,
use std::sync::Arc;
use axum::{
extract::{Path, State, WebSocketUpgrade, ws::Message},
http::HeaderMap,
response::IntoResponse,
};
use actix_ws::AggregatedMessage;
use futures::SinkExt;
use futures_util::StreamExt as _;
use serde::Deserialize;
use uuid::Uuid;
use crate::{
Data,
AppState,
api::v1::auth::check_access_token,
error::Error,
objects::{Channel, Member},
utils::{get_ws_protocol_header, global_checks},
utils::global_checks,
};
#[derive(Deserialize)]
@ -21,100 +24,114 @@ struct MessageBody {
reply_to: Option<Uuid>,
}
#[get("/{uuid}/socket")]
pub async fn ws(
req: HttpRequest,
path: web::Path<(Uuid,)>,
stream: web::Payload,
data: web::Data<Data>,
) -> Result<HttpResponse, Error> {
// Get all headers
let headers = req.headers();
ws: WebSocketUpgrade,
State(app_state): State<Arc<AppState>>,
Path(channel_uuid): Path<Uuid>,
headers: HeaderMap,
) -> Result<impl IntoResponse, Error> {
// Retrieve auth header
let auth_header = get_ws_protocol_header(headers)?;
let auth_token = headers.get(axum::http::header::SEC_WEBSOCKET_PROTOCOL);
// Get uuid from path
let channel_uuid = path.into_inner().0;
if auth_token.is_none() {
return Err(Error::Unauthorized(
"No authorization header provided".to_string(),
));
}
let mut conn = data.pool.get().await.map_err(crate::error::Error::from)?;
let auth_raw = auth_token.unwrap().to_str()?;
let mut auth = auth_raw.split_whitespace();
let response_proto = auth.next();
let auth_value = auth.next();
if response_proto.is_none() {
return Err(Error::BadRequest(
"Sec-WebSocket-Protocol header is empty".to_string(),
));
} else if response_proto.is_some_and(|rp| rp != "Authorization,") {
return Err(Error::BadRequest(
"First protocol should be Authorization".to_string(),
));
}
if auth_value.is_none() {
return Err(Error::BadRequest("No token provided".to_string()));
}
let auth_header = auth_value.unwrap();
let mut conn = app_state
.pool
.get()
.await
.map_err(crate::error::Error::from)?;
// Authorize client using auth header
let uuid = check_access_token(auth_header, &mut conn).await?;
global_checks(&data, uuid).await?;
global_checks(&app_state, uuid).await?;
let channel = Channel::fetch_one(&data, channel_uuid).await?;
let channel = Channel::fetch_one(&app_state, channel_uuid).await?;
Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?;
let (mut res, mut session_1, stream) = actix_ws::handle(&req, stream)?;
let mut stream = stream
.aggregate_continuations()
// aggregate continuation frames up to 1MiB
.max_continuation_size(2_usize.pow(20));
let mut pubsub = data
let mut pubsub = app_state
.cache_pool
.get_async_pubsub()
.await
.map_err(crate::error::Error::from)?;
let mut session_2 = session_1.clone();
let mut res = ws.on_upgrade(async move |socket| {
let (mut sender, mut receiver) = socket.split();
rt::spawn(async move {
pubsub.subscribe(channel_uuid.to_string()).await?;
while let Some(msg) = pubsub.on_message().next().await {
let payload: String = msg.get_payload()?;
session_1.text(payload).await?;
}
tokio::spawn(async move {
pubsub.subscribe(channel_uuid.to_string()).await?;
while let Some(msg) = pubsub.on_message().next().await {
let payload: String = msg.get_payload()?;
sender.send(payload.into()).await?;
}
Ok::<(), crate::error::Error>(())
});
// start task but don't wait for it
rt::spawn(async move {
// receive messages from websocket
while let Some(msg) = stream.next().await {
match msg {
Ok(AggregatedMessage::Text(text)) => {
let mut conn = data.cache_pool.get_multiplexed_tokio_connection().await?;
Ok::<(), crate::error::Error>(())
});
tokio::spawn(async move {
while let Some(msg) = receiver.next().await {
if let Ok(Message::Text(text)) = msg {
let message_body: MessageBody = serde_json::from_str(&text)?;
let message = channel
.new_message(&data, uuid, message_body.message, message_body.reply_to)
.new_message(
&app_state,
uuid,
message_body.message,
message_body.reply_to,
)
.await?;
redis::cmd("PUBLISH")
.arg(&[channel_uuid.to_string(), serde_json::to_string(&message)?])
.exec_async(&mut conn)
.exec_async(
&mut app_state
.cache_pool
.get_multiplexed_tokio_connection()
.await?,
)
.await?;
}
Ok(AggregatedMessage::Binary(bin)) => {
// echo binary message
session_2.binary(bin).await?;
}
Ok(AggregatedMessage::Ping(msg)) => {
// respond to PING frame with PONG frame
session_2.pong(&msg).await?;
}
_ => {}
}
}
Ok::<(), crate::error::Error>(())
Ok::<(), crate::error::Error>(())
});
});
let headers = res.headers_mut();
headers.append(
SEC_WEBSOCKET_PROTOCOL,
HeaderValue::from_str("Authorization")?,
axum::http::header::SEC_WEBSOCKET_PROTOCOL,
"Authorization".parse()?,
);
// respond immediately with response connected to WS session

View file

@ -17,7 +17,6 @@ mod users;
pub fn router(app_state: Arc<AppState>) -> Router<Arc<AppState>> {
let router_with_auth = Router::new()
.nest("/users", users::router())
.nest("/channels", channels::router())
.nest("/guilds", guilds::router())
.nest("/invites", invites::router())
.nest("/me", me::router())
@ -28,6 +27,7 @@ pub fn router(app_state: Arc<AppState>) -> Router<Arc<AppState>> {
Router::new()
.route("/stats", get(stats::res))
.nest("/auth", auth::router(app_state))
.nest("/auth", auth::router(app_state.clone()))
.nest("/channels", channels::router(app_state))
.merge(router_with_auth)
}

View file

@ -83,6 +83,9 @@ pub enum Error {
TooManyRequests(String),
#[error("{0}")]
InternalServerError(String),
// TODO: remove when doing socket.io
#[error(transparent)]
AxumError(#[from] axum::Error),
}
impl IntoResponse for Error {