diff --git a/handlers/file.go b/handlers/file.go index 93279f4..61ceca1 100644 --- a/handlers/file.go +++ b/handlers/file.go @@ -1,9 +1,7 @@ package handlers import ( - "fmt" "net/http" - "path/filepath" "strconv" "threadr/models" ) @@ -23,10 +21,21 @@ func FileHandler(app *App) http.HandlerFunc { return } - fileExt := filepath.Ext(file.OriginalName) - fileName := fmt.Sprintf("%d%s", fileID, fileExt) - filePath := filepath.Join(app.Config.FileStorageDir, fileName) + isProfileImage, err := models.IsProfileImageFile(app.DB, fileID) + if err != nil || !isProfileImage { + http.NotFound(w, r) + return + } + filePath, contentType, ok := models.ResolveStoredImagePath(app.Config.FileStorageDir, file) + if !ok { + http.NotFound(w, r) + return + } + + w.Header().Set("Content-Type", contentType) + w.Header().Set("X-Content-Type-Options", "nosniff") + w.Header().Set("Cache-Control", "private, max-age=300") http.ServeFile(w, r, filePath) } } diff --git a/handlers/profile_edit.go b/handlers/profile_edit.go index 28e4a2a..e35094f 100644 --- a/handlers/profile_edit.go +++ b/handlers/profile_edit.go @@ -1,18 +1,29 @@ package handlers import ( + "bytes" "crypto/sha256" + "errors" "fmt" + "image" + _ "image/gif" + _ "image/jpeg" + "image/png" + _ "image/png" "io" "log" + "mime/multipart" "net/http" "os" "path/filepath" + "strings" "threadr/models" "github.com/gorilla/sessions" ) +const maxProfileImageBytes = 2 << 20 + func ProfileEditHandler(app *App) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { session := r.Context().Value("session").(*sessions.Session) @@ -28,62 +39,65 @@ func ProfileEditHandler(app *App) http.HandlerFunc { return } + r.Body = http.MaxBytesReader(w, r.Body, maxProfileImageBytes+(256<<10)) + // Handle file upload file, handler, err := r.FormFile("pfp") if err == nil { defer file.Close() - // Create a hash of the file - h := sha256.New() - if _, err := io.Copy(h, file); err != nil { - log.Printf("Error hashing file: %v", err) - http.Error(w, "Failed to process file", http.StatusInternalServerError) - return - } - fileHash := fmt.Sprintf("%x", h.Sum(nil)) - - // Create file record in the database - fileRecord := models.File{ - OriginalName: handler.Filename, - Hash: fileHash, - HashAlgorithm: "sha256", - } - fileID, err := models.CreateFile(app.DB, fileRecord) + fileHash, fileID, err := saveProfileImageUpload(app, file) if err != nil { - log.Printf("Error creating file record: %v", err) - http.Error(w, "Failed to save file information", http.StatusInternalServerError) - return - } - - // Save the file to disk - fileExt := filepath.Ext(handler.Filename) - newFileName := fmt.Sprintf("%d%s", fileID, fileExt) - filePath := filepath.Join(app.Config.FileStorageDir, newFileName) - - // Reset file pointer - file.Seek(0, 0) - - dst, err := os.Create(filePath) - if err != nil { - log.Printf("Error creating file on disk: %v", err) + if errors.Is(err, errInvalidProfileImage) || errors.Is(err, errProfileImageTooLarge) { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + log.Printf("Error saving profile image: %v", err) http.Error(w, "Failed to save file", http.StatusInternalServerError) return } - defer dst.Close() - if _, err := io.Copy(dst, file); err != nil { - log.Printf("Error saving file to disk: %v", err) + // Create file record in the database + fileRecord := models.File{ + OriginalName: sanitizeOriginalFileName(handler.Filename), + Hash: fileHash, + HashAlgorithm: "sha256", + } + createdFileID, err := models.CreateFile(app.DB, fileRecord) + if err != nil { + log.Printf("Error creating file record: %v", err) + http.Error(w, "Failed to save file information", http.StatusInternalServerError) + _ = os.Remove(fileID) + return + } + + finalPath := filepath.Join(app.Config.FileStorageDir, models.ProfileImageStorageName(createdFileID)) + if err := os.Rename(fileID, finalPath); err != nil { + _ = os.Remove(fileID) + _ = models.DeleteFileByID(app.DB, createdFileID) + log.Printf("Error moving file on disk: %v", err) http.Error(w, "Failed to save file", http.StatusInternalServerError) return } // Update user's pfp_file_id - err = models.UpdateUserPfp(app.DB, userID, fileID) + err = models.UpdateUserPfp(app.DB, userID, createdFileID) if err != nil { + _ = os.Remove(finalPath) + _ = models.DeleteFileByID(app.DB, createdFileID) log.Printf("Error updating user pfp: %v", err) http.Error(w, "Failed to update profile", http.StatusInternalServerError) return } + } else if err != nil && !errors.Is(err, http.ErrMissingFile) { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) || strings.Contains(err.Error(), "request body too large") { + http.Error(w, errProfileImageTooLarge.Error(), http.StatusBadRequest) + return + } + log.Printf("Error reading upload: %v", err) + http.Error(w, "Failed to process file upload", http.StatusBadRequest) + return } // Update other profile fields @@ -134,3 +148,67 @@ func ProfileEditHandler(app *App) http.HandlerFunc { } } } + +var ( + errInvalidProfileImage = errors.New("Profile picture must be a PNG, JPEG, or GIF image") + errProfileImageTooLarge = errors.New("Profile picture must be 2 MB or smaller") +) + +func saveProfileImageUpload(app *App, file multipart.File) (string, string, error) { + limitedReader := io.LimitReader(file, maxProfileImageBytes+1) + data, err := io.ReadAll(limitedReader) + if err != nil { + return "", "", err + } + if int64(len(data)) > maxProfileImageBytes { + return "", "", errProfileImageTooLarge + } + + contentType := http.DetectContentType(data) + if !isAllowedProfileImageType(contentType) { + return "", "", errInvalidProfileImage + } + + img, _, err := image.Decode(bytes.NewReader(data)) + if err != nil { + return "", "", errInvalidProfileImage + } + + tmpFile, err := os.CreateTemp(app.Config.FileStorageDir, "pfp-*.png") + if err != nil { + return "", "", err + } + defer func() { + _ = tmpFile.Close() + }() + + hash := sha256.Sum256(data) + if err := png.Encode(tmpFile, img); err != nil { + _ = os.Remove(tmpFile.Name()) + return "", "", err + } + + if err := tmpFile.Close(); err != nil { + _ = os.Remove(tmpFile.Name()) + return "", "", err + } + + return fmt.Sprintf("%x", hash[:]), tmpFile.Name(), nil +} + +func sanitizeOriginalFileName(name string) string { + base := filepath.Base(strings.TrimSpace(name)) + if base == "." || base == string(filepath.Separator) || base == "" { + return "profile.png" + } + return base +} + +func isAllowedProfileImageType(contentType string) bool { + switch contentType { + case "image/png", "image/jpeg", "image/gif": + return true + default: + return false + } +} diff --git a/models/file.go b/models/file.go index 10fffa5..dde8b18 100644 --- a/models/file.go +++ b/models/file.go @@ -2,8 +2,14 @@ package models import ( "database/sql" + "fmt" + "os" + "path/filepath" + "strings" ) +const ProfileImageExtension = ".png" + type File struct { ID int OriginalName string @@ -25,6 +31,15 @@ func GetFileByID(db *sql.DB, id int64) (*File, error) { return file, nil } +func IsProfileImageFile(db *sql.DB, id int64) (bool, error) { + var exists bool + err := db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE pfp_file_id = ?)", id).Scan(&exists) + if err != nil { + return false, err + } + return exists, nil +} + func CreateFile(db *sql.DB, file File) (int64, error) { query := "INSERT INTO files (original_name, hash, hash_algorithm) VALUES (?, ?, ?)" result, err := db.Exec(query, file.OriginalName, file.Hash, file.HashAlgorithm) @@ -33,3 +48,58 @@ func CreateFile(db *sql.DB, file File) (int64, error) { } return result.LastInsertId() } + +func DeleteFileByID(db *sql.DB, id int64) error { + _, err := db.Exec("DELETE FROM files WHERE id = ?", id) + return err +} + +func ProfileImageStorageName(id int64) string { + return fmt.Sprintf("%d%s", id, ProfileImageExtension) +} + +func LegacyImageStorageName(id int64, originalName string) (string, bool) { + ext := strings.ToLower(filepath.Ext(originalName)) + if !allowedImageExtension(ext) { + return "", false + } + return fmt.Sprintf("%d%s", id, ext), true +} + +func ProfileImageContentType(fileName string) string { + switch strings.ToLower(filepath.Ext(fileName)) { + case ".jpg", ".jpeg": + return "image/jpeg" + case ".gif": + return "image/gif" + default: + return "image/png" + } +} + +func ResolveStoredImagePath(storageDir string, file *File) (string, string, bool) { + currentPath := filepath.Join(storageDir, ProfileImageStorageName(int64(file.ID))) + if _, err := os.Stat(currentPath); err == nil { + return currentPath, ProfileImageContentType(currentPath), true + } + + legacyName, ok := LegacyImageStorageName(int64(file.ID), file.OriginalName) + if !ok { + return "", "", false + } + legacyPath := filepath.Join(storageDir, legacyName) + if _, err := os.Stat(legacyPath); err == nil { + return legacyPath, ProfileImageContentType(legacyPath), true + } + + return "", "", false +} + +func allowedImageExtension(ext string) bool { + switch ext { + case ".png", ".jpg", ".jpeg", ".gif": + return true + default: + return false + } +} diff --git a/templates/pages/profile_edit.html b/templates/pages/profile_edit.html index 2d0a219..37b06ad 100644 --- a/templates/pages/profile_edit.html +++ b/templates/pages/profile_edit.html @@ -18,7 +18,8 @@
-
+
+

PNG, JPEG, or GIF only, up to 2 MB. Images are re-encoded before storage.