Create a middleware that requires auth
This commit is contained in:
parent
4c768753ab
commit
f70e04c52d
|
@ -1,11 +1,9 @@
|
||||||
use std::{sync::Arc, collections::{HashMap, BTreeMap}, time::{SystemTime, UNIX_EPOCH}};
|
use std::{sync::Arc, collections::{HashMap, BTreeMap}, time::{SystemTime, UNIX_EPOCH}};
|
||||||
|
|
||||||
use axum::{extract::{Query, State}, response::{IntoResponse, Response}, http::{StatusCode, Request, Method, HeaderName, header}, Form};
|
use axum::{extract::{Query, State}, response::{IntoResponse, Response}, http::{StatusCode, header}, Form};
|
||||||
use axum_auth::AuthBasic;
|
use axum_auth::AuthBasic;
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use qstring::QString;
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::json;
|
|
||||||
use tracing::{debug, error, info, span, Level};
|
use tracing::{debug, error, info, span, Level};
|
||||||
|
|
||||||
use hmac::{Hmac, Mac};
|
use hmac::{Hmac, Mac};
|
||||||
|
@ -14,7 +12,7 @@ use sha2::Sha256;
|
||||||
|
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
|
|
||||||
use crate::{dto::scope::Scope, app_state::AppState, query::Qs};
|
use crate::{dto::scope::Scope, app_state::AppState};
|
||||||
|
|
||||||
#[derive(Deserialize, Debug)]
|
#[derive(Deserialize, Debug)]
|
||||||
pub struct TokenAuthRequest {
|
pub struct TokenAuthRequest {
|
||||||
|
@ -179,6 +177,9 @@ pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppS
|
||||||
|
|
||||||
let json_str = serde_json::to_string(&auth_response).unwrap();
|
let json_str = serde_json::to_string(&auth_response).unwrap();
|
||||||
|
|
||||||
|
let mut auth_storage = state.auth_storage.lock().await;
|
||||||
|
auth_storage.valid_tokens.insert(token_str.clone());
|
||||||
|
|
||||||
return (
|
return (
|
||||||
StatusCode::OK,
|
StatusCode::OK,
|
||||||
[
|
[
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
use axum::extract::Query;
|
use std::sync::Arc;
|
||||||
use axum::response::{IntoResponse, Response};
|
|
||||||
use axum::http::{StatusCode, HeaderName, header};
|
|
||||||
use tracing::debug;
|
|
||||||
|
|
||||||
use self::auth::TokenAuthRequest;
|
use axum::Extension;
|
||||||
|
use axum::extract::State;
|
||||||
|
use axum::response::{IntoResponse, Response};
|
||||||
|
use axum::http::{StatusCode, HeaderName};
|
||||||
|
|
||||||
|
use crate::app_state::AppState;
|
||||||
|
|
||||||
pub mod blobs;
|
pub mod blobs;
|
||||||
pub mod uploads;
|
pub mod uploads;
|
||||||
|
@ -12,28 +14,13 @@ pub mod tags;
|
||||||
pub mod catalog;
|
pub mod catalog;
|
||||||
pub mod auth;
|
pub mod auth;
|
||||||
|
|
||||||
|
use crate::auth_storage::AuthToken;
|
||||||
|
|
||||||
/// https://docs.docker.com/registry/spec/api/#api-version-check
|
/// https://docs.docker.com/registry/spec/api/#api-version-check
|
||||||
/// full endpoint: `/v2/`
|
/// full endpoint: `/v2/`
|
||||||
pub async fn version_check(params: Option<Query<TokenAuthRequest>>, body: String) -> Response {
|
pub async fn version_check(Extension(AuthToken(_token)): Extension<AuthToken>, _state: State<Arc<AppState>>) -> Response {
|
||||||
debug!("Got body: {}", body);
|
(
|
||||||
|
|
||||||
/* (
|
|
||||||
StatusCode::OK,
|
StatusCode::OK,
|
||||||
[( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0" )]
|
[( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0" )]
|
||||||
) */
|
|
||||||
|
|
||||||
//Www-Authenticate: Bearer realm="https://auth.docker.io/token",service="registry.docker.io",scope="repository:samalba/my-app:pull,push"
|
|
||||||
|
|
||||||
let bearer = format!("Bearer realm=\"http://localhost:3000/auth\"");/* match params {
|
|
||||||
Some(Query(params)) => format!("Bearer realm=\"http://localhost:3000/token\",scope=\"{}\"", params.scope),
|
|
||||||
None => format!("Bearer realm=\"http://localhost:3000/token\""),
|
|
||||||
}; */
|
|
||||||
|
|
||||||
(
|
|
||||||
StatusCode::UNAUTHORIZED,
|
|
||||||
[
|
|
||||||
( header::WWW_AUTHENTICATE, bearer ),
|
|
||||||
( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string() )
|
|
||||||
]
|
|
||||||
).into_response()
|
).into_response()
|
||||||
}
|
}
|
|
@ -1,5 +1,6 @@
|
||||||
use sqlx::{Sqlite, Pool};
|
use sqlx::{Sqlite, Pool};
|
||||||
|
|
||||||
|
use crate::auth_storage::MemoryAuthStorage;
|
||||||
use crate::storage::StorageDriver;
|
use crate::storage::StorageDriver;
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
|
|
||||||
|
@ -9,6 +10,7 @@ pub struct AppState {
|
||||||
pub database: Pool<Sqlite>,
|
pub database: Pool<Sqlite>,
|
||||||
pub storage: Mutex<Box<dyn StorageDriver>>,
|
pub storage: Mutex<Box<dyn StorageDriver>>,
|
||||||
pub config: Config,
|
pub config: Config,
|
||||||
|
pub auth_storage: Mutex<MemoryAuthStorage>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AppState {
|
impl AppState {
|
||||||
|
@ -18,6 +20,7 @@ impl AppState {
|
||||||
database,
|
database,
|
||||||
storage,
|
storage,
|
||||||
config,
|
config,
|
||||||
|
auth_storage: Mutex::new(MemoryAuthStorage::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -0,0 +1,75 @@
|
||||||
|
use std::{collections::HashSet, ops::Deref, sync::Arc};
|
||||||
|
|
||||||
|
use axum::{extract::State, http::{StatusCode, HeaderMap, header, HeaderName, Request}, middleware::Next, response::{Response, IntoResponse}};
|
||||||
|
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
|
use crate::app_state::AppState;
|
||||||
|
|
||||||
|
/// Temporary struct for storing auth information in memory.
|
||||||
|
pub struct MemoryAuthStorage {
|
||||||
|
pub valid_tokens: HashSet<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MemoryAuthStorage {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
valid_tokens: HashSet::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AuthToken(pub String);
|
||||||
|
|
||||||
|
impl Deref for AuthToken {
|
||||||
|
type Target = String;
|
||||||
|
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Rejection = (StatusCode, HeaderMap);
|
||||||
|
|
||||||
|
pub async fn require_auth<B>(State(state): State<Arc<AppState>>, mut request: Request<B>, next: Next<B>) -> Result<Response, Rejection> {
|
||||||
|
let bearer = format!("Bearer realm=\"http://localhost:3000/auth\"");
|
||||||
|
let mut failure_headers = HeaderMap::new();
|
||||||
|
failure_headers.append(header::WWW_AUTHENTICATE, bearer.parse().unwrap());
|
||||||
|
failure_headers.append(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".parse().unwrap());
|
||||||
|
|
||||||
|
let auth = String::from(
|
||||||
|
request.headers().get(header::AUTHORIZATION)
|
||||||
|
.ok_or((StatusCode::UNAUTHORIZED, failure_headers.clone()))?
|
||||||
|
.to_str()
|
||||||
|
.map_err(|_| (StatusCode::UNAUTHORIZED, failure_headers.clone()))?
|
||||||
|
); // TODO: Don't unwrap
|
||||||
|
|
||||||
|
let token = match auth.split_once(' ') {
|
||||||
|
Some((auth, token)) if auth == "Bearer" => token,
|
||||||
|
// This line would allow empty tokens
|
||||||
|
//_ if auth == "Bearer" => Ok(AuthToken(None)),
|
||||||
|
_ => return Err( (StatusCode::UNAUTHORIZED, failure_headers) ),
|
||||||
|
};
|
||||||
|
|
||||||
|
// If the token is not valid, return an unauthorized response
|
||||||
|
let auth_storage = state.auth_storage.lock().await;
|
||||||
|
if !auth_storage.valid_tokens.contains(token) {
|
||||||
|
let bearer = format!("Bearer realm=\"http://localhost:3000/auth\"");
|
||||||
|
return Ok((
|
||||||
|
StatusCode::UNAUTHORIZED,
|
||||||
|
[
|
||||||
|
( header::WWW_AUTHENTICATE, bearer ),
|
||||||
|
( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string() )
|
||||||
|
]
|
||||||
|
).into_response());
|
||||||
|
|
||||||
|
} else {
|
||||||
|
debug!("Client successfully authenticated!");
|
||||||
|
}
|
||||||
|
drop(auth_storage);
|
||||||
|
|
||||||
|
request.extensions_mut().insert(AuthToken(String::from(token)));
|
||||||
|
|
||||||
|
Ok(next.run(request).await)
|
||||||
|
}
|
|
@ -6,6 +6,7 @@ mod storage;
|
||||||
mod byte_stream;
|
mod byte_stream;
|
||||||
mod config;
|
mod config;
|
||||||
mod query;
|
mod query;
|
||||||
|
mod auth_storage;
|
||||||
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
@ -88,15 +89,15 @@ async fn main() -> std::io::Result<()> {
|
||||||
.with_max_level(Level::DEBUG)
|
.with_max_level(Level::DEBUG)
|
||||||
.init();
|
.init();
|
||||||
|
|
||||||
|
let auth_middleware = axum::middleware::from_fn_with_state(state.clone(), auth_storage::require_auth);
|
||||||
let path_middleware = axum::middleware::from_fn(change_request_paths);
|
let path_middleware = axum::middleware::from_fn(change_request_paths);
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/auth", routing::get(api::auth::auth_basic_get)
|
.route("/auth", routing::get(api::auth::auth_basic_get)
|
||||||
.post(api::auth::auth_basic_get))
|
.post(api::auth::auth_basic_get))
|
||||||
.fallback(auth_failure)
|
|
||||||
.nest("/v2", Router::new()
|
.nest("/v2", Router::new()
|
||||||
.route("/", routing::get(api::version_check))
|
.route("/", routing::get(api::version_check))
|
||||||
/* .route("/_catalog", routing::get(api::catalog::list_repositories))
|
.route("/_catalog", routing::get(api::catalog::list_repositories))
|
||||||
.route("/:name/tags/list", routing::get(api::tags::list_tags))
|
.route("/:name/tags/list", routing::get(api::tags::list_tags))
|
||||||
.nest("/:name/blobs", Router::new()
|
.nest("/:name/blobs", Router::new()
|
||||||
.route("/:digest", routing::get(api::blobs::pull_digest_get)
|
.route("/:digest", routing::get(api::blobs::pull_digest_get)
|
||||||
|
@ -114,7 +115,8 @@ async fn main() -> std::io::Result<()> {
|
||||||
.route("/:name/manifests/:reference", routing::get(api::manifests::pull_manifest_get)
|
.route("/:name/manifests/:reference", routing::get(api::manifests::pull_manifest_get)
|
||||||
.put(api::manifests::upload_manifest_put)
|
.put(api::manifests::upload_manifest_put)
|
||||||
.head(api::manifests::manifest_exists_head)
|
.head(api::manifests::manifest_exists_head)
|
||||||
.delete(api::manifests::delete_manifest)) */
|
.delete(api::manifests::delete_manifest))
|
||||||
|
.layer(auth_middleware) // require auth for ALL v2 routes
|
||||||
)
|
)
|
||||||
.with_state(state)
|
.with_state(state)
|
||||||
.layer(TraceLayer::new_for_http());
|
.layer(TraceLayer::new_for_http());
|
||||||
|
|
Loading…
Reference in New Issue