From b4bb83b7f5504aefed1d0fb82b6d9771ee615cc8 Mon Sep 17 00:00:00 2001 From: Radical Date: Thu, 7 Aug 2025 21:38:01 +0200 Subject: [PATCH] feat: add event enum for sending and receiving messages on the socket Added in message editing and deleting with this change --- src/api/v1/channels/uuid/socket.rs | 195 +++++++++++++++++++++++++---- src/objects/mod.rs | 2 +- 2 files changed, 173 insertions(+), 24 deletions(-) diff --git a/src/api/v1/channels/uuid/socket.rs b/src/api/v1/channels/uuid/socket.rs index ac04301..b6d449b 100644 --- a/src/api/v1/channels/uuid/socket.rs +++ b/src/api/v1/channels/uuid/socket.rs @@ -5,24 +5,60 @@ use axum::{ http::HeaderMap, response::IntoResponse, }; +use diesel::{ExpressionMethods, QueryDsl, SelectableHelper, delete, update}; +use diesel_async::RunQueryDsl; use futures_util::{SinkExt, StreamExt}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::{ AppState, api::v1::auth::check_access_token, error::Error, - objects::{Channel, Member}, + objects::{self, Channel, Member, message::MessageBuilder}, + schema::messages, utils::global_checks, }; #[derive(Deserialize)] -struct MessageBody { - message: String, +#[serde(tag = "event")] +enum ReceiveEvent { + MessageSend { entity: MessageSend }, + MessageEdit { entity: MessageEdit }, + MessageDelete { entity: MessageDelete }, +} + +#[derive(Deserialize)] +struct MessageSend { + text: String, reply_to: Option, } +#[derive(Deserialize)] +struct MessageEdit { + uuid: Uuid, + text: String, +} + +#[derive(Deserialize, Serialize)] +struct MessageDelete { + uuid: Uuid, +} + +#[derive(Serialize)] +#[serde(tag = "event")] +enum SendEvent { + MessageSend { entity: objects::Message }, + MessageEdit { entity: objects::Message }, + MessageDelete { entity: MessageDelete }, + Error { entity: SendError }, +} + +#[derive(Serialize)] +struct SendError { + message: String, +} + pub async fn ws( ws: WebSocketUpgrade, State(app_state): State>, @@ -99,27 +135,140 @@ pub async fn ws( tokio::spawn(async move { while let Some(msg) = receiver.next().await { if let Ok(Message::Text(text)) = msg { - let message_body: MessageBody = serde_json::from_str(&text)?; + let message_body: ReceiveEvent = serde_json::from_str(&text)?; - let message = channel - .new_message( - &mut conn, - &app_state.cache_pool, - uuid, - message_body.message, - message_body.reply_to, - ) - .await?; + match message_body { + ReceiveEvent::MessageSend { entity } => { + let message = channel + .new_message( + &mut app_state.pool.get().await?, + &app_state.cache_pool, + uuid, + entity.text, + entity.reply_to, + ) + .await?; - redis::cmd("PUBLISH") - .arg(&[channel_uuid.to_string(), serde_json::to_string(&message)?]) - .exec_async( - &mut app_state - .cache_pool - .get_multiplexed_tokio_connection() - .await?, - ) - .await?; + redis::cmd("PUBLISH") + .arg(&[ + channel_uuid.to_string(), + serde_json::to_string(&SendEvent::MessageSend { + entity: message, + })?, + ]) + .exec_async( + &mut app_state + .cache_pool + .get_multiplexed_tokio_connection() + .await?, + ) + .await?; + } + ReceiveEvent::MessageEdit { entity } => { + use messages::dsl; + let mut message: MessageBuilder = dsl::messages + .filter(dsl::uuid.eq(entity.uuid)) + .select(MessageBuilder::as_select()) + .get_result(&mut app_state.pool.get().await?) + .await?; + + if uuid != message.user_uuid { + redis::cmd("PUBLISH") + .arg(&[ + channel_uuid.to_string(), + serde_json::to_string(&SendEvent::Error { + entity: SendError { + message: "Not allowed".to_string(), + }, + })?, + ]) + .exec_async( + &mut app_state + .cache_pool + .get_multiplexed_tokio_connection() + .await?, + ) + .await?; + + continue; + } + + update(messages::table) + .filter(dsl::uuid.eq(entity.uuid)) + .set(dsl::message.eq(&entity.text)) + .execute(&mut app_state.pool.get().await?) + .await?; + + message.message = entity.text; + + redis::cmd("PUBLISH") + .arg(&[ + channel_uuid.to_string(), + serde_json::to_string(&SendEvent::MessageEdit { + entity: message + .build( + &mut app_state.pool.get().await?, + &app_state.cache_pool, + ) + .await?, + })?, + ]) + .exec_async( + &mut app_state + .cache_pool + .get_multiplexed_tokio_connection() + .await?, + ) + .await?; + } + ReceiveEvent::MessageDelete { entity } => { + use messages::dsl; + let message: MessageBuilder = dsl::messages + .filter(dsl::uuid.eq(entity.uuid)) + .select(MessageBuilder::as_select()) + .get_result(&mut app_state.pool.get().await?) + .await?; + + if uuid != message.user_uuid { + redis::cmd("PUBLISH") + .arg(&[ + channel_uuid.to_string(), + serde_json::to_string(&SendEvent::Error { + entity: SendError { + message: "Not allowed".to_string(), + }, + })?, + ]) + .exec_async( + &mut app_state + .cache_pool + .get_multiplexed_tokio_connection() + .await?, + ) + .await?; + + continue; + } + + delete(messages::table) + .filter(dsl::uuid.eq(entity.uuid)) + .execute(&mut app_state.pool.get().await?) + .await?; + + redis::cmd("PUBLISH") + .arg(&[ + channel_uuid.to_string(), + serde_json::to_string(&SendEvent::MessageDelete { entity })?, + ]) + .exec_async( + &mut app_state + .cache_pool + .get_multiplexed_tokio_connection() + .await?, + ) + .await?; + } + } } } diff --git a/src/objects/mod.rs b/src/objects/mod.rs index 5a013ca..50fb4cb 100644 --- a/src/objects/mod.rs +++ b/src/objects/mod.rs @@ -15,7 +15,7 @@ mod guild; mod invite; mod me; mod member; -mod message; +pub mod message; mod password_reset_token; mod role; mod user;