axum rewrite #35
6 changed files with 91 additions and 70 deletions
|
@ -38,7 +38,7 @@ bindet = "0.3.2"
|
|||
bunny-api-tokio = { version = "0.4", features = ["edge_storage"], default-features = false }
|
||||
|
||||
# Web Server
|
||||
axum = { version = "0.8.4", features = ["macros", "multipart"] }
|
||||
axum = { version = "0.8.4", features = ["macros", "multipart", "ws"] }
|
||||
tower-http = { version = "0.6.6", features = ["cors"] }
|
||||
axum-extra = { version = "0.10.1", features = ["cookie", "typed-header"] }
|
||||
socketioxide = { version = "0.17.2", features = ["state"] }
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::sync::Arc;
|
|||
|
||||
use axum::{
|
||||
Router,
|
||||
routing::{delete, get, patch},
|
||||
routing::{any, delete, get, patch},
|
||||
};
|
||||
//use socketioxide::SocketIo;
|
||||
|
||||
|
@ -15,5 +15,6 @@ pub fn router() -> Router<Arc<AppState>> {
|
|||
.route("/{uuid}", get(uuid::get))
|
||||
.route("/{uuid}", delete(uuid::delete))
|
||||
.route("/{uuid}", patch(uuid::patch))
|
||||
.route("/{uuid}/socket", any(uuid::socket::ws))
|
||||
.route("/{uuid}/messages", get(uuid::messages::get))
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
//! `/api/v1/channels/{uuid}` Channel specific endpoints
|
||||
|
||||
pub mod messages;
|
||||
//pub mod socket;
|
||||
pub mod socket;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
|
|
|
@ -1,18 +1,21 @@
|
|||
use actix_web::{
|
||||
Error, HttpRequest, HttpResponse, get,
|
||||
http::header::{HeaderValue, SEC_WEBSOCKET_PROTOCOL},
|
||||
rt, web,
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
extract::{Path, State, WebSocketUpgrade, ws::Message},
|
||||
http::HeaderMap,
|
||||
response::IntoResponse,
|
||||
};
|
||||
use actix_ws::AggregatedMessage;
|
||||
use futures::SinkExt;
|
||||
use futures_util::StreamExt as _;
|
||||
use serde::Deserialize;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
Data,
|
||||
AppState,
|
||||
api::v1::auth::check_access_token,
|
||||
error::Error,
|
||||
objects::{Channel, Member},
|
||||
utils::{get_ws_protocol_header, global_checks},
|
||||
utils::global_checks,
|
||||
};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
|
@ -21,100 +24,114 @@ struct MessageBody {
|
|||
reply_to: Option<Uuid>,
|
||||
}
|
||||
|
||||
#[get("/{uuid}/socket")]
|
||||
pub async fn ws(
|
||||
req: HttpRequest,
|
||||
path: web::Path<(Uuid,)>,
|
||||
stream: web::Payload,
|
||||
data: web::Data<Data>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
// Get all headers
|
||||
let headers = req.headers();
|
||||
|
||||
ws: WebSocketUpgrade,
|
||||
State(app_state): State<Arc<AppState>>,
|
||||
Path(channel_uuid): Path<Uuid>,
|
||||
headers: HeaderMap,
|
||||
) -> Result<impl IntoResponse, Error> {
|
||||
// Retrieve auth header
|
||||
let auth_header = get_ws_protocol_header(headers)?;
|
||||
let auth_token = headers.get(axum::http::header::SEC_WEBSOCKET_PROTOCOL);
|
||||
|
||||
// Get uuid from path
|
||||
let channel_uuid = path.into_inner().0;
|
||||
if auth_token.is_none() {
|
||||
return Err(Error::Unauthorized(
|
||||
"No authorization header provided".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut conn = data.pool.get().await.map_err(crate::error::Error::from)?;
|
||||
let auth_raw = auth_token.unwrap().to_str()?;
|
||||
|
||||
let mut auth = auth_raw.split_whitespace();
|
||||
|
||||
let response_proto = auth.next();
|
||||
|
||||
let auth_value = auth.next();
|
||||
|
||||
if response_proto.is_none() {
|
||||
return Err(Error::BadRequest(
|
||||
"Sec-WebSocket-Protocol header is empty".to_string(),
|
||||
));
|
||||
} else if response_proto.is_some_and(|rp| rp != "Authorization,") {
|
||||
return Err(Error::BadRequest(
|
||||
"First protocol should be Authorization".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if auth_value.is_none() {
|
||||
return Err(Error::BadRequest("No token provided".to_string()));
|
||||
}
|
||||
|
||||
let auth_header = auth_value.unwrap();
|
||||
|
||||
let mut conn = app_state
|
||||
.pool
|
||||
.get()
|
||||
.await
|
||||
.map_err(crate::error::Error::from)?;
|
||||
|
||||
// Authorize client using auth header
|
||||
let uuid = check_access_token(auth_header, &mut conn).await?;
|
||||
|
||||
global_checks(&data, uuid).await?;
|
||||
global_checks(&app_state, uuid).await?;
|
||||
|
||||
let channel = Channel::fetch_one(&data, channel_uuid).await?;
|
||||
let channel = Channel::fetch_one(&app_state, channel_uuid).await?;
|
||||
|
||||
Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?;
|
||||
|
||||
let (mut res, mut session_1, stream) = actix_ws::handle(&req, stream)?;
|
||||
|
||||
let mut stream = stream
|
||||
.aggregate_continuations()
|
||||
// aggregate continuation frames up to 1MiB
|
||||
.max_continuation_size(2_usize.pow(20));
|
||||
|
||||
let mut pubsub = data
|
||||
let mut pubsub = app_state
|
||||
.cache_pool
|
||||
.get_async_pubsub()
|
||||
.await
|
||||
.map_err(crate::error::Error::from)?;
|
||||
|
||||
let mut session_2 = session_1.clone();
|
||||
let mut res = ws.on_upgrade(async move |socket| {
|
||||
let (mut sender, mut receiver) = socket.split();
|
||||
|
||||
rt::spawn(async move {
|
||||
pubsub.subscribe(channel_uuid.to_string()).await?;
|
||||
while let Some(msg) = pubsub.on_message().next().await {
|
||||
let payload: String = msg.get_payload()?;
|
||||
session_1.text(payload).await?;
|
||||
}
|
||||
tokio::spawn(async move {
|
||||
pubsub.subscribe(channel_uuid.to_string()).await?;
|
||||
while let Some(msg) = pubsub.on_message().next().await {
|
||||
let payload: String = msg.get_payload()?;
|
||||
sender.send(payload.into()).await?;
|
||||
}
|
||||
|
||||
Ok::<(), crate::error::Error>(())
|
||||
});
|
||||
|
||||
// start task but don't wait for it
|
||||
rt::spawn(async move {
|
||||
// receive messages from websocket
|
||||
while let Some(msg) = stream.next().await {
|
||||
match msg {
|
||||
Ok(AggregatedMessage::Text(text)) => {
|
||||
let mut conn = data.cache_pool.get_multiplexed_tokio_connection().await?;
|
||||
Ok::<(), crate::error::Error>(())
|
||||
});
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(msg) = receiver.next().await {
|
||||
if let Ok(Message::Text(text)) = msg {
|
||||
let message_body: MessageBody = serde_json::from_str(&text)?;
|
||||
|
||||
let message = channel
|
||||
.new_message(&data, uuid, message_body.message, message_body.reply_to)
|
||||
.new_message(
|
||||
&app_state,
|
||||
uuid,
|
||||
message_body.message,
|
||||
message_body.reply_to,
|
||||
)
|
||||
.await?;
|
||||
|
||||
redis::cmd("PUBLISH")
|
||||
.arg(&[channel_uuid.to_string(), serde_json::to_string(&message)?])
|
||||
.exec_async(&mut conn)
|
||||
.exec_async(
|
||||
&mut app_state
|
||||
.cache_pool
|
||||
.get_multiplexed_tokio_connection()
|
||||
.await?,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(AggregatedMessage::Binary(bin)) => {
|
||||
// echo binary message
|
||||
session_2.binary(bin).await?;
|
||||
}
|
||||
|
||||
Ok(AggregatedMessage::Ping(msg)) => {
|
||||
// respond to PING frame with PONG frame
|
||||
session_2.pong(&msg).await?;
|
||||
}
|
||||
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok::<(), crate::error::Error>(())
|
||||
Ok::<(), crate::error::Error>(())
|
||||
});
|
||||
});
|
||||
|
||||
let headers = res.headers_mut();
|
||||
|
||||
headers.append(
|
||||
SEC_WEBSOCKET_PROTOCOL,
|
||||
HeaderValue::from_str("Authorization")?,
|
||||
axum::http::header::SEC_WEBSOCKET_PROTOCOL,
|
||||
"Authorization".parse()?,
|
||||
);
|
||||
|
||||
// respond immediately with response connected to WS session
|
||||
|
|
|
@ -17,7 +17,6 @@ mod users;
|
|||
pub fn router(app_state: Arc<AppState>) -> Router<Arc<AppState>> {
|
||||
let router_with_auth = Router::new()
|
||||
.nest("/users", users::router())
|
||||
.nest("/channels", channels::router())
|
||||
.nest("/guilds", guilds::router())
|
||||
.nest("/invites", invites::router())
|
||||
.nest("/me", me::router())
|
||||
|
@ -28,6 +27,7 @@ pub fn router(app_state: Arc<AppState>) -> Router<Arc<AppState>> {
|
|||
|
||||
Router::new()
|
||||
.route("/stats", get(stats::res))
|
||||
.nest("/auth", auth::router(app_state))
|
||||
.nest("/auth", auth::router(app_state.clone()))
|
||||
.nest("/channels", channels::router(app_state))
|
||||
.merge(router_with_auth)
|
||||
}
|
||||
|
|
|
@ -83,6 +83,9 @@ pub enum Error {
|
|||
TooManyRequests(String),
|
||||
#[error("{0}")]
|
||||
InternalServerError(String),
|
||||
// TODO: remove when doing socket.io
|
||||
#[error(transparent)]
|
||||
AxumError(#[from] axum::Error),
|
||||
}
|
||||
|
||||
impl IntoResponse for Error {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue