diff --git a/Cargo.lock b/Cargo.lock index 4948682..5ca73c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -143,6 +143,18 @@ dependencies = [ "tower-service", ] +[[package]] +name = "axum-auth" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "620b37645b77baab8160f93421568d7b3dd25da0a160fab38eb1c4ef611f6d98" +dependencies = [ + "async-trait", + "axum-core", + "base64 0.13.1", + "http", +] + [[package]] name = "axum-core" version = "0.3.4" @@ -440,40 +452,6 @@ dependencies = [ "subtle", ] -[[package]] -name = "docker-registry" -version = "0.1.0" -dependencies = [ - "anyhow", - "argmap", - "async-stream", - "async-trait", - "axum", - "axum-macros", - "bytes", - "chrono", - "clap", - "figment", - "figment-cliarg-provider", - "futures", - "jws", - "pin-project-lite", - "regex", - "serde", - "serde_json", - "sha256", - "sqlx", - "tokio", - "tokio-util", - "tower-http", - "tower-layer", - "tracing", - "tracing-log", - "tracing-subscriber", - "uuid", - "wild", -] - [[package]] name = "dotenvy" version = "0.15.7" @@ -882,6 +860,21 @@ dependencies = [ "sha2 0.10.6", ] +[[package]] +name = "jwt" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6204285f77fe7d9784db3fdc449ecce1a0114927a51d5a41c4c7a292011c015f" +dependencies = [ + "base64 0.13.1", + "crypto-common", + "digest 0.10.5", + "hmac", + "serde", + "serde_json", + "sha2 0.10.6", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -1030,6 +1023,47 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" +[[package]] +name = "orca-registry" +version = "0.1.0" +dependencies = [ + "anyhow", + "argmap", + "async-stream", + "async-trait", + "axum", + "axum-auth", + "axum-macros", + "bytes", + "chrono", + "clap", + "figment", + "figment-cliarg-provider", + "futures", + "hmac", + "jws", + "jwt", + "pin-project-lite", + "qstring", + "rand", + "regex", + "serde", + "serde_json", + "serde_qs", + "sha2 0.10.6", + "sha256", + "sqlx", + "tokio", + "tokio-util", + "tower-http", + "tower-layer", + "tracing", + "tracing-log", + "tracing-subscriber", + "uuid", + "wild", +] + [[package]] name = "os_str_bytes" version = "6.4.0" @@ -1192,6 +1226,15 @@ dependencies = [ "yansi", ] +[[package]] +name = "qstring" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d464fae65fff2680baf48019211ce37aaec0c78e9264c84a3e484717f965104e" +dependencies = [ + "percent-encoding", +] + [[package]] name = "quote" version = "0.3.15" @@ -1373,6 +1416,17 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_qs" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0431a35568651e363364210c91983c1da5eb29404d9f0928b67d4ebcfa7d330c" +dependencies = [ + "percent-encoding", + "serde", + "thiserror", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" diff --git a/Cargo.toml b/Cargo.toml index 4196336..8eb738b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "docker-registry" +name = "orca-registry" version = "0.1.0" edition = "2021" @@ -14,7 +14,6 @@ uuid = { version = "1.3.1", features = [ "v4", "fast-rng" ] } sqlx = { version = "0.6.3", features = [ "runtime-tokio-rustls", "sqlite" ] } bytes = "1.4.0" - chrono = "0.4.23" tokio = { version = "1.21.2", features = [ "fs", "macros" ] } tokio-util = { version = "0.7.7", features = [ "io" ] } @@ -43,3 +42,10 @@ figment = { version = "0.10", features = ["toml", "env"] } figment-cliarg-provider = { git = "https://github.com/SeanOMik/figment-cliarg-provider.git", branch = "main" } wild = "2.1.0" argmap = "1.1.2" +serde_qs = "0.12.0" +axum-auth = "0.4.0" +qstring = "0.7.2" +jwt = "0.16.0" +hmac = "0.12.1" +sha2 = "0.10.6" +rand = "0.8.5" diff --git a/src/api/auth.rs b/src/api/auth.rs new file mode 100644 index 0000000..67d2deb --- /dev/null +++ b/src/api/auth.rs @@ -0,0 +1,195 @@ +use std::{sync::Arc, collections::{HashMap, BTreeMap}, time::{SystemTime, UNIX_EPOCH}}; + +use axum::{extract::{Query, State}, response::{IntoResponse, Response}, http::{StatusCode, Request, Method, HeaderName, header}, Form}; +use axum_auth::AuthBasic; +use chrono::{DateTime, Utc}; +use qstring::QString; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use tracing::{debug, error, info, span, Level}; + +use hmac::{Hmac, Mac}; +use jwt::SignWithKey; +use sha2::Sha256; + +use rand::Rng; + +use crate::{dto::scope::Scope, app_state::AppState, query::Qs}; + +#[derive(Deserialize, Debug)] +pub struct TokenAuthRequest { + user: Option, + password: Option, + account: Option, + /// The name of the service which hosts the resource. + /// I don't think this is necessary since the auth service is embedded with the registry. + pub service: Option, + pub scope: Vec, + offline_token: Option, + client_id: Option, +} + +#[derive(Deserialize, Debug)] +pub struct AuthForm { + username: String, + password: String, +} + +#[derive(Deserialize, Serialize, Debug)] +pub struct AuthResponse { + token: String, + expires_in: u32, + issued_at: String, +} + +fn create_jwt_token(account: String) -> anyhow::Result { + let key: Hmac = Hmac::new_from_slice(b"some-secret")?; + + let now = SystemTime::now(); + let now_secs = now + .duration_since(UNIX_EPOCH)? + .as_secs(); + + // Construct the claims for the token + let mut claims = BTreeMap::new(); + claims.insert("issuer", "orca-registry__DEV"); + claims.insert("subject", &account); + //claims.insert("audience", auth.service); + + let notbefore = format!("{}", now_secs - 10); + let issuedat = format!("{}", now_secs); + let expiration = format!("{}", now_secs + 20); + claims.insert("notbefore", ¬before); + claims.insert("issuedat", &issuedat); + claims.insert("expiration", &expiration); // TODO: 20 seconds expiry for testing + + // Create a randomized jwtid + let mut rng = rand::thread_rng(); + let jwtid = format!("{}", rng.gen::()); + claims.insert("jwtid", &jwtid); + + Ok(claims.sign_with_key(&key)?) +} + +pub async fn auth_basic_get(basic_auth: Option, state: State>, Query(params): Query>, form: Option>) -> Response { + let mut auth = TokenAuthRequest { + user: None, + password: None, + account: None, + service: None, + scope: Vec::new(), + offline_token: None, + client_id: None, + }; + + let auth_method; + + // If BasicAuth is provided, set the fields to it + if let Some(AuthBasic((username, pass))) = basic_auth { + auth.user = Some(username.clone()); + auth.password = pass; + + // I hate having to create this span here multiple times, but its the only + // way I could think of + /* let span = span!(Level::DEBUG, "auth", username = auth.user.clone()); + let _enter = span.enter(); + debug!("Read user authentication from an AuthBasic"); */ + + auth_method = "basic-auth"; + } + // Username and password could be passed in forms + // If there was a way to also check if the Method was "POST", this is where + // we would do it. + else if let Some(Form(form)) = form { + auth.user = Some(form.username.clone()); + auth.password = Some(form.password); + + let span = span!(Level::DEBUG, "auth", username = auth.user.clone()); + let _enter = span.enter(); + debug!("Read user authentication from a Form"); + + auth_method = "form"; + } else { + info!("Auth failure! Auth was not provided in either AuthBasic or Form!"); + + // Maybe BAD_REQUEST should be returned? + return (StatusCode::UNAUTHORIZED).into_response(); + } + + // Create logging span for the rest of this request + let span = span!(Level::DEBUG, "auth", username = auth.user.clone(), auth_method); + let _enter = span.enter(); + + debug!("Parsed user auth request"); + + // Get account from query string, if its specified, ensure that its the same as the user if + // that is also specified. + if let Some(account) = params.get("account") { + if let Some(user) = &auth.user { + if account != user { + error!("`user` and `account` are not the same!!! (user: {}, account: {})", user, account); + + return (StatusCode::BAD_REQUEST).into_response(); + } + } + + auth.account = Some(account.clone()); + } + + // Get service from query string + if let Some(service) = params.get("service") { + auth.service = Some(service.clone()); + } + + // Process all the scopes + if let Some(scope) = params.get("scope") { + // TODO: Handle multiple scopes + auth.scope.push(Scope::try_from(&scope[..]).unwrap()); + } + + // Get offline token and attempt to convert it to a boolean + if let Some(offline_token) = params.get("offline_token") { + if let Ok(b) = offline_token.parse::() { + auth.offline_token = Some(b); + } + } + + if let Some(client_id) = params.get("client_id") { + auth.client_id = Some(client_id.clone()); + } + + debug!("Constructed auth request"); + + if let Some(account) = auth.account { + let now = SystemTime::now(); + let token_str = create_jwt_token(account).unwrap(); + + debug!("Created jwt token"); + + // ISO8601 time format + let now_dt: DateTime = now.into(); + let now_format = format!("{}", now_dt.format("%+")); + + // Construct the auth response + let auth_response = AuthResponse { + token: token_str.clone(), + expires_in: 20, + issued_at: now_format, + }; + + let json_str = serde_json::to_string(&auth_response).unwrap(); + + return ( + StatusCode::OK, + [ + ( header::CONTENT_TYPE, "application/json" ), + ( header::AUTHORIZATION, &format!("Bearer {}", token_str) ) + ], + json_str + ).into_response(); + } + + info!("Auth failure! Not enough information given to create auth token!"); + // If we didn't get fields required to make a token, then the client did something bad + (StatusCode::UNAUTHORIZED).into_response() +} \ No newline at end of file diff --git a/src/api/mod.rs b/src/api/mod.rs index 50ead62..82df8cd 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,17 +1,39 @@ -use axum::response::IntoResponse; -use axum::http::{StatusCode, HeaderName}; +use axum::extract::Query; +use axum::response::{IntoResponse, Response}; +use axum::http::{StatusCode, HeaderName, header}; +use tracing::debug; + +use self::auth::TokenAuthRequest; pub mod blobs; pub mod uploads; pub mod manifests; pub mod tags; pub mod catalog; +pub mod auth; /// https://docs.docker.com/registry/spec/api/#api-version-check /// full endpoint: `/v2/` -pub async fn version_check() -> impl IntoResponse { - ( +pub async fn version_check(params: Option>, body: String) -> Response { + debug!("Got body: {}", body); + + /* ( StatusCode::OK, [( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0" )] - ) + ) */ + + //Www-Authenticate: Bearer realm="https://auth.docker.io/token",service="registry.docker.io",scope="repository:samalba/my-app:pull,push" + + let bearer = format!("Bearer realm=\"http://localhost:3000/auth\"");/* match params { + Some(Query(params)) => format!("Bearer realm=\"http://localhost:3000/token\",scope=\"{}\"", params.scope), + None => format!("Bearer realm=\"http://localhost:3000/token\""), + }; */ + + ( + StatusCode::UNAUTHORIZED, + [ + ( header::WWW_AUTHENTICATE, bearer ), + ( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string() ) + ] + ).into_response() } \ No newline at end of file diff --git a/src/dto/mod.rs b/src/dto/mod.rs index 443e5b5..d12df3d 100644 --- a/src/dto/mod.rs +++ b/src/dto/mod.rs @@ -2,6 +2,7 @@ use chrono::{DateTime, Utc}; pub mod manifest; pub mod digest; +pub mod scope; #[derive(Debug)] pub struct Tag { diff --git a/src/dto/scope.rs b/src/dto/scope.rs new file mode 100644 index 0000000..10c338b --- /dev/null +++ b/src/dto/scope.rs @@ -0,0 +1,158 @@ +use anyhow::anyhow; +use serde::{Deserialize, de::Visitor}; + +use std::fmt; + +#[derive(Default, Debug)] +pub enum ScopeType { + #[default] + Unknown, + Repository, +} + +impl fmt::Display for ScopeType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + ScopeType::Unknown => write!(f, ""), + ScopeType::Repository => write!(f, "repository"), + } + } +} + +#[derive(Default, Debug)] +pub enum Action { + #[default] + None, + Push, + Pull, +} + +impl fmt::Display for Action { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Action::None => write!(f, ""), + Action::Push => write!(f, "push"), + Action::Pull => write!(f, "pull"), + } + } +} + +#[derive(Default, Debug)] +pub struct Scope { + scope_type: ScopeType, + path: String, + actions: Vec, +} + +impl fmt::Display for Scope { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let actions = self.actions + .iter() + .map(|a| a.to_string()) + .collect::>() + .join(","); + + write!(f, "{}:{}:{}", self.scope_type, self.path, actions) + } +} + +impl TryFrom<&str> for Scope { + type Error = anyhow::Error; + + fn try_from(val: &str) -> Result { + let splits: Vec<&str> = val.split(":").collect(); + if splits.len() == 3 { + let scope_type = match splits[0] { + "repository" => ScopeType::Repository, + _ => { + return Err(anyhow!("Invalid scope type: `{}`!", splits[0])); + //return Err(serde::de::Error::custom(format!("Invalid scope type: `{}`!", splits[0]))); + } + }; + + let path = splits[1]; + + let actions: Result, anyhow::Error> = splits[2] + .split(",") + .map(|a| match a { + "pull" => Ok(Action::Pull), + "push" => Ok(Action::Push), + _ => Err(anyhow!("Invalid action: `{}`!", a)), //Err(serde::de::Error::custom(format!("Invalid action: `{}`!", a))), + }).collect(); + let actions = actions?; + + Ok(Scope { + scope_type, + path: String::from(path), + actions + }) + } else { + Err(anyhow!("Malformed scope string!")) + //Err(serde::de::Error::custom("Malformed scope string!")) + } + } +} + +pub struct ScopeVisitor { + +} + +impl<'de> Visitor<'de> for ScopeVisitor { + type Value = Scope; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a Scope in the format of `repository:samalba/my-app:pull,push`.") + } + + fn visit_str(self, val: &str) -> Result + where + E: serde::de::Error { + println!("Start of visit_str!"); + + let res = match Scope::try_from(val) { + Ok(val) => Ok(val), + Err(e) => Err(serde::de::Error::custom(format!("{}", e))) + }; + + res + + + + /* let splits: Vec<&str> = val.split(":").collect(); + if splits.len() == 3 { + let scope_type = match splits[0] { + "repository" => ScopeType::Repository, + _ => { + return Err(serde::de::Error::custom(format!("Invalid scope type: `{}`!", splits[0]))); + } + }; + + let path = splits[1]; + + let actions: Result, E> = splits[2] + .split(",") + .map(|a| match a { + "pull" => Ok(Action::Pull), + "push" => Ok(Action::Push), + _ => Err(serde::de::Error::custom(format!("Invalid action: `{}`!", a))), + }).collect(); + let actions = actions?; + + Ok(Scope { + scope_type, + path: String::from(path), + actions + }) + } else { + Err(serde::de::Error::custom("Malformed scope string!")) + } */ + } +} + +impl<'de> Deserialize<'de> for Scope { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de> { + deserializer.deserialize_str(ScopeVisitor {}) + } +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index f6389e7..ae49dd4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,14 +5,15 @@ mod dto; mod storage; mod byte_stream; mod config; +mod query; use std::net::SocketAddr; use std::str::FromStr; use std::sync::Arc; -use axum::http::Request; +use axum::http::{Request, StatusCode, header, HeaderName}; use axum::middleware::Next; -use axum::response::Response; +use axum::response::{Response, IntoResponse}; use axum::{Router, routing}; use axum::ServiceExt; use tower_layer::Layer; @@ -54,6 +55,20 @@ async fn change_request_paths(mut request: Request, next: Next) -> Resp next.run(request).await } +pub async fn auth_failure() -> impl IntoResponse { + let bearer = format!("Bearer realm=\"http://localhost:3000/token\""); + + ( + StatusCode::UNAUTHORIZED, + + [ + ( header::WWW_AUTHENTICATE, bearer ), + ( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string() ) + ] + ).into_response() + //StatusCode::UNAUTHORIZED +} + #[tokio::main] async fn main() -> std::io::Result<()> { let pool = SqlitePoolOptions::new() @@ -76,9 +91,12 @@ async fn main() -> std::io::Result<()> { let path_middleware = axum::middleware::from_fn(change_request_paths); let app = Router::new() + .route("/auth", routing::get(api::auth::auth_basic_get) + .post(api::auth::auth_basic_get)) + .fallback(auth_failure) .nest("/v2", Router::new() .route("/", routing::get(api::version_check)) - .route("/_catalog", routing::get(api::catalog::list_repositories)) + /* .route("/_catalog", routing::get(api::catalog::list_repositories)) .route("/:name/tags/list", routing::get(api::tags::list_tags)) .nest("/:name/blobs", Router::new() .route("/:digest", routing::get(api::blobs::pull_digest_get) @@ -96,7 +114,7 @@ async fn main() -> std::io::Result<()> { .route("/:name/manifests/:reference", routing::get(api::manifests::pull_manifest_get) .put(api::manifests::upload_manifest_put) .head(api::manifests::manifest_exists_head) - .delete(api::manifests::delete_manifest)) + .delete(api::manifests::delete_manifest)) */ ) .with_state(state) .layer(TraceLayer::new_for_http()); diff --git a/src/query.rs b/src/query.rs new file mode 100644 index 0000000..c56e2ff --- /dev/null +++ b/src/query.rs @@ -0,0 +1,33 @@ +use std::ops::Deref; + +use axum::extract::FromRequest; +use axum::http::{self, Request}; + +use async_trait::async_trait; +use serde::de::DeserializeOwned; + +pub struct Qs(pub T); + +impl Deref for Qs { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[async_trait] +impl FromRequest for Qs +where + // these bounds are required by `async_trait` + B: Send + 'static, + S: Send + Sync, + T: DeserializeOwned +{ + type Rejection = http::StatusCode; + + async fn from_request(req: Request, _state: &S) -> Result { + let query = req.uri().query().unwrap(); + Ok(Self(serde_qs::from_str(query).unwrap())) + } +} \ No newline at end of file