56 lines
1.2 KiB
Go
56 lines
1.2 KiB
Go
package handlers
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/subtle"
|
|
"encoding/base64"
|
|
"net/http"
|
|
|
|
"github.com/gorilla/sessions"
|
|
)
|
|
|
|
const csrfSessionKey = "csrf_token"
|
|
|
|
func (app *App) ensureCSRFToken(session *sessions.Session) (string, error) {
|
|
if token, ok := session.Values[csrfSessionKey].(string); ok && token != "" {
|
|
return token, nil
|
|
}
|
|
|
|
raw := make([]byte, 32)
|
|
if _, err := rand.Read(raw); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
token := base64.RawURLEncoding.EncodeToString(raw)
|
|
session.Values[csrfSessionKey] = token
|
|
return token, nil
|
|
}
|
|
|
|
func (app *App) csrfToken(session *sessions.Session) string {
|
|
token, err := app.ensureCSRFToken(session)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
return token
|
|
}
|
|
|
|
func (app *App) validateCSRFToken(r *http.Request, session *sessions.Session) bool {
|
|
expected, ok := session.Values[csrfSessionKey].(string)
|
|
if !ok || expected == "" {
|
|
return false
|
|
}
|
|
|
|
provided := r.Header.Get("X-CSRF-Token")
|
|
if provided == "" {
|
|
provided = r.FormValue("csrf_token")
|
|
}
|
|
if provided == "" {
|
|
provided = r.URL.Query().Get("csrf_token")
|
|
}
|
|
if len(provided) != len(expected) {
|
|
return false
|
|
}
|
|
|
|
return subtle.ConstantTimeCompare([]byte(provided), []byte(expected)) == 1
|
|
}
|