package handlers import ( "encoding/json" "html/template" "log" "net/http" "strconv" "sync" "threadr/models" "github.com/gorilla/sessions" "github.com/gorilla/websocket" ) var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, CheckOrigin: func(r *http.Request) bool { return true // Allow all origins for now; restrict in production }, } type Client struct { conn *websocket.Conn userID int boardID int } type ChatHub struct { clients map[*Client]bool broadcast chan models.ChatMessage register chan *Client unregister chan *Client mutex sync.Mutex } func NewChatHub() *ChatHub { return &ChatHub{ clients: make(map[*Client]bool), broadcast: make(chan models.ChatMessage), register: make(chan *Client), unregister: make(chan *Client), } } func (h *ChatHub) Run() { for { select { case client := <-h.register: h.mutex.Lock() h.clients[client] = true h.mutex.Unlock() case client := <-h.unregister: h.mutex.Lock() if _, ok := h.clients[client]; ok { delete(h.clients, client) client.conn.Close() } h.mutex.Unlock() case message := <-h.broadcast: h.mutex.Lock() for client := range h.clients { if client.boardID == message.BoardID { response, _ := json.Marshal(message) err := client.conn.WriteMessage(websocket.TextMessage, response) if err != nil { log.Printf("Error broadcasting message: %v", err) client.conn.Close() delete(h.clients, client) } } } h.mutex.Unlock() } } } var hub = NewChatHub() func init() { go hub.Run() } func ChatHandler(app *App) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { session := r.Context().Value("session").(*sessions.Session) userID, ok := session.Values["user_id"].(int) if !ok { http.Redirect(w, r, app.Config.ThreadrDir+"/login/", http.StatusFound) return } cookie, _ := r.Cookie("threadr_cookie_banner") boardIDStr := r.URL.Query().Get("id") boardID, err := strconv.Atoi(boardIDStr) if err != nil { http.Error(w, "Invalid board ID", http.StatusBadRequest) return } board, err := models.GetBoardByID(app.DB, boardID) if err != nil { log.Printf("Error fetching board: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } if board == nil { http.Error(w, "Chat board not found", http.StatusNotFound) return } if board.Type != "chat" { http.Error(w, "This is not a chat board", http.StatusBadRequest) return } if board.Private { hasPerm, err := models.HasBoardPermission(app.DB, userID, boardID, models.PermViewBoard) if err != nil { log.Printf("Error checking permission: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } if !hasPerm { http.Error(w, "You do not have permission to view this chat", http.StatusForbidden) return } } if r.URL.Query().Get("ws") == "true" { ws, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Printf("Error upgrading to WebSocket: %v", err) return } client := &Client{conn: ws, userID: userID, boardID: boardID} hub.register <- client defer func() { hub.unregister <- client }() for { _, msg, err := ws.ReadMessage() if err != nil { log.Printf("Error reading WebSocket message: %v", err) break } var chatMsg struct { Type string `json:"type"` Content string `json:"content"` ReplyTo int `json:"replyTo"` } if err := json.Unmarshal(msg, &chatMsg); err != nil { log.Printf("Error unmarshaling message: %v", err) continue } if chatMsg.Type == "message" { msgObj := models.ChatMessage{ BoardID: boardID, UserID: userID, Content: chatMsg.Content, ReplyTo: chatMsg.ReplyTo, } if err := models.CreateChatMessage(app.DB, msgObj); err != nil { log.Printf("Error saving chat message: %v", err) continue } var msgID int app.DB.QueryRow("SELECT LAST_INSERT_ID()").Scan(&msgID) savedMsg, err := models.GetChatMessageByID(app.DB, msgID) if err != nil { log.Printf("Error fetching saved message: %v", err) continue } hub.broadcast <- *savedMsg } } return } messages, err := models.GetRecentChatMessages(app.DB, boardID, 50) if err != nil { log.Printf("Error fetching chat messages: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 { messages[i], messages[j] = messages[j], messages[i] } allUsernames, err := models.GetUsernamesInBoard(app.DB, boardID) if err != nil { log.Printf("Error fetching usernames for board: %v", err) allUsernames = []string{} } allUsernamesJSON, _ := json.Marshal(allUsernames) data := struct { PageData Board models.Board Messages []models.ChatMessage AllUsernames template.JS }{ PageData: PageData{ Title: "ThreadR Chat - " + board.Name, Navbar: "boards", LoggedIn: true, ShowCookieBanner: cookie == nil || cookie.Value != "accepted", BasePath: app.Config.ThreadrDir, StaticPath: app.Config.ThreadrDir + "/static", CurrentURL: r.URL.Path, }, Board: *board, Messages: messages, AllUsernames: template.JS(allUsernamesJSON), } if err := app.Tmpl.ExecuteTemplate(w, "chat", data); err != nil { log.Printf("Error executing template in ChatHandler: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } } }