Compare commits

...

12 Commits

21 changed files with 814 additions and 367 deletions

View File

@ -1,6 +1,5 @@
/target /target
.env .env
.vscode .vscode
test.db
/registry /registry
config.toml config.toml

1
.gitignore vendored
View File

@ -3,3 +3,4 @@
.vscode .vscode
/registry /registry
config.toml config.toml
Dockerfile

87
Cargo.lock generated
View File

@ -321,7 +321,8 @@ dependencies = [
"js-sys", "js-sys",
"num-integer", "num-integer",
"num-traits", "num-traits",
"time", "serde",
"time 0.1.45",
"wasm-bindgen", "wasm-bindgen",
"winapi", "winapi",
] ]
@ -423,6 +424,16 @@ version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9cace84e55f07e7301bae1c519df89cdad8cc3cd868413d3fdbdeca9ff3db484" checksum = "9cace84e55f07e7301bae1c519df89cdad8cc3cd868413d3fdbdeca9ff3db484"
[[package]]
name = "crossbeam-channel"
version = "0.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200"
dependencies = [
"cfg-if",
"crossbeam-utils",
]
[[package]] [[package]]
name = "crossbeam-queue" name = "crossbeam-queue"
version = "0.3.8" version = "0.3.8"
@ -1132,6 +1143,15 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "matchers"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
dependencies = [
"regex-automata",
]
[[package]] [[package]]
name = "matchit" name = "matchit"
version = "0.7.0" version = "0.7.0"
@ -1331,6 +1351,7 @@ dependencies = [
"tower-http", "tower-http",
"tower-layer", "tower-layer",
"tracing", "tracing",
"tracing-appender",
"tracing-log", "tracing-log",
"tracing-subscriber", "tracing-subscriber",
"uuid", "uuid",
@ -1573,6 +1594,15 @@ dependencies = [
"regex-syntax", "regex-syntax",
] ]
[[package]]
name = "regex-automata"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
dependencies = [
"regex-syntax",
]
[[package]] [[package]]
name = "regex-syntax" name = "regex-syntax"
version = "0.6.28" version = "0.6.28"
@ -2093,6 +2123,33 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "time"
version = "0.3.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59e399c068f43a5d116fedaf73b203fa4f9c519f17e2b34f63221d3792f81446"
dependencies = [
"itoa",
"serde",
"time-core",
"time-macros",
]
[[package]]
name = "time-core"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb"
[[package]]
name = "time-macros"
version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96ba15a897f3c86766b757e5ac7221554c6750054d74d5b28844fce5fb36a6c4"
dependencies = [
"time-core",
]
[[package]] [[package]]
name = "tinyvec" name = "tinyvec"
version = "1.6.0" version = "1.6.0"
@ -2262,6 +2319,17 @@ dependencies = [
"tracing-core", "tracing-core",
] ]
[[package]]
name = "tracing-appender"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09d48f71a791638519505cefafe162606f706c25592e4bde4d97600c0195312e"
dependencies = [
"crossbeam-channel",
"time 0.3.23",
"tracing-subscriber",
]
[[package]] [[package]]
name = "tracing-attributes" name = "tracing-attributes"
version = "0.1.23" version = "0.1.23"
@ -2294,18 +2362,35 @@ dependencies = [
"tracing-core", "tracing-core",
] ]
[[package]]
name = "tracing-serde"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc6b213177105856957181934e4920de57730fc69bf42c37ee5bb664d406d9e1"
dependencies = [
"serde",
"tracing-core",
]
[[package]] [[package]]
name = "tracing-subscriber" name = "tracing-subscriber"
version = "0.3.16" version = "0.3.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6176eae26dd70d0c919749377897b54a9276bd7061339665dd68777926b5a70" checksum = "a6176eae26dd70d0c919749377897b54a9276bd7061339665dd68777926b5a70"
dependencies = [ dependencies = [
"matchers",
"nu-ansi-term", "nu-ansi-term",
"once_cell",
"regex",
"serde",
"serde_json",
"sharded-slab", "sharded-slab",
"smallvec", "smallvec",
"thread_local", "thread_local",
"tracing",
"tracing-core", "tracing-core",
"tracing-log", "tracing-log",
"tracing-serde",
] ]
[[package]] [[package]]

View File

@ -7,14 +7,15 @@ edition = "2021"
[dependencies] [dependencies]
tracing = "0.1.37" tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = [ "tracing-log" ] } tracing-subscriber = { version = "0.3.16", features = [ "tracing-log", "json", "env-filter" ] }
tracing-log = "0.1.3" tracing-log = "0.1.3"
tracing-appender = "0.2.2"
uuid = { version = "1.3.1", features = [ "v4", "fast-rng" ] } uuid = { version = "1.3.1", features = [ "v4", "fast-rng" ] }
sqlx = { version = "0.6.3", features = [ "runtime-tokio-rustls", "sqlite" ] } sqlx = { version = "0.6.3", features = [ "runtime-tokio-rustls", "sqlite" ] }
bytes = "1.4.0" bytes = "1.4.0"
chrono = "0.4.23" chrono = { version = "0.4.23", features = [ "serde" ] }
tokio = { version = "1.21.2", features = [ "fs", "macros" ] } tokio = { version = "1.21.2", features = [ "fs", "macros" ] }
tokio-util = { version = "0.7.7", features = [ "io" ] } tokio-util = { version = "0.7.7", features = [ "io" ] }

View File

@ -1,5 +1,7 @@
FROM rust:alpine3.17 as builder FROM rust:alpine3.17 as builder
ARG RELEASE_BUILD=true
# update packages # update packages
RUN apk update RUN apk update
RUN apk add build-base openssl-dev ca-certificates RUN apk add build-base openssl-dev ca-certificates
@ -17,13 +19,13 @@ WORKDIR /app/src
# Build dependencies only. Separate these for caches # Build dependencies only. Separate these for caches
RUN cargo install cargo-build-deps RUN cargo install cargo-build-deps
RUN cargo build-deps --release RUN sh -c "cargo build-deps --release"
# Build the release executable. # Build the release executable.
RUN cargo build --release RUN sh -c "cargo build --release"
# Runner stage. I tried using distroless (gcr.io/distroless/static-debian11), but the image was only ~3MBs smaller than # Runner stage. I tried using distroless (gcr.io/distroless/static-debian11), but the image was only ~3MBs smaller than
# alpine. I chose to use alpine since a user can easily be added to the image. # alpine. I chose to use alpine since it makes it easier to exec into the container to debug things.
FROM alpine:3.17 FROM alpine:3.17
ARG UNAME=orca-registry ARG UNAME=orca-registry
@ -34,6 +36,7 @@ ARG GID=1000
RUN adduser --disabled-password --gecos "" $UNAME -s -G $GID -u $UID RUN adduser --disabled-password --gecos "" $UNAME -s -G $GID -u $UID
COPY --from=builder --chown=$UID:$GID /app/src/target/release/orca-registry /app/orca-registry COPY --from=builder --chown=$UID:$GID /app/src/target/release/orca-registry /app/orca-registry
# Chown everything
RUN mkdir /data && \ RUN mkdir /data && \
chown -R $UID:$GID /data && \ chown -R $UID:$GID /data && \
chown -R $UID:$GID /app chown -R $UID:$GID /app

50
Dockerfile.debug Normal file
View File

@ -0,0 +1,50 @@
FROM rust:alpine3.17 as builder
ARG RELEASE_BUILD=true
# update packages
RUN apk update
RUN apk add build-base openssl-dev ca-certificates
# create root application folder
WORKDIR /app
COPY ./ /app/src
# Install rust toolchains
RUN rustup toolchain install stable
RUN rustup default stable
WORKDIR /app/src
# Build dependencies only. Separate these for caches
RUN cargo install cargo-build-deps
RUN sh -c "cargo build-deps"
# Build the release executable.
RUN sh -c "cargo build"
# Runner stage. I tried using distroless (gcr.io/distroless/static-debian11), but the image was only ~3MBs smaller than
# alpine. I chose to use alpine since it makes it easier to exec into the container to debug things.
FROM alpine:3.17
ARG UNAME=orca-registry
ARG UID=1000
ARG GID=1000
# Add user and copy the executable from the build stage.
RUN adduser --disabled-password --gecos "" $UNAME -s -G $GID -u $UID
COPY --from=builder --chown=$UID:$GID /app/src/target/debug/orca-registry /app/orca-registry
# Chown everything
RUN mkdir /data && \
chown -R $UID:$GID /data && \
chown -R $UID:$GID /app
USER $UNAME
WORKDIR /app/
EXPOSE 3000
ENTRYPOINT [ "/app/orca-registry" ]

View File

@ -20,7 +20,7 @@ These instructions are assuming the user is stored in the database, if you use L
2. Create a bcrypt password hash for the new user: 2. Create a bcrypt password hash for the new user:
```shell ```shell
$ htpasswd -nB $ htpasswd -nB <username>
``` ```
3. Insert the new user's email, password hash into the `user_logins` table. The salt is not used, so you can put whatever there 3. Insert the new user's email, password hash into the `user_logins` table. The salt is not used, so you can put whatever there

View File

@ -1,8 +1,17 @@
use std::{sync::Arc, collections::{HashMap, BTreeMap}, time::SystemTime}; use std::{
collections::HashMap,
sync::Arc,
time::SystemTime,
};
use axum::{extract::{Query, State}, response::{IntoResponse, Response}, http::{StatusCode, header}, Form}; use axum::{
extract::{Query, State},
http::{header, StatusCode},
response::{IntoResponse, Response},
Form,
};
use axum_auth::AuthBasic; use axum_auth::AuthBasic;
use chrono::{DateTime, Utc, Duration}; use chrono::{DateTime, Days, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tracing::{debug, error, info, span, Level}; use tracing::{debug, error, info, span, Level};
@ -12,12 +21,19 @@ use sha2::Sha256;
use rand::Rng; use rand::Rng;
use crate::{dto::{scope::Scope, user::TokenInfo}, app_state::AppState}; use crate::{database::Database, dto::scope::Action};
use crate::database::Database; use crate::{
app_state::AppState,
dto::{
scope::{Scope, ScopeType},
user::{AuthToken, TokenInfo},
RepositoryVisibility,
},
};
use crate::auth::auth_challenge_response; use crate::auth::auth_challenge_response;
#[derive(Deserialize, Debug)] #[derive(Debug)]
pub struct TokenAuthRequest { pub struct TokenAuthRequest {
user: Option<String>, user: Option<String>,
password: Option<String>, password: Option<String>,
@ -39,43 +55,59 @@ pub struct AuthForm {
#[derive(Deserialize, Serialize, Debug)] #[derive(Deserialize, Serialize, Debug)]
pub struct AuthResponse { pub struct AuthResponse {
token: String, token: String,
access_token: Option<String>,
expires_in: u32, expires_in: u32,
issued_at: String, issued_at: String,
} }
/// In the returned UserToken::user, only the username is specified fn create_jwt_token(jwt_key: String, account: Option<&str>, scopes: Vec<Scope>) -> anyhow::Result<TokenInfo> {
fn create_jwt_token(account: &str) -> anyhow::Result<TokenInfo> { let key: Hmac<Sha256> = Hmac::new_from_slice(jwt_key.as_bytes())?;
let key: Hmac<Sha256> = Hmac::new_from_slice(b"some-secret")?;
let now = chrono::offset::Utc::now(); let now = chrono::offset::Utc::now();
let now_secs = now.timestamp();
// Construct the claims for the token // Expire the token in a day
let mut claims = BTreeMap::new(); let expiration = now.checked_add_days(Days::new(1)).unwrap();
claims.insert("issuer", "orca-registry__DEV");
claims.insert("subject", &account);
//claims.insert("audience", auth.service);
let not_before = format!("{}", now_secs);
let issued_at = format!("{}", now_secs);
let expiration = format!("{}", now_secs + 86400); // 1 day
claims.insert("notbefore", &not_before);
claims.insert("issuedat", &issued_at);
claims.insert("expiration", &expiration); // TODO: 20 seconds expiry for testing
let issued_at = now;
let expiration = now + Duration::seconds(20);
// Create a randomized jwtid
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let jwtid = format!("{}", rng.gen::<u64>()); let jwtid = format!("{}", rng.gen::<u64>());
claims.insert("jwtid", &jwtid);
let token_str = claims.sign_with_key(&key)?; // empty account if they are not authenticated
Ok(TokenInfo::new(token_str, expiration, issued_at)) let account = account.map(|a| a.to_string()).unwrap_or(String::new());
// Construct the claims for the token
// TODO: Verify the token!
let token = AuthToken::new(
String::from("orca-registry__DEV"),
account,
String::from("reg"),
expiration,
now.clone(),
now.clone(),
jwtid,
scopes,
);
let token_str = token.sign_with_key(&key)?;
Ok(TokenInfo::new(token_str, expiration, now))
} }
pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppState>>, Query(params): Query<HashMap<String, String>>, form: Option<Form<AuthForm>>) -> Result<Response, StatusCode> { pub async fn auth_basic_post() -> Result<Response, StatusCode> {
return Ok((
StatusCode::METHOD_NOT_ALLOWED,
[
(header::CONTENT_TYPE, "application/json"),
(header::ALLOW, "Allow: GET, HEAD, OPTIONS"),
],
"{\"detail\": \"Method \\\"POST\\\" not allowed.\"}"
).into_response());
}
pub async fn auth_basic_get(
basic_auth: Option<AuthBasic>,
state: State<Arc<AppState>>,
Query(params): Query<HashMap<String, String>>,
form: Option<Form<AuthForm>>,
) -> Result<Response, StatusCode> {
let mut auth = TokenAuthRequest { let mut auth = TokenAuthRequest {
user: None, user: None,
password: None, password: None,
@ -88,6 +120,19 @@ pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppS
let auth_method; let auth_method;
// Process all the scopes
if let Some(scope) = params.get("scope") {
// TODO: Handle multiple scopes
match Scope::try_from(&scope[..]) {
Ok(scope) => {
auth.scope.push(scope);
}
Err(_) => {
return Err(StatusCode::BAD_REQUEST);
}
}
}
// If BasicAuth is provided, set the fields to it // If BasicAuth is provided, set the fields to it
if let Some(AuthBasic((username, pass))) = basic_auth { if let Some(AuthBasic((username, pass))) = basic_auth {
auth.user = Some(username.clone()); auth.user = Some(username.clone());
@ -113,15 +158,90 @@ pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppS
debug!("Read user authentication from a Form"); debug!("Read user authentication from a Form");
auth_method = "form"; auth_method = "form";
} else {
// If no auth parameters were specified, check if the repository is public. if it is, respond with a token.
let is_public_access = {
let mut res = vec![];
for scope in auth.scope.iter() {
match scope.scope_type {
ScopeType::Repository => {
// check repository visibility
let database = &state.database;
match database.get_repository_visibility(&scope.name).await {
Ok(Some(RepositoryVisibility::Public)) => res.push(Ok(true)),
Ok(_) => res.push(Ok(false)),
Err(e) => {
error!(
"Failure to check repository visibility for {}! Err: {}",
scope.name, e
);
res.push(Err(StatusCode::INTERNAL_SERVER_ERROR));
}
}
}
_ => res.push(Ok(false)),
}
}
// merge the booleans into a single bool, respond with errors if there are any.
let res: Result<Vec<bool>, StatusCode> = res.into_iter().collect();
res?.iter().all(|b| *b)
};
if is_public_access {
for scope in auth.scope.iter_mut() {
// only retain Action::Pull
scope.actions.retain(|a| *a == Action::Pull);
}
let token = create_jwt_token(state.config.jwt_key.clone(), None, auth.scope)
.map_err(|_| {
error!("Failed to create jwt token!");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let token_str = token.token;
let now_format = format!("{}", token.created_at.format("%+"));
let auth_response = AuthResponse {
token: token_str.clone(),
access_token: Some(token_str.clone()),
expires_in: 86400, // 1 day
issued_at: now_format,
};
let json_str = serde_json::to_string(&auth_response)
.map_err(|_| StatusCode::BAD_REQUEST)?;
debug!("Created anonymous token for public scopes!");
return Ok((
StatusCode::OK,
[
(header::CONTENT_TYPE, "application/json"),
(header::AUTHORIZATION, &format!("Bearer {}", token_str)),
],
json_str,
).into_response());
} else { } else {
info!("Auth failure! Auth was not provided in either AuthBasic or Form!"); info!("Auth failure! Auth was not provided in either AuthBasic or Form!");
// Maybe BAD_REQUEST should be returned? // Maybe BAD_REQUEST should be returned?
return Err(StatusCode::UNAUTHORIZED); return Err(StatusCode::UNAUTHORIZED);
} }
}
// Create logging span for the rest of this request // Create logging span for the rest of this request
let span = span!(Level::DEBUG, "auth", username = auth.user.clone(), auth_method); let span = span!(
Level::DEBUG,
"auth",
username = auth.user.clone(),
auth_method
);
let _enter = span.enter(); let _enter = span.enter();
debug!("Parsed user auth request"); debug!("Parsed user auth request");
@ -131,32 +251,27 @@ pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppS
if let Some(account) = params.get("account") { if let Some(account) = params.get("account") {
if let Some(user) = &auth.user { if let Some(user) = &auth.user {
if account != user { if account != user {
error!("`user` and `account` are not the same!!! (user: {}, account: {})", user, account); error!(
"`user` and `account` are not the same!!! (user: {}, account: {})",
user, account
);
return Err(StatusCode::BAD_REQUEST); return Err(StatusCode::BAD_REQUEST);
} }
} else {
auth.user = Some(account.clone());
} }
auth.account = Some(account.clone()); auth.account = Some(account.clone());
} else {
debug!("Account was not provided through params");
} }
// Get service from query string // Get service from query string
if let Some(service) = params.get("service") { if let Some(service) = params.get("service") {
auth.service = Some(service.clone()); auth.service = Some(service.clone());
} } else {
debug!("Service was not provided through params");
// Process all the scopes
if let Some(scope) = params.get("scope") {
// TODO: Handle multiple scopes
match Scope::try_from(&scope[..]) {
Ok(scope) => {
auth.scope.push(scope);
},
Err(_) => {
return Err(StatusCode::BAD_REQUEST);
}
}
} }
// Get offline token and attempt to convert it to a boolean // Get offline token and attempt to convert it to a boolean
@ -164,6 +279,8 @@ pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppS
if let Ok(b) = offline_token.parse::<bool>() { if let Ok(b) = offline_token.parse::<bool>() {
auth.offline_token = Some(b); auth.offline_token = Some(b);
} }
} else {
debug!("Offline Token was not provided through params");
} }
if let Some(client_id) = params.get("client_id") { if let Some(client_id) = params.get("client_id") {
@ -172,22 +289,41 @@ pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppS
debug!("Constructed auth request"); debug!("Constructed auth request");
if let (Some(account), Some(password)) = (&auth.account, auth.password) { if auth.user.is_none() {
debug!("User is none");
}
if auth.password.is_none() {
debug!("Password is none");
}
if let (Some(account), Some(password)) = (auth.user, auth.password) {
// Ensure that the password is correct // Ensure that the password is correct
let mut auth_driver = state.auth_checker.lock().await; let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.verify_user_login(account.clone(), password).await if !auth_driver
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? { .verify_user_login(account.clone(), password)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
{
debug!("Authentication failed, incorrect password!"); debug!("Authentication failed, incorrect password!");
// TODO: Multiple scopes
let scope = auth.scope
.first()
.and_then(|s| Some(s.clone()));
// TODO: Dont unwrap, find a way to return multiple scopes // TODO: Dont unwrap, find a way to return multiple scopes
return Ok(auth_challenge_response(&state.config, Some(auth.scope.first().unwrap().clone()))); return Ok(auth_challenge_response(
&state.config,
scope,
));
} }
drop(auth_driver); drop(auth_driver);
debug!("User password is correct"); debug!("User password is correct");
let now = SystemTime::now(); let now = SystemTime::now();
let token = create_jwt_token(account) let token = create_jwt_token(state.config.jwt_key.clone(), Some(&account), vec![])
.map_err(|_| { .map_err(|_| {
error!("Failed to create jwt token!"); error!("Failed to create jwt token!");
@ -204,33 +340,41 @@ pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppS
// Construct the auth response // Construct the auth response
let auth_response = AuthResponse { let auth_response = AuthResponse {
token: token_str.clone(), token: token_str.clone(),
expires_in: 20, access_token: Some(token_str.clone()),
expires_in: 86400, // 1 day
issued_at: now_format, issued_at: now_format,
}; };
let json_str = serde_json::to_string(&auth_response) let json_str =
.map_err(|_| StatusCode::BAD_REQUEST)?; serde_json::to_string(&auth_response).map_err(|_| StatusCode::BAD_REQUEST)?;
let database = &state.database; let database = &state.database;
database.store_user_token(token_str.clone(), account.clone(), token.expiry, token.created_at).await database
.store_user_token(
token_str.clone(),
account.clone(),
token.expiry,
token.created_at,
)
.await
.map_err(|_| { .map_err(|_| {
error!("Failed to store user token in database!"); error!("Failed to store user token in database!");
StatusCode::INTERNAL_SERVER_ERROR StatusCode::INTERNAL_SERVER_ERROR
})?; })?;
drop(database);
return Ok(( return Ok((
StatusCode::OK, StatusCode::OK,
[ [
( header::CONTENT_TYPE, "application/json" ), (header::CONTENT_TYPE, "application/json"),
( header::AUTHORIZATION, &format!("Bearer {}", token_str) ) (header::AUTHORIZATION, &format!("Bearer {}", token_str)),
], ],
json_str json_str,
).into_response()); )
.into_response());
} }
info!("Auth failure! Not enough information given to create auth token!"); info!("Auth failure! Not enough information given to create auth token!");
// If we didn't get fields required to make a token, then the client did something bad // If we didn't get fields required to make a token, then the client did something bad
Err(StatusCode::UNAUTHORIZED) Err(StatusCode::BAD_REQUEST)
} }

View File

@ -2,9 +2,10 @@ use std::sync::Arc;
use axum::body::StreamBody; use axum::body::StreamBody;
use axum::extract::{State, Path}; use axum::extract::{State, Path};
use axum::http::{StatusCode, header, HeaderName}; use axum::http::{StatusCode, header, HeaderName, HeaderMap};
use axum::response::{IntoResponse, Response}; use axum::response::{IntoResponse, Response};
use tokio_util::io::ReaderStream; use tokio_util::io::ReaderStream;
use tracing::debug;
use crate::app_state::AppState; use crate::app_state::AppState;
use crate::error::AppError; use crate::error::AppError;
@ -18,6 +19,7 @@ pub async fn digest_exists_head(Path((_name, layer_digest)): Path<(String, Strin
StatusCode::OK, StatusCode::OK,
[ [
(header::CONTENT_LENGTH, size.to_string()), (header::CONTENT_LENGTH, size.to_string()),
(header::ACCEPT_RANGES, "true".to_string()),
(HeaderName::from_static("docker-content-digest"), layer_digest) (HeaderName::from_static("docker-content-digest"), layer_digest)
] ]
).into_response()); ).into_response());
@ -27,22 +29,52 @@ pub async fn digest_exists_head(Path((_name, layer_digest)): Path<(String, Strin
Ok(StatusCode::NOT_FOUND.into_response()) Ok(StatusCode::NOT_FOUND.into_response())
} }
pub async fn pull_digest_get(Path((_name, layer_digest)): Path<(String, String)>, state: State<Arc<AppState>>) -> Result<Response, AppError> { pub async fn pull_digest_get(Path((_name, layer_digest)): Path<(String, String)>, header_map: HeaderMap, state: State<Arc<AppState>>) -> Result<Response, AppError> {
let storage = state.storage.lock().await; let storage = state.storage.lock().await;
if let Some(len) = storage.digest_length(&layer_digest).await? { if let Some(len) = storage.digest_length(&layer_digest).await? {
let stream = match storage.get_digest_stream(&layer_digest).await? { let mut stream = match storage.get_digest_stream(&layer_digest).await? {
Some(s) => s, Some(s) => s,
None => { None => {
return Ok(StatusCode::NOT_FOUND.into_response()); return Ok(StatusCode::NOT_FOUND.into_response());
} }
}; };
if let Some(range) = header_map.get(header::CONTENT_RANGE) {
let range = range.to_str().unwrap();
debug!("Range request received: {}", range);
let range = &range[6..];
let (starting, ending) = range.split_once("-").unwrap();
let (starting, ending) = (starting.parse::<i32>().unwrap(), ending.parse::<i32>().unwrap());
// recreate the ByteStream, skipping elements
stream = stream.skip_recreate(starting as usize);
// convert the `AsyncRead` into a `Stream` // convert the `AsyncRead` into a `Stream`
let stream = ReaderStream::new(stream.into_async_read()); let stream = ReaderStream::new(stream.into_async_read());
// convert the `Stream` into an `axum::body::HttpBody` // convert the `Stream` into an `axum::body::HttpBody`
let body = StreamBody::new(stream); let body = StreamBody::new(stream);
debug!("length of range request: {}", starting - ending);
Ok((
StatusCode::OK,
[
(header::CONTENT_LENGTH, (starting - ending).to_string()),
(header::RANGE, format!("bytes {}-{}/{}", starting, ending, len)),
(HeaderName::from_static("docker-content-digest"), layer_digest)
],
body
).into_response())
} else {
// convert the `AsyncRead` into a `Stream`
let stream = ReaderStream::new(stream.into_async_read());
// convert the `Stream` into an `axum::body::HttpBody`
let body = StreamBody::new(stream);
debug!("length of streamed request: {}", len);
Ok(( Ok((
StatusCode::OK, StatusCode::OK,
[ [
@ -51,6 +83,9 @@ pub async fn pull_digest_get(Path((_name, layer_digest)): Path<(String, String)>
], ],
body body
).into_response()) ).into_response())
}
} else { } else {
Ok(StatusCode::NOT_FOUND.into_response()) Ok(StatusCode::NOT_FOUND.into_response())
} }

View File

@ -6,13 +6,12 @@ use axum::http::{StatusCode, HeaderName, header};
use tracing::log::warn; use tracing::log::warn;
use tracing::{debug, info}; use tracing::{debug, info};
use crate::auth::access_denied_response;
use crate::app_state::AppState; use crate::app_state::AppState;
use crate::database::Database; use crate::database::Database;
use crate::dto::RepositoryVisibility; use crate::dto::RepositoryVisibility;
use crate::dto::digest::Digest; use crate::dto::digest::Digest;
use crate::dto::manifest::Manifest; use crate::dto::manifest::Manifest;
use crate::dto::user::{UserAuth, Permission}; use crate::dto::user::UserAuth;
use crate::error::AppError; use crate::error::AppError;
pub async fn upload_manifest_put(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, auth: UserAuth, body: String) -> Result<Response, AppError> { pub async fn upload_manifest_put(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, auth: UserAuth, body: String) -> Result<Response, AppError> {
@ -20,10 +19,13 @@ pub async fn upload_manifest_put(Path((name, reference)): Path<(String, String)>
let calculated_hash = sha256::digest(body.clone()); let calculated_hash = sha256::digest(body.clone());
let calculated_digest = format!("sha256:{}", calculated_hash); let calculated_digest = format!("sha256:{}", calculated_hash);
// anonymous users wouldn't be able to get to this point, so it should be safe to unwrap.
let user = auth.user.unwrap();
let database = &state.database; let database = &state.database;
// Create the image repository and save the image manifest. This repository will be private by default // Create the image repository and save the image manifest. This repository will be private by default
database.save_repository(&name, RepositoryVisibility::Private, Some(auth.user.email), None).await?; database.save_repository(&name, RepositoryVisibility::Private, Some(user.email), None).await?;
database.save_manifest(&name, &calculated_digest, &body).await?; database.save_manifest(&name, &calculated_digest, &body).await?;
// If the reference is not a digest, then it must be a tag name. // If the reference is not a digest, then it must be a tag name.
@ -57,20 +59,7 @@ pub async fn upload_manifest_put(Path((name, reference)): Path<(String, String)>
} }
} }
pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, auth: Option<UserAuth>) -> Result<Response, AppError> { pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>) -> Result<Response, AppError> {
// Check if the user has permission to pull, or that the repository is public
if let Some(auth) = auth {
let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await? {
return Ok(access_denied_response(&state.config));
}
} else {
let database = &state.database;
if database.get_repository_visibility(&name).await? != Some(RepositoryVisibility::Public) {
return Ok(access_denied_response(&state.config));
}
}
let database = &state.database; let database = &state.database;
let digest = match Digest::is_digest(&reference) { let digest = match Digest::is_digest(&reference) {
true => reference.clone(), true => reference.clone(),
@ -93,6 +82,8 @@ pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>,
} }
let manifest_content = manifest_content.unwrap(); let manifest_content = manifest_content.unwrap();
debug!("Pulled manifest: {}", manifest_content);
Ok(( Ok((
StatusCode::OK, StatusCode::OK,
[ [
@ -106,21 +97,8 @@ pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>,
).into_response()) ).into_response())
} }
pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, auth: Option<UserAuth>) -> Result<Response, AppError> { pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>) -> Result<Response, AppError> {
// Check if the user has permission to pull, or that the repository is public debug!("start of head");
if let Some(auth) = auth {
let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await? {
return Ok(access_denied_response(&state.config));
}
drop(auth_driver);
} else {
let database = &state.database;
if database.get_repository_visibility(&name).await? != Some(RepositoryVisibility::Public) {
return Ok(access_denied_response(&state.config));
}
}
// Get the digest from the reference path. // Get the digest from the reference path.
let database = &state.database; let database = &state.database;
let digest = match Digest::is_digest(&reference) { let digest = match Digest::is_digest(&reference) {
@ -133,6 +111,7 @@ pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)
} }
} }
}; };
debug!("found digest: {}", digest);
let manifest_content = database.get_manifest(&name, &digest).await?; let manifest_content = database.get_manifest(&name, &digest).await?;
if manifest_content.is_none() { if manifest_content.is_none() {
@ -142,6 +121,8 @@ pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)
} }
let manifest_content = manifest_content.unwrap(); let manifest_content = manifest_content.unwrap();
debug!("got content");
Ok(( Ok((
StatusCode::OK, StatusCode::OK,
[ [
@ -154,13 +135,7 @@ pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)
).into_response()) ).into_response())
} }
pub async fn delete_manifest(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, auth: UserAuth) -> Result<Response, AppError> { pub async fn delete_manifest(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>) -> Result<Response, AppError> {
let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? {
return Ok(access_denied_response(&state.config));
}
drop(auth_driver);
let database = &state.database; let database = &state.database;
let digest = match Digest::is_digest(&reference) { let digest = match Digest::is_digest(&reference) {
true => { true => {

View File

@ -17,7 +17,7 @@ pub mod auth;
/// full endpoint: `/v2/` /// full endpoint: `/v2/`
pub async fn version_check(_state: State<Arc<AppState>>) -> Response { pub async fn version_check(_state: State<Arc<AppState>>) -> Response {
( (
StatusCode::UNAUTHORIZED, StatusCode::OK,
[ [
( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0" ), ( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0" ),
] ]

View File

@ -2,7 +2,7 @@ use std::collections::HashMap;
use std::io::ErrorKind; use std::io::ErrorKind;
use std::sync::Arc; use std::sync::Arc;
use axum::http::{StatusCode, header, HeaderName}; use axum::http::{StatusCode, header, HeaderName, HeaderMap};
use axum::extract::{Path, BodyStream, State, Query}; use axum::extract::{Path, BodyStream, State, Query};
use axum::response::{IntoResponse, Response}; use axum::response::{IntoResponse, Response};
@ -30,7 +30,7 @@ pub async fn start_upload_post(Path((name, )): Path<(String, )>) -> Result<Respo
).into_response()); ).into_response());
} }
pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String, String)>, state: State<Arc<AppState>>, mut body: BodyStream) -> Result<Response, AppError> { pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String, String)>, headers: HeaderMap, state: State<Arc<AppState>>, mut body: BodyStream) -> Result<Response, AppError> {
let storage = state.storage.lock().await; let storage = state.storage.lock().await;
let current_size = storage.digest_length(&layer_uuid).await?; let current_size = storage.digest_length(&layer_uuid).await?;
@ -65,18 +65,30 @@ pub async fn chunked_upload_layer_patch(Path((name, layer_uuid)): Path<(String,
} }
}; };
let (starting, ending) = if let Some(current_size) = current_size { let ending = if let Some(current_size) = current_size {
(current_size, current_size + written_size) current_size + written_size
} else { } else {
(0, written_size) written_size
}; };
if let Some(content_length) = headers.get(header::CONTENT_LENGTH) {
let content_length = content_length.to_str().map(|cl| cl.parse::<usize>());
if let Ok(Ok(content_length)) = content_length {
debug!("Client specified a content length of {}", content_length);
if content_length != written_size {
warn!("The content length that was received from the client did not match the amount written to disk!");
}
}
}
let full_uri = format!("{}/v2/{}/blobs/uploads/{}", state.config.url(), name, layer_uuid); let full_uri = format!("{}/v2/{}/blobs/uploads/{}", state.config.url(), name, layer_uuid);
Ok(( Ok((
StatusCode::ACCEPTED, StatusCode::ACCEPTED,
[ [
(header::LOCATION, full_uri), (header::LOCATION, full_uri),
(header::RANGE, format!("{}-{}", starting, ending)), (header::RANGE, format!("0-{}", ending - 1)),
(header::CONTENT_LENGTH, "0".to_string()), (header::CONTENT_LENGTH, "0".to_string()),
(HeaderName::from_static("docker-upload-uuid"), layer_uuid) (HeaderName::from_static("docker-upload-uuid"), layer_uuid)
] ]
@ -122,7 +134,8 @@ pub async fn check_upload_status_get(Path((name, layer_uuid)): Path<(String, Str
StatusCode::CREATED, StatusCode::CREATED,
[ [
(header::LOCATION, format!("/v2/{}/blobs/uploads/{}", name, layer_uuid)), (header::LOCATION, format!("/v2/{}/blobs/uploads/{}", name, layer_uuid)),
(header::RANGE, format!("0-{}", ending)), (header::RANGE, format!("0-{}", ending - 1)),
(header::CONTENT_LENGTH, "0".to_string()),
(HeaderName::from_static("docker-upload-digest"), layer_uuid) (HeaderName::from_static("docker-upload-digest"), layer_uuid)
] ]
).into_response()) ).into_response())

View File

@ -3,7 +3,7 @@ use ldap3::{LdapConnAsync, Ldap, Scope, SearchEntry};
use sqlx::{Pool, Sqlite}; use sqlx::{Pool, Sqlite};
use tracing::{debug, warn}; use tracing::{debug, warn};
use crate::{config::LdapConnectionConfig, dto::{user::{Permission, LoginSource, RegistryUserType, self}, RepositoryVisibility}, database::Database}; use crate::{config::LdapConnectionConfig, dto::{user::{Permission, LoginSource, RegistryUserType}, RepositoryVisibility}, database::Database};
use super::AuthDriver; use super::AuthDriver;
@ -100,7 +100,6 @@ impl AuthDriver for LdapAuthDriver {
}; };
database.create_user(email.clone(), display_name, LoginSource::LDAP).await?; database.create_user(email.clone(), display_name, LoginSource::LDAP).await?;
drop(database);
// Set the user registry type // Set the user registry type
let user_type = match self.is_user_admin(email.clone()).await? { let user_type = match self.is_user_admin(email.clone()).await? {

View File

@ -1,6 +1,6 @@
pub mod ldap_driver; pub mod ldap_driver;
use std::{ops::Deref, sync::Arc}; use std::sync::Arc;
use axum::{extract::State, http::{StatusCode, HeaderMap, header, HeaderName, Request, Method}, middleware::Next, response::{Response, IntoResponse}}; use axum::{extract::State, http::{StatusCode, HeaderMap, header, HeaderName, Request, Method}, middleware::Next, response::{Response, IntoResponse}};
@ -82,66 +82,15 @@ where
Ok(false) Ok(false)
} }
#[derive(Clone)]
pub struct AuthToken(pub String);
impl Deref for AuthToken {
type Target = String;
fn deref(&self) -> &Self::Target {
&self.0
}
}
type Rejection = (StatusCode, HeaderMap); type Rejection = (StatusCode, HeaderMap);
pub async fn require_auth<B>(State(state): State<Arc<AppState>>, mut request: Request<B>, next: Next<B>) -> Result<Response, Rejection> {
let bearer = format!("Bearer realm=\"{}/auth\"", state.config.url());
let mut failure_headers = HeaderMap::new();
failure_headers.append(header::WWW_AUTHENTICATE, bearer.parse().unwrap());
failure_headers.append(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".parse().unwrap());
let auth = String::from(
request.headers().get(header::AUTHORIZATION)
.ok_or((StatusCode::UNAUTHORIZED, failure_headers.clone()))?
.to_str()
.map_err(|_| (StatusCode::UNAUTHORIZED, failure_headers.clone()))?
);
let token = match auth.split_once(' ') {
Some((auth, token)) if auth == "Bearer" => token,
// This line would allow empty tokens
//_ if auth == "Bearer" => Ok(AuthToken(None)),
_ => return Err( (StatusCode::UNAUTHORIZED, failure_headers) ),
};
// If the token is not valid, return an unauthorized response
let database = &state.database;
if let Ok(Some(user)) = database.verify_user_token(token.to_string()).await {
debug!("Authenticated user through middleware: {}", user.user.username);
request.extensions_mut().insert(user);
Ok(next.run(request).await)
} else {
let bearer = format!("Bearer realm=\"{}/auth\"", state.config.url());
Ok((
StatusCode::UNAUTHORIZED,
[
( header::WWW_AUTHENTICATE, bearer ),
( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string() )
]
).into_response())
}
}
/// Creates a response with an Unauthorized (401) status code. /// Creates a response with an Unauthorized (401) status code.
/// The www-authenticate header is set to notify the client of where to authorize with. /// The www-authenticate header is set to notify the client of where to authorize with.
#[inline(always)] #[inline(always)]
pub fn auth_challenge_response(config: &Config, scope: Option<Scope>) -> Response { pub fn auth_challenge_response(config: &Config, scope: Option<Scope>) -> Response {
let bearer = match scope { let bearer = match scope {
Some(scope) => format!("Bearer realm=\"{}/auth\",scope=\"{}\"", config.url(), scope), Some(scope) => format!("Bearer realm=\"{}/token\",scope=\"{}\"", config.url(), scope),
None => format!("Bearer realm=\"{}/auth\"", config.url()) None => format!("Bearer realm=\"{}/token\"", config.url())
}; };
debug!("responding with www-authenticate header of: \"{}\"", bearer); debug!("responding with www-authenticate header of: \"{}\"", bearer);
@ -173,11 +122,19 @@ pub async fn check_auth<B>(State(state): State<Arc<AppState>>, auth: Option<User
// note: url is relative to /v2 // note: url is relative to /v2
let url = request.uri().to_string(); let url = request.uri().to_string();
if url == "/" && auth.is_none() { if url == "/" {
// if auth is none, then the client needs to authenticate
if auth.is_none() {
debug!("Responding to /v2/ with an auth challenge"); debug!("Responding to /v2/ with an auth challenge");
return Ok(auth_challenge_response(config, None)); return Ok(auth_challenge_response(config, None));
} }
debug!("user is authed");
// the client is authenticating right now
return Ok(next.run(request).await);
}
let url_split: Vec<&str> = url.split("/").skip(1).collect(); let url_split: Vec<&str> = url.split("/").skip(1).collect();
let target_name = url_split[0].replace("%2F", "/"); let target_name = url_split[0].replace("%2F", "/");
let target_type = url_split[1]; let target_type = url_split[1];
@ -216,7 +173,8 @@ pub async fn check_auth<B>(State(state): State<Arc<AppState>>, auth: Option<User
_ => None, _ => None,
}; };
match auth_checker.user_has_permission(auth.user.email.clone(), target_name.clone(), permission, vis).await { if let Some(user) = &auth.user {
match auth_checker.user_has_permission(user.email.clone(), target_name.clone(), permission, vis).await {
Ok(false) => return Ok(auth_challenge_response(config, Some(scope))), Ok(false) => return Ok(auth_challenge_response(config, Some(scope))),
Ok(true) => { }, Ok(true) => { },
Err(e) => { Err(e) => {
@ -225,6 +183,19 @@ pub async fn check_auth<B>(State(state): State<Arc<AppState>>, auth: Option<User
return Err((StatusCode::INTERNAL_SERVER_ERROR, HeaderMap::new())); return Err((StatusCode::INTERNAL_SERVER_ERROR, HeaderMap::new()));
}, },
} }
} else {
// anonymous users can ONLY pull from public repos
if permission != Permission::PULL {
return Ok(access_denied_response(config));
}
// ensure the repo is public
let database = &state.database;
if let Some(RepositoryVisibility::Private) = database.get_repository_visibility(&target_name).await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, HeaderMap::new()))? {
return Ok(access_denied_response(config));
}
}
} }
} else { } else {
warn!("Unhandled auth check for '{target_type}'!!"); // TODO warn!("Unhandled auth check for '{target_type}'!!"); // TODO

View File

@ -17,7 +17,7 @@ pin_project! {
#[allow(dead_code)] #[allow(dead_code)]
impl ByteStream { impl ByteStream {
/// Create a new `ByteStream` by wrapping a `futures` stream. /// Create a new `ByteStream` by wrapping a `futures` stream.
pub fn new<S>(stream: S) -> ByteStream pub fn new<S>(stream: S) -> Self
where where
S: Stream<Item = Result<Bytes, std::io::Error>> + Send + 'static, S: Stream<Item = Result<Bytes, std::io::Error>> + Send + 'static,
{ {
@ -27,6 +27,13 @@ impl ByteStream {
} }
} }
/// Recreate the ByteStream, skipping `n` elements
pub fn skip_recreate(mut self, n: usize) -> Self {
self.inner = Box::pin(self.inner.skip(n));
self
}
pub(crate) fn size_hint(&self) -> Option<usize> { pub(crate) fn size_hint(&self) -> Option<usize> {
self.size_hint self.size_hint
} }

View File

@ -1,4 +1,3 @@
use anyhow::anyhow;
use figment::{Figment, providers::{Env, Toml, Format}}; use figment::{Figment, providers::{Env, Toml, Format}};
use figment_cliarg_provider::FigmentCliArgsProvider; use figment_cliarg_provider::FigmentCliArgsProvider;
use serde::{Deserialize, Deserializer}; use serde::{Deserialize, Deserializer};
@ -60,17 +59,58 @@ pub enum DatabaseConfig {
Sqlite(SqliteDbConfig), Sqlite(SqliteDbConfig),
} }
#[derive(Deserialize, Clone, Default)]
#[serde(rename_all = "snake_case")]
pub enum LogFormat {
Human,
#[default]
Json,
}
#[derive(Deserialize, Clone, Default)]
#[serde(rename_all = "snake_case")]
pub enum RollPeriod {
Minutely,
Hourly,
#[default]
Daily,
Never,
}
#[derive(Deserialize, Clone)]
pub struct LogConfig {
/// The minimum level of logging
#[serde(deserialize_with = "serialize_log_level", default = "default_log_level")]
pub level: Level,
/// The path of the logging file
#[serde(default = "default_log_path")]
pub path: String,
/// The format of the produced logs
#[serde(default)]
pub format: LogFormat,
/// The roll period of the file
#[serde(default)]
pub roll_period: RollPeriod,
#[serde(default)]
pub extra_logging: bool,
pub env_filter: Option<String>,
}
#[derive(Deserialize, Clone)] #[derive(Deserialize, Clone)]
pub struct Config { pub struct Config {
pub listen_address: String, pub listen_address: String,
pub listen_port: String, pub listen_port: String,
url: Option<String>, url: Option<String>,
#[serde(deserialize_with = "serialize_log_level", default = "default_log_level")] pub registry_path: String,
pub log_level: Level, #[serde(default)]
pub extra_logging: bool,
pub log: LogConfig,
pub ldap: Option<LdapConnectionConfig>, pub ldap: Option<LdapConnectionConfig>,
pub database: DatabaseConfig, pub database: DatabaseConfig,
pub storage: StorageConfig, pub storage: StorageConfig,
pub tls: Option<TlsConfig>, pub tls: Option<TlsConfig>,
#[serde(skip)]
pub jwt_key: String,
} }
#[allow(dead_code)] #[allow(dead_code)]
@ -120,6 +160,10 @@ fn default_log_level() -> Level {
Level::INFO Level::INFO
} }
fn default_log_path() -> String {
"orca.log".to_string()
}
fn serialize_log_level<'de, D>(deserializer: D) -> Result<Level, D::Error> fn serialize_log_level<'de, D>(deserializer: D) -> Result<Level, D::Error>
where D: Deserializer<'de> { where D: Deserializer<'de> {
let s = String::deserialize(deserializer)?.to_lowercase(); let s = String::deserialize(deserializer)?.to_lowercase();

View File

@ -1,19 +1,21 @@
use async_trait::async_trait; use async_trait::async_trait;
use rand::{Rng, distributions::Alphanumeric};
use sqlx::{Sqlite, Pool}; use sqlx::{Sqlite, Pool};
use tracing::{debug, warn}; use tracing::{debug, warn};
use chrono::{DateTime, Utc, NaiveDateTime, TimeZone}; use chrono::{DateTime, Utc, NaiveDateTime};
use crate::dto::{Tag, user::{User, RepositoryPermissions, RegistryUserType, Permission, UserAuth, TokenInfo, LoginSource}, RepositoryVisibility}; use crate::dto::{Tag, user::{User, RepositoryPermissions, RegistryUserType, Permission, UserAuth, LoginSource}, RepositoryVisibility};
#[async_trait] #[async_trait]
pub trait Database { pub trait Database {
// Digest related functions // Digest related functions
/// Create the tables in the database /// Create the tables in the database
async fn create_schema(&self) -> anyhow::Result<()>; async fn create_schema(&self) -> anyhow::Result<()>;
async fn get_jwt_secret(&self) -> anyhow::Result<String>;
// Tag related functions // Tag related functions
/// Get tags associated with a repository /// Get tags associated with a repository
@ -61,20 +63,67 @@ pub trait Database {
async fn get_user_repo_permissions(&self, email: String, repository: String) -> anyhow::Result<Option<RepositoryPermissions>>; async fn get_user_repo_permissions(&self, email: String, repository: String) -> anyhow::Result<Option<RepositoryPermissions>>;
async fn get_user_registry_usertype(&self, email: String) -> anyhow::Result<Option<RegistryUserType>>; async fn get_user_registry_usertype(&self, email: String) -> anyhow::Result<Option<RegistryUserType>>;
async fn store_user_token(&self, token: String, email: String, expiry: DateTime<Utc>, created_at: DateTime<Utc>) -> anyhow::Result<()>; async fn store_user_token(&self, token: String, email: String, expiry: DateTime<Utc>, created_at: DateTime<Utc>) -> anyhow::Result<()>;
#[deprecated = "Tokens are now verified using a secret"]
async fn verify_user_token(&self, token: String) -> anyhow::Result<Option<UserAuth>>; async fn verify_user_token(&self, token: String) -> anyhow::Result<Option<UserAuth>>;
} }
#[async_trait] #[async_trait]
impl Database for Pool<Sqlite> { impl Database for Pool<Sqlite> {
async fn create_schema(&self) -> anyhow::Result<()> { async fn create_schema(&self) -> anyhow::Result<()> {
let orca_version = "0.1.0";
let schema_version = "0.0.1";
let row: Option<(u32, )> = match sqlx::query_as("SELECT COUNT(1) FROM orca WHERE \"schema_version\" = ?")
.bind(schema_version)
.fetch_one(self).await {
Ok(row) => Some(row),
Err(e) => match e {
sqlx::Error::RowNotFound => {
None
},
// ignore no such table errors
sqlx::Error::Database(b) if b.message().starts_with("no such table") => None,
_ => {
return Err(anyhow::Error::new(e));
}
}
};
sqlx::query(include_str!("schemas/schema.sql")) sqlx::query(include_str!("schemas/schema.sql"))
.execute(self).await?; .execute(self).await?;
debug!("Created database schema"); debug!("Created database schema");
if row.is_none() || row.unwrap().0 == 0 {
let jwt_sec: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(16)
.map(char::from)
.collect();
// create schema
// TODO: Check if needed
/* sqlx::query(include_str!("schemas/schema.sql"))
.execute(self).await?;
debug!("Created database schema"); */
sqlx::query("INSERT INTO orca(orca_version, schema_version, jwt_secret) VALUES (?, ?, ?)")
.bind(orca_version)
.bind(schema_version)
.bind(jwt_sec)
.execute(self).await?;
debug!("Inserted information about orca!");
}
Ok(()) Ok(())
} }
async fn get_jwt_secret(&self) -> anyhow::Result<String> {
let rows: (String, ) = sqlx::query_as("SELECT jwt_secret FROM orca WHERE id = (SELECT max(id) FROM orca)")
.fetch_one(self).await?;
Ok(rows.0)
}
async fn link_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> anyhow::Result<()> { async fn link_manifest_layer(&self, manifest_digest: &str, layer_digest: &str) -> anyhow::Result<()> {
sqlx::query("INSERT INTO manifest_layers(manifest, layer_digest) VALUES (?, ?)") sqlx::query("INSERT INTO manifest_layers(manifest, layer_digest) VALUES (?, ?)")
.bind(manifest_digest) .bind(manifest_digest)
@ -369,6 +418,7 @@ impl Database for Pool<Sqlite> {
} }
async fn get_user(&self, email: String) -> anyhow::Result<Option<User>> { async fn get_user(&self, email: String) -> anyhow::Result<Option<User>> {
debug!("getting user");
let email = email.to_lowercase(); let email = email.to_lowercase();
let row: (String, u32) = match sqlx::query_as("SELECT username, login_source FROM users WHERE email = ?") let row: (String, u32) = match sqlx::query_as("SELECT username, login_source FROM users WHERE email = ?")
.bind(email.clone()) .bind(email.clone())
@ -513,50 +563,7 @@ impl Database for Pool<Sqlite> {
Ok(()) Ok(())
} }
async fn verify_user_token(&self, token: String) -> anyhow::Result<Option<UserAuth>> { async fn verify_user_token(&self, _token: String) -> anyhow::Result<Option<UserAuth>> {
let token_row: (String, i64, i64,) = match sqlx::query_as("SELECT email, expiry, created_at FROM user_tokens WHERE token = ?") panic!("ERR: Database::verify_user_token is deprecated!")
.bind(token.clone())
.fetch_one(self).await {
Ok(row) => row,
Err(e) => match e {
sqlx::Error::RowNotFound => {
return Ok(None)
},
_ => {
return Err(anyhow::Error::new(e));
}
}
};
let (email, expiry, created_at) = (token_row.0, token_row.1, token_row.2);
let user_row: (String, u32) = match sqlx::query_as("SELECT username, login_source FROM users WHERE email = ?")
.bind(email.clone())
.fetch_one(self).await {
Ok(row) => row,
Err(e) => match e {
sqlx::Error::RowNotFound => {
return Ok(None)
},
_ => {
return Err(anyhow::Error::new(e));
}
}
};
/* let user_row: (String, u32) = sqlx::query_as("SELECT email, login_source FROM users WHERE email = ?")
.bind(email.clone())
.fetch_one(self).await?; */
let (expiry, created_at) = (Utc.timestamp_millis_opt(expiry).single(), Utc.timestamp_millis_opt(created_at).single());
if let (Some(expiry), Some(created_at)) = (expiry, created_at) {
let user = User::new(user_row.0, email, LoginSource::try_from(user_row.1)?);
let token = TokenInfo::new(token, expiry, created_at);
let auth = UserAuth::new(user, token);
Ok(Some(auth))
} else {
Ok(None)
}
} }
} }

View File

@ -1,3 +1,10 @@
CREATE TABLE IF NOT EXISTS orca (
id INTEGER PRIMARY KEY AUTOINCREMENT,
orca_version TEXT NOT NULL,
schema_version TEXT NOT NULL,
jwt_secret TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS projects ( CREATE TABLE IF NOT EXISTS projects (
name TEXT NOT NULL UNIQUE PRIMARY KEY, name TEXT NOT NULL UNIQUE PRIMARY KEY,
-- 0 = private, 1 = public -- 0 = private, 1 = public

View File

@ -1,12 +1,13 @@
use anyhow::anyhow; use anyhow::anyhow;
use serde::{Deserialize, de::Visitor}; use serde::{Deserialize, Serialize};
use std::fmt; use std::fmt;
#[derive(Default, Debug, Clone)] #[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum ScopeType { pub enum ScopeType {
#[default] #[default]
Unknown, Unknown,
#[serde(rename = "repository")]
Repository, Repository,
} }
@ -19,11 +20,13 @@ impl fmt::Display for ScopeType {
} }
} }
#[derive(Default, Debug, Clone)] #[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum Action { pub enum Action {
#[default] #[default]
None, None,
#[serde(rename = "push")]
Push, Push,
#[serde(rename = "pull")]
Pull, Pull,
} }
@ -37,18 +40,19 @@ impl fmt::Display for Action {
} }
} }
#[derive(Default, Debug, Clone)] #[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Scope { pub struct Scope {
scope_type: ScopeType, #[serde(rename = "type")]
path: String, pub scope_type: ScopeType,
actions: Vec<Action>, pub name: String,
pub actions: Vec<Action>,
} }
impl Scope { impl Scope {
pub fn new(scope_type: ScopeType, path: String, actions: &[Action]) -> Self { pub fn new(scope_type: ScopeType, path: String, actions: &[Action]) -> Self {
Self { Self {
scope_type, scope_type,
path, name: path,
actions: actions.to_vec(), actions: actions.to_vec(),
} }
} }
@ -62,7 +66,7 @@ impl fmt::Display for Scope {
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(","); .join(",");
write!(f, "{}:{}:{}", self.scope_type, self.path, actions) write!(f, "{}:{}:{}", self.scope_type, self.name, actions)
} }
} }
@ -93,7 +97,7 @@ impl TryFrom<&str> for Scope {
Ok(Scope { Ok(Scope {
scope_type, scope_type,
path: String::from(path), name: String::from(path),
actions actions
}) })
} else { } else {
@ -102,67 +106,3 @@ impl TryFrom<&str> for Scope {
} }
} }
} }
pub struct ScopeVisitor {
}
impl<'de> Visitor<'de> for ScopeVisitor {
type Value = Scope;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a Scope in the format of `repository:samalba/my-app:pull,push`.")
}
fn visit_str<E>(self, val: &str) -> Result<Self::Value, E>
where
E: serde::de::Error {
println!("Start of visit_str!");
let res = match Scope::try_from(val) {
Ok(val) => Ok(val),
Err(e) => Err(serde::de::Error::custom(format!("{}", e)))
};
res
/* let splits: Vec<&str> = val.split(":").collect();
if splits.len() == 3 {
let scope_type = match splits[0] {
"repository" => ScopeType::Repository,
_ => {
return Err(serde::de::Error::custom(format!("Invalid scope type: `{}`!", splits[0])));
}
};
let path = splits[1];
let actions: Result<Vec<Action>, E> = splits[2]
.split(",")
.map(|a| match a {
"pull" => Ok(Action::Pull),
"push" => Ok(Action::Push),
_ => Err(serde::de::Error::custom(format!("Invalid action: `{}`!", a))),
}).collect();
let actions = actions?;
Ok(Scope {
scope_type,
path: String::from(path),
actions
})
} else {
Err(serde::de::Error::custom("Malformed scope string!"))
} */
}
}
impl<'de> Deserialize<'de> for Scope {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de> {
deserializer.deserialize_str(ScopeVisitor {})
}
}

View File

@ -4,11 +4,15 @@ use async_trait::async_trait;
use axum::{http::{StatusCode, header, HeaderName, HeaderMap, request::Parts}, extract::FromRequestParts}; use axum::{http::{StatusCode, header, HeaderName, HeaderMap, request::Parts}, extract::FromRequestParts};
use bitflags::bitflags; use bitflags::bitflags;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use hmac::{Hmac, digest::KeyInit};
use jwt::VerifyWithKey;
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use tracing::debug; use tracing::debug;
use crate::{app_state::AppState, database::Database}; use crate::{app_state::AppState, database::Database};
use super::RepositoryVisibility; use super::{RepositoryVisibility, scope::Scope};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum LoginSource { pub enum LoginSource {
@ -45,6 +49,50 @@ impl User {
} }
} }
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct AuthToken {
#[serde(rename = "iss")]
pub issuer: String,
#[serde(rename = "sub")]
pub subject: String,
#[serde(rename = "aud")]
pub audience: String,
#[serde(rename = "exp")]
#[serde(with = "chrono::serde::ts_seconds")]
pub expiration: DateTime<Utc>,
#[serde(rename = "nbf")]
#[serde(with = "chrono::serde::ts_seconds")]
pub not_before: DateTime<Utc>,
#[serde(rename = "iat")]
#[serde(with = "chrono::serde::ts_seconds")]
pub issued_at: DateTime<Utc>,
#[serde(rename = "jti")]
pub jwt_id: String,
pub access: Vec<Scope>,
}
impl AuthToken {
pub fn new(issuer: String, subject: String, audience: String, expiration: DateTime<Utc>, not_before: DateTime<Utc>, issued_at: DateTime<Utc>, jwt_id: String, access: Vec<Scope>) -> Self {
Self {
issuer,
subject,
audience,
expiration,
not_before,
issued_at,
jwt_id,
access
}
}
}
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub struct TokenInfo { pub struct TokenInfo {
pub token: String, pub token: String,
@ -64,12 +112,12 @@ impl TokenInfo {
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub struct UserAuth { pub struct UserAuth {
pub user: User, pub user: Option<User>,
pub token: TokenInfo, pub token: AuthToken,
} }
impl UserAuth { impl UserAuth {
pub fn new(user: User, token: TokenInfo) -> Self { pub fn new(user: Option<User>, token: AuthToken) -> Self {
Self { Self {
user, user,
token, token,
@ -82,13 +130,11 @@ impl FromRequestParts<Arc<AppState>> for UserAuth {
type Rejection = (StatusCode, HeaderMap); type Rejection = (StatusCode, HeaderMap);
async fn from_request_parts(parts: &mut Parts, state: &Arc<AppState>) -> Result<Self, Self::Rejection> { async fn from_request_parts(parts: &mut Parts, state: &Arc<AppState>) -> Result<Self, Self::Rejection> {
let bearer = format!("Bearer realm=\"{}/auth\"", state.config.url()); let bearer = format!("Bearer realm=\"{}/token\"", state.config.url());
let mut failure_headers = HeaderMap::new(); let mut failure_headers = HeaderMap::new();
failure_headers.append(header::WWW_AUTHENTICATE, bearer.parse().unwrap()); failure_headers.append(header::WWW_AUTHENTICATE, bearer.parse().unwrap());
failure_headers.append(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".parse().unwrap()); failure_headers.append(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".parse().unwrap());
debug!("starting UserAuth request parts");
let auth = String::from( let auth = String::from(
parts.headers parts.headers
.get(header::AUTHORIZATION) .get(header::AUTHORIZATION)
@ -97,8 +143,6 @@ impl FromRequestParts<Arc<AppState>> for UserAuth {
.map_err(|_| (StatusCode::UNAUTHORIZED, failure_headers.clone()))? .map_err(|_| (StatusCode::UNAUTHORIZED, failure_headers.clone()))?
); );
debug!("got auth header");
let token = match auth.split_once(' ') { let token = match auth.split_once(' ') {
Some((auth, token)) if auth == "Bearer" => token, Some((auth, token)) if auth == "Bearer" => token,
// This line would allow empty tokens // This line would allow empty tokens
@ -106,15 +150,36 @@ impl FromRequestParts<Arc<AppState>> for UserAuth {
_ => return Err( (StatusCode::UNAUTHORIZED, failure_headers) ), _ => return Err( (StatusCode::UNAUTHORIZED, failure_headers) ),
}; };
debug!("got token");
// If the token is not valid, return an unauthorized response // If the token is not valid, return an unauthorized response
let database = &state.database; let jwt_key: Hmac<Sha256> = Hmac::new_from_slice(state.config.jwt_key.as_bytes())
if let Ok(Some(user)) = database.verify_user_token(token.to_string()).await { .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, HeaderMap::new()) )?;
debug!("Authenticated user through request extractor: {}", user.user.username);
Ok(user) match VerifyWithKey::<AuthToken>::verify_with_key(token, &jwt_key) {
Ok(token) => {
// attempt to get the user
if !token.subject.is_empty() {
let database = &state.database;
if let Ok(Some(user)) = database.get_user(token.subject.clone()).await {
return Ok(UserAuth::new(Some(user), token));
} else { } else {
debug!("failure to get user from token: {:?}", token);
}
} else {
return Ok(UserAuth::new(None, token));
}
/* let database = &state.database;
if let Ok(user) = database.get_user(token.subject.clone()).await {
return Ok(UserAuth::new(user, token));
} else {
debug!("failure to get user from token: {:?}", token);
} */
},
Err(e) => {
debug!("Failure to verify user token: '{}'", e);
}
}
debug!("Failure to verify user token, responding with auth realm"); debug!("Failure to verify user token, responding with auth realm");
Err(( Err((
@ -122,7 +187,6 @@ impl FromRequestParts<Arc<AppState>> for UserAuth {
failure_headers failure_headers
)) ))
} }
}
} }
bitflags! { bitflags! {

View File

@ -8,8 +8,9 @@ mod config;
mod auth; mod auth;
mod error; mod error;
use std::fs;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::Path; use std::path::{Path, PathBuf};
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
@ -28,10 +29,13 @@ use tower_layer::Layer;
use sqlx::sqlite::{SqlitePoolOptions, SqliteConnectOptions, SqliteJournalMode}; use sqlx::sqlite::{SqlitePoolOptions, SqliteConnectOptions, SqliteJournalMode};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tower_http::normalize_path::NormalizePathLayer; use tower_http::normalize_path::NormalizePathLayer;
use tracing::metadata::LevelFilter;
use tracing::{debug, info}; use tracing::{debug, info};
use app_state::AppState; use app_state::AppState;
use database::Database; use database::Database;
use tracing_subscriber::{filter, EnvFilter};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use crate::storage::StorageDriver; use crate::storage::StorageDriver;
use crate::storage::filesystem::FilesystemDriver; use crate::storage::filesystem::FilesystemDriver;
@ -67,31 +71,127 @@ async fn change_request_paths<B>(mut request: Request<B>, next: Next<B>) -> Resu
Ok(next.run(request).await) Ok(next.run(request).await)
} }
fn path_relative_to(registry_path: &str, other_path: &str) -> PathBuf {
let other = PathBuf::from(other_path);
if other.is_absolute() {
other
} else {
PathBuf::from(registry_path).join(other)
}
}
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
let config = Config::new() let mut config = Config::new()
.expect("Failure to parse config!"); .expect("Failure to parse config!");
tracing_subscriber::fmt() // Create registry directory if it doesn't exist
.with_max_level(config.log_level) if !Path::new(&config.registry_path).exists() {
fs::create_dir_all(&config.registry_path)?;
}
let mut logging_guards = Vec::new();
{
let logc = &config.log;
// Create log directory if it doesn't exist
let log_path = path_relative_to(&config.registry_path, &logc.path);
if !log_path.exists() {
fs::create_dir_all(&log_path)?;
}
// Get a rolling file appender depending on the config
let file_appender = match logc.roll_period {
config::RollPeriod::Minutely => tracing_appender::rolling::minutely(log_path, "orca.log"),
config::RollPeriod::Hourly => tracing_appender::rolling::hourly(log_path, "orca.log"),
config::RollPeriod::Daily => tracing_appender::rolling::daily(log_path, "orca.log"),
config::RollPeriod::Never => tracing_appender::rolling::never(log_path, "orca.log"),
};
// Create non blocking loggers
let (file_appender_nb, _file_guard) = tracing_appender::non_blocking(file_appender);
let (stdout_nb, _stdout_guard) = tracing_appender::non_blocking(std::io::stdout());
logging_guards.push(_file_guard);
logging_guards.push(_stdout_guard);
// TODO: Is there a way for this to be less ugly?
// Get json or text layers
let (json_a, json_b, plain_a, plain_b) = match logc.format {
config::LogFormat::Json => (
Some(
tracing_subscriber::fmt::layer()
.with_writer(file_appender_nb)
.json()
),
Some(
tracing_subscriber::fmt::layer()
.with_writer(stdout_nb)
.json()
),
None,
None
),
config::LogFormat::Human => (
None,
None,
Some(
tracing_subscriber::fmt::layer()
.with_writer(file_appender_nb)
),
Some(
tracing_subscriber::fmt::layer()
.with_writer(stdout_nb)
)
)
};
// Change filter to only log orca_registry or everything
let targets_filter = if logc.extra_logging {
filter::Targets::new()
.with_default(logc.level)
} else {
filter::Targets::new()
.with_target("orca_registry", logc.level)
.with_default(LevelFilter::INFO)
};
// Get env filter if specified
let env_filter = if let Some(env_filter) = &logc.env_filter {
Some(EnvFilter::from_str(env_filter).unwrap())
} else { None };
tracing_subscriber::registry()
.with(json_a)
.with(json_b)
.with(plain_a)
.with(plain_b)
.with(targets_filter)
.with(env_filter)
.init(); .init();
}
let sqlite_config = match &config.database { let sqlite_config = match &config.database {
DatabaseConfig::Sqlite(sqlite) => sqlite, DatabaseConfig::Sqlite(sqlite) => sqlite,
}; };
// Create a database file if it doesn't exist already // Create a database file if it doesn't exist already
if !Path::new(&sqlite_config.path).exists() { let sqlite_path = path_relative_to(&config.registry_path, &sqlite_config.path);
debug!("sqlite path: {:?}", sqlite_path);
if !Path::new(&sqlite_path).exists() {
File::create(&sqlite_config.path).await?; File::create(&sqlite_config.path).await?;
} }
let connection_options = SqliteConnectOptions::from_str(&format!("sqlite://{}", &sqlite_config.path))? let connection_options = SqliteConnectOptions::from_str(&format!("sqlite://{}", sqlite_path.as_os_str().to_str().unwrap()))?
.journal_mode(SqliteJournalMode::Wal); .journal_mode(SqliteJournalMode::Wal);
let pool = SqlitePoolOptions::new() let pool = SqlitePoolOptions::new()
.max_connections(15) .max_connections(15)
.connect_with(connection_options).await?; .connect_with(connection_options).await?;
pool.create_schema().await?; pool.create_schema().await?;
// set jwt key
config.jwt_key = pool.get_jwt_secret().await?;
let storage_driver: Mutex<Box<dyn StorageDriver>> = match &config.storage { let storage_driver: Mutex<Box<dyn StorageDriver>> = match &config.storage {
StorageConfig::Filesystem(fs) => { StorageConfig::Filesystem(fs) => {
Mutex::new(Box::new(FilesystemDriver::new(&fs.path))) Mutex::new(Box::new(FilesystemDriver::new(&fs.path)))
@ -113,15 +213,15 @@ async fn main() -> anyhow::Result<()> {
let app_addr = SocketAddr::from_str(&format!("{}:{}", config.listen_address, config.listen_port))?; let app_addr = SocketAddr::from_str(&format!("{}:{}", config.listen_address, config.listen_port))?;
let tls_config = config.tls.clone(); let tls_config = config.tls.clone();
let registry_path = config.registry_path.clone();
let state = Arc::new(AppState::new(pool, storage_driver, config, auth_driver)); let state = Arc::new(AppState::new(pool, storage_driver, config, auth_driver));
//let auth_middleware = axum::middleware::from_fn_with_state(state.clone(), auth::require_auth);
let auth_middleware = axum::middleware::from_fn_with_state(state.clone(), auth::check_auth); let auth_middleware = axum::middleware::from_fn_with_state(state.clone(), auth::check_auth);
let path_middleware = axum::middleware::from_fn(change_request_paths); let path_middleware = axum::middleware::from_fn(change_request_paths);
let app = Router::new() let app = Router::new()
.route("/auth", routing::get(api::auth::auth_basic_get) .route("/token", routing::get(api::auth::auth_basic_get)
.post(api::auth::auth_basic_get)) .post(api::auth::auth_basic_post))
.nest("/v2", Router::new() .nest("/v2", Router::new()
.route("/", routing::get(api::version_check)) .route("/", routing::get(api::version_check))
.route("/_catalog", routing::get(api::catalog::list_repositories)) .route("/_catalog", routing::get(api::catalog::list_repositories))
@ -154,7 +254,9 @@ async fn main() -> anyhow::Result<()> {
Some(tls) if tls.enable => { Some(tls) if tls.enable => {
info!("Starting https server, listening on {}", app_addr); info!("Starting https server, listening on {}", app_addr);
let config = RustlsConfig::from_pem_file(&tls.cert, &tls.key).await?; let cert_path = path_relative_to(&registry_path, &tls.cert);
let key_path = path_relative_to(&registry_path, &tls.key);
let config = RustlsConfig::from_pem_file(&cert_path, &key_path).await?;
axum_server::bind_rustls(app_addr, config) axum_server::bind_rustls(app_addr, config)
.serve(layered_app.into_make_service()) .serve(layered_app.into_make_service())