diff --git a/main.go b/main.go index ae040ab..7ba5fc4 100644 --- a/main.go +++ b/main.go @@ -9,7 +9,6 @@ import ( "math/rand" "net/http" "os" - "path/filepath" "strconv" "time" @@ -33,78 +32,26 @@ func init() { } func main() { - flag.StringVar(&port, "port", "3000", "Port to run the server on") - flag.Parse() - - // Initialize Viper for configuration management - viper.SetConfigName("config") - viper.SetConfigType("yaml") - viper.AddConfigPath(".") - viper.AutomaticEnv() // Allow environment variables to override config settings - - // Read the configuration file (config.yaml in this example) + viper.SetConfigName("config") // name of config file (without extension) + viper.SetConfigType("yaml") // REQUIRED if the config file does not have the extension in the name + viper.AddConfigPath(".") // path to look for the config file in if err := viper.ReadInConfig(); err != nil { log.Fatalf("Error reading configuration file: %v", err) } - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - images := getImages(r) // Call getImages with the request parameter - io.WriteString(w, "
") - for i, image := range images { - io.WriteString(w, ``) - io.WriteString(w, ``) - io.WriteString(w, ``) - } - io.WriteString(w, "
") - }) + port := viper.GetString("server.port") + flag.Parse() - http.HandleFunc("/api/id", func(w http.ResponseWriter, r *http.Request) { - id := sanitizeInput(r.URL.Query().Get("id")) - if id == "" { - http.Error(w, "Missing id", http.StatusBadRequest) - return - } - i, err := strconv.Atoi(id) - if err != nil || i < 0 || i >= len(images) { - http.Error(w, "Invalid id", http.StatusBadRequest) - return - } - http.ServeFile(w, r, images[i]) - }) + if err := viper.ReadInConfig(); err != nil { + log.Fatalf("Error reading configuration file: %v", err) + } - http.HandleFunc("/api/list", func(w http.ResponseWriter, r *http.Request) { - // Create a slice to store image information - imageList := []map[string]string{} + images = getImages() - for _, image := range images { - imageInfo := map[string]string{ - "image": image, - "url": "/" + filepath.Base(image), - } - imageList = append(imageList, imageInfo) - } - - // Convert the slice to JSON - jsonData, err := json.Marshal(imageList) - if err != nil { - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - - // Set the content type to JSON - w.Header().Set("Content-Type", "application/json") - - // Write the JSON response - w.Write(jsonData) - }) - - http.HandleFunc("/api/random", func(w http.ResponseWriter, r *http.Request) { - // Reseed the random number generator to make it truly random on each request - rand.Seed(time.Now().UnixNano()) - - i := rand.Intn(len(images)) - http.ServeFile(w, r, images[i]) - }) + http.HandleFunc("/", homeHandler) + http.HandleFunc("/api/id", idHandler) + http.HandleFunc("/api/list", listHandler) + http.HandleFunc("/api/random", randomHandler) http.HandleFunc("/api/", func(w http.ResponseWriter, r *http.Request) { http.Error(w, "Invalid API path", http.StatusNotFound) @@ -114,17 +61,15 @@ func main() { log.Fatal(http.ListenAndServe(":"+port, nil)) } -func getImages(r *http.Request) []string { +func getImages() []string { files, err := os.ReadDir("images/") if err != nil { logger.WithError(err).Fatal("Error reading images directory") } var images []string - serverAddress := r.Host // Get the server address from the request for _, file := range files { - imagePath := "http://" + serverAddress + "/images/" + file.Name() - images = append(images, imagePath) - logger.Info("Loaded image:", imagePath) + images = append(images, file.Name()) + logger.Info("Loaded image:", file.Name()) } return images } @@ -132,3 +77,51 @@ func getImages(r *http.Request) []string { func sanitizeInput(input string) string { return template.HTMLEscapeString(input) } + +func homeHandler(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "
") + for i := range images { + io.WriteString(w, ``) + io.WriteString(w, ``) + io.WriteString(w, ``) + } + io.WriteString(w, "
") +} + +func idHandler(w http.ResponseWriter, r *http.Request) { + id := sanitizeInput(r.URL.Query().Get("id")) + if id == "" { + http.Error(w, "Missing id", http.StatusBadRequest) + return + } + i, err := strconv.Atoi(id) + if err != nil || i < 0 || i >= len(images) { + http.Error(w, "Invalid id", http.StatusBadRequest) + return + } + http.ServeFile(w, r, "images/"+images[i]) +} + +func listHandler(w http.ResponseWriter, r *http.Request) { + imageList := []map[string]string{} + for i := range images { + imageInfo := map[string]string{ + "id": strconv.Itoa(i), + "url": "/api/id?id=" + strconv.Itoa(i), + } + imageList = append(imageList, imageInfo) + } + jsonData, err := json.Marshal(imageList) + if err != nil { + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + w.Write(jsonData) +} + +func randomHandler(w http.ResponseWriter, r *http.Request) { + rand.Seed(time.Now().UnixNano()) + i := rand.Intn(len(images)) + http.Redirect(w, r, "/api/id?id="+strconv.Itoa(i), http.StatusSeeOther) +}