From 27fbb6508e445b3eecd21dc9bc445b1eff72134d Mon Sep 17 00:00:00 2001 From: Radical Date: Wed, 21 May 2025 20:47:45 +0200 Subject: [PATCH 01/17] build: switch sqlx to diesel --- Cargo.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 33d01e7..e6dcd84 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,6 @@ regex = "1.11" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" simple_logger = "5.0.0" -sqlx = { version = "0.8", features = ["runtime-tokio", "tls-native-tls", "postgres"] } redis = { version = "0.31.0", features= ["tokio-comp"] } tokio-tungstenite = { version = "0.26", features = ["native-tls", "url"] } toml = "0.8" @@ -30,6 +29,9 @@ uuid = { version = "1.16", features = ["serde", "v7"] } random-string = "1.1" actix-ws = "0.3.0" futures-util = "0.3.31" +deadpool = "0.12" +diesel = "2.2" +diesel-async = { version = "0.5", features = ["deadpool", "postgres"] } [dependencies.tokio] version = "1.44" -- 2.47.2 From b9c7bda2b15ea19754d046655fb92f5a6972f8e2 Mon Sep 17 00:00:00 2001 From: Radical Date: Wed, 21 May 2025 20:48:09 +0200 Subject: [PATCH 02/17] feat: use diesel in main fn and data struct --- src/main.rs | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/main.rs b/src/main.rs index fbad594..9036665 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,14 +3,19 @@ use actix_web::{App, HttpServer, web}; use argon2::Argon2; use clap::Parser; use simple_logger::SimpleLogger; -use sqlx::{PgPool, Pool, Postgres}; +use diesel_async::pooled_connection::AsyncDieselConnectionManager; +use diesel_async::pooled_connection::deadpool::Pool; +use diesel_async::RunQueryDsl; use std::time::SystemTime; mod config; use config::{Config, ConfigBuilder}; mod api; +type Conn = deadpool::managed::Object>; + pub mod structs; pub mod utils; +pub mod tables; type Error = Box; @@ -23,7 +28,7 @@ struct Args { #[derive(Clone)] pub struct Data { - pub pool: Pool, + pub pool: deadpool::managed::Pool, Conn>, pub cache_pool: redis::Client, pub _config: Config, pub argon2: Argon2<'static>, @@ -44,17 +49,21 @@ async fn main() -> Result<(), Error> { let web = config.web.clone(); - let pool = PgPool::connect_with(config.database.connect_options()).await?; + // create a new connection pool with the default config + let pool_config = AsyncDieselConnectionManager::::new(config.database.url()); + let pool = Pool::builder(pool_config).build()?; let cache_pool = redis::Client::open(config.cache_database.url())?; + let mut conn = pool.get().await?; + /* TODO: Figure out if a table should be used here and if not then what. Also figure out if these should be different types from what they currently are and if we should add more "constraints" TODO: References to time should be removed in favor of using the timestamp built in to UUIDv7 (apart from deleted_at in users) */ - sqlx::raw_sql( + diesel::sql_query( r#" CREATE TABLE IF NOT EXISTS users ( uuid uuid PRIMARY KEY NOT NULL, @@ -141,7 +150,7 @@ async fn main() -> Result<(), Error> { ); "#, ) - .execute(&pool) + .execute(&mut conn) .await?; /* -- 2.47.2 From 746949f0e54284f907ce414392bd62dc482a453e Mon Sep 17 00:00:00 2001 From: Radical Date: Wed, 21 May 2025 20:48:43 +0200 Subject: [PATCH 03/17] feat: use url format --- src/config.rs | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/config.rs b/src/config.rs index 65a5965..4e8fc21 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,7 +1,6 @@ use crate::Error; use log::debug; use serde::Deserialize; -use sqlx::postgres::PgConnectOptions; use tokio::fs::read_to_string; #[derive(Debug, Deserialize)] @@ -81,13 +80,24 @@ pub struct Web { } impl Database { - pub fn connect_options(&self) -> PgConnectOptions { - PgConnectOptions::new() - .database(&self.database) - .host(&self.host) - .username(&self.username) - .password(&self.password) - .port(self.port) + pub fn url(&self) -> String { + let mut url = String::from("postgres://"); + + url += &self.username; + + url += ":"; + url += &self.password; + + url += "@"; + + url += &self.host; + url += ":"; + url += &self.port.to_string(); + + url += "/"; + url += &self.database; + + url } } -- 2.47.2 From da804cd43637150379df28aabfaad2185628d66d Mon Sep 17 00:00:00 2001 From: Radical Date: Wed, 21 May 2025 20:49:13 +0200 Subject: [PATCH 04/17] feat: use diesel on Channel and ChannelPermission structs --- src/structs.rs | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/structs.rs b/src/structs.rs index 1b339b1..7cec7c9 100644 --- a/src/structs.rs +++ b/src/structs.rs @@ -1,14 +1,15 @@ use std::str::FromStr; use actix_web::HttpResponse; +use diesel::Selectable; use log::error; use serde::{Deserialize, Serialize}; -use sqlx::{Pool, Postgres, prelude::FromRow}; use uuid::Uuid; -use crate::Data; +use crate::{Conn, Data, tables::*}; -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, Selectable)] +#[diesel(table_name = channels)] pub struct Channel { pub uuid: Uuid, pub guild_uuid: Uuid, @@ -17,7 +18,7 @@ pub struct Channel { pub permissions: Vec, } -#[derive(Serialize, Clone, FromRow)] +#[derive(Serialize, Clone)] struct ChannelPermissionBuilder { role_uuid: String, permissions: i32, @@ -32,7 +33,8 @@ impl ChannelPermissionBuilder { } } -#[derive(Serialize, Deserialize, Clone, FromRow)] +#[derive(Serialize, Deserialize, Clone, Selectable)] +#[diesel(table_name = channel_permissions)] pub struct ChannelPermission { pub role_uuid: Uuid, pub permissions: i32, @@ -40,15 +42,10 @@ pub struct ChannelPermission { impl Channel { pub async fn fetch_all( - pool: &Pool, + conn: &mut Conn, guild_uuid: Uuid, ) -> Result, HttpResponse> { - let row = sqlx::query_as(&format!( - "SELECT CAST(uuid AS VARCHAR), name, description FROM channels WHERE guild_uuid = '{}'", - guild_uuid - )) - .fetch_all(pool) - .await; + if let Err(error) = row { error!("{}", error); -- 2.47.2 From f1d5b4316eeccac7be3ceee0d5d8f2e57d2cdf9d Mon Sep 17 00:00:00 2001 From: Radical Date: Wed, 21 May 2025 20:49:20 +0200 Subject: [PATCH 05/17] feat: add tables.rs --- src/tables.rs | 109 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 src/tables.rs diff --git a/src/tables.rs b/src/tables.rs new file mode 100644 index 0000000..3dbd38b --- /dev/null +++ b/src/tables.rs @@ -0,0 +1,109 @@ +use diesel::table; + +table! { + users (uuid) { + uuid -> Uuid, + username -> VarChar, + display_name -> Nullable, + password -> VarChar, + email -> VarChar, + email_verified -> Bool, + is_deleted -> Bool, + deleted_at -> Int8, + } +} + +table! { + instance_permissions (uuid) { + uuid -> Uuid, + administrator -> Bool, + } +} + +table! { + refresh_tokens (token) { + token -> VarChar, + uuid -> Uuid, + created_at -> Int8, + device_name -> VarChar, + } +} + +table! { + access_tokens (token) { + token -> VarChar, + refresh_token -> VarChar, + uuid -> Uuid, + created_at -> Int8 + } +} + +table! { + guilds (uuid) { + uuid -> Uuid, + owner_uuid -> Uuid, + name -> VarChar, + description -> VarChar + } +} + +table! { + guild_members (uuid) { + uuid -> Uuid, + guild_uuid -> Uuid, + user_uuid -> Uuid, + nickname -> VarChar, + } +} + +table! { + roles (uuid, guild_uuid) { + uuid -> Uuid, + guild_uuid -> Uuid, + name -> VarChar, + color -> Int4, + position -> Int4, + permissions -> Int8, + } +} + +table! { + role_members (role_uuid, member_uuid) { + role_uuid -> Uuid, + member_uuid -> Uuid, + } +} + +table! { + channels (uuid) { + uuid -> Uuid, + guild_uuid -> Uuid, + name -> VarChar, + description -> VarChar, + } +} + +table! { + channel_permissions (channel_uuid, role_uuid) { + channel_uuid -> Uuid, + role_uuid -> Uuid, + permissions -> Int8, + } +} + +table! { + messages (uuid) { + uuid -> Uuid, + channel_uuid -> Uuid, + user_uuid -> Uuid, + message -> VarChar, + } +} + +table! { + invites (id) { + id -> VarChar, + guild_uuid -> Uuid, + user_uuid -> Uuid, + } +} -- 2.47.2 From a6d35b0ba2c29e99fcf445bcde038f80c1b3e0a4 Mon Sep 17 00:00:00 2001 From: Radical Date: Wed, 21 May 2025 21:49:01 +0200 Subject: [PATCH 06/17] feat: use diesel-cli instead of hand writing tables after reading the documentation, crazy right? I figured out i was making my life hard, this makes my life easy again --- diesel.toml | 9 + migrations/.keep | 0 .../down.sql | 6 + .../up.sql | 36 ++++ .../2025-05-21-192435_create_users/down.sql | 4 + .../2025-05-21-192435_create_users/up.sql | 20 +++ .../down.sql | 2 + .../up.sql | 5 + .../2025-05-21-193321_create_tokens/down.sql | 3 + .../2025-05-21-193321_create_tokens/up.sql | 13 ++ .../2025-05-21-193500_create_guilds/down.sql | 3 + .../2025-05-21-193500_create_guilds/up.sql | 13 ++ .../2025-05-21-193620_create_roles/down.sql | 3 + .../2025-05-21-193620_create_roles/up.sql | 15 ++ .../down.sql | 3 + .../2025-05-21-193745_create_channels/up.sql | 13 ++ .../down.sql | 2 + .../2025-05-21-193954_create_messages/up.sql | 7 + .../2025-05-21-194207_create_invites/down.sql | 2 + .../2025-05-21-194207_create_invites/up.sql | 6 + src/main.rs | 97 +---------- src/schema.rs | 156 ++++++++++++++++++ src/structs.rs | 2 +- src/tables.rs | 109 ------------ 24 files changed, 323 insertions(+), 206 deletions(-) create mode 100644 diesel.toml create mode 100644 migrations/.keep create mode 100644 migrations/00000000000000_diesel_initial_setup/down.sql create mode 100644 migrations/00000000000000_diesel_initial_setup/up.sql create mode 100644 migrations/2025-05-21-192435_create_users/down.sql create mode 100644 migrations/2025-05-21-192435_create_users/up.sql create mode 100644 migrations/2025-05-21-192936_create_instance_permissions/down.sql create mode 100644 migrations/2025-05-21-192936_create_instance_permissions/up.sql create mode 100644 migrations/2025-05-21-193321_create_tokens/down.sql create mode 100644 migrations/2025-05-21-193321_create_tokens/up.sql create mode 100644 migrations/2025-05-21-193500_create_guilds/down.sql create mode 100644 migrations/2025-05-21-193500_create_guilds/up.sql create mode 100644 migrations/2025-05-21-193620_create_roles/down.sql create mode 100644 migrations/2025-05-21-193620_create_roles/up.sql create mode 100644 migrations/2025-05-21-193745_create_channels/down.sql create mode 100644 migrations/2025-05-21-193745_create_channels/up.sql create mode 100644 migrations/2025-05-21-193954_create_messages/down.sql create mode 100644 migrations/2025-05-21-193954_create_messages/up.sql create mode 100644 migrations/2025-05-21-194207_create_invites/down.sql create mode 100644 migrations/2025-05-21-194207_create_invites/up.sql create mode 100644 src/schema.rs delete mode 100644 src/tables.rs diff --git a/diesel.toml b/diesel.toml new file mode 100644 index 0000000..a0d61bf --- /dev/null +++ b/diesel.toml @@ -0,0 +1,9 @@ +# For documentation on how to configure this file, +# see https://diesel.rs/guides/configuring-diesel-cli + +[print_schema] +file = "src/schema.rs" +custom_type_derives = ["diesel::query_builder::QueryId", "Clone"] + +[migrations_directory] +dir = "migrations" diff --git a/migrations/.keep b/migrations/.keep new file mode 100644 index 0000000..e69de29 diff --git a/migrations/00000000000000_diesel_initial_setup/down.sql b/migrations/00000000000000_diesel_initial_setup/down.sql new file mode 100644 index 0000000..a9f5260 --- /dev/null +++ b/migrations/00000000000000_diesel_initial_setup/down.sql @@ -0,0 +1,6 @@ +-- This file was automatically created by Diesel to setup helper functions +-- and other internal bookkeeping. This file is safe to edit, any future +-- changes will be added to existing projects as new migrations. + +DROP FUNCTION IF EXISTS diesel_manage_updated_at(_tbl regclass); +DROP FUNCTION IF EXISTS diesel_set_updated_at(); diff --git a/migrations/00000000000000_diesel_initial_setup/up.sql b/migrations/00000000000000_diesel_initial_setup/up.sql new file mode 100644 index 0000000..d68895b --- /dev/null +++ b/migrations/00000000000000_diesel_initial_setup/up.sql @@ -0,0 +1,36 @@ +-- This file was automatically created by Diesel to setup helper functions +-- and other internal bookkeeping. This file is safe to edit, any future +-- changes will be added to existing projects as new migrations. + + + + +-- Sets up a trigger for the given table to automatically set a column called +-- `updated_at` whenever the row is modified (unless `updated_at` was included +-- in the modified columns) +-- +-- # Example +-- +-- ```sql +-- CREATE TABLE users (id SERIAL PRIMARY KEY, updated_at TIMESTAMP NOT NULL DEFAULT NOW()); +-- +-- SELECT diesel_manage_updated_at('users'); +-- ``` +CREATE OR REPLACE FUNCTION diesel_manage_updated_at(_tbl regclass) RETURNS VOID AS $$ +BEGIN + EXECUTE format('CREATE TRIGGER set_updated_at BEFORE UPDATE ON %s + FOR EACH ROW EXECUTE PROCEDURE diesel_set_updated_at()', _tbl); +END; +$$ LANGUAGE plpgsql; + +CREATE OR REPLACE FUNCTION diesel_set_updated_at() RETURNS trigger AS $$ +BEGIN + IF ( + NEW IS DISTINCT FROM OLD AND + NEW.updated_at IS NOT DISTINCT FROM OLD.updated_at + ) THEN + NEW.updated_at := current_timestamp; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; diff --git a/migrations/2025-05-21-192435_create_users/down.sql b/migrations/2025-05-21-192435_create_users/down.sql new file mode 100644 index 0000000..a54826f --- /dev/null +++ b/migrations/2025-05-21-192435_create_users/down.sql @@ -0,0 +1,4 @@ +-- This file should undo anything in `up.sql` +DROP INDEX idx_unique_username_active; +DROP INDEX idx_unique_email_active; +DROP TABLE users; diff --git a/migrations/2025-05-21-192435_create_users/up.sql b/migrations/2025-05-21-192435_create_users/up.sql new file mode 100644 index 0000000..0262507 --- /dev/null +++ b/migrations/2025-05-21-192435_create_users/up.sql @@ -0,0 +1,20 @@ +-- Your SQL goes here +CREATE TABLE users ( + uuid uuid PRIMARY KEY NOT NULL, + username varchar(32) NOT NULL, + display_name varchar(64) DEFAULT NULL, + password varchar(512) NOT NULL, + email varchar(100) NOT NULL, + email_verified boolean NOT NULL DEFAULT FALSE, + is_deleted boolean NOT NULL DEFAULT FALSE, + deleted_at int8 DEFAULT NULL, + CONSTRAINT unique_username_active UNIQUE NULLS NOT DISTINCT (username, is_deleted), + CONSTRAINT unique_email_active UNIQUE NULLS NOT DISTINCT (email, is_deleted) +); + +CREATE UNIQUE INDEX idx_unique_username_active +ON users(username) +WHERE is_deleted = FALSE; +CREATE UNIQUE INDEX idx_unique_email_active +ON users(email) +WHERE is_deleted = FALSE; diff --git a/migrations/2025-05-21-192936_create_instance_permissions/down.sql b/migrations/2025-05-21-192936_create_instance_permissions/down.sql new file mode 100644 index 0000000..c72fb0f --- /dev/null +++ b/migrations/2025-05-21-192936_create_instance_permissions/down.sql @@ -0,0 +1,2 @@ +-- This file should undo anything in `up.sql` +DROP TABLE instance_permissions; diff --git a/migrations/2025-05-21-192936_create_instance_permissions/up.sql b/migrations/2025-05-21-192936_create_instance_permissions/up.sql new file mode 100644 index 0000000..f3dd755 --- /dev/null +++ b/migrations/2025-05-21-192936_create_instance_permissions/up.sql @@ -0,0 +1,5 @@ +-- Your SQL goes here +CREATE TABLE instance_permissions ( + uuid uuid PRIMARY KEY NOT NULL REFERENCES users(uuid), + administrator boolean NOT NULL DEFAULT FALSE +); diff --git a/migrations/2025-05-21-193321_create_tokens/down.sql b/migrations/2025-05-21-193321_create_tokens/down.sql new file mode 100644 index 0000000..4555fe6 --- /dev/null +++ b/migrations/2025-05-21-193321_create_tokens/down.sql @@ -0,0 +1,3 @@ +-- This file should undo anything in `up.sql` +DROP TABLE access_tokens; +DROP TABLE refresh_tokens; diff --git a/migrations/2025-05-21-193321_create_tokens/up.sql b/migrations/2025-05-21-193321_create_tokens/up.sql new file mode 100644 index 0000000..b3fb554 --- /dev/null +++ b/migrations/2025-05-21-193321_create_tokens/up.sql @@ -0,0 +1,13 @@ +-- Your SQL goes here +CREATE TABLE refresh_tokens ( + token varchar(64) PRIMARY KEY UNIQUE NOT NULL, + uuid uuid NOT NULL REFERENCES users(uuid), + created_at int8 NOT NULL, + device_name varchar(16) NOT NULL +); +CREATE TABLE access_tokens ( + token varchar(32) PRIMARY KEY UNIQUE NOT NULL, + refresh_token varchar(64) UNIQUE NOT NULL REFERENCES refresh_tokens(token) ON UPDATE CASCADE ON DELETE CASCADE, + uuid uuid NOT NULL REFERENCES users(uuid), + created_at int8 NOT NULL +); diff --git a/migrations/2025-05-21-193500_create_guilds/down.sql b/migrations/2025-05-21-193500_create_guilds/down.sql new file mode 100644 index 0000000..12ae87e --- /dev/null +++ b/migrations/2025-05-21-193500_create_guilds/down.sql @@ -0,0 +1,3 @@ +-- This file should undo anything in `up.sql` +DROP TABLE guild_members; +DROP TABLE guilds; diff --git a/migrations/2025-05-21-193500_create_guilds/up.sql b/migrations/2025-05-21-193500_create_guilds/up.sql new file mode 100644 index 0000000..268c597 --- /dev/null +++ b/migrations/2025-05-21-193500_create_guilds/up.sql @@ -0,0 +1,13 @@ +-- Your SQL goes here +CREATE TABLE guilds ( + uuid uuid PRIMARY KEY NOT NULL, + owner_uuid uuid NOT NULL REFERENCES users(uuid), + name VARCHAR(100) NOT NULL, + description VARCHAR(300) +); +CREATE TABLE guild_members ( + uuid uuid PRIMARY KEY NOT NULL, + guild_uuid uuid NOT NULL REFERENCES guilds(uuid) ON DELETE CASCADE, + user_uuid uuid NOT NULL REFERENCES users(uuid), + nickname VARCHAR(100) DEFAULT NULL +); diff --git a/migrations/2025-05-21-193620_create_roles/down.sql b/migrations/2025-05-21-193620_create_roles/down.sql new file mode 100644 index 0000000..f215a04 --- /dev/null +++ b/migrations/2025-05-21-193620_create_roles/down.sql @@ -0,0 +1,3 @@ +-- This file should undo anything in `up.sql` +DROP TABLE role_members; +DROP TABLE roles; diff --git a/migrations/2025-05-21-193620_create_roles/up.sql b/migrations/2025-05-21-193620_create_roles/up.sql new file mode 100644 index 0000000..55d051d --- /dev/null +++ b/migrations/2025-05-21-193620_create_roles/up.sql @@ -0,0 +1,15 @@ +-- Your SQL goes here +CREATE TABLE roles ( + uuid uuid UNIQUE NOT NULL, + guild_uuid uuid NOT NULL REFERENCES guilds(uuid) ON DELETE CASCADE, + name VARCHAR(50) NOT NULL, + color int NOT NULL DEFAULT 16777215, + position int NOT NULL, + permissions int8 NOT NULL DEFAULT 0, + PRIMARY KEY (uuid, guild_uuid) +); +CREATE TABLE role_members ( + role_uuid uuid NOT NULL REFERENCES roles(uuid) ON DELETE CASCADE, + member_uuid uuid NOT NULL REFERENCES guild_members(uuid) ON DELETE CASCADE, + PRIMARY KEY (role_uuid, member_uuid) +); diff --git a/migrations/2025-05-21-193745_create_channels/down.sql b/migrations/2025-05-21-193745_create_channels/down.sql new file mode 100644 index 0000000..6334604 --- /dev/null +++ b/migrations/2025-05-21-193745_create_channels/down.sql @@ -0,0 +1,3 @@ +-- This file should undo anything in `up.sql` +DROP TABLE channel_permissions; +DROP TABLE channels; diff --git a/migrations/2025-05-21-193745_create_channels/up.sql b/migrations/2025-05-21-193745_create_channels/up.sql new file mode 100644 index 0000000..2cce7f2 --- /dev/null +++ b/migrations/2025-05-21-193745_create_channels/up.sql @@ -0,0 +1,13 @@ +-- Your SQL goes here +CREATE TABLE channels ( + uuid uuid PRIMARY KEY NOT NULL, + guild_uuid uuid NOT NULL REFERENCES guilds(uuid) ON DELETE CASCADE, + name varchar(32) NOT NULL, + description varchar(500) NOT NULL +); +CREATE TABLE channel_permissions ( + channel_uuid uuid NOT NULL REFERENCES channels(uuid) ON DELETE CASCADE, + role_uuid uuid NOT NULL REFERENCES roles(uuid) ON DELETE CASCADE, + permissions int8 NOT NULL DEFAULT 0, + PRIMARY KEY (channel_uuid, role_uuid) +); diff --git a/migrations/2025-05-21-193954_create_messages/down.sql b/migrations/2025-05-21-193954_create_messages/down.sql new file mode 100644 index 0000000..bb9ce09 --- /dev/null +++ b/migrations/2025-05-21-193954_create_messages/down.sql @@ -0,0 +1,2 @@ +-- This file should undo anything in `up.sql` +DROP TABLE messages; diff --git a/migrations/2025-05-21-193954_create_messages/up.sql b/migrations/2025-05-21-193954_create_messages/up.sql new file mode 100644 index 0000000..1510974 --- /dev/null +++ b/migrations/2025-05-21-193954_create_messages/up.sql @@ -0,0 +1,7 @@ +-- Your SQL goes here +CREATE TABLE messages ( + uuid uuid PRIMARY KEY NOT NULL, + channel_uuid uuid NOT NULL REFERENCES channels(uuid) ON DELETE CASCADE, + user_uuid uuid NOT NULL REFERENCES users(uuid), + message varchar(4000) NOT NULL +); diff --git a/migrations/2025-05-21-194207_create_invites/down.sql b/migrations/2025-05-21-194207_create_invites/down.sql new file mode 100644 index 0000000..03b72de --- /dev/null +++ b/migrations/2025-05-21-194207_create_invites/down.sql @@ -0,0 +1,2 @@ +-- This file should undo anything in `up.sql` +DROP TABLE invites; diff --git a/migrations/2025-05-21-194207_create_invites/up.sql b/migrations/2025-05-21-194207_create_invites/up.sql new file mode 100644 index 0000000..795b39c --- /dev/null +++ b/migrations/2025-05-21-194207_create_invites/up.sql @@ -0,0 +1,6 @@ +-- Your SQL goes here +CREATE TABLE invites ( + id varchar(32) PRIMARY KEY NOT NULL, + guild_uuid uuid NOT NULL REFERENCES guilds(uuid) ON DELETE CASCADE, + user_uuid uuid NOT NULL REFERENCES users(uuid) +); diff --git a/src/main.rs b/src/main.rs index 9036665..0a9d493 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,7 +15,7 @@ type Conn = deadpool::managed::Object; @@ -57,101 +57,6 @@ async fn main() -> Result<(), Error> { let mut conn = pool.get().await?; - /* - TODO: Figure out if a table should be used here and if not then what. - Also figure out if these should be different types from what they currently are and if we should add more "constraints" - - TODO: References to time should be removed in favor of using the timestamp built in to UUIDv7 (apart from deleted_at in users) - */ - diesel::sql_query( - r#" - CREATE TABLE IF NOT EXISTS users ( - uuid uuid PRIMARY KEY NOT NULL, - username varchar(32) NOT NULL, - display_name varchar(64) DEFAULT NULL, - password varchar(512) NOT NULL, - email varchar(100) NOT NULL, - email_verified boolean NOT NULL DEFAULT FALSE, - is_deleted boolean NOT NULL DEFAULT FALSE, - deleted_at int8 DEFAULT NULL, - CONSTRAINT unique_username_active UNIQUE NULLS NOT DISTINCT (username, is_deleted), - CONSTRAINT unique_email_active UNIQUE NULLS NOT DISTINCT (email, is_deleted) - ); - CREATE UNIQUE INDEX IF NOT EXISTS idx_unique_username_active - ON users(username) - WHERE is_deleted = FALSE; - CREATE UNIQUE INDEX IF NOT EXISTS idx_unique_email_active - ON users(email) - WHERE is_deleted = FALSE; - CREATE TABLE IF NOT EXISTS instance_permissions ( - uuid uuid NOT NULL REFERENCES users(uuid), - administrator boolean NOT NULL DEFAULT FALSE - ); - CREATE TABLE IF NOT EXISTS refresh_tokens ( - token varchar(64) PRIMARY KEY UNIQUE NOT NULL, - uuid uuid NOT NULL REFERENCES users(uuid), - created_at int8 NOT NULL, - device_name varchar(16) NOT NULL - ); - CREATE TABLE IF NOT EXISTS access_tokens ( - token varchar(32) PRIMARY KEY UNIQUE NOT NULL, - refresh_token varchar(64) UNIQUE NOT NULL REFERENCES refresh_tokens(token) ON UPDATE CASCADE ON DELETE CASCADE, - uuid uuid NOT NULL REFERENCES users(uuid), - created_at int8 NOT NULL - ); - CREATE TABLE IF NOT EXISTS guilds ( - uuid uuid PRIMARY KEY NOT NULL, - owner_uuid uuid NOT NULL REFERENCES users(uuid), - name VARCHAR(100) NOT NULL, - description VARCHAR(300) - ); - CREATE TABLE IF NOT EXISTS guild_members ( - uuid uuid PRIMARY KEY NOT NULL, - guild_uuid uuid NOT NULL REFERENCES guilds(uuid) ON DELETE CASCADE, - user_uuid uuid NOT NULL REFERENCES users(uuid), - nickname VARCHAR(100) DEFAULT NULL - ); - CREATE TABLE IF NOT EXISTS roles ( - uuid uuid UNIQUE NOT NULL, - guild_uuid uuid NOT NULL REFERENCES guilds(uuid) ON DELETE CASCADE, - name VARCHAR(50) NOT NULL, - color int NOT NULL DEFAULT 16777215, - position int NOT NULL, - permissions int8 NOT NULL DEFAULT 0, - PRIMARY KEY (uuid, guild_uuid) - ); - CREATE TABLE IF NOT EXISTS role_members ( - role_uuid uuid NOT NULL REFERENCES roles(uuid) ON DELETE CASCADE, - member_uuid uuid NOT NULL REFERENCES guild_members(uuid) ON DELETE CASCADE, - PRIMARY KEY (role_uuid, member_uuid) - ); - CREATE TABLE IF NOT EXISTS channels ( - uuid uuid PRIMARY KEY NOT NULL, - guild_uuid uuid NOT NULL REFERENCES guilds(uuid) ON DELETE CASCADE, - name varchar(32) NOT NULL, - description varchar(500) NOT NULL - ); - CREATE TABLE IF NOT EXISTS channel_permissions ( - channel_uuid uuid NOT NULL REFERENCES channels(uuid) ON DELETE CASCADE, - role_uuid uuid NOT NULL REFERENCES roles(uuid) ON DELETE CASCADE, - permissions int8 NOT NULL DEFAULT 0, - PRIMARY KEY (channel_uuid, role_uuid) - ); - CREATE TABLE IF NOT EXISTS messages ( - uuid uuid PRIMARY KEY NOT NULL, - channel_uuid uuid NOT NULL REFERENCES channels(uuid) ON DELETE CASCADE, - user_uuid uuid NOT NULL REFERENCES users(uuid), - message varchar(4000) NOT NULL - ); - CREATE TABLE IF NOT EXISTS invites ( - id varchar(32) PRIMARY KEY NOT NULL, - guild_uuid uuid NOT NULL REFERENCES guilds(uuid) ON DELETE CASCADE, - user_uuid uuid NOT NULL REFERENCES users(uuid) - ); - "#, - ) - .execute(&mut conn) - .await?; /* **Stored for later possible use** diff --git a/src/schema.rs b/src/schema.rs new file mode 100644 index 0000000..f83018c --- /dev/null +++ b/src/schema.rs @@ -0,0 +1,156 @@ +// @generated automatically by Diesel CLI. + +diesel::table! { + access_tokens (token) { + #[max_length = 32] + token -> Varchar, + #[max_length = 64] + refresh_token -> Varchar, + uuid -> Uuid, + created_at -> Int8, + } +} + +diesel::table! { + channel_permissions (channel_uuid, role_uuid) { + channel_uuid -> Uuid, + role_uuid -> Uuid, + permissions -> Int8, + } +} + +diesel::table! { + channels (uuid) { + uuid -> Uuid, + guild_uuid -> Uuid, + #[max_length = 32] + name -> Varchar, + #[max_length = 500] + description -> Varchar, + } +} + +diesel::table! { + guild_members (uuid) { + uuid -> Uuid, + guild_uuid -> Uuid, + user_uuid -> Uuid, + #[max_length = 100] + nickname -> Nullable, + } +} + +diesel::table! { + guilds (uuid) { + uuid -> Uuid, + owner_uuid -> Uuid, + #[max_length = 100] + name -> Varchar, + #[max_length = 300] + description -> Nullable, + } +} + +diesel::table! { + instance_permissions (uuid) { + uuid -> Uuid, + administrator -> Bool, + } +} + +diesel::table! { + invites (id) { + #[max_length = 32] + id -> Varchar, + guild_uuid -> Uuid, + user_uuid -> Uuid, + } +} + +diesel::table! { + messages (uuid) { + uuid -> Uuid, + channel_uuid -> Uuid, + user_uuid -> Uuid, + #[max_length = 4000] + message -> Varchar, + } +} + +diesel::table! { + refresh_tokens (token) { + #[max_length = 64] + token -> Varchar, + uuid -> Uuid, + created_at -> Int8, + #[max_length = 16] + device_name -> Varchar, + } +} + +diesel::table! { + role_members (role_uuid, member_uuid) { + role_uuid -> Uuid, + member_uuid -> Uuid, + } +} + +diesel::table! { + roles (uuid, guild_uuid) { + uuid -> Uuid, + guild_uuid -> Uuid, + #[max_length = 50] + name -> Varchar, + color -> Int4, + position -> Int4, + permissions -> Int8, + } +} + +diesel::table! { + users (uuid) { + uuid -> Uuid, + #[max_length = 32] + username -> Varchar, + #[max_length = 64] + display_name -> Nullable, + #[max_length = 512] + password -> Varchar, + #[max_length = 100] + email -> Varchar, + email_verified -> Bool, + is_deleted -> Bool, + deleted_at -> Nullable, + } +} + +diesel::joinable!(access_tokens -> refresh_tokens (refresh_token)); +diesel::joinable!(access_tokens -> users (uuid)); +diesel::joinable!(channel_permissions -> channels (channel_uuid)); +diesel::joinable!(channels -> guilds (guild_uuid)); +diesel::joinable!(guild_members -> guilds (guild_uuid)); +diesel::joinable!(guild_members -> users (user_uuid)); +diesel::joinable!(guilds -> users (owner_uuid)); +diesel::joinable!(instance_permissions -> users (uuid)); +diesel::joinable!(invites -> guilds (guild_uuid)); +diesel::joinable!(invites -> users (user_uuid)); +diesel::joinable!(messages -> channels (channel_uuid)); +diesel::joinable!(messages -> users (user_uuid)); +diesel::joinable!(refresh_tokens -> users (uuid)); +diesel::joinable!(role_members -> guild_members (member_uuid)); +diesel::joinable!(roles -> guilds (guild_uuid)); + +diesel::allow_tables_to_appear_in_same_query!( + access_tokens, + channel_permissions, + channels, + guild_members, + guilds, + instance_permissions, + invites, + messages, + refresh_tokens, + role_members, + roles, + users, +); diff --git a/src/structs.rs b/src/structs.rs index 7cec7c9..b9fd471 100644 --- a/src/structs.rs +++ b/src/structs.rs @@ -6,7 +6,7 @@ use log::error; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use crate::{Conn, Data, tables::*}; +use crate::{Conn, Data, schema::*}; #[derive(Serialize, Deserialize, Clone, Selectable)] #[diesel(table_name = channels)] diff --git a/src/tables.rs b/src/tables.rs deleted file mode 100644 index 3dbd38b..0000000 --- a/src/tables.rs +++ /dev/null @@ -1,109 +0,0 @@ -use diesel::table; - -table! { - users (uuid) { - uuid -> Uuid, - username -> VarChar, - display_name -> Nullable, - password -> VarChar, - email -> VarChar, - email_verified -> Bool, - is_deleted -> Bool, - deleted_at -> Int8, - } -} - -table! { - instance_permissions (uuid) { - uuid -> Uuid, - administrator -> Bool, - } -} - -table! { - refresh_tokens (token) { - token -> VarChar, - uuid -> Uuid, - created_at -> Int8, - device_name -> VarChar, - } -} - -table! { - access_tokens (token) { - token -> VarChar, - refresh_token -> VarChar, - uuid -> Uuid, - created_at -> Int8 - } -} - -table! { - guilds (uuid) { - uuid -> Uuid, - owner_uuid -> Uuid, - name -> VarChar, - description -> VarChar - } -} - -table! { - guild_members (uuid) { - uuid -> Uuid, - guild_uuid -> Uuid, - user_uuid -> Uuid, - nickname -> VarChar, - } -} - -table! { - roles (uuid, guild_uuid) { - uuid -> Uuid, - guild_uuid -> Uuid, - name -> VarChar, - color -> Int4, - position -> Int4, - permissions -> Int8, - } -} - -table! { - role_members (role_uuid, member_uuid) { - role_uuid -> Uuid, - member_uuid -> Uuid, - } -} - -table! { - channels (uuid) { - uuid -> Uuid, - guild_uuid -> Uuid, - name -> VarChar, - description -> VarChar, - } -} - -table! { - channel_permissions (channel_uuid, role_uuid) { - channel_uuid -> Uuid, - role_uuid -> Uuid, - permissions -> Int8, - } -} - -table! { - messages (uuid) { - uuid -> Uuid, - channel_uuid -> Uuid, - user_uuid -> Uuid, - message -> VarChar, - } -} - -table! { - invites (id) { - id -> VarChar, - guild_uuid -> Uuid, - user_uuid -> Uuid, - } -} -- 2.47.2 From 2e1382c1d41261c7d82c70a843537777f06fff2c Mon Sep 17 00:00:00 2001 From: Radical Date: Thu, 22 May 2025 16:28:58 +0200 Subject: [PATCH 07/17] feat: make channel description nullable --- .../2025-05-21-203022_channel_description_nullable/down.sql | 4 ++++ .../2025-05-21-203022_channel_description_nullable/up.sql | 3 +++ src/schema.rs | 2 +- 3 files changed, 8 insertions(+), 1 deletion(-) create mode 100644 migrations/2025-05-21-203022_channel_description_nullable/down.sql create mode 100644 migrations/2025-05-21-203022_channel_description_nullable/up.sql diff --git a/migrations/2025-05-21-203022_channel_description_nullable/down.sql b/migrations/2025-05-21-203022_channel_description_nullable/down.sql new file mode 100644 index 0000000..73344b1 --- /dev/null +++ b/migrations/2025-05-21-203022_channel_description_nullable/down.sql @@ -0,0 +1,4 @@ +-- This file should undo anything in `up.sql` +UPDATE channels SET description = '' WHERE description IS NULL; +ALTER TABLE ONLY channels ALTER COLUMN description SET NOT NULL; +ALTER TABLE ONLY channels ALTER COLUMN description DROP DEFAULT; diff --git a/migrations/2025-05-21-203022_channel_description_nullable/up.sql b/migrations/2025-05-21-203022_channel_description_nullable/up.sql new file mode 100644 index 0000000..5ca6776 --- /dev/null +++ b/migrations/2025-05-21-203022_channel_description_nullable/up.sql @@ -0,0 +1,3 @@ +-- Your SQL goes here +ALTER TABLE ONLY channels ALTER COLUMN description DROP NOT NULL; +ALTER TABLE ONLY channels ALTER COLUMN description SET DEFAULT NULL; diff --git a/src/schema.rs b/src/schema.rs index f83018c..b3274fc 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -26,7 +26,7 @@ diesel::table! { #[max_length = 32] name -> Varchar, #[max_length = 500] - description -> Varchar, + description -> Nullable, } } -- 2.47.2 From c1885210fbb39b272fb1f25da215c1b0ae4c4536 Mon Sep 17 00:00:00 2001 From: Radical Date: Thu, 22 May 2025 16:29:57 +0200 Subject: [PATCH 08/17] feat: include migrations in binary Lets us change the schema and not worry about instance admins having to manually update their DB! --- Cargo.toml | 5 +++-- build.rs | 3 +++ src/main.rs | 18 +++++++++++++++--- 3 files changed, 21 insertions(+), 5 deletions(-) create mode 100644 build.rs diff --git a/Cargo.toml b/Cargo.toml index e6dcd84..9a91b30 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,8 +30,9 @@ random-string = "1.1" actix-ws = "0.3.0" futures-util = "0.3.31" deadpool = "0.12" -diesel = "2.2" -diesel-async = { version = "0.5", features = ["deadpool", "postgres"] } +diesel = { version = "2.2", features = ["uuid"] } +diesel-async = { version = "0.5", features = ["deadpool", "postgres", "async-connection-wrapper"] } +diesel_migrations = { version = "2.2.0", features = ["postgres"] } [dependencies.tokio] version = "1.44" diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..284ad12 --- /dev/null +++ b/build.rs @@ -0,0 +1,3 @@ +fn main() { + println!("cargo:rerun-if-changed=migrations"); +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 0a9d493..10da4f1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,14 +5,16 @@ use clap::Parser; use simple_logger::SimpleLogger; use diesel_async::pooled_connection::AsyncDieselConnectionManager; use diesel_async::pooled_connection::deadpool::Pool; -use diesel_async::RunQueryDsl; use std::time::SystemTime; mod config; use config::{Config, ConfigBuilder}; -mod api; +use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness}; + +pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!(); type Conn = deadpool::managed::Object>; +mod api; pub mod structs; pub mod utils; pub mod schema; @@ -55,8 +57,18 @@ async fn main() -> Result<(), Error> { let cache_pool = redis::Client::open(config.cache_database.url())?; - let mut conn = pool.get().await?; + let database_url = config.database.url(); + tokio::task::spawn_blocking(move || { + use diesel::prelude::Connection; + use diesel_async::async_connection_wrapper::AsyncConnectionWrapper; + + + let mut conn = AsyncConnectionWrapper::::establish(&database_url)?; + + conn.run_pending_migrations(MIGRATIONS); + Ok::<_, Box>(()) + }).await?; /* **Stored for later possible use** -- 2.47.2 From 73ceea63b6f4be1b270775ada27479b09e84e017 Mon Sep 17 00:00:00 2001 From: Radical Date: Thu, 22 May 2025 16:31:38 +0200 Subject: [PATCH 09/17] feat: refactor structs.rs to diesel! --- src/structs.rs | 744 +++++++++++++++++++++---------------------------- 1 file changed, 320 insertions(+), 424 deletions(-) diff --git a/src/structs.rs b/src/structs.rs index b9fd471..547a852 100644 --- a/src/structs.rs +++ b/src/structs.rs @@ -1,15 +1,42 @@ -use std::str::FromStr; - use actix_web::HttpResponse; -use diesel::Selectable; +use diesel::{delete, insert_into, prelude::{Insertable, Queryable}, ExpressionMethods, QueryDsl, Selectable, SelectableHelper}; use log::error; use serde::{Deserialize, Serialize}; use uuid::Uuid; +use diesel_async::{pooled_connection::AsyncDieselConnectionManager, RunQueryDsl}; use crate::{Conn, Data, schema::*}; -#[derive(Serialize, Deserialize, Clone, Selectable)] +#[derive(Queryable, Selectable, Insertable, Clone)] #[diesel(table_name = channels)] +#[diesel(check_for_backend(diesel::pg::Pg))] +struct ChannelBuilder { + uuid: Uuid, + guild_uuid: Uuid, + name: String, + description: Option, +} + +impl ChannelBuilder { + async fn build(self, conn: &mut Conn) -> Result { + use self::channel_permissions::dsl::*; + let channel_permission: Vec = channel_permissions + .filter(channel_uuid.eq(self.uuid)) + .select((role_uuid, permissions)) + .load(conn) + .await?; + + Ok(Channel { + uuid: self.uuid, + guild_uuid: self.guild_uuid, + name: self.name, + description: self.description, + permissions: channel_permission, + }) + } +} + +#[derive(Serialize, Deserialize, Clone)] pub struct Channel { pub uuid: Uuid, pub guild_uuid: Uuid, @@ -18,116 +45,81 @@ pub struct Channel { pub permissions: Vec, } -#[derive(Serialize, Clone)] -struct ChannelPermissionBuilder { - role_uuid: String, - permissions: i32, -} - -impl ChannelPermissionBuilder { - fn build(&self) -> ChannelPermission { - ChannelPermission { - role_uuid: Uuid::from_str(&self.role_uuid).unwrap(), - permissions: self.permissions, - } - } -} - -#[derive(Serialize, Deserialize, Clone, Selectable)] +#[derive(Serialize, Deserialize, Clone, Queryable)] #[diesel(table_name = channel_permissions)] +#[diesel(check_for_backend(diesel::pg::Pg))] pub struct ChannelPermission { pub role_uuid: Uuid, - pub permissions: i32, + pub permissions: i64, } impl Channel { pub async fn fetch_all( - conn: &mut Conn, + pool: &deadpool::managed::Pool, Conn>, guild_uuid: Uuid, ) -> Result, HttpResponse> { - + let mut conn = pool.get().await.unwrap(); - if let Err(error) = row { + use channels::dsl; + let channel_builders_result: Result, diesel::result::Error> = dsl::channels + .filter(dsl::guild_uuid.eq(guild_uuid)) + .select(ChannelBuilder::as_select()) + .load(&mut conn) + .await; + + if let Err(error) = channel_builders_result { error!("{}", error); return Err(HttpResponse::InternalServerError().finish()); } - let channels: Vec<(String, String, Option)> = row.unwrap(); + let channel_builders = channel_builders_result.unwrap(); - let futures = channels.iter().map(async |t| { - let (uuid, name, description) = t.to_owned(); - - let row = sqlx::query_as(&format!("SELECT CAST(role_uuid AS VARCHAR), permissions FROM channel_permissions WHERE channel_uuid = '{}'", uuid)) - .fetch_all(pool) - .await; - - if let Err(error) = row { - error!("{}", error); - - return Err(HttpResponse::InternalServerError().finish()) - } - - let channel_permission_builders: Vec = row.unwrap(); - - Ok(Self { - uuid: Uuid::from_str(&uuid).unwrap(), - guild_uuid, - name, - description, - permissions: channel_permission_builders.iter().map(|b| b.build()).collect(), - }) + let channel_futures = channel_builders.iter().map(async move |c| { + let mut conn = pool.get().await?; + c.clone().build(&mut conn).await }); - let channels = futures::future::join_all(futures).await; + + let channels = futures::future::try_join_all(channel_futures).await; - let channels: Result, HttpResponse> = channels.into_iter().collect(); + if let Err(error) = channels { + error!("{}", error); - channels + return Err(HttpResponse::InternalServerError().finish()) + } + + Ok(channels.unwrap()) } pub async fn fetch_one( - pool: &Pool, - guild_uuid: Uuid, + conn: &mut Conn, channel_uuid: Uuid, ) -> Result { - let row = sqlx::query_as(&format!( - "SELECT name, description FROM channels WHERE guild_uuid = '{}' AND uuid = '{}'", - guild_uuid, channel_uuid - )) - .fetch_one(pool) - .await; - - if let Err(error) = row { - error!("{}", error); - - return Err(HttpResponse::InternalServerError().finish()); - } - - let (name, description): (String, Option) = row.unwrap(); - - let row = sqlx::query_as(&format!("SELECT CAST(role_uuid AS VARCHAR), permissions FROM channel_permissions WHERE channel_uuid = '{}'", channel_uuid)) - .fetch_all(pool) + use channels::dsl; + let channel_builder_result: Result = dsl::channels + .filter(dsl::uuid.eq(channel_uuid)) + .select(ChannelBuilder::as_select()) + .get_result(conn) .await; - if let Err(error) = row { + if let Err(error) = channel_builder_result { error!("{}", error); - return Err(HttpResponse::InternalServerError().finish()); + return Err(HttpResponse::InternalServerError().finish()) } - let channel_permission_builders: Vec = row.unwrap(); + let channel_builder = channel_builder_result.unwrap(); - Ok(Self { - uuid: channel_uuid, - guild_uuid, - name, - description, - permissions: channel_permission_builders - .iter() - .map(|b| b.build()) - .collect(), - }) + let channel = channel_builder.build(conn).await; + + if let Err(error) = channel { + error!("{}", error); + + return Err(HttpResponse::InternalServerError().finish()) + } + + Ok(channel.unwrap()) } pub async fn new( @@ -136,19 +128,28 @@ impl Channel { name: String, description: Option, ) -> Result { + let mut conn = data.pool.get().await.unwrap(); + let channel_uuid = Uuid::now_v7(); - let row = sqlx::query(&format!("INSERT INTO channels (uuid, guild_uuid, name, description) VALUES ('{}', '{}', $1, $2)", channel_uuid, guild_uuid)) - .bind(&name) - .bind(&description) - .execute(&data.pool) + let new_channel = ChannelBuilder { + uuid: channel_uuid, + guild_uuid: guild_uuid, + name: name.clone(), + description: description.clone(), + }; + + let insert_result = insert_into(channels::table) + .values(new_channel) + .execute(&mut conn) .await; - if let Err(error) = row { + if let Err(error) = insert_result { error!("{}", error); return Err(HttpResponse::InternalServerError().finish()); } + // returns different object because there's no reason to build the channelbuilder (wastes 1 database request) let channel = Self { uuid: channel_uuid, guild_uuid, @@ -176,13 +177,12 @@ impl Channel { Ok(channel) } - pub async fn delete(self, pool: &Pool) -> Result<(), HttpResponse> { - let result = sqlx::query(&format!( - "DELETE FROM channels WHERE channel_uuid = '{}'", - self.uuid - )) - .execute(pool) - .await; + pub async fn delete(self, conn: &mut Conn) -> Result<(), HttpResponse> { + use channels::dsl; + let result = delete(channels::table) + .filter(dsl::uuid.eq(self.uuid)) + .execute(conn) + .await; if let Err(error) = result { error!("{}", error); @@ -195,50 +195,53 @@ impl Channel { pub async fn fetch_messages( &self, - pool: &Pool, + conn: &mut Conn, amount: i64, offset: i64, ) -> Result, HttpResponse> { - let row = sqlx::query_as(&format!("SELECT CAST(uuid AS VARCHAR), CAST(user_uuid AS VARCHAR), CAST(channel_uuid AS VARCHAR), message FROM messages WHERE channel_uuid = '{}' ORDER BY uuid DESC LIMIT $1 OFFSET $2", self.uuid)) - .bind(amount) - .bind(offset) - .fetch_all(pool) + use messages::dsl; + let messages: Result, diesel::result::Error> = dsl::messages + .filter(dsl::channel_uuid.eq(self.uuid)) + .select(Message::as_select()) + .limit(amount) + .offset(offset) + .load(conn) .await; - if let Err(error) = row { + if let Err(error) = messages { error!("{}", error); return Err(HttpResponse::InternalServerError().finish()); } - let message_builders: Vec = row.unwrap(); - - Ok(message_builders.iter().map(|b| b.build()).collect()) + Ok(messages.unwrap()) } pub async fn new_message( &self, - pool: &Pool, + conn: &mut Conn, user_uuid: Uuid, message: String, ) -> Result { let message_uuid = Uuid::now_v7(); - let row = sqlx::query(&format!("INSERT INTO messages (uuid, channel_uuid, user_uuid, message) VALUES ('{}', '{}', '{}', $1)", message_uuid, self.uuid, user_uuid)) - .bind(&message) - .execute(pool) - .await; - - if let Err(error) = row { - error!("{}", error); - return Err(HttpResponse::InternalServerError().finish()); - } - - Ok(Message { + let message = Message { uuid: message_uuid, channel_uuid: self.uuid, user_uuid, message, - }) + }; + + let insert_result = insert_into(messages::table) + .values(message.clone()) + .execute(conn) + .await; + + if let Err(error) = insert_result { + error!("{}", error); + return Err(HttpResponse::InternalServerError().finish()); + } + + Ok(message) } } @@ -280,6 +283,34 @@ impl Permissions { } } +#[derive(Serialize, Queryable, Selectable, Insertable, Clone)] +#[diesel(table_name = guilds)] +#[diesel(check_for_backend(diesel::pg::Pg))] +struct GuildBuilder { + uuid: Uuid, + name: String, + description: Option, + owner_uuid: Uuid, +} + +impl GuildBuilder { + async fn build(self, conn: &mut Conn) -> Result { + let member_count = Member::count(conn, self.uuid).await?; + + let roles = Role::fetch_all(conn, self.uuid).await?; + + Ok(Guild { + uuid: self.uuid, + name: self.name, + description: self.description, + icon: String::from("bogus"), + owner_uuid: self.owner_uuid, + roles: roles, + member_count: member_count, + }) + } +} + #[derive(Serialize)] pub struct Guild { pub uuid: Uuid, @@ -292,85 +323,50 @@ pub struct Guild { } impl Guild { - pub async fn fetch_one(pool: &Pool, guild_uuid: Uuid) -> Result { - let row = sqlx::query_as(&format!( - "SELECT CAST(owner_uuid AS VARCHAR), name, description FROM guilds WHERE uuid = '{}'", - guild_uuid - )) - .fetch_one(pool) - .await; + pub async fn fetch_one(conn: &mut Conn, guild_uuid: Uuid) -> Result { + use guilds::dsl; + let guild_builder: Result = dsl::guilds + .filter(dsl::uuid.eq(guild_uuid)) + .select(GuildBuilder::as_select()) + .get_result(conn) + .await; - if let Err(error) = row { + if let Err(error) = guild_builder { error!("{}", error); return Err(HttpResponse::InternalServerError().finish()); } - let (owner_uuid_raw, name, description): (String, String, Option) = row.unwrap(); + let guild = guild_builder.unwrap().build(conn).await?; - let owner_uuid = Uuid::from_str(&owner_uuid_raw).unwrap(); - - let member_count = Member::count(pool, guild_uuid).await?; - - let roles = Role::fetch_all(pool, guild_uuid).await?; - - Ok(Self { - uuid: guild_uuid, - name, - description, - // FIXME: This isnt supposed to be bogus - icon: String::from("bogus"), - owner_uuid, - roles, - member_count, - }) + Ok(guild) } pub async fn fetch_amount( - pool: &Pool, - start: i32, - amount: i32, + pool: &deadpool::managed::Pool, Conn>, + offset: i64, + amount: i64, ) -> Result, HttpResponse> { // Fetch guild data from database - let rows = sqlx::query_as::<_, (String, String, String, Option)>( - "SELECT CAST(uuid AS VARCHAR), CAST(owner_uuid AS VARCHAR), name, description - FROM guilds - ORDER BY name - LIMIT $1 OFFSET $2", - ) - .bind(amount) - .bind(start) - .fetch_all(pool) - .await - .map_err(|error| { - error!("{}", error); - HttpResponse::InternalServerError().finish() - })?; + let mut conn = pool.get().await.unwrap(); + + use guilds::dsl; + let guild_builders: Vec = dsl::guilds + .select(GuildBuilder::as_select()) + .order_by(dsl::uuid) + .offset(offset) + .limit(amount) + .load(&mut conn) + .await + .map_err(|error| { + error!("{}", error); + HttpResponse::InternalServerError().finish() + })?; // Process each guild concurrently - let guild_futures = rows.into_iter().map(|(guild_uuid_raw, owner_uuid_raw, name, description)| async move { - let uuid = Uuid::from_str(&guild_uuid_raw).map_err(|_| { - HttpResponse::BadRequest().body("Invalid guild UUID format") - })?; - - let owner_uuid = Uuid::from_str(&owner_uuid_raw).map_err(|_| { - HttpResponse::BadRequest().body("Invalid owner UUID format") - })?; - - let (member_count, roles) = tokio::try_join!( - Member::count(pool, uuid), - Role::fetch_all(pool, uuid) - )?; - - Ok::(Self { - uuid, - name, - description, - icon: String::from("bogus"), // FIXME: Replace with actual icon handling - owner_uuid, - roles, - member_count, - }) + let guild_futures = guild_builders.iter().map(async move |g| { + let mut conn = pool.get().await.unwrap(); + g.clone().build(&mut conn).await }); // Execute all futures concurrently and collect results @@ -378,49 +374,28 @@ impl Guild { } pub async fn new( - pool: &Pool, + conn: &mut Conn, name: String, description: Option, owner_uuid: Uuid, ) -> Result { let guild_uuid = Uuid::now_v7(); - let row = sqlx::query(&format!( - "INSERT INTO guilds (uuid, owner_uuid, name, description) VALUES ('{}', '{}', $1, $2)", - guild_uuid, owner_uuid - )) - .bind(&name) - .bind(&description) - .execute(pool) - .await; + let guild_builder = GuildBuilder { + uuid: guild_uuid, + name: name.clone(), + description: description.clone(), + owner_uuid, + }; - if let Err(error) = row { - error!("{}", error); - return Err(HttpResponse::InternalServerError().finish()); - } - - let row = sqlx::query(&format!( - "INSERT INTO guild_members (uuid, guild_uuid, user_uuid) VALUES ('{}', '{}', '{}')", - Uuid::now_v7(), - guild_uuid, - owner_uuid - )) - .execute(pool) - .await; - - if let Err(error) = row { - error!("{}", error); - - let row = sqlx::query(&format!("DELETE FROM guilds WHERE uuid = '{}'", guild_uuid)) - .execute(pool) - .await; - - if let Err(error) = row { + insert_into(guilds::table) + .values(guild_builder) + .execute(conn) + .await + .map_err(|error| { error!("{}", error); - } - - return Err(HttpResponse::InternalServerError().finish()); - } + HttpResponse::InternalServerError().finish() + })?; Ok(Guild { uuid: guild_uuid, @@ -433,168 +408,116 @@ impl Guild { }) } - pub async fn get_invites(&self, pool: &Pool) -> Result, HttpResponse> { - let invites = sqlx::query_as(&format!( - "SELECT (id, guild_uuid, user_uuid) FROM invites WHERE guild_uuid = '{}'", - self.uuid - )) - .fetch_all(pool) - .await; + pub async fn get_invites(&self, conn: &mut Conn) -> Result, HttpResponse> { + use invites::dsl; + let invites = dsl::invites + .filter(dsl::guild_uuid.eq(self.uuid)) + .select(Invite::as_select()) + .load(conn) + .await + .map_err(|error| { + error!("{}", error); + HttpResponse::InternalServerError().finish() + })?; - if let Err(error) = invites { - error!("{}", error); - return Err(HttpResponse::InternalServerError().finish()); - } - - Ok(invites - .unwrap() - .iter() - .map(|b: &InviteBuilder| b.build()) - .collect()) + Ok(invites) } pub async fn create_invite( &self, - pool: &Pool, + conn: &mut Conn, member: &Member, custom_id: Option, ) -> Result { let invite_id; - if custom_id.is_none() { - let charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; - - invite_id = random_string::generate(8, charset); - } else { - invite_id = custom_id.unwrap(); + if let Some(id) = custom_id { + invite_id = id; if invite_id.len() > 32 { return Err(HttpResponse::BadRequest().finish()); } + } else { + let charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + + invite_id = random_string::generate(8, charset); } - let result = sqlx::query(&format!( - "INSERT INTO invites (id, guild_uuid, user_uuid) VALUES ($1, '{}', '{}'", - self.uuid, member.user_uuid - )) - .bind(&invite_id) - .execute(pool) - .await; - - if let Err(error) = result { - error!("{}", error); - return Err(HttpResponse::InternalServerError().finish()); - } - - Ok(Invite { + let invite = Invite { id: invite_id, user_uuid: member.user_uuid, guild_uuid: self.uuid, - }) + }; + + insert_into(invites::table) + .values(invite.clone()) + .execute(conn) + .await + .map_err(|error| { + error!("{}", error); + HttpResponse::InternalServerError().finish() + })?; + + Ok(invite) } } -#[derive(FromRow)] -struct RoleBuilder { - uuid: String, - guild_uuid: String, - name: String, - color: i64, - position: i32, - permissions: i64, -} - -impl RoleBuilder { - fn build(&self) -> Role { - Role { - uuid: Uuid::from_str(&self.uuid).unwrap(), - guild_uuid: Uuid::from_str(&self.guild_uuid).unwrap(), - name: self.name.clone(), - color: self.color, - position: self.position, - permissions: self.permissions, - } - } -} - -#[derive(Serialize, Clone)] +#[derive(Serialize, Clone, Queryable, Selectable, Insertable)] +#[diesel(table_name = roles)] +#[diesel(check_for_backend(diesel::pg::Pg))] pub struct Role { uuid: Uuid, guild_uuid: Uuid, name: String, - color: i64, + color: i32, position: i32, permissions: i64, } impl Role { pub async fn fetch_all( - pool: &Pool, + conn: &mut Conn, guild_uuid: Uuid, ) -> Result, HttpResponse> { - let role_builders_result = sqlx::query_as(&format!("SELECT (uuid, guild_uuid, name, color, position, permissions) FROM roles WHERE guild_uuid = '{}'", guild_uuid)) - .fetch_all(pool) - .await; + use roles::dsl; + let roles: Vec = dsl::roles + .filter(dsl::guild_uuid.eq(guild_uuid)) + .select(Role::as_select()) + .load(conn) + .await + .map_err(|error| { + error!("{}", error); + HttpResponse::InternalServerError().finish() + })?; - if let Err(error) = role_builders_result { - error!("{}", error); - - return Err(HttpResponse::InternalServerError().finish()); - } - - let role_builders: Vec = role_builders_result.unwrap(); - - Ok(role_builders.iter().map(|b| b.build()).collect()) + Ok(roles) } pub async fn fetch_one( - pool: &Pool, + conn: &mut Conn, role_uuid: Uuid, - guild_uuid: Uuid, ) -> Result { - let row = sqlx::query_as(&format!("SELECT (name, color, position, permissions) FROM roles WHERE guild_uuid = '{}' AND uuid = '{}'", guild_uuid, role_uuid)) - .fetch_one(pool) - .await; + use roles::dsl; + let role: Role = dsl::roles + .filter(dsl::uuid.eq(role_uuid)) + .select(Role::as_select()) + .get_result(conn) + .await + .map_err(|error| { + error!("{}", error); + HttpResponse::InternalServerError().finish() + })?; - if let Err(error) = row { - error!("{}", error); - - return Err(HttpResponse::InternalServerError().finish()); - } - - let (name, color, position, permissions) = row.unwrap(); - - Ok(Role { - uuid: role_uuid, - guild_uuid, - name, - color, - position, - permissions, - }) + Ok(role) } pub async fn new( - pool: &Pool, + conn: &mut Conn, guild_uuid: Uuid, name: String, ) -> Result { let role_uuid = Uuid::now_v7(); - let row = sqlx::query(&format!( - "INSERT INTO channels (uuid, guild_uuid, name, position) VALUES ('{}', '{}', $1, $2)", - role_uuid, guild_uuid - )) - .bind(&name) - .bind(0) - .execute(pool) - .await; - - if let Err(error) = row { - error!("{}", error); - return Err(HttpResponse::InternalServerError().finish()); - } - - let role = Self { + let role = Role { uuid: role_uuid, guild_uuid, name, @@ -603,10 +526,22 @@ impl Role { permissions: 0, }; + insert_into(roles::table) + .values(role.clone()) + .execute(conn) + .await + .map_err(|error| { + error!("{}", error); + HttpResponse::InternalServerError().finish() + })?; + Ok(role) } } +#[derive(Queryable, Selectable, Insertable)] +#[diesel(table_name = guild_members)] +#[diesel(check_for_backend(diesel::pg::Pg))] pub struct Member { pub uuid: Uuid, pub nickname: Option, @@ -615,67 +550,63 @@ pub struct Member { } impl Member { - async fn count(pool: &Pool, guild_uuid: Uuid) -> Result { - let member_count = sqlx::query_scalar(&format!( - "SELECT COUNT(uuid) FROM guild_members WHERE guild_uuid = '{}'", - guild_uuid - )) - .fetch_one(pool) - .await; + async fn count(conn: &mut Conn, guild_uuid: Uuid) -> Result { + use guild_members::dsl; + let count: i64 = dsl::guild_members + .filter(dsl::guild_uuid.eq(guild_uuid)) + .count() + .get_result(conn) + .await + .map_err(|error| { + error!("{}", error); + HttpResponse::InternalServerError() + })?; - if let Err(error) = member_count { - error!("{}", error); - - return Err(HttpResponse::InternalServerError().finish()); - } - - Ok(member_count.unwrap()) + Ok(count) } pub async fn fetch_one( - pool: &Pool, + conn: &mut Conn, user_uuid: Uuid, guild_uuid: Uuid, ) -> Result { - let row = sqlx::query_as(&format!("SELECT CAST(uuid AS VARCHAR), nickname FROM guild_members WHERE guild_uuid = '{}' AND user_uuid = '{}'", guild_uuid, user_uuid)) - .fetch_one(pool) - .await; - - if let Err(error) = row { + use guild_members::dsl; + let member: Member = dsl::guild_members + .filter(dsl::user_uuid.eq(user_uuid)) + .filter(dsl::guild_uuid.eq(guild_uuid)) + .select(Member::as_select()) + .get_result(conn) + .await + .map_err(|error| { error!("{}", error); + HttpResponse::InternalServerError().finish() + })?; - return Err(HttpResponse::InternalServerError().finish()); - } - - let (uuid, nickname): (String, Option) = row.unwrap(); - - Ok(Self { - uuid: Uuid::from_str(&uuid).unwrap(), - nickname, - user_uuid, - guild_uuid, - }) + Ok(member) } pub async fn new( - pool: &Pool, + conn: &mut Conn, user_uuid: Uuid, guild_uuid: Uuid, ) -> Result { let member_uuid = Uuid::now_v7(); - let row = sqlx::query(&format!( - "INSERT INTO guild_members uuid, guild_uuid, user_uuid VALUES ('{}', '{}', '{}')", - member_uuid, guild_uuid, user_uuid - )) - .execute(pool) - .await; + let member = Member { + uuid: member_uuid, + guild_uuid, + user_uuid, + nickname: None, + }; - if let Err(error) = row { - error!("{}", error); - - return Err(HttpResponse::InternalServerError().finish()); - } + insert_into(guild_members::table) + .values(member) + .execute(conn) + .await + .map_err(|error| { + error!("{}", error); + HttpResponse::InternalServerError().finish() + })?; Ok(Self { uuid: member_uuid, @@ -686,26 +617,9 @@ impl Member { } } -#[derive(FromRow)] -struct MessageBuilder { - uuid: String, - channel_uuid: String, - user_uuid: String, - message: String, -} - -impl MessageBuilder { - fn build(&self) -> Message { - Message { - uuid: Uuid::from_str(&self.uuid).unwrap(), - channel_uuid: Uuid::from_str(&self.channel_uuid).unwrap(), - user_uuid: Uuid::from_str(&self.user_uuid).unwrap(), - message: self.message.clone(), - } - } -} - -#[derive(Serialize)] +#[derive(Clone, Serialize, Queryable, Selectable, Insertable)] +#[diesel(table_name = messages)] +#[diesel(check_for_backend(diesel::pg::Pg))] pub struct Message { uuid: Uuid, channel_uuid: Uuid, @@ -713,25 +627,8 @@ pub struct Message { message: String, } -#[derive(FromRow)] -pub struct InviteBuilder { - id: String, - user_uuid: String, - guild_uuid: String, -} - -impl InviteBuilder { - fn build(&self) -> Invite { - Invite { - id: self.id.clone(), - user_uuid: Uuid::from_str(&self.user_uuid).unwrap(), - guild_uuid: Uuid::from_str(&self.guild_uuid).unwrap(), - } - } -} - /// Server invite struct -#[derive(Serialize)] +#[derive(Clone, Serialize, Queryable, Selectable, Insertable)] pub struct Invite { /// case-sensitive alphanumeric string with a fixed length of 8 characters, can be up to 32 characters for custom invites id: String, @@ -742,20 +639,19 @@ pub struct Invite { } impl Invite { - pub async fn fetch_one(pool: &Pool, invite_id: String) -> Result { - let invite: Result = - sqlx::query_as("SELECT id, user_uuid, guild_uuid FROM invites WHERE id = $1") - .bind(invite_id) - .fetch_one(pool) - .await; + pub async fn fetch_one(conn: &mut Conn, invite_id: String) -> Result { + use invites::dsl; + let invite: Invite = dsl::invites + .filter(dsl::id.eq(invite_id)) + .select(Invite::as_select()) + .get_result(conn) + .await + .map_err(|error| { + error!("{}", error); + HttpResponse::InternalServerError().finish() + })?; - if let Err(error) = invite { - error!("{}", error); - - return Err(HttpResponse::InternalServerError().finish()); - } - - Ok(invite.unwrap().build()) + Ok(invite) } } -- 2.47.2 From fee46e143327476bf4c7f9bfffe4ebbbd6920ed6 Mon Sep 17 00:00:00 2001 From: Radical Date: Fri, 23 May 2025 12:52:41 +0200 Subject: [PATCH 10/17] feat: use thiserror for errors --- Cargo.toml | 1 + src/error.rs | 79 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+) create mode 100644 src/error.rs diff --git a/Cargo.toml b/Cargo.toml index 9a91b30..8c03fec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ deadpool = "0.12" diesel = { version = "2.2", features = ["uuid"] } diesel-async = { version = "0.5", features = ["deadpool", "postgres", "async-connection-wrapper"] } diesel_migrations = { version = "2.2.0", features = ["postgres"] } +thiserror = "2.0.12" [dependencies.tokio] version = "1.44" diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..5d10251 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,79 @@ +use std::{io, time::SystemTimeError}; + +use actix_web::{error::ResponseError, http::{header::{ContentType, ToStrError}, StatusCode}, HttpResponse}; +use deadpool::managed::{BuildError, PoolError}; +use redis::RedisError; +use serde::Serialize; +use thiserror::Error; +use diesel::{result::Error as DieselError, ConnectionError}; +use diesel_async::pooled_connection::PoolError as DieselPoolError; +use tokio::task::JoinError; +use serde_json::Error as JsonError; +use toml::de::Error as TomlError; +use log::error; + +#[derive(Debug, Error)] +pub enum Error { + #[error(transparent)] + SqlError(#[from] DieselError), + #[error(transparent)] + PoolError(#[from] PoolError), + #[error(transparent)] + BuildError(#[from] BuildError), + #[error(transparent)] + RedisError(#[from] RedisError), + #[error(transparent)] + ConnectionError(#[from] ConnectionError), + #[error(transparent)] + JoinError(#[from] JoinError), + #[error(transparent)] + IoError(#[from] io::Error), + #[error(transparent)] + TomlError(#[from] TomlError), + #[error(transparent)] + JsonError(#[from] JsonError), + #[error(transparent)] + SystemTimeError(#[from] SystemTimeError), + #[error(transparent)] + ToStrError(#[from] ToStrError), + #[error(transparent)] + RandomError(#[from] getrandom::Error), + #[error("{0}")] + PasswordHashError(String), + #[error("{0}")] + BadRequest(String), + #[error("{0}")] + Unauthorized(String), +} + +impl ResponseError for Error { + fn error_response(&self) -> HttpResponse { + error!("{}: {}", self.status_code(), self.to_string()); + + HttpResponse::build(self.status_code()) + .insert_header(ContentType::json()) + .json(WebError::new(self.to_string())) + } + + fn status_code(&self) -> StatusCode { + match *self { + Error::SqlError(DieselError::NotFound) => StatusCode::NOT_FOUND, + Error::BadRequest(_) => StatusCode::BAD_REQUEST, + Error::Unauthorized(_) => StatusCode::UNAUTHORIZED, + _ => StatusCode::INTERNAL_SERVER_ERROR, + } + } +} + +#[derive(Serialize)] +struct WebError { + message: String, +} + +impl WebError { + fn new(message: String) -> Self { + Self { + message, + } + } +} -- 2.47.2 From 3e698edf8cd8e498b60e24e74d44c96810a89207 Mon Sep 17 00:00:00 2001 From: Radical Date: Fri, 23 May 2025 12:54:10 +0200 Subject: [PATCH 11/17] feat: use new error type in main --- src/main.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/main.rs b/src/main.rs index 10da4f1..7fb7087 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,6 +2,7 @@ use actix_cors::Cors; use actix_web::{App, HttpServer, web}; use argon2::Argon2; use clap::Parser; +use error::Error; use simple_logger::SimpleLogger; use diesel_async::pooled_connection::AsyncDieselConnectionManager; use diesel_async::pooled_connection::deadpool::Pool; @@ -18,8 +19,7 @@ mod api; pub mod structs; pub mod utils; pub mod schema; - -type Error = Box; +pub mod error; #[derive(Parser, Debug)] #[command(version, about, long_about = None)] @@ -66,9 +66,9 @@ async fn main() -> Result<(), Error> { let mut conn = AsyncConnectionWrapper::::establish(&database_url)?; - conn.run_pending_migrations(MIGRATIONS); + conn.run_pending_migrations(MIGRATIONS)?; Ok::<_, Box>(()) - }).await?; + }).await?.unwrap(); /* **Stored for later possible use** -- 2.47.2 From 49db25e4548e81ae6777edaea9b664d45c87bf78 Mon Sep 17 00:00:00 2001 From: Radical Date: Fri, 23 May 2025 12:54:52 +0200 Subject: [PATCH 12/17] feat: use new error type in structs, utils and config --- src/config.rs | 2 +- src/structs.rs | 328 ++++++++++++++++++------------------------------- src/utils.rs | 35 +++--- 3 files changed, 144 insertions(+), 221 deletions(-) diff --git a/src/config.rs b/src/config.rs index 4e8fc21..3b537c9 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,4 +1,4 @@ -use crate::Error; +use crate::error::Error; use log::debug; use serde::Deserialize; use tokio::fs::read_to_string; diff --git a/src/structs.rs b/src/structs.rs index 547a852..3bc214b 100644 --- a/src/structs.rs +++ b/src/structs.rs @@ -1,11 +1,17 @@ -use actix_web::HttpResponse; use diesel::{delete, insert_into, prelude::{Insertable, Queryable}, ExpressionMethods, QueryDsl, Selectable, SelectableHelper}; -use log::error; use serde::{Deserialize, Serialize}; use uuid::Uuid; use diesel_async::{pooled_connection::AsyncDieselConnectionManager, RunQueryDsl}; -use crate::{Conn, Data, schema::*}; +use crate::{error::Error, Conn, Data, schema::*}; + +fn load_or_empty(query_result: Result, diesel::result::Error>) -> Result, diesel::result::Error> { + match query_result { + Ok(vec) => Ok(vec), + Err(diesel::result::Error::NotFound) => Ok(Vec::new()), + Err(e) => Err(e), + } +} #[derive(Queryable, Selectable, Insertable, Clone)] #[diesel(table_name = channels)] @@ -18,13 +24,15 @@ struct ChannelBuilder { } impl ChannelBuilder { - async fn build(self, conn: &mut Conn) -> Result { + async fn build(self, conn: &mut Conn) -> Result { use self::channel_permissions::dsl::*; - let channel_permission: Vec = channel_permissions - .filter(channel_uuid.eq(self.uuid)) - .select((role_uuid, permissions)) - .load(conn) - .await?; + let channel_permission: Vec = load_or_empty( + channel_permissions + .filter(channel_uuid.eq(self.uuid)) + .select(ChannelPermission::as_select()) + .load(conn) + .await + )?; Ok(Channel { uuid: self.uuid, @@ -45,7 +53,7 @@ pub struct Channel { pub permissions: Vec, } -#[derive(Serialize, Deserialize, Clone, Queryable)] +#[derive(Serialize, Deserialize, Clone, Queryable, Selectable)] #[diesel(table_name = channel_permissions)] #[diesel(check_for_backend(diesel::pg::Pg))] pub struct ChannelPermission { @@ -57,69 +65,38 @@ impl Channel { pub async fn fetch_all( pool: &deadpool::managed::Pool, Conn>, guild_uuid: Uuid, - ) -> Result, HttpResponse> { - let mut conn = pool.get().await.unwrap(); + ) -> Result, Error> { + let mut conn = pool.get().await?; use channels::dsl; - let channel_builders_result: Result, diesel::result::Error> = dsl::channels - .filter(dsl::guild_uuid.eq(guild_uuid)) - .select(ChannelBuilder::as_select()) - .load(&mut conn) - .await; - - if let Err(error) = channel_builders_result { - error!("{}", error); - - return Err(HttpResponse::InternalServerError().finish()); - } - - let channel_builders = channel_builders_result.unwrap(); + let channel_builders: Vec = load_or_empty( + dsl::channels + .filter(dsl::guild_uuid.eq(guild_uuid)) + .select(ChannelBuilder::as_select()) + .load(&mut conn) + .await + )?; let channel_futures = channel_builders.iter().map(async move |c| { let mut conn = pool.get().await?; c.clone().build(&mut conn).await }); - - let channels = futures::future::try_join_all(channel_futures).await; - - if let Err(error) = channels { - error!("{}", error); - - return Err(HttpResponse::InternalServerError().finish()) - } - - Ok(channels.unwrap()) + futures::future::try_join_all(channel_futures).await } pub async fn fetch_one( conn: &mut Conn, channel_uuid: Uuid, - ) -> Result { + ) -> Result { use channels::dsl; - let channel_builder_result: Result = dsl::channels + let channel_builder: ChannelBuilder = dsl::channels .filter(dsl::uuid.eq(channel_uuid)) .select(ChannelBuilder::as_select()) .get_result(conn) - .await; + .await?; - if let Err(error) = channel_builder_result { - error!("{}", error); - - return Err(HttpResponse::InternalServerError().finish()) - } - - let channel_builder = channel_builder_result.unwrap(); - - let channel = channel_builder.build(conn).await; - - if let Err(error) = channel { - error!("{}", error); - - return Err(HttpResponse::InternalServerError().finish()) - } - - Ok(channel.unwrap()) + channel_builder.build(conn).await } pub async fn new( @@ -127,8 +104,8 @@ impl Channel { guild_uuid: Uuid, name: String, description: Option, - ) -> Result { - let mut conn = data.pool.get().await.unwrap(); + ) -> Result { + let mut conn = data.pool.get().await?; let channel_uuid = Uuid::now_v7(); @@ -139,15 +116,10 @@ impl Channel { description: description.clone(), }; - let insert_result = insert_into(channels::table) + insert_into(channels::table) .values(new_channel) .execute(&mut conn) - .await; - - if let Err(error) = insert_result { - error!("{}", error); - return Err(HttpResponse::InternalServerError().finish()); - } + .await?; // returns different object because there's no reason to build the channelbuilder (wastes 1 database request) let channel = Self { @@ -158,37 +130,21 @@ impl Channel { permissions: vec![], }; - let cache_result = data + data .set_cache_key(channel_uuid.to_string(), channel.clone(), 1800) - .await; + .await?; - if let Err(error) = cache_result { - error!("{}", error); - return Err(HttpResponse::InternalServerError().finish()); - } - - let cache_deletion_result = data.del_cache_key(format!("{}_channels", guild_uuid)).await; - - if let Err(error) = cache_deletion_result { - error!("{}", error); - return Err(HttpResponse::InternalServerError().finish()); - } + data.del_cache_key(format!("{}_channels", guild_uuid)).await?; Ok(channel) } - pub async fn delete(self, conn: &mut Conn) -> Result<(), HttpResponse> { + pub async fn delete(self, conn: &mut Conn) -> Result<(), Error> { use channels::dsl; - let result = delete(channels::table) + delete(channels::table) .filter(dsl::uuid.eq(self.uuid)) .execute(conn) - .await; - - if let Err(error) = result { - error!("{}", error); - - return Err(HttpResponse::InternalServerError().finish()); - } + .await?; Ok(()) } @@ -198,22 +154,19 @@ impl Channel { conn: &mut Conn, amount: i64, offset: i64, - ) -> Result, HttpResponse> { + ) -> Result, Error> { use messages::dsl; - let messages: Result, diesel::result::Error> = dsl::messages - .filter(dsl::channel_uuid.eq(self.uuid)) - .select(Message::as_select()) - .limit(amount) - .offset(offset) - .load(conn) - .await; + let messages: Vec = load_or_empty( + dsl::messages + .filter(dsl::channel_uuid.eq(self.uuid)) + .select(Message::as_select()) + .limit(amount) + .offset(offset) + .load(conn) + .await + )?; - if let Err(error) = messages { - error!("{}", error); - return Err(HttpResponse::InternalServerError().finish()); - } - - Ok(messages.unwrap()) + Ok(messages) } pub async fn new_message( @@ -221,7 +174,7 @@ impl Channel { conn: &mut Conn, user_uuid: Uuid, message: String, - ) -> Result { + ) -> Result { let message_uuid = Uuid::now_v7(); let message = Message { @@ -231,15 +184,10 @@ impl Channel { message, }; - let insert_result = insert_into(messages::table) + insert_into(messages::table) .values(message.clone()) .execute(conn) - .await; - - if let Err(error) = insert_result { - error!("{}", error); - return Err(HttpResponse::InternalServerError().finish()); - } + .await?; Ok(message) } @@ -294,7 +242,7 @@ struct GuildBuilder { } impl GuildBuilder { - async fn build(self, conn: &mut Conn) -> Result { + async fn build(self, conn: &mut Conn) -> Result { let member_count = Member::count(conn, self.uuid).await?; let roles = Role::fetch_all(conn, self.uuid).await?; @@ -323,49 +271,39 @@ pub struct Guild { } impl Guild { - pub async fn fetch_one(conn: &mut Conn, guild_uuid: Uuid) -> Result { + pub async fn fetch_one(conn: &mut Conn, guild_uuid: Uuid) -> Result { use guilds::dsl; - let guild_builder: Result = dsl::guilds + let guild_builder: GuildBuilder = dsl::guilds .filter(dsl::uuid.eq(guild_uuid)) .select(GuildBuilder::as_select()) .get_result(conn) - .await; + .await?; - if let Err(error) = guild_builder { - error!("{}", error); - - return Err(HttpResponse::InternalServerError().finish()); - } - - let guild = guild_builder.unwrap().build(conn).await?; - - Ok(guild) + guild_builder.build(conn).await } pub async fn fetch_amount( pool: &deadpool::managed::Pool, Conn>, offset: i64, amount: i64, - ) -> Result, HttpResponse> { + ) -> Result, Error> { // Fetch guild data from database - let mut conn = pool.get().await.unwrap(); + let mut conn = pool.get().await?; use guilds::dsl; - let guild_builders: Vec = dsl::guilds - .select(GuildBuilder::as_select()) - .order_by(dsl::uuid) - .offset(offset) - .limit(amount) - .load(&mut conn) - .await - .map_err(|error| { - error!("{}", error); - HttpResponse::InternalServerError().finish() - })?; + let guild_builders: Vec = load_or_empty( + dsl::guilds + .select(GuildBuilder::as_select()) + .order_by(dsl::uuid) + .offset(offset) + .limit(amount) + .load(&mut conn) + .await + )?; // Process each guild concurrently let guild_futures = guild_builders.iter().map(async move |g| { - let mut conn = pool.get().await.unwrap(); + let mut conn = pool.get().await?; g.clone().build(&mut conn).await }); @@ -378,7 +316,7 @@ impl Guild { name: String, description: Option, owner_uuid: Uuid, - ) -> Result { + ) -> Result { let guild_uuid = Uuid::now_v7(); let guild_builder = GuildBuilder { @@ -391,11 +329,21 @@ impl Guild { insert_into(guilds::table) .values(guild_builder) .execute(conn) - .await - .map_err(|error| { - error!("{}", error); - HttpResponse::InternalServerError().finish() - })?; + .await?; + + let member_uuid = Uuid::now_v7(); + + let member = Member { + uuid: member_uuid, + nickname: None, + user_uuid: owner_uuid, + guild_uuid, + }; + + insert_into(guild_members::table) + .values(member) + .execute(conn) + .await?; Ok(Guild { uuid: guild_uuid, @@ -408,17 +356,15 @@ impl Guild { }) } - pub async fn get_invites(&self, conn: &mut Conn) -> Result, HttpResponse> { + pub async fn get_invites(&self, conn: &mut Conn) -> Result, Error> { use invites::dsl; - let invites = dsl::invites - .filter(dsl::guild_uuid.eq(self.uuid)) - .select(Invite::as_select()) - .load(conn) - .await - .map_err(|error| { - error!("{}", error); - HttpResponse::InternalServerError().finish() - })?; + let invites = load_or_empty( + dsl::invites + .filter(dsl::guild_uuid.eq(self.uuid)) + .select(Invite::as_select()) + .load(conn) + .await + )?; Ok(invites) } @@ -428,13 +374,13 @@ impl Guild { conn: &mut Conn, member: &Member, custom_id: Option, - ) -> Result { + ) -> Result { let invite_id; if let Some(id) = custom_id { invite_id = id; if invite_id.len() > 32 { - return Err(HttpResponse::BadRequest().finish()); + return Err(Error::BadRequest("MAX LENGTH".to_string())) } } else { let charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; @@ -451,11 +397,7 @@ impl Guild { insert_into(invites::table) .values(invite.clone()) .execute(conn) - .await - .map_err(|error| { - error!("{}", error); - HttpResponse::InternalServerError().finish() - })?; + .await?; Ok(invite) } @@ -477,17 +419,15 @@ impl Role { pub async fn fetch_all( conn: &mut Conn, guild_uuid: Uuid, - ) -> Result, HttpResponse> { + ) -> Result, Error> { use roles::dsl; - let roles: Vec = dsl::roles - .filter(dsl::guild_uuid.eq(guild_uuid)) - .select(Role::as_select()) - .load(conn) - .await - .map_err(|error| { - error!("{}", error); - HttpResponse::InternalServerError().finish() - })?; + let roles: Vec = load_or_empty( + dsl::roles + .filter(dsl::guild_uuid.eq(guild_uuid)) + .select(Role::as_select()) + .load(conn) + .await + )?; Ok(roles) } @@ -495,17 +435,13 @@ impl Role { pub async fn fetch_one( conn: &mut Conn, role_uuid: Uuid, - ) -> Result { + ) -> Result { use roles::dsl; let role: Role = dsl::roles .filter(dsl::uuid.eq(role_uuid)) .select(Role::as_select()) .get_result(conn) - .await - .map_err(|error| { - error!("{}", error); - HttpResponse::InternalServerError().finish() - })?; + .await?; Ok(role) } @@ -514,7 +450,7 @@ impl Role { conn: &mut Conn, guild_uuid: Uuid, name: String, - ) -> Result { + ) -> Result { let role_uuid = Uuid::now_v7(); let role = Role { @@ -529,11 +465,7 @@ impl Role { insert_into(roles::table) .values(role.clone()) .execute(conn) - .await - .map_err(|error| { - error!("{}", error); - HttpResponse::InternalServerError().finish() - })?; + .await?; Ok(role) } @@ -550,17 +482,13 @@ pub struct Member { } impl Member { - async fn count(conn: &mut Conn, guild_uuid: Uuid) -> Result { + async fn count(conn: &mut Conn, guild_uuid: Uuid) -> Result { use guild_members::dsl; let count: i64 = dsl::guild_members .filter(dsl::guild_uuid.eq(guild_uuid)) .count() .get_result(conn) - .await - .map_err(|error| { - error!("{}", error); - HttpResponse::InternalServerError() - })?; + .await?; Ok(count) } @@ -569,18 +497,14 @@ impl Member { conn: &mut Conn, user_uuid: Uuid, guild_uuid: Uuid, - ) -> Result { + ) -> Result { use guild_members::dsl; let member: Member = dsl::guild_members .filter(dsl::user_uuid.eq(user_uuid)) .filter(dsl::guild_uuid.eq(guild_uuid)) .select(Member::as_select()) .get_result(conn) - .await - .map_err(|error| { - error!("{}", error); - HttpResponse::InternalServerError().finish() - })?; + .await?; Ok(member) } @@ -589,7 +513,7 @@ impl Member { conn: &mut Conn, user_uuid: Uuid, guild_uuid: Uuid, - ) -> Result { + ) -> Result { let member_uuid = Uuid::now_v7(); let member = Member { @@ -602,11 +526,7 @@ impl Member { insert_into(guild_members::table) .values(member) .execute(conn) - .await - .map_err(|error| { - error!("{}", error); - HttpResponse::InternalServerError().finish() - })?; + .await?; Ok(Self { uuid: member_uuid, @@ -639,17 +559,13 @@ pub struct Invite { } impl Invite { - pub async fn fetch_one(conn: &mut Conn, invite_id: String) -> Result { + pub async fn fetch_one(conn: &mut Conn, invite_id: String) -> Result { use invites::dsl; let invite: Invite = dsl::invites .filter(dsl::id.eq(invite_id)) .select(Invite::as_select()) .get_result(conn) - .await - .map_err(|error| { - error!("{}", error); - HttpResponse::InternalServerError().finish() - })?; + .await?; Ok(invite) } @@ -657,6 +573,6 @@ impl Invite { #[derive(Deserialize)] pub struct StartAmountQuery { - pub start: Option, - pub amount: Option, + pub start: Option, + pub amount: Option, } diff --git a/src/utils.rs b/src/utils.rs index 77c5e0a..b7ddcc1 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,5 +1,4 @@ use actix_web::{ - HttpResponse, cookie::{Cookie, SameSite, time::Duration}, http::header::HeaderMap, }; @@ -8,25 +7,31 @@ use hex::encode; use redis::RedisError; use serde::Serialize; -use crate::Data; +use crate::{error::Error, Data}; -pub fn get_auth_header(headers: &HeaderMap) -> Result<&str, HttpResponse> { +pub fn get_auth_header(headers: &HeaderMap) -> Result<&str, Error> { let auth_token = headers.get(actix_web::http::header::AUTHORIZATION); if auth_token.is_none() { - return Err(HttpResponse::Unauthorized().finish()); + return Err(Error::Unauthorized("No authorization header provided".to_string())); } - let auth = auth_token.unwrap().to_str(); + let auth_raw = auth_token.unwrap().to_str()?; - if let Err(error) = auth { - return Err(HttpResponse::Unauthorized().json(format!(r#" {{ "error": "{}" }} "#, error))); + let mut auth = auth_raw.split_whitespace(); + + let auth_type = auth.nth(0); + + let auth_value = auth.nth(0); + + if auth_type.is_none() { + return Err(Error::BadRequest("Authorization header is empty".to_string())); + } else if auth_type.is_some_and(|at| at != "Bearer") { + return Err(Error::BadRequest("Only token auth is supported".to_string())); } - - let auth_value = auth.unwrap().split_whitespace().nth(1); - + if auth_value.is_none() { - return Err(HttpResponse::BadRequest().finish()); + return Err(Error::BadRequest("No token provided".to_string())); } Ok(auth_value.unwrap()) @@ -60,12 +65,12 @@ impl Data { key: String, value: impl Serialize, expire: u32, - ) -> Result<(), RedisError> { + ) -> Result<(), Error> { let mut conn = self.cache_pool.get_multiplexed_tokio_connection().await?; let key_encoded = encode(key); - let value_json = serde_json::to_string(&value).unwrap(); + let value_json = serde_json::to_string(&value)?; redis::cmd("SET") .arg(&[key_encoded.clone(), value_json]) @@ -75,7 +80,9 @@ impl Data { redis::cmd("EXPIRE") .arg(&[key_encoded, expire.to_string()]) .exec_async(&mut conn) - .await + .await?; + + Ok(()) } pub async fn get_cache_key(&self, key: String) -> Result { -- 2.47.2 From bf51f623e47cb2877630ab63838b1e0847e0a803 Mon Sep 17 00:00:00 2001 From: Radical Date: Fri, 23 May 2025 12:55:27 +0200 Subject: [PATCH 13/17] feat: migrate to diesel and new error type in auth --- src/api/v1/auth/login.rs | 152 +++++++++++------------------------- src/api/v1/auth/mod.rs | 53 ++++++------- src/api/v1/auth/refresh.rs | 75 ++++++++---------- src/api/v1/auth/register.rs | 123 ++++++++++------------------- src/api/v1/auth/revoke.rs | 105 +++++-------------------- 5 files changed, 162 insertions(+), 346 deletions(-) diff --git a/src/api/v1/auth/login.rs b/src/api/v1/auth/login.rs index 38d5449..8ad345e 100644 --- a/src/api/v1/auth/login.rs +++ b/src/api/v1/auth/login.rs @@ -1,14 +1,14 @@ use std::time::{SystemTime, UNIX_EPOCH}; -use actix_web::{Error, HttpResponse, post, web}; +use actix_web::{HttpResponse, post, web}; use argon2::{PasswordHash, PasswordVerifier}; -use log::error; +use diesel::{dsl::insert_into, ExpressionMethods, QueryDsl}; +use diesel_async::RunQueryDsl; use serde::Deserialize; +use uuid::Uuid; use crate::{ - Data, - api::v1::auth::{EMAIL_REGEX, PASSWORD_REGEX, USERNAME_REGEX}, - utils::{generate_access_token, generate_refresh_token, refresh_token_cookie}, + error::Error, api::v1::auth::{EMAIL_REGEX, PASSWORD_REGEX, USERNAME_REGEX}, schema::*, utils::{generate_access_token, generate_refresh_token, refresh_token_cookie}, Data }; use super::Response; @@ -29,66 +29,42 @@ pub async fn response( return Ok(HttpResponse::Forbidden().json(r#"{ "password_hashed": false }"#)); } + use users::dsl; + + let mut conn = data.pool.get().await?; + if EMAIL_REGEX.is_match(&login_information.username) { - let row = - sqlx::query_as("SELECT CAST(uuid as VARCHAR), password FROM users WHERE email = $1") - .bind(&login_information.username) - .fetch_one(&data.pool) - .await; + // FIXME: error handling, right now i just want this to work + let (uuid, password): (Uuid, String) = dsl::users + .filter(dsl::email.eq(&login_information.username)) + .select((dsl::uuid, dsl::password)) + .get_result(&mut conn) + .await?; - if let Err(error) = row { - if error.to_string() - == "no rows returned by a query that expected to return at least one row" - { - return Ok(HttpResponse::Unauthorized().finish()); - } - - error!("{}", error); - return Ok(HttpResponse::InternalServerError().json( - r#"{ "error": "Unhandled exception occured, contact the server administrator" }"#, - )); - } - - let (uuid, password): (String, String) = row.unwrap(); - - return Ok(login( + return login( data.clone(), uuid, login_information.password.clone(), password, login_information.device_name.clone(), ) - .await); + .await; } else if USERNAME_REGEX.is_match(&login_information.username) { - let row = - sqlx::query_as("SELECT CAST(uuid as VARCHAR), password FROM users WHERE username = $1") - .bind(&login_information.username) - .fetch_one(&data.pool) - .await; + // FIXME: error handling, right now i just want this to work + let (uuid, password): (Uuid, String) = dsl::users + .filter(dsl::username.eq(&login_information.username)) + .select((dsl::uuid, dsl::password)) + .get_result(&mut conn) + .await?; - if let Err(error) = row { - if error.to_string() - == "no rows returned by a query that expected to return at least one row" - { - return Ok(HttpResponse::Unauthorized().finish()); - } - - error!("{}", error); - return Ok(HttpResponse::InternalServerError().json( - r#"{ "error": "Unhandled exception occured, contact the server administrator" }"#, - )); - } - - let (uuid, password): (String, String) = row.unwrap(); - - return Ok(login( + return login( data.clone(), uuid, login_information.password.clone(), password, login_information.device_name.clone(), ) - .await); + .await; } Ok(HttpResponse::Unauthorized().finish()) @@ -96,79 +72,45 @@ pub async fn response( async fn login( data: actix_web::web::Data, - uuid: String, + uuid: Uuid, request_password: String, database_password: String, device_name: String, -) -> HttpResponse { - let parsed_hash_raw = PasswordHash::new(&database_password); +) -> Result { + let mut conn = data.pool.get().await?; - if let Err(error) = parsed_hash_raw { - error!("{}", error); - return HttpResponse::InternalServerError().finish(); - } - - let parsed_hash = parsed_hash_raw.unwrap(); + let parsed_hash = PasswordHash::new(&database_password).map_err(|e| Error::PasswordHashError(e.to_string()))?; if data .argon2 .verify_password(request_password.as_bytes(), &parsed_hash) .is_err() { - return HttpResponse::Unauthorized().finish(); + return Err(Error::Unauthorized("Wrong username or password".to_string())); } - let refresh_token_raw = generate_refresh_token(); - let access_token_raw = generate_access_token(); - - if let Err(error) = refresh_token_raw { - error!("{}", error); - return HttpResponse::InternalServerError().finish(); - } - - let refresh_token = refresh_token_raw.unwrap(); - - if let Err(error) = access_token_raw { - error!("{}", error); - return HttpResponse::InternalServerError().finish(); - } - - let access_token = access_token_raw.unwrap(); + let refresh_token = generate_refresh_token()?; + let access_token = generate_access_token()?; let current_time = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() + .duration_since(UNIX_EPOCH)? .as_secs() as i64; - if let Err(error) = sqlx::query(&format!( - "INSERT INTO refresh_tokens (token, uuid, created_at, device_name) VALUES ($1, '{}', $2, $3 )", - uuid - )) - .bind(&refresh_token) - .bind(current_time) - .bind(device_name) - .execute(&data.pool) - .await - { - error!("{}", error); - return HttpResponse::InternalServerError().finish(); - } + use refresh_tokens::dsl as rdsl; - if let Err(error) = sqlx::query(&format!( - "INSERT INTO access_tokens (token, refresh_token, uuid, created_at) VALUES ($1, $2, '{}', $3 )", - uuid - )) - .bind(&access_token) - .bind(&refresh_token) - .bind(current_time) - .execute(&data.pool) - .await - { - error!("{}", error); - return HttpResponse::InternalServerError().finish() - } + insert_into(refresh_tokens::table) + .values((rdsl::token.eq(&refresh_token), rdsl::uuid.eq(uuid), rdsl::created_at.eq(current_time), rdsl::device_name.eq(device_name))) + .execute(&mut conn) + .await?; - HttpResponse::Ok() + use access_tokens::dsl as adsl; + + insert_into(access_tokens::table) + .values((adsl::token.eq(&access_token), adsl::refresh_token.eq(&refresh_token), adsl::uuid.eq(uuid), adsl::created_at.eq(current_time))) + .execute(&mut conn) + .await?; + + Ok(HttpResponse::Ok() .cookie(refresh_token_cookie(refresh_token)) - .json(Response { access_token }) + .json(Response { access_token })) } diff --git a/src/api/v1/auth/mod.rs b/src/api/v1/auth/mod.rs index 326b2ef..249ec4b 100644 --- a/src/api/v1/auth/mod.rs +++ b/src/api/v1/auth/mod.rs @@ -1,16 +1,17 @@ use std::{ - str::FromStr, sync::LazyLock, time::{SystemTime, UNIX_EPOCH}, }; -use actix_web::{HttpResponse, Scope, web}; -use log::error; +use actix_web::{Scope, web}; +use diesel::{ExpressionMethods, QueryDsl}; +use diesel_async::RunQueryDsl; use regex::Regex; use serde::Serialize; -use sqlx::Postgres; use uuid::Uuid; +use crate::{error::Error, Conn, schema::access_tokens::dsl}; + mod login; mod refresh; mod register; @@ -40,40 +41,30 @@ pub fn web() -> Scope { pub async fn check_access_token( access_token: &str, - pool: &sqlx::Pool, -) -> Result { - let row = sqlx::query_as( - "SELECT CAST(uuid as VARCHAR), created_at FROM access_tokens WHERE token = $1", - ) - .bind(access_token) - .fetch_one(pool) - .await; - - if let Err(error) = row { - if error.to_string() - == "no rows returned by a query that expected to return at least one row" - { - return Err(HttpResponse::Unauthorized().finish()); - } - - error!("{}", error); - return Err(HttpResponse::InternalServerError().json( - r#"{ "error": "Unhandled exception occured, contact the server administrator" }"#, - )); - } - - let (uuid, created_at): (String, i64) = row.unwrap(); + conn: &mut Conn, +) -> Result { + let (uuid, created_at): (Uuid, i64) = dsl::access_tokens + .filter(dsl::token.eq(access_token)) + .select((dsl::uuid, dsl::created_at)) + .get_result(conn) + .await + .map_err(|error| { + if error == diesel::result::Error::NotFound { + Error::Unauthorized("Invalid access token".to_string()) + } else { + Error::from(error) + } + })?; let current_time = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() + .duration_since(UNIX_EPOCH)? .as_secs() as i64; let lifetime = current_time - created_at; if lifetime > 3600 { - return Err(HttpResponse::Unauthorized().finish()); + return Err(Error::Unauthorized("Invalid access token".to_string())); } - Ok(Uuid::from_str(&uuid).unwrap()) + Ok(uuid) } diff --git a/src/api/v1/auth/refresh.rs b/src/api/v1/auth/refresh.rs index cf1c4bb..468945d 100644 --- a/src/api/v1/auth/refresh.rs +++ b/src/api/v1/auth/refresh.rs @@ -1,10 +1,11 @@ -use actix_web::{Error, HttpRequest, HttpResponse, post, web}; +use actix_web::{HttpRequest, HttpResponse, post, web}; +use diesel::{delete, update, ExpressionMethods, QueryDsl}; +use diesel_async::RunQueryDsl; use log::error; use std::time::{SystemTime, UNIX_EPOCH}; use crate::{ - Data, - utils::{generate_access_token, generate_refresh_token, refresh_token_cookie}, + error::Error, schema::{access_tokens::{self, dsl}, refresh_tokens::{self, dsl as rdsl}}, utils::{generate_access_token, generate_refresh_token, refresh_token_cookie}, Data }; use super::Response; @@ -20,23 +21,23 @@ pub async fn res(req: HttpRequest, data: web::Data) -> Result(&mut conn) .await { - let created_at: i64 = row; - let lifetime = current_time - created_at; if lifetime > 2592000 { - if let Err(error) = sqlx::query("DELETE FROM refresh_tokens WHERE token = $1") - .bind(&refresh_token) - .execute(&data.pool) + if let Err(error) = delete(refresh_tokens::table) + .filter(rdsl::token.eq(&refresh_token)) + .execute(&mut conn) .await { error!("{}", error); @@ -52,8 +53,7 @@ pub async fn res(req: HttpRequest, data: web::Data) -> Result 1987200 { @@ -66,14 +66,14 @@ pub async fn res(req: HttpRequest, data: web::Data) -> Result { refresh_token = new_refresh_token; @@ -84,27 +84,16 @@ pub async fn res(req: HttpRequest, data: web::Data) -> Result { - let refresh_token = generate_refresh_token(); - let access_token = generate_access_token(); + .execute(&mut conn) + .await?; - if refresh_token.is_err() { - error!("{}", refresh_token.unwrap_err()); - return Ok(HttpResponse::InternalServerError().finish()); - } + let refresh_token = generate_refresh_token()?; + let access_token = generate_access_token()?; - let refresh_token = refresh_token.unwrap(); + let current_time = SystemTime::now() + .duration_since(UNIX_EPOCH)? + .as_secs() as i64; - if access_token.is_err() { - error!("{}", access_token.unwrap_err()); - return Ok(HttpResponse::InternalServerError().finish()); - } + insert_into(refresh_tokens::table) + .values(( + rdsl::token.eq(&refresh_token), + rdsl::uuid.eq(uuid), + rdsl::created_at.eq(current_time), + rdsl::device_name.eq(&account_information.device_name), + )) + .execute(&mut conn) + .await?; - let access_token = access_token.unwrap(); + insert_into(access_tokens::table) + .values(( + adsl::token.eq(&access_token), + adsl::refresh_token.eq(&refresh_token), + adsl::uuid.eq(uuid), + adsl::created_at.eq(current_time), + )) + .execute(&mut conn) + .await?; - let current_time = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() as i64; - - if let Err(error) = sqlx::query(&format!("INSERT INTO refresh_tokens (token, uuid, created_at, device_name) VALUES ($1, '{}', $2, $3 )", uuid)) - .bind(&refresh_token) - .bind(current_time) - .bind(&account_information.device_name) - .execute(&data.pool) - .await { - error!("{}", error); - return Ok(HttpResponse::InternalServerError().finish()) - } - - if let Err(error) = sqlx::query(&format!("INSERT INTO access_tokens (token, refresh_token, uuid, created_at) VALUES ($1, $2, '{}', $3 )", uuid)) - .bind(&access_token) - .bind(&refresh_token) - .bind(current_time) - .execute(&data.pool) - .await { - error!("{}", error); - return Ok(HttpResponse::InternalServerError().finish()) - } - - HttpResponse::Ok() - .cookie(refresh_token_cookie(refresh_token)) - .json(Response { access_token }) - } - Err(error) => { - let err_msg = error.as_database_error().unwrap().message(); - - match err_msg { - err_msg - if err_msg.contains("unique") && err_msg.contains("username_key") => - { - HttpResponse::Forbidden().json(ResponseError { - gorb_id_available: false, - ..Default::default() - }) - } - err_msg if err_msg.contains("unique") && err_msg.contains("email_key") => { - HttpResponse::Forbidden().json(ResponseError { - email_available: false, - ..Default::default() - }) - } - _ => { - error!("{}", err_msg); - HttpResponse::InternalServerError().finish() - } - } - } - }, - ); + return Ok(HttpResponse::Ok() + .cookie(refresh_token_cookie(refresh_token)) + .json(Response { access_token })) } Ok(HttpResponse::InternalServerError().finish()) diff --git a/src/api/v1/auth/revoke.rs b/src/api/v1/auth/revoke.rs index a4f9196..116ed5c 100644 --- a/src/api/v1/auth/revoke.rs +++ b/src/api/v1/auth/revoke.rs @@ -1,10 +1,10 @@ -use actix_web::{Error, HttpRequest, HttpResponse, post, web}; +use actix_web::{HttpRequest, HttpResponse, post, web}; use argon2::{PasswordHash, PasswordVerifier}; -use futures::future; -use log::error; -use serde::{Deserialize, Serialize}; +use diesel::{delete, ExpressionMethods, QueryDsl}; +use diesel_async::RunQueryDsl; +use serde::Deserialize; -use crate::{Data, api::v1::auth::check_access_token, utils::get_auth_header}; +use crate::{api::v1::auth::check_access_token, error::Error, schema::users::dsl as udsl, schema::refresh_tokens::{self, dsl as rdsl}, utils::get_auth_header, Data}; #[derive(Deserialize)] struct RevokeRequest { @@ -12,17 +12,6 @@ struct RevokeRequest { device_name: String, } -#[derive(Serialize)] -struct Response { - deleted: bool, -} - -impl Response { - fn new(deleted: bool) -> Self { - Self { deleted } - } -} - // TODO: Should maybe be a delete request? #[post("/revoke")] pub async fn res( @@ -32,85 +21,33 @@ pub async fn res( ) -> Result { let headers = req.headers(); - let auth_header = get_auth_header(headers); + let auth_header = get_auth_header(headers)?; - if let Err(error) = auth_header { - return Ok(error); - } + let mut conn = data.pool.get().await?; - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; + let uuid = check_access_token(auth_header, &mut conn).await?; - if let Err(error) = authorized { - return Ok(error); - } + let database_password: String = udsl::users + .filter(udsl::uuid.eq(uuid)) + .select(udsl::password) + .get_result(&mut conn) + .await?; - let uuid = authorized.unwrap(); - - let database_password_raw = sqlx::query_scalar(&format!( - "SELECT password FROM users WHERE uuid = '{}'", - uuid - )) - .fetch_one(&data.pool) - .await; - - if let Err(error) = database_password_raw { - error!("{}", error); - return Ok(HttpResponse::InternalServerError().json(Response::new(false))); - } - - let database_password: String = database_password_raw.unwrap(); - - let hashed_password_raw = PasswordHash::new(&database_password); - - if let Err(error) = hashed_password_raw { - error!("{}", error); - return Ok(HttpResponse::InternalServerError().json(Response::new(false))); - } - - let hashed_password = hashed_password_raw.unwrap(); + let hashed_password = PasswordHash::new(&database_password).map_err(|e| Error::PasswordHashError(e.to_string()))?; if data .argon2 .verify_password(revoke_request.password.as_bytes(), &hashed_password) .is_err() { - return Ok(HttpResponse::Unauthorized().finish()); + return Err(Error::Unauthorized("Wrong username or password".to_string())); } - let tokens_raw = sqlx::query_scalar(&format!( - "SELECT token FROM refresh_tokens WHERE uuid = '{}' AND device_name = $1", - uuid - )) - .bind(&revoke_request.device_name) - .fetch_all(&data.pool) - .await; + delete(refresh_tokens::table) + .filter(rdsl::uuid.eq(uuid)) + .filter(rdsl::device_name.eq(&revoke_request.device_name)) + .execute(&mut conn) + .await?; - if tokens_raw.is_err() { - error!("{:?}", tokens_raw); - return Ok(HttpResponse::InternalServerError().json(Response::new(false))); - } - - let tokens: Vec = tokens_raw.unwrap(); - - let mut refresh_tokens_delete = vec![]; - - for token in tokens { - refresh_tokens_delete.push( - sqlx::query("DELETE FROM refresh_tokens WHERE token = $1") - .bind(token.clone()) - .execute(&data.pool), - ); - } - - let results = future::join_all(refresh_tokens_delete).await; - - let errors: Vec<&Result> = - results.iter().filter(|r| r.is_err()).collect(); - - if !errors.is_empty() { - error!("{:?}", errors); - return Ok(HttpResponse::InternalServerError().finish()); - } - - Ok(HttpResponse::Ok().json(Response::new(true))) + Ok(HttpResponse::Ok().finish()) } -- 2.47.2 From 6190d762854138a915918e9376423243e88e33fe Mon Sep 17 00:00:00 2001 From: Radical Date: Fri, 23 May 2025 12:56:19 +0200 Subject: [PATCH 14/17] feat: migrate to diesel and new error type in servers --- src/api/v1/servers/mod.rs | 48 +++------- src/api/v1/servers/uuid/channels/mod.rs | 69 +++----------- .../v1/servers/uuid/channels/uuid/messages.rs | 57 +++--------- src/api/v1/servers/uuid/channels/uuid/mod.rs | 91 ++++--------------- .../v1/servers/uuid/channels/uuid/socket.rs | 56 +++--------- src/api/v1/servers/uuid/invites/mod.rs | 77 ++++------------ src/api/v1/servers/uuid/mod.rs | 31 ++----- src/api/v1/servers/uuid/roles/mod.rs | 78 ++++------------ src/api/v1/servers/uuid/roles/uuid.rs | 45 ++------- 9 files changed, 123 insertions(+), 429 deletions(-) diff --git a/src/api/v1/servers/mod.rs b/src/api/v1/servers/mod.rs index 7c74ff0..8e2e186 100644 --- a/src/api/v1/servers/mod.rs +++ b/src/api/v1/servers/mod.rs @@ -1,9 +1,9 @@ -use actix_web::{get, post, web, Error, HttpRequest, HttpResponse, Scope}; +use actix_web::{get, post, web, HttpRequest, HttpResponse, Scope}; use serde::Deserialize; mod uuid; -use crate::{api::v1::auth::check_access_token, structs::{Guild, StartAmountQuery}, utils::get_auth_header, Data}; +use crate::{error::Error, api::v1::auth::check_access_token, structs::{Guild, StartAmountQuery}, utils::get_auth_header, Data}; #[derive(Deserialize)] struct GuildInfo { @@ -26,33 +26,21 @@ pub async fn create( ) -> Result { let headers = req.headers(); - let auth_header = get_auth_header(headers); + let auth_header = get_auth_header(headers)?; - if let Err(error) = auth_header { - return Ok(error); - } + let mut conn = data.pool.get().await?; - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; - - if let Err(error) = authorized { - return Ok(error); - } - - let uuid = authorized.unwrap(); + let uuid = check_access_token(auth_header, &mut conn).await?; let guild = Guild::new( - &data.pool, + &mut conn, guild_info.name.clone(), guild_info.description.clone(), uuid, ) - .await; + .await?; - if let Err(error) = guild { - return Ok(error); - } - - Ok(HttpResponse::Ok().json(guild.unwrap())) + Ok(HttpResponse::Ok().json(guild)) } #[get("")] @@ -63,28 +51,16 @@ pub async fn get( ) -> Result { let headers = req.headers(); - let auth_header = get_auth_header(headers); + let auth_header = get_auth_header(headers)?; let start = request_query.start.unwrap_or(0); let amount = request_query.amount.unwrap_or(10); - if let Err(error) = auth_header { - return Ok(error); - } + check_access_token(auth_header, &mut data.pool.get().await.unwrap()).await?; - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; + let guilds = Guild::fetch_amount(&data.pool, start, amount).await?; - if let Err(error) = authorized { - return Ok(error); - } - - let guilds = Guild::fetch_amount(&data.pool, start, amount).await; - - if let Err(error) = guilds { - return Ok(error); - } - - Ok(HttpResponse::Ok().json(guilds.unwrap())) + Ok(HttpResponse::Ok().json(guilds)) } diff --git a/src/api/v1/servers/uuid/channels/mod.rs b/src/api/v1/servers/uuid/channels/mod.rs index 3e6a342..4348422 100644 --- a/src/api/v1/servers/uuid/channels/mod.rs +++ b/src/api/v1/servers/uuid/channels/mod.rs @@ -1,12 +1,12 @@ use crate::{ + error::Error, Data, api::v1::auth::check_access_token, structs::{Channel, Member}, utils::get_auth_header, }; use ::uuid::Uuid; -use actix_web::{Error, HttpRequest, HttpResponse, get, post, web}; -use log::error; +use actix_web::{HttpRequest, HttpResponse, get, post, web}; use serde::Deserialize; pub mod uuid; @@ -25,52 +25,27 @@ pub async fn get( ) -> Result { let headers = req.headers(); - let auth_header = get_auth_header(headers); - - if let Err(error) = auth_header { - return Ok(error); - } + let auth_header = get_auth_header(headers)?; let guild_uuid = path.into_inner().0; - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; + let mut conn = data.pool.get().await?; - if let Err(error) = authorized { - return Ok(error); - } + let uuid = check_access_token(auth_header, &mut conn).await?; - let uuid = authorized.unwrap(); + Member::fetch_one(&mut conn, uuid, guild_uuid).await?; - let member = Member::fetch_one(&data.pool, uuid, guild_uuid).await; - - if let Err(error) = member { - return Ok(error); - } - - let cache_result = data.get_cache_key(format!("{}_channels", guild_uuid)).await; - - if let Ok(cache_hit) = cache_result { + if let Ok(cache_hit) = data.get_cache_key(format!("{}_channels", guild_uuid)).await { return Ok(HttpResponse::Ok() .content_type("application/json") .body(cache_hit)); } - let channels_result = Channel::fetch_all(&data.pool, guild_uuid).await; + let channels = Channel::fetch_all(&data.pool, guild_uuid).await?; - if let Err(error) = channels_result { - return Ok(error); - } - - let channels = channels_result.unwrap(); - - let cache_result = data + data .set_cache_key(format!("{}_channels", guild_uuid), channels.clone(), 1800) - .await; - - if let Err(error) = cache_result { - error!("{}", error); - return Ok(HttpResponse::InternalServerError().finish()); - } + .await?; Ok(HttpResponse::Ok().json(channels)) } @@ -84,27 +59,15 @@ pub async fn create( ) -> Result { let headers = req.headers(); - let auth_header = get_auth_header(headers); - - if let Err(error) = auth_header { - return Ok(error); - } + let auth_header = get_auth_header(headers)?; let guild_uuid = path.into_inner().0; - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; + let mut conn = data.pool.get().await?; - if let Err(error) = authorized { - return Ok(error); - } + let uuid = check_access_token(auth_header, &mut conn).await?; - let uuid = authorized.unwrap(); - - let member = Member::fetch_one(&data.pool, uuid, guild_uuid).await; - - if let Err(error) = member { - return Ok(error); - } + Member::fetch_one(&mut conn, uuid, guild_uuid).await?; // FIXME: Logic to check permissions, should probably be done in utils.rs @@ -116,9 +79,5 @@ pub async fn create( ) .await; - if let Err(error) = channel { - return Ok(error); - } - Ok(HttpResponse::Ok().json(channel.unwrap())) } diff --git a/src/api/v1/servers/uuid/channels/uuid/messages.rs b/src/api/v1/servers/uuid/channels/uuid/messages.rs index ff36a4f..954651e 100644 --- a/src/api/v1/servers/uuid/channels/uuid/messages.rs +++ b/src/api/v1/servers/uuid/channels/uuid/messages.rs @@ -1,12 +1,12 @@ use crate::{ + error::Error, Data, api::v1::auth::check_access_token, structs::{Channel, Member}, utils::get_auth_header, }; use ::uuid::Uuid; -use actix_web::{Error, HttpRequest, HttpResponse, get, web}; -use log::error; +use actix_web::{HttpRequest, HttpResponse, get, web}; use serde::Deserialize; #[derive(Deserialize)] @@ -24,60 +24,31 @@ pub async fn get( ) -> Result { let headers = req.headers(); - let auth_header = get_auth_header(headers); - - if let Err(error) = auth_header { - return Ok(error); - } + let auth_header = get_auth_header(headers)?; let (guild_uuid, channel_uuid) = path.into_inner(); - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; + let mut conn = data.pool.get().await?; - if let Err(error) = authorized { - return Ok(error); - } + let uuid = check_access_token(auth_header, &mut conn).await?; - let uuid = authorized.unwrap(); - - let member = Member::fetch_one(&data.pool, uuid, guild_uuid).await; - - if let Err(error) = member { - return Ok(error); - } - - let cache_result = data.get_cache_key(format!("{}", channel_uuid)).await; + Member::fetch_one(&mut conn, uuid, guild_uuid).await?; let channel: Channel; - if let Ok(cache_hit) = cache_result { - channel = serde_json::from_str(&cache_hit).unwrap() + if let Ok(cache_hit) = data.get_cache_key(format!("{}", channel_uuid)).await { + channel = serde_json::from_str(&cache_hit)? } else { - let channel_result = Channel::fetch_one(&data.pool, guild_uuid, channel_uuid).await; + channel = Channel::fetch_one(&mut conn, channel_uuid).await?; - if let Err(error) = channel_result { - return Ok(error); - } - - channel = channel_result.unwrap(); - - let cache_result = data + data .set_cache_key(format!("{}", channel_uuid), channel.clone(), 60) - .await; - - if let Err(error) = cache_result { - error!("{}", error); - return Ok(HttpResponse::InternalServerError().finish()); - } + .await?; } let messages = channel - .fetch_messages(&data.pool, message_request.amount, message_request.offset) - .await; + .fetch_messages(&mut conn, message_request.amount, message_request.offset) + .await?; - if let Err(error) = messages { - return Ok(error); - } - - Ok(HttpResponse::Ok().json(messages.unwrap())) + Ok(HttpResponse::Ok().json(messages)) } diff --git a/src/api/v1/servers/uuid/channels/uuid/mod.rs b/src/api/v1/servers/uuid/channels/uuid/mod.rs index c737509..4cf6013 100644 --- a/src/api/v1/servers/uuid/channels/uuid/mod.rs +++ b/src/api/v1/servers/uuid/channels/uuid/mod.rs @@ -2,14 +2,14 @@ pub mod messages; pub mod socket; use crate::{ + error::Error, Data, api::v1::auth::check_access_token, structs::{Channel, Member}, utils::get_auth_header, }; -use ::uuid::Uuid; -use actix_web::{Error, HttpRequest, HttpResponse, delete, get, web}; -use log::error; +use uuid::Uuid; +use actix_web::{HttpRequest, HttpResponse, delete, get, web}; #[get("{uuid}/channels/{channel_uuid}")] pub async fn get( @@ -19,52 +19,27 @@ pub async fn get( ) -> Result { let headers = req.headers(); - let auth_header = get_auth_header(headers); - - if let Err(error) = auth_header { - return Ok(error); - } + let auth_header = get_auth_header(headers)?; let (guild_uuid, channel_uuid) = path.into_inner(); - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; + let mut conn = data.pool.get().await?; - if let Err(error) = authorized { - return Ok(error); - } + let uuid = check_access_token(auth_header, &mut conn).await?; - let uuid = authorized.unwrap(); + Member::fetch_one(&mut conn, uuid, guild_uuid).await?; - let member = Member::fetch_one(&data.pool, uuid, guild_uuid).await; - - if let Err(error) = member { - return Ok(error); - } - - let cache_result = data.get_cache_key(format!("{}", channel_uuid)).await; - - if let Ok(cache_hit) = cache_result { + if let Ok(cache_hit) = data.get_cache_key(format!("{}", channel_uuid)).await { return Ok(HttpResponse::Ok() .content_type("application/json") .body(cache_hit)); } - let channel_result = Channel::fetch_one(&data.pool, guild_uuid, channel_uuid).await; + let channel = Channel::fetch_one(&mut conn, channel_uuid).await?; - if let Err(error) = channel_result { - return Ok(error); - } - - let channel = channel_result.unwrap(); - - let cache_result = data + data .set_cache_key(format!("{}", channel_uuid), channel.clone(), 60) - .await; - - if let Err(error) = cache_result { - error!("{}", error); - return Ok(HttpResponse::InternalServerError().finish()); - } + .await?; Ok(HttpResponse::Ok().json(channel)) } @@ -77,55 +52,27 @@ pub async fn delete( ) -> Result { let headers = req.headers(); - let auth_header = get_auth_header(headers); - - if let Err(error) = auth_header { - return Ok(error); - } + let auth_header = get_auth_header(headers)?; let (guild_uuid, channel_uuid) = path.into_inner(); - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; + let mut conn = data.pool.get().await?; - if let Err(error) = authorized { - return Ok(error); - } + let uuid = check_access_token(auth_header, &mut conn).await?; - let uuid = authorized.unwrap(); - - let member = Member::fetch_one(&data.pool, uuid, guild_uuid).await; - - if let Err(error) = member { - return Ok(error); - } - - let cache_result = data.get_cache_key(format!("{}", channel_uuid)).await; + Member::fetch_one(&mut conn, uuid, guild_uuid).await?; let channel: Channel; - if let Ok(cache_hit) = cache_result { + if let Ok(cache_hit) = data.get_cache_key(format!("{}", channel_uuid)).await { channel = serde_json::from_str(&cache_hit).unwrap(); - let result = data.del_cache_key(format!("{}", channel_uuid)).await; - - if let Err(error) = result { - error!("{}", error) - } + data.del_cache_key(format!("{}", channel_uuid)).await?; } else { - let channel_result = Channel::fetch_one(&data.pool, guild_uuid, channel_uuid).await; - - if let Err(error) = channel_result { - return Ok(error); - } - - channel = channel_result.unwrap(); + channel = Channel::fetch_one(&mut conn, channel_uuid).await?; } - let delete_result = channel.delete(&data.pool).await; - - if let Err(error) = delete_result { - return Ok(error); - } + channel.delete(&mut conn).await?; Ok(HttpResponse::Ok().finish()) } diff --git a/src/api/v1/servers/uuid/channels/uuid/socket.rs b/src/api/v1/servers/uuid/channels/uuid/socket.rs index b9b4ff7..14cb7d9 100644 --- a/src/api/v1/servers/uuid/channels/uuid/socket.rs +++ b/src/api/v1/servers/uuid/channels/uuid/socket.rs @@ -1,7 +1,6 @@ use actix_web::{Error, HttpRequest, HttpResponse, get, rt, web}; use actix_ws::AggregatedMessage; use futures_util::StreamExt as _; -use log::error; use uuid::Uuid; use crate::{ @@ -22,57 +21,30 @@ pub async fn echo( let headers = req.headers(); // Retrieve auth header - let auth_header = get_auth_header(headers); - - if let Err(error) = auth_header { - return Ok(error); - } + let auth_header = get_auth_header(headers)?; // Get uuids from path let (guild_uuid, channel_uuid) = path.into_inner(); + let mut conn = data.pool.get().await.map_err(|e| crate::error::Error::from(e))?; + // Authorize client using auth header - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; - - if let Err(error) = authorized { - return Ok(error); - } - - // Unwrap user uuid from authorization - let uuid = authorized.unwrap(); + let uuid = check_access_token(auth_header, &mut conn).await?; // Get server member from psql - let member = Member::fetch_one(&data.pool, uuid, guild_uuid).await; - - if let Err(error) = member { - return Ok(error); - } - - // Get cache for channel - let cache_result = data.get_cache_key(format!("{}", channel_uuid)).await; + Member::fetch_one(&mut conn, uuid, guild_uuid).await?; let channel: Channel; // Return channel cache or result from psql as `channel` variable - if let Ok(cache_hit) = cache_result { + if let Ok(cache_hit) = data.get_cache_key(format!("{}", channel_uuid)).await { channel = serde_json::from_str(&cache_hit).unwrap() } else { - let channel_result = Channel::fetch_one(&data.pool, guild_uuid, channel_uuid).await; + channel = Channel::fetch_one(&mut conn, channel_uuid).await?; - if let Err(error) = channel_result { - return Ok(error); - } - - channel = channel_result.unwrap(); - - let cache_result = data + data .set_cache_key(format!("{}", channel_uuid), channel.clone(), 60) - .await; - - if let Err(error) = cache_result { - error!("{}", error); - return Ok(HttpResponse::InternalServerError().finish()); - } + .await?; } let (res, mut session_1, stream) = actix_ws::handle(&req, stream)?; @@ -82,17 +54,11 @@ pub async fn echo( // aggregate continuation frames up to 1MiB .max_continuation_size(2_usize.pow(20)); - let pubsub_result = data.cache_pool.get_async_pubsub().await; - - if let Err(error) = pubsub_result { - error!("{}", error); - return Ok(HttpResponse::InternalServerError().finish()); - } + let mut pubsub = data.cache_pool.get_async_pubsub().await.map_err(|e| crate::error::Error::from(e))?; let mut session_2 = session_1.clone(); rt::spawn(async move { - let mut pubsub = pubsub_result.unwrap(); pubsub.subscribe(channel_uuid.to_string()).await.unwrap(); while let Some(msg) = pubsub.on_message().next().await { let payload: String = msg.get_payload().unwrap(); @@ -118,7 +84,7 @@ pub async fn echo( .await .unwrap(); channel - .new_message(&data.pool, uuid, text.to_string()) + .new_message(&mut data.pool.get().await.unwrap(), uuid, text.to_string()) .await .unwrap(); } diff --git a/src/api/v1/servers/uuid/invites/mod.rs b/src/api/v1/servers/uuid/invites/mod.rs index 2a07808..badb3e0 100644 --- a/src/api/v1/servers/uuid/invites/mod.rs +++ b/src/api/v1/servers/uuid/invites/mod.rs @@ -1,8 +1,9 @@ -use actix_web::{Error, HttpRequest, HttpResponse, get, post, web}; +use actix_web::{HttpRequest, HttpResponse, get, post, web}; use serde::Deserialize; use uuid::Uuid; use crate::{ + error::Error, Data, api::v1::auth::check_access_token, structs::{Guild, Member}, @@ -22,43 +23,21 @@ pub async fn get( ) -> Result { let headers = req.headers(); - let auth_header = get_auth_header(headers); - - if let Err(error) = auth_header { - return Ok(error); - } + let auth_header = get_auth_header(headers)?; let guild_uuid = path.into_inner().0; - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; + let mut conn = data.pool.get().await?; - if let Err(error) = authorized { - return Ok(error); - } + let uuid = check_access_token(auth_header, &mut conn).await?; - let uuid = authorized.unwrap(); + Member::fetch_one(&mut conn, uuid, guild_uuid).await?; - let member = Member::fetch_one(&data.pool, uuid, guild_uuid).await; + let guild = Guild::fetch_one(&mut conn, guild_uuid).await?; - if let Err(error) = member { - return Ok(error); - } + let invites = guild.get_invites(&mut conn).await?; - let guild_result = Guild::fetch_one(&data.pool, guild_uuid).await; - - if let Err(error) = guild_result { - return Ok(error); - } - - let guild = guild_result.unwrap(); - - let invites = guild.get_invites(&data.pool).await; - - if let Err(error) = invites { - return Ok(error); - } - - Ok(HttpResponse::Ok().json(invites.unwrap())) + Ok(HttpResponse::Ok().json(invites)) } #[post("{uuid}/invites")] @@ -70,45 +49,21 @@ pub async fn create( ) -> Result { let headers = req.headers(); - let auth_header = get_auth_header(headers); - - if let Err(error) = auth_header { - return Ok(error); - } + let auth_header = get_auth_header(headers)?; let guild_uuid = path.into_inner().0; - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; + let mut conn = data.pool.get().await?; - if let Err(error) = authorized { - return Ok(error); - } + let uuid = check_access_token(auth_header, &mut conn).await?; - let uuid = authorized.unwrap(); + let member = Member::fetch_one(&mut conn, uuid, guild_uuid).await?; - let member_result = Member::fetch_one(&data.pool, uuid, guild_uuid).await; - - if let Err(error) = member_result { - return Ok(error); - } - - let member = member_result.unwrap(); - - let guild_result = Guild::fetch_one(&data.pool, guild_uuid).await; - - if let Err(error) = guild_result { - return Ok(error); - } - - let guild = guild_result.unwrap(); + let guild = Guild::fetch_one(&mut conn, guild_uuid).await?; let custom_id = invite_request.as_ref().map(|ir| ir.custom_id.clone()); - let invite = guild.create_invite(&data.pool, &member, custom_id).await; + let invite = guild.create_invite(&mut conn, &member, custom_id).await?; - if let Err(error) = invite { - return Ok(error); - } - - Ok(HttpResponse::Ok().json(invite.unwrap())) + Ok(HttpResponse::Ok().json(invite)) } diff --git a/src/api/v1/servers/uuid/mod.rs b/src/api/v1/servers/uuid/mod.rs index 8f387aa..bac4004 100644 --- a/src/api/v1/servers/uuid/mod.rs +++ b/src/api/v1/servers/uuid/mod.rs @@ -1,4 +1,4 @@ -use actix_web::{Error, HttpRequest, HttpResponse, Scope, get, web}; +use actix_web::{HttpRequest, HttpResponse, Scope, get, web}; use uuid::Uuid; mod channels; @@ -6,6 +6,7 @@ mod invites; mod roles; use crate::{ + error::Error, Data, api::v1::auth::check_access_token, structs::{Guild, Member}, @@ -40,33 +41,17 @@ pub async fn res( ) -> Result { let headers = req.headers(); - let auth_header = get_auth_header(headers); - - if let Err(error) = auth_header { - return Ok(error); - } + let auth_header = get_auth_header(headers)?; let guild_uuid = path.into_inner().0; - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; + let mut conn = data.pool.get().await?; - if let Err(error) = authorized { - return Ok(error); - } + let uuid = check_access_token(auth_header, &mut conn).await?; - let uuid = authorized.unwrap(); + Member::fetch_one(&mut conn, uuid, guild_uuid).await?; - let member = Member::fetch_one(&data.pool, uuid, guild_uuid).await; + let guild = Guild::fetch_one(&mut conn, guild_uuid).await?; - if let Err(error) = member { - return Ok(error); - } - - let guild = Guild::fetch_one(&data.pool, guild_uuid).await; - - if let Err(error) = guild { - return Ok(error); - } - - Ok(HttpResponse::Ok().json(guild.unwrap())) + Ok(HttpResponse::Ok().json(guild)) } diff --git a/src/api/v1/servers/uuid/roles/mod.rs b/src/api/v1/servers/uuid/roles/mod.rs index 8d22813..a2912f9 100644 --- a/src/api/v1/servers/uuid/roles/mod.rs +++ b/src/api/v1/servers/uuid/roles/mod.rs @@ -1,13 +1,14 @@ +use ::uuid::Uuid; +use actix_web::{HttpRequest, HttpResponse, get, post, web}; +use serde::Deserialize; + use crate::{ + error::Error, Data, api::v1::auth::check_access_token, structs::{Member, Role}, utils::get_auth_header, }; -use ::uuid::Uuid; -use actix_web::{Error, HttpRequest, HttpResponse, get, post, web}; -use log::error; -use serde::Deserialize; pub mod uuid; @@ -24,52 +25,27 @@ pub async fn get( ) -> Result { let headers = req.headers(); - let auth_header = get_auth_header(headers); - - if let Err(error) = auth_header { - return Ok(error); - } + let auth_header = get_auth_header(headers)?; let guild_uuid = path.into_inner().0; - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; + let mut conn = data.pool.get().await?; - if let Err(error) = authorized { - return Ok(error); - } + let uuid = check_access_token(auth_header, &mut conn).await?; - let uuid = authorized.unwrap(); + Member::fetch_one(&mut conn, uuid, guild_uuid).await?; - let member = Member::fetch_one(&data.pool, uuid, guild_uuid).await; - - if let Err(error) = member { - return Ok(error); - } - - let cache_result = data.get_cache_key(format!("{}_roles", guild_uuid)).await; - - if let Ok(cache_hit) = cache_result { + if let Ok(cache_hit) = data.get_cache_key(format!("{}_roles", guild_uuid)).await { return Ok(HttpResponse::Ok() .content_type("application/json") .body(cache_hit)); } - let roles_result = Role::fetch_all(&data.pool, guild_uuid).await; + let roles = Role::fetch_all(&mut conn, guild_uuid).await?; - if let Err(error) = roles_result { - return Ok(error); - } - - let roles = roles_result.unwrap(); - - let cache_result = data + data .set_cache_key(format!("{}_roles", guild_uuid), roles.clone(), 1800) - .await; - - if let Err(error) = cache_result { - error!("{}", error); - return Ok(HttpResponse::InternalServerError().finish()); - } + .await?; Ok(HttpResponse::Ok().json(roles)) } @@ -83,35 +59,19 @@ pub async fn create( ) -> Result { let headers = req.headers(); - let auth_header = get_auth_header(headers); - - if let Err(error) = auth_header { - return Ok(error); - } + let auth_header = get_auth_header(headers)?; let guild_uuid = path.into_inner().0; - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; + let mut conn = data.pool.get().await.unwrap(); - if let Err(error) = authorized { - return Ok(error); - } + let uuid = check_access_token(auth_header, &mut conn).await?; - let uuid = authorized.unwrap(); - - let member = Member::fetch_one(&data.pool, uuid, guild_uuid).await; - - if let Err(error) = member { - return Ok(error); - } + Member::fetch_one(&mut conn, uuid, guild_uuid).await?; // FIXME: Logic to check permissions, should probably be done in utils.rs - let role = Role::new(&data.pool, guild_uuid, role_info.name.clone()).await; + let role = Role::new(&mut conn, guild_uuid, role_info.name.clone()).await?; - if let Err(error) = role { - return Ok(error); - } - - Ok(HttpResponse::Ok().json(role.unwrap())) + Ok(HttpResponse::Ok().json(role)) } diff --git a/src/api/v1/servers/uuid/roles/uuid.rs b/src/api/v1/servers/uuid/roles/uuid.rs index 38bdca9..3279d16 100644 --- a/src/api/v1/servers/uuid/roles/uuid.rs +++ b/src/api/v1/servers/uuid/roles/uuid.rs @@ -1,12 +1,12 @@ use crate::{ + error::Error, Data, api::v1::auth::check_access_token, structs::{Member, Role}, utils::get_auth_header, }; use ::uuid::Uuid; -use actix_web::{Error, HttpRequest, HttpResponse, get, web}; -use log::error; +use actix_web::{HttpRequest, HttpResponse, get, web}; #[get("{uuid}/roles/{role_uuid}")] pub async fn get( @@ -16,52 +16,27 @@ pub async fn get( ) -> Result { let headers = req.headers(); - let auth_header = get_auth_header(headers); - - if let Err(error) = auth_header { - return Ok(error); - } + let auth_header = get_auth_header(headers)?; let (guild_uuid, role_uuid) = path.into_inner(); - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; + let mut conn = data.pool.get().await?; - if let Err(error) = authorized { - return Ok(error); - } + let uuid = check_access_token(auth_header, &mut conn).await?; - let uuid = authorized.unwrap(); + Member::fetch_one(&mut conn, uuid, guild_uuid).await?; - let member = Member::fetch_one(&data.pool, uuid, guild_uuid).await; - - if let Err(error) = member { - return Ok(error); - } - - let cache_result = data.get_cache_key(format!("{}", role_uuid)).await; - - if let Ok(cache_hit) = cache_result { + if let Ok(cache_hit) = data.get_cache_key(format!("{}", role_uuid)).await { return Ok(HttpResponse::Ok() .content_type("application/json") .body(cache_hit)); } - let role_result = Role::fetch_one(&data.pool, guild_uuid, role_uuid).await; + let role = Role::fetch_one(&mut conn, role_uuid).await?; - if let Err(error) = role_result { - return Ok(error); - } - - let role = role_result.unwrap(); - - let cache_result = data + data .set_cache_key(format!("{}", role_uuid), role.clone(), 60) - .await; - - if let Err(error) = cache_result { - error!("{}", error); - return Ok(HttpResponse::InternalServerError().finish()); - } + .await?; Ok(HttpResponse::Ok().json(role)) } -- 2.47.2 From dfe2ca9486ca04cb3b6329ae8327c4e041e279c0 Mon Sep 17 00:00:00 2001 From: Radical Date: Fri, 23 May 2025 12:56:51 +0200 Subject: [PATCH 15/17] feat: migrate to diesel and new error type in users --- src/api/v1/users/me.rs | 52 ++++++++++++++++------------------------ src/api/v1/users/mod.rs | 45 ++++++++++++++++------------------ src/api/v1/users/uuid.rs | 50 +++++++++++++------------------------- 3 files changed, 59 insertions(+), 88 deletions(-) diff --git a/src/api/v1/users/me.rs b/src/api/v1/users/me.rs index f641678..49f88ba 100644 --- a/src/api/v1/users/me.rs +++ b/src/api/v1/users/me.rs @@ -1,51 +1,41 @@ -use actix_web::{Error, HttpRequest, HttpResponse, get, web}; +use actix_web::{HttpRequest, HttpResponse, get, web}; +use diesel::{prelude::Queryable, ExpressionMethods, QueryDsl, Selectable, SelectableHelper}; +use diesel_async::RunQueryDsl; use log::error; use serde::Serialize; +use uuid::Uuid; -use crate::{Data, api::v1::auth::check_access_token, utils::get_auth_header}; +use crate::{error::Error, api::v1::auth::check_access_token, schema::users::{self, dsl}, utils::get_auth_header, Data}; -#[derive(Serialize)] +#[derive(Serialize, Queryable, Selectable)] +#[diesel(table_name = users)] +#[diesel(check_for_backend(diesel::pg::Pg))] struct Response { - uuid: String, + uuid: Uuid, username: String, - display_name: String, + display_name: Option, } #[get("/me")] pub async fn res(req: HttpRequest, data: web::Data) -> Result { let headers = req.headers(); - let auth_header = get_auth_header(headers); + let auth_header = get_auth_header(headers)?; - if let Err(error) = auth_header { - return Ok(error); - } + let mut conn = data.pool.get().await?; - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; + let uuid = check_access_token(auth_header, &mut conn).await?; - if let Err(error) = authorized { - return Ok(error); - } + let user: Result = dsl::users + .filter(dsl::uuid.eq(uuid)) + .select(Response::as_select()) + .get_result(&mut conn) + .await; - let uuid = authorized.unwrap(); - - let row = sqlx::query_as(&format!( - "SELECT username, display_name FROM users WHERE uuid = '{}'", - uuid - )) - .fetch_one(&data.pool) - .await; - - if let Err(error) = row { + if let Err(error) = user { error!("{}", error); - return Ok(HttpResponse::InternalServerError().finish()); + return Ok(HttpResponse::InternalServerError().finish()) } - let (username, display_name): (String, Option) = row.unwrap(); - - Ok(HttpResponse::Ok().json(Response { - uuid: uuid.to_string(), - username, - display_name: display_name.unwrap_or_default(), - })) + Ok(HttpResponse::Ok().json(user.unwrap())) } diff --git a/src/api/v1/users/mod.rs b/src/api/v1/users/mod.rs index d6eb6bd..37d884a 100644 --- a/src/api/v1/users/mod.rs +++ b/src/api/v1/users/mod.rs @@ -1,15 +1,19 @@ -use crate::{api::v1::auth::check_access_token, structs::StartAmountQuery, utils::get_auth_header, Data}; -use actix_web::{Error, HttpRequest, HttpResponse, Scope, get, web}; -use log::error; +use actix_web::{HttpRequest, HttpResponse, Scope, get, web}; +use diesel::{prelude::Queryable, QueryDsl, Selectable, SelectableHelper}; +use diesel_async::RunQueryDsl; use serde::Serialize; -use sqlx::prelude::FromRow; +use ::uuid::Uuid; + +use crate::{error::Error,api::v1::auth::check_access_token, schema::users::{self, dsl}, structs::StartAmountQuery, utils::get_auth_header, Data}; mod me; mod uuid; -#[derive(Serialize, FromRow)] +#[derive(Serialize, Queryable, Selectable)] +#[diesel(table_name = users)] +#[diesel(check_for_backend(diesel::pg::Pg))] struct Response { - uuid: String, + uuid: Uuid, username: String, display_name: Option, email: String, @@ -30,7 +34,7 @@ pub async fn res( ) -> Result { let headers = req.headers(); - let auth_header = get_auth_header(headers); + let auth_header = get_auth_header(headers)?; let start = request_query.start.unwrap_or(0); @@ -40,24 +44,17 @@ pub async fn res( return Ok(HttpResponse::BadRequest().finish()); } - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; + let mut conn = data.pool.get().await?; - if let Err(error) = authorized { - return Ok(error); - } + check_access_token(auth_header, &mut conn).await?; - let row = sqlx::query_as("SELECT CAST(uuid AS VARCHAR), username, display_name, email FROM users ORDER BY username LIMIT $1 OFFSET $2") - .bind(amount) - .bind(start) - .fetch_all(&data.pool) - .await; + let users: Vec = dsl::users + .order_by(dsl::username) + .offset(start) + .limit(amount) + .select(Response::as_select()) + .load(&mut conn) + .await?; - if let Err(error) = row { - error!("{}", error); - return Ok(HttpResponse::InternalServerError().finish()); - } - - let accounts: Vec = row.unwrap(); - - Ok(HttpResponse::Ok().json(accounts)) + Ok(HttpResponse::Ok().json(users)) } diff --git a/src/api/v1/users/uuid.rs b/src/api/v1/users/uuid.rs index 9edaffa..bfb0f69 100644 --- a/src/api/v1/users/uuid.rs +++ b/src/api/v1/users/uuid.rs @@ -1,15 +1,19 @@ -use actix_web::{Error, HttpRequest, HttpResponse, get, web}; +use actix_web::{HttpRequest, HttpResponse, get, web}; +use diesel::{ExpressionMethods, QueryDsl, Queryable, Selectable, SelectableHelper}; +use diesel_async::RunQueryDsl; use log::error; use serde::Serialize; use uuid::Uuid; -use crate::{Data, api::v1::auth::check_access_token, utils::get_auth_header}; +use crate::{error::Error, api::v1::auth::check_access_token, schema::users::{self, dsl}, utils::get_auth_header, Data}; -#[derive(Serialize, Clone)] +#[derive(Serialize, Queryable, Selectable, Clone)] +#[diesel(table_name = users)] +#[diesel(check_for_backend(diesel::pg::Pg))] struct Response { - uuid: String, + uuid: Uuid, username: String, - display_name: String, + display_name: Option, } #[get("/{uuid}")] @@ -22,17 +26,11 @@ pub async fn res( let uuid = path.into_inner().0; - let auth_header = get_auth_header(headers); + let auth_header = get_auth_header(headers)?; - if let Err(error) = auth_header { - return Ok(error); - } + let mut conn = data.pool.get().await?; - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; - - if let Err(error) = authorized { - return Ok(error); - } + check_access_token(auth_header, &mut conn).await?; let cache_result = data.get_cache_key(uuid.to_string()).await; @@ -42,25 +40,11 @@ pub async fn res( .body(cache_hit)); } - let row = sqlx::query_as(&format!( - "SELECT username, display_name FROM users WHERE uuid = '{}'", - uuid - )) - .fetch_one(&data.pool) - .await; - - if let Err(error) = row { - error!("{}", error); - return Ok(HttpResponse::InternalServerError().finish()); - } - - let (username, display_name): (String, Option) = row.unwrap(); - - let user = Response { - uuid: uuid.to_string(), - username, - display_name: display_name.unwrap_or_default(), - }; + let user: Response = dsl::users + .filter(dsl::uuid.eq(uuid)) + .select(Response::as_select()) + .get_result(&mut conn) + .await?; let cache_result = data .set_cache_key(uuid.to_string(), user.clone(), 1800) -- 2.47.2 From 49e08af3d97170c90d24f532c501bf1d8df71b97 Mon Sep 17 00:00:00 2001 From: Radical Date: Fri, 23 May 2025 12:57:08 +0200 Subject: [PATCH 16/17] feat: migrate to diesel and new error type in invites --- src/api/v1/invites/id.rs | 63 +++++++++------------------------------- 1 file changed, 14 insertions(+), 49 deletions(-) diff --git a/src/api/v1/invites/id.rs b/src/api/v1/invites/id.rs index 2adb8d8..67f10af 100644 --- a/src/api/v1/invites/id.rs +++ b/src/api/v1/invites/id.rs @@ -1,6 +1,7 @@ -use actix_web::{Error, HttpRequest, HttpResponse, get, post, web}; +use actix_web::{HttpRequest, HttpResponse, get, post, web}; use crate::{ + error::Error, Data, api::v1::auth::check_access_token, structs::{Guild, Invite, Member}, @@ -15,29 +16,17 @@ pub async fn get( ) -> Result { let headers = req.headers(); - let auth_header = get_auth_header(headers); + let auth_header = get_auth_header(headers)?; - if let Err(error) = auth_header { - return Ok(error); - } + let mut conn = data.pool.get().await?; + + check_access_token(auth_header, &mut conn).await?; let invite_id = path.into_inner().0; - let result = Invite::fetch_one(&data.pool, invite_id).await; + let invite = Invite::fetch_one(&mut conn, invite_id).await?; - if let Err(error) = result { - return Ok(error); - } - - let invite = result.unwrap(); - - let guild_result = Guild::fetch_one(&data.pool, invite.guild_uuid).await; - - if let Err(error) = guild_result { - return Ok(error); - } - - let guild = guild_result.unwrap(); + let guild = Guild::fetch_one(&mut conn, invite.guild_uuid).await?; Ok(HttpResponse::Ok().json(guild)) } @@ -50,43 +39,19 @@ pub async fn join( ) -> Result { let headers = req.headers(); - let auth_header = get_auth_header(headers); - - if let Err(error) = auth_header { - return Ok(error); - } + let auth_header = get_auth_header(headers)?; let invite_id = path.into_inner().0; - let authorized = check_access_token(auth_header.unwrap(), &data.pool).await; + let mut conn = data.pool.get().await?; - if let Err(error) = authorized { - return Ok(error); - } + let uuid = check_access_token(auth_header, &mut conn).await?; - let uuid = authorized.unwrap(); + let invite = Invite::fetch_one(&mut conn, invite_id).await?; - let result = Invite::fetch_one(&data.pool, invite_id).await; + let guild = Guild::fetch_one(&mut conn, invite.guild_uuid).await?; - if let Err(error) = result { - return Ok(error); - } - - let invite = result.unwrap(); - - let guild_result = Guild::fetch_one(&data.pool, invite.guild_uuid).await; - - if let Err(error) = guild_result { - return Ok(error); - } - - let guild = guild_result.unwrap(); - - let member = Member::new(&data.pool, uuid, guild.uuid).await; - - if let Err(error) = member { - return Ok(error); - } + Member::new(&mut conn, uuid, guild.uuid).await?; Ok(HttpResponse::Ok().json(guild)) } -- 2.47.2 From a670b32c86e99f8074d421d031187ebc8253ffc4 Mon Sep 17 00:00:00 2001 From: Radical Date: Fri, 23 May 2025 12:57:19 +0200 Subject: [PATCH 17/17] feat: migrate to diesel and new error type in stats --- src/api/v1/stats.rs | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/api/v1/stats.rs b/src/api/v1/stats.rs index 0ebf431..6ab8d64 100644 --- a/src/api/v1/stats.rs +++ b/src/api/v1/stats.rs @@ -1,31 +1,31 @@ use std::time::SystemTime; -use actix_web::{HttpResponse, Responder, get, web}; +use actix_web::{HttpResponse, get, web}; +use diesel::QueryDsl; +use diesel_async::RunQueryDsl; use serde::Serialize; +use crate::error::Error; use crate::Data; +use crate::schema::users::dsl::{users, uuid}; const VERSION: Option<&'static str> = option_env!("CARGO_PKG_VERSION"); #[derive(Serialize)] struct Response { - accounts: usize, + accounts: i64, uptime: u64, version: String, build_number: String, } #[get("/stats")] -pub async fn res(data: web::Data) -> impl Responder { - let accounts; - if let Ok(users) = sqlx::query("SELECT uuid FROM users") - .fetch_all(&data.pool) - .await - { - accounts = users.len(); - } else { - return HttpResponse::InternalServerError().finish(); - } +pub async fn res(data: web::Data) -> Result { + let accounts: i64 = users + .select(uuid) + .count() + .get_result(&mut data.pool.get().await?) + .await?; let response = Response { // TODO: Get number of accounts from db @@ -39,5 +39,5 @@ pub async fn res(data: web::Data) -> impl Responder { build_number: String::from("how do i implement this?"), }; - HttpResponse::Ok().json(response) + Ok(HttpResponse::Ok().json(response)) } -- 2.47.2