threadr.lostcave.ddnss.de/handlers/chat.go

242 lines
6.7 KiB
Go

package handlers
import (
"encoding/json"
"log"
"net/http"
"sync"
"threadr/models"
"html/template"
"regexp"
"time"
"github.com/gorilla/sessions"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins for now; restrict in production
},
}
// ChatHub manages WebSocket connections and broadcasts messages
type ChatHub struct {
clients map[*websocket.Conn]int // Map of connections to user IDs
broadcast chan []byte
register chan *websocket.Conn
unregister chan *websocket.Conn
mutex sync.Mutex
}
func NewChatHub() *ChatHub {
return &ChatHub{
clients: make(map[*websocket.Conn]int),
broadcast: make(chan []byte),
register: make(chan *websocket.Conn),
unregister: make(chan *websocket.Conn),
}
}
func (h *ChatHub) Run() {
for {
select {
case client := <-h.register:
h.mutex.Lock()
h.clients[client] = 0 // UserID set later
h.mutex.Unlock()
case client := <-h.unregister:
h.mutex.Lock()
delete(h.clients, client)
h.mutex.Unlock()
client.Close()
case message := <-h.broadcast:
h.mutex.Lock()
for client := range h.clients {
err := client.WriteMessage(websocket.TextMessage, message)
if err != nil {
log.Printf("Error broadcasting message: %v", err)
client.Close()
delete(h.clients, client)
}
}
h.mutex.Unlock()
}
}
}
var hub = NewChatHub()
func init() {
go hub.Run()
}
// ChatDisplayMessage is used to pass chat message data to the template,
// with content pre-formatted for HTML rendering.
type ChatDisplayMessage struct {
ID int
UserID int
Content string // Raw content for JS processing if needed, maybe not
FormattedContent template.HTML // Content with mentions as HTML spans
ReplyTo int
Timestamp time.Time
Username string
PfpURL string
}
// formatChatMessageContent takes raw message content and wraps @mentions in HTML spans.
func formatChatMessageContent(content string) template.HTML {
// Regex to find @username. $0 will be replaced with the entire match (e.g., @username)
r := regexp.MustCompile(`@([a-zA-Z0-9_]+)`)
formatted := r.ReplaceAllString(content, `<span class="chat-message-mention">$0</span>`)
return template.HTML(formatted)
}
func ChatHandler(app *App) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
session := r.Context().Value("session").(*sessions.Session)
userID, ok := session.Values["user_id"].(int)
if !ok {
http.Redirect(w, r, app.Config.ThreadrDir+"/login/", http.StatusFound)
return
}
cookie, _ := r.Cookie("threadr_cookie_banner")
if r.URL.Query().Get("ws") == "true" {
// Handle WebSocket connection
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Printf("Error upgrading to WebSocket: %v", err)
return
}
hub.register <- ws
hub.mutex.Lock()
hub.clients[ws] = userID
hub.mutex.Unlock()
defer func() {
hub.unregister <- ws
}()
for {
_, msg, err := ws.ReadMessage()
if err != nil {
log.Printf("Error reading WebSocket message: %v", err)
break
}
var chatMsg struct {
Type string `json:"type"`
Content string `json:"content"`
ReplyTo int `json:"replyTo"`
}
if err := json.Unmarshal(msg, &chatMsg); err != nil {
log.Printf("Error unmarshaling message: %v", err)
continue
}
if chatMsg.Type == "message" {
msgObj := models.ChatMessage{
UserID: userID,
Content: chatMsg.Content,
ReplyTo: chatMsg.ReplyTo,
}
// CreateChatMessage now also handles notifications and returns the new message ID
newMsgID, err := models.CreateChatMessage(app.DB, msgObj)
if err != nil {
log.Printf("Error saving chat message: %v", err)
continue
}
// Fetch the saved message with timestamp and user details
savedMsg, err := models.GetChatMessageByID(app.DB, newMsgID)
if err != nil {
log.Printf("Error fetching saved message: %v", err)
continue
}
if savedMsg == nil {
log.Printf("Error: saved message with ID %d not found after creation", newMsgID)
continue
}
// Prepare ChatDisplayMessage for broadcast
displayMsg := ChatDisplayMessage{
ID: savedMsg.ID,
UserID: savedMsg.UserID,
Content: savedMsg.Content, // Keep raw content if needed by JS, or remove
FormattedContent: formatChatMessageContent(savedMsg.Content),
ReplyTo: savedMsg.ReplyTo,
Timestamp: savedMsg.Timestamp,
Username: savedMsg.Username,
PfpURL: savedMsg.PfpURL,
}
response, _ := json.Marshal(displayMsg)
hub.broadcast <- response
}
}
return
}
if r.URL.Query().Get("autocomplete") == "true" {
// Handle autocomplete for mentions
prefix := r.URL.Query().Get("prefix")
usernames, err := models.GetUsernamesMatching(app.DB, prefix)
if err != nil {
log.Printf("Error fetching usernames for autocomplete: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
response, _ := json.Marshal(usernames)
w.Header().Set("Content-Type", "application/json")
w.Write(response)
return
}
// Render chat page
rawMessages, err := models.GetRecentChatMessages(app.DB, 50)
if err != nil {
log.Printf("Error fetching chat messages: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
var displayMessages []ChatDisplayMessage
for _, msg := range rawMessages {
displayMessages = append(displayMessages, ChatDisplayMessage{
ID: msg.ID,
UserID: msg.UserID,
Content: msg.Content, // Raw content
FormattedContent: formatChatMessageContent(msg.Content),
ReplyTo: msg.ReplyTo,
Timestamp: msg.Timestamp,
Username: msg.Username,
PfpURL: msg.PfpURL,
})
}
// Reverse messages to show oldest first
for i, j := 0, len(displayMessages)-1; i < j; i, j = i+1, j-1 {
displayMessages[i], displayMessages[j] = displayMessages[j], displayMessages[i]
}
data := struct {
PageData
Messages []ChatDisplayMessage
}{
PageData: PageData{
Title: "ThreadR - Chat",
Navbar: "chat",
LoggedIn: true,
ShowCookieBanner: cookie == nil || cookie.Value != "accepted",
BasePath: app.Config.ThreadrDir,
StaticPath: app.Config.ThreadrDir + "/static",
CurrentURL: r.URL.Path,
},
Messages: displayMessages,
}
if err := app.Tmpl.ExecuteTemplate(w, "chat", data); err != nil {
log.Printf("Error executing template in ChatHandler: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
}
}