Compare commits

...

2 commits

Author SHA1 Message Date
bb9c14db3d switch to headers for auth 2025-05-04 19:05:51 +02:00
b7a1043081 use created_at instead of created 2025-05-04 19:05:31 +02:00
9 changed files with 103 additions and 166 deletions

View file

@ -160,7 +160,7 @@ async fn login(
.as_secs() as i64; .as_secs() as i64;
if let Err(error) = sqlx::query(&format!( if let Err(error) = sqlx::query(&format!(
"INSERT INTO refresh_tokens (token, uuid, created, device_name) VALUES ($1, '{}', $2, $3 )", "INSERT INTO refresh_tokens (token, uuid, created_at, device_name) VALUES ($1, '{}', $2, $3 )",
uuid uuid
)) ))
.bind(&refresh_token) .bind(&refresh_token)
@ -174,7 +174,7 @@ async fn login(
} }
if let Err(error) = sqlx::query(&format!( if let Err(error) = sqlx::query(&format!(
"INSERT INTO access_tokens (token, refresh_token, uuid, created) VALUES ($1, $2, '{}', $3 )", "INSERT INTO access_tokens (token, refresh_token, uuid, created_at) VALUES ($1, $2, '{}', $3 )",
uuid uuid
)) ))
.bind(&access_token) .bind(&access_token)

View file

@ -34,33 +34,36 @@ pub fn web() -> Scope {
} }
pub async fn check_access_token( pub async fn check_access_token(
access_token: String, access_token: &str,
pool: &sqlx::Pool<Postgres>, pool: &sqlx::Pool<Postgres>,
) -> Result<Uuid, HttpResponse> { ) -> Result<Uuid, HttpResponse> {
let row = sqlx::query_as( let row =
"SELECT CAST(uuid as VARCHAR), created FROM access_tokens WHERE token = $1", sqlx::query_as("SELECT CAST(uuid as VARCHAR), created_at FROM access_tokens WHERE token = $1")
) .bind(&access_token)
.bind(&access_token) .fetch_one(pool)
.fetch_one(pool) .await;
.await;
if let Err(error) = row { if let Err(error) = row {
if error.to_string() == "no rows returned by a query that expected to return at least one row" { if error.to_string()
return Err(HttpResponse::Unauthorized().finish()) == "no rows returned by a query that expected to return at least one row"
{
return Err(HttpResponse::Unauthorized().finish());
} }
error!("{}", error); error!("{}", error);
return Err(HttpResponse::InternalServerError().json(r#"{ "error": "Unhandled exception occured, contact the server administrator" }"#)) return Err(HttpResponse::InternalServerError().json(
r#"{ "error": "Unhandled exception occured, contact the server administrator" }"#,
));
} }
let (uuid, created): (String, i64) = row.unwrap(); let (uuid, created_at): (String, i64) = row.unwrap();
let current_time = SystemTime::now() let current_time = SystemTime::now()
.duration_since(UNIX_EPOCH) .duration_since(UNIX_EPOCH)
.unwrap() .unwrap()
.as_secs() as i64; .as_secs() as i64;
let lifetime = current_time - created; let lifetime = current_time - created_at;
if lifetime > 3600 { if lifetime > 3600 {
return Err(HttpResponse::Unauthorized().finish()); return Err(HttpResponse::Unauthorized().finish());

View file

@ -1,7 +1,6 @@
use actix_web::{Error, HttpResponse, error, post, web}; use actix_web::{post, web, Error, HttpRequest, HttpResponse};
use futures::StreamExt;
use log::error; use log::error;
use serde::{Deserialize, Serialize}; use serde::Serialize;
use std::time::{SystemTime, UNIX_EPOCH}; use std::time::{SystemTime, UNIX_EPOCH};
use crate::{ use crate::{
@ -9,32 +8,21 @@ use crate::{
crypto::{generate_access_token, generate_refresh_token}, crypto::{generate_access_token, generate_refresh_token},
}; };
#[derive(Deserialize)]
struct RefreshRequest {
refresh_token: String,
}
#[derive(Serialize)] #[derive(Serialize)]
struct Response { struct Response {
refresh_token: String, refresh_token: String,
access_token: String, access_token: String,
} }
const MAX_SIZE: usize = 262_144;
#[post("/refresh")] #[post("/refresh")]
pub async fn res(mut payload: web::Payload, data: web::Data<Data>) -> Result<HttpResponse, Error> { pub async fn res(req: HttpRequest, data: web::Data<Data>) -> Result<HttpResponse, Error> {
let mut body = web::BytesMut::new(); let refresh_token_cookie = req.cookie("refresh_token");
while let Some(chunk) = payload.next().await {
let chunk = chunk?; if let None = refresh_token_cookie {
// limit max size of in-memory payload return Ok(HttpResponse::Unauthorized().finish())
if (body.len() + chunk.len()) > MAX_SIZE {
return Err(error::ErrorBadRequest("overflow"));
}
body.extend_from_slice(&chunk);
} }
let refresh_request = serde_json::from_slice::<RefreshRequest>(&body)?; let mut refresh_token = String::from(refresh_token_cookie.unwrap().value());
let current_time = SystemTime::now() let current_time = SystemTime::now()
.duration_since(UNIX_EPOCH) .duration_since(UNIX_EPOCH)
@ -42,26 +30,18 @@ pub async fn res(mut payload: web::Payload, data: web::Data<Data>) -> Result<Htt
.as_secs() as i64; .as_secs() as i64;
if let Ok(row) = if let Ok(row) =
sqlx::query_as("SELECT CAST(uuid as VARCHAR), created FROM refresh_tokens WHERE token = $1") sqlx::query_scalar("SELECT created_at FROM refresh_tokens WHERE token = $1")
.bind(&refresh_request.refresh_token) .bind(&refresh_token)
.fetch_one(&data.pool) .fetch_one(&data.pool)
.await .await
{ {
let (uuid, created): (String, i64) = row; let created_at: i64 = row;
if let Err(error) = sqlx::query("DELETE FROM access_tokens WHERE refresh_token = $1") let lifetime = current_time - created_at;
.bind(&refresh_request.refresh_token)
.execute(&data.pool)
.await
{
error!("{}", error);
}
let lifetime = current_time - created;
if lifetime > 2592000 { if lifetime > 2592000 {
if let Err(error) = sqlx::query("DELETE FROM refresh_tokens WHERE token = $1") if let Err(error) = sqlx::query("DELETE FROM refresh_tokens WHERE token = $1")
.bind(&refresh_request.refresh_token) .bind(&refresh_token)
.execute(&data.pool) .execute(&data.pool)
.await .await
{ {
@ -76,8 +56,6 @@ pub async fn res(mut payload: web::Payload, data: web::Data<Data>) -> Result<Htt
.unwrap() .unwrap()
.as_secs() as i64; .as_secs() as i64;
let mut refresh_token = refresh_request.refresh_token;
if lifetime > 1987200 { if lifetime > 1987200 {
let new_refresh_token = generate_refresh_token(); let new_refresh_token = generate_refresh_token();
@ -88,7 +66,7 @@ pub async fn res(mut payload: web::Payload, data: web::Data<Data>) -> Result<Htt
let new_refresh_token = new_refresh_token.unwrap(); let new_refresh_token = new_refresh_token.unwrap();
match sqlx::query("UPDATE refresh_tokens SET token = $1, created = $2 WHERE token = $3") match sqlx::query("UPDATE refresh_tokens SET token = $1, created_at = $2 WHERE token = $3")
.bind(&new_refresh_token) .bind(&new_refresh_token)
.bind(current_time) .bind(current_time)
.bind(&refresh_token) .bind(&refresh_token)
@ -113,10 +91,10 @@ pub async fn res(mut payload: web::Payload, data: web::Data<Data>) -> Result<Htt
let access_token = access_token.unwrap(); let access_token = access_token.unwrap();
if let Err(error) = sqlx::query(&format!("INSERT INTO access_tokens (token, refresh_token, uuid, created) VALUES ($1, $2, '{}', $3 )", uuid)) if let Err(error) = sqlx::query("UPDATE access_tokens SET token = $1, created_at = $2 WHERE refresh_token = $3")
.bind(&access_token) .bind(&access_token)
.bind(&refresh_token)
.bind(current_time) .bind(current_time)
.bind(&refresh_token)
.execute(&data.pool) .execute(&data.pool)
.await { .await {
error!("{}", error); error!("{}", error);

View file

@ -139,7 +139,7 @@ pub async fn res(mut payload: web::Payload, data: web::Data<Data>) -> Result<Htt
.unwrap() .unwrap()
.as_secs() as i64; .as_secs() as i64;
if let Err(error) = sqlx::query(&format!("INSERT INTO refresh_tokens (token, uuid, created, device_name) VALUES ($1, '{}', $2, $3 )", uuid)) if let Err(error) = sqlx::query(&format!("INSERT INTO refresh_tokens (token, uuid, created_at, device_name) VALUES ($1, '{}', $2, $3 )", uuid))
.bind(&refresh_token) .bind(&refresh_token)
.bind(current_time) .bind(current_time)
.bind(account_information.device_name) .bind(account_information.device_name)
@ -149,7 +149,7 @@ pub async fn res(mut payload: web::Payload, data: web::Data<Data>) -> Result<Htt
return Ok(HttpResponse::InternalServerError().finish()) return Ok(HttpResponse::InternalServerError().finish())
} }
if let Err(error) = sqlx::query(&format!("INSERT INTO access_tokens (token, refresh_token, uuid, created) VALUES ($1, $2, '{}', $3 )", uuid)) if let Err(error) = sqlx::query(&format!("INSERT INTO access_tokens (token, refresh_token, uuid, created_at) VALUES ($1, $2, '{}', $3 )", uuid))
.bind(&access_token) .bind(&access_token)
.bind(&refresh_token) .bind(&refresh_token)
.bind(current_time) .bind(current_time)

View file

@ -1,14 +1,13 @@
use actix_web::{Error, HttpResponse, error, post, web}; use actix_web::{Error, HttpRequest, HttpResponse, error, post, web};
use argon2::{PasswordHash, PasswordVerifier}; use argon2::{PasswordHash, PasswordVerifier};
use futures::{StreamExt, future}; use futures::{StreamExt, future};
use log::error; use log::error;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{Data, api::v1::auth::check_access_token}; use crate::{Data, api::v1::auth::check_access_token, utils::get_auth_header};
#[derive(Deserialize)] #[derive(Deserialize)]
struct RevokeRequest { struct RevokeRequest {
access_token: String,
password: String, password: String,
device_name: String, device_name: String,
} }
@ -27,7 +26,19 @@ impl Response {
const MAX_SIZE: usize = 262_144; const MAX_SIZE: usize = 262_144;
#[post("/revoke")] #[post("/revoke")]
pub async fn res(mut payload: web::Payload, data: web::Data<Data>) -> Result<HttpResponse, Error> { pub async fn res(
req: HttpRequest,
mut payload: web::Payload,
data: web::Data<Data>,
) -> Result<HttpResponse, Error> {
let headers = req.headers();
let auth_header = get_auth_header(headers);
if let Err(error) = auth_header {
return Ok(error);
}
let mut body = web::BytesMut::new(); let mut body = web::BytesMut::new();
while let Some(chunk) = payload.next().await { while let Some(chunk) = payload.next().await {
let chunk = chunk?; let chunk = chunk?;
@ -40,7 +51,7 @@ pub async fn res(mut payload: web::Payload, data: web::Data<Data>) -> Result<Htt
let revoke_request = serde_json::from_slice::<RevokeRequest>(&body)?; let revoke_request = serde_json::from_slice::<RevokeRequest>(&body)?;
let authorized = check_access_token(revoke_request.access_token, &data.pool).await; let authorized = check_access_token(auth_header.unwrap(), &data.pool).await;
if let Err(error) = authorized { if let Err(error) = authorized {
return Ok(error); return Ok(error);
@ -94,16 +105,9 @@ pub async fn res(mut payload: web::Payload, data: web::Data<Data>) -> Result<Htt
let tokens: Vec<String> = tokens_raw.unwrap(); let tokens: Vec<String> = tokens_raw.unwrap();
let mut access_tokens_delete = vec![];
let mut refresh_tokens_delete = vec![]; let mut refresh_tokens_delete = vec![];
for token in tokens { for token in tokens {
access_tokens_delete.push(
sqlx::query("DELETE FROM access_tokens WHERE refresh_token = $1")
.bind(token.clone())
.execute(&data.pool),
);
refresh_tokens_delete.push( refresh_tokens_delete.push(
sqlx::query("DELETE FROM refresh_tokens WHERE token = $1") sqlx::query("DELETE FROM refresh_tokens WHERE token = $1")
.bind(token.clone()) .bind(token.clone())
@ -111,29 +115,16 @@ pub async fn res(mut payload: web::Payload, data: web::Data<Data>) -> Result<Htt
); );
} }
let results_access_tokens = future::join_all(access_tokens_delete).await; let results = future::join_all(refresh_tokens_delete).await;
let results_refresh_tokens = future::join_all(refresh_tokens_delete).await;
let access_tokens_errors: Vec<&Result<sqlx::postgres::PgQueryResult, sqlx::Error>> = let errors: Vec<&Result<sqlx::postgres::PgQueryResult, sqlx::Error>> =
results_access_tokens results
.iter()
.filter(|r| r.is_err())
.collect();
let refresh_tokens_errors: Vec<&Result<sqlx::postgres::PgQueryResult, sqlx::Error>> =
results_refresh_tokens
.iter() .iter()
.filter(|r| r.is_err()) .filter(|r| r.is_err())
.collect(); .collect();
if !access_tokens_errors.is_empty() && !refresh_tokens_errors.is_empty() { if !errors.is_empty() {
error!("{:?}", access_tokens_errors); error!("{:?}", errors);
error!("{:?}", refresh_tokens_errors);
return Ok(HttpResponse::InternalServerError().finish());
} else if !access_tokens_errors.is_empty() {
error!("{:?}", access_tokens_errors);
return Ok(HttpResponse::InternalServerError().finish());
} else if !refresh_tokens_errors.is_empty() {
error!("{:?}", refresh_tokens_errors);
return Ok(HttpResponse::InternalServerError().finish()); return Ok(HttpResponse::InternalServerError().finish());
} }

View file

@ -1,14 +1,8 @@
use actix_web::{Error, HttpResponse, error, post, web}; use actix_web::{Error, HttpRequest, HttpResponse, get, web};
use futures::StreamExt;
use log::error; use log::error;
use serde::{Deserialize, Serialize}; use serde::Serialize;
use crate::{Data, api::v1::auth::check_access_token}; use crate::{Data, api::v1::auth::check_access_token, utils::get_auth_header};
#[derive(Deserialize)]
struct AuthenticationRequest {
access_token: String,
}
#[derive(Serialize)] #[derive(Serialize)]
struct Response { struct Response {
@ -17,26 +11,17 @@ struct Response {
display_name: String, display_name: String,
} }
const MAX_SIZE: usize = 262_144; #[get("/me")]
pub async fn res(req: HttpRequest, data: web::Data<Data>) -> Result<HttpResponse, Error> {
let headers = req.headers();
#[post("/me")] let auth_header = get_auth_header(headers);
pub async fn res(
mut payload: web::Payload, if let Err(error) = auth_header {
data: web::Data<Data>, return Ok(error);
) -> Result<HttpResponse, Error> {
let mut body = web::BytesMut::new();
while let Some(chunk) = payload.next().await {
let chunk = chunk?;
// limit max size of in-memory payload
if (body.len() + chunk.len()) > MAX_SIZE {
return Err(error::ErrorBadRequest("overflow"));
}
body.extend_from_slice(&chunk);
} }
let authentication_request = serde_json::from_slice::<AuthenticationRequest>(&body)?; let authorized = check_access_token(auth_header.unwrap(), &data.pool).await;
let authorized = check_access_token(authentication_request.access_token, &data.pool).await;
if let Err(error) = authorized { if let Err(error) = authorized {
return Ok(error); return Ok(error);

View file

@ -1,18 +1,16 @@
use actix_web::{error, post, web, Error, HttpResponse, Scope}; use crate::{Data, api::v1::auth::check_access_token, utils::get_auth_header};
use futures::StreamExt; use actix_web::{get, web, Error, HttpRequest, HttpResponse, Scope};
use log::error; use log::error;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::prelude::FromRow; use sqlx::prelude::FromRow;
use crate::{Data, api::v1::auth::check_access_token};
mod me; mod me;
mod uuid; mod uuid;
#[derive(Deserialize)] #[derive(Deserialize)]
struct Request { struct RequestQuery {
access_token: String, start: Option<i32>,
start: i32, amount: Option<i32>,
amount: i32,
} }
#[derive(Serialize, FromRow)] #[derive(Serialize, FromRow)]
@ -23,8 +21,6 @@ struct Response {
email: String, email: String,
} }
const MAX_SIZE: usize = 262_144;
pub fn web() -> Scope { pub fn web() -> Scope {
web::scope("/users") web::scope("/users")
.service(res) .service(res)
@ -32,36 +28,33 @@ pub fn web() -> Scope {
.service(uuid::res) .service(uuid::res)
} }
#[post("")] #[get("")]
pub async fn res( pub async fn res(
mut payload: web::Payload, req: HttpRequest,
request_query: web::Query<RequestQuery>,
data: web::Data<Data>, data: web::Data<Data>,
) -> Result<HttpResponse, Error> { ) -> Result<HttpResponse, Error> {
let mut body = web::BytesMut::new(); let headers = req.headers();
while let Some(chunk) = payload.next().await {
let chunk = chunk?; let auth_header = get_auth_header(headers);
// limit max size of in-memory payload
if (body.len() + chunk.len()) > MAX_SIZE { let start = request_query.start.unwrap_or(0);
return Err(error::ErrorBadRequest("overflow"));
} let amount = request_query.amount.unwrap_or(10);
body.extend_from_slice(&chunk);
if amount > 100 {
return Ok(HttpResponse::BadRequest().finish());
} }
let request = serde_json::from_slice::<Request>(&body)?; let authorized = check_access_token(auth_header.unwrap(), &data.pool).await;
if request.amount > 100 {
return Ok(HttpResponse::BadRequest().finish())
}
let authorized = check_access_token(request.access_token, &data.pool).await;
if let Err(error) = authorized { if let Err(error) = authorized {
return Ok(error); return Ok(error);
} }
let row = sqlx::query_as("SELECT CAST(uuid AS VARCHAR), username, display_name, email FROM users ORDER BY username LIMIT $1 OFFSET $2") let row = sqlx::query_as("SELECT CAST(uuid AS VARCHAR), username, display_name, email FROM users ORDER BY username LIMIT $1 OFFSET $2")
.bind(request.amount) .bind(amount)
.bind(request.start) .bind(start)
.fetch_all(&data.pool) .fetch_all(&data.pool)
.await; .await;
@ -74,4 +67,3 @@ pub async fn res(
Ok(HttpResponse::Ok().json(accounts)) Ok(HttpResponse::Ok().json(accounts))
} }

View file

@ -1,15 +1,9 @@
use actix_web::{Error, HttpResponse, error, post, web}; use actix_web::{Error, HttpRequest, HttpResponse, get, web};
use futures::StreamExt;
use log::error; use log::error;
use serde::{Deserialize, Serialize}; use serde::Serialize;
use uuid::Uuid; use uuid::Uuid;
use crate::{Data, api::v1::auth::check_access_token}; use crate::{Data, api::v1::auth::check_access_token, utils::get_auth_header};
#[derive(Deserialize)]
struct AuthenticationRequest {
access_token: String,
}
#[derive(Serialize)] #[derive(Serialize)]
struct Response { struct Response {
@ -18,29 +12,23 @@ struct Response {
display_name: String, display_name: String,
} }
const MAX_SIZE: usize = 262_144; #[get("/{uuid}")]
#[post("/{uuid}")]
pub async fn res( pub async fn res(
mut payload: web::Payload, req: HttpRequest,
path: web::Path<(Uuid,)>, path: web::Path<(Uuid,)>,
data: web::Data<Data>, data: web::Data<Data>,
) -> Result<HttpResponse, Error> { ) -> Result<HttpResponse, Error> {
let mut body = web::BytesMut::new(); let headers = req.headers();
while let Some(chunk) = payload.next().await {
let chunk = chunk?;
// limit max size of in-memory payload
if (body.len() + chunk.len()) > MAX_SIZE {
return Err(error::ErrorBadRequest("overflow"));
}
body.extend_from_slice(&chunk);
}
let uuid = path.into_inner().0; let uuid = path.into_inner().0;
let authentication_request = serde_json::from_slice::<AuthenticationRequest>(&body)?; let auth_header = get_auth_header(headers);
let authorized = check_access_token(authentication_request.access_token, &data.pool).await; if let Err(error) = auth_header {
return Ok(error);
}
let authorized = check_access_token(auth_header.unwrap(), &data.pool).await;
if let Err(error) = authorized { if let Err(error) = authorized {
return Ok(error); return Ok(error);

View file

@ -64,14 +64,14 @@ async fn main() -> Result<(), Error> {
CREATE TABLE IF NOT EXISTS refresh_tokens ( CREATE TABLE IF NOT EXISTS refresh_tokens (
token varchar(64) PRIMARY KEY UNIQUE NOT NULL, token varchar(64) PRIMARY KEY UNIQUE NOT NULL,
uuid uuid NOT NULL REFERENCES users(uuid), uuid uuid NOT NULL REFERENCES users(uuid),
created int8 NOT NULL, created_at int8 NOT NULL,
device_name varchar(16) NOT NULL device_name varchar(16) NOT NULL
); );
CREATE TABLE IF NOT EXISTS access_tokens ( CREATE TABLE IF NOT EXISTS access_tokens (
token varchar(32) PRIMARY KEY UNIQUE NOT NULL, token varchar(32) PRIMARY KEY UNIQUE NOT NULL,
refresh_token varchar(64) UNIQUE NOT NULL REFERENCES refresh_tokens(token), refresh_token varchar(64) UNIQUE NOT NULL REFERENCES refresh_tokens(token) ON UPDATE CASCADE ON DELETE CASCADE,
uuid uuid NOT NULL REFERENCES users(uuid), uuid uuid NOT NULL REFERENCES users(uuid),
created int8 NOT NULL created_at int8 NOT NULL
) )
"#, "#,
) )