1
0
Fork 0

Add rate limiter, rename maxConn into inFlightReq

Co-authored-by: Julien Salleyron <julien.salleyron@gmail.com>
Co-authored-by: Jean-Baptiste Doumenjou <jb.doumenjou@gmail.com>
This commit is contained in:
mpl 2019-08-26 12:20:06 +02:00 committed by Traefiker Bot
parent a8c73f7baf
commit 4ec90c5c0d
30 changed files with 1419 additions and 651 deletions

View file

@ -0,0 +1,49 @@
package middlewares
import (
"context"
"errors"
"fmt"
"net/http"
"github.com/containous/traefik/v2/pkg/config/dynamic"
"github.com/containous/traefik/v2/pkg/log"
"github.com/vulcand/oxy/utils"
)
// GetSourceExtractor returns the SourceExtractor function corresponding to the given sourceMatcher.
// It defaults to a RemoteAddrStrategy IPStrategy if need be.
func GetSourceExtractor(ctx context.Context, sourceMatcher *dynamic.SourceCriterion) (utils.SourceExtractor, error) {
if sourceMatcher == nil ||
sourceMatcher.IPStrategy == nil &&
sourceMatcher.RequestHeaderName == "" && !sourceMatcher.RequestHost {
sourceMatcher = &dynamic.SourceCriterion{
IPStrategy: &dynamic.IPStrategy{},
}
}
logger := log.FromContext(ctx)
if sourceMatcher.IPStrategy != nil {
strategy, err := sourceMatcher.IPStrategy.Get()
if err != nil {
return nil, err
}
logger.Debug("Using IPStrategy")
return utils.ExtractorFunc(func(req *http.Request) (string, int64, error) {
return strategy.GetIP(req), 1, nil
}), nil
}
if sourceMatcher.RequestHeaderName != "" {
logger.Debug("Using RequestHeaderName")
return utils.NewExtractor(fmt.Sprintf("request.header.%s", sourceMatcher.RequestHeaderName))
}
if sourceMatcher.RequestHost {
logger.Debug("Using RequestHost")
return utils.NewExtractor("request.host")
}
return nil, errors.New("no SourceCriterion criterion defined")
}

View file

@ -0,0 +1,57 @@
package inflightreq
import (
"context"
"fmt"
"net/http"
"github.com/containous/traefik/v2/pkg/config/dynamic"
"github.com/containous/traefik/v2/pkg/log"
"github.com/containous/traefik/v2/pkg/middlewares"
"github.com/containous/traefik/v2/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/vulcand/oxy/connlimit"
)
const (
typeName = "InFlightReq"
)
type inFlightReq struct {
handler http.Handler
name string
}
// New creates a max request middleware.
func New(ctx context.Context, next http.Handler, config dynamic.InFlightReq, name string) (http.Handler, error) {
ctxLog := log.With(ctx, log.Str(log.MiddlewareName, name), log.Str(log.MiddlewareType, typeName))
log.FromContext(ctxLog).Debug("Creating middleware")
if config.SourceCriterion == nil ||
config.SourceCriterion.IPStrategy == nil &&
config.SourceCriterion.RequestHeaderName == "" && !config.SourceCriterion.RequestHost {
config.SourceCriterion = &dynamic.SourceCriterion{
RequestHost: true,
}
}
sourceMatcher, err := middlewares.GetSourceExtractor(ctxLog, config.SourceCriterion)
if err != nil {
return nil, fmt.Errorf("error creating requests limiter: %v", err)
}
handler, err := connlimit.New(next, sourceMatcher, config.Amount)
if err != nil {
return nil, fmt.Errorf("error creating connection limit: %v", err)
}
return &inFlightReq{handler: handler, name: name}, nil
}
func (i *inFlightReq) GetTracingInformation() (string, ext.SpanKindEnum) {
return i.name, tracing.SpanKindNoneEnum
}
func (i *inFlightReq) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
i.handler.ServeHTTP(rw, req)
}

View file

@ -1,48 +0,0 @@
package maxconnection
import (
"context"
"fmt"
"net/http"
"github.com/containous/traefik/v2/pkg/config/dynamic"
"github.com/containous/traefik/v2/pkg/middlewares"
"github.com/containous/traefik/v2/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/vulcand/oxy/connlimit"
"github.com/vulcand/oxy/utils"
)
const (
typeName = "MaxConnection"
)
type maxConnection struct {
handler http.Handler
name string
}
// New creates a max connection middleware.
func New(ctx context.Context, next http.Handler, maxConns dynamic.MaxConn, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
extractFunc, err := utils.NewExtractor(maxConns.ExtractorFunc)
if err != nil {
return nil, fmt.Errorf("error creating connection limit: %v", err)
}
handler, err := connlimit.New(next, extractFunc, maxConns.Amount)
if err != nil {
return nil, fmt.Errorf("error creating connection limit: %v", err)
}
return &maxConnection{handler: handler, name: name}, nil
}
func (mc *maxConnection) GetTracingInformation() (string, ext.SpanKindEnum) {
return mc.name, tracing.SpanKindNoneEnum
}
func (mc *maxConnection) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
mc.handler.ServeHTTP(rw, req)
}

View file

@ -1,54 +1,146 @@
// Package ratelimiter implements a rate limiting and traffic shaping middleware with a set of token buckets.
package ratelimiter
import (
"context"
"fmt"
"net/http"
"sync"
"time"
"github.com/containous/traefik/v2/pkg/config/dynamic"
"github.com/containous/traefik/v2/pkg/log"
"github.com/containous/traefik/v2/pkg/middlewares"
"github.com/containous/traefik/v2/pkg/tracing"
"github.com/mailgun/ttlmap"
"github.com/opentracing/opentracing-go/ext"
"github.com/vulcand/oxy/ratelimit"
"github.com/vulcand/oxy/utils"
"golang.org/x/time/rate"
)
const (
typeName = "RateLimiterType"
typeName = "RateLimiterType"
maxSources = 65536
)
// rateLimiter implements rate limiting and traffic shaping with a set of token buckets;
// one for each traffic source. The same parameters are applied to all the buckets.
type rateLimiter struct {
handler http.Handler
name string
name string
rate rate.Limit // reqs/s
burst int64
// maxDelay is the maximum duration we're willing to wait for a bucket reservation to become effective, in nanoseconds.
// For now it is somewhat arbitrarily set to 1/rate.
maxDelay time.Duration
sourceMatcher utils.SourceExtractor
next http.Handler
bucketsMu sync.Mutex
buckets *ttlmap.TtlMap // actual buckets, keyed by source.
}
// New creates rate limiter middleware.
// New returns a rate limiter middleware.
func New(ctx context.Context, next http.Handler, config dynamic.RateLimit, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
ctxLog := log.With(ctx, log.Str(log.MiddlewareName, name), log.Str(log.MiddlewareType, typeName))
log.FromContext(ctxLog).Debug("Creating middleware")
extractFunc, err := utils.NewExtractor(config.ExtractorFunc)
if err != nil {
return nil, err
}
rateSet := ratelimit.NewRateSet()
for _, rate := range config.RateSet {
if err = rateSet.Add(time.Duration(rate.Period), rate.Average, rate.Burst); err != nil {
return nil, err
if config.SourceCriterion == nil ||
config.SourceCriterion.IPStrategy == nil &&
config.SourceCriterion.RequestHeaderName == "" && !config.SourceCriterion.RequestHost {
config.SourceCriterion = &dynamic.SourceCriterion{
IPStrategy: &dynamic.IPStrategy{},
}
}
rl, err := ratelimit.New(next, extractFunc, rateSet)
sourceMatcher, err := middlewares.GetSourceExtractor(ctxLog, config.SourceCriterion)
if err != nil {
return nil, err
}
return &rateLimiter{handler: rl, name: name}, nil
buckets, err := ttlmap.NewMap(maxSources)
if err != nil {
return nil, err
}
burst := config.Burst
if burst <= 0 {
burst = 1
}
// Logically, we should set maxDelay to ~infinity when config.Average == 0 (because it means to rate limiting),
// but since the reservation will give us a delay = 0 anyway in this case, we're good even with any maxDelay >= 0.
var maxDelay time.Duration
if config.Average != 0 {
maxDelay = time.Second / time.Duration(config.Average*2)
}
return &rateLimiter{
name: name,
rate: rate.Limit(config.Average),
burst: burst,
maxDelay: maxDelay,
next: next,
sourceMatcher: sourceMatcher,
buckets: buckets,
}, nil
}
func (r *rateLimiter) GetTracingInformation() (string, ext.SpanKindEnum) {
return r.name, tracing.SpanKindNoneEnum
func (rl *rateLimiter) GetTracingInformation() (string, ext.SpanKindEnum) {
return rl.name, tracing.SpanKindNoneEnum
}
func (r *rateLimiter) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
r.handler.ServeHTTP(rw, req)
func (rl *rateLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
logger := middlewares.GetLogger(r.Context(), rl.name, typeName)
source, amount, err := rl.sourceMatcher.Extract(r)
if err != nil {
logger.Errorf("could not extract source of request: %v", err)
http.Error(w, "could not extract source of request", http.StatusInternalServerError)
return
}
if amount != 1 {
logger.Infof("ignoring token bucket amount > 1: %d", amount)
}
rl.bucketsMu.Lock()
defer rl.bucketsMu.Unlock()
var bucket *rate.Limiter
if rlSource, exists := rl.buckets.Get(source); exists {
bucket = rlSource.(*rate.Limiter)
} else {
bucket = rate.NewLimiter(rl.rate, int(rl.burst))
if err := rl.buckets.Set(source, bucket, int(rl.maxDelay)*10+1); err != nil {
logger.Errorf("could not insert bucket: %v", err)
http.Error(w, "could not insert bucket", http.StatusInternalServerError)
return
}
}
res := bucket.Reserve()
if !res.OK() {
http.Error(w, "No bursty traffic allowed", http.StatusTooManyRequests)
return
}
delay := res.Delay()
if delay > rl.maxDelay {
res.Cancel()
rl.serveDelayError(w, r, delay)
return
}
time.Sleep(delay)
rl.next.ServeHTTP(w, r)
}
func (rl *rateLimiter) serveDelayError(w http.ResponseWriter, r *http.Request, delay time.Duration) {
w.Header().Set("Retry-After", fmt.Sprintf("%.0f", delay.Seconds()))
w.Header().Set("X-Retry-In", delay.String())
w.WriteHeader(http.StatusTooManyRequests)
if _, err := w.Write([]byte(http.StatusText(http.StatusTooManyRequests))); err != nil {
middlewares.GetLogger(r.Context(), rl.name, typeName).Errorf("could not serve 429: %v", err)
}
}

View file

@ -0,0 +1,160 @@
package ratelimiter
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/containous/traefik/v2/pkg/config/dynamic"
"github.com/containous/traefik/v2/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vulcand/oxy/utils"
)
func TestNewRateLimiter(t *testing.T) {
testCases := []struct {
desc string
config dynamic.RateLimit
expectedMaxDelay time.Duration
expectedSourceIP string
}{
{
desc: "maxDelay computation",
config: dynamic.RateLimit{
Average: 200,
Burst: 10,
},
expectedMaxDelay: 2500 * time.Microsecond,
},
{
desc: "default SourceMatcher is remote address ip strategy",
config: dynamic.RateLimit{
Average: 200,
Burst: 10,
},
expectedSourceIP: "127.0.0.1",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
h, err := New(context.Background(), next, test.config, "rate-limiter")
require.NoError(t, err)
rtl, _ := h.(*rateLimiter)
if test.expectedMaxDelay != 0 {
assert.Equal(t, test.expectedMaxDelay, rtl.maxDelay)
}
if test.expectedSourceIP != "" {
extractor, ok := rtl.sourceMatcher.(utils.ExtractorFunc)
require.True(t, ok, "Not an ExtractorFunc")
req := http.Request{
RemoteAddr: fmt.Sprintf("%s:1234", test.expectedSourceIP),
}
ip, _, err := extractor(&req)
assert.NoError(t, err)
assert.Equal(t, test.expectedSourceIP, ip)
}
})
}
}
func TestRateLimit(t *testing.T) {
testCases := []struct {
desc string
config dynamic.RateLimit
reqCount int
}{
{
desc: "Average is respected",
config: dynamic.RateLimit{
Average: 100,
Burst: 1,
},
reqCount: 200,
},
{
desc: "Burst is taken into account",
config: dynamic.RateLimit{
Average: 100,
Burst: 200,
},
reqCount: 300,
},
{
desc: "Zero average ==> no rate limiting",
config: dynamic.RateLimit{
Average: 0,
Burst: 1,
},
reqCount: 100,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
reqCount := 0
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqCount++
})
h, err := New(context.Background(), next, test.config, "rate-limiter")
require.NoError(t, err)
start := time.Now()
for {
if reqCount >= test.reqCount {
break
}
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
req.RemoteAddr = "127.0.0.1:1234"
w := httptest.NewRecorder()
h.ServeHTTP(w, req)
// TODO(mpl): predict and count the 200 VS the 429?
}
stop := time.Now()
elapsed := stop.Sub(start)
if test.config.Average == 0 {
if elapsed > time.Millisecond {
t.Fatalf("rate should not have been limited, but: %d requests in %v", reqCount, elapsed)
}
return
}
// Assume allowed burst is initially consumed in an infinitesimal period of time
var expectedDuration time.Duration
if test.config.Average != 0 {
expectedDuration = time.Duration((int64(test.reqCount)-test.config.Burst+1)/test.config.Average) * time.Second
}
// Allow for a 2% leeway
minDuration := expectedDuration * 98 / 100
maxDuration := expectedDuration * 102 / 100
if elapsed < minDuration {
t.Fatalf("rate was faster than expected: %d requests in %v", reqCount, elapsed)
}
if elapsed > maxDuration {
t.Fatalf("rate was slower than expected: %d requests in %v", reqCount, elapsed)
}
})
}
}