diff --git a/docs/todo.md b/docs/todo.md index ce7a786..8897726 100644 --- a/docs/todo.md +++ b/docs/todo.md @@ -1,5 +1,9 @@ - [x] slashes in repo names -- [ ] Simple auth +- [x] Simple auth +- [x] ldap auth +- [ ] permission stuff + - [ ] Only allow users to create repositories if its the same name as their username, or if they're an admin + - [ ] Only allow users to pull from their own repositories - [ ] postgresql - [ ] prometheus metrics - [x] streaming layer bytes into providers diff --git a/src/api/auth.rs b/src/api/auth.rs index 8c1fe88..3b53e3f 100644 --- a/src/api/auth.rs +++ b/src/api/auth.rs @@ -56,9 +56,9 @@ fn create_jwt_token(account: &str) -> anyhow::Result { claims.insert("subject", &account); //claims.insert("audience", auth.service); - let not_before = format!("{}", now_secs - 10); + let not_before = format!("{}", now_secs); let issued_at = format!("{}", now_secs); - let expiration = format!("{}", now_secs + 20); + let expiration = format!("{}", now_secs + 86400); // 1 day claims.insert("notbefore", ¬_before); claims.insert("issuedat", &issued_at); claims.insert("expiration", &expiration); // TODO: 20 seconds expiry for testing @@ -75,7 +75,7 @@ fn create_jwt_token(account: &str) -> anyhow::Result { Ok(TokenInfo::new(token_str, expiration, issued_at)) } -pub async fn auth_basic_get(basic_auth: Option, state: State>, Query(params): Query>, form: Option>) -> Response { +pub async fn auth_basic_get(basic_auth: Option, state: State>, Query(params): Query>, form: Option>) -> Result { let mut auth = TokenAuthRequest { user: None, password: None, @@ -117,7 +117,7 @@ pub async fn auth_basic_get(basic_auth: Option, state: State, state: State, state: State { + auth.scope.push(scope); + }, + Err(_) => { + return Err(StatusCode::BAD_REQUEST); + } + } } // Get offline token and attempt to convert it to a boolean @@ -168,17 +175,19 @@ pub async fn auth_basic_get(basic_auth: Option, state: State, state: State, state: State>, Extension(auth): Extension) -> Response { +pub async fn digest_exists_head(Path((name, layer_digest)): Path<(String, String)>, state: State>, Extension(auth): Extension) -> Result { // Check if the user has permission to pull, or that the repository is public let mut auth_driver = state.auth_checker.lock().await; - if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await.unwrap() { - return unauthenticated_response(&state.config); + if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await? { + return Ok(unauthenticated_response(&state.config)); } drop(auth_driver); let storage = state.storage.lock().await; - if storage.has_digest(&layer_digest).await.unwrap() { - if let Some(size) = storage.digest_length(&layer_digest).await.unwrap() { - return ( + if storage.has_digest(&layer_digest).await? { + if let Some(size) = storage.digest_length(&layer_digest).await? { + return Ok(( StatusCode::OK, [ (header::CONTENT_LENGTH, size.to_string()), (HeaderName::from_static("docker-content-digest"), layer_digest) ] - ).into_response(); + ).into_response()); } } - StatusCode::NOT_FOUND.into_response() + Ok(StatusCode::NOT_FOUND.into_response()) } -pub async fn pull_digest_get(Path((name, layer_digest)): Path<(String, String)>, state: State>, Extension(auth): Extension) -> Response { +pub async fn pull_digest_get(Path((name, layer_digest)): Path<(String, String)>, state: State>, Extension(auth): Extension) -> Result { // Check if the user has permission to pull, or that the repository is public let mut auth_driver = state.auth_checker.lock().await; - if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await.unwrap() { - return unauthenticated_response(&state.config); + if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await? { + return Ok(unauthenticated_response(&state.config)); } drop(auth_driver); let storage = state.storage.lock().await; - if let Some(len) = storage.digest_length(&layer_digest).await.unwrap() { - let stream = storage.get_digest_stream(&layer_digest).await.unwrap().unwrap(); + if let Some(len) = storage.digest_length(&layer_digest).await? { + let stream = match storage.get_digest_stream(&layer_digest).await? { + Some(s) => s, + None => { + return Ok(StatusCode::NOT_FOUND.into_response()); + } + }; // convert the `AsyncRead` into a `Stream` let stream = ReaderStream::new(stream.into_async_read()); // convert the `Stream` into an `axum::body::HttpBody` let body = StreamBody::new(stream); - ( + Ok(( StatusCode::OK, [ (header::CONTENT_LENGTH, len.to_string()), (HeaderName::from_static("docker-content-digest"), layer_digest) ], body - ).into_response() + ).into_response()) } else { - StatusCode::NOT_FOUND.into_response() + Ok(StatusCode::NOT_FOUND.into_response()) } } diff --git a/src/api/catalog.rs b/src/api/catalog.rs index 42e28b9..dff2215 100644 --- a/src/api/catalog.rs +++ b/src/api/catalog.rs @@ -1,9 +1,9 @@ use std::sync::Arc; -use axum::{extract::{State, Query}, http::{StatusCode, header, HeaderMap, HeaderName}, response::IntoResponse}; +use axum::{extract::{State, Query}, http::{StatusCode, header, HeaderMap, HeaderName}, response::{IntoResponse, Response}}; use serde::{Serialize, Deserialize}; -use crate::{app_state::AppState, database::Database}; +use crate::{app_state::AppState, database::Database, error::AppError}; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -20,7 +20,7 @@ pub struct ListRepositoriesParams { last_repo: Option, } -pub async fn list_repositories(Query(params): Query, state: State>) -> impl IntoResponse { +pub async fn list_repositories(Query(params): Query, state: State>) -> Result { let mut link_header = None; // Paginate tag results if n was specified, else just pull everything. @@ -30,7 +30,7 @@ pub async fn list_repositories(Query(params): Query, sta // Convert the last param to a String, and list all the repos let last_repo = params.last_repo.and_then(|t| Some(t.to_string())); - let repos = database.list_repositories(Some(limit), last_repo).await.unwrap(); + let repos = database.list_repositories(Some(limit), last_repo).await?; // Get the new last repository for the response let last_repo = repos.last().and_then(|s| Some(s.clone())); @@ -47,7 +47,7 @@ pub async fn list_repositories(Query(params): Query, sta repos }, None => { - database.list_repositories(None, None).await.unwrap() + database.list_repositories(None, None).await? } }; @@ -55,20 +55,20 @@ pub async fn list_repositories(Query(params): Query, sta let repo_list = RepositoryList { repositories, }; - let response_body = serde_json::to_string(&repo_list).unwrap(); + let response_body = serde_json::to_string(&repo_list)?; let mut headers = HeaderMap::new(); - headers.insert(header::CONTENT_TYPE, "application/json".parse().unwrap()); - headers.insert(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".parse().unwrap()); + headers.insert(header::CONTENT_TYPE, "application/json".parse()?); + headers.insert(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".parse()?); if let Some(link_header) = link_header { - headers.insert(header::LINK, link_header.parse().unwrap()); + headers.insert(header::LINK, link_header.parse()?); } // Construct the response, optionally adding the Link header if it was constructed. - ( + Ok(( StatusCode::OK, headers, response_body - ) + ).into_response()) } \ No newline at end of file diff --git a/src/api/manifests.rs b/src/api/manifests.rs index edb3aa6..3ebb6fe 100644 --- a/src/api/manifests.rs +++ b/src/api/manifests.rs @@ -3,22 +3,23 @@ use std::sync::Arc; use axum::Extension; use axum::extract::{Path, State}; use axum::response::{Response, IntoResponse}; -use axum::http::{StatusCode, HeaderMap, HeaderName, header}; +use axum::http::{StatusCode, HeaderName, header}; use tracing::log::warn; use tracing::{debug, info}; -use crate::auth::{unauthenticated_response, AuthDriver}; +use crate::auth::unauthenticated_response; use crate::app_state::AppState; use crate::database::Database; use crate::dto::RepositoryVisibility; use crate::dto::digest::Digest; use crate::dto::manifest::Manifest; use crate::dto::user::{UserAuth, Permission}; +use crate::error::AppError; -pub async fn upload_manifest_put(Path((name, reference)): Path<(String, String)>, state: State>, Extension(auth): Extension, body: String) -> Response { +pub async fn upload_manifest_put(Path((name, reference)): Path<(String, String)>, state: State>, Extension(auth): Extension, body: String) -> Result { let mut auth_driver = state.auth_checker.lock().await; - if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() { - return unauthenticated_response(&state.config); + if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? { + return Ok(unauthenticated_response(&state.config)); } drop(auth_driver); @@ -29,45 +30,45 @@ pub async fn upload_manifest_put(Path((name, reference)): Path<(String, String)> let database = &state.database; // Create the image repository and save the image manifest. This repository will be private by default - database.save_repository(&name, RepositoryVisibility::Private, None).await.unwrap(); - database.save_manifest(&name, &calculated_digest, &body).await.unwrap(); + database.save_repository(&name, RepositoryVisibility::Private, None).await?; + database.save_manifest(&name, &calculated_digest, &body).await?; // If the reference is not a digest, then it must be a tag name. if !Digest::is_digest(&reference) { - database.save_tag(&name, &reference, &calculated_digest).await.unwrap(); + database.save_tag(&name, &reference, &calculated_digest).await?; } info!("Saved manifest {}", calculated_digest); - match serde_json::from_str(&body).unwrap() { + match serde_json::from_str(&body)? { Manifest::Image(image) => { // Link the manifest to the image layer - database.link_manifest_layer(&calculated_digest, &image.config.digest).await.unwrap(); + database.link_manifest_layer(&calculated_digest, &image.config.digest).await?; debug!("Linked manifest {} to layer {}", calculated_digest, image.config.digest); for layer in image.layers { - database.link_manifest_layer(&calculated_digest, &layer.digest).await.unwrap(); + database.link_manifest_layer(&calculated_digest, &layer.digest).await?; debug!("Linked manifest {} to layer {}", calculated_digest, image.config.digest); } - ( + Ok(( StatusCode::CREATED, [ (HeaderName::from_static("docker-content-digest"), calculated_digest) ] - ).into_response() + ).into_response()) }, Manifest::List(_list) => { warn!("ManifestList request was received!"); - StatusCode::NOT_IMPLEMENTED.into_response() + Ok(StatusCode::NOT_IMPLEMENTED.into_response()) } } } -pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>, state: State>, Extension(auth): Extension) -> Response { +pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>, state: State>, Extension(auth): Extension) -> Result { // Check if the user has permission to pull, or that the repository is public let mut auth_driver = state.auth_checker.lock().await; - if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await.unwrap() { - return unauthenticated_response(&state.config); + if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await? { + return Ok(unauthenticated_response(&state.config)); } drop(auth_driver); @@ -76,24 +77,24 @@ pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>, true => reference.clone(), false => { debug!("Attempting to get manifest digest using tag (repository={}, reference={})", name, reference); - if let Some(tag) = database.get_tag(&name, &reference).await.unwrap() { + if let Some(tag) = database.get_tag(&name, &reference).await? { tag.manifest_digest } else { - return StatusCode::NOT_FOUND.into_response(); + return Ok(StatusCode::NOT_FOUND.into_response()); } } }; - let manifest_content = database.get_manifest(&name, &digest).await.unwrap(); + let manifest_content = database.get_manifest(&name, &digest).await?; if manifest_content.is_none() { debug!("Failed to get manifest in repo {}, for digest {}", name, digest); // The digest that was provided in the request was invalid. // NOTE: This could also mean that there's a bug and the tag pointed to an invalid manifest. - return StatusCode::NOT_FOUND.into_response(); + return Ok(StatusCode::NOT_FOUND.into_response()); } let manifest_content = manifest_content.unwrap(); - ( + Ok(( StatusCode::OK, [ (HeaderName::from_static("docker-content-digest"), digest), @@ -103,14 +104,14 @@ pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>, (HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string()), ], manifest_content - ).into_response() + ).into_response()) } -pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)>, state: State>, Extension(auth): Extension) -> Response { +pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)>, state: State>, Extension(auth): Extension) -> Result { // Check if the user has permission to pull, or that the repository is public let mut auth_driver = state.auth_checker.lock().await; - if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await.unwrap() { - return unauthenticated_response(&state.config); + if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await? { + return Ok(unauthenticated_response(&state.config)); } drop(auth_driver); @@ -119,23 +120,23 @@ pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String) let digest = match Digest::is_digest(&reference) { true => reference.clone(), false => { - if let Some(tag) = database.get_tag(&name, &reference).await.unwrap() { + if let Some(tag) = database.get_tag(&name, &reference).await? { tag.manifest_digest } else { - return StatusCode::NOT_FOUND.into_response(); + return Ok(StatusCode::NOT_FOUND.into_response()); } } }; - let manifest_content = database.get_manifest(&name, &digest).await.unwrap(); + let manifest_content = database.get_manifest(&name, &digest).await?; if manifest_content.is_none() { // The digest that was provided in the request was invalid. // NOTE: This could also mean that there's a bug and the tag pointed to an invalid manifest. - return StatusCode::NOT_FOUND.into_response(); + return Ok(StatusCode::NOT_FOUND.into_response()); } let manifest_content = manifest_content.unwrap(); - ( + Ok(( StatusCode::OK, [ (HeaderName::from_static("docker-content-digest"), digest), @@ -144,43 +145,41 @@ pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String) (HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string()), ], manifest_content - ).into_response() + ).into_response()) } -pub async fn delete_manifest(Path((name, reference)): Path<(String, String)>, headers: HeaderMap, state: State>, Extension(auth): Extension) -> Response { +pub async fn delete_manifest(Path((name, reference)): Path<(String, String)>, state: State>, Extension(auth): Extension) -> Result { let mut auth_driver = state.auth_checker.lock().await; - if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() { - return unauthenticated_response(&state.config); + if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? { + return Ok(unauthenticated_response(&state.config)); } drop(auth_driver); - - let _authorization = headers.get("Authorization").unwrap(); // TODO: use authorization header let database = &state.database; let digest = match Digest::is_digest(&reference) { true => { // Check if the manifest exists - if database.get_manifest(&name, &reference).await.unwrap().is_none() { - return StatusCode::NOT_FOUND.into_response(); + if database.get_manifest(&name, &reference).await?.is_none() { + return Ok(StatusCode::NOT_FOUND.into_response()); } reference.clone() }, false => { - if let Some(tag) = database.get_tag(&name, &reference).await.unwrap() { + if let Some(tag) = database.get_tag(&name, &reference).await? { tag.manifest_digest } else { - return StatusCode::NOT_FOUND.into_response(); + return Ok(StatusCode::NOT_FOUND.into_response()); } } }; - database.delete_manifest(&name, &digest).await.unwrap(); + database.delete_manifest(&name, &digest).await?; - ( + Ok(( StatusCode::ACCEPTED, [ (header::CONTENT_LENGTH, "None"), ], - ).into_response() + ).into_response()) } \ No newline at end of file diff --git a/src/api/tags.rs b/src/api/tags.rs index 76e8282..6207880 100644 --- a/src/api/tags.rs +++ b/src/api/tags.rs @@ -1,9 +1,9 @@ use std::sync::Arc; -use axum::{extract::{Path, Query, State}, response::IntoResponse, http::{StatusCode, header, HeaderMap, HeaderName}}; +use axum::{extract::{Path, Query, State}, response::{IntoResponse, Response}, http::{StatusCode, header, HeaderMap, HeaderName}}; use serde::{Serialize, Deserialize}; -use crate::{app_state::AppState, database::Database}; +use crate::{app_state::AppState, database::Database, error::AppError}; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] @@ -21,7 +21,7 @@ pub struct ListRepositoriesParams { last_tag: Option, } -pub async fn list_tags(Path((name, )): Path<(String, )>, Query(params): Query, state: State>) -> impl IntoResponse { +pub async fn list_tags(Path((name, )): Path<(String, )>, Query(params): Query, state: State>) -> Result { let mut link_header = None; // Paginate tag results if n was specified, else just pull everything. @@ -31,7 +31,7 @@ pub async fn list_tags(Path((name, )): Path<(String, )>, Query(params): Query
  • , Query(params): Query
  • { - database.list_repository_tags(&name).await.unwrap() + database.list_repository_tags(&name).await? } }; @@ -57,21 +57,21 @@ pub async fn list_tags(Path((name, )): Path<(String, )>, Query(params): Query
  • , Extension(auth): Extension, state: State>) -> Response { +pub async fn start_upload_post(Path((name, )): Path<(String, )>, Extension(auth): Extension, state: State>) -> Result { let mut auth_driver = state.auth_checker.lock().await; - if auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() { + if auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? { debug!("Upload requested"); let uuid = uuid::Uuid::new_v4(); @@ -29,24 +29,24 @@ pub async fn start_upload_post(Path((name, )): Path<(String, )>, Extension(auth) let location = format!("/v2/{}/blobs/uploads/{}", name, uuid.to_string()); debug!("Constructed upload url: {}", location); - return ( + return Ok(( StatusCode::ACCEPTED, [ (header::LOCATION, location) ] - ).into_response(); + ).into_response()); } - unauthenticated_response(&state.config) + Ok(unauthenticated_response(&state.config)) } -pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String, String)>, Extension(auth): Extension, state: State>, mut body: BodyStream) -> Response { +pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String, String)>, Extension(auth): Extension, state: State>, mut body: BodyStream) -> Result { let mut auth_driver = state.auth_checker.lock().await; - if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() { - return unauthenticated_response(&state.config); + if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? { + return Ok(unauthenticated_response(&state.config)); } drop(auth_driver); let storage = state.storage.lock().await; - let current_size = storage.digest_length(&layer_uuid).await.unwrap(); + let current_size = storage.digest_length(&layer_uuid).await?; let written_size = match storage.supports_streaming().await { true => { @@ -61,7 +61,7 @@ pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String, }; let byte_stream = ByteStream::new(io_stream); - let len = storage.save_digest_stream(&layer_uuid, byte_stream, true).await.unwrap(); + let len = storage.save_digest_stream(&layer_uuid, byte_stream, true).await?; len }, @@ -70,11 +70,11 @@ pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String, let mut bytes = BytesMut::new(); while let Some(item) = body.next().await { - bytes.extend_from_slice(&item.unwrap()); + bytes.extend_from_slice(&item?); } let bytes_len = bytes.len(); - storage.save_digest(&layer_uuid, &bytes.into(), true).await.unwrap(); + storage.save_digest(&layer_uuid, &bytes.into(), true).await?; bytes_len } }; @@ -86,7 +86,7 @@ pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String, }; let full_uri = format!("{}/v2/{}/blobs/uploads/{}", state.config.get_url(), name, layer_uuid); - ( + Ok(( StatusCode::ACCEPTED, [ (header::LOCATION, full_uri), @@ -94,13 +94,13 @@ pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String, (header::CONTENT_LENGTH, "0".to_string()), (HeaderName::from_static("docker-upload-uuid"), layer_uuid) ] - ).into_response() + ).into_response()) } -pub async fn finish_chunked_upload_put(Path((name, layer_uuid)): Path<(String, String)>, Query(query): Query>, Extension(auth): Extension, state: State>, body: Bytes) -> Response { +pub async fn finish_chunked_upload_put(Path((name, layer_uuid)): Path<(String, String)>, Query(query): Query>, Extension(auth): Extension, state: State>, body: Bytes) -> Result { let mut auth_driver = state.auth_checker.lock().await; - if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() { - return unauthenticated_response(&state.config); + if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? { + return Ok(unauthenticated_response(&state.config)); } drop(auth_driver); @@ -108,54 +108,54 @@ pub async fn finish_chunked_upload_put(Path((name, layer_uuid)): Path<(String, S let storage = state.storage.lock().await; if !body.is_empty() { - storage.save_digest(&layer_uuid, &body, true).await.unwrap(); + storage.save_digest(&layer_uuid, &body, true).await?; } else { // TODO: Validate layer with all digest params } - storage.replace_digest(&layer_uuid, &digest).await.unwrap(); + storage.replace_digest(&layer_uuid, &digest).await?; debug!("Completed upload, finished uuid {} to digest {}", layer_uuid, digest); - ( + Ok(( StatusCode::CREATED, [ (header::LOCATION, format!("/v2/{}/blobs/{}", name, digest)), (header::CONTENT_LENGTH, "0".to_string()), (HeaderName::from_static("docker-upload-digest"), digest.to_owned()) ] - ).into_response() + ).into_response()) } -pub async fn cancel_upload_delete(Path((name, layer_uuid)): Path<(String, String)>, state: State>, Extension(auth): Extension) -> Response { +pub async fn cancel_upload_delete(Path((name, layer_uuid)): Path<(String, String)>, state: State>, Extension(auth): Extension) -> Result { let mut auth_driver = state.auth_checker.lock().await; - if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() { - return unauthenticated_response(&state.config); + if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? { + return Ok(unauthenticated_response(&state.config)); } drop(auth_driver); let storage = state.storage.lock().await; - storage.delete_digest(&layer_uuid).await.unwrap(); + storage.delete_digest(&layer_uuid).await?; // I'm not sure what this response should be, its not specified in the registry spec. - StatusCode::OK.into_response() + Ok(StatusCode::OK.into_response()) } -pub async fn check_upload_status_get(Path((name, layer_uuid)): Path<(String, String)>, state: State>, Extension(auth): Extension) -> Response { +pub async fn check_upload_status_get(Path((name, layer_uuid)): Path<(String, String)>, state: State>, Extension(auth): Extension) -> Result { let mut auth_driver = state.auth_checker.lock().await; - if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() { - return unauthenticated_response(&state.config); + if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? { + return Ok(unauthenticated_response(&state.config)); } drop(auth_driver); let storage = state.storage.lock().await; - let ending = storage.digest_length(&layer_uuid).await.unwrap().unwrap_or(0); + let ending = storage.digest_length(&layer_uuid).await?.unwrap_or(0); - ( + Ok(( StatusCode::CREATED, [ (header::LOCATION, format!("/v2/{}/blobs/uploads/{}", name, layer_uuid)), (header::RANGE, format!("0-{}", ending)), (HeaderName::from_static("docker-upload-digest"), layer_uuid) ] - ).into_response() + ).into_response()) } \ No newline at end of file diff --git a/src/auth/ldap_driver.rs b/src/auth/ldap_driver.rs index d276055..155e5b1 100644 --- a/src/auth/ldap_driver.rs +++ b/src/auth/ldap_driver.rs @@ -1,11 +1,9 @@ -use std::{slice::Iter, iter::Peekable}; - use async_trait::async_trait; -use ldap3::{LdapConnAsync, Ldap, Scope, asn1::PL, ResultEntry, SearchEntry}; +use ldap3::{LdapConnAsync, Ldap, Scope, SearchEntry}; use sqlx::{Pool, Sqlite}; use tracing::{debug, warn}; -use crate::{config::LdapConnectionConfig, dto::{user::{Permission, LoginSource}, RepositoryVisibility}, database::Database}; +use crate::{config::LdapConnectionConfig, dto::{user::{Permission, LoginSource, RegistryUserType}, RepositoryVisibility}, database::Database}; use super::AuthDriver; @@ -37,46 +35,7 @@ impl LdapAuthDriver { Ok(()) } - /* pub async fn verify_login(&mut self, username: &str, password: &str) -> anyhow::Result { - self.bind().await?; - - let filter = self.ldap_config.user_search_filter.replace("%s", &username); - let res = self.ldap.search(&self.ldap_config.user_base_dn, Scope::Subtree, &filter, - vec!["userPassword", "uid", "cn", "mail", "displayName"]).await?; - let (entries, _res) = res.success()?; - - let entries: Vec = entries - .into_iter() - .map(|e| SearchEntry::construct(e)) - .collect(); - - if entries.is_empty() { - Ok(false) - } else if entries.len() > 1 { - warn!("Got multiple DNs for user ({}), unsure which one to use!!", username); - Ok(false) - } else { - let entry = entries.first().unwrap(); - - let res = self.ldap.simple_bind(&entry.dn, password).await?; - if res.rc == 0 { - Ok(true) - } else if res.rc == 49 { - warn!("User failed to auth (invalidCredentials, rc=49)!"); - Ok(false) - } else { - // this would fail, its just here to propagate the error down - res.success()?; - - Ok(false) - } - } - } */ -} - -#[async_trait] -impl AuthDriver for LdapAuthDriver { - async fn user_has_permission(&mut self, email: String, repository: String, permission: Permission, required_visibility: Option) -> anyhow::Result { + async fn is_user_admin(&mut self, email: String) -> anyhow::Result { self.bind().await?; // Send a request to LDAP to check if the user is an admin @@ -90,7 +49,14 @@ impl AuthDriver for LdapAuthDriver { .map(|e| SearchEntry::construct(e)) .collect(); - if entries.len() > 0 { + Ok(entries.len() > 0) + } +} + +#[async_trait] +impl AuthDriver for LdapAuthDriver { + async fn user_has_permission(&mut self, email: String, repository: String, permission: Permission, required_visibility: Option) -> anyhow::Result { + if self.is_user_admin(email.clone()).await? { Ok(true) } else { debug!("LDAP is falling back to database"); @@ -118,7 +84,7 @@ impl AuthDriver for LdapAuthDriver { warn!("Got multiple DNs for user ({}), unsure which one to use!!", email); Ok(false) } else { - let entry = entries.first().unwrap(); + let entry = entries.first().unwrap(); // there will be an entry let res = self.ldap.simple_bind(&entry.dn, &password).await?; if res.rc == 0 { @@ -126,8 +92,22 @@ impl AuthDriver for LdapAuthDriver { // Check if the user is stored in the database, if not, add it. let database = &self.database; if !database.does_user_exist(email.clone()).await? { - let display_name = entry.attrs.get(&self.ldap_config.display_name_attribute).unwrap().first().unwrap().clone(); - database.create_user(email, display_name, LoginSource::LDAP).await?; + let display_name = match entry.attrs.get(&self.ldap_config.display_name_attribute) { + // theres no way the vector would be empty + Some(display) => display.first().unwrap().clone(), + None => return Ok(false), + }; + + database.create_user(email.clone(), display_name, LoginSource::LDAP).await?; + drop(database); + + // Set the user registry type + let user_type = match self.is_user_admin(email.clone()).await? { + true => RegistryUserType::Admin, + false => RegistryUserType::Regular + }; + + self.database.set_user_registry_type(email, user_type).await?; } Ok(true) diff --git a/src/auth/mod.rs b/src/auth/mod.rs index ab5c886..8ef409e 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -1,8 +1,8 @@ pub mod ldap_driver; -use std::{collections::HashSet, ops::Deref, sync::Arc}; +use std::{ops::Deref, sync::Arc}; -use axum::{extract::{State, Path}, http::{StatusCode, HeaderMap, header, HeaderName, Request}, middleware::Next, response::{Response, IntoResponse}}; +use axum::{extract::State, http::{StatusCode, HeaderMap, header, HeaderName, Request}, middleware::Next, response::{Response, IntoResponse}}; use sqlx::{Pool, Sqlite}; use tracing::debug; @@ -101,7 +101,7 @@ pub async fn require_auth(State(state): State>, mut request: Re // If the token is not valid, return an unauthorized response let database = &state.database; - if let Some(user) = database.verify_user_token(token.to_string()).await.unwrap() { + if let Ok(Some(user)) = database.verify_user_token(token.to_string()).await { debug!("Authenticated user through middleware: {}", user.user.username); request.extensions_mut().insert(user); diff --git a/src/database/mod.rs b/src/database/mod.rs index bf83aa7..974770f 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -12,47 +12,48 @@ pub trait Database { // Digest related functions /// Create the tables in the database - async fn create_schema(&self) -> sqlx::Result<()>; + async fn create_schema(&self) -> anyhow::Result<()>; // Tag related functions /// Get tags associated with a repository - async fn list_repository_tags(&self, repository: &str,) -> sqlx::Result>; - async fn list_repository_tags_page(&self, repository: &str, limit: u32, last_tag: Option) -> sqlx::Result>; + async fn list_repository_tags(&self, repository: &str,) -> anyhow::Result>; + async fn list_repository_tags_page(&self, repository: &str, limit: u32, last_tag: Option) -> anyhow::Result>; /// Get a manifest digest using the tag name. - async fn get_tag(&self, repository: &str, tag: &str) -> sqlx::Result>; + async fn get_tag(&self, repository: &str, tag: &str) -> anyhow::Result>; /// Save a tag and reference it to the manifest digest. - async fn save_tag(&self, repository: &str, tag: &str, manifest_digest: &str) -> sqlx::Result<()>; + async fn save_tag(&self, repository: &str, tag: &str, manifest_digest: &str) -> anyhow::Result<()>; /// Delete a tag. - async fn delete_tag(&self, repository: &str, tag: &str) -> sqlx::Result<()>; + async fn delete_tag(&self, repository: &str, tag: &str) -> anyhow::Result<()>; // Manifest related functions /// Get a manifest's content. - async fn get_manifest(&self, repository: &str, digest: &str) -> sqlx::Result>; + async fn get_manifest(&self, repository: &str, digest: &str) -> anyhow::Result>; /// Save a manifest's content. - async fn save_manifest(&self, repository: &str, digest: &str, content: &str) -> sqlx::Result<()>; + async fn save_manifest(&self, repository: &str, digest: &str, content: &str) -> anyhow::Result<()>; /// Delete a manifest /// Returns digests that this manifest pointed to. - async fn delete_manifest(&self, repository: &str, digest: &str) -> sqlx::Result>; - async fn link_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> sqlx::Result<()>; - async fn unlink_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> sqlx::Result<()>; + async fn delete_manifest(&self, repository: &str, digest: &str) -> anyhow::Result>; + async fn link_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> anyhow::Result<()>; + async fn unlink_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> anyhow::Result<()>; // Repository related functions - async fn has_repository(&self, repository: &str) -> sqlx::Result; + async fn has_repository(&self, repository: &str) -> anyhow::Result; async fn get_repository_visibility(&self, repository: &str) -> anyhow::Result>; /// Create a repository - async fn save_repository(&self, repository: &str, visibility: RepositoryVisibility, owning_project: Option) -> sqlx::Result<()>; + async fn save_repository(&self, repository: &str, visibility: RepositoryVisibility, owning_project: Option) -> anyhow::Result<()>; /// List all repositories. /// If limit is not specified, a default limit of 1000 will be returned. - async fn list_repositories(&self, limit: Option, last_repo: Option) -> sqlx::Result>; + async fn list_repositories(&self, limit: Option, last_repo: Option) -> anyhow::Result>; /// User stuff - async fn does_user_exist(&self, email: String) -> sqlx::Result; - async fn create_user(&self, email: String, username: String, login_source: LoginSource) -> sqlx::Result; - async fn add_user_auth(&self, email: String, password_hash: String, password_salt: String) -> sqlx::Result<()>; + async fn does_user_exist(&self, email: String) -> anyhow::Result; + async fn create_user(&self, email: String, username: String, login_source: LoginSource) -> anyhow::Result; + async fn add_user_auth(&self, email: String, password_hash: String, password_salt: String) -> anyhow::Result<()>; + async fn set_user_registry_type(&self, email: String, user_type: RegistryUserType) -> anyhow::Result<()>; async fn verify_user_login(&self, email: String, password: String) -> anyhow::Result; async fn get_user_registry_type(&self, email: String) -> anyhow::Result>; async fn get_user_repo_permissions(&self, email: String, repository: String) -> anyhow::Result>; @@ -63,7 +64,7 @@ pub trait Database { #[async_trait] impl Database for Pool { - async fn create_schema(&self) -> sqlx::Result<()> { + async fn create_schema(&self) -> anyhow::Result<()> { sqlx::query(include_str!("schemas/schema.sql")) .execute(self).await?; @@ -72,7 +73,7 @@ impl Database for Pool { Ok(()) } - async fn link_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> sqlx::Result<()> { + async fn link_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> anyhow::Result<()> { sqlx::query("INSERT INTO manifest_layers(manifest, layer_digest) VALUES (?, ?)") .bind(manifest_digest) .bind(layer_digest) @@ -83,7 +84,7 @@ impl Database for Pool { Ok(()) } - async fn unlink_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> sqlx::Result<()> { + async fn unlink_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> anyhow::Result<()> { sqlx::query("DELETE FROM manifest_layers WHERE manifest = ? AND layer_digest = ?") .bind(manifest_digest) .bind(layer_digest) @@ -94,7 +95,7 @@ impl Database for Pool { Ok(()) } - async fn list_repository_tags(&self, repository: &str,) -> sqlx::Result> { + async fn list_repository_tags(&self, repository: &str,) -> anyhow::Result> { let rows: Vec<(String, String, i64, )> = sqlx::query_as("SELECT name, image_manifest, last_updated FROM image_tags WHERE repository = ?") .bind(repository) .fetch_all(self).await?; @@ -108,7 +109,7 @@ impl Database for Pool { Ok(tags) } - async fn list_repository_tags_page(&self, repository: &str, limit: u32, last_tag: Option) -> sqlx::Result> { + async fn list_repository_tags_page(&self, repository: &str, limit: u32, last_tag: Option) -> anyhow::Result> { // Query differently depending on if `last_tag` was specified let rows: Vec<(String, String, i64, )> = match last_tag { Some(last_tag) => { @@ -135,7 +136,7 @@ impl Database for Pool { Ok(tags) } - async fn get_tag(&self, repository: &str, tag: &str) -> sqlx::Result> { + async fn get_tag(&self, repository: &str, tag: &str) -> anyhow::Result> { debug!("get tag"); let row: (String, i64, ) = match sqlx::query_as("SELECT image_manifest, last_updated FROM image_tags WHERE name = ? AND repository = ?") .bind(tag) @@ -147,7 +148,7 @@ impl Database for Pool { return Ok(None) }, _ => { - return Err(e); + return Err(anyhow::Error::new(e)); } } }; @@ -157,7 +158,7 @@ impl Database for Pool { Ok(Some(Tag::new(tag.to_string(), repository.to_string(), last_updated, row.0))) } - async fn save_tag(&self, repository: &str, tag: &str, digest: &str) -> sqlx::Result<()> { + async fn save_tag(&self, repository: &str, tag: &str, digest: &str) -> anyhow::Result<()> { sqlx::query("INSERT INTO image_tags (name, repository, image_manifest, last_updated) VALUES (?, ?, ?, ?)") .bind(tag) .bind(repository) @@ -168,7 +169,7 @@ impl Database for Pool { Ok(()) } - async fn delete_tag(&self, repository: &str, tag: &str) -> sqlx::Result<()> { + async fn delete_tag(&self, repository: &str, tag: &str) -> anyhow::Result<()> { sqlx::query("DELETE FROM image_tags WHERE 'name' = ? AND repository = ?") .bind(tag) .bind(repository) @@ -177,7 +178,7 @@ impl Database for Pool { Ok(()) } - async fn get_manifest(&self, repository: &str, digest: &str) -> sqlx::Result> { + async fn get_manifest(&self, repository: &str, digest: &str) -> anyhow::Result> { let row: (String, ) = match sqlx::query_as("SELECT content FROM image_manifests where digest = ? AND repository = ?") .bind(digest) .bind(repository) @@ -188,7 +189,7 @@ impl Database for Pool { return Ok(None) }, _ => { - return Err(e); + return Err(anyhow::Error::new(e)); } } }; @@ -196,7 +197,7 @@ impl Database for Pool { Ok(Some(row.0)) } - async fn save_manifest(&self, repository: &str, digest: &str, manifest: &str) -> sqlx::Result<()> { + async fn save_manifest(&self, repository: &str, digest: &str, manifest: &str) -> anyhow::Result<()> { sqlx::query("INSERT INTO image_manifests (digest, repository, content) VALUES (?, ?, ?)") .bind(digest) .bind(repository) @@ -206,7 +207,7 @@ impl Database for Pool { Ok(()) } - async fn delete_manifest(&self, repository: &str, digest: &str) -> sqlx::Result> { + async fn delete_manifest(&self, repository: &str, digest: &str) -> anyhow::Result> { sqlx::query("DELETE FROM image_manifests where digest = ? AND repository = ?") .bind(digest) .bind(repository) @@ -234,7 +235,7 @@ impl Database for Pool { Ok(digests) } - async fn has_repository(&self, repository: &str) -> sqlx::Result { + async fn has_repository(&self, repository: &str) -> anyhow::Result { let row: (u32, ) = match sqlx::query_as("SELECT COUNT(1) FROM repositories WHERE \"name\" = ?") .bind(repository) .fetch_one(self).await { @@ -244,7 +245,7 @@ impl Database for Pool { return Ok(false) }, _ => { - return Err(e); + return Err(anyhow::Error::new(e)); } } }; @@ -270,7 +271,7 @@ impl Database for Pool { Ok(Some(RepositoryVisibility::try_from(row.0)?)) } - async fn save_repository(&self, repository: &str, visibility: RepositoryVisibility, owning_project: Option) -> sqlx::Result<()> { + async fn save_repository(&self, repository: &str, visibility: RepositoryVisibility, owning_project: Option) -> anyhow::Result<()> { // ensure that the repository was not already created if self.has_repository(repository).await? { debug!("repo exists"); @@ -297,8 +298,8 @@ impl Database for Pool { Ok(()) } - //async fn list_repositories(&self) -> sqlx::Result> { - async fn list_repositories(&self, limit: Option, last_repo: Option) -> sqlx::Result> { + //async fn list_repositories(&self) -> anyhow::Result> { + async fn list_repositories(&self, limit: Option, last_repo: Option) -> anyhow::Result> { let limit = limit.unwrap_or(1000); // set default limit // Query differently depending on if `last_repo` was specified @@ -322,7 +323,7 @@ impl Database for Pool { Ok(repos) } - async fn does_user_exist(&self, email: String) -> sqlx::Result { + async fn does_user_exist(&self, email: String) -> anyhow::Result { let row: (u32, ) = match sqlx::query_as("SELECT COUNT(1) FROM users WHERE \"email\" = ?") .bind(email) .fetch_one(self).await { @@ -332,7 +333,7 @@ impl Database for Pool { return Ok(false) }, _ => { - return Err(e); + return Err(anyhow::Error::new(e)); } } }; @@ -340,7 +341,7 @@ impl Database for Pool { Ok(row.0 > 0) } - async fn create_user(&self, email: String, username: String, login_source: LoginSource) -> sqlx::Result { + async fn create_user(&self, email: String, username: String, login_source: LoginSource) -> anyhow::Result { let username = username.to_lowercase(); let email = email.to_lowercase(); sqlx::query("INSERT INTO users (username, email, login_source) VALUES (?, ?, ?)") @@ -352,7 +353,7 @@ impl Database for Pool { Ok(User::new(username, email, login_source)) } - async fn add_user_auth(&self, email: String, password_hash: String, password_salt: String) -> sqlx::Result<()> { + async fn add_user_auth(&self, email: String, password_hash: String, password_salt: String) -> anyhow::Result<()> { let email = email.to_lowercase(); sqlx::query("INSERT INTO user_logins (email, password_hash, password_salt) VALUES (?, ?, ?)") .bind(email.clone()) @@ -363,6 +364,16 @@ impl Database for Pool { Ok(()) } + async fn set_user_registry_type(&self, email: String, user_type: RegistryUserType) -> anyhow::Result<()> { + let email = email.to_lowercase(); + sqlx::query("INSERT INTO user_registry_permissions (email, user_type) VALUES (?, ?)") + .bind(email.clone()) + .bind(user_type as u32) + .execute(self).await?; + + Ok(()) + } + async fn verify_user_login(&self, email: String, password: String) -> anyhow::Result { let email = email.to_lowercase(); let row: (String, ) = sqlx::query_as("SELECT password_hash FROM users WHERE email = ?") @@ -410,10 +421,17 @@ impl Database for Pool { } }; - let vis = self.get_repository_visibility(&repository).await?.unwrap(); + let vis = match self.get_repository_visibility(&repository).await? { + Some(v) => v, + None => return Ok(None), + }; // Also get the user type for the registry, if its admin return admin repository permissions - let utype = self.get_user_registry_usertype(email).await?.unwrap(); // unwrap should be safe + let utype = match self.get_user_registry_usertype(email).await? { + Some(t) => t, + None => return Ok(None), + }; + if utype == RegistryUserType::Admin { Ok(Some(RepositoryPermissions::new(Permission::ADMIN.bits(), vis))) } else { @@ -479,11 +497,15 @@ impl Database for Pool { .bind(email.clone()) .fetch_one(self).await?; */ - let (expiry, created_at) = (Utc.timestamp_millis_opt(expiry).unwrap(), Utc.timestamp_millis_opt(created_at).unwrap()); - let user = User::new(email, user_row.0, LoginSource::try_from(user_row.1)?); - let token = TokenInfo::new(token, expiry, created_at); - let auth = UserAuth::new(user, token); + let (expiry, created_at) = (Utc.timestamp_millis_opt(expiry).single(), Utc.timestamp_millis_opt(created_at).single()); + if let (Some(expiry), Some(created_at)) = (expiry, created_at) { + let user = User::new(email, user_row.0, LoginSource::try_from(user_row.1)?); + let token = TokenInfo::new(token, expiry, created_at); + let auth = UserAuth::new(user, token); - Ok(Some(auth)) + Ok(Some(auth)) + } else { + Ok(None) + } } } \ No newline at end of file diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..c9fad69 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,26 @@ +use axum::{response::{IntoResponse, Response}, http::StatusCode}; + +// Make our own error that wraps `anyhow::Error`. +pub struct AppError(anyhow::Error); + +// Tell axum how to convert `AppError` into a response. +impl IntoResponse for AppError { + fn into_response(self) -> Response { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Something went wrong: {}", self.0), + ) + .into_response() + } +} + +// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into +// `Result<_, AppError>`. That way you don't need to do that manually. +impl From for AppError +where + E: Into, +{ + fn from(err: E) -> Self { + Self(err.into()) + } +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index e2f9c6d..729a814 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,20 +5,19 @@ mod dto; mod storage; mod byte_stream; mod config; -mod query; mod auth; +mod error; use std::net::SocketAddr; use std::str::FromStr; use std::sync::Arc; use auth::{AuthDriver, ldap_driver::LdapAuthDriver}; -use axum::http::{Request, StatusCode, header, HeaderName}; +use axum::http::{Request, StatusCode}; use axum::middleware::Next; -use axum::response::{Response, IntoResponse}; +use axum::response::Response; use axum::{Router, routing}; use axum::ServiceExt; -use bcrypt::Version; use tower_layer::Layer; use sqlx::sqlite::SqlitePoolOptions; @@ -29,21 +28,21 @@ use tracing::{debug, Level}; use app_state::AppState; use database::Database; -use crate::dto::user::Permission; use crate::storage::StorageDriver; use crate::storage::filesystem::FilesystemDriver; -use crate::config::{Config, LdapConnectionConfig}; +use crate::config::Config; use tower_http::trace::TraceLayer; /// Encode the 'name' path parameter in the url -async fn change_request_paths(mut request: Request, next: Next) -> Response { +async fn change_request_paths(mut request: Request, next: Next) -> Result { // Attempt to find the name using regex in the url - let regex = regex::Regex::new(r"/v2/([\w/]+)/(blobs|tags|manifests)").unwrap(); + let regex = regex::Regex::new(r"/v2/([\w/]+)/(blobs|tags|manifests)") + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; let captures = match regex.captures(request.uri().path()) { Some(captures) => captures, - None => return next.run(request).await, + None => return Ok(next.run(request).await), }; // Find the name in the request and encode it in the url @@ -54,24 +53,25 @@ async fn change_request_paths(mut request: Request, next: Next) -> Resp let uri_str = request.uri().to_string().replace(&name, &encoded_name); debug!("Rewrote request url to: '{}'", uri_str); - *request.uri_mut() = uri_str.parse().unwrap(); + *request.uri_mut() = uri_str.parse() + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - next.run(request).await + Ok(next.run(request).await) } #[tokio::main] -async fn main() -> std::io::Result<()> { +async fn main() -> anyhow::Result<()> { tracing_subscriber::fmt() .with_max_level(Level::DEBUG) .init(); - let config = Config::new().expect("Failure to parse config!"); + let config = Config::new() + .expect("Failure to parse config!"); let pool = SqlitePoolOptions::new() .max_connections(15) - .connect("test.db").await.unwrap(); - - pool.create_schema().await.unwrap(); + .connect("test.db").await?; + pool.create_schema().await?; let storage_driver: Mutex> = Mutex::new(Box::new(FilesystemDriver::new("registry/blobs"))); @@ -79,7 +79,7 @@ async fn main() -> std::io::Result<()> { // the fallback is a database auth driver. let auth_driver: Mutex> = match config.ldap.clone() { Some(ldap) => { - let ldap_driver = LdapAuthDriver::new(ldap, pool.clone()).await.unwrap(); + let ldap_driver = LdapAuthDriver::new(ldap, pool.clone()).await?; Mutex::new(Box::new(ldap_driver)) }, None => { @@ -87,7 +87,7 @@ async fn main() -> std::io::Result<()> { } }; - let app_addr = SocketAddr::from_str(&format!("{}:{}", config.listen_address, config.listen_port)).unwrap(); + let app_addr = SocketAddr::from_str(&format!("{}:{}", config.listen_address, config.listen_port))?; let state = Arc::new(AppState::new(pool, storage_driver, config, auth_driver)); @@ -129,8 +129,7 @@ async fn main() -> std::io::Result<()> { debug!("Starting http server, listening on {}", app_addr); axum::Server::bind(&app_addr) .serve(layered_app.into_make_service()) - .await - .unwrap(); + .await?; Ok(()) } \ No newline at end of file diff --git a/src/storage/filesystem.rs b/src/storage/filesystem.rs index 920762b..1fca687 100644 --- a/src/storage/filesystem.rs +++ b/src/storage/filesystem.rs @@ -1,6 +1,6 @@ use std::{path::Path, io::ErrorKind}; -use anyhow::Context; +use anyhow::{Context, anyhow}; use async_trait::async_trait; use bytes::Bytes; use futures::StreamExt; @@ -53,7 +53,7 @@ impl StorageDriver for FilesystemDriver { len += bytes.len(); - file.write_all(&bytes).await.unwrap(); + file.write_all(&bytes).await?; } Ok(len) @@ -140,9 +140,12 @@ impl StorageDriver for FilesystemDriver { async fn replace_digest(&self, uuid: &str, digest: &str) -> anyhow::Result<()> { let path = self.get_digest_path(uuid); let path = Path::new(&path); - let parent = path.clone().parent().unwrap(); + let parent = path + .clone() + .parent() + .ok_or(anyhow!("Failure to get parent path of digest file!"))?; - fs::rename(path, format!("{}/{}", parent.as_os_str().to_str().unwrap(), digest)).await?; + fs::rename(path, format!("{}/{}", parent.display(), digest)).await?; Ok(()) }