diff --git a/backend/Cargo.lock b/backend/Cargo.lock index f4a7b76..06ab043 100644 --- a/backend/Cargo.lock +++ b/backend/Cargo.lock @@ -125,6 +125,7 @@ dependencies = [ "anyhow", "axum", "axum-test", + "lazy-regex", "serde", "serde_json", "tokio", @@ -692,6 +693,29 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "lazy-regex" +version = "3.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bae91019476d3ec7147de9aa291cadb6d870abf2f3015d2da73a90325ac1496" +dependencies = [ + "lazy-regex-proc_macros", + "once_cell", + "regex", +] + +[[package]] +name = "lazy-regex-proc_macros" +version = "3.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4de9c1e1439d8b7b3061b2d209809f447ca33241733d9a3c01eabf2dc8d94358" +dependencies = [ + "proc-macro2", + "quote", + "regex", + "syn", +] + [[package]] name = "lazy_static" version = "1.5.0" diff --git a/backend/Cargo.toml b/backend/Cargo.toml index 6cb4d0f..9f4b617 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -13,6 +13,7 @@ tower-http = { version = "0.6.8", features = ["fs"] } tower-cookies = "0.11.0" tracing = "0.1.44" tracing-subscriber = { version = "0.3.23", features = ["env-filter"] } +lazy-regex = "3.6.0" [dev-dependencies] axum-test = "20.0.0" diff --git a/backend/src/ctx.rs b/backend/src/ctx.rs new file mode 100644 index 0000000..7034526 --- /dev/null +++ b/backend/src/ctx.rs @@ -0,0 +1,14 @@ +#[derive(Clone)] +pub struct Ctx { + user_id: u64, +} + +impl Ctx { + pub fn new(user_id: u64) -> Self { + Self { user_id } + } + + pub fn user_id(&self) -> u64 { + self.user_id + } +} diff --git a/backend/src/error.rs b/backend/src/error.rs index 7d5cddb..c48319b 100644 --- a/backend/src/error.rs +++ b/backend/src/error.rs @@ -3,9 +3,12 @@ use std::fmt; use axum::{http::StatusCode, response::IntoResponse}; use tracing::info; -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum LoftError { LoginFail, + AuthFailNoAuthTokenCookie, + AuthFailTokenWrongFormat, + AuthFailCtxNotInRequestExt, FileIdNotFound, } @@ -20,7 +23,10 @@ impl std::error::Error for LoftError {} impl IntoResponse for LoftError { fn into_response(self) -> axum::response::Response { match self { - Self::LoginFail => { + Self::LoginFail + | Self::AuthFailNoAuthTokenCookie + | Self::AuthFailTokenWrongFormat + | Self::AuthFailCtxNotInRequestExt => { info!("UNAUTHORIZED"); StatusCode::UNAUTHORIZED.into_response() } diff --git a/backend/src/main.rs b/backend/src/main.rs index 5c120e1..0e92399 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -1,3 +1,4 @@ +mod ctx; mod error; mod model; mod web; @@ -11,7 +12,12 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use crate::{ model::FileController, - web::{routes_file::routes_file, routes_health::routes_health, routes_login::routes_login}, + web::{ + mw_auth::{mw_ctx_resolver, mw_require_auth}, + routes_file::routes_file, + routes_health::routes_health, + routes_login::routes_login, + }, }; #[tokio::main] @@ -27,11 +33,18 @@ async fn main() -> Result<()> { let file_controller = FileController::new().await?; + let routes_file = + routes_file(file_controller.clone()).route_layer(middleware::from_fn(mw_require_auth)); + let app = Router::new() - .nest("/api", routes_file(file_controller)) + .nest("/api", routes_file) .merge(routes_health()) .merge(routes_login()) .layer(middleware::map_response(main_response_mapper)) + .layer(middleware::from_fn_with_state( + file_controller, + mw_ctx_resolver, + )) .layer(CookieManagerLayer::new()) .fallback_service(ServeDir::new("./")); diff --git a/backend/src/model.rs b/backend/src/model.rs index 1ad691a..ed73389 100644 --- a/backend/src/model.rs +++ b/backend/src/model.rs @@ -68,12 +68,10 @@ impl FileController { #[cfg(test)] mod tests { - use anyhow::Result; - use super::*; - async fn fc() -> Result { - Ok(FileController::new().await?) + async fn fc() -> Result { + Ok(FileController::new().await.unwrap()) } fn new_file(name: &str) -> FileToCreate { @@ -84,53 +82,48 @@ mod tests { } #[tokio::test] - async fn test_upload_and_list() -> Result<()> { - let fc = fc().await?; - fc.upload_file(new_file("a.txt")).await?; - fc.upload_file(new_file("b.txt")).await?; - let files = fc.list_files().await?; + async fn test_upload_and_list() { + let fc = fc().await.unwrap(); + fc.upload_file(new_file("a.txt")).await.unwrap(); + fc.upload_file(new_file("b.txt")).await.unwrap(); + let files = fc.list_files().await.unwrap(); assert_eq!(files.len(), 2); - Ok(()) } #[tokio::test] - async fn test_download() -> Result<()> { - let fc = fc().await?; - let uploaded = fc.upload_file(new_file("a.txt")).await?; - let downloaded = fc.download_file(uploaded.id).await?; + async fn test_download() { + let fc = fc().await.unwrap(); + let uploaded = fc.upload_file(new_file("a.txt")).await.unwrap(); + let downloaded = fc.download_file(uploaded.id).await.unwrap(); assert_eq!(downloaded.name, "a.txt"); - Ok(()) } #[tokio::test] - async fn test_download_not_found() -> Result<()> { - let fc = fc().await?; + async fn test_download_not_found() { + let fc = fc().await.unwrap(); assert!(matches!( fc.download_file(99).await, Err(LoftError::FileIdNotFound) )); - Ok(()) } #[tokio::test] - async fn test_delete() -> Result<()> { - let fc = fc().await?; - let uploaded = fc.upload_file(new_file("a.txt")).await?; - fc.delete_file(uploaded.id).await?; + async fn test_delete() { + let fc = fc().await.unwrap(); + let uploaded = fc.upload_file(new_file("a.txt")).await.unwrap(); + fc.delete_file(uploaded.id).await.unwrap(); assert!(matches!( fc.download_file(uploaded.id).await, Err(LoftError::FileIdNotFound) )); - Ok(()) } #[tokio::test] - async fn test_delete_not_found() -> Result<()> { - let fc = fc().await?; + async fn test_delete_not_found() { + let fc = fc().await.unwrap(); assert!(matches!( fc.delete_file(99).await, Err(LoftError::FileIdNotFound) )); - Ok(()) } } diff --git a/backend/src/web/mod.rs b/backend/src/web/mod.rs index 1cdb6fa..4efd868 100644 --- a/backend/src/web/mod.rs +++ b/backend/src/web/mod.rs @@ -1,3 +1,6 @@ +pub mod mw_auth; pub mod routes_file; pub mod routes_health; pub mod routes_login; + +pub const AUTH_TOKEN: &str = "auth-token"; diff --git a/backend/src/web/mw_auth.rs b/backend/src/web/mw_auth.rs new file mode 100644 index 0000000..9de7c94 --- /dev/null +++ b/backend/src/web/mw_auth.rs @@ -0,0 +1,79 @@ +use axum::{ + extract::{FromRequestParts, Request}, + middleware::Next, + response::Response, +}; +use lazy_regex::regex_captures; +use tower_cookies::{Cookie, Cookies}; + +use crate::{ctx::Ctx, error::LoftError, web::AUTH_TOKEN}; + +/// validates the cookie exists and is well-formed (3-part format) +pub async fn mw_require_auth( + ctx: Result, + req: Request, + next: Next, +) -> Result { + ctx?; + + Ok(next.run(req).await) +} + +pub async fn mw_ctx_resolver( + cookies: Cookies, + mut req: Request, + next: Next, +) -> Result { + let auth_token = cookies.get(AUTH_TOKEN).map(|c| c.value().to_string()); + let result_ctx = match auth_token + .ok_or(LoftError::AuthFailNoAuthTokenCookie) + .and_then(parse_auth_token) + { + Ok((user_id, _, _)) => { + //TODO: add validation + Ok(Ctx::new(user_id)) + } + Err(e) => Err(e), + }; + + if result_ctx.is_err() && !matches!(result_ctx, Err(LoftError::AuthFailNoAuthTokenCookie)) { + cookies.remove(Cookie::from(AUTH_TOKEN)) + } + + req.extensions_mut().insert(result_ctx); + + Ok(next.run(req).await) +} + +impl FromRequestParts for Ctx { + type Rejection = LoftError; + + // extracts user_id from the token and makes it available to handlers as an extractor + fn from_request_parts( + parts: &mut axum::http::request::Parts, + _: &S, + ) -> impl Future> + Send { + async move { + parts + .extensions + .get::>() + .ok_or(LoftError::AuthFailCtxNotInRequestExt)? + .clone() + } + } +} + +fn parse_auth_token(auth_token: String) -> Result<(u64, u64, String), LoftError> { + let (_, user_id, expiration, signature) = + regex_captures!(r"^user-(\d+)\.(\d+)\.([a-f0-9]+)$", &auth_token) + .ok_or(LoftError::AuthFailTokenWrongFormat)?; + + let user_id: u64 = user_id + .parse() + .map_err(|_| LoftError::AuthFailTokenWrongFormat)?; + let expiration: u64 = expiration + .parse() + .map_err(|_| LoftError::AuthFailTokenWrongFormat)?; + + Ok((user_id, expiration, signature.to_string())) +} diff --git a/backend/src/web/routes_file.rs b/backend/src/web/routes_file.rs index d5af8c8..aed31a4 100644 --- a/backend/src/web/routes_file.rs +++ b/backend/src/web/routes_file.rs @@ -59,21 +59,69 @@ async fn list_files( #[cfg(test)] mod tests { + use axum::{Router, middleware}; use axum_test::TestServer; use serde_json::json; + use tower_cookies::CookieManagerLayer; - use crate::{model::FileController, web::routes_file::routes_file}; + use crate::{ + model::FileController, + web::{ + mw_auth::{mw_ctx_resolver, mw_require_auth}, + routes_file::routes_file, + }, + }; + + // Cookie format: user-[user-id].[expiration].[signature] + const AUTH_COOKIE: &str = "auth-token=user-1.0123456789.a1b2c3d4e5f6"; + const BAD_AUTH_COOKIE: &str = "auth-token=user-1.0123456789"; async fn test_server() -> TestServer { - let fc = FileController::new().await.unwrap(); - TestServer::new(routes_file(fc)) + let file_controller = FileController::new().await.unwrap(); + let routes_file = + routes_file(file_controller.clone()).route_layer(middleware::from_fn(mw_require_auth)); + let app = Router::new() + .nest("/api", routes_file) + .layer(middleware::from_fn_with_state( + file_controller, + mw_ctx_resolver, + )) + .layer(CookieManagerLayer::new()); + TestServer::new(app) + } + + #[tokio::test] + async fn test_requires_auth() { + let server = test_server().await; + server.get("/api/files").await.assert_status_unauthorized(); + } + + #[tokio::test] + async fn test_requires_auth_invalid_cookie() { + let server = test_server().await; + server + .get("/api/files") + .add_header(axum::http::header::COOKIE, BAD_AUTH_COOKIE) + .await + .assert_status_unauthorized(); + } + + #[tokio::test] + async fn test_requires_auth_post() { + let server = test_server().await; + server + .post("/api/files") + .json(&json!({"name": "a.txt", "file_type": "text"})) + .await + .assert_status_unauthorized(); } #[tokio::test] async fn test_list_files_empty() { let server = test_server().await; server - .get("/files") + .get("/api/files") + .add_header(axum::http::header::COOKIE, AUTH_COOKIE) .await .assert_status_ok() .assert_json(&json!([])); @@ -83,7 +131,8 @@ mod tests { async fn test_upload_and_list_files() { let server = test_server().await; let res = server - .post("/files") + .post("/api/files") + .add_header(axum::http::header::COOKIE, AUTH_COOKIE) .json(&json!({"name": "a.txt", "file_type": "text"})) .await; res.assert_status_ok(); @@ -91,7 +140,11 @@ mod tests { assert_eq!(file["name"], "a.txt"); assert_eq!(file["id"], 0); - let list = server.get("/files").await.json::(); + let list = server + .get("/api/files") + .add_header(axum::http::header::COOKIE, AUTH_COOKIE) + .await + .json::(); assert_eq!(list.as_array().unwrap().len(), 1); } @@ -99,10 +152,14 @@ mod tests { async fn test_download_file() { let server = test_server().await; server - .post("/files") + .post("/api/files") + .add_header(axum::http::header::COOKIE, AUTH_COOKIE) .json(&json!({"name": "b.txt", "file_type": "text"})) .await; - let res = server.get("/files/0").await; + let res = server + .get("/api/files/0") + .add_header(axum::http::header::COOKIE, AUTH_COOKIE) + .await; res.assert_status_ok(); assert_eq!(res.json::()["name"], "b.txt"); } @@ -110,23 +167,40 @@ mod tests { #[tokio::test] async fn test_download_file_not_found() { let server = test_server().await; - server.get("/files/99").await.assert_status_not_found(); + server + .get("/api/files/99") + .add_header(axum::http::header::COOKIE, AUTH_COOKIE) + .await + .assert_status_not_found(); } #[tokio::test] async fn test_delete_file() { let server = test_server().await; server - .post("/files") + .post("/api/files") + .add_header(axum::http::header::COOKIE, AUTH_COOKIE) .json(&json!({"name": "c.txt", "file_type": "text"})) .await; - server.delete("/files/0").await.assert_status_ok(); - server.get("/files/0").await.assert_status_not_found(); + server + .delete("/api/files/0") + .add_header(axum::http::header::COOKIE, AUTH_COOKIE) + .await + .assert_status_ok(); + server + .get("/api/files/0") + .add_header(axum::http::header::COOKIE, AUTH_COOKIE) + .await + .assert_status_not_found(); } #[tokio::test] async fn test_delete_file_not_found() { let server = test_server().await; - server.delete("/files/99").await.assert_status_not_found(); + server + .delete("/api/files/99") + .add_header(axum::http::header::COOKIE, AUTH_COOKIE) + .await + .assert_status_not_found(); } } diff --git a/backend/src/web/routes_login.rs b/backend/src/web/routes_login.rs index 238e9c2..6c41260 100644 --- a/backend/src/web/routes_login.rs +++ b/backend/src/web/routes_login.rs @@ -6,7 +6,7 @@ use serde::Deserialize; use serde_json::{Value, json}; use tower_cookies::{Cookie, Cookies}; -use crate::error::LoftError; +use crate::{error::LoftError, web::AUTH_TOKEN}; pub fn routes_login() -> Router { Router::new() @@ -24,7 +24,7 @@ async fn login( } // FIXME: real auth-token generation-signature - cookies.add(Cookie::new("auth-token", "user-1.exp.sign")); + cookies.add(Cookie::new(AUTH_TOKEN, "user-1.exp.sign")); let body = Json(json!({ "result": {