Compare commits

...

8 commits

Author SHA1 Message Date
22ab3d8a04 feat: add a way to revoke refresh_tokens using device_name 2025-05-01 20:19:39 +02:00
a89d705239 feat: use device_name in refresh_tokens table 2025-05-01 20:19:18 +02:00
c009d578a7 perf: optimize user fetching code 2025-05-01 20:18:39 +02:00
705abeb643 fix: fix password regex
this should probably be moved to its own function so we can change it on the fly
2025-05-01 20:17:59 +02:00
1646e60e65 fix: underscore unused config var in data 2025-05-01 20:15:38 +02:00
2864196584 perf: avoid cloning when checking access 2025-05-01 20:12:02 +02:00
7b86706793 perf: dont needlessly update uuid in token 2025-05-01 19:19:35 +02:00
aea640a64c style: use the same response for login/register 2025-05-01 19:18:44 +02:00
7 changed files with 167 additions and 45 deletions

View file

@ -16,10 +16,9 @@ struct LoginInformation {
} }
#[derive(Serialize)] #[derive(Serialize)]
struct Response { pub struct Response {
access_token: String, pub access_token: String,
expires_in: u64, pub refresh_token: String,
refresh_token: String,
} }
const MAX_SIZE: usize = 262_144; const MAX_SIZE: usize = 262_144;
@ -44,7 +43,7 @@ pub async fn response(mut payload: web::Payload, data: web::Data<Data>) -> Resul
let username_regex = Regex::new(r"[a-zA-Z0-9.-_]").unwrap(); let username_regex = Regex::new(r"[a-zA-Z0-9.-_]").unwrap();
// Password is expected to be hashed using SHA3-384 // Password is expected to be hashed using SHA3-384
let password_regex = Regex::new(r"/[0-9a-f]{96}/i").unwrap(); let password_regex = Regex::new(r"[0-9a-f]{96}").unwrap();
if !password_regex.is_match(&login_information.password) { if !password_regex.is_match(&login_information.password) {
return Ok(HttpResponse::Forbidden().json(r#"{ "password_hashed": false }"#)); return Ok(HttpResponse::Forbidden().json(r#"{ "password_hashed": false }"#));
@ -53,14 +52,14 @@ pub async fn response(mut payload: web::Payload, data: web::Data<Data>) -> Resul
if email_regex.is_match(&login_information.username) { if email_regex.is_match(&login_information.username) {
if let Ok(row) = sqlx::query_as("SELECT CAST(uuid as VARCHAR), password FROM users WHERE email = $1").bind(login_information.username).fetch_one(&data.pool).await { if let Ok(row) = sqlx::query_as("SELECT CAST(uuid as VARCHAR), password FROM users WHERE email = $1").bind(login_information.username).fetch_one(&data.pool).await {
let (uuid, password): (String, String) = row; let (uuid, password): (String, String) = row;
return Ok(login(data.clone(), uuid, login_information.password, password).await) return Ok(login(data.clone(), uuid, login_information.password, password, login_information.device_name).await)
} }
return Ok(HttpResponse::Unauthorized().finish()) return Ok(HttpResponse::Unauthorized().finish())
} else if username_regex.is_match(&login_information.username) { } else if username_regex.is_match(&login_information.username) {
if let Ok(row) = sqlx::query_as("SELECT CAST(uuid as VARCHAR), password FROM users WHERE username = $1").bind(login_information.username).fetch_one(&data.pool).await { if let Ok(row) = sqlx::query_as("SELECT CAST(uuid as VARCHAR), password FROM users WHERE username = $1").bind(login_information.username).fetch_one(&data.pool).await {
let (uuid, password): (String, String) = row; let (uuid, password): (String, String) = row;
return Ok(login(data.clone(), uuid, login_information.password, password).await) return Ok(login(data.clone(), uuid, login_information.password, password, login_information.device_name).await)
} }
return Ok(HttpResponse::Unauthorized().finish()) return Ok(HttpResponse::Unauthorized().finish())
@ -69,7 +68,7 @@ pub async fn response(mut payload: web::Payload, data: web::Data<Data>) -> Resul
Ok(HttpResponse::Unauthorized().finish()) Ok(HttpResponse::Unauthorized().finish())
} }
async fn login(data: actix_web::web::Data<Data>, uuid: String, request_password: String, database_password: String) -> HttpResponse { async fn login(data: actix_web::web::Data<Data>, uuid: String, request_password: String, database_password: String, device_name: String) -> HttpResponse {
if let Ok(parsed_hash) = PasswordHash::new(&database_password) { if let Ok(parsed_hash) = PasswordHash::new(&database_password) {
if data.argon2.verify_password(request_password.as_bytes(), &parsed_hash).is_ok() { if data.argon2.verify_password(request_password.as_bytes(), &parsed_hash).is_ok() {
let refresh_token = generate_refresh_token(); let refresh_token = generate_refresh_token();
@ -91,16 +90,17 @@ async fn login(data: actix_web::web::Data<Data>, uuid: String, request_password:
let current_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64; let current_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64;
if let Err(error) = sqlx::query(&format!("INSERT INTO refresh_tokens (token, uuid, created) VALUES ($1, '{}', $2 )", uuid)) if let Err(error) = sqlx::query(&format!("INSERT INTO refresh_tokens (token, uuid, created, device_name) VALUES ($1, '{}', $2, $3 )", uuid))
.bind(&refresh_token) .bind(&refresh_token)
.bind(current_time) .bind(current_time)
.bind(device_name)
.execute(&data.pool) .execute(&data.pool)
.await { .await {
eprintln!("{}", error); eprintln!("{}", error);
return HttpResponse::InternalServerError().finish() return HttpResponse::InternalServerError().finish()
} }
if let Err(error) = sqlx::query(&format!("INSERT INTO refresh_tokens (token, refresh_token, uuid, created) VALUES ($1, $2, '{}', $3 )", uuid)) if let Err(error) = sqlx::query(&format!("INSERT INTO access_tokens (token, refresh_token, uuid, created) VALUES ($1, $2, '{}', $3 )", uuid))
.bind(&access_token) .bind(&access_token)
.bind(&refresh_token) .bind(&refresh_token)
.bind(current_time) .bind(current_time)
@ -111,9 +111,8 @@ async fn login(data: actix_web::web::Data<Data>, uuid: String, request_password:
} }
return HttpResponse::Ok().json(Response { return HttpResponse::Ok().json(Response {
access_token: "bogus".to_string(), access_token,
expires_in: 0, refresh_token,
refresh_token: "bogus".to_string(),
}) })
} }

View file

@ -7,18 +7,20 @@ use uuid::Uuid;
mod register; mod register;
mod login; mod login;
mod refresh; mod refresh;
mod revoke;
pub fn web() -> Scope { pub fn web() -> Scope {
web::scope("/auth") web::scope("/auth")
.service(register::res) .service(register::res)
.service(login::response) .service(login::response)
.service(refresh::res) .service(refresh::res)
.service(revoke::res)
} }
pub async fn check_access_token(access_token: String, pool: sqlx::Pool<Postgres>) -> Result<Uuid, HttpResponse> { pub async fn check_access_token<'a>(access_token: String, pool: &'a sqlx::Pool<Postgres>) -> Result<Uuid, HttpResponse> {
match sqlx::query_as("SELECT CAST(uuid as VARCHAR), created FROM access_tokens WHERE token = $1") match sqlx::query_as("SELECT CAST(uuid as VARCHAR), created FROM access_tokens WHERE token = $1")
.bind(&access_token) .bind(&access_token)
.fetch_one(&pool) .fetch_one(&*pool)
.await { .await {
Ok(row) => { Ok(row) => {
let (uuid, created): (String, i64) = row; let (uuid, created): (String, i64) = row;

View file

@ -71,7 +71,7 @@ pub async fn res(mut payload: web::Payload, data: web::Data<Data>) -> Result<Htt
let new_refresh_token = new_refresh_token.unwrap(); let new_refresh_token = new_refresh_token.unwrap();
match sqlx::query(&format!("UPDATE refresh_tokens SET token = $1, uuid = {}, created = $2 WHERE token = $3", uuid)) match sqlx::query("UPDATE refresh_tokens SET token = $1, created = $2 WHERE token = $3")
.bind(&new_refresh_token) .bind(&new_refresh_token)
.bind(&current_time) .bind(&current_time)
.bind(&refresh_token) .bind(&refresh_token)

View file

@ -8,6 +8,7 @@ use uuid::Uuid;
use argon2::{password_hash::{rand_core::OsRng, SaltString}, PasswordHasher}; use argon2::{password_hash::{rand_core::OsRng, SaltString}, PasswordHasher};
use crate::{crypto::{generate_access_token, generate_refresh_token}, Data}; use crate::{crypto::{generate_access_token, generate_refresh_token}, Data};
use super::login::Response;
#[derive(Deserialize)] #[derive(Deserialize)]
struct AccountInformation { struct AccountInformation {
@ -48,12 +49,6 @@ impl Default for ResponseError {
} }
} }
#[derive(Serialize)]
struct Response {
access_token: String,
refresh_token: String,
}
const MAX_SIZE: usize = 262_144; const MAX_SIZE: usize = 262_144;
#[post("/register")] #[post("/register")]
@ -137,9 +132,10 @@ pub async fn res(mut payload: web::Payload, data: web::Data<Data>) -> Result<Htt
let current_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64; let current_time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() as i64;
if let Err(error) = sqlx::query(&format!("INSERT INTO refresh_tokens (token, uuid, created) VALUES ($1, '{}', $2 )", uuid)) if let Err(error) = sqlx::query(&format!("INSERT INTO refresh_tokens (token, uuid, created, device_name) VALUES ($1, '{}', $2, $3 )", uuid))
.bind(&refresh_token) .bind(&refresh_token)
.bind(current_time) .bind(current_time)
.bind(account_information.device_name)
.execute(&data.pool) .execute(&data.pool)
.await { .await {
eprintln!("{}", error); eprintln!("{}", error);

121
src/api/v1/auth/revoke.rs Normal file
View file

@ -0,0 +1,121 @@
use actix_web::{error, post, web, Error, HttpResponse};
use argon2::{PasswordHash, PasswordVerifier};
use serde::{Deserialize, Serialize};
use futures::{future, StreamExt};
use crate::{api::v1::auth::check_access_token, Data};
#[derive(Deserialize)]
struct RevokeRequest {
access_token: String,
password: String,
device_name: String,
}
#[derive(Serialize)]
struct Response {
deleted: bool,
}
impl Response {
fn new(deleted: bool) -> Self {
Self {
deleted
}
}
}
const MAX_SIZE: usize = 262_144;
#[post("/revoke")]
pub async fn res(mut payload: web::Payload, data: web::Data<Data>) -> Result<HttpResponse, Error> {
let mut body = web::BytesMut::new();
while let Some(chunk) = payload.next().await {
let chunk = chunk?;
// limit max size of in-memory payload
if (body.len() + chunk.len()) > MAX_SIZE {
return Err(error::ErrorBadRequest("overflow"));
}
body.extend_from_slice(&chunk);
}
let revoke_request = serde_json::from_slice::<RevokeRequest>(&body)?;
let authorized = check_access_token(revoke_request.access_token, &data.pool).await;
if authorized.is_err() {
return Ok(authorized.unwrap_err())
}
let uuid = authorized.unwrap();
let database_password_raw = sqlx::query_scalar(&format!("SELECT password FROM users WHERE uuid = '{}'", uuid))
.fetch_one(&data.pool)
.await;
if database_password_raw.is_err() {
eprintln!("{}", database_password_raw.unwrap_err());
return Ok(HttpResponse::InternalServerError().json(Response::new(false)));
}
let database_password: String = database_password_raw.unwrap();
let hashed_password_raw = PasswordHash::new(&database_password);
if hashed_password_raw.is_err() {
eprintln!("{}", hashed_password_raw.unwrap_err());
return Ok(HttpResponse::InternalServerError().json(Response::new(false)));
}
let hashed_password = hashed_password_raw.unwrap();
if data.argon2.verify_password(revoke_request.password.as_bytes(), &hashed_password).is_err() {
return Ok(HttpResponse::Unauthorized().finish())
}
let tokens_raw = sqlx::query_scalar(&format!("SELECT token FROM refresh_tokens WHERE uuid = '{}' AND device_name = $1", uuid))
.bind(revoke_request.device_name)
.fetch_all(&data.pool)
.await;
if tokens_raw.is_err() {
eprintln!("{:?}", tokens_raw);
return Ok(HttpResponse::InternalServerError().json(Response::new(false)))
}
let tokens: Vec<String> = tokens_raw.unwrap();
let mut access_tokens_delete = vec![];
let mut refresh_tokens_delete = vec![];
for token in tokens {
access_tokens_delete.push(sqlx::query("DELETE FROM access_tokens WHERE refresh_token = $1")
.bind(token.clone())
.execute(&data.pool));
refresh_tokens_delete.push(sqlx::query("DELETE FROM refresh_tokens WHERE token = $1")
.bind(token.clone())
.execute(&data.pool));
}
let results_access_tokens = future::join_all(access_tokens_delete).await;
let results_refresh_tokens = future::join_all(refresh_tokens_delete).await;
let access_tokens_errors: Vec<&Result<sqlx::postgres::PgQueryResult, sqlx::Error>> = results_access_tokens.iter().filter(|r| r.is_err()).collect();
let refresh_tokens_errors: Vec<&Result<sqlx::postgres::PgQueryResult, sqlx::Error>> = results_refresh_tokens.iter().filter(|r| r.is_err()).collect();
if !access_tokens_errors.is_empty() && !refresh_tokens_errors.is_empty() {
println!("{:?}", access_tokens_errors);
println!("{:?}", refresh_tokens_errors);
return Ok(HttpResponse::InternalServerError().finish())
} else if !access_tokens_errors.is_empty() {
println!("{:?}", access_tokens_errors);
return Ok(HttpResponse::InternalServerError().finish())
} else if !refresh_tokens_errors.is_empty() {
println!("{:?}", refresh_tokens_errors);
return Ok(HttpResponse::InternalServerError().finish())
}
Ok(HttpResponse::Ok().json(Response::new(true)))
}

View file

@ -1,6 +1,7 @@
use actix_web::{error, post, web, Error, HttpResponse}; use actix_web::{error, post, web, Error, HttpResponse};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use futures::StreamExt; use futures::StreamExt;
use uuid::Uuid;
use crate::{api::v1::auth::check_access_token, Data}; use crate::{api::v1::auth::check_access_token, Data};
@ -34,33 +35,35 @@ pub async fn res(mut payload: web::Payload, path: web::Path<(String,)>, data: we
let authentication_request = serde_json::from_slice::<AuthenticationRequest>(&body)?; let authentication_request = serde_json::from_slice::<AuthenticationRequest>(&body)?;
let authorized = check_access_token(authentication_request.access_token, data.pool.clone()).await; let authorized = check_access_token(authentication_request.access_token, &data.pool).await;
if authorized.is_err() { if authorized.is_err() {
return Ok(authorized.unwrap_err()) return Ok(authorized.unwrap_err())
} }
let uuid = authorized.unwrap(); let mut uuid = authorized.unwrap();
if request != "me" {
let requested_uuid = Uuid::parse_str(&request);
if requested_uuid.is_err() {
return Ok(HttpResponse::BadRequest().json(r#"{ "error": "UUID is invalid!" }"#))
}
uuid = requested_uuid.unwrap()
}
if request == "me" {
let row = sqlx::query_as(&format!("SELECT username, display_name FROM users WHERE uuid = '{}'", uuid)) let row = sqlx::query_as(&format!("SELECT username, display_name FROM users WHERE uuid = '{}'", uuid))
.fetch_one(&data.pool) .fetch_one(&data.pool)
.await .await;
.unwrap();
let (username, display_name): (String, Option<String>) = row; if row.is_err() {
eprintln!("{}", row.unwrap_err());
return Ok(HttpResponse::Ok().json(Response { uuid: uuid.to_string(), username, display_name: display_name.unwrap_or_default() })) return Ok(HttpResponse::InternalServerError().finish())
} else {
println!("{}", request);
if let Ok(row) = sqlx::query_as(&format!("SELECT CAST(uuid as VARCHAR), username, display_name FROM users WHERE uuid = '{}'", request))
.fetch_one(&data.pool)
.await {
let (uuid, username, display_name): (String, String, Option<String>) = row;
return Ok(HttpResponse::Ok().json(Response { uuid, username, display_name: display_name.unwrap_or_default() }))
} }
Ok(HttpResponse::NotFound().finish()) let (username, display_name): (String, Option<String>) = row.unwrap();
}
Ok(HttpResponse::Ok().json(Response { uuid: uuid.to_string(), username, display_name: display_name.unwrap_or_default() }))
} }

View file

@ -21,7 +21,7 @@ struct Args {
#[derive(Clone)] #[derive(Clone)]
struct Data { struct Data {
pub pool: Pool<Postgres>, pub pool: Pool<Postgres>,
pub config: Config, pub _config: Config,
pub argon2: Argon2<'static>, pub argon2: Argon2<'static>,
pub start_time: SystemTime, pub start_time: SystemTime,
} }
@ -56,7 +56,8 @@ async fn main() -> Result<(), Error> {
CREATE TABLE IF NOT EXISTS refresh_tokens ( CREATE TABLE IF NOT EXISTS refresh_tokens (
token varchar(64) PRIMARY KEY UNIQUE NOT NULL, token varchar(64) PRIMARY KEY UNIQUE NOT NULL,
uuid uuid NOT NULL REFERENCES users(uuid), uuid uuid NOT NULL REFERENCES users(uuid),
created int8 NOT NULL created int8 NOT NULL,
device_name varchar(16) NOT NULL
); );
CREATE TABLE IF NOT EXISTS access_tokens ( CREATE TABLE IF NOT EXISTS access_tokens (
token varchar(32) PRIMARY KEY UNIQUE NOT NULL, token varchar(32) PRIMARY KEY UNIQUE NOT NULL,
@ -70,7 +71,7 @@ async fn main() -> Result<(), Error> {
let data = Data { let data = Data {
pool, pool,
config, _config: config,
// TODO: Possibly implement "pepper" into this (thinking it could generate one if it doesnt exist and store it on disk) // TODO: Possibly implement "pepper" into this (thinking it could generate one if it doesnt exist and store it on disk)
argon2: Argon2::default(), argon2: Argon2::default(),
start_time: SystemTime::now(), start_time: SystemTime::now(),