Add Redis rate limiter
This commit is contained in:
parent
c166a41c99
commit
550d96ea67
26 changed files with 2268 additions and 69 deletions
85
pkg/middlewares/ratelimiter/rate_limiter.go
Normal file → Executable file
85
pkg/middlewares/ratelimiter/rate_limiter.go
Normal file → Executable file
|
@ -8,7 +8,7 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/mailgun/ttlmap"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
||||
"github.com/traefik/traefik/v3/pkg/middlewares"
|
||||
|
@ -23,24 +23,23 @@ const (
|
|||
maxSources = 65536
|
||||
)
|
||||
|
||||
type limiter interface {
|
||||
Allow(ctx context.Context, token string) (*time.Duration, error)
|
||||
}
|
||||
|
||||
// rateLimiter implements rate limiting and traffic shaping with a set of token buckets;
|
||||
// one for each traffic source. The same parameters are applied to all the buckets.
|
||||
type rateLimiter struct {
|
||||
name string
|
||||
rate rate.Limit // reqs/s
|
||||
burst int64
|
||||
name string
|
||||
rate rate.Limit // reqs/s
|
||||
// maxDelay is the maximum duration we're willing to wait for a bucket reservation to become effective, in nanoseconds.
|
||||
// For now it is somewhat arbitrarily set to 1/(2*rate).
|
||||
maxDelay time.Duration
|
||||
// each rate limiter for a given source is stored in the buckets ttlmap.
|
||||
// To keep this ttlmap constrained in size,
|
||||
// each ratelimiter is "garbage collected" when it is considered expired.
|
||||
// It is considered expired after it hasn't been used for ttl seconds.
|
||||
ttl int
|
||||
maxDelay time.Duration
|
||||
sourceMatcher utils.SourceExtractor
|
||||
next http.Handler
|
||||
logger *zerolog.Logger
|
||||
|
||||
buckets *ttlmap.TtlMap // actual buckets, keyed by source.
|
||||
limiter limiter
|
||||
}
|
||||
|
||||
// New returns a rate limiter middleware.
|
||||
|
@ -60,12 +59,7 @@ func New(ctx context.Context, next http.Handler, config dynamic.RateLimit, name
|
|||
|
||||
sourceMatcher, err := middlewares.GetSourceExtractor(ctxLog, config.SourceCriterion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buckets, err := ttlmap.NewConcurrent(maxSources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("getting source extractor: %w", err)
|
||||
}
|
||||
|
||||
burst := config.Burst
|
||||
|
@ -109,16 +103,27 @@ func New(ctx context.Context, next http.Handler, config dynamic.RateLimit, name
|
|||
} else if rtl > 0 {
|
||||
ttl += int(1 / rtl)
|
||||
}
|
||||
var limiter limiter
|
||||
if config.Redis != nil {
|
||||
limiter, err = newRedisLimiter(ctx, rate.Limit(rtl), burst, maxDelay, ttl, config, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating redis limiter: %w", err)
|
||||
}
|
||||
} else {
|
||||
limiter, err = newInMemoryRateLimiter(rate.Limit(rtl), burst, maxDelay, ttl, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating in-memory limiter: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &rateLimiter{
|
||||
logger: logger,
|
||||
name: name,
|
||||
rate: rate.Limit(rtl),
|
||||
burst: burst,
|
||||
maxDelay: maxDelay,
|
||||
next: next,
|
||||
sourceMatcher: sourceMatcher,
|
||||
buckets: buckets,
|
||||
ttl: ttl,
|
||||
limiter: limiter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -141,38 +146,34 @@ func (rl *rateLimiter) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||
logger.Info().Msgf("ignoring token bucket amount > 1: %d", amount)
|
||||
}
|
||||
|
||||
var bucket *rate.Limiter
|
||||
if rlSource, exists := rl.buckets.Get(source); exists {
|
||||
bucket = rlSource.(*rate.Limiter)
|
||||
} else {
|
||||
bucket = rate.NewLimiter(rl.rate, int(rl.burst))
|
||||
}
|
||||
|
||||
// We Set even in the case where the source already exists,
|
||||
// because we want to update the expiryTime every time we get the source,
|
||||
// as the expiryTime is supposed to reflect the activity (or lack thereof) on that source.
|
||||
if err := rl.buckets.Set(source, bucket, rl.ttl); err != nil {
|
||||
logger.Error().Err(err).Msg("Could not insert/update bucket")
|
||||
observability.SetStatusErrorf(req.Context(), "Could not insert/update bucket")
|
||||
http.Error(rw, "could not insert/update bucket", http.StatusInternalServerError)
|
||||
delay, err := rl.limiter.Allow(ctx, source)
|
||||
if err != nil {
|
||||
rl.logger.Error().Err(err).Msg("Could not insert/update bucket")
|
||||
observability.SetStatusErrorf(ctx, "Could not insert/update bucket")
|
||||
http.Error(rw, "Could not insert/update bucket", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
res := bucket.Reserve()
|
||||
if !res.OK() {
|
||||
observability.SetStatusErrorf(req.Context(), "No bursty traffic allowed")
|
||||
if delay == nil {
|
||||
observability.SetStatusErrorf(ctx, "No bursty traffic allowed")
|
||||
http.Error(rw, "No bursty traffic allowed", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
delay := res.Delay()
|
||||
if delay > rl.maxDelay {
|
||||
res.Cancel()
|
||||
rl.serveDelayError(ctx, rw, delay)
|
||||
if *delay > rl.maxDelay {
|
||||
rl.serveDelayError(ctx, rw, *delay)
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(delay)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
observability.SetStatusErrorf(ctx, "Context canceled")
|
||||
http.Error(rw, "context canceled", http.StatusInternalServerError)
|
||||
return
|
||||
|
||||
case <-time.After(*delay):
|
||||
}
|
||||
|
||||
rl.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue