Use jwt secret by verifying it in the auth middleware

This commit is contained in:
SeanOMik 2023-07-14 21:25:38 -04:00
parent 7cc19bc1cd
commit b46a7a844b
Signed by: SeanOMik
GPG Key ID: 568F326C7EB33ACB
10 changed files with 120 additions and 256 deletions

View File

@ -53,4 +53,4 @@ rand = "0.8.5"
bcrypt = "0.14.0"
bitflags = "2.2.1"
ldap3 = "0.11.1"
lazy_static = "1.4.0"
lazy_static = "1.4.0"

View File

@ -1,5 +1,5 @@
use std::{
collections::{BTreeMap, HashMap},
collections::HashMap,
sync::Arc,
time::SystemTime,
};
@ -11,7 +11,7 @@ use axum::{
Form,
};
use axum_auth::AuthBasic;
use chrono::{DateTime, Days, Duration, Utc};
use chrono::{DateTime, Days, Utc};
use serde::{Deserialize, Serialize};
use tracing::{debug, error, info, span, Level};
@ -33,7 +33,7 @@ use crate::{
use crate::auth::auth_challenge_response;
#[derive(Deserialize, Debug)]
#[derive(Debug)]
pub struct TokenAuthRequest {
user: Option<String>,
password: Option<String>,
@ -157,13 +157,13 @@ pub async fn auth_basic_get(
ScopeType::Repository => {
// check repository visibility
let database = &state.database;
match database.get_repository_visibility(&scope.path).await {
match database.get_repository_visibility(&scope.name).await {
Ok(Some(RepositoryVisibility::Public)) => res.push(Ok(true)),
Ok(_) => res.push(Ok(false)),
Err(e) => {
error!(
"Failure to check repository visibility for {}! Err: {}",
scope.path, e
scope.name, e
);
res.push(Err(StatusCode::INTERNAL_SERVER_ERROR));
@ -201,8 +201,8 @@ pub async fn auth_basic_get(
issued_at: now_format,
};
let json_str =
serde_json::to_string(&auth_response).map_err(|_| StatusCode::BAD_REQUEST)?;
let json_str = serde_json::to_string(&auth_response)
.map_err(|_| StatusCode::BAD_REQUEST)?;
debug!("Created anonymous token for public scopes!");
@ -213,8 +213,7 @@ pub async fn auth_basic_get(
(header::AUTHORIZATION, &format!("Bearer {}", token_str)),
],
json_str,
)
.into_response());
).into_response());
} else {
info!("Auth failure! Auth was not provided in either AuthBasic or Form!");
@ -279,10 +278,15 @@ pub async fn auth_basic_get(
{
debug!("Authentication failed, incorrect password!");
// TODO: Multiple scopes
let scope = auth.scope
.first()
.and_then(|s| Some(s.clone()));
// TODO: Dont unwrap, find a way to return multiple scopes
return Ok(auth_challenge_response(
&state.config,
Some(auth.scope.first().unwrap().clone()),
scope,
));
}
drop(auth_driver);

View File

@ -6,13 +6,12 @@ use axum::http::{StatusCode, HeaderName, header};
use tracing::log::warn;
use tracing::{debug, info};
use crate::auth::access_denied_response;
use crate::app_state::AppState;
use crate::database::Database;
use crate::dto::RepositoryVisibility;
use crate::dto::digest::Digest;
use crate::dto::manifest::Manifest;
use crate::dto::user::{UserAuth, Permission};
use crate::dto::user::UserAuth;
use crate::error::AppError;
pub async fn upload_manifest_put(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, auth: UserAuth, body: String) -> Result<Response, AppError> {
@ -20,10 +19,13 @@ pub async fn upload_manifest_put(Path((name, reference)): Path<(String, String)>
let calculated_hash = sha256::digest(body.clone());
let calculated_digest = format!("sha256:{}", calculated_hash);
// anonymous users wouldn't be able to get to this point, so it should be safe to unwrap.
let user = auth.user.unwrap();
let database = &state.database;
// Create the image repository and save the image manifest. This repository will be private by default
database.save_repository(&name, RepositoryVisibility::Private, Some(auth.user.email), None).await?;
database.save_repository(&name, RepositoryVisibility::Private, Some(user.email), None).await?;
database.save_manifest(&name, &calculated_digest, &body).await?;
// If the reference is not a digest, then it must be a tag name.
@ -57,20 +59,7 @@ pub async fn upload_manifest_put(Path((name, reference)): Path<(String, String)>
}
}
pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, auth: Option<UserAuth>) -> Result<Response, AppError> {
// Check if the user has permission to pull, or that the repository is public
if let Some(auth) = auth {
let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await? {
return Ok(access_denied_response(&state.config));
}
} else {
let database = &state.database;
if database.get_repository_visibility(&name).await? != Some(RepositoryVisibility::Public) {
return Ok(access_denied_response(&state.config));
}
}
pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>) -> Result<Response, AppError> {
let database = &state.database;
let digest = match Digest::is_digest(&reference) {
true => reference.clone(),
@ -106,21 +95,8 @@ pub async fn pull_manifest_get(Path((name, reference)): Path<(String, String)>,
).into_response())
}
pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, auth: Option<UserAuth>) -> Result<Response, AppError> {
// Check if the user has permission to pull, or that the repository is public
if let Some(auth) = auth {
let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PULL, Some(RepositoryVisibility::Public)).await? {
return Ok(access_denied_response(&state.config));
}
drop(auth_driver);
} else {
let database = &state.database;
if database.get_repository_visibility(&name).await? != Some(RepositoryVisibility::Public) {
return Ok(access_denied_response(&state.config));
}
}
pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>) -> Result<Response, AppError> {
debug!("start of head");
// Get the digest from the reference path.
let database = &state.database;
let digest = match Digest::is_digest(&reference) {
@ -133,6 +109,7 @@ pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)
}
}
};
debug!("found digest: {}", digest);
let manifest_content = database.get_manifest(&name, &digest).await?;
if manifest_content.is_none() {
@ -142,6 +119,8 @@ pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)
}
let manifest_content = manifest_content.unwrap();
debug!("got content");
Ok((
StatusCode::OK,
[
@ -154,13 +133,7 @@ pub async fn manifest_exists_head(Path((name, reference)): Path<(String, String)
).into_response())
}
pub async fn delete_manifest(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>, auth: UserAuth) -> Result<Response, AppError> {
let mut auth_driver = state.auth_checker.lock().await;
if !auth_driver.user_has_permission(auth.user.username, name.clone(), Permission::PUSH, None).await? {
return Ok(access_denied_response(&state.config));
}
drop(auth_driver);
pub async fn delete_manifest(Path((name, reference)): Path<(String, String)>, state: State<Arc<AppState>>) -> Result<Response, AppError> {
let database = &state.database;
let digest = match Digest::is_digest(&reference) {
true => {

View File

@ -17,7 +17,7 @@ pub mod auth;
/// full endpoint: `/v2/`
pub async fn version_check(_state: State<Arc<AppState>>) -> Response {
(
StatusCode::UNAUTHORIZED,
StatusCode::OK,
[
( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0" ),
]

View File

@ -3,7 +3,7 @@ use ldap3::{LdapConnAsync, Ldap, Scope, SearchEntry};
use sqlx::{Pool, Sqlite};
use tracing::{debug, warn};
use crate::{config::LdapConnectionConfig, dto::{user::{Permission, LoginSource, RegistryUserType, self}, RepositoryVisibility}, database::Database};
use crate::{config::LdapConnectionConfig, dto::{user::{Permission, LoginSource, RegistryUserType}, RepositoryVisibility}, database::Database};
use super::AuthDriver;

View File

@ -1,6 +1,6 @@
pub mod ldap_driver;
use std::{ops::Deref, sync::Arc};
use std::sync::Arc;
use axum::{extract::State, http::{StatusCode, HeaderMap, header, HeaderName, Request, Method}, middleware::Next, response::{Response, IntoResponse}};
@ -82,59 +82,8 @@ where
Ok(false)
}
#[derive(Clone)]
pub struct AuthToken(pub String);
impl Deref for AuthToken {
type Target = String;
fn deref(&self) -> &Self::Target {
&self.0
}
}
type Rejection = (StatusCode, HeaderMap);
pub async fn require_auth<B>(State(state): State<Arc<AppState>>, mut request: Request<B>, next: Next<B>) -> Result<Response, Rejection> {
let bearer = format!("Bearer realm=\"{}/auth\"", state.config.url());
let mut failure_headers = HeaderMap::new();
failure_headers.append(header::WWW_AUTHENTICATE, bearer.parse().unwrap());
failure_headers.append(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".parse().unwrap());
let auth = String::from(
request.headers().get(header::AUTHORIZATION)
.ok_or((StatusCode::UNAUTHORIZED, failure_headers.clone()))?
.to_str()
.map_err(|_| (StatusCode::UNAUTHORIZED, failure_headers.clone()))?
);
let token = match auth.split_once(' ') {
Some((auth, token)) if auth == "Bearer" => token,
// This line would allow empty tokens
//_ if auth == "Bearer" => Ok(AuthToken(None)),
_ => return Err( (StatusCode::UNAUTHORIZED, failure_headers) ),
};
// If the token is not valid, return an unauthorized response
let database = &state.database;
if let Ok(Some(user)) = database.verify_user_token(token.to_string()).await {
debug!("Authenticated user through middleware: {}", user.user.username);
request.extensions_mut().insert(user);
Ok(next.run(request).await)
} else {
let bearer = format!("Bearer realm=\"{}/auth\"", state.config.url());
Ok((
StatusCode::UNAUTHORIZED,
[
( header::WWW_AUTHENTICATE, bearer ),
( HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".to_string() )
]
).into_response())
}
}
/// Creates a response with an Unauthorized (401) status code.
/// The www-authenticate header is set to notify the client of where to authorize with.
#[inline(always)]
@ -173,9 +122,17 @@ pub async fn check_auth<B>(State(state): State<Arc<AppState>>, auth: Option<User
// note: url is relative to /v2
let url = request.uri().to_string();
if url == "/" && auth.is_none() {
debug!("Responding to /v2/ with an auth challenge");
return Ok(auth_challenge_response(config, None));
if url == "/" {
// if auth is none, then the client needs to authenticate
if auth.is_none() {
debug!("Responding to /v2/ with an auth challenge");
return Ok(auth_challenge_response(config, None));
}
debug!("user is authed");
// the client is authenticating right now
return Ok(next.run(request).await);
}
let url_split: Vec<&str> = url.split("/").skip(1).collect();
@ -216,14 +173,28 @@ pub async fn check_auth<B>(State(state): State<Arc<AppState>>, auth: Option<User
_ => None,
};
match auth_checker.user_has_permission(auth.user.email.clone(), target_name.clone(), permission, vis).await {
Ok(false) => return Ok(auth_challenge_response(config, Some(scope))),
Ok(true) => { },
Err(e) => {
error!("Error when checking user permissions! {}", e);
if let Some(user) = &auth.user {
match auth_checker.user_has_permission(user.email.clone(), target_name.clone(), permission, vis).await {
Ok(false) => return Ok(auth_challenge_response(config, Some(scope))),
Ok(true) => { },
Err(e) => {
error!("Error when checking user permissions! {}", e);
return Err((StatusCode::INTERNAL_SERVER_ERROR, HeaderMap::new()));
},
return Err((StatusCode::INTERNAL_SERVER_ERROR, HeaderMap::new()));
},
}
} else {
// anonymous users can ONLY pull from public repos
if permission != Permission::PULL {
return Ok(access_denied_response(config));
}
// ensure the repo is public
let database = &state.database;
if let Some(RepositoryVisibility::Private) = database.get_repository_visibility(&target_name).await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, HeaderMap::new()))? {
return Ok(access_denied_response(config));
}
}
}
} else {

View File

@ -1,4 +1,3 @@
use anyhow::anyhow;
use figment::{Figment, providers::{Env, Toml, Format}};
use figment_cliarg_provider::FigmentCliArgsProvider;
use serde::{Deserialize, Deserializer};

View File

@ -1,13 +1,11 @@
use async_trait::async_trait;
use hmac::{Hmac, digest::KeyInit};
use rand::{Rng, distributions::Alphanumeric};
use sha2::Sha256;
use sqlx::{Sqlite, Pool, sqlite::SqliteError};
use sqlx::{Sqlite, Pool};
use tracing::{debug, warn};
use chrono::{DateTime, Utc, NaiveDateTime, TimeZone};
use chrono::{DateTime, Utc, NaiveDateTime};
use crate::dto::{Tag, user::{User, RepositoryPermissions, RegistryUserType, Permission, UserAuth, TokenInfo, LoginSource}, RepositoryVisibility};
use crate::dto::{Tag, user::{User, RepositoryPermissions, RegistryUserType, Permission, UserAuth, LoginSource}, RepositoryVisibility};
#[async_trait]
pub trait Database {
@ -65,6 +63,7 @@ pub trait Database {
async fn get_user_repo_permissions(&self, email: String, repository: String) -> anyhow::Result<Option<RepositoryPermissions>>;
async fn get_user_registry_usertype(&self, email: String) -> anyhow::Result<Option<RegistryUserType>>;
async fn store_user_token(&self, token: String, email: String, expiry: DateTime<Utc>, created_at: DateTime<Utc>) -> anyhow::Result<()>;
#[deprecated = "Tokens are now verified using a secret"]
async fn verify_user_token(&self, token: String) -> anyhow::Result<Option<UserAuth>>;
}
@ -90,6 +89,10 @@ impl Database for Pool<Sqlite> {
}
};
sqlx::query(include_str!("schemas/schema.sql"))
.execute(self).await?;
debug!("Created database schema");
if row.is_none() || row.unwrap().0 == 0 {
let jwt_sec: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
@ -99,9 +102,9 @@ impl Database for Pool<Sqlite> {
// create schema
// TODO: Check if needed
sqlx::query(include_str!("schemas/schema.sql"))
/* sqlx::query(include_str!("schemas/schema.sql"))
.execute(self).await?;
debug!("Created database schema");
debug!("Created database schema"); */
sqlx::query("INSERT INTO orca(orca_version, schema_version, jwt_secret) VALUES (?, ?, ?)")
.bind(orca_version)
@ -415,6 +418,7 @@ impl Database for Pool<Sqlite> {
}
async fn get_user(&self, email: String) -> anyhow::Result<Option<User>> {
debug!("getting user");
let email = email.to_lowercase();
let row: (String, u32) = match sqlx::query_as("SELECT username, login_source FROM users WHERE email = ?")
.bind(email.clone())
@ -559,50 +563,7 @@ impl Database for Pool<Sqlite> {
Ok(())
}
async fn verify_user_token(&self, token: String) -> anyhow::Result<Option<UserAuth>> {
let token_row: (String, i64, i64,) = match sqlx::query_as("SELECT email, expiry, created_at FROM user_tokens WHERE token = ?")
.bind(token.clone())
.fetch_one(self).await {
Ok(row) => row,
Err(e) => match e {
sqlx::Error::RowNotFound => {
return Ok(None)
},
_ => {
return Err(anyhow::Error::new(e));
}
}
};
let (email, expiry, created_at) = (token_row.0, token_row.1, token_row.2);
let user_row: (String, u32) = match sqlx::query_as("SELECT username, login_source FROM users WHERE email = ?")
.bind(email.clone())
.fetch_one(self).await {
Ok(row) => row,
Err(e) => match e {
sqlx::Error::RowNotFound => {
return Ok(None)
},
_ => {
return Err(anyhow::Error::new(e));
}
}
};
/* let user_row: (String, u32) = sqlx::query_as("SELECT email, login_source FROM users WHERE email = ?")
.bind(email.clone())
.fetch_one(self).await?; */
let (expiry, created_at) = (Utc.timestamp_millis_opt(expiry).single(), Utc.timestamp_millis_opt(created_at).single());
if let (Some(expiry), Some(created_at)) = (expiry, created_at) {
let user = User::new(user_row.0, email, LoginSource::try_from(user_row.1)?);
let token = TokenInfo::new(token, expiry, created_at);
let auth = UserAuth::new(user, token);
Ok(Some(auth))
} else {
Ok(None)
}
async fn verify_user_token(&self, _token: String) -> anyhow::Result<Option<UserAuth>> {
panic!("ERR: Database::verify_user_token is deprecated!")
}
}

View File

@ -1,5 +1,5 @@
use anyhow::anyhow;
use serde::{Deserialize, de::Visitor, Serialize};
use serde::{Deserialize, Serialize};
use std::fmt;
@ -40,10 +40,11 @@ impl fmt::Display for Action {
}
}
#[derive(Default, Debug, Clone, Serialize, PartialEq, Eq)]
#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Scope {
#[serde(rename = "type")]
pub scope_type: ScopeType,
pub path: String,
pub name: String,
pub actions: Vec<Action>,
}
@ -51,7 +52,7 @@ impl Scope {
pub fn new(scope_type: ScopeType, path: String, actions: &[Action]) -> Self {
Self {
scope_type,
path,
name: path,
actions: actions.to_vec(),
}
}
@ -65,7 +66,7 @@ impl fmt::Display for Scope {
.collect::<Vec<String>>()
.join(",");
write!(f, "{}:{}:{}", self.scope_type, self.path, actions)
write!(f, "{}:{}:{}", self.scope_type, self.name, actions)
}
}
@ -96,7 +97,7 @@ impl TryFrom<&str> for Scope {
Ok(Scope {
scope_type,
path: String::from(path),
name: String::from(path),
actions
})
} else {
@ -104,68 +105,4 @@ impl TryFrom<&str> for Scope {
//Err(serde::de::Error::custom("Malformed scope string!"))
}
}
}
pub struct ScopeVisitor {
}
impl<'de> Visitor<'de> for ScopeVisitor {
type Value = Scope;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a Scope in the format of `repository:samalba/my-app:pull,push`.")
}
fn visit_str<E>(self, val: &str) -> Result<Self::Value, E>
where
E: serde::de::Error {
println!("Start of visit_str!");
let res = match Scope::try_from(val) {
Ok(val) => Ok(val),
Err(e) => Err(serde::de::Error::custom(format!("{}", e)))
};
res
/* let splits: Vec<&str> = val.split(":").collect();
if splits.len() == 3 {
let scope_type = match splits[0] {
"repository" => ScopeType::Repository,
_ => {
return Err(serde::de::Error::custom(format!("Invalid scope type: `{}`!", splits[0])));
}
};
let path = splits[1];
let actions: Result<Vec<Action>, E> = splits[2]
.split(",")
.map(|a| match a {
"pull" => Ok(Action::Pull),
"push" => Ok(Action::Push),
_ => Err(serde::de::Error::custom(format!("Invalid action: `{}`!", a))),
}).collect();
let actions = actions?;
Ok(Scope {
scope_type,
path: String::from(path),
actions
})
} else {
Err(serde::de::Error::custom("Malformed scope string!"))
} */
}
}
impl<'de> Deserialize<'de> for Scope {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de> {
deserializer.deserialize_str(ScopeVisitor {})
}
}

View File

@ -4,7 +4,10 @@ use async_trait::async_trait;
use axum::{http::{StatusCode, header, HeaderName, HeaderMap, request::Parts}, extract::FromRequestParts};
use bitflags::bitflags;
use chrono::{DateTime, Utc};
use hmac::{Hmac, digest::KeyInit};
use jwt::VerifyWithKey;
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use tracing::debug;
use crate::{app_state::AppState, database::Database};
@ -109,12 +112,12 @@ impl TokenInfo {
#[derive(Clone, Debug, PartialEq)]
pub struct UserAuth {
pub user: User,
pub token: TokenInfo,
pub user: Option<User>,
pub token: AuthToken,
}
impl UserAuth {
pub fn new(user: User, token: TokenInfo) -> Self {
pub fn new(user: Option<User>, token: AuthToken) -> Self {
Self {
user,
token,
@ -132,8 +135,6 @@ impl FromRequestParts<Arc<AppState>> for UserAuth {
failure_headers.append(header::WWW_AUTHENTICATE, bearer.parse().unwrap());
failure_headers.append(HeaderName::from_static("docker-distribution-api-version"), "registry/2.0".parse().unwrap());
debug!("starting UserAuth request parts");
let auth = String::from(
parts.headers
.get(header::AUTHORIZATION)
@ -142,8 +143,6 @@ impl FromRequestParts<Arc<AppState>> for UserAuth {
.map_err(|_| (StatusCode::UNAUTHORIZED, failure_headers.clone()))?
);
debug!("got auth header");
let token = match auth.split_once(' ') {
Some((auth, token)) if auth == "Bearer" => token,
// This line would allow empty tokens
@ -151,22 +150,42 @@ impl FromRequestParts<Arc<AppState>> for UserAuth {
_ => return Err( (StatusCode::UNAUTHORIZED, failure_headers) ),
};
debug!("got token");
// If the token is not valid, return an unauthorized response
let database = &state.database;
if let Ok(Some(user)) = database.verify_user_token(token.to_string()).await {
debug!("Authenticated user through request extractor: {}", user.user.username);
let jwt_key: Hmac<Sha256> = Hmac::new_from_slice(state.config.jwt_key.as_bytes())
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, HeaderMap::new()) )?;
Ok(user)
} else {
debug!("Failure to verify user token, responding with auth realm");
match VerifyWithKey::<AuthToken>::verify_with_key(token, &jwt_key) {
Ok(token) => {
// attempt to get the user
if !token.subject.is_empty() {
let database = &state.database;
if let Ok(Some(user)) = database.get_user(token.subject.clone()).await {
return Ok(UserAuth::new(Some(user), token));
} else {
debug!("failure to get user from token: {:?}", token);
}
} else {
return Ok(UserAuth::new(None, token));
}
Err((
StatusCode::UNAUTHORIZED,
failure_headers
))
/* let database = &state.database;
if let Ok(user) = database.get_user(token.subject.clone()).await {
return Ok(UserAuth::new(user, token));
} else {
debug!("failure to get user from token: {:?}", token);
} */
},
Err(e) => {
debug!("Failure to verify user token: '{}'", e);
}
}
debug!("Failure to verify user token, responding with auth realm");
Err((
StatusCode::UNAUTHORIZED,
failure_headers
))
}
}