package handlers import ( "encoding/json" "log" "net/http" "sync" "threadr/models" "github.com/gorilla/sessions" "github.com/gorilla/websocket" ) var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { return true // Allow all origins for now; restrict in production }, } // ChatHub manages WebSocket connections and broadcasts messages type ChatHub struct { clients map[*websocket.Conn]int // Map of connections to user IDs broadcast chan []byte register chan *websocket.Conn unregister chan *websocket.Conn mutex sync.Mutex } func NewChatHub() *ChatHub { return &ChatHub{ clients: make(map[*websocket.Conn]int), broadcast: make(chan []byte), register: make(chan *websocket.Conn), unregister: make(chan *websocket.Conn), } } func (h *ChatHub) Run() { for { select { case client := <-h.register: h.mutex.Lock() h.clients[client] = 0 // UserID set later h.mutex.Unlock() case client := <-h.unregister: h.mutex.Lock() delete(h.clients, client) h.mutex.Unlock() client.Close() case message := <-h.broadcast: h.mutex.Lock() for client := range h.clients { err := client.WriteMessage(websocket.TextMessage, message) if err != nil { log.Printf("Error broadcasting message: %v", err) client.Close() delete(h.clients, client) } } h.mutex.Unlock() } } } var hub = NewChatHub() func init() { go hub.Run() } func ChatHandler(app *App) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { session := r.Context().Value("session").(*sessions.Session) userID, ok := session.Values["user_id"].(int) if !ok { http.Redirect(w, r, app.Config.ThreadrDir+"/login/", http.StatusFound) return } cookie, _ := r.Cookie("threadr_cookie_banner") if r.URL.Query().Get("ws") == "true" { // Handle WebSocket connection ws, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Printf("Error upgrading to WebSocket: %v", err) return } hub.register <- ws hub.mutex.Lock() hub.clients[ws] = userID hub.mutex.Unlock() defer func() { hub.unregister <- ws }() for { _, msg, err := ws.ReadMessage() if err != nil { log.Printf("Error reading WebSocket message: %v", err) break } var chatMsg struct { Type string `json:"type"` Content string `json:"content"` ReplyTo int `json:"replyTo"` } if err := json.Unmarshal(msg, &chatMsg); err != nil { log.Printf("Error unmarshaling message: %v", err) continue } if chatMsg.Type == "message" { msgObj := models.ChatMessage{ UserID: userID, Content: chatMsg.Content, ReplyTo: chatMsg.ReplyTo, } if err := models.CreateChatMessage(app.DB, msgObj); err != nil { log.Printf("Error saving chat message: %v", err) continue } // Fetch the saved message with timestamp and user details var msgID int app.DB.QueryRow("SELECT LAST_INSERT_ID()").Scan(&msgID) savedMsg, err := models.GetChatMessageByID(app.DB, msgID) if err != nil { log.Printf("Error fetching saved message: %v", err) continue } response, _ := json.Marshal(savedMsg) hub.broadcast <- response } } return } if r.URL.Query().Get("autocomplete") == "true" { // Handle autocomplete for mentions prefix := r.URL.Query().Get("prefix") usernames, err := models.GetUsernamesMatching(app.DB, prefix) if err != nil { log.Printf("Error fetching usernames for autocomplete: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } response, _ := json.Marshal(usernames) w.Header().Set("Content-Type", "application/json") w.Write(response) return } // Render chat page messages, err := models.GetRecentChatMessages(app.DB, 50) if err != nil { log.Printf("Error fetching chat messages: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } // Reverse messages to show oldest first for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 { messages[i], messages[j] = messages[j], messages[i] } data := struct { PageData Messages []models.ChatMessage }{ PageData: PageData{ Title: "ThreadR - Chat", Navbar: "chat", LoggedIn: true, ShowCookieBanner: cookie == nil || cookie.Value != "accepted", BasePath: app.Config.ThreadrDir, StaticPath: app.Config.ThreadrDir + "/static", CurrentURL: r.URL.Path, }, Messages: messages, } if err := app.Tmpl.ExecuteTemplate(w, "chat", data); err != nil { log.Printf("Error executing template in ChatHandler: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } } }