1
0
Fork 0

Add Redis rate limiter

This commit is contained in:
longquan0104 2025-03-10 17:02:05 +07:00 committed by GitHub
parent c166a41c99
commit 550d96ea67
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 2268 additions and 69 deletions

85
pkg/middlewares/ratelimiter/rate_limiter.go Normal file → Executable file
View file

@ -8,7 +8,7 @@ import (
"net/http"
"time"
"github.com/mailgun/ttlmap"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/middlewares"
@ -23,24 +23,23 @@ const (
maxSources = 65536
)
type limiter interface {
Allow(ctx context.Context, token string) (*time.Duration, error)
}
// rateLimiter implements rate limiting and traffic shaping with a set of token buckets;
// one for each traffic source. The same parameters are applied to all the buckets.
type rateLimiter struct {
name string
rate rate.Limit // reqs/s
burst int64
name string
rate rate.Limit // reqs/s
// maxDelay is the maximum duration we're willing to wait for a bucket reservation to become effective, in nanoseconds.
// For now it is somewhat arbitrarily set to 1/(2*rate).
maxDelay time.Duration
// each rate limiter for a given source is stored in the buckets ttlmap.
// To keep this ttlmap constrained in size,
// each ratelimiter is "garbage collected" when it is considered expired.
// It is considered expired after it hasn't been used for ttl seconds.
ttl int
maxDelay time.Duration
sourceMatcher utils.SourceExtractor
next http.Handler
logger *zerolog.Logger
buckets *ttlmap.TtlMap // actual buckets, keyed by source.
limiter limiter
}
// New returns a rate limiter middleware.
@ -60,12 +59,7 @@ func New(ctx context.Context, next http.Handler, config dynamic.RateLimit, name
sourceMatcher, err := middlewares.GetSourceExtractor(ctxLog, config.SourceCriterion)
if err != nil {
return nil, err
}
buckets, err := ttlmap.NewConcurrent(maxSources)
if err != nil {
return nil, err
return nil, fmt.Errorf("getting source extractor: %w", err)
}
burst := config.Burst
@ -109,16 +103,27 @@ func New(ctx context.Context, next http.Handler, config dynamic.RateLimit, name
} else if rtl > 0 {
ttl += int(1 / rtl)
}
var limiter limiter
if config.Redis != nil {
limiter, err = newRedisLimiter(ctx, rate.Limit(rtl), burst, maxDelay, ttl, config, logger)
if err != nil {
return nil, fmt.Errorf("creating redis limiter: %w", err)
}
} else {
limiter, err = newInMemoryRateLimiter(rate.Limit(rtl), burst, maxDelay, ttl, logger)
if err != nil {
return nil, fmt.Errorf("creating in-memory limiter: %w", err)
}
}
return &rateLimiter{
logger: logger,
name: name,
rate: rate.Limit(rtl),
burst: burst,
maxDelay: maxDelay,
next: next,
sourceMatcher: sourceMatcher,
buckets: buckets,
ttl: ttl,
limiter: limiter,
}, nil
}
@ -141,38 +146,34 @@ func (rl *rateLimiter) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger.Info().Msgf("ignoring token bucket amount > 1: %d", amount)
}
var bucket *rate.Limiter
if rlSource, exists := rl.buckets.Get(source); exists {
bucket = rlSource.(*rate.Limiter)
} else {
bucket = rate.NewLimiter(rl.rate, int(rl.burst))
}
// We Set even in the case where the source already exists,
// because we want to update the expiryTime every time we get the source,
// as the expiryTime is supposed to reflect the activity (or lack thereof) on that source.
if err := rl.buckets.Set(source, bucket, rl.ttl); err != nil {
logger.Error().Err(err).Msg("Could not insert/update bucket")
observability.SetStatusErrorf(req.Context(), "Could not insert/update bucket")
http.Error(rw, "could not insert/update bucket", http.StatusInternalServerError)
delay, err := rl.limiter.Allow(ctx, source)
if err != nil {
rl.logger.Error().Err(err).Msg("Could not insert/update bucket")
observability.SetStatusErrorf(ctx, "Could not insert/update bucket")
http.Error(rw, "Could not insert/update bucket", http.StatusInternalServerError)
return
}
res := bucket.Reserve()
if !res.OK() {
observability.SetStatusErrorf(req.Context(), "No bursty traffic allowed")
if delay == nil {
observability.SetStatusErrorf(ctx, "No bursty traffic allowed")
http.Error(rw, "No bursty traffic allowed", http.StatusTooManyRequests)
return
}
delay := res.Delay()
if delay > rl.maxDelay {
res.Cancel()
rl.serveDelayError(ctx, rw, delay)
if *delay > rl.maxDelay {
rl.serveDelayError(ctx, rw, *delay)
return
}
time.Sleep(delay)
select {
case <-ctx.Done():
observability.SetStatusErrorf(ctx, "Context canceled")
http.Error(rw, "context canceled", http.StatusInternalServerError)
return
case <-time.After(*delay):
}
rl.next.ServeHTTP(rw, req)
}