diff --git a/src/api/v1/auth/register.rs b/src/api/v1/auth/register.rs index 807fab8..545e5aa 100644 --- a/src/api/v1/auth/register.rs +++ b/src/api/v1/auth/register.rs @@ -159,7 +159,7 @@ pub async fn post( .await?; if let Some(initial_guild) = app_state.config.instance.initial_guild { - Member::new(&app_state, uuid, initial_guild).await?; + Member::new(&mut conn, &app_state.cache_pool, uuid, initial_guild).await?; } let mut response = ( diff --git a/src/api/v1/auth/reset_password.rs b/src/api/v1/auth/reset_password.rs index bac465c..35c4b41 100644 --- a/src/api/v1/auth/reset_password.rs +++ b/src/api/v1/auth/reset_password.rs @@ -38,11 +38,17 @@ pub async fn get( State(app_state): State>, query: Query, ) -> Result { - if let Ok(password_reset_token) = - PasswordResetToken::get_with_identifier(&app_state, query.identifier.clone()).await + let mut conn = app_state.pool.get().await?; + + if let Ok(password_reset_token) = PasswordResetToken::get_with_identifier( + &mut conn, + &app_state.cache_pool, + query.identifier.clone(), + ) + .await { if Utc::now().signed_duration_since(password_reset_token.created_at) > Duration::hours(1) { - password_reset_token.delete(&app_state).await?; + password_reset_token.delete(&app_state.cache_pool).await?; } else { return Err(Error::TooManyRequests( "Please allow 1 hour before sending a new email".to_string(), @@ -50,7 +56,7 @@ pub async fn get( } } - PasswordResetToken::new(&app_state, query.identifier.clone()).await?; + PasswordResetToken::new(&mut conn, &app_state, query.identifier.clone()).await?; Ok(StatusCode::OK) } @@ -87,10 +93,14 @@ pub async fn post( reset_password: Json, ) -> Result { let password_reset_token = - PasswordResetToken::get(&app_state, reset_password.token.clone()).await?; + PasswordResetToken::get(&app_state.cache_pool, reset_password.token.clone()).await?; password_reset_token - .set_password(&app_state, reset_password.password.clone()) + .set_password( + &mut app_state.pool.get().await?, + &app_state, + reset_password.password.clone(), + ) .await?; Ok(StatusCode::OK) diff --git a/src/api/v1/auth/verify_email.rs b/src/api/v1/auth/verify_email.rs index 0801768..1cb8aef 100644 --- a/src/api/v1/auth/verify_email.rs +++ b/src/api/v1/auth/verify_email.rs @@ -55,7 +55,7 @@ pub async fn get( return Ok(StatusCode::NO_CONTENT); } - let email_token = EmailToken::get(&app_state, me.uuid).await?; + let email_token = EmailToken::get(&app_state.cache_pool, me.uuid).await?; if query.token != email_token.token { return Ok(StatusCode::UNAUTHORIZED); @@ -63,7 +63,7 @@ pub async fn get( me.verify_email(&mut conn).await?; - email_token.delete(&app_state).await?; + email_token.delete(&app_state.cache_pool).await?; Ok(StatusCode::OK) } @@ -91,9 +91,9 @@ pub async fn post( return Ok(StatusCode::NO_CONTENT); } - if let Ok(email_token) = EmailToken::get(&app_state, me.uuid).await { + if let Ok(email_token) = EmailToken::get(&app_state.cache_pool, me.uuid).await { if Utc::now().signed_duration_since(email_token.created_at) > Duration::hours(1) { - email_token.delete(&app_state).await?; + email_token.delete(&app_state.cache_pool).await?; } else { return Err(Error::TooManyRequests( "Please allow 1 hour before sending a new email".to_string(), diff --git a/src/api/v1/channels/uuid/messages.rs b/src/api/v1/channels/uuid/messages.rs index b8f0ad6..1f9010d 100644 --- a/src/api/v1/channels/uuid/messages.rs +++ b/src/api/v1/channels/uuid/messages.rs @@ -60,14 +60,21 @@ pub async fn get( Query(message_request): Query, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; - let channel = Channel::fetch_one(&app_state, channel_uuid).await?; + global_checks(&mut conn, &app_state.config, uuid).await?; - Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid).await?; + let channel = Channel::fetch_one(&mut conn, &app_state.cache_pool, channel_uuid).await?; + + Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; let messages = channel - .fetch_messages(&app_state, message_request.amount, message_request.offset) + .fetch_messages( + &mut conn, + &app_state.cache_pool, + message_request.amount, + message_request.offset, + ) .await?; Ok((StatusCode::OK, Json(messages))) diff --git a/src/api/v1/channels/uuid/mod.rs b/src/api/v1/channels/uuid/mod.rs index 373742e..f5566b3 100644 --- a/src/api/v1/channels/uuid/mod.rs +++ b/src/api/v1/channels/uuid/mod.rs @@ -27,11 +27,13 @@ pub async fn get( Path(channel_uuid): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; - let channel = Channel::fetch_one(&app_state, channel_uuid).await?; + global_checks(&mut conn, &app_state.config, uuid).await?; - Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid).await?; + let channel = Channel::fetch_one(&mut conn, &app_state.cache_pool, channel_uuid).await?; + + Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; Ok((StatusCode::OK, Json(channel))) } @@ -41,19 +43,19 @@ pub async fn delete( Path(channel_uuid): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; - let channel = Channel::fetch_one(&app_state, channel_uuid).await?; + global_checks(&mut conn, &app_state.config, uuid).await?; - let member = - Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid) - .await?; + let channel = Channel::fetch_one(&mut conn, &app_state.cache_pool, channel_uuid).await?; + + let member = Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; member - .check_permission(&app_state, Permissions::ManageChannel) + .check_permission(&mut conn, &app_state.cache_pool, Permissions::ManageChannel) .await?; - channel.delete(&app_state).await?; + channel.delete(&mut conn, &app_state.cache_pool).await?; Ok(StatusCode::OK) } @@ -102,31 +104,37 @@ pub async fn patch( Extension(CurrentUser(uuid)): Extension>, Json(new_info): Json, ) -> Result { - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; - let mut channel = Channel::fetch_one(&app_state, channel_uuid).await?; + global_checks(&mut conn, &app_state.config, uuid).await?; - let member = - Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid) - .await?; + let mut channel = Channel::fetch_one(&mut conn, &app_state.cache_pool, channel_uuid).await?; + + let member = Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; member - .check_permission(&app_state, Permissions::ManageChannel) + .check_permission(&mut conn, &app_state.cache_pool, Permissions::ManageChannel) .await?; if let Some(new_name) = &new_info.name { - channel.set_name(&app_state, new_name.to_string()).await?; + channel + .set_name(&mut conn, &app_state.cache_pool, new_name.to_string()) + .await?; } if let Some(new_description) = &new_info.description { channel - .set_description(&app_state, new_description.to_string()) + .set_description( + &mut conn, + &app_state.cache_pool, + new_description.to_string(), + ) .await?; } if let Some(new_is_above) = &new_info.is_above { channel - .set_description(&app_state, new_is_above.to_string()) + .set_description(&mut conn, &app_state.cache_pool, new_is_above.to_string()) .await?; } diff --git a/src/api/v1/channels/uuid/socket.rs b/src/api/v1/channels/uuid/socket.rs index dd020e3..ac04301 100644 --- a/src/api/v1/channels/uuid/socket.rs +++ b/src/api/v1/channels/uuid/socket.rs @@ -71,9 +71,9 @@ pub async fn ws( // Authorize client using auth header let uuid = check_access_token(auth_header, &mut conn).await?; - global_checks(&app_state, uuid).await?; + global_checks(&mut conn, &app_state.config, uuid).await?; - let channel = Channel::fetch_one(&app_state, channel_uuid).await?; + let channel = Channel::fetch_one(&mut conn, &app_state.cache_pool, channel_uuid).await?; Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; @@ -103,7 +103,8 @@ pub async fn ws( let message = channel .new_message( - &app_state, + &mut conn, + &app_state.cache_pool, uuid, message_body.message, message_body.reply_to, diff --git a/src/api/v1/guilds/mod.rs b/src/api/v1/guilds/mod.rs index 8118522..5b9f089 100644 --- a/src/api/v1/guilds/mod.rs +++ b/src/api/v1/guilds/mod.rs @@ -128,9 +128,11 @@ pub async fn get_guilds( let start = request_query.start.unwrap_or(0); let amount = request_query.amount.unwrap_or(10); - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; - let guilds = Guild::fetch_amount(&app_state.pool, start, amount).await?; + global_checks(&mut conn, &app_state.config, uuid).await?; + + let guilds = Guild::fetch_amount(&mut conn, start, amount).await?; Ok((StatusCode::OK, Json(guilds))) } diff --git a/src/api/v1/guilds/uuid/channels.rs b/src/api/v1/guilds/uuid/channels.rs index 82368b9..1cd7f78 100644 --- a/src/api/v1/guilds/uuid/channels.rs +++ b/src/api/v1/guilds/uuid/channels.rs @@ -14,7 +14,7 @@ use crate::{ api::v1::auth::CurrentUser, error::Error, objects::{Channel, Member, Permissions}, - utils::{global_checks, order_by_is_above}, + utils::{CacheFns, global_checks, order_by_is_above}, }; #[derive(Deserialize)] @@ -28,23 +28,26 @@ pub async fn get( Path(guild_uuid): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; - Member::check_membership(&mut app_state.pool.get().await?, uuid, guild_uuid).await?; + global_checks(&mut conn, &app_state.config, uuid).await?; + + Member::check_membership(&mut conn, uuid, guild_uuid).await?; if let Ok(cache_hit) = app_state - .get_cache_key(format!("{guild_uuid}_channels")) + .cache_pool + .get_cache_key::>(format!("{guild_uuid}_channels")) .await - && let Ok(channels) = serde_json::from_str::>(&cache_hit) { - return Ok((StatusCode::OK, Json(channels)).into_response()); + return Ok((StatusCode::OK, Json(cache_hit)).into_response()); } - let channels = Channel::fetch_all(&app_state.pool, guild_uuid).await?; + let channels = Channel::fetch_all(&mut conn, guild_uuid).await?; let channels_ordered = order_by_is_above(channels).await?; app_state + .cache_pool .set_cache_key( format!("{guild_uuid}_channels"), channels_ordered.clone(), @@ -61,17 +64,19 @@ pub async fn create( Extension(CurrentUser(uuid)): Extension>, Json(channel_info): Json, ) -> Result { - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; - let member = - Member::check_membership(&mut app_state.pool.get().await?, uuid, guild_uuid).await?; + global_checks(&mut conn, &app_state.config, uuid).await?; + + let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; member - .check_permission(&app_state, Permissions::ManageChannel) + .check_permission(&mut conn, &app_state.cache_pool, Permissions::ManageChannel) .await?; let channel = Channel::new( - &app_state, + &mut conn, + &app_state.cache_pool, guild_uuid, channel_info.name.clone(), channel_info.description.clone(), diff --git a/src/api/v1/guilds/uuid/invites/mod.rs b/src/api/v1/guilds/uuid/invites/mod.rs index 649fc16..fa06f44 100644 --- a/src/api/v1/guilds/uuid/invites/mod.rs +++ b/src/api/v1/guilds/uuid/invites/mod.rs @@ -27,10 +27,10 @@ pub async fn get( Path(guild_uuid): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + global_checks(&mut conn, &app_state.config, uuid).await?; + Member::check_membership(&mut conn, uuid, guild_uuid).await?; let guild = Guild::fetch_one(&mut conn, guild_uuid).await?; @@ -46,14 +46,14 @@ pub async fn create( Extension(CurrentUser(uuid)): Extension>, Json(invite_request): Json, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + global_checks(&mut conn, &app_state.config, uuid).await?; + let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; member - .check_permission(&app_state, Permissions::CreateInvite) + .check_permission(&mut conn, &app_state.cache_pool, Permissions::CreateInvite) .await?; let guild = Guild::fetch_one(&mut conn, guild_uuid).await?; diff --git a/src/api/v1/guilds/uuid/members.rs b/src/api/v1/guilds/uuid/members.rs index 3ae10f7..56710af 100644 --- a/src/api/v1/guilds/uuid/members.rs +++ b/src/api/v1/guilds/uuid/members.rs @@ -21,15 +21,15 @@ pub async fn get( Path(guild_uuid): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + global_checks(&mut conn, &app_state.config, uuid).await?; + Member::check_membership(&mut conn, uuid, guild_uuid).await?; let me = Me::get(&mut conn, uuid).await?; - let members = Member::fetch_all(&app_state, &me, guild_uuid).await?; + let members = Member::fetch_all(&mut conn, &app_state.cache_pool, &me, guild_uuid).await?; Ok((StatusCode::OK, Json(members))) } diff --git a/src/api/v1/guilds/uuid/mod.rs b/src/api/v1/guilds/uuid/mod.rs index 52f0b64..d49e56a 100644 --- a/src/api/v1/guilds/uuid/mod.rs +++ b/src/api/v1/guilds/uuid/mod.rs @@ -82,10 +82,10 @@ pub async fn get_guild( Path(guild_uuid): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + global_checks(&mut conn, &app_state.config, uuid).await?; + Member::check_membership(&mut conn, uuid, guild_uuid).await?; let guild = Guild::fetch_one(&mut conn, guild_uuid).await?; @@ -102,14 +102,14 @@ pub async fn edit( Extension(CurrentUser(uuid)): Extension>, mut multipart: Multipart, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + global_checks(&mut conn, &app_state.config, uuid).await?; + let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; member - .check_permission(&app_state, Permissions::ManageGuild) + .check_permission(&mut conn, &app_state.cache_pool, Permissions::ManageGuild) .await?; let mut guild = Guild::fetch_one(&mut conn, guild_uuid).await?; @@ -127,14 +127,7 @@ pub async fn edit( } if let Some(icon) = icon { - guild - .set_icon( - &app_state.bunny_storage, - &mut conn, - app_state.config.bunny.cdn_url.clone(), - icon, - ) - .await?; + guild.set_icon(&mut conn, &app_state, icon).await?; } Ok(StatusCode::OK) diff --git a/src/api/v1/guilds/uuid/roles/mod.rs b/src/api/v1/guilds/uuid/roles/mod.rs index 0e496a0..d3660ce 100644 --- a/src/api/v1/guilds/uuid/roles/mod.rs +++ b/src/api/v1/guilds/uuid/roles/mod.rs @@ -14,7 +14,7 @@ use crate::{ api::v1::auth::CurrentUser, error::Error, objects::{Member, Permissions, Role}, - utils::{global_checks, order_by_is_above}, + utils::{CacheFns, global_checks, order_by_is_above}, }; pub mod uuid; @@ -29,16 +29,18 @@ pub async fn get( Path(guild_uuid): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + global_checks(&mut conn, &app_state.config, uuid).await?; + Member::check_membership(&mut conn, uuid, guild_uuid).await?; - if let Ok(cache_hit) = app_state.get_cache_key(format!("{guild_uuid}_roles")).await - && let Ok(roles) = serde_json::from_str::>(&cache_hit) + if let Ok(cache_hit) = app_state + .cache_pool + .get_cache_key::>(format!("{guild_uuid}_roles")) + .await { - return Ok((StatusCode::OK, Json(roles)).into_response()); + return Ok((StatusCode::OK, Json(cache_hit)).into_response()); } let roles = Role::fetch_all(&mut conn, guild_uuid).await?; @@ -46,6 +48,7 @@ pub async fn get( let roles_ordered = order_by_is_above(roles).await?; app_state + .cache_pool .set_cache_key(format!("{guild_uuid}_roles"), roles_ordered.clone(), 1800) .await?; @@ -58,14 +61,14 @@ pub async fn create( Extension(CurrentUser(uuid)): Extension>, Json(role_info): Json, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + global_checks(&mut conn, &app_state.config, uuid).await?; + let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; member - .check_permission(&app_state, Permissions::ManageRole) + .check_permission(&mut conn, &app_state.cache_pool, Permissions::ManageRole) .await?; let role = Role::new(&mut conn, guild_uuid, role_info.name.clone()).await?; diff --git a/src/api/v1/guilds/uuid/roles/uuid.rs b/src/api/v1/guilds/uuid/roles/uuid.rs index 732d553..e7890d0 100644 --- a/src/api/v1/guilds/uuid/roles/uuid.rs +++ b/src/api/v1/guilds/uuid/roles/uuid.rs @@ -13,7 +13,7 @@ use crate::{ api::v1::auth::CurrentUser, error::Error, objects::{Member, Role}, - utils::global_checks, + utils::{CacheFns, global_checks}, }; pub async fn get( @@ -21,21 +21,24 @@ pub async fn get( Path((guild_uuid, role_uuid)): Path<(Uuid, Uuid)>, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + global_checks(&mut conn, &app_state.config, uuid).await?; + Member::check_membership(&mut conn, uuid, guild_uuid).await?; - if let Ok(cache_hit) = app_state.get_cache_key(format!("{role_uuid}")).await - && let Ok(role) = serde_json::from_str::(&cache_hit) + if let Ok(cache_hit) = app_state + .cache_pool + .get_cache_key::(format!("{role_uuid}")) + .await { - return Ok((StatusCode::OK, Json(role)).into_response()); + return Ok((StatusCode::OK, Json(cache_hit)).into_response()); } let role = Role::fetch_one(&mut conn, role_uuid).await?; app_state + .cache_pool .set_cache_key(format!("{role_uuid}"), role.clone(), 60) .await?; diff --git a/src/api/v1/invites/id.rs b/src/api/v1/invites/id.rs index 72ceea4..99f177f 100644 --- a/src/api/v1/invites/id.rs +++ b/src/api/v1/invites/id.rs @@ -34,15 +34,15 @@ pub async fn join( Path(invite_id): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + global_checks(&mut conn, &app_state.config, uuid).await?; + let invite = Invite::fetch_one(&mut conn, invite_id).await?; let guild = Guild::fetch_one(&mut conn, invite.guild_uuid).await?; - Member::new(&app_state, uuid, guild.uuid).await?; + Member::new(&mut conn, &app_state.cache_pool, uuid, guild.uuid).await?; Ok((StatusCode::OK, Json(guild))) } diff --git a/src/api/v1/me/friends/mod.rs b/src/api/v1/me/friends/mod.rs index a56f8d4..904a1f5 100644 --- a/src/api/v1/me/friends/mod.rs +++ b/src/api/v1/me/friends/mod.rs @@ -19,11 +19,13 @@ pub async fn get( State(app_state): State>, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; - let me = Me::get(&mut app_state.pool.get().await?, uuid).await?; + global_checks(&mut conn, &app_state.config, uuid).await?; - let friends = me.get_friends(&app_state).await?; + let me = Me::get(&mut conn, uuid).await?; + + let friends = me.get_friends(&mut conn, &app_state.cache_pool).await?; Ok((StatusCode::OK, Json(friends))) } @@ -57,10 +59,10 @@ pub async fn post( Extension(CurrentUser(uuid)): Extension>, Json(user_request): Json, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + global_checks(&mut conn, &app_state.config, uuid).await?; + let me = Me::get(&mut conn, uuid).await?; let target_uuid = user_uuid_from_username(&mut conn, &user_request.username).await?; diff --git a/src/api/v1/me/friends/uuid.rs b/src/api/v1/me/friends/uuid.rs index 5367435..35f0742 100644 --- a/src/api/v1/me/friends/uuid.rs +++ b/src/api/v1/me/friends/uuid.rs @@ -17,10 +17,10 @@ pub async fn delete( Path(friend_uuid): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + global_checks(&mut conn, &app_state.config, uuid).await?; + let me = Me::get(&mut conn, uuid).await?; me.remove_friend(&mut conn, friend_uuid).await?; diff --git a/src/api/v1/me/guilds.rs b/src/api/v1/me/guilds.rs index 88dfad9..42d5c21 100644 --- a/src/api/v1/me/guilds.rs +++ b/src/api/v1/me/guilds.rs @@ -58,10 +58,10 @@ pub async fn get( State(app_state): State>, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + global_checks(&mut conn, &app_state.config, uuid).await?; + let me = Me::get(&mut conn, uuid).await?; let memberships = me.fetch_memberships(&mut conn).await?; diff --git a/src/api/v1/me/mod.rs b/src/api/v1/me/mod.rs index e167d14..86d3d9e 100644 --- a/src/api/v1/me/mod.rs +++ b/src/api/v1/me/mod.rs @@ -73,36 +73,41 @@ pub async fn update( let json = json_raw.unwrap_or_default(); + let mut conn = app_state.pool.get().await?; + if avatar.is_some() || json.username.is_some() || json.display_name.is_some() { - global_checks(&app_state, uuid).await?; + global_checks(&mut conn, &app_state.config, uuid).await?; } - let mut me = Me::get(&mut app_state.pool.get().await?, uuid).await?; + let mut me = Me::get(&mut conn, uuid).await?; if let Some(avatar) = avatar { - me.set_avatar(&app_state, app_state.config.bunny.cdn_url.clone(), avatar) - .await?; + me.set_avatar(&mut conn, &app_state, avatar).await?; } if let Some(username) = &json.username { - me.set_username(&app_state, username.clone()).await?; + me.set_username(&mut conn, &app_state.cache_pool, username.clone()) + .await?; } if let Some(display_name) = &json.display_name { - me.set_display_name(&app_state, display_name.clone()) + me.set_display_name(&mut conn, &app_state.cache_pool, display_name.clone()) .await?; } if let Some(email) = &json.email { - me.set_email(&app_state, email.clone()).await?; + me.set_email(&mut conn, &app_state.cache_pool, email.clone()) + .await?; } if let Some(pronouns) = &json.pronouns { - me.set_pronouns(&app_state, pronouns.clone()).await?; + me.set_pronouns(&mut conn, &app_state.cache_pool, pronouns.clone()) + .await?; } if let Some(about) = &json.about { - me.set_about(&app_state, about.clone()).await?; + me.set_about(&mut conn, &app_state.cache_pool, about.clone()) + .await?; } Ok(StatusCode::OK) diff --git a/src/api/v1/users/mod.rs b/src/api/v1/users/mod.rs index a4b93ce..999e13f 100644 --- a/src/api/v1/users/mod.rs +++ b/src/api/v1/users/mod.rs @@ -70,9 +70,11 @@ pub async fn users( return Ok(StatusCode::BAD_REQUEST.into_response()); } - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; - let users = User::fetch_amount(&mut app_state.pool.get().await?, start, amount).await?; + global_checks(&mut conn, &app_state.config, uuid).await?; + + let users = User::fetch_amount(&mut conn, start, amount).await?; Ok((StatusCode::OK, Json(users)).into_response()) } diff --git a/src/api/v1/users/uuid.rs b/src/api/v1/users/uuid.rs index cee6df0..e015c3c 100644 --- a/src/api/v1/users/uuid.rs +++ b/src/api/v1/users/uuid.rs @@ -39,11 +39,14 @@ pub async fn get( Path(user_uuid): Path, Extension(CurrentUser(uuid)): Extension>, ) -> Result { - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; - let me = Me::get(&mut app_state.pool.get().await?, uuid).await?; + global_checks(&mut conn, &app_state.config, uuid).await?; - let user = User::fetch_one_with_friendship(&app_state, &me, user_uuid).await?; + let me = Me::get(&mut conn, uuid).await?; + + let user = + User::fetch_one_with_friendship(&mut conn, &app_state.cache_pool, &me, user_uuid).await?; Ok((StatusCode::OK, Json(user))) } diff --git a/src/objects/channel.rs b/src/objects/channel.rs index cacb153..03a2cf6 100644 --- a/src/objects/channel.rs +++ b/src/objects/channel.rs @@ -2,15 +2,15 @@ use diesel::{ ExpressionMethods, Insertable, QueryDsl, Queryable, Selectable, SelectableHelper, delete, insert_into, update, }; -use diesel_async::{RunQueryDsl, pooled_connection::AsyncDieselConnectionManager}; +use diesel_async::RunQueryDsl; use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::{ - AppState, Conn, + Conn, error::Error, schema::{channel_permissions, channels, messages}, - utils::{CHANNEL_REGEX, order_by_is_above}, + utils::{CHANNEL_REGEX, CacheFns, order_by_is_above}, }; use super::{HasIsAbove, HasUuid, Message, load_or_empty, message::MessageBuilder}; @@ -79,49 +79,44 @@ impl HasIsAbove for Channel { } impl Channel { - pub async fn fetch_all( - pool: &deadpool::managed::Pool< - AsyncDieselConnectionManager, - Conn, - >, - guild_uuid: Uuid, - ) -> Result, Error> { - let mut conn = pool.get().await?; - + pub async fn fetch_all(conn: &mut Conn, guild_uuid: Uuid) -> Result, Error> { use channels::dsl; let channel_builders: Vec = load_or_empty( dsl::channels .filter(dsl::guild_uuid.eq(guild_uuid)) .select(ChannelBuilder::as_select()) - .load(&mut conn) + .load(conn) .await, )?; - let channel_futures = channel_builders.iter().map(async move |c| { - let mut conn = pool.get().await?; - c.clone().build(&mut conn).await - }); + let mut channels = vec![]; - futures_util::future::try_join_all(channel_futures).await - } - - pub async fn fetch_one(app_state: &AppState, channel_uuid: Uuid) -> Result { - if let Ok(cache_hit) = app_state.get_cache_key(channel_uuid.to_string()).await { - return Ok(serde_json::from_str(&cache_hit)?); + for builder in channel_builders { + channels.push(builder.build(conn).await?); } - let mut conn = app_state.pool.get().await?; + Ok(channels) + } + + pub async fn fetch_one( + conn: &mut Conn, + cache_pool: &redis::Client, + channel_uuid: Uuid, + ) -> Result { + if let Ok(cache_hit) = cache_pool.get_cache_key(channel_uuid.to_string()).await { + return Ok(cache_hit); + } use channels::dsl; let channel_builder: ChannelBuilder = dsl::channels .filter(dsl::uuid.eq(channel_uuid)) .select(ChannelBuilder::as_select()) - .get_result(&mut conn) + .get_result(conn) .await?; - let channel = channel_builder.build(&mut conn).await?; + let channel = channel_builder.build(conn).await?; - app_state + cache_pool .set_cache_key(channel_uuid.to_string(), channel.clone(), 60) .await?; @@ -129,7 +124,8 @@ impl Channel { } pub async fn new( - app_state: &AppState, + conn: &mut Conn, + cache_pool: &redis::Client, guild_uuid: Uuid, name: String, description: Option, @@ -138,11 +134,9 @@ impl Channel { return Err(Error::BadRequest("Channel name is invalid".to_string())); } - let mut conn = app_state.pool.get().await?; - let channel_uuid = Uuid::now_v7(); - let channels = Self::fetch_all(&app_state.pool, guild_uuid).await?; + let channels = Self::fetch_all(conn, guild_uuid).await?; let channels_ordered = order_by_is_above(channels).await?; @@ -158,7 +152,7 @@ impl Channel { insert_into(channels::table) .values(new_channel.clone()) - .execute(&mut conn) + .execute(conn) .await?; if let Some(old_last_channel) = last_channel { @@ -166,7 +160,7 @@ impl Channel { update(channels::table) .filter(dsl::uuid.eq(old_last_channel.uuid)) .set(dsl::is_above.eq(new_channel.uuid)) - .execute(&mut conn) + .execute(conn) .await?; } @@ -180,16 +174,16 @@ impl Channel { permissions: vec![], }; - app_state + cache_pool .set_cache_key(channel_uuid.to_string(), channel.clone(), 1800) .await?; - if app_state - .get_cache_key(format!("{guild_uuid}_channels")) + if cache_pool + .get_cache_key::>(format!("{guild_uuid}_channels")) .await .is_ok() { - app_state + cache_pool .del_cache_key(format!("{guild_uuid}_channels")) .await?; } @@ -197,14 +191,12 @@ impl Channel { Ok(channel) } - pub async fn delete(self, app_state: &AppState) -> Result<(), Error> { - let mut conn = app_state.pool.get().await?; - + pub async fn delete(self, conn: &mut Conn, cache_pool: &redis::Client) -> Result<(), Error> { use channels::dsl; match update(channels::table) .filter(dsl::is_above.eq(self.uuid)) .set(dsl::is_above.eq(None::)) - .execute(&mut conn) + .execute(conn) .await { Ok(r) => Ok(r), @@ -214,13 +206,13 @@ impl Channel { delete(channels::table) .filter(dsl::uuid.eq(self.uuid)) - .execute(&mut conn) + .execute(conn) .await?; match update(channels::table) .filter(dsl::is_above.eq(self.uuid)) .set(dsl::is_above.eq(self.is_above)) - .execute(&mut conn) + .execute(conn) .await { Ok(r) => Ok(r), @@ -228,16 +220,20 @@ impl Channel { Err(e) => Err(e), }?; - if app_state.get_cache_key(self.uuid.to_string()).await.is_ok() { - app_state.del_cache_key(self.uuid.to_string()).await?; - } - - if app_state - .get_cache_key(format!("{}_channels", self.guild_uuid)) + if cache_pool + .get_cache_key::(self.uuid.to_string()) .await .is_ok() { - app_state + cache_pool.del_cache_key(self.uuid.to_string()).await?; + } + + if cache_pool + .get_cache_key::>(format!("{}_channels", self.guild_uuid)) + .await + .is_ok() + { + cache_pool .del_cache_key(format!("{}_channels", self.guild_uuid)) .await?; } @@ -247,32 +243,36 @@ impl Channel { pub async fn fetch_messages( &self, - app_state: &AppState, + conn: &mut Conn, + cache_pool: &redis::Client, amount: i64, offset: i64, ) -> Result, Error> { - let mut conn = app_state.pool.get().await?; - use messages::dsl; - let messages: Vec = load_or_empty( + let message_builders: Vec = load_or_empty( dsl::messages .filter(dsl::channel_uuid.eq(self.uuid)) .select(MessageBuilder::as_select()) .order(dsl::uuid.desc()) .limit(amount) .offset(offset) - .load(&mut conn) + .load(conn) .await, )?; - let message_futures = messages.iter().map(async move |b| b.build(app_state).await); + let mut messages = vec![]; - futures_util::future::try_join_all(message_futures).await + for builder in message_builders { + messages.push(builder.build(conn, cache_pool).await?); + } + + Ok(messages) } pub async fn new_message( &self, - app_state: &AppState, + conn: &mut Conn, + cache_pool: &redis::Client, user_uuid: Uuid, message: String, reply_to: Option, @@ -287,66 +287,101 @@ impl Channel { reply_to, }; - let mut conn = app_state.pool.get().await?; - insert_into(messages::table) .values(message.clone()) - .execute(&mut conn) + .execute(conn) .await?; - message.build(app_state).await + message.build(conn, cache_pool).await } - pub async fn set_name(&mut self, app_state: &AppState, new_name: String) -> Result<(), Error> { + pub async fn set_name( + &mut self, + conn: &mut Conn, + cache_pool: &redis::Client, + new_name: String, + ) -> Result<(), Error> { if !CHANNEL_REGEX.is_match(&new_name) { return Err(Error::BadRequest("Channel name is invalid".to_string())); } - let mut conn = app_state.pool.get().await?; - use channels::dsl; update(channels::table) .filter(dsl::uuid.eq(self.uuid)) .set(dsl::name.eq(&new_name)) - .execute(&mut conn) + .execute(conn) .await?; self.name = new_name; + if cache_pool + .get_cache_key::(self.uuid.to_string()) + .await + .is_ok() + { + cache_pool.del_cache_key(self.uuid.to_string()).await?; + } + + if cache_pool + .get_cache_key::>(format!("{}_channels", self.guild_uuid)) + .await + .is_ok() + { + cache_pool + .del_cache_key(format!("{}_channels", self.guild_uuid)) + .await?; + } + Ok(()) } pub async fn set_description( &mut self, - app_state: &AppState, + conn: &mut Conn, + cache_pool: &redis::Client, new_description: String, ) -> Result<(), Error> { - let mut conn = app_state.pool.get().await?; - use channels::dsl; update(channels::table) .filter(dsl::uuid.eq(self.uuid)) .set(dsl::description.eq(&new_description)) - .execute(&mut conn) + .execute(conn) .await?; self.description = Some(new_description); + if cache_pool + .get_cache_key::(self.uuid.to_string()) + .await + .is_ok() + { + cache_pool.del_cache_key(self.uuid.to_string()).await?; + } + + if cache_pool + .get_cache_key::>(format!("{}_channels", self.guild_uuid)) + .await + .is_ok() + { + cache_pool + .del_cache_key(format!("{}_channels", self.guild_uuid)) + .await?; + } + Ok(()) } pub async fn move_channel( &mut self, - app_state: &AppState, + conn: &mut Conn, + cache_pool: &redis::Client, new_is_above: Uuid, ) -> Result<(), Error> { - let mut conn = app_state.pool.get().await?; - use channels::dsl; let old_above_uuid: Option = match dsl::channels .filter(dsl::is_above.eq(self.uuid)) .select(dsl::uuid) - .get_result(&mut conn) + .get_result(conn) .await { Ok(r) => Ok(Some(r)), @@ -358,14 +393,14 @@ impl Channel { update(channels::table) .filter(dsl::uuid.eq(uuid)) .set(dsl::is_above.eq(None::)) - .execute(&mut conn) + .execute(conn) .await?; } match update(channels::table) .filter(dsl::is_above.eq(new_is_above)) .set(dsl::is_above.eq(self.uuid)) - .execute(&mut conn) + .execute(conn) .await { Ok(r) => Ok(r), @@ -376,19 +411,37 @@ impl Channel { update(channels::table) .filter(dsl::uuid.eq(self.uuid)) .set(dsl::is_above.eq(new_is_above)) - .execute(&mut conn) + .execute(conn) .await?; if let Some(uuid) = old_above_uuid { update(channels::table) .filter(dsl::uuid.eq(uuid)) .set(dsl::is_above.eq(self.is_above)) - .execute(&mut conn) + .execute(conn) .await?; } self.is_above = Some(new_is_above); + if cache_pool + .get_cache_key::(self.uuid.to_string()) + .await + .is_ok() + { + cache_pool.del_cache_key(self.uuid.to_string()).await?; + } + + if cache_pool + .get_cache_key::>(format!("{}_channels", self.guild_uuid)) + .await + .is_ok() + { + cache_pool + .del_cache_key(format!("{}_channels", self.guild_uuid)) + .await?; + } + Ok(()) } } diff --git a/src/objects/email_token.rs b/src/objects/email_token.rs index 64d2fdb..c826620 100644 --- a/src/objects/email_token.rs +++ b/src/objects/email_token.rs @@ -3,7 +3,11 @@ use lettre::message::MultiPart; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use crate::{AppState, error::Error, utils::generate_token}; +use crate::{ + AppState, + error::Error, + utils::{CacheFns, generate_token}, +}; use super::Me; @@ -15,12 +19,10 @@ pub struct EmailToken { } impl EmailToken { - pub async fn get(app_state: &AppState, user_uuid: Uuid) -> Result { - let email_token = serde_json::from_str( - &app_state - .get_cache_key(format!("{user_uuid}_email_verify")) - .await?, - )?; + pub async fn get(cache_pool: &redis::Client, user_uuid: Uuid) -> Result { + let email_token = cache_pool + .get_cache_key(format!("{user_uuid}_email_verify")) + .await?; Ok(email_token) } @@ -37,6 +39,7 @@ impl EmailToken { }; app_state + .cache_pool .set_cache_key(format!("{}_email_verify", me.uuid), email_token, 86400) .await?; @@ -59,8 +62,8 @@ impl EmailToken { Ok(()) } - pub async fn delete(&self, app_state: &AppState) -> Result<(), Error> { - app_state + pub async fn delete(&self, cache_pool: &redis::Client) -> Result<(), Error> { + cache_pool .del_cache_key(format!("{}_email_verify", self.user_uuid)) .await?; diff --git a/src/objects/guild.rs b/src/objects/guild.rs index 9514e49..9640e28 100644 --- a/src/objects/guild.rs +++ b/src/objects/guild.rs @@ -3,14 +3,14 @@ use diesel::{ ExpressionMethods, Insertable, QueryDsl, Queryable, Selectable, SelectableHelper, insert_into, update, }; -use diesel_async::{RunQueryDsl, pooled_connection::AsyncDieselConnectionManager}; +use diesel_async::RunQueryDsl; use serde::Serialize; use tokio::task; use url::Url; use uuid::Uuid; use crate::{ - Conn, + AppState, Conn, error::Error, schema::{guild_members, guilds, invites}, utils::image_check, @@ -68,16 +68,11 @@ impl Guild { } pub async fn fetch_amount( - pool: &deadpool::managed::Pool< - AsyncDieselConnectionManager, - Conn, - >, + conn: &mut Conn, offset: i64, amount: i64, ) -> Result, Error> { // Fetch guild data from database - let mut conn = pool.get().await?; - use guilds::dsl; let guild_builders: Vec = load_or_empty( dsl::guilds @@ -85,18 +80,17 @@ impl Guild { .order_by(dsl::uuid) .offset(offset) .limit(amount) - .load(&mut conn) + .load(conn) .await, )?; - // Process each guild concurrently - let guild_futures = guild_builders.iter().map(async move |g| { - let mut conn = pool.get().await?; - g.clone().build(&mut conn).await - }); + let mut guilds = vec![]; - // Execute all futures concurrently and collect results - futures_util::future::try_join_all(guild_futures).await + for builder in guild_builders { + guilds.push(builder.build(conn).await?); + } + + Ok(guilds) } pub async fn new(conn: &mut Conn, name: String, owner_uuid: Uuid) -> Result { @@ -188,9 +182,8 @@ impl Guild { // FIXME: Horrible security pub async fn set_icon( &mut self, - bunny_storage: &bunny_api_tokio::EdgeStorageClient, conn: &mut Conn, - cdn_url: Url, + app_state: &AppState, icon: Bytes, ) -> Result<(), Error> { let icon_clone = icon.clone(); @@ -199,14 +192,14 @@ impl Guild { if let Some(icon) = &self.icon { let relative_url = icon.path().trim_start_matches('/'); - bunny_storage.delete(relative_url).await?; + app_state.bunny_storage.delete(relative_url).await?; } let path = format!("icons/{}/{}.{}", self.uuid, Uuid::now_v7(), image_type); - bunny_storage.upload(path.clone(), icon).await?; + app_state.bunny_storage.upload(path.clone(), icon).await?; - let icon_url = cdn_url.join(&path)?; + let icon_url = app_state.config.bunny.cdn_url.join(&path)?; use guilds::dsl; update(guilds::table) diff --git a/src/objects/me.rs b/src/objects/me.rs index a0b399d..d03e08b 100644 --- a/src/objects/me.rs +++ b/src/objects/me.rs @@ -14,7 +14,7 @@ use crate::{ error::Error, objects::{Friend, FriendRequest, User}, schema::{friend_requests, friends, guild_members, guilds, users}, - utils::{EMAIL_REGEX, USERNAME_REGEX, image_check}, + utils::{CacheFns, EMAIL_REGEX, USERNAME_REGEX, image_check}, }; use super::{Guild, guild::GuildBuilder, load_or_empty, member::MemberBuilder}; @@ -75,15 +75,13 @@ impl Me { pub async fn set_avatar( &mut self, + conn: &mut Conn, app_state: &AppState, - cdn_url: Url, avatar: Bytes, ) -> Result<(), Error> { let avatar_clone = avatar.clone(); let image_type = task::spawn_blocking(move || image_check(avatar_clone)).await??; - let mut conn = app_state.pool.get().await?; - if let Some(avatar) = &self.avatar { let avatar_url: Url = avatar.parse()?; @@ -96,17 +94,25 @@ impl Me { app_state.bunny_storage.upload(path.clone(), avatar).await?; - let avatar_url = cdn_url.join(&path)?; + let avatar_url = app_state.config.bunny.cdn_url.join(&path)?; use users::dsl; update(users::table) .filter(dsl::uuid.eq(self.uuid)) .set(dsl::avatar.eq(avatar_url.as_str())) - .execute(&mut conn) + .execute(conn) .await?; - if app_state.get_cache_key(self.uuid.to_string()).await.is_ok() { - app_state.del_cache_key(self.uuid.to_string()).await? + if app_state + .cache_pool + .get_cache_key::(self.uuid.to_string()) + .await + .is_ok() + { + app_state + .cache_pool + .del_cache_key(self.uuid.to_string()) + .await? } self.avatar = Some(avatar_url.to_string()); @@ -127,7 +133,8 @@ impl Me { pub async fn set_username( &mut self, - app_state: &AppState, + conn: &mut Conn, + cache_pool: &redis::Client, new_username: String, ) -> Result<(), Error> { if !USERNAME_REGEX.is_match(&new_username) @@ -137,17 +144,19 @@ impl Me { return Err(Error::BadRequest("Invalid username".to_string())); } - let mut conn = app_state.pool.get().await?; - use users::dsl; update(users::table) .filter(dsl::uuid.eq(self.uuid)) .set(dsl::username.eq(new_username.as_str())) - .execute(&mut conn) + .execute(conn) .await?; - if app_state.get_cache_key(self.uuid.to_string()).await.is_ok() { - app_state.del_cache_key(self.uuid.to_string()).await? + if cache_pool + .get_cache_key::(self.uuid.to_string()) + .await + .is_ok() + { + cache_pool.del_cache_key(self.uuid.to_string()).await? } self.username = new_username; @@ -157,11 +166,10 @@ impl Me { pub async fn set_display_name( &mut self, - app_state: &AppState, + conn: &mut Conn, + cache_pool: &redis::Client, new_display_name: String, ) -> Result<(), Error> { - let mut conn = app_state.pool.get().await?; - let new_display_name_option = if new_display_name.is_empty() { None } else { @@ -172,11 +180,15 @@ impl Me { update(users::table) .filter(dsl::uuid.eq(self.uuid)) .set(dsl::display_name.eq(&new_display_name_option)) - .execute(&mut conn) + .execute(conn) .await?; - if app_state.get_cache_key(self.uuid.to_string()).await.is_ok() { - app_state.del_cache_key(self.uuid.to_string()).await? + if cache_pool + .get_cache_key::(self.uuid.to_string()) + .await + .is_ok() + { + cache_pool.del_cache_key(self.uuid.to_string()).await? } self.display_name = new_display_name_option; @@ -186,15 +198,14 @@ impl Me { pub async fn set_email( &mut self, - app_state: &AppState, + conn: &mut Conn, + cache_pool: &redis::Client, new_email: String, ) -> Result<(), Error> { if !EMAIL_REGEX.is_match(&new_email) { return Err(Error::BadRequest("Invalid username".to_string())); } - let mut conn = app_state.pool.get().await?; - use users::dsl; update(users::table) .filter(dsl::uuid.eq(self.uuid)) @@ -202,11 +213,15 @@ impl Me { dsl::email.eq(new_email.as_str()), dsl::email_verified.eq(false), )) - .execute(&mut conn) + .execute(conn) .await?; - if app_state.get_cache_key(self.uuid.to_string()).await.is_ok() { - app_state.del_cache_key(self.uuid.to_string()).await? + if cache_pool + .get_cache_key::(self.uuid.to_string()) + .await + .is_ok() + { + cache_pool.del_cache_key(self.uuid.to_string()).await? } self.email = new_email; @@ -216,20 +231,23 @@ impl Me { pub async fn set_pronouns( &mut self, - app_state: &AppState, + conn: &mut Conn, + cache_pool: &redis::Client, new_pronouns: String, ) -> Result<(), Error> { - let mut conn = app_state.pool.get().await?; - use users::dsl; update(users::table) .filter(dsl::uuid.eq(self.uuid)) .set((dsl::pronouns.eq(new_pronouns.as_str()),)) - .execute(&mut conn) + .execute(conn) .await?; - if app_state.get_cache_key(self.uuid.to_string()).await.is_ok() { - app_state.del_cache_key(self.uuid.to_string()).await? + if cache_pool + .get_cache_key::(self.uuid.to_string()) + .await + .is_ok() + { + cache_pool.del_cache_key(self.uuid.to_string()).await? } Ok(()) @@ -237,20 +255,23 @@ impl Me { pub async fn set_about( &mut self, - app_state: &AppState, + conn: &mut Conn, + cache_pool: &redis::Client, new_about: String, ) -> Result<(), Error> { - let mut conn = app_state.pool.get().await?; - use users::dsl; update(users::table) .filter(dsl::uuid.eq(self.uuid)) .set((dsl::about.eq(new_about.as_str()),)) - .execute(&mut conn) + .execute(conn) .await?; - if app_state.get_cache_key(self.uuid.to_string()).await.is_ok() { - app_state.del_cache_key(self.uuid.to_string()).await? + if cache_pool + .get_cache_key::(self.uuid.to_string()) + .await + .is_ok() + { + cache_pool.del_cache_key(self.uuid.to_string()).await? } Ok(()) @@ -366,16 +387,18 @@ impl Me { Ok(()) } - pub async fn get_friends(&self, app_state: &AppState) -> Result, Error> { + pub async fn get_friends( + &self, + conn: &mut Conn, + cache_pool: &redis::Client, + ) -> Result, Error> { use friends::dsl; - let mut conn = app_state.pool.get().await?; - let friends1 = load_or_empty( dsl::friends .filter(dsl::uuid1.eq(self.uuid)) .select(Friend::as_select()) - .load(&mut conn) + .load(conn) .await, )?; @@ -383,21 +406,21 @@ impl Me { dsl::friends .filter(dsl::uuid2.eq(self.uuid)) .select(Friend::as_select()) - .load(&mut conn) + .load(conn) .await, )?; - let friend_futures = friends1.iter().map(async move |friend| { - User::fetch_one_with_friendship(app_state, self, friend.uuid2).await - }); + let mut friends = vec![]; - let mut friends = futures_util::future::try_join_all(friend_futures).await?; + for friend in friends1 { + friends + .push(User::fetch_one_with_friendship(conn, cache_pool, self, friend.uuid2).await?); + } - let friend_futures = friends2.iter().map(async move |friend| { - User::fetch_one_with_friendship(app_state, self, friend.uuid1).await - }); - - friends.append(&mut futures_util::future::try_join_all(friend_futures).await?); + for friend in friends2 { + friends + .push(User::fetch_one_with_friendship(conn, cache_pool, self, friend.uuid1).await?); + } Ok(friends) } diff --git a/src/objects/member.rs b/src/objects/member.rs index 50b76b0..b7befdc 100644 --- a/src/objects/member.rs +++ b/src/objects/member.rs @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::{ - AppState, Conn, + Conn, error::Error, objects::{Me, Permissions, Role}, schema::guild_members, @@ -26,13 +26,18 @@ pub struct MemberBuilder { } impl MemberBuilder { - pub async fn build(&self, app_state: &AppState, me: Option<&Me>) -> Result { + pub async fn build( + &self, + conn: &mut Conn, + cache_pool: &redis::Client, + me: Option<&Me>, + ) -> Result { let user; if let Some(me) = me { - user = User::fetch_one_with_friendship(app_state, me, self.user_uuid).await?; + user = User::fetch_one_with_friendship(conn, cache_pool, me, self.user_uuid).await?; } else { - user = User::fetch_one(app_state, self.user_uuid).await?; + user = User::fetch_one(conn, cache_pool, self.user_uuid).await?; } Ok(Member { @@ -47,11 +52,12 @@ impl MemberBuilder { pub async fn check_permission( &self, - app_state: &AppState, + conn: &mut Conn, + cache_pool: &redis::Client, permission: Permissions, ) -> Result<(), Error> { if !self.is_owner { - let roles = Role::fetch_from_member(app_state, self.uuid).await?; + let roles = Role::fetch_from_member(conn, cache_pool, self.uuid).await?; let allowed = roles.iter().any(|r| r.permissions & permission as i64 != 0); if !allowed { return Err(Error::Forbidden("Not allowed".to_string())); @@ -101,56 +107,53 @@ impl Member { } pub async fn fetch_one( - app_state: &AppState, + conn: &mut Conn, + cache_pool: &redis::Client, me: &Me, user_uuid: Uuid, guild_uuid: Uuid, ) -> Result { - let mut conn = app_state.pool.get().await?; - use guild_members::dsl; let member: MemberBuilder = dsl::guild_members .filter(dsl::user_uuid.eq(user_uuid)) .filter(dsl::guild_uuid.eq(guild_uuid)) .select(MemberBuilder::as_select()) - .get_result(&mut conn) + .get_result(conn) .await?; - member.build(app_state, Some(me)).await + member.build(conn, cache_pool, Some(me)).await } pub async fn fetch_all( - app_state: &AppState, + conn: &mut Conn, + cache_pool: &redis::Client, me: &Me, guild_uuid: Uuid, ) -> Result, Error> { - let mut conn = app_state.pool.get().await?; - use guild_members::dsl; let member_builders: Vec = load_or_empty( dsl::guild_members .filter(dsl::guild_uuid.eq(guild_uuid)) .select(MemberBuilder::as_select()) - .load(&mut conn) + .load(conn) .await, )?; let mut members = vec![]; for builder in member_builders { - members.push(builder.build(app_state, Some(me)).await?); + members.push(builder.build(conn, cache_pool, Some(me)).await?); } Ok(members) } pub async fn new( - app_state: &AppState, + conn: &mut Conn, + cache_pool: &redis::Client, user_uuid: Uuid, guild_uuid: Uuid, ) -> Result { - let mut conn = app_state.pool.get().await?; - let member_uuid = Uuid::now_v7(); let member = MemberBuilder { @@ -163,9 +166,9 @@ impl Member { insert_into(guild_members::table) .values(&member) - .execute(&mut conn) + .execute(conn) .await?; - member.build(app_state, None).await + member.build(conn, cache_pool, None).await } } diff --git a/src/objects/message.rs b/src/objects/message.rs index caff969..a5224e0 100644 --- a/src/objects/message.rs +++ b/src/objects/message.rs @@ -2,7 +2,7 @@ use diesel::{Insertable, Queryable, Selectable}; use serde::Serialize; use uuid::Uuid; -use crate::{AppState, error::Error, schema::messages}; +use crate::{Conn, error::Error, schema::messages}; use super::User; @@ -18,8 +18,12 @@ pub struct MessageBuilder { } impl MessageBuilder { - pub async fn build(&self, app_state: &AppState) -> Result { - let user = User::fetch_one(app_state, self.user_uuid).await?; + pub async fn build( + &self, + conn: &mut Conn, + cache_pool: &redis::Client, + ) -> Result { + let user = User::fetch_one(conn, cache_pool, self.user_uuid).await?; Ok(Message { uuid: self.uuid, diff --git a/src/objects/password_reset_token.rs b/src/objects/password_reset_token.rs index 04ff43c..ca5c62f 100644 --- a/src/objects/password_reset_token.rs +++ b/src/objects/password_reset_token.rs @@ -10,10 +10,10 @@ use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::{ - AppState, + AppState, Conn, error::Error, schema::users, - utils::{PASSWORD_REGEX, generate_token, global_checks, user_uuid_from_identifier}, + utils::{CacheFns, PASSWORD_REGEX, generate_token, global_checks, user_uuid_from_identifier}, }; #[derive(Serialize, Deserialize)] @@ -24,50 +24,49 @@ pub struct PasswordResetToken { } impl PasswordResetToken { - pub async fn get(app_state: &AppState, token: String) -> Result { - let user_uuid: Uuid = - serde_json::from_str(&app_state.get_cache_key(token.to_string()).await?)?; - let password_reset_token = serde_json::from_str( - &app_state - .get_cache_key(format!("{user_uuid}_password_reset")) - .await?, - )?; + pub async fn get( + cache_pool: &redis::Client, + token: String, + ) -> Result { + let user_uuid: Uuid = cache_pool.get_cache_key(token.to_string()).await?; + let password_reset_token = cache_pool + .get_cache_key(format!("{user_uuid}_password_reset")) + .await?; Ok(password_reset_token) } pub async fn get_with_identifier( - app_state: &AppState, + conn: &mut Conn, + cache_pool: &redis::Client, identifier: String, ) -> Result { - let mut conn = app_state.pool.get().await?; + let user_uuid = user_uuid_from_identifier(conn, &identifier).await?; - let user_uuid = user_uuid_from_identifier(&mut conn, &identifier).await?; - - let password_reset_token = serde_json::from_str( - &app_state - .get_cache_key(format!("{user_uuid}_password_reset")) - .await?, - )?; + let password_reset_token = cache_pool + .get_cache_key(format!("{user_uuid}_password_reset")) + .await?; Ok(password_reset_token) } #[allow(clippy::new_ret_no_self)] - pub async fn new(app_state: &AppState, identifier: String) -> Result<(), Error> { + pub async fn new( + conn: &mut Conn, + app_state: &AppState, + identifier: String, + ) -> Result<(), Error> { let token = generate_token::<32>()?; - let mut conn = app_state.pool.get().await?; + let user_uuid = user_uuid_from_identifier(conn, &identifier).await?; - let user_uuid = user_uuid_from_identifier(&mut conn, &identifier).await?; - - global_checks(app_state, user_uuid).await?; + global_checks(conn, &app_state.config, user_uuid).await?; use users::dsl as udsl; let (username, email_address): (String, String) = udsl::users .filter(udsl::uuid.eq(user_uuid)) .select((udsl::username, udsl::email)) - .get_result(&mut conn) + .get_result(conn) .await?; let password_reset_token = PasswordResetToken { @@ -77,6 +76,7 @@ impl PasswordResetToken { }; app_state + .cache_pool .set_cache_key( format!("{user_uuid}_password_reset"), password_reset_token, @@ -84,6 +84,7 @@ impl PasswordResetToken { ) .await?; app_state + .cache_pool .set_cache_key(token.clone(), user_uuid, 86400) .await?; @@ -106,7 +107,12 @@ impl PasswordResetToken { Ok(()) } - pub async fn set_password(&self, app_state: &AppState, password: String) -> Result<(), Error> { + pub async fn set_password( + &self, + conn: &mut Conn, + app_state: &AppState, + password: String, + ) -> Result<(), Error> { if !PASSWORD_REGEX.is_match(&password) { return Err(Error::BadRequest( "Please provide a valid password".to_string(), @@ -120,19 +126,17 @@ impl PasswordResetToken { .hash_password(password.as_bytes(), &salt) .map_err(|e| Error::PasswordHashError(e.to_string()))?; - let mut conn = app_state.pool.get().await?; - use users::dsl; update(users::table) .filter(dsl::uuid.eq(self.user_uuid)) .set(dsl::password.eq(hashed_password.to_string())) - .execute(&mut conn) + .execute(conn) .await?; let (username, email_address): (String, String) = dsl::users .filter(dsl::uuid.eq(self.user_uuid)) .select((dsl::username, dsl::email)) - .get_result(&mut conn) + .get_result(conn) .await?; let login_page = app_state.config.web.frontend_url.join("login")?; @@ -149,14 +153,14 @@ impl PasswordResetToken { app_state.mail_client.send_mail(email).await?; - self.delete(app_state).await + self.delete(&app_state.cache_pool).await } - pub async fn delete(&self, app_state: &AppState) -> Result<(), Error> { - app_state + pub async fn delete(&self, cache_pool: &redis::Client) -> Result<(), Error> { + cache_pool .del_cache_key(format!("{}_password_reset", &self.user_uuid)) .await?; - app_state.del_cache_key(self.token.to_string()).await?; + cache_pool.del_cache_key(self.token.to_string()).await?; Ok(()) } diff --git a/src/objects/role.rs b/src/objects/role.rs index ea70686..01e4738 100644 --- a/src/objects/role.rs +++ b/src/objects/role.rs @@ -7,10 +7,10 @@ use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::{ - AppState, Conn, + Conn, error::Error, schema::{role_members, roles}, - utils::order_by_is_above, + utils::{CacheFns, order_by_is_above}, }; use super::{HasIsAbove, HasUuid, load_or_empty}; @@ -75,34 +75,33 @@ impl Role { } pub async fn fetch_from_member( - app_state: &AppState, + conn: &mut Conn, + cache_pool: &redis::Client, member_uuid: Uuid, ) -> Result, Error> { - if let Ok(roles) = app_state + if let Ok(roles) = cache_pool .get_cache_key(format!("{member_uuid}_roles")) .await { - return Ok(serde_json::from_str(&roles)?); + return Ok(roles); } - let mut conn = app_state.pool.get().await?; - use role_members::dsl; let role_memberships: Vec = load_or_empty( dsl::role_members .filter(dsl::member_uuid.eq(member_uuid)) .select(RoleMember::as_select()) - .load(&mut conn) + .load(conn) .await, )?; let mut roles = vec![]; for membership in role_memberships { - roles.push(membership.fetch_role(&mut conn).await?); + roles.push(membership.fetch_role(conn).await?); } - app_state + cache_pool .set_cache_key(format!("{member_uuid}_roles"), roles.clone(), 300) .await?; diff --git a/src/objects/user.rs b/src/objects/user.rs index c1f164d..a686c39 100644 --- a/src/objects/user.rs +++ b/src/objects/user.rs @@ -4,7 +4,7 @@ use diesel_async::RunQueryDsl; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use crate::{AppState, Conn, error::Error, objects::Me, schema::users}; +use crate::{Conn, error::Error, objects::Me, schema::users, utils::CacheFns}; use super::load_or_empty; @@ -46,23 +46,25 @@ pub struct User { } impl User { - pub async fn fetch_one(app_state: &AppState, user_uuid: Uuid) -> Result { - let mut conn = app_state.pool.get().await?; - - if let Ok(cache_hit) = app_state.get_cache_key(user_uuid.to_string()).await { - return Ok(serde_json::from_str(&cache_hit)?); + pub async fn fetch_one( + conn: &mut Conn, + cache_pool: &redis::Client, + user_uuid: Uuid, + ) -> Result { + if let Ok(cache_hit) = cache_pool.get_cache_key(user_uuid.to_string()).await { + return Ok(cache_hit); } use users::dsl; let user_builder: UserBuilder = dsl::users .filter(dsl::uuid.eq(user_uuid)) .select(UserBuilder::as_select()) - .get_result(&mut conn) + .get_result(conn) .await?; let user = user_builder.build(); - app_state + cache_pool .set_cache_key(user_uuid.to_string(), user.clone(), 1800) .await?; @@ -70,15 +72,14 @@ impl User { } pub async fn fetch_one_with_friendship( - app_state: &AppState, + conn: &mut Conn, + cache_pool: &redis::Client, me: &Me, user_uuid: Uuid, ) -> Result { - let mut conn = app_state.pool.get().await?; + let mut user = Self::fetch_one(conn, cache_pool, user_uuid).await?; - let mut user = Self::fetch_one(app_state, user_uuid).await?; - - if let Some(friend) = me.friends_with(&mut conn, user_uuid).await? { + if let Some(friend) = me.friends_with(conn, user_uuid).await? { user.friends_since = Some(friend.accepted_at); } diff --git a/src/utils.rs b/src/utils.rs index e1df906..7ef880e 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -8,14 +8,13 @@ use diesel::{ExpressionMethods, QueryDsl}; use diesel_async::RunQueryDsl; use getrandom::fill; use hex::encode; -use redis::RedisError; use regex::Regex; -use serde::Serialize; +use serde::{Serialize, de::DeserializeOwned}; use time::Duration; use uuid::Uuid; use crate::{ - AppState, Conn, + Conn, config::Config, error::Error, objects::{HasIsAbove, HasUuid}, @@ -115,15 +114,13 @@ pub async fn user_uuid_from_username(conn: &mut Conn, username: &String) -> Resu } } -pub async fn global_checks(app_state: &AppState, user_uuid: Uuid) -> Result<(), Error> { - if app_state.config.instance.require_email_verification { - let mut conn = app_state.pool.get().await?; - +pub async fn global_checks(conn: &mut Conn, config: &Config, user_uuid: Uuid) -> Result<(), Error> { + if config.instance.require_email_verification { use users::dsl; let email_verified: bool = dsl::users .filter(dsl::uuid.eq(user_uuid)) .select(dsl::email_verified) - .get_result(&mut conn) + .get_result(conn) .await?; if !email_verified { @@ -161,14 +158,28 @@ where Ok(ordered) } -impl AppState { - pub async fn set_cache_key( +#[allow(async_fn_in_trait)] +pub trait CacheFns { + async fn set_cache_key( + &self, + key: String, + value: impl Serialize, + expire: u32, + ) -> Result<(), Error>; + async fn get_cache_key(&self, key: String) -> Result + where + T: DeserializeOwned; + async fn del_cache_key(&self, key: String) -> Result<(), Error>; +} + +impl CacheFns for redis::Client { + async fn set_cache_key( &self, key: String, value: impl Serialize, expire: u32, ) -> Result<(), Error> { - let mut conn = self.cache_pool.get_multiplexed_tokio_connection().await?; + let mut conn = self.get_multiplexed_tokio_connection().await?; let key_encoded = encode(key); @@ -187,26 +198,31 @@ impl AppState { Ok(()) } - pub async fn get_cache_key(&self, key: String) -> Result { - let mut conn = self.cache_pool.get_multiplexed_tokio_connection().await?; + async fn get_cache_key(&self, key: String) -> Result + where + T: DeserializeOwned, + { + let mut conn = self.get_multiplexed_tokio_connection().await?; let key_encoded = encode(key); - redis::cmd("GET") + let res: String = redis::cmd("GET") .arg(key_encoded) .query_async(&mut conn) - .await + .await?; + + Ok(serde_json::from_str(&res)?) } - pub async fn del_cache_key(&self, key: String) -> Result<(), RedisError> { - let mut conn = self.cache_pool.get_multiplexed_tokio_connection().await?; + async fn del_cache_key(&self, key: String) -> Result<(), Error> { + let mut conn = self.get_multiplexed_tokio_connection().await?; let key_encoded = encode(key); - redis::cmd("DEL") + Ok(redis::cmd("DEL") .arg(key_encoded) .query_async(&mut conn) - .await + .await?) } }