package handlers
import (
"encoding/json"
"html/template"
"log"
"net/http"
"strconv"
"sync"
"threadr/models"
"github.com/gorilla/sessions"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins for now; restrict in production
},
}
type Client struct {
conn *websocket.Conn
userID int
boardID int
}
type ChatHub struct {
clients map[*Client]bool
broadcast chan models.ChatMessage
register chan *Client
unregister chan *Client
mutex sync.Mutex
}
func NewChatHub() *ChatHub {
return &ChatHub{
clients: make(map[*Client]bool),
broadcast: make(chan models.ChatMessage),
register: make(chan *Client),
unregister: make(chan *Client),
}
}
func (h *ChatHub) Run() {
for {
select {
case client := <-h.register:
h.mutex.Lock()
h.clients[client] = true
h.mutex.Unlock()
case client := <-h.unregister:
h.mutex.Lock()
if _, ok := h.clients[client]; ok {
delete(h.clients, client)
client.conn.Close()
}
h.mutex.Unlock()
case message := <-h.broadcast:
h.mutex.Lock()
for client := range h.clients {
if client.boardID == message.BoardID {
response, _ := json.Marshal(message)
err := client.conn.WriteMessage(websocket.TextMessage, response)
if err != nil {
log.Printf("Error broadcasting message: %v", err)
client.conn.Close()
delete(h.clients, client)
}
}
}
h.mutex.Unlock()
}
}
}
var hub = NewChatHub()
func init() {
go hub.Run()
}
func ChatHandler(app *App) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := r.Context().Value("session").(*sessions.Session)
userID, ok := session.Values["user_id"].(int)
if !ok {
http.Redirect(w, r, app.Config.ThreadrDir+"/login/", http.StatusFound)
return
}
cookie, _ := r.Cookie("threadr_cookie_banner")
boardIDStr := r.URL.Query().Get("id")
boardID, err := strconv.Atoi(boardIDStr)
if err != nil {
http.Error(w, "Invalid board ID", http.StatusBadRequest)
return
}
board, err := models.GetBoardByID(app.DB, boardID)
if err != nil {
log.Printf("Error fetching board: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
if board == nil {
http.Error(w, "Chat board not found", http.StatusNotFound)
return
}
if board.Type != "chat" {
http.Error(w, "This is not a chat board", http.StatusBadRequest)
return
}
if board.Private {
hasPerm, err := models.HasBoardPermission(app.DB, userID, boardID, models.PermViewBoard)
if err != nil {
log.Printf("Error checking permission: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
if !hasPerm {
http.Error(w, "You do not have permission to view this chat", http.StatusForbidden)
return
}
}
currentUser, err := models.GetUserByID(app.DB, userID)
if err != nil {
log.Printf("Error fetching current user: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
if currentUser == nil {
http.Error(w, "User not found", http.StatusNotFound)
return
}
currentUsername := currentUser.Username
if r.URL.Query().Get("ws") == "true" {
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("Error upgrading to WebSocket: %v", err)
return
}
client := &Client{conn: ws, userID: userID, boardID: boardID}
hub.register <- client
defer func() {
hub.unregister <- client
}()
for {
_, msg, err := ws.ReadMessage()
if err != nil {
log.Printf("Error reading WebSocket message: %v", err)
break
}
var chatMsg struct {
Type string `json:"type"`
Content string `json:"content"`
ReplyTo int `json:"replyTo"`
}
if err := json.Unmarshal(msg, &chatMsg); err != nil {
log.Printf("Error unmarshaling message: %v", err)
continue
}
if chatMsg.Type == "message" {
msgObj := models.ChatMessage{
BoardID: boardID,
UserID: userID,
Content: chatMsg.Content,
ReplyTo: chatMsg.ReplyTo,
}
if err := models.CreateChatMessage(app.DB, msgObj); err != nil {
log.Printf("Error saving chat message: %v", err)
continue
}
var msgID int
app.DB.QueryRow("SELECT LAST_INSERT_ID()").Scan(&msgID)
savedMsg, err := models.GetChatMessageByID(app.DB, msgID)
if err != nil {
log.Printf("Error fetching saved message: %v", err)
continue
}
hub.broadcast <- *savedMsg
}
}
return
}
messages, err := models.GetRecentChatMessages(app.DB, boardID, 50)
if err != nil {
log.Printf("Error fetching chat messages: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 {
messages[i], messages[j] = messages[j], messages[i]
}
allUsernames, err := models.GetUsernamesInBoard(app.DB, boardID)
if err != nil {
log.Printf("Error fetching usernames for board: %v", err)
allUsernames = []string{}
}
allUsernamesJSON, _ := json.Marshal(allUsernames)
data := struct {
PageData
Board models.Board
Messages []models.ChatMessage
AllUsernames template.JS
CurrentUsername string
}{
PageData: PageData{
Title: "ThreadR Chat - " + board.Name,
Navbar: "boards",
LoggedIn: true,
ShowCookieBanner: cookie == nil || cookie.Value != "accepted",
BasePath: app.Config.ThreadrDir,
StaticPath: app.Config.ThreadrDir + "/static",
CurrentURL: r.URL.Path,
},
Board: *board,
Messages: messages,
AllUsernames: template.JS(allUsernamesJSON),
CurrentUsername: currentUsername,
}
if err := app.Tmpl.ExecuteTemplate(w, "chat", data); err != nil {
log.Printf("Error executing template in ChatHandler: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
}
}