diff --git a/cmd/files.go b/cmd/files.go index a2494d4..7af9085 100644 --- a/cmd/files.go +++ b/cmd/files.go @@ -18,11 +18,24 @@ import ( "github.com/h2non/filetype" ) -func containsCaseIntensitive(a string, b string) bool { - return strings.Contains( - strings.ToLower(a), - strings.ToLower(b), - ) +var ( + ErrNoImagesFound = fmt.Errorf("no supported image formats found") +) + +func appendPaths(paths []string, path, filter string) ([]string, error) { + absolutePath, err := filepath.Abs(path) + if err != nil { + return paths, err + } + + switch { + case filter != "" && strings.Contains(path, filter): + paths = append(paths, absolutePath) + case filter == "": + paths = append(paths, absolutePath) + } + + return paths, nil } func getFirstFile(path string) (string, error) { @@ -111,28 +124,32 @@ func checkIfImage(path string) (bool, error) { return false, nil } -func getFiles(path string) ([]string, error) { +func getFiles(path, filter string) ([]string, error) { var paths []string err := filepath.WalkDir(path, func(p string, info os.DirEntry, err error) error { + if err != nil { + return err + } + switch { - case info.IsDir() && p != path: + case !Recursive && info.IsDir() && p != path: return filepath.SkipDir - case Filter != "": - absolutePath, err := filepath.Abs(p) + case Filter != "" && !info.IsDir(): + paths, err = appendPaths(paths, p, Filter) if err != nil { return err } - - if containsCaseIntensitive(p, Filter) { - paths = append(paths, absolutePath) + case filter != "" && !info.IsDir(): + paths, err = appendPaths(paths, p, filter) + if err != nil { + return err } default: - absolutePath, err := filepath.Abs(p) + paths, err = appendPaths(paths, p, "") if err != nil { return err } - paths = append(paths, absolutePath) } return err @@ -144,64 +161,24 @@ func getFiles(path string) ([]string, error) { return paths, nil } -func getFilesRecursive(path string) ([]string, error) { - var paths []string - - err := filepath.WalkDir(path, func(p string, info os.DirEntry, err error) error { - switch { - case Filter != "" && !info.IsDir(): - absolutePath, err := filepath.Abs(p) - if err != nil { - return err - } - - if containsCaseIntensitive(p, Filter) { - paths = append(paths, absolutePath) - } - case Filter == "" && !info.IsDir(): - absolutePath, err := filepath.Abs(p) - if err != nil { - return err - } - - paths = append(paths, absolutePath) - } - - return err - }) - if err != nil { - return nil, err - } - - return paths, nil -} - -func getFileList(paths []string) ([]string, error) { +func getFileList(paths []string, filter string) ([]string, error) { fileList := []string{} for i := 0; i < len(paths); i++ { - if Recursive { - f, err := getFilesRecursive(paths[i]) - if err != nil { - return nil, err - } - fileList = append(fileList, f...) - } else { - f, err := getFiles(paths[i]) - if err != nil { - return nil, err - } - - fileList = append(fileList, f...) + f, err := getFiles(paths[i], filter) + if err != nil { + return nil, err } + + fileList = append(fileList, f...) } return fileList, nil } -func pickFile(args []string) (string, error) { - fileList, err := getFileList(args) +func pickFile(args []string, filter string) (string, error) { + fileList, err := getFileList(args, filter) if err != nil { return "", err } @@ -221,9 +198,7 @@ func pickFile(args []string) (string, error) { } } - err = errors.New("no images found") - - return "", err + return "", ErrNoImagesFound } func normalizePaths(args []string) ([]string, error) { diff --git a/cmd/version.go b/cmd/version.go index 25d8805..71339df 100644 --- a/cmd/version.go +++ b/cmd/version.go @@ -10,7 +10,7 @@ import ( "github.com/spf13/cobra" ) -var Version = "0.9.0" +var Version = "0.10.0" func init() { rootCmd.AddCommand(versionCmd) diff --git a/cmd/web.go b/cmd/web.go index c26e7b2..7231df5 100644 --- a/cmd/web.go +++ b/cmd/web.go @@ -22,6 +22,15 @@ const LOGDATE string = "2006-01-02T15:04:05.000000000-07:00" const PREFIX string = "/src" +func stripQueryParam(inUrl string) string { + u, err := url.Parse(inUrl) + if err != nil { + panic(err) + } + u.RawQuery = "" + return u.String() +} + func refererToUri(referer string) string { parts := strings.SplitAfterN(referer, "/", 4) @@ -44,8 +53,8 @@ func serveHtml(w http.ResponseWriter, r http.Request, filePath string) error { htmlBody += fileName htmlBody += ` -
- @@ -60,11 +69,13 @@ func serveHtml(w http.ResponseWriter, r http.Request, filePath string) error { } func serveStaticFile(w http.ResponseWriter, r http.Request, paths []string) error { - prefixedFilePath, err := url.QueryUnescape(r.RequestURI) + prefixedFilePath, err := url.QueryUnescape(stripQueryParam(r.URL.Path)) if err != nil { return err } + fmt.Println("Prefixed file path is " + prefixedFilePath) + filePath := strings.TrimPrefix(prefixedFilePath, PREFIX) var matchesPrefix = false @@ -75,7 +86,7 @@ func serveStaticFile(w http.ResponseWriter, r http.Request, paths []string) erro } if !matchesPrefix { if Verbose { - fmt.Printf("%v Failed to serve file outside specified path(s): %v", time.Now().Format(LOGDATE), filePath) + fmt.Printf("%v Failed to serve file outside specified path(s): %v\n", time.Now().Format(LOGDATE), filePath) } http.NotFound(w, &r) @@ -86,7 +97,7 @@ func serveStaticFile(w http.ResponseWriter, r http.Request, paths []string) erro _, err = os.Stat(filePath) if errors.Is(err, os.ErrNotExist) { if Verbose { - fmt.Printf("%v Failed to serve non-existent file: %v", time.Now().Format(LOGDATE), filePath) + fmt.Printf("%v Failed to serve non-existent file: %v\n", time.Now().Format(LOGDATE), filePath) } http.NotFound(w, &r) @@ -127,10 +138,12 @@ func serveStaticFileHandler(paths []string) http.HandlerFunc { func serveHtmlHandler(paths []string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - refererUri := refererToUri(r.Referer()) + refererUri := stripQueryParam(refererToUri(r.Referer())) + + filter := r.URL.Query().Get("filter") switch { - case r.RequestURI == "/" && Successive && refererUri != "": + case r.URL.Path == "/" && Successive && refererUri != "": query, err := url.QueryUnescape(refererUri) if err != nil { log.Fatal(err) @@ -142,8 +155,12 @@ func serveHtmlHandler(paths []string) http.HandlerFunc { } if filePath == "" { - filePath, err = pickFile(paths) - if err != nil { + filePath, err = pickFile(paths, filter) + switch { + case err != nil && err == ErrNoImagesFound: + http.NotFound(w, r) + return + case err != nil: log.Fatal(err) } @@ -153,11 +170,15 @@ func serveHtmlHandler(paths []string) http.HandlerFunc { } } - newUrl := r.URL.Host + filePath + newUrl := fmt.Sprintf("%v%v?filter=%v", r.URL.Host, filePath, filter) http.Redirect(w, r, newUrl, http.StatusSeeOther) - case r.RequestURI == "/" && Successive && refererUri == "": - filePath, err := pickFile(paths) - if err != nil { + case r.URL.Path == "/" && Successive && refererUri == "": + filePath, err := pickFile(paths, filter) + switch { + case err != nil && err == ErrNoImagesFound: + http.NotFound(w, r) + return + case err != nil: log.Fatal(err) } @@ -166,21 +187,22 @@ func serveHtmlHandler(paths []string) http.HandlerFunc { log.Fatal(err) } - newUrl := r.URL.Host + filePath + newUrl := fmt.Sprintf("%v%v?filter=%v", r.URL.Host, filePath, filter) http.Redirect(w, r, newUrl, http.StatusSeeOther) - case r.RequestURI == "/": - filePath, err := pickFile(paths) - if err != nil { + case r.URL.Path == "/": + filePath, err := pickFile(paths, filter) + switch { + case err != nil && err == ErrNoImagesFound: + http.NotFound(w, r) + return + case err != nil: log.Fatal(err) } - newUrl := r.URL.Host + filePath + newUrl := fmt.Sprintf("%v%v?filter=%v", r.URL.Host, filePath, filter) http.Redirect(w, r, newUrl, http.StatusSeeOther) default: - filePath, err := url.QueryUnescape(r.RequestURI) - if err != nil { - log.Fatal(err) - } + filePath := r.URL.Path isImage, err := checkIfImage(filePath) if err != nil {