4
0
mirror of https://github.com/cwinfo/matterbridge.git synced 2025-07-03 08:27:44 +00:00

Update vendor (#1414)

This commit is contained in:
Wim
2021-03-20 22:40:23 +01:00
committed by GitHub
parent 3a8857c8c9
commit ee5d9b43b5
187 changed files with 6746 additions and 1611 deletions

View File

@ -73,7 +73,7 @@ func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc {
auth := c.Request().Header.Get(echo.HeaderAuthorization)
l := len(basic)
if len(auth) > l+1 && strings.ToLower(auth[:l]) == basic {
if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) {
b, err := base64.StdEncoding.DecodeString(auth[l+1:])
if err != nil {
return err

View File

@ -8,6 +8,7 @@ import (
"net"
"net/http"
"strings"
"sync"
"github.com/labstack/echo/v4"
)
@ -58,6 +59,8 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
config.Level = DefaultGzipConfig.Level
}
pool := gzipCompressPool(config)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
@ -68,11 +71,13 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
res.Header().Add(echo.HeaderVary, echo.HeaderAcceptEncoding)
if strings.Contains(c.Request().Header.Get(echo.HeaderAcceptEncoding), gzipScheme) {
res.Header().Set(echo.HeaderContentEncoding, gzipScheme) // Issue #806
rw := res.Writer
w, err := gzip.NewWriterLevel(rw, config.Level)
if err != nil {
return err
i := pool.Get()
w, ok := i.(*gzip.Writer)
if !ok {
return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
}
rw := res.Writer
w.Reset(rw)
defer func() {
if res.Size == 0 {
if res.Header().Get(echo.HeaderContentEncoding) == gzipScheme {
@ -85,6 +90,7 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc {
w.Reset(ioutil.Discard)
}
w.Close()
pool.Put(w)
}()
grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw}
res.Writer = grw
@ -126,3 +132,15 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error {
}
return http.ErrNotSupported
}
func gzipCompressPool(config GzipConfig) sync.Pool {
return sync.Pool{
New: func() interface{} {
w, err := gzip.NewWriterLevel(ioutil.Discard, config.Level)
if err != nil {
return err
}
return w
},
}
}

View File

@ -2,6 +2,7 @@ package middleware
import (
"net/http"
"regexp"
"strconv"
"strings"
@ -18,6 +19,13 @@ type (
// Optional. Default value []string{"*"}.
AllowOrigins []string `yaml:"allow_origins"`
// AllowOriginFunc is a custom function to validate the origin. It takes the
// origin as an argument and returns true if allowed or false otherwise. If
// an error is returned, it is returned by the handler. If this option is
// set, AllowOrigins is ignored.
// Optional.
AllowOriginFunc func(origin string) (bool, error) `yaml:"allow_origin_func"`
// AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request.
// Optional. Default value DefaultCORSConfig.AllowMethods.
@ -76,6 +84,15 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
config.AllowMethods = DefaultCORSConfig.AllowMethods
}
allowOriginPatterns := []string{}
for _, origin := range config.AllowOrigins {
pattern := regexp.QuoteMeta(origin)
pattern = strings.Replace(pattern, "\\*", ".*", -1)
pattern = strings.Replace(pattern, "\\?", ".", -1)
pattern = "^" + pattern + "$"
allowOriginPatterns = append(allowOriginPatterns, pattern)
}
allowMethods := strings.Join(config.AllowMethods, ",")
allowHeaders := strings.Join(config.AllowHeaders, ",")
exposeHeaders := strings.Join(config.ExposeHeaders, ",")
@ -92,25 +109,73 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
origin := req.Header.Get(echo.HeaderOrigin)
allowOrigin := ""
// Check allowed origins
for _, o := range config.AllowOrigins {
if o == "*" && config.AllowCredentials {
allowOrigin = origin
break
preflight := req.Method == http.MethodOptions
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
// No Origin provided
if origin == "" {
if !preflight {
return next(c)
}
if o == "*" || o == origin {
allowOrigin = o
break
return c.NoContent(http.StatusNoContent)
}
if config.AllowOriginFunc != nil {
allowed, err := config.AllowOriginFunc(origin)
if err != nil {
return err
}
if matchSubdomain(origin, o) {
if allowed {
allowOrigin = origin
break
}
} else {
// Check allowed origins
for _, o := range config.AllowOrigins {
if o == "*" && config.AllowCredentials {
allowOrigin = origin
break
}
if o == "*" || o == origin {
allowOrigin = o
break
}
if matchSubdomain(origin, o) {
allowOrigin = origin
break
}
}
// Check allowed origin patterns
for _, re := range allowOriginPatterns {
if allowOrigin == "" {
didx := strings.Index(origin, "://")
if didx == -1 {
continue
}
domAuth := origin[didx+3:]
// to avoid regex cost by invalid long domain
if len(domAuth) > 253 {
break
}
if match, _ := regexp.MatchString(re, origin); match {
allowOrigin = origin
break
}
}
}
}
// Origin not allowed
if allowOrigin == "" {
if !preflight {
return next(c)
}
return c.NoContent(http.StatusNoContent)
}
// Simple request
if req.Method != http.MethodOptions {
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
if !preflight {
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)
if config.AllowCredentials {
res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true")
@ -122,7 +187,6 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc {
}
// Preflight request
res.Header().Add(echo.HeaderVary, echo.HeaderOrigin)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod)
res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders)
res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin)

View File

@ -57,6 +57,10 @@ type (
// Indicates if CSRF cookie is HTTP only.
// Optional. Default value false.
CookieHTTPOnly bool `yaml:"cookie_http_only"`
// Indicates SameSite mode of the CSRF cookie.
// Optional. Default value SameSiteDefaultMode.
CookieSameSite http.SameSite `yaml:"cookie_same_site"`
}
// csrfTokenExtractor defines a function that takes `echo.Context` and returns
@ -67,12 +71,13 @@ type (
var (
// DefaultCSRFConfig is the default CSRF middleware config.
DefaultCSRFConfig = CSRFConfig{
Skipper: DefaultSkipper,
TokenLength: 32,
TokenLookup: "header:" + echo.HeaderXCSRFToken,
ContextKey: "csrf",
CookieName: "_csrf",
CookieMaxAge: 86400,
Skipper: DefaultSkipper,
TokenLength: 32,
TokenLookup: "header:" + echo.HeaderXCSRFToken,
ContextKey: "csrf",
CookieName: "_csrf",
CookieMaxAge: 86400,
CookieSameSite: http.SameSiteDefaultMode,
}
)
@ -105,6 +110,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
if config.CookieMaxAge == 0 {
config.CookieMaxAge = DefaultCSRFConfig.CookieMaxAge
}
if config.CookieSameSite == SameSiteNoneMode {
config.CookieSecure = true
}
// Initialize
parts := strings.Split(config.TokenLookup, ":")
@ -157,6 +165,9 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc {
if config.CookieDomain != "" {
cookie.Domain = config.CookieDomain
}
if config.CookieSameSite != http.SameSiteDefaultMode {
cookie.SameSite = config.CookieSameSite
}
cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)
cookie.Secure = config.CookieSecure
cookie.HttpOnly = config.CookieHTTPOnly

View File

@ -0,0 +1,12 @@
// +build go1.13
package middleware
import (
"net/http"
)
const (
// SameSiteNoneMode required to be redefined for Go 1.12 support (see #1524)
SameSiteNoneMode http.SameSite = http.SameSiteNoneMode
)

View File

@ -0,0 +1,12 @@
// +build !go1.13
package middleware
import (
"net/http"
)
const (
// SameSiteNoneMode required to be redefined for Go 1.12 support (see #1524)
SameSiteNoneMode http.SameSite = 4
)

View File

@ -0,0 +1,120 @@
package middleware
import (
"bytes"
"compress/gzip"
"io"
"io/ioutil"
"net/http"
"sync"
"github.com/labstack/echo/v4"
)
type (
// DecompressConfig defines the config for Decompress middleware.
DecompressConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers
GzipDecompressPool Decompressor
}
)
//GZIPEncoding content-encoding header if set to "gzip", decompress body contents.
const GZIPEncoding string = "gzip"
// Decompressor is used to get the sync.Pool used by the middleware to get Gzip readers
type Decompressor interface {
gzipDecompressPool() sync.Pool
}
var (
//DefaultDecompressConfig defines the config for decompress middleware
DefaultDecompressConfig = DecompressConfig{
Skipper: DefaultSkipper,
GzipDecompressPool: &DefaultGzipDecompressPool{},
}
)
// DefaultGzipDecompressPool is the default implementation of Decompressor interface
type DefaultGzipDecompressPool struct {
}
func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool {
return sync.Pool{
New: func() interface{} {
// create with an empty reader (but with GZIP header)
w, err := gzip.NewWriterLevel(ioutil.Discard, gzip.BestSpeed)
if err != nil {
return err
}
b := new(bytes.Buffer)
w.Reset(b)
w.Flush()
w.Close()
r, err := gzip.NewReader(bytes.NewReader(b.Bytes()))
if err != nil {
return err
}
return r
},
}
}
//Decompress decompresses request body based if content encoding type is set to "gzip" with default config
func Decompress() echo.MiddlewareFunc {
return DecompressWithConfig(DefaultDecompressConfig)
}
//DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config
func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultGzipConfig.Skipper
}
if config.GzipDecompressPool == nil {
config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
pool := config.GzipDecompressPool.gzipDecompressPool()
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
switch c.Request().Header.Get(echo.HeaderContentEncoding) {
case GZIPEncoding:
b := c.Request().Body
i := pool.Get()
gr, ok := i.(*gzip.Reader)
if !ok {
return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error())
}
if err := gr.Reset(b); err != nil {
pool.Put(gr)
if err == io.EOF { //ignore if body is empty
return next(c)
}
return err
}
var buf bytes.Buffer
io.Copy(&buf, gr)
gr.Close()
pool.Put(gr)
b.Close() // http.Request.Body is closed by the Server, but because we are replacing it, it must be closed here
r := ioutil.NopCloser(&buf)
c.Request().Body = r
}
return next(c)
}
}
}

View File

@ -57,6 +57,7 @@ type (
// - "query:<name>"
// - "param:<name>"
// - "cookie:<name>"
// - "form:<name>"
TokenLookup string
// AuthScheme to be used in the Authorization header.
@ -86,6 +87,7 @@ const (
// Errors
var (
ErrJWTMissing = echo.NewHTTPError(http.StatusBadRequest, "missing or malformed jwt")
ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt")
)
var (
@ -166,6 +168,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
extractor = jwtFromParam(parts[1])
case "cookie":
extractor = jwtFromCookie(parts[1])
case "form":
extractor = jwtFromForm(parts[1])
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -213,8 +217,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
return config.ErrorHandlerWithContext(err, c)
}
return &echo.HTTPError{
Code: http.StatusUnauthorized,
Message: "invalid or expired jwt",
Code: ErrJWTInvalid.Code,
Message: ErrJWTInvalid.Message,
Internal: err,
}
}
@ -265,3 +269,14 @@ func jwtFromCookie(name string) jwtExtractor {
return cookie.Value, nil
}
}
// jwtFromForm returns a `jwtExtractor` that extracts token from the form field.
func jwtFromForm(name string) jwtExtractor {
return func(c echo.Context) (string, error) {
field := c.FormValue(name)
if field == "" {
return "", ErrJWTMissing
}
return field, nil
}
}

View File

@ -1,6 +1,8 @@
package middleware
import (
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
@ -32,6 +34,47 @@ func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer {
return strings.NewReplacer(replace...)
}
func rewriteRulesRegex(rewrite map[string]string) map[*regexp.Regexp]string {
// Initialize
rulesRegex := map[*regexp.Regexp]string{}
for k, v := range rewrite {
k = regexp.QuoteMeta(k)
k = strings.Replace(k, `\*`, "(.*?)", -1)
if strings.HasPrefix(k, `\^`) {
k = strings.Replace(k, `\^`, "^", -1)
}
k = k + "$"
rulesRegex[regexp.MustCompile(k)] = v
}
return rulesRegex
}
func rewritePath(rewriteRegex map[*regexp.Regexp]string, req *http.Request) {
for k, v := range rewriteRegex {
rawPath := req.URL.RawPath
if rawPath != "" {
// RawPath is only set when there has been escaping done. In that case Path must be deduced from rewritten RawPath
// because encoded Path could match rules that RawPath did not
if replacer := captureTokens(k, rawPath); replacer != nil {
rawPath = replacer.Replace(v)
req.URL.RawPath = rawPath
req.URL.Path, _ = url.PathUnescape(rawPath)
return // rewrite only once
}
continue
}
if replacer := captureTokens(k, req.URL.Path); replacer != nil {
req.URL.Path = replacer.Replace(v)
return // rewrite only once
}
}
}
// DefaultSkipper returns false which processes the middleware.
func DefaultSkipper(echo.Context) bool {
return false

View File

@ -8,7 +8,6 @@ import (
"net/http"
"net/url"
"regexp"
"strings"
"sync"
"sync/atomic"
"time"
@ -37,6 +36,13 @@ type (
// "/users/*/orders/*": "/user/$1/order/$2",
Rewrite map[string]string
// RegexRewrite defines rewrite rules using regexp.Rexexp with captures
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
// Example:
// "^/old/[0.9]+/": "/new",
// "^/api/.+?/(.*)": "/v2/$1",
RegexRewrite map[*regexp.Regexp]string
// Context key to store selected ProxyTarget into context.
// Optional. Default value "target".
ContextKey string
@ -47,8 +53,6 @@ type (
// ModifyResponse defines function to modify response from ProxyTarget.
ModifyResponse func(*http.Response) error
rewriteRegex map[*regexp.Regexp]string
}
// ProxyTarget defines the upstream target.
@ -206,12 +210,14 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
if config.Balancer == nil {
panic("echo: proxy middleware requires balancer")
}
config.rewriteRegex = map[*regexp.Regexp]string{}
// Initialize
for k, v := range config.Rewrite {
k = strings.Replace(k, "*", "(\\S*)", -1)
config.rewriteRegex[regexp.MustCompile(k)] = v
if config.Rewrite != nil {
if config.RegexRewrite == nil {
config.RegexRewrite = make(map[*regexp.Regexp]string)
}
for k, v := range rewriteRulesRegex(config.Rewrite) {
config.RegexRewrite[k] = v
}
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -225,13 +231,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc {
tgt := config.Balancer.Next(c)
c.Set(config.ContextKey, tgt)
// Rewrite
for k, v := range config.rewriteRegex {
replacer := captureTokens(k, echo.GetPath(req))
if replacer != nil {
req.URL.Path = replacer.Replace(v)
}
}
// Set rewrite path and raw path
rewritePath(config.RegexRewrite, req)
// Fix header
// Basically it's not good practice to unconditionally pass incoming x-real-ip header to upstream.

View File

@ -3,13 +3,22 @@
package middleware
import (
"context"
"fmt"
"net/http"
"net/http/httputil"
"strings"
"github.com/labstack/echo/v4"
)
// StatusCodeContextCanceled is a custom HTTP status code for situations
// where a client unexpectedly closed the connection to the server.
// As there is no standard error code for "client closed connection", but
// various well-known HTTP clients and server implement this HTTP code we use
// 499 too instead of the more problematic 5xx, which does not allow to detect this situation
const StatusCodeContextCanceled = 499
func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler {
proxy := httputil.NewSingleHostReverseProxy(tgt.URL)
proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) {
@ -17,7 +26,20 @@ func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handle
if tgt.Name != "" {
desc = fmt.Sprintf("%s(%s)", tgt.Name, tgt.URL.String())
}
c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("remote %s unreachable, could not forward: %v", desc, err)))
// If the client canceled the request (usually by closing the connection), we can report a
// client error (4xx) instead of a server error (5xx) to correctly identify the situation.
// The Go standard library (at of late 2020) wraps the exported, standard
// context.Canceled error with unexported garbage value requiring a substring check, see
// https://github.com/golang/go/blob/6965b01ea248cabb70c3749fd218b36089a21efb/src/net/net.go#L416-L430
if err == context.Canceled || strings.Contains(err.Error(), "operation was canceled") {
httpError := echo.NewHTTPError(StatusCodeContextCanceled, fmt.Sprintf("client closed connection: %v", err))
httpError.Internal = err
c.Set("_error", httpError)
} else {
httpError := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("remote %s unreachable, could not forward: %v", desc, err))
httpError.Internal = err
c.Set("_error", httpError)
}
}
proxy.Transport = config.Transport
proxy.ModifyResponse = config.ModifyResponse

View File

@ -0,0 +1,266 @@
package middleware
import (
"net/http"
"sync"
"time"
"github.com/labstack/echo/v4"
"golang.org/x/time/rate"
)
type (
// RateLimiterStore is the interface to be implemented by custom stores.
RateLimiterStore interface {
// Stores for the rate limiter have to implement the Allow method
Allow(identifier string) (bool, error)
}
)
type (
// RateLimiterConfig defines the configuration for the rate limiter
RateLimiterConfig struct {
Skipper Skipper
BeforeFunc BeforeFunc
// IdentifierExtractor uses echo.Context to extract the identifier for a visitor
IdentifierExtractor Extractor
// Store defines a store for the rate limiter
Store RateLimiterStore
// ErrorHandler provides a handler to be called when IdentifierExtractor returns an error
ErrorHandler func(context echo.Context, err error) error
// DenyHandler provides a handler to be called when RateLimiter denies access
DenyHandler func(context echo.Context, identifier string, err error) error
}
// Extractor is used to extract data from echo.Context
Extractor func(context echo.Context) (string, error)
)
// errors
var (
// ErrRateLimitExceeded denotes an error raised when rate limit is exceeded
ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded")
// ErrExtractorError denotes an error raised when extractor function is unsuccessful
ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while extracting identifier")
)
// DefaultRateLimiterConfig defines default values for RateLimiterConfig
var DefaultRateLimiterConfig = RateLimiterConfig{
Skipper: DefaultSkipper,
IdentifierExtractor: func(ctx echo.Context) (string, error) {
id := ctx.RealIP()
return id, nil
},
ErrorHandler: func(context echo.Context, err error) error {
return &echo.HTTPError{
Code: ErrExtractorError.Code,
Message: ErrExtractorError.Message,
Internal: err,
}
},
DenyHandler: func(context echo.Context, identifier string, err error) error {
return &echo.HTTPError{
Code: ErrRateLimitExceeded.Code,
Message: ErrRateLimitExceeded.Message,
Internal: err,
}
},
}
/*
RateLimiter returns a rate limiting middleware
e := echo.New()
limiterStore := middleware.NewRateLimiterMemoryStore(20)
e.GET("/rate-limited", func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}, RateLimiter(limiterStore))
*/
func RateLimiter(store RateLimiterStore) echo.MiddlewareFunc {
config := DefaultRateLimiterConfig
config.Store = store
return RateLimiterWithConfig(config)
}
/*
RateLimiterWithConfig returns a rate limiting middleware
e := echo.New()
config := middleware.RateLimiterConfig{
Skipper: DefaultSkipper,
Store: middleware.NewRateLimiterMemoryStore(
middleware.RateLimiterMemoryStoreConfig{Rate: 10, Burst: 30, ExpiresIn: 3 * time.Minute}
)
IdentifierExtractor: func(ctx echo.Context) (string, error) {
id := ctx.RealIP()
return id, nil
},
ErrorHandler: func(context echo.Context, err error) error {
return context.JSON(http.StatusTooManyRequests, nil)
},
DenyHandler: func(context echo.Context, identifier string) error {
return context.JSON(http.StatusForbidden, nil)
},
}
e.GET("/rate-limited", func(c echo.Context) error {
return c.String(http.StatusOK, "test")
}, middleware.RateLimiterWithConfig(config))
*/
func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc {
if config.Skipper == nil {
config.Skipper = DefaultRateLimiterConfig.Skipper
}
if config.IdentifierExtractor == nil {
config.IdentifierExtractor = DefaultRateLimiterConfig.IdentifierExtractor
}
if config.ErrorHandler == nil {
config.ErrorHandler = DefaultRateLimiterConfig.ErrorHandler
}
if config.DenyHandler == nil {
config.DenyHandler = DefaultRateLimiterConfig.DenyHandler
}
if config.Store == nil {
panic("Store configuration must be provided")
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}
if config.BeforeFunc != nil {
config.BeforeFunc(c)
}
identifier, err := config.IdentifierExtractor(c)
if err != nil {
c.Error(config.ErrorHandler(c, err))
return nil
}
if allow, err := config.Store.Allow(identifier); !allow {
c.Error(config.DenyHandler(c, identifier, err))
return nil
}
return next(c)
}
}
}
type (
// RateLimiterMemoryStore is the built-in store implementation for RateLimiter
RateLimiterMemoryStore struct {
visitors map[string]*Visitor
mutex sync.Mutex
rate rate.Limit
burst int
expiresIn time.Duration
lastCleanup time.Time
}
// Visitor signifies a unique user's limiter details
Visitor struct {
*rate.Limiter
lastSeen time.Time
}
)
/*
NewRateLimiterMemoryStore returns an instance of RateLimiterMemoryStore with
the provided rate (as req/s). Burst and ExpiresIn will be set to default values.
Example (with 20 requests/sec):
limiterStore := middleware.NewRateLimiterMemoryStore(20)
*/
func NewRateLimiterMemoryStore(rate rate.Limit) (store *RateLimiterMemoryStore) {
return NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{
Rate: rate,
})
}
/*
NewRateLimiterMemoryStoreWithConfig returns an instance of RateLimiterMemoryStore
with the provided configuration. Rate must be provided. Burst will be set to the value of
the configured rate if not provided or set to 0.
The build-in memory store is usually capable for modest loads. For higher loads other
store implementations should be considered.
Characteristics:
* Concurrency above 100 parallel requests may causes measurable lock contention
* A high number of different IP addresses (above 16000) may be impacted by the internally used Go map
* A high number of requests from a single IP address may cause lock contention
Example:
limiterStore := middleware.NewRateLimiterMemoryStoreWithConfig(
middleware.RateLimiterMemoryStoreConfig{Rate: 50, Burst: 200, ExpiresIn: 5 * time.Minutes},
)
*/
func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (store *RateLimiterMemoryStore) {
store = &RateLimiterMemoryStore{}
store.rate = config.Rate
store.burst = config.Burst
store.expiresIn = config.ExpiresIn
if config.ExpiresIn == 0 {
store.expiresIn = DefaultRateLimiterMemoryStoreConfig.ExpiresIn
}
if config.Burst == 0 {
store.burst = int(config.Rate)
}
store.visitors = make(map[string]*Visitor)
store.lastCleanup = now()
return
}
// RateLimiterMemoryStoreConfig represents configuration for RateLimiterMemoryStore
type RateLimiterMemoryStoreConfig struct {
Rate rate.Limit // Rate of requests allowed to pass as req/s
Burst int // Burst additionally allows a number of requests to pass when rate limit is reached
ExpiresIn time.Duration // ExpiresIn is the duration after that a rate limiter is cleaned up
}
// DefaultRateLimiterMemoryStoreConfig provides default configuration values for RateLimiterMemoryStore
var DefaultRateLimiterMemoryStoreConfig = RateLimiterMemoryStoreConfig{
ExpiresIn: 3 * time.Minute,
}
// Allow implements RateLimiterStore.Allow
func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) {
store.mutex.Lock()
limiter, exists := store.visitors[identifier]
if !exists {
limiter = new(Visitor)
limiter.Limiter = rate.NewLimiter(store.rate, store.burst)
store.visitors[identifier] = limiter
}
limiter.lastSeen = now()
if now().Sub(store.lastCleanup) > store.expiresIn {
store.cleanupStaleVisitors()
}
store.mutex.Unlock()
return limiter.AllowN(now(), 1), nil
}
/*
cleanupStaleVisitors helps manage the size of the visitors map by removing stale records
of users who haven't visited again after the configured expiry time has elapsed
*/
func (store *RateLimiterMemoryStore) cleanupStaleVisitors() {
for id, visitor := range store.visitors {
if now().Sub(visitor.lastSeen) > store.expiresIn {
delete(store.visitors, id)
}
}
store.lastCleanup = now()
}
/*
actual time method which is mocked in test file
*/
var now = time.Now

View File

@ -2,7 +2,6 @@ package middleware
import (
"regexp"
"strings"
"github.com/labstack/echo/v4"
)
@ -23,7 +22,12 @@ type (
// Required.
Rules map[string]string `yaml:"rules"`
rulesRegex map[*regexp.Regexp]string
// RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures
// Every capture group in the values can be retrieved by index e.g. $1, $2 and so on.
// Example:
// "^/old/[0.9]+/": "/new",
// "^/api/.+?/(.*)": "/v2/$1",
RegexRules map[*regexp.Regexp]string `yaml:"regex_rules"`
}
)
@ -47,20 +51,19 @@ func Rewrite(rules map[string]string) echo.MiddlewareFunc {
// See: `Rewrite()`.
func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
// Defaults
if config.Rules == nil {
panic("echo: rewrite middleware requires url path rewrite rules")
if config.Rules == nil && config.RegexRules == nil {
panic("echo: rewrite middleware requires url path rewrite rules or regex rules")
}
if config.Skipper == nil {
config.Skipper = DefaultBodyDumpConfig.Skipper
}
config.rulesRegex = map[*regexp.Regexp]string{}
// Initialize
for k, v := range config.Rules {
k = regexp.QuoteMeta(k)
k = strings.Replace(k, `\*`, "(.*)", -1)
k = k + "$"
config.rulesRegex[regexp.MustCompile(k)] = v
if config.RegexRules == nil {
config.RegexRules = make(map[*regexp.Regexp]string)
}
for k, v := range rewriteRulesRegex(config.Rules) {
config.RegexRules[k] = v
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
@ -70,15 +73,8 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc {
}
req := c.Request()
// Rewrite
for k, v := range config.rulesRegex {
replacer := captureTokens(k, req.URL.Path)
if replacer != nil {
req.URL.Path = replacer.Replace(v)
break
}
}
// Set rewrite path and raw path
rewritePath(config.RegexRules, req)
return next(c)
}
}

View File

@ -60,7 +60,7 @@ func AddTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFunc
// Redirect
if config.RedirectCode != 0 {
return c.Redirect(config.RedirectCode, uri)
return c.Redirect(config.RedirectCode, sanitizeURI(uri))
}
// Forward
@ -108,7 +108,7 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu
// Redirect
if config.RedirectCode != 0 {
return c.Redirect(config.RedirectCode, uri)
return c.Redirect(config.RedirectCode, sanitizeURI(uri))
}
// Forward
@ -119,3 +119,12 @@ func RemoveTrailingSlashWithConfig(config TrailingSlashConfig) echo.MiddlewareFu
}
}
}
func sanitizeURI(uri string) string {
// double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri
// we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash
if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') {
uri = "/" + strings.TrimLeft(uri, `/\`)
}
return uri
}

View File

@ -36,6 +36,12 @@ type (
// Enable directory browsing.
// Optional. Default value false.
Browse bool `yaml:"browse"`
// Enable ignoring of the base of the URL path.
// Example: when assigning a static middleware to a non root path group,
// the filesystem path is not doubled
// Optional. Default value false.
IgnoreBase bool `yaml:"ignoreBase"`
}
)
@ -161,7 +167,16 @@ func StaticWithConfig(config StaticConfig) echo.MiddlewareFunc {
if err != nil {
return
}
name := filepath.Join(config.Root, path.Clean("/"+p)) // "/"+ for security
name := filepath.Join(config.Root, filepath.Clean("/"+p)) // "/"+ for security
if config.IgnoreBase {
routePath := path.Base(strings.TrimRight(c.Path(), "/*"))
baseURLPath := path.Base(p)
if baseURLPath == routePath {
i := strings.LastIndex(name, routePath)
name = name[:i] + strings.Replace(name[i:], routePath, "", 1)
}
}
fi, err := os.Stat(name)
if err != nil {

View File

@ -0,0 +1,111 @@
// +build go1.13
package middleware
import (
"context"
"github.com/labstack/echo/v4"
"net/http"
"time"
)
type (
// TimeoutConfig defines the config for Timeout middleware.
TimeoutConfig struct {
// Skipper defines a function to skip middleware.
Skipper Skipper
// ErrorMessage is written to response on timeout in addition to http.StatusServiceUnavailable (503) status code
// It can be used to define a custom timeout error message
ErrorMessage string
// OnTimeoutRouteErrorHandler is an error handler that is executed for error that was returned from wrapped route after
// request timeouted and we already had sent the error code (503) and message response to the client.
// NB: do not write headers/body inside this handler. The response has already been sent to the client and response writer
// will not accept anything no more. If you want to know what actual route middleware timeouted use `c.Path()`
OnTimeoutRouteErrorHandler func(err error, c echo.Context)
// Timeout configures a timeout for the middleware, defaults to 0 for no timeout
// NOTE: when difference between timeout duration and handler execution time is almost the same (in range of 100microseconds)
// the result of timeout does not seem to be reliable - could respond timeout, could respond handler output
// difference over 500microseconds (0.5millisecond) response seems to be reliable
Timeout time.Duration
}
)
var (
// DefaultTimeoutConfig is the default Timeout middleware config.
DefaultTimeoutConfig = TimeoutConfig{
Skipper: DefaultSkipper,
Timeout: 0,
ErrorMessage: "",
}
)
// Timeout returns a middleware which recovers from panics anywhere in the chain
// and handles the control to the centralized HTTPErrorHandler.
func Timeout() echo.MiddlewareFunc {
return TimeoutWithConfig(DefaultTimeoutConfig)
}
// TimeoutWithConfig returns a Timeout middleware with config.
// See: `Timeout()`.
func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc {
// Defaults
if config.Skipper == nil {
config.Skipper = DefaultTimeoutConfig.Skipper
}
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) || config.Timeout == 0 {
return next(c)
}
handlerWrapper := echoHandlerFuncWrapper{
ctx: c,
handler: next,
errChan: make(chan error, 1),
errHandler: config.OnTimeoutRouteErrorHandler,
}
handler := http.TimeoutHandler(handlerWrapper, config.Timeout, config.ErrorMessage)
handler.ServeHTTP(c.Response().Writer, c.Request())
select {
case err := <-handlerWrapper.errChan:
return err
default:
return nil
}
}
}
}
type echoHandlerFuncWrapper struct {
ctx echo.Context
handler echo.HandlerFunc
errHandler func(err error, c echo.Context)
errChan chan error
}
func (t echoHandlerFuncWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
// replace writer with TimeoutHandler custom one. This will guarantee that
// `writes by h to its ResponseWriter will return ErrHandlerTimeout.`
originalWriter := t.ctx.Response().Writer
t.ctx.Response().Writer = rw
err := t.handler(t.ctx)
if ctxErr := r.Context().Err(); ctxErr == context.DeadlineExceeded {
if err != nil && t.errHandler != nil {
t.errHandler(err, t.ctx)
}
return // on timeout we can not send handler error to client because `http.TimeoutHandler` has already sent headers
}
// we restore original writer only for cases we did not timeout. On timeout we have already sent response to client
// and should not anymore send additional headers/data
// so on timeout writer stays what http.TimeoutHandler uses and prevents writing headers/body
t.ctx.Response().Writer = originalWriter
if err != nil {
t.errChan <- err
}
}