Add rate limiter, rename maxConn into inFlightReq
Co-authored-by: Julien Salleyron <julien.salleyron@gmail.com> Co-authored-by: Jean-Baptiste Doumenjou <jb.doumenjou@gmail.com>
This commit is contained in:
parent
a8c73f7baf
commit
4ec90c5c0d
30 changed files with 1419 additions and 651 deletions
49
pkg/middlewares/extractor.go
Normal file
49
pkg/middlewares/extractor.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/v2/pkg/config/dynamic"
|
||||
"github.com/containous/traefik/v2/pkg/log"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
// GetSourceExtractor returns the SourceExtractor function corresponding to the given sourceMatcher.
|
||||
// It defaults to a RemoteAddrStrategy IPStrategy if need be.
|
||||
func GetSourceExtractor(ctx context.Context, sourceMatcher *dynamic.SourceCriterion) (utils.SourceExtractor, error) {
|
||||
if sourceMatcher == nil ||
|
||||
sourceMatcher.IPStrategy == nil &&
|
||||
sourceMatcher.RequestHeaderName == "" && !sourceMatcher.RequestHost {
|
||||
sourceMatcher = &dynamic.SourceCriterion{
|
||||
IPStrategy: &dynamic.IPStrategy{},
|
||||
}
|
||||
}
|
||||
|
||||
logger := log.FromContext(ctx)
|
||||
if sourceMatcher.IPStrategy != nil {
|
||||
strategy, err := sourceMatcher.IPStrategy.Get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Debug("Using IPStrategy")
|
||||
return utils.ExtractorFunc(func(req *http.Request) (string, int64, error) {
|
||||
return strategy.GetIP(req), 1, nil
|
||||
}), nil
|
||||
}
|
||||
|
||||
if sourceMatcher.RequestHeaderName != "" {
|
||||
logger.Debug("Using RequestHeaderName")
|
||||
return utils.NewExtractor(fmt.Sprintf("request.header.%s", sourceMatcher.RequestHeaderName))
|
||||
}
|
||||
|
||||
if sourceMatcher.RequestHost {
|
||||
logger.Debug("Using RequestHost")
|
||||
return utils.NewExtractor("request.host")
|
||||
}
|
||||
|
||||
return nil, errors.New("no SourceCriterion criterion defined")
|
||||
}
|
57
pkg/middlewares/inflightreq/inflight_req.go
Normal file
57
pkg/middlewares/inflightreq/inflight_req.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package inflightreq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/v2/pkg/config/dynamic"
|
||||
"github.com/containous/traefik/v2/pkg/log"
|
||||
"github.com/containous/traefik/v2/pkg/middlewares"
|
||||
"github.com/containous/traefik/v2/pkg/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/vulcand/oxy/connlimit"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "InFlightReq"
|
||||
)
|
||||
|
||||
type inFlightReq struct {
|
||||
handler http.Handler
|
||||
name string
|
||||
}
|
||||
|
||||
// New creates a max request middleware.
|
||||
func New(ctx context.Context, next http.Handler, config dynamic.InFlightReq, name string) (http.Handler, error) {
|
||||
ctxLog := log.With(ctx, log.Str(log.MiddlewareName, name), log.Str(log.MiddlewareType, typeName))
|
||||
log.FromContext(ctxLog).Debug("Creating middleware")
|
||||
|
||||
if config.SourceCriterion == nil ||
|
||||
config.SourceCriterion.IPStrategy == nil &&
|
||||
config.SourceCriterion.RequestHeaderName == "" && !config.SourceCriterion.RequestHost {
|
||||
config.SourceCriterion = &dynamic.SourceCriterion{
|
||||
RequestHost: true,
|
||||
}
|
||||
}
|
||||
|
||||
sourceMatcher, err := middlewares.GetSourceExtractor(ctxLog, config.SourceCriterion)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating requests limiter: %v", err)
|
||||
}
|
||||
|
||||
handler, err := connlimit.New(next, sourceMatcher, config.Amount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating connection limit: %v", err)
|
||||
}
|
||||
|
||||
return &inFlightReq{handler: handler, name: name}, nil
|
||||
}
|
||||
|
||||
func (i *inFlightReq) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return i.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (i *inFlightReq) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
i.handler.ServeHTTP(rw, req)
|
||||
}
|
|
@ -1,48 +0,0 @@
|
|||
package maxconnection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/v2/pkg/config/dynamic"
|
||||
"github.com/containous/traefik/v2/pkg/middlewares"
|
||||
"github.com/containous/traefik/v2/pkg/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/vulcand/oxy/connlimit"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "MaxConnection"
|
||||
)
|
||||
|
||||
type maxConnection struct {
|
||||
handler http.Handler
|
||||
name string
|
||||
}
|
||||
|
||||
// New creates a max connection middleware.
|
||||
func New(ctx context.Context, next http.Handler, maxConns dynamic.MaxConn, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
|
||||
extractFunc, err := utils.NewExtractor(maxConns.ExtractorFunc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating connection limit: %v", err)
|
||||
}
|
||||
|
||||
handler, err := connlimit.New(next, extractFunc, maxConns.Amount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating connection limit: %v", err)
|
||||
}
|
||||
|
||||
return &maxConnection{handler: handler, name: name}, nil
|
||||
}
|
||||
|
||||
func (mc *maxConnection) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return mc.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (mc *maxConnection) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
mc.handler.ServeHTTP(rw, req)
|
||||
}
|
|
@ -1,54 +1,146 @@
|
|||
// Package ratelimiter implements a rate limiting and traffic shaping middleware with a set of token buckets.
|
||||
package ratelimiter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/v2/pkg/config/dynamic"
|
||||
"github.com/containous/traefik/v2/pkg/log"
|
||||
"github.com/containous/traefik/v2/pkg/middlewares"
|
||||
"github.com/containous/traefik/v2/pkg/tracing"
|
||||
"github.com/mailgun/ttlmap"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/vulcand/oxy/ratelimit"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "RateLimiterType"
|
||||
typeName = "RateLimiterType"
|
||||
maxSources = 65536
|
||||
)
|
||||
|
||||
// rateLimiter implements rate limiting and traffic shaping with a set of token buckets;
|
||||
// one for each traffic source. The same parameters are applied to all the buckets.
|
||||
type rateLimiter struct {
|
||||
handler http.Handler
|
||||
name string
|
||||
name string
|
||||
rate rate.Limit // reqs/s
|
||||
burst int64
|
||||
// maxDelay is the maximum duration we're willing to wait for a bucket reservation to become effective, in nanoseconds.
|
||||
// For now it is somewhat arbitrarily set to 1/rate.
|
||||
maxDelay time.Duration
|
||||
sourceMatcher utils.SourceExtractor
|
||||
next http.Handler
|
||||
|
||||
bucketsMu sync.Mutex
|
||||
buckets *ttlmap.TtlMap // actual buckets, keyed by source.
|
||||
}
|
||||
|
||||
// New creates rate limiter middleware.
|
||||
// New returns a rate limiter middleware.
|
||||
func New(ctx context.Context, next http.Handler, config dynamic.RateLimit, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
ctxLog := log.With(ctx, log.Str(log.MiddlewareName, name), log.Str(log.MiddlewareType, typeName))
|
||||
log.FromContext(ctxLog).Debug("Creating middleware")
|
||||
|
||||
extractFunc, err := utils.NewExtractor(config.ExtractorFunc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rateSet := ratelimit.NewRateSet()
|
||||
for _, rate := range config.RateSet {
|
||||
if err = rateSet.Add(time.Duration(rate.Period), rate.Average, rate.Burst); err != nil {
|
||||
return nil, err
|
||||
if config.SourceCriterion == nil ||
|
||||
config.SourceCriterion.IPStrategy == nil &&
|
||||
config.SourceCriterion.RequestHeaderName == "" && !config.SourceCriterion.RequestHost {
|
||||
config.SourceCriterion = &dynamic.SourceCriterion{
|
||||
IPStrategy: &dynamic.IPStrategy{},
|
||||
}
|
||||
}
|
||||
|
||||
rl, err := ratelimit.New(next, extractFunc, rateSet)
|
||||
sourceMatcher, err := middlewares.GetSourceExtractor(ctxLog, config.SourceCriterion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &rateLimiter{handler: rl, name: name}, nil
|
||||
|
||||
buckets, err := ttlmap.NewMap(maxSources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
burst := config.Burst
|
||||
if burst <= 0 {
|
||||
burst = 1
|
||||
}
|
||||
|
||||
// Logically, we should set maxDelay to ~infinity when config.Average == 0 (because it means to rate limiting),
|
||||
// but since the reservation will give us a delay = 0 anyway in this case, we're good even with any maxDelay >= 0.
|
||||
var maxDelay time.Duration
|
||||
if config.Average != 0 {
|
||||
maxDelay = time.Second / time.Duration(config.Average*2)
|
||||
}
|
||||
|
||||
return &rateLimiter{
|
||||
name: name,
|
||||
rate: rate.Limit(config.Average),
|
||||
burst: burst,
|
||||
maxDelay: maxDelay,
|
||||
next: next,
|
||||
sourceMatcher: sourceMatcher,
|
||||
buckets: buckets,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *rateLimiter) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return r.name, tracing.SpanKindNoneEnum
|
||||
func (rl *rateLimiter) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return rl.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (r *rateLimiter) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
r.handler.ServeHTTP(rw, req)
|
||||
func (rl *rateLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
logger := middlewares.GetLogger(r.Context(), rl.name, typeName)
|
||||
|
||||
source, amount, err := rl.sourceMatcher.Extract(r)
|
||||
if err != nil {
|
||||
logger.Errorf("could not extract source of request: %v", err)
|
||||
http.Error(w, "could not extract source of request", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if amount != 1 {
|
||||
logger.Infof("ignoring token bucket amount > 1: %d", amount)
|
||||
}
|
||||
|
||||
rl.bucketsMu.Lock()
|
||||
defer rl.bucketsMu.Unlock()
|
||||
|
||||
var bucket *rate.Limiter
|
||||
if rlSource, exists := rl.buckets.Get(source); exists {
|
||||
bucket = rlSource.(*rate.Limiter)
|
||||
} else {
|
||||
bucket = rate.NewLimiter(rl.rate, int(rl.burst))
|
||||
if err := rl.buckets.Set(source, bucket, int(rl.maxDelay)*10+1); err != nil {
|
||||
logger.Errorf("could not insert bucket: %v", err)
|
||||
http.Error(w, "could not insert bucket", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
res := bucket.Reserve()
|
||||
if !res.OK() {
|
||||
http.Error(w, "No bursty traffic allowed", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
delay := res.Delay()
|
||||
if delay > rl.maxDelay {
|
||||
res.Cancel()
|
||||
rl.serveDelayError(w, r, delay)
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(delay)
|
||||
rl.next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) serveDelayError(w http.ResponseWriter, r *http.Request, delay time.Duration) {
|
||||
w.Header().Set("Retry-After", fmt.Sprintf("%.0f", delay.Seconds()))
|
||||
w.Header().Set("X-Retry-In", delay.String())
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
|
||||
if _, err := w.Write([]byte(http.StatusText(http.StatusTooManyRequests))); err != nil {
|
||||
middlewares.GetLogger(r.Context(), rl.name, typeName).Errorf("could not serve 429: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
160
pkg/middlewares/ratelimiter/rate_limiter_test.go
Normal file
160
pkg/middlewares/ratelimiter/rate_limiter_test.go
Normal file
|
@ -0,0 +1,160 @@
|
|||
package ratelimiter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/v2/pkg/config/dynamic"
|
||||
"github.com/containous/traefik/v2/pkg/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
func TestNewRateLimiter(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
config dynamic.RateLimit
|
||||
expectedMaxDelay time.Duration
|
||||
expectedSourceIP string
|
||||
}{
|
||||
{
|
||||
desc: "maxDelay computation",
|
||||
config: dynamic.RateLimit{
|
||||
Average: 200,
|
||||
Burst: 10,
|
||||
},
|
||||
expectedMaxDelay: 2500 * time.Microsecond,
|
||||
},
|
||||
{
|
||||
desc: "default SourceMatcher is remote address ip strategy",
|
||||
config: dynamic.RateLimit{
|
||||
Average: 200,
|
||||
Burst: 10,
|
||||
},
|
||||
expectedSourceIP: "127.0.0.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
h, err := New(context.Background(), next, test.config, "rate-limiter")
|
||||
require.NoError(t, err)
|
||||
|
||||
rtl, _ := h.(*rateLimiter)
|
||||
if test.expectedMaxDelay != 0 {
|
||||
assert.Equal(t, test.expectedMaxDelay, rtl.maxDelay)
|
||||
}
|
||||
|
||||
if test.expectedSourceIP != "" {
|
||||
extractor, ok := rtl.sourceMatcher.(utils.ExtractorFunc)
|
||||
require.True(t, ok, "Not an ExtractorFunc")
|
||||
|
||||
req := http.Request{
|
||||
RemoteAddr: fmt.Sprintf("%s:1234", test.expectedSourceIP),
|
||||
}
|
||||
|
||||
ip, _, err := extractor(&req)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, test.expectedSourceIP, ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimit(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
config dynamic.RateLimit
|
||||
reqCount int
|
||||
}{
|
||||
{
|
||||
desc: "Average is respected",
|
||||
config: dynamic.RateLimit{
|
||||
Average: 100,
|
||||
Burst: 1,
|
||||
},
|
||||
reqCount: 200,
|
||||
},
|
||||
{
|
||||
desc: "Burst is taken into account",
|
||||
config: dynamic.RateLimit{
|
||||
Average: 100,
|
||||
Burst: 200,
|
||||
},
|
||||
reqCount: 300,
|
||||
},
|
||||
{
|
||||
desc: "Zero average ==> no rate limiting",
|
||||
config: dynamic.RateLimit{
|
||||
Average: 0,
|
||||
Burst: 1,
|
||||
},
|
||||
reqCount: 100,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
reqCount := 0
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reqCount++
|
||||
})
|
||||
|
||||
h, err := New(context.Background(), next, test.config, "rate-limiter")
|
||||
require.NoError(t, err)
|
||||
|
||||
start := time.Now()
|
||||
for {
|
||||
if reqCount >= test.reqCount {
|
||||
break
|
||||
}
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
|
||||
req.RemoteAddr = "127.0.0.1:1234"
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.ServeHTTP(w, req)
|
||||
// TODO(mpl): predict and count the 200 VS the 429?
|
||||
}
|
||||
|
||||
stop := time.Now()
|
||||
elapsed := stop.Sub(start)
|
||||
if test.config.Average == 0 {
|
||||
if elapsed > time.Millisecond {
|
||||
t.Fatalf("rate should not have been limited, but: %d requests in %v", reqCount, elapsed)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Assume allowed burst is initially consumed in an infinitesimal period of time
|
||||
var expectedDuration time.Duration
|
||||
if test.config.Average != 0 {
|
||||
expectedDuration = time.Duration((int64(test.reqCount)-test.config.Burst+1)/test.config.Average) * time.Second
|
||||
}
|
||||
|
||||
// Allow for a 2% leeway
|
||||
minDuration := expectedDuration * 98 / 100
|
||||
maxDuration := expectedDuration * 102 / 100
|
||||
|
||||
if elapsed < minDuration {
|
||||
t.Fatalf("rate was faster than expected: %d requests in %v", reqCount, elapsed)
|
||||
}
|
||||
if elapsed > maxDuration {
|
||||
t.Fatalf("rate was slower than expected: %d requests in %v", reqCount, elapsed)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue