diff --git a/cmd/files.go b/cmd/files.go index 498035b..f55fcba 100644 --- a/cmd/files.go +++ b/cmd/files.go @@ -5,12 +5,40 @@ Copyright © 2022 Seednode package cmd import ( + "bytes" + "errors" + "io" "math/rand" "os" "path/filepath" "time" ) +func checkIfImage(path string) (bool, error) { + magicNumber := make([]byte, 3) + + file, err := os.Open(path) + if err != nil { + return false, err + } + + _, err = io.ReadFull(file, magicNumber) + if err != nil { + return false, err + } + + switch { + case bytes.Compare(magicNumber, []byte{0xFF, 0xD8, 0xFF}) == 0: // JPG + return true, nil + case bytes.Compare(magicNumber, []byte{0x89, 0x50, 0x4E}) == 0: // PNG + return true, nil + case bytes.Compare(magicNumber, []byte{0x47, 0x49, 0x46}) == 0: // GIF + return true, nil + default: + return false, nil + } +} + func getFiles(path string) ([]string, error) { var paths []string @@ -78,13 +106,29 @@ func getFileList(args []string) ([]string, error) { return fileList, nil } -func pickFile(fileList []string) (string, string) { - rand.Seed(time.Now().UnixMicro()) +func pickFile(fileList []string) (string, string, error) { + rand.Seed(time.Now().UnixNano()) - filePath := fileList[rand.Intn(len(fileList))] - fileName := filepath.Base(filePath) + rand.Shuffle(len(fileList), func(i, j int) { fileList[i], fileList[j] = fileList[j], fileList[i] }) - return fileName, filePath + var filePath string + var fileName string + + for i := 0; i < len(fileList); i++ { + filePath = fileList[i] + fileName = filepath.Base(filePath) + isImage, err := checkIfImage(filePath) + if err != nil { + return "", "", err + } + if isImage { + return fileName, filePath, nil + } + } + + err := errors.New("no images found") + + return "", "", err } func normalizePaths(args []string) ([]string, error) { diff --git a/cmd/version.go b/cmd/version.go index aa7fa2e..50fd336 100644 --- a/cmd/version.go +++ b/cmd/version.go @@ -10,7 +10,7 @@ import ( "github.com/spf13/cobra" ) -var Version = "0.2.1" +var Version = "0.3.0" func init() { rootCmd.AddCommand(versionCmd) diff --git a/cmd/web.go b/cmd/web.go index 8645ee5..58e5268 100644 --- a/cmd/web.go +++ b/cmd/web.go @@ -15,13 +15,17 @@ import ( "strings" ) -func generatePageHtml(w http.ResponseWriter, paths []string) error { +func generatePageHtml(w http.ResponseWriter, r http.Request, paths []string) error { fileList, err := getFileList(paths) if err != nil { return err } - fileName, filePath := pickFile(fileList) + fileName, filePath, err := pickFile(fileList) + if err != nil { + http.NotFound(w, &r) + return nil + } w.Header().Add("Content-Type", "text/html") @@ -91,7 +95,7 @@ func serveStaticFile(w http.ResponseWriter, r http.Request, paths []string) erro func servePageHandler(paths []string) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if r.RequestURI == "/" { - err := generatePageHtml(w, paths) + err := generatePageHtml(w, *r, paths) if err != nil { log.Fatal(err) }