All: Enhance session management and security features

Using proper auth + security on login.
jocadbz
Joca 2025-12-21 21:26:02 -03:00
parent 876ac33d1b
commit ef06bf160a
Signed by: jocadbz
GPG Key ID: B1836DCE2F50BDF7
7 changed files with 331 additions and 294 deletions

View File

@ -12,7 +12,7 @@ This, of course, assumes you have a decent understanding of Go.
### Configuration Files ### Configuration Files
* **config/config.json.sample**: * **config/config.json.sample**:
This file provides a template for the main application configuration. It defines critical parameters for the application to run, such as database credentials, domain, and file storage locations. This file provides a template for the main application configuration. It defines critical parameters for the application to run, such as database credentials, domain, file storage locations, and session security.
Example content: Example content:
{ {
"domain_name": "localhost", "domain_name": "localhost",
@ -21,8 +21,13 @@ This, of course, assumes you have a decent understanding of Go.
"db_password": "threadr_password", "db_password": "threadr_password",
"db_database": "threadr_db", "db_database": "threadr_db",
"db_svr_host": "localhost:3306", "db_svr_host": "localhost:3306",
"file_storage_dir": "files" "file_storage_dir": "files",
"session_secret": "change-me-to-32-byte-random",
"session_secure": false
} }
Notes:
- `session_secret` should be a 32+ byte random value. At runtime, it is overridden by the `THREADR_SESSION_SECRET` environment variable if present (recommended for production).
- `session_secure` controls whether cookies are marked `Secure`; set to `true` in HTTPS environments.
* **config/config.json**: * **config/config.json**:
The active configuration file, copied from `config.json.sample` and modified for the specific deployment. Contains sensitive information like database passwords. The active configuration file, copied from `config.json.sample` and modified for the specific deployment. Contains sensitive information like database passwords.
@ -55,19 +60,9 @@ This directory contains the HTTP handler functions that process incoming request
* **handlers/app.go**: * **handlers/app.go**:
Defines common application-wide structures and middleware: Defines common application-wide structures and middleware:
- `PageData`: A struct holding data passed to HTML templates for rendering common elements (title, navbar state, login status, cookie banner, base paths, current URL). - `PageData`: A struct holding data passed to HTML templates for rendering common elements (title, navbar state, login status, cookie banner, base paths, current URL).
- `Config`: A struct to unmarshal application configuration from `config.json`. - `Config`: A struct to unmarshal application configuration from `config.json` (and env overrides). Fields include DB settings, domain, file storage dir, `session_secret`, and `session_secure`.
Example JSON for `Config`:
{
"domain_name": "localhost",
"threadr_dir": "/threadr",
"db_username": "threadr_user",
"db_password": "threadr_password",
"db_database": "threadr_db",
"db_svr_host": "localhost:3306",
"file_storage_dir": "files"
}
- `App`: The main application context struct, holding pointers to the database connection, session store, configuration, and templates. - `App`: The main application context struct, holding pointers to the database connection, session store, configuration, and templates.
- `SessionMW`: Middleware to retrieve or create a new Gorilla session for each request, making the session available in the request context. - `SessionMW`: Middleware to retrieve or create a new Gorilla session for each request, applying secure cookie options (HttpOnly, SameSite=Lax, Secure configurable) and attaching the session to the request context.
- `RequireLoginMW`: Middleware to enforce user authentication for specific routes, redirecting unauthenticated users to the login page. - `RequireLoginMW`: Middleware to enforce user authentication for specific routes, redirecting unauthenticated users to the login page.
* **handlers/about.go**: * **handlers/about.go**:

View File

@ -5,5 +5,7 @@
"db_password": "threadr_password", "db_password": "threadr_password",
"db_database": "threadr_db", "db_database": "threadr_db",
"db_svr_host": "localhost:3306", "db_svr_host": "localhost:3306",
"file_storage_dir": "files" "file_storage_dir": "files",
"session_secret": "change-me-to-32-byte-random",
"session_secure": false
} }

5
go.mod
View File

@ -6,11 +6,12 @@ require (
github.com/go-sql-driver/mysql v1.9.0 github.com/go-sql-driver/mysql v1.9.0
github.com/gorilla/sessions v1.4.0 github.com/gorilla/sessions v1.4.0
github.com/gorilla/websocket v1.5.0 github.com/gorilla/websocket v1.5.0
golang.org/x/crypto v0.45.0
golang.org/x/term v0.37.0
) )
require ( require (
filippo.io/edwards25519 v1.1.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect
github.com/gorilla/securecookie v1.1.2 // indirect github.com/gorilla/securecookie v1.1.2 // indirect
golang.org/x/sys v0.33.0 // indirect golang.org/x/sys v0.38.0 // indirect
golang.org/x/term v0.32.0 // indirect
) )

10
go.sum
View File

@ -10,7 +10,9 @@ github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzq
github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik= github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg= golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=

View File

@ -11,68 +11,86 @@ import (
) )
type PageData struct { type PageData struct {
Title string Title string
Navbar string Navbar string
LoggedIn bool LoggedIn bool
ShowCookieBanner bool ShowCookieBanner bool
BasePath string BasePath string
StaticPath string StaticPath string
CurrentURL string CurrentURL string
} }
type Config struct { type Config struct {
DomainName string `json:"domain_name"` DomainName string `json:"domain_name"`
ThreadrDir string `json:"threadr_dir"` ThreadrDir string `json:"threadr_dir"`
DBUsername string `json:"db_username"` DBUsername string `json:"db_username"`
DBPassword string `json:"db_password"` DBPassword string `json:"db_password"`
DBDatabase string `json:"db_database"` DBDatabase string `json:"db_database"`
DBServerHost string `json:"db_svr_host"` DBServerHost string `json:"db_svr_host"`
FileStorageDir string `json:"file_storage_dir"` FileStorageDir string `json:"file_storage_dir"`
SessionSecret string `json:"session_secret"`
SessionSecure bool `json:"session_secure"`
} }
type App struct { type App struct {
DB *sql.DB DB *sql.DB
Store *sessions.CookieStore Store *sessions.CookieStore
Config *Config Config *Config
Tmpl *template.Template Tmpl *template.Template
} }
func (app *App) SessionMW(next http.HandlerFunc) http.HandlerFunc { func (app *App) SessionMW(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
session, err := app.Store.Get(r, "session-name") session, err := app.Store.Get(r, "session-name")
if err != nil { if err != nil {
session = sessions.NewSession(app.Store, "session-name") session = sessions.NewSession(app.Store, "session-name")
session.Options = &sessions.Options{ }
Path: "/",
MaxAge: 86400 * 30, // 30 days
HttpOnly: true,
}
}
ctx := context.WithValue(r.Context(), "session", session) // Enforce secure cookie options on every request.
r = r.WithContext(ctx) session.Options = app.cookieOptions(r)
next(w, r) ctx := context.WithValue(r.Context(), "session", session)
r = r.WithContext(ctx)
if err := session.Save(r, w); err != nil { next(w, r)
/*
Ok, so here's the thing if err := session.Save(r, w); err != nil {
Errors coming from this function here "can" be ignored. /*
They mostly come from errors while setting cookies, so in some Ok, so here's the thing
environments this will trigger a lot, but they are harmless. Errors coming from this function here "can" be ignored.
*/ They mostly come from errors while setting cookies, so in some
log.Printf("Error saving session in SessionMW: %v", err) environments this will trigger a lot, but they are harmless.
} */
} log.Printf("Error saving session in SessionMW: %v", err)
}
}
} }
func (app *App) RequireLoginMW(next http.HandlerFunc) http.HandlerFunc { func (app *App) RequireLoginMW(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
session := r.Context().Value("session").(*sessions.Session) session := r.Context().Value("session").(*sessions.Session)
if _, ok := session.Values["user_id"].(int); !ok { if _, ok := session.Values["user_id"].(int); !ok {
http.Redirect(w, r, app.Config.ThreadrDir+"/login/?error=session", http.StatusFound) http.Redirect(w, r, app.Config.ThreadrDir+"/login/?error=session", http.StatusFound)
return return
} }
next(w, r) next(w, r)
} }
}
func (app *App) cookieOptions(r *http.Request) *sessions.Options {
secure := app.Config.SessionSecure
if r.TLS != nil {
secure = true
} // I dunno what I am doing honestly
options := &sessions.Options{
Path: app.Config.ThreadrDir + "/",
MaxAge: 86400 * 30,
HttpOnly: true,
Secure: secure,
SameSite: http.SameSiteLaxMode,
}
if app.Config.DomainName != "" {
options.Domain = app.Config.DomainName
}
return options
} }

View File

@ -1,68 +1,69 @@
package handlers package handlers
import ( import (
"database/sql" "database/sql"
"log" "log"
"net/http" "net/http"
"threadr/models" "threadr/models"
"github.com/gorilla/sessions"
"github.com/gorilla/sessions"
) )
func LoginHandler(app *App) http.HandlerFunc { func LoginHandler(app *App) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
session := r.Context().Value("session").(*sessions.Session) session := r.Context().Value("session").(*sessions.Session)
if r.Method == http.MethodPost { if r.Method == http.MethodPost {
username := r.FormValue("username") username := r.FormValue("username")
password := r.FormValue("password") password := r.FormValue("password")
user, err := models.GetUserByUsername(app.DB, username) user, err := models.GetUserByUsername(app.DB, username)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
log.Printf("Error fetching user in LoginHandler: %v", err) log.Printf("Error fetching user in LoginHandler: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError) http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return return
} }
if user == nil || !models.CheckPassword(password, user.AuthenticationSalt, user.AuthenticationAlgorithm, user.AuthenticationString) { if user == nil || !models.CheckPassword(password, user.AuthenticationSalt, user.AuthenticationAlgorithm, user.AuthenticationString) {
http.Redirect(w, r, app.Config.ThreadrDir+"/login/?error=invalid", http.StatusFound) http.Redirect(w, r, app.Config.ThreadrDir+"/login/?error=invalid", http.StatusFound)
return return
} }
session.Values["user_id"] = user.ID // Regenerate session to avoid fixation
session.Values["user_ip"] = r.RemoteAddr session.Options.MaxAge = -1
session.Values["user_agent"] = r.UserAgent() _ = session.Save(r, w)
session.Options = &sessions.Options{ session = sessions.NewSession(app.Store, "session-name")
Path: "/", session.Options = app.cookieOptions(r)
MaxAge: 86400 * 30, // 30 days session.Values["user_id"] = user.ID
HttpOnly: true, session.Values["user_ip"] = r.RemoteAddr
} session.Values["user_agent"] = r.UserAgent()
if err := session.Save(r, w); err != nil { if err := session.Save(r, w); err != nil {
log.Printf("Error saving session: %v", err) log.Printf("Error saving session: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError) http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return return
} }
http.Redirect(w, r, app.Config.ThreadrDir+"/userhome/", http.StatusFound) http.Redirect(w, r, app.Config.ThreadrDir+"/userhome/", http.StatusFound)
return return
} }
data := struct { data := struct {
PageData PageData
Error string Error string
}{ }{
PageData: PageData{ PageData: PageData{
Title: "ThreadR - Login", Title: "ThreadR - Login",
Navbar: "login", Navbar: "login",
LoggedIn: false, LoggedIn: false,
BasePath: app.Config.ThreadrDir, BasePath: app.Config.ThreadrDir,
StaticPath: app.Config.ThreadrDir + "/static", StaticPath: app.Config.ThreadrDir + "/static",
CurrentURL: r.URL.Path, CurrentURL: r.URL.Path,
}, },
Error: "", Error: "",
} }
if r.URL.Query().Get("error") == "invalid" { if r.URL.Query().Get("error") == "invalid" {
data.Error = "Invalid username or password" data.Error = "Invalid username or password"
} }
if err := app.Tmpl.ExecuteTemplate(w, "login", data); err != nil { if err := app.Tmpl.ExecuteTemplate(w, "login", data); err != nil {
log.Printf("Error executing template in LoginHandler: %v", err) log.Printf("Error executing template in LoginHandler: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError) http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return return
} }
} }
} }

342
main.go
View File

@ -22,19 +22,19 @@ import (
) )
func loadConfig(filename string) (*handlers.Config, error) { func loadConfig(filename string) (*handlers.Config, error) {
file, err := os.Open(filename) file, err := os.Open(filename)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer file.Close() defer file.Close()
var config handlers.Config var config handlers.Config
err = json.NewDecoder(file).Decode(&config) err = json.NewDecoder(file).Decode(&config)
return &config, err return &config, err
} }
func createTablesIfNotExist(db *sql.DB) error { func createTablesIfNotExist(db *sql.DB) error {
// Create boards table // Create boards table
_, err := db.Exec(` _, err := db.Exec(`
CREATE TABLE boards ( CREATE TABLE boards (
id INT AUTO_INCREMENT PRIMARY KEY, id INT AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(255) NOT NULL, name VARCHAR(255) NOT NULL,
@ -46,12 +46,12 @@ func createTablesIfNotExist(db *sql.DB) error {
color_scheme VARCHAR(255), color_scheme VARCHAR(255),
type VARCHAR(20) DEFAULT 'classic' NOT NULL type VARCHAR(20) DEFAULT 'classic' NOT NULL
)`) )`)
if err != nil { if err != nil {
return fmt.Errorf("error creating boards table: %v", err) return fmt.Errorf("error creating boards table: %v", err)
} }
// Create threads table // Create threads table
_, err = db.Exec(` _, err = db.Exec(`
CREATE TABLE threads ( CREATE TABLE threads (
id INT AUTO_INCREMENT PRIMARY KEY, id INT AUTO_INCREMENT PRIMARY KEY,
board_id INT NOT NULL, board_id INT NOT NULL,
@ -62,12 +62,12 @@ func createTablesIfNotExist(db *sql.DB) error {
accepted_answer_post_id INT, accepted_answer_post_id INT,
FOREIGN KEY (board_id) REFERENCES boards(id) FOREIGN KEY (board_id) REFERENCES boards(id)
)`) )`)
if err != nil { if err != nil {
return fmt.Errorf("error creating threads table: %v", err) return fmt.Errorf("error creating threads table: %v", err)
} }
// Create posts table // Create posts table
_, err = db.Exec(` _, err = db.Exec(`
CREATE TABLE posts ( CREATE TABLE posts (
id INT AUTO_INCREMENT PRIMARY KEY, id INT AUTO_INCREMENT PRIMARY KEY,
thread_id INT NOT NULL, thread_id INT NOT NULL,
@ -81,12 +81,12 @@ func createTablesIfNotExist(db *sql.DB) error {
reply_to INT DEFAULT -1, reply_to INT DEFAULT -1,
FOREIGN KEY (thread_id) REFERENCES threads(id) FOREIGN KEY (thread_id) REFERENCES threads(id)
)`) )`)
if err != nil { if err != nil {
return fmt.Errorf("error creating posts table: %v", err) return fmt.Errorf("error creating posts table: %v", err)
} }
// Create likes table // Create likes table
_, err = db.Exec(` _, err = db.Exec(`
CREATE TABLE likes ( CREATE TABLE likes (
id INT AUTO_INCREMENT PRIMARY KEY, id INT AUTO_INCREMENT PRIMARY KEY,
post_id INT NOT NULL, post_id INT NOT NULL,
@ -95,12 +95,12 @@ func createTablesIfNotExist(db *sql.DB) error {
UNIQUE KEY unique_like (post_id, user_id), UNIQUE KEY unique_like (post_id, user_id),
FOREIGN KEY (post_id) REFERENCES posts(id) FOREIGN KEY (post_id) REFERENCES posts(id)
)`) )`)
if err != nil { if err != nil {
return fmt.Errorf("error creating likes table: %v", err) return fmt.Errorf("error creating likes table: %v", err)
} }
// Create board_permissions table // Create board_permissions table
_, err = db.Exec(` _, err = db.Exec(`
CREATE TABLE board_permissions ( CREATE TABLE board_permissions (
user_id INT NOT NULL, user_id INT NOT NULL,
board_id INT NOT NULL, board_id INT NOT NULL,
@ -108,12 +108,12 @@ func createTablesIfNotExist(db *sql.DB) error {
PRIMARY KEY (user_id, board_id), PRIMARY KEY (user_id, board_id),
FOREIGN KEY (board_id) REFERENCES boards(id) FOREIGN KEY (board_id) REFERENCES boards(id)
)`) )`)
if err != nil { if err != nil {
return fmt.Errorf("error creating board_permissions table: %v", err) return fmt.Errorf("error creating board_permissions table: %v", err)
} }
// Create notifications table // Create notifications table
_, err = db.Exec(` _, err = db.Exec(`
CREATE TABLE notifications ( CREATE TABLE notifications (
id INT AUTO_INCREMENT PRIMARY KEY, id INT AUTO_INCREMENT PRIMARY KEY,
user_id INT NOT NULL, user_id INT NOT NULL,
@ -122,12 +122,12 @@ func createTablesIfNotExist(db *sql.DB) error {
is_read BOOLEAN DEFAULT FALSE, is_read BOOLEAN DEFAULT FALSE,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)`) )`)
if err != nil { if err != nil {
return fmt.Errorf("error creating notifications table: %v", err) return fmt.Errorf("error creating notifications table: %v", err)
} }
// Create reactions table // Create reactions table
_, err = db.Exec(` _, err = db.Exec(`
CREATE TABLE reactions ( CREATE TABLE reactions (
id INT AUTO_INCREMENT PRIMARY KEY, id INT AUTO_INCREMENT PRIMARY KEY,
post_id INT NOT NULL, post_id INT NOT NULL,
@ -135,12 +135,12 @@ func createTablesIfNotExist(db *sql.DB) error {
emoji VARCHAR(10) NOT NULL, emoji VARCHAR(10) NOT NULL,
FOREIGN KEY (post_id) REFERENCES posts(id) FOREIGN KEY (post_id) REFERENCES posts(id)
)`) )`)
if err != nil { if err != nil {
return fmt.Errorf("error creating reactions table: %v", err) return fmt.Errorf("error creating reactions table: %v", err)
} }
// Create reposts table // Create reposts table
_, err = db.Exec(` _, err = db.Exec(`
CREATE TABLE reposts ( CREATE TABLE reposts (
id INT AUTO_INCREMENT PRIMARY KEY, id INT AUTO_INCREMENT PRIMARY KEY,
thread_id INT NOT NULL, thread_id INT NOT NULL,
@ -150,12 +150,12 @@ func createTablesIfNotExist(db *sql.DB) error {
FOREIGN KEY (thread_id) REFERENCES threads(id), FOREIGN KEY (thread_id) REFERENCES threads(id),
FOREIGN KEY (board_id) REFERENCES boards(id) FOREIGN KEY (board_id) REFERENCES boards(id)
)`) )`)
if err != nil { if err != nil {
return fmt.Errorf("error creating reposts table: %v", err) return fmt.Errorf("error creating reposts table: %v", err)
} }
// Create news table // Create news table
_, err = db.Exec(` _, err = db.Exec(`
CREATE TABLE news ( CREATE TABLE news (
id INT AUTO_INCREMENT PRIMARY KEY, id INT AUTO_INCREMENT PRIMARY KEY,
title VARCHAR(255) NOT NULL, title VARCHAR(255) NOT NULL,
@ -163,12 +163,12 @@ func createTablesIfNotExist(db *sql.DB) error {
posted_by INT NOT NULL, posted_by INT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)`) )`)
if err != nil { if err != nil {
return fmt.Errorf("error creating news table: %v", err) return fmt.Errorf("error creating news table: %v", err)
} }
// Create chat_messages table // Create chat_messages table
_, err = db.Exec(` _, err = db.Exec(`
CREATE TABLE chat_messages ( CREATE TABLE chat_messages (
id INT AUTO_INCREMENT PRIMARY KEY, id INT AUTO_INCREMENT PRIMARY KEY,
board_id INT NOT NULL, board_id INT NOT NULL,
@ -178,12 +178,12 @@ func createTablesIfNotExist(db *sql.DB) error {
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (board_id) REFERENCES boards(id) ON DELETE CASCADE FOREIGN KEY (board_id) REFERENCES boards(id) ON DELETE CASCADE
)`) )`)
if err != nil { if err != nil {
return fmt.Errorf("error creating chat_messages table: %v", err) return fmt.Errorf("error creating chat_messages table: %v", err)
} }
// Create files table (Hope this does not break anything) // Create files table (Hope this does not break anything)
_, err = db.Exec(` _, err = db.Exec(`
CREATE TABLE files ( CREATE TABLE files (
id INT AUTO_INCREMENT PRIMARY KEY, id INT AUTO_INCREMENT PRIMARY KEY,
original_name VARCHAR(255) NOT NULL, original_name VARCHAR(255) NOT NULL,
@ -191,13 +191,13 @@ func createTablesIfNotExist(db *sql.DB) error {
hash_algorithm VARCHAR(50) NOT NULL, hash_algorithm VARCHAR(50) NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)`) )`)
if err != nil { if err != nil {
return fmt.Errorf("error creating files table: %v", err) return fmt.Errorf("error creating files table: %v", err)
} }
// Create users table (KEEP THIS HERE!) // Create users table (KEEP THIS HERE!)
// Otherwise SQL bitches about the foreign key. // Otherwise SQL bitches about the foreign key.
_, err = db.Exec(` _, err = db.Exec(`
CREATE TABLE users ( CREATE TABLE users (
id INT AUTO_INCREMENT PRIMARY KEY, id INT AUTO_INCREMENT PRIMARY KEY,
username VARCHAR(255) NOT NULL UNIQUE, username VARCHAR(255) NOT NULL UNIQUE,
@ -213,12 +213,12 @@ func createTablesIfNotExist(db *sql.DB) error {
permissions BIGINT DEFAULT 0, permissions BIGINT DEFAULT 0,
FOREIGN KEY (pfp_file_id) REFERENCES files(id) FOREIGN KEY (pfp_file_id) REFERENCES files(id)
)`) )`)
if err != nil { if err != nil {
return fmt.Errorf("error creating users table: %v", err) return fmt.Errorf("error creating users table: %v", err)
} }
log.Println("Database tables created.") log.Println("Database tables created.")
return nil return nil
} }
func ensureAdminUser(db *sql.DB) error { func ensureAdminUser(db *sql.DB) error {
@ -284,110 +284,128 @@ func ensureAdminUser(db *sql.DB) error {
} }
func main() { func main() {
// Define command-line flag for initialization // Define command-line flag for initialization
initialize := flag.Bool("initialize", false, "Initialize database tables and admin user") initialize := flag.Bool("initialize", false, "Initialize database tables and admin user")
flag.BoolVar(initialize, "i", false, "Short for --initialize") flag.BoolVar(initialize, "i", false, "Short for --initialize")
flag.Parse() flag.Parse()
config, err := loadConfig("config/config.json") config, err := loadConfig("config/config.json")
if err != nil { if err != nil {
log.Fatal("Error loading config:", err) log.Fatal("Error loading config:", err)
} }
dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s", config.DBUsername, config.DBPassword, config.DBServerHost, config.DBDatabase) // Allow environment variable override for the session secret to avoid hardcoding secrets in files.
db, err := sql.Open("mysql", dsn) if envSecret := os.Getenv("THREADR_SESSION_SECRET"); envSecret != "" {
if err != nil { config.SessionSecret = envSecret
log.Fatal("Error connecting to database:", err) }
} if len(config.SessionSecret) < 32 {
defer db.Close() log.Fatal("Session secret must be at least 32 bytes; set THREADR_SESSION_SECRET or session_secret in config")
}
// Create the file directory dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s", config.DBUsername, config.DBPassword, config.DBServerHost, config.DBDatabase)
// TODO: Wouldn't this be better suited on the initialize function? db, err := sql.Open("mysql", dsn)
// Discussion pending. if err != nil {
err = os.MkdirAll(config.FileStorageDir, 0700) log.Fatal("Error connecting to database:", err)
if err != nil { }
log.Fatal("Error creating file storage directory:", err) defer db.Close()
}
// Perform initialization if the flag is set // Create the file directory
if *initialize { // TODO: Wouldn't this be better suited on the initialize function?
log.Println("Initializing database...") // Discussion pending.
err = createTablesIfNotExist(db) err = os.MkdirAll(config.FileStorageDir, 0700)
if err != nil { if err != nil {
log.Fatal("Error creating database tables:", err) log.Fatal("Error creating file storage directory:", err)
} }
err = ensureAdminUser(db) // Perform initialization if the flag is set
if err != nil { if *initialize {
log.Fatal("Error ensuring admin user:", err) log.Println("Initializing database...")
} err = createTablesIfNotExist(db)
if err != nil {
log.Fatal("Error creating database tables:", err)
}
log.Println("Initialization completed successfully. Exiting.") err = ensureAdminUser(db)
return if err != nil {
} log.Fatal("Error ensuring admin user:", err)
}
// Normal startup (without automatic table creation) log.Println("Initialization completed successfully. Exiting.")
log.Println("Starting ThreadR server...") return
}
dir, err := os.Getwd() // Normal startup (without automatic table creation)
if err != nil { log.Println("Starting ThreadR server...")
log.Fatal("Error getting working directory:", err)
}
// Parse partial templates dir, err := os.Getwd()
tmpl := template.Must(template.ParseFiles( if err != nil {
filepath.Join(dir, "templates/partials/navbar.html"), log.Fatal("Error getting working directory:", err)
filepath.Join(dir, "templates/partials/cookie_banner.html"), }
))
// Parse page-specific templates with unique names // Parse partial templates
tmpl, err = tmpl.ParseFiles( tmpl := template.Must(template.ParseFiles(
filepath.Join(dir, "templates/pages/about.html"), filepath.Join(dir, "templates/partials/navbar.html"),
filepath.Join(dir, "templates/pages/board.html"), filepath.Join(dir, "templates/partials/cookie_banner.html"),
filepath.Join(dir, "templates/pages/boards.html"), ))
filepath.Join(dir, "templates/pages/home.html"),
filepath.Join(dir, "templates/pages/login.html"),
filepath.Join(dir, "templates/pages/news.html"),
filepath.Join(dir, "templates/pages/profile.html"),
filepath.Join(dir, "templates/pages/profile_edit.html"),
filepath.Join(dir, "templates/pages/signup.html"),
filepath.Join(dir, "templates/pages/thread.html"),
filepath.Join(dir, "templates/pages/userhome.html"),
filepath.Join(dir, "templates/pages/chat.html"),
)
if err != nil {
log.Fatal("Error parsing page templates:", err)
}
store := sessions.NewCookieStore([]byte("secret-key")) // Replace with secure key in production // Parse page-specific templates with unique names
tmpl, err = tmpl.ParseFiles(
filepath.Join(dir, "templates/pages/about.html"),
filepath.Join(dir, "templates/pages/board.html"),
filepath.Join(dir, "templates/pages/boards.html"),
filepath.Join(dir, "templates/pages/home.html"),
filepath.Join(dir, "templates/pages/login.html"),
filepath.Join(dir, "templates/pages/news.html"),
filepath.Join(dir, "templates/pages/profile.html"),
filepath.Join(dir, "templates/pages/profile_edit.html"),
filepath.Join(dir, "templates/pages/signup.html"),
filepath.Join(dir, "templates/pages/thread.html"),
filepath.Join(dir, "templates/pages/userhome.html"),
filepath.Join(dir, "templates/pages/chat.html"),
)
if err != nil {
log.Fatal("Error parsing page templates:", err)
}
app := &handlers.App{ store := sessions.NewCookieStore([]byte(config.SessionSecret))
DB: db, store.Options = &sessions.Options{
Store: store, Path: config.ThreadrDir + "/",
Config: config, MaxAge: 86400 * 30,
Tmpl: tmpl, HttpOnly: true,
} Secure: config.SessionSecure,
SameSite: http.SameSiteLaxMode,
}
if config.DomainName != "" {
store.Options.Domain = config.DomainName
}
fs := http.FileServer(http.Dir("static")) app := &handlers.App{
http.Handle(config.ThreadrDir+"/static/", http.StripPrefix(config.ThreadrDir+"/static/", fs)) DB: db,
Store: store,
Config: config,
Tmpl: tmpl,
}
http.HandleFunc(config.ThreadrDir+"/", app.SessionMW(handlers.HomeHandler(app))) fs := http.FileServer(http.Dir("static"))
http.HandleFunc(config.ThreadrDir+"/login/", app.SessionMW(handlers.LoginHandler(app))) http.Handle(config.ThreadrDir+"/static/", http.StripPrefix(config.ThreadrDir+"/static/", fs))
http.HandleFunc(config.ThreadrDir+"/logout/", app.SessionMW(handlers.LogoutHandler(app)))
http.HandleFunc(config.ThreadrDir+"/userhome/", app.SessionMW(app.RequireLoginMW(handlers.UserHomeHandler(app))))
http.HandleFunc(config.ThreadrDir+"/boards/", app.SessionMW(handlers.BoardsHandler(app)))
http.HandleFunc(config.ThreadrDir+"/board/", app.SessionMW(handlers.BoardHandler(app)))
http.HandleFunc(config.ThreadrDir+"/thread/", app.SessionMW(handlers.ThreadHandler(app)))
http.HandleFunc(config.ThreadrDir+"/about/", app.SessionMW(handlers.AboutHandler(app)))
http.HandleFunc(config.ThreadrDir+"/profile/", app.SessionMW(app.RequireLoginMW(handlers.ProfileHandler(app))))
http.HandleFunc(config.ThreadrDir+"/profile/edit/", app.SessionMW(app.RequireLoginMW(handlers.ProfileEditHandler(app))))
http.HandleFunc(config.ThreadrDir+"/like/", app.SessionMW(app.RequireLoginMW(handlers.LikeHandler(app))))
http.HandleFunc(config.ThreadrDir+"/news/", app.SessionMW(handlers.NewsHandler(app)))
http.HandleFunc(config.ThreadrDir+"/signup/", app.SessionMW(handlers.SignupHandler(app)))
http.HandleFunc(config.ThreadrDir+"/accept_cookie/", app.SessionMW(handlers.AcceptCookieHandler(app)))
http.HandleFunc(config.ThreadrDir+"/chat/", app.SessionMW(app.RequireLoginMW(handlers.ChatHandler(app))))
http.HandleFunc(config.ThreadrDir+"/file", app.SessionMW(handlers.FileHandler(app)))
log.Println("Server starting on :8080") http.HandleFunc(config.ThreadrDir+"/", app.SessionMW(handlers.HomeHandler(app)))
log.Fatal(http.ListenAndServe(":8080", nil)) http.HandleFunc(config.ThreadrDir+"/login/", app.SessionMW(handlers.LoginHandler(app)))
http.HandleFunc(config.ThreadrDir+"/logout/", app.SessionMW(handlers.LogoutHandler(app)))
http.HandleFunc(config.ThreadrDir+"/userhome/", app.SessionMW(app.RequireLoginMW(handlers.UserHomeHandler(app))))
http.HandleFunc(config.ThreadrDir+"/boards/", app.SessionMW(handlers.BoardsHandler(app)))
http.HandleFunc(config.ThreadrDir+"/board/", app.SessionMW(handlers.BoardHandler(app)))
http.HandleFunc(config.ThreadrDir+"/thread/", app.SessionMW(handlers.ThreadHandler(app)))
http.HandleFunc(config.ThreadrDir+"/about/", app.SessionMW(handlers.AboutHandler(app)))
http.HandleFunc(config.ThreadrDir+"/profile/", app.SessionMW(app.RequireLoginMW(handlers.ProfileHandler(app))))
http.HandleFunc(config.ThreadrDir+"/profile/edit/", app.SessionMW(app.RequireLoginMW(handlers.ProfileEditHandler(app))))
http.HandleFunc(config.ThreadrDir+"/like/", app.SessionMW(app.RequireLoginMW(handlers.LikeHandler(app))))
http.HandleFunc(config.ThreadrDir+"/news/", app.SessionMW(handlers.NewsHandler(app)))
http.HandleFunc(config.ThreadrDir+"/signup/", app.SessionMW(handlers.SignupHandler(app)))
http.HandleFunc(config.ThreadrDir+"/accept_cookie/", app.SessionMW(handlers.AcceptCookieHandler(app)))
http.HandleFunc(config.ThreadrDir+"/chat/", app.SessionMW(app.RequireLoginMW(handlers.ChatHandler(app))))
http.HandleFunc(config.ThreadrDir+"/file", app.SessionMW(handlers.FileHandler(app)))
log.Println("Server starting on :8080")
log.Fatal(http.ListenAndServe(":8080", nil))
} }