diff --git a/handlers/file.go b/handlers/file.go
index 93279f4..61ceca1 100644
--- a/handlers/file.go
+++ b/handlers/file.go
@@ -1,9 +1,7 @@
package handlers
import (
- "fmt"
"net/http"
- "path/filepath"
"strconv"
"threadr/models"
)
@@ -23,10 +21,21 @@ func FileHandler(app *App) http.HandlerFunc {
return
}
- fileExt := filepath.Ext(file.OriginalName)
- fileName := fmt.Sprintf("%d%s", fileID, fileExt)
- filePath := filepath.Join(app.Config.FileStorageDir, fileName)
+ isProfileImage, err := models.IsProfileImageFile(app.DB, fileID)
+ if err != nil || !isProfileImage {
+ http.NotFound(w, r)
+ return
+ }
+ filePath, contentType, ok := models.ResolveStoredImagePath(app.Config.FileStorageDir, file)
+ if !ok {
+ http.NotFound(w, r)
+ return
+ }
+
+ w.Header().Set("Content-Type", contentType)
+ w.Header().Set("X-Content-Type-Options", "nosniff")
+ w.Header().Set("Cache-Control", "private, max-age=300")
http.ServeFile(w, r, filePath)
}
}
diff --git a/handlers/profile_edit.go b/handlers/profile_edit.go
index 28e4a2a..e35094f 100644
--- a/handlers/profile_edit.go
+++ b/handlers/profile_edit.go
@@ -1,18 +1,29 @@
package handlers
import (
+ "bytes"
"crypto/sha256"
+ "errors"
"fmt"
+ "image"
+ _ "image/gif"
+ _ "image/jpeg"
+ "image/png"
+ _ "image/png"
"io"
"log"
+ "mime/multipart"
"net/http"
"os"
"path/filepath"
+ "strings"
"threadr/models"
"github.com/gorilla/sessions"
)
+const maxProfileImageBytes = 2 << 20
+
func ProfileEditHandler(app *App) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := r.Context().Value("session").(*sessions.Session)
@@ -28,62 +39,65 @@ func ProfileEditHandler(app *App) http.HandlerFunc {
return
}
+ r.Body = http.MaxBytesReader(w, r.Body, maxProfileImageBytes+(256<<10))
+
// Handle file upload
file, handler, err := r.FormFile("pfp")
if err == nil {
defer file.Close()
- // Create a hash of the file
- h := sha256.New()
- if _, err := io.Copy(h, file); err != nil {
- log.Printf("Error hashing file: %v", err)
- http.Error(w, "Failed to process file", http.StatusInternalServerError)
- return
- }
- fileHash := fmt.Sprintf("%x", h.Sum(nil))
-
- // Create file record in the database
- fileRecord := models.File{
- OriginalName: handler.Filename,
- Hash: fileHash,
- HashAlgorithm: "sha256",
- }
- fileID, err := models.CreateFile(app.DB, fileRecord)
+ fileHash, fileID, err := saveProfileImageUpload(app, file)
if err != nil {
- log.Printf("Error creating file record: %v", err)
- http.Error(w, "Failed to save file information", http.StatusInternalServerError)
- return
- }
-
- // Save the file to disk
- fileExt := filepath.Ext(handler.Filename)
- newFileName := fmt.Sprintf("%d%s", fileID, fileExt)
- filePath := filepath.Join(app.Config.FileStorageDir, newFileName)
-
- // Reset file pointer
- file.Seek(0, 0)
-
- dst, err := os.Create(filePath)
- if err != nil {
- log.Printf("Error creating file on disk: %v", err)
+ if errors.Is(err, errInvalidProfileImage) || errors.Is(err, errProfileImageTooLarge) {
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+ log.Printf("Error saving profile image: %v", err)
http.Error(w, "Failed to save file", http.StatusInternalServerError)
return
}
- defer dst.Close()
- if _, err := io.Copy(dst, file); err != nil {
- log.Printf("Error saving file to disk: %v", err)
+ // Create file record in the database
+ fileRecord := models.File{
+ OriginalName: sanitizeOriginalFileName(handler.Filename),
+ Hash: fileHash,
+ HashAlgorithm: "sha256",
+ }
+ createdFileID, err := models.CreateFile(app.DB, fileRecord)
+ if err != nil {
+ log.Printf("Error creating file record: %v", err)
+ http.Error(w, "Failed to save file information", http.StatusInternalServerError)
+ _ = os.Remove(fileID)
+ return
+ }
+
+ finalPath := filepath.Join(app.Config.FileStorageDir, models.ProfileImageStorageName(createdFileID))
+ if err := os.Rename(fileID, finalPath); err != nil {
+ _ = os.Remove(fileID)
+ _ = models.DeleteFileByID(app.DB, createdFileID)
+ log.Printf("Error moving file on disk: %v", err)
http.Error(w, "Failed to save file", http.StatusInternalServerError)
return
}
// Update user's pfp_file_id
- err = models.UpdateUserPfp(app.DB, userID, fileID)
+ err = models.UpdateUserPfp(app.DB, userID, createdFileID)
if err != nil {
+ _ = os.Remove(finalPath)
+ _ = models.DeleteFileByID(app.DB, createdFileID)
log.Printf("Error updating user pfp: %v", err)
http.Error(w, "Failed to update profile", http.StatusInternalServerError)
return
}
+ } else if err != nil && !errors.Is(err, http.ErrMissingFile) {
+ var maxBytesErr *http.MaxBytesError
+ if errors.As(err, &maxBytesErr) || strings.Contains(err.Error(), "request body too large") {
+ http.Error(w, errProfileImageTooLarge.Error(), http.StatusBadRequest)
+ return
+ }
+ log.Printf("Error reading upload: %v", err)
+ http.Error(w, "Failed to process file upload", http.StatusBadRequest)
+ return
}
// Update other profile fields
@@ -134,3 +148,67 @@ func ProfileEditHandler(app *App) http.HandlerFunc {
}
}
}
+
+var (
+ errInvalidProfileImage = errors.New("Profile picture must be a PNG, JPEG, or GIF image")
+ errProfileImageTooLarge = errors.New("Profile picture must be 2 MB or smaller")
+)
+
+func saveProfileImageUpload(app *App, file multipart.File) (string, string, error) {
+ limitedReader := io.LimitReader(file, maxProfileImageBytes+1)
+ data, err := io.ReadAll(limitedReader)
+ if err != nil {
+ return "", "", err
+ }
+ if int64(len(data)) > maxProfileImageBytes {
+ return "", "", errProfileImageTooLarge
+ }
+
+ contentType := http.DetectContentType(data)
+ if !isAllowedProfileImageType(contentType) {
+ return "", "", errInvalidProfileImage
+ }
+
+ img, _, err := image.Decode(bytes.NewReader(data))
+ if err != nil {
+ return "", "", errInvalidProfileImage
+ }
+
+ tmpFile, err := os.CreateTemp(app.Config.FileStorageDir, "pfp-*.png")
+ if err != nil {
+ return "", "", err
+ }
+ defer func() {
+ _ = tmpFile.Close()
+ }()
+
+ hash := sha256.Sum256(data)
+ if err := png.Encode(tmpFile, img); err != nil {
+ _ = os.Remove(tmpFile.Name())
+ return "", "", err
+ }
+
+ if err := tmpFile.Close(); err != nil {
+ _ = os.Remove(tmpFile.Name())
+ return "", "", err
+ }
+
+ return fmt.Sprintf("%x", hash[:]), tmpFile.Name(), nil
+}
+
+func sanitizeOriginalFileName(name string) string {
+ base := filepath.Base(strings.TrimSpace(name))
+ if base == "." || base == string(filepath.Separator) || base == "" {
+ return "profile.png"
+ }
+ return base
+}
+
+func isAllowedProfileImageType(contentType string) bool {
+ switch contentType {
+ case "image/png", "image/jpeg", "image/gif":
+ return true
+ default:
+ return false
+ }
+}
diff --git a/models/file.go b/models/file.go
index 10fffa5..dde8b18 100644
--- a/models/file.go
+++ b/models/file.go
@@ -2,8 +2,14 @@ package models
import (
"database/sql"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
)
+const ProfileImageExtension = ".png"
+
type File struct {
ID int
OriginalName string
@@ -25,6 +31,15 @@ func GetFileByID(db *sql.DB, id int64) (*File, error) {
return file, nil
}
+func IsProfileImageFile(db *sql.DB, id int64) (bool, error) {
+ var exists bool
+ err := db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE pfp_file_id = ?)", id).Scan(&exists)
+ if err != nil {
+ return false, err
+ }
+ return exists, nil
+}
+
func CreateFile(db *sql.DB, file File) (int64, error) {
query := "INSERT INTO files (original_name, hash, hash_algorithm) VALUES (?, ?, ?)"
result, err := db.Exec(query, file.OriginalName, file.Hash, file.HashAlgorithm)
@@ -33,3 +48,58 @@ func CreateFile(db *sql.DB, file File) (int64, error) {
}
return result.LastInsertId()
}
+
+func DeleteFileByID(db *sql.DB, id int64) error {
+ _, err := db.Exec("DELETE FROM files WHERE id = ?", id)
+ return err
+}
+
+func ProfileImageStorageName(id int64) string {
+ return fmt.Sprintf("%d%s", id, ProfileImageExtension)
+}
+
+func LegacyImageStorageName(id int64, originalName string) (string, bool) {
+ ext := strings.ToLower(filepath.Ext(originalName))
+ if !allowedImageExtension(ext) {
+ return "", false
+ }
+ return fmt.Sprintf("%d%s", id, ext), true
+}
+
+func ProfileImageContentType(fileName string) string {
+ switch strings.ToLower(filepath.Ext(fileName)) {
+ case ".jpg", ".jpeg":
+ return "image/jpeg"
+ case ".gif":
+ return "image/gif"
+ default:
+ return "image/png"
+ }
+}
+
+func ResolveStoredImagePath(storageDir string, file *File) (string, string, bool) {
+ currentPath := filepath.Join(storageDir, ProfileImageStorageName(int64(file.ID)))
+ if _, err := os.Stat(currentPath); err == nil {
+ return currentPath, ProfileImageContentType(currentPath), true
+ }
+
+ legacyName, ok := LegacyImageStorageName(int64(file.ID), file.OriginalName)
+ if !ok {
+ return "", "", false
+ }
+ legacyPath := filepath.Join(storageDir, legacyName)
+ if _, err := os.Stat(legacyPath); err == nil {
+ return legacyPath, ProfileImageContentType(legacyPath), true
+ }
+
+ return "", "", false
+}
+
+func allowedImageExtension(ext string) bool {
+ switch ext {
+ case ".png", ".jpg", ".jpeg", ".gif":
+ return true
+ default:
+ return false
+ }
+}
diff --git a/templates/pages/profile_edit.html b/templates/pages/profile_edit.html
index 2d0a219..37b06ad 100644
--- a/templates/pages/profile_edit.html
+++ b/templates/pages/profile_edit.html
@@ -18,7 +18,8 @@
-
+
+
PNG, JPEG, or GIF only, up to 2 MB. Images are re-encoded before storage.