diff --git a/Cargo.lock b/Cargo.lock index 5a119b4..fd1bdef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -71,6 +71,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5" dependencies = [ "axum-core", + "axum-macros", "bytes", "futures-util", "http", @@ -115,6 +116,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-macros" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "backtrace" version = "0.3.75" @@ -1670,6 +1682,7 @@ dependencies = [ "tower", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -1792,6 +1805,8 @@ dependencies = [ "rand_core", "subtle", "tokio", + "tower", + "tower-http", "tracing", "tracing-subscriber", ] diff --git a/Cargo.toml b/Cargo.toml index 25be1a3..1ca2aa8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,13 +8,7 @@ opt-level = "s" lto = "thin" [dependencies] -axum = { version = "0.8.4", default-features = false, features = [ - "http2", - "multipart", - "tokio", - "tower-log", - "tracing", -] } +axum = { version = "0.8.4", default-features = false, features = ["http2", "macros", "multipart", "tokio", "tower-log", "tracing"] } base64 = "0.22.1" bs58 = "0.5.1" color-eyre = "0.6.5" @@ -22,5 +16,7 @@ object_store = { version = "0.12.3", default-features = false, features = ["aws" rand_core = { version = "0.9.3", features = ["os_rng"] } subtle = { version = "2.6.1", default-features = false } tokio = { version = "1.47.1", features = ["macros", "rt", "net"] } +tower = "0.5.2" +tower-http = { version = "0.6.6", features = ["trace"] } tracing = "0.1.41" tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } diff --git a/src/main.rs b/src/main.rs index 0b12a3a..54788f9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,8 @@ use axum::{ body::Bytes, - extract::{DefaultBodyLimit, FromRequestParts, Multipart, State}, + extract::{DefaultBodyLimit, Multipart, Request, State}, http::{header, StatusCode}, + middleware::Next, response::{Html, IntoResponse, Redirect, Response}, routing::{get, post}, Router, @@ -11,6 +12,7 @@ use color_eyre::eyre::{self, bail, Context}; use color_eyre::{eyre::OptionExt, Result}; use object_store::ObjectStore; use rand_core::TryRngCore; +use tower::ServiceBuilder; use tracing::{error, info, level_filters::LevelFilter}; use tracing_subscriber::EnvFilter; @@ -48,16 +50,23 @@ async fn main() -> Result<()> { .build() .wrap_err("failed to build client")?; + let state = Config { + username, + password, + s3_client, + }; + let app = Router::new() .route("/", get(index)) .route("/", post(upload)) - .with_state(Config { - username, - password, - s3_client, - }) - // raise limit to 100MB - .layer(DefaultBodyLimit::max(100_000_000)); + .with_state(state.clone()) + .layer( + ServiceBuilder::new() + .layer(tower_http::trace::TraceLayer::new_for_http()) + // raise limit to 100MB + .layer(DefaultBodyLimit::max(100_000_000)) + .layer(axum::middleware::from_fn_with_state(state, auth_middleware)), + ); let addr = "0.0.0.0:3050"; let listener = tokio::net::TcpListener::bind(addr) @@ -68,23 +77,11 @@ async fn main() -> Result<()> { axum::serve(listener, app).await.wrap_err("failed to serve") } -// todo: use middleware -async fn index(_: Auth) -> impl IntoResponse { +async fn index() -> impl IntoResponse { Html(include_str!("../index.html")) } -async fn upload( - auth: Auth, - State(config): State, - multipart: Multipart, -) -> Result { - if auth.username != config.username { - return Err(reject_auth("invalid username")); - } - if subtle::ConstantTimeEq::ct_ne(auth.password.as_bytes(), config.password.as_bytes()).into() { - return Err(reject_auth("invalid password")); - } - +async fn upload(State(config): State, multipart: Multipart) -> Result { let req = parse_req(multipart).await.map_err(|err| { info!(?err, "Bad request for upload"); (StatusCode::BAD_REQUEST, err.to_string()).into_response() @@ -203,9 +200,48 @@ async fn parse_req(mut multipart: Multipart) -> Result { }) } -struct Auth { - username: String, - password: String, +#[axum::debug_middleware] +async fn auth_middleware(State(config): State, request: Request, next: Next) -> Response { + match check_auth(config, request).await { + Ok(request) => next.run(request).await, + Err(err) => err, + } +} + +async fn check_auth(config: Config, request: Request) -> Result { + let Some(header) = request.headers().get(header::AUTHORIZATION) else { + return Err(reject_auth("missing authorization header")); + }; + + let header = header + .to_str() + .map_err(|_| reject_auth("authorization header is invalid UTF-8"))?; + + let Some(("Basic", value)) = header.split_once(' ') else { + return Err(reject_auth( + "invalid authorization header, missing 'Basic '", + )); + }; + + let decoded = String::from_utf8( + base64::prelude::BASE64_STANDARD + .decode(value) + .map_err(|_| reject_auth("invalid base64 value"))?, + ) + .map_err(|_| reject_auth("invalid UTF-8 after base64 decode"))?; + + let Some((username, password)) = decoded.split_once(':') else { + return Err(reject_auth("missing : between username and password")); + }; + + if username != config.username { + return Err(reject_auth("invalid username")); + } + if subtle::ConstantTimeEq::ct_ne(password.as_bytes(), config.password.as_bytes()).into() { + return Err(reject_auth("invalid password")); + } + + Ok(request) } fn reject_auth(reason: &str) -> Response { @@ -219,49 +255,3 @@ fn reject_auth(reason: &str) -> Response { ) .into_response() } - -impl FromRequestParts for Auth { - type Rejection = Response; - - async fn from_request_parts( - parts: &mut axum::http::request::Parts, - config: &Config, - ) -> Result { - let Some(header) = parts.headers.get(header::AUTHORIZATION) else { - return Err(reject_auth("missing authorization header")); - }; - - let header = header - .to_str() - .map_err(|_| reject_auth("authorization header is invalid UTF-8"))?; - - let Some(("Basic", value)) = header.split_once(' ') else { - return Err(reject_auth( - "invalid authorization header, missing 'Basic '", - )); - }; - - let decoded = String::from_utf8( - base64::prelude::BASE64_STANDARD - .decode(value) - .map_err(|_| reject_auth("invalid base64 value"))?, - ) - .map_err(|_| reject_auth("invalid UTF-8 after base64 decode"))?; - - let Some((username, password)) = decoded.split_once(':') else { - return Err(reject_auth("missing : between username and password")); - }; - - if username != config.username { - return Err(reject_auth("invalid username")); - } - if subtle::ConstantTimeEq::ct_ne(password.as_bytes(), config.password.as_bytes()).into() { - return Err(reject_auth("invalid password")); - } - - Ok(Auth { - username: username.to_owned(), - password: password.to_owned(), - }) - } -}