Create a middleware that requires auth

This commit is contained in:
SeanOMik 2023-04-28 14:55:05 -04:00
parent 4c768753ab
commit f70e04c52d
Signed by: SeanOMik
GPG Key ID: 568F326C7EB33ACB
5 changed files with 99 additions and 31 deletions

View File

@ -1,11 +1,9 @@
use std::{sync::Arc, collections::{HashMap, BTreeMap}, time::{SystemTime, UNIX_EPOCH}}; use std::{sync::Arc, collections::{HashMap, BTreeMap}, time::{SystemTime, UNIX_EPOCH}};
use axum::{extract::{Query, State}, response::{IntoResponse, Response}, http::{StatusCode, Request, Method, HeaderName, header}, Form}; use axum::{extract::{Query, State}, response::{IntoResponse, Response}, http::{StatusCode, header}, Form};
use axum_auth::AuthBasic; use axum_auth::AuthBasic;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use qstring::QString;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json;
use tracing::{debug, error, info, span, Level}; use tracing::{debug, error, info, span, Level};
use hmac::{Hmac, Mac}; use hmac::{Hmac, Mac};
@ -14,7 +12,7 @@ use sha2::Sha256;
use rand::Rng; use rand::Rng;
use crate::{dto::scope::Scope, app_state::AppState, query::Qs}; use crate::{dto::scope::Scope, app_state::AppState};
#[derive(Deserialize, Debug)] #[derive(Deserialize, Debug)]
pub struct TokenAuthRequest { pub struct TokenAuthRequest {
@ -179,6 +177,9 @@ pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppS
let json_str = serde_json::to_string(&auth_response).unwrap(); let json_str = serde_json::to_string(&auth_response).unwrap();
let mut auth_storage = state.auth_storage.lock().await;
auth_storage.valid_tokens.insert(token_str.clone());
return ( return (
StatusCode::OK, StatusCode::OK,
[ [

View File

@ -1,9 +1,11 @@
use axum::extract::Query; use std::sync::Arc;
use axum::response::{IntoResponse, Response};
use axum::http::{StatusCode, HeaderName, header};
use tracing::debug;
use self::auth::TokenAuthRequest; use axum::Extension;
use axum::extract::State;
use axum::response::{IntoResponse, Response};
use axum::http::{StatusCode, HeaderName};
use crate::app_state::AppState;
pub mod blobs; pub mod blobs;
pub mod uploads; pub mod uploads;
@ -12,28 +14,13 @@ pub mod tags;
pub mod catalog; pub mod catalog;
pub mod auth; pub mod auth;
use crate::auth_storage::AuthToken;
/// https://docs.docker.com/registry/spec/api/#api-version-check /// https://docs.docker.com/registry/spec/api/#api-version-check
/// full endpoint: `/v2/` /// full endpoint: `/v2/`
pub async fn version_check(params: Option<Query<TokenAuthRequest>>, body: String) -> Response { pub async fn version_check(Extension(AuthToken(_token)): Extension<AuthToken>, _state: State<Arc<AppState>>) -> Response {
debug!("Got body: {}", body); (
/* (
StatusCode::OK, StatusCode::OK,
[( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0" )] [( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0" )]
) */
//Www-Authenticate: Bearer realm="https://auth.docker.io/token",service="registry.docker.io",scope="repository:samalba/my-app:pull,push"
let bearer = format!("Bearer realm=\"http://localhost:3000/auth\"");/* match params {
Some(Query(params)) => format!("Bearer realm=\"http://localhost:3000/token\",scope=\"{}\"", params.scope),
None => format!("Bearer realm=\"http://localhost:3000/token\""),
}; */
(
StatusCode::UNAUTHORIZED,
[
( header::WWW_AUTHENTICATE, bearer ),
( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string() )
]
).into_response() ).into_response()
} }

View File

@ -1,5 +1,6 @@
use sqlx::{Sqlite, Pool}; use sqlx::{Sqlite, Pool};
use crate::auth_storage::MemoryAuthStorage;
use crate::storage::StorageDriver; use crate::storage::StorageDriver;
use crate::config::Config; use crate::config::Config;
@ -9,6 +10,7 @@ pub struct AppState {
pub database: Pool<Sqlite>, pub database: Pool<Sqlite>,
pub storage: Mutex<Box<dyn StorageDriver>>, pub storage: Mutex<Box<dyn StorageDriver>>,
pub config: Config, pub config: Config,
pub auth_storage: Mutex<MemoryAuthStorage>,
} }
impl AppState { impl AppState {
@ -18,6 +20,7 @@ impl AppState {
database, database,
storage, storage,
config, config,
auth_storage: Mutex::new(MemoryAuthStorage::new()),
} }
} }
} }

75
src/auth_storage.rs Normal file
View File

@ -0,0 +1,75 @@
use std::{collections::HashSet, ops::Deref, sync::Arc};
use axum::{extract::State, http::{StatusCode, HeaderMap, header, HeaderName, Request}, middleware::Next, response::{Response, IntoResponse}};
use tracing::debug;
use crate::app_state::AppState;
/// Temporary struct for storing auth information in memory.
pub struct MemoryAuthStorage {
pub valid_tokens: HashSet<String>,
}
impl MemoryAuthStorage {
pub fn new() -> Self {
Self {
valid_tokens: HashSet::new(),
}
}
}
#[derive(Clone)]
pub struct AuthToken(pub String);
impl Deref for AuthToken {
type Target = String;
fn deref(&self) -> &Self::Target {
&self.0
}
}
type Rejection = (StatusCode, HeaderMap);
pub async fn require_auth<B>(State(state): State<Arc<AppState>>, mut request: Request<B>, next: Next<B>) -> Result<Response, Rejection> {
let bearer = format!("Bearer realm=\"http://localhost:3000/auth\"");
let mut failure_headers = HeaderMap::new();
failure_headers.append(header::WWW_AUTHENTICATE, bearer.parse().unwrap());
failure_headers.append(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".parse().unwrap());
let auth = String::from(
request.headers().get(header::AUTHORIZATION)
.ok_or((StatusCode::UNAUTHORIZED, failure_headers.clone()))?
.to_str()
.map_err(|_| (StatusCode::UNAUTHORIZED, failure_headers.clone()))?
); // TODO: Don't unwrap
let token = match auth.split_once(' ') {
Some((auth, token)) if auth == "Bearer" => token,
// This line would allow empty tokens
//_ if auth == "Bearer" => Ok(AuthToken(None)),
_ => return Err( (StatusCode::UNAUTHORIZED, failure_headers) ),
};
// If the token is not valid, return an unauthorized response
let auth_storage = state.auth_storage.lock().await;
if !auth_storage.valid_tokens.contains(token) {
let bearer = format!("Bearer realm=\"http://localhost:3000/auth\"");
return Ok((
StatusCode::UNAUTHORIZED,
[
( header::WWW_AUTHENTICATE, bearer ),
( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string() )
]
).into_response());
} else {
debug!("Client successfully authenticated!");
}
drop(auth_storage);
request.extensions_mut().insert(AuthToken(String::from(token)));
Ok(next.run(request).await)
}

View File

@ -6,6 +6,7 @@ mod storage;
mod byte_stream; mod byte_stream;
mod config; mod config;
mod query; mod query;
mod auth_storage;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::str::FromStr; use std::str::FromStr;
@ -88,15 +89,15 @@ async fn main() -> std::io::Result<()> {
.with_max_level(Level::DEBUG) .with_max_level(Level::DEBUG)
.init(); .init();
let auth_middleware = axum::middleware::from_fn_with_state(state.clone(), auth_storage::require_auth);
let path_middleware = axum::middleware::from_fn(change_request_paths); let path_middleware = axum::middleware::from_fn(change_request_paths);
let app = Router::new() let app = Router::new()
.route("/auth", routing::get(api::auth::auth_basic_get) .route("/auth", routing::get(api::auth::auth_basic_get)
.post(api::auth::auth_basic_get)) .post(api::auth::auth_basic_get))
.fallback(auth_failure)
.nest("/v2", Router::new() .nest("/v2", Router::new()
.route("/", routing::get(api::version_check)) .route("/", routing::get(api::version_check))
/* .route("/_catalog", routing::get(api::catalog::list_repositories)) .route("/_catalog", routing::get(api::catalog::list_repositories))
.route("/:name/tags/list", routing::get(api::tags::list_tags)) .route("/:name/tags/list", routing::get(api::tags::list_tags))
.nest("/:name/blobs", Router::new() .nest("/:name/blobs", Router::new()
.route("/:digest", routing::get(api::blobs::pull_digest_get) .route("/:digest", routing::get(api::blobs::pull_digest_get)
@ -114,7 +115,8 @@ async fn main() -> std::io::Result<()> {
.route("/:name/manifests/:reference", routing::get(api::manifests::pull_manifest_get) .route("/:name/manifests/:reference", routing::get(api::manifests::pull_manifest_get)
.put(api::manifests::upload_manifest_put) .put(api::manifests::upload_manifest_put)
.head(api::manifests::manifest_exists_head) .head(api::manifests::manifest_exists_head)
.delete(api::manifests::delete_manifest)) */ .delete(api::manifests::delete_manifest))
.layer(auth_middleware) // require auth for ALL v2 routes
) )
.with_state(state) .with_state(state)
.layer(TraceLayer::new_for_http()); .layer(TraceLayer::new_for_http());