Generate token for auth, kind of hacky currently

This commit is contained in:
SeanOMik 2023-04-28 12:28:37 -04:00
parent 89757973a9
commit 4c768753ab
Signed by: SeanOMik
GPG Key ID: 568F326C7EB33ACB
8 changed files with 532 additions and 45 deletions

122
Cargo.lock generated
View File

@ -143,6 +143,18 @@ dependencies = [
"tower-service",
]
[[package]]
name = "axum-auth"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "620b37645b77baab8160f93421568d7b3dd25da0a160fab38eb1c4ef611f6d98"
dependencies = [
"async-trait",
"axum-core",
"base64 0.13.1",
"http",
]
[[package]]
name = "axum-core"
version = "0.3.4"
@ -440,40 +452,6 @@ dependencies = [
"subtle",
]
[[package]]
name = "docker-registry"
version = "0.1.0"
dependencies = [
"anyhow",
"argmap",
"async-stream",
"async-trait",
"axum",
"axum-macros",
"bytes",
"chrono",
"clap",
"figment",
"figment-cliarg-provider",
"futures",
"jws",
"pin-project-lite",
"regex",
"serde",
"serde_json",
"sha256",
"sqlx",
"tokio",
"tokio-util",
"tower-http",
"tower-layer",
"tracing",
"tracing-log",
"tracing-subscriber",
"uuid",
"wild",
]
[[package]]
name = "dotenvy"
version = "0.15.7"
@ -882,6 +860,21 @@ dependencies = [
"sha2 0.10.6",
]
[[package]]
name = "jwt"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6204285f77fe7d9784db3fdc449ecce1a0114927a51d5a41c4c7a292011c015f"
dependencies = [
"base64 0.13.1",
"crypto-common",
"digest 0.10.5",
"hmac",
"serde",
"serde_json",
"sha2 0.10.6",
]
[[package]]
name = "lazy_static"
version = "1.4.0"
@ -1030,6 +1023,47 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
[[package]]
name = "orca-registry"
version = "0.1.0"
dependencies = [
"anyhow",
"argmap",
"async-stream",
"async-trait",
"axum",
"axum-auth",
"axum-macros",
"bytes",
"chrono",
"clap",
"figment",
"figment-cliarg-provider",
"futures",
"hmac",
"jws",
"jwt",
"pin-project-lite",
"qstring",
"rand",
"regex",
"serde",
"serde_json",
"serde_qs",
"sha2 0.10.6",
"sha256",
"sqlx",
"tokio",
"tokio-util",
"tower-http",
"tower-layer",
"tracing",
"tracing-log",
"tracing-subscriber",
"uuid",
"wild",
]
[[package]]
name = "os_str_bytes"
version = "6.4.0"
@ -1192,6 +1226,15 @@ dependencies = [
"yansi",
]
[[package]]
name = "qstring"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d464fae65fff2680baf48019211ce37aaec0c78e9264c84a3e484717f965104e"
dependencies = [
"percent-encoding",
]
[[package]]
name = "quote"
version = "0.3.15"
@ -1373,6 +1416,17 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_qs"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0431a35568651e363364210c91983c1da5eb29404d9f0928b67d4ebcfa7d330c"
dependencies = [
"percent-encoding",
"serde",
"thiserror",
]
[[package]]
name = "serde_urlencoded"
version = "0.7.1"

View File

@ -1,5 +1,5 @@
[package]
name = "docker-registry"
name = "orca-registry"
version = "0.1.0"
edition = "2021"
@ -14,7 +14,6 @@ uuid = { version = "1.3.1", features = [ "v4", "fast-rng" ] }
sqlx = { version = "0.6.3", features = [ "runtime-tokio-rustls", "sqlite" ] }
bytes = "1.4.0"
chrono = "0.4.23"
tokio = { version = "1.21.2", features = [ "fs", "macros" ] }
tokio-util = { version = "0.7.7", features = [ "io" ] }
@ -43,3 +42,10 @@ figment = { version = "0.10", features = ["toml", "env"] }
figment-cliarg-provider = { git = "https://github.com/SeanOMik/figment-cliarg-provider.git", branch = "main" }
wild = "2.1.0"
argmap = "1.1.2"
serde_qs = "0.12.0"
axum-auth = "0.4.0"
qstring = "0.7.2"
jwt = "0.16.0"
hmac = "0.12.1"
sha2 = "0.10.6"
rand = "0.8.5"

195
src/api/auth.rs Normal file
View File

@ -0,0 +1,195 @@
use std::{sync::Arc, collections::{HashMap, BTreeMap}, time::{SystemTime, UNIX_EPOCH}};
use axum::{extract::{Query, State}, response::{IntoResponse, Response}, http::{StatusCode, Request, Method, HeaderName, header}, Form};
use axum_auth::AuthBasic;
use chrono::{DateTime, Utc};
use qstring::QString;
use serde::{Deserialize, Serialize};
use serde_json::json;
use tracing::{debug, error, info, span, Level};
use hmac::{Hmac, Mac};
use jwt::SignWithKey;
use sha2::Sha256;
use rand::Rng;
use crate::{dto::scope::Scope, app_state::AppState, query::Qs};
#[derive(Deserialize, Debug)]
pub struct TokenAuthRequest {
user: Option<String>,
password: Option<String>,
account: Option<String>,
/// The name of the service which hosts the resource.
/// I don't think this is necessary since the auth service is embedded with the registry.
pub service: Option<String>,
pub scope: Vec<Scope>,
offline_token: Option<bool>,
client_id: Option<String>,
}
#[derive(Deserialize, Debug)]
pub struct AuthForm {
username: String,
password: String,
}
#[derive(Deserialize, Serialize, Debug)]
pub struct AuthResponse {
token: String,
expires_in: u32,
issued_at: String,
}
fn create_jwt_token(account: String) -> anyhow::Result<String> {
let key: Hmac<Sha256> = Hmac::new_from_slice(b"some-secret")?;
let now = SystemTime::now();
let now_secs = now
.duration_since(UNIX_EPOCH)?
.as_secs();
// Construct the claims for the token
let mut claims = BTreeMap::new();
claims.insert("issuer", "orca-registry__DEV");
claims.insert("subject", &account);
//claims.insert("audience", auth.service);
let notbefore = format!("{}", now_secs - 10);
let issuedat = format!("{}", now_secs);
let expiration = format!("{}", now_secs + 20);
claims.insert("notbefore", &notbefore);
claims.insert("issuedat", &issuedat);
claims.insert("expiration", &expiration); // TODO: 20 seconds expiry for testing
// Create a randomized jwtid
let mut rng = rand::thread_rng();
let jwtid = format!("{}", rng.gen::<u64>());
claims.insert("jwtid", &jwtid);
Ok(claims.sign_with_key(&key)?)
}
pub async fn auth_basic_get(basic_auth: Option<AuthBasic>, state: State<Arc<AppState>>, Query(params): Query<HashMap<String, String>>, form: Option<Form<AuthForm>>) -> Response {
let mut auth = TokenAuthRequest {
user: None,
password: None,
account: None,
service: None,
scope: Vec::new(),
offline_token: None,
client_id: None,
};
let auth_method;
// If BasicAuth is provided, set the fields to it
if let Some(AuthBasic((username, pass))) = basic_auth {
auth.user = Some(username.clone());
auth.password = pass;
// I hate having to create this span here multiple times, but its the only
// way I could think of
/* let span = span!(Level::DEBUG, "auth", username = auth.user.clone());
let _enter = span.enter();
debug!("Read user authentication from an AuthBasic"); */
auth_method = "basic-auth";
}
// Username and password could be passed in forms
// If there was a way to also check if the Method was "POST", this is where
// we would do it.
else if let Some(Form(form)) = form {
auth.user = Some(form.username.clone());
auth.password = Some(form.password);
let span = span!(Level::DEBUG, "auth", username = auth.user.clone());
let _enter = span.enter();
debug!("Read user authentication from a Form");
auth_method = "form";
} else {
info!("Auth failure! Auth was not provided in either AuthBasic or Form!");
// Maybe BAD_REQUEST should be returned?
return (StatusCode::UNAUTHORIZED).into_response();
}
// Create logging span for the rest of this request
let span = span!(Level::DEBUG, "auth", username = auth.user.clone(), auth_method);
let _enter = span.enter();
debug!("Parsed user auth request");
// Get account from query string, if its specified, ensure that its the same as the user if
// that is also specified.
if let Some(account) = params.get("account") {
if let Some(user) = &auth.user {
if account != user {
error!("`user` and `account` are not the same!!! (user: {}, account: {})", user, account);
return (StatusCode::BAD_REQUEST).into_response();
}
}
auth.account = Some(account.clone());
}
// Get service from query string
if let Some(service) = params.get("service") {
auth.service = Some(service.clone());
}
// Process all the scopes
if let Some(scope) = params.get("scope") {
// TODO: Handle multiple scopes
auth.scope.push(Scope::try_from(&scope[..]).unwrap());
}
// Get offline token and attempt to convert it to a boolean
if let Some(offline_token) = params.get("offline_token") {
if let Ok(b) = offline_token.parse::<bool>() {
auth.offline_token = Some(b);
}
}
if let Some(client_id) = params.get("client_id") {
auth.client_id = Some(client_id.clone());
}
debug!("Constructed auth request");
if let Some(account) = auth.account {
let now = SystemTime::now();
let token_str = create_jwt_token(account).unwrap();
debug!("Created jwt token");
// ISO8601 time format
let now_dt: DateTime<Utc> = now.into();
let now_format = format!("{}", now_dt.format("%+"));
// Construct the auth response
let auth_response = AuthResponse {
token: token_str.clone(),
expires_in: 20,
issued_at: now_format,
};
let json_str = serde_json::to_string(&auth_response).unwrap();
return (
StatusCode::OK,
[
( header::CONTENT_TYPE, "application/json" ),
( header::AUTHORIZATION, &format!("Bearer {}", token_str) )
],
json_str
).into_response();
}
info!("Auth failure! Not enough information given to create auth token!");
// If we didn't get fields required to make a token, then the client did something bad
(StatusCode::UNAUTHORIZED).into_response()
}

View File

@ -1,17 +1,39 @@
use axum::response::IntoResponse;
use axum::http::{StatusCode, HeaderName};
use axum::extract::Query;
use axum::response::{IntoResponse, Response};
use axum::http::{StatusCode, HeaderName, header};
use tracing::debug;
use self::auth::TokenAuthRequest;
pub mod blobs;
pub mod uploads;
pub mod manifests;
pub mod tags;
pub mod catalog;
pub mod auth;
/// https://docs.docker.com/registry/spec/api/#api-version-check
/// full endpoint: `/v2/`
pub async fn version_check() -> impl IntoResponse {
(
pub async fn version_check(params: Option<Query<TokenAuthRequest>>, body: String) -> Response {
debug!("Got body: {}", body);
/* (
StatusCode::OK,
[( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0" )]
)
) */
//Www-Authenticate: Bearer realm="https://auth.docker.io/token",service="registry.docker.io",scope="repository:samalba/my-app:pull,push"
let bearer = format!("Bearer realm=\"http://localhost:3000/auth\"");/* match params {
Some(Query(params)) => format!("Bearer realm=\"http://localhost:3000/token\",scope=\"{}\"", params.scope),
None => format!("Bearer realm=\"http://localhost:3000/token\""),
}; */
(
StatusCode::UNAUTHORIZED,
[
( header::WWW_AUTHENTICATE, bearer ),
( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string() )
]
).into_response()
}

View File

@ -2,6 +2,7 @@ use chrono::{DateTime, Utc};
pub mod manifest;
pub mod digest;
pub mod scope;
#[derive(Debug)]
pub struct Tag {

158
src/dto/scope.rs Normal file
View File

@ -0,0 +1,158 @@
use anyhow::anyhow;
use serde::{Deserialize, de::Visitor};
use std::fmt;
#[derive(Default, Debug)]
pub enum ScopeType {
#[default]
Unknown,
Repository,
}
impl fmt::Display for ScopeType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
ScopeType::Unknown => write!(f, ""),
ScopeType::Repository => write!(f, "repository"),
}
}
}
#[derive(Default, Debug)]
pub enum Action {
#[default]
None,
Push,
Pull,
}
impl fmt::Display for Action {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
Action::None => write!(f, ""),
Action::Push => write!(f, "push"),
Action::Pull => write!(f, "pull"),
}
}
}
#[derive(Default, Debug)]
pub struct Scope {
scope_type: ScopeType,
path: String,
actions: Vec<Action>,
}
impl fmt::Display for Scope {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let actions = self.actions
.iter()
.map(|a| a.to_string())
.collect::<Vec<String>>()
.join(",");
write!(f, "{}:{}:{}", self.scope_type, self.path, actions)
}
}
impl TryFrom<&str> for Scope {
type Error = anyhow::Error;
fn try_from(val: &str) -> Result<Self, Self::Error> {
let splits: Vec<&str> = val.split(":").collect();
if splits.len() == 3 {
let scope_type = match splits[0] {
"repository" => ScopeType::Repository,
_ => {
return Err(anyhow!("Invalid scope type: `{}`!", splits[0]));
//return Err(serde::de::Error::custom(format!("Invalid scope type: `{}`!", splits[0])));
}
};
let path = splits[1];
let actions: Result<Vec<Action>, anyhow::Error> = splits[2]
.split(",")
.map(|a| match a {
"pull" => Ok(Action::Pull),
"push" => Ok(Action::Push),
_ => Err(anyhow!("Invalid action: `{}`!", a)), //Err(serde::de::Error::custom(format!("Invalid action: `{}`!", a))),
}).collect();
let actions = actions?;
Ok(Scope {
scope_type,
path: String::from(path),
actions
})
} else {
Err(anyhow!("Malformed scope string!"))
//Err(serde::de::Error::custom("Malformed scope string!"))
}
}
}
pub struct ScopeVisitor {
}
impl<'de> Visitor<'de> for ScopeVisitor {
type Value = Scope;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a Scope in the format of `repository:samalba/my-app:pull,push`.")
}
fn visit_str<E>(self, val: &str) -> Result<Self::Value, E>
where
E: serde::de::Error {
println!("Start of visit_str!");
let res = match Scope::try_from(val) {
Ok(val) => Ok(val),
Err(e) => Err(serde::de::Error::custom(format!("{}", e)))
};
res
/* let splits: Vec<&str> = val.split(":").collect();
if splits.len() == 3 {
let scope_type = match splits[0] {
"repository" => ScopeType::Repository,
_ => {
return Err(serde::de::Error::custom(format!("Invalid scope type: `{}`!", splits[0])));
}
};
let path = splits[1];
let actions: Result<Vec<Action>, E> = splits[2]
.split(",")
.map(|a| match a {
"pull" => Ok(Action::Pull),
"push" => Ok(Action::Push),
_ => Err(serde::de::Error::custom(format!("Invalid action: `{}`!", a))),
}).collect();
let actions = actions?;
Ok(Scope {
scope_type,
path: String::from(path),
actions
})
} else {
Err(serde::de::Error::custom("Malformed scope string!"))
} */
}
}
impl<'de> Deserialize<'de> for Scope {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de> {
deserializer.deserialize_str(ScopeVisitor {})
}
}

View File

@ -5,14 +5,15 @@ mod dto;
mod storage;
mod byte_stream;
mod config;
mod query;
use std::net::SocketAddr;
use std::str::FromStr;
use std::sync::Arc;
use axum::http::Request;
use axum::http::{Request, StatusCode, header, HeaderName};
use axum::middleware::Next;
use axum::response::Response;
use axum::response::{Response, IntoResponse};
use axum::{Router, routing};
use axum::ServiceExt;
use tower_layer::Layer;
@ -54,6 +55,20 @@ async fn change_request_paths<B>(mut request: Request<B>, next: Next<B>) -> Resp
next.run(request).await
}
pub async fn auth_failure() -> impl IntoResponse {
let bearer = format!("Bearer realm=\"http://localhost:3000/token\"");
(
StatusCode::UNAUTHORIZED,
[
( header::WWW_AUTHENTICATE, bearer ),
( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string() )
]
).into_response()
//StatusCode::UNAUTHORIZED
}
#[tokio::main]
async fn main() -> std::io::Result<()> {
let pool = SqlitePoolOptions::new()
@ -76,9 +91,12 @@ async fn main() -> std::io::Result<()> {
let path_middleware = axum::middleware::from_fn(change_request_paths);
let app = Router::new()
.route("/auth", routing::get(api::auth::auth_basic_get)
.post(api::auth::auth_basic_get))
.fallback(auth_failure)
.nest("/v2", Router::new()
.route("/", routing::get(api::version_check))
.route("/_catalog", routing::get(api::catalog::list_repositories))
/* .route("/_catalog", routing::get(api::catalog::list_repositories))
.route("/:name/tags/list", routing::get(api::tags::list_tags))
.nest("/:name/blobs", Router::new()
.route("/:digest", routing::get(api::blobs::pull_digest_get)
@ -96,7 +114,7 @@ async fn main() -> std::io::Result<()> {
.route("/:name/manifests/:reference", routing::get(api::manifests::pull_manifest_get)
.put(api::manifests::upload_manifest_put)
.head(api::manifests::manifest_exists_head)
.delete(api::manifests::delete_manifest))
.delete(api::manifests::delete_manifest)) */
)
.with_state(state)
.layer(TraceLayer::new_for_http());

33
src/query.rs Normal file
View File

@ -0,0 +1,33 @@
use std::ops::Deref;
use axum::extract::FromRequest;
use axum::http::{self, Request};
use async_trait::async_trait;
use serde::de::DeserializeOwned;
pub struct Qs<T>(pub T);
impl<T> Deref for Qs<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[async_trait]
impl<S, B, T> FromRequest<S, B> for Qs<T>
where
// these bounds are required by `async_trait`
B: Send + 'static,
S: Send + Sync,
T: DeserializeOwned
{
type Rejection = http::StatusCode;
async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
let query = req.uri().query().unwrap();
Ok(Self(serde_qs::from_str(query).unwrap()))
}
}