diff --git a/src/api/v1/auth/login.rs b/src/api/v1/auth/login.rs index 81d1370..25ce5ae 100644 --- a/src/api/v1/auth/login.rs +++ b/src/api/v1/auth/login.rs @@ -16,10 +16,9 @@ struct LoginInformation { } #[derive(Serialize)] -struct Response { - access_token: String, - expires_in: u64, - refresh_token: String, +pub struct Response { + pub access_token: String, + pub refresh_token: String, } const MAX_SIZE: usize = 262_144; @@ -44,7 +43,7 @@ pub async fn response(mut payload: web::Payload, data: web::Data) -> Resul let username_regex = Regex::new(r"[a-zA-Z0-9.-_]").unwrap(); // Password is expected to be hashed using SHA3-384 - let password_regex = Regex::new(r"/[0-9a-f]{96}/i").unwrap(); + let password_regex = Regex::new(r"[0-9a-f]{96}").unwrap(); if !password_regex.is_match(&login_information.password) { return Ok(HttpResponse::Forbidden().json(r#"{ "password_hashed": false }"#)); @@ -53,14 +52,14 @@ pub async fn response(mut payload: web::Payload, data: web::Data) -> Resul if email_regex.is_match(&login_information.username) { if let Ok(row) = sqlx::query_as("SELECT CAST(uuid as VARCHAR), password FROM users WHERE email = $1").bind(login_information.username).fetch_one(&data.pool).await { let (uuid, password): (String, String) = row; - return Ok(login(data.clone(), uuid, login_information.password, password).await) + return Ok(login(data.clone(), uuid, login_information.password, password, login_information.device_name).await) } return Ok(HttpResponse::Unauthorized().finish()) } else if username_regex.is_match(&login_information.username) { if let Ok(row) = sqlx::query_as("SELECT CAST(uuid as VARCHAR), password FROM users WHERE username = $1").bind(login_information.username).fetch_one(&data.pool).await { let (uuid, password): (String, String) = row; - return Ok(login(data.clone(), uuid, login_information.password, password).await) + return Ok(login(data.clone(), uuid, login_information.password, password, login_information.device_name).await) } return Ok(HttpResponse::Unauthorized().finish()) @@ -69,7 +68,7 @@ pub async fn response(mut payload: web::Payload, data: web::Data) -> Resul Ok(HttpResponse::Unauthorized().finish()) } -async fn login(data: actix_web::web::Data, uuid: String, request_password: String, database_password: String) -> HttpResponse { +async fn login(data: actix_web::web::Data, uuid: String, request_password: String, database_password: String, device_name: String) -> HttpResponse { if let Ok(parsed_hash) = PasswordHash::new(&database_password) { if data.argon2.verify_password(request_password.as_bytes(), &parsed_hash).is_ok() { let refresh_token = generate_refresh_token(); @@ -91,16 +90,17 @@ async fn login(data: actix_web::web::Data, uuid: String, request_password: let current_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64; - if let Err(error) = sqlx::query(&format!("INSERT INTO refresh_tokens (token, uuid, created) VALUES ($1, '{}', $2 )", uuid)) + if let Err(error) = sqlx::query(&format!("INSERT INTO refresh_tokens (token, uuid, created, device_name) VALUES ($1, '{}', $2, $3 )", uuid)) .bind(&refresh_token) .bind(current_time) + .bind(device_name) .execute(&data.pool) .await { eprintln!("{}", error); return HttpResponse::InternalServerError().finish() } - if let Err(error) = sqlx::query(&format!("INSERT INTO refresh_tokens (token, refresh_token, uuid, created) VALUES ($1, $2, '{}', $3 )", uuid)) + if let Err(error) = sqlx::query(&format!("INSERT INTO access_tokens (token, refresh_token, uuid, created) VALUES ($1, $2, '{}', $3 )", uuid)) .bind(&access_token) .bind(&refresh_token) .bind(current_time) @@ -111,9 +111,8 @@ async fn login(data: actix_web::web::Data, uuid: String, request_password: } return HttpResponse::Ok().json(Response { - access_token: "bogus".to_string(), - expires_in: 0, - refresh_token: "bogus".to_string(), + access_token, + refresh_token, }) } diff --git a/src/api/v1/auth/mod.rs b/src/api/v1/auth/mod.rs index 4cd06e7..a217d90 100644 --- a/src/api/v1/auth/mod.rs +++ b/src/api/v1/auth/mod.rs @@ -7,18 +7,20 @@ use uuid::Uuid; mod register; mod login; mod refresh; +mod revoke; pub fn web() -> Scope { web::scope("/auth") .service(register::res) .service(login::response) .service(refresh::res) + .service(revoke::res) } -pub async fn check_access_token(access_token: String, pool: sqlx::Pool) -> Result { +pub async fn check_access_token<'a>(access_token: String, pool: &'a sqlx::Pool) -> Result { match sqlx::query_as("SELECT CAST(uuid as VARCHAR), created FROM access_tokens WHERE token = $1") .bind(&access_token) - .fetch_one(&pool) + .fetch_one(&*pool) .await { Ok(row) => { let (uuid, created): (String, i64) = row; diff --git a/src/api/v1/auth/refresh.rs b/src/api/v1/auth/refresh.rs index b924d43..b6af68a 100644 --- a/src/api/v1/auth/refresh.rs +++ b/src/api/v1/auth/refresh.rs @@ -71,7 +71,7 @@ pub async fn res(mut payload: web::Payload, data: web::Data) -> Result) -> Result Self { + Self { + deleted + } + } +} + +const MAX_SIZE: usize = 262_144; + +#[post("/revoke")] +pub async fn res(mut payload: web::Payload, data: web::Data) -> Result { + let mut body = web::BytesMut::new(); + while let Some(chunk) = payload.next().await { + let chunk = chunk?; + // limit max size of in-memory payload + if (body.len() + chunk.len()) > MAX_SIZE { + return Err(error::ErrorBadRequest("overflow")); + } + body.extend_from_slice(&chunk); + } + + let revoke_request = serde_json::from_slice::(&body)?; + + let authorized = check_access_token(revoke_request.access_token, &data.pool).await; + + if authorized.is_err() { + return Ok(authorized.unwrap_err()) + } + + let uuid = authorized.unwrap(); + + let database_password_raw = sqlx::query_scalar(&format!("SELECT password FROM users WHERE uuid = '{}'", uuid)) + .fetch_one(&data.pool) + .await; + + if database_password_raw.is_err() { + eprintln!("{}", database_password_raw.unwrap_err()); + return Ok(HttpResponse::InternalServerError().json(Response::new(false))); + } + + let database_password: String = database_password_raw.unwrap(); + + let hashed_password_raw = PasswordHash::new(&database_password); + + if hashed_password_raw.is_err() { + eprintln!("{}", hashed_password_raw.unwrap_err()); + return Ok(HttpResponse::InternalServerError().json(Response::new(false))); + } + + let hashed_password = hashed_password_raw.unwrap(); + + if data.argon2.verify_password(revoke_request.password.as_bytes(), &hashed_password).is_err() { + return Ok(HttpResponse::Unauthorized().finish()) + } + + let tokens_raw = sqlx::query_scalar(&format!("SELECT token FROM refresh_tokens WHERE uuid = '{}' AND device_name = $1", uuid)) + .bind(revoke_request.device_name) + .fetch_all(&data.pool) + .await; + + if tokens_raw.is_err() { + eprintln!("{:?}", tokens_raw); + return Ok(HttpResponse::InternalServerError().json(Response::new(false))) + } + + let tokens: Vec = tokens_raw.unwrap(); + + let mut access_tokens_delete = vec![]; + let mut refresh_tokens_delete = vec![]; + + + for token in tokens { + access_tokens_delete.push(sqlx::query("DELETE FROM access_tokens WHERE refresh_token = $1") + .bind(token.clone()) + .execute(&data.pool)); + + refresh_tokens_delete.push(sqlx::query("DELETE FROM refresh_tokens WHERE token = $1") + .bind(token.clone()) + .execute(&data.pool)); + } + + let results_access_tokens = future::join_all(access_tokens_delete).await; + let results_refresh_tokens = future::join_all(refresh_tokens_delete).await; + + let access_tokens_errors: Vec<&Result> = results_access_tokens.iter().filter(|r| r.is_err()).collect(); + let refresh_tokens_errors: Vec<&Result> = results_refresh_tokens.iter().filter(|r| r.is_err()).collect(); + + if !access_tokens_errors.is_empty() && !refresh_tokens_errors.is_empty() { + println!("{:?}", access_tokens_errors); + println!("{:?}", refresh_tokens_errors); + return Ok(HttpResponse::InternalServerError().finish()) + } else if !access_tokens_errors.is_empty() { + println!("{:?}", access_tokens_errors); + return Ok(HttpResponse::InternalServerError().finish()) + } else if !refresh_tokens_errors.is_empty() { + println!("{:?}", refresh_tokens_errors); + return Ok(HttpResponse::InternalServerError().finish()) + } + + Ok(HttpResponse::Ok().json(Response::new(true))) +} diff --git a/src/api/v1/user.rs b/src/api/v1/user.rs index 9287e1b..25721b3 100644 --- a/src/api/v1/user.rs +++ b/src/api/v1/user.rs @@ -1,6 +1,7 @@ use actix_web::{error, post, web, Error, HttpResponse}; use serde::{Deserialize, Serialize}; use futures::StreamExt; +use uuid::Uuid; use crate::{api::v1::auth::check_access_token, Data}; @@ -34,33 +35,35 @@ pub async fn res(mut payload: web::Payload, path: web::Path<(String,)>, data: we let authentication_request = serde_json::from_slice::(&body)?; - let authorized = check_access_token(authentication_request.access_token, data.pool.clone()).await; + let authorized = check_access_token(authentication_request.access_token, &data.pool).await; if authorized.is_err() { return Ok(authorized.unwrap_err()) } - let uuid = authorized.unwrap(); + let mut uuid = authorized.unwrap(); - if request == "me" { - let row = sqlx::query_as(&format!("SELECT username, display_name FROM users WHERE uuid = '{}'", uuid)) - .fetch_one(&data.pool) - .await - .unwrap(); + if request != "me" { + let requested_uuid = Uuid::parse_str(&request); - let (username, display_name): (String, Option) = row; - - return Ok(HttpResponse::Ok().json(Response { uuid: uuid.to_string(), username, display_name: display_name.unwrap_or_default() })) - } else { - println!("{}", request); - if let Ok(row) = sqlx::query_as(&format!("SELECT CAST(uuid as VARCHAR), username, display_name FROM users WHERE uuid = '{}'", request)) - .fetch_one(&data.pool) - .await { - let (uuid, username, display_name): (String, String, Option) = row; - - return Ok(HttpResponse::Ok().json(Response { uuid, username, display_name: display_name.unwrap_or_default() })) + if requested_uuid.is_err() { + return Ok(HttpResponse::BadRequest().json(r#"{ "error": "UUID is invalid!" }"#)) } - Ok(HttpResponse::NotFound().finish()) + uuid = requested_uuid.unwrap() } + + + let row = sqlx::query_as(&format!("SELECT username, display_name FROM users WHERE uuid = '{}'", uuid)) + .fetch_one(&data.pool) + .await; + + if row.is_err() { + eprintln!("{}", row.unwrap_err()); + return Ok(HttpResponse::InternalServerError().finish()) + } + + let (username, display_name): (String, Option) = row.unwrap(); + + Ok(HttpResponse::Ok().json(Response { uuid: uuid.to_string(), username, display_name: display_name.unwrap_or_default() })) } diff --git a/src/main.rs b/src/main.rs index 467cc29..0a364bb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,7 +21,7 @@ struct Args { #[derive(Clone)] struct Data { pub pool: Pool, - pub config: Config, + pub _config: Config, pub argon2: Argon2<'static>, pub start_time: SystemTime, } @@ -56,7 +56,8 @@ async fn main() -> Result<(), Error> { CREATE TABLE IF NOT EXISTS refresh_tokens ( token varchar(64) PRIMARY KEY UNIQUE NOT NULL, uuid uuid NOT NULL REFERENCES users(uuid), - created int8 NOT NULL + created int8 NOT NULL, + device_name varchar(16) NOT NULL ); CREATE TABLE IF NOT EXISTS access_tokens ( token varchar(32) PRIMARY KEY UNIQUE NOT NULL, @@ -70,7 +71,7 @@ async fn main() -> Result<(), Error> { let data = Data { pool, - config, + _config: config, // TODO: Possibly implement "pepper" into this (thinking it could generate one if it doesnt exist and store it on disk) argon2: Argon2::default(), start_time: SystemTime::now(),