Files
2025-04-09 01:00:12 +01:00

379 lines
12 KiB
Go

package csrf
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"slices"
"github.com/gorilla/securecookie"
)
// CSRF token length in bytes.
const tokenLength = 32
// Context/session keys & prefixes
const (
tokenKey string = "gorilla.csrf.Token" // #nosec G101
formKey string = "gorilla.csrf.Form" // #nosec G101
errorKey string = "gorilla.csrf.Error"
skipCheckKey string = "gorilla.csrf.Skip"
cookieName string = "_gorilla_csrf"
errorPrefix string = "gorilla/csrf: "
)
type contextKey string
// PlaintextHTTPContextKey is the context key used to store whether the request
// is being served via plaintext HTTP. This is used to signal to the middleware
// that strict Referer checking should not be enforced as is done for HTTPS by
// default.
const PlaintextHTTPContextKey contextKey = "plaintext"
var (
// The name value used in form fields.
fieldName = tokenKey
// defaultAge sets the default MaxAge for cookies.
defaultAge = 3600 * 12
// The default HTTP request header to inspect
headerName = "X-CSRF-Token"
// Idempotent (safe) methods as defined by RFC7231 section 4.2.2.
safeMethods = []string{"GET", "HEAD", "OPTIONS", "TRACE"}
)
// TemplateTag provides a default template tag - e.g. {{ .csrfField }} - for use
// with the TemplateField function.
var TemplateTag = "csrfField"
var (
// ErrNoReferer is returned when a HTTPS request provides an empty Referer
// header.
ErrNoReferer = errors.New("referer not supplied")
// ErrBadOrigin is returned when the Origin header is present and is not a
// trusted origin.
ErrBadOrigin = errors.New("origin invalid")
// ErrBadReferer is returned when the scheme & host in the URL do not match
// the supplied Referer header.
ErrBadReferer = errors.New("referer invalid")
// ErrNoToken is returned if no CSRF token is supplied in the request.
ErrNoToken = errors.New("CSRF token not found in request")
// ErrBadToken is returned if the CSRF token in the request does not match
// the token in the session, or is otherwise malformed.
ErrBadToken = errors.New("CSRF token invalid")
)
// SameSiteMode allows a server to define a cookie attribute making it impossible for
// the browser to send this cookie along with cross-site requests. The main
// goal is to mitigate the risk of cross-origin information leakage, and provide
// some protection against cross-site request forgery attacks.
//
// See https://tools.ietf.org/html/draft-ietf-httpbis-cookie-same-site-00 for details.
type SameSiteMode int
// SameSite options
const (
// SameSiteDefaultMode sets the `SameSite` cookie attribute, which is
// invalid in some older browsers due to changes in the SameSite spec. These
// browsers will not send the cookie to the server.
// csrf uses SameSiteLaxMode (SameSite=Lax) as the default as of v1.7.0+
SameSiteDefaultMode SameSiteMode = iota + 1
SameSiteLaxMode
SameSiteStrictMode
SameSiteNoneMode
)
type csrf struct {
h http.Handler
sc *securecookie.SecureCookie
st store
opts options
}
// options contains the optional settings for the CSRF middleware.
type options struct {
MaxAge int
Domain string
Path string
// Note that the function and field names match the case of the associated
// http.Cookie field instead of the "correct" HTTPOnly name that golint suggests.
HttpOnly bool
Secure bool
SameSite SameSiteMode
RequestHeader string
FieldName string
ErrorHandler http.Handler
CookieName string
TrustedOrigins []string
}
// Protect is HTTP middleware that provides Cross-Site Request Forgery
// protection.
//
// It securely generates a masked (unique-per-request) token that
// can be embedded in the HTTP response (e.g. form field or HTTP header).
// The original (unmasked) token is stored in the session, which is inaccessible
// by an attacker (provided you are using HTTPS). Subsequent requests are
// expected to include this token, which is compared against the session token.
// Requests that do not provide a matching token are served with a HTTP 403
// 'Forbidden' error response.
//
// Example:
//
// package main
//
// import (
// "html/template"
//
// "github.com/gorilla/csrf"
// "github.com/gorilla/mux"
// )
//
// var t = template.Must(template.New("signup_form.tmpl").Parse(form))
//
// func main() {
// r := mux.NewRouter()
//
// r.HandleFunc("/signup", GetSignupForm)
// // POST requests without a valid token will return a HTTP 403 Forbidden.
// r.HandleFunc("/signup/post", PostSignupForm)
//
// // Add the middleware to your router.
// http.ListenAndServe(":8000",
// // Note that the authentication key provided should be 32 bytes
// // long and persist across application restarts.
// csrf.Protect([]byte("32-byte-long-auth-key"))(r))
// }
//
// func GetSignupForm(w http.ResponseWriter, r *http.Request) {
// // signup_form.tmpl just needs a {{ .csrfField }} template tag for
// // csrf.TemplateField to inject the CSRF token into. Easy!
// t.ExecuteTemplate(w, "signup_form.tmpl", map[string]interface{}{
// csrf.TemplateTag: csrf.TemplateField(r),
// })
// // We could also retrieve the token directly from csrf.Token(r) and
// // set it in the request header - w.Header.Set("X-CSRF-Token", token)
// // This is useful if you're sending JSON to clients or a front-end JavaScript
// // framework.
// }
func Protect(authKey []byte, opts ...Option) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
cs := parseOptions(h, opts...)
// Set the defaults if no options have been specified
if cs.opts.ErrorHandler == nil {
cs.opts.ErrorHandler = http.HandlerFunc(unauthorizedHandler)
}
if cs.opts.MaxAge < 0 {
// Default of 12 hours
cs.opts.MaxAge = defaultAge
}
if cs.opts.FieldName == "" {
cs.opts.FieldName = fieldName
}
if cs.opts.CookieName == "" {
cs.opts.CookieName = cookieName
}
if cs.opts.RequestHeader == "" {
cs.opts.RequestHeader = headerName
}
// Create an authenticated securecookie instance.
if cs.sc == nil {
cs.sc = securecookie.New(authKey, nil)
// Use JSON serialization (faster than one-off gob encoding)
cs.sc.SetSerializer(securecookie.JSONEncoder{})
// Set the MaxAge of the underlying securecookie.
cs.sc.MaxAge(cs.opts.MaxAge)
}
if cs.st == nil {
// Default to the cookieStore
cs.st = &cookieStore{
name: cs.opts.CookieName,
maxAge: cs.opts.MaxAge,
secure: cs.opts.Secure,
httpOnly: cs.opts.HttpOnly,
sameSite: cs.opts.SameSite,
path: cs.opts.Path,
domain: cs.opts.Domain,
sc: cs.sc,
}
}
return cs
}
}
// Implements http.Handler for the csrf type.
func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Skip the check if directed to. This should always be a bool.
if val, err := contextGet(r, skipCheckKey); err == nil {
if skip, ok := val.(bool); ok {
if skip {
cs.h.ServeHTTP(w, r)
return
}
}
}
// Retrieve the token from the session.
// An error represents either a cookie that failed HMAC validation
// or that doesn't exist.
realToken, err := cs.st.Get(r)
if err != nil || len(realToken) != tokenLength {
// If there was an error retrieving the token, the token doesn't exist
// yet, or it's the wrong length, generate a new token.
// Note that the new token will (correctly) fail validation downstream
// as it will no longer match the request token.
realToken, err = generateRandomBytes(tokenLength)
if err != nil {
r = envError(r, err)
cs.opts.ErrorHandler.ServeHTTP(w, r)
return
}
// Save the new (real) token in the session store.
err = cs.st.Save(realToken, w)
if err != nil {
r = envError(r, err)
cs.opts.ErrorHandler.ServeHTTP(w, r)
return
}
}
// Save the masked token to the request context
r = contextSave(r, tokenKey, mask(realToken, r))
// Save the field name to the request context
r = contextSave(r, formKey, cs.opts.FieldName)
// HTTP methods not defined as idempotent ("safe") under RFC7231 require
// inspection.
if !contains(safeMethods, r.Method) {
var isPlaintext bool
val := r.Context().Value(PlaintextHTTPContextKey)
if val != nil {
isPlaintext, _ = val.(bool)
}
// take a copy of the request URL to avoid mutating the original
// attached to the request.
// set the scheme & host based on the request context as these are not
// populated by default for server requests
// ref: https://pkg.go.dev/net/http#Request
requestURL := *r.URL // shallow clone
requestURL.Scheme = "https"
if isPlaintext {
requestURL.Scheme = "http"
}
if requestURL.Host == "" {
requestURL.Host = r.Host
}
// if we have an Origin header, check it against our allowlist
origin := r.Header.Get("Origin")
if origin != "" {
parsedOrigin, err := url.Parse(origin)
if err != nil {
r = envError(r, ErrBadOrigin)
cs.opts.ErrorHandler.ServeHTTP(w, r)
return
}
if !sameOrigin(&requestURL, parsedOrigin) && !slices.Contains(cs.opts.TrustedOrigins, parsedOrigin.Host) {
r = envError(r, ErrBadOrigin)
cs.opts.ErrorHandler.ServeHTTP(w, r)
return
}
}
// If we are serving via TLS and have no Origin header, prevent against
// CSRF via HTTP machine in the middle attacks by enforcing strict
// Referer origin checks. Consider an attacker who performs a
// successful HTTP Machine-in-the-Middle attack and uses this to inject
// a form and cause submission to our origin. We strictly disallow
// cleartext HTTP origins and evaluate the domain against an allowlist.
if origin == "" && !isPlaintext {
// Fetch the Referer value. Call the error handler if it's empty or
// otherwise fails to parse.
referer, err := url.Parse(r.Referer())
if err != nil || referer.String() == "" {
r = envError(r, ErrNoReferer)
cs.opts.ErrorHandler.ServeHTTP(w, r)
return
}
// disallow cleartext HTTP referers when serving via TLS
if referer.Scheme == "http" {
r = envError(r, ErrBadReferer)
cs.opts.ErrorHandler.ServeHTTP(w, r)
return
}
// If the request is being served via TLS and the Referer is not the
// same origin, check the domain against our allowlist. We only
// check when we have host information from the referer.
if referer.Host != "" && referer.Host != r.Host && !slices.Contains(cs.opts.TrustedOrigins, referer.Host) {
r = envError(r, ErrBadReferer)
cs.opts.ErrorHandler.ServeHTTP(w, r)
return
}
}
// Retrieve the combined token (pad + masked) token...
maskedToken, err := cs.requestToken(r)
if err != nil {
r = envError(r, ErrBadToken)
cs.opts.ErrorHandler.ServeHTTP(w, r)
return
}
if maskedToken == nil {
r = envError(r, ErrNoToken)
cs.opts.ErrorHandler.ServeHTTP(w, r)
return
}
// ... and unmask it.
requestToken := unmask(maskedToken)
// Compare the request token against the real token
if !compareTokens(requestToken, realToken) {
r = envError(r, ErrBadToken)
cs.opts.ErrorHandler.ServeHTTP(w, r)
return
}
}
// Set the Vary: Cookie header to protect clients from caching the response.
w.Header().Add("Vary", "Cookie")
// Call the wrapped handler/router on success.
cs.h.ServeHTTP(w, r)
// Clear the request context after the handler has completed.
contextClear(r)
}
// PlaintextHTTPRequest accepts as input a http.Request and returns a new
// http.Request with the PlaintextHTTPContextKey set to true. This is used to
// signal to the CSRF middleware that the request is being served over plaintext
// HTTP and that Referer-based origin allow-listing checks should be skipped.
func PlaintextHTTPRequest(r *http.Request) *http.Request {
ctx := context.WithValue(r.Context(), PlaintextHTTPContextKey, true)
return r.WithContext(ctx)
}
// unauthorizedhandler sets a HTTP 403 Forbidden status and writes the
// CSRF failure reason to the response.
func unauthorizedHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, fmt.Sprintf("%s - %s",
http.StatusText(http.StatusForbidden), FailureReason(r)),
http.StatusForbidden)
}