From 6f22e84969bb02855f802d583bc8168cde3551c7 Mon Sep 17 00:00:00 2001 From: SeanOMik Date: Sat, 22 Jul 2023 00:20:17 -0400 Subject: [PATCH] Change /auth endpoint to /token, dont allow post for /token --- src/api/auth.rs | 24 +++++++++++++++++++----- src/auth/mod.rs | 4 ++-- src/dto/user.rs | 2 +- src/main.rs | 4 ++-- 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/api/auth.rs b/src/api/auth.rs index a19b60a..745cabf 100644 --- a/src/api/auth.rs +++ b/src/api/auth.rs @@ -55,6 +55,7 @@ pub struct AuthForm { #[derive(Deserialize, Serialize, Debug)] pub struct AuthResponse { token: String, + access_token: Option, expires_in: u32, issued_at: String, } @@ -90,6 +91,17 @@ fn create_jwt_token(jwt_key: String, account: Option<&str>, scopes: Vec) Ok(TokenInfo::new(token_str, expiration, now)) } +pub async fn auth_basic_post() -> Result { + return Ok(( + StatusCode::METHOD_NOT_ALLOWED, + [ + (header::CONTENT_TYPE, "application/json"), + (header::ALLOW, "Allow: GET, HEAD, OPTIONS"), + ], + "{\"detail\": \"Method \\\"POST\\\" not allowed.\"}" + ).into_response()); +} + pub async fn auth_basic_get( basic_auth: Option, state: State>, @@ -197,6 +209,7 @@ pub async fn auth_basic_get( let auth_response = AuthResponse { token: token_str.clone(), + access_token: Some(token_str.clone()), expires_in: 86400, // 1 day issued_at: now_format, }; @@ -276,15 +289,15 @@ pub async fn auth_basic_get( debug!("Constructed auth request"); - if auth.account.is_none() { - debug!("Account is none"); + if auth.user.is_none() { + debug!("User is none"); } if auth.password.is_none() { debug!("Password is none"); } - if let (Some(account), Some(password)) = (&auth.user, auth.password) { + if let (Some(account), Some(password)) = (auth.user, auth.password) { // Ensure that the password is correct let mut auth_driver = state.auth_checker.lock().await; if !auth_driver @@ -310,7 +323,7 @@ pub async fn auth_basic_get( debug!("User password is correct"); let now = SystemTime::now(); - let token = create_jwt_token(state.config.jwt_key.clone(), Some(account), vec![]) + let token = create_jwt_token(state.config.jwt_key.clone(), Some(&account), vec![]) .map_err(|_| { error!("Failed to create jwt token!"); @@ -327,7 +340,8 @@ pub async fn auth_basic_get( // Construct the auth response let auth_response = AuthResponse { token: token_str.clone(), - expires_in: 20, + access_token: Some(token_str.clone()), + expires_in: 86400, // 1 day issued_at: now_format, }; diff --git a/src/auth/mod.rs b/src/auth/mod.rs index 5b96634..2fae3dc 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -89,8 +89,8 @@ type Rejection = (StatusCode, HeaderMap); #[inline(always)] pub fn auth_challenge_response(config: &Config, scope: Option) -> Response { let bearer = match scope { - Some(scope) => format!("Bearer realm=\"{}/auth\",scope=\"{}\"", config.url(), scope), - None => format!("Bearer realm=\"{}/auth\"", config.url()) + Some(scope) => format!("Bearer realm=\"{}/token\",scope=\"{}\"", config.url(), scope), + None => format!("Bearer realm=\"{}/token\"", config.url()) }; debug!("responding with www-authenticate header of: \"{}\"", bearer); diff --git a/src/dto/user.rs b/src/dto/user.rs index 01bbaf7..74a36de 100644 --- a/src/dto/user.rs +++ b/src/dto/user.rs @@ -130,7 +130,7 @@ impl FromRequestParts> for UserAuth { type Rejection = (StatusCode, HeaderMap); async fn from_request_parts(parts: &mut Parts, state: &Arc) -> Result { - let bearer = format!("Bearer realm=\"{}/auth\"", state.config.url()); + let bearer = format!("Bearer realm=\"{}/token\"", state.config.url()); let mut failure_headers = HeaderMap::new(); failure_headers.append(header::WWW_AUTHENTICATE, bearer.parse().unwrap()); failure_headers.append(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".parse().unwrap()); diff --git a/src/main.rs b/src/main.rs index caaa047..73efc2e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -123,8 +123,8 @@ async fn main() -> anyhow::Result<()> { let path_middleware = axum::middleware::from_fn(change_request_paths); let app = Router::new() - .route("/auth", routing::get(api::auth::auth_basic_get) - .post(api::auth::auth_basic_get)) + .route("/token", routing::get(api::auth::auth_basic_get) + .post(api::auth::auth_basic_post)) .nest("/v2", Router::new() .route("/", routing::get(api::version_check)) .route("/_catalog", routing::get(api::catalog::list_repositories))