remove ALL necessary unwraps

This commit is contained in:
SeanOMik 2023-05-29 23:36:30 -04:00
parent 7637faf047
commit c4b25e1415
Signed by: SeanOMik
GPG Key ID: 568F326C7EB33ACB
13 changed files with 305 additions and 256 deletions

View File

@ -1,5 +1,9 @@
- [x] slashes in repo names
- [ ] Simple auth
- [x] Simple auth
- [x] ldap auth
- [ ] permission stuff
- [ ] Only allow users to create repositories if its the same name as their username, or if they're an admin
- [ ] Only allow users to pull from their own repositories
- [ ] postgresql
- [ ] prometheus metrics
- [x] streaming layer bytes into providers

View File

@ -56,9 +56,9 @@ fn create_jwt_token(account: &str) -> anyhow::Result<TokenInfo> {
claims.insert("subject", &account);
//claims.insert("audience", auth.service);
let not_before = format!("{}", now_secs - 10);
let not_before = format!("{}", now_secs);
let issued_at = format!("{}", now_secs);
let expiration = format!("{}", now_secs + 20);
let expiration = format!("{}", now_secs + 86400); // 1 day
claims.insert("notbefore", &not_before);
claims.insert("issuedat", &issued_at);
claims.insert("expiration", &expiration); // TODO: 20 seconds expiry for testing
@ -75,7 +75,7 @@ fn create_jwt_token(account: &str) -> anyhow::Result<TokenInfo> {
Ok(TokenInfo::new(token_str, expiration, issued_at))
}
pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppState>>, Query(params): Query<HashMap<String, String>>, form: Option<Form<AuthForm>>) -> Response {
pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppState>>, Query(params): Query<HashMap<String, String>>, form: Option<Form<AuthForm>>) -> Result<Response, StatusCode> {
let mut auth = TokenAuthRequest {
user: None,
password: None,
@ -117,7 +117,7 @@ pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppS
info!("Auth failure! Auth was not provided in either AuthBasic or Form!");
// Maybe BAD_REQUEST should be returned?
return (StatusCode::UNAUTHORIZED).into_response();
return Err(StatusCode::UNAUTHORIZED);
}
// Create logging span for the rest of this request
@ -133,7 +133,7 @@ pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppS
if account != user {
error!("`user` and `account` are not the same!!! (user: {}, account: {})", user, account);
return (StatusCode::BAD_REQUEST).into_response();
return Err(StatusCode::BAD_REQUEST);
}
}
@ -149,7 +149,14 @@ pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppS
if let Some(scope) = params.get("scope") {
// TODO: Handle multiple scopes
auth.scope.push(Scope::try_from(&scope[..]).unwrap());
match Scope::try_from(&scope[..]) {
Ok(scope) => {
auth.scope.push(scope);
},
Err(_) => {
return Err(StatusCode::BAD_REQUEST);
}
}
}
// Get offline token and attempt to convert it to a boolean
@ -168,17 +175,19 @@ pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppS
if let (Some(account), Some(password)) = (&auth.account, auth.password) {
// Ensure that the password is correct
let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.verify_user_login(account.clone(), password).await.unwrap() {
if !auth_driver.verify_user_login(account.clone(), password).await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? {
debug!("Authentication failed, incorrect password!");
return unauthenticated_response(&state.config);
return Ok(unauthenticated_response(&state.config));
}
drop(auth_driver);
debug!("User password is correct");
let now = SystemTime::now();
let token = create_jwt_token(account).unwrap();
let token = create_jwt_token(account)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let token_str = token.token;
debug!("Created jwt token");
@ -194,23 +203,25 @@ pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppS
issued_at: now_format,
};
let json_str = serde_json::to_string(&auth_response).unwrap();
let json_str = serde_json::to_string(&auth_response)
.map_err(|_| StatusCode::BAD_REQUEST)?;
let database = &state.database;
database.store_user_token(token_str.clone(), account.clone(), token.expiry, token.created_at).await.unwrap();
database.store_user_token(token_str.clone(), account.clone(), token.expiry, token.created_at).await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
drop(database);
return (
return Ok((
StatusCode::OK,
[
( header::CONTENT_TYPE, "application/json" ),
( header::AUTHORIZATION, &format!("Bearer {}", token_str) )
],
json_str
).into_response();
).into_response());
}
info!("Auth failure! Not enough information given to create auth token!");
// If we didn't get fields required to make a token, then the client did something bad
(StatusCode::UNAUTHORIZED).into_response()
Err(StatusCode::UNAUTHORIZED)
}

View File

@ -8,64 +8,69 @@ use axum::response::{IntoResponse, Response};
use tokio_util::io::ReaderStream;
use crate::app_state::AppState;
use crate::auth::{unauthenticated_response, AuthDriver};
use crate::database::Database;
use crate::auth::unauthenticated_response;
use crate::dto::RepositoryVisibility;
use crate::dto::user::{Permission, RegistryUserType, UserAuth};
use crate::dto::user::{Permission, UserAuth};
use crate::error::AppError;
pub async fn digest_exists_head(Path((name, layer_digest)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Response {
pub async fn digest_exists_head(Path((name, layer_digest)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Result<Response, AppError> {
// Check if the user has permission to pull, or that the repository is public
let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await.unwrap() {
return unauthenticated_response(&state.config);
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await? {
return Ok(unauthenticated_response(&state.config));
}
drop(auth_driver);
let storage = state.storage.lock().await;
if storage.has_digest(&layer_digest).await.unwrap() {
if let Some(size) = storage.digest_length(&layer_digest).await.unwrap() {
return (
if storage.has_digest(&layer_digest).await? {
if let Some(size) = storage.digest_length(&layer_digest).await? {
return Ok((
StatusCode::OK,
[
(header::CONTENT_LENGTH, size.to_string()),
(HeaderName::from_static("docker-content-digest"), layer_digest)
]
).into_response();
).into_response());
}
}
StatusCode::NOT_FOUND.into_response()
Ok(StatusCode::NOT_FOUND.into_response())
}
pub async fn pull_digest_get(Path((name, layer_digest)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Response {
pub async fn pull_digest_get(Path((name, layer_digest)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Result<Response, AppError> {
// Check if the user has permission to pull, or that the repository is public
let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await.unwrap() {
return unauthenticated_response(&state.config);
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await? {
return Ok(unauthenticated_response(&state.config));
}
drop(auth_driver);
let storage = state.storage.lock().await;
if let Some(len) = storage.digest_length(&layer_digest).await.unwrap() {
let stream = storage.get_digest_stream(&layer_digest).await.unwrap().unwrap();
if let Some(len) = storage.digest_length(&layer_digest).await? {
let stream = match storage.get_digest_stream(&layer_digest).await? {
Some(s) => s,
None => {
return Ok(StatusCode::NOT_FOUND.into_response());
}
};
// convert the `AsyncRead` into a `Stream`
let stream = ReaderStream::new(stream.into_async_read());
// convert the `Stream` into an `axum::body::HttpBody`
let body = StreamBody::new(stream);
(
Ok((
StatusCode::OK,
[
(header::CONTENT_LENGTH, len.to_string()),
(HeaderName::from_static("docker-content-digest"), layer_digest)
],
body
).into_response()
).into_response())
} else {
StatusCode::NOT_FOUND.into_response()
Ok(StatusCode::NOT_FOUND.into_response())
}
}

View File

@ -1,9 +1,9 @@
use std::sync::Arc;
use axum::{extract::{State, Query}, http::{StatusCode, header, HeaderMap, HeaderName}, response::IntoResponse};
use axum::{extract::{State, Query}, http::{StatusCode, header, HeaderMap, HeaderName}, response::{IntoResponse, Response}};
use serde::{Serialize, Deserialize};
use crate::{app_state::AppState, database::Database};
use crate::{app_state::AppState, database::Database, error::AppError};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
@ -20,7 +20,7 @@ pub struct ListRepositoriesParams {
last_repo: Option<String>,
}
pub async fn list_repositories(Query(params): Query<ListRepositoriesParams>, state: State<Arc<AppState>>) -> impl IntoResponse {
pub async fn list_repositories(Query(params): Query<ListRepositoriesParams>, state: State<Arc<AppState>>) -> Result<Response, AppError> {
let mut link_header = None;
// Paginate tag results if n was specified, else just pull everything.
@ -30,7 +30,7 @@ pub async fn list_repositories(Query(params): Query<ListRepositoriesParams>, sta
// Convert the last param to a String, and list all the repos
let last_repo = params.last_repo.and_then(|t| Some(t.to_string()));
let repos = database.list_repositories(Some(limit), last_repo).await.unwrap();
let repos = database.list_repositories(Some(limit), last_repo).await?;
// Get the new last repository for the response
let last_repo = repos.last().and_then(|s| Some(s.clone()));
@ -47,7 +47,7 @@ pub async fn list_repositories(Query(params): Query<ListRepositoriesParams>, sta
repos
},
None => {
database.list_repositories(None, None).await.unwrap()
database.list_repositories(None, None).await?
}
};
@ -55,20 +55,20 @@ pub async fn list_repositories(Query(params): Query<ListRepositoriesParams>, sta
let repo_list = RepositoryList {
repositories,
};
let response_body = serde_json::to_string(&repo_list).unwrap();
let response_body = serde_json::to_string(&repo_list)?;
let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, "application/json".parse().unwrap());
headers.insert(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".parse().unwrap());
headers.insert(header::CONTENT_TYPE, "application/json".parse()?);
headers.insert(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".parse()?);
if let Some(link_header) = link_header {
headers.insert(header::LINK, link_header.parse().unwrap());
headers.insert(header::LINK, link_header.parse()?);
}
// Construct the response, optionally adding the Link header if it was constructed.
(
Ok((
StatusCode::OK,
headers,
response_body
)
).into_response())
}

View File

@ -3,22 +3,23 @@ use std::sync::Arc;
use axum::Extension;
use axum::extract::{Path, State};
use axum::response::{Response, IntoResponse};
use axum::http::{StatusCode, HeaderMap, HeaderName, header};
use axum::http::{StatusCode, HeaderName, header};
use tracing::log::warn;
use tracing::{debug, info};
use crate::auth::{unauthenticated_response, AuthDriver};
use crate::auth::unauthenticated_response;
use crate::app_state::AppState;
use crate::database::Database;
use crate::dto::RepositoryVisibility;
use crate::dto::digest::Digest;
use crate::dto::manifest::Manifest;
use crate::dto::user::{UserAuth, Permission};
use crate::error::AppError;
pub async fn upload_manifest_put(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>, body: String) -> Response {
pub async fn upload_manifest_put(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>, body: String) -> Result<Response, AppError> {
let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() {
return unauthenticated_response(&state.config);
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? {
return Ok(unauthenticated_response(&state.config));
}
drop(auth_driver);
@ -29,45 +30,45 @@ pub async fn upload_manifest_put(Path((name, reference)): Path<(String, String)>
let database = &state.database;
// Create the image repository and save the image manifest. This repository will be private by default
database.save_repository(&name, RepositoryVisibility::Private, None).await.unwrap();
database.save_manifest(&name, &calculated_digest, &body).await.unwrap();
database.save_repository(&name, RepositoryVisibility::Private, None).await?;
database.save_manifest(&name, &calculated_digest, &body).await?;
// If the reference is not a digest, then it must be a tag name.
if !Digest::is_digest(&reference) {
database.save_tag(&name, &reference, &calculated_digest).await.unwrap();
database.save_tag(&name, &reference, &calculated_digest).await?;
}
info!("Saved manifest {}", calculated_digest);
match serde_json::from_str(&body).unwrap() {
match serde_json::from_str(&body)? {
Manifest::Image(image) => {
// Link the manifest to the image layer
database.link_manifest_layer(&calculated_digest, &image.config.digest).await.unwrap();
database.link_manifest_layer(&calculated_digest, &image.config.digest).await?;
debug!("Linked manifest {} to layer {}", calculated_digest, image.config.digest);
for layer in image.layers {
database.link_manifest_layer(&calculated_digest, &layer.digest).await.unwrap();
database.link_manifest_layer(&calculated_digest, &layer.digest).await?;
debug!("Linked manifest {} to layer {}", calculated_digest, image.config.digest);
}
(
Ok((
StatusCode::CREATED,
[ (HeaderName::from_static("docker-content-digest"), calculated_digest) ]
).into_response()
).into_response())
},
Manifest::List(_list) => {
warn!("ManifestList request was received!");
StatusCode::NOT_IMPLEMENTED.into_response()
Ok(StatusCode::NOT_IMPLEMENTED.into_response())
}
}
}
pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Response {
pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Result<Response, AppError> {
// Check if the user has permission to pull, or that the repository is public
let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await.unwrap() {
return unauthenticated_response(&state.config);
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await? {
return Ok(unauthenticated_response(&state.config));
}
drop(auth_driver);
@ -76,24 +77,24 @@ pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>,
true => reference.clone(),
false => {
debug!("Attempting to get manifest digest using tag (repository={}, reference={})", name, reference);
if let Some(tag) = database.get_tag(&name, &reference).await.unwrap() {
if let Some(tag) = database.get_tag(&name, &reference).await? {
tag.manifest_digest
} else {
return StatusCode::NOT_FOUND.into_response();
return Ok(StatusCode::NOT_FOUND.into_response());
}
}
};
let manifest_content = database.get_manifest(&name, &digest).await.unwrap();
let manifest_content = database.get_manifest(&name, &digest).await?;
if manifest_content.is_none() {
debug!("Failed to get manifest in repo {}, for digest {}", name, digest);
// The digest that was provided in the request was invalid.
// NOTE: This could also mean that there's a bug and the tag pointed to an invalid manifest.
return StatusCode::NOT_FOUND.into_response();
return Ok(StatusCode::NOT_FOUND.into_response());
}
let manifest_content = manifest_content.unwrap();
(
Ok((
StatusCode::OK,
[
(HeaderName::from_static("docker-content-digest"), digest),
@ -103,14 +104,14 @@ pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>,
(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string()),
],
manifest_content
).into_response()
).into_response())
}
pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Response {
pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Result<Response, AppError> {
// Check if the user has permission to pull, or that the repository is public
let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await.unwrap() {
return unauthenticated_response(&state.config);
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await? {
return Ok(unauthenticated_response(&state.config));
}
drop(auth_driver);
@ -119,23 +120,23 @@ pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)
let digest = match Digest::is_digest(&reference) {
true => reference.clone(),
false => {
if let Some(tag) = database.get_tag(&name, &reference).await.unwrap() {
if let Some(tag) = database.get_tag(&name, &reference).await? {
tag.manifest_digest
} else {
return StatusCode::NOT_FOUND.into_response();
return Ok(StatusCode::NOT_FOUND.into_response());
}
}
};
let manifest_content = database.get_manifest(&name, &digest).await.unwrap();
let manifest_content = database.get_manifest(&name, &digest).await?;
if manifest_content.is_none() {
// The digest that was provided in the request was invalid.
// NOTE: This could also mean that there's a bug and the tag pointed to an invalid manifest.
return StatusCode::NOT_FOUND.into_response();
return Ok(StatusCode::NOT_FOUND.into_response());
}
let manifest_content = manifest_content.unwrap();
(
Ok((
StatusCode::OK,
[
(HeaderName::from_static("docker-content-digest"), digest),
@ -144,43 +145,41 @@ pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)
(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string()),
],
manifest_content
).into_response()
).into_response())
}
pub async fn delete_manifest(Path((name, reference)): Path<(String, String)>, headers: HeaderMap, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Response {
pub async fn delete_manifest(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Result<Response, AppError> {
let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() {
return unauthenticated_response(&state.config);
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? {
return Ok(unauthenticated_response(&state.config));
}
drop(auth_driver);
let _authorization = headers.get("Authorization").unwrap(); // TODO: use authorization header
let database = &state.database;
let digest = match Digest::is_digest(&reference) {
true => {
// Check if the manifest exists
if database.get_manifest(&name, &reference).await.unwrap().is_none() {
return StatusCode::NOT_FOUND.into_response();
if database.get_manifest(&name, &reference).await?.is_none() {
return Ok(StatusCode::NOT_FOUND.into_response());
}
reference.clone()
},
false => {
if let Some(tag) = database.get_tag(&name, &reference).await.unwrap() {
if let Some(tag) = database.get_tag(&name, &reference).await? {
tag.manifest_digest
} else {
return StatusCode::NOT_FOUND.into_response();
return Ok(StatusCode::NOT_FOUND.into_response());
}
}
};
database.delete_manifest(&name, &digest).await.unwrap();
database.delete_manifest(&name, &digest).await?;
(
Ok((
StatusCode::ACCEPTED,
[
(header::CONTENT_LENGTH, "None"),
],
).into_response()
).into_response())
}

View File

@ -1,9 +1,9 @@
use std::sync::Arc;
use axum::{extract::{Path, Query, State}, response::IntoResponse, http::{StatusCode, header, HeaderMap, HeaderName}};
use axum::{extract::{Path, Query, State}, response::{IntoResponse, Response}, http::{StatusCode, header, HeaderMap, HeaderName}};
use serde::{Serialize, Deserialize};
use crate::{app_state::AppState, database::Database};
use crate::{app_state::AppState, database::Database, error::AppError};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
@ -21,7 +21,7 @@ pub struct ListRepositoriesParams {
last_tag: Option<String>,
}
pub async fn list_tags(Path((name, )): Path<(String, )>, Query(params): Query<ListRepositoriesParams>, state: State<Arc<AppState>>) -> impl IntoResponse {
pub async fn list_tags(Path((name, )): Path<(String, )>, Query(params): Query<ListRepositoriesParams>, state: State<Arc<AppState>>) -> Result<Response, AppError> {
let mut link_header = None;
// Paginate tag results if n was specified, else just pull everything.
@ -31,7 +31,7 @@ pub async fn list_tags(Path((name, )): Path<(String, )>, Query(params): Query<Li
// Convert the last param to a String, and list all the tags
let last_tag = params.last_tag.and_then(|t| Some(t.to_string()));
let tags = database.list_repository_tags_page(&name, limit, last_tag).await.unwrap();
let tags = database.list_repository_tags_page(&name, limit, last_tag).await?;
// Get the new last repository for the response
let last_tag = tags.last();
@ -48,7 +48,7 @@ pub async fn list_tags(Path((name, )): Path<(String, )>, Query(params): Query<Li
tags
},
None => {
database.list_repository_tags(&name).await.unwrap()
database.list_repository_tags(&name).await?
}
};
@ -57,21 +57,21 @@ pub async fn list_tags(Path((name, )): Path<(String, )>, Query(params): Query<Li
name,
tags: tags.into_iter().map(|t| t.name).collect(),
};
let response_body = serde_json::to_string(&tag_list).unwrap();
let response_body = serde_json::to_string(&tag_list)?;
// Create headers
let mut headers = HeaderMap::new();
headers.insert(header::CONTENT_TYPE, "application/json".parse().unwrap());
headers.insert(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".parse().unwrap());
headers.insert(header::CONTENT_TYPE, "application/json".parse()?);
headers.insert(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".parse()?);
// Add the link header if it was constructed
if let Some(link_header) = link_header {
headers.insert(header::LINK, link_header.parse().unwrap());
headers.insert(header::LINK, link_header.parse()?);
}
(
Ok((
StatusCode::OK,
headers,
response_body
)
).into_response())
}

View File

@ -12,15 +12,15 @@ use futures::StreamExt;
use tracing::{debug, warn};
use crate::app_state::AppState;
use crate::auth::{unauthenticated_response, AuthDriver};
use crate::auth::unauthenticated_response;
use crate::byte_stream::ByteStream;
use crate::database::Database;
use crate::dto::user::{UserAuth, Permission, RegistryUser, RegistryUserType};
use crate::dto::user::{UserAuth, Permission};
use crate::error::AppError;
/// Starting an upload
pub async fn start_upload_post(Path((name, )): Path<(String, )>, Extension(auth): Extension<UserAuth>, state: State<Arc<AppState>>) -> Response {
pub async fn start_upload_post(Path((name, )): Path<(String, )>, Extension(auth): Extension<UserAuth>, state: State<Arc<AppState>>) -> Result<Response, AppError> {
let mut auth_driver = state.auth_checker.lock().await;
if auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() {
if auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? {
debug!("Upload requested");
let uuid = uuid::Uuid::new_v4();
@ -29,24 +29,24 @@ pub async fn start_upload_post(Path((name, )): Path<(String, )>, Extension(auth)
let location = format!("/v2/{}/blobs/uploads/{}", name, uuid.to_string());
debug!("Constructed upload url: {}", location);
return (
return Ok((
StatusCode::ACCEPTED,
[ (header::LOCATION, location) ]
).into_response();
).into_response());
}
unauthenticated_response(&state.config)
Ok(unauthenticated_response(&state.config))
}
pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String, String)>, Extension(auth): Extension<UserAuth>, state: State<Arc<AppState>>, mut body: BodyStream) -> Response {
pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String, String)>, Extension(auth): Extension<UserAuth>, state: State<Arc<AppState>>, mut body: BodyStream) -> Result<Response, AppError> {
let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() {
return unauthenticated_response(&state.config);
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? {
return Ok(unauthenticated_response(&state.config));
}
drop(auth_driver);
let storage = state.storage.lock().await;
let current_size = storage.digest_length(&layer_uuid).await.unwrap();
let current_size = storage.digest_length(&layer_uuid).await?;
let written_size = match storage.supports_streaming().await {
true => {
@ -61,7 +61,7 @@ pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String,
};
let byte_stream = ByteStream::new(io_stream);
let len = storage.save_digest_stream(&layer_uuid, byte_stream, true).await.unwrap();
let len = storage.save_digest_stream(&layer_uuid, byte_stream, true).await?;
len
},
@ -70,11 +70,11 @@ pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String,
let mut bytes = BytesMut::new();
while let Some(item) = body.next().await {
bytes.extend_from_slice(&item.unwrap());
bytes.extend_from_slice(&item?);
}
let bytes_len = bytes.len();
storage.save_digest(&layer_uuid, &bytes.into(), true).await.unwrap();
storage.save_digest(&layer_uuid, &bytes.into(), true).await?;
bytes_len
}
};
@ -86,7 +86,7 @@ pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String,
};
let full_uri = format!("{}/v2/{}/blobs/uploads/{}", state.config.get_url(), name, layer_uuid);
(
Ok((
StatusCode::ACCEPTED,
[
(header::LOCATION, full_uri),
@ -94,13 +94,13 @@ pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String,
(header::CONTENT_LENGTH, "0".to_string()),
(HeaderName::from_static("docker-upload-uuid"), layer_uuid)
]
).into_response()
).into_response())
}
pub async fn finish_chunked_upload_put(Path((name, layer_uuid)): Path<(String, String)>, Query(query): Query<HashMap<String, String>>, Extension(auth): Extension<UserAuth>, state: State<Arc<AppState>>, body: Bytes) -> Response {
pub async fn finish_chunked_upload_put(Path((name, layer_uuid)): Path<(String, String)>, Query(query): Query<HashMap<String, String>>, Extension(auth): Extension<UserAuth>, state: State<Arc<AppState>>, body: Bytes) -> Result<Response, AppError> {
let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() {
return unauthenticated_response(&state.config);
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? {
return Ok(unauthenticated_response(&state.config));
}
drop(auth_driver);
@ -108,54 +108,54 @@ pub async fn finish_chunked_upload_put(Path((name, layer_uuid)): Path<(String, S
let storage = state.storage.lock().await;
if !body.is_empty() {
storage.save_digest(&layer_uuid, &body, true).await.unwrap();
storage.save_digest(&layer_uuid, &body, true).await?;
} else {
// TODO: Validate layer with all digest params
}
storage.replace_digest(&layer_uuid, &digest).await.unwrap();
storage.replace_digest(&layer_uuid, &digest).await?;
debug!("Completed upload, finished uuid {} to digest {}", layer_uuid, digest);
(
Ok((
StatusCode::CREATED,
[
(header::LOCATION, format!("/v2/{}/blobs/{}", name, digest)),
(header::CONTENT_LENGTH, "0".to_string()),
(HeaderName::from_static("docker-upload-digest"), digest.to_owned())
]
).into_response()
).into_response())
}
pub async fn cancel_upload_delete(Path((name, layer_uuid)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Response {
pub async fn cancel_upload_delete(Path((name, layer_uuid)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Result<Response, AppError> {
let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() {
return unauthenticated_response(&state.config);
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? {
return Ok(unauthenticated_response(&state.config));
}
drop(auth_driver);
let storage = state.storage.lock().await;
storage.delete_digest(&layer_uuid).await.unwrap();
storage.delete_digest(&layer_uuid).await?;
// I'm not sure what this response should be, its not specified in the registry spec.
StatusCode::OK.into_response()
Ok(StatusCode::OK.into_response())
}
pub async fn check_upload_status_get(Path((name, layer_uuid)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Response {
pub async fn check_upload_status_get(Path((name, layer_uuid)): Path<(String, String)>, state: State<Arc<AppState>>, Extension(auth): Extension<UserAuth>) -> Result<Response, AppError> {
let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await.unwrap() {
return unauthenticated_response(&state.config);
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? {
return Ok(unauthenticated_response(&state.config));
}
drop(auth_driver);
let storage = state.storage.lock().await;
let ending = storage.digest_length(&layer_uuid).await.unwrap().unwrap_or(0);
let ending = storage.digest_length(&layer_uuid).await?.unwrap_or(0);
(
Ok((
StatusCode::CREATED,
[
(header::LOCATION, format!("/v2/{}/blobs/uploads/{}", name, layer_uuid)),
(header::RANGE, format!("0-{}", ending)),
(HeaderName::from_static("docker-upload-digest"), layer_uuid)
]
).into_response()
).into_response())
}

View File

@ -1,11 +1,9 @@
use std::{slice::Iter, iter::Peekable};
use async_trait::async_trait;
use ldap3::{LdapConnAsync, Ldap, Scope, asn1::PL, ResultEntry, SearchEntry};
use ldap3::{LdapConnAsync, Ldap, Scope, SearchEntry};
use sqlx::{Pool, Sqlite};
use tracing::{debug, warn};
use crate::{config::LdapConnectionConfig, dto::{user::{Permission, LoginSource}, RepositoryVisibility}, database::Database};
use crate::{config::LdapConnectionConfig, dto::{user::{Permission, LoginSource, RegistryUserType}, RepositoryVisibility}, database::Database};
use super::AuthDriver;
@ -37,46 +35,7 @@ impl LdapAuthDriver {
Ok(())
}
/* pub async fn verify_login(&mut self, username: &str, password: &str) -> anyhow::Result<bool> {
self.bind().await?;
let filter = self.ldap_config.user_search_filter.replace("%s", &username);
let res = self.ldap.search(&self.ldap_config.user_base_dn, Scope::Subtree, &filter,
vec!["userPassword", "uid", "cn", "mail", "displayName"]).await?;
let (entries, _res) = res.success()?;
let entries: Vec<SearchEntry> = entries
.into_iter()
.map(|e| SearchEntry::construct(e))
.collect();
if entries.is_empty() {
Ok(false)
} else if entries.len() > 1 {
warn!("Got multiple DNs for user ({}), unsure which one to use!!", username);
Ok(false)
} else {
let entry = entries.first().unwrap();
let res = self.ldap.simple_bind(&entry.dn, password).await?;
if res.rc == 0 {
Ok(true)
} else if res.rc == 49 {
warn!("User failed to auth (invalidCredentials, rc=49)!");
Ok(false)
} else {
// this would fail, its just here to propagate the error down
res.success()?;
Ok(false)
}
}
} */
}
#[async_trait]
impl AuthDriver for LdapAuthDriver {
async fn user_has_permission(&mut self, email: String, repository: String, permission: Permission, required_visibility: Option<RepositoryVisibility>) -> anyhow::Result<bool> {
async fn is_user_admin(&mut self, email: String) -> anyhow::Result<bool> {
self.bind().await?;
// Send a request to LDAP to check if the user is an admin
@ -90,7 +49,14 @@ impl AuthDriver for LdapAuthDriver {
.map(|e| SearchEntry::construct(e))
.collect();
if entries.len() > 0 {
Ok(entries.len() > 0)
}
}
#[async_trait]
impl AuthDriver for LdapAuthDriver {
async fn user_has_permission(&mut self, email: String, repository: String, permission: Permission, required_visibility: Option<RepositoryVisibility>) -> anyhow::Result<bool> {
if self.is_user_admin(email.clone()).await? {
Ok(true)
} else {
debug!("LDAP is falling back to database");
@ -118,7 +84,7 @@ impl AuthDriver for LdapAuthDriver {
warn!("Got multiple DNs for user ({}), unsure which one to use!!", email);
Ok(false)
} else {
let entry = entries.first().unwrap();
let entry = entries.first().unwrap(); // there will be an entry
let res = self.ldap.simple_bind(&entry.dn, &password).await?;
if res.rc == 0 {
@ -126,8 +92,22 @@ impl AuthDriver for LdapAuthDriver {
// Check if the user is stored in the database, if not, add it.
let database = &self.database;
if !database.does_user_exist(email.clone()).await? {
let display_name = entry.attrs.get(&self.ldap_config.display_name_attribute).unwrap().first().unwrap().clone();
database.create_user(email, display_name, LoginSource::LDAP).await?;
let display_name = match entry.attrs.get(&self.ldap_config.display_name_attribute) {
// theres no way the vector would be empty
Some(display) => display.first().unwrap().clone(),
None => return Ok(false),
};
database.create_user(email.clone(), display_name, LoginSource::LDAP).await?;
drop(database);
// Set the user registry type
let user_type = match self.is_user_admin(email.clone()).await? {
true => RegistryUserType::Admin,
false => RegistryUserType::Regular
};
self.database.set_user_registry_type(email, user_type).await?;
}
Ok(true)

View File

@ -1,8 +1,8 @@
pub mod ldap_driver;
use std::{collections::HashSet, ops::Deref, sync::Arc};
use std::{ops::Deref, sync::Arc};
use axum::{extract::{State, Path}, http::{StatusCode, HeaderMap, header, HeaderName, Request}, middleware::Next, response::{Response, IntoResponse}};
use axum::{extract::State, http::{StatusCode, HeaderMap, header, HeaderName, Request}, middleware::Next, response::{Response, IntoResponse}};
use sqlx::{Pool, Sqlite};
use tracing::debug;
@ -101,7 +101,7 @@ pub async fn require_auth<B>(State(state): State<Arc<AppState>>, mut request: Re
// If the token is not valid, return an unauthorized response
let database = &state.database;
if let Some(user) = database.verify_user_token(token.to_string()).await.unwrap() {
if let Ok(Some(user)) = database.verify_user_token(token.to_string()).await {
debug!("Authenticated user through middleware: {}", user.user.username);
request.extensions_mut().insert(user);

View File

@ -12,47 +12,48 @@ pub trait Database {
// Digest related functions
/// Create the tables in the database
async fn create_schema(&self) -> sqlx::Result<()>;
async fn create_schema(&self) -> anyhow::Result<()>;
// Tag related functions
/// Get tags associated with a repository
async fn list_repository_tags(&self, repository: &str,) -> sqlx::Result<Vec<Tag>>;
async fn list_repository_tags_page(&self, repository: &str, limit: u32, last_tag: Option<String>) -> sqlx::Result<Vec<Tag>>;
async fn list_repository_tags(&self, repository: &str,) -> anyhow::Result<Vec<Tag>>;
async fn list_repository_tags_page(&self, repository: &str, limit: u32, last_tag: Option<String>) -> anyhow::Result<Vec<Tag>>;
/// Get a manifest digest using the tag name.
async fn get_tag(&self, repository: &str, tag: &str) -> sqlx::Result<Option<Tag>>;
async fn get_tag(&self, repository: &str, tag: &str) -> anyhow::Result<Option<Tag>>;
/// Save a tag and reference it to the manifest digest.
async fn save_tag(&self, repository: &str, tag: &str, manifest_digest: &str) -> sqlx::Result<()>;
async fn save_tag(&self, repository: &str, tag: &str, manifest_digest: &str) -> anyhow::Result<()>;
/// Delete a tag.
async fn delete_tag(&self, repository: &str, tag: &str) -> sqlx::Result<()>;
async fn delete_tag(&self, repository: &str, tag: &str) -> anyhow::Result<()>;
// Manifest related functions
/// Get a manifest's content.
async fn get_manifest(&self, repository: &str, digest: &str) -> sqlx::Result<Option<String>>;
async fn get_manifest(&self, repository: &str, digest: &str) -> anyhow::Result<Option<String>>;
/// Save a manifest's content.
async fn save_manifest(&self, repository: &str, digest: &str, content: &str) -> sqlx::Result<()>;
async fn save_manifest(&self, repository: &str, digest: &str, content: &str) -> anyhow::Result<()>;
/// Delete a manifest
/// Returns digests that this manifest pointed to.
async fn delete_manifest(&self, repository: &str, digest: &str) -> sqlx::Result<Vec<String>>;
async fn link_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> sqlx::Result<()>;
async fn unlink_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> sqlx::Result<()>;
async fn delete_manifest(&self, repository: &str, digest: &str) -> anyhow::Result<Vec<String>>;
async fn link_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> anyhow::Result<()>;
async fn unlink_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> anyhow::Result<()>;
// Repository related functions
async fn has_repository(&self, repository: &str) -> sqlx::Result<bool>;
async fn has_repository(&self, repository: &str) -> anyhow::Result<bool>;
async fn get_repository_visibility(&self, repository: &str) -> anyhow::Result<Option<RepositoryVisibility>>;
/// Create a repository
async fn save_repository(&self, repository: &str, visibility: RepositoryVisibility, owning_project: Option<String>) -> sqlx::Result<()>;
async fn save_repository(&self, repository: &str, visibility: RepositoryVisibility, owning_project: Option<String>) -> anyhow::Result<()>;
/// List all repositories.
/// If limit is not specified, a default limit of 1000 will be returned.
async fn list_repositories(&self, limit: Option<u32>, last_repo: Option<String>) -> sqlx::Result<Vec<String>>;
async fn list_repositories(&self, limit: Option<u32>, last_repo: Option<String>) -> anyhow::Result<Vec<String>>;
/// User stuff
async fn does_user_exist(&self, email: String) -> sqlx::Result<bool>;
async fn create_user(&self, email: String, username: String, login_source: LoginSource) -> sqlx::Result<User>;
async fn add_user_auth(&self, email: String, password_hash: String, password_salt: String) -> sqlx::Result<()>;
async fn does_user_exist(&self, email: String) -> anyhow::Result<bool>;
async fn create_user(&self, email: String, username: String, login_source: LoginSource) -> anyhow::Result<User>;
async fn add_user_auth(&self, email: String, password_hash: String, password_salt: String) -> anyhow::Result<()>;
async fn set_user_registry_type(&self, email: String, user_type: RegistryUserType) -> anyhow::Result<()>;
async fn verify_user_login(&self, email: String, password: String) -> anyhow::Result<bool>;
async fn get_user_registry_type(&self, email: String) -> anyhow::Result<Option<RegistryUserType>>;
async fn get_user_repo_permissions(&self, email: String, repository: String) -> anyhow::Result<Option<RepositoryPermissions>>;
@ -63,7 +64,7 @@ pub trait Database {
#[async_trait]
impl Database for Pool<Sqlite> {
async fn create_schema(&self) -> sqlx::Result<()> {
async fn create_schema(&self) -> anyhow::Result<()> {
sqlx::query(include_str!("schemas/schema.sql"))
.execute(self).await?;
@ -72,7 +73,7 @@ impl Database for Pool<Sqlite> {
Ok(())
}
async fn link_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> sqlx::Result<()> {
async fn link_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> anyhow::Result<()> {
sqlx::query("INSERT INTO manifest_layers(manifest, layer_digest) VALUES (?, ?)")
.bind(manifest_digest)
.bind(layer_digest)
@ -83,7 +84,7 @@ impl Database for Pool<Sqlite> {
Ok(())
}
async fn unlink_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> sqlx::Result<()> {
async fn unlink_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> anyhow::Result<()> {
sqlx::query("DELETE FROM manifest_layers WHERE manifest = ? AND layer_digest = ?")
.bind(manifest_digest)
.bind(layer_digest)
@ -94,7 +95,7 @@ impl Database for Pool<Sqlite> {
Ok(())
}
async fn list_repository_tags(&self, repository: &str,) -> sqlx::Result<Vec<Tag>> {
async fn list_repository_tags(&self, repository: &str,) -> anyhow::Result<Vec<Tag>> {
let rows: Vec<(String, String, i64, )> = sqlx::query_as("SELECT name, image_manifest, last_updated FROM image_tags WHERE repository = ?")
.bind(repository)
.fetch_all(self).await?;
@ -108,7 +109,7 @@ impl Database for Pool<Sqlite> {
Ok(tags)
}
async fn list_repository_tags_page(&self, repository: &str, limit: u32, last_tag: Option<String>) -> sqlx::Result<Vec<Tag>> {
async fn list_repository_tags_page(&self, repository: &str, limit: u32, last_tag: Option<String>) -> anyhow::Result<Vec<Tag>> {
// Query differently depending on if `last_tag` was specified
let rows: Vec<(String, String, i64, )> = match last_tag {
Some(last_tag) => {
@ -135,7 +136,7 @@ impl Database for Pool<Sqlite> {
Ok(tags)
}
async fn get_tag(&self, repository: &str, tag: &str) -> sqlx::Result<Option<Tag>> {
async fn get_tag(&self, repository: &str, tag: &str) -> anyhow::Result<Option<Tag>> {
debug!("get tag");
let row: (String, i64, ) = match sqlx::query_as("SELECT image_manifest, last_updated FROM image_tags WHERE name = ? AND repository = ?")
.bind(tag)
@ -147,7 +148,7 @@ impl Database for Pool<Sqlite> {
return Ok(None)
},
_ => {
return Err(e);
return Err(anyhow::Error::new(e));
}
}
};
@ -157,7 +158,7 @@ impl Database for Pool<Sqlite> {
Ok(Some(Tag::new(tag.to_string(), repository.to_string(), last_updated, row.0)))
}
async fn save_tag(&self, repository: &str, tag: &str, digest: &str) -> sqlx::Result<()> {
async fn save_tag(&self, repository: &str, tag: &str, digest: &str) -> anyhow::Result<()> {
sqlx::query("INSERT INTO image_tags (name, repository, image_manifest, last_updated) VALUES (?, ?, ?, ?)")
.bind(tag)
.bind(repository)
@ -168,7 +169,7 @@ impl Database for Pool<Sqlite> {
Ok(())
}
async fn delete_tag(&self, repository: &str, tag: &str) -> sqlx::Result<()> {
async fn delete_tag(&self, repository: &str, tag: &str) -> anyhow::Result<()> {
sqlx::query("DELETE FROM image_tags WHERE 'name' = ? AND repository = ?")
.bind(tag)
.bind(repository)
@ -177,7 +178,7 @@ impl Database for Pool<Sqlite> {
Ok(())
}
async fn get_manifest(&self, repository: &str, digest: &str) -> sqlx::Result<Option<String>> {
async fn get_manifest(&self, repository: &str, digest: &str) -> anyhow::Result<Option<String>> {
let row: (String, ) = match sqlx::query_as("SELECT content FROM image_manifests where digest = ? AND repository = ?")
.bind(digest)
.bind(repository)
@ -188,7 +189,7 @@ impl Database for Pool<Sqlite> {
return Ok(None)
},
_ => {
return Err(e);
return Err(anyhow::Error::new(e));
}
}
};
@ -196,7 +197,7 @@ impl Database for Pool<Sqlite> {
Ok(Some(row.0))
}
async fn save_manifest(&self, repository: &str, digest: &str, manifest: &str) -> sqlx::Result<()> {
async fn save_manifest(&self, repository: &str, digest: &str, manifest: &str) -> anyhow::Result<()> {
sqlx::query("INSERT INTO image_manifests (digest, repository, content) VALUES (?, ?, ?)")
.bind(digest)
.bind(repository)
@ -206,7 +207,7 @@ impl Database for Pool<Sqlite> {
Ok(())
}
async fn delete_manifest(&self, repository: &str, digest: &str) -> sqlx::Result<Vec<String>> {
async fn delete_manifest(&self, repository: &str, digest: &str) -> anyhow::Result<Vec<String>> {
sqlx::query("DELETE FROM image_manifests where digest = ? AND repository = ?")
.bind(digest)
.bind(repository)
@ -234,7 +235,7 @@ impl Database for Pool<Sqlite> {
Ok(digests)
}
async fn has_repository(&self, repository: &str) -> sqlx::Result<bool> {
async fn has_repository(&self, repository: &str) -> anyhow::Result<bool> {
let row: (u32, ) = match sqlx::query_as("SELECT COUNT(1) FROM repositories WHERE \"name\" = ?")
.bind(repository)
.fetch_one(self).await {
@ -244,7 +245,7 @@ impl Database for Pool<Sqlite> {
return Ok(false)
},
_ => {
return Err(e);
return Err(anyhow::Error::new(e));
}
}
};
@ -270,7 +271,7 @@ impl Database for Pool<Sqlite> {
Ok(Some(RepositoryVisibility::try_from(row.0)?))
}
async fn save_repository(&self, repository: &str, visibility: RepositoryVisibility, owning_project: Option<String>) -> sqlx::Result<()> {
async fn save_repository(&self, repository: &str, visibility: RepositoryVisibility, owning_project: Option<String>) -> anyhow::Result<()> {
// ensure that the repository was not already created
if self.has_repository(repository).await? {
debug!("repo exists");
@ -297,8 +298,8 @@ impl Database for Pool<Sqlite> {
Ok(())
}
//async fn list_repositories(&self) -> sqlx::Result<Vec<String>> {
async fn list_repositories(&self, limit: Option<u32>, last_repo: Option<String>) -> sqlx::Result<Vec<String>> {
//async fn list_repositories(&self) -> anyhow::Result<Vec<String>> {
async fn list_repositories(&self, limit: Option<u32>, last_repo: Option<String>) -> anyhow::Result<Vec<String>> {
let limit = limit.unwrap_or(1000); // set default limit
// Query differently depending on if `last_repo` was specified
@ -322,7 +323,7 @@ impl Database for Pool<Sqlite> {
Ok(repos)
}
async fn does_user_exist(&self, email: String) -> sqlx::Result<bool> {
async fn does_user_exist(&self, email: String) -> anyhow::Result<bool> {
let row: (u32, ) = match sqlx::query_as("SELECT COUNT(1) FROM users WHERE \"email\" = ?")
.bind(email)
.fetch_one(self).await {
@ -332,7 +333,7 @@ impl Database for Pool<Sqlite> {
return Ok(false)
},
_ => {
return Err(e);
return Err(anyhow::Error::new(e));
}
}
};
@ -340,7 +341,7 @@ impl Database for Pool<Sqlite> {
Ok(row.0 > 0)
}
async fn create_user(&self, email: String, username: String, login_source: LoginSource) -> sqlx::Result<User> {
async fn create_user(&self, email: String, username: String, login_source: LoginSource) -> anyhow::Result<User> {
let username = username.to_lowercase();
let email = email.to_lowercase();
sqlx::query("INSERT INTO users (username, email, login_source) VALUES (?, ?, ?)")
@ -352,7 +353,7 @@ impl Database for Pool<Sqlite> {
Ok(User::new(username, email, login_source))
}
async fn add_user_auth(&self, email: String, password_hash: String, password_salt: String) -> sqlx::Result<()> {
async fn add_user_auth(&self, email: String, password_hash: String, password_salt: String) -> anyhow::Result<()> {
let email = email.to_lowercase();
sqlx::query("INSERT INTO user_logins (email, password_hash, password_salt) VALUES (?, ?, ?)")
.bind(email.clone())
@ -363,6 +364,16 @@ impl Database for Pool<Sqlite> {
Ok(())
}
async fn set_user_registry_type(&self, email: String, user_type: RegistryUserType) -> anyhow::Result<()> {
let email = email.to_lowercase();
sqlx::query("INSERT INTO user_registry_permissions (email, user_type) VALUES (?, ?)")
.bind(email.clone())
.bind(user_type as u32)
.execute(self).await?;
Ok(())
}
async fn verify_user_login(&self, email: String, password: String) -> anyhow::Result<bool> {
let email = email.to_lowercase();
let row: (String, ) = sqlx::query_as("SELECT password_hash FROM users WHERE email = ?")
@ -410,10 +421,17 @@ impl Database for Pool<Sqlite> {
}
};
let vis = self.get_repository_visibility(&repository).await?.unwrap();
let vis = match self.get_repository_visibility(&repository).await? {
Some(v) => v,
None => return Ok(None),
};
// Also get the user type for the registry, if its admin return admin repository permissions
let utype = self.get_user_registry_usertype(email).await?.unwrap(); // unwrap should be safe
let utype = match self.get_user_registry_usertype(email).await? {
Some(t) => t,
None => return Ok(None),
};
if utype == RegistryUserType::Admin {
Ok(Some(RepositoryPermissions::new(Permission::ADMIN.bits(), vis)))
} else {
@ -479,11 +497,15 @@ impl Database for Pool<Sqlite> {
.bind(email.clone())
.fetch_one(self).await?; */
let (expiry, created_at) = (Utc.timestamp_millis_opt(expiry).unwrap(), Utc.timestamp_millis_opt(created_at).unwrap());
let (expiry, created_at) = (Utc.timestamp_millis_opt(expiry).single(), Utc.timestamp_millis_opt(created_at).single());
if let (Some(expiry), Some(created_at)) = (expiry, created_at) {
let user = User::new(email, user_row.0, LoginSource::try_from(user_row.1)?);
let token = TokenInfo::new(token, expiry, created_at);
let auth = UserAuth::new(user, token);
Ok(Some(auth))
} else {
Ok(None)
}
}
}

26
src/error.rs Normal file
View File

@ -0,0 +1,26 @@
use axum::{response::{IntoResponse, Response}, http::StatusCode};
// Make our own error that wraps `anyhow::Error`.
pub struct AppError(anyhow::Error);
// Tell axum how to convert `AppError` into a response.
impl IntoResponse for AppError {
fn into_response(self) -> Response {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Something went wrong: {}", self.0),
)
.into_response()
}
}
// This enables using `?` on functions that return `Result<_, anyhow::Error>` to turn them into
// `Result<_, AppError>`. That way you don't need to do that manually.
impl<E> From<E> for AppError
where
E: Into<anyhow::Error>,
{
fn from(err: E) -> Self {
Self(err.into())
}
}

View File

@ -5,20 +5,19 @@ mod dto;
mod storage;
mod byte_stream;
mod config;
mod query;
mod auth;
mod error;
use std::net::SocketAddr;
use std::str::FromStr;
use std::sync::Arc;
use auth::{AuthDriver, ldap_driver::LdapAuthDriver};
use axum::http::{Request, StatusCode, header, HeaderName};
use axum::http::{Request, StatusCode};
use axum::middleware::Next;
use axum::response::{Response, IntoResponse};
use axum::response::Response;
use axum::{Router, routing};
use axum::ServiceExt;
use bcrypt::Version;
use tower_layer::Layer;
use sqlx::sqlite::SqlitePoolOptions;
@ -29,21 +28,21 @@ use tracing::{debug, Level};
use app_state::AppState;
use database::Database;
use crate::dto::user::Permission;
use crate::storage::StorageDriver;
use crate::storage::filesystem::FilesystemDriver;
use crate::config::{Config, LdapConnectionConfig};
use crate::config::Config;
use tower_http::trace::TraceLayer;
/// Encode the 'name' path parameter in the url
async fn change_request_paths<B>(mut request: Request<B>, next: Next<B>) -> Response {
async fn change_request_paths<B>(mut request: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {
// Attempt to find the name using regex in the url
let regex = regex::Regex::new(r"/v2/([\w/]+)/(blobs|tags|manifests)").unwrap();
let regex = regex::Regex::new(r"/v2/([\w/]+)/(blobs|tags|manifests)")
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let captures = match regex.captures(request.uri().path()) {
Some(captures) => captures,
None => return next.run(request).await,
None => return Ok(next.run(request).await),
};
// Find the name in the request and encode it in the url
@ -54,24 +53,25 @@ async fn change_request_paths<B>(mut request: Request<B>, next: Next<B>) -> Resp
let uri_str = request.uri().to_string().replace(&name, &encoded_name);
debug!("Rewrote request url to: '{}'", uri_str);
*request.uri_mut() = uri_str.parse().unwrap();
*request.uri_mut() = uri_str.parse()
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
next.run(request).await
Ok(next.run(request).await)
}
#[tokio::main]
async fn main() -> std::io::Result<()> {
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_max_level(Level::DEBUG)
.init();
let config = Config::new().expect("Failure to parse config!");
let config = Config::new()
.expect("Failure to parse config!");
let pool = SqlitePoolOptions::new()
.max_connections(15)
.connect("test.db").await.unwrap();
pool.create_schema().await.unwrap();
.connect("test.db").await?;
pool.create_schema().await?;
let storage_driver: Mutex<Box<dyn StorageDriver>> = Mutex::new(Box::new(FilesystemDriver::new("registry/blobs")));
@ -79,7 +79,7 @@ async fn main() -> std::io::Result<()> {
// the fallback is a database auth driver.
let auth_driver: Mutex<Box<dyn AuthDriver>> = match config.ldap.clone() {
Some(ldap) => {
let ldap_driver = LdapAuthDriver::new(ldap, pool.clone()).await.unwrap();
let ldap_driver = LdapAuthDriver::new(ldap, pool.clone()).await?;
Mutex::new(Box::new(ldap_driver))
},
None => {
@ -87,7 +87,7 @@ async fn main() -> std::io::Result<()> {
}
};
let app_addr = SocketAddr::from_str(&format!("{}:{}", config.listen_address, config.listen_port)).unwrap();
let app_addr = SocketAddr::from_str(&format!("{}:{}", config.listen_address, config.listen_port))?;
let state = Arc::new(AppState::new(pool, storage_driver, config, auth_driver));
@ -129,8 +129,7 @@ async fn main() -> std::io::Result<()> {
debug!("Starting http server, listening on {}", app_addr);
axum::Server::bind(&app_addr)
.serve(layered_app.into_make_service())
.await
.unwrap();
.await?;
Ok(())
}

View File

@ -1,6 +1,6 @@
use std::{path::Path, io::ErrorKind};
use anyhow::Context;
use anyhow::{Context, anyhow};
use async_trait::async_trait;
use bytes::Bytes;
use futures::StreamExt;
@ -53,7 +53,7 @@ impl StorageDriver for FilesystemDriver {
len += bytes.len();
file.write_all(&bytes).await.unwrap();
file.write_all(&bytes).await?;
}
Ok(len)
@ -140,9 +140,12 @@ impl StorageDriver for FilesystemDriver {
async fn replace_digest(&self, uuid: &str, digest: &str) -> anyhow::Result<()> {
let path = self.get_digest_path(uuid);
let path = Path::new(&path);
let parent = path.clone().parent().unwrap();
let parent = path
.clone()
.parent()
.ok_or(anyhow!("Failure to get parent path of digest file!"))?;
fs::rename(path, format!("{}/{}", parent.as_os_str().to_str().unwrap(), digest)).await?;
fs::rename(path, format!("{}/{}", parent.display(), digest)).await?;
Ok(())
}