diff --git a/go.mod b/go.mod
index 74365a9..188efa0 100644
--- a/go.mod
+++ b/go.mod
@@ -5,6 +5,7 @@ go 1.24.0
require (
github.com/go-sql-driver/mysql v1.9.0
github.com/gorilla/sessions v1.4.0
+ github.com/gorilla/websocket v1.5.0
)
require (
diff --git a/go.sum b/go.sum
index 2b8450f..8d7d4ce 100644
--- a/go.sum
+++ b/go.sum
@@ -8,3 +8,5 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ=
github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik=
+github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
+github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
diff --git a/handlers/app.go b/handlers/app.go
index 062c4cb..ccf0060 100644
--- a/handlers/app.go
+++ b/handlers/app.go
@@ -41,14 +41,22 @@ func (app *App) SessionMW(next http.HandlerFunc) http.HandlerFunc {
session, err := app.Store.Get(r, "session-name")
if err != nil {
session = sessions.NewSession(app.Store, "session-name")
+ session.Options = &sessions.Options{
+ Path: "/",
+ MaxAge: 86400 * 30, // 30 days
+ HttpOnly: true,
+ }
}
if _, ok := session.Values["user_id"].(int); ok {
- if session.Values["user_ip"] != r.RemoteAddr || session.Values["user_agent"] != r.UserAgent() {
- session.Values = make(map[interface{}]interface{})
- session.Options.MaxAge = -1
- session.Save(r, w)
- http.Redirect(w, r, app.Config.ThreadrDir+"/login/?error=session", http.StatusFound)
- return
+ // Skip IP and User-Agent check for WebSocket connections
+ if r.URL.Query().Get("ws") != "true" {
+ if session.Values["user_ip"] != r.RemoteAddr || session.Values["user_agent"] != r.UserAgent() {
+ session.Values = make(map[interface{}]interface{})
+ session.Options.MaxAge = -1
+ session.Save(r, w)
+ http.Redirect(w, r, app.Config.ThreadrDir+"/login/?error=session", http.StatusFound)
+ return
+ }
}
ctx := context.WithValue(r.Context(), "session", session)
r = r.WithContext(ctx)
diff --git a/handlers/chat.go b/handlers/chat.go
new file mode 100644
index 0000000..ce229c1
--- /dev/null
+++ b/handlers/chat.go
@@ -0,0 +1,188 @@
+package handlers
+
+import (
+ "encoding/json"
+ "log"
+ "net/http"
+ "sync"
+ "threadr/models"
+ "github.com/gorilla/sessions"
+ "github.com/gorilla/websocket"
+)
+
+var upgrader = websocket.Upgrader{
+ ReadBufferSize: 1024,
+ WriteBufferSize: 1024,
+ CheckOrigin: func(r *http.Request) bool {
+ return true // Allow all origins for now; restrict in production
+ },
+}
+
+// ChatHub manages WebSocket connections and broadcasts messages
+type ChatHub struct {
+ clients map[*websocket.Conn]int // Map of connections to user IDs
+ broadcast chan []byte
+ register chan *websocket.Conn
+ unregister chan *websocket.Conn
+ mutex sync.Mutex
+}
+
+func NewChatHub() *ChatHub {
+ return &ChatHub{
+ clients: make(map[*websocket.Conn]int),
+ broadcast: make(chan []byte),
+ register: make(chan *websocket.Conn),
+ unregister: make(chan *websocket.Conn),
+ }
+}
+
+func (h *ChatHub) Run() {
+ for {
+ select {
+ case client := <-h.register:
+ h.mutex.Lock()
+ h.clients[client] = 0 // UserID set later
+ h.mutex.Unlock()
+ case client := <-h.unregister:
+ h.mutex.Lock()
+ delete(h.clients, client)
+ h.mutex.Unlock()
+ client.Close()
+ case message := <-h.broadcast:
+ h.mutex.Lock()
+ for client := range h.clients {
+ err := client.WriteMessage(websocket.TextMessage, message)
+ if err != nil {
+ log.Printf("Error broadcasting message: %v", err)
+ client.Close()
+ delete(h.clients, client)
+ }
+ }
+ h.mutex.Unlock()
+ }
+ }
+}
+
+var hub = NewChatHub()
+
+func init() {
+ go hub.Run()
+}
+
+func ChatHandler(app *App) http.HandlerFunc {
+ return func(w http.ResponseWriter, r *http.Request) {
+ session := r.Context().Value("session").(*sessions.Session)
+ userID, ok := session.Values["user_id"].(int)
+ if !ok {
+ http.Redirect(w, r, app.Config.ThreadrDir+"/login/", http.StatusFound)
+ return
+ }
+ cookie, _ := r.Cookie("threadr_cookie_banner")
+
+ if r.URL.Query().Get("ws") == "true" {
+ // Handle WebSocket connection
+ ws, err := upgrader.Upgrade(w, r, nil)
+ if err != nil {
+ log.Printf("Error upgrading to WebSocket: %v", err)
+ return
+ }
+ hub.register <- ws
+ hub.mutex.Lock()
+ hub.clients[ws] = userID
+ hub.mutex.Unlock()
+
+ defer func() {
+ hub.unregister <- ws
+ }()
+
+ for {
+ _, msg, err := ws.ReadMessage()
+ if err != nil {
+ log.Printf("Error reading WebSocket message: %v", err)
+ break
+ }
+ var chatMsg struct {
+ Type string `json:"type"`
+ Content string `json:"content"`
+ ReplyTo int `json:"replyTo"`
+ }
+ if err := json.Unmarshal(msg, &chatMsg); err != nil {
+ log.Printf("Error unmarshaling message: %v", err)
+ continue
+ }
+
+ if chatMsg.Type == "message" {
+ msgObj := models.ChatMessage{
+ UserID: userID,
+ Content: chatMsg.Content,
+ ReplyTo: chatMsg.ReplyTo,
+ }
+ if err := models.CreateChatMessage(app.DB, msgObj); err != nil {
+ log.Printf("Error saving chat message: %v", err)
+ continue
+ }
+ // Fetch the saved message with timestamp and user details
+ var msgID int
+ app.DB.QueryRow("SELECT LAST_INSERT_ID()").Scan(&msgID)
+ savedMsg, err := models.GetChatMessageByID(app.DB, msgID)
+ if err != nil {
+ log.Printf("Error fetching saved message: %v", err)
+ continue
+ }
+ response, _ := json.Marshal(savedMsg)
+ hub.broadcast <- response
+ }
+ }
+ return
+ }
+
+ if r.URL.Query().Get("autocomplete") == "true" {
+ // Handle autocomplete for mentions
+ prefix := r.URL.Query().Get("prefix")
+ usernames, err := models.GetUsernamesMatching(app.DB, prefix)
+ if err != nil {
+ log.Printf("Error fetching usernames for autocomplete: %v", err)
+ http.Error(w, "Internal Server Error", http.StatusInternalServerError)
+ return
+ }
+ response, _ := json.Marshal(usernames)
+ w.Header().Set("Content-Type", "application/json")
+ w.Write(response)
+ return
+ }
+
+ // Render chat page
+ messages, err := models.GetRecentChatMessages(app.DB, 50)
+ if err != nil {
+ log.Printf("Error fetching chat messages: %v", err)
+ http.Error(w, "Internal Server Error", http.StatusInternalServerError)
+ return
+ }
+
+ // Reverse messages to show oldest first
+ for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 {
+ messages[i], messages[j] = messages[j], messages[i]
+ }
+
+ data := struct {
+ PageData
+ Messages []models.ChatMessage
+ }{
+ PageData: PageData{
+ Title: "ThreadR - Chat",
+ Navbar: "chat",
+ LoggedIn: true,
+ ShowCookieBanner: cookie == nil || cookie.Value != "accepted",
+ BasePath: app.Config.ThreadrDir,
+ StaticPath: app.Config.ThreadrDir + "/static",
+ CurrentURL: r.URL.Path,
+ },
+ Messages: messages,
+ }
+ if err := app.Tmpl.ExecuteTemplate(w, "chat", data); err != nil {
+ log.Printf("Error executing template in ChatHandler: %v", err)
+ http.Error(w, "Internal Server Error", http.StatusInternalServerError)
+ return
+ }
+ }
+}
\ No newline at end of file
diff --git a/handlers/login.go b/handlers/login.go
index b337bf1..11b19ec 100644
--- a/handlers/login.go
+++ b/handlers/login.go
@@ -27,6 +27,11 @@ func LoginHandler(app *App) http.HandlerFunc {
session.Values["user_id"] = user.ID
session.Values["user_ip"] = r.RemoteAddr
session.Values["user_agent"] = r.UserAgent()
+ session.Options = &sessions.Options{
+ Path: "/",
+ MaxAge: 86400 * 30, // 30 days
+ HttpOnly: true,
+ }
if err := session.Save(r, w); err != nil {
log.Printf("Error saving session: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
diff --git a/main.go b/main.go
index e72bc94..987a516 100644
--- a/main.go
+++ b/main.go
@@ -182,6 +182,19 @@ func createTablesIfNotExist(db *sql.DB) error {
return fmt.Errorf("error creating news table: %v", err)
}
+ // Create chat_messages table
+ _, err = db.Exec(`
+ CREATE TABLE IF NOT EXISTS chat_messages (
+ id INT AUTO_INCREMENT PRIMARY KEY,
+ user_id INT NOT NULL,
+ content TEXT NOT NULL,
+ reply_to INT DEFAULT -1,
+ timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
+ )`)
+ if err != nil {
+ return fmt.Errorf("error creating chat_messages table: %v", err)
+ }
+
log.Println("Database tables created or already exist")
return nil
}
@@ -281,6 +294,7 @@ func main() {
filepath.Join(dir, "templates/pages/signup.html"),
filepath.Join(dir, "templates/pages/thread.html"),
filepath.Join(dir, "templates/pages/userhome.html"),
+ filepath.Join(dir, "templates/pages/chat.html"),
)
if err != nil {
log.Fatal("Error parsing page templates:", err)
@@ -312,6 +326,7 @@ func main() {
http.HandleFunc(config.ThreadrDir+"/news/", app.SessionMW(handlers.NewsHandler(app)))
http.HandleFunc(config.ThreadrDir+"/signup/", app.SessionMW(handlers.SignupHandler(app)))
http.HandleFunc(config.ThreadrDir+"/accept_cookie/", app.SessionMW(handlers.AcceptCookieHandler(app)))
+ http.HandleFunc(config.ThreadrDir+"/chat/", app.SessionMW(app.RequireLoginMW(handlers.ChatHandler(app))))
log.Println("Server starting on :8080")
log.Fatal(http.ListenAndServe(":8080", nil))
diff --git a/models/chat.go b/models/chat.go
new file mode 100644
index 0000000..7e33565
--- /dev/null
+++ b/models/chat.go
@@ -0,0 +1,132 @@
+package models
+
+import (
+ "database/sql"
+ "time"
+)
+
+type ChatMessage struct {
+ ID int
+ UserID int
+ Content string
+ ReplyTo int // -1 if not a reply
+ Timestamp time.Time
+ Username string // For display, fetched from user
+ PfpURL string // For display, fetched from user
+ Mentions []string // List of mentioned usernames
+}
+
+func CreateChatMessage(db *sql.DB, msg ChatMessage) error {
+ query := "INSERT INTO chat_messages (user_id, content, reply_to, timestamp) VALUES (?, ?, ?, NOW())"
+ _, err := db.Exec(query, msg.UserID, msg.Content, msg.ReplyTo)
+ return err
+}
+
+func GetRecentChatMessages(db *sql.DB, limit int) ([]ChatMessage, error) {
+ query := `
+ SELECT cm.id, cm.user_id, cm.content, cm.reply_to, cm.timestamp, u.username, u.pfp_url
+ FROM chat_messages cm
+ JOIN users u ON cm.user_id = u.id
+ ORDER BY cm.timestamp DESC
+ LIMIT ?`
+ rows, err := db.Query(query, limit)
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var messages []ChatMessage
+ for rows.Next() {
+ var msg ChatMessage
+ var timestampStr string
+ var pfpURL sql.NullString
+ err := rows.Scan(&msg.ID, &msg.UserID, &msg.Content, &msg.ReplyTo, ×tampStr, &msg.Username, &pfpURL)
+ if err != nil {
+ return nil, err
+ }
+ msg.Timestamp, err = time.Parse("2006-01-02 15:04:05", timestampStr)
+ if err != nil {
+ msg.Timestamp = time.Time{}
+ }
+ if pfpURL.Valid {
+ msg.PfpURL = pfpURL.String
+ }
+ // Parse mentions from content (simple @username detection)
+ msg.Mentions = extractMentions(msg.Content)
+ messages = append(messages, msg)
+ }
+ return messages, nil
+}
+
+func GetChatMessageByID(db *sql.DB, id int) (*ChatMessage, error) {
+ query := `
+ SELECT cm.id, cm.user_id, cm.content, cm.reply_to, cm.timestamp, u.username, u.pfp_url
+ FROM chat_messages cm
+ JOIN users u ON cm.user_id = u.id
+ WHERE cm.id = ?`
+ row := db.QueryRow(query, id)
+ var msg ChatMessage
+ var timestampStr string
+ var pfpURL sql.NullString
+ err := row.Scan(&msg.ID, &msg.UserID, &msg.Content, &msg.ReplyTo, ×tampStr, &msg.Username, &pfpURL)
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ if err != nil {
+ return nil, err
+ }
+ msg.Timestamp, err = time.Parse("2006-01-02 15:04:05", timestampStr)
+ if err != nil {
+ msg.Timestamp = time.Time{}
+ }
+ if pfpURL.Valid {
+ msg.PfpURL = pfpURL.String
+ }
+ msg.Mentions = extractMentions(msg.Content)
+ return &msg, nil
+}
+
+func GetUsernamesMatching(db *sql.DB, prefix string) ([]string, error) {
+ query := "SELECT username FROM users WHERE username LIKE ? LIMIT 10"
+ rows, err := db.Query(query, prefix+"%")
+ if err != nil {
+ return nil, err
+ }
+ defer rows.Close()
+
+ var usernames []string
+ for rows.Next() {
+ var username string
+ if err := rows.Scan(&username); err != nil {
+ return nil, err
+ }
+ usernames = append(usernames, username)
+ }
+ return usernames, nil
+}
+
+// Simple utility to extract mentions from content
+func extractMentions(content string) []string {
+ var mentions []string
+ var currentMention string
+ inMention := false
+
+ for _, char := range content {
+ if char == '@' {
+ inMention = true
+ currentMention = "@"
+ } else if inMention && (char == ' ' || char == '\n' || char == '\t') {
+ if len(currentMention) > 1 {
+ mentions = append(mentions, currentMention)
+ }
+ inMention = false
+ currentMention = ""
+ } else if inMention {
+ currentMention += string(char)
+ }
+ }
+ if inMention && len(currentMention) > 1 {
+ mentions = append(mentions, currentMention)
+ }
+ return mentions
+}
\ No newline at end of file
diff --git a/templates/pages/chat.html b/templates/pages/chat.html
new file mode 100644
index 0000000..382947f
--- /dev/null
+++ b/templates/pages/chat.html
@@ -0,0 +1,383 @@
+{{define "chat"}}
+
+
+
+ {{.Title}}
+
+
+
+
+ {{template "navbar" .}}
+
+
+
+
+ {{range .Messages}}
+
+
+ {{if gt .ReplyTo 0}}
+
Replying to message #{{.ReplyTo}}
+ {{end}}
+
{{.Content | html}}
+
+
+ {{end}}
+
+
+
+
+
+
+
+
+ {{template "cookie_banner" .}}
+
+
+
+{{end}}
\ No newline at end of file
diff --git a/templates/partials/navbar.html b/templates/partials/navbar.html
index 0f6af1d..3d85f26 100644
--- a/templates/partials/navbar.html
+++ b/templates/partials/navbar.html
@@ -4,6 +4,7 @@
{{if .LoggedIn}}
User Home
Profile
+ Chat
Logout
{{else}}
Login