Allow slashes in repository names

This commit is contained in:
SeanOMik 2023-04-25 15:55:12 -04:00
parent dfb91a9cd8
commit 81ddb75b77
Signed by: SeanOMik
GPG Key ID: 568F326C7EB33ACB
1 changed files with 32 additions and 4 deletions

View File

@ -8,6 +8,9 @@ mod byte_stream;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use axum::http::Request;
use axum::middleware::Next;
use axum::response::Response;
use axum::{Router, routing}; use axum::{Router, routing};
use axum::ServiceExt; use axum::ServiceExt;
use tower_layer::Layer; use tower_layer::Layer;
@ -27,7 +30,28 @@ use tower_http::trace::TraceLayer;
pub const REGISTRY_URL: &'static str = "http://localhost:3000"; // TODO: Move into configuration or something (make sure it doesn't end in /) pub const REGISTRY_URL: &'static str = "http://localhost:3000"; // TODO: Move into configuration or something (make sure it doesn't end in /)
//#[actix_web::main] /// Encode the 'name' path parameter in the url
async fn change_request_paths<B>(mut request: Request<B>, next: Next<B>) -> Response {
// Attempt to find the name using regex in the url
let regex = regex::Regex::new(r"/v2/([\w/]+)/(blobs|tags|manifests)").unwrap();
let captures = match regex.captures(request.uri().path()) {
Some(captures) => captures,
None => return next.run(request).await,
};
// Find the name in the request and encode it in the url
let name = captures.get(1).unwrap().as_str().to_string();
let encoded_name = name.replace('/', "%2F");
// Replace the name in the uri
let uri_str = request.uri().to_string().replace(&name, &encoded_name);
debug!("Rewrote request url to: '{}'", uri_str);
*request.uri_mut() = uri_str.parse().unwrap();
next.run(request).await
}
#[tokio::main] #[tokio::main]
async fn main() -> std::io::Result<()> { async fn main() -> std::io::Result<()> {
let pool = SqlitePoolOptions::new() let pool = SqlitePoolOptions::new()
@ -44,7 +68,9 @@ async fn main() -> std::io::Result<()> {
.with_max_level(Level::DEBUG) .with_max_level(Level::DEBUG)
.init(); .init();
let app = NormalizePathLayer::trim_trailing_slash().layer(Router::new() let path_middleware = axum::middleware::from_fn(change_request_paths);
let app = Router::new()
.nest("/v2", Router::new() .nest("/v2", Router::new()
.route("/", routing::get(api::version_check)) .route("/", routing::get(api::version_check))
.route("/_catalog", routing::get(api::catalog::list_repositories)) .route("/_catalog", routing::get(api::catalog::list_repositories))
@ -68,12 +94,14 @@ async fn main() -> std::io::Result<()> {
.delete(api::manifests::delete_manifest)) .delete(api::manifests::delete_manifest))
) )
.with_state(state) .with_state(state)
.layer(TraceLayer::new_for_http())); .layer(TraceLayer::new_for_http());
let layered_app = NormalizePathLayer::trim_trailing_slash().layer(path_middleware.layer(app));
let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
debug!("Starting http server, listening on {}", addr); debug!("Starting http server, listening on {}", addr);
axum::Server::bind(&addr) axum::Server::bind(&addr)
.serve(app.into_make_service()) .serve(layered_app.into_make_service())
.await .await
.unwrap(); .unwrap();