From f70e04c52d7be1cbfbc382309a10d94ed25e069e Mon Sep 17 00:00:00 2001 From: SeanOMik Date: Fri, 28 Apr 2023 14:55:05 -0400 Subject: [PATCH] Create a middleware that requires auth --- src/api/auth.rs | 9 +++--- src/api/mod.rs | 35 +++++++-------------- src/app_state.rs | 3 ++ src/auth_storage.rs | 75 +++++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 8 +++-- 5 files changed, 99 insertions(+), 31 deletions(-) create mode 100644 src/auth_storage.rs diff --git a/src/api/auth.rs b/src/api/auth.rs index 67d2deb..dc85b51 100644 --- a/src/api/auth.rs +++ b/src/api/auth.rs @@ -1,11 +1,9 @@ use std::{sync::Arc, collections::{HashMap, BTreeMap}, time::{SystemTime, UNIX_EPOCH}}; -use axum::{extract::{Query, State}, response::{IntoResponse, Response}, http::{StatusCode, Request, Method, HeaderName, header}, Form}; +use axum::{extract::{Query, State}, response::{IntoResponse, Response}, http::{StatusCode, header}, Form}; use axum_auth::AuthBasic; use chrono::{DateTime, Utc}; -use qstring::QString; use serde::{Deserialize, Serialize}; -use serde_json::json; use tracing::{debug, error, info, span, Level}; use hmac::{Hmac, Mac}; @@ -14,7 +12,7 @@ use sha2::Sha256; use rand::Rng; -use crate::{dto::scope::Scope, app_state::AppState, query::Qs}; +use crate::{dto::scope::Scope, app_state::AppState}; #[derive(Deserialize, Debug)] pub struct TokenAuthRequest { @@ -179,6 +177,9 @@ pub async fn auth_basic_get(basic_auth: Option, state: State>, body: String) -> Response { - debug!("Got body: {}", body); - - /* ( +pub async fn version_check(Extension(AuthToken(_token)): Extension, _state: State>) -> Response { + ( StatusCode::OK, [( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0" )] - ) */ - - //Www-Authenticate: Bearer realm="https://auth.docker.io/token",service="registry.docker.io",scope="repository:samalba/my-app:pull,push" - - let bearer = format!("Bearer realm=\"http://localhost:3000/auth\"");/* match params { - Some(Query(params)) => format!("Bearer realm=\"http://localhost:3000/token\",scope=\"{}\"", params.scope), - None => format!("Bearer realm=\"http://localhost:3000/token\""), - }; */ - - ( - StatusCode::UNAUTHORIZED, - [ - ( header::WWW_AUTHENTICATE, bearer ), - ( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string() ) - ] ).into_response() } \ No newline at end of file diff --git a/src/app_state.rs b/src/app_state.rs index 15bbe66..9496963 100644 --- a/src/app_state.rs +++ b/src/app_state.rs @@ -1,5 +1,6 @@ use sqlx::{Sqlite, Pool}; +use crate::auth_storage::MemoryAuthStorage; use crate::storage::StorageDriver; use crate::config::Config; @@ -9,6 +10,7 @@ pub struct AppState { pub database: Pool, pub storage: Mutex>, pub config: Config, + pub auth_storage: Mutex, } impl AppState { @@ -18,6 +20,7 @@ impl AppState { database, storage, config, + auth_storage: Mutex::new(MemoryAuthStorage::new()), } } } \ No newline at end of file diff --git a/src/auth_storage.rs b/src/auth_storage.rs new file mode 100644 index 0000000..b06aa04 --- /dev/null +++ b/src/auth_storage.rs @@ -0,0 +1,75 @@ +use std::{collections::HashSet, ops::Deref, sync::Arc}; + +use axum::{extract::State, http::{StatusCode, HeaderMap, header, HeaderName, Request}, middleware::Next, response::{Response, IntoResponse}}; + +use tracing::debug; + +use crate::app_state::AppState; + +/// Temporary struct for storing auth information in memory. +pub struct MemoryAuthStorage { + pub valid_tokens: HashSet, +} + +impl MemoryAuthStorage { + pub fn new() -> Self { + Self { + valid_tokens: HashSet::new(), + } + } +} + +#[derive(Clone)] +pub struct AuthToken(pub String); + +impl Deref for AuthToken { + type Target = String; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +type Rejection = (StatusCode, HeaderMap); + +pub async fn require_auth(State(state): State>, mut request: Request, next: Next) -> Result { + let bearer = format!("Bearer realm=\"http://localhost:3000/auth\""); + let mut failure_headers = HeaderMap::new(); + failure_headers.append(header::WWW_AUTHENTICATE, bearer.parse().unwrap()); + failure_headers.append(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".parse().unwrap()); + + let auth = String::from( + request.headers().get(header::AUTHORIZATION) + .ok_or((StatusCode::UNAUTHORIZED, failure_headers.clone()))? + .to_str() + .map_err(|_| (StatusCode::UNAUTHORIZED, failure_headers.clone()))? + ); // TODO: Don't unwrap + + let token = match auth.split_once(' ') { + Some((auth, token)) if auth == "Bearer" => token, + // This line would allow empty tokens + //_ if auth == "Bearer" => Ok(AuthToken(None)), + _ => return Err( (StatusCode::UNAUTHORIZED, failure_headers) ), + }; + + // If the token is not valid, return an unauthorized response + let auth_storage = state.auth_storage.lock().await; + if !auth_storage.valid_tokens.contains(token) { + let bearer = format!("Bearer realm=\"http://localhost:3000/auth\""); + return Ok(( + StatusCode::UNAUTHORIZED, + [ + ( header::WWW_AUTHENTICATE, bearer ), + ( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string() ) + ] + ).into_response()); + + } else { + debug!("Client successfully authenticated!"); + } + drop(auth_storage); + + request.extensions_mut().insert(AuthToken(String::from(token))); + + Ok(next.run(request).await) +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index ae49dd4..07682c6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,7 @@ mod storage; mod byte_stream; mod config; mod query; +mod auth_storage; use std::net::SocketAddr; use std::str::FromStr; @@ -88,15 +89,15 @@ async fn main() -> std::io::Result<()> { .with_max_level(Level::DEBUG) .init(); + let auth_middleware = axum::middleware::from_fn_with_state(state.clone(), auth_storage::require_auth); let path_middleware = axum::middleware::from_fn(change_request_paths); let app = Router::new() .route("/auth", routing::get(api::auth::auth_basic_get) .post(api::auth::auth_basic_get)) - .fallback(auth_failure) .nest("/v2", Router::new() .route("/", routing::get(api::version_check)) - /* .route("/_catalog", routing::get(api::catalog::list_repositories)) + .route("/_catalog", routing::get(api::catalog::list_repositories)) .route("/:name/tags/list", routing::get(api::tags::list_tags)) .nest("/:name/blobs", Router::new() .route("/:digest", routing::get(api::blobs::pull_digest_get) @@ -114,7 +115,8 @@ async fn main() -> std::io::Result<()> { .route("/:name/manifests/:reference", routing::get(api::manifests::pull_manifest_get) .put(api::manifests::upload_manifest_put) .head(api::manifests::manifest_exists_head) - .delete(api::manifests::delete_manifest)) */ + .delete(api::manifests::delete_manifest)) + .layer(auth_middleware) // require auth for ALL v2 routes ) .with_state(state) .layer(TraceLayer::new_for_http());