diff --git a/src/api/mod.rs b/src/api/mod.rs index 988ee45..a00d1e5 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -2,15 +2,15 @@ use std::sync::Arc; -use axum::{routing::get, Router}; +use axum::{Router, routing::get}; use crate::AppState; mod v1; mod versions; -pub fn router(path: &str, app_state: Arc) -> Router> { +pub fn router(path: &str) -> Router> { Router::new() .route(&format!("{path}/versions"), get(versions::versions)) - .nest(&format!("{path}/v1"), v1::router(app_state)) + .nest(&format!("{path}/v1"), v1::router()) } diff --git a/src/api/v1/auth/devices.rs b/src/api/v1/auth/devices.rs index 336a52f..a3c12d1 100644 --- a/src/api/v1/auth/devices.rs +++ b/src/api/v1/auth/devices.rs @@ -2,14 +2,20 @@ use std::sync::Arc; -use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json}; +use axum::{Json, extract::State, http::StatusCode, response::IntoResponse}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; use diesel::{ExpressionMethods, QueryDsl, Queryable, Selectable, SelectableHelper}; use diesel_async::RunQueryDsl; use serde::Serialize; -use uuid::Uuid; use crate::{ - api::v1::auth::CurrentUser, error::Error, schema::refresh_tokens::{self, dsl}, AppState + AppState, + api::v1::auth::check_access_token, + error::Error, + schema::refresh_tokens::{self, dsl}, }; #[derive(Serialize, Selectable, Queryable)] @@ -36,12 +42,16 @@ struct Device { /// ``` pub async fn get( State(app_state): State>, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, ) -> Result { + let mut conn = app_state.pool.get().await?; + + let uuid = check_access_token(auth.token(), &mut conn).await?; + let devices: Vec = dsl::refresh_tokens .filter(dsl::uuid.eq(uuid)) .select(Device::as_select()) - .get_results(&mut app_state.pool.get().await?) + .get_results(&mut conn) .await?; Ok((StatusCode::OK, Json(devices))) diff --git a/src/api/v1/auth/logout.rs b/src/api/v1/auth/logout.rs index 977d452..906afcc 100644 --- a/src/api/v1/auth/logout.rs +++ b/src/api/v1/auth/logout.rs @@ -38,6 +38,8 @@ pub async fn res( ))? .to_owned(); + let access_token_cookie = jar.get("access_token"); + let refresh_token = String::from(refresh_token_cookie.value_trimmed()); let mut conn = app_state.pool.get().await?; @@ -61,5 +63,13 @@ pub async fn res( HeaderValue::from_str(&refresh_token_cookie.to_string())?, ); + if let Some(cookie) = access_token_cookie { + let mut cookie = cookie.clone(); + cookie.make_removal(); + response + .headers_mut() + .append("Set-Cookie", HeaderValue::from_str(&cookie.to_string())?); + } + Ok(response) } diff --git a/src/api/v1/auth/mod.rs b/src/api/v1/auth/mod.rs index 899d6d2..59d7a8e 100644 --- a/src/api/v1/auth/mod.rs +++ b/src/api/v1/auth/mod.rs @@ -4,9 +4,9 @@ use std::{ }; use axum::{ - extract::{Request, State}, middleware::{from_fn_with_state, Next}, response::IntoResponse, routing::{delete, get, post}, Router + Router, + routing::{delete, get, post}, }; -use axum_extra::{headers::{authorization::Bearer, Authorization}, TypedHeader}; use diesel::{ExpressionMethods, QueryDsl}; use diesel_async::RunQueryDsl; use serde::Serialize; @@ -30,22 +30,18 @@ pub struct Response { } -pub fn router(app_state: Arc) -> Router> { - let router_with_auth = Router::new() - .route("/verify-email", get(verify_email::get)) - .route("/verify-email", post(verify_email::post)) - .route("/revoke", post(revoke::post)) - .route("/devices", get(devices::get)) - .layer(from_fn_with_state(app_state, CurrentUser::check_auth_layer)); - +pub fn router() -> Router> { Router::new() .route("/register", post(register::post)) .route("/login", post(login::response)) .route("/logout", delete(logout::res)) .route("/refresh", post(refresh::post)) + .route("/revoke", post(revoke::post)) + .route("/verify-email", get(verify_email::get)) + .route("/verify-email", post(verify_email::post)) .route("/reset-password", get(reset_password::get)) .route("/reset-password", post(reset_password::post)) - .merge(router_with_auth) + .route("/devices", get(devices::get)) } pub async fn check_access_token(access_token: &str, conn: &mut Conn) -> Result { @@ -72,20 +68,3 @@ pub async fn check_access_token(access_token: &str, conn: &mut Conn) -> Result(pub Uuid); - -impl CurrentUser { - pub async fn check_auth_layer( - State(app_state): State>, - TypedHeader(auth): TypedHeader>, - mut req: Request, - next: Next - ) -> Result { - let current_user = CurrentUser(check_access_token(auth.token(), &mut app_state.pool.get().await?).await?); - - req.extensions_mut().insert(current_user); - Ok(next.run(req).await) - } -} diff --git a/src/api/v1/auth/revoke.rs b/src/api/v1/auth/revoke.rs index b59172e..50aa6d2 100644 --- a/src/api/v1/auth/revoke.rs +++ b/src/api/v1/auth/revoke.rs @@ -1,14 +1,21 @@ use std::sync::Arc; use argon2::{PasswordHash, PasswordVerifier}; -use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json}; +use axum::{Json, extract::State, http::StatusCode, response::IntoResponse}; +use axum_extra::{ + TypedHeader, + headers::authorization::{Authorization, Bearer}, +}; use diesel::{ExpressionMethods, QueryDsl, delete}; use diesel_async::RunQueryDsl; use serde::Deserialize; -use uuid::Uuid; use crate::{ - api::v1::auth::CurrentUser, error::Error, schema::{refresh_tokens::{self, dsl as rdsl}, users::dsl as udsl}, AppState + AppState, + api::v1::auth::check_access_token, + error::Error, + schema::refresh_tokens::{self, dsl as rdsl}, + schema::users::dsl as udsl, }; #[derive(Deserialize)] @@ -21,11 +28,13 @@ pub struct RevokeRequest { #[axum::debug_handler] pub async fn post( State(app_state): State>, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, Json(revoke_request): Json, ) -> Result { let mut conn = app_state.pool.get().await?; + let uuid = check_access_token(auth.token(), &mut conn).await?; + let database_password: String = udsl::users .filter(udsl::uuid.eq(uuid)) .select(udsl::password) diff --git a/src/api/v1/auth/verify_email.rs b/src/api/v1/auth/verify_email.rs index 1270966..28aa1ab 100644 --- a/src/api/v1/auth/verify_email.rs +++ b/src/api/v1/auth/verify_email.rs @@ -5,14 +5,20 @@ use std::sync::Arc; use axum::{ extract::{Query, State}, http::StatusCode, - response::IntoResponse, Extension, + response::IntoResponse, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, }; use chrono::{Duration, Utc}; use serde::Deserialize; -use uuid::Uuid; use crate::{ - api::v1::auth::CurrentUser, error::Error, objects::{EmailToken, Me}, AppState + AppState, + api::v1::auth::check_access_token, + error::Error, + objects::{EmailToken, Me}, }; #[derive(Deserialize)] @@ -41,10 +47,12 @@ pub struct QueryParams { pub async fn get( State(app_state): State>, Query(query): Query, - Extension(CurrentUser(uuid)): Extension> + TypedHeader(auth): TypedHeader>, ) -> Result { let mut conn = app_state.pool.get().await?; + let uuid = check_access_token(auth.token(), &mut conn).await?; + let me = Me::get(&mut conn, uuid).await?; if me.email_verified { @@ -79,9 +87,13 @@ pub async fn get( /// pub async fn post( State(app_state): State>, - Extension(CurrentUser(uuid)): Extension> + TypedHeader(auth): TypedHeader>, ) -> Result { - let me = Me::get(&mut app_state.pool.get().await?, uuid).await?; + let mut conn = app_state.pool.get().await?; + + let uuid = check_access_token(auth.token(), &mut conn).await?; + + let me = Me::get(&mut conn, uuid).await?; if me.email_verified { return Ok(StatusCode::NO_CONTENT); diff --git a/src/api/v1/channels/mod.rs b/src/api/v1/channels/mod.rs index 24b62f7..dc82b86 100644 --- a/src/api/v1/channels/mod.rs +++ b/src/api/v1/channels/mod.rs @@ -11,9 +11,14 @@ use crate::AppState; mod uuid; pub fn router() -> Router> { + //let (layer, io) = SocketIo::new_layer(); + + //io.ns("/{uuid}/socket", uuid::socket::ws); + Router::new() .route("/{uuid}", get(uuid::get)) .route("/{uuid}", delete(uuid::delete)) .route("/{uuid}", patch(uuid::patch)) .route("/{uuid}/messages", get(uuid::messages::get)) + //.layer(layer) } diff --git a/src/api/v1/channels/uuid/messages.rs b/src/api/v1/channels/uuid/messages.rs index 0297bbc..8c12ee0 100644 --- a/src/api/v1/channels/uuid/messages.rs +++ b/src/api/v1/channels/uuid/messages.rs @@ -3,11 +3,22 @@ use std::sync::Arc; use crate::{ - api::v1::auth::CurrentUser, error::Error, objects::{Channel, Member}, utils::global_checks, AppState + AppState, + api::v1::auth::check_access_token, + error::Error, + objects::{Channel, Member}, + utils::global_checks, }; use ::uuid::Uuid; use axum::{ - extract::{Path, Query, State}, http::StatusCode, response::IntoResponse, Extension, Json + Json, + extract::{Path, Query, State}, + http::StatusCode, + response::IntoResponse, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, }; use serde::Deserialize; @@ -51,13 +62,17 @@ pub async fn get( State(app_state): State>, Path(channel_uuid): Path, Query(message_request): Query, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, ) -> Result { + let mut conn = app_state.pool.get().await?; + + let uuid = check_access_token(auth.token(), &mut conn).await?; + global_checks(&app_state, uuid).await?; let channel = Channel::fetch_one(&app_state, channel_uuid).await?; - Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid).await?; + Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; let messages = channel .fetch_messages(&app_state, message_request.amount, message_request.offset) diff --git a/src/api/v1/channels/uuid/mod.rs b/src/api/v1/channels/uuid/mod.rs index c1560f0..3ce91c3 100644 --- a/src/api/v1/channels/uuid/mod.rs +++ b/src/api/v1/channels/uuid/mod.rs @@ -7,28 +7,38 @@ use std::sync::Arc; use crate::{ AppState, - api::v1::auth::CurrentUser, + api::v1::auth::check_access_token, error::Error, objects::{Channel, Member, Permissions}, utils::global_checks, }; use axum::{ - extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json + Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, }; - use serde::Deserialize; use uuid::Uuid; pub async fn get( State(app_state): State>, Path(channel_uuid): Path, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, ) -> Result { + let mut conn = app_state.pool.get().await?; + + let uuid = check_access_token(auth.token(), &mut conn).await?; + global_checks(&app_state, uuid).await?; let channel = Channel::fetch_one(&app_state, channel_uuid).await?; - Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid).await?; + Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; Ok((StatusCode::OK, Json(channel))) } @@ -36,13 +46,17 @@ pub async fn get( pub async fn delete( State(app_state): State>, Path(channel_uuid): Path, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, ) -> Result { + let mut conn = app_state.pool.get().await?; + + let uuid = check_access_token(auth.token(), &mut conn).await?; + global_checks(&app_state, uuid).await?; let channel = Channel::fetch_one(&app_state, channel_uuid).await?; - let member = Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid).await?; + let member = Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; member .check_permission(&app_state, Permissions::ManageChannel) @@ -94,14 +108,18 @@ pub struct NewInfo { pub async fn patch( State(app_state): State>, Path(channel_uuid): Path, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, Json(new_info): Json, ) -> Result { + let mut conn = app_state.pool.get().await?; + + let uuid = check_access_token(auth.token(), &mut conn).await?; + global_checks(&app_state, uuid).await?; let mut channel = Channel::fetch_one(&app_state, channel_uuid).await?; - let member = Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid).await?; + let member = Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; member .check_permission(&app_state, Permissions::ManageChannel) diff --git a/src/api/v1/guilds/mod.rs b/src/api/v1/guilds/mod.rs index dbee589..18a117f 100644 --- a/src/api/v1/guilds/mod.rs +++ b/src/api/v1/guilds/mod.rs @@ -3,15 +3,26 @@ use std::sync::Arc; use axum::{ - extract::State, http::StatusCode, response::IntoResponse, routing::{get, post}, Extension, Json, Router + Json, Router, + extract::State, + http::StatusCode, + response::IntoResponse, + routing::{get, post}, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, }; use serde::Deserialize; -use ::uuid::Uuid; mod uuid; use crate::{ - api::v1::auth::CurrentUser, error::Error, objects::{Guild, StartAmountQuery}, utils::global_checks, AppState + AppState, + api::v1::auth::check_access_token, + error::Error, + objects::{Guild, StartAmountQuery}, + utils::global_checks, }; #[derive(Deserialize)] @@ -52,10 +63,14 @@ pub fn router() -> Router> { /// NOTE: UUIDs in this response are made using `uuidgen`, UUIDs made by the actual backend will be UUIDv7 and have extractable timestamps pub async fn new( State(app_state): State>, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, Json(guild_info): Json, ) -> Result { - let guild = Guild::new(&mut app_state.pool.get().await?, guild_info.name.clone(), uuid).await?; + let mut conn = app_state.pool.get().await?; + + let uuid = check_access_token(auth.token(), &mut conn).await?; + + let guild = Guild::new(&mut conn, guild_info.name.clone(), uuid).await?; Ok((StatusCode::OK, Json(guild))) } @@ -109,12 +124,15 @@ pub async fn new( /// NOTE: UUIDs in this response are made using `uuidgen`, UUIDs made by the actual backend will be UUIDv7 and have extractable timestamps pub async fn get_guilds( State(app_state): State>, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, Json(request_query): Json, ) -> Result { let start = request_query.start.unwrap_or(0); + let amount = request_query.amount.unwrap_or(10); + let uuid = check_access_token(auth.token(), &mut app_state.pool.get().await?).await?; + global_checks(&app_state, uuid).await?; let guilds = Guild::fetch_amount(&app_state.pool, start, amount).await?; diff --git a/src/api/v1/guilds/uuid/channels.rs b/src/api/v1/guilds/uuid/channels.rs index a28aa6c..0104566 100644 --- a/src/api/v1/guilds/uuid/channels.rs +++ b/src/api/v1/guilds/uuid/channels.rs @@ -2,12 +2,23 @@ use std::sync::Arc; use ::uuid::Uuid; use axum::{ - extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json + Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, }; use serde::Deserialize; use crate::{ - api::v1::auth::CurrentUser, error::Error, objects::{Channel, Member, Permissions}, utils::{global_checks, order_by_is_above}, AppState + AppState, + api::v1::auth::check_access_token, + error::Error, + objects::{Channel, Member, Permissions}, + utils::{global_checks, order_by_is_above}, }; #[derive(Deserialize)] @@ -19,11 +30,15 @@ pub struct ChannelInfo { pub async fn get( State(app_state): State>, Path(guild_uuid): Path, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, ) -> Result { + let mut conn = app_state.pool.get().await?; + + let uuid = check_access_token(auth.token(), &mut conn).await?; + global_checks(&app_state, uuid).await?; - Member::check_membership(&mut app_state.pool.get().await?, uuid, guild_uuid).await?; + Member::check_membership(&mut conn, uuid, guild_uuid).await?; if let Ok(cache_hit) = app_state .get_cache_key(format!("{guild_uuid}_channels")) @@ -50,12 +65,16 @@ pub async fn get( pub async fn create( State(app_state): State>, Path(guild_uuid): Path, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, Json(channel_info): Json, ) -> Result { + let mut conn = app_state.pool.get().await?; + + let uuid = check_access_token(auth.token(), &mut conn).await?; + global_checks(&app_state, uuid).await?; - let member = Member::check_membership(&mut app_state.pool.get().await?, uuid, guild_uuid).await?; + let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; member .check_permission(&app_state, Permissions::ManageChannel) diff --git a/src/api/v1/guilds/uuid/invites/mod.rs b/src/api/v1/guilds/uuid/invites/mod.rs index 2070452..7703cf7 100644 --- a/src/api/v1/guilds/uuid/invites/mod.rs +++ b/src/api/v1/guilds/uuid/invites/mod.rs @@ -1,14 +1,21 @@ use std::sync::Arc; use axum::{ - extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json + Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, }; use serde::Deserialize; use uuid::Uuid; use crate::{ AppState, - api::v1::auth::CurrentUser, + api::v1::auth::check_access_token, error::Error, objects::{Guild, Member, Permissions}, utils::global_checks, @@ -22,12 +29,14 @@ pub struct InviteRequest { pub async fn get( State(app_state): State>, Path(guild_uuid): Path, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + let uuid = check_access_token(auth.token(), &mut conn).await?; + + global_checks(&app_state, uuid).await?; + Member::check_membership(&mut conn, uuid, guild_uuid).await?; let guild = Guild::fetch_one(&mut conn, guild_uuid).await?; @@ -40,13 +49,15 @@ pub async fn get( pub async fn create( State(app_state): State>, Path(guild_uuid): Path, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, Json(invite_request): Json, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + let uuid = check_access_token(auth.token(), &mut conn).await?; + + global_checks(&app_state, uuid).await?; + let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; member diff --git a/src/api/v1/guilds/uuid/members.rs b/src/api/v1/guilds/uuid/members.rs index 6c8b980..bd2f853 100644 --- a/src/api/v1/guilds/uuid/members.rs +++ b/src/api/v1/guilds/uuid/members.rs @@ -2,12 +2,19 @@ use std::sync::Arc; use ::uuid::Uuid; use axum::{ - extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json + Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, }; use crate::{ AppState, - api::v1::auth::CurrentUser, + api::v1::auth::check_access_token, error::Error, objects::{Me, Member}, utils::global_checks, @@ -16,12 +23,14 @@ use crate::{ pub async fn get( State(app_state): State>, Path(guild_uuid): Path, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + let uuid = check_access_token(auth.token(), &mut conn).await?; + + global_checks(&app_state, uuid).await?; + Member::check_membership(&mut conn, uuid, guild_uuid).await?; let me = Me::get(&mut conn, uuid).await?; diff --git a/src/api/v1/guilds/uuid/mod.rs b/src/api/v1/guilds/uuid/mod.rs index c5a809f..0a27123 100644 --- a/src/api/v1/guilds/uuid/mod.rs +++ b/src/api/v1/guilds/uuid/mod.rs @@ -3,7 +3,15 @@ use std::sync::Arc; use axum::{ - extract::{Multipart, Path, State}, http::StatusCode, response::IntoResponse, routing::{get, patch, post}, Extension, Json, Router + Json, Router, + extract::{Multipart, Path, State}, + http::StatusCode, + response::IntoResponse, + routing::{get, patch, post}, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, }; use bytes::Bytes; use uuid::Uuid; @@ -15,7 +23,7 @@ mod roles; use crate::{ AppState, - api::v1::auth::CurrentUser, + api::v1::auth::check_access_token, error::Error, objects::{Guild, Member, Permissions}, utils::global_checks, @@ -76,12 +84,14 @@ pub fn router() -> Router> { pub async fn get_guild( State(app_state): State>, Path(guild_uuid): Path, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + let uuid = check_access_token(auth.token(), &mut conn).await?; + + global_checks(&app_state, uuid).await?; + Member::check_membership(&mut conn, uuid, guild_uuid).await?; let guild = Guild::fetch_one(&mut conn, guild_uuid).await?; @@ -95,13 +105,15 @@ pub async fn get_guild( pub async fn edit( State(app_state): State>, Path(guild_uuid): Path, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, mut multipart: Multipart, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + let uuid = check_access_token(auth.token(), &mut conn).await?; + + global_checks(&app_state, uuid).await?; + let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; member diff --git a/src/api/v1/guilds/uuid/roles/mod.rs b/src/api/v1/guilds/uuid/roles/mod.rs index 5331143..12960c2 100644 --- a/src/api/v1/guilds/uuid/roles/mod.rs +++ b/src/api/v1/guilds/uuid/roles/mod.rs @@ -2,13 +2,20 @@ use std::sync::Arc; use ::uuid::Uuid; use axum::{ - extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json + Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, }; use serde::Deserialize; use crate::{ AppState, - api::v1::auth::CurrentUser, + api::v1::auth::check_access_token, error::Error, objects::{Member, Permissions, Role}, utils::{global_checks, order_by_is_above}, @@ -24,12 +31,12 @@ pub struct RoleInfo { pub async fn get( State(app_state): State>, Path(guild_uuid): Path, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + let uuid = check_access_token(auth.token(), &mut conn).await?; + Member::check_membership(&mut conn, uuid, guild_uuid).await?; if let Ok(cache_hit) = app_state.get_cache_key(format!("{guild_uuid}_roles")).await { @@ -50,13 +57,15 @@ pub async fn get( pub async fn create( State(app_state): State>, Path(guild_uuid): Path, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, Json(role_info): Json, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + let uuid = check_access_token(auth.token(), &mut conn).await?; + + global_checks(&app_state, uuid).await?; + let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; member diff --git a/src/api/v1/guilds/uuid/roles/uuid.rs b/src/api/v1/guilds/uuid/roles/uuid.rs index 91300bf..a62a5b4 100644 --- a/src/api/v1/guilds/uuid/roles/uuid.rs +++ b/src/api/v1/guilds/uuid/roles/uuid.rs @@ -2,12 +2,19 @@ use std::sync::Arc; use ::uuid::Uuid; use axum::{ - extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json + Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, }; use crate::{ AppState, - api::v1::auth::CurrentUser, + api::v1::auth::check_access_token, error::Error, objects::{Member, Role}, utils::global_checks, @@ -16,12 +23,14 @@ use crate::{ pub async fn get( State(app_state): State>, Path((guild_uuid, role_uuid)): Path<(Uuid, Uuid)>, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + let uuid = check_access_token(auth.token(), &mut conn).await?; + + global_checks(&app_state, uuid).await?; + Member::check_membership(&mut conn, uuid, guild_uuid).await?; if let Ok(cache_hit) = app_state.get_cache_key(format!("{role_uuid}")).await { diff --git a/src/api/v1/invites/id.rs b/src/api/v1/invites/id.rs index c752177..b832557 100644 --- a/src/api/v1/invites/id.rs +++ b/src/api/v1/invites/id.rs @@ -1,13 +1,19 @@ use std::sync::Arc; use axum::{ - extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json + Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, }; -use uuid::Uuid; use crate::{ AppState, - api::v1::auth::CurrentUser, + api::v1::auth::check_access_token, error::Error, objects::{Guild, Invite, Member}, utils::global_checks, @@ -29,12 +35,14 @@ pub async fn get( pub async fn join( State(app_state): State>, Path(invite_id): Path, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + let uuid = check_access_token(auth.token(), &mut conn).await?; + + global_checks(&app_state, uuid).await?; + let invite = Invite::fetch_one(&mut conn, invite_id).await?; let guild = Guild::fetch_one(&mut conn, invite.guild_uuid).await?; diff --git a/src/api/v1/me/friends/mod.rs b/src/api/v1/me/friends/mod.rs index 63284a8..8a7851c 100644 --- a/src/api/v1/me/friends/mod.rs +++ b/src/api/v1/me/friends/mod.rs @@ -1,23 +1,34 @@ use std::sync::Arc; -use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json}; +use axum::{Json, extract::State, http::StatusCode, response::IntoResponse}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; use serde::Deserialize; -use ::uuid::Uuid; pub mod uuid; use crate::{ - api::v1::auth::CurrentUser, error::Error, objects::Me, utils::{global_checks, user_uuid_from_username}, AppState + AppState, + api::v1::auth::check_access_token, + error::Error, + objects::Me, + utils::{global_checks, user_uuid_from_username}, }; /// Returns a list of users that are your friends pub async fn get( State(app_state): State>, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, ) -> Result { + let mut conn = app_state.pool.get().await?; + + let uuid = check_access_token(auth.token(), &mut conn).await?; + global_checks(&app_state, uuid).await?; - let me = Me::get(&mut app_state.pool.get().await?, uuid).await?; + let me = Me::get(&mut conn, uuid).await?; let friends = me.get_friends(&app_state).await?; @@ -50,13 +61,15 @@ pub struct UserReq { /// pub async fn post( State(app_state): State>, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, Json(user_request): Json, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + let uuid = check_access_token(auth.token(), &mut conn).await?; + + global_checks(&app_state, uuid).await?; + let me = Me::get(&mut conn, uuid).await?; let target_uuid = user_uuid_from_username(&mut conn, &user_request.username).await?; diff --git a/src/api/v1/me/friends/uuid.rs b/src/api/v1/me/friends/uuid.rs index 5a32386..8d40f26 100644 --- a/src/api/v1/me/friends/uuid.rs +++ b/src/api/v1/me/friends/uuid.rs @@ -3,23 +3,29 @@ use std::sync::Arc; use axum::{ extract::{Path, State}, http::StatusCode, - response::IntoResponse, Extension, + response::IntoResponse, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, }; use uuid::Uuid; use crate::{ - AppState, api::v1::auth::CurrentUser, error::Error, objects::Me, utils::global_checks, + AppState, api::v1::auth::check_access_token, error::Error, objects::Me, utils::global_checks, }; pub async fn delete( State(app_state): State>, Path(friend_uuid): Path, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + let uuid = check_access_token(auth.token(), &mut conn).await?; + + global_checks(&app_state, uuid).await?; + let me = Me::get(&mut conn, uuid).await?; me.remove_friend(&mut conn, friend_uuid).await?; diff --git a/src/api/v1/me/guilds.rs b/src/api/v1/me/guilds.rs index a2d2111..adfe845 100644 --- a/src/api/v1/me/guilds.rs +++ b/src/api/v1/me/guilds.rs @@ -2,11 +2,14 @@ use std::sync::Arc; -use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json}; -use uuid::Uuid; +use axum::{Json, extract::State, http::StatusCode, response::IntoResponse}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; use crate::{ - AppState, api::v1::auth::CurrentUser, error::Error, objects::Me, utils::global_checks, + AppState, api::v1::auth::check_access_token, error::Error, objects::Me, utils::global_checks, }; /// `GET /api/v1/me/guilds` Returns all guild memberships in a list @@ -56,12 +59,14 @@ use crate::{ /// NOTE: UUIDs in this response are made using `uuidgen`, UUIDs made by the actual backend will be UUIDv7 and have extractable timestamps pub async fn get( State(app_state): State>, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, ) -> Result { - global_checks(&app_state, uuid).await?; - let mut conn = app_state.pool.get().await?; + let uuid = check_access_token(auth.token(), &mut conn).await?; + + global_checks(&app_state, uuid).await?; + let me = Me::get(&mut conn, uuid).await?; let memberships = me.fetch_memberships(&mut conn).await?; diff --git a/src/api/v1/me/mod.rs b/src/api/v1/me/mod.rs index ce577d4..e9680bc 100644 --- a/src/api/v1/me/mod.rs +++ b/src/api/v1/me/mod.rs @@ -1,14 +1,21 @@ use std::sync::Arc; use axum::{ - extract::{DefaultBodyLimit, Multipart, State}, http::StatusCode, response::IntoResponse, routing::{delete, get, patch, post}, Extension, Json, Router + Json, Router, + extract::{DefaultBodyLimit, Multipart, State}, + http::StatusCode, + response::IntoResponse, + routing::{delete, get, patch, post}, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, }; use bytes::Bytes; use serde::Deserialize; -use uuid::Uuid; use crate::{ - api::v1::auth::CurrentUser, error::Error, objects::Me, utils::global_checks, AppState + AppState, api::v1::auth::check_access_token, error::Error, objects::Me, utils::global_checks, }; mod friends; @@ -31,9 +38,13 @@ pub fn router() -> Router> { pub async fn get_me( State(app_state): State>, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, ) -> Result { - let me = Me::get(&mut app_state.pool.get().await?, uuid).await?; + let mut conn = app_state.pool.get().await?; + + let uuid = check_access_token(auth.token(), &mut conn).await?; + + let me = Me::get(&mut conn, uuid).await?; Ok((StatusCode::OK, Json(me))) } @@ -49,9 +60,13 @@ struct NewInfo { pub async fn update( State(app_state): State>, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, mut multipart: Multipart, ) -> Result { + let mut conn = app_state.pool.get().await?; + + let uuid = check_access_token(auth.token(), &mut conn).await?; + let mut json_raw: Option = None; let mut avatar: Option = None; @@ -73,7 +88,7 @@ pub async fn update( global_checks(&app_state, uuid).await?; } - let mut me = Me::get(&mut app_state.pool.get().await?, uuid).await?; + let mut me = Me::get(&mut conn, uuid).await?; if let Some(avatar) = avatar { me.set_avatar(&app_state, app_state.config.bunny.cdn_url.clone(), avatar) diff --git a/src/api/v1/mod.rs b/src/api/v1/mod.rs index f3e4305..4e8654b 100644 --- a/src/api/v1/mod.rs +++ b/src/api/v1/mod.rs @@ -2,9 +2,9 @@ use std::sync::Arc; -use axum::{middleware::from_fn_with_state, routing::get, Router}; +use axum::{routing::get, Router}; -use crate::{api::v1::auth::CurrentUser, AppState}; +use crate::AppState; mod auth; mod channels; @@ -14,17 +14,13 @@ mod me; mod stats; mod users; -pub fn router(app_state: Arc) -> Router> { - let router_with_auth = Router::new() +pub fn router() -> Router> { + Router::new() + .route("/stats", get(stats::res)) + .nest("/auth", auth::router()) .nest("/users", users::router()) .nest("/channels", channels::router()) .nest("/guilds", guilds::router()) .nest("/invites", invites::router()) .nest("/me", me::router()) - .layer(from_fn_with_state(app_state.clone(), CurrentUser::check_auth_layer)); - - Router::new() - .route("/stats", get(stats::res)) - .nest("/auth", auth::router(app_state)) - .merge(router_with_auth) } diff --git a/src/api/v1/users/mod.rs b/src/api/v1/users/mod.rs index 82f2125..f0d09c5 100644 --- a/src/api/v1/users/mod.rs +++ b/src/api/v1/users/mod.rs @@ -3,12 +3,23 @@ use std::sync::Arc; use axum::{ - extract::{Query, State}, http::StatusCode, response::IntoResponse, routing::get, Extension, Json, Router + Json, Router, + extract::{Query, State}, + http::StatusCode, + response::IntoResponse, + routing::get, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, }; -use ::uuid::Uuid; use crate::{ - api::v1::auth::CurrentUser, error::Error, objects::{StartAmountQuery, User}, utils::global_checks, AppState + AppState, + api::v1::auth::check_access_token, + error::Error, + objects::{StartAmountQuery, User}, + utils::global_checks, }; mod uuid; @@ -52,7 +63,7 @@ pub fn router() -> Router> { pub async fn users( State(app_state): State>, Query(request_query): Query, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, ) -> Result { let start = request_query.start.unwrap_or(0); @@ -62,9 +73,13 @@ pub async fn users( return Ok(StatusCode::BAD_REQUEST.into_response()); } + let mut conn = app_state.pool.get().await?; + + let uuid = check_access_token(auth.token(), &mut conn).await?; + global_checks(&app_state, uuid).await?; - let users = User::fetch_amount(&mut app_state.pool.get().await?, start, amount).await?; + let users = User::fetch_amount(&mut conn, start, amount).await?; Ok((StatusCode::OK, Json(users)).into_response()) } diff --git a/src/api/v1/users/uuid.rs b/src/api/v1/users/uuid.rs index 2bdcfac..1b7d43b 100644 --- a/src/api/v1/users/uuid.rs +++ b/src/api/v1/users/uuid.rs @@ -3,12 +3,23 @@ use std::sync::Arc; use axum::{ - extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json + Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, }; use uuid::Uuid; use crate::{ - api::v1::auth::CurrentUser, error::Error, objects::{Me, User}, utils::global_checks, AppState + AppState, + api::v1::auth::check_access_token, + error::Error, + objects::{Me, User}, + utils::global_checks, }; /// `GET /api/v1/users/{uuid}` Returns user with the given UUID @@ -30,11 +41,15 @@ use crate::{ pub async fn get( State(app_state): State>, Path(user_uuid): Path, - Extension(CurrentUser(uuid)): Extension>, + TypedHeader(auth): TypedHeader>, ) -> Result { + let mut conn = app_state.pool.get().await?; + + let uuid = check_access_token(auth.token(), &mut conn).await?; + global_checks(&app_state, uuid).await?; - let me = Me::get(&mut app_state.pool.get().await?, uuid).await?; + let me = Me::get(&mut conn, uuid).await?; let user = User::fetch_one_with_friendship(&app_state, &me, user_uuid).await?; diff --git a/src/main.rs b/src/main.rs index 8e6effc..ab37924 100644 --- a/src/main.rs +++ b/src/main.rs @@ -163,7 +163,7 @@ async fn main() -> Result<(), Error> { // build our application with a route let app = Router::new() // `GET /` goes to `root` - .merge(api::router(web.backend_url.path().trim_end_matches("/"), app_state.clone())) + .merge(api::router(web.backend_url.path().trim_end_matches("/"))) .with_state(app_state) .layer(cors) .layer(socket_io);