Create a middleware that requires auth
This commit is contained in:
parent
4c768753ab
commit
f70e04c52d
|
@ -1,11 +1,9 @@
|
|||
use std::{sync::Arc, collections::{HashMap, BTreeMap}, time::{SystemTime, UNIX_EPOCH}};
|
||||
|
||||
use axum::{extract::{Query, State}, response::{IntoResponse, Response}, http::{StatusCode, Request, Method, HeaderName, header}, Form};
|
||||
use axum::{extract::{Query, State}, response::{IntoResponse, Response}, http::{StatusCode, header}, Form};
|
||||
use axum_auth::AuthBasic;
|
||||
use chrono::{DateTime, Utc};
|
||||
use qstring::QString;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use tracing::{debug, error, info, span, Level};
|
||||
|
||||
use hmac::{Hmac, Mac};
|
||||
|
@ -14,7 +12,7 @@ use sha2::Sha256;
|
|||
|
||||
use rand::Rng;
|
||||
|
||||
use crate::{dto::scope::Scope, app_state::AppState, query::Qs};
|
||||
use crate::{dto::scope::Scope, app_state::AppState};
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct TokenAuthRequest {
|
||||
|
@ -179,6 +177,9 @@ pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppS
|
|||
|
||||
let json_str = serde_json::to_string(&auth_response).unwrap();
|
||||
|
||||
let mut auth_storage = state.auth_storage.lock().await;
|
||||
auth_storage.valid_tokens.insert(token_str.clone());
|
||||
|
||||
return (
|
||||
StatusCode::OK,
|
||||
[
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
use axum::extract::Query;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::http::{StatusCode, HeaderName, header};
|
||||
use tracing::debug;
|
||||
use std::sync::Arc;
|
||||
|
||||
use self::auth::TokenAuthRequest;
|
||||
use axum::Extension;
|
||||
use axum::extract::State;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::http::{StatusCode, HeaderName};
|
||||
|
||||
use crate::app_state::AppState;
|
||||
|
||||
pub mod blobs;
|
||||
pub mod uploads;
|
||||
|
@ -12,28 +14,13 @@ pub mod tags;
|
|||
pub mod catalog;
|
||||
pub mod auth;
|
||||
|
||||
use crate::auth_storage::AuthToken;
|
||||
|
||||
/// https://docs.docker.com/registry/spec/api/#api-version-check
|
||||
/// full endpoint: `/v2/`
|
||||
pub async fn version_check(params: Option<Query<TokenAuthRequest>>, body: String) -> Response {
|
||||
debug!("Got body: {}", body);
|
||||
|
||||
/* (
|
||||
pub async fn version_check(Extension(AuthToken(_token)): Extension<AuthToken>, _state: State<Arc<AppState>>) -> Response {
|
||||
(
|
||||
StatusCode::OK,
|
||||
[( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0" )]
|
||||
) */
|
||||
|
||||
//Www-Authenticate: Bearer realm="https://auth.docker.io/token",service="registry.docker.io",scope="repository:samalba/my-app:pull,push"
|
||||
|
||||
let bearer = format!("Bearer realm=\"http://localhost:3000/auth\"");/* match params {
|
||||
Some(Query(params)) => format!("Bearer realm=\"http://localhost:3000/token\",scope=\"{}\"", params.scope),
|
||||
None => format!("Bearer realm=\"http://localhost:3000/token\""),
|
||||
}; */
|
||||
|
||||
(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
[
|
||||
( header::WWW_AUTHENTICATE, bearer ),
|
||||
( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string() )
|
||||
]
|
||||
).into_response()
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
use sqlx::{Sqlite, Pool};
|
||||
|
||||
use crate::auth_storage::MemoryAuthStorage;
|
||||
use crate::storage::StorageDriver;
|
||||
use crate::config::Config;
|
||||
|
||||
|
@ -9,6 +10,7 @@ pub struct AppState {
|
|||
pub database: Pool<Sqlite>,
|
||||
pub storage: Mutex<Box<dyn StorageDriver>>,
|
||||
pub config: Config,
|
||||
pub auth_storage: Mutex<MemoryAuthStorage>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
|
@ -18,6 +20,7 @@ impl AppState {
|
|||
database,
|
||||
storage,
|
||||
config,
|
||||
auth_storage: Mutex::new(MemoryAuthStorage::new()),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
use std::{collections::HashSet, ops::Deref, sync::Arc};
|
||||
|
||||
use axum::{extract::State, http::{StatusCode, HeaderMap, header, HeaderName, Request}, middleware::Next, response::{Response, IntoResponse}};
|
||||
|
||||
use tracing::debug;
|
||||
|
||||
use crate::app_state::AppState;
|
||||
|
||||
/// Temporary struct for storing auth information in memory.
|
||||
pub struct MemoryAuthStorage {
|
||||
pub valid_tokens: HashSet<String>,
|
||||
}
|
||||
|
||||
impl MemoryAuthStorage {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
valid_tokens: HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AuthToken(pub String);
|
||||
|
||||
impl Deref for AuthToken {
|
||||
type Target = String;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
type Rejection = (StatusCode, HeaderMap);
|
||||
|
||||
pub async fn require_auth<B>(State(state): State<Arc<AppState>>, mut request: Request<B>, next: Next<B>) -> Result<Response, Rejection> {
|
||||
let bearer = format!("Bearer realm=\"http://localhost:3000/auth\"");
|
||||
let mut failure_headers = HeaderMap::new();
|
||||
failure_headers.append(header::WWW_AUTHENTICATE, bearer.parse().unwrap());
|
||||
failure_headers.append(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".parse().unwrap());
|
||||
|
||||
let auth = String::from(
|
||||
request.headers().get(header::AUTHORIZATION)
|
||||
.ok_or((StatusCode::UNAUTHORIZED, failure_headers.clone()))?
|
||||
.to_str()
|
||||
.map_err(|_| (StatusCode::UNAUTHORIZED, failure_headers.clone()))?
|
||||
); // TODO: Don't unwrap
|
||||
|
||||
let token = match auth.split_once(' ') {
|
||||
Some((auth, token)) if auth == "Bearer" => token,
|
||||
// This line would allow empty tokens
|
||||
//_ if auth == "Bearer" => Ok(AuthToken(None)),
|
||||
_ => return Err( (StatusCode::UNAUTHORIZED, failure_headers) ),
|
||||
};
|
||||
|
||||
// If the token is not valid, return an unauthorized response
|
||||
let auth_storage = state.auth_storage.lock().await;
|
||||
if !auth_storage.valid_tokens.contains(token) {
|
||||
let bearer = format!("Bearer realm=\"http://localhost:3000/auth\"");
|
||||
return Ok((
|
||||
StatusCode::UNAUTHORIZED,
|
||||
[
|
||||
( header::WWW_AUTHENTICATE, bearer ),
|
||||
( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string() )
|
||||
]
|
||||
).into_response());
|
||||
|
||||
} else {
|
||||
debug!("Client successfully authenticated!");
|
||||
}
|
||||
drop(auth_storage);
|
||||
|
||||
request.extensions_mut().insert(AuthToken(String::from(token)));
|
||||
|
||||
Ok(next.run(request).await)
|
||||
}
|
|
@ -6,6 +6,7 @@ mod storage;
|
|||
mod byte_stream;
|
||||
mod config;
|
||||
mod query;
|
||||
mod auth_storage;
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::str::FromStr;
|
||||
|
@ -88,15 +89,15 @@ async fn main() -> std::io::Result<()> {
|
|||
.with_max_level(Level::DEBUG)
|
||||
.init();
|
||||
|
||||
let auth_middleware = axum::middleware::from_fn_with_state(state.clone(), auth_storage::require_auth);
|
||||
let path_middleware = axum::middleware::from_fn(change_request_paths);
|
||||
|
||||
let app = Router::new()
|
||||
.route("/auth", routing::get(api::auth::auth_basic_get)
|
||||
.post(api::auth::auth_basic_get))
|
||||
.fallback(auth_failure)
|
||||
.nest("/v2", Router::new()
|
||||
.route("/", routing::get(api::version_check))
|
||||
/* .route("/_catalog", routing::get(api::catalog::list_repositories))
|
||||
.route("/_catalog", routing::get(api::catalog::list_repositories))
|
||||
.route("/:name/tags/list", routing::get(api::tags::list_tags))
|
||||
.nest("/:name/blobs", Router::new()
|
||||
.route("/:digest", routing::get(api::blobs::pull_digest_get)
|
||||
|
@ -114,7 +115,8 @@ async fn main() -> std::io::Result<()> {
|
|||
.route("/:name/manifests/:reference", routing::get(api::manifests::pull_manifest_get)
|
||||
.put(api::manifests::upload_manifest_put)
|
||||
.head(api::manifests::manifest_exists_head)
|
||||
.delete(api::manifests::delete_manifest)) */
|
||||
.delete(api::manifests::delete_manifest))
|
||||
.layer(auth_middleware) // require auth for ALL v2 routes
|
||||
)
|
||||
.with_state(state)
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
|
Loading…
Reference in New Issue