diff --git a/src/api/v1/auth/register.rs b/src/api/v1/auth/register.rs index 545e5aa..807fab8 100644 --- a/src/api/v1/auth/register.rs +++ b/src/api/v1/auth/register.rs @@ -159,7 +159,7 @@ pub async fn post( .await?; if let Some(initial_guild) = app_state.config.instance.initial_guild { - Member::new(&mut conn, &app_state.cache_pool, uuid, initial_guild).await?; + Member::new(&app_state, uuid, initial_guild).await?; } let mut response = ( diff --git a/src/api/v1/auth/reset_password.rs b/src/api/v1/auth/reset_password.rs index 35c4b41..bac465c 100644 --- a/src/api/v1/auth/reset_password.rs +++ b/src/api/v1/auth/reset_password.rs @@ -38,17 +38,11 @@ pub async fn get( State(app_state): State>, query: Query, ) -> Result { - let mut conn = app_state.pool.get().await?; - - if let Ok(password_reset_token) = PasswordResetToken::get_with_identifier( - &mut conn, - &app_state.cache_pool, - query.identifier.clone(), - ) - .await + if let Ok(password_reset_token) = + PasswordResetToken::get_with_identifier(&app_state, query.identifier.clone()).await { if Utc::now().signed_duration_since(password_reset_token.created_at) > Duration::hours(1) { - password_reset_token.delete(&app_state.cache_pool).await?; + password_reset_token.delete(&app_state).await?; } else { return Err(Error::TooManyRequests( "Please allow 1 hour before sending a new email".to_string(), @@ -56,7 +50,7 @@ pub async fn get( } } - PasswordResetToken::new(&mut conn, &app_state, query.identifier.clone()).await?; + PasswordResetToken::new(&app_state, query.identifier.clone()).await?; Ok(StatusCode::OK) } @@ -93,14 +87,10 @@ pub async fn post( reset_password: Json, ) -> Result { let password_reset_token = - PasswordResetToken::get(&app_state.cache_pool, reset_password.token.clone()).await?; + PasswordResetToken::get(&app_state, reset_password.token.clone()).await?; password_reset_token - .set_password( - &mut app_state.pool.get().await?, - &app_state, - reset_password.password.clone(), - ) + .set_password(&app_state, reset_password.password.clone()) .await?; Ok(StatusCode::OK) diff --git a/src/api/v1/auth/verify_email.rs b/src/api/v1/auth/verify_email.rs index 1cb8aef..0801768 100644 --- a/src/api/v1/auth/verify_email.rs +++ b/src/api/v1/auth/verify_email.rs @@ -55,7 +55,7 @@ pub async fn get( return Ok(StatusCode::NO_CONTENT); } - let email_token = EmailToken::get(&app_state.cache_pool, me.uuid).await?; + let email_token = EmailToken::get(&app_state, me.uuid).await?; if query.token != email_token.token { return Ok(StatusCode::UNAUTHORIZED); @@ -63,7 +63,7 @@ pub async fn get( me.verify_email(&mut conn).await?; - email_token.delete(&app_state.cache_pool).await?; + email_token.delete(&app_state).await?; Ok(StatusCode::OK) } @@ -91,9 +91,9 @@ pub async fn post( return Ok(StatusCode::NO_CONTENT); } - if let Ok(email_token) = EmailToken::get(&app_state.cache_pool, me.uuid).await { + if let Ok(email_token) = EmailToken::get(&app_state, me.uuid).await { if Utc::now().signed_duration_since(email_token.created_at) > Duration::hours(1) { - email_token.delete(&app_state.cache_pool).await?; + email_token.delete(&app_state).await?; } else { return Err(Error::TooManyRequests( "Please allow 1 hour before sending a new email".to_string(), diff --git a/src/api/v1/channels/uuid/messages.rs b/src/api/v1/channels/uuid/messages.rs index 1f9010d..b8f0ad6 100644 --- a/src/api/v1/channels/uuid/messages.rs +++ b/src/api/v1/channels/uuid/messages.rs @@ -60,21 +60,14 @@ pub async fn get( Query(message_request): Query, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + let channel = Channel::fetch_one(&app_state, channel_uuid).await?; - let channel = Channel::fetch_one(&mut conn, &app_state.cache_pool, channel_uuid).await?; - - Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; + Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid).await?; let messages = channel - .fetch_messages( - &mut conn, - &app_state.cache_pool, - message_request.amount, - message_request.offset, - ) + .fetch_messages(&app_state, message_request.amount, message_request.offset) .await?; Ok((StatusCode::OK, Json(messages))) diff --git a/src/api/v1/channels/uuid/mod.rs b/src/api/v1/channels/uuid/mod.rs index f5566b3..373742e 100644 --- a/src/api/v1/channels/uuid/mod.rs +++ b/src/api/v1/channels/uuid/mod.rs @@ -27,13 +27,11 @@ pub async fn get( Path(channel_uuid): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + let channel = Channel::fetch_one(&app_state, channel_uuid).await?; - let channel = Channel::fetch_one(&mut conn, &app_state.cache_pool, channel_uuid).await?; - - Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; + Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid).await?; Ok((StatusCode::OK, Json(channel))) } @@ -43,19 +41,19 @@ pub async fn delete( Path(channel_uuid): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + let channel = Channel::fetch_one(&app_state, channel_uuid).await?; - let channel = Channel::fetch_one(&mut conn, &app_state.cache_pool, channel_uuid).await?; - - let member = Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; + let member = + Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid) + .await?; member - .check_permission(&mut conn, &app_state.cache_pool, Permissions::ManageChannel) + .check_permission(&app_state, Permissions::ManageChannel) .await?; - channel.delete(&mut conn, &app_state.cache_pool).await?; + channel.delete(&app_state).await?; Ok(StatusCode::OK) } @@ -104,37 +102,31 @@ pub async fn patch( Extension(CurrentUser(uuid)): Extension>, Json(new_info): Json, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + let mut channel = Channel::fetch_one(&app_state, channel_uuid).await?; - let mut channel = Channel::fetch_one(&mut conn, &app_state.cache_pool, channel_uuid).await?; - - let member = Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; + let member = + Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid) + .await?; member - .check_permission(&mut conn, &app_state.cache_pool, Permissions::ManageChannel) + .check_permission(&app_state, Permissions::ManageChannel) .await?; if let Some(new_name) = &new_info.name { - channel - .set_name(&mut conn, &app_state.cache_pool, new_name.to_string()) - .await?; + channel.set_name(&app_state, new_name.to_string()).await?; } if let Some(new_description) = &new_info.description { channel - .set_description( - &mut conn, - &app_state.cache_pool, - new_description.to_string(), - ) + .set_description(&app_state, new_description.to_string()) .await?; } if let Some(new_is_above) = &new_info.is_above { channel - .set_description(&mut conn, &app_state.cache_pool, new_is_above.to_string()) + .set_description(&app_state, new_is_above.to_string()) .await?; } diff --git a/src/api/v1/channels/uuid/socket.rs b/src/api/v1/channels/uuid/socket.rs index ac04301..dd020e3 100644 --- a/src/api/v1/channels/uuid/socket.rs +++ b/src/api/v1/channels/uuid/socket.rs @@ -71,9 +71,9 @@ pub async fn ws( // Authorize client using auth header let uuid = check_access_token(auth_header, &mut conn).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + global_checks(&app_state, uuid).await?; - let channel = Channel::fetch_one(&mut conn, &app_state.cache_pool, channel_uuid).await?; + let channel = Channel::fetch_one(&app_state, channel_uuid).await?; Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; @@ -103,8 +103,7 @@ pub async fn ws( let message = channel .new_message( - &mut conn, - &app_state.cache_pool, + &app_state, uuid, message_body.message, message_body.reply_to, diff --git a/src/api/v1/guilds/mod.rs b/src/api/v1/guilds/mod.rs index 5b9f089..8118522 100644 --- a/src/api/v1/guilds/mod.rs +++ b/src/api/v1/guilds/mod.rs @@ -128,11 +128,9 @@ pub async fn get_guilds( let start = request_query.start.unwrap_or(0); let amount = request_query.amount.unwrap_or(10); - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; - - let guilds = Guild::fetch_amount(&mut conn, start, amount).await?; + let guilds = Guild::fetch_amount(&app_state.pool, start, amount).await?; Ok((StatusCode::OK, Json(guilds))) } diff --git a/src/api/v1/guilds/uuid/bans.rs b/src/api/v1/guilds/uuid/bans.rs index 2e31a59..29d5a05 100644 --- a/src/api/v1/guilds/uuid/bans.rs +++ b/src/api/v1/guilds/uuid/bans.rs @@ -21,13 +21,13 @@ pub async fn get( Path(guild_uuid): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + let mut conn = app_state.pool.get().await?; let caller = Member::check_membership(&mut conn, uuid, guild_uuid).await?; caller - .check_permission(&mut conn, &app_state.cache_pool, Permissions::BanMember) + .check_permission(&app_state, Permissions::BanMember) .await?; let all_guild_bans = GuildBan::fetch_all(&mut conn, guild_uuid).await?; @@ -40,13 +40,13 @@ pub async fn unban( Path((guild_uuid, user_uuid)): Path<(Uuid, Uuid)>, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + let mut conn = app_state.pool.get().await?; let caller = Member::check_membership(&mut conn, uuid, guild_uuid).await?; caller - .check_permission(&mut conn, &app_state.cache_pool, Permissions::BanMember) + .check_permission(&app_state, Permissions::BanMember) .await?; let ban = GuildBan::fetch_one(&mut conn, guild_uuid, user_uuid).await?; diff --git a/src/api/v1/guilds/uuid/channels.rs b/src/api/v1/guilds/uuid/channels.rs index 1cd7f78..82368b9 100644 --- a/src/api/v1/guilds/uuid/channels.rs +++ b/src/api/v1/guilds/uuid/channels.rs @@ -14,7 +14,7 @@ use crate::{ api::v1::auth::CurrentUser, error::Error, objects::{Channel, Member, Permissions}, - utils::{CacheFns, global_checks, order_by_is_above}, + utils::{global_checks, order_by_is_above}, }; #[derive(Deserialize)] @@ -28,26 +28,23 @@ pub async fn get( Path(guild_uuid): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; - - Member::check_membership(&mut conn, uuid, guild_uuid).await?; + Member::check_membership(&mut app_state.pool.get().await?, uuid, guild_uuid).await?; if let Ok(cache_hit) = app_state - .cache_pool - .get_cache_key::>(format!("{guild_uuid}_channels")) + .get_cache_key(format!("{guild_uuid}_channels")) .await + && let Ok(channels) = serde_json::from_str::>(&cache_hit) { - return Ok((StatusCode::OK, Json(cache_hit)).into_response()); + return Ok((StatusCode::OK, Json(channels)).into_response()); } - let channels = Channel::fetch_all(&mut conn, guild_uuid).await?; + let channels = Channel::fetch_all(&app_state.pool, guild_uuid).await?; let channels_ordered = order_by_is_above(channels).await?; app_state - .cache_pool .set_cache_key( format!("{guild_uuid}_channels"), channels_ordered.clone(), @@ -64,19 +61,17 @@ pub async fn create( Extension(CurrentUser(uuid)): Extension>, Json(channel_info): Json, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; - - let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; + let member = + Member::check_membership(&mut app_state.pool.get().await?, uuid, guild_uuid).await?; member - .check_permission(&mut conn, &app_state.cache_pool, Permissions::ManageChannel) + .check_permission(&app_state, Permissions::ManageChannel) .await?; let channel = Channel::new( - &mut conn, - &app_state.cache_pool, + &app_state, guild_uuid, channel_info.name.clone(), channel_info.description.clone(), diff --git a/src/api/v1/guilds/uuid/invites/mod.rs b/src/api/v1/guilds/uuid/invites/mod.rs index fa06f44..649fc16 100644 --- a/src/api/v1/guilds/uuid/invites/mod.rs +++ b/src/api/v1/guilds/uuid/invites/mod.rs @@ -27,9 +27,9 @@ pub async fn get( Path(guild_uuid): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + let mut conn = app_state.pool.get().await?; Member::check_membership(&mut conn, uuid, guild_uuid).await?; @@ -46,14 +46,14 @@ pub async fn create( Extension(CurrentUser(uuid)): Extension>, Json(invite_request): Json, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + let mut conn = app_state.pool.get().await?; let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; member - .check_permission(&mut conn, &app_state.cache_pool, Permissions::CreateInvite) + .check_permission(&app_state, Permissions::CreateInvite) .await?; let guild = Guild::fetch_one(&mut conn, guild_uuid).await?; diff --git a/src/api/v1/guilds/uuid/members.rs b/src/api/v1/guilds/uuid/members.rs index 56710af..3ae10f7 100644 --- a/src/api/v1/guilds/uuid/members.rs +++ b/src/api/v1/guilds/uuid/members.rs @@ -21,15 +21,15 @@ pub async fn get( Path(guild_uuid): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + let mut conn = app_state.pool.get().await?; Member::check_membership(&mut conn, uuid, guild_uuid).await?; let me = Me::get(&mut conn, uuid).await?; - let members = Member::fetch_all(&mut conn, &app_state.cache_pool, &me, guild_uuid).await?; + let members = Member::fetch_all(&app_state, &me, guild_uuid).await?; Ok((StatusCode::OK, Json(members))) } diff --git a/src/api/v1/guilds/uuid/mod.rs b/src/api/v1/guilds/uuid/mod.rs index 53f469b..65a7c76 100644 --- a/src/api/v1/guilds/uuid/mod.rs +++ b/src/api/v1/guilds/uuid/mod.rs @@ -86,9 +86,9 @@ pub async fn get_guild( Path(guild_uuid): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + let mut conn = app_state.pool.get().await?; Member::check_membership(&mut conn, uuid, guild_uuid).await?; @@ -106,14 +106,14 @@ pub async fn edit( Extension(CurrentUser(uuid)): Extension>, mut multipart: Multipart, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + let mut conn = app_state.pool.get().await?; let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; member - .check_permission(&mut conn, &app_state.cache_pool, Permissions::ManageGuild) + .check_permission(&app_state, Permissions::ManageGuild) .await?; let mut guild = Guild::fetch_one(&mut conn, guild_uuid).await?; @@ -131,7 +131,14 @@ pub async fn edit( } if let Some(icon) = icon { - guild.set_icon(&mut conn, &app_state, icon).await?; + guild + .set_icon( + &app_state.bunny_storage, + &mut conn, + app_state.config.bunny.cdn_url.clone(), + icon, + ) + .await?; } Ok(StatusCode::OK) diff --git a/src/api/v1/guilds/uuid/roles/mod.rs b/src/api/v1/guilds/uuid/roles/mod.rs index d3660ce..0e496a0 100644 --- a/src/api/v1/guilds/uuid/roles/mod.rs +++ b/src/api/v1/guilds/uuid/roles/mod.rs @@ -14,7 +14,7 @@ use crate::{ api::v1::auth::CurrentUser, error::Error, objects::{Member, Permissions, Role}, - utils::{CacheFns, global_checks, order_by_is_above}, + utils::{global_checks, order_by_is_above}, }; pub mod uuid; @@ -29,18 +29,16 @@ pub async fn get( Path(guild_uuid): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + let mut conn = app_state.pool.get().await?; Member::check_membership(&mut conn, uuid, guild_uuid).await?; - if let Ok(cache_hit) = app_state - .cache_pool - .get_cache_key::>(format!("{guild_uuid}_roles")) - .await + if let Ok(cache_hit) = app_state.get_cache_key(format!("{guild_uuid}_roles")).await + && let Ok(roles) = serde_json::from_str::>(&cache_hit) { - return Ok((StatusCode::OK, Json(cache_hit)).into_response()); + return Ok((StatusCode::OK, Json(roles)).into_response()); } let roles = Role::fetch_all(&mut conn, guild_uuid).await?; @@ -48,7 +46,6 @@ pub async fn get( let roles_ordered = order_by_is_above(roles).await?; app_state - .cache_pool .set_cache_key(format!("{guild_uuid}_roles"), roles_ordered.clone(), 1800) .await?; @@ -61,14 +58,14 @@ pub async fn create( Extension(CurrentUser(uuid)): Extension>, Json(role_info): Json, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + let mut conn = app_state.pool.get().await?; let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; member - .check_permission(&mut conn, &app_state.cache_pool, Permissions::ManageRole) + .check_permission(&app_state, Permissions::ManageRole) .await?; let role = Role::new(&mut conn, guild_uuid, role_info.name.clone()).await?; diff --git a/src/api/v1/guilds/uuid/roles/uuid.rs b/src/api/v1/guilds/uuid/roles/uuid.rs index e7890d0..732d553 100644 --- a/src/api/v1/guilds/uuid/roles/uuid.rs +++ b/src/api/v1/guilds/uuid/roles/uuid.rs @@ -13,7 +13,7 @@ use crate::{ api::v1::auth::CurrentUser, error::Error, objects::{Member, Role}, - utils::{CacheFns, global_checks}, + utils::global_checks, }; pub async fn get( @@ -21,24 +21,21 @@ pub async fn get( Path((guild_uuid, role_uuid)): Path<(Uuid, Uuid)>, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + let mut conn = app_state.pool.get().await?; Member::check_membership(&mut conn, uuid, guild_uuid).await?; - if let Ok(cache_hit) = app_state - .cache_pool - .get_cache_key::(format!("{role_uuid}")) - .await + if let Ok(cache_hit) = app_state.get_cache_key(format!("{role_uuid}")).await + && let Ok(role) = serde_json::from_str::(&cache_hit) { - return Ok((StatusCode::OK, Json(cache_hit)).into_response()); + return Ok((StatusCode::OK, Json(role)).into_response()); } let role = Role::fetch_one(&mut conn, role_uuid).await?; app_state - .cache_pool .set_cache_key(format!("{role_uuid}"), role.clone(), 60) .await?; diff --git a/src/api/v1/invites/id.rs b/src/api/v1/invites/id.rs index 99f177f..72ceea4 100644 --- a/src/api/v1/invites/id.rs +++ b/src/api/v1/invites/id.rs @@ -34,15 +34,15 @@ pub async fn join( Path(invite_id): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + let mut conn = app_state.pool.get().await?; let invite = Invite::fetch_one(&mut conn, invite_id).await?; let guild = Guild::fetch_one(&mut conn, invite.guild_uuid).await?; - Member::new(&mut conn, &app_state.cache_pool, uuid, guild.uuid).await?; + Member::new(&app_state, uuid, guild.uuid).await?; Ok((StatusCode::OK, Json(guild))) } diff --git a/src/api/v1/me/friends/mod.rs b/src/api/v1/me/friends/mod.rs index 904a1f5..a56f8d4 100644 --- a/src/api/v1/me/friends/mod.rs +++ b/src/api/v1/me/friends/mod.rs @@ -19,13 +19,11 @@ pub async fn get( State(app_state): State>, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + let me = Me::get(&mut app_state.pool.get().await?, uuid).await?; - let me = Me::get(&mut conn, uuid).await?; - - let friends = me.get_friends(&mut conn, &app_state.cache_pool).await?; + let friends = me.get_friends(&app_state).await?; Ok((StatusCode::OK, Json(friends))) } @@ -59,9 +57,9 @@ pub async fn post( Extension(CurrentUser(uuid)): Extension>, Json(user_request): Json, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + let mut conn = app_state.pool.get().await?; let me = Me::get(&mut conn, uuid).await?; diff --git a/src/api/v1/me/friends/uuid.rs b/src/api/v1/me/friends/uuid.rs index 35f0742..5367435 100644 --- a/src/api/v1/me/friends/uuid.rs +++ b/src/api/v1/me/friends/uuid.rs @@ -17,9 +17,9 @@ pub async fn delete( Path(friend_uuid): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + let mut conn = app_state.pool.get().await?; let me = Me::get(&mut conn, uuid).await?; diff --git a/src/api/v1/me/guilds.rs b/src/api/v1/me/guilds.rs index 42d5c21..88dfad9 100644 --- a/src/api/v1/me/guilds.rs +++ b/src/api/v1/me/guilds.rs @@ -58,9 +58,9 @@ pub async fn get( State(app_state): State>, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + let mut conn = app_state.pool.get().await?; let me = Me::get(&mut conn, uuid).await?; diff --git a/src/api/v1/me/mod.rs b/src/api/v1/me/mod.rs index 86d3d9e..e167d14 100644 --- a/src/api/v1/me/mod.rs +++ b/src/api/v1/me/mod.rs @@ -73,41 +73,36 @@ pub async fn update( let json = json_raw.unwrap_or_default(); - let mut conn = app_state.pool.get().await?; - if avatar.is_some() || json.username.is_some() || json.display_name.is_some() { - global_checks(&mut conn, &app_state.config, uuid).await?; + global_checks(&app_state, uuid).await?; } - let mut me = Me::get(&mut conn, uuid).await?; + let mut me = Me::get(&mut app_state.pool.get().await?, uuid).await?; if let Some(avatar) = avatar { - me.set_avatar(&mut conn, &app_state, avatar).await?; - } - - if let Some(username) = &json.username { - me.set_username(&mut conn, &app_state.cache_pool, username.clone()) + me.set_avatar(&app_state, app_state.config.bunny.cdn_url.clone(), avatar) .await?; } + if let Some(username) = &json.username { + me.set_username(&app_state, username.clone()).await?; + } + if let Some(display_name) = &json.display_name { - me.set_display_name(&mut conn, &app_state.cache_pool, display_name.clone()) + me.set_display_name(&app_state, display_name.clone()) .await?; } if let Some(email) = &json.email { - me.set_email(&mut conn, &app_state.cache_pool, email.clone()) - .await?; + me.set_email(&app_state, email.clone()).await?; } if let Some(pronouns) = &json.pronouns { - me.set_pronouns(&mut conn, &app_state.cache_pool, pronouns.clone()) - .await?; + me.set_pronouns(&app_state, pronouns.clone()).await?; } if let Some(about) = &json.about { - me.set_about(&mut conn, &app_state.cache_pool, about.clone()) - .await?; + me.set_about(&app_state, about.clone()).await?; } Ok(StatusCode::OK) diff --git a/src/api/v1/members/uuid/ban.rs b/src/api/v1/members/uuid/ban.rs index b959efa..dfe53f6 100644 --- a/src/api/v1/members/uuid/ban.rs +++ b/src/api/v1/members/uuid/ban.rs @@ -29,16 +29,16 @@ pub async fn post( Extension(CurrentUser(uuid)): Extension>, Json(payload): Json, ) -> Result { + global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; - global_checks(&mut conn, &app_state.config, uuid).await?; - - let member = Member::fetch_one_with_member(&mut conn, &app_state.cache_pool, None, member_uuid).await?; + let member = Member::fetch_one_with_member(&app_state, None, member_uuid).await?; let caller = Member::check_membership(&mut conn, uuid, member.guild_uuid).await?; caller - .check_permission(&mut conn, &app_state.cache_pool, Permissions::BanMember) + .check_permission(&app_state, Permissions::BanMember) .await?; member.ban(&mut conn, &payload.reason).await?; diff --git a/src/api/v1/members/uuid/mod.rs b/src/api/v1/members/uuid/mod.rs index 2bdd1ba..42a4418 100644 --- a/src/api/v1/members/uuid/mod.rs +++ b/src/api/v1/members/uuid/mod.rs @@ -25,13 +25,13 @@ pub async fn get( Path(member_uuid): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + let mut conn = app_state.pool.get().await?; let me = Me::get(&mut conn, uuid).await?; - let member = Member::fetch_one_with_member(&mut conn, &app_state.cache_pool, Some(&me), member_uuid).await?; + let member = Member::fetch_one_with_member(&app_state, Some(&me), member_uuid).await?; Member::check_membership(&mut conn, uuid, member.guild_uuid).await?; Ok((StatusCode::OK, Json(member))) @@ -42,18 +42,18 @@ pub async fn delete( Path(member_uuid): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + let mut conn = app_state.pool.get().await?; let me = Me::get(&mut conn, uuid).await?; - let member = Member::fetch_one_with_member(&mut conn, &app_state.cache_pool, Some(&me), member_uuid).await?; + let member = Member::fetch_one_with_member(&app_state, Some(&me), member_uuid).await?; let deleter = Member::check_membership(&mut conn, uuid, member.guild_uuid).await?; deleter - .check_permission(&mut conn, &app_state.cache_pool, Permissions::KickMember) + .check_permission(&app_state, Permissions::KickMember) .await?; member.delete(&mut conn).await?; diff --git a/src/api/v1/users/mod.rs b/src/api/v1/users/mod.rs index 999e13f..a4b93ce 100644 --- a/src/api/v1/users/mod.rs +++ b/src/api/v1/users/mod.rs @@ -70,11 +70,9 @@ pub async fn users( return Ok(StatusCode::BAD_REQUEST.into_response()); } - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; - - let users = User::fetch_amount(&mut conn, start, amount).await?; + let users = User::fetch_amount(&mut app_state.pool.get().await?, start, amount).await?; Ok((StatusCode::OK, Json(users)).into_response()) } diff --git a/src/api/v1/users/uuid.rs b/src/api/v1/users/uuid.rs index e015c3c..cee6df0 100644 --- a/src/api/v1/users/uuid.rs +++ b/src/api/v1/users/uuid.rs @@ -39,14 +39,11 @@ pub async fn get( Path(user_uuid): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - global_checks(&mut conn, &app_state.config, uuid).await?; + let me = Me::get(&mut app_state.pool.get().await?, uuid).await?; - let me = Me::get(&mut conn, uuid).await?; - - let user = - User::fetch_one_with_friendship(&mut conn, &app_state.cache_pool, &me, user_uuid).await?; + let user = User::fetch_one_with_friendship(&app_state, &me, user_uuid).await?; Ok((StatusCode::OK, Json(user))) } diff --git a/src/objects/channel.rs b/src/objects/channel.rs index 03a2cf6..cacb153 100644 --- a/src/objects/channel.rs +++ b/src/objects/channel.rs @@ -2,15 +2,15 @@ use diesel::{ ExpressionMethods, Insertable, QueryDsl, Queryable, Selectable, SelectableHelper, delete, insert_into, update, }; -use diesel_async::RunQueryDsl; +use diesel_async::{RunQueryDsl, pooled_connection::AsyncDieselConnectionManager}; use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::{ - Conn, + AppState, Conn, error::Error, schema::{channel_permissions, channels, messages}, - utils::{CHANNEL_REGEX, CacheFns, order_by_is_above}, + utils::{CHANNEL_REGEX, order_by_is_above}, }; use super::{HasIsAbove, HasUuid, Message, load_or_empty, message::MessageBuilder}; @@ -79,44 +79,49 @@ impl HasIsAbove for Channel { } impl Channel { - pub async fn fetch_all(conn: &mut Conn, guild_uuid: Uuid) -> Result, Error> { + pub async fn fetch_all( + pool: &deadpool::managed::Pool< + AsyncDieselConnectionManager, + Conn, + >, + guild_uuid: Uuid, + ) -> Result, Error> { + let mut conn = pool.get().await?; + use channels::dsl; let channel_builders: Vec = load_or_empty( dsl::channels .filter(dsl::guild_uuid.eq(guild_uuid)) .select(ChannelBuilder::as_select()) - .load(conn) + .load(&mut conn) .await, )?; - let mut channels = vec![]; + let channel_futures = channel_builders.iter().map(async move |c| { + let mut conn = pool.get().await?; + c.clone().build(&mut conn).await + }); - for builder in channel_builders { - channels.push(builder.build(conn).await?); - } - - Ok(channels) + futures_util::future::try_join_all(channel_futures).await } - pub async fn fetch_one( - conn: &mut Conn, - cache_pool: &redis::Client, - channel_uuid: Uuid, - ) -> Result { - if let Ok(cache_hit) = cache_pool.get_cache_key(channel_uuid.to_string()).await { - return Ok(cache_hit); + pub async fn fetch_one(app_state: &AppState, channel_uuid: Uuid) -> Result { + if let Ok(cache_hit) = app_state.get_cache_key(channel_uuid.to_string()).await { + return Ok(serde_json::from_str(&cache_hit)?); } + let mut conn = app_state.pool.get().await?; + use channels::dsl; let channel_builder: ChannelBuilder = dsl::channels .filter(dsl::uuid.eq(channel_uuid)) .select(ChannelBuilder::as_select()) - .get_result(conn) + .get_result(&mut conn) .await?; - let channel = channel_builder.build(conn).await?; + let channel = channel_builder.build(&mut conn).await?; - cache_pool + app_state .set_cache_key(channel_uuid.to_string(), channel.clone(), 60) .await?; @@ -124,8 +129,7 @@ impl Channel { } pub async fn new( - conn: &mut Conn, - cache_pool: &redis::Client, + app_state: &AppState, guild_uuid: Uuid, name: String, description: Option, @@ -134,9 +138,11 @@ impl Channel { return Err(Error::BadRequest("Channel name is invalid".to_string())); } + let mut conn = app_state.pool.get().await?; + let channel_uuid = Uuid::now_v7(); - let channels = Self::fetch_all(conn, guild_uuid).await?; + let channels = Self::fetch_all(&app_state.pool, guild_uuid).await?; let channels_ordered = order_by_is_above(channels).await?; @@ -152,7 +158,7 @@ impl Channel { insert_into(channels::table) .values(new_channel.clone()) - .execute(conn) + .execute(&mut conn) .await?; if let Some(old_last_channel) = last_channel { @@ -160,7 +166,7 @@ impl Channel { update(channels::table) .filter(dsl::uuid.eq(old_last_channel.uuid)) .set(dsl::is_above.eq(new_channel.uuid)) - .execute(conn) + .execute(&mut conn) .await?; } @@ -174,16 +180,16 @@ impl Channel { permissions: vec![], }; - cache_pool + app_state .set_cache_key(channel_uuid.to_string(), channel.clone(), 1800) .await?; - if cache_pool - .get_cache_key::>(format!("{guild_uuid}_channels")) + if app_state + .get_cache_key(format!("{guild_uuid}_channels")) .await .is_ok() { - cache_pool + app_state .del_cache_key(format!("{guild_uuid}_channels")) .await?; } @@ -191,12 +197,14 @@ impl Channel { Ok(channel) } - pub async fn delete(self, conn: &mut Conn, cache_pool: &redis::Client) -> Result<(), Error> { + pub async fn delete(self, app_state: &AppState) -> Result<(), Error> { + let mut conn = app_state.pool.get().await?; + use channels::dsl; match update(channels::table) .filter(dsl::is_above.eq(self.uuid)) .set(dsl::is_above.eq(None::)) - .execute(conn) + .execute(&mut conn) .await { Ok(r) => Ok(r), @@ -206,13 +214,13 @@ impl Channel { delete(channels::table) .filter(dsl::uuid.eq(self.uuid)) - .execute(conn) + .execute(&mut conn) .await?; match update(channels::table) .filter(dsl::is_above.eq(self.uuid)) .set(dsl::is_above.eq(self.is_above)) - .execute(conn) + .execute(&mut conn) .await { Ok(r) => Ok(r), @@ -220,20 +228,16 @@ impl Channel { Err(e) => Err(e), }?; - if cache_pool - .get_cache_key::(self.uuid.to_string()) - .await - .is_ok() - { - cache_pool.del_cache_key(self.uuid.to_string()).await?; + if app_state.get_cache_key(self.uuid.to_string()).await.is_ok() { + app_state.del_cache_key(self.uuid.to_string()).await?; } - if cache_pool - .get_cache_key::>(format!("{}_channels", self.guild_uuid)) + if app_state + .get_cache_key(format!("{}_channels", self.guild_uuid)) .await .is_ok() { - cache_pool + app_state .del_cache_key(format!("{}_channels", self.guild_uuid)) .await?; } @@ -243,36 +247,32 @@ impl Channel { pub async fn fetch_messages( &self, - conn: &mut Conn, - cache_pool: &redis::Client, + app_state: &AppState, amount: i64, offset: i64, ) -> Result, Error> { + let mut conn = app_state.pool.get().await?; + use messages::dsl; - let message_builders: Vec = load_or_empty( + let messages: Vec = load_or_empty( dsl::messages .filter(dsl::channel_uuid.eq(self.uuid)) .select(MessageBuilder::as_select()) .order(dsl::uuid.desc()) .limit(amount) .offset(offset) - .load(conn) + .load(&mut conn) .await, )?; - let mut messages = vec![]; + let message_futures = messages.iter().map(async move |b| b.build(app_state).await); - for builder in message_builders { - messages.push(builder.build(conn, cache_pool).await?); - } - - Ok(messages) + futures_util::future::try_join_all(message_futures).await } pub async fn new_message( &self, - conn: &mut Conn, - cache_pool: &redis::Client, + app_state: &AppState, user_uuid: Uuid, message: String, reply_to: Option, @@ -287,101 +287,66 @@ impl Channel { reply_to, }; + let mut conn = app_state.pool.get().await?; + insert_into(messages::table) .values(message.clone()) - .execute(conn) + .execute(&mut conn) .await?; - message.build(conn, cache_pool).await + message.build(app_state).await } - pub async fn set_name( - &mut self, - conn: &mut Conn, - cache_pool: &redis::Client, - new_name: String, - ) -> Result<(), Error> { + pub async fn set_name(&mut self, app_state: &AppState, new_name: String) -> Result<(), Error> { if !CHANNEL_REGEX.is_match(&new_name) { return Err(Error::BadRequest("Channel name is invalid".to_string())); } + let mut conn = app_state.pool.get().await?; + use channels::dsl; update(channels::table) .filter(dsl::uuid.eq(self.uuid)) .set(dsl::name.eq(&new_name)) - .execute(conn) + .execute(&mut conn) .await?; self.name = new_name; - if cache_pool - .get_cache_key::(self.uuid.to_string()) - .await - .is_ok() - { - cache_pool.del_cache_key(self.uuid.to_string()).await?; - } - - if cache_pool - .get_cache_key::>(format!("{}_channels", self.guild_uuid)) - .await - .is_ok() - { - cache_pool - .del_cache_key(format!("{}_channels", self.guild_uuid)) - .await?; - } - Ok(()) } pub async fn set_description( &mut self, - conn: &mut Conn, - cache_pool: &redis::Client, + app_state: &AppState, new_description: String, ) -> Result<(), Error> { + let mut conn = app_state.pool.get().await?; + use channels::dsl; update(channels::table) .filter(dsl::uuid.eq(self.uuid)) .set(dsl::description.eq(&new_description)) - .execute(conn) + .execute(&mut conn) .await?; self.description = Some(new_description); - if cache_pool - .get_cache_key::(self.uuid.to_string()) - .await - .is_ok() - { - cache_pool.del_cache_key(self.uuid.to_string()).await?; - } - - if cache_pool - .get_cache_key::>(format!("{}_channels", self.guild_uuid)) - .await - .is_ok() - { - cache_pool - .del_cache_key(format!("{}_channels", self.guild_uuid)) - .await?; - } - Ok(()) } pub async fn move_channel( &mut self, - conn: &mut Conn, - cache_pool: &redis::Client, + app_state: &AppState, new_is_above: Uuid, ) -> Result<(), Error> { + let mut conn = app_state.pool.get().await?; + use channels::dsl; let old_above_uuid: Option = match dsl::channels .filter(dsl::is_above.eq(self.uuid)) .select(dsl::uuid) - .get_result(conn) + .get_result(&mut conn) .await { Ok(r) => Ok(Some(r)), @@ -393,14 +358,14 @@ impl Channel { update(channels::table) .filter(dsl::uuid.eq(uuid)) .set(dsl::is_above.eq(None::)) - .execute(conn) + .execute(&mut conn) .await?; } match update(channels::table) .filter(dsl::is_above.eq(new_is_above)) .set(dsl::is_above.eq(self.uuid)) - .execute(conn) + .execute(&mut conn) .await { Ok(r) => Ok(r), @@ -411,37 +376,19 @@ impl Channel { update(channels::table) .filter(dsl::uuid.eq(self.uuid)) .set(dsl::is_above.eq(new_is_above)) - .execute(conn) + .execute(&mut conn) .await?; if let Some(uuid) = old_above_uuid { update(channels::table) .filter(dsl::uuid.eq(uuid)) .set(dsl::is_above.eq(self.is_above)) - .execute(conn) + .execute(&mut conn) .await?; } self.is_above = Some(new_is_above); - if cache_pool - .get_cache_key::(self.uuid.to_string()) - .await - .is_ok() - { - cache_pool.del_cache_key(self.uuid.to_string()).await?; - } - - if cache_pool - .get_cache_key::>(format!("{}_channels", self.guild_uuid)) - .await - .is_ok() - { - cache_pool - .del_cache_key(format!("{}_channels", self.guild_uuid)) - .await?; - } - Ok(()) } } diff --git a/src/objects/email_token.rs b/src/objects/email_token.rs index c826620..64d2fdb 100644 --- a/src/objects/email_token.rs +++ b/src/objects/email_token.rs @@ -3,11 +3,7 @@ use lettre::message::MultiPart; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use crate::{ - AppState, - error::Error, - utils::{CacheFns, generate_token}, -}; +use crate::{AppState, error::Error, utils::generate_token}; use super::Me; @@ -19,10 +15,12 @@ pub struct EmailToken { } impl EmailToken { - pub async fn get(cache_pool: &redis::Client, user_uuid: Uuid) -> Result { - let email_token = cache_pool - .get_cache_key(format!("{user_uuid}_email_verify")) - .await?; + pub async fn get(app_state: &AppState, user_uuid: Uuid) -> Result { + let email_token = serde_json::from_str( + &app_state + .get_cache_key(format!("{user_uuid}_email_verify")) + .await?, + )?; Ok(email_token) } @@ -39,7 +37,6 @@ impl EmailToken { }; app_state - .cache_pool .set_cache_key(format!("{}_email_verify", me.uuid), email_token, 86400) .await?; @@ -62,8 +59,8 @@ impl EmailToken { Ok(()) } - pub async fn delete(&self, cache_pool: &redis::Client) -> Result<(), Error> { - cache_pool + pub async fn delete(&self, app_state: &AppState) -> Result<(), Error> { + app_state .del_cache_key(format!("{}_email_verify", self.user_uuid)) .await?; diff --git a/src/objects/guild.rs b/src/objects/guild.rs index 9640e28..9514e49 100644 --- a/src/objects/guild.rs +++ b/src/objects/guild.rs @@ -3,14 +3,14 @@ use diesel::{ ExpressionMethods, Insertable, QueryDsl, Queryable, Selectable, SelectableHelper, insert_into, update, }; -use diesel_async::RunQueryDsl; +use diesel_async::{RunQueryDsl, pooled_connection::AsyncDieselConnectionManager}; use serde::Serialize; use tokio::task; use url::Url; use uuid::Uuid; use crate::{ - AppState, Conn, + Conn, error::Error, schema::{guild_members, guilds, invites}, utils::image_check, @@ -68,11 +68,16 @@ impl Guild { } pub async fn fetch_amount( - conn: &mut Conn, + pool: &deadpool::managed::Pool< + AsyncDieselConnectionManager, + Conn, + >, offset: i64, amount: i64, ) -> Result, Error> { // Fetch guild data from database + let mut conn = pool.get().await?; + use guilds::dsl; let guild_builders: Vec = load_or_empty( dsl::guilds @@ -80,17 +85,18 @@ impl Guild { .order_by(dsl::uuid) .offset(offset) .limit(amount) - .load(conn) + .load(&mut conn) .await, )?; - let mut guilds = vec![]; + // Process each guild concurrently + let guild_futures = guild_builders.iter().map(async move |g| { + let mut conn = pool.get().await?; + g.clone().build(&mut conn).await + }); - for builder in guild_builders { - guilds.push(builder.build(conn).await?); - } - - Ok(guilds) + // Execute all futures concurrently and collect results + futures_util::future::try_join_all(guild_futures).await } pub async fn new(conn: &mut Conn, name: String, owner_uuid: Uuid) -> Result { @@ -182,8 +188,9 @@ impl Guild { // FIXME: Horrible security pub async fn set_icon( &mut self, + bunny_storage: &bunny_api_tokio::EdgeStorageClient, conn: &mut Conn, - app_state: &AppState, + cdn_url: Url, icon: Bytes, ) -> Result<(), Error> { let icon_clone = icon.clone(); @@ -192,14 +199,14 @@ impl Guild { if let Some(icon) = &self.icon { let relative_url = icon.path().trim_start_matches('/'); - app_state.bunny_storage.delete(relative_url).await?; + bunny_storage.delete(relative_url).await?; } let path = format!("icons/{}/{}.{}", self.uuid, Uuid::now_v7(), image_type); - app_state.bunny_storage.upload(path.clone(), icon).await?; + bunny_storage.upload(path.clone(), icon).await?; - let icon_url = app_state.config.bunny.cdn_url.join(&path)?; + let icon_url = cdn_url.join(&path)?; use guilds::dsl; update(guilds::table) diff --git a/src/objects/me.rs b/src/objects/me.rs index d03e08b..a0b399d 100644 --- a/src/objects/me.rs +++ b/src/objects/me.rs @@ -14,7 +14,7 @@ use crate::{ error::Error, objects::{Friend, FriendRequest, User}, schema::{friend_requests, friends, guild_members, guilds, users}, - utils::{CacheFns, EMAIL_REGEX, USERNAME_REGEX, image_check}, + utils::{EMAIL_REGEX, USERNAME_REGEX, image_check}, }; use super::{Guild, guild::GuildBuilder, load_or_empty, member::MemberBuilder}; @@ -75,13 +75,15 @@ impl Me { pub async fn set_avatar( &mut self, - conn: &mut Conn, app_state: &AppState, + cdn_url: Url, avatar: Bytes, ) -> Result<(), Error> { let avatar_clone = avatar.clone(); let image_type = task::spawn_blocking(move || image_check(avatar_clone)).await??; + let mut conn = app_state.pool.get().await?; + if let Some(avatar) = &self.avatar { let avatar_url: Url = avatar.parse()?; @@ -94,25 +96,17 @@ impl Me { app_state.bunny_storage.upload(path.clone(), avatar).await?; - let avatar_url = app_state.config.bunny.cdn_url.join(&path)?; + let avatar_url = cdn_url.join(&path)?; use users::dsl; update(users::table) .filter(dsl::uuid.eq(self.uuid)) .set(dsl::avatar.eq(avatar_url.as_str())) - .execute(conn) + .execute(&mut conn) .await?; - if app_state - .cache_pool - .get_cache_key::(self.uuid.to_string()) - .await - .is_ok() - { - app_state - .cache_pool - .del_cache_key(self.uuid.to_string()) - .await? + if app_state.get_cache_key(self.uuid.to_string()).await.is_ok() { + app_state.del_cache_key(self.uuid.to_string()).await? } self.avatar = Some(avatar_url.to_string()); @@ -133,8 +127,7 @@ impl Me { pub async fn set_username( &mut self, - conn: &mut Conn, - cache_pool: &redis::Client, + app_state: &AppState, new_username: String, ) -> Result<(), Error> { if !USERNAME_REGEX.is_match(&new_username) @@ -144,19 +137,17 @@ impl Me { return Err(Error::BadRequest("Invalid username".to_string())); } + let mut conn = app_state.pool.get().await?; + use users::dsl; update(users::table) .filter(dsl::uuid.eq(self.uuid)) .set(dsl::username.eq(new_username.as_str())) - .execute(conn) + .execute(&mut conn) .await?; - if cache_pool - .get_cache_key::(self.uuid.to_string()) - .await - .is_ok() - { - cache_pool.del_cache_key(self.uuid.to_string()).await? + if app_state.get_cache_key(self.uuid.to_string()).await.is_ok() { + app_state.del_cache_key(self.uuid.to_string()).await? } self.username = new_username; @@ -166,10 +157,11 @@ impl Me { pub async fn set_display_name( &mut self, - conn: &mut Conn, - cache_pool: &redis::Client, + app_state: &AppState, new_display_name: String, ) -> Result<(), Error> { + let mut conn = app_state.pool.get().await?; + let new_display_name_option = if new_display_name.is_empty() { None } else { @@ -180,15 +172,11 @@ impl Me { update(users::table) .filter(dsl::uuid.eq(self.uuid)) .set(dsl::display_name.eq(&new_display_name_option)) - .execute(conn) + .execute(&mut conn) .await?; - if cache_pool - .get_cache_key::(self.uuid.to_string()) - .await - .is_ok() - { - cache_pool.del_cache_key(self.uuid.to_string()).await? + if app_state.get_cache_key(self.uuid.to_string()).await.is_ok() { + app_state.del_cache_key(self.uuid.to_string()).await? } self.display_name = new_display_name_option; @@ -198,14 +186,15 @@ impl Me { pub async fn set_email( &mut self, - conn: &mut Conn, - cache_pool: &redis::Client, + app_state: &AppState, new_email: String, ) -> Result<(), Error> { if !EMAIL_REGEX.is_match(&new_email) { return Err(Error::BadRequest("Invalid username".to_string())); } + let mut conn = app_state.pool.get().await?; + use users::dsl; update(users::table) .filter(dsl::uuid.eq(self.uuid)) @@ -213,15 +202,11 @@ impl Me { dsl::email.eq(new_email.as_str()), dsl::email_verified.eq(false), )) - .execute(conn) + .execute(&mut conn) .await?; - if cache_pool - .get_cache_key::(self.uuid.to_string()) - .await - .is_ok() - { - cache_pool.del_cache_key(self.uuid.to_string()).await? + if app_state.get_cache_key(self.uuid.to_string()).await.is_ok() { + app_state.del_cache_key(self.uuid.to_string()).await? } self.email = new_email; @@ -231,23 +216,20 @@ impl Me { pub async fn set_pronouns( &mut self, - conn: &mut Conn, - cache_pool: &redis::Client, + app_state: &AppState, new_pronouns: String, ) -> Result<(), Error> { + let mut conn = app_state.pool.get().await?; + use users::dsl; update(users::table) .filter(dsl::uuid.eq(self.uuid)) .set((dsl::pronouns.eq(new_pronouns.as_str()),)) - .execute(conn) + .execute(&mut conn) .await?; - if cache_pool - .get_cache_key::(self.uuid.to_string()) - .await - .is_ok() - { - cache_pool.del_cache_key(self.uuid.to_string()).await? + if app_state.get_cache_key(self.uuid.to_string()).await.is_ok() { + app_state.del_cache_key(self.uuid.to_string()).await? } Ok(()) @@ -255,23 +237,20 @@ impl Me { pub async fn set_about( &mut self, - conn: &mut Conn, - cache_pool: &redis::Client, + app_state: &AppState, new_about: String, ) -> Result<(), Error> { + let mut conn = app_state.pool.get().await?; + use users::dsl; update(users::table) .filter(dsl::uuid.eq(self.uuid)) .set((dsl::about.eq(new_about.as_str()),)) - .execute(conn) + .execute(&mut conn) .await?; - if cache_pool - .get_cache_key::(self.uuid.to_string()) - .await - .is_ok() - { - cache_pool.del_cache_key(self.uuid.to_string()).await? + if app_state.get_cache_key(self.uuid.to_string()).await.is_ok() { + app_state.del_cache_key(self.uuid.to_string()).await? } Ok(()) @@ -387,18 +366,16 @@ impl Me { Ok(()) } - pub async fn get_friends( - &self, - conn: &mut Conn, - cache_pool: &redis::Client, - ) -> Result, Error> { + pub async fn get_friends(&self, app_state: &AppState) -> Result, Error> { use friends::dsl; + let mut conn = app_state.pool.get().await?; + let friends1 = load_or_empty( dsl::friends .filter(dsl::uuid1.eq(self.uuid)) .select(Friend::as_select()) - .load(conn) + .load(&mut conn) .await, )?; @@ -406,21 +383,21 @@ impl Me { dsl::friends .filter(dsl::uuid2.eq(self.uuid)) .select(Friend::as_select()) - .load(conn) + .load(&mut conn) .await, )?; - let mut friends = vec![]; + let friend_futures = friends1.iter().map(async move |friend| { + User::fetch_one_with_friendship(app_state, self, friend.uuid2).await + }); - for friend in friends1 { - friends - .push(User::fetch_one_with_friendship(conn, cache_pool, self, friend.uuid2).await?); - } + let mut friends = futures_util::future::try_join_all(friend_futures).await?; - for friend in friends2 { - friends - .push(User::fetch_one_with_friendship(conn, cache_pool, self, friend.uuid1).await?); - } + let friend_futures = friends2.iter().map(async move |friend| { + User::fetch_one_with_friendship(app_state, self, friend.uuid1).await + }); + + friends.append(&mut futures_util::future::try_join_all(friend_futures).await?); Ok(friends) } diff --git a/src/objects/member.rs b/src/objects/member.rs index 3eb8c4d..dc35e08 100644 --- a/src/objects/member.rs +++ b/src/objects/member.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::{ - Conn, + AppState, Conn, error::Error, objects::{GuildBan, Me, Permissions, Role}, schema::{guild_bans, guild_members}, @@ -27,18 +27,13 @@ pub struct MemberBuilder { } impl MemberBuilder { - pub async fn build( - &self, - conn: &mut Conn, - cache_pool: &redis::Client, - me: Option<&Me>, - ) -> Result { + pub async fn build(&self, app_state: &AppState, me: Option<&Me>) -> Result { let user; if let Some(me) = me { - user = User::fetch_one_with_friendship(conn, cache_pool, me, self.user_uuid).await?; + user = User::fetch_one_with_friendship(app_state, me, self.user_uuid).await?; } else { - user = User::fetch_one(conn, cache_pool, self.user_uuid).await?; + user = User::fetch_one(app_state, self.user_uuid).await?; } Ok(Member { @@ -53,12 +48,11 @@ impl MemberBuilder { pub async fn check_permission( &self, - conn: &mut Conn, - cache_pool: &redis::Client, + app_state: &AppState, permission: Permissions, ) -> Result<(), Error> { if !self.is_owner { - let roles = Role::fetch_from_member(conn, cache_pool, self.uuid).await?; + let roles = Role::fetch_from_member(app_state, self.uuid).await?; let allowed = roles.iter().any(|r| r.permissions & permission as i64 != 0); if !allowed { return Err(Error::Forbidden("Not allowed".to_string())); @@ -108,71 +102,74 @@ impl Member { } pub async fn fetch_one( - conn: &mut Conn, - cache_pool: &redis::Client, + app_state: &AppState, me: &Me, user_uuid: Uuid, guild_uuid: Uuid, ) -> Result { + let mut conn = app_state.pool.get().await?; + use guild_members::dsl; let member: MemberBuilder = dsl::guild_members .filter(dsl::user_uuid.eq(user_uuid)) .filter(dsl::guild_uuid.eq(guild_uuid)) .select(MemberBuilder::as_select()) - .get_result(conn) + .get_result(&mut conn) .await?; - member.build(conn, cache_pool, Some(me)).await + member.build(app_state, Some(me)).await } pub async fn fetch_one_with_member( - conn: &mut Conn, - cache_pool: &redis::Client, + app_state: &AppState, me: Option<&Me>, uuid: Uuid, ) -> Result { + let mut conn = app_state.pool.get().await?; + use guild_members::dsl; let member: MemberBuilder = dsl::guild_members .filter(dsl::uuid.eq(uuid)) .select(MemberBuilder::as_select()) - .get_result(conn) + .get_result(&mut conn) .await?; - member.build(conn, cache_pool, me).await + member.build(app_state, me).await } pub async fn fetch_all( - conn: &mut Conn, - cache_pool: &redis::Client, + app_state: &AppState, me: &Me, guild_uuid: Uuid, ) -> Result, Error> { + let mut conn = app_state.pool.get().await?; + use guild_members::dsl; let member_builders: Vec = load_or_empty( dsl::guild_members .filter(dsl::guild_uuid.eq(guild_uuid)) .select(MemberBuilder::as_select()) - .load(conn) + .load(&mut conn) .await, )?; let mut members = vec![]; for builder in member_builders { - members.push(builder.build(conn, cache_pool, Some(me)).await?); + members.push(builder.build(app_state, Some(me)).await?); } Ok(members) } pub async fn new( - conn: &mut Conn, - cache_pool: &redis::Client, + app_state: &AppState, user_uuid: Uuid, guild_uuid: Uuid, ) -> Result { - let banned = GuildBan::fetch_one(conn, guild_uuid, user_uuid).await; + let mut conn = app_state.pool.get().await?; + let banned = GuildBan::fetch_one(&mut conn, guild_uuid, user_uuid).await; match banned { Ok(_) => Err(Error::Forbidden("User banned".to_string())), Err(Error::SqlError(diesel::result::Error::NotFound)) => Ok(()), @@ -191,10 +188,10 @@ impl Member { insert_into(guild_members::table) .values(&member) - .execute(conn) + .execute(&mut conn) .await?; - member.build(conn, cache_pool, None).await + member.build(app_state, None).await } pub async fn delete(self, conn: &mut Conn) -> Result<(), Error> { diff --git a/src/objects/message.rs b/src/objects/message.rs index a5224e0..caff969 100644 --- a/src/objects/message.rs +++ b/src/objects/message.rs @@ -2,7 +2,7 @@ use diesel::{Insertable, Queryable, Selectable}; use serde::Serialize; use uuid::Uuid; -use crate::{Conn, error::Error, schema::messages}; +use crate::{AppState, error::Error, schema::messages}; use super::User; @@ -18,12 +18,8 @@ pub struct MessageBuilder { } impl MessageBuilder { - pub async fn build( - &self, - conn: &mut Conn, - cache_pool: &redis::Client, - ) -> Result { - let user = User::fetch_one(conn, cache_pool, self.user_uuid).await?; + pub async fn build(&self, app_state: &AppState) -> Result { + let user = User::fetch_one(app_state, self.user_uuid).await?; Ok(Message { uuid: self.uuid, diff --git a/src/objects/password_reset_token.rs b/src/objects/password_reset_token.rs index ca5c62f..04ff43c 100644 --- a/src/objects/password_reset_token.rs +++ b/src/objects/password_reset_token.rs @@ -10,10 +10,10 @@ use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::{ - AppState, Conn, + AppState, error::Error, schema::users, - utils::{CacheFns, PASSWORD_REGEX, generate_token, global_checks, user_uuid_from_identifier}, + utils::{PASSWORD_REGEX, generate_token, global_checks, user_uuid_from_identifier}, }; #[derive(Serialize, Deserialize)] @@ -24,49 +24,50 @@ pub struct PasswordResetToken { } impl PasswordResetToken { - pub async fn get( - cache_pool: &redis::Client, - token: String, - ) -> Result { - let user_uuid: Uuid = cache_pool.get_cache_key(token.to_string()).await?; - let password_reset_token = cache_pool - .get_cache_key(format!("{user_uuid}_password_reset")) - .await?; + pub async fn get(app_state: &AppState, token: String) -> Result { + let user_uuid: Uuid = + serde_json::from_str(&app_state.get_cache_key(token.to_string()).await?)?; + let password_reset_token = serde_json::from_str( + &app_state + .get_cache_key(format!("{user_uuid}_password_reset")) + .await?, + )?; Ok(password_reset_token) } pub async fn get_with_identifier( - conn: &mut Conn, - cache_pool: &redis::Client, + app_state: &AppState, identifier: String, ) -> Result { - let user_uuid = user_uuid_from_identifier(conn, &identifier).await?; + let mut conn = app_state.pool.get().await?; - let password_reset_token = cache_pool - .get_cache_key(format!("{user_uuid}_password_reset")) - .await?; + let user_uuid = user_uuid_from_identifier(&mut conn, &identifier).await?; + + let password_reset_token = serde_json::from_str( + &app_state + .get_cache_key(format!("{user_uuid}_password_reset")) + .await?, + )?; Ok(password_reset_token) } #[allow(clippy::new_ret_no_self)] - pub async fn new( - conn: &mut Conn, - app_state: &AppState, - identifier: String, - ) -> Result<(), Error> { + pub async fn new(app_state: &AppState, identifier: String) -> Result<(), Error> { let token = generate_token::<32>()?; - let user_uuid = user_uuid_from_identifier(conn, &identifier).await?; + let mut conn = app_state.pool.get().await?; - global_checks(conn, &app_state.config, user_uuid).await?; + let user_uuid = user_uuid_from_identifier(&mut conn, &identifier).await?; + + global_checks(app_state, user_uuid).await?; use users::dsl as udsl; let (username, email_address): (String, String) = udsl::users .filter(udsl::uuid.eq(user_uuid)) .select((udsl::username, udsl::email)) - .get_result(conn) + .get_result(&mut conn) .await?; let password_reset_token = PasswordResetToken { @@ -76,7 +77,6 @@ impl PasswordResetToken { }; app_state - .cache_pool .set_cache_key( format!("{user_uuid}_password_reset"), password_reset_token, @@ -84,7 +84,6 @@ impl PasswordResetToken { ) .await?; app_state - .cache_pool .set_cache_key(token.clone(), user_uuid, 86400) .await?; @@ -107,12 +106,7 @@ impl PasswordResetToken { Ok(()) } - pub async fn set_password( - &self, - conn: &mut Conn, - app_state: &AppState, - password: String, - ) -> Result<(), Error> { + pub async fn set_password(&self, app_state: &AppState, password: String) -> Result<(), Error> { if !PASSWORD_REGEX.is_match(&password) { return Err(Error::BadRequest( "Please provide a valid password".to_string(), @@ -126,17 +120,19 @@ impl PasswordResetToken { .hash_password(password.as_bytes(), &salt) .map_err(|e| Error::PasswordHashError(e.to_string()))?; + let mut conn = app_state.pool.get().await?; + use users::dsl; update(users::table) .filter(dsl::uuid.eq(self.user_uuid)) .set(dsl::password.eq(hashed_password.to_string())) - .execute(conn) + .execute(&mut conn) .await?; let (username, email_address): (String, String) = dsl::users .filter(dsl::uuid.eq(self.user_uuid)) .select((dsl::username, dsl::email)) - .get_result(conn) + .get_result(&mut conn) .await?; let login_page = app_state.config.web.frontend_url.join("login")?; @@ -153,14 +149,14 @@ impl PasswordResetToken { app_state.mail_client.send_mail(email).await?; - self.delete(&app_state.cache_pool).await + self.delete(app_state).await } - pub async fn delete(&self, cache_pool: &redis::Client) -> Result<(), Error> { - cache_pool + pub async fn delete(&self, app_state: &AppState) -> Result<(), Error> { + app_state .del_cache_key(format!("{}_password_reset", &self.user_uuid)) .await?; - cache_pool.del_cache_key(self.token.to_string()).await?; + app_state.del_cache_key(self.token.to_string()).await?; Ok(()) } diff --git a/src/objects/role.rs b/src/objects/role.rs index 46f54f6..4a4009b 100644 --- a/src/objects/role.rs +++ b/src/objects/role.rs @@ -7,10 +7,10 @@ use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::{ - Conn, + AppState, Conn, error::Error, schema::{role_members, roles}, - utils::{CacheFns, order_by_is_above}, + utils::order_by_is_above, }; use super::{HasIsAbove, HasUuid, load_or_empty}; @@ -75,33 +75,34 @@ impl Role { } pub async fn fetch_from_member( - conn: &mut Conn, - cache_pool: &redis::Client, + app_state: &AppState, member_uuid: Uuid, ) -> Result, Error> { - if let Ok(roles) = cache_pool + if let Ok(roles) = app_state .get_cache_key(format!("{member_uuid}_roles")) .await { - return Ok(roles); + return Ok(serde_json::from_str(&roles)?); } + let mut conn = app_state.pool.get().await?; + use role_members::dsl; let role_memberships: Vec = load_or_empty( dsl::role_members .filter(dsl::member_uuid.eq(member_uuid)) .select(RoleMember::as_select()) - .load(conn) + .load(&mut conn) .await, )?; let mut roles = vec![]; for membership in role_memberships { - roles.push(membership.fetch_role(conn).await?); + roles.push(membership.fetch_role(&mut conn).await?); } - cache_pool + app_state .set_cache_key(format!("{member_uuid}_roles"), roles.clone(), 300) .await?; diff --git a/src/objects/user.rs b/src/objects/user.rs index a686c39..c1f164d 100644 --- a/src/objects/user.rs +++ b/src/objects/user.rs @@ -4,7 +4,7 @@ use diesel_async::RunQueryDsl; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use crate::{Conn, error::Error, objects::Me, schema::users, utils::CacheFns}; +use crate::{AppState, Conn, error::Error, objects::Me, schema::users}; use super::load_or_empty; @@ -46,25 +46,23 @@ pub struct User { } impl User { - pub async fn fetch_one( - conn: &mut Conn, - cache_pool: &redis::Client, - user_uuid: Uuid, - ) -> Result { - if let Ok(cache_hit) = cache_pool.get_cache_key(user_uuid.to_string()).await { - return Ok(cache_hit); + pub async fn fetch_one(app_state: &AppState, user_uuid: Uuid) -> Result { + let mut conn = app_state.pool.get().await?; + + if let Ok(cache_hit) = app_state.get_cache_key(user_uuid.to_string()).await { + return Ok(serde_json::from_str(&cache_hit)?); } use users::dsl; let user_builder: UserBuilder = dsl::users .filter(dsl::uuid.eq(user_uuid)) .select(UserBuilder::as_select()) - .get_result(conn) + .get_result(&mut conn) .await?; let user = user_builder.build(); - cache_pool + app_state .set_cache_key(user_uuid.to_string(), user.clone(), 1800) .await?; @@ -72,14 +70,15 @@ impl User { } pub async fn fetch_one_with_friendship( - conn: &mut Conn, - cache_pool: &redis::Client, + app_state: &AppState, me: &Me, user_uuid: Uuid, ) -> Result { - let mut user = Self::fetch_one(conn, cache_pool, user_uuid).await?; + let mut conn = app_state.pool.get().await?; - if let Some(friend) = me.friends_with(conn, user_uuid).await? { + let mut user = Self::fetch_one(app_state, user_uuid).await?; + + if let Some(friend) = me.friends_with(&mut conn, user_uuid).await? { user.friends_since = Some(friend.accepted_at); } diff --git a/src/utils.rs b/src/utils.rs index 7ef880e..e1df906 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -8,13 +8,14 @@ use diesel::{ExpressionMethods, QueryDsl}; use diesel_async::RunQueryDsl; use getrandom::fill; use hex::encode; +use redis::RedisError; use regex::Regex; -use serde::{Serialize, de::DeserializeOwned}; +use serde::Serialize; use time::Duration; use uuid::Uuid; use crate::{ - Conn, + AppState, Conn, config::Config, error::Error, objects::{HasIsAbove, HasUuid}, @@ -114,13 +115,15 @@ pub async fn user_uuid_from_username(conn: &mut Conn, username: &String) -> Resu } } -pub async fn global_checks(conn: &mut Conn, config: &Config, user_uuid: Uuid) -> Result<(), Error> { - if config.instance.require_email_verification { +pub async fn global_checks(app_state: &AppState, user_uuid: Uuid) -> Result<(), Error> { + if app_state.config.instance.require_email_verification { + let mut conn = app_state.pool.get().await?; + use users::dsl; let email_verified: bool = dsl::users .filter(dsl::uuid.eq(user_uuid)) .select(dsl::email_verified) - .get_result(conn) + .get_result(&mut conn) .await?; if !email_verified { @@ -158,28 +161,14 @@ where Ok(ordered) } -#[allow(async_fn_in_trait)] -pub trait CacheFns { - async fn set_cache_key( - &self, - key: String, - value: impl Serialize, - expire: u32, - ) -> Result<(), Error>; - async fn get_cache_key(&self, key: String) -> Result - where - T: DeserializeOwned; - async fn del_cache_key(&self, key: String) -> Result<(), Error>; -} - -impl CacheFns for redis::Client { - async fn set_cache_key( +impl AppState { + pub async fn set_cache_key( &self, key: String, value: impl Serialize, expire: u32, ) -> Result<(), Error> { - let mut conn = self.get_multiplexed_tokio_connection().await?; + let mut conn = self.cache_pool.get_multiplexed_tokio_connection().await?; let key_encoded = encode(key); @@ -198,31 +187,26 @@ impl CacheFns for redis::Client { Ok(()) } - async fn get_cache_key(&self, key: String) -> Result - where - T: DeserializeOwned, - { - let mut conn = self.get_multiplexed_tokio_connection().await?; + pub async fn get_cache_key(&self, key: String) -> Result { + let mut conn = self.cache_pool.get_multiplexed_tokio_connection().await?; let key_encoded = encode(key); - let res: String = redis::cmd("GET") + redis::cmd("GET") .arg(key_encoded) .query_async(&mut conn) - .await?; - - Ok(serde_json::from_str(&res)?) + .await } - async fn del_cache_key(&self, key: String) -> Result<(), Error> { - let mut conn = self.get_multiplexed_tokio_connection().await?; + pub async fn del_cache_key(&self, key: String) -> Result<(), RedisError> { + let mut conn = self.cache_pool.get_multiplexed_tokio_connection().await?; let key_encoded = encode(key); - Ok(redis::cmd("DEL") + redis::cmd("DEL") .arg(key_encoded) .query_async(&mut conn) - .await?) + .await } }