diff --git a/go.mod b/go.mod index 74365a9..188efa0 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.24.0 require ( github.com/go-sql-driver/mysql v1.9.0 github.com/gorilla/sessions v1.4.0 + github.com/gorilla/websocket v1.5.0 ) require ( diff --git a/go.sum b/go.sum index 2b8450f..8d7d4ce 100644 --- a/go.sum +++ b/go.sum @@ -8,3 +8,5 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ= github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik= +github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= +github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= diff --git a/handlers/app.go b/handlers/app.go index 062c4cb..ccf0060 100644 --- a/handlers/app.go +++ b/handlers/app.go @@ -41,14 +41,22 @@ func (app *App) SessionMW(next http.HandlerFunc) http.HandlerFunc { session, err := app.Store.Get(r, "session-name") if err != nil { session = sessions.NewSession(app.Store, "session-name") + session.Options = &sessions.Options{ + Path: "/", + MaxAge: 86400 * 30, // 30 days + HttpOnly: true, + } } if _, ok := session.Values["user_id"].(int); ok { - if session.Values["user_ip"] != r.RemoteAddr || session.Values["user_agent"] != r.UserAgent() { - session.Values = make(map[interface{}]interface{}) - session.Options.MaxAge = -1 - session.Save(r, w) - http.Redirect(w, r, app.Config.ThreadrDir+"/login/?error=session", http.StatusFound) - return + // Skip IP and User-Agent check for WebSocket connections + if r.URL.Query().Get("ws") != "true" { + if session.Values["user_ip"] != r.RemoteAddr || session.Values["user_agent"] != r.UserAgent() { + session.Values = make(map[interface{}]interface{}) + session.Options.MaxAge = -1 + session.Save(r, w) + http.Redirect(w, r, app.Config.ThreadrDir+"/login/?error=session", http.StatusFound) + return + } } ctx := context.WithValue(r.Context(), "session", session) r = r.WithContext(ctx) diff --git a/handlers/chat.go b/handlers/chat.go new file mode 100644 index 0000000..ce229c1 --- /dev/null +++ b/handlers/chat.go @@ -0,0 +1,188 @@ +package handlers + +import ( + "encoding/json" + "log" + "net/http" + "sync" + "threadr/models" + "github.com/gorilla/sessions" + "github.com/gorilla/websocket" +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true // Allow all origins for now; restrict in production + }, +} + +// ChatHub manages WebSocket connections and broadcasts messages +type ChatHub struct { + clients map[*websocket.Conn]int // Map of connections to user IDs + broadcast chan []byte + register chan *websocket.Conn + unregister chan *websocket.Conn + mutex sync.Mutex +} + +func NewChatHub() *ChatHub { + return &ChatHub{ + clients: make(map[*websocket.Conn]int), + broadcast: make(chan []byte), + register: make(chan *websocket.Conn), + unregister: make(chan *websocket.Conn), + } +} + +func (h *ChatHub) Run() { + for { + select { + case client := <-h.register: + h.mutex.Lock() + h.clients[client] = 0 // UserID set later + h.mutex.Unlock() + case client := <-h.unregister: + h.mutex.Lock() + delete(h.clients, client) + h.mutex.Unlock() + client.Close() + case message := <-h.broadcast: + h.mutex.Lock() + for client := range h.clients { + err := client.WriteMessage(websocket.TextMessage, message) + if err != nil { + log.Printf("Error broadcasting message: %v", err) + client.Close() + delete(h.clients, client) + } + } + h.mutex.Unlock() + } + } +} + +var hub = NewChatHub() + +func init() { + go hub.Run() +} + +func ChatHandler(app *App) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + session := r.Context().Value("session").(*sessions.Session) + userID, ok := session.Values["user_id"].(int) + if !ok { + http.Redirect(w, r, app.Config.ThreadrDir+"/login/", http.StatusFound) + return + } + cookie, _ := r.Cookie("threadr_cookie_banner") + + if r.URL.Query().Get("ws") == "true" { + // Handle WebSocket connection + ws, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("Error upgrading to WebSocket: %v", err) + return + } + hub.register <- ws + hub.mutex.Lock() + hub.clients[ws] = userID + hub.mutex.Unlock() + + defer func() { + hub.unregister <- ws + }() + + for { + _, msg, err := ws.ReadMessage() + if err != nil { + log.Printf("Error reading WebSocket message: %v", err) + break + } + var chatMsg struct { + Type string `json:"type"` + Content string `json:"content"` + ReplyTo int `json:"replyTo"` + } + if err := json.Unmarshal(msg, &chatMsg); err != nil { + log.Printf("Error unmarshaling message: %v", err) + continue + } + + if chatMsg.Type == "message" { + msgObj := models.ChatMessage{ + UserID: userID, + Content: chatMsg.Content, + ReplyTo: chatMsg.ReplyTo, + } + if err := models.CreateChatMessage(app.DB, msgObj); err != nil { + log.Printf("Error saving chat message: %v", err) + continue + } + // Fetch the saved message with timestamp and user details + var msgID int + app.DB.QueryRow("SELECT LAST_INSERT_ID()").Scan(&msgID) + savedMsg, err := models.GetChatMessageByID(app.DB, msgID) + if err != nil { + log.Printf("Error fetching saved message: %v", err) + continue + } + response, _ := json.Marshal(savedMsg) + hub.broadcast <- response + } + } + return + } + + if r.URL.Query().Get("autocomplete") == "true" { + // Handle autocomplete for mentions + prefix := r.URL.Query().Get("prefix") + usernames, err := models.GetUsernamesMatching(app.DB, prefix) + if err != nil { + log.Printf("Error fetching usernames for autocomplete: %v", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + response, _ := json.Marshal(usernames) + w.Header().Set("Content-Type", "application/json") + w.Write(response) + return + } + + // Render chat page + messages, err := models.GetRecentChatMessages(app.DB, 50) + if err != nil { + log.Printf("Error fetching chat messages: %v", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + // Reverse messages to show oldest first + for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 { + messages[i], messages[j] = messages[j], messages[i] + } + + data := struct { + PageData + Messages []models.ChatMessage + }{ + PageData: PageData{ + Title: "ThreadR - Chat", + Navbar: "chat", + LoggedIn: true, + ShowCookieBanner: cookie == nil || cookie.Value != "accepted", + BasePath: app.Config.ThreadrDir, + StaticPath: app.Config.ThreadrDir + "/static", + CurrentURL: r.URL.Path, + }, + Messages: messages, + } + if err := app.Tmpl.ExecuteTemplate(w, "chat", data); err != nil { + log.Printf("Error executing template in ChatHandler: %v", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + } +} \ No newline at end of file diff --git a/handlers/login.go b/handlers/login.go index b337bf1..11b19ec 100644 --- a/handlers/login.go +++ b/handlers/login.go @@ -27,6 +27,11 @@ func LoginHandler(app *App) http.HandlerFunc { session.Values["user_id"] = user.ID session.Values["user_ip"] = r.RemoteAddr session.Values["user_agent"] = r.UserAgent() + session.Options = &sessions.Options{ + Path: "/", + MaxAge: 86400 * 30, // 30 days + HttpOnly: true, + } if err := session.Save(r, w); err != nil { log.Printf("Error saving session: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) diff --git a/main.go b/main.go index e72bc94..987a516 100644 --- a/main.go +++ b/main.go @@ -182,6 +182,19 @@ func createTablesIfNotExist(db *sql.DB) error { return fmt.Errorf("error creating news table: %v", err) } + // Create chat_messages table + _, err = db.Exec(` + CREATE TABLE IF NOT EXISTS chat_messages ( + id INT AUTO_INCREMENT PRIMARY KEY, + user_id INT NOT NULL, + content TEXT NOT NULL, + reply_to INT DEFAULT -1, + timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP + )`) + if err != nil { + return fmt.Errorf("error creating chat_messages table: %v", err) + } + log.Println("Database tables created or already exist") return nil } @@ -281,6 +294,7 @@ func main() { filepath.Join(dir, "templates/pages/signup.html"), filepath.Join(dir, "templates/pages/thread.html"), filepath.Join(dir, "templates/pages/userhome.html"), + filepath.Join(dir, "templates/pages/chat.html"), ) if err != nil { log.Fatal("Error parsing page templates:", err) @@ -312,6 +326,7 @@ func main() { http.HandleFunc(config.ThreadrDir+"/news/", app.SessionMW(handlers.NewsHandler(app))) http.HandleFunc(config.ThreadrDir+"/signup/", app.SessionMW(handlers.SignupHandler(app))) http.HandleFunc(config.ThreadrDir+"/accept_cookie/", app.SessionMW(handlers.AcceptCookieHandler(app))) + http.HandleFunc(config.ThreadrDir+"/chat/", app.SessionMW(app.RequireLoginMW(handlers.ChatHandler(app)))) log.Println("Server starting on :8080") log.Fatal(http.ListenAndServe(":8080", nil)) diff --git a/models/chat.go b/models/chat.go new file mode 100644 index 0000000..7e33565 --- /dev/null +++ b/models/chat.go @@ -0,0 +1,132 @@ +package models + +import ( + "database/sql" + "time" +) + +type ChatMessage struct { + ID int + UserID int + Content string + ReplyTo int // -1 if not a reply + Timestamp time.Time + Username string // For display, fetched from user + PfpURL string // For display, fetched from user + Mentions []string // List of mentioned usernames +} + +func CreateChatMessage(db *sql.DB, msg ChatMessage) error { + query := "INSERT INTO chat_messages (user_id, content, reply_to, timestamp) VALUES (?, ?, ?, NOW())" + _, err := db.Exec(query, msg.UserID, msg.Content, msg.ReplyTo) + return err +} + +func GetRecentChatMessages(db *sql.DB, limit int) ([]ChatMessage, error) { + query := ` + SELECT cm.id, cm.user_id, cm.content, cm.reply_to, cm.timestamp, u.username, u.pfp_url + FROM chat_messages cm + JOIN users u ON cm.user_id = u.id + ORDER BY cm.timestamp DESC + LIMIT ?` + rows, err := db.Query(query, limit) + if err != nil { + return nil, err + } + defer rows.Close() + + var messages []ChatMessage + for rows.Next() { + var msg ChatMessage + var timestampStr string + var pfpURL sql.NullString + err := rows.Scan(&msg.ID, &msg.UserID, &msg.Content, &msg.ReplyTo, ×tampStr, &msg.Username, &pfpURL) + if err != nil { + return nil, err + } + msg.Timestamp, err = time.Parse("2006-01-02 15:04:05", timestampStr) + if err != nil { + msg.Timestamp = time.Time{} + } + if pfpURL.Valid { + msg.PfpURL = pfpURL.String + } + // Parse mentions from content (simple @username detection) + msg.Mentions = extractMentions(msg.Content) + messages = append(messages, msg) + } + return messages, nil +} + +func GetChatMessageByID(db *sql.DB, id int) (*ChatMessage, error) { + query := ` + SELECT cm.id, cm.user_id, cm.content, cm.reply_to, cm.timestamp, u.username, u.pfp_url + FROM chat_messages cm + JOIN users u ON cm.user_id = u.id + WHERE cm.id = ?` + row := db.QueryRow(query, id) + var msg ChatMessage + var timestampStr string + var pfpURL sql.NullString + err := row.Scan(&msg.ID, &msg.UserID, &msg.Content, &msg.ReplyTo, ×tampStr, &msg.Username, &pfpURL) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + msg.Timestamp, err = time.Parse("2006-01-02 15:04:05", timestampStr) + if err != nil { + msg.Timestamp = time.Time{} + } + if pfpURL.Valid { + msg.PfpURL = pfpURL.String + } + msg.Mentions = extractMentions(msg.Content) + return &msg, nil +} + +func GetUsernamesMatching(db *sql.DB, prefix string) ([]string, error) { + query := "SELECT username FROM users WHERE username LIKE ? LIMIT 10" + rows, err := db.Query(query, prefix+"%") + if err != nil { + return nil, err + } + defer rows.Close() + + var usernames []string + for rows.Next() { + var username string + if err := rows.Scan(&username); err != nil { + return nil, err + } + usernames = append(usernames, username) + } + return usernames, nil +} + +// Simple utility to extract mentions from content +func extractMentions(content string) []string { + var mentions []string + var currentMention string + inMention := false + + for _, char := range content { + if char == '@' { + inMention = true + currentMention = "@" + } else if inMention && (char == ' ' || char == '\n' || char == '\t') { + if len(currentMention) > 1 { + mentions = append(mentions, currentMention) + } + inMention = false + currentMention = "" + } else if inMention { + currentMention += string(char) + } + } + if inMention && len(currentMention) > 1 { + mentions = append(mentions, currentMention) + } + return mentions +} \ No newline at end of file diff --git a/templates/pages/chat.html b/templates/pages/chat.html new file mode 100644 index 0000000..382947f --- /dev/null +++ b/templates/pages/chat.html @@ -0,0 +1,383 @@ +{{define "chat"}} + + + + {{.Title}} + + + + + {{template "navbar" .}} +
+
+

General Chat

+
+
+
+ {{range .Messages}} +
+
+ {{if .PfpURL}} + PFP + {{else}} +
+ {{end}} + {{.Username}} + {{.Timestamp.Format "02/01/2006 15:04"}} +
+ {{if gt .ReplyTo 0}} +
Replying to message #{{.ReplyTo}}
+ {{end}} +
{{.Content | html}}
+
+ Reply +
+
+ {{end}} +
+
+ + +
+
+
+
+ {{template "cookie_banner" .}} + + + +{{end}} \ No newline at end of file diff --git a/templates/partials/navbar.html b/templates/partials/navbar.html index 0f6af1d..3d85f26 100644 --- a/templates/partials/navbar.html +++ b/templates/partials/navbar.html @@ -4,6 +4,7 @@ {{if .LoggedIn}}
  • User Home
  • Profile
  • +
  • Chat
  • Logout
  • {{else}}
  • Login