Harden profile image uploads.
parent
7a5b0f8ca5
commit
8ff0b7f2c2
|
|
@ -1,9 +1,7 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"threadr/models"
|
||||
)
|
||||
|
|
@ -23,10 +21,21 @@ func FileHandler(app *App) http.HandlerFunc {
|
|||
return
|
||||
}
|
||||
|
||||
fileExt := filepath.Ext(file.OriginalName)
|
||||
fileName := fmt.Sprintf("%d%s", fileID, fileExt)
|
||||
filePath := filepath.Join(app.Config.FileStorageDir, fileName)
|
||||
isProfileImage, err := models.IsProfileImageFile(app.DB, fileID)
|
||||
if err != nil || !isProfileImage {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
filePath, contentType, ok := models.ResolveStoredImagePath(app.Config.FileStorageDir, file)
|
||||
if !ok {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", contentType)
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("Cache-Control", "private, max-age=300")
|
||||
http.ServeFile(w, r, filePath)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,18 +1,29 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/gif"
|
||||
_ "image/jpeg"
|
||||
"image/png"
|
||||
_ "image/png"
|
||||
"io"
|
||||
"log"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"threadr/models"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
const maxProfileImageBytes = 2 << 20
|
||||
|
||||
func ProfileEditHandler(app *App) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
session := r.Context().Value("session").(*sessions.Session)
|
||||
|
|
@ -28,62 +39,65 @@ func ProfileEditHandler(app *App) http.HandlerFunc {
|
|||
return
|
||||
}
|
||||
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxProfileImageBytes+(256<<10))
|
||||
|
||||
// Handle file upload
|
||||
file, handler, err := r.FormFile("pfp")
|
||||
if err == nil {
|
||||
defer file.Close()
|
||||
|
||||
// Create a hash of the file
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, file); err != nil {
|
||||
log.Printf("Error hashing file: %v", err)
|
||||
http.Error(w, "Failed to process file", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
fileHash := fmt.Sprintf("%x", h.Sum(nil))
|
||||
|
||||
// Create file record in the database
|
||||
fileRecord := models.File{
|
||||
OriginalName: handler.Filename,
|
||||
Hash: fileHash,
|
||||
HashAlgorithm: "sha256",
|
||||
}
|
||||
fileID, err := models.CreateFile(app.DB, fileRecord)
|
||||
fileHash, fileID, err := saveProfileImageUpload(app, file)
|
||||
if err != nil {
|
||||
log.Printf("Error creating file record: %v", err)
|
||||
http.Error(w, "Failed to save file information", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Save the file to disk
|
||||
fileExt := filepath.Ext(handler.Filename)
|
||||
newFileName := fmt.Sprintf("%d%s", fileID, fileExt)
|
||||
filePath := filepath.Join(app.Config.FileStorageDir, newFileName)
|
||||
|
||||
// Reset file pointer
|
||||
file.Seek(0, 0)
|
||||
|
||||
dst, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
log.Printf("Error creating file on disk: %v", err)
|
||||
if errors.Is(err, errInvalidProfileImage) || errors.Is(err, errProfileImageTooLarge) {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
log.Printf("Error saving profile image: %v", err)
|
||||
http.Error(w, "Failed to save file", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
if _, err := io.Copy(dst, file); err != nil {
|
||||
log.Printf("Error saving file to disk: %v", err)
|
||||
// Create file record in the database
|
||||
fileRecord := models.File{
|
||||
OriginalName: sanitizeOriginalFileName(handler.Filename),
|
||||
Hash: fileHash,
|
||||
HashAlgorithm: "sha256",
|
||||
}
|
||||
createdFileID, err := models.CreateFile(app.DB, fileRecord)
|
||||
if err != nil {
|
||||
log.Printf("Error creating file record: %v", err)
|
||||
http.Error(w, "Failed to save file information", http.StatusInternalServerError)
|
||||
_ = os.Remove(fileID)
|
||||
return
|
||||
}
|
||||
|
||||
finalPath := filepath.Join(app.Config.FileStorageDir, models.ProfileImageStorageName(createdFileID))
|
||||
if err := os.Rename(fileID, finalPath); err != nil {
|
||||
_ = os.Remove(fileID)
|
||||
_ = models.DeleteFileByID(app.DB, createdFileID)
|
||||
log.Printf("Error moving file on disk: %v", err)
|
||||
http.Error(w, "Failed to save file", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Update user's pfp_file_id
|
||||
err = models.UpdateUserPfp(app.DB, userID, fileID)
|
||||
err = models.UpdateUserPfp(app.DB, userID, createdFileID)
|
||||
if err != nil {
|
||||
_ = os.Remove(finalPath)
|
||||
_ = models.DeleteFileByID(app.DB, createdFileID)
|
||||
log.Printf("Error updating user pfp: %v", err)
|
||||
http.Error(w, "Failed to update profile", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
} else if err != nil && !errors.Is(err, http.ErrMissingFile) {
|
||||
var maxBytesErr *http.MaxBytesError
|
||||
if errors.As(err, &maxBytesErr) || strings.Contains(err.Error(), "request body too large") {
|
||||
http.Error(w, errProfileImageTooLarge.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
log.Printf("Error reading upload: %v", err)
|
||||
http.Error(w, "Failed to process file upload", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Update other profile fields
|
||||
|
|
@ -134,3 +148,67 @@ func ProfileEditHandler(app *App) http.HandlerFunc {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
errInvalidProfileImage = errors.New("Profile picture must be a PNG, JPEG, or GIF image")
|
||||
errProfileImageTooLarge = errors.New("Profile picture must be 2 MB or smaller")
|
||||
)
|
||||
|
||||
func saveProfileImageUpload(app *App, file multipart.File) (string, string, error) {
|
||||
limitedReader := io.LimitReader(file, maxProfileImageBytes+1)
|
||||
data, err := io.ReadAll(limitedReader)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
if int64(len(data)) > maxProfileImageBytes {
|
||||
return "", "", errProfileImageTooLarge
|
||||
}
|
||||
|
||||
contentType := http.DetectContentType(data)
|
||||
if !isAllowedProfileImageType(contentType) {
|
||||
return "", "", errInvalidProfileImage
|
||||
}
|
||||
|
||||
img, _, err := image.Decode(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return "", "", errInvalidProfileImage
|
||||
}
|
||||
|
||||
tmpFile, err := os.CreateTemp(app.Config.FileStorageDir, "pfp-*.png")
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer func() {
|
||||
_ = tmpFile.Close()
|
||||
}()
|
||||
|
||||
hash := sha256.Sum256(data)
|
||||
if err := png.Encode(tmpFile, img); err != nil {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
_ = os.Remove(tmpFile.Name())
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%x", hash[:]), tmpFile.Name(), nil
|
||||
}
|
||||
|
||||
func sanitizeOriginalFileName(name string) string {
|
||||
base := filepath.Base(strings.TrimSpace(name))
|
||||
if base == "." || base == string(filepath.Separator) || base == "" {
|
||||
return "profile.png"
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
func isAllowedProfileImageType(contentType string) bool {
|
||||
switch contentType {
|
||||
case "image/png", "image/jpeg", "image/gif":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,8 +2,14 @@ package models
|
|||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const ProfileImageExtension = ".png"
|
||||
|
||||
type File struct {
|
||||
ID int
|
||||
OriginalName string
|
||||
|
|
@ -25,6 +31,15 @@ func GetFileByID(db *sql.DB, id int64) (*File, error) {
|
|||
return file, nil
|
||||
}
|
||||
|
||||
func IsProfileImageFile(db *sql.DB, id int64) (bool, error) {
|
||||
var exists bool
|
||||
err := db.QueryRow("SELECT EXISTS(SELECT 1 FROM users WHERE pfp_file_id = ?)", id).Scan(&exists)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
func CreateFile(db *sql.DB, file File) (int64, error) {
|
||||
query := "INSERT INTO files (original_name, hash, hash_algorithm) VALUES (?, ?, ?)"
|
||||
result, err := db.Exec(query, file.OriginalName, file.Hash, file.HashAlgorithm)
|
||||
|
|
@ -33,3 +48,58 @@ func CreateFile(db *sql.DB, file File) (int64, error) {
|
|||
}
|
||||
return result.LastInsertId()
|
||||
}
|
||||
|
||||
func DeleteFileByID(db *sql.DB, id int64) error {
|
||||
_, err := db.Exec("DELETE FROM files WHERE id = ?", id)
|
||||
return err
|
||||
}
|
||||
|
||||
func ProfileImageStorageName(id int64) string {
|
||||
return fmt.Sprintf("%d%s", id, ProfileImageExtension)
|
||||
}
|
||||
|
||||
func LegacyImageStorageName(id int64, originalName string) (string, bool) {
|
||||
ext := strings.ToLower(filepath.Ext(originalName))
|
||||
if !allowedImageExtension(ext) {
|
||||
return "", false
|
||||
}
|
||||
return fmt.Sprintf("%d%s", id, ext), true
|
||||
}
|
||||
|
||||
func ProfileImageContentType(fileName string) string {
|
||||
switch strings.ToLower(filepath.Ext(fileName)) {
|
||||
case ".jpg", ".jpeg":
|
||||
return "image/jpeg"
|
||||
case ".gif":
|
||||
return "image/gif"
|
||||
default:
|
||||
return "image/png"
|
||||
}
|
||||
}
|
||||
|
||||
func ResolveStoredImagePath(storageDir string, file *File) (string, string, bool) {
|
||||
currentPath := filepath.Join(storageDir, ProfileImageStorageName(int64(file.ID)))
|
||||
if _, err := os.Stat(currentPath); err == nil {
|
||||
return currentPath, ProfileImageContentType(currentPath), true
|
||||
}
|
||||
|
||||
legacyName, ok := LegacyImageStorageName(int64(file.ID), file.OriginalName)
|
||||
if !ok {
|
||||
return "", "", false
|
||||
}
|
||||
legacyPath := filepath.Join(storageDir, legacyName)
|
||||
if _, err := os.Stat(legacyPath); err == nil {
|
||||
return legacyPath, ProfileImageContentType(legacyPath), true
|
||||
}
|
||||
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
func allowedImageExtension(ext string) bool {
|
||||
switch ext {
|
||||
case ".png", ".jpg", ".jpeg", ".gif":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,8 @@
|
|||
<label for="display_name">Display Name:</label>
|
||||
<input type="text" id="display_name" name="display_name" value="{{.User.DisplayName}}" maxlength="255"><br>
|
||||
<label for="pfp">Profile Picture:</label>
|
||||
<input type="file" id="pfp" name="pfp" accept="image/*"><br>
|
||||
<input type="file" id="pfp" name="pfp" accept="image/png,image/jpeg,image/gif"><br>
|
||||
<p style="margin-top: 0.25em; font-size: 0.9em; opacity: 0.8;">PNG, JPEG, or GIF only, up to 2 MB. Images are re-encoded before storage.</p>
|
||||
<label for="bio">Bio:</label>
|
||||
<textarea id="bio" name="bio" maxlength="500">{{.User.Bio}}</textarea><br>
|
||||
<input type="submit" value="Save">
|
||||
|
|
|
|||
Loading…
Reference in New Issue