From 8656115dc9b3f0a6527782f1484d828438521bc7 Mon Sep 17 00:00:00 2001 From: JustTemmie <47639983+JustTemmie@users.noreply.github.com> Date: Mon, 14 Jul 2025 00:36:15 +0200 Subject: [PATCH 01/19] feat: start implementing device name generation in the backend --- Cargo.toml | 1 + src/api/v1/auth/login.rs | 12 +- src/api/v1/auth/mod.rs | 1 + src/api/v1/auth/refresh.rs | 11 +- src/api/v1/auth/register.rs | 25 +- src/main.rs | 1 + src/word_list.rs | 1005 +++++++++++++++++++++++++++++++++++ 7 files changed, 1034 insertions(+), 22 deletions(-) create mode 100644 src/word_list.rs diff --git a/Cargo.toml b/Cargo.toml index c1c71bc..43e5ea5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,5 +60,6 @@ regex = "1.11" random-string = "1.1" lettre = { version = "0.11", features = ["tokio1", "tokio1-native-tls"] } chrono = { version = "0.4.41", features = ["serde"] } +rand = "0.9.1" diff --git a/src/api/v1/auth/login.rs b/src/api/v1/auth/login.rs index 2faaeb4..cef4726 100644 --- a/src/api/v1/auth/login.rs +++ b/src/api/v1/auth/login.rs @@ -7,10 +7,7 @@ use diesel_async::RunQueryDsl; use serde::Deserialize; use crate::{ - Data, - error::Error, - schema::*, - utils::{PASSWORD_REGEX, generate_token, new_refresh_token_cookie, user_uuid_from_identifier}, + error::Error, schema::*, utils::{generate_token, new_refresh_token_cookie, user_uuid_from_identifier, PASSWORD_REGEX}, generate_device_name::generate_device_name, Data }; use super::Response; @@ -19,7 +16,6 @@ use super::Response; struct LoginInformation { username: String, password: String, - device_name: String, } #[post("/login")] @@ -63,12 +59,14 @@ pub async fn response( use refresh_tokens::dsl as rdsl; + let device_name = generate_device_name(); + insert_into(refresh_tokens::table) .values(( rdsl::token.eq(&refresh_token), rdsl::uuid.eq(uuid), rdsl::created_at.eq(current_time), - rdsl::device_name.eq(&login_information.device_name), + rdsl::device_name.eq(&device_name), )) .execute(&mut conn) .await?; @@ -87,5 +85,5 @@ pub async fn response( Ok(HttpResponse::Ok() .cookie(new_refresh_token_cookie(&data.config, refresh_token)) - .json(Response { access_token })) + .json(Response { access_token, device_name })) } diff --git a/src/api/v1/auth/mod.rs b/src/api/v1/auth/mod.rs index 0e6b006..947e7aa 100644 --- a/src/api/v1/auth/mod.rs +++ b/src/api/v1/auth/mod.rs @@ -20,6 +20,7 @@ mod verify_email; #[derive(Serialize)] struct Response { access_token: String, + device_name: String, } pub fn web() -> Scope { diff --git a/src/api/v1/auth/refresh.rs b/src/api/v1/auth/refresh.rs index abd9a34..e9a444b 100644 --- a/src/api/v1/auth/refresh.rs +++ b/src/api/v1/auth/refresh.rs @@ -77,6 +77,15 @@ pub async fn res(req: HttpRequest, data: web::Data) -> Result()?; + let device_name: String; + + // fix me tomorrow + // let devices: Vec = dsl::refresh_tokens + // .filter(dsl::uuid.eq(uuid)) + // .select(Device::as_select()) + // .get_results(&mut conn) + // .await?; + update(access_tokens::table) .filter(dsl::refresh_token.eq(&refresh_token)) .set(( @@ -88,7 +97,7 @@ pub async fn res(req: HttpRequest, data: web::Data) -> Result String { + let adjective_index = rand::rng().random_range(0..ADJECTIVES_LENGTH-1); + let animal_index = rand::rng().random_range(0..ANIMALS_LENGTH-1); + + return [ADJECTIVES[adjective_index], ANIMALS[animal_index]].join(" ") +} + +const ANIMALS_LENGTH: usize = 223; +const ADJECTIVES_LENGTH: usize = 765; + +const ANIMALS: [&'static str; ANIMALS_LENGTH] = [ + "Aardvark", + "Albatross", + "Alligator", + "Alpaca", + "Ant", + "Anteater", + "Antelope", + "Ape", + "Armadillo", + "Donkey", + "Baboon", + "Badger", + "Barracuda", + "Bat", + "Bear", + "Beaver", + "Bee", + "Bison", + "Boar", + "Buffalo", + "Butterfly", + "Camel", + "Capybara", + "Caribou", + "Cassowary", + "Cat", + "Caterpillar", + "Cattle", + "Chamois", + "Cheetah", + "Chicken", + "Chimpanzee", + "Chinchilla", + "Chough", + "Clam", + "Cobra", + "Cockroach", + "Cod", + "Cormorant", + "Coyote", + "Crab", + "Crane", + "Crocodile", + "Crow", + "Curlew", + "Deer", + "Dinosaur", + "Dog", + "Dogfish", + "Dolphin", + "Dotterel", + "Dove", + "Dragonfly", + "Duck", + "Dugong", + "Dunlin", + "Eagle", + "Echidna", + "Eel", + "Eland", + "Elephant", + "Elk", + "Emu", + "Falcon", + "Ferret", + "Finch", + "Fish", + "Flamingo", + "Fly", + "Fox", + "Frog", + "Gaur", + "Gazelle", + "Gerbil", + "Giraffe", + "Gnat", + "Gnu", + "Goat", + "Goldfinch", + "Goldfish", + "Goose", + "Gorilla", + "Goshawk", + "Grasshopper", + "Grouse", + "Guanaco", + "Gull", + "Hamster", + "Hare", + "Hawk", + "Hedgehog", + "Heron", + "Herring", + "Hippopotamus", + "Hornet", + "Horse", + "Hummingbird", + "Hyena", + "Ibex", + "Ibis", + "Jackal", + "Jaguar", + "Jay", + "Jellyfish", + "Kangaroo", + "Kingfisher", + "Koala", + "Kookabura", + "Kouprey", + "Kudu", + "Lapwing", + "Lark", + "Lemur", + "Leopard", + "Lion", + "Llama", + "Lobster", + "Locust", + "Loris", + "Louse", + "Lyrebird", + "Magpie", + "Mallard", + "Manatee", + "Mandrill", + "Mantis", + "Marten", + "Meerkat", + "Mink", + "Mole", + "Mongoose", + "Monkey", + "Moose", + "Mosquito", + "Mouse", + "Mule", + "Narwhal", + "Newt", + "Nightingale", + "Octopus", + "Okapi", + "Opossum", + "Oryx", + "Ostrich", + "Otter", + "Owl", + "Oyster", + "Panther", + "Parrot", + "Partridge", + "Peafowl", + "Pelican", + "Penguin", + "Pheasant", + "Pig", + "Pigeon", + "Pony", + "Porcupine", + "Porpoise", + "Quail", + "Quelea", + "Quetzal", + "Rabbit", + "Raccoon", + "Rail", + "Ram", + "Rat", + "Raven", + "Red deer", + "Red panda", + "Reindeer", + "Rhinoceros", + "Rook", + "Salamander", + "Salmon", + "Sand Dollar", + "Sandpiper", + "Sardine", + "Scorpion", + "Seahorse", + "Seal", + "Shark", + "Sheep", + "Shrew", + "Skunk", + "Snail", + "Snake", + "Sparrow", + "Spider", + "Spoonbill", + "Squid", + "Squirrel", + "Starling", + "Stingray", + "Stinkbug", + "Stork", + "Swallow", + "Swan", + "Tapir", + "Tarsier", + "Termite", + "Tiger", + "Toad", + "Trout", + "Turkey", + "Turtle", + "Viper", + "Vulture", + "Wallaby", + "Walrus", + "Wasp", + "Weasel", + "Whale", + "Wildcat", + "Wolf", + "Wolverine", + "Wombat", + "Woodcock", + "Woodpecker", + "Worm", + "Wren", + "Yak", + "Zebra", +]; + +const ADJECTIVES: [&'static str; ADJECTIVES_LENGTH] = [ + "other", + "such", + "first", + "many", + "new", + "more", + "same", + "own", + "good", + "different", + "great", + "long", + "high", + "social", + "little", + "much", + "important", + "small", + "most", + "large", + "old", + "few", + "general", + "second", + "public", + "last", + "several", + "early", + "certain", + "economic", + "least", + "common", + "present", + "next", + "local", + "best", + "particular", + "young", + "various", + "necessary", + "whole", + "only", + "true", + "able", + "major", + "full", + "low", + "available", + "real", + "similar", + "total", + "special", + "less", + "short", + "specific", + "single", + "self", + "national", + "individual", + "clear", + "personal", + "higher", + "better", + "third", + "natural", + "greater", + "open", + "difficult", + "current", + "further", + "main", + "physical", + "foreign", + "lower", + "strong", + "private", + "likely", + "international", + "significant", + "late", + "basic", + "hard", + "modern", + "simple", + "normal", + "sure", + "central", + "original", + "effective", + "following", + "direct", + "final", + "cultural", + "big", + "recent", + "complete", + "financial", + "positive", + "primary", + "appropriate", + "legal", + "european", + "equal", + "larger", + "average", + "historical", + "critical", + "wide", + "traditional", + "additional", + "active", + "complex", + "former", + "independent", + "entire", + "actual", + "close", + "constant", + "previous", + "easy", + "serious", + "potential", + "fine", + "industrial", + "subject", + "future", + "internal", + "initial", + "well", + "essential", + "dark", + "popular", + "successful", + "standard", + "year", + "past", + "ready", + "professional", + "wrong", + "very", + "proper", + "separate", + "heavy", + "civil", + "responsible", + "considerable", + "light", + "cold", + "above", + "older", + "practical", + "external", + "sufficient", + "interesting", + "upper", + "scientific", + "key", + "annual", + "limited", + "smaller", + "southern", + "earlier", + "commercial", + "powerful", + "later", + "like", + "clinical", + "ancient", + "educational", + "typical", + "technical", + "environmental", + "formal", + "aware", + "beautiful", + "variable", + "obvious", + "secondary", + "enough", + "urban", + "regular", + "relevant", + "greatest", + "spiritual", + "time", + "double", + "happy", + "term", + "multiple", + "dependent", + "correct", + "northern", + "middle", + "rural", + "official", + "fundamental", + "numerous", + "overall", + "usual", + "native", + "regional", + "highest", + "north", + "agricultural", + "literary", + "broad", + "perfect", + "experimental", + "fourth", + "global", + "ordinary", + "related", + "apparent", + "daily", + "principal", + "contemporary", + "severe", + "reasonable", + "subsequent", + "worth", + "longer", + "emotional", + "intellectual", + "unique", + "pure", + "familiar", + "american", + "solid", + "brief", + "famous", + "fresh", + "day", + "corresponding", + "characteristic", + "maximum", + "detailed", + "outside", + "theoretical", + "fair", + "opposite", + "capable", + "visual", + "interested", + "joint", + "adequate", + "based", + "substantial", + "unable", + "structural", + "soft", + "false", + "largest", + "inner", + "mean", + "extensive", + "excellent", + "rapid", + "absolute", + "consistent", + "continuous", + "administrative", + "strange", + "willing", + "alternative", + "slow", + "distinct", + "safe", + "permanent", + "front", + "corporate", + "academic", + "thin", + "nineteenth", + "universal", + "functional", + "unknown", + "careful", + "narrow", + "evident", + "sound", + "classical", + "minor", + "weak", + "suitable", + "chief", + "extreme", + "yellow", + "warm", + "mixed", + "flat", + "huge", + "vast", + "stable", + "valuable", + "rare", + "visible", + "sensitive", + "mechanical", + "state", + "radical", + "extra", + "superior", + "conventional", + "thick", + "dominant", + "post", + "collective", + "younger", + "efficient", + "linear", + "organic", + "oral", + "century", + "creative", + "vertical", + "dynamic", + "empty", + "minimum", + "cognitive", + "logical", + "afraid", + "equivalent", + "quick", + "near", + "concrete", + "mass", + "acute", + "sharp", + "easier", + "quiet", + "adult", + "accurate", + "ideal", + "partial", + "bright", + "identical", + "conservative", + "magnetic", + "frequent", + "electronic", + "fixed", + "square", + "cross", + "clean", + "back", + "organizational", + "constitutional", + "genetic", + "ultimate", + "secret", + "vital", + "dramatic", + "objective", + "round", + "alive", + "straight", + "unusual", + "rational", + "electric", + "mutual", + "class", + "competitive", + "revolutionary", + "statistical", + "random", + "musical", + "crucial", + "racial", + "sudden", + "acid", + "content", + "temporary", + "line", + "remarkable", + "exact", + "valid", + "helpful", + "nice", + "comprehensive", + "united", + "level", + "fifth", + "nervous", + "expensive", + "prominent", + "healthy", + "liquid", + "institutional", + "silent", + "sweet", + "strategic", + "molecular", + "comparative", + "called", + "electrical", + "raw", + "acceptable", + "scale", + "violent", + "all", + "desirable", + "tall", + "steady", + "wonderful", + "sub", + "distant", + "progressive", + "enormous", + "horizontal", + "and", + "intense", + "smooth", + "applicable", + "over", + "animal", + "abstract", + "wise", + "worst", + "gold", + "precise", + "legislative", + "remote", + "technological", + "outer", + "uniform", + "slight", + "attractive", + "evil", + "tiny", + "royal", + "angry", + "advanced", + "friendly", + "dear", + "busy", + "spatial", + "rough", + "primitive", + "judicial", + "systematic", + "lateral", + "sorry", + "plain", + "off", + "comfortable", + "definite", + "massive", + "firm", + "widespread", + "prior", + "twentieth", + "mathematical", + "verbal", + "marginal", + "excessive", + "stronger", + "gross", + "world", + "productive", + "wider", + "glad", + "linguistic", + "patient", + "symbolic", + "earliest", + "plastic", + "type", + "prime", + "eighteenth", + "blind", + "neutral", + "guilty", + "hand", + "extraordinary", + "metal", + "surprising", + "fellow", + "york", + "grand", + "thermal", + "artificial", + "five", + "lowest", + "genuine", + "dimensional", + "optical", + "unlikely", + "developmental", + "reliable", + "executive", + "comparable", + "satisfactory", + "golden", + "diverse", + "preliminary", + "wooden", + "noble", + "part", + "striking", + "cool", + "classic", + "elderly", + "four", + "temporal", + "indirect", + "romantic", + "intermediate", + "differential", + "passive", + "life", + "voluntary", + "out", + "adjacent", + "behavioral", + "exclusive", + "closed", + "inherent", + "inevitable", + "complicated", + "quantitative", + "respective", + "artistic", + "probable", + "anxious", + "informal", + "strict", + "fiscal", + "ideological", + "profound", + "extended", + "eternal", + "known", + "infinite", + "proud", + "honest", + "peculiar", + "absent", + "pleasant", + "optimal", + "renal", + "static", + "outstanding", + "presidential", + "digital", + "integrated", + "legitimate", + "curious", + "aggressive", + "deeper", + "elementary", + "history", + "surgical", + "occasional", + "flexible", + "convenient", + "solar", + "atomic", + "isolated", + "latest", + "sad", + "conceptual", + "underlying", + "everyday", + "cost", + "intensive", + "odd", + "subjective", + "mid", + "worthy", + "pale", + "meaningful", + "therapeutic", + "making", + "circular", + "realistic", + "multi", + "child", + "sophisticated", + "down", + "leading", + "intelligent", + "governmental", + "numerical", + "minimal", + "diagnostic", + "indigenous", + "aesthetic", + "distinctive", + "operational", + "sole", + "material", + "fast", + "bitter", + "broader", + "brilliant", + "peripheral", + "rigid", + "automatic", + "lesser", + "routine", + "favorable", + "cooperative", + "cardiac", + "arbitrary", + "loose", + "favorite", + "subtle", + "uncertain", + "hostile", + "monthly", + "naval", + "physiological", + "historic", + "developed", + "skilled", + "anterior", + "pro", + "gentle", + "loud", + "pulmonary", + "innocent", + "provincial", + "mild", + "page", + "specialized", + "bare", + "excess", + "inter", + "shaped", + "theological", + "sensory", + "the", + "stress", + "novel", + "working", + "shorter", + "secular", + "geographical", + "intimate", + "liable", + "selective", + "influential", + "modest", + "successive", + "continued", + "water", + "expert", + "municipal", + "marine", + "thirty", + "adverse", + "wacky", + "closer", + "virtual", + "peaceful", + "mobile", + "sixth", + "immune", + "coastal", + "representative", + "lead", + "forward", + "faithful", + "crystal", + "protective", + "elaborate", + "tremendous", + "welcoming", + "abnormal", + "grateful", + "proportional", + "dual", + "operative", + "precious", + "sympathetic", + "accessible", + "lovely", + "spinal", + "even", + "marked", + "observed", + "point", + "mature", + "competent", + "residential", + "impressive", + "unexpected", + "nearby", + "unnecessary", + "generous", + "cerebral", + "unpublished", + "delicate", + "analytical", + "tropical", + "statutory", + "cell", + "weekly", + "end", + "online", + "beneficial", + "aged", + "tough", + "eager", + "ongoing", + "silver", + "persistent", + "calm", + "nearest", + "hidden", + "magic", + "pretty", + "wealthy", + "exciting", + "decisive", + "confident", + "invisible", + "notable", + "medium", + "manual", + "select", + "thorough", + "causal", + "giant", + "bigger", + "pink", + "improved", + "immense", + "hour", + "intact", + "grade", + "dense", + "hungry", + "biggest", + "abundant", + "handsome", + "retail", + "insufficient", + "irregular", + "intrinsic", + "residual", + "follow", + "fluid", + "mysterious", + "descriptive", + "elastic", + "destructive", + "architectural", + "synthetic", + "continental", + "evolutionary", + "lucky", + "bold", + "funny", + "peak", + "smallest", + "reluctant", + "suspicious", + "smart", + "mighty", + "brave", + "humble", + "vocal", + "obscure", + "innovative", +]; \ No newline at end of file -- 2.47.3 From e7bc53f8588527680505e9dcddaaf3167409aefc Mon Sep 17 00:00:00 2001 From: JustTemmie <47639983+JustTemmie@users.noreply.github.com> Date: Mon, 14 Jul 2025 01:02:03 +0200 Subject: [PATCH 02/19] feat: try reading the device name from the table --- src/api/v1/auth/refresh.rs | 27 +++++++------------ src/{word_list.rs => generate_device_name.rs} | 0 2 files changed, 9 insertions(+), 18 deletions(-) rename src/{word_list.rs => generate_device_name.rs} (100%) diff --git a/src/api/v1/auth/refresh.rs b/src/api/v1/auth/refresh.rs index e9a444b..a89efbc 100644 --- a/src/api/v1/auth/refresh.rs +++ b/src/api/v1/auth/refresh.rs @@ -5,13 +5,10 @@ use log::error; use std::time::{SystemTime, UNIX_EPOCH}; use crate::{ - Data, - error::Error, - schema::{ + error::Error, schema::{ access_tokens::{self, dsl}, - refresh_tokens::{self, dsl as rdsl}, - }, - utils::{generate_token, new_refresh_token_cookie}, + refresh_tokens::{self, device_name, dsl as rdsl}, + }, utils::{generate_token, new_refresh_token_cookie}, Data }; use super::Response; @@ -53,6 +50,7 @@ pub async fn res(req: HttpRequest, data: web::Data) -> Result 1987200 { let new_refresh_token = generate_token::<32>()?; @@ -63,11 +61,13 @@ pub async fn res(req: HttpRequest, data: web::Data) -> Result(&mut conn) .await { - Ok(_) => { + Ok(device_name) => { refresh_token = new_refresh_token; + existing_device_name = device_name.to_string(); } Err(error) => { error!("{error}"); @@ -77,15 +77,6 @@ pub async fn res(req: HttpRequest, data: web::Data) -> Result()?; - let device_name: String; - - // fix me tomorrow - // let devices: Vec = dsl::refresh_tokens - // .filter(dsl::uuid.eq(uuid)) - // .select(Device::as_select()) - // .get_results(&mut conn) - // .await?; - update(access_tokens::table) .filter(dsl::refresh_token.eq(&refresh_token)) .set(( @@ -97,7 +88,7 @@ pub async fn res(req: HttpRequest, data: web::Data) -> Result Date: Tue, 15 Jul 2025 02:30:07 +0200 Subject: [PATCH 03/19] fix: increase length of refresh token field --- .../2025-07-15-002434_increase_device_name_length/down.sql | 2 ++ migrations/2025-07-15-002434_increase_device_name_length/up.sql | 2 ++ 2 files changed, 4 insertions(+) create mode 100644 migrations/2025-07-15-002434_increase_device_name_length/down.sql create mode 100644 migrations/2025-07-15-002434_increase_device_name_length/up.sql diff --git a/migrations/2025-07-15-002434_increase_device_name_length/down.sql b/migrations/2025-07-15-002434_increase_device_name_length/down.sql new file mode 100644 index 0000000..4fe6628 --- /dev/null +++ b/migrations/2025-07-15-002434_increase_device_name_length/down.sql @@ -0,0 +1,2 @@ +-- This file should undo anything in `up.sql` +ALTER TABLE refresh_tokens ALTER COLUMN device_name TYPE varchar(16); \ No newline at end of file diff --git a/migrations/2025-07-15-002434_increase_device_name_length/up.sql b/migrations/2025-07-15-002434_increase_device_name_length/up.sql new file mode 100644 index 0000000..9d44298 --- /dev/null +++ b/migrations/2025-07-15-002434_increase_device_name_length/up.sql @@ -0,0 +1,2 @@ +-- Your SQL goes here +ALTER TABLE refresh_tokens ALTER COLUMN device_name TYPE varchar(64); \ No newline at end of file -- 2.47.3 From fc061738fa74d1d5ead22ec034d7ac01e8119f3b Mon Sep 17 00:00:00 2001 From: JustTemmie <47639983+JustTemmie@users.noreply.github.com> Date: Tue, 15 Jul 2025 02:42:53 +0200 Subject: [PATCH 04/19] feat: finish adding device name to login, register, and refresh endpoints --- src/api/v1/auth/login.rs | 2 +- src/api/v1/auth/refresh.rs | 10 +++++----- src/api/v1/auth/register.rs | 4 ++-- src/main.rs | 2 +- src/schema.rs | 2 +- src/utils.rs | 14 +++++++++----- src/{generate_device_name.rs => wordlist.rs} | 16 ++-------------- 7 files changed, 21 insertions(+), 29 deletions(-) rename src/{generate_device_name.rs => wordlist.rs} (96%) diff --git a/src/api/v1/auth/login.rs b/src/api/v1/auth/login.rs index cef4726..b2f3180 100644 --- a/src/api/v1/auth/login.rs +++ b/src/api/v1/auth/login.rs @@ -7,7 +7,7 @@ use diesel_async::RunQueryDsl; use serde::Deserialize; use crate::{ - error::Error, schema::*, utils::{generate_token, new_refresh_token_cookie, user_uuid_from_identifier, PASSWORD_REGEX}, generate_device_name::generate_device_name, Data + error::Error, schema::*, utils::{generate_device_name, generate_token, new_refresh_token_cookie, user_uuid_from_identifier, PASSWORD_REGEX}, Data }; use super::Response; diff --git a/src/api/v1/auth/refresh.rs b/src/api/v1/auth/refresh.rs index a89efbc..90728de 100644 --- a/src/api/v1/auth/refresh.rs +++ b/src/api/v1/auth/refresh.rs @@ -7,7 +7,7 @@ use std::time::{SystemTime, UNIX_EPOCH}; use crate::{ error::Error, schema::{ access_tokens::{self, dsl}, - refresh_tokens::{self, device_name, dsl as rdsl}, + refresh_tokens::{self, dsl as rdsl}, }, utils::{generate_token, new_refresh_token_cookie}, Data }; @@ -50,7 +50,7 @@ pub async fn res(req: HttpRequest, data: web::Data) -> Result 1987200 { let new_refresh_token = generate_token::<32>()?; @@ -65,9 +65,9 @@ pub async fn res(req: HttpRequest, data: web::Data) -> Result(&mut conn) .await { - Ok(device_name) => { + Ok(existing_device_name) => { refresh_token = new_refresh_token; - existing_device_name = device_name.to_string(); + device_name = existing_device_name; } Err(error) => { error!("{error}"); @@ -88,7 +88,7 @@ pub async fn res(req: HttpRequest, data: web::Data) -> Result Varchar, uuid -> Uuid, created_at -> Int8, - #[max_length = 16] + #[max_length = 64] device_name -> Varchar, } } diff --git a/src/utils.rs b/src/utils.rs index 072143f..d058db3 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,4 +1,5 @@ use std::sync::LazyLock; +use rand::Rng; use actix_web::{ cookie::{Cookie, SameSite, time::Duration}, @@ -16,11 +17,7 @@ use serde::Serialize; use uuid::Uuid; use crate::{ - Conn, Data, - config::Config, - error::Error, - objects::{HasIsAbove, HasUuid}, - schema::users, + config::Config, error::Error, objects::{HasIsAbove, HasUuid}, schema::users, wordlist::{ADJECTIVES, ANIMALS}, Conn, Data }; pub static EMAIL_REGEX: LazyLock = LazyLock::new(|| { @@ -282,3 +279,10 @@ impl Data { .await } } + +pub fn generate_device_name() -> String { + let adjective_index = rand::rng().random_range(0..ADJECTIVES.len()-1); + let animal_index = rand::rng().random_range(0..ANIMALS.len()-1); + + return [ADJECTIVES[adjective_index], ANIMALS[animal_index]].join(" ") +} \ No newline at end of file diff --git a/src/generate_device_name.rs b/src/wordlist.rs similarity index 96% rename from src/generate_device_name.rs rename to src/wordlist.rs index 1992f95..3ca3c3c 100644 --- a/src/generate_device_name.rs +++ b/src/wordlist.rs @@ -1,16 +1,4 @@ -use rand::Rng; - -pub fn generate_device_name() -> String { - let adjective_index = rand::rng().random_range(0..ADJECTIVES_LENGTH-1); - let animal_index = rand::rng().random_range(0..ANIMALS_LENGTH-1); - - return [ADJECTIVES[adjective_index], ANIMALS[animal_index]].join(" ") -} - -const ANIMALS_LENGTH: usize = 223; -const ADJECTIVES_LENGTH: usize = 765; - -const ANIMALS: [&'static str; ANIMALS_LENGTH] = [ +pub const ANIMALS: [&'static str; 223] = [ "Aardvark", "Albatross", "Alligator", @@ -236,7 +224,7 @@ const ANIMALS: [&'static str; ANIMALS_LENGTH] = [ "Zebra", ]; -const ADJECTIVES: [&'static str; ADJECTIVES_LENGTH] = [ +pub const ADJECTIVES: [&'static str; 765] = [ "other", "such", "first", -- 2.47.3 From 324137ce8bc9eeadb08c3b149a7f122d7aa2e68f Mon Sep 17 00:00:00 2001 From: Radical Date: Wed, 16 Jul 2025 16:36:22 +0200 Subject: [PATCH 05/19] refactor: rewrite entire codebase in axum instead of actix Replaces actix with axum for web, allows us to use socket.io and gives us access to the tower ecosystem of middleware breaks compatibility with our current websocket implementation, needs to be reimplemented for socket.io --- Cargo.toml | 17 ++-- src/api/mod.rs | 15 +-- src/api/v1/auth/devices.rs | 30 +++--- src/api/v1/auth/login.rs | 56 +++++++---- src/api/v1/auth/logout.rs | 56 ++++++++--- src/api/v1/auth/mod.rs | 42 ++++---- src/api/v1/auth/refresh.rs | 107 +++++++++++++++----- src/api/v1/auth/register.rs | 110 ++++++++++++--------- src/api/v1/auth/reset_password.rs | 49 ++++++---- src/api/v1/auth/revoke.rs | 37 +++---- src/api/v1/auth/verify_email.rs | 68 ++++++------- src/api/v1/channels/mod.rs | 28 ++++-- src/api/v1/channels/uuid/messages.rs | 48 ++++----- src/api/v1/channels/uuid/mod.rs | 110 ++++++++++----------- src/api/v1/guilds/mod.rs | 74 +++++++------- src/api/v1/guilds/uuid/channels.rs | 100 +++++++++---------- src/api/v1/guilds/uuid/icon.rs | 62 ------------ src/api/v1/guilds/uuid/invites/mod.rs | 69 +++++++------ src/api/v1/guilds/uuid/members.rs | 45 +++++---- src/api/v1/guilds/uuid/mod.rs | 121 ++++++++++++++++------- src/api/v1/guilds/uuid/roles/mod.rs | 76 +++++++-------- src/api/v1/guilds/uuid/roles/uuid.rs | 52 +++++----- src/api/v1/invites/id.rs | 54 ++++++----- src/api/v1/invites/mod.rs | 15 ++- src/api/v1/me/friends/mod.rs | 57 ++++++----- src/api/v1/me/friends/uuid.rs | 41 ++++---- src/api/v1/me/guilds.rs | 32 +++--- src/api/v1/me/mod.rs | 134 ++++++++++++++------------ src/api/v1/mod.rs | 24 +++-- src/api/v1/stats.rs | 21 ++-- src/api/v1/users/mod.rs | 49 ++++++---- src/api/v1/users/uuid.rs | 42 ++++---- src/api/versions.rs | 7 +- src/error.rs | 70 +++++++++----- src/main.rs | 84 ++++++++-------- src/objects/channel.rs | 67 +++++++------ src/objects/email_token.rs | 28 +++--- src/objects/guild.rs | 6 +- src/objects/me.rs | 84 +++++++++------- src/objects/member.rs | 38 +++++--- src/objects/message.rs | 6 +- src/objects/mod.rs | 31 ++++++ src/objects/password_reset_token.rs | 77 ++++++++------- src/objects/role.rs | 17 +++- src/objects/user.rs | 17 ++-- src/socket.rs | 26 +++++ src/utils.rs | 111 +++++---------------- 47 files changed, 1381 insertions(+), 1129 deletions(-) delete mode 100644 src/api/v1/guilds/uuid/icon.rs create mode 100644 src/socket.rs diff --git a/Cargo.toml b/Cargo.toml index c1c71bc..aef435f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,6 @@ thiserror = "2.0.12" # CLI clap = { version = "4.5", features = ["derive"] } log = "0.4" -simple_logger = "5.0.0" # async futures = "0.3" @@ -30,19 +29,21 @@ futures-util = "0.3.31" # Data (de)serialization serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -toml = "0.8" +toml = "0.9" +bytes = "1.10.1" +rmpv = { version = "1.3.0", features = ["with-serde"] } # File Storage bindet = "0.3.2" bunny-api-tokio = { version = "0.4", features = ["edge_storage"], default-features = false } # Web Server -actix-web = "4.11" -actix-cors = "0.7.1" -actix-ws = "0.3.0" -actix-multipart = "0.7.2" +axum = { version = "0.8.4", features = ["macros", "multipart"] } +tower-http = { version = "0.6.6", features = ["cors"] } +axum-extra = { version = "0.10.1", features = ["cookie", "typed-header"] } +socketioxide = { version = "0.17.2", features = ["state"] } url = { version = "2.5", features = ["serde"] } -tokio-tungstenite = { version = "0.27", features = ["native-tls", "url"] } +time = "0.3.41" # Database uuid = { version = "1.17", features = ["serde", "v7"] } @@ -60,5 +61,5 @@ regex = "1.11" random-string = "1.1" lettre = { version = "0.11", features = ["tokio1", "tokio1-native-tls"] } chrono = { version = "0.4.41", features = ["serde"] } - +tracing-subscriber = "0.3.19" diff --git a/src/api/mod.rs b/src/api/mod.rs index 6d83e02..e4c3f2e 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,13 +1,16 @@ //! `/api` Contains the entire API -use actix_web::Scope; -use actix_web::web; +use std::sync::Arc; + +use axum::{Router, routing::get}; + +use crate::AppState; mod v1; mod versions; -pub fn web(path: &str) -> Scope { - web::scope(path.trim_end_matches('/')) - .service(v1::web()) - .service(versions::get) +pub fn router() -> Router> { + Router::new() + .route("/versions", get(versions::versions)) + .nest("/v1", v1::router()) } diff --git a/src/api/v1/auth/devices.rs b/src/api/v1/auth/devices.rs index 532ad00..a3c12d1 100644 --- a/src/api/v1/auth/devices.rs +++ b/src/api/v1/auth/devices.rs @@ -1,16 +1,21 @@ //! `/api/v1/auth/devices` Returns list of logged in devices -use actix_web::{HttpRequest, HttpResponse, get, web}; +use std::sync::Arc; + +use axum::{Json, extract::State, http::StatusCode, response::IntoResponse}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; use diesel::{ExpressionMethods, QueryDsl, Queryable, Selectable, SelectableHelper}; use diesel_async::RunQueryDsl; use serde::Serialize; use crate::{ - Data, + AppState, api::v1::auth::check_access_token, error::Error, schema::refresh_tokens::{self, dsl}, - utils::get_auth_header, }; #[derive(Serialize, Selectable, Queryable)] @@ -18,7 +23,7 @@ use crate::{ #[diesel(check_for_backend(diesel::pg::Pg))] struct Device { device_name: String, - created_at: i64 + created_at: i64, } /// `GET /api/v1/auth/devices` Returns list of logged in devices @@ -35,18 +40,13 @@ struct Device { /// /// ]); /// ``` -#[get("/devices")] pub async fn get( - req: HttpRequest, - data: web::Data, -) -> Result { - let headers = req.headers(); + State(app_state): State>, + TypedHeader(auth): TypedHeader>, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; - - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; + let uuid = check_access_token(auth.token(), &mut conn).await?; let devices: Vec = dsl::refresh_tokens .filter(dsl::uuid.eq(uuid)) @@ -54,5 +54,5 @@ pub async fn get( .get_results(&mut conn) .await?; - Ok(HttpResponse::Ok().json(devices)) + Ok((StatusCode::OK, Json(devices))) } diff --git a/src/api/v1/auth/login.rs b/src/api/v1/auth/login.rs index 2faaeb4..2391fdf 100644 --- a/src/api/v1/auth/login.rs +++ b/src/api/v1/auth/login.rs @@ -1,39 +1,47 @@ -use std::time::{SystemTime, UNIX_EPOCH}; +use std::{ + sync::Arc, + time::{SystemTime, UNIX_EPOCH}, +}; -use actix_web::{HttpResponse, post, web}; use argon2::{PasswordHash, PasswordVerifier}; +use axum::{ + Json, + extract::State, + http::{HeaderValue, StatusCode}, + response::IntoResponse, +}; use diesel::{ExpressionMethods, QueryDsl, dsl::insert_into}; use diesel_async::RunQueryDsl; use serde::Deserialize; use crate::{ - Data, + AppState, error::Error, schema::*, - utils::{PASSWORD_REGEX, generate_token, new_refresh_token_cookie, user_uuid_from_identifier}, + utils::{ + PASSWORD_REGEX, generate_token, new_access_token_cookie, new_refresh_token_cookie, + user_uuid_from_identifier, + }, }; -use super::Response; - #[derive(Deserialize)] -struct LoginInformation { +pub struct LoginInformation { username: String, password: String, device_name: String, } -#[post("/login")] pub async fn response( - login_information: web::Json, - data: web::Data, -) -> Result { + State(app_state): State>, + Json(login_information): Json, +) -> Result { if !PASSWORD_REGEX.is_match(&login_information.password) { - return Ok(HttpResponse::Forbidden().json(r#"{ "password_hashed": false }"#)); + return Err(Error::BadRequest("Bad password".to_string())); } use users::dsl; - let mut conn = data.pool.get().await?; + let mut conn = app_state.pool.get().await?; let uuid = user_uuid_from_identifier(&mut conn, &login_information.username).await?; @@ -46,7 +54,7 @@ pub async fn response( let parsed_hash = PasswordHash::new(&database_password) .map_err(|e| Error::PasswordHashError(e.to_string()))?; - if data + if app_state .argon2 .verify_password(login_information.password.as_bytes(), &parsed_hash) .is_err() @@ -85,7 +93,21 @@ pub async fn response( .execute(&mut conn) .await?; - Ok(HttpResponse::Ok() - .cookie(new_refresh_token_cookie(&data.config, refresh_token)) - .json(Response { access_token })) + let mut response = StatusCode::OK.into_response(); + + response.headers_mut().insert( + "Set-Cookie", + HeaderValue::from_str( + &new_refresh_token_cookie(&app_state.config, refresh_token).to_string(), + )?, + ); + + response.headers_mut().insert( + "Set-Cookie2", + HeaderValue::from_str( + &new_access_token_cookie(&app_state.config, access_token).to_string(), + )?, + ); + + Ok(response) } diff --git a/src/api/v1/auth/logout.rs b/src/api/v1/auth/logout.rs index b805d91..6e5e98d 100644 --- a/src/api/v1/auth/logout.rs +++ b/src/api/v1/auth/logout.rs @@ -1,9 +1,16 @@ -use actix_web::{HttpRequest, HttpResponse, get, web}; +use std::sync::Arc; + +use axum::{ + extract::State, + http::{HeaderValue, StatusCode}, + response::IntoResponse, +}; +use axum_extra::extract::CookieJar; use diesel::{ExpressionMethods, delete}; use diesel_async::RunQueryDsl; use crate::{ - Data, + AppState, error::Error, schema::refresh_tokens::{self, dsl}, }; @@ -20,28 +27,49 @@ use crate::{ /// /// 401 Unauthorized (no refresh token found) /// -#[get("/logout")] -pub async fn res(req: HttpRequest, data: web::Data) -> Result { - let mut refresh_token_cookie = req.cookie("refresh_token").ok_or(Error::Unauthorized( - "request has no refresh token".to_string(), - ))?; +pub async fn res( + State(app_state): State>, + jar: CookieJar, +) -> Result { + let mut refresh_token_cookie = jar + .get("refresh_token") + .ok_or(Error::Unauthorized( + "request has no refresh token".to_string(), + ))? + .to_owned(); - let refresh_token = String::from(refresh_token_cookie.value()); + let access_token_cookie = jar.get("access_token"); - let mut conn = data.pool.get().await?; + let refresh_token = String::from(refresh_token_cookie.value_trimmed()); + + let mut conn = app_state.pool.get().await?; let deleted = delete(refresh_tokens::table) .filter(dsl::token.eq(refresh_token)) .execute(&mut conn) .await?; - refresh_token_cookie.make_removal(); + let mut response; if deleted == 0 { - return Ok(HttpResponse::NotFound() - .cookie(refresh_token_cookie) - .finish()); + response = StatusCode::NOT_FOUND.into_response(); + } else { + response = StatusCode::OK.into_response(); } - Ok(HttpResponse::Ok().cookie(refresh_token_cookie).finish()) + refresh_token_cookie.make_removal(); + response.headers_mut().append( + "Set-Cookie", + HeaderValue::from_str(&refresh_token_cookie.to_string())?, + ); + + if let Some(cookie) = access_token_cookie { + let mut cookie = cookie.clone(); + cookie.make_removal(); + response + .headers_mut() + .append("Set-Cookie2", HeaderValue::from_str(&cookie.to_string())?); + } + + Ok(response) } diff --git a/src/api/v1/auth/mod.rs b/src/api/v1/auth/mod.rs index 0e6b006..88be220 100644 --- a/src/api/v1/auth/mod.rs +++ b/src/api/v1/auth/mod.rs @@ -1,12 +1,17 @@ -use std::time::{SystemTime, UNIX_EPOCH}; +use std::{ + sync::Arc, + time::{SystemTime, UNIX_EPOCH}, +}; -use actix_web::{Scope, web}; +use axum::{ + Router, + routing::{delete, get, post}, +}; use diesel::{ExpressionMethods, QueryDsl}; use diesel_async::RunQueryDsl; -use serde::Serialize; use uuid::Uuid; -use crate::{Conn, error::Error, schema::access_tokens::dsl}; +use crate::{AppState, Conn, error::Error, schema::access_tokens::dsl}; mod devices; mod login; @@ -17,23 +22,18 @@ mod reset_password; mod revoke; mod verify_email; -#[derive(Serialize)] -struct Response { - access_token: String, -} - -pub fn web() -> Scope { - web::scope("/auth") - .service(register::res) - .service(login::response) - .service(logout::res) - .service(refresh::res) - .service(revoke::res) - .service(verify_email::get) - .service(verify_email::post) - .service(reset_password::get) - .service(reset_password::post) - .service(devices::get) +pub fn router() -> Router> { + Router::new() + .route("/register", post(register::post)) + .route("/login", post(login::response)) + .route("/logout", delete(logout::res)) + .route("/refresh", post(refresh::post)) + .route("/revoke", post(revoke::post)) + .route("/verify-email", get(verify_email::get)) + .route("/verify-email", post(verify_email::post)) + .route("/reset-password", get(reset_password::get)) + .route("/reset-password", post(reset_password::post)) + .route("/devices", get(devices::get)) } pub async fn check_access_token(access_token: &str, conn: &mut Conn) -> Result { diff --git a/src/api/v1/auth/refresh.rs b/src/api/v1/auth/refresh.rs index abd9a34..2a7e611 100644 --- a/src/api/v1/auth/refresh.rs +++ b/src/api/v1/auth/refresh.rs @@ -1,32 +1,45 @@ -use actix_web::{HttpRequest, HttpResponse, post, web}; +use axum::{ + extract::State, + http::{HeaderValue, StatusCode}, + response::IntoResponse, +}; +use axum_extra::extract::CookieJar; use diesel::{ExpressionMethods, QueryDsl, delete, update}; use diesel_async::RunQueryDsl; use log::error; -use std::time::{SystemTime, UNIX_EPOCH}; +use std::{ + sync::Arc, + time::{SystemTime, UNIX_EPOCH}, +}; use crate::{ - Data, + AppState, error::Error, schema::{ access_tokens::{self, dsl}, refresh_tokens::{self, dsl as rdsl}, }, - utils::{generate_token, new_refresh_token_cookie}, + utils::{generate_token, new_access_token_cookie, new_refresh_token_cookie}, }; -use super::Response; +pub async fn post( + State(app_state): State>, + jar: CookieJar, +) -> Result { + let mut refresh_token_cookie = jar + .get("refresh_token") + .ok_or(Error::Unauthorized( + "request has no refresh token".to_string(), + ))? + .to_owned(); -#[post("/refresh")] -pub async fn res(req: HttpRequest, data: web::Data) -> Result { - let mut refresh_token_cookie = req.cookie("refresh_token").ok_or(Error::Unauthorized( - "request has no refresh token".to_string(), - ))?; + let access_token_cookie = jar.get("access_token"); - let mut refresh_token = String::from(refresh_token_cookie.value()); + let refresh_token = String::from(refresh_token_cookie.value_trimmed()); let current_time = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() as i64; - let mut conn = data.pool.get().await?; + let mut conn = app_state.pool.get().await?; if let Ok(created_at) = rdsl::refresh_tokens .filter(rdsl::token.eq(&refresh_token)) @@ -45,15 +58,29 @@ pub async fn res(req: HttpRequest, data: web::Data) -> Result 1987200 { let new_refresh_token = generate_token::<32>()?; @@ -67,7 +94,13 @@ pub async fn res(req: HttpRequest, data: web::Data) -> Result { - refresh_token = new_refresh_token; + response.headers_mut().append( + "Set-Cookie", + HeaderValue::from_str( + &new_refresh_token_cookie(&app_state.config, new_refresh_token) + .to_string(), + )?, + ); } Err(error) => { error!("{error}"); @@ -86,14 +119,40 @@ pub async fn res(req: HttpRequest, data: web::Data) -> Result, - data: web::Data, -) -> Result { - if !data.config.instance.registration { +pub async fn post( + State(app_state): State>, + Json(account_information): Json, +) -> Result { + if !app_state.config.instance.registration { return Err(Error::Forbidden( "registration is disabled on this instance".to_string(), )); @@ -78,36 +77,48 @@ pub async fn res( let uuid = Uuid::now_v7(); if !EMAIL_REGEX.is_match(&account_information.email) { - return Ok(HttpResponse::Forbidden().json(ResponseError { - email_valid: false, - ..Default::default() - })); + return Ok(( + StatusCode::FORBIDDEN, + Json(ResponseError { + email_valid: false, + ..Default::default() + }), + ) + .into_response()); } if !USERNAME_REGEX.is_match(&account_information.identifier) || account_information.identifier.len() < 3 || account_information.identifier.len() > 32 { - return Ok(HttpResponse::Forbidden().json(ResponseError { - gorb_id_valid: false, - ..Default::default() - })); + return Ok(( + StatusCode::FORBIDDEN, + Json(ResponseError { + gorb_id_valid: false, + ..Default::default() + }), + ) + .into_response()); } if !PASSWORD_REGEX.is_match(&account_information.password) { - return Ok(HttpResponse::Forbidden().json(ResponseError { - password_hashed: false, - ..Default::default() - })); + return Ok(( + StatusCode::FORBIDDEN, + Json(ResponseError { + password_strength: false, + ..Default::default() + }), + ) + .into_response()); } let salt = SaltString::generate(&mut OsRng); - if let Ok(hashed_password) = data + if let Ok(hashed_password) = app_state .argon2 .hash_password(account_information.password.as_bytes(), &salt) { - let mut conn = data.pool.get().await?; + let mut conn = app_state.pool.get().await?; // TODO: Check security of this implementation insert_into(users::table) @@ -145,14 +156,27 @@ pub async fn res( .execute(&mut conn) .await?; - if let Some(initial_guild) = data.config.instance.initial_guild { - Member::new(&data, uuid, initial_guild).await?; + if let Some(initial_guild) = app_state.config.instance.initial_guild { + Member::new(&app_state, uuid, initial_guild).await?; } - return Ok(HttpResponse::Ok() - .cookie(new_refresh_token_cookie(&data.config, refresh_token)) - .json(Response { access_token })); + let mut response = StatusCode::OK.into_response(); + + response.headers_mut().append( + "Set-Cookie", + HeaderValue::from_str( + &new_refresh_token_cookie(&app_state.config, refresh_token).to_string(), + )?, + ); + response.headers_mut().append( + "Set-Cookie2", + HeaderValue::from_str( + &new_access_token_cookie(&app_state.config, access_token).to_string(), + )?, + ); + + return Ok(response); } - Ok(HttpResponse::InternalServerError().finish()) + Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response()) } diff --git a/src/api/v1/auth/reset_password.rs b/src/api/v1/auth/reset_password.rs index 9a4497f..bac465c 100644 --- a/src/api/v1/auth/reset_password.rs +++ b/src/api/v1/auth/reset_password.rs @@ -1,13 +1,20 @@ //! `/api/v1/auth/reset-password` Endpoints for resetting user password -use actix_web::{HttpResponse, get, post, web}; +use std::sync::Arc; + +use axum::{ + Json, + extract::{Query, State}, + http::StatusCode, + response::IntoResponse, +}; use chrono::{Duration, Utc}; use serde::Deserialize; -use crate::{Data, error::Error, objects::PasswordResetToken}; +use crate::{AppState, error::Error, objects::PasswordResetToken}; #[derive(Deserialize)] -struct Query { +pub struct QueryParams { identifier: String, } @@ -20,17 +27,22 @@ struct Query { /// /// ### Responses /// 200 Email sent +/// /// 429 Too Many Requests +/// /// 404 Not found +/// /// 400 Bad request /// -#[get("/reset-password")] -pub async fn get(query: web::Query, data: web::Data) -> Result { +pub async fn get( + State(app_state): State>, + query: Query, +) -> Result { if let Ok(password_reset_token) = - PasswordResetToken::get_with_identifier(&data, query.identifier.clone()).await + PasswordResetToken::get_with_identifier(&app_state, query.identifier.clone()).await { if Utc::now().signed_duration_since(password_reset_token.created_at) > Duration::hours(1) { - password_reset_token.delete(&data).await?; + password_reset_token.delete(&app_state).await?; } else { return Err(Error::TooManyRequests( "Please allow 1 hour before sending a new email".to_string(), @@ -38,13 +50,13 @@ pub async fn get(query: web::Query, data: web::Data) -> Result, - data: web::Data, -) -> Result { - let password_reset_token = PasswordResetToken::get(&data, reset_password.token.clone()).await?; + State(app_state): State>, + reset_password: Json, +) -> Result { + let password_reset_token = + PasswordResetToken::get(&app_state, reset_password.token.clone()).await?; password_reset_token - .set_password(&data, reset_password.password.clone()) + .set_password(&app_state, reset_password.password.clone()) .await?; - Ok(HttpResponse::Ok().finish()) + Ok(StatusCode::OK) } diff --git a/src/api/v1/auth/revoke.rs b/src/api/v1/auth/revoke.rs index 2e95884..50aa6d2 100644 --- a/src/api/v1/auth/revoke.rs +++ b/src/api/v1/auth/revoke.rs @@ -1,38 +1,39 @@ -use actix_web::{HttpRequest, HttpResponse, post, web}; +use std::sync::Arc; + use argon2::{PasswordHash, PasswordVerifier}; +use axum::{Json, extract::State, http::StatusCode, response::IntoResponse}; +use axum_extra::{ + TypedHeader, + headers::authorization::{Authorization, Bearer}, +}; use diesel::{ExpressionMethods, QueryDsl, delete}; use diesel_async::RunQueryDsl; use serde::Deserialize; use crate::{ - Data, + AppState, api::v1::auth::check_access_token, error::Error, schema::refresh_tokens::{self, dsl as rdsl}, schema::users::dsl as udsl, - utils::get_auth_header, }; #[derive(Deserialize)] -struct RevokeRequest { +pub struct RevokeRequest { password: String, device_name: String, } // TODO: Should maybe be a delete request? -#[post("/revoke")] -pub async fn res( - req: HttpRequest, - revoke_request: web::Json, - data: web::Data, -) -> Result { - let headers = req.headers(); +#[axum::debug_handler] +pub async fn post( + State(app_state): State>, + TypedHeader(auth): TypedHeader>, + Json(revoke_request): Json, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; - - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; + let uuid = check_access_token(auth.token(), &mut conn).await?; let database_password: String = udsl::users .filter(udsl::uuid.eq(uuid)) @@ -43,7 +44,7 @@ pub async fn res( let hashed_password = PasswordHash::new(&database_password) .map_err(|e| Error::PasswordHashError(e.to_string()))?; - if data + if app_state .argon2 .verify_password(revoke_request.password.as_bytes(), &hashed_password) .is_err() @@ -59,5 +60,5 @@ pub async fn res( .execute(&mut conn) .await?; - Ok(HttpResponse::Ok().finish()) + Ok(StatusCode::OK) } diff --git a/src/api/v1/auth/verify_email.rs b/src/api/v1/auth/verify_email.rs index 6b895aa..28aa1ab 100644 --- a/src/api/v1/auth/verify_email.rs +++ b/src/api/v1/auth/verify_email.rs @@ -1,19 +1,28 @@ //! `/api/v1/auth/verify-email` Endpoints for verifying user emails -use actix_web::{HttpRequest, HttpResponse, get, post, web}; +use std::sync::Arc; + +use axum::{ + extract::{Query, State}, + http::StatusCode, + response::IntoResponse, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; use chrono::{Duration, Utc}; use serde::Deserialize; use crate::{ - Data, + AppState, api::v1::auth::check_access_token, error::Error, objects::{EmailToken, Me}, - utils::get_auth_header, }; #[derive(Deserialize)] -struct Query { +pub struct QueryParams { token: String, } @@ -35,37 +44,32 @@ struct Query { /// /// 401 Unauthorized /// -#[get("/verify-email")] pub async fn get( - req: HttpRequest, - query: web::Query, - data: web::Data, -) -> Result { - let headers = req.headers(); + State(app_state): State>, + Query(query): Query, + TypedHeader(auth): TypedHeader>, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; - - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; + let uuid = check_access_token(auth.token(), &mut conn).await?; let me = Me::get(&mut conn, uuid).await?; if me.email_verified { - return Ok(HttpResponse::NoContent().finish()); + return Ok(StatusCode::NO_CONTENT); } - let email_token = EmailToken::get(&data, me.uuid).await?; + let email_token = EmailToken::get(&app_state, me.uuid).await?; if query.token != email_token.token { - return Ok(HttpResponse::Unauthorized().finish()); + return Ok(StatusCode::UNAUTHORIZED); } me.verify_email(&mut conn).await?; - email_token.delete(&data).await?; + email_token.delete(&app_state).await?; - Ok(HttpResponse::Ok().finish()) + Ok(StatusCode::OK) } /// `POST /api/v1/auth/verify-email` Sends user verification email @@ -81,25 +85,23 @@ pub async fn get( /// /// 401 Unauthorized /// -#[post("/verify-email")] -pub async fn post(req: HttpRequest, data: web::Data) -> Result { - let headers = req.headers(); +pub async fn post( + State(app_state): State>, + TypedHeader(auth): TypedHeader>, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; - - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; + let uuid = check_access_token(auth.token(), &mut conn).await?; let me = Me::get(&mut conn, uuid).await?; if me.email_verified { - return Ok(HttpResponse::NoContent().finish()); + return Ok(StatusCode::NO_CONTENT); } - if let Ok(email_token) = EmailToken::get(&data, me.uuid).await { + if let Ok(email_token) = EmailToken::get(&app_state, me.uuid).await { if Utc::now().signed_duration_since(email_token.created_at) > Duration::hours(1) { - email_token.delete(&data).await?; + email_token.delete(&app_state).await?; } else { return Err(Error::TooManyRequests( "Please allow 1 hour before sending a new email".to_string(), @@ -107,7 +109,7 @@ pub async fn post(req: HttpRequest, data: web::Data) -> Result Scope { - web::scope("/channels") - .service(uuid::get) - .service(uuid::delete) - .service(uuid::patch) - .service(uuid::messages::get) - .service(uuid::socket::ws) +pub fn router() -> Router> { + //let (layer, io) = SocketIo::new_layer(); + + //io.ns("/{uuid}/socket", uuid::socket::ws); + + Router::new() + .route("/{uuid}", get(uuid::get)) + .route("/{uuid}", delete(uuid::delete)) + .route("/{uuid}", patch(uuid::patch)) + .route("/{uuid}/messages", get(uuid::messages::get)) + //.layer(layer) } diff --git a/src/api/v1/channels/uuid/messages.rs b/src/api/v1/channels/uuid/messages.rs index 9fdea0b..8c12ee0 100644 --- a/src/api/v1/channels/uuid/messages.rs +++ b/src/api/v1/channels/uuid/messages.rs @@ -1,18 +1,29 @@ //! `/api/v1/channels/{uuid}/messages` Endpoints related to channel messages +use std::sync::Arc; + use crate::{ - Data, + AppState, api::v1::auth::check_access_token, error::Error, objects::{Channel, Member}, - utils::{get_auth_header, global_checks}, + utils::global_checks, }; use ::uuid::Uuid; -use actix_web::{HttpRequest, HttpResponse, get, web}; +use axum::{ + Json, + extract::{Path, Query, State}, + http::StatusCode, + response::IntoResponse, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; use serde::Deserialize; #[derive(Deserialize)] -struct MessageRequest { +pub struct MessageRequest { amount: i64, offset: i64, } @@ -47,32 +58,25 @@ struct MessageRequest { /// }); /// ``` /// -#[get("/{uuid}/messages")] pub async fn get( - req: HttpRequest, - path: web::Path<(Uuid,)>, - message_request: web::Query, - data: web::Data, -) -> Result { - let headers = req.headers(); + State(app_state): State>, + Path(channel_uuid): Path, + Query(message_request): Query, + TypedHeader(auth): TypedHeader>, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; + let uuid = check_access_token(auth.token(), &mut conn).await?; - let channel_uuid = path.into_inner().0; + global_checks(&app_state, uuid).await?; - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; - - global_checks(&data, uuid).await?; - - let channel = Channel::fetch_one(&data, channel_uuid).await?; + let channel = Channel::fetch_one(&app_state, channel_uuid).await?; Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; let messages = channel - .fetch_messages(&data, message_request.amount, message_request.offset) + .fetch_messages(&app_state, message_request.amount, message_request.offset) .await?; - Ok(HttpResponse::Ok().json(messages)) + Ok((StatusCode::OK, Json(messages))) } diff --git a/src/api/v1/channels/uuid/mod.rs b/src/api/v1/channels/uuid/mod.rs index fff2ef0..3ce91c3 100644 --- a/src/api/v1/channels/uuid/mod.rs +++ b/src/api/v1/channels/uuid/mod.rs @@ -1,77 +1,74 @@ //! `/api/v1/channels/{uuid}` Channel specific endpoints pub mod messages; -pub mod socket; +//pub mod socket; + +use std::sync::Arc; use crate::{ - Data, + AppState, api::v1::auth::check_access_token, error::Error, objects::{Channel, Member, Permissions}, - utils::{get_auth_header, global_checks}, + utils::global_checks, +}; +use axum::{ + Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, }; -use actix_web::{HttpRequest, HttpResponse, delete, get, patch, web}; use serde::Deserialize; use uuid::Uuid; -#[get("/{uuid}")] pub async fn get( - req: HttpRequest, - path: web::Path<(Uuid,)>, - data: web::Data, -) -> Result { - let headers = req.headers(); + State(app_state): State>, + Path(channel_uuid): Path, + TypedHeader(auth): TypedHeader>, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; + let uuid = check_access_token(auth.token(), &mut conn).await?; - let channel_uuid = path.into_inner().0; + global_checks(&app_state, uuid).await?; - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; - - global_checks(&data, uuid).await?; - - let channel = Channel::fetch_one(&data, channel_uuid).await?; + let channel = Channel::fetch_one(&app_state, channel_uuid).await?; Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; - Ok(HttpResponse::Ok().json(channel)) + Ok((StatusCode::OK, Json(channel))) } -#[delete("/{uuid}")] pub async fn delete( - req: HttpRequest, - path: web::Path<(Uuid,)>, - data: web::Data, -) -> Result { - let headers = req.headers(); + State(app_state): State>, + Path(channel_uuid): Path, + TypedHeader(auth): TypedHeader>, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; + let uuid = check_access_token(auth.token(), &mut conn).await?; - let channel_uuid = path.into_inner().0; + global_checks(&app_state, uuid).await?; - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; - - global_checks(&data, uuid).await?; - - let channel = Channel::fetch_one(&data, channel_uuid).await?; + let channel = Channel::fetch_one(&app_state, channel_uuid).await?; let member = Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; member - .check_permission(&data, Permissions::ManageChannel) + .check_permission(&app_state, Permissions::ManageChannel) .await?; - channel.delete(&data).await?; + channel.delete(&app_state).await?; - Ok(HttpResponse::Ok().finish()) + Ok(StatusCode::OK) } #[derive(Deserialize)] -struct NewInfo { +pub struct NewInfo { name: Option, description: Option, is_above: Option, @@ -108,48 +105,41 @@ struct NewInfo { /// }); /// ``` /// NOTE: UUIDs in this response are made using `uuidgen`, UUIDs made by the actual backend will be UUIDv7 and have extractable timestamps -#[patch("/{uuid}")] pub async fn patch( - req: HttpRequest, - path: web::Path<(Uuid,)>, - new_info: web::Json, - data: web::Data, -) -> Result { - let headers = req.headers(); + State(app_state): State>, + Path(channel_uuid): Path, + TypedHeader(auth): TypedHeader>, + Json(new_info): Json, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; + let uuid = check_access_token(auth.token(), &mut conn).await?; - let channel_uuid = path.into_inner().0; + global_checks(&app_state, uuid).await?; - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; - - global_checks(&data, uuid).await?; - - let mut channel = Channel::fetch_one(&data, channel_uuid).await?; + let mut channel = Channel::fetch_one(&app_state, channel_uuid).await?; let member = Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; member - .check_permission(&data, Permissions::ManageChannel) + .check_permission(&app_state, Permissions::ManageChannel) .await?; if let Some(new_name) = &new_info.name { - channel.set_name(&data, new_name.to_string()).await?; + channel.set_name(&app_state, new_name.to_string()).await?; } if let Some(new_description) = &new_info.description { channel - .set_description(&data, new_description.to_string()) + .set_description(&app_state, new_description.to_string()) .await?; } if let Some(new_is_above) = &new_info.is_above { channel - .set_description(&data, new_is_above.to_string()) + .set_description(&app_state, new_is_above.to_string()) .await?; } - Ok(HttpResponse::Ok().json(channel)) + Ok((StatusCode::OK, Json(channel))) } diff --git a/src/api/v1/guilds/mod.rs b/src/api/v1/guilds/mod.rs index ada5dc8..18a117f 100644 --- a/src/api/v1/guilds/mod.rs +++ b/src/api/v1/guilds/mod.rs @@ -1,28 +1,40 @@ //! `/api/v1/guilds` Guild related endpoints -use actix_web::{HttpRequest, HttpResponse, Scope, get, post, web}; +use std::sync::Arc; + +use axum::{ + Json, Router, + extract::State, + http::StatusCode, + response::IntoResponse, + routing::{get, post}, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; use serde::Deserialize; mod uuid; use crate::{ - Data, + AppState, api::v1::auth::check_access_token, error::Error, objects::{Guild, StartAmountQuery}, - utils::{get_auth_header, global_checks}, + utils::global_checks, }; #[derive(Deserialize)] -struct GuildInfo { +pub struct GuildInfo { name: String, } -pub fn web() -> Scope { - web::scope("/guilds") - .service(post) - .service(get) - .service(uuid::web()) +pub fn router() -> Router> { + Router::new() + .route("/", post(new)) + .route("/", get(get_guilds)) + .nest("/{uuid}", uuid::router()) } /// `POST /api/v1/guilds` Creates a new guild @@ -49,23 +61,18 @@ pub fn web() -> Scope { /// }); /// ``` /// NOTE: UUIDs in this response are made using `uuidgen`, UUIDs made by the actual backend will be UUIDv7 and have extractable timestamps -#[post("")] -pub async fn post( - req: HttpRequest, - guild_info: web::Json, - data: web::Data, -) -> Result { - let headers = req.headers(); +pub async fn new( + State(app_state): State>, + TypedHeader(auth): TypedHeader>, + Json(guild_info): Json, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; - - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; + let uuid = check_access_token(auth.token(), &mut conn).await?; let guild = Guild::new(&mut conn, guild_info.name.clone(), uuid).await?; - Ok(HttpResponse::Ok().json(guild)) + Ok((StatusCode::OK, Json(guild))) } /// `GET /api/v1/servers` Fetches all guilds @@ -115,25 +122,20 @@ pub async fn post( /// ]); /// ``` /// NOTE: UUIDs in this response are made using `uuidgen`, UUIDs made by the actual backend will be UUIDv7 and have extractable timestamps -#[get("")] -pub async fn get( - req: HttpRequest, - request_query: web::Query, - data: web::Data, -) -> Result { - let headers = req.headers(); - - let auth_header = get_auth_header(headers)?; - +pub async fn get_guilds( + State(app_state): State>, + TypedHeader(auth): TypedHeader>, + Json(request_query): Json, +) -> Result { let start = request_query.start.unwrap_or(0); let amount = request_query.amount.unwrap_or(10); - let uuid = check_access_token(auth_header, &mut data.pool.get().await?).await?; + let uuid = check_access_token(auth.token(), &mut app_state.pool.get().await?).await?; - global_checks(&data, uuid).await?; + global_checks(&app_state, uuid).await?; - let guilds = Guild::fetch_amount(&data.pool, start, amount).await?; + let guilds = Guild::fetch_amount(&app_state.pool, start, amount).await?; - Ok(HttpResponse::Ok().json(guilds)) + Ok((StatusCode::OK, Json(guilds))) } diff --git a/src/api/v1/guilds/uuid/channels.rs b/src/api/v1/guilds/uuid/channels.rs index b9f91cb..0104566 100644 --- a/src/api/v1/guilds/uuid/channels.rs +++ b/src/api/v1/guilds/uuid/channels.rs @@ -1,92 +1,92 @@ +use std::sync::Arc; + +use ::uuid::Uuid; +use axum::{ + Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; +use serde::Deserialize; + use crate::{ - Data, + AppState, api::v1::auth::check_access_token, error::Error, objects::{Channel, Member, Permissions}, - utils::{get_auth_header, global_checks, order_by_is_above}, + utils::{global_checks, order_by_is_above}, }; -use ::uuid::Uuid; -use actix_web::{HttpRequest, HttpResponse, get, post, web}; -use serde::Deserialize; #[derive(Deserialize)] -struct ChannelInfo { +pub struct ChannelInfo { name: String, description: Option, } -#[get("{uuid}/channels")] pub async fn get( - req: HttpRequest, - path: web::Path<(Uuid,)>, - data: web::Data, -) -> Result { - let headers = req.headers(); + State(app_state): State>, + Path(guild_uuid): Path, + TypedHeader(auth): TypedHeader>, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; + let uuid = check_access_token(auth.token(), &mut conn).await?; - let guild_uuid = path.into_inner().0; - - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; - - global_checks(&data, uuid).await?; + global_checks(&app_state, uuid).await?; Member::check_membership(&mut conn, uuid, guild_uuid).await?; - if let Ok(cache_hit) = data.get_cache_key(format!("{guild_uuid}_channels")).await { - return Ok(HttpResponse::Ok() - .content_type("application/json") - .body(cache_hit)); + if let Ok(cache_hit) = app_state + .get_cache_key(format!("{guild_uuid}_channels")) + .await + { + return Ok((StatusCode::OK, Json(cache_hit)).into_response()); } - let channels = Channel::fetch_all(&data.pool, guild_uuid).await?; + let channels = Channel::fetch_all(&app_state.pool, guild_uuid).await?; let channels_ordered = order_by_is_above(channels).await?; - data.set_cache_key( - format!("{guild_uuid}_channels"), - channels_ordered.clone(), - 1800, - ) - .await?; + app_state + .set_cache_key( + format!("{guild_uuid}_channels"), + channels_ordered.clone(), + 1800, + ) + .await?; - Ok(HttpResponse::Ok().json(channels_ordered)) + Ok((StatusCode::OK, Json(channels_ordered)).into_response()) } -#[post("{uuid}/channels")] pub async fn create( - req: HttpRequest, - channel_info: web::Json, - path: web::Path<(Uuid,)>, - data: web::Data, -) -> Result { - let headers = req.headers(); + State(app_state): State>, + Path(guild_uuid): Path, + TypedHeader(auth): TypedHeader>, + Json(channel_info): Json, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; + let uuid = check_access_token(auth.token(), &mut conn).await?; - let guild_uuid = path.into_inner().0; - - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; - - global_checks(&data, uuid).await?; + global_checks(&app_state, uuid).await?; let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; member - .check_permission(&data, Permissions::ManageChannel) + .check_permission(&app_state, Permissions::ManageChannel) .await?; let channel = Channel::new( - data.clone(), + &app_state, guild_uuid, channel_info.name.clone(), channel_info.description.clone(), ) .await?; - Ok(HttpResponse::Ok().json(channel)) + Ok((StatusCode::OK, Json(channel))) } diff --git a/src/api/v1/guilds/uuid/icon.rs b/src/api/v1/guilds/uuid/icon.rs deleted file mode 100644 index 600ccba..0000000 --- a/src/api/v1/guilds/uuid/icon.rs +++ /dev/null @@ -1,62 +0,0 @@ -//! `/api/v1/guilds/{uuid}/icon` icon related endpoints, will probably be replaced by a multipart post to above endpoint - -use actix_web::{HttpRequest, HttpResponse, put, web}; -use futures_util::StreamExt as _; -use uuid::Uuid; - -use crate::{ - Data, - api::v1::auth::check_access_token, - error::Error, - objects::{Guild, Member, Permissions}, - utils::{get_auth_header, global_checks}, -}; - -/// `PUT /api/v1/guilds/{uuid}/icon` Icon upload -/// -/// requires auth: no -/// -/// put request expects a file and nothing else -#[put("{uuid}/icon")] -pub async fn upload( - req: HttpRequest, - path: web::Path<(Uuid,)>, - mut payload: web::Payload, - data: web::Data, -) -> Result { - let headers = req.headers(); - - let auth_header = get_auth_header(headers)?; - - let guild_uuid = path.into_inner().0; - - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; - - global_checks(&data, uuid).await?; - - let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; - - member - .check_permission(&data, Permissions::ManageGuild) - .await?; - - let mut guild = Guild::fetch_one(&mut conn, guild_uuid).await?; - - let mut bytes = web::BytesMut::new(); - while let Some(item) = payload.next().await { - bytes.extend_from_slice(&item?); - } - - guild - .set_icon( - &data.bunny_storage, - &mut conn, - data.config.bunny.cdn_url.clone(), - bytes, - ) - .await?; - - Ok(HttpResponse::Ok().finish()) -} diff --git a/src/api/v1/guilds/uuid/invites/mod.rs b/src/api/v1/guilds/uuid/invites/mod.rs index f1c62bc..7703cf7 100644 --- a/src/api/v1/guilds/uuid/invites/mod.rs +++ b/src/api/v1/guilds/uuid/invites/mod.rs @@ -1,37 +1,41 @@ -use actix_web::{HttpRequest, HttpResponse, get, post, web}; +use std::sync::Arc; + +use axum::{ + Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; use serde::Deserialize; use uuid::Uuid; use crate::{ - Data, + AppState, api::v1::auth::check_access_token, error::Error, objects::{Guild, Member, Permissions}, - utils::{get_auth_header, global_checks}, + utils::global_checks, }; #[derive(Deserialize)] -struct InviteRequest { +pub struct InviteRequest { custom_id: Option, } -#[get("{uuid}/invites")] pub async fn get( - req: HttpRequest, - path: web::Path<(Uuid,)>, - data: web::Data, -) -> Result { - let headers = req.headers(); + State(app_state): State>, + Path(guild_uuid): Path, + TypedHeader(auth): TypedHeader>, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; + let uuid = check_access_token(auth.token(), &mut conn).await?; - let guild_uuid = path.into_inner().0; - - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; - - global_checks(&data, uuid).await?; + global_checks(&app_state, uuid).await?; Member::check_membership(&mut conn, uuid, guild_uuid).await?; @@ -39,32 +43,25 @@ pub async fn get( let invites = guild.get_invites(&mut conn).await?; - Ok(HttpResponse::Ok().json(invites)) + Ok((StatusCode::OK, Json(invites))) } -#[post("{uuid}/invites")] pub async fn create( - req: HttpRequest, - path: web::Path<(Uuid,)>, - invite_request: web::Json, - data: web::Data, -) -> Result { - let headers = req.headers(); + State(app_state): State>, + Path(guild_uuid): Path, + TypedHeader(auth): TypedHeader>, + Json(invite_request): Json, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; + let uuid = check_access_token(auth.token(), &mut conn).await?; - let guild_uuid = path.into_inner().0; - - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; - - global_checks(&data, uuid).await?; + global_checks(&app_state, uuid).await?; let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; member - .check_permission(&data, Permissions::CreateInvite) + .check_permission(&app_state, Permissions::CreateInvite) .await?; let guild = Guild::fetch_one(&mut conn, guild_uuid).await?; @@ -73,5 +70,5 @@ pub async fn create( .create_invite(&mut conn, uuid, invite_request.custom_id.clone()) .await?; - Ok(HttpResponse::Ok().json(invite)) + Ok((StatusCode::OK, Json(invite))) } diff --git a/src/api/v1/guilds/uuid/members.rs b/src/api/v1/guilds/uuid/members.rs index 0afc2c5..bd2f853 100644 --- a/src/api/v1/guilds/uuid/members.rs +++ b/src/api/v1/guilds/uuid/members.rs @@ -1,36 +1,41 @@ +use std::sync::Arc; + +use ::uuid::Uuid; +use axum::{ + Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; + use crate::{ - Data, + AppState, api::v1::auth::check_access_token, error::Error, objects::{Me, Member}, - utils::{get_auth_header, global_checks}, + utils::global_checks, }; -use ::uuid::Uuid; -use actix_web::{HttpRequest, HttpResponse, get, web}; -#[get("{uuid}/members")] pub async fn get( - req: HttpRequest, - path: web::Path<(Uuid,)>, - data: web::Data, -) -> Result { - let headers = req.headers(); + State(app_state): State>, + Path(guild_uuid): Path, + TypedHeader(auth): TypedHeader>, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; + let uuid = check_access_token(auth.token(), &mut conn).await?; - let guild_uuid = path.into_inner().0; - - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; - - global_checks(&data, uuid).await?; + global_checks(&app_state, uuid).await?; Member::check_membership(&mut conn, uuid, guild_uuid).await?; let me = Me::get(&mut conn, uuid).await?; - let members = Member::fetch_all(&data, &me, guild_uuid).await?; + let members = Member::fetch_all(&app_state, &me, guild_uuid).await?; - Ok(HttpResponse::Ok().json(members)) + Ok((StatusCode::OK, Json(members))) } diff --git a/src/api/v1/guilds/uuid/mod.rs b/src/api/v1/guilds/uuid/mod.rs index 4c88d7a..0a27123 100644 --- a/src/api/v1/guilds/uuid/mod.rs +++ b/src/api/v1/guilds/uuid/mod.rs @@ -1,40 +1,51 @@ //! `/api/v1/guilds/{uuid}` Specific server endpoints -use actix_web::{HttpRequest, HttpResponse, Scope, get, web}; +use std::sync::Arc; + +use axum::{ + Json, Router, + extract::{Multipart, Path, State}, + http::StatusCode, + response::IntoResponse, + routing::{get, patch, post}, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; +use bytes::Bytes; use uuid::Uuid; mod channels; -mod icon; mod invites; mod members; mod roles; use crate::{ - Data, + AppState, api::v1::auth::check_access_token, error::Error, - objects::{Guild, Member}, - utils::{get_auth_header, global_checks}, + objects::{Guild, Member, Permissions}, + utils::global_checks, }; -pub fn web() -> Scope { - web::scope("") +pub fn router() -> Router> { + Router::new() // Servers - .service(get) + .route("/", get(get_guild)) + .route("/", patch(edit)) // Channels - .service(channels::get) - .service(channels::create) + .route("/channels", get(channels::get)) + .route("/channels", post(channels::create)) // Roles - .service(roles::get) - .service(roles::create) - .service(roles::uuid::get) + .route("/roles", get(roles::get)) + .route("/roles", post(roles::create)) + .route("/roles/{role_uuid}", get(roles::uuid::get)) // Invites - .service(invites::get) - .service(invites::create) - // Icon - .service(icon::upload) + .route("/invites", get(invites::get)) + .route("/invites", post(invites::create)) // Members - .service(members::get) + .route("/members", get(members::get)) } /// `GET /api/v1/guilds/{uuid}` DESCRIPTION @@ -70,27 +81,69 @@ pub fn web() -> Scope { /// "member_count": 20 /// }); /// ``` -#[get("/{uuid}")] -pub async fn get( - req: HttpRequest, - path: web::Path<(Uuid,)>, - data: web::Data, -) -> Result { - let headers = req.headers(); +pub async fn get_guild( + State(app_state): State>, + Path(guild_uuid): Path, + TypedHeader(auth): TypedHeader>, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; + let uuid = check_access_token(auth.token(), &mut conn).await?; - let guild_uuid = path.into_inner().0; - - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; - - global_checks(&data, uuid).await?; + global_checks(&app_state, uuid).await?; Member::check_membership(&mut conn, uuid, guild_uuid).await?; let guild = Guild::fetch_one(&mut conn, guild_uuid).await?; - Ok(HttpResponse::Ok().json(guild)) + Ok((StatusCode::OK, Json(guild))) +} + +/// `PATCH /api/v1/guilds/{uuid}` change guild settings +/// +/// requires auth: yes +pub async fn edit( + State(app_state): State>, + Path(guild_uuid): Path, + TypedHeader(auth): TypedHeader>, + mut multipart: Multipart, +) -> Result { + let mut conn = app_state.pool.get().await?; + + let uuid = check_access_token(auth.token(), &mut conn).await?; + + global_checks(&app_state, uuid).await?; + + let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; + + member + .check_permission(&app_state, Permissions::ManageGuild) + .await?; + + let mut guild = Guild::fetch_one(&mut conn, guild_uuid).await?; + + let mut icon: Option = None; + + while let Some(field) = multipart.next_field().await.unwrap() { + let name = field + .name() + .ok_or(Error::BadRequest("Field has no name".to_string()))?; + + if name == "icon" { + icon = Some(field.bytes().await?); + } + } + + if let Some(icon) = icon { + guild + .set_icon( + &app_state.bunny_storage, + &mut conn, + app_state.config.bunny.cdn_url.clone(), + icon, + ) + .await?; + } + + Ok(StatusCode::OK) } diff --git a/src/api/v1/guilds/uuid/roles/mod.rs b/src/api/v1/guilds/uuid/roles/mod.rs index 0fcc5b3..12960c2 100644 --- a/src/api/v1/guilds/uuid/roles/mod.rs +++ b/src/api/v1/guilds/uuid/roles/mod.rs @@ -1,82 +1,78 @@ +use std::sync::Arc; + use ::uuid::Uuid; -use actix_web::{HttpRequest, HttpResponse, get, post, web}; +use axum::{ + Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; use serde::Deserialize; use crate::{ - Data, + AppState, api::v1::auth::check_access_token, error::Error, objects::{Member, Permissions, Role}, - utils::{get_auth_header, global_checks, order_by_is_above}, + utils::{global_checks, order_by_is_above}, }; pub mod uuid; #[derive(Deserialize)] -struct RoleInfo { +pub struct RoleInfo { name: String, } -#[get("{uuid}/roles")] pub async fn get( - req: HttpRequest, - path: web::Path<(Uuid,)>, - data: web::Data, -) -> Result { - let headers = req.headers(); + State(app_state): State>, + Path(guild_uuid): Path, + TypedHeader(auth): TypedHeader>, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; - - let guild_uuid = path.into_inner().0; - - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; + let uuid = check_access_token(auth.token(), &mut conn).await?; Member::check_membership(&mut conn, uuid, guild_uuid).await?; - if let Ok(cache_hit) = data.get_cache_key(format!("{guild_uuid}_roles")).await { - return Ok(HttpResponse::Ok() - .content_type("application/json") - .body(cache_hit)); + if let Ok(cache_hit) = app_state.get_cache_key(format!("{guild_uuid}_roles")).await { + return Ok((StatusCode::OK, Json(cache_hit)).into_response()); } let roles = Role::fetch_all(&mut conn, guild_uuid).await?; let roles_ordered = order_by_is_above(roles).await?; - data.set_cache_key(format!("{guild_uuid}_roles"), roles_ordered.clone(), 1800) + app_state + .set_cache_key(format!("{guild_uuid}_roles"), roles_ordered.clone(), 1800) .await?; - Ok(HttpResponse::Ok().json(roles_ordered)) + Ok((StatusCode::OK, Json(roles_ordered)).into_response()) } -#[post("{uuid}/roles")] pub async fn create( - req: HttpRequest, - role_info: web::Json, - path: web::Path<(Uuid,)>, - data: web::Data, -) -> Result { - let headers = req.headers(); + State(app_state): State>, + Path(guild_uuid): Path, + TypedHeader(auth): TypedHeader>, + Json(role_info): Json, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; + let uuid = check_access_token(auth.token(), &mut conn).await?; - let guild_uuid = path.into_inner().0; - - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; - - global_checks(&data, uuid).await?; + global_checks(&app_state, uuid).await?; let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; member - .check_permission(&data, Permissions::ManageRole) + .check_permission(&app_state, Permissions::ManageRole) .await?; let role = Role::new(&mut conn, guild_uuid, role_info.name.clone()).await?; - Ok(HttpResponse::Ok().json(role)) + Ok((StatusCode::OK, Json(role)).into_response()) } diff --git a/src/api/v1/guilds/uuid/roles/uuid.rs b/src/api/v1/guilds/uuid/roles/uuid.rs index bd747d8..a62a5b4 100644 --- a/src/api/v1/guilds/uuid/roles/uuid.rs +++ b/src/api/v1/guilds/uuid/roles/uuid.rs @@ -1,43 +1,47 @@ +use std::sync::Arc; + +use ::uuid::Uuid; +use axum::{ + Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; + use crate::{ - Data, + AppState, api::v1::auth::check_access_token, error::Error, objects::{Member, Role}, - utils::{get_auth_header, global_checks}, + utils::global_checks, }; -use ::uuid::Uuid; -use actix_web::{HttpRequest, HttpResponse, get, web}; -#[get("{uuid}/roles/{role_uuid}")] pub async fn get( - req: HttpRequest, - path: web::Path<(Uuid, Uuid)>, - data: web::Data, -) -> Result { - let headers = req.headers(); + State(app_state): State>, + Path((guild_uuid, role_uuid)): Path<(Uuid, Uuid)>, + TypedHeader(auth): TypedHeader>, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; + let uuid = check_access_token(auth.token(), &mut conn).await?; - let (guild_uuid, role_uuid) = path.into_inner(); - - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; - - global_checks(&data, uuid).await?; + global_checks(&app_state, uuid).await?; Member::check_membership(&mut conn, uuid, guild_uuid).await?; - if let Ok(cache_hit) = data.get_cache_key(format!("{role_uuid}")).await { - return Ok(HttpResponse::Ok() - .content_type("application/json") - .body(cache_hit)); + if let Ok(cache_hit) = app_state.get_cache_key(format!("{role_uuid}")).await { + return Ok((StatusCode::OK, Json(cache_hit)).into_response()); } let role = Role::fetch_one(&mut conn, role_uuid).await?; - data.set_cache_key(format!("{role_uuid}"), role.clone(), 60) + app_state + .set_cache_key(format!("{role_uuid}"), role.clone(), 60) .await?; - Ok(HttpResponse::Ok().json(role)) + Ok((StatusCode::OK, Json(role)).into_response()) } diff --git a/src/api/v1/invites/id.rs b/src/api/v1/invites/id.rs index 22e2868..b832557 100644 --- a/src/api/v1/invites/id.rs +++ b/src/api/v1/invites/id.rs @@ -1,49 +1,53 @@ -use actix_web::{HttpRequest, HttpResponse, get, post, web}; +use std::sync::Arc; + +use axum::{ + Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; use crate::{ - Data, + AppState, api::v1::auth::check_access_token, error::Error, objects::{Guild, Invite, Member}, - utils::{get_auth_header, global_checks}, + utils::global_checks, }; -#[get("{id}")] -pub async fn get(path: web::Path<(String,)>, data: web::Data) -> Result { - let mut conn = data.pool.get().await?; - - let invite_id = path.into_inner().0; +pub async fn get( + State(app_state): State>, + Path(invite_id): Path, +) -> Result { + let mut conn = app_state.pool.get().await?; let invite = Invite::fetch_one(&mut conn, invite_id).await?; let guild = Guild::fetch_one(&mut conn, invite.guild_uuid).await?; - Ok(HttpResponse::Ok().json(guild)) + Ok((StatusCode::OK, Json(guild))) } -#[post("{id}")] pub async fn join( - req: HttpRequest, - path: web::Path<(String,)>, - data: web::Data, -) -> Result { - let headers = req.headers(); + State(app_state): State>, + Path(invite_id): Path, + TypedHeader(auth): TypedHeader>, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; + let uuid = check_access_token(auth.token(), &mut conn).await?; - let invite_id = path.into_inner().0; - - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; - - global_checks(&data, uuid).await?; + global_checks(&app_state, uuid).await?; let invite = Invite::fetch_one(&mut conn, invite_id).await?; let guild = Guild::fetch_one(&mut conn, invite.guild_uuid).await?; - Member::new(&data, uuid, guild.uuid).await?; + Member::new(&app_state, uuid, guild.uuid).await?; - Ok(HttpResponse::Ok().json(guild)) + Ok((StatusCode::OK, Json(guild))) } diff --git a/src/api/v1/invites/mod.rs b/src/api/v1/invites/mod.rs index 3714a83..50fb707 100644 --- a/src/api/v1/invites/mod.rs +++ b/src/api/v1/invites/mod.rs @@ -1,7 +1,16 @@ -use actix_web::{Scope, web}; +use std::sync::Arc; + +use axum::{ + Router, + routing::{get, post}, +}; + +use crate::AppState; mod id; -pub fn web() -> Scope { - web::scope("/invites").service(id::get).service(id::join) +pub fn router() -> Router> { + Router::new() + .route("/{id}", get(id::get)) + .route("/{id}", post(id::join)) } diff --git a/src/api/v1/me/friends/mod.rs b/src/api/v1/me/friends/mod.rs index 8de0a5d..8a7851c 100644 --- a/src/api/v1/me/friends/mod.rs +++ b/src/api/v1/me/friends/mod.rs @@ -1,38 +1,42 @@ -use actix_web::{HttpRequest, HttpResponse, get, post, web}; +use std::sync::Arc; + +use axum::{Json, extract::State, http::StatusCode, response::IntoResponse}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; use serde::Deserialize; pub mod uuid; use crate::{ - Data, + AppState, api::v1::auth::check_access_token, error::Error, objects::Me, - utils::{get_auth_header, global_checks, user_uuid_from_username} + utils::{global_checks, user_uuid_from_username}, }; /// Returns a list of users that are your friends -#[get("/friends")] -pub async fn get(req: HttpRequest, data: web::Data) -> Result { - let headers = req.headers(); +pub async fn get( + State(app_state): State>, + TypedHeader(auth): TypedHeader>, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; + let uuid = check_access_token(auth.token(), &mut conn).await?; - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; - - global_checks(&data, uuid).await?; + global_checks(&app_state, uuid).await?; let me = Me::get(&mut conn, uuid).await?; - let friends = me.get_friends(&data).await?; + let friends = me.get_friends(&app_state).await?; - Ok(HttpResponse::Ok().json(friends)) + Ok((StatusCode::OK, Json(friends))) } #[derive(Deserialize)] -struct UserReq { +pub struct UserReq { username: String, } @@ -55,26 +59,21 @@ struct UserReq { /// /// 400 Bad Request (usually means users are already friends) /// -#[post("/friends")] pub async fn post( - req: HttpRequest, - json: web::Json, - data: web::Data, -) -> Result { - let headers = req.headers(); + State(app_state): State>, + TypedHeader(auth): TypedHeader>, + Json(user_request): Json, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; + let uuid = check_access_token(auth.token(), &mut conn).await?; - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; - - global_checks(&data, uuid).await?; + global_checks(&app_state, uuid).await?; let me = Me::get(&mut conn, uuid).await?; - let target_uuid = user_uuid_from_username(&mut conn, &json.username).await?; + let target_uuid = user_uuid_from_username(&mut conn, &user_request.username).await?; me.add_friend(&mut conn, target_uuid).await?; - Ok(HttpResponse::Ok().finish()) + Ok(StatusCode::OK) } diff --git a/src/api/v1/me/friends/uuid.rs b/src/api/v1/me/friends/uuid.rs index 34bfeff..8d40f26 100644 --- a/src/api/v1/me/friends/uuid.rs +++ b/src/api/v1/me/friends/uuid.rs @@ -1,33 +1,34 @@ -use actix_web::{HttpRequest, HttpResponse, delete, web}; +use std::sync::Arc; + +use axum::{ + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; use uuid::Uuid; use crate::{ - Data, - api::v1::auth::check_access_token, - error::Error, - objects::Me, - utils::{get_auth_header, global_checks}, + AppState, api::v1::auth::check_access_token, error::Error, objects::Me, utils::global_checks, }; -#[delete("/friends/{uuid}")] pub async fn delete( - req: HttpRequest, - path: web::Path<(Uuid,)>, - data: web::Data, -) -> Result { - let headers = req.headers(); + State(app_state): State>, + Path(friend_uuid): Path, + TypedHeader(auth): TypedHeader>, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; + let uuid = check_access_token(auth.token(), &mut conn).await?; - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; - - global_checks(&data, uuid).await?; + global_checks(&app_state, uuid).await?; let me = Me::get(&mut conn, uuid).await?; - me.remove_friend(&mut conn, path.0).await?; + me.remove_friend(&mut conn, friend_uuid).await?; - Ok(HttpResponse::Ok().finish()) + Ok(StatusCode::OK) } diff --git a/src/api/v1/me/guilds.rs b/src/api/v1/me/guilds.rs index 71cfca4..adfe845 100644 --- a/src/api/v1/me/guilds.rs +++ b/src/api/v1/me/guilds.rs @@ -1,13 +1,15 @@ //! `/api/v1/me/guilds` Contains endpoint related to guild memberships -use actix_web::{HttpRequest, HttpResponse, get, web}; +use std::sync::Arc; + +use axum::{Json, extract::State, http::StatusCode, response::IntoResponse}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; use crate::{ - Data, - api::v1::auth::check_access_token, - error::Error, - objects::Me, - utils::{get_auth_header, global_checks}, + AppState, api::v1::auth::check_access_token, error::Error, objects::Me, utils::global_checks, }; /// `GET /api/v1/me/guilds` Returns all guild memberships in a list @@ -55,21 +57,19 @@ use crate::{ /// ]); /// ``` /// NOTE: UUIDs in this response are made using `uuidgen`, UUIDs made by the actual backend will be UUIDv7 and have extractable timestamps -#[get("/guilds")] -pub async fn get(req: HttpRequest, data: web::Data) -> Result { - let headers = req.headers(); +pub async fn get( + State(app_state): State>, + TypedHeader(auth): TypedHeader>, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; + let uuid = check_access_token(auth.token(), &mut conn).await?; - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; - - global_checks(&data, uuid).await?; + global_checks(&app_state, uuid).await?; let me = Me::get(&mut conn, uuid).await?; let memberships = me.fetch_memberships(&mut conn).await?; - Ok(HttpResponse::Ok().json(memberships)) + Ok((StatusCode::OK, Json(memberships))) } diff --git a/src/api/v1/me/mod.rs b/src/api/v1/me/mod.rs index f667ca4..e9680bc 100644 --- a/src/api/v1/me/mod.rs +++ b/src/api/v1/me/mod.rs @@ -1,108 +1,120 @@ -use actix_multipart::form::{MultipartForm, json::Json as MpJson, tempfile::TempFile}; -use actix_web::{HttpRequest, HttpResponse, Scope, get, patch, web}; +use std::sync::Arc; + +use axum::{ + Json, Router, + extract::{DefaultBodyLimit, Multipart, State}, + http::StatusCode, + response::IntoResponse, + routing::{delete, get, patch, post}, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; +use bytes::Bytes; use serde::Deserialize; use crate::{ - Data, - api::v1::auth::check_access_token, - error::Error, - objects::Me, - utils::{get_auth_header, global_checks}, + AppState, api::v1::auth::check_access_token, error::Error, objects::Me, utils::global_checks, }; mod friends; mod guilds; -pub fn web() -> Scope { - web::scope("/me") - .service(get) - .service(update) - .service(guilds::get) - .service(friends::get) - .service(friends::post) - .service(friends::uuid::delete) +pub fn router() -> Router> { + Router::new() + .route("/", get(get_me)) + .route( + "/", + patch(update).layer(DefaultBodyLimit::max( + 100 * 1024 * 1024, /* limit is in bytes */ + )), + ) + .route("/guilds", get(guilds::get)) + .route("/friends", get(friends::get)) + .route("/friends", post(friends::post)) + .route("/friends/{uuid}", delete(friends::uuid::delete)) } -#[get("")] -pub async fn get(req: HttpRequest, data: web::Data) -> Result { - let headers = req.headers(); +pub async fn get_me( + State(app_state): State>, + TypedHeader(auth): TypedHeader>, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; - - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; + let uuid = check_access_token(auth.token(), &mut conn).await?; let me = Me::get(&mut conn, uuid).await?; - Ok(HttpResponse::Ok().json(me)) + Ok((StatusCode::OK, Json(me))) } -#[derive(Debug, Deserialize, Clone)] +#[derive(Default, Debug, Deserialize, Clone)] struct NewInfo { username: Option, display_name: Option, - //password: Option, will probably be handled through a reset password link email: Option, pronouns: Option, about: Option, } -#[derive(Debug, MultipartForm)] -struct UploadForm { - #[multipart(limit = "100MB")] - avatar: Option, - json: MpJson, -} - -#[patch("")] pub async fn update( - req: HttpRequest, - MultipartForm(form): MultipartForm, - data: web::Data, -) -> Result { - let headers = req.headers(); + State(app_state): State>, + TypedHeader(auth): TypedHeader>, + mut multipart: Multipart, +) -> Result { + let mut conn = app_state.pool.get().await?; - let auth_header = get_auth_header(headers)?; + let uuid = check_access_token(auth.token(), &mut conn).await?; - let mut conn = data.pool.get().await?; + let mut json_raw: Option = None; + let mut avatar: Option = None; - let uuid = check_access_token(auth_header, &mut conn).await?; + while let Some(field) = multipart.next_field().await.unwrap() { + let name = field + .name() + .ok_or(Error::BadRequest("Field has no name".to_string()))?; - if form.avatar.is_some() || form.json.username.is_some() || form.json.display_name.is_some() { - global_checks(&data, uuid).await?; + if name == "avatar" { + avatar = Some(field.bytes().await?); + } else if name == "json" { + json_raw = Some(serde_json::from_str(&field.text().await?)?) + } + } + + let json = json_raw.unwrap_or_default(); + + if avatar.is_some() || json.username.is_some() || json.display_name.is_some() { + global_checks(&app_state, uuid).await?; } let mut me = Me::get(&mut conn, uuid).await?; - if let Some(avatar) = form.avatar { - let bytes = tokio::fs::read(avatar.file).await?; - - let byte_slice: &[u8] = &bytes; - - me.set_avatar(&data, data.config.bunny.cdn_url.clone(), byte_slice.into()) + if let Some(avatar) = avatar { + me.set_avatar(&app_state, app_state.config.bunny.cdn_url.clone(), avatar) .await?; } - if let Some(username) = &form.json.username { - me.set_username(&data, username.clone()).await?; + if let Some(username) = &json.username { + me.set_username(&app_state, username.clone()).await?; } - if let Some(display_name) = &form.json.display_name { - me.set_display_name(&data, display_name.clone()).await?; + if let Some(display_name) = &json.display_name { + me.set_display_name(&app_state, display_name.clone()) + .await?; } - if let Some(email) = &form.json.email { - me.set_email(&data, email.clone()).await?; + if let Some(email) = &json.email { + me.set_email(&app_state, email.clone()).await?; } - if let Some(pronouns) = &form.json.pronouns { - me.set_pronouns(&data, pronouns.clone()).await?; + if let Some(pronouns) = &json.pronouns { + me.set_pronouns(&app_state, pronouns.clone()).await?; } - if let Some(about) = &form.json.about { - me.set_about(&data, about.clone()).await?; + if let Some(about) = &json.about { + me.set_about(&app_state, about.clone()).await?; } - Ok(HttpResponse::Ok().finish()) + Ok(StatusCode::OK) } diff --git a/src/api/v1/mod.rs b/src/api/v1/mod.rs index 6c2df0b..4e8654b 100644 --- a/src/api/v1/mod.rs +++ b/src/api/v1/mod.rs @@ -1,6 +1,10 @@ //! `/api/v1` Contains version 1 of the api -use actix_web::{Scope, web}; +use std::sync::Arc; + +use axum::{routing::get, Router}; + +use crate::AppState; mod auth; mod channels; @@ -10,13 +14,13 @@ mod me; mod stats; mod users; -pub fn web() -> Scope { - web::scope("/v1") - .service(stats::res) - .service(auth::web()) - .service(users::web()) - .service(channels::web()) - .service(guilds::web()) - .service(invites::web()) - .service(me::web()) +pub fn router() -> Router> { + Router::new() + .route("/stats", get(stats::res)) + .nest("/auth", auth::router()) + .nest("/users", users::router()) + .nest("/channels", channels::router()) + .nest("/guilds", guilds::router()) + .nest("/invites", invites::router()) + .nest("/me", me::router()) } diff --git a/src/api/v1/stats.rs b/src/api/v1/stats.rs index 760ec71..17c5df6 100644 --- a/src/api/v1/stats.rs +++ b/src/api/v1/stats.rs @@ -1,13 +1,17 @@ //! `/api/v1/stats` Returns stats about the server +use std::sync::Arc; use std::time::SystemTime; -use actix_web::{HttpResponse, get, web}; +use axum::Json; +use axum::extract::State; +use axum::http::StatusCode; +use axum::response::IntoResponse; use diesel::QueryDsl; use diesel_async::RunQueryDsl; use serde::Serialize; -use crate::Data; +use crate::AppState; use crate::error::Error; use crate::schema::users::dsl::{users, uuid}; @@ -39,27 +43,26 @@ struct Response { /// "build_number": "39d01bb" /// }); /// ``` -#[get("/stats")] -pub async fn res(data: web::Data) -> Result { +pub async fn res(State(app_state): State>) -> Result { let accounts: i64 = users .select(uuid) .count() - .get_result(&mut data.pool.get().await?) + .get_result(&mut app_state.pool.get().await?) .await?; let response = Response { // TODO: Get number of accounts from db accounts, uptime: SystemTime::now() - .duration_since(data.start_time) + .duration_since(app_state.start_time) .expect("Seriously why dont you have time??") .as_secs(), version: String::from(VERSION.unwrap_or("UNKNOWN")), - registration_enabled: data.config.instance.registration, - email_verification_required: data.config.instance.require_email_verification, + registration_enabled: app_state.config.instance.registration, + email_verification_required: app_state.config.instance.require_email_verification, // TODO: Get build number from git hash or remove this from the spec build_number: String::from(GIT_SHORT_HASH), }; - Ok(HttpResponse::Ok().json(response)) + Ok((StatusCode::OK, Json(response))) } diff --git a/src/api/v1/users/mod.rs b/src/api/v1/users/mod.rs index 334fd5f..f0d09c5 100644 --- a/src/api/v1/users/mod.rs +++ b/src/api/v1/users/mod.rs @@ -1,19 +1,33 @@ //! `/api/v1/users` Contains endpoints related to all users -use actix_web::{HttpRequest, HttpResponse, Scope, get, web}; +use std::sync::Arc; + +use axum::{ + Json, Router, + extract::{Query, State}, + http::StatusCode, + response::IntoResponse, + routing::get, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; use crate::{ - Data, + AppState, api::v1::auth::check_access_token, error::Error, objects::{StartAmountQuery, User}, - utils::{get_auth_header, global_checks}, + utils::global_checks, }; mod uuid; -pub fn web() -> Scope { - web::scope("/users").service(get).service(uuid::get) +pub fn router() -> Router> { + Router::new() + .route("/", get(users)) + .route("/{uuid}", get(uuid::get)) } /// `GET /api/v1/users` Returns all users on this instance @@ -46,31 +60,26 @@ pub fn web() -> Scope { /// ]); /// ``` /// NOTE: UUIDs in this response are made using `uuidgen`, UUIDs made by the actual backend will be UUIDv7 and have extractable timestamps -#[get("")] -pub async fn get( - req: HttpRequest, - request_query: web::Query, - data: web::Data, -) -> Result { - let headers = req.headers(); - - let auth_header = get_auth_header(headers)?; - +pub async fn users( + State(app_state): State>, + Query(request_query): Query, + TypedHeader(auth): TypedHeader>, +) -> Result { let start = request_query.start.unwrap_or(0); let amount = request_query.amount.unwrap_or(10); if amount > 100 { - return Ok(HttpResponse::BadRequest().finish()); + return Ok(StatusCode::BAD_REQUEST.into_response()); } - let mut conn = data.pool.get().await?; + let mut conn = app_state.pool.get().await?; - let uuid = check_access_token(auth_header, &mut conn).await?; + let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&data, uuid).await?; + global_checks(&app_state, uuid).await?; let users = User::fetch_amount(&mut conn, start, amount).await?; - Ok(HttpResponse::Ok().json(users)) + Ok((StatusCode::OK, Json(users)).into_response()) } diff --git a/src/api/v1/users/uuid.rs b/src/api/v1/users/uuid.rs index 5d36b75..1b7d43b 100644 --- a/src/api/v1/users/uuid.rs +++ b/src/api/v1/users/uuid.rs @@ -1,14 +1,25 @@ //! `/api/v1/users/{uuid}` Specific user endpoints -use actix_web::{HttpRequest, HttpResponse, get, web}; +use std::sync::Arc; + +use axum::{ + Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, +}; use uuid::Uuid; use crate::{ - Data, + AppState, api::v1::auth::check_access_token, error::Error, objects::{Me, User}, - utils::{get_auth_header, global_checks}, + utils::global_checks, }; /// `GET /api/v1/users/{uuid}` Returns user with the given UUID @@ -27,27 +38,20 @@ use crate::{ /// }); /// ``` /// NOTE: UUIDs in this response are made using `uuidgen`, UUIDs made by the actual backend will be UUIDv7 and have extractable timestamps -#[get("/{uuid}")] pub async fn get( - req: HttpRequest, - path: web::Path<(Uuid,)>, - data: web::Data, -) -> Result { - let headers = req.headers(); + State(app_state): State>, + Path(user_uuid): Path, + TypedHeader(auth): TypedHeader>, +) -> Result { + let mut conn = app_state.pool.get().await?; - let user_uuid = path.into_inner().0; + let uuid = check_access_token(auth.token(), &mut conn).await?; - let auth_header = get_auth_header(headers)?; - - let mut conn = data.pool.get().await?; - - let uuid = check_access_token(auth_header, &mut conn).await?; - - global_checks(&data, uuid).await?; + global_checks(&app_state, uuid).await?; let me = Me::get(&mut conn, uuid).await?; - let user = User::fetch_one_with_friendship(&data, &me, user_uuid).await?; + let user = User::fetch_one_with_friendship(&app_state, &me, user_uuid).await?; - Ok(HttpResponse::Ok().json(user)) + Ok((StatusCode::OK, Json(user))) } diff --git a/src/api/versions.rs b/src/api/versions.rs index 0c3e106..3c9576b 100644 --- a/src/api/versions.rs +++ b/src/api/versions.rs @@ -1,5 +1,5 @@ //! `/api/v1/versions` Returns info about api versions -use actix_web::{HttpResponse, Responder, get}; +use axum::{Json, http::StatusCode, response::IntoResponse}; use serde::Serialize; #[derive(Serialize)] @@ -24,13 +24,12 @@ struct UnstableFeatures; /// ] /// }); /// ``` -#[get("/versions")] -pub async fn get() -> impl Responder { +pub async fn versions() -> impl IntoResponse { let response = Response { unstable_features: UnstableFeatures, // TODO: Find a way to dynamically update this possibly? versions: vec![String::from("1")], }; - HttpResponse::Ok().json(response) + (StatusCode::OK, Json(response)) } diff --git a/src/error.rs b/src/error.rs index 35b533d..1b8f27c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,12 +1,16 @@ use std::{io, time::SystemTimeError}; -use actix_web::{ - HttpResponse, - error::{PayloadError, ResponseError}, +use axum::{ + Json, + extract::{ + multipart::MultipartError, + rejection::{JsonRejection, QueryRejection}, + }, http::{ StatusCode, - header::{ContentType, ToStrError}, + header::{InvalidHeaderValue, ToStrError}, }, + response::IntoResponse, }; use bunny_api_tokio::error::Error as BunnyError; use deadpool::managed::{BuildError, PoolError}; @@ -54,9 +58,13 @@ pub enum Error { #[error(transparent)] UrlParseError(#[from] url::ParseError), #[error(transparent)] - PayloadError(#[from] PayloadError), + JsonRejection(#[from] JsonRejection), #[error(transparent)] - WsClosed(#[from] actix_ws::Closed), + QueryRejection(#[from] QueryRejection), + #[error(transparent)] + MultipartError(#[from] MultipartError), + #[error(transparent)] + InvalidHeaderValue(#[from] InvalidHeaderValue), #[error(transparent)] EmailError(#[from] EmailError), #[error(transparent)] @@ -77,26 +85,40 @@ pub enum Error { InternalServerError(String), } -impl ResponseError for Error { - fn error_response(&self) -> HttpResponse { +impl IntoResponse for Error { + fn into_response(self) -> axum::response::Response { + let error = match self { + Error::SqlError(DieselError::NotFound) => { + (StatusCode::NOT_FOUND, Json(WebError::new(self.to_string()))) + } + Error::BunnyError(BunnyError::NotFound(_)) => { + (StatusCode::NOT_FOUND, Json(WebError::new(self.to_string()))) + } + Error::BadRequest(_) => ( + StatusCode::BAD_REQUEST, + Json(WebError::new(self.to_string())), + ), + Error::Unauthorized(_) => ( + StatusCode::UNAUTHORIZED, + Json(WebError::new(self.to_string())), + ), + Error::Forbidden(_) => (StatusCode::FORBIDDEN, Json(WebError::new(self.to_string()))), + Error::TooManyRequests(_) => ( + StatusCode::TOO_MANY_REQUESTS, + Json(WebError::new(self.to_string())), + ), + _ => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(WebError::new(self.to_string())), + ), + }; + + let (code, _) = error; + debug!("{self:?}"); - error!("{}: {}", self.status_code(), self); + error!("{code}: {self}"); - HttpResponse::build(self.status_code()) - .insert_header(ContentType::json()) - .json(WebError::new(self.to_string())) - } - - fn status_code(&self) -> StatusCode { - match *self { - Error::SqlError(DieselError::NotFound) => StatusCode::NOT_FOUND, - Error::BunnyError(BunnyError::NotFound(_)) => StatusCode::NOT_FOUND, - Error::BadRequest(_) => StatusCode::BAD_REQUEST, - Error::Unauthorized(_) => StatusCode::UNAUTHORIZED, - Error::Forbidden(_) => StatusCode::FORBIDDEN, - Error::TooManyRequests(_) => StatusCode::TOO_MANY_REQUESTS, - _ => StatusCode::INTERNAL_SERVER_ERROR, - } + error.into_response() } } diff --git a/src/main.rs b/src/main.rs index 248289a..6bb2be3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,13 +1,13 @@ -use actix_cors::Cors; -use actix_web::{App, HttpServer, web}; use argon2::Argon2; +use axum::Router; use clap::Parser; use diesel_async::pooled_connection::AsyncDieselConnectionManager; use diesel_async::pooled_connection::deadpool::Pool; use error::Error; use objects::MailClient; -use simple_logger::SimpleLogger; -use std::time::SystemTime; +use socketioxide::SocketIo; +use std::{sync::Arc, time::SystemTime}; +use tower_http::cors::{Any, CorsLayer}; mod config; use config::{Config, ConfigBuilder}; use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations}; @@ -22,6 +22,7 @@ pub mod error; pub mod objects; pub mod schema; pub mod utils; +mod socket; #[derive(Parser, Debug)] #[command(version, about, long_about = None)] @@ -31,7 +32,7 @@ struct Args { } #[derive(Clone)] -pub struct Data { +pub struct AppState { pub pool: deadpool::managed::Pool< AsyncDieselConnectionManager, Conn, @@ -46,12 +47,14 @@ pub struct Data { #[tokio::main] async fn main() -> Result<(), Error> { - SimpleLogger::new() - .with_level(log::LevelFilter::Info) - .with_colors(true) - .env() - .init() - .unwrap(); + tracing_subscriber::fmt::init(); + + //SimpleLogger::new() + // .with_level(log::LevelFilter::Info) + // .with_colors(true) + // .env() + // .init() + // .unwrap(); let args = Args::parse(); let config = ConfigBuilder::load(args.config).await?.build(); @@ -112,7 +115,7 @@ async fn main() -> Result<(), Error> { ) */ - let data = Data { + let app_state = Arc::new(AppState { pool, cache_pool, config, @@ -121,42 +124,31 @@ async fn main() -> Result<(), Error> { start_time: SystemTime::now(), bunny_storage, mail_client, - }; + }); - HttpServer::new(move || { - // Set CORS headers - let cors = Cors::default() - /* - Set Allowed-Control-Allow-Origin header to whatever - the request's Origin header is. Must be done like this - rather than setting it to "*" due to CORS not allowing - sending of credentials (cookies) with wildcard origin. - */ - .allowed_origin_fn(|_origin, _req_head| true) - /* - Allows any request method in CORS preflight requests. - This will be restricted to only ones actually in use later. - */ - .allow_any_method() - /* - Allows any header(s) in request in CORS preflight requests. - This wll be restricted to only ones actually in use later. - */ - .allow_any_header() - /* - Allows browser to include cookies in requests. - This is needed for receiving the secure HttpOnly refresh_token cookie. - */ - .supports_credentials(); + let cors = CorsLayer::new() + // Allow any origin (equivalent to allowed_origin_fn returning true) + .allow_origin(Any) + // Allow any method + .allow_methods(Any) + // Allow any headers + .allow_headers(Any); - App::new() - .app_data(web::Data::new(data.clone())) - .wrap(cors) - .service(api::web(data.config.web.backend_url.path())) - }) - .bind((web.ip, web.port))? - .run() - .await?; + let (socket_io, io) = SocketIo::builder().with_state(app_state.clone()).build_layer(); + + io.ns("/", socket::on_connect); + + // build our application with a route + let app = Router::new() + // `GET /` goes to `root` + .nest(web.backend_url.path(), api::router()) + .with_state(app_state) + .layer(cors) + .layer(socket_io); + + // run our app with hyper, listening globally on port 3000 + let listener = tokio::net::TcpListener::bind(web.ip + ":" + &web.port.to_string()).await?; + axum::serve(listener, app).await?; Ok(()) } diff --git a/src/objects/channel.rs b/src/objects/channel.rs index 1192c69..3b34ac6 100644 --- a/src/objects/channel.rs +++ b/src/objects/channel.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::{ - Conn, Data, + AppState, Conn, error::Error, schema::{channel_permissions, channels, messages}, utils::{CHANNEL_REGEX, order_by_is_above}, @@ -105,12 +105,12 @@ impl Channel { futures::future::try_join_all(channel_futures).await } - pub async fn fetch_one(data: &Data, channel_uuid: Uuid) -> Result { - if let Ok(cache_hit) = data.get_cache_key(channel_uuid.to_string()).await { + pub async fn fetch_one(app_state: &AppState, channel_uuid: Uuid) -> Result { + if let Ok(cache_hit) = app_state.get_cache_key(channel_uuid.to_string()).await { return Ok(serde_json::from_str(&cache_hit)?); } - let mut conn = data.pool.get().await?; + let mut conn = app_state.pool.get().await?; use channels::dsl; let channel_builder: ChannelBuilder = dsl::channels @@ -121,14 +121,15 @@ impl Channel { let channel = channel_builder.build(&mut conn).await?; - data.set_cache_key(channel_uuid.to_string(), channel.clone(), 60) + app_state + .set_cache_key(channel_uuid.to_string(), channel.clone(), 60) .await?; Ok(channel) } pub async fn new( - data: actix_web::web::Data, + app_state: &AppState, guild_uuid: Uuid, name: String, description: Option, @@ -137,11 +138,11 @@ impl Channel { return Err(Error::BadRequest("Channel name is invalid".to_string())); } - let mut conn = data.pool.get().await?; + let mut conn = app_state.pool.get().await?; let channel_uuid = Uuid::now_v7(); - let channels = Self::fetch_all(&data.pool, guild_uuid).await?; + let channels = Self::fetch_all(&app_state.pool, guild_uuid).await?; let channels_ordered = order_by_is_above(channels).await?; @@ -179,22 +180,25 @@ impl Channel { permissions: vec![], }; - data.set_cache_key(channel_uuid.to_string(), channel.clone(), 1800) + app_state + .set_cache_key(channel_uuid.to_string(), channel.clone(), 1800) .await?; - if data + if app_state .get_cache_key(format!("{guild_uuid}_channels")) .await .is_ok() { - data.del_cache_key(format!("{guild_uuid}_channels")).await?; + app_state + .del_cache_key(format!("{guild_uuid}_channels")) + .await?; } Ok(channel) } - pub async fn delete(self, data: &Data) -> Result<(), Error> { - let mut conn = data.pool.get().await?; + pub async fn delete(self, app_state: &AppState) -> Result<(), Error> { + let mut conn = app_state.pool.get().await?; use channels::dsl; match update(channels::table) @@ -224,16 +228,17 @@ impl Channel { Err(e) => Err(e), }?; - if data.get_cache_key(self.uuid.to_string()).await.is_ok() { - data.del_cache_key(self.uuid.to_string()).await?; + if app_state.get_cache_key(self.uuid.to_string()).await.is_ok() { + app_state.del_cache_key(self.uuid.to_string()).await?; } - if data + if app_state .get_cache_key(format!("{}_channels", self.guild_uuid)) .await .is_ok() { - data.del_cache_key(format!("{}_channels", self.guild_uuid)) + app_state + .del_cache_key(format!("{}_channels", self.guild_uuid)) .await?; } @@ -242,11 +247,11 @@ impl Channel { pub async fn fetch_messages( &self, - data: &Data, + app_state: &AppState, amount: i64, offset: i64, ) -> Result, Error> { - let mut conn = data.pool.get().await?; + let mut conn = app_state.pool.get().await?; use messages::dsl; let messages: Vec = load_or_empty( @@ -260,14 +265,14 @@ impl Channel { .await, )?; - let message_futures = messages.iter().map(async move |b| b.build(data).await); + let message_futures = messages.iter().map(async move |b| b.build(app_state).await); futures::future::try_join_all(message_futures).await } pub async fn new_message( &self, - data: &Data, + app_state: &AppState, user_uuid: Uuid, message: String, reply_to: Option, @@ -282,22 +287,22 @@ impl Channel { reply_to, }; - let mut conn = data.pool.get().await?; + let mut conn = app_state.pool.get().await?; insert_into(messages::table) .values(message.clone()) .execute(&mut conn) .await?; - message.build(data).await + message.build(app_state).await } - pub async fn set_name(&mut self, data: &Data, new_name: String) -> Result<(), Error> { + pub async fn set_name(&mut self, app_state: &AppState, new_name: String) -> Result<(), Error> { if !CHANNEL_REGEX.is_match(&new_name) { return Err(Error::BadRequest("Channel name is invalid".to_string())); } - let mut conn = data.pool.get().await?; + let mut conn = app_state.pool.get().await?; use channels::dsl; update(channels::table) @@ -313,10 +318,10 @@ impl Channel { pub async fn set_description( &mut self, - data: &Data, + app_state: &AppState, new_description: String, ) -> Result<(), Error> { - let mut conn = data.pool.get().await?; + let mut conn = app_state.pool.get().await?; use channels::dsl; update(channels::table) @@ -330,8 +335,12 @@ impl Channel { Ok(()) } - pub async fn move_channel(&mut self, data: &Data, new_is_above: Uuid) -> Result<(), Error> { - let mut conn = data.pool.get().await?; + pub async fn move_channel( + &mut self, + app_state: &AppState, + new_is_above: Uuid, + ) -> Result<(), Error> { + let mut conn = app_state.pool.get().await?; use channels::dsl; let old_above_uuid: Option = match dsl::channels diff --git a/src/objects/email_token.rs b/src/objects/email_token.rs index bfd1ef5..64d2fdb 100644 --- a/src/objects/email_token.rs +++ b/src/objects/email_token.rs @@ -3,7 +3,7 @@ use lettre::message::MultiPart; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use crate::{Data, error::Error, utils::generate_token}; +use crate::{AppState, error::Error, utils::generate_token}; use super::Me; @@ -15,9 +15,9 @@ pub struct EmailToken { } impl EmailToken { - pub async fn get(data: &Data, user_uuid: Uuid) -> Result { + pub async fn get(app_state: &AppState, user_uuid: Uuid) -> Result { let email_token = serde_json::from_str( - &data + &app_state .get_cache_key(format!("{user_uuid}_email_verify")) .await?, )?; @@ -26,7 +26,7 @@ impl EmailToken { } #[allow(clippy::new_ret_no_self)] - pub async fn new(data: &Data, me: Me) -> Result<(), Error> { + pub async fn new(app_state: &AppState, me: Me) -> Result<(), Error> { let token = generate_token::<32>()?; let email_token = EmailToken { @@ -36,30 +36,32 @@ impl EmailToken { created_at: Utc::now(), }; - data.set_cache_key(format!("{}_email_verify", me.uuid), email_token, 86400) + app_state + .set_cache_key(format!("{}_email_verify", me.uuid), email_token, 86400) .await?; - let mut verify_endpoint = data.config.web.frontend_url.join("verify-email")?; + let mut verify_endpoint = app_state.config.web.frontend_url.join("verify-email")?; verify_endpoint.set_query(Some(&format!("token={token}"))); - let email = data + let email = app_state .mail_client .message_builder() .to(me.email.parse()?) - .subject(format!("{} E-mail Verification", data.config.instance.name)) + .subject(format!("{} E-mail Verification", app_state.config.instance.name)) .multipart(MultiPart::alternative_plain_html( - format!("Verify your {} account\n\nHello, {}!\nThanks for creating a new account on Gorb.\nThe final step to create your account is to verify your email address by visiting the page, within 24 hours.\n\n{}\n\nIf you didn't ask to verify this address, you can safely ignore this email\n\nThanks, The gorb team.", data.config.instance.name, me.username, verify_endpoint), - format!(r#"

Verify your {} Account

Hello, {}!

Thanks for creating a new account on Gorb.

The final step to create your account is to verify your email address by clicking the button below, within 24 hours.

VERIFY ACCOUNT

If you didn't ask to verify this address, you can safely ignore this email.

"#, data.config.instance.name, me.username, verify_endpoint) + format!("Verify your {} account\n\nHello, {}!\nThanks for creating a new account on Gorb.\nThe final step to create your account is to verify your email address by visiting the page, within 24 hours.\n\n{}\n\nIf you didn't ask to verify this address, you can safely ignore this email\n\nThanks, The gorb team.", app_state.config.instance.name, me.username, verify_endpoint), + format!(r#"

Verify your {} Account

Hello, {}!

Thanks for creating a new account on Gorb.

The final step to create your account is to verify your email address by clicking the button below, within 24 hours.

VERIFY ACCOUNT

If you didn't ask to verify this address, you can safely ignore this email.

"#, app_state.config.instance.name, me.username, verify_endpoint) ))?; - data.mail_client.send_mail(email).await?; + app_state.mail_client.send_mail(email).await?; Ok(()) } - pub async fn delete(&self, data: &Data) -> Result<(), Error> { - data.del_cache_key(format!("{}_email_verify", self.user_uuid)) + pub async fn delete(&self, app_state: &AppState) -> Result<(), Error> { + app_state + .del_cache_key(format!("{}_email_verify", self.user_uuid)) .await?; Ok(()) diff --git a/src/objects/guild.rs b/src/objects/guild.rs index aa01f54..e27e129 100644 --- a/src/objects/guild.rs +++ b/src/objects/guild.rs @@ -1,4 +1,4 @@ -use actix_web::web::BytesMut; +use axum::body::Bytes; use diesel::{ ExpressionMethods, Insertable, QueryDsl, Queryable, Selectable, SelectableHelper, insert_into, update, @@ -191,7 +191,7 @@ impl Guild { bunny_storage: &bunny_api_tokio::EdgeStorageClient, conn: &mut Conn, cdn_url: Url, - icon: BytesMut, + icon: Bytes, ) -> Result<(), Error> { let icon_clone = icon.clone(); let image_type = task::spawn_blocking(move || image_check(icon_clone)).await??; @@ -204,7 +204,7 @@ impl Guild { let path = format!("icons/{}/{}.{}", self.uuid, Uuid::now_v7(), image_type); - bunny_storage.upload(path.clone(), icon.into()).await?; + bunny_storage.upload(path.clone(), icon).await?; let icon_url = cdn_url.join(&path)?; diff --git a/src/objects/me.rs b/src/objects/me.rs index 37951ab..3b51da4 100644 --- a/src/objects/me.rs +++ b/src/objects/me.rs @@ -1,4 +1,4 @@ -use actix_web::web::BytesMut; +use axum::body::Bytes; use diesel::{ ExpressionMethods, QueryDsl, Queryable, Selectable, SelectableHelper, delete, insert_into, update, @@ -10,7 +10,7 @@ use url::Url; use uuid::Uuid; use crate::{ - Conn, Data, + AppState, Conn, error::Error, objects::{Friend, FriendRequest, User}, schema::{friend_requests, friends, guild_members, guilds, users}, @@ -75,28 +75,26 @@ impl Me { pub async fn set_avatar( &mut self, - data: &Data, + app_state: &AppState, cdn_url: Url, - avatar: BytesMut, + avatar: Bytes, ) -> Result<(), Error> { let avatar_clone = avatar.clone(); let image_type = task::spawn_blocking(move || image_check(avatar_clone)).await??; - let mut conn = data.pool.get().await?; + let mut conn = app_state.pool.get().await?; if let Some(avatar) = &self.avatar { let avatar_url: Url = avatar.parse()?; let relative_url = avatar_url.path().trim_start_matches('/'); - data.bunny_storage.delete(relative_url).await?; + app_state.bunny_storage.delete(relative_url).await?; } let path = format!("avatar/{}/{}.{}", self.uuid, Uuid::now_v7(), image_type); - data.bunny_storage - .upload(path.clone(), avatar.into()) - .await?; + app_state.bunny_storage.upload(path.clone(), avatar).await?; let avatar_url = cdn_url.join(&path)?; @@ -107,8 +105,8 @@ impl Me { .execute(&mut conn) .await?; - if data.get_cache_key(self.uuid.to_string()).await.is_ok() { - data.del_cache_key(self.uuid.to_string()).await? + if app_state.get_cache_key(self.uuid.to_string()).await.is_ok() { + app_state.del_cache_key(self.uuid.to_string()).await? } self.avatar = Some(avatar_url.to_string()); @@ -127,7 +125,11 @@ impl Me { Ok(()) } - pub async fn set_username(&mut self, data: &Data, new_username: String) -> Result<(), Error> { + pub async fn set_username( + &mut self, + app_state: &AppState, + new_username: String, + ) -> Result<(), Error> { if !USERNAME_REGEX.is_match(&new_username) || new_username.len() < 3 || new_username.len() > 32 @@ -135,7 +137,7 @@ impl Me { return Err(Error::BadRequest("Invalid username".to_string())); } - let mut conn = data.pool.get().await?; + let mut conn = app_state.pool.get().await?; use users::dsl; update(users::table) @@ -144,8 +146,8 @@ impl Me { .execute(&mut conn) .await?; - if data.get_cache_key(self.uuid.to_string()).await.is_ok() { - data.del_cache_key(self.uuid.to_string()).await? + if app_state.get_cache_key(self.uuid.to_string()).await.is_ok() { + app_state.del_cache_key(self.uuid.to_string()).await? } self.username = new_username; @@ -155,10 +157,10 @@ impl Me { pub async fn set_display_name( &mut self, - data: &Data, + app_state: &AppState, new_display_name: String, ) -> Result<(), Error> { - let mut conn = data.pool.get().await?; + let mut conn = app_state.pool.get().await?; let new_display_name_option = if new_display_name.is_empty() { None @@ -173,8 +175,8 @@ impl Me { .execute(&mut conn) .await?; - if data.get_cache_key(self.uuid.to_string()).await.is_ok() { - data.del_cache_key(self.uuid.to_string()).await? + if app_state.get_cache_key(self.uuid.to_string()).await.is_ok() { + app_state.del_cache_key(self.uuid.to_string()).await? } self.display_name = new_display_name_option; @@ -182,12 +184,16 @@ impl Me { Ok(()) } - pub async fn set_email(&mut self, data: &Data, new_email: String) -> Result<(), Error> { + pub async fn set_email( + &mut self, + app_state: &AppState, + new_email: String, + ) -> Result<(), Error> { if !EMAIL_REGEX.is_match(&new_email) { return Err(Error::BadRequest("Invalid username".to_string())); } - let mut conn = data.pool.get().await?; + let mut conn = app_state.pool.get().await?; use users::dsl; update(users::table) @@ -199,8 +205,8 @@ impl Me { .execute(&mut conn) .await?; - if data.get_cache_key(self.uuid.to_string()).await.is_ok() { - data.del_cache_key(self.uuid.to_string()).await? + if app_state.get_cache_key(self.uuid.to_string()).await.is_ok() { + app_state.del_cache_key(self.uuid.to_string()).await? } self.email = new_email; @@ -208,8 +214,12 @@ impl Me { Ok(()) } - pub async fn set_pronouns(&mut self, data: &Data, new_pronouns: String) -> Result<(), Error> { - let mut conn = data.pool.get().await?; + pub async fn set_pronouns( + &mut self, + app_state: &AppState, + new_pronouns: String, + ) -> Result<(), Error> { + let mut conn = app_state.pool.get().await?; use users::dsl; update(users::table) @@ -218,15 +228,19 @@ impl Me { .execute(&mut conn) .await?; - if data.get_cache_key(self.uuid.to_string()).await.is_ok() { - data.del_cache_key(self.uuid.to_string()).await? + if app_state.get_cache_key(self.uuid.to_string()).await.is_ok() { + app_state.del_cache_key(self.uuid.to_string()).await? } Ok(()) } - pub async fn set_about(&mut self, data: &Data, new_about: String) -> Result<(), Error> { - let mut conn = data.pool.get().await?; + pub async fn set_about( + &mut self, + app_state: &AppState, + new_about: String, + ) -> Result<(), Error> { + let mut conn = app_state.pool.get().await?; use users::dsl; update(users::table) @@ -235,8 +249,8 @@ impl Me { .execute(&mut conn) .await?; - if data.get_cache_key(self.uuid.to_string()).await.is_ok() { - data.del_cache_key(self.uuid.to_string()).await? + if app_state.get_cache_key(self.uuid.to_string()).await.is_ok() { + app_state.del_cache_key(self.uuid.to_string()).await? } Ok(()) @@ -352,10 +366,10 @@ impl Me { Ok(()) } - pub async fn get_friends(&self, data: &Data) -> Result, Error> { + pub async fn get_friends(&self, app_state: &AppState) -> Result, Error> { use friends::dsl; - let mut conn = data.pool.get().await?; + let mut conn = app_state.pool.get().await?; let friends1 = load_or_empty( dsl::friends @@ -374,13 +388,13 @@ impl Me { )?; let friend_futures = friends1.iter().map(async move |friend| { - User::fetch_one_with_friendship(data, self, friend.uuid2).await + User::fetch_one_with_friendship(app_state, self, friend.uuid2).await }); let mut friends = futures::future::try_join_all(friend_futures).await?; let friend_futures = friends2.iter().map(async move |friend| { - User::fetch_one_with_friendship(data, self, friend.uuid1).await + User::fetch_one_with_friendship(app_state, self, friend.uuid1).await }); friends.append(&mut futures::future::try_join_all(friend_futures).await?); diff --git a/src/objects/member.rs b/src/objects/member.rs index c2a71d9..50b76b0 100644 --- a/src/objects/member.rs +++ b/src/objects/member.rs @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::{ - Conn, Data, + AppState, Conn, error::Error, objects::{Me, Permissions, Role}, schema::guild_members, @@ -26,13 +26,13 @@ pub struct MemberBuilder { } impl MemberBuilder { - pub async fn build(&self, data: &Data, me: Option<&Me>) -> Result { + pub async fn build(&self, app_state: &AppState, me: Option<&Me>) -> Result { let user; if let Some(me) = me { - user = User::fetch_one_with_friendship(data, me, self.user_uuid).await?; + user = User::fetch_one_with_friendship(app_state, me, self.user_uuid).await?; } else { - user = User::fetch_one(data, self.user_uuid).await?; + user = User::fetch_one(app_state, self.user_uuid).await?; } Ok(Member { @@ -47,11 +47,11 @@ impl MemberBuilder { pub async fn check_permission( &self, - data: &Data, + app_state: &AppState, permission: Permissions, ) -> Result<(), Error> { if !self.is_owner { - let roles = Role::fetch_from_member(data, self.uuid).await?; + let roles = Role::fetch_from_member(app_state, self.uuid).await?; let allowed = roles.iter().any(|r| r.permissions & permission as i64 != 0); if !allowed { return Err(Error::Forbidden("Not allowed".to_string())); @@ -101,12 +101,12 @@ impl Member { } pub async fn fetch_one( - data: &Data, + app_state: &AppState, me: &Me, user_uuid: Uuid, guild_uuid: Uuid, ) -> Result { - let mut conn = data.pool.get().await?; + let mut conn = app_state.pool.get().await?; use guild_members::dsl; let member: MemberBuilder = dsl::guild_members @@ -116,11 +116,15 @@ impl Member { .get_result(&mut conn) .await?; - member.build(data, Some(me)).await + member.build(app_state, Some(me)).await } - pub async fn fetch_all(data: &Data, me: &Me, guild_uuid: Uuid) -> Result, Error> { - let mut conn = data.pool.get().await?; + pub async fn fetch_all( + app_state: &AppState, + me: &Me, + guild_uuid: Uuid, + ) -> Result, Error> { + let mut conn = app_state.pool.get().await?; use guild_members::dsl; let member_builders: Vec = load_or_empty( @@ -134,14 +138,18 @@ impl Member { let mut members = vec![]; for builder in member_builders { - members.push(builder.build(&data, Some(me)).await?); + members.push(builder.build(app_state, Some(me)).await?); } Ok(members) } - pub async fn new(data: &Data, user_uuid: Uuid, guild_uuid: Uuid) -> Result { - let mut conn = data.pool.get().await?; + pub async fn new( + app_state: &AppState, + user_uuid: Uuid, + guild_uuid: Uuid, + ) -> Result { + let mut conn = app_state.pool.get().await?; let member_uuid = Uuid::now_v7(); @@ -158,6 +166,6 @@ impl Member { .execute(&mut conn) .await?; - member.build(data, None).await + member.build(app_state, None).await } } diff --git a/src/objects/message.rs b/src/objects/message.rs index a887541..caff969 100644 --- a/src/objects/message.rs +++ b/src/objects/message.rs @@ -2,7 +2,7 @@ use diesel::{Insertable, Queryable, Selectable}; use serde::Serialize; use uuid::Uuid; -use crate::{Data, error::Error, schema::messages}; +use crate::{AppState, error::Error, schema::messages}; use super::User; @@ -18,8 +18,8 @@ pub struct MessageBuilder { } impl MessageBuilder { - pub async fn build(&self, data: &Data) -> Result { - let user = User::fetch_one(data, self.user_uuid).await?; + pub async fn build(&self, app_state: &AppState) -> Result { + let user = User::fetch_one(app_state, self.user_uuid).await?; Ok(Message { uuid: self.uuid, diff --git a/src/objects/mod.rs b/src/objects/mod.rs index 9974410..4af16d8 100644 --- a/src/objects/mod.rs +++ b/src/objects/mod.rs @@ -42,6 +42,37 @@ pub trait HasUuid { pub trait HasIsAbove { fn is_above(&self) -> Option<&Uuid>; } +/* +pub trait Cookies { + fn cookies(&self) -> CookieJar; + fn cookie>(&self, cookie: T) -> Option; +} + +impl Cookies for Request { + fn cookies(&self) -> CookieJar { + let cookies = self.headers() + .get(axum::http::header::COOKIE) + .and_then(|value| value.to_str().ok()) + .map(|s| Cookie::split_parse(s.to_string())) + .and_then(|c| c.collect::, cookie::ParseError>>().ok()) + .unwrap_or(vec![]); + + let mut cookie_jar = CookieJar::new(); + + for cookie in cookies { + cookie_jar.add(cookie) + } + + cookie_jar + } + + fn cookie>(&self, cookie: T) -> Option { + self.cookies() + .get(cookie.as_ref()) + .and_then(|c| Some(c.to_owned())) + } +} +*/ fn load_or_empty( query_result: Result, diesel::result::Error>, diff --git a/src/objects/password_reset_token.rs b/src/objects/password_reset_token.rs index 7f714ef..04ff43c 100644 --- a/src/objects/password_reset_token.rs +++ b/src/objects/password_reset_token.rs @@ -10,7 +10,7 @@ use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::{ - Data, + AppState, error::Error, schema::users, utils::{PASSWORD_REGEX, generate_token, global_checks, user_uuid_from_identifier}, @@ -24,10 +24,11 @@ pub struct PasswordResetToken { } impl PasswordResetToken { - pub async fn get(data: &Data, token: String) -> Result { - let user_uuid: Uuid = serde_json::from_str(&data.get_cache_key(token.to_string()).await?)?; + pub async fn get(app_state: &AppState, token: String) -> Result { + let user_uuid: Uuid = + serde_json::from_str(&app_state.get_cache_key(token.to_string()).await?)?; let password_reset_token = serde_json::from_str( - &data + &app_state .get_cache_key(format!("{user_uuid}_password_reset")) .await?, )?; @@ -36,15 +37,15 @@ impl PasswordResetToken { } pub async fn get_with_identifier( - data: &Data, + app_state: &AppState, identifier: String, ) -> Result { - let mut conn = data.pool.get().await?; + let mut conn = app_state.pool.get().await?; let user_uuid = user_uuid_from_identifier(&mut conn, &identifier).await?; let password_reset_token = serde_json::from_str( - &data + &app_state .get_cache_key(format!("{user_uuid}_password_reset")) .await?, )?; @@ -53,14 +54,14 @@ impl PasswordResetToken { } #[allow(clippy::new_ret_no_self)] - pub async fn new(data: &Data, identifier: String) -> Result<(), Error> { + pub async fn new(app_state: &AppState, identifier: String) -> Result<(), Error> { let token = generate_token::<32>()?; - let mut conn = data.pool.get().await?; + let mut conn = app_state.pool.get().await?; let user_uuid = user_uuid_from_identifier(&mut conn, &identifier).await?; - global_checks(data, user_uuid).await?; + global_checks(app_state, user_uuid).await?; use users::dsl as udsl; let (username, email_address): (String, String) = udsl::users @@ -75,34 +76,37 @@ impl PasswordResetToken { created_at: Utc::now(), }; - data.set_cache_key( - format!("{user_uuid}_password_reset"), - password_reset_token, - 86400, - ) - .await?; - data.set_cache_key(token.clone(), user_uuid, 86400).await?; + app_state + .set_cache_key( + format!("{user_uuid}_password_reset"), + password_reset_token, + 86400, + ) + .await?; + app_state + .set_cache_key(token.clone(), user_uuid, 86400) + .await?; - let mut reset_endpoint = data.config.web.frontend_url.join("reset-password")?; + let mut reset_endpoint = app_state.config.web.frontend_url.join("reset-password")?; reset_endpoint.set_query(Some(&format!("token={token}"))); - let email = data + let email = app_state .mail_client .message_builder() .to(email_address.parse()?) - .subject(format!("{} Password Reset", data.config.instance.name)) + .subject(format!("{} Password Reset", app_state.config.instance.name)) .multipart(MultiPart::alternative_plain_html( - format!("{} Password Reset\n\nHello, {}!\nSomeone requested a password reset for your Gorb account.\nClick the button below within 24 hours to reset your password.\n\n{}\n\nIf you didn't request a password reset, don't worry, your account is safe and you can safely ignore this email.\n\nThanks, The gorb team.", data.config.instance.name, username, reset_endpoint), - format!(r#"

{} Password Reset

Hello, {}!

Someone requested a password reset for your Gorb account.

Click the button below within 24 hours to reset your password.

RESET PASSWORD

If you didn't request a password reset, don't worry, your account is safe and you can safely ignore this email.

"#, data.config.instance.name, username, reset_endpoint) + format!("{} Password Reset\n\nHello, {}!\nSomeone requested a password reset for your Gorb account.\nClick the button below within 24 hours to reset your password.\n\n{}\n\nIf you didn't request a password reset, don't worry, your account is safe and you can safely ignore this email.\n\nThanks, The gorb team.", app_state.config.instance.name, username, reset_endpoint), + format!(r#"

{} Password Reset

Hello, {}!

Someone requested a password reset for your Gorb account.

Click the button below within 24 hours to reset your password.

RESET PASSWORD

If you didn't request a password reset, don't worry, your account is safe and you can safely ignore this email.

"#, app_state.config.instance.name, username, reset_endpoint) ))?; - data.mail_client.send_mail(email).await?; + app_state.mail_client.send_mail(email).await?; Ok(()) } - pub async fn set_password(&self, data: &Data, password: String) -> Result<(), Error> { + pub async fn set_password(&self, app_state: &AppState, password: String) -> Result<(), Error> { if !PASSWORD_REGEX.is_match(&password) { return Err(Error::BadRequest( "Please provide a valid password".to_string(), @@ -111,12 +115,12 @@ impl PasswordResetToken { let salt = SaltString::generate(&mut OsRng); - let hashed_password = data + let hashed_password = app_state .argon2 .hash_password(password.as_bytes(), &salt) .map_err(|e| Error::PasswordHashError(e.to_string()))?; - let mut conn = data.pool.get().await?; + let mut conn = app_state.pool.get().await?; use users::dsl; update(users::table) @@ -131,27 +135,28 @@ impl PasswordResetToken { .get_result(&mut conn) .await?; - let login_page = data.config.web.frontend_url.join("login")?; + let login_page = app_state.config.web.frontend_url.join("login")?; - let email = data + let email = app_state .mail_client .message_builder() .to(email_address.parse()?) - .subject(format!("Your {} Password has been Reset", data.config.instance.name)) + .subject(format!("Your {} Password has been Reset", app_state.config.instance.name)) .multipart(MultiPart::alternative_plain_html( - format!("{} Password Reset Confirmation\n\nHello, {}!\nYour password has been successfully reset for your Gorb account.\nIf you did not initiate this change, please click the link below to reset your password immediately.\n\n{}\n\nThanks, The gorb team.", data.config.instance.name, username, login_page), - format!(r#"

{} Password Reset Confirmation

Hello, {}!

Your password has been successfully reset for your Gorb account.

If you did not initiate this change, please click the button below to reset your password immediately.

RESET PASSWORD
"#, data.config.instance.name, username, login_page) + format!("{} Password Reset Confirmation\n\nHello, {}!\nYour password has been successfully reset for your Gorb account.\nIf you did not initiate this change, please click the link below to reset your password immediately.\n\n{}\n\nThanks, The gorb team.", app_state.config.instance.name, username, login_page), + format!(r#"

{} Password Reset Confirmation

Hello, {}!

Your password has been successfully reset for your Gorb account.

If you did not initiate this change, please click the button below to reset your password immediately.

RESET PASSWORD
"#, app_state.config.instance.name, username, login_page) ))?; - data.mail_client.send_mail(email).await?; + app_state.mail_client.send_mail(email).await?; - self.delete(data).await + self.delete(app_state).await } - pub async fn delete(&self, data: &Data) -> Result<(), Error> { - data.del_cache_key(format!("{}_password_reset", &self.user_uuid)) + pub async fn delete(&self, app_state: &AppState) -> Result<(), Error> { + app_state + .del_cache_key(format!("{}_password_reset", &self.user_uuid)) .await?; - data.del_cache_key(self.token.to_string()).await?; + app_state.del_cache_key(self.token.to_string()).await?; Ok(()) } diff --git a/src/objects/role.rs b/src/objects/role.rs index 68e9c27..ea70686 100644 --- a/src/objects/role.rs +++ b/src/objects/role.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize}; use uuid::Uuid; use crate::{ - Conn, Data, + AppState, Conn, error::Error, schema::{role_members, roles}, utils::order_by_is_above, @@ -74,12 +74,18 @@ impl Role { Ok(roles) } - pub async fn fetch_from_member(data: &Data, member_uuid: Uuid) -> Result, Error> { - if let Ok(roles) = data.get_cache_key(format!("{member_uuid}_roles")).await { + pub async fn fetch_from_member( + app_state: &AppState, + member_uuid: Uuid, + ) -> Result, Error> { + if let Ok(roles) = app_state + .get_cache_key(format!("{member_uuid}_roles")) + .await + { return Ok(serde_json::from_str(&roles)?); } - let mut conn = data.pool.get().await?; + let mut conn = app_state.pool.get().await?; use role_members::dsl; let role_memberships: Vec = load_or_empty( @@ -96,7 +102,8 @@ impl Role { roles.push(membership.fetch_role(&mut conn).await?); } - data.set_cache_key(format!("{member_uuid}_roles"), roles.clone(), 300) + app_state + .set_cache_key(format!("{member_uuid}_roles"), roles.clone(), 300) .await?; Ok(roles) diff --git a/src/objects/user.rs b/src/objects/user.rs index 8e42351..c1f164d 100644 --- a/src/objects/user.rs +++ b/src/objects/user.rs @@ -4,7 +4,7 @@ use diesel_async::RunQueryDsl; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use crate::{Conn, Data, error::Error, objects::Me, schema::users}; +use crate::{AppState, Conn, error::Error, objects::Me, schema::users}; use super::load_or_empty; @@ -46,10 +46,10 @@ pub struct User { } impl User { - pub async fn fetch_one(data: &Data, user_uuid: Uuid) -> Result { - let mut conn = data.pool.get().await?; + pub async fn fetch_one(app_state: &AppState, user_uuid: Uuid) -> Result { + let mut conn = app_state.pool.get().await?; - if let Ok(cache_hit) = data.get_cache_key(user_uuid.to_string()).await { + if let Ok(cache_hit) = app_state.get_cache_key(user_uuid.to_string()).await { return Ok(serde_json::from_str(&cache_hit)?); } @@ -62,20 +62,21 @@ impl User { let user = user_builder.build(); - data.set_cache_key(user_uuid.to_string(), user.clone(), 1800) + app_state + .set_cache_key(user_uuid.to_string(), user.clone(), 1800) .await?; Ok(user) } pub async fn fetch_one_with_friendship( - data: &Data, + app_state: &AppState, me: &Me, user_uuid: Uuid, ) -> Result { - let mut conn = data.pool.get().await?; + let mut conn = app_state.pool.get().await?; - let mut user = Self::fetch_one(data, user_uuid).await?; + let mut user = Self::fetch_one(app_state, user_uuid).await?; if let Some(friend) = me.friends_with(&mut conn, user_uuid).await? { user.friends_since = Some(friend.accepted_at); diff --git a/src/socket.rs b/src/socket.rs new file mode 100644 index 0000000..e00a7c0 --- /dev/null +++ b/src/socket.rs @@ -0,0 +1,26 @@ +use std::sync::Arc; + +use log::info; +use rmpv::Value; +use socketioxide::{ + extract::{AckSender, Data, SocketRef, State}, +}; + +use crate::AppState; + +pub async fn on_connect(State(app_state): State>, socket: SocketRef, Data(data): Data) { + socket.emit("auth", &data).ok(); + + socket.on("message", async |socket: SocketRef, Data::(data)| { + info!("{}", data); + socket.emit("message-back", &data).ok(); + }); + + socket.on( + "message-with-ack", + async |Data::(data), ack: AckSender| { + info!("{}", data); + ack.send(&data).ok(); + }, + ); +} diff --git a/src/utils.rs b/src/utils.rs index 072143f..6083188 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,10 +1,7 @@ use std::sync::LazyLock; -use actix_web::{ - cookie::{Cookie, SameSite, time::Duration}, - http::header::HeaderMap, - web::BytesMut, -}; +use axum::body::Bytes; +use axum_extra::extract::cookie::{Cookie, SameSite}; use bindet::FileType; use diesel::{ExpressionMethods, QueryDsl}; use diesel_async::RunQueryDsl; @@ -13,10 +10,11 @@ use hex::encode; use redis::RedisError; use regex::Regex; use serde::Serialize; +use time::Duration; use uuid::Uuid; use crate::{ - Conn, Data, + AppState, Conn, config::Config, error::Error, objects::{HasIsAbove, HasUuid}, @@ -33,86 +31,26 @@ pub static USERNAME_REGEX: LazyLock = pub static CHANNEL_REGEX: LazyLock = LazyLock::new(|| Regex::new(r"^[a-z0-9_.-]+$").unwrap()); -// Password is expected to be hashed using SHA3-384 pub static PASSWORD_REGEX: LazyLock = LazyLock::new(|| Regex::new(r"[0-9a-f]{96}").unwrap()); -pub fn get_auth_header(headers: &HeaderMap) -> Result<&str, Error> { - let auth_token = headers.get(actix_web::http::header::AUTHORIZATION); - - if auth_token.is_none() { - return Err(Error::Unauthorized( - "No authorization header provided".to_string(), - )); - } - - let auth_raw = auth_token.unwrap().to_str()?; - - let mut auth = auth_raw.split_whitespace(); - - let auth_type = auth.next(); - - let auth_value = auth.next(); - - if auth_type.is_none() { - return Err(Error::BadRequest( - "Authorization header is empty".to_string(), - )); - } else if auth_type.is_some_and(|at| at != "Bearer") { - return Err(Error::BadRequest( - "Only token auth is supported".to_string(), - )); - } - - if auth_value.is_none() { - return Err(Error::BadRequest("No token provided".to_string())); - } - - Ok(auth_value.unwrap()) -} - -pub fn get_ws_protocol_header(headers: &HeaderMap) -> Result<&str, Error> { - let auth_token = headers.get(actix_web::http::header::SEC_WEBSOCKET_PROTOCOL); - - if auth_token.is_none() { - return Err(Error::Unauthorized( - "No authorization header provided".to_string(), - )); - } - - let auth_raw = auth_token.unwrap().to_str()?; - - let mut auth = auth_raw.split_whitespace(); - - let response_proto = auth.next(); - - let auth_value = auth.next(); - - if response_proto.is_none() { - return Err(Error::BadRequest( - "Sec-WebSocket-Protocol header is empty".to_string(), - )); - } else if response_proto.is_some_and(|rp| rp != "Authorization,") { - return Err(Error::BadRequest( - "First protocol should be Authorization".to_string(), - )); - } - - if auth_value.is_none() { - return Err(Error::BadRequest("No token provided".to_string())); - } - - Ok(auth_value.unwrap()) -} - -pub fn new_refresh_token_cookie(config: &Config, refresh_token: String) -> Cookie<'static> { - Cookie::build("refresh_token", refresh_token) +pub fn new_refresh_token_cookie(config: &Config, refresh_token: String) -> Cookie { + Cookie::build(("refresh_token", refresh_token)) .http_only(true) .secure(true) .same_site(SameSite::None) - //.domain(config.web.backend_url.domain().unwrap().to_string()) .path(config.web.backend_url.path().to_string()) .max_age(Duration::days(30)) - .finish() + .build() +} + +pub fn new_access_token_cookie(config: &Config, access_token: String) -> Cookie { + Cookie::build(("access_token", access_token)) + .http_only(false) + .secure(true) + .same_site(SameSite::None) + .path(config.web.backend_url.path().to_string()) + .max_age(Duration::hours(1)) + .build() } pub fn generate_token() -> Result { @@ -121,7 +59,7 @@ pub fn generate_token() -> Result { Ok(encode(buf)) } -pub fn image_check(icon: BytesMut) -> Result { +pub fn image_check(icon: Bytes) -> Result { let buf = std::io::Cursor::new(icon); let detect = bindet::detect(buf).map_err(|e| e.kind()); @@ -168,10 +106,7 @@ pub async fn user_uuid_from_identifier( } } -pub async fn user_uuid_from_username( - conn: &mut Conn, - username: &String, -) -> Result { +pub async fn user_uuid_from_username(conn: &mut Conn, username: &String) -> Result { if USERNAME_REGEX.is_match(username) { use users::dsl; let user_uuid = dsl::users @@ -188,9 +123,9 @@ pub async fn user_uuid_from_username( } } -pub async fn global_checks(data: &Data, user_uuid: Uuid) -> Result<(), Error> { - if data.config.instance.require_email_verification { - let mut conn = data.pool.get().await?; +pub async fn global_checks(app_state: &AppState, user_uuid: Uuid) -> Result<(), Error> { + if app_state.config.instance.require_email_verification { + let mut conn = app_state.pool.get().await?; use users::dsl; let email_verified: bool = dsl::users @@ -234,7 +169,7 @@ where Ok(ordered) } -impl Data { +impl AppState { pub async fn set_cache_key( &self, key: String, -- 2.47.3 From 1946080716a72f733ac1aa424cbc1c869eb21429 Mon Sep 17 00:00:00 2001 From: Radical Date: Thu, 17 Jul 2025 16:07:09 +0200 Subject: [PATCH 06/19] ci: remove parentheses from name --- .woodpecker/build-and-publish.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.woodpecker/build-and-publish.yml b/.woodpecker/build-and-publish.yml index 9836311..4e263a9 100644 --- a/.woodpecker/build-and-publish.yml +++ b/.woodpecker/build-and-publish.yml @@ -36,7 +36,7 @@ steps: - branch: main event: push - - name: container-build-and-publish (staging) + - name: container-build-and-publish-staging image: docker commands: - docker login --username radical --password $PASSWORD git.gorb.app -- 2.47.3 From 9a0ebf2b2fce0d624945fff518560cf65695fd1e Mon Sep 17 00:00:00 2001 From: Radical Date: Thu, 17 Jul 2025 16:48:34 +0200 Subject: [PATCH 07/19] fix: use merge instead of nesting --- src/api/mod.rs | 6 +++--- src/main.rs | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/api/mod.rs b/src/api/mod.rs index e4c3f2e..a00d1e5 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -9,8 +9,8 @@ use crate::AppState; mod v1; mod versions; -pub fn router() -> Router> { +pub fn router(path: &str) -> Router> { Router::new() - .route("/versions", get(versions::versions)) - .nest("/v1", v1::router()) + .route(&format!("{path}/versions"), get(versions::versions)) + .nest(&format!("{path}/v1"), v1::router()) } diff --git a/src/main.rs b/src/main.rs index 6bb2be3..15bae09 100644 --- a/src/main.rs +++ b/src/main.rs @@ -141,7 +141,7 @@ async fn main() -> Result<(), Error> { // build our application with a route let app = Router::new() // `GET /` goes to `root` - .nest(web.backend_url.path(), api::router()) + .merge(api::router(web.backend_url.path().trim_end_matches("/"))) .with_state(app_state) .layer(cors) .layer(socket_io); -- 2.47.3 From 8f53c9f718e6be24075427a95e42e6e17ca441f6 Mon Sep 17 00:00:00 2001 From: Radical Date: Thu, 17 Jul 2025 21:34:35 +0200 Subject: [PATCH 08/19] fix: try to fix up cors Login still not working, unsure of where failure point is --- src/main.rs | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/main.rs b/src/main.rs index 15bae09..9624d18 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,7 +7,7 @@ use error::Error; use objects::MailClient; use socketioxide::SocketIo; use std::{sync::Arc, time::SystemTime}; -use tower_http::cors::{Any, CorsLayer}; +use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin, CorsLayer}; mod config; use config::{Config, ConfigBuilder}; use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations}; @@ -128,11 +128,21 @@ async fn main() -> Result<(), Error> { let cors = CorsLayer::new() // Allow any origin (equivalent to allowed_origin_fn returning true) - .allow_origin(Any) + .allow_origin(AllowOrigin::predicate(|_origin, _request_head| { + true + })) // Allow any method - .allow_methods(Any) + .allow_methods(AllowMethods::mirror_request()) // Allow any headers - .allow_headers(Any); + .allow_headers(AllowHeaders::mirror_request()) + /* + vec![ + "content-type".parse().unwrap(), + "authorization".parse().unwrap(), + ] + */ + // Allow credentials + .allow_credentials(true); let (socket_io, io) = SocketIo::builder().with_state(app_state.clone()).build_layer(); -- 2.47.3 From d67a7ce0ca8afbd02a8768a63e24dff84ed762f2 Mon Sep 17 00:00:00 2001 From: Radical Date: Fri, 18 Jul 2025 12:00:28 +0200 Subject: [PATCH 09/19] fix: try explicitly setting methods and headers --- src/main.rs | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/main.rs b/src/main.rs index 9624d18..baf4a61 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,5 @@ use argon2::Argon2; -use axum::Router; +use axum::{http::header, Router}; use clap::Parser; use diesel_async::pooled_connection::AsyncDieselConnectionManager; use diesel_async::pooled_connection::deadpool::Pool; @@ -131,16 +131,22 @@ async fn main() -> Result<(), Error> { .allow_origin(AllowOrigin::predicate(|_origin, _request_head| { true })) - // Allow any method - .allow_methods(AllowMethods::mirror_request()) - // Allow any headers - .allow_headers(AllowHeaders::mirror_request()) - /* - vec![ - "content-type".parse().unwrap(), - "authorization".parse().unwrap(), - ] - */ + .allow_methods(AllowMethods::list([ + "GET".parse().unwrap(), + "POST".parse().unwrap(), + "PUT".parse().unwrap(), + "PATCH".parse().unwrap(), + "DELETE".parse().unwrap(), + "OPTIONS".parse().unwrap(), + ])) + .allow_headers(AllowHeaders::list([ + header::AUTHORIZATION, + header::CONTENT_TYPE, + header::ORIGIN, + header::ACCEPT, + header::COOKIE, + "x-requested-with".parse().unwrap(), + ])) // Allow credentials .allow_credentials(true); -- 2.47.3 From 2fbf41ba8cdaefee91ef252de04bb2aefdff2dce Mon Sep 17 00:00:00 2001 From: Radical Date: Sat, 19 Jul 2025 19:10:36 +0200 Subject: [PATCH 10/19] fix: use .append() and not Set-Cookie2 web dev is too confusing.. --- src/api/v1/auth/login.rs | 6 +++--- src/api/v1/auth/logout.rs | 2 +- src/api/v1/auth/refresh.rs | 26 +++++++++----------------- src/api/v1/auth/register.rs | 2 +- 4 files changed, 14 insertions(+), 22 deletions(-) diff --git a/src/api/v1/auth/login.rs b/src/api/v1/auth/login.rs index 2391fdf..7779564 100644 --- a/src/api/v1/auth/login.rs +++ b/src/api/v1/auth/login.rs @@ -95,15 +95,15 @@ pub async fn response( let mut response = StatusCode::OK.into_response(); - response.headers_mut().insert( + response.headers_mut().append( "Set-Cookie", HeaderValue::from_str( &new_refresh_token_cookie(&app_state.config, refresh_token).to_string(), )?, ); - response.headers_mut().insert( - "Set-Cookie2", + response.headers_mut().append( + "Set-Cookie", HeaderValue::from_str( &new_access_token_cookie(&app_state.config, access_token).to_string(), )?, diff --git a/src/api/v1/auth/logout.rs b/src/api/v1/auth/logout.rs index 6e5e98d..906afcc 100644 --- a/src/api/v1/auth/logout.rs +++ b/src/api/v1/auth/logout.rs @@ -68,7 +68,7 @@ pub async fn res( cookie.make_removal(); response .headers_mut() - .append("Set-Cookie2", HeaderValue::from_str(&cookie.to_string())?); + .append("Set-Cookie", HeaderValue::from_str(&cookie.to_string())?); } Ok(response) diff --git a/src/api/v1/auth/refresh.rs b/src/api/v1/auth/refresh.rs index 2a7e611..b104a8e 100644 --- a/src/api/v1/auth/refresh.rs +++ b/src/api/v1/auth/refresh.rs @@ -71,7 +71,7 @@ pub async fn post( cookie.make_removal(); response .headers_mut() - .append("Set-Cookie2", HeaderValue::from_str(&cookie.to_string())?); + .append("Set-Cookie", HeaderValue::from_str(&cookie.to_string())?); } return Ok(response); @@ -119,21 +119,13 @@ pub async fn post( .execute(&mut conn) .await?; - if response.headers().get("Set-Cookie").is_some() { - response.headers_mut().append( - "Set-Cookie2", - HeaderValue::from_str( - &new_access_token_cookie(&app_state.config, access_token).to_string(), - )?, - ); - } else { - response.headers_mut().append( - "Set-Cookie", - HeaderValue::from_str( - &new_access_token_cookie(&app_state.config, access_token).to_string(), - )?, - ); - } + + response.headers_mut().append( + "Set-Cookie", + HeaderValue::from_str( + &new_access_token_cookie(&app_state.config, access_token).to_string(), + )?, + ); return Ok(response); } @@ -151,7 +143,7 @@ pub async fn post( cookie.make_removal(); response .headers_mut() - .append("Set-Cookie2", HeaderValue::from_str(&cookie.to_string())?); + .append("Set-Cookie", HeaderValue::from_str(&cookie.to_string())?); } Ok(response) diff --git a/src/api/v1/auth/register.rs b/src/api/v1/auth/register.rs index 06b63ca..237f1e0 100644 --- a/src/api/v1/auth/register.rs +++ b/src/api/v1/auth/register.rs @@ -169,7 +169,7 @@ pub async fn post( )?, ); response.headers_mut().append( - "Set-Cookie2", + "Set-Cookie", HeaderValue::from_str( &new_access_token_cookie(&app_state.config, access_token).to_string(), )?, -- 2.47.3 From 252b9a3dc652fd28d5bbb8a4359121433f637adc Mon Sep 17 00:00:00 2001 From: Radical Date: Sat, 19 Jul 2025 23:03:23 +0200 Subject: [PATCH 11/19] fix: add more cors shit can someone please just make cors disappear? god i hate this shit. --- src/main.rs | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/main.rs b/src/main.rs index baf4a61..73110c4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,5 @@ use argon2::Argon2; -use axum::{http::header, Router}; +use axum::{http::{header, Method}, Router}; use clap::Parser; use diesel_async::pooled_connection::AsyncDieselConnectionManager; use diesel_async::pooled_connection::deadpool::Pool; @@ -131,22 +131,28 @@ async fn main() -> Result<(), Error> { .allow_origin(AllowOrigin::predicate(|_origin, _request_head| { true })) - .allow_methods(AllowMethods::list([ - "GET".parse().unwrap(), - "POST".parse().unwrap(), - "PUT".parse().unwrap(), - "PATCH".parse().unwrap(), - "DELETE".parse().unwrap(), - "OPTIONS".parse().unwrap(), - ])) - .allow_headers(AllowHeaders::list([ + .allow_methods(vec![ + Method::GET, + Method::POST, + Method::PUT, + Method::DELETE, + Method::HEAD, + Method::OPTIONS, + Method::CONNECT, + Method::PATCH, + Method::TRACE, + ]) + .allow_headers(vec![ + header::ACCEPT, + header::ACCEPT_LANGUAGE, header::AUTHORIZATION, + header::CONTENT_LANGUAGE, header::CONTENT_TYPE, header::ORIGIN, header::ACCEPT, header::COOKIE, "x-requested-with".parse().unwrap(), - ])) + ]) // Allow credentials .allow_credentials(true); -- 2.47.3 From d2fec66ddbcc3a5739aec630322c7ee2e019d7e3 Mon Sep 17 00:00:00 2001 From: Radical Date: Sat, 19 Jul 2025 23:20:16 +0200 Subject: [PATCH 12/19] fix: try not setting path on access token --- src/api/v1/auth/login.rs | 2 +- src/api/v1/auth/refresh.rs | 2 +- src/api/v1/auth/register.rs | 2 +- src/utils.rs | 3 +-- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/api/v1/auth/login.rs b/src/api/v1/auth/login.rs index 7779564..995e299 100644 --- a/src/api/v1/auth/login.rs +++ b/src/api/v1/auth/login.rs @@ -105,7 +105,7 @@ pub async fn response( response.headers_mut().append( "Set-Cookie", HeaderValue::from_str( - &new_access_token_cookie(&app_state.config, access_token).to_string(), + &new_access_token_cookie(access_token).to_string(), )?, ); diff --git a/src/api/v1/auth/refresh.rs b/src/api/v1/auth/refresh.rs index b104a8e..d6bc3a9 100644 --- a/src/api/v1/auth/refresh.rs +++ b/src/api/v1/auth/refresh.rs @@ -123,7 +123,7 @@ pub async fn post( response.headers_mut().append( "Set-Cookie", HeaderValue::from_str( - &new_access_token_cookie(&app_state.config, access_token).to_string(), + &new_access_token_cookie(access_token).to_string(), )?, ); diff --git a/src/api/v1/auth/register.rs b/src/api/v1/auth/register.rs index 237f1e0..9f05b04 100644 --- a/src/api/v1/auth/register.rs +++ b/src/api/v1/auth/register.rs @@ -171,7 +171,7 @@ pub async fn post( response.headers_mut().append( "Set-Cookie", HeaderValue::from_str( - &new_access_token_cookie(&app_state.config, access_token).to_string(), + &new_access_token_cookie(access_token).to_string(), )?, ); diff --git a/src/utils.rs b/src/utils.rs index 6083188..7cda5b3 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -43,12 +43,11 @@ pub fn new_refresh_token_cookie(config: &Config, refresh_token: String) -> Cooki .build() } -pub fn new_access_token_cookie(config: &Config, access_token: String) -> Cookie { +pub fn new_access_token_cookie<'cookie>(access_token: String) -> Cookie<'cookie> { Cookie::build(("access_token", access_token)) .http_only(false) .secure(true) .same_site(SameSite::None) - .path(config.web.backend_url.path().to_string()) .max_age(Duration::hours(1)) .build() } -- 2.47.3 From 9bf435b5350cf179318fba400d51337a62a62954 Mon Sep 17 00:00:00 2001 From: Radical Date: Sat, 19 Jul 2025 23:39:56 +0200 Subject: [PATCH 13/19] fix: revert changes to access_token made during refactor --- src/api/v1/auth/login.rs | 12 +++-------- src/api/v1/auth/mod.rs | 8 +++++++ src/api/v1/auth/refresh.rs | 42 +++++++++---------------------------- src/api/v1/auth/register.rs | 11 +++------- src/main.rs | 2 +- src/utils.rs | 9 -------- 6 files changed, 25 insertions(+), 59 deletions(-) diff --git a/src/api/v1/auth/login.rs b/src/api/v1/auth/login.rs index 995e299..d5cba95 100644 --- a/src/api/v1/auth/login.rs +++ b/src/api/v1/auth/login.rs @@ -14,12 +14,13 @@ use diesel::{ExpressionMethods, QueryDsl, dsl::insert_into}; use diesel_async::RunQueryDsl; use serde::Deserialize; +use super::Response; use crate::{ AppState, error::Error, schema::*, utils::{ - PASSWORD_REGEX, generate_token, new_access_token_cookie, new_refresh_token_cookie, + PASSWORD_REGEX, generate_token, new_refresh_token_cookie, user_uuid_from_identifier, }, }; @@ -93,7 +94,7 @@ pub async fn response( .execute(&mut conn) .await?; - let mut response = StatusCode::OK.into_response(); + let mut response = (StatusCode::OK, Json(Response { access_token })).into_response(); response.headers_mut().append( "Set-Cookie", @@ -102,12 +103,5 @@ pub async fn response( )?, ); - response.headers_mut().append( - "Set-Cookie", - HeaderValue::from_str( - &new_access_token_cookie(access_token).to_string(), - )?, - ); - Ok(response) } diff --git a/src/api/v1/auth/mod.rs b/src/api/v1/auth/mod.rs index 88be220..59d7a8e 100644 --- a/src/api/v1/auth/mod.rs +++ b/src/api/v1/auth/mod.rs @@ -9,6 +9,7 @@ use axum::{ }; use diesel::{ExpressionMethods, QueryDsl}; use diesel_async::RunQueryDsl; +use serde::Serialize; use uuid::Uuid; use crate::{AppState, Conn, error::Error, schema::access_tokens::dsl}; @@ -22,6 +23,13 @@ mod reset_password; mod revoke; mod verify_email; + +#[derive(Serialize)] +pub struct Response { + access_token: String, +} + + pub fn router() -> Router> { Router::new() .route("/register", post(register::post)) diff --git a/src/api/v1/auth/refresh.rs b/src/api/v1/auth/refresh.rs index d6bc3a9..e9709ed 100644 --- a/src/api/v1/auth/refresh.rs +++ b/src/api/v1/auth/refresh.rs @@ -1,7 +1,7 @@ use axum::{ extract::State, http::{HeaderValue, StatusCode}, - response::IntoResponse, + response::IntoResponse, Json, }; use axum_extra::extract::CookieJar; use diesel::{ExpressionMethods, QueryDsl, delete, update}; @@ -12,6 +12,7 @@ use std::{ time::{SystemTime, UNIX_EPOCH}, }; +use super::Response; use crate::{ AppState, error::Error, @@ -19,7 +20,7 @@ use crate::{ access_tokens::{self, dsl}, refresh_tokens::{self, dsl as rdsl}, }, - utils::{generate_token, new_access_token_cookie, new_refresh_token_cookie}, + utils::{generate_token, new_refresh_token_cookie}, }; pub async fn post( @@ -33,9 +34,7 @@ pub async fn post( ))? .to_owned(); - let access_token_cookie = jar.get("access_token"); - - let refresh_token = String::from(refresh_token_cookie.value_trimmed()); + let mut refresh_token = String::from(refresh_token_cookie.value_trimmed()); let current_time = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() as i64; @@ -66,21 +65,11 @@ pub async fn post( HeaderValue::from_str(&refresh_token_cookie.to_string())?, ); - if let Some(cookie) = access_token_cookie { - let mut cookie = cookie.clone(); - cookie.make_removal(); - response - .headers_mut() - .append("Set-Cookie", HeaderValue::from_str(&cookie.to_string())?); - } - return Ok(response); } let current_time = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() as i64; - let mut response = StatusCode::OK.into_response(); - if lifetime > 1987200 { let new_refresh_token = generate_token::<32>()?; @@ -94,13 +83,7 @@ pub async fn post( .await { Ok(_) => { - response.headers_mut().append( - "Set-Cookie", - HeaderValue::from_str( - &new_refresh_token_cookie(&app_state.config, new_refresh_token) - .to_string(), - )?, - ); + refresh_token = new_refresh_token; } Err(error) => { error!("{error}"); @@ -119,13 +102,16 @@ pub async fn post( .execute(&mut conn) .await?; - + let mut response = (StatusCode::OK, Json(Response { access_token })).into_response(); + + // TODO: Dont set this when refresh token is unchanged response.headers_mut().append( "Set-Cookie", HeaderValue::from_str( - &new_access_token_cookie(access_token).to_string(), + &new_refresh_token_cookie(&app_state.config, refresh_token).to_string(), )?, ); + return Ok(response); } @@ -138,13 +124,5 @@ pub async fn post( HeaderValue::from_str(&refresh_token_cookie.to_string())?, ); - if let Some(cookie) = access_token_cookie { - let mut cookie = cookie.clone(); - cookie.make_removal(); - response - .headers_mut() - .append("Set-Cookie", HeaderValue::from_str(&cookie.to_string())?); - } - Ok(response) } diff --git a/src/api/v1/auth/register.rs b/src/api/v1/auth/register.rs index 9f05b04..c190821 100644 --- a/src/api/v1/auth/register.rs +++ b/src/api/v1/auth/register.rs @@ -18,6 +18,7 @@ use diesel_async::RunQueryDsl; use serde::{Deserialize, Serialize}; use uuid::Uuid; +use super::Response; use crate::{ AppState, error::Error, @@ -28,7 +29,7 @@ use crate::{ users::{self, dsl as udsl}, }, utils::{ - EMAIL_REGEX, PASSWORD_REGEX, USERNAME_REGEX, generate_token, new_access_token_cookie, + EMAIL_REGEX, PASSWORD_REGEX, USERNAME_REGEX, generate_token, new_refresh_token_cookie, }, }; @@ -160,7 +161,7 @@ pub async fn post( Member::new(&app_state, uuid, initial_guild).await?; } - let mut response = StatusCode::OK.into_response(); + let mut response = (StatusCode::OK, Json(Response {access_token})).into_response(); response.headers_mut().append( "Set-Cookie", @@ -168,12 +169,6 @@ pub async fn post( &new_refresh_token_cookie(&app_state.config, refresh_token).to_string(), )?, ); - response.headers_mut().append( - "Set-Cookie", - HeaderValue::from_str( - &new_access_token_cookie(access_token).to_string(), - )?, - ); return Ok(response); } diff --git a/src/main.rs b/src/main.rs index 73110c4..ab37924 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,7 +7,7 @@ use error::Error; use objects::MailClient; use socketioxide::SocketIo; use std::{sync::Arc, time::SystemTime}; -use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin, CorsLayer}; +use tower_http::cors::{AllowOrigin, CorsLayer}; mod config; use config::{Config, ConfigBuilder}; use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations}; diff --git a/src/utils.rs b/src/utils.rs index 7cda5b3..0f986a2 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -43,15 +43,6 @@ pub fn new_refresh_token_cookie(config: &Config, refresh_token: String) -> Cooki .build() } -pub fn new_access_token_cookie<'cookie>(access_token: String) -> Cookie<'cookie> { - Cookie::build(("access_token", access_token)) - .http_only(false) - .secure(true) - .same_site(SameSite::None) - .max_age(Duration::hours(1)) - .build() -} - pub fn generate_token() -> Result { let mut buf = [0u8; N]; fill(&mut buf)?; -- 2.47.3 From dada230e08426590c13e1edb1215f4331f9a42fe Mon Sep 17 00:00:00 2001 From: Radical Date: Sun, 20 Jul 2025 13:04:08 +0200 Subject: [PATCH 14/19] fix: remove the rest of the leftover code from access_token cookies --- src/api/v1/auth/logout.rs | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/api/v1/auth/logout.rs b/src/api/v1/auth/logout.rs index 906afcc..977d452 100644 --- a/src/api/v1/auth/logout.rs +++ b/src/api/v1/auth/logout.rs @@ -38,8 +38,6 @@ pub async fn res( ))? .to_owned(); - let access_token_cookie = jar.get("access_token"); - let refresh_token = String::from(refresh_token_cookie.value_trimmed()); let mut conn = app_state.pool.get().await?; @@ -63,13 +61,5 @@ pub async fn res( HeaderValue::from_str(&refresh_token_cookie.to_string())?, ); - if let Some(cookie) = access_token_cookie { - let mut cookie = cookie.clone(); - cookie.make_removal(); - response - .headers_mut() - .append("Set-Cookie", HeaderValue::from_str(&cookie.to_string())?); - } - Ok(response) } -- 2.47.3 From 1ad88725bd4f09d45483cc0f7cc575f0f3f30410 Mon Sep 17 00:00:00 2001 From: Radical Date: Sun, 20 Jul 2025 14:12:57 +0200 Subject: [PATCH 15/19] feat: use custom middleware for authorization --- src/api/mod.rs | 6 ++--- src/api/v1/auth/devices.rs | 20 ++++----------- src/api/v1/auth/mod.rs | 35 ++++++++++++++++++++------ src/api/v1/auth/revoke.rs | 17 +++---------- src/api/v1/auth/verify_email.rs | 24 +++++------------- src/api/v1/channels/mod.rs | 5 ---- src/api/v1/channels/uuid/messages.rs | 23 +++-------------- src/api/v1/channels/uuid/mod.rs | 36 +++++++-------------------- src/api/v1/guilds/mod.rs | 30 +++++----------------- src/api/v1/guilds/uuid/channels.rs | 31 +++++------------------ src/api/v1/guilds/uuid/invites/mod.rs | 27 ++++++-------------- src/api/v1/guilds/uuid/members.rs | 19 ++++---------- src/api/v1/guilds/uuid/mod.rs | 28 ++++++--------------- src/api/v1/guilds/uuid/roles/mod.rs | 25 ++++++------------- src/api/v1/guilds/uuid/roles/uuid.rs | 19 ++++---------- src/api/v1/invites/id.rs | 20 +++++---------- src/api/v1/me/friends/mod.rs | 29 ++++++--------------- src/api/v1/me/friends/uuid.rs | 16 ++++-------- src/api/v1/me/guilds.rs | 17 +++++-------- src/api/v1/me/mod.rs | 29 ++++++--------------- src/api/v1/mod.rs | 16 +++++++----- src/api/v1/users/mod.rs | 25 ++++--------------- src/api/v1/users/uuid.rs | 23 +++-------------- src/main.rs | 2 +- 24 files changed, 157 insertions(+), 365 deletions(-) diff --git a/src/api/mod.rs b/src/api/mod.rs index a00d1e5..988ee45 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -2,15 +2,15 @@ use std::sync::Arc; -use axum::{Router, routing::get}; +use axum::{routing::get, Router}; use crate::AppState; mod v1; mod versions; -pub fn router(path: &str) -> Router> { +pub fn router(path: &str, app_state: Arc) -> Router> { Router::new() .route(&format!("{path}/versions"), get(versions::versions)) - .nest(&format!("{path}/v1"), v1::router()) + .nest(&format!("{path}/v1"), v1::router(app_state)) } diff --git a/src/api/v1/auth/devices.rs b/src/api/v1/auth/devices.rs index a3c12d1..336a52f 100644 --- a/src/api/v1/auth/devices.rs +++ b/src/api/v1/auth/devices.rs @@ -2,20 +2,14 @@ use std::sync::Arc; -use axum::{Json, extract::State, http::StatusCode, response::IntoResponse}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, -}; +use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json}; use diesel::{ExpressionMethods, QueryDsl, Queryable, Selectable, SelectableHelper}; use diesel_async::RunQueryDsl; use serde::Serialize; +use uuid::Uuid; use crate::{ - AppState, - api::v1::auth::check_access_token, - error::Error, - schema::refresh_tokens::{self, dsl}, + api::v1::auth::CurrentUser, error::Error, schema::refresh_tokens::{self, dsl}, AppState }; #[derive(Serialize, Selectable, Queryable)] @@ -42,16 +36,12 @@ struct Device { /// ``` pub async fn get( State(app_state): State>, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - let devices: Vec = dsl::refresh_tokens .filter(dsl::uuid.eq(uuid)) .select(Device::as_select()) - .get_results(&mut conn) + .get_results(&mut app_state.pool.get().await?) .await?; Ok((StatusCode::OK, Json(devices))) diff --git a/src/api/v1/auth/mod.rs b/src/api/v1/auth/mod.rs index 59d7a8e..899d6d2 100644 --- a/src/api/v1/auth/mod.rs +++ b/src/api/v1/auth/mod.rs @@ -4,9 +4,9 @@ use std::{ }; use axum::{ - Router, - routing::{delete, get, post}, + extract::{Request, State}, middleware::{from_fn_with_state, Next}, response::IntoResponse, routing::{delete, get, post}, Router }; +use axum_extra::{headers::{authorization::Bearer, Authorization}, TypedHeader}; use diesel::{ExpressionMethods, QueryDsl}; use diesel_async::RunQueryDsl; use serde::Serialize; @@ -30,18 +30,22 @@ pub struct Response { } -pub fn router() -> Router> { +pub fn router(app_state: Arc) -> Router> { + let router_with_auth = Router::new() + .route("/verify-email", get(verify_email::get)) + .route("/verify-email", post(verify_email::post)) + .route("/revoke", post(revoke::post)) + .route("/devices", get(devices::get)) + .layer(from_fn_with_state(app_state, CurrentUser::check_auth_layer)); + Router::new() .route("/register", post(register::post)) .route("/login", post(login::response)) .route("/logout", delete(logout::res)) .route("/refresh", post(refresh::post)) - .route("/revoke", post(revoke::post)) - .route("/verify-email", get(verify_email::get)) - .route("/verify-email", post(verify_email::post)) .route("/reset-password", get(reset_password::get)) .route("/reset-password", post(reset_password::post)) - .route("/devices", get(devices::get)) + .merge(router_with_auth) } pub async fn check_access_token(access_token: &str, conn: &mut Conn) -> Result { @@ -68,3 +72,20 @@ pub async fn check_access_token(access_token: &str, conn: &mut Conn) -> Result(pub Uuid); + +impl CurrentUser { + pub async fn check_auth_layer( + State(app_state): State>, + TypedHeader(auth): TypedHeader>, + mut req: Request, + next: Next + ) -> Result { + let current_user = CurrentUser(check_access_token(auth.token(), &mut app_state.pool.get().await?).await?); + + req.extensions_mut().insert(current_user); + Ok(next.run(req).await) + } +} diff --git a/src/api/v1/auth/revoke.rs b/src/api/v1/auth/revoke.rs index 50aa6d2..b59172e 100644 --- a/src/api/v1/auth/revoke.rs +++ b/src/api/v1/auth/revoke.rs @@ -1,21 +1,14 @@ use std::sync::Arc; use argon2::{PasswordHash, PasswordVerifier}; -use axum::{Json, extract::State, http::StatusCode, response::IntoResponse}; -use axum_extra::{ - TypedHeader, - headers::authorization::{Authorization, Bearer}, -}; +use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json}; use diesel::{ExpressionMethods, QueryDsl, delete}; use diesel_async::RunQueryDsl; use serde::Deserialize; +use uuid::Uuid; use crate::{ - AppState, - api::v1::auth::check_access_token, - error::Error, - schema::refresh_tokens::{self, dsl as rdsl}, - schema::users::dsl as udsl, + api::v1::auth::CurrentUser, error::Error, schema::{refresh_tokens::{self, dsl as rdsl}, users::dsl as udsl}, AppState }; #[derive(Deserialize)] @@ -28,13 +21,11 @@ pub struct RevokeRequest { #[axum::debug_handler] pub async fn post( State(app_state): State>, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, Json(revoke_request): Json, ) -> Result { let mut conn = app_state.pool.get().await?; - let uuid = check_access_token(auth.token(), &mut conn).await?; - let database_password: String = udsl::users .filter(udsl::uuid.eq(uuid)) .select(udsl::password) diff --git a/src/api/v1/auth/verify_email.rs b/src/api/v1/auth/verify_email.rs index 28aa1ab..1270966 100644 --- a/src/api/v1/auth/verify_email.rs +++ b/src/api/v1/auth/verify_email.rs @@ -5,20 +5,14 @@ use std::sync::Arc; use axum::{ extract::{Query, State}, http::StatusCode, - response::IntoResponse, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + response::IntoResponse, Extension, }; use chrono::{Duration, Utc}; use serde::Deserialize; +use uuid::Uuid; use crate::{ - AppState, - api::v1::auth::check_access_token, - error::Error, - objects::{EmailToken, Me}, + api::v1::auth::CurrentUser, error::Error, objects::{EmailToken, Me}, AppState }; #[derive(Deserialize)] @@ -47,12 +41,10 @@ pub struct QueryParams { pub async fn get( State(app_state): State>, Query(query): Query, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension> ) -> Result { let mut conn = app_state.pool.get().await?; - let uuid = check_access_token(auth.token(), &mut conn).await?; - let me = Me::get(&mut conn, uuid).await?; if me.email_verified { @@ -87,13 +79,9 @@ pub async fn get( /// pub async fn post( State(app_state): State>, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension> ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - - let me = Me::get(&mut conn, uuid).await?; + let me = Me::get(&mut app_state.pool.get().await?, uuid).await?; if me.email_verified { return Ok(StatusCode::NO_CONTENT); diff --git a/src/api/v1/channels/mod.rs b/src/api/v1/channels/mod.rs index dc82b86..24b62f7 100644 --- a/src/api/v1/channels/mod.rs +++ b/src/api/v1/channels/mod.rs @@ -11,14 +11,9 @@ use crate::AppState; mod uuid; pub fn router() -> Router> { - //let (layer, io) = SocketIo::new_layer(); - - //io.ns("/{uuid}/socket", uuid::socket::ws); - Router::new() .route("/{uuid}", get(uuid::get)) .route("/{uuid}", delete(uuid::delete)) .route("/{uuid}", patch(uuid::patch)) .route("/{uuid}/messages", get(uuid::messages::get)) - //.layer(layer) } diff --git a/src/api/v1/channels/uuid/messages.rs b/src/api/v1/channels/uuid/messages.rs index 8c12ee0..0297bbc 100644 --- a/src/api/v1/channels/uuid/messages.rs +++ b/src/api/v1/channels/uuid/messages.rs @@ -3,22 +3,11 @@ use std::sync::Arc; use crate::{ - AppState, - api::v1::auth::check_access_token, - error::Error, - objects::{Channel, Member}, - utils::global_checks, + api::v1::auth::CurrentUser, error::Error, objects::{Channel, Member}, utils::global_checks, AppState }; use ::uuid::Uuid; use axum::{ - Json, - extract::{Path, Query, State}, - http::StatusCode, - response::IntoResponse, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{Path, Query, State}, http::StatusCode, response::IntoResponse, Extension, Json }; use serde::Deserialize; @@ -62,17 +51,13 @@ pub async fn get( State(app_state): State>, Path(channel_uuid): Path, Query(message_request): Query, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; let channel = Channel::fetch_one(&app_state, channel_uuid).await?; - Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; + Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid).await?; let messages = channel .fetch_messages(&app_state, message_request.amount, message_request.offset) diff --git a/src/api/v1/channels/uuid/mod.rs b/src/api/v1/channels/uuid/mod.rs index 3ce91c3..c1560f0 100644 --- a/src/api/v1/channels/uuid/mod.rs +++ b/src/api/v1/channels/uuid/mod.rs @@ -7,38 +7,28 @@ use std::sync::Arc; use crate::{ AppState, - api::v1::auth::check_access_token, + api::v1::auth::CurrentUser, error::Error, objects::{Channel, Member, Permissions}, utils::global_checks, }; use axum::{ - Json, - extract::{Path, State}, - http::StatusCode, - response::IntoResponse, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json }; + use serde::Deserialize; use uuid::Uuid; pub async fn get( State(app_state): State>, Path(channel_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; let channel = Channel::fetch_one(&app_state, channel_uuid).await?; - Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; + Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid).await?; Ok((StatusCode::OK, Json(channel))) } @@ -46,17 +36,13 @@ pub async fn get( pub async fn delete( State(app_state): State>, Path(channel_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; let channel = Channel::fetch_one(&app_state, channel_uuid).await?; - let member = Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; + let member = Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid).await?; member .check_permission(&app_state, Permissions::ManageChannel) @@ -108,18 +94,14 @@ pub struct NewInfo { pub async fn patch( State(app_state): State>, Path(channel_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, Json(new_info): Json, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; let mut channel = Channel::fetch_one(&app_state, channel_uuid).await?; - let member = Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; + let member = Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid).await?; member .check_permission(&app_state, Permissions::ManageChannel) diff --git a/src/api/v1/guilds/mod.rs b/src/api/v1/guilds/mod.rs index 18a117f..dbee589 100644 --- a/src/api/v1/guilds/mod.rs +++ b/src/api/v1/guilds/mod.rs @@ -3,26 +3,15 @@ use std::sync::Arc; use axum::{ - Json, Router, - extract::State, - http::StatusCode, - response::IntoResponse, - routing::{get, post}, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::State, http::StatusCode, response::IntoResponse, routing::{get, post}, Extension, Json, Router }; use serde::Deserialize; +use ::uuid::Uuid; mod uuid; use crate::{ - AppState, - api::v1::auth::check_access_token, - error::Error, - objects::{Guild, StartAmountQuery}, - utils::global_checks, + api::v1::auth::CurrentUser, error::Error, objects::{Guild, StartAmountQuery}, utils::global_checks, AppState }; #[derive(Deserialize)] @@ -63,14 +52,10 @@ pub fn router() -> Router> { /// NOTE: UUIDs in this response are made using `uuidgen`, UUIDs made by the actual backend will be UUIDv7 and have extractable timestamps pub async fn new( State(app_state): State>, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, Json(guild_info): Json, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - - let guild = Guild::new(&mut conn, guild_info.name.clone(), uuid).await?; + let guild = Guild::new(&mut app_state.pool.get().await?, guild_info.name.clone(), uuid).await?; Ok((StatusCode::OK, Json(guild))) } @@ -124,15 +109,12 @@ pub async fn new( /// NOTE: UUIDs in this response are made using `uuidgen`, UUIDs made by the actual backend will be UUIDv7 and have extractable timestamps pub async fn get_guilds( State(app_state): State>, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, Json(request_query): Json, ) -> Result { let start = request_query.start.unwrap_or(0); - let amount = request_query.amount.unwrap_or(10); - let uuid = check_access_token(auth.token(), &mut app_state.pool.get().await?).await?; - global_checks(&app_state, uuid).await?; let guilds = Guild::fetch_amount(&app_state.pool, start, amount).await?; diff --git a/src/api/v1/guilds/uuid/channels.rs b/src/api/v1/guilds/uuid/channels.rs index 0104566..a28aa6c 100644 --- a/src/api/v1/guilds/uuid/channels.rs +++ b/src/api/v1/guilds/uuid/channels.rs @@ -2,23 +2,12 @@ use std::sync::Arc; use ::uuid::Uuid; use axum::{ - Json, - extract::{Path, State}, - http::StatusCode, - response::IntoResponse, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json }; use serde::Deserialize; use crate::{ - AppState, - api::v1::auth::check_access_token, - error::Error, - objects::{Channel, Member, Permissions}, - utils::{global_checks, order_by_is_above}, + api::v1::auth::CurrentUser, error::Error, objects::{Channel, Member, Permissions}, utils::{global_checks, order_by_is_above}, AppState }; #[derive(Deserialize)] @@ -30,15 +19,11 @@ pub struct ChannelInfo { pub async fn get( State(app_state): State>, Path(guild_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; - Member::check_membership(&mut conn, uuid, guild_uuid).await?; + Member::check_membership(&mut app_state.pool.get().await?, uuid, guild_uuid).await?; if let Ok(cache_hit) = app_state .get_cache_key(format!("{guild_uuid}_channels")) @@ -65,16 +50,12 @@ pub async fn get( pub async fn create( State(app_state): State>, Path(guild_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, Json(channel_info): Json, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; - let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; + let member = Member::check_membership(&mut app_state.pool.get().await?, uuid, guild_uuid).await?; member .check_permission(&app_state, Permissions::ManageChannel) diff --git a/src/api/v1/guilds/uuid/invites/mod.rs b/src/api/v1/guilds/uuid/invites/mod.rs index 7703cf7..2070452 100644 --- a/src/api/v1/guilds/uuid/invites/mod.rs +++ b/src/api/v1/guilds/uuid/invites/mod.rs @@ -1,21 +1,14 @@ use std::sync::Arc; use axum::{ - Json, - extract::{Path, State}, - http::StatusCode, - response::IntoResponse, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json }; use serde::Deserialize; use uuid::Uuid; use crate::{ AppState, - api::v1::auth::check_access_token, + api::v1::auth::CurrentUser, error::Error, objects::{Guild, Member, Permissions}, utils::global_checks, @@ -29,14 +22,12 @@ pub struct InviteRequest { pub async fn get( State(app_state): State>, Path(guild_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; + Member::check_membership(&mut conn, uuid, guild_uuid).await?; let guild = Guild::fetch_one(&mut conn, guild_uuid).await?; @@ -49,15 +40,13 @@ pub async fn get( pub async fn create( State(app_state): State>, Path(guild_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, Json(invite_request): Json, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; + let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; member diff --git a/src/api/v1/guilds/uuid/members.rs b/src/api/v1/guilds/uuid/members.rs index bd2f853..6c8b980 100644 --- a/src/api/v1/guilds/uuid/members.rs +++ b/src/api/v1/guilds/uuid/members.rs @@ -2,19 +2,12 @@ use std::sync::Arc; use ::uuid::Uuid; use axum::{ - Json, - extract::{Path, State}, - http::StatusCode, - response::IntoResponse, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json }; use crate::{ AppState, - api::v1::auth::check_access_token, + api::v1::auth::CurrentUser, error::Error, objects::{Me, Member}, utils::global_checks, @@ -23,14 +16,12 @@ use crate::{ pub async fn get( State(app_state): State>, Path(guild_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; + Member::check_membership(&mut conn, uuid, guild_uuid).await?; let me = Me::get(&mut conn, uuid).await?; diff --git a/src/api/v1/guilds/uuid/mod.rs b/src/api/v1/guilds/uuid/mod.rs index 0a27123..c5a809f 100644 --- a/src/api/v1/guilds/uuid/mod.rs +++ b/src/api/v1/guilds/uuid/mod.rs @@ -3,15 +3,7 @@ use std::sync::Arc; use axum::{ - Json, Router, - extract::{Multipart, Path, State}, - http::StatusCode, - response::IntoResponse, - routing::{get, patch, post}, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{Multipart, Path, State}, http::StatusCode, response::IntoResponse, routing::{get, patch, post}, Extension, Json, Router }; use bytes::Bytes; use uuid::Uuid; @@ -23,7 +15,7 @@ mod roles; use crate::{ AppState, - api::v1::auth::check_access_token, + api::v1::auth::CurrentUser, error::Error, objects::{Guild, Member, Permissions}, utils::global_checks, @@ -84,14 +76,12 @@ pub fn router() -> Router> { pub async fn get_guild( State(app_state): State>, Path(guild_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; + Member::check_membership(&mut conn, uuid, guild_uuid).await?; let guild = Guild::fetch_one(&mut conn, guild_uuid).await?; @@ -105,15 +95,13 @@ pub async fn get_guild( pub async fn edit( State(app_state): State>, Path(guild_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, mut multipart: Multipart, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; + let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; member diff --git a/src/api/v1/guilds/uuid/roles/mod.rs b/src/api/v1/guilds/uuid/roles/mod.rs index 12960c2..5331143 100644 --- a/src/api/v1/guilds/uuid/roles/mod.rs +++ b/src/api/v1/guilds/uuid/roles/mod.rs @@ -2,20 +2,13 @@ use std::sync::Arc; use ::uuid::Uuid; use axum::{ - Json, - extract::{Path, State}, - http::StatusCode, - response::IntoResponse, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json }; use serde::Deserialize; use crate::{ AppState, - api::v1::auth::check_access_token, + api::v1::auth::CurrentUser, error::Error, objects::{Member, Permissions, Role}, utils::{global_checks, order_by_is_above}, @@ -31,11 +24,11 @@ pub struct RoleInfo { pub async fn get( State(app_state): State>, Path(guild_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; + global_checks(&app_state, uuid).await?; - let uuid = check_access_token(auth.token(), &mut conn).await?; + let mut conn = app_state.pool.get().await?; Member::check_membership(&mut conn, uuid, guild_uuid).await?; @@ -57,15 +50,13 @@ pub async fn get( pub async fn create( State(app_state): State>, Path(guild_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, Json(role_info): Json, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; + let member = Member::check_membership(&mut conn, uuid, guild_uuid).await?; member diff --git a/src/api/v1/guilds/uuid/roles/uuid.rs b/src/api/v1/guilds/uuid/roles/uuid.rs index a62a5b4..91300bf 100644 --- a/src/api/v1/guilds/uuid/roles/uuid.rs +++ b/src/api/v1/guilds/uuid/roles/uuid.rs @@ -2,19 +2,12 @@ use std::sync::Arc; use ::uuid::Uuid; use axum::{ - Json, - extract::{Path, State}, - http::StatusCode, - response::IntoResponse, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json }; use crate::{ AppState, - api::v1::auth::check_access_token, + api::v1::auth::CurrentUser, error::Error, objects::{Member, Role}, utils::global_checks, @@ -23,14 +16,12 @@ use crate::{ pub async fn get( State(app_state): State>, Path((guild_uuid, role_uuid)): Path<(Uuid, Uuid)>, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; + Member::check_membership(&mut conn, uuid, guild_uuid).await?; if let Ok(cache_hit) = app_state.get_cache_key(format!("{role_uuid}")).await { diff --git a/src/api/v1/invites/id.rs b/src/api/v1/invites/id.rs index b832557..c752177 100644 --- a/src/api/v1/invites/id.rs +++ b/src/api/v1/invites/id.rs @@ -1,19 +1,13 @@ use std::sync::Arc; use axum::{ - Json, - extract::{Path, State}, - http::StatusCode, - response::IntoResponse, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json }; +use uuid::Uuid; use crate::{ AppState, - api::v1::auth::check_access_token, + api::v1::auth::CurrentUser, error::Error, objects::{Guild, Invite, Member}, utils::global_checks, @@ -35,14 +29,12 @@ pub async fn get( pub async fn join( State(app_state): State>, Path(invite_id): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; + let invite = Invite::fetch_one(&mut conn, invite_id).await?; let guild = Guild::fetch_one(&mut conn, invite.guild_uuid).await?; diff --git a/src/api/v1/me/friends/mod.rs b/src/api/v1/me/friends/mod.rs index 8a7851c..63284a8 100644 --- a/src/api/v1/me/friends/mod.rs +++ b/src/api/v1/me/friends/mod.rs @@ -1,34 +1,23 @@ use std::sync::Arc; -use axum::{Json, extract::State, http::StatusCode, response::IntoResponse}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, -}; +use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json}; use serde::Deserialize; +use ::uuid::Uuid; pub mod uuid; use crate::{ - AppState, - api::v1::auth::check_access_token, - error::Error, - objects::Me, - utils::{global_checks, user_uuid_from_username}, + api::v1::auth::CurrentUser, error::Error, objects::Me, utils::{global_checks, user_uuid_from_username}, AppState }; /// Returns a list of users that are your friends pub async fn get( State(app_state): State>, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; - let me = Me::get(&mut conn, uuid).await?; + let me = Me::get(&mut app_state.pool.get().await?, uuid).await?; let friends = me.get_friends(&app_state).await?; @@ -61,15 +50,13 @@ pub struct UserReq { /// pub async fn post( State(app_state): State>, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, Json(user_request): Json, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; + let me = Me::get(&mut conn, uuid).await?; let target_uuid = user_uuid_from_username(&mut conn, &user_request.username).await?; diff --git a/src/api/v1/me/friends/uuid.rs b/src/api/v1/me/friends/uuid.rs index 8d40f26..5a32386 100644 --- a/src/api/v1/me/friends/uuid.rs +++ b/src/api/v1/me/friends/uuid.rs @@ -3,29 +3,23 @@ use std::sync::Arc; use axum::{ extract::{Path, State}, http::StatusCode, - response::IntoResponse, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + response::IntoResponse, Extension, }; use uuid::Uuid; use crate::{ - AppState, api::v1::auth::check_access_token, error::Error, objects::Me, utils::global_checks, + AppState, api::v1::auth::CurrentUser, error::Error, objects::Me, utils::global_checks, }; pub async fn delete( State(app_state): State>, Path(friend_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; + let me = Me::get(&mut conn, uuid).await?; me.remove_friend(&mut conn, friend_uuid).await?; diff --git a/src/api/v1/me/guilds.rs b/src/api/v1/me/guilds.rs index adfe845..a2d2111 100644 --- a/src/api/v1/me/guilds.rs +++ b/src/api/v1/me/guilds.rs @@ -2,14 +2,11 @@ use std::sync::Arc; -use axum::{Json, extract::State, http::StatusCode, response::IntoResponse}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, -}; +use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json}; +use uuid::Uuid; use crate::{ - AppState, api::v1::auth::check_access_token, error::Error, objects::Me, utils::global_checks, + AppState, api::v1::auth::CurrentUser, error::Error, objects::Me, utils::global_checks, }; /// `GET /api/v1/me/guilds` Returns all guild memberships in a list @@ -59,14 +56,12 @@ use crate::{ /// NOTE: UUIDs in this response are made using `uuidgen`, UUIDs made by the actual backend will be UUIDv7 and have extractable timestamps pub async fn get( State(app_state): State>, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; + let mut conn = app_state.pool.get().await?; + let me = Me::get(&mut conn, uuid).await?; let memberships = me.fetch_memberships(&mut conn).await?; diff --git a/src/api/v1/me/mod.rs b/src/api/v1/me/mod.rs index e9680bc..ce577d4 100644 --- a/src/api/v1/me/mod.rs +++ b/src/api/v1/me/mod.rs @@ -1,21 +1,14 @@ use std::sync::Arc; use axum::{ - Json, Router, - extract::{DefaultBodyLimit, Multipart, State}, - http::StatusCode, - response::IntoResponse, - routing::{delete, get, patch, post}, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{DefaultBodyLimit, Multipart, State}, http::StatusCode, response::IntoResponse, routing::{delete, get, patch, post}, Extension, Json, Router }; use bytes::Bytes; use serde::Deserialize; +use uuid::Uuid; use crate::{ - AppState, api::v1::auth::check_access_token, error::Error, objects::Me, utils::global_checks, + api::v1::auth::CurrentUser, error::Error, objects::Me, utils::global_checks, AppState }; mod friends; @@ -38,13 +31,9 @@ pub fn router() -> Router> { pub async fn get_me( State(app_state): State>, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - - let me = Me::get(&mut conn, uuid).await?; + let me = Me::get(&mut app_state.pool.get().await?, uuid).await?; Ok((StatusCode::OK, Json(me))) } @@ -60,13 +49,9 @@ struct NewInfo { pub async fn update( State(app_state): State>, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, mut multipart: Multipart, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - let mut json_raw: Option = None; let mut avatar: Option = None; @@ -88,7 +73,7 @@ pub async fn update( global_checks(&app_state, uuid).await?; } - let mut me = Me::get(&mut conn, uuid).await?; + let mut me = Me::get(&mut app_state.pool.get().await?, uuid).await?; if let Some(avatar) = avatar { me.set_avatar(&app_state, app_state.config.bunny.cdn_url.clone(), avatar) diff --git a/src/api/v1/mod.rs b/src/api/v1/mod.rs index 4e8654b..f3e4305 100644 --- a/src/api/v1/mod.rs +++ b/src/api/v1/mod.rs @@ -2,9 +2,9 @@ use std::sync::Arc; -use axum::{routing::get, Router}; +use axum::{middleware::from_fn_with_state, routing::get, Router}; -use crate::AppState; +use crate::{api::v1::auth::CurrentUser, AppState}; mod auth; mod channels; @@ -14,13 +14,17 @@ mod me; mod stats; mod users; -pub fn router() -> Router> { - Router::new() - .route("/stats", get(stats::res)) - .nest("/auth", auth::router()) +pub fn router(app_state: Arc) -> Router> { + let router_with_auth = Router::new() .nest("/users", users::router()) .nest("/channels", channels::router()) .nest("/guilds", guilds::router()) .nest("/invites", invites::router()) .nest("/me", me::router()) + .layer(from_fn_with_state(app_state.clone(), CurrentUser::check_auth_layer)); + + Router::new() + .route("/stats", get(stats::res)) + .nest("/auth", auth::router(app_state)) + .merge(router_with_auth) } diff --git a/src/api/v1/users/mod.rs b/src/api/v1/users/mod.rs index f0d09c5..82f2125 100644 --- a/src/api/v1/users/mod.rs +++ b/src/api/v1/users/mod.rs @@ -3,23 +3,12 @@ use std::sync::Arc; use axum::{ - Json, Router, - extract::{Query, State}, - http::StatusCode, - response::IntoResponse, - routing::get, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{Query, State}, http::StatusCode, response::IntoResponse, routing::get, Extension, Json, Router }; +use ::uuid::Uuid; use crate::{ - AppState, - api::v1::auth::check_access_token, - error::Error, - objects::{StartAmountQuery, User}, - utils::global_checks, + api::v1::auth::CurrentUser, error::Error, objects::{StartAmountQuery, User}, utils::global_checks, AppState }; mod uuid; @@ -63,7 +52,7 @@ pub fn router() -> Router> { pub async fn users( State(app_state): State>, Query(request_query): Query, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { let start = request_query.start.unwrap_or(0); @@ -73,13 +62,9 @@ pub async fn users( return Ok(StatusCode::BAD_REQUEST.into_response()); } - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; - let users = User::fetch_amount(&mut conn, start, amount).await?; + let users = User::fetch_amount(&mut app_state.pool.get().await?, start, amount).await?; Ok((StatusCode::OK, Json(users)).into_response()) } diff --git a/src/api/v1/users/uuid.rs b/src/api/v1/users/uuid.rs index 1b7d43b..2bdcfac 100644 --- a/src/api/v1/users/uuid.rs +++ b/src/api/v1/users/uuid.rs @@ -3,23 +3,12 @@ use std::sync::Arc; use axum::{ - Json, - extract::{Path, State}, - http::StatusCode, - response::IntoResponse, -}; -use axum_extra::{ - TypedHeader, - headers::{Authorization, authorization::Bearer}, + extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json }; use uuid::Uuid; use crate::{ - AppState, - api::v1::auth::check_access_token, - error::Error, - objects::{Me, User}, - utils::global_checks, + api::v1::auth::CurrentUser, error::Error, objects::{Me, User}, utils::global_checks, AppState }; /// `GET /api/v1/users/{uuid}` Returns user with the given UUID @@ -41,15 +30,11 @@ use crate::{ pub async fn get( State(app_state): State>, Path(user_uuid): Path, - TypedHeader(auth): TypedHeader>, + Extension(CurrentUser(uuid)): Extension>, ) -> Result { - let mut conn = app_state.pool.get().await?; - - let uuid = check_access_token(auth.token(), &mut conn).await?; - global_checks(&app_state, uuid).await?; - let me = Me::get(&mut conn, uuid).await?; + let me = Me::get(&mut app_state.pool.get().await?, uuid).await?; let user = User::fetch_one_with_friendship(&app_state, &me, user_uuid).await?; diff --git a/src/main.rs b/src/main.rs index ab37924..8e6effc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -163,7 +163,7 @@ async fn main() -> Result<(), Error> { // build our application with a route let app = Router::new() // `GET /` goes to `root` - .merge(api::router(web.backend_url.path().trim_end_matches("/"))) + .merge(api::router(web.backend_url.path().trim_end_matches("/"), app_state.clone())) .with_state(app_state) .layer(cors) .layer(socket_io); -- 2.47.3 From a602c2624f27f846a527b453f2caa0f822af29d2 Mon Sep 17 00:00:00 2001 From: Radical Date: Sun, 20 Jul 2025 16:30:46 +0200 Subject: [PATCH 16/19] style: cargo fmt & clippy fixes --- src/api/mod.rs | 2 +- src/api/v1/auth/devices.rs | 7 +++++-- src/api/v1/auth/login.rs | 13 ++++++++++--- src/api/v1/auth/mod.rs | 18 +++++++++++------ src/api/v1/auth/refresh.rs | 16 +++++++++++---- src/api/v1/auth/register.rs | 21 +++++++++++++------- src/api/v1/auth/revoke.rs | 10 ++++++++-- src/api/v1/auth/verify_email.rs | 12 ++++++++---- src/api/v1/channels/uuid/messages.rs | 11 +++++++++-- src/api/v1/channels/uuid/mod.rs | 13 ++++++++++--- src/api/v1/guilds/mod.rs | 21 ++++++++++++++++---- src/api/v1/guilds/uuid/channels.rs | 14 +++++++++++--- src/api/v1/guilds/uuid/invites/mod.rs | 5 ++++- src/api/v1/guilds/uuid/members.rs | 5 ++++- src/api/v1/guilds/uuid/mod.rs | 6 +++++- src/api/v1/guilds/uuid/roles/mod.rs | 5 ++++- src/api/v1/guilds/uuid/roles/uuid.rs | 5 ++++- src/api/v1/invites/id.rs | 5 ++++- src/api/v1/me/friends/mod.rs | 10 +++++++--- src/api/v1/me/friends/uuid.rs | 3 ++- src/api/v1/me/guilds.rs | 2 +- src/api/v1/me/mod.rs | 8 ++++++-- src/api/v1/mod.rs | 9 ++++++--- src/api/v1/users/mod.rs | 16 +++++++++++---- src/api/v1/users/uuid.rs | 11 +++++++++-- src/main.rs | 28 ++++++++++++++++----------- src/socket.rs | 14 ++++++++------ src/utils.rs | 6 +++--- src/wordlist.rs | 6 +++--- 29 files changed, 216 insertions(+), 86 deletions(-) diff --git a/src/api/mod.rs b/src/api/mod.rs index 988ee45..5aaa8a5 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -2,7 +2,7 @@ use std::sync::Arc; -use axum::{routing::get, Router}; +use axum::{Router, routing::get}; use crate::AppState; diff --git a/src/api/v1/auth/devices.rs b/src/api/v1/auth/devices.rs index 336a52f..35fe957 100644 --- a/src/api/v1/auth/devices.rs +++ b/src/api/v1/auth/devices.rs @@ -2,14 +2,17 @@ use std::sync::Arc; -use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json}; +use axum::{Extension, Json, extract::State, http::StatusCode, response::IntoResponse}; use diesel::{ExpressionMethods, QueryDsl, Queryable, Selectable, SelectableHelper}; use diesel_async::RunQueryDsl; use serde::Serialize; use uuid::Uuid; use crate::{ - api::v1::auth::CurrentUser, error::Error, schema::refresh_tokens::{self, dsl}, AppState + AppState, + api::v1::auth::CurrentUser, + error::Error, + schema::refresh_tokens::{self, dsl}, }; #[derive(Serialize, Selectable, Queryable)] diff --git a/src/api/v1/auth/login.rs b/src/api/v1/auth/login.rs index 61cb6a0..22cc838 100644 --- a/src/api/v1/auth/login.rs +++ b/src/api/v1/auth/login.rs @@ -20,8 +20,8 @@ use crate::{ error::Error, schema::*, utils::{ - PASSWORD_REGEX, generate_token, new_refresh_token_cookie, - user_uuid_from_identifier, generate_device_name + PASSWORD_REGEX, generate_device_name, generate_token, new_refresh_token_cookie, + user_uuid_from_identifier, }, }; @@ -95,7 +95,14 @@ pub async fn response( .execute(&mut conn) .await?; - let mut response = (StatusCode::OK, Json(Response { access_token, device_name })).into_response(); + let mut response = ( + StatusCode::OK, + Json(Response { + access_token, + device_name, + }), + ) + .into_response(); response.headers_mut().append( "Set-Cookie", diff --git a/src/api/v1/auth/mod.rs b/src/api/v1/auth/mod.rs index c579899..9a72f11 100644 --- a/src/api/v1/auth/mod.rs +++ b/src/api/v1/auth/mod.rs @@ -4,9 +4,16 @@ use std::{ }; use axum::{ - extract::{Request, State}, middleware::{from_fn_with_state, Next}, response::IntoResponse, routing::{delete, get, post}, Router + Router, + extract::{Request, State}, + middleware::{Next, from_fn_with_state}, + response::IntoResponse, + routing::{delete, get, post}, +}; +use axum_extra::{ + TypedHeader, + headers::{Authorization, authorization::Bearer}, }; -use axum_extra::{headers::{authorization::Bearer, Authorization}, TypedHeader}; use diesel::{ExpressionMethods, QueryDsl}; use diesel_async::RunQueryDsl; use serde::Serialize; @@ -23,14 +30,12 @@ mod reset_password; mod revoke; mod verify_email; - #[derive(Serialize)] pub struct Response { access_token: String, device_name: String, } - pub fn router(app_state: Arc) -> Router> { let router_with_auth = Router::new() .route("/verify-email", get(verify_email::get)) @@ -82,9 +87,10 @@ impl CurrentUser { State(app_state): State>, TypedHeader(auth): TypedHeader>, mut req: Request, - next: Next + next: Next, ) -> Result { - let current_user = CurrentUser(check_access_token(auth.token(), &mut app_state.pool.get().await?).await?); + let current_user = + CurrentUser(check_access_token(auth.token(), &mut app_state.pool.get().await?).await?); req.extensions_mut().insert(current_user); Ok(next.run(req).await) diff --git a/src/api/v1/auth/refresh.rs b/src/api/v1/auth/refresh.rs index ee4f7ae..4b96226 100644 --- a/src/api/v1/auth/refresh.rs +++ b/src/api/v1/auth/refresh.rs @@ -1,7 +1,8 @@ use axum::{ + Json, extract::State, http::{HeaderValue, StatusCode}, - response::IntoResponse, Json, + response::IntoResponse, }; use axum_extra::extract::CookieJar; use diesel::{ExpressionMethods, QueryDsl, delete, update}; @@ -19,7 +20,8 @@ use crate::{ schema::{ access_tokens::{self, dsl}, refresh_tokens::{self, dsl as rdsl}, - }, utils::{generate_token, new_refresh_token_cookie} + }, + utils::{generate_token, new_refresh_token_cookie}, }; pub async fn post( @@ -104,7 +106,14 @@ pub async fn post( .execute(&mut conn) .await?; - let mut response = (StatusCode::OK, Json(Response { access_token, device_name })).into_response(); + let mut response = ( + StatusCode::OK, + Json(Response { + access_token, + device_name, + }), + ) + .into_response(); // TODO: Dont set this when refresh token is unchanged response.headers_mut().append( @@ -113,7 +122,6 @@ pub async fn post( &new_refresh_token_cookie(&app_state.config, refresh_token).to_string(), )?, ); - return Ok(response); } diff --git a/src/api/v1/auth/register.rs b/src/api/v1/auth/register.rs index f2520bf..807fab8 100644 --- a/src/api/v1/auth/register.rs +++ b/src/api/v1/auth/register.rs @@ -29,8 +29,8 @@ use crate::{ users::{self, dsl as udsl}, }, utils::{ - EMAIL_REGEX, PASSWORD_REGEX, USERNAME_REGEX, generate_token, - new_refresh_token_cookie, generate_device_name + EMAIL_REGEX, PASSWORD_REGEX, USERNAME_REGEX, generate_device_name, generate_token, + new_refresh_token_cookie, }, }; @@ -137,11 +137,11 @@ pub async fn post( let current_time = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() as i64; let device_name = generate_device_name(); - + insert_into(refresh_tokens::table) - .values(( - rdsl::token.eq(&refresh_token), - rdsl::uuid.eq(uuid), + .values(( + rdsl::token.eq(&refresh_token), + rdsl::uuid.eq(uuid), rdsl::created_at.eq(current_time), rdsl::device_name.eq(&device_name), )) @@ -162,7 +162,14 @@ pub async fn post( Member::new(&app_state, uuid, initial_guild).await?; } - let mut response = (StatusCode::OK, Json(Response {access_token, device_name})).into_response(); + let mut response = ( + StatusCode::OK, + Json(Response { + access_token, + device_name, + }), + ) + .into_response(); response.headers_mut().append( "Set-Cookie", diff --git a/src/api/v1/auth/revoke.rs b/src/api/v1/auth/revoke.rs index b59172e..dd87ec3 100644 --- a/src/api/v1/auth/revoke.rs +++ b/src/api/v1/auth/revoke.rs @@ -1,14 +1,20 @@ use std::sync::Arc; use argon2::{PasswordHash, PasswordVerifier}; -use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json}; +use axum::{Extension, Json, extract::State, http::StatusCode, response::IntoResponse}; use diesel::{ExpressionMethods, QueryDsl, delete}; use diesel_async::RunQueryDsl; use serde::Deserialize; use uuid::Uuid; use crate::{ - api::v1::auth::CurrentUser, error::Error, schema::{refresh_tokens::{self, dsl as rdsl}, users::dsl as udsl}, AppState + AppState, + api::v1::auth::CurrentUser, + error::Error, + schema::{ + refresh_tokens::{self, dsl as rdsl}, + users::dsl as udsl, + }, }; #[derive(Deserialize)] diff --git a/src/api/v1/auth/verify_email.rs b/src/api/v1/auth/verify_email.rs index 1270966..0801768 100644 --- a/src/api/v1/auth/verify_email.rs +++ b/src/api/v1/auth/verify_email.rs @@ -3,16 +3,20 @@ use std::sync::Arc; use axum::{ + Extension, extract::{Query, State}, http::StatusCode, - response::IntoResponse, Extension, + response::IntoResponse, }; use chrono::{Duration, Utc}; use serde::Deserialize; use uuid::Uuid; use crate::{ - api::v1::auth::CurrentUser, error::Error, objects::{EmailToken, Me}, AppState + AppState, + api::v1::auth::CurrentUser, + error::Error, + objects::{EmailToken, Me}, }; #[derive(Deserialize)] @@ -41,7 +45,7 @@ pub struct QueryParams { pub async fn get( State(app_state): State>, Query(query): Query, - Extension(CurrentUser(uuid)): Extension> + Extension(CurrentUser(uuid)): Extension>, ) -> Result { let mut conn = app_state.pool.get().await?; @@ -79,7 +83,7 @@ pub async fn get( /// pub async fn post( State(app_state): State>, - Extension(CurrentUser(uuid)): Extension> + Extension(CurrentUser(uuid)): Extension>, ) -> Result { let me = Me::get(&mut app_state.pool.get().await?, uuid).await?; diff --git a/src/api/v1/channels/uuid/messages.rs b/src/api/v1/channels/uuid/messages.rs index 0297bbc..b8f0ad6 100644 --- a/src/api/v1/channels/uuid/messages.rs +++ b/src/api/v1/channels/uuid/messages.rs @@ -3,11 +3,18 @@ use std::sync::Arc; use crate::{ - api::v1::auth::CurrentUser, error::Error, objects::{Channel, Member}, utils::global_checks, AppState + AppState, + api::v1::auth::CurrentUser, + error::Error, + objects::{Channel, Member}, + utils::global_checks, }; use ::uuid::Uuid; use axum::{ - extract::{Path, Query, State}, http::StatusCode, response::IntoResponse, Extension, Json + Extension, Json, + extract::{Path, Query, State}, + http::StatusCode, + response::IntoResponse, }; use serde::Deserialize; diff --git a/src/api/v1/channels/uuid/mod.rs b/src/api/v1/channels/uuid/mod.rs index c1560f0..5c88a29 100644 --- a/src/api/v1/channels/uuid/mod.rs +++ b/src/api/v1/channels/uuid/mod.rs @@ -13,7 +13,10 @@ use crate::{ utils::global_checks, }; use axum::{ - extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json + Extension, Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, }; use serde::Deserialize; @@ -42,7 +45,9 @@ pub async fn delete( let channel = Channel::fetch_one(&app_state, channel_uuid).await?; - let member = Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid).await?; + let member = + Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid) + .await?; member .check_permission(&app_state, Permissions::ManageChannel) @@ -101,7 +106,9 @@ pub async fn patch( let mut channel = Channel::fetch_one(&app_state, channel_uuid).await?; - let member = Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid).await?; + let member = + Member::check_membership(&mut app_state.pool.get().await?, uuid, channel.guild_uuid) + .await?; member .check_permission(&app_state, Permissions::ManageChannel) diff --git a/src/api/v1/guilds/mod.rs b/src/api/v1/guilds/mod.rs index dbee589..8118522 100644 --- a/src/api/v1/guilds/mod.rs +++ b/src/api/v1/guilds/mod.rs @@ -2,16 +2,24 @@ use std::sync::Arc; +use ::uuid::Uuid; use axum::{ - extract::State, http::StatusCode, response::IntoResponse, routing::{get, post}, Extension, Json, Router + Extension, Json, Router, + extract::State, + http::StatusCode, + response::IntoResponse, + routing::{get, post}, }; use serde::Deserialize; -use ::uuid::Uuid; mod uuid; use crate::{ - api::v1::auth::CurrentUser, error::Error, objects::{Guild, StartAmountQuery}, utils::global_checks, AppState + AppState, + api::v1::auth::CurrentUser, + error::Error, + objects::{Guild, StartAmountQuery}, + utils::global_checks, }; #[derive(Deserialize)] @@ -55,7 +63,12 @@ pub async fn new( Extension(CurrentUser(uuid)): Extension>, Json(guild_info): Json, ) -> Result { - let guild = Guild::new(&mut app_state.pool.get().await?, guild_info.name.clone(), uuid).await?; + let guild = Guild::new( + &mut app_state.pool.get().await?, + guild_info.name.clone(), + uuid, + ) + .await?; Ok((StatusCode::OK, Json(guild))) } diff --git a/src/api/v1/guilds/uuid/channels.rs b/src/api/v1/guilds/uuid/channels.rs index a28aa6c..836982d 100644 --- a/src/api/v1/guilds/uuid/channels.rs +++ b/src/api/v1/guilds/uuid/channels.rs @@ -2,12 +2,19 @@ use std::sync::Arc; use ::uuid::Uuid; use axum::{ - extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json + Extension, Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, }; use serde::Deserialize; use crate::{ - api::v1::auth::CurrentUser, error::Error, objects::{Channel, Member, Permissions}, utils::{global_checks, order_by_is_above}, AppState + AppState, + api::v1::auth::CurrentUser, + error::Error, + objects::{Channel, Member, Permissions}, + utils::{global_checks, order_by_is_above}, }; #[derive(Deserialize)] @@ -55,7 +62,8 @@ pub async fn create( ) -> Result { global_checks(&app_state, uuid).await?; - let member = Member::check_membership(&mut app_state.pool.get().await?, uuid, guild_uuid).await?; + let member = + Member::check_membership(&mut app_state.pool.get().await?, uuid, guild_uuid).await?; member .check_permission(&app_state, Permissions::ManageChannel) diff --git a/src/api/v1/guilds/uuid/invites/mod.rs b/src/api/v1/guilds/uuid/invites/mod.rs index 2070452..649fc16 100644 --- a/src/api/v1/guilds/uuid/invites/mod.rs +++ b/src/api/v1/guilds/uuid/invites/mod.rs @@ -1,7 +1,10 @@ use std::sync::Arc; use axum::{ - extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json + Extension, Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, }; use serde::Deserialize; use uuid::Uuid; diff --git a/src/api/v1/guilds/uuid/members.rs b/src/api/v1/guilds/uuid/members.rs index 6c8b980..3ae10f7 100644 --- a/src/api/v1/guilds/uuid/members.rs +++ b/src/api/v1/guilds/uuid/members.rs @@ -2,7 +2,10 @@ use std::sync::Arc; use ::uuid::Uuid; use axum::{ - extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json + Extension, Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, }; use crate::{ diff --git a/src/api/v1/guilds/uuid/mod.rs b/src/api/v1/guilds/uuid/mod.rs index c5a809f..52f0b64 100644 --- a/src/api/v1/guilds/uuid/mod.rs +++ b/src/api/v1/guilds/uuid/mod.rs @@ -3,7 +3,11 @@ use std::sync::Arc; use axum::{ - extract::{Multipart, Path, State}, http::StatusCode, response::IntoResponse, routing::{get, patch, post}, Extension, Json, Router + Extension, Json, Router, + extract::{Multipart, Path, State}, + http::StatusCode, + response::IntoResponse, + routing::{get, patch, post}, }; use bytes::Bytes; use uuid::Uuid; diff --git a/src/api/v1/guilds/uuid/roles/mod.rs b/src/api/v1/guilds/uuid/roles/mod.rs index 5331143..820ef0d 100644 --- a/src/api/v1/guilds/uuid/roles/mod.rs +++ b/src/api/v1/guilds/uuid/roles/mod.rs @@ -2,7 +2,10 @@ use std::sync::Arc; use ::uuid::Uuid; use axum::{ - extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json + Extension, Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, }; use serde::Deserialize; diff --git a/src/api/v1/guilds/uuid/roles/uuid.rs b/src/api/v1/guilds/uuid/roles/uuid.rs index 91300bf..06193a1 100644 --- a/src/api/v1/guilds/uuid/roles/uuid.rs +++ b/src/api/v1/guilds/uuid/roles/uuid.rs @@ -2,7 +2,10 @@ use std::sync::Arc; use ::uuid::Uuid; use axum::{ - extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json + Extension, Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, }; use crate::{ diff --git a/src/api/v1/invites/id.rs b/src/api/v1/invites/id.rs index c752177..72ceea4 100644 --- a/src/api/v1/invites/id.rs +++ b/src/api/v1/invites/id.rs @@ -1,7 +1,10 @@ use std::sync::Arc; use axum::{ - extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json + Extension, Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, }; use uuid::Uuid; diff --git a/src/api/v1/me/friends/mod.rs b/src/api/v1/me/friends/mod.rs index 63284a8..a56f8d4 100644 --- a/src/api/v1/me/friends/mod.rs +++ b/src/api/v1/me/friends/mod.rs @@ -1,13 +1,17 @@ use std::sync::Arc; -use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json}; -use serde::Deserialize; use ::uuid::Uuid; +use axum::{Extension, Json, extract::State, http::StatusCode, response::IntoResponse}; +use serde::Deserialize; pub mod uuid; use crate::{ - api::v1::auth::CurrentUser, error::Error, objects::Me, utils::{global_checks, user_uuid_from_username}, AppState + AppState, + api::v1::auth::CurrentUser, + error::Error, + objects::Me, + utils::{global_checks, user_uuid_from_username}, }; /// Returns a list of users that are your friends diff --git a/src/api/v1/me/friends/uuid.rs b/src/api/v1/me/friends/uuid.rs index 5a32386..5367435 100644 --- a/src/api/v1/me/friends/uuid.rs +++ b/src/api/v1/me/friends/uuid.rs @@ -1,9 +1,10 @@ use std::sync::Arc; use axum::{ + Extension, extract::{Path, State}, http::StatusCode, - response::IntoResponse, Extension, + response::IntoResponse, }; use uuid::Uuid; diff --git a/src/api/v1/me/guilds.rs b/src/api/v1/me/guilds.rs index a2d2111..88dfad9 100644 --- a/src/api/v1/me/guilds.rs +++ b/src/api/v1/me/guilds.rs @@ -2,7 +2,7 @@ use std::sync::Arc; -use axum::{extract::State, http::StatusCode, response::IntoResponse, Extension, Json}; +use axum::{Extension, Json, extract::State, http::StatusCode, response::IntoResponse}; use uuid::Uuid; use crate::{ diff --git a/src/api/v1/me/mod.rs b/src/api/v1/me/mod.rs index ce577d4..e167d14 100644 --- a/src/api/v1/me/mod.rs +++ b/src/api/v1/me/mod.rs @@ -1,14 +1,18 @@ use std::sync::Arc; use axum::{ - extract::{DefaultBodyLimit, Multipart, State}, http::StatusCode, response::IntoResponse, routing::{delete, get, patch, post}, Extension, Json, Router + Extension, Json, Router, + extract::{DefaultBodyLimit, Multipart, State}, + http::StatusCode, + response::IntoResponse, + routing::{delete, get, patch, post}, }; use bytes::Bytes; use serde::Deserialize; use uuid::Uuid; use crate::{ - api::v1::auth::CurrentUser, error::Error, objects::Me, utils::global_checks, AppState + AppState, api::v1::auth::CurrentUser, error::Error, objects::Me, utils::global_checks, }; mod friends; diff --git a/src/api/v1/mod.rs b/src/api/v1/mod.rs index f3e4305..5ca9558 100644 --- a/src/api/v1/mod.rs +++ b/src/api/v1/mod.rs @@ -2,9 +2,9 @@ use std::sync::Arc; -use axum::{middleware::from_fn_with_state, routing::get, Router}; +use axum::{Router, middleware::from_fn_with_state, routing::get}; -use crate::{api::v1::auth::CurrentUser, AppState}; +use crate::{AppState, api::v1::auth::CurrentUser}; mod auth; mod channels; @@ -21,7 +21,10 @@ pub fn router(app_state: Arc) -> Router> { .nest("/guilds", guilds::router()) .nest("/invites", invites::router()) .nest("/me", me::router()) - .layer(from_fn_with_state(app_state.clone(), CurrentUser::check_auth_layer)); + .layer(from_fn_with_state( + app_state.clone(), + CurrentUser::check_auth_layer, + )); Router::new() .route("/stats", get(stats::res)) diff --git a/src/api/v1/users/mod.rs b/src/api/v1/users/mod.rs index 82f2125..a4b93ce 100644 --- a/src/api/v1/users/mod.rs +++ b/src/api/v1/users/mod.rs @@ -2,13 +2,21 @@ use std::sync::Arc; -use axum::{ - extract::{Query, State}, http::StatusCode, response::IntoResponse, routing::get, Extension, Json, Router -}; use ::uuid::Uuid; +use axum::{ + Extension, Json, Router, + extract::{Query, State}, + http::StatusCode, + response::IntoResponse, + routing::get, +}; use crate::{ - api::v1::auth::CurrentUser, error::Error, objects::{StartAmountQuery, User}, utils::global_checks, AppState + AppState, + api::v1::auth::CurrentUser, + error::Error, + objects::{StartAmountQuery, User}, + utils::global_checks, }; mod uuid; diff --git a/src/api/v1/users/uuid.rs b/src/api/v1/users/uuid.rs index 2bdcfac..cee6df0 100644 --- a/src/api/v1/users/uuid.rs +++ b/src/api/v1/users/uuid.rs @@ -3,12 +3,19 @@ use std::sync::Arc; use axum::{ - extract::{Path, State}, http::StatusCode, response::IntoResponse, Extension, Json + Extension, Json, + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, }; use uuid::Uuid; use crate::{ - api::v1::auth::CurrentUser, error::Error, objects::{Me, User}, utils::global_checks, AppState + AppState, + api::v1::auth::CurrentUser, + error::Error, + objects::{Me, User}, + utils::global_checks, }; /// `GET /api/v1/users/{uuid}` Returns user with the given UUID diff --git a/src/main.rs b/src/main.rs index ffbfa4e..e42c8dc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,29 +1,32 @@ use argon2::Argon2; -use axum::{http::{header, Method}, Router}; +use axum::{ + Router, + http::{Method, header}, +}; use clap::Parser; +use config::{Config, ConfigBuilder}; use diesel_async::pooled_connection::AsyncDieselConnectionManager; use diesel_async::pooled_connection::deadpool::Pool; +use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations}; use error::Error; use objects::MailClient; use socketioxide::SocketIo; use std::{sync::Arc, time::SystemTime}; use tower_http::cors::{AllowOrigin, CorsLayer}; -use config::{Config, ConfigBuilder}; -use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations}; pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!(); type Conn = deadpool::managed::Object>; -mod config; -mod wordlist; mod api; +mod config; pub mod error; pub mod objects; pub mod schema; -pub mod utils; mod socket; +pub mod utils; +mod wordlist; #[derive(Parser, Debug)] #[command(version, about, long_about = None)] @@ -129,9 +132,7 @@ async fn main() -> Result<(), Error> { let cors = CorsLayer::new() // Allow any origin (equivalent to allowed_origin_fn returning true) - .allow_origin(AllowOrigin::predicate(|_origin, _request_head| { - true - })) + .allow_origin(AllowOrigin::predicate(|_origin, _request_head| true)) .allow_methods(vec![ Method::GET, Method::POST, @@ -157,14 +158,19 @@ async fn main() -> Result<(), Error> { // Allow credentials .allow_credentials(true); - let (socket_io, io) = SocketIo::builder().with_state(app_state.clone()).build_layer(); + let (socket_io, io) = SocketIo::builder() + .with_state(app_state.clone()) + .build_layer(); io.ns("/", socket::on_connect); // build our application with a route let app = Router::new() // `GET /` goes to `root` - .merge(api::router(web.backend_url.path().trim_end_matches("/"), app_state.clone())) + .merge(api::router( + web.backend_url.path().trim_end_matches("/"), + app_state.clone(), + )) .with_state(app_state) .layer(cors) .layer(socket_io); diff --git a/src/socket.rs b/src/socket.rs index e00a7c0..3fcae32 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -2,24 +2,26 @@ use std::sync::Arc; use log::info; use rmpv::Value; -use socketioxide::{ - extract::{AckSender, Data, SocketRef, State}, -}; +use socketioxide::extract::{AckSender, Data, SocketRef, State}; use crate::AppState; -pub async fn on_connect(State(app_state): State>, socket: SocketRef, Data(data): Data) { +pub async fn on_connect( + State(_app_state): State>, + socket: SocketRef, + Data(data): Data, +) { socket.emit("auth", &data).ok(); socket.on("message", async |socket: SocketRef, Data::(data)| { - info!("{}", data); + info!("{data}"); socket.emit("message-back", &data).ok(); }); socket.on( "message-with-ack", async |Data::(data), ack: AckSender| { - info!("{}", data); + info!("{data}"); ack.send(&data).ok(); }, ); diff --git a/src/utils.rs b/src/utils.rs index ac8e343..e1df906 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -1,5 +1,5 @@ +use rand::seq::IndexedRandom; use std::sync::LazyLock; -use rand::{seq::IndexedRandom}; use axum::body::Bytes; use axum_extra::extract::cookie::{Cookie, SameSite}; @@ -20,7 +20,7 @@ use crate::{ error::Error, objects::{HasIsAbove, HasUuid}, schema::users, - wordlist::{ADJECTIVES, ANIMALS} + wordlist::{ADJECTIVES, ANIMALS}, }; pub static EMAIL_REGEX: LazyLock = LazyLock::new(|| { @@ -216,5 +216,5 @@ pub fn generate_device_name() -> String { let adjective = ADJECTIVES.choose(&mut rng).unwrap(); let animal = ANIMALS.choose(&mut rng).unwrap(); - return [*adjective, *animal].join(" ") + [*adjective, *animal].join(" ") } diff --git a/src/wordlist.rs b/src/wordlist.rs index 0c17723..1227c1f 100644 --- a/src/wordlist.rs +++ b/src/wordlist.rs @@ -1,4 +1,4 @@ -pub const ANIMALS: [&'static str; 223] = [ +pub const ANIMALS: [&str; 223] = [ "Aardvark", "Albatross", "Alligator", @@ -224,7 +224,7 @@ pub const ANIMALS: [&'static str; 223] = [ "Zebra", ]; -pub const ADJECTIVES: [&'static str; 765] = [ +pub const ADJECTIVES: [&str; 765] = [ "Other", "Such", "First", @@ -990,4 +990,4 @@ pub const ADJECTIVES: [&'static str; 765] = [ "Vocal", "Obscure", "Innovative", -]; \ No newline at end of file +]; -- 2.47.3 From 2fb7e7781f42c419c0d9f31de04b5a784e4c11f7 Mon Sep 17 00:00:00 2001 From: Radical Date: Sun, 20 Jul 2025 18:11:08 +0200 Subject: [PATCH 17/19] feat: reimplement old websocket --- Cargo.toml | 2 +- src/api/v1/channels/mod.rs | 3 +- src/api/v1/channels/uuid/mod.rs | 2 +- src/api/v1/channels/uuid/socket.rs | 147 ++++++++++++++++------------- src/api/v1/mod.rs | 4 +- src/error.rs | 3 + 6 files changed, 91 insertions(+), 70 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2b9962c..e0c83bb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,7 @@ bindet = "0.3.2" bunny-api-tokio = { version = "0.4", features = ["edge_storage"], default-features = false } # Web Server -axum = { version = "0.8.4", features = ["macros", "multipart"] } +axum = { version = "0.8.4", features = ["macros", "multipart", "ws"] } tower-http = { version = "0.6.6", features = ["cors"] } axum-extra = { version = "0.10.1", features = ["cookie", "typed-header"] } socketioxide = { version = "0.17.2", features = ["state"] } diff --git a/src/api/v1/channels/mod.rs b/src/api/v1/channels/mod.rs index 24b62f7..cc033af 100644 --- a/src/api/v1/channels/mod.rs +++ b/src/api/v1/channels/mod.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use axum::{ Router, - routing::{delete, get, patch}, + routing::{any, delete, get, patch}, }; //use socketioxide::SocketIo; @@ -15,5 +15,6 @@ pub fn router() -> Router> { .route("/{uuid}", get(uuid::get)) .route("/{uuid}", delete(uuid::delete)) .route("/{uuid}", patch(uuid::patch)) + .route("/{uuid}/socket", any(uuid::socket::ws)) .route("/{uuid}/messages", get(uuid::messages::get)) } diff --git a/src/api/v1/channels/uuid/mod.rs b/src/api/v1/channels/uuid/mod.rs index 5c88a29..373742e 100644 --- a/src/api/v1/channels/uuid/mod.rs +++ b/src/api/v1/channels/uuid/mod.rs @@ -1,7 +1,7 @@ //! `/api/v1/channels/{uuid}` Channel specific endpoints pub mod messages; -//pub mod socket; +pub mod socket; use std::sync::Arc; diff --git a/src/api/v1/channels/uuid/socket.rs b/src/api/v1/channels/uuid/socket.rs index 7233f39..46a7334 100644 --- a/src/api/v1/channels/uuid/socket.rs +++ b/src/api/v1/channels/uuid/socket.rs @@ -1,18 +1,21 @@ -use actix_web::{ - Error, HttpRequest, HttpResponse, get, - http::header::{HeaderValue, SEC_WEBSOCKET_PROTOCOL}, - rt, web, +use std::sync::Arc; + +use axum::{ + extract::{Path, State, WebSocketUpgrade, ws::Message}, + http::HeaderMap, + response::IntoResponse, }; -use actix_ws::AggregatedMessage; +use futures::SinkExt; use futures_util::StreamExt as _; use serde::Deserialize; use uuid::Uuid; use crate::{ - Data, + AppState, api::v1::auth::check_access_token, + error::Error, objects::{Channel, Member}, - utils::{get_ws_protocol_header, global_checks}, + utils::global_checks, }; #[derive(Deserialize)] @@ -21,100 +24,114 @@ struct MessageBody { reply_to: Option, } -#[get("/{uuid}/socket")] pub async fn ws( - req: HttpRequest, - path: web::Path<(Uuid,)>, - stream: web::Payload, - data: web::Data, -) -> Result { - // Get all headers - let headers = req.headers(); - + ws: WebSocketUpgrade, + State(app_state): State>, + Path(channel_uuid): Path, + headers: HeaderMap, +) -> Result { // Retrieve auth header - let auth_header = get_ws_protocol_header(headers)?; + let auth_token = headers.get(axum::http::header::SEC_WEBSOCKET_PROTOCOL); - // Get uuid from path - let channel_uuid = path.into_inner().0; + if auth_token.is_none() { + return Err(Error::Unauthorized( + "No authorization header provided".to_string(), + )); + } - let mut conn = data.pool.get().await.map_err(crate::error::Error::from)?; + let auth_raw = auth_token.unwrap().to_str()?; + + let mut auth = auth_raw.split_whitespace(); + + let response_proto = auth.next(); + + let auth_value = auth.next(); + + if response_proto.is_none() { + return Err(Error::BadRequest( + "Sec-WebSocket-Protocol header is empty".to_string(), + )); + } else if response_proto.is_some_and(|rp| rp != "Authorization,") { + return Err(Error::BadRequest( + "First protocol should be Authorization".to_string(), + )); + } + + if auth_value.is_none() { + return Err(Error::BadRequest("No token provided".to_string())); + } + + let auth_header = auth_value.unwrap(); + + let mut conn = app_state + .pool + .get() + .await + .map_err(crate::error::Error::from)?; // Authorize client using auth header let uuid = check_access_token(auth_header, &mut conn).await?; - global_checks(&data, uuid).await?; + global_checks(&app_state, uuid).await?; - let channel = Channel::fetch_one(&data, channel_uuid).await?; + let channel = Channel::fetch_one(&app_state, channel_uuid).await?; Member::check_membership(&mut conn, uuid, channel.guild_uuid).await?; - let (mut res, mut session_1, stream) = actix_ws::handle(&req, stream)?; - - let mut stream = stream - .aggregate_continuations() - // aggregate continuation frames up to 1MiB - .max_continuation_size(2_usize.pow(20)); - - let mut pubsub = data + let mut pubsub = app_state .cache_pool .get_async_pubsub() .await .map_err(crate::error::Error::from)?; - let mut session_2 = session_1.clone(); + let mut res = ws.on_upgrade(async move |socket| { + let (mut sender, mut receiver) = socket.split(); - rt::spawn(async move { - pubsub.subscribe(channel_uuid.to_string()).await?; - while let Some(msg) = pubsub.on_message().next().await { - let payload: String = msg.get_payload()?; - session_1.text(payload).await?; - } + tokio::spawn(async move { + pubsub.subscribe(channel_uuid.to_string()).await?; + while let Some(msg) = pubsub.on_message().next().await { + let payload: String = msg.get_payload()?; + sender.send(payload.into()).await?; + } - Ok::<(), crate::error::Error>(()) - }); - - // start task but don't wait for it - rt::spawn(async move { - // receive messages from websocket - while let Some(msg) = stream.next().await { - match msg { - Ok(AggregatedMessage::Text(text)) => { - let mut conn = data.cache_pool.get_multiplexed_tokio_connection().await?; + Ok::<(), crate::error::Error>(()) + }); + tokio::spawn(async move { + while let Some(msg) = receiver.next().await { + if let Ok(Message::Text(text)) = msg { let message_body: MessageBody = serde_json::from_str(&text)?; let message = channel - .new_message(&data, uuid, message_body.message, message_body.reply_to) + .new_message( + &app_state, + uuid, + message_body.message, + message_body.reply_to, + ) .await?; redis::cmd("PUBLISH") .arg(&[channel_uuid.to_string(), serde_json::to_string(&message)?]) - .exec_async(&mut conn) + .exec_async( + &mut app_state + .cache_pool + .get_multiplexed_tokio_connection() + .await?, + ) .await?; } - - Ok(AggregatedMessage::Binary(bin)) => { - // echo binary message - session_2.binary(bin).await?; - } - - Ok(AggregatedMessage::Ping(msg)) => { - // respond to PING frame with PONG frame - session_2.pong(&msg).await?; - } - - _ => {} } - } - Ok::<(), crate::error::Error>(()) + Ok::<(), crate::error::Error>(()) + }); }); let headers = res.headers_mut(); headers.append( - SEC_WEBSOCKET_PROTOCOL, - HeaderValue::from_str("Authorization")?, + axum::http::header::SEC_WEBSOCKET_PROTOCOL, + "Authorization".parse()?, ); // respond immediately with response connected to WS session diff --git a/src/api/v1/mod.rs b/src/api/v1/mod.rs index 5ca9558..860944c 100644 --- a/src/api/v1/mod.rs +++ b/src/api/v1/mod.rs @@ -17,7 +17,6 @@ mod users; pub fn router(app_state: Arc) -> Router> { let router_with_auth = Router::new() .nest("/users", users::router()) - .nest("/channels", channels::router()) .nest("/guilds", guilds::router()) .nest("/invites", invites::router()) .nest("/me", me::router()) @@ -28,6 +27,7 @@ pub fn router(app_state: Arc) -> Router> { Router::new() .route("/stats", get(stats::res)) - .nest("/auth", auth::router(app_state)) + .nest("/auth", auth::router(app_state.clone())) + .nest("/channels", channels::router(app_state)) .merge(router_with_auth) } diff --git a/src/error.rs b/src/error.rs index 1b8f27c..d6f7a12 100644 --- a/src/error.rs +++ b/src/error.rs @@ -83,6 +83,9 @@ pub enum Error { TooManyRequests(String), #[error("{0}")] InternalServerError(String), + // TODO: remove when doing socket.io + #[error(transparent)] + AxumError(#[from] axum::Error), } impl IntoResponse for Error { -- 2.47.3 From 8ec1610b2e8e08a9af76cc086b544e8fc68a6501 Mon Sep 17 00:00:00 2001 From: Radical Date: Sun, 20 Jul 2025 18:11:31 +0200 Subject: [PATCH 18/19] feat: remove dependency on socket.io Keeping stuff commented so we can revisit, currently just need a working version --- Cargo.toml | 4 ++-- src/api/v1/auth/revoke.rs | 1 - src/api/v1/channels/mod.rs | 15 +++++++++------ src/main.rs | 17 +++++------------ 4 files changed, 16 insertions(+), 21 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e0c83bb..3decea6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,10 +38,10 @@ bindet = "0.3.2" bunny-api-tokio = { version = "0.4", features = ["edge_storage"], default-features = false } # Web Server -axum = { version = "0.8.4", features = ["macros", "multipart", "ws"] } +axum = { version = "0.8.4", features = ["multipart", "ws"] } tower-http = { version = "0.6.6", features = ["cors"] } axum-extra = { version = "0.10.1", features = ["cookie", "typed-header"] } -socketioxide = { version = "0.17.2", features = ["state"] } +#socketioxide = { version = "0.17.2", features = ["state"] } url = { version = "2.5", features = ["serde"] } time = "0.3.41" diff --git a/src/api/v1/auth/revoke.rs b/src/api/v1/auth/revoke.rs index dd87ec3..90b96ae 100644 --- a/src/api/v1/auth/revoke.rs +++ b/src/api/v1/auth/revoke.rs @@ -24,7 +24,6 @@ pub struct RevokeRequest { } // TODO: Should maybe be a delete request? -#[axum::debug_handler] pub async fn post( State(app_state): State>, Extension(CurrentUser(uuid)): Extension>, diff --git a/src/api/v1/channels/mod.rs b/src/api/v1/channels/mod.rs index cc033af..41d029a 100644 --- a/src/api/v1/channels/mod.rs +++ b/src/api/v1/channels/mod.rs @@ -1,20 +1,23 @@ use std::sync::Arc; use axum::{ - Router, - routing::{any, delete, get, patch}, + middleware::from_fn_with_state, routing::{any, delete, get, patch}, Router }; //use socketioxide::SocketIo; -use crate::AppState; +use crate::{api::v1::auth::CurrentUser, AppState}; mod uuid; -pub fn router() -> Router> { - Router::new() +pub fn router(app_state: Arc) -> Router> { + let router_with_auth = Router::new() .route("/{uuid}", get(uuid::get)) .route("/{uuid}", delete(uuid::delete)) .route("/{uuid}", patch(uuid::patch)) - .route("/{uuid}/socket", any(uuid::socket::ws)) .route("/{uuid}/messages", get(uuid::messages::get)) + .layer(from_fn_with_state(app_state, CurrentUser::check_auth_layer)); + + Router::new() + .route("/{uuid}/socket", any(uuid::socket::ws)) + .merge(router_with_auth) } diff --git a/src/main.rs b/src/main.rs index e42c8dc..13e661d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,7 +10,6 @@ use diesel_async::pooled_connection::deadpool::Pool; use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations}; use error::Error; use objects::MailClient; -use socketioxide::SocketIo; use std::{sync::Arc, time::SystemTime}; use tower_http::cors::{AllowOrigin, CorsLayer}; @@ -24,7 +23,7 @@ mod config; pub mod error; pub mod objects; pub mod schema; -mod socket; +//mod socket; pub mod utils; mod wordlist; @@ -53,12 +52,6 @@ pub struct AppState { async fn main() -> Result<(), Error> { tracing_subscriber::fmt::init(); - //SimpleLogger::new() - // .with_level(log::LevelFilter::Info) - // .with_colors(true) - // .env() - // .init() - // .unwrap(); let args = Args::parse(); let config = ConfigBuilder::load(args.config).await?.build(); @@ -158,12 +151,12 @@ async fn main() -> Result<(), Error> { // Allow credentials .allow_credentials(true); - let (socket_io, io) = SocketIo::builder() + /*let (socket_io, io) = SocketIo::builder() .with_state(app_state.clone()) .build_layer(); io.ns("/", socket::on_connect); - + */ // build our application with a route let app = Router::new() // `GET /` goes to `root` @@ -172,8 +165,8 @@ async fn main() -> Result<(), Error> { app_state.clone(), )) .with_state(app_state) - .layer(cors) - .layer(socket_io); + //.layer(socket_io) + .layer(cors); // run our app with hyper, listening globally on port 3000 let listener = tokio::net::TcpListener::bind(web.ip + ":" + &web.port.to_string()).await?; -- 2.47.3 From 1c07957c4e77422a472dea3538e66499dc37ab9f Mon Sep 17 00:00:00 2001 From: Radical Date: Sun, 20 Jul 2025 18:45:50 +0200 Subject: [PATCH 19/19] refactor: small dependency optimizations --- Cargo.toml | 6 +----- src/api/v1/channels/uuid/socket.rs | 3 +-- src/objects/channel.rs | 4 ++-- src/objects/guild.rs | 2 +- src/objects/me.rs | 4 ++-- 5 files changed, 7 insertions(+), 12 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3decea6..cdbcc0f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,6 @@ clap = { version = "4.5", features = ["derive"] } log = "0.4" # async -futures = "0.3" tokio = { version = "1.46", features = ["full"] } futures-util = "0.3.31" @@ -31,7 +30,6 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" toml = "0.9" bytes = "1.10.1" -rmpv = { version = "1.3.0", features = ["with-serde"] } # File Storage bindet = "0.3.2" @@ -39,8 +37,8 @@ bunny-api-tokio = { version = "0.4", features = ["edge_storage"], default-featur # Web Server axum = { version = "0.8.4", features = ["multipart", "ws"] } -tower-http = { version = "0.6.6", features = ["cors"] } axum-extra = { version = "0.10.1", features = ["cookie", "typed-header"] } +tower-http = { version = "0.6.6", features = ["cors"] } #socketioxide = { version = "0.17.2", features = ["state"] } url = { version = "2.5", features = ["serde"] } time = "0.3.41" @@ -63,5 +61,3 @@ lettre = { version = "0.11", features = ["tokio1", "tokio1-native-tls"] } chrono = { version = "0.4.41", features = ["serde"] } tracing-subscriber = "0.3.19" rand = "0.9.1" - - diff --git a/src/api/v1/channels/uuid/socket.rs b/src/api/v1/channels/uuid/socket.rs index 46a7334..dd020e3 100644 --- a/src/api/v1/channels/uuid/socket.rs +++ b/src/api/v1/channels/uuid/socket.rs @@ -5,8 +5,7 @@ use axum::{ http::HeaderMap, response::IntoResponse, }; -use futures::SinkExt; -use futures_util::StreamExt as _; +use futures_util::{SinkExt, StreamExt}; use serde::Deserialize; use uuid::Uuid; diff --git a/src/objects/channel.rs b/src/objects/channel.rs index 3b34ac6..cacb153 100644 --- a/src/objects/channel.rs +++ b/src/objects/channel.rs @@ -102,7 +102,7 @@ impl Channel { c.clone().build(&mut conn).await }); - futures::future::try_join_all(channel_futures).await + futures_util::future::try_join_all(channel_futures).await } pub async fn fetch_one(app_state: &AppState, channel_uuid: Uuid) -> Result { @@ -267,7 +267,7 @@ impl Channel { let message_futures = messages.iter().map(async move |b| b.build(app_state).await); - futures::future::try_join_all(message_futures).await + futures_util::future::try_join_all(message_futures).await } pub async fn new_message( diff --git a/src/objects/guild.rs b/src/objects/guild.rs index e27e129..9514e49 100644 --- a/src/objects/guild.rs +++ b/src/objects/guild.rs @@ -96,7 +96,7 @@ impl Guild { }); // Execute all futures concurrently and collect results - futures::future::try_join_all(guild_futures).await + futures_util::future::try_join_all(guild_futures).await } pub async fn new(conn: &mut Conn, name: String, owner_uuid: Uuid) -> Result { diff --git a/src/objects/me.rs b/src/objects/me.rs index 3b51da4..a0b399d 100644 --- a/src/objects/me.rs +++ b/src/objects/me.rs @@ -391,13 +391,13 @@ impl Me { User::fetch_one_with_friendship(app_state, self, friend.uuid2).await }); - let mut friends = futures::future::try_join_all(friend_futures).await?; + let mut friends = futures_util::future::try_join_all(friend_futures).await?; let friend_futures = friends2.iter().map(async move |friend| { User::fetch_one_with_friendship(app_state, self, friend.uuid1).await }); - friends.append(&mut futures::future::try_join_all(friend_futures).await?); + friends.append(&mut futures_util::future::try_join_all(friend_futures).await?); Ok(friends) } -- 2.47.3