package middleware import ( "crypto/subtle" "errors" "net/http" "strings" "time" "github.com/labstack/echo/v4" "github.com/labstack/gommon/random" ) type ( // CSRFConfig defines the config for CSRF middleware. CSRFConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper // TokenLength is the length of the generated token. TokenLength uint8 `yaml:"token_length"` // Optional. Default value 32. // TokenLookup is a string in the form of ":" that is used // to extract token from the request. // Optional. Default value "header:X-CSRF-Token". // Possible values: // - "header:" // - "form:" // - "query:" TokenLookup string `yaml:"token_lookup"` // Context key to store generated CSRF token into context. // Optional. Default value "csrf". ContextKey string `yaml:"context_key"` // Name of the CSRF cookie. This cookie will store CSRF token. // Optional. Default value "csrf". CookieName string `yaml:"cookie_name"` // Domain of the CSRF cookie. // Optional. Default value none. CookieDomain string `yaml:"cookie_domain"` // Path of the CSRF cookie. // Optional. Default value none. CookiePath string `yaml:"cookie_path"` // Max age (in seconds) of the CSRF cookie. // Optional. Default value 86400 (24hr). CookieMaxAge int `yaml:"cookie_max_age"` // Indicates if CSRF cookie is secure. // Optional. Default value false. CookieSecure bool `yaml:"cookie_secure"` // Indicates if CSRF cookie is HTTP only. // Optional. Default value false. CookieHTTPOnly bool `yaml:"cookie_http_only"` } // csrfTokenExtractor defines a function that takes `echo.Context` and returns // either a token or an error. csrfTokenExtractor func(echo.Context) (string, error) ) var ( // DefaultCSRFConfig is the default CSRF middleware config. DefaultCSRFConfig = CSRFConfig{ Skipper: DefaultSkipper, TokenLength: 32, TokenLookup: "header:" + echo.HeaderXCSRFToken, ContextKey: "csrf", CookieName: "_csrf", CookieMaxAge: 86400, } ) // CSRF returns a Cross-Site Request Forgery (CSRF) middleware. // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery func CSRF() echo.MiddlewareFunc { c := DefaultCSRFConfig return CSRFWithConfig(c) } // CSRFWithConfig returns a CSRF middleware with config. // See `CSRF()`. func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { // Defaults if config.Skipper == nil { config.Skipper = DefaultCSRFConfig.Skipper } if config.TokenLength == 0 { config.TokenLength = DefaultCSRFConfig.TokenLength } if config.TokenLookup == "" { config.TokenLookup = DefaultCSRFConfig.TokenLookup } if config.ContextKey == "" { config.ContextKey = DefaultCSRFConfig.ContextKey } if config.CookieName == "" { config.CookieName = DefaultCSRFConfig.CookieName } if config.CookieMaxAge == 0 { config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge } // Initialize parts := strings.Split(config.TokenLookup, ":") extractor := csrfTokenFromHeader(parts[1]) switch parts[0] { case "form": extractor = csrfTokenFromForm(parts[1]) case "query": extractor = csrfTokenFromQuery(parts[1]) } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if config.Skipper(c) { return next(c) } req := c.Request() k, err := c.Cookie(config.CookieName) token := "" // Generate token if err != nil { token = random.String(config.TokenLength) } else { // Reuse token token = k.Value } switch req.Method { case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace: default: // Validate token only for requests which are not defined as 'safe' by RFC7231 clientToken, err := extractor(c) if err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } if !validateCSRFToken(token, clientToken) { return echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") } } // Set CSRF cookie cookie := new(http.Cookie) cookie.Name = config.CookieName cookie.Value = token if config.CookiePath != "" { cookie.Path = config.CookiePath } if config.CookieDomain != "" { cookie.Domain = config.CookieDomain } cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second) cookie.Secure = config.CookieSecure cookie.HttpOnly = config.CookieHTTPOnly c.SetCookie(cookie) // Store token in the context c.Set(config.ContextKey, token) // Protect clients from caching the response c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie) return next(c) } } } // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the // provided request header. func csrfTokenFromHeader(header string) csrfTokenExtractor { return func(c echo.Context) (string, error) { return c.Request().Header.Get(header), nil } } // csrfTokenFromForm returns a `csrfTokenExtractor` that extracts token from the // provided form parameter. func csrfTokenFromForm(param string) csrfTokenExtractor { return func(c echo.Context) (string, error) { token := c.FormValue(param) if token == "" { return "", errors.New("missing csrf token in the form parameter") } return token, nil } } // csrfTokenFromQuery returns a `csrfTokenExtractor` that extracts token from the // provided query parameter. func csrfTokenFromQuery(param string) csrfTokenExtractor { return func(c echo.Context) (string, error) { token := c.QueryParam(param) if token == "" { return "", errors.New("missing csrf token in the query string") } return token, nil } } func validateCSRFToken(token, clientToken string) bool { return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1 }