remove ALL necessary unwraps

This commit is contained in:
SeanOMik 2023-05-29 23:36:30 -04:00
parent 7637faf047
commit c4b25e1415
Signed by: SeanOMik
GPG Key ID: 568F326C7EB33ACB
13 changed files with 305 additions and 256 deletions

View File

@ -1,5 +1,9 @@
- [x] slashes in repo names - [x] slashes in repo names
- [ ] Simple auth - [x] Simple auth
- [x] ldap auth
- [ ] permission stuff
- [ ] Only allow users to create repositories if its the same name as their username, or if they're an admin
- [ ] Only allow users to pull from their own repositories
- [ ] postgresql - [ ] postgresql
- [ ] prometheus metrics - [ ] prometheus metrics
- [x] streaming layer bytes into providers - [x] streaming layer bytes into providers

View File

@ -56,9 +56,9 @@ fn create_jwt_token(account: &str) -> anyhow::Result<TokenInfo> {
claims.insert("subject", &account); claims.insert("subject", &account);
//claims.insert("audience", auth.service); //claims.insert("audience", auth.service);
let not_before = format!("{}", now_secs - 10); let not_before = format!("{}", now_secs);
let issued_at = format!("{}", now_secs); let issued_at = format!("{}", now_secs);
let expiration = format!("{}", now_secs + 20); let expiration = format!("{}", now_secs + 86400); // 1 day
claims.insert("notbefore", &not_before); claims.insert("notbefore", &not_before);
claims.insert("issuedat", &issued_at); claims.insert("issuedat", &issued_at);
claims.insert("expiration", &expiration); // TODO: 20 seconds expiry for testing claims.insert("expiration", &expiration); // TODO: 20 seconds expiry for testing
@ -75,7 +75,7 @@ fn create_jwt_token(account: &str) -> anyhow::Result<TokenInfo> {
Ok(TokenInfo::new(token_str, expiration, issued_at)) Ok(TokenInfo::new(token_str, expiration, issued_at))
} }
pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppState>>, Query(params): Query<HashMap<String, String>>, form: Option<Form<AuthForm>>) -> Response { pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppState>>, Query(params): Query<HashMap<String, String>>, form: Option<Form<AuthForm>>) -> Result<Response, StatusCode> {
let mut auth = TokenAuthRequest { let mut auth = TokenAuthRequest {
user: None, user: None,
password: None, password: None,
@ -117,7 +117,7 @@ pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppS
info!("Auth failure! Auth was not provided in either AuthBasic or Form!"); info!("Auth failure! Auth was not provided in either AuthBasic or Form!");
// Maybe BAD_REQUEST should be returned? // Maybe BAD_REQUEST should be returned?
return (StatusCode::UNAUTHORIZED).into_response(); return Err(StatusCode::UNAUTHORIZED);
} }
// Create logging span for the rest of this request // Create logging span for the rest of this request
@ -133,7 +133,7 @@ pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppS
if account != user { if account != user {
error!("`user` and `account` are not the same!!! (user: {}, account: {})", user, account); error!("`user` and `account` are not the same!!! (user: {}, account: {})", user, account);
return (StatusCode::BAD_REQUEST).into_response(); return Err(StatusCode::BAD_REQUEST);
} }
} }
@ -149,7 +149,14 @@ pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppS
if let Some(scope) = params.get("scope") { if let Some(scope) = params.get("scope") {
// TODO: Handle multiple scopes // TODO: Handle multiple scopes
auth.scope.push(Scope::try_from(&scope[..]).unwrap()); match Scope::try_from(&scope[..]) {
Ok(scope) => {
auth.scope.push(scope);
},
Err(_) => {
return Err(StatusCode::BAD_REQUEST);
}
}
} }
// Get offline token and attempt to convert it to a boolean // Get offline token and attempt to convert it to a boolean
@ -168,17 +175,19 @@ pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppS
if let (Some(account), Some(password)) = (&auth.account, auth.password) { if let (Some(account), Some(password)) = (&auth.account, auth.password) {
// Ensure that the password is correct // Ensure that the password is correct
let mut auth_driver = state.auth_checker.lock().await; let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.verify_user_login(account.clone(), password).await.unwrap() { if !auth_driver.verify_user_login(account.clone(), password).await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? {
debug!("Authentication failed, incorrect password!"); debug!("Authentication failed, incorrect password!");
return unauthenticated_response(&state.config); return Ok(unauthenticated_response(&state.config));
} }
drop(auth_driver); drop(auth_driver);
debug!("User password is correct"); debug!("User password is correct");
let now = SystemTime::now(); let now = SystemTime::now();
let token = create_jwt_token(account).unwrap(); let token = create_jwt_token(account)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let token_str = token.token; let token_str = token.token;
debug!("Created jwt token"); debug!("Created jwt token");
@ -194,23 +203,25 @@ pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppS
issued_at: now_format, issued_at: now_format,
}; };
let json_str = serde_json::to_string(&auth_response).unwrap(); let json_str = serde_json::to_string(&auth_response)
.map_err(|_| StatusCode::BAD_REQUEST)?;
let database = &state.database; let database = &state.database;
database.store_user_token(token_str.clone(), account.clone(), token.expiry, token.created_at).await.unwrap(); database.store_user_token(token_str.clone(), account.clone(), token.expiry, token.created_at).await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
drop(database); drop(database);
return ( return Ok((
StatusCode::OK, StatusCode::OK,
[ [
( header::CONTENT_TYPE, "application/json" ), ( header::CONTENT_TYPE, "application/json" ),
( header::AUTHORIZATION, &format!("Bearer {}", token_str) ) ( header::AUTHORIZATION, &format!("Bearer {}", token_str) )
], ],
json_str json_str
).into_response(); ).into_response());
} }
info!("Auth failure! Not enough information given to create auth token!"); info!("Auth failure! Not enough information given to create auth token!");
// If we didn't get fields required to make a token, then the client did something bad // If we didn't get fields required to make a token, then the client did something bad
(StatusCode::UNAUTHORIZED).into_response() Err(StatusCode::UNAUTHORIZED)
} }

View File

@ -8,64 +8,69 @@ use axum::response::{IntoResponse, Response};
use tokio_util::io::ReaderStream; use tokio_util::io::ReaderStream;
use crate::app_state::AppState; use crate::app_state::AppState;
use crate::auth::{unauthenticated_response, AuthDriver}; use crate::auth::unauthenticated_response;
use crate::database::Database;
use crate::dto::RepositoryVisibility; use crate::dto::RepositoryVisibility;
use crate::dto::user::{Permission, RegistryUserType, UserAuth}; use crate::dto::user::{Permission, UserAuth};
use crate::error::AppError;
pub async fn digest_exists_head(Path((name, layer_digest)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Response { pub async fn digest_exists_head(Path((name, layer_digest)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Result<Response, AppError> {
// Check if the user has permission to pull, or that the repository is public // Check if the user has permission to pull, or that the repository is public
let mut auth_driver = state.auth_checker.lock().await; let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await.unwrap() { if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await? {
return unauthenticated_response(&state.config); return Ok(unauthenticated_response(&state.config));
} }
drop(auth_driver); drop(auth_driver);
let storage = state.storage.lock().await; let storage = state.storage.lock().await;
if storage.has_digest(&layer_digest).await.unwrap() { if storage.has_digest(&layer_digest).await? {
if let Some(size) = storage.digest_length(&layer_digest).await.unwrap() { if let Some(size) = storage.digest_length(&layer_digest).await? {
return ( return Ok((
StatusCode::OK, StatusCode::OK,
[ [
(header::CONTENT_LENGTH, size.to_string()), (header::CONTENT_LENGTH, size.to_string()),
(HeaderName::from_static("docker-content-digest"), layer_digest) (HeaderName::from_static("docker-content-digest"), layer_digest)
] ]
).into_response(); ).into_response());
} }
} }
StatusCode::NOT_FOUND.into_response() Ok(StatusCode::NOT_FOUND.into_response())
} }
pub async fn pull_digest_get(Path((name, layer_digest)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Response { pub async fn pull_digest_get(Path((name, layer_digest)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Result<Response, AppError> {
// Check if the user has permission to pull, or that the repository is public // Check if the user has permission to pull, or that the repository is public
let mut auth_driver = state.auth_checker.lock().await; let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await.unwrap() { if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await? {
return unauthenticated_response(&state.config); return Ok(unauthenticated_response(&state.config));
} }
drop(auth_driver); drop(auth_driver);
let storage = state.storage.lock().await; let storage = state.storage.lock().await;
if let Some(len) = storage.digest_length(&layer_digest).await.unwrap() { if let Some(len) = storage.digest_length(&layer_digest).await? {
let stream = storage.get_digest_stream(&layer_digest).await.unwrap().unwrap(); let stream = match storage.get_digest_stream(&layer_digest).await? {
Some(s) => s,
None => {
return Ok(StatusCode::NOT_FOUND.into_response());
}
};
// convert the `AsyncRead` into a `Stream` // convert the `AsyncRead` into a `Stream`
let stream = ReaderStream::new(stream.into_async_read()); let stream = ReaderStream::new(stream.into_async_read());
// convert the `Stream` into an `axum::body::HttpBody` // convert the `Stream` into an `axum::body::HttpBody`
let body = StreamBody::new(stream); let body = StreamBody::new(stream);
( Ok((
StatusCode::OK, StatusCode::OK,
[ [
(header::CONTENT_LENGTH, len.to_string()), (header::CONTENT_LENGTH, len.to_string()),
(HeaderName::from_static("docker-content-digest"), layer_digest) (HeaderName::from_static("docker-content-digest"), layer_digest)
], ],
body body
).into_response() ).into_response())
} else { } else {
StatusCode::NOT_FOUND.into_response() Ok(StatusCode::NOT_FOUND.into_response())
} }
} }

View File

@ -1,9 +1,9 @@
use std::sync::Arc; use std::sync::Arc;
use axum::{extract::{State, Query}, http::{StatusCode, header, HeaderMap, HeaderName}, response::IntoResponse}; use axum::{extract::{State, Query}, http::{StatusCode, header, HeaderMap, HeaderName}, response::{IntoResponse, Response}};
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use crate::{app_state::AppState, database::Database}; use crate::{app_state::AppState, database::Database, error::AppError};
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
@ -20,7 +20,7 @@ pub struct ListRepositoriesParams {
last_repo: Option<String>, last_repo: Option<String>,
} }
pub async fn list_repositories(Query(params): Query<ListRepositoriesParams>, state: State<Arc<AppState>>) -> impl IntoResponse { pub async fn list_repositories(Query(params): Query<ListRepositoriesParams>, state: State<Arc<AppState>>) -> Result<Response, AppError> {
let mut link_header = None; let mut link_header = None;
// Paginate tag results if n was specified, else just pull everything. // Paginate tag results if n was specified, else just pull everything.
@ -30,7 +30,7 @@ pub async fn list_repositories(Query(params): Query<ListRepositoriesParams>, sta
// Convert the last param to a String, and list all the repos // Convert the last param to a String, and list all the repos
let last_repo = params.last_repo.and_then(|t| Some(t.to_string())); let last_repo = params.last_repo.and_then(|t| Some(t.to_string()));
let repos = database.list_repositories(Some(limit), last_repo).await.unwrap(); let repos = database.list_repositories(Some(limit), last_repo).await?;
// Get the new last repository for the response // Get the new last repository for the response
let last_repo = repos.last().and_then(|s| Some(s.clone())); let last_repo = repos.last().and_then(|s| Some(s.clone()));
@ -47,7 +47,7 @@ pub async fn list_repositories(Query(params): Query<ListRepositoriesParams>, sta
repos repos
}, },
None => { None => {
database.list_repositories(None, None).await.unwrap() database.list_repositories(None, None).await?
} }
}; };
@ -55,20 +55,20 @@ pub async fn list_repositories(Query(params): Query<ListRepositoriesParams>, sta
let repo_list = RepositoryList { let repo_list = RepositoryList {
repositories, repositories,
}; };
let response_body = serde_json::to_string(&repo_list).unwrap(); let response_body = serde_json::to_string(&repo_list)?;
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, "application/json".parse().unwrap()); headers.insert(header::CONTENT_TYPE, "application/json".parse()?);
headers.insert(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".parse().unwrap()); headers.insert(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".parse()?);
if let Some(link_header) = link_header { if let Some(link_header) = link_header {
headers.insert(header::LINK, link_header.parse().unwrap()); headers.insert(header::LINK, link_header.parse()?);
} }
// Construct the response, optionally adding the Link header if it was constructed. // Construct the response, optionally adding the Link header if it was constructed.
( Ok((
StatusCode::OK, StatusCode::OK,
headers, headers,
response_body response_body
) ).into_response())
} }

View File

@ -3,22 +3,23 @@ use std::sync::Arc;
use axum::Extension; use axum::Extension;
use axum::extract::{Path, State}; use axum::extract::{Path, State};
use axum::response::{Response, IntoResponse}; use axum::response::{Response, IntoResponse};
use axum::http::{StatusCode, HeaderMap, HeaderName, header}; use axum::http::{StatusCode, HeaderName, header};
use tracing::log::warn; use tracing::log::warn;
use tracing::{debug, info}; use tracing::{debug, info};
use crate::auth::{unauthenticated_response, AuthDriver}; use crate::auth::unauthenticated_response;
use crate::app_state::AppState; use crate::app_state::AppState;
use crate::database::Database; use crate::database::Database;
use crate::dto::RepositoryVisibility; use crate::dto::RepositoryVisibility;
use crate::dto::digest::Digest; use crate::dto::digest::Digest;
use crate::dto::manifest::Manifest; use crate::dto::manifest::Manifest;
use crate::dto::user::{UserAuth, Permission}; use crate::dto::user::{UserAuth, Permission};
use crate::error::AppError;
pub async fn upload_manifest_put(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>, body: String) -> Response { pub async fn upload_manifest_put(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>, body: String) -> Result<Response, AppError> {
let mut auth_driver = state.auth_checker.lock().await; let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() { if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? {
return unauthenticated_response(&state.config); return Ok(unauthenticated_response(&state.config));
} }
drop(auth_driver); drop(auth_driver);
@ -29,45 +30,45 @@ pub async fn upload_manifest_put(Path((name, reference)): Path<(String, String)>
let database = &state.database; let database = &state.database;
// Create the image repository and save the image manifest. This repository will be private by default // Create the image repository and save the image manifest. This repository will be private by default
database.save_repository(&name, RepositoryVisibility::Private, None).await.unwrap(); database.save_repository(&name, RepositoryVisibility::Private, None).await?;
database.save_manifest(&name, &calculated_digest, &body).await.unwrap(); database.save_manifest(&name, &calculated_digest, &body).await?;
// If the reference is not a digest, then it must be a tag name. // If the reference is not a digest, then it must be a tag name.
if !Digest::is_digest(&reference) { if !Digest::is_digest(&reference) {
database.save_tag(&name, &reference, &calculated_digest).await.unwrap(); database.save_tag(&name, &reference, &calculated_digest).await?;
} }
info!("Saved manifest {}", calculated_digest); info!("Saved manifest {}", calculated_digest);
match serde_json::from_str(&body).unwrap() { match serde_json::from_str(&body)? {
Manifest::Image(image) => { Manifest::Image(image) => {
// Link the manifest to the image layer // Link the manifest to the image layer
database.link_manifest_layer(&calculated_digest, &image.config.digest).await.unwrap(); database.link_manifest_layer(&calculated_digest, &image.config.digest).await?;
debug!("Linked manifest {} to layer {}", calculated_digest, image.config.digest); debug!("Linked manifest {} to layer {}", calculated_digest, image.config.digest);
for layer in image.layers { for layer in image.layers {
database.link_manifest_layer(&calculated_digest, &layer.digest).await.unwrap(); database.link_manifest_layer(&calculated_digest, &layer.digest).await?;
debug!("Linked manifest {} to layer {}", calculated_digest, image.config.digest); debug!("Linked manifest {} to layer {}", calculated_digest, image.config.digest);
} }
( Ok((
StatusCode::CREATED, StatusCode::CREATED,
[ (HeaderName::from_static("docker-content-digest"), calculated_digest) ] [ (HeaderName::from_static("docker-content-digest"), calculated_digest) ]
).into_response() ).into_response())
}, },
Manifest::List(_list) => { Manifest::List(_list) => {
warn!("ManifestList request was received!"); warn!("ManifestList request was received!");
StatusCode::NOT_IMPLEMENTED.into_response() Ok(StatusCode::NOT_IMPLEMENTED.into_response())
} }
} }
} }
pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Response { pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Result<Response, AppError> {
// Check if the user has permission to pull, or that the repository is public // Check if the user has permission to pull, or that the repository is public
let mut auth_driver = state.auth_checker.lock().await; let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await.unwrap() { if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await? {
return unauthenticated_response(&state.config); return Ok(unauthenticated_response(&state.config));
} }
drop(auth_driver); drop(auth_driver);
@ -76,24 +77,24 @@ pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>,
true => reference.clone(), true => reference.clone(),
false => { false => {
debug!("Attempting to get manifest digest using tag (repository={}, reference={})", name, reference); debug!("Attempting to get manifest digest using tag (repository={}, reference={})", name, reference);
if let Some(tag) = database.get_tag(&name, &reference).await.unwrap() { if let Some(tag) = database.get_tag(&name, &reference).await? {
tag.manifest_digest tag.manifest_digest
} else { } else {
return StatusCode::NOT_FOUND.into_response(); return Ok(StatusCode::NOT_FOUND.into_response());
} }
} }
}; };
let manifest_content = database.get_manifest(&name, &digest).await.unwrap(); let manifest_content = database.get_manifest(&name, &digest).await?;
if manifest_content.is_none() { if manifest_content.is_none() {
debug!("Failed to get manifest in repo {}, for digest {}", name, digest); debug!("Failed to get manifest in repo {}, for digest {}", name, digest);
// The digest that was provided in the request was invalid. // The digest that was provided in the request was invalid.
// NOTE: This could also mean that there's a bug and the tag pointed to an invalid manifest. // NOTE: This could also mean that there's a bug and the tag pointed to an invalid manifest.
return StatusCode::NOT_FOUND.into_response(); return Ok(StatusCode::NOT_FOUND.into_response());
} }
let manifest_content = manifest_content.unwrap(); let manifest_content = manifest_content.unwrap();
( Ok((
StatusCode::OK, StatusCode::OK,
[ [
(HeaderName::from_static("docker-content-digest"), digest), (HeaderName::from_static("docker-content-digest"), digest),
@ -103,14 +104,14 @@ pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>,
(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string()), (HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string()),
], ],
manifest_content manifest_content
).into_response() ).into_response())
} }
pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Response { pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Result<Response, AppError> {
// Check if the user has permission to pull, or that the repository is public // Check if the user has permission to pull, or that the repository is public
let mut auth_driver = state.auth_checker.lock().await; let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await.unwrap() { if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await? {
return unauthenticated_response(&state.config); return Ok(unauthenticated_response(&state.config));
} }
drop(auth_driver); drop(auth_driver);
@ -119,23 +120,23 @@ pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)
let digest = match Digest::is_digest(&reference) { let digest = match Digest::is_digest(&reference) {
true => reference.clone(), true => reference.clone(),
false => { false => {
if let Some(tag) = database.get_tag(&name, &reference).await.unwrap() { if let Some(tag) = database.get_tag(&name, &reference).await? {
tag.manifest_digest tag.manifest_digest
} else { } else {
return StatusCode::NOT_FOUND.into_response(); return Ok(StatusCode::NOT_FOUND.into_response());
} }
} }
}; };
let manifest_content = database.get_manifest(&name, &digest).await.unwrap(); let manifest_content = database.get_manifest(&name, &digest).await?;
if manifest_content.is_none() { if manifest_content.is_none() {
// The digest that was provided in the request was invalid. // The digest that was provided in the request was invalid.
// NOTE: This could also mean that there's a bug and the tag pointed to an invalid manifest. // NOTE: This could also mean that there's a bug and the tag pointed to an invalid manifest.
return StatusCode::NOT_FOUND.into_response(); return Ok(StatusCode::NOT_FOUND.into_response());
} }
let manifest_content = manifest_content.unwrap(); let manifest_content = manifest_content.unwrap();
( Ok((
StatusCode::OK, StatusCode::OK,
[ [
(HeaderName::from_static("docker-content-digest"), digest), (HeaderName::from_static("docker-content-digest"), digest),
@ -144,43 +145,41 @@ pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)
(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string()), (HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string()),
], ],
manifest_content manifest_content
).into_response() ).into_response())
} }
pub async fn delete_manifest(Path((name, reference)): Path<(String, String)>, headers: HeaderMap, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Response { pub async fn delete_manifest(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Result<Response, AppError> {
let mut auth_driver = state.auth_checker.lock().await; let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() { if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? {
return unauthenticated_response(&state.config); return Ok(unauthenticated_response(&state.config));
} }
drop(auth_driver); drop(auth_driver);
let _authorization = headers.get("Authorization").unwrap(); // TODO: use authorization header
let database = &state.database; let database = &state.database;
let digest = match Digest::is_digest(&reference) { let digest = match Digest::is_digest(&reference) {
true => { true => {
// Check if the manifest exists // Check if the manifest exists
if database.get_manifest(&name, &reference).await.unwrap().is_none() { if database.get_manifest(&name, &reference).await?.is_none() {
return StatusCode::NOT_FOUND.into_response(); return Ok(StatusCode::NOT_FOUND.into_response());
} }
reference.clone() reference.clone()
}, },
false => { false => {
if let Some(tag) = database.get_tag(&name, &reference).await.unwrap() { if let Some(tag) = database.get_tag(&name, &reference).await? {
tag.manifest_digest tag.manifest_digest
} else { } else {
return StatusCode::NOT_FOUND.into_response(); return Ok(StatusCode::NOT_FOUND.into_response());
} }
} }
}; };
database.delete_manifest(&name, &digest).await.unwrap(); database.delete_manifest(&name, &digest).await?;
( Ok((
StatusCode::ACCEPTED, StatusCode::ACCEPTED,
[ [
(header::CONTENT_LENGTH, "None"), (header::CONTENT_LENGTH, "None"),
], ],
).into_response() ).into_response())
} }

View File

@ -1,9 +1,9 @@
use std::sync::Arc; use std::sync::Arc;
use axum::{extract::{Path, Query, State}, response::IntoResponse, http::{StatusCode, header, HeaderMap, HeaderName}}; use axum::{extract::{Path, Query, State}, response::{IntoResponse, Response}, http::{StatusCode, header, HeaderMap, HeaderName}};
use serde::{Serialize, Deserialize}; use serde::{Serialize, Deserialize};
use crate::{app_state::AppState, database::Database}; use crate::{app_state::AppState, database::Database, error::AppError};
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
@ -21,7 +21,7 @@ pub struct ListRepositoriesParams {
last_tag: Option<String>, last_tag: Option<String>,
} }
pub async fn list_tags(Path((name, )): Path<(String, )>, Query(params): Query<ListRepositoriesParams>, state: State<Arc<AppState>>) -> impl IntoResponse { pub async fn list_tags(Path((name, )): Path<(String, )>, Query(params): Query<ListRepositoriesParams>, state: State<Arc<AppState>>) -> Result<Response, AppError> {
let mut link_header = None; let mut link_header = None;
// Paginate tag results if n was specified, else just pull everything. // Paginate tag results if n was specified, else just pull everything.
@ -31,7 +31,7 @@ pub async fn list_tags(Path((name, )): Path<(String, )>, Query(params): Query<Li
// Convert the last param to a String, and list all the tags // Convert the last param to a String, and list all the tags
let last_tag = params.last_tag.and_then(|t| Some(t.to_string())); let last_tag = params.last_tag.and_then(|t| Some(t.to_string()));
let tags = database.list_repository_tags_page(&name, limit, last_tag).await.unwrap(); let tags = database.list_repository_tags_page(&name, limit, last_tag).await?;
// Get the new last repository for the response // Get the new last repository for the response
let last_tag = tags.last(); let last_tag = tags.last();
@ -48,7 +48,7 @@ pub async fn list_tags(Path((name, )): Path<(String, )>, Query(params): Query<Li
tags tags
}, },
None => { None => {
database.list_repository_tags(&name).await.unwrap() database.list_repository_tags(&name).await?
} }
}; };
@ -57,21 +57,21 @@ pub async fn list_tags(Path((name, )): Path<(String, )>, Query(params): Query<Li
name, name,
tags: tags.into_iter().map(|t| t.name).collect(), tags: tags.into_iter().map(|t| t.name).collect(),
}; };
let response_body = serde_json::to_string(&tag_list).unwrap(); let response_body = serde_json::to_string(&tag_list)?;
// Create headers // Create headers
let mut headers = HeaderMap::new(); let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, "application/json".parse().unwrap()); headers.insert(header::CONTENT_TYPE, "application/json".parse()?);
headers.insert(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".parse().unwrap()); headers.insert(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".parse()?);
// Add the link header if it was constructed // Add the link header if it was constructed
if let Some(link_header) = link_header { if let Some(link_header) = link_header {
headers.insert(header::LINK, link_header.parse().unwrap()); headers.insert(header::LINK, link_header.parse()?);
} }
( Ok((
StatusCode::OK, StatusCode::OK,
headers, headers,
response_body response_body
) ).into_response())
} }

View File

@ -12,15 +12,15 @@ use futures::StreamExt;
use tracing::{debug, warn}; use tracing::{debug, warn};
use crate::app_state::AppState; use crate::app_state::AppState;
use crate::auth::{unauthenticated_response, AuthDriver}; use crate::auth::unauthenticated_response;
use crate::byte_stream::ByteStream; use crate::byte_stream::ByteStream;
use crate::database::Database; use crate::dto::user::{UserAuth, Permission};
use crate::dto::user::{UserAuth, Permission, RegistryUser, RegistryUserType}; use crate::error::AppError;
/// Starting an upload /// Starting an upload
pub async fn start_upload_post(Path((name, )): Path<(String, )>, Extension(auth): Extension<UserAuth>, state: State<Arc<AppState>>) -> Response { pub async fn start_upload_post(Path((name, )): Path<(String, )>, Extension(auth): Extension<UserAuth>, state: State<Arc<AppState>>) -> Result<Response, AppError> {
let mut auth_driver = state.auth_checker.lock().await; let mut auth_driver = state.auth_checker.lock().await;
if auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() { if auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? {
debug!("Upload requested"); debug!("Upload requested");
let uuid = uuid::Uuid::new_v4(); let uuid = uuid::Uuid::new_v4();
@ -29,24 +29,24 @@ pub async fn start_upload_post(Path((name, )): Path<(String, )>, Extension(auth)
let location = format!("/v2/{}/blobs/uploads/{}", name, uuid.to_string()); let location = format!("/v2/{}/blobs/uploads/{}", name, uuid.to_string());
debug!("Constructed upload url: {}", location); debug!("Constructed upload url: {}", location);
return ( return Ok((
StatusCode::ACCEPTED, StatusCode::ACCEPTED,
[ (header::LOCATION, location) ] [ (header::LOCATION, location) ]
).into_response(); ).into_response());
} }
unauthenticated_response(&state.config) Ok(unauthenticated_response(&state.config))
} }
pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String, String)>, Extension(auth): Extension<UserAuth>, state: State<Arc<AppState>>, mut body: BodyStream) -> Response { pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String, String)>, Extension(auth): Extension<UserAuth>, state: State<Arc<AppState>>, mut body: BodyStream) -> Result<Response, AppError> {
let mut auth_driver = state.auth_checker.lock().await; let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() { if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? {
return unauthenticated_response(&state.config); return Ok(unauthenticated_response(&state.config));
} }
drop(auth_driver); drop(auth_driver);
let storage = state.storage.lock().await; let storage = state.storage.lock().await;
let current_size = storage.digest_length(&layer_uuid).await.unwrap(); let current_size = storage.digest_length(&layer_uuid).await?;
let written_size = match storage.supports_streaming().await { let written_size = match storage.supports_streaming().await {
true => { true => {
@ -61,7 +61,7 @@ pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String,
}; };
let byte_stream = ByteStream::new(io_stream); let byte_stream = ByteStream::new(io_stream);
let len = storage.save_digest_stream(&layer_uuid, byte_stream, true).await.unwrap(); let len = storage.save_digest_stream(&layer_uuid, byte_stream, true).await?;
len len
}, },
@ -70,11 +70,11 @@ pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String,
let mut bytes = BytesMut::new(); let mut bytes = BytesMut::new();
while let Some(item) = body.next().await { while let Some(item) = body.next().await {
bytes.extend_from_slice(&item.unwrap()); bytes.extend_from_slice(&item?);
} }
let bytes_len = bytes.len(); let bytes_len = bytes.len();
storage.save_digest(&layer_uuid, &bytes.into(), true).await.unwrap(); storage.save_digest(&layer_uuid, &bytes.into(), true).await?;
bytes_len bytes_len
} }
}; };
@ -86,7 +86,7 @@ pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String,
}; };
let full_uri = format!("{}/v2/{}/blobs/uploads/{}", state.config.get_url(), name, layer_uuid); let full_uri = format!("{}/v2/{}/blobs/uploads/{}", state.config.get_url(), name, layer_uuid);
( Ok((
StatusCode::ACCEPTED, StatusCode::ACCEPTED,
[ [
(header::LOCATION, full_uri), (header::LOCATION, full_uri),
@ -94,13 +94,13 @@ pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String,
(header::CONTENT_LENGTH, "0".to_string()), (header::CONTENT_LENGTH, "0".to_string()),
(HeaderName::from_static("docker-upload-uuid"), layer_uuid) (HeaderName::from_static("docker-upload-uuid"), layer_uuid)
] ]
).into_response() ).into_response())
} }
pub async fn finish_chunked_upload_put(Path((name, layer_uuid)): Path<(String, String)>, Query(query): Query<HashMap<String, String>>, Extension(auth): Extension<UserAuth>, state: State<Arc<AppState>>, body: Bytes) -> Response { pub async fn finish_chunked_upload_put(Path((name, layer_uuid)): Path<(String, String)>, Query(query): Query<HashMap<String, String>>, Extension(auth): Extension<UserAuth>, state: State<Arc<AppState>>, body: Bytes) -> Result<Response, AppError> {
let mut auth_driver = state.auth_checker.lock().await; let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() { if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? {
return unauthenticated_response(&state.config); return Ok(unauthenticated_response(&state.config));
} }
drop(auth_driver); drop(auth_driver);
@ -108,54 +108,54 @@ pub async fn finish_chunked_upload_put(Path((name, layer_uuid)): Path<(String, S
let storage = state.storage.lock().await; let storage = state.storage.lock().await;
if !body.is_empty() { if !body.is_empty() {
storage.save_digest(&layer_uuid, &body, true).await.unwrap(); storage.save_digest(&layer_uuid, &body, true).await?;
} else { } else {
// TODO: Validate layer with all digest params // TODO: Validate layer with all digest params
} }
storage.replace_digest(&layer_uuid, &digest).await.unwrap(); storage.replace_digest(&layer_uuid, &digest).await?;
debug!("Completed upload, finished uuid {} to digest {}", layer_uuid, digest); debug!("Completed upload, finished uuid {} to digest {}", layer_uuid, digest);
( Ok((
StatusCode::CREATED, StatusCode::CREATED,
[ [
(header::LOCATION, format!("/v2/{}/blobs/{}", name, digest)), (header::LOCATION, format!("/v2/{}/blobs/{}", name, digest)),
(header::CONTENT_LENGTH, "0".to_string()), (header::CONTENT_LENGTH, "0".to_string()),
(HeaderName::from_static("docker-upload-digest"), digest.to_owned()) (HeaderName::from_static("docker-upload-digest"), digest.to_owned())
] ]
).into_response() ).into_response())
} }
pub async fn cancel_upload_delete(Path((name, layer_uuid)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Response { pub async fn cancel_upload_delete(Path((name, layer_uuid)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Result<Response, AppError> {
let mut auth_driver = state.auth_checker.lock().await; let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() { if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? {
return unauthenticated_response(&state.config); return Ok(unauthenticated_response(&state.config));
} }
drop(auth_driver); drop(auth_driver);
let storage = state.storage.lock().await; let storage = state.storage.lock().await;
storage.delete_digest(&layer_uuid).await.unwrap(); storage.delete_digest(&layer_uuid).await?;
// I'm not sure what this response should be, its not specified in the registry spec. // I'm not sure what this response should be, its not specified in the registry spec.
StatusCode::OK.into_response() Ok(StatusCode::OK.into_response())
} }
pub async fn check_upload_status_get(Path((name, layer_uuid)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Response { pub async fn check_upload_status_get(Path((name, layer_uuid)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Result<Response, AppError> {
let mut auth_driver = state.auth_checker.lock().await; let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() { if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? {
return unauthenticated_response(&state.config); return Ok(unauthenticated_response(&state.config));
} }
drop(auth_driver); drop(auth_driver);
let storage = state.storage.lock().await; let storage = state.storage.lock().await;
let ending = storage.digest_length(&layer_uuid).await.unwrap().unwrap_or(0); let ending = storage.digest_length(&layer_uuid).await?.unwrap_or(0);
( Ok((
StatusCode::CREATED, StatusCode::CREATED,
[ [
(header::LOCATION, format!("/v2/{}/blobs/uploads/{}", name, layer_uuid)), (header::LOCATION, format!("/v2/{}/blobs/uploads/{}", name, layer_uuid)),
(header::RANGE, format!("0-{}", ending)), (header::RANGE, format!("0-{}", ending)),
(HeaderName::from_static("docker-upload-digest"), layer_uuid) (HeaderName::from_static("docker-upload-digest"), layer_uuid)
] ]
).into_response() ).into_response())
} }

View File

@ -1,11 +1,9 @@
use std::{slice::Iter, iter::Peekable};
use async_trait::async_trait; use async_trait::async_trait;
use ldap3::{LdapConnAsync, Ldap, Scope, asn1::PL, ResultEntry, SearchEntry}; use ldap3::{LdapConnAsync, Ldap, Scope, SearchEntry};
use sqlx::{Pool, Sqlite}; use sqlx::{Pool, Sqlite};
use tracing::{debug, warn}; use tracing::{debug, warn};
use crate::{config::LdapConnectionConfig, dto::{user::{Permission, LoginSource}, RepositoryVisibility}, database::Database}; use crate::{config::LdapConnectionConfig, dto::{user::{Permission, LoginSource, RegistryUserType}, RepositoryVisibility}, database::Database};
use super::AuthDriver; use super::AuthDriver;
@ -37,46 +35,7 @@ impl LdapAuthDriver {
Ok(()) Ok(())
} }
/* pub async fn verify_login(&mut self, username: &str, password: &str) -> anyhow::Result<bool> { async fn is_user_admin(&mut self, email: String) -> anyhow::Result<bool> {
self.bind().await?;
let filter = self.ldap_config.user_search_filter.replace("%s", &username);
let res = self.ldap.search(&self.ldap_config.user_base_dn, Scope::Subtree, &filter,
vec!["userPassword", "uid", "cn", "mail", "displayName"]).await?;
let (entries, _res) = res.success()?;
let entries: Vec<SearchEntry> = entries
.into_iter()
.map(|e| SearchEntry::construct(e))
.collect();
if entries.is_empty() {
Ok(false)
} else if entries.len() > 1 {
warn!("Got multiple DNs for user ({}), unsure which one to use!!", username);
Ok(false)
} else {
let entry = entries.first().unwrap();
let res = self.ldap.simple_bind(&entry.dn, password).await?;
if res.rc == 0 {
Ok(true)
} else if res.rc == 49 {
warn!("User failed to auth (invalidCredentials, rc=49)!");
Ok(false)
} else {
// this would fail, its just here to propagate the error down
res.success()?;
Ok(false)
}
}
} */
}
#[async_trait]
impl AuthDriver for LdapAuthDriver {
async fn user_has_permission(&mut self, email: String, repository: String, permission: Permission, required_visibility: Option<RepositoryVisibility>) -> anyhow::Result<bool> {
self.bind().await?; self.bind().await?;
// Send a request to LDAP to check if the user is an admin // Send a request to LDAP to check if the user is an admin
@ -90,7 +49,14 @@ impl AuthDriver for LdapAuthDriver {
.map(|e| SearchEntry::construct(e)) .map(|e| SearchEntry::construct(e))
.collect(); .collect();
if entries.len() > 0 { Ok(entries.len() > 0)
}
}
#[async_trait]
impl AuthDriver for LdapAuthDriver {
async fn user_has_permission(&mut self, email: String, repository: String, permission: Permission, required_visibility: Option<RepositoryVisibility>) -> anyhow::Result<bool> {
if self.is_user_admin(email.clone()).await? {
Ok(true) Ok(true)
} else { } else {
debug!("LDAP is falling back to database"); debug!("LDAP is falling back to database");
@ -118,7 +84,7 @@ impl AuthDriver for LdapAuthDriver {
warn!("Got multiple DNs for user ({}), unsure which one to use!!", email); warn!("Got multiple DNs for user ({}), unsure which one to use!!", email);
Ok(false) Ok(false)
} else { } else {
let entry = entries.first().unwrap(); let entry = entries.first().unwrap(); // there will be an entry
let res = self.ldap.simple_bind(&entry.dn, &password).await?; let res = self.ldap.simple_bind(&entry.dn, &password).await?;
if res.rc == 0 { if res.rc == 0 {
@ -126,8 +92,22 @@ impl AuthDriver for LdapAuthDriver {
// Check if the user is stored in the database, if not, add it. // Check if the user is stored in the database, if not, add it.
let database = &self.database; let database = &self.database;
if !database.does_user_exist(email.clone()).await? { if !database.does_user_exist(email.clone()).await? {
let display_name = entry.attrs.get(&self.ldap_config.display_name_attribute).unwrap().first().unwrap().clone(); let display_name = match entry.attrs.get(&self.ldap_config.display_name_attribute) {
database.create_user(email, display_name, LoginSource::LDAP).await?; // theres no way the vector would be empty
Some(display) => display.first().unwrap().clone(),
None => return Ok(false),
};
database.create_user(email.clone(), display_name, LoginSource::LDAP).await?;
drop(database);
// Set the user registry type
let user_type = match self.is_user_admin(email.clone()).await? {
true => RegistryUserType::Admin,
false => RegistryUserType::Regular
};
self.database.set_user_registry_type(email, user_type).await?;
} }
Ok(true) Ok(true)

View File

@ -1,8 +1,8 @@
pub mod ldap_driver; pub mod ldap_driver;
use std::{collections::HashSet, ops::Deref, sync::Arc}; use std::{ops::Deref, sync::Arc};
use axum::{extract::{State, Path}, http::{StatusCode, HeaderMap, header, HeaderName, Request}, middleware::Next, response::{Response, IntoResponse}}; use axum::{extract::State, http::{StatusCode, HeaderMap, header, HeaderName, Request}, middleware::Next, response::{Response, IntoResponse}};
use sqlx::{Pool, Sqlite}; use sqlx::{Pool, Sqlite};
use tracing::debug; use tracing::debug;
@ -101,7 +101,7 @@ pub async fn require_auth<B>(State(state): State<Arc<AppState>>, mut request: Re
// If the token is not valid, return an unauthorized response // If the token is not valid, return an unauthorized response
let database = &state.database; let database = &state.database;
if let Some(user) = database.verify_user_token(token.to_string()).await.unwrap() { if let Ok(Some(user)) = database.verify_user_token(token.to_string()).await {
debug!("Authenticated user through middleware: {}", user.user.username); debug!("Authenticated user through middleware: {}", user.user.username);
request.extensions_mut().insert(user); request.extensions_mut().insert(user);

View File

@ -12,47 +12,48 @@ pub trait Database {
// Digest related functions // Digest related functions
/// Create the tables in the database /// Create the tables in the database
async fn create_schema(&self) -> sqlx::Result<()>; async fn create_schema(&self) -> anyhow::Result<()>;
// Tag related functions // Tag related functions
/// Get tags associated with a repository /// Get tags associated with a repository
async fn list_repository_tags(&self, repository: &str,) -> sqlx::Result<Vec<Tag>>; async fn list_repository_tags(&self, repository: &str,) -> anyhow::Result<Vec<Tag>>;
async fn list_repository_tags_page(&self, repository: &str, limit: u32, last_tag: Option<String>) -> sqlx::Result<Vec<Tag>>; async fn list_repository_tags_page(&self, repository: &str, limit: u32, last_tag: Option<String>) -> anyhow::Result<Vec<Tag>>;
/// Get a manifest digest using the tag name. /// Get a manifest digest using the tag name.
async fn get_tag(&self, repository: &str, tag: &str) -> sqlx::Result<Option<Tag>>; async fn get_tag(&self, repository: &str, tag: &str) -> anyhow::Result<Option<Tag>>;
/// Save a tag and reference it to the manifest digest. /// Save a tag and reference it to the manifest digest.
async fn save_tag(&self, repository: &str, tag: &str, manifest_digest: &str) -> sqlx::Result<()>; async fn save_tag(&self, repository: &str, tag: &str, manifest_digest: &str) -> anyhow::Result<()>;
/// Delete a tag. /// Delete a tag.
async fn delete_tag(&self, repository: &str, tag: &str) -> sqlx::Result<()>; async fn delete_tag(&self, repository: &str, tag: &str) -> anyhow::Result<()>;
// Manifest related functions // Manifest related functions
/// Get a manifest's content. /// Get a manifest's content.
async fn get_manifest(&self, repository: &str, digest: &str) -> sqlx::Result<Option<String>>; async fn get_manifest(&self, repository: &str, digest: &str) -> anyhow::Result<Option<String>>;
/// Save a manifest's content. /// Save a manifest's content.
async fn save_manifest(&self, repository: &str, digest: &str, content: &str) -> sqlx::Result<()>; async fn save_manifest(&self, repository: &str, digest: &str, content: &str) -> anyhow::Result<()>;
/// Delete a manifest /// Delete a manifest
/// Returns digests that this manifest pointed to. /// Returns digests that this manifest pointed to.
async fn delete_manifest(&self, repository: &str, digest: &str) -> sqlx::Result<Vec<String>>; async fn delete_manifest(&self, repository: &str, digest: &str) -> anyhow::Result<Vec<String>>;
async fn link_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> sqlx::Result<()>; async fn link_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> anyhow::Result<()>;
async fn unlink_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> sqlx::Result<()>; async fn unlink_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> anyhow::Result<()>;
// Repository related functions // Repository related functions
async fn has_repository(&self, repository: &str) -> sqlx::Result<bool>; async fn has_repository(&self, repository: &str) -> anyhow::Result<bool>;
async fn get_repository_visibility(&self, repository: &str) -> anyhow::Result<Option<RepositoryVisibility>>; async fn get_repository_visibility(&self, repository: &str) -> anyhow::Result<Option<RepositoryVisibility>>;
/// Create a repository /// Create a repository
async fn save_repository(&self, repository: &str, visibility: RepositoryVisibility, owning_project: Option<String>) -> sqlx::Result<()>; async fn save_repository(&self, repository: &str, visibility: RepositoryVisibility, owning_project: Option<String>) -> anyhow::Result<()>;
/// List all repositories. /// List all repositories.
/// If limit is not specified, a default limit of 1000 will be returned. /// If limit is not specified, a default limit of 1000 will be returned.
async fn list_repositories(&self, limit: Option<u32>, last_repo: Option<String>) -> sqlx::Result<Vec<String>>; async fn list_repositories(&self, limit: Option<u32>, last_repo: Option<String>) -> anyhow::Result<Vec<String>>;
/// User stuff /// User stuff
async fn does_user_exist(&self, email: String) -> sqlx::Result<bool>; async fn does_user_exist(&self, email: String) -> anyhow::Result<bool>;
async fn create_user(&self, email: String, username: String, login_source: LoginSource) -> sqlx::Result<User>; async fn create_user(&self, email: String, username: String, login_source: LoginSource) -> anyhow::Result<User>;
async fn add_user_auth(&self, email: String, password_hash: String, password_salt: String) -> sqlx::Result<()>; async fn add_user_auth(&self, email: String, password_hash: String, password_salt: String) -> anyhow::Result<()>;
async fn set_user_registry_type(&self, email: String, user_type: RegistryUserType) -> anyhow::Result<()>;
async fn verify_user_login(&self, email: String, password: String) -> anyhow::Result<bool>; async fn verify_user_login(&self, email: String, password: String) -> anyhow::Result<bool>;
async fn get_user_registry_type(&self, email: String) -> anyhow::Result<Option<RegistryUserType>>; async fn get_user_registry_type(&self, email: String) -> anyhow::Result<Option<RegistryUserType>>;
async fn get_user_repo_permissions(&self, email: String, repository: String) -> anyhow::Result<Option<RepositoryPermissions>>; async fn get_user_repo_permissions(&self, email: String, repository: String) -> anyhow::Result<Option<RepositoryPermissions>>;
@ -63,7 +64,7 @@ pub trait Database {
#[async_trait] #[async_trait]
impl Database for Pool<Sqlite> { impl Database for Pool<Sqlite> {
async fn create_schema(&self) -> sqlx::Result<()> { async fn create_schema(&self) -> anyhow::Result<()> {
sqlx::query(include_str!("schemas/schema.sql")) sqlx::query(include_str!("schemas/schema.sql"))
.execute(self).await?; .execute(self).await?;
@ -72,7 +73,7 @@ impl Database for Pool<Sqlite> {
Ok(()) Ok(())
} }
async fn link_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> sqlx::Result<()> { async fn link_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> anyhow::Result<()> {
sqlx::query("INSERT INTO manifest_layers(manifest, layer_digest) VALUES (?, ?)") sqlx::query("INSERT INTO manifest_layers(manifest, layer_digest) VALUES (?, ?)")
.bind(manifest_digest) .bind(manifest_digest)
.bind(layer_digest) .bind(layer_digest)
@ -83,7 +84,7 @@ impl Database for Pool<Sqlite> {
Ok(()) Ok(())
} }
async fn unlink_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> sqlx::Result<()> { async fn unlink_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> anyhow::Result<()> {
sqlx::query("DELETE FROM manifest_layers WHERE manifest = ? AND layer_digest = ?") sqlx::query("DELETE FROM manifest_layers WHERE manifest = ? AND layer_digest = ?")
.bind(manifest_digest) .bind(manifest_digest)
.bind(layer_digest) .bind(layer_digest)
@ -94,7 +95,7 @@ impl Database for Pool<Sqlite> {
Ok(()) Ok(())
} }
async fn list_repository_tags(&self, repository: &str,) -> sqlx::Result<Vec<Tag>> { async fn list_repository_tags(&self, repository: &str,) -> anyhow::Result<Vec<Tag>> {
let rows: Vec<(String, String, i64, )> = sqlx::query_as("SELECT name, image_manifest, last_updated FROM image_tags WHERE repository = ?") let rows: Vec<(String, String, i64, )> = sqlx::query_as("SELECT name, image_manifest, last_updated FROM image_tags WHERE repository = ?")
.bind(repository) .bind(repository)
.fetch_all(self).await?; .fetch_all(self).await?;
@ -108,7 +109,7 @@ impl Database for Pool<Sqlite> {
Ok(tags) Ok(tags)
} }
async fn list_repository_tags_page(&self, repository: &str, limit: u32, last_tag: Option<String>) -> sqlx::Result<Vec<Tag>> { async fn list_repository_tags_page(&self, repository: &str, limit: u32, last_tag: Option<String>) -> anyhow::Result<Vec<Tag>> {
// Query differently depending on if `last_tag` was specified // Query differently depending on if `last_tag` was specified
let rows: Vec<(String, String, i64, )> = match last_tag { let rows: Vec<(String, String, i64, )> = match last_tag {
Some(last_tag) => { Some(last_tag) => {
@ -135,7 +136,7 @@ impl Database for Pool<Sqlite> {
Ok(tags) Ok(tags)
} }
async fn get_tag(&self, repository: &str, tag: &str) -> sqlx::Result<Option<Tag>> { async fn get_tag(&self, repository: &str, tag: &str) -> anyhow::Result<Option<Tag>> {
debug!("get tag"); debug!("get tag");
let row: (String, i64, ) = match sqlx::query_as("SELECT image_manifest, last_updated FROM image_tags WHERE name = ? AND repository = ?") let row: (String, i64, ) = match sqlx::query_as("SELECT image_manifest, last_updated FROM image_tags WHERE name = ? AND repository = ?")
.bind(tag) .bind(tag)
@ -147,7 +148,7 @@ impl Database for Pool<Sqlite> {
return Ok(None) return Ok(None)
}, },
_ => { _ => {
return Err(e); return Err(anyhow::Error::new(e));
} }
} }
}; };
@ -157,7 +158,7 @@ impl Database for Pool<Sqlite> {
Ok(Some(Tag::new(tag.to_string(), repository.to_string(), last_updated, row.0))) Ok(Some(Tag::new(tag.to_string(), repository.to_string(), last_updated, row.0)))
} }
async fn save_tag(&self, repository: &str, tag: &str, digest: &str) -> sqlx::Result<()> { async fn save_tag(&self, repository: &str, tag: &str, digest: &str) -> anyhow::Result<()> {
sqlx::query("INSERT INTO image_tags (name, repository, image_manifest, last_updated) VALUES (?, ?, ?, ?)") sqlx::query("INSERT INTO image_tags (name, repository, image_manifest, last_updated) VALUES (?, ?, ?, ?)")
.bind(tag) .bind(tag)
.bind(repository) .bind(repository)
@ -168,7 +169,7 @@ impl Database for Pool<Sqlite> {
Ok(()) Ok(())
} }
async fn delete_tag(&self, repository: &str, tag: &str) -> sqlx::Result<()> { async fn delete_tag(&self, repository: &str, tag: &str) -> anyhow::Result<()> {
sqlx::query("DELETE FROM image_tags WHERE 'name' = ? AND repository = ?") sqlx::query("DELETE FROM image_tags WHERE 'name' = ? AND repository = ?")
.bind(tag) .bind(tag)
.bind(repository) .bind(repository)
@ -177,7 +178,7 @@ impl Database for Pool<Sqlite> {
Ok(()) Ok(())
} }
async fn get_manifest(&self, repository: &str, digest: &str) -> sqlx::Result<Option<String>> { async fn get_manifest(&self, repository: &str, digest: &str) -> anyhow::Result<Option<String>> {
let row: (String, ) = match sqlx::query_as("SELECT content FROM image_manifests where digest = ? AND repository = ?") let row: (String, ) = match sqlx::query_as("SELECT content FROM image_manifests where digest = ? AND repository = ?")
.bind(digest) .bind(digest)
.bind(repository) .bind(repository)
@ -188,7 +189,7 @@ impl Database for Pool<Sqlite> {
return Ok(None) return Ok(None)
}, },
_ => { _ => {
return Err(e); return Err(anyhow::Error::new(e));
} }
} }
}; };
@ -196,7 +197,7 @@ impl Database for Pool<Sqlite> {
Ok(Some(row.0)) Ok(Some(row.0))
} }
async fn save_manifest(&self, repository: &str, digest: &str, manifest: &str) -> sqlx::Result<()> { async fn save_manifest(&self, repository: &str, digest: &str, manifest: &str) -> anyhow::Result<()> {
sqlx::query("INSERT INTO image_manifests (digest, repository, content) VALUES (?, ?, ?)") sqlx::query("INSERT INTO image_manifests (digest, repository, content) VALUES (?, ?, ?)")
.bind(digest) .bind(digest)
.bind(repository) .bind(repository)
@ -206,7 +207,7 @@ impl Database for Pool<Sqlite> {
Ok(()) Ok(())
} }
async fn delete_manifest(&self, repository: &str, digest: &str) -> sqlx::Result<Vec<String>> { async fn delete_manifest(&self, repository: &str, digest: &str) -> anyhow::Result<Vec<String>> {
sqlx::query("DELETE FROM image_manifests where digest = ? AND repository = ?") sqlx::query("DELETE FROM image_manifests where digest = ? AND repository = ?")
.bind(digest) .bind(digest)
.bind(repository) .bind(repository)
@ -234,7 +235,7 @@ impl Database for Pool<Sqlite> {
Ok(digests) Ok(digests)
} }
async fn has_repository(&self, repository: &str) -> sqlx::Result<bool> { async fn has_repository(&self, repository: &str) -> anyhow::Result<bool> {
let row: (u32, ) = match sqlx::query_as("SELECT COUNT(1) FROM repositories WHERE \"name\" = ?") let row: (u32, ) = match sqlx::query_as("SELECT COUNT(1) FROM repositories WHERE \"name\" = ?")
.bind(repository) .bind(repository)
.fetch_one(self).await { .fetch_one(self).await {
@ -244,7 +245,7 @@ impl Database for Pool<Sqlite> {
return Ok(false) return Ok(false)
}, },
_ => { _ => {
return Err(e); return Err(anyhow::Error::new(e));
} }
} }
}; };
@ -270,7 +271,7 @@ impl Database for Pool<Sqlite> {
Ok(Some(RepositoryVisibility::try_from(row.0)?)) Ok(Some(RepositoryVisibility::try_from(row.0)?))
} }
async fn save_repository(&self, repository: &str, visibility: RepositoryVisibility, owning_project: Option<String>) -> sqlx::Result<()> { async fn save_repository(&self, repository: &str, visibility: RepositoryVisibility, owning_project: Option<String>) -> anyhow::Result<()> {
// ensure that the repository was not already created // ensure that the repository was not already created
if self.has_repository(repository).await? { if self.has_repository(repository).await? {
debug!("repo exists"); debug!("repo exists");
@ -297,8 +298,8 @@ impl Database for Pool<Sqlite> {
Ok(()) Ok(())
} }
//async fn list_repositories(&self) -> sqlx::Result<Vec<String>> { //async fn list_repositories(&self) -> anyhow::Result<Vec<String>> {
async fn list_repositories(&self, limit: Option<u32>, last_repo: Option<String>) -> sqlx::Result<Vec<String>> { async fn list_repositories(&self, limit: Option<u32>, last_repo: Option<String>) -> anyhow::Result<Vec<String>> {
let limit = limit.unwrap_or(1000); // set default limit let limit = limit.unwrap_or(1000); // set default limit
// Query differently depending on if `last_repo` was specified // Query differently depending on if `last_repo` was specified
@ -322,7 +323,7 @@ impl Database for Pool<Sqlite> {
Ok(repos) Ok(repos)
} }
async fn does_user_exist(&self, email: String) -> sqlx::Result<bool> { async fn does_user_exist(&self, email: String) -> anyhow::Result<bool> {
let row: (u32, ) = match sqlx::query_as("SELECT COUNT(1) FROM users WHERE \"email\" = ?") let row: (u32, ) = match sqlx::query_as("SELECT COUNT(1) FROM users WHERE \"email\" = ?")
.bind(email) .bind(email)
.fetch_one(self).await { .fetch_one(self).await {
@ -332,7 +333,7 @@ impl Database for Pool<Sqlite> {
return Ok(false) return Ok(false)
}, },
_ => { _ => {
return Err(e); return Err(anyhow::Error::new(e));
} }
} }
}; };
@ -340,7 +341,7 @@ impl Database for Pool<Sqlite> {
Ok(row.0 > 0) Ok(row.0 > 0)
} }
async fn create_user(&self, email: String, username: String, login_source: LoginSource) -> sqlx::Result<User> { async fn create_user(&self, email: String, username: String, login_source: LoginSource) -> anyhow::Result<User> {
let username = username.to_lowercase(); let username = username.to_lowercase();
let email = email.to_lowercase(); let email = email.to_lowercase();
sqlx::query("INSERT INTO users (username, email, login_source) VALUES (?, ?, ?)") sqlx::query("INSERT INTO users (username, email, login_source) VALUES (?, ?, ?)")
@ -352,7 +353,7 @@ impl Database for Pool<Sqlite> {
Ok(User::new(username, email, login_source)) Ok(User::new(username, email, login_source))
} }
async fn add_user_auth(&self, email: String, password_hash: String, password_salt: String) -> sqlx::Result<()> { async fn add_user_auth(&self, email: String, password_hash: String, password_salt: String) -> anyhow::Result<()> {
let email = email.to_lowercase(); let email = email.to_lowercase();
sqlx::query("INSERT INTO user_logins (email, password_hash, password_salt) VALUES (?, ?, ?)") sqlx::query("INSERT INTO user_logins (email, password_hash, password_salt) VALUES (?, ?, ?)")
.bind(email.clone()) .bind(email.clone())
@ -363,6 +364,16 @@ impl Database for Pool<Sqlite> {
Ok(()) Ok(())
} }
async fn set_user_registry_type(&self, email: String, user_type: RegistryUserType) -> anyhow::Result<()> {
let email = email.to_lowercase();
sqlx::query("INSERT INTO user_registry_permissions (email, user_type) VALUES (?, ?)")
.bind(email.clone())
.bind(user_type as u32)
.execute(self).await?;
Ok(())
}
async fn verify_user_login(&self, email: String, password: String) -> anyhow::Result<bool> { async fn verify_user_login(&self, email: String, password: String) -> anyhow::Result<bool> {
let email = email.to_lowercase(); let email = email.to_lowercase();
let row: (String, ) = sqlx::query_as("SELECT password_hash FROM users WHERE email = ?") let row: (String, ) = sqlx::query_as("SELECT password_hash FROM users WHERE email = ?")
@ -410,10 +421,17 @@ impl Database for Pool<Sqlite> {
} }
}; };
let vis = self.get_repository_visibility(&repository).await?.unwrap(); let vis = match self.get_repository_visibility(&repository).await? {
Some(v) => v,
None => return Ok(None),
};
// Also get the user type for the registry, if its admin return admin repository permissions // Also get the user type for the registry, if its admin return admin repository permissions
let utype = self.get_user_registry_usertype(email).await?.unwrap(); // unwrap should be safe let utype = match self.get_user_registry_usertype(email).await? {
Some(t) => t,
None => return Ok(None),
};
if utype == RegistryUserType::Admin { if utype == RegistryUserType::Admin {
Ok(Some(RepositoryPermissions::new(Permission::ADMIN.bits(), vis))) Ok(Some(RepositoryPermissions::new(Permission::ADMIN.bits(), vis)))
} else { } else {
@ -479,11 +497,15 @@ impl Database for Pool<Sqlite> {
.bind(email.clone()) .bind(email.clone())
.fetch_one(self).await?; */ .fetch_one(self).await?; */
let (expiry, created_at) = (Utc.timestamp_millis_opt(expiry).unwrap(), Utc.timestamp_millis_opt(created_at).unwrap()); let (expiry, created_at) = (Utc.timestamp_millis_opt(expiry).single(), Utc.timestamp_millis_opt(created_at).single());
if let (Some(expiry), Some(created_at)) = (expiry, created_at) {
let user = User::new(email, user_row.0, LoginSource::try_from(user_row.1)?); let user = User::new(email, user_row.0, LoginSource::try_from(user_row.1)?);
let token = TokenInfo::new(token, expiry, created_at); let token = TokenInfo::new(token, expiry, created_at);
let auth = UserAuth::new(user, token); let auth = UserAuth::new(user, token);
Ok(Some(auth)) Ok(Some(auth))
} else {
Ok(None)
}
} }
} }

26
src/error.rs Normal file
View File

@ -0,0 +1,26 @@
use axum::{response::{IntoResponse, Response}, http::StatusCode};
// Make our own error that wraps `anyhow::Error`.
pub struct AppError(anyhow::Error);
// Tell axum how to convert `AppError` into a response.
impl IntoResponse for AppError {
fn into_response(self) -> Response {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Something went wrong: {}", self.0),
)
.into_response()
}
}
// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into
// `Result<_, AppError>`. That way you don't need to do that manually.
impl<E> From<E> for AppError
where
E: Into<anyhow::Error>,
{
fn from(err: E) -> Self {
Self(err.into())
}
}

View File

@ -5,20 +5,19 @@ mod dto;
mod storage; mod storage;
mod byte_stream; mod byte_stream;
mod config; mod config;
mod query;
mod auth; mod auth;
mod error;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use auth::{AuthDriver, ldap_driver::LdapAuthDriver}; use auth::{AuthDriver, ldap_driver::LdapAuthDriver};
use axum::http::{Request, StatusCode, header, HeaderName}; use axum::http::{Request, StatusCode};
use axum::middleware::Next; use axum::middleware::Next;
use axum::response::{Response, IntoResponse}; use axum::response::Response;
use axum::{Router, routing}; use axum::{Router, routing};
use axum::ServiceExt; use axum::ServiceExt;
use bcrypt::Version;
use tower_layer::Layer; use tower_layer::Layer;
use sqlx::sqlite::SqlitePoolOptions; use sqlx::sqlite::SqlitePoolOptions;
@ -29,21 +28,21 @@ use tracing::{debug, Level};
use app_state::AppState; use app_state::AppState;
use database::Database; use database::Database;
use crate::dto::user::Permission;
use crate::storage::StorageDriver; use crate::storage::StorageDriver;
use crate::storage::filesystem::FilesystemDriver; use crate::storage::filesystem::FilesystemDriver;
use crate::config::{Config, LdapConnectionConfig}; use crate::config::Config;
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
/// Encode the 'name' path parameter in the url /// Encode the 'name' path parameter in the url
async fn change_request_paths<B>(mut request: Request<B>, next: Next<B>) -> Response { async fn change_request_paths<B>(mut request: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {
// Attempt to find the name using regex in the url // Attempt to find the name using regex in the url
let regex = regex::Regex::new(r"/v2/([\w/]+)/(blobs|tags|manifests)").unwrap(); let regex = regex::Regex::new(r"/v2/([\w/]+)/(blobs|tags|manifests)")
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let captures = match regex.captures(request.uri().path()) { let captures = match regex.captures(request.uri().path()) {
Some(captures) => captures, Some(captures) => captures,
None => return next.run(request).await, None => return Ok(next.run(request).await),
}; };
// Find the name in the request and encode it in the url // Find the name in the request and encode it in the url
@ -54,24 +53,25 @@ async fn change_request_paths<B>(mut request: Request<B>, next: Next<B>) -> Resp
let uri_str = request.uri().to_string().replace(&name, &encoded_name); let uri_str = request.uri().to_string().replace(&name, &encoded_name);
debug!("Rewrote request url to: '{}'", uri_str); debug!("Rewrote request url to: '{}'", uri_str);
*request.uri_mut() = uri_str.parse().unwrap(); *request.uri_mut() = uri_str.parse()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
next.run(request).await Ok(next.run(request).await)
} }
#[tokio::main] #[tokio::main]
async fn main() -> std::io::Result<()> { async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt() tracing_subscriber::fmt()
.with_max_level(Level::DEBUG) .with_max_level(Level::DEBUG)
.init(); .init();
let config = Config::new().expect("Failure to parse config!"); let config = Config::new()
.expect("Failure to parse config!");
let pool = SqlitePoolOptions::new() let pool = SqlitePoolOptions::new()
.max_connections(15) .max_connections(15)
.connect("test.db").await.unwrap(); .connect("test.db").await?;
pool.create_schema().await?;
pool.create_schema().await.unwrap();
let storage_driver: Mutex<Box<dyn StorageDriver>> = Mutex::new(Box::new(FilesystemDriver::new("registry/blobs"))); let storage_driver: Mutex<Box<dyn StorageDriver>> = Mutex::new(Box::new(FilesystemDriver::new("registry/blobs")));
@ -79,7 +79,7 @@ async fn main() -> std::io::Result<()> {
// the fallback is a database auth driver. // the fallback is a database auth driver.
let auth_driver: Mutex<Box<dyn AuthDriver>> = match config.ldap.clone() { let auth_driver: Mutex<Box<dyn AuthDriver>> = match config.ldap.clone() {
Some(ldap) => { Some(ldap) => {
let ldap_driver = LdapAuthDriver::new(ldap, pool.clone()).await.unwrap(); let ldap_driver = LdapAuthDriver::new(ldap, pool.clone()).await?;
Mutex::new(Box::new(ldap_driver)) Mutex::new(Box::new(ldap_driver))
}, },
None => { None => {
@ -87,7 +87,7 @@ async fn main() -> std::io::Result<()> {
} }
}; };
let app_addr = SocketAddr::from_str(&format!("{}:{}", config.listen_address, config.listen_port)).unwrap(); let app_addr = SocketAddr::from_str(&format!("{}:{}", config.listen_address, config.listen_port))?;
let state = Arc::new(AppState::new(pool, storage_driver, config, auth_driver)); let state = Arc::new(AppState::new(pool, storage_driver, config, auth_driver));
@ -129,8 +129,7 @@ async fn main() -> std::io::Result<()> {
debug!("Starting http server, listening on {}", app_addr); debug!("Starting http server, listening on {}", app_addr);
axum::Server::bind(&app_addr) axum::Server::bind(&app_addr)
.serve(layered_app.into_make_service()) .serve(layered_app.into_make_service())
.await .await?;
.unwrap();
Ok(()) Ok(())
} }

View File

@ -1,6 +1,6 @@
use std::{path::Path, io::ErrorKind}; use std::{path::Path, io::ErrorKind};
use anyhow::Context; use anyhow::{Context, anyhow};
use async_trait::async_trait; use async_trait::async_trait;
use bytes::Bytes; use bytes::Bytes;
use futures::StreamExt; use futures::StreamExt;
@ -53,7 +53,7 @@ impl StorageDriver for FilesystemDriver {
len += bytes.len(); len += bytes.len();
file.write_all(&bytes).await.unwrap(); file.write_all(&bytes).await?;
} }
Ok(len) Ok(len)
@ -140,9 +140,12 @@ impl StorageDriver for FilesystemDriver {
async fn replace_digest(&self, uuid: &str, digest: &str) -> anyhow::Result<()> { async fn replace_digest(&self, uuid: &str, digest: &str) -> anyhow::Result<()> {
let path = self.get_digest_path(uuid); let path = self.get_digest_path(uuid);
let path = Path::new(&path); let path = Path::new(&path);
let parent = path.clone().parent().unwrap(); let parent = path
.clone()
.parent()
.ok_or(anyhow!("Failure to get parent path of digest file!"))?;
fs::rename(path, format!("{}/{}", parent.as_os_str().to_str().unwrap(), digest)).await?; fs::rename(path, format!("{}/{}", parent.display(), digest)).await?;
Ok(()) Ok(())
} }