diff --git a/cmd/files.go b/cmd/files.go index fcff148..348e335 100644 --- a/cmd/files.go +++ b/cmd/files.go @@ -74,19 +74,19 @@ func getFilesRecursive(path string) ([]string, error) { return paths, nil } -func getFileList(args []string) ([]string, error) { +func getFileList(paths []string) ([]string, error) { fileList := []string{} - for i := 0; i < len(args); i++ { + for i := 0; i < len(paths); i++ { if Recursive { - f, err := getFilesRecursive(args[i]) + f, err := getFilesRecursive(paths[i]) if err != nil { return nil, err } fileList = append(fileList, f...) } else { - f, err := getFiles(args[i]) + f, err := getFiles(paths[i]) if err != nil { return nil, err } @@ -98,7 +98,12 @@ func getFileList(args []string) ([]string, error) { return fileList, nil } -func pickFile(fileList []string) (string, string, error) { +func pickFile(args []string) (string, error) { + fileList, err := getFileList(args) + if err != nil { + return "", err + } + rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(fileList), func(i, j int) { fileList[i], fileList[j] = fileList[j], fileList[i] }) @@ -107,17 +112,16 @@ func pickFile(fileList []string) (string, string, error) { filePath := fileList[i] isImage, err := checkIfImage(filePath) if err != nil { - return "", "", err + return "", err } if isImage { - fileName := filepath.Base(filePath) - return fileName, filePath, nil + return filePath, nil } } - err := errors.New("no images found") + err = errors.New("no images found") - return "", "", err + return "", err } func normalizePaths(args []string) ([]string, error) { diff --git a/cmd/version.go b/cmd/version.go index 4f868ec..0d0e08b 100644 --- a/cmd/version.go +++ b/cmd/version.go @@ -10,7 +10,7 @@ import ( "github.com/spf13/cobra" ) -var Version = "0.5.2" +var Version = "0.6.0" func init() { rootCmd.AddCommand(versionCmd) diff --git a/cmd/web.go b/cmd/web.go index 6ae004b..9ed5d5d 100644 --- a/cmd/web.go +++ b/cmd/web.go @@ -12,6 +12,7 @@ import ( "net/http" "net/url" "os" + "path/filepath" "strconv" "strings" "time" @@ -19,17 +20,10 @@ import ( const LOGDATE string = "2006-01-02T15:04:05.000000000-07:00" -func generatePageHtml(w http.ResponseWriter, r http.Request, paths []string) error { - fileList, err := getFileList(paths) - if err != nil { - return err - } +const PREFIX string = "/src" - fileName, filePath, err := pickFile(fileList) - if err != nil { - http.NotFound(w, &r) - return nil - } +func generatePageHtml(w http.ResponseWriter, r http.Request, filePath string) error { + fileName := filepath.Base(filePath) w.Header().Add("Content-Type", "text/html") @@ -42,12 +36,12 @@ func generatePageHtml(w http.ResponseWriter, r http.Request, paths []string) err ` - _, err = io.WriteString(w, htmlBody) + _, err := io.WriteString(w, htmlBody) if err != nil { return err } @@ -58,11 +52,13 @@ func generatePageHtml(w http.ResponseWriter, r http.Request, paths []string) err func serveStaticFile(w http.ResponseWriter, r http.Request, paths []string) error { request := r.RequestURI - filePath, err := url.QueryUnescape(request) + prefixedFilePath, err := url.QueryUnescape(request) if err != nil { return err } + filePath := strings.TrimPrefix(prefixedFilePath, PREFIX) + var matchesPrefix = false for i := 0; i < len(paths); i++ { @@ -115,18 +111,43 @@ func serveStaticFile(w http.ResponseWriter, r http.Request, paths []string) erro return nil } +func serveStaticFileHandler(paths []string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + err := serveStaticFile(w, *r, paths) + if err != nil { + log.Fatal(err) + } + } +} + func servePageHandler(paths []string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if r.RequestURI == "/" { - err := generatePageHtml(w, *r, paths) + filePath, err := pickFile(paths) if err != nil { log.Fatal(err) } + + newUrl := r.URL.Host + filePath + http.Redirect(w, r, newUrl, http.StatusSeeOther) } else { - err := serveStaticFile(w, *r, paths) + filePath, err := url.QueryUnescape(r.RequestURI) if err != nil { log.Fatal(err) } + + isImage, err := checkIfImage(filePath) + if err != nil { + fmt.Println(err) + http.NotFound(w, r) + } + + if isImage { + err := generatePageHtml(w, *r, filePath) + if err != nil { + log.Fatal(err) + } + } } } } @@ -139,7 +160,11 @@ func ServePage(args []string) { log.Fatal(err) } + for _, i := range paths { + fmt.Println("Paths: " + i) + } http.HandleFunc("/", servePageHandler(paths)) + http.Handle(PREFIX+"/", http.StripPrefix(PREFIX, serveStaticFileHandler(paths))) http.HandleFunc("/favicon.ico", doNothing) port := strconv.Itoa(Port)