188 lines
4.8 KiB
Go
188 lines
4.8 KiB
Go
package handlers
|
|
|
|
import (
|
|
"encoding/json"
|
|
"log"
|
|
"net/http"
|
|
"sync"
|
|
"threadr/models"
|
|
"github.com/gorilla/sessions"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
var upgrader = websocket.Upgrader{
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
return true // Allow all origins for now; restrict in production
|
|
},
|
|
}
|
|
|
|
// ChatHub manages WebSocket connections and broadcasts messages
|
|
type ChatHub struct {
|
|
clients map[*websocket.Conn]int // Map of connections to user IDs
|
|
broadcast chan []byte
|
|
register chan *websocket.Conn
|
|
unregister chan *websocket.Conn
|
|
mutex sync.Mutex
|
|
}
|
|
|
|
func NewChatHub() *ChatHub {
|
|
return &ChatHub{
|
|
clients: make(map[*websocket.Conn]int),
|
|
broadcast: make(chan []byte),
|
|
register: make(chan *websocket.Conn),
|
|
unregister: make(chan *websocket.Conn),
|
|
}
|
|
}
|
|
|
|
func (h *ChatHub) Run() {
|
|
for {
|
|
select {
|
|
case client := <-h.register:
|
|
h.mutex.Lock()
|
|
h.clients[client] = 0 // UserID set later
|
|
h.mutex.Unlock()
|
|
case client := <-h.unregister:
|
|
h.mutex.Lock()
|
|
delete(h.clients, client)
|
|
h.mutex.Unlock()
|
|
client.Close()
|
|
case message := <-h.broadcast:
|
|
h.mutex.Lock()
|
|
for client := range h.clients {
|
|
err := client.WriteMessage(websocket.TextMessage, message)
|
|
if err != nil {
|
|
log.Printf("Error broadcasting message: %v", err)
|
|
client.Close()
|
|
delete(h.clients, client)
|
|
}
|
|
}
|
|
h.mutex.Unlock()
|
|
}
|
|
}
|
|
}
|
|
|
|
var hub = NewChatHub()
|
|
|
|
func init() {
|
|
go hub.Run()
|
|
}
|
|
|
|
func ChatHandler(app *App) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
session := r.Context().Value("session").(*sessions.Session)
|
|
userID, ok := session.Values["user_id"].(int)
|
|
if !ok {
|
|
http.Redirect(w, r, app.Config.ThreadrDir+"/login/", http.StatusFound)
|
|
return
|
|
}
|
|
cookie, _ := r.Cookie("threadr_cookie_banner")
|
|
|
|
if r.URL.Query().Get("ws") == "true" {
|
|
// Handle WebSocket connection
|
|
ws, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
log.Printf("Error upgrading to WebSocket: %v", err)
|
|
return
|
|
}
|
|
hub.register <- ws
|
|
hub.mutex.Lock()
|
|
hub.clients[ws] = userID
|
|
hub.mutex.Unlock()
|
|
|
|
defer func() {
|
|
hub.unregister <- ws
|
|
}()
|
|
|
|
for {
|
|
_, msg, err := ws.ReadMessage()
|
|
if err != nil {
|
|
log.Printf("Error reading WebSocket message: %v", err)
|
|
break
|
|
}
|
|
var chatMsg struct {
|
|
Type string `json:"type"`
|
|
Content string `json:"content"`
|
|
ReplyTo int `json:"replyTo"`
|
|
}
|
|
if err := json.Unmarshal(msg, &chatMsg); err != nil {
|
|
log.Printf("Error unmarshaling message: %v", err)
|
|
continue
|
|
}
|
|
|
|
if chatMsg.Type == "message" {
|
|
msgObj := models.ChatMessage{
|
|
UserID: userID,
|
|
Content: chatMsg.Content,
|
|
ReplyTo: chatMsg.ReplyTo,
|
|
}
|
|
if err := models.CreateChatMessage(app.DB, msgObj); err != nil {
|
|
log.Printf("Error saving chat message: %v", err)
|
|
continue
|
|
}
|
|
// Fetch the saved message with timestamp and user details
|
|
var msgID int
|
|
app.DB.QueryRow("SELECT LAST_INSERT_ID()").Scan(&msgID)
|
|
savedMsg, err := models.GetChatMessageByID(app.DB, msgID)
|
|
if err != nil {
|
|
log.Printf("Error fetching saved message: %v", err)
|
|
continue
|
|
}
|
|
response, _ := json.Marshal(savedMsg)
|
|
hub.broadcast <- response
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
if r.URL.Query().Get("autocomplete") == "true" {
|
|
// Handle autocomplete for mentions
|
|
prefix := r.URL.Query().Get("prefix")
|
|
usernames, err := models.GetUsernamesMatching(app.DB, prefix)
|
|
if err != nil {
|
|
log.Printf("Error fetching usernames for autocomplete: %v", err)
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
response, _ := json.Marshal(usernames)
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Write(response)
|
|
return
|
|
}
|
|
|
|
// Render chat page
|
|
messages, err := models.GetRecentChatMessages(app.DB, 50)
|
|
if err != nil {
|
|
log.Printf("Error fetching chat messages: %v", err)
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Reverse messages to show oldest first
|
|
for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 {
|
|
messages[i], messages[j] = messages[j], messages[i]
|
|
}
|
|
|
|
data := struct {
|
|
PageData
|
|
Messages []models.ChatMessage
|
|
}{
|
|
PageData: PageData{
|
|
Title: "ThreadR - Chat",
|
|
Navbar: "chat",
|
|
LoggedIn: true,
|
|
ShowCookieBanner: cookie == nil || cookie.Value != "accepted",
|
|
BasePath: app.Config.ThreadrDir,
|
|
StaticPath: app.Config.ThreadrDir + "/static",
|
|
CurrentURL: r.URL.Path,
|
|
},
|
|
Messages: messages,
|
|
}
|
|
if err := app.Tmpl.ExecuteTemplate(w, "chat", data); err != nil {
|
|
log.Printf("Error executing template in ChatHandler: %v", err)
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
} |