From 81ddb75b7732313bbf2d59e0a98a4eb5497910ce Mon Sep 17 00:00:00 2001 From: SeanOMik Date: Tue, 25 Apr 2023 15:55:12 -0400 Subject: [PATCH] Allow slashes in repository names --- src/main.rs | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/src/main.rs b/src/main.rs index 635938f..d2ddadd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,6 +8,9 @@ mod byte_stream; use std::net::SocketAddr; use std::sync::Arc; +use axum::http::Request; +use axum::middleware::Next; +use axum::response::Response; use axum::{Router, routing}; use axum::ServiceExt; use tower_layer::Layer; @@ -27,7 +30,28 @@ use tower_http::trace::TraceLayer; pub const REGISTRY_URL: &'static str = "http://localhost:3000"; // TODO: Move into configuration or something (make sure it doesn't end in /) -//#[actix_web::main] +/// Encode the 'name' path parameter in the url +async fn change_request_paths(mut request: Request, next: Next) -> Response { + // Attempt to find the name using regex in the url + let regex = regex::Regex::new(r"/v2/([\w/]+)/(blobs|tags|manifests)").unwrap(); + let captures = match regex.captures(request.uri().path()) { + Some(captures) => captures, + None => return next.run(request).await, + }; + + // Find the name in the request and encode it in the url + let name = captures.get(1).unwrap().as_str().to_string(); + let encoded_name = name.replace('/', "%2F"); + + // Replace the name in the uri + let uri_str = request.uri().to_string().replace(&name, &encoded_name); + debug!("Rewrote request url to: '{}'", uri_str); + + *request.uri_mut() = uri_str.parse().unwrap(); + + next.run(request).await +} + #[tokio::main] async fn main() -> std::io::Result<()> { let pool = SqlitePoolOptions::new() @@ -44,7 +68,9 @@ async fn main() -> std::io::Result<()> { .with_max_level(Level::DEBUG) .init(); - let app = NormalizePathLayer::trim_trailing_slash().layer(Router::new() + let path_middleware = axum::middleware::from_fn(change_request_paths); + + let app = Router::new() .nest("/v2", Router::new() .route("/", routing::get(api::version_check)) .route("/_catalog", routing::get(api::catalog::list_repositories)) @@ -68,12 +94,14 @@ async fn main() -> std::io::Result<()> { .delete(api::manifests::delete_manifest)) ) .with_state(state) - .layer(TraceLayer::new_for_http())); + .layer(TraceLayer::new_for_http()); + + let layered_app = NormalizePathLayer::trim_trailing_slash().layer(path_middleware.layer(app)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); debug!("Starting http server, listening on {}", addr); axum::Server::bind(&addr) - .serve(app.into_make_service()) + .serve(layered_app.into_make_service()) .await .unwrap();