diff --git a/src/api/auth.rs b/src/api/auth.rs index 4ed39d6..9ac7847 100644 --- a/src/api/auth.rs +++ b/src/api/auth.rs @@ -15,6 +15,8 @@ use rand::Rng; use crate::{dto::{scope::Scope, user::{UserAuth, TokenInfo}}, app_state::AppState}; use crate::database::Database; +use crate::auth_storage::{unauthenticated_response, AuthDriver}; + #[derive(Deserialize, Debug)] pub struct TokenAuthRequest { user: Option, @@ -165,13 +167,14 @@ pub async fn auth_basic_get(basic_auth: Option, state: State, state: State, state: State>, Extension(auth): Extension) -> Response { // Check if the user has permission to pull, or that the repository is public - let database = &state.database; - /* if !does_user_have_permission(database, auth.user.username, name.clone(), Permission::PULL).await.unwrap() - && !database.get_repository_visibility(&name).await.unwrap() - .and_then(|v| Some(v == RepositoryVisibility::Public)) - .unwrap_or_else(|| false) { - - return get_unauthenticated_response(&state.config); - } */ - if !does_user_have_repository_permission(database, auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await.unwrap() { - return get_unauthenticated_response(&state.config); + let auth_driver = state.auth_checker.lock().await; + if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await.unwrap() { + return unauthenticated_response(&state.config); } - drop(database); + drop(auth_driver); let storage = state.storage.lock().await; @@ -47,11 +40,11 @@ pub async fn digest_exists_head(Path((name, layer_digest)): Path<(String, String pub async fn pull_digest_get(Path((name, layer_digest)): Path<(String, String)>, state: State>, Extension(auth): Extension) -> Response { // Check if the user has permission to pull, or that the repository is public - let database = &state.database; - if !does_user_have_repository_permission(database, auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await.unwrap() { - return get_unauthenticated_response(&state.config); + let auth_driver = state.auth_checker.lock().await; + if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await.unwrap() { + return unauthenticated_response(&state.config); } - drop(database); + drop(auth_driver); let storage = state.storage.lock().await; diff --git a/src/api/manifests.rs b/src/api/manifests.rs index 22b15aa..ab716c8 100644 --- a/src/api/manifests.rs +++ b/src/api/manifests.rs @@ -7,7 +7,7 @@ use axum::http::{StatusCode, HeaderMap, HeaderName, header}; use tracing::log::warn; use tracing::{debug, info}; -use crate::auth_storage::{does_user_have_permission, get_unauthenticated_response, does_user_have_repository_permission}; +use crate::auth_storage::{unauthenticated_response, AuthDriver}; use crate::app_state::AppState; use crate::database::Database; use crate::dto::RepositoryVisibility; @@ -16,9 +16,11 @@ use crate::dto::manifest::Manifest; use crate::dto::user::{UserAuth, Permission}; pub async fn upload_manifest_put(Path((name, reference)): Path<(String, String)>, state: State>, Extension(auth): Extension, body: String) -> Response { - if !does_user_have_permission(&state.database, auth.user.username, name.clone(), Permission::PUSH).await.unwrap() { - return get_unauthenticated_response(&state.config); + let auth_driver = state.auth_checker.lock().await; + if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() { + return unauthenticated_response(&state.config); } + drop(auth_driver); // Calculate the sha256 digest for the manifest. let calculated_hash = sha256::digest(body.clone()); @@ -63,11 +65,11 @@ pub async fn upload_manifest_put(Path((name, reference)): Path<(String, String)> pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>, state: State>, Extension(auth): Extension) -> Response { // Check if the user has permission to pull, or that the repository is public - let database = &state.database; - if !does_user_have_repository_permission(database, auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await.unwrap() { - return get_unauthenticated_response(&state.config); + let auth_driver = state.auth_checker.lock().await; + if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await.unwrap() { + return unauthenticated_response(&state.config); } - drop(database); + drop(auth_driver); let database = &state.database; let digest = match Digest::is_digest(&reference) { @@ -106,11 +108,11 @@ pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>, pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)>, state: State>, Extension(auth): Extension) -> Response { // Check if the user has permission to pull, or that the repository is public - let database = &state.database; - if !does_user_have_repository_permission(database, auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await.unwrap() { - return get_unauthenticated_response(&state.config); + let auth_driver = state.auth_checker.lock().await; + if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await.unwrap() { + return unauthenticated_response(&state.config); } - drop(database); + drop(auth_driver); // Get the digest from the reference path. let database = &state.database; @@ -146,9 +148,11 @@ pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String) } pub async fn delete_manifest(Path((name, reference)): Path<(String, String)>, headers: HeaderMap, state: State>, Extension(auth): Extension) -> Response { - if !does_user_have_permission(&state.database, auth.user.username, name.clone(), Permission::PUSH).await.unwrap() { - return get_unauthenticated_response(&state.config); + let auth_driver = state.auth_checker.lock().await; + if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() { + return unauthenticated_response(&state.config); } + drop(auth_driver); let _authorization = headers.get("Authorization").unwrap(); // TODO: use authorization header diff --git a/src/api/uploads.rs b/src/api/uploads.rs index 38c582d..30e1a04 100644 --- a/src/api/uploads.rs +++ b/src/api/uploads.rs @@ -12,14 +12,15 @@ use futures::StreamExt; use tracing::{debug, warn}; use crate::app_state::AppState; -use crate::auth_storage::{does_user_have_permission, get_unauthenticated_response}; +use crate::auth_storage::{unauthenticated_response, AuthDriver}; use crate::byte_stream::ByteStream; use crate::database::Database; use crate::dto::user::{UserAuth, Permission, RegistryUser, RegistryUserType}; /// Starting an upload pub async fn start_upload_post(Path((name, )): Path<(String, )>, Extension(auth): Extension, state: State>) -> Response { - if does_user_have_permission(&state.database, auth.user.username, name.clone(), Permission::PUSH).await.unwrap() { + let auth_driver = state.auth_checker.lock().await; + if auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() { debug!("Upload requested"); let uuid = uuid::Uuid::new_v4(); @@ -34,13 +35,15 @@ pub async fn start_upload_post(Path((name, )): Path<(String, )>, Extension(auth) ).into_response(); } - get_unauthenticated_response(&state.config) + unauthenticated_response(&state.config) } pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String, String)>, Extension(auth): Extension, state: State>, mut body: BodyStream) -> Response { - if !does_user_have_permission(&state.database, auth.user.username, name.clone(), Permission::PUSH).await.unwrap() { - return get_unauthenticated_response(&state.config); + let auth_driver = state.auth_checker.lock().await; + if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() { + return unauthenticated_response(&state.config); } + drop(auth_driver); let storage = state.storage.lock().await; let current_size = storage.digest_length(&layer_uuid).await.unwrap(); @@ -95,9 +98,11 @@ pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String, } pub async fn finish_chunked_upload_put(Path((name, layer_uuid)): Path<(String, String)>, Query(query): Query>, Extension(auth): Extension, state: State>, body: Bytes) -> Response { - if !does_user_have_permission(&state.database, auth.user.username, name.clone(), Permission::PUSH).await.unwrap() { - return get_unauthenticated_response(&state.config); + let auth_driver = state.auth_checker.lock().await; + if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() { + return unauthenticated_response(&state.config); } + drop(auth_driver); let digest = query.get("digest").unwrap(); @@ -122,9 +127,11 @@ pub async fn finish_chunked_upload_put(Path((name, layer_uuid)): Path<(String, S } pub async fn cancel_upload_delete(Path((name, layer_uuid)): Path<(String, String)>, state: State>, Extension(auth): Extension) -> Response { - if !does_user_have_permission(&state.database, auth.user.username, name.clone(), Permission::PUSH).await.unwrap() { - return get_unauthenticated_response(&state.config); + let auth_driver = state.auth_checker.lock().await; + if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() { + return unauthenticated_response(&state.config); } + drop(auth_driver); let storage = state.storage.lock().await; storage.delete_digest(&layer_uuid).await.unwrap(); @@ -134,9 +141,11 @@ pub async fn cancel_upload_delete(Path((name, layer_uuid)): Path<(String, String } pub async fn check_upload_status_get(Path((name, layer_uuid)): Path<(String, String)>, state: State>, Extension(auth): Extension) -> Response { - if !does_user_have_permission(&state.database, auth.user.username, name.clone(), Permission::PUSH).await.unwrap() { - return get_unauthenticated_response(&state.config); + let auth_driver = state.auth_checker.lock().await; + if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() { + return unauthenticated_response(&state.config); } + drop(auth_driver); let storage = state.storage.lock().await; let ending = storage.digest_length(&layer_uuid).await.unwrap().unwrap_or(0); diff --git a/src/app_state.rs b/src/app_state.rs index 9496963..82577e2 100644 --- a/src/app_state.rs +++ b/src/app_state.rs @@ -1,6 +1,6 @@ use sqlx::{Sqlite, Pool}; -use crate::auth_storage::MemoryAuthStorage; +use crate::auth_storage::AuthDriver; use crate::storage::StorageDriver; use crate::config::Config; @@ -10,17 +10,17 @@ pub struct AppState { pub database: Pool, pub storage: Mutex>, pub config: Config, - pub auth_storage: Mutex, + pub auth_checker: Mutex>, } impl AppState { - pub fn new(database: Pool, storage: Mutex>, config: Config) -> Self + pub fn new(database: Pool, storage: Mutex>, config: Config, auth_checker: Mutex>) -> Self { Self { database, storage, config, - auth_storage: Mutex::new(MemoryAuthStorage::new()), + auth_checker, } } } \ No newline at end of file diff --git a/src/auth_storage.rs b/src/auth_storage.rs index 6db6f14..d8efd77 100644 --- a/src/auth_storage.rs +++ b/src/auth_storage.rs @@ -2,13 +2,66 @@ use std::{collections::HashSet, ops::Deref, sync::Arc}; use axum::{extract::{State, Path}, http::{StatusCode, HeaderMap, header, HeaderName, Request}, middleware::Next, response::{Response, IntoResponse}}; +use sqlx::{Pool, Sqlite}; use tracing::debug; use crate::{app_state::AppState, dto::{user::{Permission, RegistryUserType}, RepositoryVisibility}, config::Config}; use crate::database::Database; +use async_trait::async_trait; + +#[async_trait] +pub trait AuthDriver: Send + Sync { + /// Checks if a user has permission to do something in a repository. + /// + /// * `username`: Name of the user. + /// * `repository`: Name of the repository. + /// * `permissions`: Permission to check for. + /// * `required_visibility`: Specified if there is a specific visibility of the repository that will give the user permission. + async fn user_has_permission(&self, username: String, repository: String, permission: Permission, required_visibility: Option) -> anyhow::Result; + async fn verify_user_login(&self, username: String, password: String) -> anyhow::Result; +} + +#[async_trait] +impl AuthDriver for Pool { + async fn user_has_permission(&self, username: String, repository: String, permission: Permission, required_visibility: Option) -> anyhow::Result { + let allowed_to = { + match self.get_user_registry_type(username.clone()).await? { + Some(RegistryUserType::Admin) => true, + _ => { + if let Some(perms) = self.get_user_repo_permissions(username, repository.clone()).await? { + if perms.has_permission(permission) { + return Ok(true); + } + } + + if let Some(vis) = required_visibility { + if let Some(repo_vis) = self.get_repository_visibility(&repository).await? { + if vis == repo_vis { + return Ok(true); + } + } + } + + false + } + /* match database.get_user_repo_permissions(username, repository).await.unwrap() { + Some(perms) => if perms.has_permission(permission) { true } else { false }, + _ => false, + } */ + } + }; + + Ok(allowed_to) + } + + async fn verify_user_login(&self, username: String, password: String) -> anyhow::Result { + Database::verify_user_login(self, username, password).await + } +} + /// Temporary struct for storing auth information in memory. -pub struct MemoryAuthStorage { +/* pub struct MemoryAuthStorage { pub valid_tokens: HashSet, } @@ -18,7 +71,7 @@ impl MemoryAuthStorage { valid_tokens: HashSet::new(), } } -} +} */ #[derive(Clone)] pub struct AuthToken(pub String); @@ -73,49 +126,8 @@ pub async fn require_auth(State(state): State>, mut request: Re } } -pub async fn does_user_have_permission(database: &impl Database, username: String, repository: String, permission: Permission) -> anyhow::Result { - does_user_have_repository_permission(database, username, repository, permission, None).await -} - -/// Checks if a user has permission to do something in a repository. -/// -/// * `database`: Database connection. -/// * `username`: Name of the user. -/// * `repository`: Name of the repository. -/// * `permissions`: Permission to check for. -/// * `required_visibility`: Specified if there is a specific visibility of the repository that will give the user permission. -pub async fn does_user_have_repository_permission(database: &impl Database, username: String, repository: String, permission: Permission, required_visibility: Option) -> anyhow::Result { - let allowed_to = { - match database.get_user_registry_type(username.clone()).await? { - Some(RegistryUserType::Admin) => true, - _ => { - if let Some(perms) = database.get_user_repo_permissions(username, repository.clone()).await? { - if perms.has_permission(permission) { - return Ok(true); - } - } - - if let Some(vis) = required_visibility { - if let Some(repo_vis) = database.get_repository_visibility(&repository).await? { - if vis == repo_vis { - return Ok(true); - } - } - } - - false - } - /* match database.get_user_repo_permissions(username, repository).await.unwrap() { - Some(perms) => if perms.has_permission(permission) { true } else { false }, - _ => false, - } */ - } - }; - - Ok(allowed_to) -} - -pub fn get_unauthenticated_response(config: &Config) -> Response { +#[inline(always)] +pub fn unauthenticated_response(config: &Config) -> Response { let bearer = format!("Bearer realm=\"{}/auth\"", config.get_url()); ( StatusCode::UNAUTHORIZED, diff --git a/src/database/schemas/schema.sql b/src/database/schemas/schema.sql index cd69708..84b10c9 100644 --- a/src/database/schemas/schema.sql +++ b/src/database/schemas/schema.sql @@ -62,4 +62,7 @@ CREATE TABLE IF NOT EXISTS user_tokens ( username TEXT NOT NULL, expiry BIGINT NOT NULL, created_at BIGINT NOT NULL -); \ No newline at end of file +); + +-- create admin user +INSERT OR IGNORE INTO users (username, email, password_hash, password_salt) VALUES ('admin', 'admin@example.com', '$2b$12$x5ECk0jUmOSfBWxW52wsyOmFxNZkwc2J9FH225if4eBnQYUvYLYYq', 'x5ECk0jUmOSfBWxW52wsyO'); \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index e4270db..bf317fa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,6 +12,7 @@ use std::net::SocketAddr; use std::str::FromStr; use std::sync::Arc; +use auth_storage::AuthDriver; use axum::http::{Request, StatusCode, header, HeaderName}; use axum::middleware::Next; use axum::response::{Response, IntoResponse}; @@ -66,11 +67,12 @@ async fn main() -> std::io::Result<()> { pool.create_schema().await.unwrap(); let storage_driver: Mutex> = Mutex::new(Box::new(FilesystemDriver::new("registry/blobs"))); + let auth_driver: Mutex> = Mutex::new(Box::new(pool.clone())); let config = Config::new().expect("Failure to parse config!"); let app_addr = SocketAddr::from_str(&format!("{}:{}", config.listen_address, config.listen_port)).unwrap(); - let state = Arc::new(AppState::new(pool, storage_driver, config)); + let state = Arc::new(AppState::new(pool, storage_driver, config, auth_driver)); tracing_subscriber::fmt() .with_max_level(Level::DEBUG)