diff --git a/Cargo.toml b/Cargo.toml index 953a221..538c4b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -53,4 +53,4 @@ rand = "0.8.5" bcrypt = "0.14.0" bitflags = "2.2.1" ldap3 = "0.11.1" -lazy_static = "1.4.0" \ No newline at end of file +lazy_static = "1.4.0" diff --git a/src/api/auth.rs b/src/api/auth.rs index 0cd0626..f85cbaf 100644 --- a/src/api/auth.rs +++ b/src/api/auth.rs @@ -1,5 +1,5 @@ use std::{ - collections::{BTreeMap, HashMap}, + collections::HashMap, sync::Arc, time::SystemTime, }; @@ -11,7 +11,7 @@ use axum::{ Form, }; use axum_auth::AuthBasic; -use chrono::{DateTime, Days, Duration, Utc}; +use chrono::{DateTime, Days, Utc}; use serde::{Deserialize, Serialize}; use tracing::{debug, error, info, span, Level}; @@ -33,7 +33,7 @@ use crate::{ use crate::auth::auth_challenge_response; -#[derive(Deserialize, Debug)] +#[derive(Debug)] pub struct TokenAuthRequest { user: Option, password: Option, @@ -157,13 +157,13 @@ pub async fn auth_basic_get( ScopeType::Repository => { // check repository visibility let database = &state.database; - match database.get_repository_visibility(&scope.path).await { + match database.get_repository_visibility(&scope.name).await { Ok(Some(RepositoryVisibility::Public)) => res.push(Ok(true)), Ok(_) => res.push(Ok(false)), Err(e) => { error!( "Failure to check repository visibility for {}! Err: {}", - scope.path, e + scope.name, e ); res.push(Err(StatusCode::INTERNAL_SERVER_ERROR)); @@ -201,8 +201,8 @@ pub async fn auth_basic_get( issued_at: now_format, }; - let json_str = - serde_json::to_string(&auth_response).map_err(|_| StatusCode::BAD_REQUEST)?; + let json_str = serde_json::to_string(&auth_response) + .map_err(|_| StatusCode::BAD_REQUEST)?; debug!("Created anonymous token for public scopes!"); @@ -213,8 +213,7 @@ pub async fn auth_basic_get( (header::AUTHORIZATION, &format!("Bearer {}", token_str)), ], json_str, - ) - .into_response()); + ).into_response()); } else { info!("Auth failure! Auth was not provided in either AuthBasic or Form!"); @@ -279,10 +278,15 @@ pub async fn auth_basic_get( { debug!("Authentication failed, incorrect password!"); + // TODO: Multiple scopes + let scope = auth.scope + .first() + .and_then(|s| Some(s.clone())); + // TODO: Dont unwrap, find a way to return multiple scopes return Ok(auth_challenge_response( &state.config, - Some(auth.scope.first().unwrap().clone()), + scope, )); } drop(auth_driver); diff --git a/src/api/manifests.rs b/src/api/manifests.rs index f553029..15cb3a1 100644 --- a/src/api/manifests.rs +++ b/src/api/manifests.rs @@ -6,13 +6,12 @@ use axum::http::{StatusCode, HeaderName, header}; use tracing::log::warn; use tracing::{debug, info}; -use crate::auth::access_denied_response; use crate::app_state::AppState; use crate::database::Database; use crate::dto::RepositoryVisibility; use crate::dto::digest::Digest; use crate::dto::manifest::Manifest; -use crate::dto::user::{UserAuth, Permission}; +use crate::dto::user::UserAuth; use crate::error::AppError; pub async fn upload_manifest_put(Path((name, reference)): Path<(String, String)>, state: State>, auth: UserAuth, body: String) -> Result { @@ -20,10 +19,13 @@ pub async fn upload_manifest_put(Path((name, reference)): Path<(String, String)> let calculated_hash = sha256::digest(body.clone()); let calculated_digest = format!("sha256:{}", calculated_hash); + // anonymous users wouldn't be able to get to this point, so it should be safe to unwrap. + let user = auth.user.unwrap(); + let database = &state.database; // Create the image repository and save the image manifest. This repository will be private by default - database.save_repository(&name, RepositoryVisibility::Private, Some(auth.user.email), None).await?; + database.save_repository(&name, RepositoryVisibility::Private, Some(user.email), None).await?; database.save_manifest(&name, &calculated_digest, &body).await?; // If the reference is not a digest, then it must be a tag name. @@ -57,20 +59,7 @@ pub async fn upload_manifest_put(Path((name, reference)): Path<(String, String)> } } -pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>, state: State>, auth: Option) -> Result { - // Check if the user has permission to pull, or that the repository is public - if let Some(auth) = auth { - let mut auth_driver = state.auth_checker.lock().await; - if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await? { - return Ok(access_denied_response(&state.config)); - } - } else { - let database = &state.database; - if database.get_repository_visibility(&name).await? != Some(RepositoryVisibility::Public) { - return Ok(access_denied_response(&state.config)); - } - } - +pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>, state: State>) -> Result { let database = &state.database; let digest = match Digest::is_digest(&reference) { true => reference.clone(), @@ -106,21 +95,8 @@ pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>, ).into_response()) } -pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)>, state: State>, auth: Option) -> Result { - // Check if the user has permission to pull, or that the repository is public - if let Some(auth) = auth { - let mut auth_driver = state.auth_checker.lock().await; - if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await? { - return Ok(access_denied_response(&state.config)); - } - drop(auth_driver); - } else { - let database = &state.database; - if database.get_repository_visibility(&name).await? != Some(RepositoryVisibility::Public) { - return Ok(access_denied_response(&state.config)); - } - } - +pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)>, state: State>) -> Result { + debug!("start of head"); // Get the digest from the reference path. let database = &state.database; let digest = match Digest::is_digest(&reference) { @@ -133,6 +109,7 @@ pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String) } } }; + debug!("found digest: {}", digest); let manifest_content = database.get_manifest(&name, &digest).await?; if manifest_content.is_none() { @@ -142,6 +119,8 @@ pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String) } let manifest_content = manifest_content.unwrap(); + debug!("got content"); + Ok(( StatusCode::OK, [ @@ -154,13 +133,7 @@ pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String) ).into_response()) } -pub async fn delete_manifest(Path((name, reference)): Path<(String, String)>, state: State>, auth: UserAuth) -> Result { - let mut auth_driver = state.auth_checker.lock().await; - if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? { - return Ok(access_denied_response(&state.config)); - } - drop(auth_driver); - +pub async fn delete_manifest(Path((name, reference)): Path<(String, String)>, state: State>) -> Result { let database = &state.database; let digest = match Digest::is_digest(&reference) { true => { diff --git a/src/api/mod.rs b/src/api/mod.rs index e7e0113..097efa1 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -17,7 +17,7 @@ pub mod auth; /// full endpoint: `/v2/` pub async fn version_check(_state: State>) -> Response { ( - StatusCode::UNAUTHORIZED, + StatusCode::OK, [ ( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0" ), ] diff --git a/src/auth/ldap_driver.rs b/src/auth/ldap_driver.rs index b66933b..486dd26 100644 --- a/src/auth/ldap_driver.rs +++ b/src/auth/ldap_driver.rs @@ -3,7 +3,7 @@ use ldap3::{LdapConnAsync, Ldap, Scope, SearchEntry}; use sqlx::{Pool, Sqlite}; use tracing::{debug, warn}; -use crate::{config::LdapConnectionConfig, dto::{user::{Permission, LoginSource, RegistryUserType, self}, RepositoryVisibility}, database::Database}; +use crate::{config::LdapConnectionConfig, dto::{user::{Permission, LoginSource, RegistryUserType}, RepositoryVisibility}, database::Database}; use super::AuthDriver; diff --git a/src/auth/mod.rs b/src/auth/mod.rs index df74c65..5b96634 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -1,6 +1,6 @@ pub mod ldap_driver; -use std::{ops::Deref, sync::Arc}; +use std::sync::Arc; use axum::{extract::State, http::{StatusCode, HeaderMap, header, HeaderName, Request, Method}, middleware::Next, response::{Response, IntoResponse}}; @@ -82,59 +82,8 @@ where Ok(false) } -#[derive(Clone)] -pub struct AuthToken(pub String); - -impl Deref for AuthToken { - type Target = String; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - type Rejection = (StatusCode, HeaderMap); -pub async fn require_auth(State(state): State>, mut request: Request, next: Next) -> Result { - let bearer = format!("Bearer realm=\"{}/auth\"", state.config.url()); - let mut failure_headers = HeaderMap::new(); - failure_headers.append(header::WWW_AUTHENTICATE, bearer.parse().unwrap()); - failure_headers.append(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".parse().unwrap()); - - let auth = String::from( - request.headers().get(header::AUTHORIZATION) - .ok_or((StatusCode::UNAUTHORIZED, failure_headers.clone()))? - .to_str() - .map_err(|_| (StatusCode::UNAUTHORIZED, failure_headers.clone()))? - ); - - let token = match auth.split_once(' ') { - Some((auth, token)) if auth == "Bearer" => token, - // This line would allow empty tokens - //_ if auth == "Bearer" => Ok(AuthToken(None)), - _ => return Err( (StatusCode::UNAUTHORIZED, failure_headers) ), - }; - - // If the token is not valid, return an unauthorized response - let database = &state.database; - if let Ok(Some(user)) = database.verify_user_token(token.to_string()).await { - debug!("Authenticated user through middleware: {}", user.user.username); - - request.extensions_mut().insert(user); - - Ok(next.run(request).await) - } else { - let bearer = format!("Bearer realm=\"{}/auth\"", state.config.url()); - Ok(( - StatusCode::UNAUTHORIZED, - [ - ( header::WWW_AUTHENTICATE, bearer ), - ( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string() ) - ] - ).into_response()) - } -} - /// Creates a response with an Unauthorized (401) status code. /// The www-authenticate header is set to notify the client of where to authorize with. #[inline(always)] @@ -173,9 +122,17 @@ pub async fn check_auth(State(state): State>, auth: Option = url.split("/").skip(1).collect(); @@ -216,14 +173,28 @@ pub async fn check_auth(State(state): State>, auth: Option None, }; - match auth_checker.user_has_permission(auth.user.email.clone(), target_name.clone(), permission, vis).await { - Ok(false) => return Ok(auth_challenge_response(config, Some(scope))), - Ok(true) => { }, - Err(e) => { - error!("Error when checking user permissions! {}", e); + if let Some(user) = &auth.user { + match auth_checker.user_has_permission(user.email.clone(), target_name.clone(), permission, vis).await { + Ok(false) => return Ok(auth_challenge_response(config, Some(scope))), + Ok(true) => { }, + Err(e) => { + error!("Error when checking user permissions! {}", e); - return Err((StatusCode::INTERNAL_SERVER_ERROR, HeaderMap::new())); - }, + return Err((StatusCode::INTERNAL_SERVER_ERROR, HeaderMap::new())); + }, + } + } else { + // anonymous users can ONLY pull from public repos + if permission != Permission::PULL { + return Ok(access_denied_response(config)); + } + + // ensure the repo is public + let database = &state.database; + if let Some(RepositoryVisibility::Private) = database.get_repository_visibility(&target_name).await + .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, HeaderMap::new()))? { + return Ok(access_denied_response(config)); + } } } } else { diff --git a/src/config.rs b/src/config.rs index 55a009d..549a8a8 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,4 +1,3 @@ -use anyhow::anyhow; use figment::{Figment, providers::{Env, Toml, Format}}; use figment_cliarg_provider::FigmentCliArgsProvider; use serde::{Deserialize, Deserializer}; diff --git a/src/database/mod.rs b/src/database/mod.rs index 1cf3c4d..97c034c 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,13 +1,11 @@ use async_trait::async_trait; -use hmac::{Hmac, digest::KeyInit}; use rand::{Rng, distributions::Alphanumeric}; -use sha2::Sha256; -use sqlx::{Sqlite, Pool, sqlite::SqliteError}; +use sqlx::{Sqlite, Pool}; use tracing::{debug, warn}; -use chrono::{DateTime, Utc, NaiveDateTime, TimeZone}; +use chrono::{DateTime, Utc, NaiveDateTime}; -use crate::dto::{Tag, user::{User, RepositoryPermissions, RegistryUserType, Permission, UserAuth, TokenInfo, LoginSource}, RepositoryVisibility}; +use crate::dto::{Tag, user::{User, RepositoryPermissions, RegistryUserType, Permission, UserAuth, LoginSource}, RepositoryVisibility}; #[async_trait] pub trait Database { @@ -65,6 +63,7 @@ pub trait Database { async fn get_user_repo_permissions(&self, email: String, repository: String) -> anyhow::Result>; async fn get_user_registry_usertype(&self, email: String) -> anyhow::Result>; async fn store_user_token(&self, token: String, email: String, expiry: DateTime, created_at: DateTime) -> anyhow::Result<()>; + #[deprecated = "Tokens are now verified using a secret"] async fn verify_user_token(&self, token: String) -> anyhow::Result>; } @@ -90,6 +89,10 @@ impl Database for Pool { } }; + sqlx::query(include_str!("schemas/schema.sql")) + .execute(self).await?; + debug!("Created database schema"); + if row.is_none() || row.unwrap().0 == 0 { let jwt_sec: String = rand::thread_rng() .sample_iter(&Alphanumeric) @@ -99,9 +102,9 @@ impl Database for Pool { // create schema // TODO: Check if needed - sqlx::query(include_str!("schemas/schema.sql")) + /* sqlx::query(include_str!("schemas/schema.sql")) .execute(self).await?; - debug!("Created database schema"); + debug!("Created database schema"); */ sqlx::query("INSERT INTO orca(orca_version, schema_version, jwt_secret) VALUES (?, ?, ?)") .bind(orca_version) @@ -415,6 +418,7 @@ impl Database for Pool { } async fn get_user(&self, email: String) -> anyhow::Result> { + debug!("getting user"); let email = email.to_lowercase(); let row: (String, u32) = match sqlx::query_as("SELECT username, login_source FROM users WHERE email = ?") .bind(email.clone()) @@ -559,50 +563,7 @@ impl Database for Pool { Ok(()) } - async fn verify_user_token(&self, token: String) -> anyhow::Result> { - let token_row: (String, i64, i64,) = match sqlx::query_as("SELECT email, expiry, created_at FROM user_tokens WHERE token = ?") - .bind(token.clone()) - .fetch_one(self).await { - Ok(row) => row, - Err(e) => match e { - sqlx::Error::RowNotFound => { - return Ok(None) - }, - _ => { - return Err(anyhow::Error::new(e)); - } - } - }; - - let (email, expiry, created_at) = (token_row.0, token_row.1, token_row.2); - - let user_row: (String, u32) = match sqlx::query_as("SELECT username, login_source FROM users WHERE email = ?") - .bind(email.clone()) - .fetch_one(self).await { - Ok(row) => row, - Err(e) => match e { - sqlx::Error::RowNotFound => { - return Ok(None) - }, - _ => { - return Err(anyhow::Error::new(e)); - } - } - }; - - /* let user_row: (String, u32) = sqlx::query_as("SELECT email, login_source FROM users WHERE email = ?") - .bind(email.clone()) - .fetch_one(self).await?; */ - - let (expiry, created_at) = (Utc.timestamp_millis_opt(expiry).single(), Utc.timestamp_millis_opt(created_at).single()); - if let (Some(expiry), Some(created_at)) = (expiry, created_at) { - let user = User::new(user_row.0, email, LoginSource::try_from(user_row.1)?); - let token = TokenInfo::new(token, expiry, created_at); - let auth = UserAuth::new(user, token); - - Ok(Some(auth)) - } else { - Ok(None) - } + async fn verify_user_token(&self, _token: String) -> anyhow::Result> { + panic!("ERR: Database::verify_user_token is deprecated!") } } \ No newline at end of file diff --git a/src/dto/scope.rs b/src/dto/scope.rs index cf3cea2..e6ac9ea 100644 --- a/src/dto/scope.rs +++ b/src/dto/scope.rs @@ -1,5 +1,5 @@ use anyhow::anyhow; -use serde::{Deserialize, de::Visitor, Serialize}; +use serde::{Deserialize, Serialize}; use std::fmt; @@ -40,10 +40,11 @@ impl fmt::Display for Action { } } -#[derive(Default, Debug, Clone, Serialize, PartialEq, Eq)] +#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct Scope { + #[serde(rename = "type")] pub scope_type: ScopeType, - pub path: String, + pub name: String, pub actions: Vec, } @@ -51,7 +52,7 @@ impl Scope { pub fn new(scope_type: ScopeType, path: String, actions: &[Action]) -> Self { Self { scope_type, - path, + name: path, actions: actions.to_vec(), } } @@ -65,7 +66,7 @@ impl fmt::Display for Scope { .collect::>() .join(","); - write!(f, "{}:{}:{}", self.scope_type, self.path, actions) + write!(f, "{}:{}:{}", self.scope_type, self.name, actions) } } @@ -96,7 +97,7 @@ impl TryFrom<&str> for Scope { Ok(Scope { scope_type, - path: String::from(path), + name: String::from(path), actions }) } else { @@ -104,68 +105,4 @@ impl TryFrom<&str> for Scope { //Err(serde::de::Error::custom("Malformed scope string!")) } } -} - -pub struct ScopeVisitor { - -} - -impl<'de> Visitor<'de> for ScopeVisitor { - type Value = Scope; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a Scope in the format of `repository:samalba/my-app:pull,push`.") - } - - fn visit_str(self, val: &str) -> Result - where - E: serde::de::Error { - println!("Start of visit_str!"); - - let res = match Scope::try_from(val) { - Ok(val) => Ok(val), - Err(e) => Err(serde::de::Error::custom(format!("{}", e))) - }; - - res - - - - /* let splits: Vec<&str> = val.split(":").collect(); - if splits.len() == 3 { - let scope_type = match splits[0] { - "repository" => ScopeType::Repository, - _ => { - return Err(serde::de::Error::custom(format!("Invalid scope type: `{}`!", splits[0]))); - } - }; - - let path = splits[1]; - - let actions: Result, E> = splits[2] - .split(",") - .map(|a| match a { - "pull" => Ok(Action::Pull), - "push" => Ok(Action::Push), - _ => Err(serde::de::Error::custom(format!("Invalid action: `{}`!", a))), - }).collect(); - let actions = actions?; - - Ok(Scope { - scope_type, - path: String::from(path), - actions - }) - } else { - Err(serde::de::Error::custom("Malformed scope string!")) - } */ - } -} - -impl<'de> Deserialize<'de> for Scope { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de> { - deserializer.deserialize_str(ScopeVisitor {}) - } } \ No newline at end of file diff --git a/src/dto/user.rs b/src/dto/user.rs index 92b9735..01bbaf7 100644 --- a/src/dto/user.rs +++ b/src/dto/user.rs @@ -4,7 +4,10 @@ use async_trait::async_trait; use axum::{http::{StatusCode, header, HeaderName, HeaderMap, request::Parts}, extract::FromRequestParts}; use bitflags::bitflags; use chrono::{DateTime, Utc}; +use hmac::{Hmac, digest::KeyInit}; +use jwt::VerifyWithKey; use serde::{Deserialize, Serialize}; +use sha2::Sha256; use tracing::debug; use crate::{app_state::AppState, database::Database}; @@ -109,12 +112,12 @@ impl TokenInfo { #[derive(Clone, Debug, PartialEq)] pub struct UserAuth { - pub user: User, - pub token: TokenInfo, + pub user: Option, + pub token: AuthToken, } impl UserAuth { - pub fn new(user: User, token: TokenInfo) -> Self { + pub fn new(user: Option, token: AuthToken) -> Self { Self { user, token, @@ -132,8 +135,6 @@ impl FromRequestParts> for UserAuth { failure_headers.append(header::WWW_AUTHENTICATE, bearer.parse().unwrap()); failure_headers.append(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".parse().unwrap()); - debug!("starting UserAuth request parts"); - let auth = String::from( parts.headers .get(header::AUTHORIZATION) @@ -142,8 +143,6 @@ impl FromRequestParts> for UserAuth { .map_err(|_| (StatusCode::UNAUTHORIZED, failure_headers.clone()))? ); - debug!("got auth header"); - let token = match auth.split_once(' ') { Some((auth, token)) if auth == "Bearer" => token, // This line would allow empty tokens @@ -151,22 +150,42 @@ impl FromRequestParts> for UserAuth { _ => return Err( (StatusCode::UNAUTHORIZED, failure_headers) ), }; - debug!("got token"); - // If the token is not valid, return an unauthorized response - let database = &state.database; - if let Ok(Some(user)) = database.verify_user_token(token.to_string()).await { - debug!("Authenticated user through request extractor: {}", user.user.username); + let jwt_key: Hmac = Hmac::new_from_slice(state.config.jwt_key.as_bytes()) + .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, HeaderMap::new()) )?; - Ok(user) - } else { - debug!("Failure to verify user token, responding with auth realm"); + match VerifyWithKey::::verify_with_key(token, &jwt_key) { + Ok(token) => { + // attempt to get the user + if !token.subject.is_empty() { + let database = &state.database; + if let Ok(Some(user)) = database.get_user(token.subject.clone()).await { + return Ok(UserAuth::new(Some(user), token)); + } else { + debug!("failure to get user from token: {:?}", token); + } + } else { + return Ok(UserAuth::new(None, token)); + } - Err(( - StatusCode::UNAUTHORIZED, - failure_headers - )) + /* let database = &state.database; + if let Ok(user) = database.get_user(token.subject.clone()).await { + return Ok(UserAuth::new(user, token)); + } else { + debug!("failure to get user from token: {:?}", token); + } */ + }, + Err(e) => { + debug!("Failure to verify user token: '{}'", e); + } } + + debug!("Failure to verify user token, responding with auth realm"); + + Err(( + StatusCode::UNAUTHORIZED, + failure_headers + )) } }