diff --git a/Cargo.toml b/Cargo.toml index e34d9b6..4e2f58d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ lto = true codegen-units = 1 [dependencies] +actix-cors = "0.7.1" actix-web = "4.10" argon2 = { version = "0.5.3", features = ["std"] } clap = { version = "4.5.37", features = ["derive"] } @@ -21,6 +22,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" simple_logger = "5.0.0" sqlx = { version = "0.8", features = ["runtime-tokio", "tls-native-tls", "postgres"] } +redis = { version = "0.30", features= ["tokio-comp"] } toml = "0.8" url = { version = "2.5", features = ["serde"] } uuid = { version = "1.16", features = ["serde", "v7"] } diff --git a/Dockerfile b/Dockerfile index 7867f8b..d9a0389 100644 --- a/Dockerfile +++ b/Dockerfile @@ -18,6 +18,12 @@ RUN useradd --create-home --home-dir /gorb gorb USER gorb -ENV DATABASE_USERNAME="gorb" DATABASE_PASSWORD="gorb" DATABASE="gorb" DATABASE_HOST="localhost" DATABASE_PORT="5432" +ENV DATABASE_USERNAME="gorb" \ +DATABASE_PASSWORD="gorb" \ +DATABASE="gorb" \ +DATABASE_HOST="database" \ +DATABASE_PORT="5432" \ +CACHE_DB_HOST="valkey" \ +CACHE_DB_PORT="6379" ENTRYPOINT ["/usr/bin/entrypoint.sh"] diff --git a/compose.dev.yml b/compose.dev.yml index 02f46a3..d064beb 100644 --- a/compose.dev.yml +++ b/compose.dev.yml @@ -34,3 +34,8 @@ services: - POSTGRES_USER=gorb - POSTGRES_PASSWORD=gorb - POSTGRES_DB=gorb + valkey: + image: valkey/valkey + restart: always + networks: + - gorb diff --git a/compose.yml b/compose.yml index 4544dea..84e6695 100644 --- a/compose.yml +++ b/compose.yml @@ -32,3 +32,8 @@ services: - POSTGRES_USER=gorb - POSTGRES_PASSWORD=gorb - POSTGRES_DB=gorb + valkey: + image: valkey/valkey + restart: always + networks: + - gorb diff --git a/entrypoint.sh b/entrypoint.sh index 63bfa84..a212f8e 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -16,6 +16,10 @@ password = "${DATABASE_PASSWORD}" database = "${DATABASE}" host = "${DATABASE_HOST}" port = ${DATABASE_PORT} + +[cache_database] +host = "${CACHE_DB_HOST}" +port = ${CACHE_DB_PORT} EOF fi diff --git a/src/api/mod.rs b/src/api/mod.rs index 80dc442..b79c824 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,2 +1,11 @@ -pub mod v1; -pub mod versions; +use actix_web::Scope; +use actix_web::web; + +mod v1; +mod versions; + +pub fn web() -> Scope { + web::scope("/api") + .service(v1::web()) + .service(versions::res) +} diff --git a/src/api/v1/auth/login.rs b/src/api/v1/auth/login.rs index 3be5474..bc6af8c 100644 --- a/src/api/v1/auth/login.rs +++ b/src/api/v1/auth/login.rs @@ -1,17 +1,17 @@ use std::time::{SystemTime, UNIX_EPOCH}; -use actix_web::{Error, HttpResponse, error, post, web}; +use actix_web::{error, post, web, Error, HttpResponse}; use argon2::{PasswordHash, PasswordVerifier}; use futures::StreamExt; use log::error; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use crate::{ - Data, - api::v1::auth::{EMAIL_REGEX, PASSWORD_REGEX, USERNAME_REGEX}, - crypto::{generate_access_token, generate_refresh_token}, + api::v1::auth::{EMAIL_REGEX, PASSWORD_REGEX, USERNAME_REGEX}, utils::{generate_access_token, generate_refresh_token, refresh_token_cookie}, Data }; +use super::Response; + #[derive(Deserialize)] struct LoginInformation { username: String, @@ -19,12 +19,6 @@ struct LoginInformation { device_name: String, } -#[derive(Serialize)] -pub struct Response { - pub access_token: String, - pub refresh_token: String, -} - const MAX_SIZE: usize = 262_144; #[post("/login")] @@ -160,7 +154,7 @@ async fn login( .as_secs() as i64; if let Err(error) = sqlx::query(&format!( - "INSERT INTO refresh_tokens (token, uuid, created, device_name) VALUES ($1, '{}', $2, $3 )", + "INSERT INTO refresh_tokens (token, uuid, created_at, device_name) VALUES ($1, '{}', $2, $3 )", uuid )) .bind(&refresh_token) @@ -174,7 +168,7 @@ async fn login( } if let Err(error) = sqlx::query(&format!( - "INSERT INTO access_tokens (token, refresh_token, uuid, created) VALUES ($1, $2, '{}', $3 )", + "INSERT INTO access_tokens (token, refresh_token, uuid, created_at) VALUES ($1, $2, '{}', $3 )", uuid )) .bind(&access_token) @@ -187,8 +181,7 @@ async fn login( return HttpResponse::InternalServerError().finish() } - HttpResponse::Ok().json(Response { + HttpResponse::Ok().cookie(refresh_token_cookie(refresh_token)).json(Response { access_token, - refresh_token, }) } diff --git a/src/api/v1/auth/mod.rs b/src/api/v1/auth/mod.rs index ff74c6b..25910de 100644 --- a/src/api/v1/auth/mod.rs +++ b/src/api/v1/auth/mod.rs @@ -7,6 +7,7 @@ use std::{ use actix_web::{HttpResponse, Scope, web}; use log::error; use regex::Regex; +use serde::Serialize; use sqlx::Postgres; use uuid::Uuid; @@ -15,12 +16,16 @@ mod refresh; mod register; mod revoke; +#[derive(Serialize)] +struct Response { + access_token: String, +} + static EMAIL_REGEX: LazyLock = LazyLock::new(|| { Regex::new(r"[-A-Za-z0-9!#$%&'*+/=?^_`{|}~]+(?:\.[-A-Za-z0-9!#$%&'*+/=?^_`{|}~]+)*@(?:[A-Za-z0-9](?:[-A-Za-z0-9]*[A-Za-z0-9])?\.)+[A-Za-z0-9](?:[-A-Za-z0-9]*[A-Za-z0-9])?").unwrap() }); -// FIXME: This regex doesnt seem to be working -static USERNAME_REGEX: LazyLock = LazyLock::new(|| Regex::new(r"[a-zA-Z0-9.-_]").unwrap()); +static USERNAME_REGEX: LazyLock = LazyLock::new(|| Regex::new(r"^[a-z0-9_.-]+$").unwrap()); // Password is expected to be hashed using SHA3-384 static PASSWORD_REGEX: LazyLock = LazyLock::new(|| Regex::new(r"[0-9a-f]{96}").unwrap()); @@ -34,33 +39,36 @@ pub fn web() -> Scope { } pub async fn check_access_token( - access_token: String, + access_token: &str, pool: &sqlx::Pool, ) -> Result { - let row = sqlx::query_as( - "SELECT CAST(uuid as VARCHAR), created FROM access_tokens WHERE token = $1", - ) - .bind(&access_token) - .fetch_one(pool) - .await; + let row = + sqlx::query_as("SELECT CAST(uuid as VARCHAR), created_at FROM access_tokens WHERE token = $1") + .bind(&access_token) + .fetch_one(pool) + .await; if let Err(error) = row { - if error.to_string() == "no rows returned by a query that expected to return at least one row" { - return Err(HttpResponse::Unauthorized().finish()) + if error.to_string() + == "no rows returned by a query that expected to return at least one row" + { + return Err(HttpResponse::Unauthorized().finish()); } error!("{}", error); - return Err(HttpResponse::InternalServerError().json(r#"{ "error": "Unhandled exception occured, contact the server administrator" }"#)) + return Err(HttpResponse::InternalServerError().json( + r#"{ "error": "Unhandled exception occured, contact the server administrator" }"#, + )); } - let (uuid, created): (String, i64) = row.unwrap(); + let (uuid, created_at): (String, i64) = row.unwrap(); let current_time = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs() as i64; - let lifetime = current_time - created; + let lifetime = current_time - created_at; if lifetime > 3600 { return Err(HttpResponse::Unauthorized().finish()); diff --git a/src/api/v1/auth/refresh.rs b/src/api/v1/auth/refresh.rs index 5ac2402..008420b 100644 --- a/src/api/v1/auth/refresh.rs +++ b/src/api/v1/auth/refresh.rs @@ -1,40 +1,22 @@ -use actix_web::{Error, HttpResponse, error, post, web}; -use futures::StreamExt; +use actix_web::{post, web, Error, HttpRequest, HttpResponse}; use log::error; -use serde::{Deserialize, Serialize}; use std::time::{SystemTime, UNIX_EPOCH}; use crate::{ - Data, - crypto::{generate_access_token, generate_refresh_token}, + utils::{generate_access_token, generate_refresh_token, refresh_token_cookie}, Data }; -#[derive(Deserialize)] -struct RefreshRequest { - refresh_token: String, -} - -#[derive(Serialize)] -struct Response { - refresh_token: String, - access_token: String, -} - -const MAX_SIZE: usize = 262_144; +use super::Response; #[post("/refresh")] -pub async fn res(mut payload: web::Payload, data: web::Data) -> Result { - let mut body = web::BytesMut::new(); - while let Some(chunk) = payload.next().await { - let chunk = chunk?; - // limit max size of in-memory payload - if (body.len() + chunk.len()) > MAX_SIZE { - return Err(error::ErrorBadRequest("overflow")); - } - body.extend_from_slice(&chunk); +pub async fn res(req: HttpRequest, data: web::Data) -> Result { + let recv_refresh_token_cookie = req.cookie("refresh_token"); + + if let None = recv_refresh_token_cookie { + return Ok(HttpResponse::Unauthorized().finish()) } - let refresh_request = serde_json::from_slice::(&body)?; + let mut refresh_token = String::from(recv_refresh_token_cookie.unwrap().value()); let current_time = SystemTime::now() .duration_since(UNIX_EPOCH) @@ -42,33 +24,29 @@ pub async fn res(mut payload: web::Payload, data: web::Data) -> Result 2592000 { if let Err(error) = sqlx::query("DELETE FROM refresh_tokens WHERE token = $1") - .bind(&refresh_request.refresh_token) + .bind(&refresh_token) .execute(&data.pool) .await { error!("{}", error); } - return Ok(HttpResponse::Unauthorized().finish()); + let mut refresh_token_cookie = refresh_token_cookie(refresh_token); + + refresh_token_cookie.make_removal(); + + return Ok(HttpResponse::Unauthorized().cookie(refresh_token_cookie).finish()); } let current_time = SystemTime::now() @@ -76,8 +54,6 @@ pub async fn res(mut payload: web::Payload, data: web::Data) -> Result 1987200 { let new_refresh_token = generate_refresh_token(); @@ -88,7 +64,7 @@ pub async fn res(mut payload: web::Payload, data: web::Data) -> Result) -> Result) -> Result) -> Result) -> Result { diff --git a/src/api/v1/auth/revoke.rs b/src/api/v1/auth/revoke.rs index f3285c4..9ebbb30 100644 --- a/src/api/v1/auth/revoke.rs +++ b/src/api/v1/auth/revoke.rs @@ -1,14 +1,13 @@ -use actix_web::{Error, HttpResponse, error, post, web}; +use actix_web::{Error, HttpRequest, HttpResponse, error, post, web}; use argon2::{PasswordHash, PasswordVerifier}; use futures::{StreamExt, future}; use log::error; use serde::{Deserialize, Serialize}; -use crate::{Data, api::v1::auth::check_access_token}; +use crate::{Data, api::v1::auth::check_access_token, utils::get_auth_header}; #[derive(Deserialize)] struct RevokeRequest { - access_token: String, password: String, device_name: String, } @@ -27,7 +26,19 @@ impl Response { const MAX_SIZE: usize = 262_144; #[post("/revoke")] -pub async fn res(mut payload: web::Payload, data: web::Data) -> Result { +pub async fn res( + req: HttpRequest, + mut payload: web::Payload, + data: web::Data, +) -> Result { + let headers = req.headers(); + + let auth_header = get_auth_header(headers); + + if let Err(error) = auth_header { + return Ok(error); + } + let mut body = web::BytesMut::new(); while let Some(chunk) = payload.next().await { let chunk = chunk?; @@ -40,7 +51,7 @@ pub async fn res(mut payload: web::Payload, data: web::Data) -> Result(&body)?; - let authorized = check_access_token(revoke_request.access_token, &data.pool).await; + let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; if let Err(error) = authorized { return Ok(error); @@ -94,16 +105,9 @@ pub async fn res(mut payload: web::Payload, data: web::Data) -> Result = tokens_raw.unwrap(); - let mut access_tokens_delete = vec![]; let mut refresh_tokens_delete = vec![]; for token in tokens { - access_tokens_delete.push( - sqlx::query("DELETE FROM access_tokens WHERE refresh_token = $1") - .bind(token.clone()) - .execute(&data.pool), - ); - refresh_tokens_delete.push( sqlx::query("DELETE FROM refresh_tokens WHERE token = $1") .bind(token.clone()) @@ -111,29 +115,16 @@ pub async fn res(mut payload: web::Payload, data: web::Data) -> Result> = - results_access_tokens - .iter() - .filter(|r| r.is_err()) - .collect(); - let refresh_tokens_errors: Vec<&Result> = - results_refresh_tokens + let errors: Vec<&Result> = + results .iter() .filter(|r| r.is_err()) .collect(); - if !access_tokens_errors.is_empty() && !refresh_tokens_errors.is_empty() { - error!("{:?}", access_tokens_errors); - error!("{:?}", refresh_tokens_errors); - return Ok(HttpResponse::InternalServerError().finish()); - } else if !access_tokens_errors.is_empty() { - error!("{:?}", access_tokens_errors); - return Ok(HttpResponse::InternalServerError().finish()); - } else if !refresh_tokens_errors.is_empty() { - error!("{:?}", refresh_tokens_errors); + if !errors.is_empty() { + error!("{:?}", errors); return Ok(HttpResponse::InternalServerError().finish()); } diff --git a/src/api/v1/users/me.rs b/src/api/v1/users/me.rs index 18e6ba8..f641678 100644 --- a/src/api/v1/users/me.rs +++ b/src/api/v1/users/me.rs @@ -1,14 +1,8 @@ -use actix_web::{Error, HttpResponse, error, post, web}; -use futures::StreamExt; +use actix_web::{Error, HttpRequest, HttpResponse, get, web}; use log::error; -use serde::{Deserialize, Serialize}; +use serde::Serialize; -use crate::{Data, api::v1::auth::check_access_token}; - -#[derive(Deserialize)] -struct AuthenticationRequest { - access_token: String, -} +use crate::{Data, api::v1::auth::check_access_token, utils::get_auth_header}; #[derive(Serialize)] struct Response { @@ -17,26 +11,17 @@ struct Response { display_name: String, } -const MAX_SIZE: usize = 262_144; +#[get("/me")] +pub async fn res(req: HttpRequest, data: web::Data) -> Result { + let headers = req.headers(); -#[post("/me")] -pub async fn res( - mut payload: web::Payload, - data: web::Data, -) -> Result { - let mut body = web::BytesMut::new(); - while let Some(chunk) = payload.next().await { - let chunk = chunk?; - // limit max size of in-memory payload - if (body.len() + chunk.len()) > MAX_SIZE { - return Err(error::ErrorBadRequest("overflow")); - } - body.extend_from_slice(&chunk); + let auth_header = get_auth_header(headers); + + if let Err(error) = auth_header { + return Ok(error); } - let authentication_request = serde_json::from_slice::(&body)?; - - let authorized = check_access_token(authentication_request.access_token, &data.pool).await; + let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; if let Err(error) = authorized { return Ok(error); diff --git a/src/api/v1/users/mod.rs b/src/api/v1/users/mod.rs index 937d857..d7cb1c6 100644 --- a/src/api/v1/users/mod.rs +++ b/src/api/v1/users/mod.rs @@ -1,18 +1,16 @@ -use actix_web::{error, post, web, Error, HttpResponse, Scope}; -use futures::StreamExt; +use crate::{Data, api::v1::auth::check_access_token, utils::get_auth_header}; +use actix_web::{get, web, Error, HttpRequest, HttpResponse, Scope}; use log::error; use serde::{Deserialize, Serialize}; use sqlx::prelude::FromRow; -use crate::{Data, api::v1::auth::check_access_token}; mod me; mod uuid; #[derive(Deserialize)] -struct Request { - access_token: String, - start: i32, - amount: i32, +struct RequestQuery { + start: Option, + amount: Option, } #[derive(Serialize, FromRow)] @@ -23,8 +21,6 @@ struct Response { email: String, } -const MAX_SIZE: usize = 262_144; - pub fn web() -> Scope { web::scope("/users") .service(res) @@ -32,36 +28,33 @@ pub fn web() -> Scope { .service(uuid::res) } -#[post("")] +#[get("")] pub async fn res( - mut payload: web::Payload, + req: HttpRequest, + request_query: web::Query, data: web::Data, ) -> Result { - let mut body = web::BytesMut::new(); - while let Some(chunk) = payload.next().await { - let chunk = chunk?; - // limit max size of in-memory payload - if (body.len() + chunk.len()) > MAX_SIZE { - return Err(error::ErrorBadRequest("overflow")); - } - body.extend_from_slice(&chunk); + let headers = req.headers(); + + let auth_header = get_auth_header(headers); + + let start = request_query.start.unwrap_or(0); + + let amount = request_query.amount.unwrap_or(10); + + if amount > 100 { + return Ok(HttpResponse::BadRequest().finish()); } - let request = serde_json::from_slice::(&body)?; - - if request.amount > 100 { - return Ok(HttpResponse::BadRequest().finish()) - } - - let authorized = check_access_token(request.access_token, &data.pool).await; + let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; if let Err(error) = authorized { return Ok(error); } let row = sqlx::query_as("SELECT CAST(uuid AS VARCHAR), username, display_name, email FROM users ORDER BY username LIMIT $1 OFFSET $2") - .bind(request.amount) - .bind(request.start) + .bind(amount) + .bind(start) .fetch_all(&data.pool) .await; @@ -74,4 +67,3 @@ pub async fn res( Ok(HttpResponse::Ok().json(accounts)) } - diff --git a/src/api/v1/users/uuid.rs b/src/api/v1/users/uuid.rs index 41d87cc..5e4db39 100644 --- a/src/api/v1/users/uuid.rs +++ b/src/api/v1/users/uuid.rs @@ -1,51 +1,45 @@ -use actix_web::{Error, HttpResponse, error, post, web}; -use futures::StreamExt; +use actix_web::{Error, HttpRequest, HttpResponse, get, web}; use log::error; -use serde::{Deserialize, Serialize}; +use serde::Serialize; use uuid::Uuid; -use crate::{Data, api::v1::auth::check_access_token}; +use crate::{Data, api::v1::auth::check_access_token, utils::get_auth_header}; -#[derive(Deserialize)] -struct AuthenticationRequest { - access_token: String, -} - -#[derive(Serialize)] +#[derive(Serialize, Clone)] struct Response { uuid: String, username: String, display_name: String, } -const MAX_SIZE: usize = 262_144; - -#[post("/{uuid}")] +#[get("/{uuid}")] pub async fn res( - mut payload: web::Payload, + req: HttpRequest, path: web::Path<(Uuid,)>, data: web::Data, ) -> Result { - let mut body = web::BytesMut::new(); - while let Some(chunk) = payload.next().await { - let chunk = chunk?; - // limit max size of in-memory payload - if (body.len() + chunk.len()) > MAX_SIZE { - return Err(error::ErrorBadRequest("overflow")); - } - body.extend_from_slice(&chunk); - } + let headers = req.headers(); let uuid = path.into_inner().0; - let authentication_request = serde_json::from_slice::(&body)?; + let auth_header = get_auth_header(headers); - let authorized = check_access_token(authentication_request.access_token, &data.pool).await; + if let Err(error) = auth_header { + return Ok(error); + } + + let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; if let Err(error) = authorized { return Ok(error); } + let cache_result = data.get_cache_key(uuid.to_string()).await; + + if let Ok(cache_hit) = cache_result { + return Ok(HttpResponse::Ok().content_type("application/json").body(cache_hit)) + } + let row = sqlx::query_as(&format!( "SELECT username, display_name FROM users WHERE uuid = '{}'", uuid @@ -60,9 +54,18 @@ pub async fn res( let (username, display_name): (String, Option) = row.unwrap(); - Ok(HttpResponse::Ok().json(Response { + let user = Response { uuid: uuid.to_string(), username, display_name: display_name.unwrap_or_default(), - })) + }; + + let cache_result = data.set_cache_key(uuid.to_string(), user.clone(), 1800).await; + + if let Err(error) = cache_result { + error!("{}", error); + return Ok(HttpResponse::InternalServerError().finish()); + } + + Ok(HttpResponse::Ok().json(user)) } diff --git a/src/config.rs b/src/config.rs index a2a6192..65a5965 100644 --- a/src/config.rs +++ b/src/config.rs @@ -7,6 +7,7 @@ use tokio::fs::read_to_string; #[derive(Debug, Deserialize)] pub struct ConfigBuilder { database: Database, + cache_database: CacheDatabase, web: Option, } @@ -19,6 +20,15 @@ pub struct Database { port: u16, } +#[derive(Debug, Deserialize, Clone)] +pub struct CacheDatabase { + username: Option, + password: Option, + host: String, + database: Option, + port: u16, +} + #[derive(Debug, Deserialize)] struct WebBuilder { url: Option, @@ -51,6 +61,7 @@ impl ConfigBuilder { Config { database: self.database, + cache_database: self.cache_database, web, } } @@ -59,6 +70,7 @@ impl ConfigBuilder { #[derive(Debug, Clone)] pub struct Config { pub database: Database, + pub cache_database: CacheDatabase, pub web: Web, } @@ -78,3 +90,33 @@ impl Database { .port(self.port) } } + +impl CacheDatabase { + pub fn url(&self) -> String { + let mut url = String::from("redis://"); + + if let Some(username) = &self.username { + url += username; + } + + if let Some(password) = &self.password { + url += ":"; + url += password; + } + + if self.username.is_some() || self.password.is_some() { + url += "@"; + } + + url += &self.host; + url += ":"; + url += &self.port.to_string(); + + if let Some(database) = &self.database { + url += "/"; + url += database; + } + + url + } +} diff --git a/src/crypto.rs b/src/crypto.rs deleted file mode 100644 index c4d96c8..0000000 --- a/src/crypto.rs +++ /dev/null @@ -1,14 +0,0 @@ -use getrandom::fill; -use hex::encode; - -pub fn generate_access_token() -> Result { - let mut buf = [0u8; 16]; - fill(&mut buf)?; - Ok(encode(buf)) -} - -pub fn generate_refresh_token() -> Result { - let mut buf = [0u8; 32]; - fill(&mut buf)?; - Ok(encode(buf)) -} diff --git a/src/main.rs b/src/main.rs index 4c909b1..7f21e2e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,4 @@ +use actix_cors::Cors; use actix_web::{App, HttpServer, web}; use argon2::Argon2; use clap::Parser; @@ -7,7 +8,8 @@ use std::time::SystemTime; mod config; use config::{Config, ConfigBuilder}; mod api; -pub mod crypto; + +pub mod utils; type Error = Box; @@ -21,6 +23,7 @@ struct Args { #[derive(Clone)] struct Data { pub pool: Pool, + pub cache_pool: redis::Client, pub _config: Config, pub argon2: Argon2<'static>, pub start_time: SystemTime, @@ -42,6 +45,8 @@ async fn main() -> Result<(), Error> { let pool = PgPool::connect_with(config.database.connect_options()).await?; + let cache_pool = redis::Client::open(config.cache_database.url())?; + /* TODO: Figure out if a table should be used here and if not then what. Also figure out if these should be different types from what they currently are and if we should add more "constraints" @@ -63,14 +68,14 @@ async fn main() -> Result<(), Error> { CREATE TABLE IF NOT EXISTS refresh_tokens ( token varchar(64) PRIMARY KEY UNIQUE NOT NULL, uuid uuid NOT NULL REFERENCES users(uuid), - created int8 NOT NULL, + created_at int8 NOT NULL, device_name varchar(16) NOT NULL ); CREATE TABLE IF NOT EXISTS access_tokens ( token varchar(32) PRIMARY KEY UNIQUE NOT NULL, - refresh_token varchar(64) UNIQUE NOT NULL REFERENCES refresh_tokens(token), + refresh_token varchar(64) UNIQUE NOT NULL REFERENCES refresh_tokens(token) ON UPDATE CASCADE ON DELETE CASCADE, uuid uuid NOT NULL REFERENCES users(uuid), - created int8 NOT NULL + created_at int8 NOT NULL ) "#, ) @@ -79,17 +84,46 @@ async fn main() -> Result<(), Error> { let data = Data { pool, + cache_pool, _config: config, // TODO: Possibly implement "pepper" into this (thinking it could generate one if it doesnt exist and store it on disk) argon2: Argon2::default(), start_time: SystemTime::now(), }; + HttpServer::new(move || { + // Set CORS headers + let cors = Cors::default() + /* + Set Allowed-Control-Allow-Origin header to whatever + the request's Origin header is. Must be done like this + rather than setting it to "*" due to CORS not allowing + sending of credentials (cookies) with wildcard origin. + */ + .allowed_origin_fn(|_origin, _req_head| { + true + }) + /* + Allows any request method in CORS preflight requests. + This will be restricted to only ones actually in use later. + */ + .allow_any_method() + /* + Allows any header(s) in request in CORS preflight requests. + This wll be restricted to only ones actually in use later. + */ + .allow_any_header() + /* + Allows browser to include cookies in requests. + This is needed for receiving the secure HttpOnly refresh_token cookie. + */ + .supports_credentials(); + App::new() .app_data(web::Data::new(data.clone())) - .service(api::versions::res) - .service(api::v1::web()) + .wrap(cors) + .service(api::web()) }) .bind((web.url, web.port))? .run() diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..15e5e2e --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,74 @@ +use actix_web::{cookie::{time::Duration, Cookie, SameSite}, http::header::HeaderMap, HttpResponse}; +use getrandom::fill; +use hex::encode; +use redis::RedisError; +use serde::Serialize; + +use crate::Data; + +pub fn get_auth_header(headers: &HeaderMap) -> Result<&str, HttpResponse> { + let auth_token = headers.get(actix_web::http::header::AUTHORIZATION); + + if let None = auth_token { + return Err(HttpResponse::Unauthorized().finish()); + } + + let auth = auth_token.unwrap().to_str(); + + if let Err(error) = auth { + return Err(HttpResponse::Unauthorized().json(format!(r#" {{ "error": "{}" }} "#, error))); + } + + let auth_value = auth.unwrap().split_whitespace().nth(1); + + if let None = auth_value { + return Err(HttpResponse::BadRequest().finish()); + } + + Ok(auth_value.unwrap()) +} + +pub fn refresh_token_cookie(refresh_token: String) -> Cookie<'static> { + Cookie::build("refresh_token", refresh_token) + .http_only(true) + .secure(true) + .same_site(SameSite::None) + .path("/api") + .max_age(Duration::days(30)) + .finish() +} + +pub fn generate_access_token() -> Result { + let mut buf = [0u8; 16]; + fill(&mut buf)?; + Ok(encode(buf)) +} + +pub fn generate_refresh_token() -> Result { + let mut buf = [0u8; 32]; + fill(&mut buf)?; + Ok(encode(buf)) +} + +impl Data { + pub async fn set_cache_key(&self, key: String, value: impl Serialize, expire: u32) -> Result<(), RedisError> { + let mut conn = self.cache_pool.get_multiplexed_tokio_connection().await?; + + let key_encoded = encode(key); + + let value_json = serde_json::to_string(&value).unwrap(); + + redis::cmd("SET",).arg(&[key_encoded.clone(), value_json]).exec_async(&mut conn).await?; + + redis::cmd("EXPIRE").arg(&[key_encoded, expire.to_string()]).exec_async(&mut conn).await + } + + pub async fn get_cache_key(&self, key: String) -> Result { + let mut conn = self.cache_pool.get_multiplexed_tokio_connection().await?; + + let key_encoded = encode(key); + + redis::cmd("GET").arg(key_encoded).query_async(&mut conn).await + } +} +