diff --git a/src/api/mod.rs b/src/api/mod.rs index a00d1e5..988ee45 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -2,15 +2,15 @@ use std::sync::Arc; -use axum::{Router, routing::get}; +use axum::{routing::get, Router}; use crate::AppState; mod v1; mod versions; -pub fn router(path: &str) -> Router> { +pub fn router(path: &str, app_state: Arc) -> Router> { Router::new() .route(&format!("{path}/versions"), get(versions::versions)) - .nest(&format!("{path}/v1"), v1::router()) + .nest(&format!("{path}/v1"), v1::router(app_state)) } diff --git a/src/api/v1/auth/devices.rs b/src/api/v1/auth/devices.rs index a3c12d1..336a52f 100644 --- a/src/api/v1/auth/devices.rs +++ b/src/api/v1/auth/devices.rs @@ -2,20 +2,14 @@ use std::sync::Arc; -use axum::{Json, extract::State, http::StatusCode, response::IntoResponse}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, -}; +use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json}; use diesel::{ExpressionMethods, QueryDsl, Queryable, Selectable, SelectableHelper}; use diesel_async::RunQueryDsl; use serde::Serialize; +use uuid::Uuid; use crate::{ - AppState, - api::v1::auth::check_access_token, - error::Error, - schema::refresh_tokens::{self, dsl}, + api::v1::auth::CurrentUser, error::Error, schema::refresh_tokens::{self, dsl}, AppState }; #[derive(Serialize, Selectable, Queryable)] @@ -42,16 +36,12 @@ struct Device { /// ``` pub async fn get( State(app_state): State>, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - let devices: Vec = dsl::refresh_tokens .filter(dsl::uuid.eq(uuid)) .select(Device::as_select()) - .get_results(&mut conn) + .get_results(&mut app_state.pool.get().await?) .await?; Ok((StatusCode::OK, Json(devices))) diff --git a/src/api/v1/auth/mod.rs b/src/api/v1/auth/mod.rs index 59d7a8e..899d6d2 100644 --- a/src/api/v1/auth/mod.rs +++ b/src/api/v1/auth/mod.rs @@ -4,9 +4,9 @@ use std::{ }; use axum::{ - Router, - routing::{delete, get, post}, + extract::{Request, State}, middleware::{from_fn_with_state, Next}, response::IntoResponse, routing::{delete, get, post}, Router }; +use axum_extra::{headers::{authorization::Bearer, Authorization}, TypedHeader}; use diesel::{ExpressionMethods, QueryDsl}; use diesel_async::RunQueryDsl; use serde::Serialize; @@ -30,18 +30,22 @@ pub struct Response { } -pub fn router() -> Router> { +pub fn router(app_state: Arc) -> Router> { + let router_with_auth = Router::new() + .route("/verify-email", get(verify_email::get)) + .route("/verify-email", post(verify_email::post)) + .route("/revoke", post(revoke::post)) + .route("/devices", get(devices::get)) + .layer(from_fn_with_state(app_state, CurrentUser::check_auth_layer)); + Router::new() .route("/register", post(register::post)) .route("/login", post(login::response)) .route("/logout", delete(logout::res)) .route("/refresh", post(refresh::post)) - .route("/revoke", post(revoke::post)) - .route("/verify-email", get(verify_email::get)) - .route("/verify-email", post(verify_email::post)) .route("/reset-password", get(reset_password::get)) .route("/reset-password", post(reset_password::post)) - .route("/devices", get(devices::get)) + .merge(router_with_auth) } pub async fn check_access_token(access_token: &str, conn: &mut Conn) -> Result { @@ -68,3 +72,20 @@ pub async fn check_access_token(access_token: &str, conn: &mut Conn) -> Result(pub Uuid); + +impl CurrentUser { + pub async fn check_auth_layer( + State(app_state): State>, + TypedHeader(auth): TypedHeader>, + mut req: Request, + next: Next + ) -> Result { + let current_user = CurrentUser(check_access_token(auth.token(), &mut app_state.pool.get().await?).await?); + + req.extensions_mut().insert(current_user); + Ok(next.run(req).await) + } +} diff --git a/src/api/v1/auth/revoke.rs b/src/api/v1/auth/revoke.rs index 50aa6d2..b59172e 100644 --- a/src/api/v1/auth/revoke.rs +++ b/src/api/v1/auth/revoke.rs @@ -1,21 +1,14 @@ use std::sync::Arc; use argon2::{PasswordHash, PasswordVerifier}; -use axum::{Json, extract::State, http::StatusCode, response::IntoResponse}; -use axum_extra::{ - TypedHeader, - headers::authorization::{Authorization, Bearer}, -}; +use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json}; use diesel::{ExpressionMethods, QueryDsl, delete}; use diesel_async::RunQueryDsl; use serde::Deserialize; +use uuid::Uuid; use crate::{ - AppState, - api::v1::auth::check_access_token, - error::Error, - schema::refresh_tokens::{self, dsl as rdsl}, - schema::users::dsl as udsl, + api::v1::auth::CurrentUser, error::Error, schema::{refresh_tokens::{self, dsl as rdsl}, users::dsl as udsl}, AppState }; #[derive(Deserialize)] @@ -28,13 +21,11 @@ pub struct RevokeRequest { #[axum::debug_handler] pub async fn post( State(app_state): State>, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, Json(revoke_request): Json, ) -> Result { let mut conn = app_state.pool.get().await?; - let uuid = check_access_token(auth.token(), &mut conn).await?; - let database_password: String = udsl::users .filter(udsl::uuid.eq(uuid)) .select(udsl::password) diff --git a/src/api/v1/auth/verify_email.rs b/src/api/v1/auth/verify_email.rs index 28aa1ab..1270966 100644 --- a/src/api/v1/auth/verify_email.rs +++ b/src/api/v1/auth/verify_email.rs @@ -5,20 +5,14 @@ use std::sync::Arc; use axum::{ extract::{Query, State}, http::StatusCode, - response::IntoResponse, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + response::IntoResponse, Extension, }; use chrono::{Duration, Utc}; use serde::Deserialize; +use uuid::Uuid; use crate::{ - AppState, - api::v1::auth::check_access_token, - error::Error, - objects::{EmailToken, Me}, + api::v1::auth::CurrentUser, error::Error, objects::{EmailToken, Me}, AppState }; #[derive(Deserialize)] @@ -47,12 +41,10 @@ pub struct QueryParams { pub async fn get( State(app_state): State>, Query(query): Query, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension> ) -> Result { let mut conn = app_state.pool.get().await?; - let uuid = check_access_token(auth.token(), &mut conn).await?; - let me = Me::get(&mut conn, uuid).await?; if me.email_verified { @@ -87,13 +79,9 @@ pub async fn get( /// pub async fn post( State(app_state): State>, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension> ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - - let me = Me::get(&mut conn, uuid).await?; + let me = Me::get(&mut app_state.pool.get().await?, uuid).await?; if me.email_verified { return Ok(StatusCode::NO_CONTENT); diff --git a/src/api/v1/channels/mod.rs b/src/api/v1/channels/mod.rs index dc82b86..24b62f7 100644 --- a/src/api/v1/channels/mod.rs +++ b/src/api/v1/channels/mod.rs @@ -11,14 +11,9 @@ use crate::AppState; mod uuid; pub fn router() -> Router> { - //let (layer, io) = SocketIo::new_layer(); - - //io.ns("/{uuid}/socket", uuid::socket::ws); - Router::new() .route("/{uuid}", get(uuid::get)) .route("/{uuid}", delete(uuid::delete)) .route("/{uuid}", patch(uuid::patch)) .route("/{uuid}/messages", get(uuid::messages::get)) - //.layer(layer) } diff --git a/src/api/v1/channels/uuid/messages.rs b/src/api/v1/channels/uuid/messages.rs index 8c12ee0..0297bbc 100644 --- a/src/api/v1/channels/uuid/messages.rs +++ b/src/api/v1/channels/uuid/messages.rs @@ -3,22 +3,11 @@ use std::sync::Arc; use crate::{ - AppState, - api::v1::auth::check_access_token, - error::Error, - objects::{Channel, Member}, - utils::global_checks, + api::v1::auth::CurrentUser, error::Error, objects::{Channel, Member}, utils::global_checks, AppState }; use ::uuid::Uuid; use axum::{ - Json, - extract::{Path, Query, State}, - http::StatusCode, - response::IntoResponse, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{Path, Query, State}, http::StatusCode, response::IntoResponse, Extension, Json }; use serde::Deserialize; @@ -62,17 +51,13 @@ pub async fn get( State(app_state): State>, Path(channel_uuid): Path, Query(message_request): Query, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; let channel = Channel::fetch_one(&app_state, channel_uuid).await?; - Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; + Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid).await?; let messages = channel .fetch_messages(&app_state, message_request.amount, message_request.offset) diff --git a/src/api/v1/channels/uuid/mod.rs b/src/api/v1/channels/uuid/mod.rs index 3ce91c3..c1560f0 100644 --- a/src/api/v1/channels/uuid/mod.rs +++ b/src/api/v1/channels/uuid/mod.rs @@ -7,38 +7,28 @@ use std::sync::Arc; use crate::{ AppState, - api::v1::auth::check_access_token, + api::v1::auth::CurrentUser, error::Error, objects::{Channel, Member, Permissions}, utils::global_checks, }; use axum::{ - Json, - extract::{Path, State}, - http::StatusCode, - response::IntoResponse, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json }; + use serde::Deserialize; use uuid::Uuid; pub async fn get( State(app_state): State>, Path(channel_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; let channel = Channel::fetch_one(&app_state, channel_uuid).await?; - Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; + Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid).await?; Ok((StatusCode::OK, Json(channel))) } @@ -46,17 +36,13 @@ pub async fn get( pub async fn delete( State(app_state): State>, Path(channel_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; let channel = Channel::fetch_one(&app_state, channel_uuid).await?; - let member = Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; + let member = Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid).await?; member .check_permission(&app_state, Permissions::ManageChannel) @@ -108,18 +94,14 @@ pub struct NewInfo { pub async fn patch( State(app_state): State>, Path(channel_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, Json(new_info): Json, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; let mut channel = Channel::fetch_one(&app_state, channel_uuid).await?; - let member = Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; + let member = Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid).await?; member .check_permission(&app_state, Permissions::ManageChannel) diff --git a/src/api/v1/guilds/mod.rs b/src/api/v1/guilds/mod.rs index 18a117f..dbee589 100644 --- a/src/api/v1/guilds/mod.rs +++ b/src/api/v1/guilds/mod.rs @@ -3,26 +3,15 @@ use std::sync::Arc; use axum::{ - Json, Router, - extract::State, - http::StatusCode, - response::IntoResponse, - routing::{get, post}, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::State, http::StatusCode, response::IntoResponse, routing::{get, post}, Extension, Json, Router }; use serde::Deserialize; +use ::uuid::Uuid; mod uuid; use crate::{ - AppState, - api::v1::auth::check_access_token, - error::Error, - objects::{Guild, StartAmountQuery}, - utils::global_checks, + api::v1::auth::CurrentUser, error::Error, objects::{Guild, StartAmountQuery}, utils::global_checks, AppState }; #[derive(Deserialize)] @@ -63,14 +52,10 @@ pub fn router() -> Router> { /// NOTE: UUIDs in this response are made using `uuidgen`, UUIDs made by the actual backend will be UUIDv7 and have extractable timestamps pub async fn new( State(app_state): State>, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, Json(guild_info): Json, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - - let guild = Guild::new(&mut conn, guild_info.name.clone(), uuid).await?; + let guild = Guild::new(&mut app_state.pool.get().await?, guild_info.name.clone(), uuid).await?; Ok((StatusCode::OK, Json(guild))) } @@ -124,15 +109,12 @@ pub async fn new( /// NOTE: UUIDs in this response are made using `uuidgen`, UUIDs made by the actual backend will be UUIDv7 and have extractable timestamps pub async fn get_guilds( State(app_state): State>, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, Json(request_query): Json, ) -> Result { let start = request_query.start.unwrap_or(0); - let amount = request_query.amount.unwrap_or(10); - let uuid = check_access_token(auth.token(), &mut app_state.pool.get().await?).await?; - global_checks(&app_state, uuid).await?; let guilds = Guild::fetch_amount(&app_state.pool, start, amount).await?; diff --git a/src/api/v1/guilds/uuid/channels.rs b/src/api/v1/guilds/uuid/channels.rs index 0104566..a28aa6c 100644 --- a/src/api/v1/guilds/uuid/channels.rs +++ b/src/api/v1/guilds/uuid/channels.rs @@ -2,23 +2,12 @@ use std::sync::Arc; use ::uuid::Uuid; use axum::{ - Json, - extract::{Path, State}, - http::StatusCode, - response::IntoResponse, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json }; use serde::Deserialize; use crate::{ - AppState, - api::v1::auth::check_access_token, - error::Error, - objects::{Channel, Member, Permissions}, - utils::{global_checks, order_by_is_above}, + api::v1::auth::CurrentUser, error::Error, objects::{Channel, Member, Permissions}, utils::{global_checks, order_by_is_above}, AppState }; #[derive(Deserialize)] @@ -30,15 +19,11 @@ pub struct ChannelInfo { pub async fn get( State(app_state): State>, Path(guild_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; - Member::check_membership(&mut conn, uuid, guild_uuid).await?; + Member::check_membership(&mut app_state.pool.get().await?, uuid, guild_uuid).await?; if let Ok(cache_hit) = app_state .get_cache_key(format!("{guild_uuid}_channels")) @@ -65,16 +50,12 @@ pub async fn get( pub async fn create( State(app_state): State>, Path(guild_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, Json(channel_info): Json, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; - let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; + let member = Member::check_membership(&mut app_state.pool.get().await?, uuid, guild_uuid).await?; member .check_permission(&app_state, Permissions::ManageChannel) diff --git a/src/api/v1/guilds/uuid/invites/mod.rs b/src/api/v1/guilds/uuid/invites/mod.rs index 7703cf7..2070452 100644 --- a/src/api/v1/guilds/uuid/invites/mod.rs +++ b/src/api/v1/guilds/uuid/invites/mod.rs @@ -1,21 +1,14 @@ use std::sync::Arc; use axum::{ - Json, - extract::{Path, State}, - http::StatusCode, - response::IntoResponse, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json }; use serde::Deserialize; use uuid::Uuid; use crate::{ AppState, - api::v1::auth::check_access_token, + api::v1::auth::CurrentUser, error::Error, objects::{Guild, Member, Permissions}, utils::global_checks, @@ -29,14 +22,12 @@ pub struct InviteRequest { pub async fn get( State(app_state): State>, Path(guild_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; + Member::check_membership(&mut conn, uuid, guild_uuid).await?; let guild = Guild::fetch_one(&mut conn, guild_uuid).await?; @@ -49,15 +40,13 @@ pub async fn get( pub async fn create( State(app_state): State>, Path(guild_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, Json(invite_request): Json, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; + let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; member diff --git a/src/api/v1/guilds/uuid/members.rs b/src/api/v1/guilds/uuid/members.rs index bd2f853..6c8b980 100644 --- a/src/api/v1/guilds/uuid/members.rs +++ b/src/api/v1/guilds/uuid/members.rs @@ -2,19 +2,12 @@ use std::sync::Arc; use ::uuid::Uuid; use axum::{ - Json, - extract::{Path, State}, - http::StatusCode, - response::IntoResponse, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json }; use crate::{ AppState, - api::v1::auth::check_access_token, + api::v1::auth::CurrentUser, error::Error, objects::{Me, Member}, utils::global_checks, @@ -23,14 +16,12 @@ use crate::{ pub async fn get( State(app_state): State>, Path(guild_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; + Member::check_membership(&mut conn, uuid, guild_uuid).await?; let me = Me::get(&mut conn, uuid).await?; diff --git a/src/api/v1/guilds/uuid/mod.rs b/src/api/v1/guilds/uuid/mod.rs index 0a27123..c5a809f 100644 --- a/src/api/v1/guilds/uuid/mod.rs +++ b/src/api/v1/guilds/uuid/mod.rs @@ -3,15 +3,7 @@ use std::sync::Arc; use axum::{ - Json, Router, - extract::{Multipart, Path, State}, - http::StatusCode, - response::IntoResponse, - routing::{get, patch, post}, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{Multipart, Path, State}, http::StatusCode, response::IntoResponse, routing::{get, patch, post}, Extension, Json, Router }; use bytes::Bytes; use uuid::Uuid; @@ -23,7 +15,7 @@ mod roles; use crate::{ AppState, - api::v1::auth::check_access_token, + api::v1::auth::CurrentUser, error::Error, objects::{Guild, Member, Permissions}, utils::global_checks, @@ -84,14 +76,12 @@ pub fn router() -> Router> { pub async fn get_guild( State(app_state): State>, Path(guild_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; + Member::check_membership(&mut conn, uuid, guild_uuid).await?; let guild = Guild::fetch_one(&mut conn, guild_uuid).await?; @@ -105,15 +95,13 @@ pub async fn get_guild( pub async fn edit( State(app_state): State>, Path(guild_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, mut multipart: Multipart, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; + let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; member diff --git a/src/api/v1/guilds/uuid/roles/mod.rs b/src/api/v1/guilds/uuid/roles/mod.rs index 12960c2..5331143 100644 --- a/src/api/v1/guilds/uuid/roles/mod.rs +++ b/src/api/v1/guilds/uuid/roles/mod.rs @@ -2,20 +2,13 @@ use std::sync::Arc; use ::uuid::Uuid; use axum::{ - Json, - extract::{Path, State}, - http::StatusCode, - response::IntoResponse, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json }; use serde::Deserialize; use crate::{ AppState, - api::v1::auth::check_access_token, + api::v1::auth::CurrentUser, error::Error, objects::{Member, Permissions, Role}, utils::{global_checks, order_by_is_above}, @@ -31,11 +24,11 @@ pub struct RoleInfo { pub async fn get( State(app_state): State>, Path(guild_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - let uuid = check_access_token(auth.token(), &mut conn).await?; + let mut conn = app_state.pool.get().await?; Member::check_membership(&mut conn, uuid, guild_uuid).await?; @@ -57,15 +50,13 @@ pub async fn get( pub async fn create( State(app_state): State>, Path(guild_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, Json(role_info): Json, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; + let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; member diff --git a/src/api/v1/guilds/uuid/roles/uuid.rs b/src/api/v1/guilds/uuid/roles/uuid.rs index a62a5b4..91300bf 100644 --- a/src/api/v1/guilds/uuid/roles/uuid.rs +++ b/src/api/v1/guilds/uuid/roles/uuid.rs @@ -2,19 +2,12 @@ use std::sync::Arc; use ::uuid::Uuid; use axum::{ - Json, - extract::{Path, State}, - http::StatusCode, - response::IntoResponse, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json }; use crate::{ AppState, - api::v1::auth::check_access_token, + api::v1::auth::CurrentUser, error::Error, objects::{Member, Role}, utils::global_checks, @@ -23,14 +16,12 @@ use crate::{ pub async fn get( State(app_state): State>, Path((guild_uuid, role_uuid)): Path<(Uuid, Uuid)>, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; + Member::check_membership(&mut conn, uuid, guild_uuid).await?; if let Ok(cache_hit) = app_state.get_cache_key(format!("{role_uuid}")).await { diff --git a/src/api/v1/invites/id.rs b/src/api/v1/invites/id.rs index b832557..c752177 100644 --- a/src/api/v1/invites/id.rs +++ b/src/api/v1/invites/id.rs @@ -1,19 +1,13 @@ use std::sync::Arc; use axum::{ - Json, - extract::{Path, State}, - http::StatusCode, - response::IntoResponse, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json }; +use uuid::Uuid; use crate::{ AppState, - api::v1::auth::check_access_token, + api::v1::auth::CurrentUser, error::Error, objects::{Guild, Invite, Member}, utils::global_checks, @@ -35,14 +29,12 @@ pub async fn get( pub async fn join( State(app_state): State>, Path(invite_id): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; + let invite = Invite::fetch_one(&mut conn, invite_id).await?; let guild = Guild::fetch_one(&mut conn, invite.guild_uuid).await?; diff --git a/src/api/v1/me/friends/mod.rs b/src/api/v1/me/friends/mod.rs index 8a7851c..63284a8 100644 --- a/src/api/v1/me/friends/mod.rs +++ b/src/api/v1/me/friends/mod.rs @@ -1,34 +1,23 @@ use std::sync::Arc; -use axum::{Json, extract::State, http::StatusCode, response::IntoResponse}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, -}; +use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json}; use serde::Deserialize; +use ::uuid::Uuid; pub mod uuid; use crate::{ - AppState, - api::v1::auth::check_access_token, - error::Error, - objects::Me, - utils::{global_checks, user_uuid_from_username}, + api::v1::auth::CurrentUser, error::Error, objects::Me, utils::{global_checks, user_uuid_from_username}, AppState }; /// Returns a list of users that are your friends pub async fn get( State(app_state): State>, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; - let me = Me::get(&mut conn, uuid).await?; + let me = Me::get(&mut app_state.pool.get().await?, uuid).await?; let friends = me.get_friends(&app_state).await?; @@ -61,15 +50,13 @@ pub struct UserReq { /// pub async fn post( State(app_state): State>, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, Json(user_request): Json, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; + let me = Me::get(&mut conn, uuid).await?; let target_uuid = user_uuid_from_username(&mut conn, &user_request.username).await?; diff --git a/src/api/v1/me/friends/uuid.rs b/src/api/v1/me/friends/uuid.rs index 8d40f26..5a32386 100644 --- a/src/api/v1/me/friends/uuid.rs +++ b/src/api/v1/me/friends/uuid.rs @@ -3,29 +3,23 @@ use std::sync::Arc; use axum::{ extract::{Path, State}, http::StatusCode, - response::IntoResponse, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + response::IntoResponse, Extension, }; use uuid::Uuid; use crate::{ - AppState, api::v1::auth::check_access_token, error::Error, objects::Me, utils::global_checks, + AppState, api::v1::auth::CurrentUser, error::Error, objects::Me, utils::global_checks, }; pub async fn delete( State(app_state): State>, Path(friend_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; + let me = Me::get(&mut conn, uuid).await?; me.remove_friend(&mut conn, friend_uuid).await?; diff --git a/src/api/v1/me/guilds.rs b/src/api/v1/me/guilds.rs index adfe845..a2d2111 100644 --- a/src/api/v1/me/guilds.rs +++ b/src/api/v1/me/guilds.rs @@ -2,14 +2,11 @@ use std::sync::Arc; -use axum::{Json, extract::State, http::StatusCode, response::IntoResponse}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, -}; +use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json}; +use uuid::Uuid; use crate::{ - AppState, api::v1::auth::check_access_token, error::Error, objects::Me, utils::global_checks, + AppState, api::v1::auth::CurrentUser, error::Error, objects::Me, utils::global_checks, }; /// `GET /api/v1/me/guilds` Returns all guild memberships in a list @@ -59,14 +56,12 @@ use crate::{ /// NOTE: UUIDs in this response are made using `uuidgen`, UUIDs made by the actual backend will be UUIDv7 and have extractable timestamps pub async fn get( State(app_state): State>, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; + let me = Me::get(&mut conn, uuid).await?; let memberships = me.fetch_memberships(&mut conn).await?; diff --git a/src/api/v1/me/mod.rs b/src/api/v1/me/mod.rs index e9680bc..ce577d4 100644 --- a/src/api/v1/me/mod.rs +++ b/src/api/v1/me/mod.rs @@ -1,21 +1,14 @@ use std::sync::Arc; use axum::{ - Json, Router, - extract::{DefaultBodyLimit, Multipart, State}, - http::StatusCode, - response::IntoResponse, - routing::{delete, get, patch, post}, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{DefaultBodyLimit, Multipart, State}, http::StatusCode, response::IntoResponse, routing::{delete, get, patch, post}, Extension, Json, Router }; use bytes::Bytes; use serde::Deserialize; +use uuid::Uuid; use crate::{ - AppState, api::v1::auth::check_access_token, error::Error, objects::Me, utils::global_checks, + api::v1::auth::CurrentUser, error::Error, objects::Me, utils::global_checks, AppState }; mod friends; @@ -38,13 +31,9 @@ pub fn router() -> Router> { pub async fn get_me( State(app_state): State>, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - - let me = Me::get(&mut conn, uuid).await?; + let me = Me::get(&mut app_state.pool.get().await?, uuid).await?; Ok((StatusCode::OK, Json(me))) } @@ -60,13 +49,9 @@ struct NewInfo { pub async fn update( State(app_state): State>, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, mut multipart: Multipart, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - let mut json_raw: Option = None; let mut avatar: Option = None; @@ -88,7 +73,7 @@ pub async fn update( global_checks(&app_state, uuid).await?; } - let mut me = Me::get(&mut conn, uuid).await?; + let mut me = Me::get(&mut app_state.pool.get().await?, uuid).await?; if let Some(avatar) = avatar { me.set_avatar(&app_state, app_state.config.bunny.cdn_url.clone(), avatar) diff --git a/src/api/v1/mod.rs b/src/api/v1/mod.rs index 4e8654b..f3e4305 100644 --- a/src/api/v1/mod.rs +++ b/src/api/v1/mod.rs @@ -2,9 +2,9 @@ use std::sync::Arc; -use axum::{routing::get, Router}; +use axum::{middleware::from_fn_with_state, routing::get, Router}; -use crate::AppState; +use crate::{api::v1::auth::CurrentUser, AppState}; mod auth; mod channels; @@ -14,13 +14,17 @@ mod me; mod stats; mod users; -pub fn router() -> Router> { - Router::new() - .route("/stats", get(stats::res)) - .nest("/auth", auth::router()) +pub fn router(app_state: Arc) -> Router> { + let router_with_auth = Router::new() .nest("/users", users::router()) .nest("/channels", channels::router()) .nest("/guilds", guilds::router()) .nest("/invites", invites::router()) .nest("/me", me::router()) + .layer(from_fn_with_state(app_state.clone(), CurrentUser::check_auth_layer)); + + Router::new() + .route("/stats", get(stats::res)) + .nest("/auth", auth::router(app_state)) + .merge(router_with_auth) } diff --git a/src/api/v1/users/mod.rs b/src/api/v1/users/mod.rs index f0d09c5..82f2125 100644 --- a/src/api/v1/users/mod.rs +++ b/src/api/v1/users/mod.rs @@ -3,23 +3,12 @@ use std::sync::Arc; use axum::{ - Json, Router, - extract::{Query, State}, - http::StatusCode, - response::IntoResponse, - routing::get, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{Query, State}, http::StatusCode, response::IntoResponse, routing::get, Extension, Json, Router }; +use ::uuid::Uuid; use crate::{ - AppState, - api::v1::auth::check_access_token, - error::Error, - objects::{StartAmountQuery, User}, - utils::global_checks, + api::v1::auth::CurrentUser, error::Error, objects::{StartAmountQuery, User}, utils::global_checks, AppState }; mod uuid; @@ -63,7 +52,7 @@ pub fn router() -> Router> { pub async fn users( State(app_state): State>, Query(request_query): Query, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { let start = request_query.start.unwrap_or(0); @@ -73,13 +62,9 @@ pub async fn users( return Ok(StatusCode::BAD_REQUEST.into_response()); } - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; - let users = User::fetch_amount(&mut conn, start, amount).await?; + let users = User::fetch_amount(&mut app_state.pool.get().await?, start, amount).await?; Ok((StatusCode::OK, Json(users)).into_response()) } diff --git a/src/api/v1/users/uuid.rs b/src/api/v1/users/uuid.rs index 1b7d43b..2bdcfac 100644 --- a/src/api/v1/users/uuid.rs +++ b/src/api/v1/users/uuid.rs @@ -3,23 +3,12 @@ use std::sync::Arc; use axum::{ - Json, - extract::{Path, State}, - http::StatusCode, - response::IntoResponse, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json }; use uuid::Uuid; use crate::{ - AppState, - api::v1::auth::check_access_token, - error::Error, - objects::{Me, User}, - utils::global_checks, + api::v1::auth::CurrentUser, error::Error, objects::{Me, User}, utils::global_checks, AppState }; /// `GET /api/v1/users/{uuid}` Returns user with the given UUID @@ -41,15 +30,11 @@ use crate::{ pub async fn get( State(app_state): State>, Path(user_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; - let me = Me::get(&mut conn, uuid).await?; + let me = Me::get(&mut app_state.pool.get().await?, uuid).await?; let user = User::fetch_one_with_friendship(&app_state, &me, user_uuid).await?; diff --git a/src/main.rs b/src/main.rs index ab37924..8e6effc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -163,7 +163,7 @@ async fn main() -> Result<(), Error> { // build our application with a route let app = Router::new() // `GET /` goes to `root` - .merge(api::router(web.backend_url.path().trim_end_matches("/"))) + .merge(api::router(web.backend_url.path().trim_end_matches("/"), app_state.clone())) .with_state(app_state) .layer(cors) .layer(socket_io);