From ef06bf160ab08565e03ab3bc380543a3243eb26a Mon Sep 17 00:00:00 2001 From: Jocadbz Date: Sun, 21 Dec 2025 21:26:02 -0300 Subject: [PATCH] All: Enhance session management and security features Using proper auth + security on login. --- DOCUMENTATION.md | 23 +-- config/config.json.sample | 4 +- go.mod | 5 +- go.sum | 10 +- handlers/app.go | 118 +++++++------ handlers/login.go | 121 +++++++------- main.go | 344 ++++++++++++++++++++------------------ 7 files changed, 331 insertions(+), 294 deletions(-) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index eca44b3..dee5bd2 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -12,7 +12,7 @@ This, of course, assumes you have a decent understanding of Go. ### Configuration Files * **config/config.json.sample**: - This file provides a template for the main application configuration. It defines critical parameters for the application to run, such as database credentials, domain, and file storage locations. + This file provides a template for the main application configuration. It defines critical parameters for the application to run, such as database credentials, domain, file storage locations, and session security. Example content: { "domain_name": "localhost", @@ -21,8 +21,13 @@ This, of course, assumes you have a decent understanding of Go. "db_password": "threadr_password", "db_database": "threadr_db", "db_svr_host": "localhost:3306", - "file_storage_dir": "files" + "file_storage_dir": "files", + "session_secret": "change-me-to-32-byte-random", + "session_secure": false } + Notes: + - `session_secret` should be a 32+ byte random value. At runtime, it is overridden by the `THREADR_SESSION_SECRET` environment variable if present (recommended for production). + - `session_secure` controls whether cookies are marked `Secure`; set to `true` in HTTPS environments. * **config/config.json**: The active configuration file, copied from `config.json.sample` and modified for the specific deployment. Contains sensitive information like database passwords. @@ -55,19 +60,9 @@ This directory contains the HTTP handler functions that process incoming request * **handlers/app.go**: Defines common application-wide structures and middleware: - `PageData`: A struct holding data passed to HTML templates for rendering common elements (title, navbar state, login status, cookie banner, base paths, current URL). - - `Config`: A struct to unmarshal application configuration from `config.json`. - Example JSON for `Config`: - { - "domain_name": "localhost", - "threadr_dir": "/threadr", - "db_username": "threadr_user", - "db_password": "threadr_password", - "db_database": "threadr_db", - "db_svr_host": "localhost:3306", - "file_storage_dir": "files" - } + - `Config`: A struct to unmarshal application configuration from `config.json` (and env overrides). Fields include DB settings, domain, file storage dir, `session_secret`, and `session_secure`. - `App`: The main application context struct, holding pointers to the database connection, session store, configuration, and templates. - - `SessionMW`: Middleware to retrieve or create a new Gorilla session for each request, making the session available in the request context. + - `SessionMW`: Middleware to retrieve or create a new Gorilla session for each request, applying secure cookie options (HttpOnly, SameSite=Lax, Secure configurable) and attaching the session to the request context. - `RequireLoginMW`: Middleware to enforce user authentication for specific routes, redirecting unauthenticated users to the login page. * **handlers/about.go**: diff --git a/config/config.json.sample b/config/config.json.sample index bb96f0a..ba2aa31 100644 --- a/config/config.json.sample +++ b/config/config.json.sample @@ -5,5 +5,7 @@ "db_password": "threadr_password", "db_database": "threadr_db", "db_svr_host": "localhost:3306", - "file_storage_dir": "files" + "file_storage_dir": "files", + "session_secret": "change-me-to-32-byte-random", + "session_secure": false } diff --git a/go.mod b/go.mod index f39d814..cde1f60 100644 --- a/go.mod +++ b/go.mod @@ -6,11 +6,12 @@ require ( github.com/go-sql-driver/mysql v1.9.0 github.com/gorilla/sessions v1.4.0 github.com/gorilla/websocket v1.5.0 + golang.org/x/crypto v0.45.0 + golang.org/x/term v0.37.0 ) require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/gorilla/securecookie v1.1.2 // indirect - golang.org/x/sys v0.33.0 // indirect - golang.org/x/term v0.32.0 // indirect + golang.org/x/sys v0.38.0 // indirect ) diff --git a/go.sum b/go.sum index 498ed04..a04c4c2 100644 --- a/go.sum +++ b/go.sum @@ -10,7 +10,9 @@ github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzq github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= -golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/term v0.32.0 h1:DR4lr0TjUs3epypdhTOkMmuF5CDFJ/8pOnbzMZPQ7bg= -golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= +golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= diff --git a/handlers/app.go b/handlers/app.go index 7f54ed0..69abd92 100644 --- a/handlers/app.go +++ b/handlers/app.go @@ -11,68 +11,86 @@ import ( ) type PageData struct { - Title string - Navbar string - LoggedIn bool - ShowCookieBanner bool - BasePath string - StaticPath string - CurrentURL string + Title string + Navbar string + LoggedIn bool + ShowCookieBanner bool + BasePath string + StaticPath string + CurrentURL string } type Config struct { - DomainName string `json:"domain_name"` - ThreadrDir string `json:"threadr_dir"` - DBUsername string `json:"db_username"` - DBPassword string `json:"db_password"` - DBDatabase string `json:"db_database"` - DBServerHost string `json:"db_svr_host"` - FileStorageDir string `json:"file_storage_dir"` + DomainName string `json:"domain_name"` + ThreadrDir string `json:"threadr_dir"` + DBUsername string `json:"db_username"` + DBPassword string `json:"db_password"` + DBDatabase string `json:"db_database"` + DBServerHost string `json:"db_svr_host"` + FileStorageDir string `json:"file_storage_dir"` + SessionSecret string `json:"session_secret"` + SessionSecure bool `json:"session_secure"` } type App struct { - DB *sql.DB - Store *sessions.CookieStore - Config *Config - Tmpl *template.Template + DB *sql.DB + Store *sessions.CookieStore + Config *Config + Tmpl *template.Template } func (app *App) SessionMW(next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - session, err := app.Store.Get(r, "session-name") - if err != nil { - session = sessions.NewSession(app.Store, "session-name") - session.Options = &sessions.Options{ - Path: "/", - MaxAge: 86400 * 30, // 30 days - HttpOnly: true, - } - } + return func(w http.ResponseWriter, r *http.Request) { + session, err := app.Store.Get(r, "session-name") + if err != nil { + session = sessions.NewSession(app.Store, "session-name") + } - ctx := context.WithValue(r.Context(), "session", session) - r = r.WithContext(ctx) + // Enforce secure cookie options on every request. + session.Options = app.cookieOptions(r) - next(w, r) + ctx := context.WithValue(r.Context(), "session", session) + r = r.WithContext(ctx) - if err := session.Save(r, w); err != nil { - /* - Ok, so here's the thing - Errors coming from this function here "can" be ignored. - They mostly come from errors while setting cookies, so in some - environments this will trigger a lot, but they are harmless. - */ - log.Printf("Error saving session in SessionMW: %v", err) - } - } + next(w, r) + + if err := session.Save(r, w); err != nil { + /* + Ok, so here's the thing + Errors coming from this function here "can" be ignored. + They mostly come from errors while setting cookies, so in some + environments this will trigger a lot, but they are harmless. + */ + log.Printf("Error saving session in SessionMW: %v", err) + } + } } func (app *App) RequireLoginMW(next http.HandlerFunc) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - session := r.Context().Value("session").(*sessions.Session) - if _, ok := session.Values["user_id"].(int); !ok { - http.Redirect(w, r, app.Config.ThreadrDir+"/login/?error=session", http.StatusFound) - return - } - next(w, r) - } -} \ No newline at end of file + return func(w http.ResponseWriter, r *http.Request) { + session := r.Context().Value("session").(*sessions.Session) + if _, ok := session.Values["user_id"].(int); !ok { + http.Redirect(w, r, app.Config.ThreadrDir+"/login/?error=session", http.StatusFound) + return + } + next(w, r) + } +} + +func (app *App) cookieOptions(r *http.Request) *sessions.Options { + secure := app.Config.SessionSecure + if r.TLS != nil { + secure = true + } // I dunno what I am doing honestly + options := &sessions.Options{ + Path: app.Config.ThreadrDir + "/", + MaxAge: 86400 * 30, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + } + if app.Config.DomainName != "" { + options.Domain = app.Config.DomainName + } + return options +} diff --git a/handlers/login.go b/handlers/login.go index 11b19ec..9f1f258 100644 --- a/handlers/login.go +++ b/handlers/login.go @@ -1,68 +1,69 @@ package handlers import ( - "database/sql" - "log" - "net/http" - "threadr/models" - "github.com/gorilla/sessions" + "database/sql" + "log" + "net/http" + "threadr/models" + + "github.com/gorilla/sessions" ) func LoginHandler(app *App) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - session := r.Context().Value("session").(*sessions.Session) - if r.Method == http.MethodPost { - username := r.FormValue("username") - password := r.FormValue("password") - user, err := models.GetUserByUsername(app.DB, username) - if err != nil && err != sql.ErrNoRows { - log.Printf("Error fetching user in LoginHandler: %v", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - if user == nil || !models.CheckPassword(password, user.AuthenticationSalt, user.AuthenticationAlgorithm, user.AuthenticationString) { - http.Redirect(w, r, app.Config.ThreadrDir+"/login/?error=invalid", http.StatusFound) - return - } - session.Values["user_id"] = user.ID - session.Values["user_ip"] = r.RemoteAddr - session.Values["user_agent"] = r.UserAgent() - session.Options = &sessions.Options{ - Path: "/", - MaxAge: 86400 * 30, // 30 days - HttpOnly: true, - } - if err := session.Save(r, w); err != nil { - log.Printf("Error saving session: %v", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - http.Redirect(w, r, app.Config.ThreadrDir+"/userhome/", http.StatusFound) - return - } + return func(w http.ResponseWriter, r *http.Request) { + session := r.Context().Value("session").(*sessions.Session) + if r.Method == http.MethodPost { + username := r.FormValue("username") + password := r.FormValue("password") + user, err := models.GetUserByUsername(app.DB, username) + if err != nil && err != sql.ErrNoRows { + log.Printf("Error fetching user in LoginHandler: %v", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + if user == nil || !models.CheckPassword(password, user.AuthenticationSalt, user.AuthenticationAlgorithm, user.AuthenticationString) { + http.Redirect(w, r, app.Config.ThreadrDir+"/login/?error=invalid", http.StatusFound) + return + } + // Regenerate session to avoid fixation + session.Options.MaxAge = -1 + _ = session.Save(r, w) + session = sessions.NewSession(app.Store, "session-name") + session.Options = app.cookieOptions(r) + session.Values["user_id"] = user.ID + session.Values["user_ip"] = r.RemoteAddr + session.Values["user_agent"] = r.UserAgent() + if err := session.Save(r, w); err != nil { + log.Printf("Error saving session: %v", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + http.Redirect(w, r, app.Config.ThreadrDir+"/userhome/", http.StatusFound) + return + } - data := struct { - PageData - Error string - }{ - PageData: PageData{ - Title: "ThreadR - Login", - Navbar: "login", - LoggedIn: false, - BasePath: app.Config.ThreadrDir, - StaticPath: app.Config.ThreadrDir + "/static", - CurrentURL: r.URL.Path, - }, - Error: "", - } - if r.URL.Query().Get("error") == "invalid" { - data.Error = "Invalid username or password" - } + data := struct { + PageData + Error string + }{ + PageData: PageData{ + Title: "ThreadR - Login", + Navbar: "login", + LoggedIn: false, + BasePath: app.Config.ThreadrDir, + StaticPath: app.Config.ThreadrDir + "/static", + CurrentURL: r.URL.Path, + }, + Error: "", + } + if r.URL.Query().Get("error") == "invalid" { + data.Error = "Invalid username or password" + } - if err := app.Tmpl.ExecuteTemplate(w, "login", data); err != nil { - log.Printf("Error executing template in LoginHandler: %v", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return - } - } -} \ No newline at end of file + if err := app.Tmpl.ExecuteTemplate(w, "login", data); err != nil { + log.Printf("Error executing template in LoginHandler: %v", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + } +} diff --git a/main.go b/main.go index 06ac7eb..5ed803c 100644 --- a/main.go +++ b/main.go @@ -22,19 +22,19 @@ import ( ) func loadConfig(filename string) (*handlers.Config, error) { - file, err := os.Open(filename) - if err != nil { - return nil, err - } - defer file.Close() - var config handlers.Config - err = json.NewDecoder(file).Decode(&config) - return &config, err + file, err := os.Open(filename) + if err != nil { + return nil, err + } + defer file.Close() + var config handlers.Config + err = json.NewDecoder(file).Decode(&config) + return &config, err } func createTablesIfNotExist(db *sql.DB) error { - // Create boards table - _, err := db.Exec(` + // Create boards table + _, err := db.Exec(` CREATE TABLE boards ( id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(255) NOT NULL, @@ -46,12 +46,12 @@ func createTablesIfNotExist(db *sql.DB) error { color_scheme VARCHAR(255), type VARCHAR(20) DEFAULT 'classic' NOT NULL )`) - if err != nil { - return fmt.Errorf("error creating boards table: %v", err) - } + if err != nil { + return fmt.Errorf("error creating boards table: %v", err) + } - // Create threads table - _, err = db.Exec(` + // Create threads table + _, err = db.Exec(` CREATE TABLE threads ( id INT AUTO_INCREMENT PRIMARY KEY, board_id INT NOT NULL, @@ -62,12 +62,12 @@ func createTablesIfNotExist(db *sql.DB) error { accepted_answer_post_id INT, FOREIGN KEY (board_id) REFERENCES boards(id) )`) - if err != nil { - return fmt.Errorf("error creating threads table: %v", err) - } + if err != nil { + return fmt.Errorf("error creating threads table: %v", err) + } - // Create posts table - _, err = db.Exec(` + // Create posts table + _, err = db.Exec(` CREATE TABLE posts ( id INT AUTO_INCREMENT PRIMARY KEY, thread_id INT NOT NULL, @@ -81,12 +81,12 @@ func createTablesIfNotExist(db *sql.DB) error { reply_to INT DEFAULT -1, FOREIGN KEY (thread_id) REFERENCES threads(id) )`) - if err != nil { - return fmt.Errorf("error creating posts table: %v", err) - } + if err != nil { + return fmt.Errorf("error creating posts table: %v", err) + } - // Create likes table - _, err = db.Exec(` + // Create likes table + _, err = db.Exec(` CREATE TABLE likes ( id INT AUTO_INCREMENT PRIMARY KEY, post_id INT NOT NULL, @@ -95,12 +95,12 @@ func createTablesIfNotExist(db *sql.DB) error { UNIQUE KEY unique_like (post_id, user_id), FOREIGN KEY (post_id) REFERENCES posts(id) )`) - if err != nil { - return fmt.Errorf("error creating likes table: %v", err) - } + if err != nil { + return fmt.Errorf("error creating likes table: %v", err) + } - // Create board_permissions table - _, err = db.Exec(` + // Create board_permissions table + _, err = db.Exec(` CREATE TABLE board_permissions ( user_id INT NOT NULL, board_id INT NOT NULL, @@ -108,12 +108,12 @@ func createTablesIfNotExist(db *sql.DB) error { PRIMARY KEY (user_id, board_id), FOREIGN KEY (board_id) REFERENCES boards(id) )`) - if err != nil { - return fmt.Errorf("error creating board_permissions table: %v", err) - } + if err != nil { + return fmt.Errorf("error creating board_permissions table: %v", err) + } - // Create notifications table - _, err = db.Exec(` + // Create notifications table + _, err = db.Exec(` CREATE TABLE notifications ( id INT AUTO_INCREMENT PRIMARY KEY, user_id INT NOT NULL, @@ -122,12 +122,12 @@ func createTablesIfNotExist(db *sql.DB) error { is_read BOOLEAN DEFAULT FALSE, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP )`) - if err != nil { - return fmt.Errorf("error creating notifications table: %v", err) - } + if err != nil { + return fmt.Errorf("error creating notifications table: %v", err) + } - // Create reactions table - _, err = db.Exec(` + // Create reactions table + _, err = db.Exec(` CREATE TABLE reactions ( id INT AUTO_INCREMENT PRIMARY KEY, post_id INT NOT NULL, @@ -135,12 +135,12 @@ func createTablesIfNotExist(db *sql.DB) error { emoji VARCHAR(10) NOT NULL, FOREIGN KEY (post_id) REFERENCES posts(id) )`) - if err != nil { - return fmt.Errorf("error creating reactions table: %v", err) - } + if err != nil { + return fmt.Errorf("error creating reactions table: %v", err) + } - // Create reposts table - _, err = db.Exec(` + // Create reposts table + _, err = db.Exec(` CREATE TABLE reposts ( id INT AUTO_INCREMENT PRIMARY KEY, thread_id INT NOT NULL, @@ -150,12 +150,12 @@ func createTablesIfNotExist(db *sql.DB) error { FOREIGN KEY (thread_id) REFERENCES threads(id), FOREIGN KEY (board_id) REFERENCES boards(id) )`) - if err != nil { - return fmt.Errorf("error creating reposts table: %v", err) - } + if err != nil { + return fmt.Errorf("error creating reposts table: %v", err) + } - // Create news table - _, err = db.Exec(` + // Create news table + _, err = db.Exec(` CREATE TABLE news ( id INT AUTO_INCREMENT PRIMARY KEY, title VARCHAR(255) NOT NULL, @@ -163,12 +163,12 @@ func createTablesIfNotExist(db *sql.DB) error { posted_by INT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP )`) - if err != nil { - return fmt.Errorf("error creating news table: %v", err) - } + if err != nil { + return fmt.Errorf("error creating news table: %v", err) + } - // Create chat_messages table - _, err = db.Exec(` + // Create chat_messages table + _, err = db.Exec(` CREATE TABLE chat_messages ( id INT AUTO_INCREMENT PRIMARY KEY, board_id INT NOT NULL, @@ -178,12 +178,12 @@ func createTablesIfNotExist(db *sql.DB) error { timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (board_id) REFERENCES boards(id) ON DELETE CASCADE )`) - if err != nil { - return fmt.Errorf("error creating chat_messages table: %v", err) - } + if err != nil { + return fmt.Errorf("error creating chat_messages table: %v", err) + } - // Create files table (Hope this does not break anything) - _, err = db.Exec(` + // Create files table (Hope this does not break anything) + _, err = db.Exec(` CREATE TABLE files ( id INT AUTO_INCREMENT PRIMARY KEY, original_name VARCHAR(255) NOT NULL, @@ -191,13 +191,13 @@ func createTablesIfNotExist(db *sql.DB) error { hash_algorithm VARCHAR(50) NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP )`) - if err != nil { - return fmt.Errorf("error creating files table: %v", err) - } + if err != nil { + return fmt.Errorf("error creating files table: %v", err) + } - // Create users table (KEEP THIS HERE!) - // Otherwise SQL bitches about the foreign key. - _, err = db.Exec(` + // Create users table (KEEP THIS HERE!) + // Otherwise SQL bitches about the foreign key. + _, err = db.Exec(` CREATE TABLE users ( id INT AUTO_INCREMENT PRIMARY KEY, username VARCHAR(255) NOT NULL UNIQUE, @@ -213,12 +213,12 @@ func createTablesIfNotExist(db *sql.DB) error { permissions BIGINT DEFAULT 0, FOREIGN KEY (pfp_file_id) REFERENCES files(id) )`) - if err != nil { - return fmt.Errorf("error creating users table: %v", err) - } + if err != nil { + return fmt.Errorf("error creating users table: %v", err) + } - log.Println("Database tables created.") - return nil + log.Println("Database tables created.") + return nil } func ensureAdminUser(db *sql.DB) error { @@ -284,110 +284,128 @@ func ensureAdminUser(db *sql.DB) error { } func main() { - // Define command-line flag for initialization - initialize := flag.Bool("initialize", false, "Initialize database tables and admin user") - flag.BoolVar(initialize, "i", false, "Short for --initialize") - flag.Parse() + // Define command-line flag for initialization + initialize := flag.Bool("initialize", false, "Initialize database tables and admin user") + flag.BoolVar(initialize, "i", false, "Short for --initialize") + flag.Parse() - config, err := loadConfig("config/config.json") - if err != nil { - log.Fatal("Error loading config:", err) - } + config, err := loadConfig("config/config.json") + if err != nil { + log.Fatal("Error loading config:", err) + } - dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s", config.DBUsername, config.DBPassword, config.DBServerHost, config.DBDatabase) - db, err := sql.Open("mysql", dsn) - if err != nil { - log.Fatal("Error connecting to database:", err) - } - defer db.Close() + // Allow environment variable override for the session secret to avoid hardcoding secrets in files. + if envSecret := os.Getenv("THREADR_SESSION_SECRET"); envSecret != "" { + config.SessionSecret = envSecret + } + if len(config.SessionSecret) < 32 { + log.Fatal("Session secret must be at least 32 bytes; set THREADR_SESSION_SECRET or session_secret in config") + } - // Create the file directory - // TODO: Wouldn't this be better suited on the initialize function? - // Discussion pending. - err = os.MkdirAll(config.FileStorageDir, 0700) - if err != nil { - log.Fatal("Error creating file storage directory:", err) - } + dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s", config.DBUsername, config.DBPassword, config.DBServerHost, config.DBDatabase) + db, err := sql.Open("mysql", dsn) + if err != nil { + log.Fatal("Error connecting to database:", err) + } + defer db.Close() - // Perform initialization if the flag is set - if *initialize { - log.Println("Initializing database...") - err = createTablesIfNotExist(db) - if err != nil { - log.Fatal("Error creating database tables:", err) - } + // Create the file directory + // TODO: Wouldn't this be better suited on the initialize function? + // Discussion pending. + err = os.MkdirAll(config.FileStorageDir, 0700) + if err != nil { + log.Fatal("Error creating file storage directory:", err) + } - err = ensureAdminUser(db) - if err != nil { - log.Fatal("Error ensuring admin user:", err) - } + // Perform initialization if the flag is set + if *initialize { + log.Println("Initializing database...") + err = createTablesIfNotExist(db) + if err != nil { + log.Fatal("Error creating database tables:", err) + } - log.Println("Initialization completed successfully. Exiting.") - return - } + err = ensureAdminUser(db) + if err != nil { + log.Fatal("Error ensuring admin user:", err) + } - // Normal startup (without automatic table creation) - log.Println("Starting ThreadR server...") + log.Println("Initialization completed successfully. Exiting.") + return + } - dir, err := os.Getwd() - if err != nil { - log.Fatal("Error getting working directory:", err) - } + // Normal startup (without automatic table creation) + log.Println("Starting ThreadR server...") - // Parse partial templates - tmpl := template.Must(template.ParseFiles( - filepath.Join(dir, "templates/partials/navbar.html"), - filepath.Join(dir, "templates/partials/cookie_banner.html"), - )) + dir, err := os.Getwd() + if err != nil { + log.Fatal("Error getting working directory:", err) + } - // Parse page-specific templates with unique names - tmpl, err = tmpl.ParseFiles( - filepath.Join(dir, "templates/pages/about.html"), - filepath.Join(dir, "templates/pages/board.html"), - filepath.Join(dir, "templates/pages/boards.html"), - filepath.Join(dir, "templates/pages/home.html"), - filepath.Join(dir, "templates/pages/login.html"), - filepath.Join(dir, "templates/pages/news.html"), - filepath.Join(dir, "templates/pages/profile.html"), - filepath.Join(dir, "templates/pages/profile_edit.html"), - filepath.Join(dir, "templates/pages/signup.html"), - filepath.Join(dir, "templates/pages/thread.html"), - filepath.Join(dir, "templates/pages/userhome.html"), - filepath.Join(dir, "templates/pages/chat.html"), - ) - if err != nil { - log.Fatal("Error parsing page templates:", err) - } + // Parse partial templates + tmpl := template.Must(template.ParseFiles( + filepath.Join(dir, "templates/partials/navbar.html"), + filepath.Join(dir, "templates/partials/cookie_banner.html"), + )) - store := sessions.NewCookieStore([]byte("secret-key")) // Replace with secure key in production + // Parse page-specific templates with unique names + tmpl, err = tmpl.ParseFiles( + filepath.Join(dir, "templates/pages/about.html"), + filepath.Join(dir, "templates/pages/board.html"), + filepath.Join(dir, "templates/pages/boards.html"), + filepath.Join(dir, "templates/pages/home.html"), + filepath.Join(dir, "templates/pages/login.html"), + filepath.Join(dir, "templates/pages/news.html"), + filepath.Join(dir, "templates/pages/profile.html"), + filepath.Join(dir, "templates/pages/profile_edit.html"), + filepath.Join(dir, "templates/pages/signup.html"), + filepath.Join(dir, "templates/pages/thread.html"), + filepath.Join(dir, "templates/pages/userhome.html"), + filepath.Join(dir, "templates/pages/chat.html"), + ) + if err != nil { + log.Fatal("Error parsing page templates:", err) + } - app := &handlers.App{ - DB: db, - Store: store, - Config: config, - Tmpl: tmpl, - } + store := sessions.NewCookieStore([]byte(config.SessionSecret)) + store.Options = &sessions.Options{ + Path: config.ThreadrDir + "/", + MaxAge: 86400 * 30, + HttpOnly: true, + Secure: config.SessionSecure, + SameSite: http.SameSiteLaxMode, + } + if config.DomainName != "" { + store.Options.Domain = config.DomainName + } - fs := http.FileServer(http.Dir("static")) - http.Handle(config.ThreadrDir+"/static/", http.StripPrefix(config.ThreadrDir+"/static/", fs)) + app := &handlers.App{ + DB: db, + Store: store, + Config: config, + Tmpl: tmpl, + } - http.HandleFunc(config.ThreadrDir+"/", app.SessionMW(handlers.HomeHandler(app))) - http.HandleFunc(config.ThreadrDir+"/login/", app.SessionMW(handlers.LoginHandler(app))) - http.HandleFunc(config.ThreadrDir+"/logout/", app.SessionMW(handlers.LogoutHandler(app))) - http.HandleFunc(config.ThreadrDir+"/userhome/", app.SessionMW(app.RequireLoginMW(handlers.UserHomeHandler(app)))) - http.HandleFunc(config.ThreadrDir+"/boards/", app.SessionMW(handlers.BoardsHandler(app))) - http.HandleFunc(config.ThreadrDir+"/board/", app.SessionMW(handlers.BoardHandler(app))) - http.HandleFunc(config.ThreadrDir+"/thread/", app.SessionMW(handlers.ThreadHandler(app))) - http.HandleFunc(config.ThreadrDir+"/about/", app.SessionMW(handlers.AboutHandler(app))) - http.HandleFunc(config.ThreadrDir+"/profile/", app.SessionMW(app.RequireLoginMW(handlers.ProfileHandler(app)))) - http.HandleFunc(config.ThreadrDir+"/profile/edit/", app.SessionMW(app.RequireLoginMW(handlers.ProfileEditHandler(app)))) - http.HandleFunc(config.ThreadrDir+"/like/", app.SessionMW(app.RequireLoginMW(handlers.LikeHandler(app)))) - http.HandleFunc(config.ThreadrDir+"/news/", app.SessionMW(handlers.NewsHandler(app))) - http.HandleFunc(config.ThreadrDir+"/signup/", app.SessionMW(handlers.SignupHandler(app))) - http.HandleFunc(config.ThreadrDir+"/accept_cookie/", app.SessionMW(handlers.AcceptCookieHandler(app))) - http.HandleFunc(config.ThreadrDir+"/chat/", app.SessionMW(app.RequireLoginMW(handlers.ChatHandler(app)))) - http.HandleFunc(config.ThreadrDir+"/file", app.SessionMW(handlers.FileHandler(app))) + fs := http.FileServer(http.Dir("static")) + http.Handle(config.ThreadrDir+"/static/", http.StripPrefix(config.ThreadrDir+"/static/", fs)) - log.Println("Server starting on :8080") - log.Fatal(http.ListenAndServe(":8080", nil)) -} \ No newline at end of file + http.HandleFunc(config.ThreadrDir+"/", app.SessionMW(handlers.HomeHandler(app))) + http.HandleFunc(config.ThreadrDir+"/login/", app.SessionMW(handlers.LoginHandler(app))) + http.HandleFunc(config.ThreadrDir+"/logout/", app.SessionMW(handlers.LogoutHandler(app))) + http.HandleFunc(config.ThreadrDir+"/userhome/", app.SessionMW(app.RequireLoginMW(handlers.UserHomeHandler(app)))) + http.HandleFunc(config.ThreadrDir+"/boards/", app.SessionMW(handlers.BoardsHandler(app))) + http.HandleFunc(config.ThreadrDir+"/board/", app.SessionMW(handlers.BoardHandler(app))) + http.HandleFunc(config.ThreadrDir+"/thread/", app.SessionMW(handlers.ThreadHandler(app))) + http.HandleFunc(config.ThreadrDir+"/about/", app.SessionMW(handlers.AboutHandler(app))) + http.HandleFunc(config.ThreadrDir+"/profile/", app.SessionMW(app.RequireLoginMW(handlers.ProfileHandler(app)))) + http.HandleFunc(config.ThreadrDir+"/profile/edit/", app.SessionMW(app.RequireLoginMW(handlers.ProfileEditHandler(app)))) + http.HandleFunc(config.ThreadrDir+"/like/", app.SessionMW(app.RequireLoginMW(handlers.LikeHandler(app)))) + http.HandleFunc(config.ThreadrDir+"/news/", app.SessionMW(handlers.NewsHandler(app))) + http.HandleFunc(config.ThreadrDir+"/signup/", app.SessionMW(handlers.SignupHandler(app))) + http.HandleFunc(config.ThreadrDir+"/accept_cookie/", app.SessionMW(handlers.AcceptCookieHandler(app))) + http.HandleFunc(config.ThreadrDir+"/chat/", app.SessionMW(app.RequireLoginMW(handlers.ChatHandler(app)))) + http.HandleFunc(config.ThreadrDir+"/file", app.SessionMW(handlers.FileHandler(app))) + + log.Println("Server starting on :8080") + log.Fatal(http.ListenAndServe(":8080", nil)) +}