diff --git a/src/api/mod.rs b/src/api/mod.rs index b79c824..80dc442 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,11 +1,2 @@ -use actix_web::Scope; -use actix_web::web; - -mod v1; -mod versions; - -pub fn web() -> Scope { - web::scope("/api") - .service(v1::web()) - .service(versions::res) -} +pub mod v1; +pub mod versions; diff --git a/src/api/v1/auth/login.rs b/src/api/v1/auth/login.rs index 0ea3d83..3be5474 100644 --- a/src/api/v1/auth/login.rs +++ b/src/api/v1/auth/login.rs @@ -1,17 +1,17 @@ use std::time::{SystemTime, UNIX_EPOCH}; -use actix_web::{error, post, web, Error, HttpResponse}; +use actix_web::{Error, HttpResponse, error, post, web}; use argon2::{PasswordHash, PasswordVerifier}; use futures::StreamExt; use log::error; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use crate::{ - api::v1::auth::{EMAIL_REGEX, PASSWORD_REGEX, USERNAME_REGEX}, crypto::{generate_access_token, generate_refresh_token}, utils::refresh_token_cookie, Data + Data, + api::v1::auth::{EMAIL_REGEX, PASSWORD_REGEX, USERNAME_REGEX}, + crypto::{generate_access_token, generate_refresh_token}, }; -use super::Response; - #[derive(Deserialize)] struct LoginInformation { username: String, @@ -19,6 +19,12 @@ struct LoginInformation { device_name: String, } +#[derive(Serialize)] +pub struct Response { + pub access_token: String, + pub refresh_token: String, +} + const MAX_SIZE: usize = 262_144; #[post("/login")] @@ -154,7 +160,7 @@ async fn login( .as_secs() as i64; if let Err(error) = sqlx::query(&format!( - "INSERT INTO refresh_tokens (token, uuid, created_at, device_name) VALUES ($1, '{}', $2, $3 )", + "INSERT INTO refresh_tokens (token, uuid, created, device_name) VALUES ($1, '{}', $2, $3 )", uuid )) .bind(&refresh_token) @@ -168,7 +174,7 @@ async fn login( } if let Err(error) = sqlx::query(&format!( - "INSERT INTO access_tokens (token, refresh_token, uuid, created_at) VALUES ($1, $2, '{}', $3 )", + "INSERT INTO access_tokens (token, refresh_token, uuid, created) VALUES ($1, $2, '{}', $3 )", uuid )) .bind(&access_token) @@ -181,7 +187,8 @@ async fn login( return HttpResponse::InternalServerError().finish() } - HttpResponse::Ok().cookie(refresh_token_cookie(refresh_token)).json(Response { + HttpResponse::Ok().json(Response { access_token, + refresh_token, }) } diff --git a/src/api/v1/auth/mod.rs b/src/api/v1/auth/mod.rs index bfd32af..ff74c6b 100644 --- a/src/api/v1/auth/mod.rs +++ b/src/api/v1/auth/mod.rs @@ -7,7 +7,6 @@ use std::{ use actix_web::{HttpResponse, Scope, web}; use log::error; use regex::Regex; -use serde::Serialize; use sqlx::Postgres; use uuid::Uuid; @@ -16,11 +15,6 @@ mod refresh; mod register; mod revoke; -#[derive(Serialize)] -struct Response { - access_token: String, -} - static EMAIL_REGEX: LazyLock = LazyLock::new(|| { Regex::new(r"[-A-Za-z0-9!#$%&'*+/=?^_`{|}~]+(?:\.[-A-Za-z0-9!#$%&'*+/=?^_`{|}~]+)*@(?:[A-Za-z0-9](?:[-A-Za-z0-9]*[A-Za-z0-9])?\.)+[A-Za-z0-9](?:[-A-Za-z0-9]*[A-Za-z0-9])?").unwrap() }); @@ -40,36 +34,33 @@ pub fn web() -> Scope { } pub async fn check_access_token( - access_token: &str, + access_token: String, pool: &sqlx::Pool, ) -> Result { - let row = - sqlx::query_as("SELECT CAST(uuid as VARCHAR), created_at FROM access_tokens WHERE token = $1") - .bind(&access_token) - .fetch_one(pool) - .await; + let row = sqlx::query_as( + "SELECT CAST(uuid as VARCHAR), created FROM access_tokens WHERE token = $1", + ) + .bind(&access_token) + .fetch_one(pool) + .await; if let Err(error) = row { - if error.to_string() - == "no rows returned by a query that expected to return at least one row" - { - return Err(HttpResponse::Unauthorized().finish()); + if error.to_string() == "no rows returned by a query that expected to return at least one row" { + return Err(HttpResponse::Unauthorized().finish()) } error!("{}", error); - return Err(HttpResponse::InternalServerError().json( - r#"{ "error": "Unhandled exception occured, contact the server administrator" }"#, - )); + return Err(HttpResponse::InternalServerError().json(r#"{ "error": "Unhandled exception occured, contact the server administrator" }"#)) } - let (uuid, created_at): (String, i64) = row.unwrap(); + let (uuid, created): (String, i64) = row.unwrap(); let current_time = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs() as i64; - let lifetime = current_time - created_at; + let lifetime = current_time - created; if lifetime > 3600 { return Err(HttpResponse::Unauthorized().finish()); diff --git a/src/api/v1/auth/refresh.rs b/src/api/v1/auth/refresh.rs index 8c3e7d0..5ac2402 100644 --- a/src/api/v1/auth/refresh.rs +++ b/src/api/v1/auth/refresh.rs @@ -1,22 +1,40 @@ -use actix_web::{post, web, Error, HttpRequest, HttpResponse}; +use actix_web::{Error, HttpResponse, error, post, web}; +use futures::StreamExt; use log::error; +use serde::{Deserialize, Serialize}; use std::time::{SystemTime, UNIX_EPOCH}; use crate::{ - crypto::{generate_access_token, generate_refresh_token}, utils::refresh_token_cookie, Data + Data, + crypto::{generate_access_token, generate_refresh_token}, }; -use super::Response; +#[derive(Deserialize)] +struct RefreshRequest { + refresh_token: String, +} + +#[derive(Serialize)] +struct Response { + refresh_token: String, + access_token: String, +} + +const MAX_SIZE: usize = 262_144; #[post("/refresh")] -pub async fn res(req: HttpRequest, data: web::Data) -> Result { - let recv_refresh_token_cookie = req.cookie("refresh_token"); - - if let None = recv_refresh_token_cookie { - return Ok(HttpResponse::Unauthorized().finish()) +pub async fn res(mut payload: web::Payload, data: web::Data) -> Result { + let mut body = web::BytesMut::new(); + while let Some(chunk) = payload.next().await { + let chunk = chunk?; + // limit max size of in-memory payload + if (body.len() + chunk.len()) > MAX_SIZE { + return Err(error::ErrorBadRequest("overflow")); + } + body.extend_from_slice(&chunk); } - let mut refresh_token = String::from(recv_refresh_token_cookie.unwrap().value()); + let refresh_request = serde_json::from_slice::(&body)?; let current_time = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -24,29 +42,33 @@ pub async fn res(req: HttpRequest, data: web::Data) -> Result 2592000 { if let Err(error) = sqlx::query("DELETE FROM refresh_tokens WHERE token = $1") - .bind(&refresh_token) + .bind(&refresh_request.refresh_token) .execute(&data.pool) .await { error!("{}", error); } - let mut refresh_token_cookie = refresh_token_cookie(refresh_token); - - refresh_token_cookie.make_removal(); - - return Ok(HttpResponse::Unauthorized().cookie(refresh_token_cookie).finish()); + return Ok(HttpResponse::Unauthorized().finish()); } let current_time = SystemTime::now() @@ -54,6 +76,8 @@ pub async fn res(req: HttpRequest, data: web::Data) -> Result 1987200 { let new_refresh_token = generate_refresh_token(); @@ -64,7 +88,7 @@ pub async fn res(req: HttpRequest, data: web::Data) -> Result) -> Result) -> Result) -> Result) -> Result { diff --git a/src/api/v1/auth/revoke.rs b/src/api/v1/auth/revoke.rs index 9ebbb30..f3285c4 100644 --- a/src/api/v1/auth/revoke.rs +++ b/src/api/v1/auth/revoke.rs @@ -1,13 +1,14 @@ -use actix_web::{Error, HttpRequest, HttpResponse, error, post, web}; +use actix_web::{Error, HttpResponse, error, post, web}; use argon2::{PasswordHash, PasswordVerifier}; use futures::{StreamExt, future}; use log::error; use serde::{Deserialize, Serialize}; -use crate::{Data, api::v1::auth::check_access_token, utils::get_auth_header}; +use crate::{Data, api::v1::auth::check_access_token}; #[derive(Deserialize)] struct RevokeRequest { + access_token: String, password: String, device_name: String, } @@ -26,19 +27,7 @@ impl Response { const MAX_SIZE: usize = 262_144; #[post("/revoke")] -pub async fn res( - req: HttpRequest, - mut payload: web::Payload, - data: web::Data, -) -> Result { - let headers = req.headers(); - - let auth_header = get_auth_header(headers); - - if let Err(error) = auth_header { - return Ok(error); - } - +pub async fn res(mut payload: web::Payload, data: web::Data) -> Result { let mut body = web::BytesMut::new(); while let Some(chunk) = payload.next().await { let chunk = chunk?; @@ -51,7 +40,7 @@ pub async fn res( let revoke_request = serde_json::from_slice::(&body)?; - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; + let authorized = check_access_token(revoke_request.access_token, &data.pool).await; if let Err(error) = authorized { return Ok(error); @@ -105,9 +94,16 @@ pub async fn res( let tokens: Vec = tokens_raw.unwrap(); + let mut access_tokens_delete = vec![]; let mut refresh_tokens_delete = vec![]; for token in tokens { + access_tokens_delete.push( + sqlx::query("DELETE FROM access_tokens WHERE refresh_token = $1") + .bind(token.clone()) + .execute(&data.pool), + ); + refresh_tokens_delete.push( sqlx::query("DELETE FROM refresh_tokens WHERE token = $1") .bind(token.clone()) @@ -115,16 +111,29 @@ pub async fn res( ); } - let results = future::join_all(refresh_tokens_delete).await; + let results_access_tokens = future::join_all(access_tokens_delete).await; + let results_refresh_tokens = future::join_all(refresh_tokens_delete).await; - let errors: Vec<&Result> = - results + let access_tokens_errors: Vec<&Result> = + results_access_tokens + .iter() + .filter(|r| r.is_err()) + .collect(); + let refresh_tokens_errors: Vec<&Result> = + results_refresh_tokens .iter() .filter(|r| r.is_err()) .collect(); - if !errors.is_empty() { - error!("{:?}", errors); + if !access_tokens_errors.is_empty() && !refresh_tokens_errors.is_empty() { + error!("{:?}", access_tokens_errors); + error!("{:?}", refresh_tokens_errors); + return Ok(HttpResponse::InternalServerError().finish()); + } else if !access_tokens_errors.is_empty() { + error!("{:?}", access_tokens_errors); + return Ok(HttpResponse::InternalServerError().finish()); + } else if !refresh_tokens_errors.is_empty() { + error!("{:?}", refresh_tokens_errors); return Ok(HttpResponse::InternalServerError().finish()); } diff --git a/src/api/v1/users/me.rs b/src/api/v1/users/me.rs index f641678..18e6ba8 100644 --- a/src/api/v1/users/me.rs +++ b/src/api/v1/users/me.rs @@ -1,8 +1,14 @@ -use actix_web::{Error, HttpRequest, HttpResponse, get, web}; +use actix_web::{Error, HttpResponse, error, post, web}; +use futures::StreamExt; use log::error; -use serde::Serialize; +use serde::{Deserialize, Serialize}; -use crate::{Data, api::v1::auth::check_access_token, utils::get_auth_header}; +use crate::{Data, api::v1::auth::check_access_token}; + +#[derive(Deserialize)] +struct AuthenticationRequest { + access_token: String, +} #[derive(Serialize)] struct Response { @@ -11,17 +17,26 @@ struct Response { display_name: String, } -#[get("/me")] -pub async fn res(req: HttpRequest, data: web::Data) -> Result { - let headers = req.headers(); +const MAX_SIZE: usize = 262_144; - let auth_header = get_auth_header(headers); - - if let Err(error) = auth_header { - return Ok(error); +#[post("/me")] +pub async fn res( + mut payload: web::Payload, + data: web::Data, +) -> Result { + let mut body = web::BytesMut::new(); + while let Some(chunk) = payload.next().await { + let chunk = chunk?; + // limit max size of in-memory payload + if (body.len() + chunk.len()) > MAX_SIZE { + return Err(error::ErrorBadRequest("overflow")); + } + body.extend_from_slice(&chunk); } - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; + let authentication_request = serde_json::from_slice::(&body)?; + + let authorized = check_access_token(authentication_request.access_token, &data.pool).await; if let Err(error) = authorized { return Ok(error); diff --git a/src/api/v1/users/mod.rs b/src/api/v1/users/mod.rs index d7cb1c6..937d857 100644 --- a/src/api/v1/users/mod.rs +++ b/src/api/v1/users/mod.rs @@ -1,16 +1,18 @@ -use crate::{Data, api::v1::auth::check_access_token, utils::get_auth_header}; -use actix_web::{get, web, Error, HttpRequest, HttpResponse, Scope}; +use actix_web::{error, post, web, Error, HttpResponse, Scope}; +use futures::StreamExt; use log::error; use serde::{Deserialize, Serialize}; use sqlx::prelude::FromRow; +use crate::{Data, api::v1::auth::check_access_token}; mod me; mod uuid; #[derive(Deserialize)] -struct RequestQuery { - start: Option, - amount: Option, +struct Request { + access_token: String, + start: i32, + amount: i32, } #[derive(Serialize, FromRow)] @@ -21,6 +23,8 @@ struct Response { email: String, } +const MAX_SIZE: usize = 262_144; + pub fn web() -> Scope { web::scope("/users") .service(res) @@ -28,33 +32,36 @@ pub fn web() -> Scope { .service(uuid::res) } -#[get("")] +#[post("")] pub async fn res( - req: HttpRequest, - request_query: web::Query, + mut payload: web::Payload, data: web::Data, ) -> Result { - let headers = req.headers(); - - let auth_header = get_auth_header(headers); - - let start = request_query.start.unwrap_or(0); - - let amount = request_query.amount.unwrap_or(10); - - if amount > 100 { - return Ok(HttpResponse::BadRequest().finish()); + let mut body = web::BytesMut::new(); + while let Some(chunk) = payload.next().await { + let chunk = chunk?; + // limit max size of in-memory payload + if (body.len() + chunk.len()) > MAX_SIZE { + return Err(error::ErrorBadRequest("overflow")); + } + body.extend_from_slice(&chunk); } - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; + let request = serde_json::from_slice::(&body)?; + + if request.amount > 100 { + return Ok(HttpResponse::BadRequest().finish()) + } + + let authorized = check_access_token(request.access_token, &data.pool).await; if let Err(error) = authorized { return Ok(error); } let row = sqlx::query_as("SELECT CAST(uuid AS VARCHAR), username, display_name, email FROM users ORDER BY username LIMIT $1 OFFSET $2") - .bind(amount) - .bind(start) + .bind(request.amount) + .bind(request.start) .fetch_all(&data.pool) .await; @@ -67,3 +74,4 @@ pub async fn res( Ok(HttpResponse::Ok().json(accounts)) } + diff --git a/src/api/v1/users/uuid.rs b/src/api/v1/users/uuid.rs index f4c1f13..41d87cc 100644 --- a/src/api/v1/users/uuid.rs +++ b/src/api/v1/users/uuid.rs @@ -1,9 +1,15 @@ -use actix_web::{Error, HttpRequest, HttpResponse, get, web}; +use actix_web::{Error, HttpResponse, error, post, web}; +use futures::StreamExt; use log::error; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use uuid::Uuid; -use crate::{Data, api::v1::auth::check_access_token, utils::get_auth_header}; +use crate::{Data, api::v1::auth::check_access_token}; + +#[derive(Deserialize)] +struct AuthenticationRequest { + access_token: String, +} #[derive(Serialize)] struct Response { @@ -12,23 +18,29 @@ struct Response { display_name: String, } -#[get("/{uuid}")] +const MAX_SIZE: usize = 262_144; + +#[post("/{uuid}")] pub async fn res( - req: HttpRequest, + mut payload: web::Payload, path: web::Path<(Uuid,)>, data: web::Data, ) -> Result { - let headers = req.headers(); + let mut body = web::BytesMut::new(); + while let Some(chunk) = payload.next().await { + let chunk = chunk?; + // limit max size of in-memory payload + if (body.len() + chunk.len()) > MAX_SIZE { + return Err(error::ErrorBadRequest("overflow")); + } + body.extend_from_slice(&chunk); + } let uuid = path.into_inner().0; - let auth_header = get_auth_header(headers); + let authentication_request = serde_json::from_slice::(&body)?; - if let Err(error) = auth_header { - return Ok(error); - } - - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; + let authorized = check_access_token(authentication_request.access_token, &data.pool).await; if let Err(error) = authorized { return Ok(error); diff --git a/src/main.rs b/src/main.rs index e967021..4c909b1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,7 +8,6 @@ mod config; use config::{Config, ConfigBuilder}; mod api; pub mod crypto; -pub mod utils; type Error = Box; @@ -64,14 +63,14 @@ async fn main() -> Result<(), Error> { CREATE TABLE IF NOT EXISTS refresh_tokens ( token varchar(64) PRIMARY KEY UNIQUE NOT NULL, uuid uuid NOT NULL REFERENCES users(uuid), - created_at int8 NOT NULL, + created int8 NOT NULL, device_name varchar(16) NOT NULL ); CREATE TABLE IF NOT EXISTS access_tokens ( token varchar(32) PRIMARY KEY UNIQUE NOT NULL, - refresh_token varchar(64) UNIQUE NOT NULL REFERENCES refresh_tokens(token) ON UPDATE CASCADE ON DELETE CASCADE, + refresh_token varchar(64) UNIQUE NOT NULL REFERENCES refresh_tokens(token), uuid uuid NOT NULL REFERENCES users(uuid), - created_at int8 NOT NULL + created int8 NOT NULL ) "#, ) @@ -89,7 +88,8 @@ async fn main() -> Result<(), Error> { HttpServer::new(move || { App::new() .app_data(web::Data::new(data.clone())) - .service(api::web()) + .service(api::versions::res) + .service(api::v1::web()) }) .bind((web.url, web.port))? .run() diff --git a/src/utils.rs b/src/utils.rs deleted file mode 100644 index b432d19..0000000 --- a/src/utils.rs +++ /dev/null @@ -1,33 +0,0 @@ -use actix_web::{cookie::{time::Duration, Cookie, SameSite}, http::header::HeaderMap, HttpResponse}; - -pub fn get_auth_header(headers: &HeaderMap) -> Result<&str, HttpResponse> { - let auth_token = headers.get(actix_web::http::header::AUTHORIZATION); - - if let None = auth_token { - return Err(HttpResponse::Unauthorized().finish()); - } - - let auth = auth_token.unwrap().to_str(); - - if let Err(error) = auth { - return Err(HttpResponse::Unauthorized().json(format!(r#" {{ "error": "{}" }} "#, error))); - } - - let auth_value = auth.unwrap().split_whitespace().nth(1); - - if let None = auth_value { - return Err(HttpResponse::BadRequest().finish()); - } - - Ok(auth_value.unwrap()) -} - -pub fn refresh_token_cookie(refresh_token: String) -> Cookie<'static> { - Cookie::build("refresh_token", refresh_token) - .http_only(true) - .secure(true) - .same_site(SameSite::None) - .path("/api") - .max_age(Duration::days(30)) - .finish() -}