1
0
Fork 0

Add Redis rate limiter

This commit is contained in:
longquan0104 2025-03-10 17:02:05 +07:00 committed by GitHub
parent c166a41c99
commit 550d96ea67
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 2268 additions and 69 deletions

View file

@ -0,0 +1,72 @@
package ratelimiter
import (
"context"
"fmt"
"time"
"github.com/mailgun/ttlmap"
"github.com/rs/zerolog"
"golang.org/x/time/rate"
)
type inMemoryRateLimiter struct {
rate rate.Limit // reqs/s
burst int64
// maxDelay is the maximum duration we're willing to wait for a bucket reservation to become effective, in nanoseconds.
// For now it is somewhat arbitrarily set to 1/(2*rate).
maxDelay time.Duration
// Each rate limiter for a given source is stored in the buckets ttlmap.
// To keep this ttlmap constrained in size,
// each ratelimiter is "garbage collected" when it is considered expired.
// It is considered expired after it hasn't been used for ttl seconds.
ttl int
buckets *ttlmap.TtlMap // actual buckets, keyed by source.
logger *zerolog.Logger
}
func newInMemoryRateLimiter(rate rate.Limit, burst int64, maxDelay time.Duration, ttl int, logger *zerolog.Logger) (*inMemoryRateLimiter, error) {
buckets, err := ttlmap.NewConcurrent(maxSources)
if err != nil {
return nil, fmt.Errorf("creating ttlmap: %w", err)
}
return &inMemoryRateLimiter{
rate: rate,
burst: burst,
maxDelay: maxDelay,
ttl: ttl,
logger: logger,
buckets: buckets,
}, nil
}
func (i *inMemoryRateLimiter) Allow(_ context.Context, source string) (*time.Duration, error) {
// Get bucket which contains limiter information.
var bucket *rate.Limiter
if rlSource, exists := i.buckets.Get(source); exists {
bucket = rlSource.(*rate.Limiter)
} else {
bucket = rate.NewLimiter(i.rate, int(i.burst))
}
// We Set even in the case where the source already exists,
// because we want to update the expiryTime everytime we get the source,
// as the expiryTime is supposed to reflect the activity (or lack thereof) on that source.
if err := i.buckets.Set(source, bucket, i.ttl); err != nil {
return nil, fmt.Errorf("setting buckets: %w", err)
}
res := bucket.Reserve()
if !res.OK() {
return nil, nil
}
delay := res.Delay()
if delay > i.maxDelay {
res.Cancel()
}
return &delay, nil
}

View file

@ -0,0 +1,66 @@
package ratelimiter
import (
"context"
"github.com/redis/go-redis/v9"
)
type Rediser interface {
Eval(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd
EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd
ScriptExists(ctx context.Context, hashes ...string) *redis.BoolSliceCmd
ScriptLoad(ctx context.Context, script string) *redis.StringCmd
Del(ctx context.Context, keys ...string) *redis.IntCmd
EvalRO(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd
EvalShaRO(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd
}
//nolint:dupword
var AllowTokenBucketRaw = `
local key = KEYS[1]
local limit, burst, ttl, t, max_delay = tonumber(ARGV[1]), tonumber(ARGV[2]), tonumber(ARGV[3]), tonumber(ARGV[4]),
tonumber(ARGV[5])
local bucket = {
limit = limit,
burst = burst,
tokens = 0,
last = 0
}
local rl_source = redis.call('hgetall', key)
if table.maxn(rl_source) == 4 then
-- Get bucket state from redis
bucket.last = tonumber(rl_source[2])
bucket.tokens = tonumber(rl_source[4])
end
local last = bucket.last
if t < last then
last = t
end
local elapsed = t - last
local delta = bucket.limit * elapsed
local tokens = bucket.tokens + delta
tokens = math.min(tokens, bucket.burst)
tokens = tokens - 1
local wait_duration = 0
if tokens < 0 then
wait_duration = (tokens * -1) / bucket.limit
if wait_duration > max_delay then
tokens = tokens + 1
tokens = math.min(tokens, burst)
end
end
redis.call('hset', key, 'last', t, 'tokens', tokens)
redis.call('expire', key, ttl)
return {tostring(true), tostring(wait_duration),tostring(tokens)}`
var AllowTokenBucketScript = redis.NewScript(AllowTokenBucketRaw)

85
pkg/middlewares/ratelimiter/rate_limiter.go Normal file → Executable file
View file

@ -8,7 +8,7 @@ import (
"net/http"
"time"
"github.com/mailgun/ttlmap"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/middlewares"
@ -23,24 +23,23 @@ const (
maxSources = 65536
)
type limiter interface {
Allow(ctx context.Context, token string) (*time.Duration, error)
}
// rateLimiter implements rate limiting and traffic shaping with a set of token buckets;
// one for each traffic source. The same parameters are applied to all the buckets.
type rateLimiter struct {
name string
rate rate.Limit // reqs/s
burst int64
name string
rate rate.Limit // reqs/s
// maxDelay is the maximum duration we're willing to wait for a bucket reservation to become effective, in nanoseconds.
// For now it is somewhat arbitrarily set to 1/(2*rate).
maxDelay time.Duration
// each rate limiter for a given source is stored in the buckets ttlmap.
// To keep this ttlmap constrained in size,
// each ratelimiter is "garbage collected" when it is considered expired.
// It is considered expired after it hasn't been used for ttl seconds.
ttl int
maxDelay time.Duration
sourceMatcher utils.SourceExtractor
next http.Handler
logger *zerolog.Logger
buckets *ttlmap.TtlMap // actual buckets, keyed by source.
limiter limiter
}
// New returns a rate limiter middleware.
@ -60,12 +59,7 @@ func New(ctx context.Context, next http.Handler, config dynamic.RateLimit, name
sourceMatcher, err := middlewares.GetSourceExtractor(ctxLog, config.SourceCriterion)
if err != nil {
return nil, err
}
buckets, err := ttlmap.NewConcurrent(maxSources)
if err != nil {
return nil, err
return nil, fmt.Errorf("getting source extractor: %w", err)
}
burst := config.Burst
@ -109,16 +103,27 @@ func New(ctx context.Context, next http.Handler, config dynamic.RateLimit, name
} else if rtl > 0 {
ttl += int(1 / rtl)
}
var limiter limiter
if config.Redis != nil {
limiter, err = newRedisLimiter(ctx, rate.Limit(rtl), burst, maxDelay, ttl, config, logger)
if err != nil {
return nil, fmt.Errorf("creating redis limiter: %w", err)
}
} else {
limiter, err = newInMemoryRateLimiter(rate.Limit(rtl), burst, maxDelay, ttl, logger)
if err != nil {
return nil, fmt.Errorf("creating in-memory limiter: %w", err)
}
}
return &rateLimiter{
logger: logger,
name: name,
rate: rate.Limit(rtl),
burst: burst,
maxDelay: maxDelay,
next: next,
sourceMatcher: sourceMatcher,
buckets: buckets,
ttl: ttl,
limiter: limiter,
}, nil
}
@ -141,38 +146,34 @@ func (rl *rateLimiter) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger.Info().Msgf("ignoring token bucket amount > 1: %d", amount)
}
var bucket *rate.Limiter
if rlSource, exists := rl.buckets.Get(source); exists {
bucket = rlSource.(*rate.Limiter)
} else {
bucket = rate.NewLimiter(rl.rate, int(rl.burst))
}
// We Set even in the case where the source already exists,
// because we want to update the expiryTime every time we get the source,
// as the expiryTime is supposed to reflect the activity (or lack thereof) on that source.
if err := rl.buckets.Set(source, bucket, rl.ttl); err != nil {
logger.Error().Err(err).Msg("Could not insert/update bucket")
observability.SetStatusErrorf(req.Context(), "Could not insert/update bucket")
http.Error(rw, "could not insert/update bucket", http.StatusInternalServerError)
delay, err := rl.limiter.Allow(ctx, source)
if err != nil {
rl.logger.Error().Err(err).Msg("Could not insert/update bucket")
observability.SetStatusErrorf(ctx, "Could not insert/update bucket")
http.Error(rw, "Could not insert/update bucket", http.StatusInternalServerError)
return
}
res := bucket.Reserve()
if !res.OK() {
observability.SetStatusErrorf(req.Context(), "No bursty traffic allowed")
if delay == nil {
observability.SetStatusErrorf(ctx, "No bursty traffic allowed")
http.Error(rw, "No bursty traffic allowed", http.StatusTooManyRequests)
return
}
delay := res.Delay()
if delay > rl.maxDelay {
res.Cancel()
rl.serveDelayError(ctx, rw, delay)
if *delay > rl.maxDelay {
rl.serveDelayError(ctx, rw, *delay)
return
}
time.Sleep(delay)
select {
case <-ctx.Done():
observability.SetStatusErrorf(ctx, "Context canceled")
http.Error(rw, "context canceled", http.StatusInternalServerError)
return
case <-time.After(*delay):
}
rl.next.ServeHTTP(rw, req)
}

View file

@ -2,19 +2,25 @@ package ratelimiter
import (
"context"
"errors"
"fmt"
"math/rand"
"net/http"
"net/http/httptest"
"os"
"strconv"
"testing"
"time"
"github.com/mailgun/ttlmap"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
ptypes "github.com/traefik/paerser/types"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/testhelpers"
"github.com/vulcand/oxy/v2/utils"
lua "github.com/yuin/gopher-lua"
"golang.org/x/time/rate"
)
@ -84,7 +90,17 @@ func TestNewRateLimiter(t *testing.T) {
RequestHeaderName: "Foo",
},
},
expectedError: "iPStrategy and RequestHeaderName are mutually exclusive",
expectedError: "getting source extractor: iPStrategy and RequestHeaderName are mutually exclusive",
},
{
desc: "Use Redis",
config: dynamic.RateLimit{
Average: 200,
Burst: 10,
Redis: &dynamic.Redis{
Endpoints: []string{"localhost:6379"},
},
},
},
}
@ -138,7 +154,7 @@ func TestNewRateLimiter(t *testing.T) {
}
}
func TestRateLimit(t *testing.T) {
func TestInMemoryRateLimit(t *testing.T) {
testCases := []struct {
desc string
config dynamic.RateLimit
@ -326,15 +342,357 @@ func TestRateLimit(t *testing.T) {
minCount := computeMinCount(wantCount)
if reqCount < minCount {
t.Fatalf("rate was slower than expected: %d requests (wanted > %d) in %v", reqCount, minCount, elapsed)
t.Fatalf("rate was slower than expected: %d requests (wanted > %d) (dropped %d) in %v", reqCount, minCount, dropped, elapsed)
}
if reqCount > maxCount {
t.Fatalf("rate was faster than expected: %d requests (wanted < %d) in %v", reqCount, maxCount, elapsed)
t.Fatalf("rate was faster than expected: %d requests (wanted < %d) (dropped %d) in %v", reqCount, maxCount, dropped, elapsed)
}
})
}
}
func TestRedisRateLimit(t *testing.T) {
testCases := []struct {
desc string
config dynamic.RateLimit
loadDuration time.Duration
incomingLoad int // in reqs/s
burst int
}{
{
desc: "Average is respected",
config: dynamic.RateLimit{
Average: 100,
Burst: 1,
},
loadDuration: 2 * time.Second,
incomingLoad: 400,
},
{
desc: "burst allowed, no bursty traffic",
config: dynamic.RateLimit{
Average: 100,
Burst: 100,
},
loadDuration: 2 * time.Second,
incomingLoad: 200,
},
{
desc: "burst allowed, initial burst, under capacity",
config: dynamic.RateLimit{
Average: 100,
Burst: 100,
},
loadDuration: 2 * time.Second,
incomingLoad: 200,
burst: 50,
},
{
desc: "burst allowed, initial burst, over capacity",
config: dynamic.RateLimit{
Average: 100,
Burst: 100,
},
loadDuration: 2 * time.Second,
incomingLoad: 200,
burst: 150,
},
{
desc: "burst over average, initial burst, over capacity",
config: dynamic.RateLimit{
Average: 100,
Burst: 200,
},
loadDuration: 2 * time.Second,
incomingLoad: 200,
burst: 300,
},
{
desc: "lower than 1/s",
config: dynamic.RateLimit{
// Bug on gopher-lua on parsing the string to number "5e-07" => 0.0000005
// See https://github.com/yuin/gopher-lua/issues/491
// Average: 5,
Average: 1,
Period: ptypes.Duration(10 * time.Second),
},
loadDuration: 2 * time.Second,
incomingLoad: 100,
burst: 0,
},
{
desc: "lower than 1/s, longer",
config: dynamic.RateLimit{
// Bug on gopher-lua on parsing the string to number "5e-07" => 0.0000005
// See https://github.com/yuin/gopher-lua/issues/491
// Average: 5,
Average: 1,
Period: ptypes.Duration(10 * time.Second),
},
loadDuration: time.Minute,
incomingLoad: 100,
burst: 0,
},
{
desc: "lower than 1/s, longer, harsher",
config: dynamic.RateLimit{
Average: 1,
Period: ptypes.Duration(time.Minute),
},
loadDuration: time.Minute,
incomingLoad: 100,
burst: 0,
},
{
desc: "period below 1 second",
config: dynamic.RateLimit{
Average: 50,
Period: ptypes.Duration(500 * time.Millisecond),
},
loadDuration: 2 * time.Second,
incomingLoad: 300,
burst: 0,
},
// TODO Try to disambiguate when it fails if it is because of too high a load.
// {
// desc: "Zero average ==> no rate limiting",
// config: dynamic.RateLimit{
// Average: 0,
// Burst: 1,
// },
// incomingLoad: 1000,
// loadDuration: time.Second,
// },
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
randPort := rand.Int()
if test.loadDuration >= time.Minute && testing.Short() {
t.Skip("skipping test in short mode.")
}
t.Parallel()
reqCount := 0
dropped := 0
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqCount++
})
test.config.Redis = &dynamic.Redis{
Endpoints: []string{"localhost:6379"},
}
h, err := New(context.Background(), next, test.config, "rate-limiter")
require.NoError(t, err)
l := h.(*rateLimiter)
limiter := l.limiter.(*redisLimiter)
limiter.client = newMockRedisClient(limiter.ttl)
h = l
loadPeriod := time.Duration(1e9 / test.incomingLoad)
start := time.Now()
end := start.Add(test.loadDuration)
ticker := time.NewTicker(loadPeriod)
defer ticker.Stop()
for {
if time.Now().After(end) {
break
}
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
req.RemoteAddr = "127.0.0." + strconv.Itoa(randPort) + ":" + strconv.Itoa(randPort)
w := httptest.NewRecorder()
h.ServeHTTP(w, req)
if w.Result().StatusCode != http.StatusOK {
dropped++
}
if test.burst > 0 && reqCount < test.burst {
// if a burst is defined we first hammer the server with test.burst requests as fast as possible
continue
}
<-ticker.C
}
stop := time.Now()
elapsed := stop.Sub(start)
burst := test.config.Burst
if burst < 1 {
// actual default value
burst = 1
}
period := time.Duration(test.config.Period)
if period == 0 {
period = time.Second
}
if test.config.Average == 0 {
if reqCount < 75*test.incomingLoad/100 {
t.Fatalf("we (arbitrarily) expect at least 75%% of the requests to go through with no rate limiting, and yet only %d/%d went through", reqCount, test.incomingLoad)
}
if dropped != 0 {
t.Fatalf("no request should have been dropped if rate limiting is disabled, and yet %d were", dropped)
}
return
}
// Note that even when there is no bursty traffic,
// we take into account the configured burst,
// because it also helps absorbing non-bursty traffic.
rate := float64(test.config.Average) / float64(period)
wantCount := int(int64(rate*float64(test.loadDuration)) + burst)
// Allow for a 2% leeway
maxCount := wantCount * 102 / 100
// With very high CPU loads,
// we can expect some extra delay in addition to the rate limiting we already do,
// so we allow for some extra leeway there.
// Feel free to adjust wrt to the load on e.g. the CI.
minCount := computeMinCount(wantCount)
if reqCount < minCount {
t.Fatalf("rate was slower than expected: %d requests (wanted > %d) (dropped %d) in %v", reqCount, minCount, dropped, elapsed)
}
if reqCount > maxCount {
t.Fatalf("rate was faster than expected: %d requests (wanted < %d) (dropped %d) in %v", reqCount, maxCount, dropped, elapsed)
}
})
}
}
type mockRedisClient struct {
ttl int
keys *ttlmap.TtlMap
}
func newMockRedisClient(ttl int) Rediser {
buckets, _ := ttlmap.NewConcurrent(65536)
return &mockRedisClient{
ttl: ttl,
keys: buckets,
}
}
func (m *mockRedisClient) EvalSha(ctx context.Context, _ string, keys []string, args ...interface{}) *redis.Cmd {
state := lua.NewState()
defer state.Close()
tableKeys := state.NewTable()
for _, key := range keys {
tableKeys.Append(lua.LString(key))
}
state.SetGlobal("KEYS", tableKeys)
tableArgv := state.NewTable()
for _, arg := range args {
tableArgv.Append(lua.LString(fmt.Sprint(arg)))
}
state.SetGlobal("ARGV", tableArgv)
mod := state.SetFuncs(state.NewTable(), map[string]lua.LGFunction{
"call": func(state *lua.LState) int {
switch state.Get(1).String() {
case "hset":
key := state.Get(2).String()
keyLast := state.Get(3).String()
last := state.Get(4).String()
keyTokens := state.Get(5).String()
tokens := state.Get(6).String()
table := []string{keyLast, last, keyTokens, tokens}
_ = m.keys.Set(key, table, m.ttl)
case "hgetall":
key := state.Get(2).String()
value, ok := m.keys.Get(key)
table := state.NewTable()
if !ok {
state.Push(table)
} else {
switch v := value.(type) {
case []string:
if len(v) != 4 {
break
}
for i := range v {
table.Append(lua.LString(v[i]))
}
default:
fmt.Printf("Unknown type: %T\n", v)
}
state.Push(table)
}
case "expire":
default:
return 0
}
return 1
},
})
state.SetGlobal("redis", mod)
state.Push(mod)
cmd := redis.NewCmd(ctx)
if err := state.DoString(AllowTokenBucketRaw); err != nil {
cmd.SetErr(err)
return cmd
}
result := state.Get(2)
resultTable, ok := result.(*lua.LTable)
if !ok {
cmd.SetErr(errors.New("unexpected response type: " + result.String()))
return cmd
}
var resultSlice []interface{}
resultTable.ForEach(func(_ lua.LValue, value lua.LValue) {
valueNbr, ok := value.(lua.LNumber)
if !ok {
valueStr, ok := value.(lua.LString)
if !ok {
cmd.SetErr(errors.New("unexpected response value type " + value.String()))
}
resultSlice = append(resultSlice, string(valueStr))
return
}
resultSlice = append(resultSlice, int64(valueNbr))
})
cmd.SetVal(resultSlice)
return cmd
}
func (m *mockRedisClient) Eval(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd {
return m.EvalSha(ctx, script, keys, args...)
}
func (m *mockRedisClient) ScriptExists(ctx context.Context, hashes ...string) *redis.BoolSliceCmd {
return nil
}
func (m *mockRedisClient) ScriptLoad(ctx context.Context, script string) *redis.StringCmd {
return nil
}
func (m *mockRedisClient) Del(ctx context.Context, keys ...string) *redis.IntCmd {
return nil
}
func (m *mockRedisClient) EvalRO(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd {
return nil
}
func (m *mockRedisClient) EvalShaRO(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd {
return nil
}
func computeMinCount(wantCount int) int {
if os.Getenv("CI") != "" {
return wantCount * 60 / 100

View file

@ -0,0 +1,118 @@
package ratelimiter
import (
"context"
"fmt"
"strconv"
"time"
"github.com/redis/go-redis/v9"
"github.com/rs/zerolog"
ptypes "github.com/traefik/paerser/types"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"golang.org/x/time/rate"
)
const redisPrefix = "rate:"
type redisLimiter struct {
rate rate.Limit // reqs/s
burst int64
maxDelay time.Duration
period ptypes.Duration
logger *zerolog.Logger
ttl int
client Rediser
}
func newRedisLimiter(ctx context.Context, rate rate.Limit, burst int64, maxDelay time.Duration, ttl int, config dynamic.RateLimit, logger *zerolog.Logger) (limiter, error) {
options := &redis.UniversalOptions{
Addrs: config.Redis.Endpoints,
Username: config.Redis.Username,
Password: config.Redis.Password,
DB: config.Redis.DB,
PoolSize: config.Redis.PoolSize,
MinIdleConns: config.Redis.MinIdleConns,
MaxActiveConns: config.Redis.MaxActiveConns,
}
if config.Redis.DialTimeout != nil && *config.Redis.DialTimeout > 0 {
options.DialTimeout = time.Duration(*config.Redis.DialTimeout)
}
if config.Redis.ReadTimeout != nil {
if *config.Redis.ReadTimeout > 0 {
options.ReadTimeout = time.Duration(*config.Redis.ReadTimeout)
} else {
options.ReadTimeout = -1
}
}
if config.Redis.WriteTimeout != nil {
if *config.Redis.ReadTimeout > 0 {
options.WriteTimeout = time.Duration(*config.Redis.WriteTimeout)
} else {
options.WriteTimeout = -1
}
}
if config.Redis.TLS != nil {
var err error
options.TLSConfig, err = config.Redis.TLS.CreateTLSConfig(ctx)
if err != nil {
return nil, fmt.Errorf("creating TLS config: %w", err)
}
}
return &redisLimiter{
rate: rate,
burst: burst,
period: config.Period,
maxDelay: maxDelay,
logger: logger,
ttl: ttl,
client: redis.NewUniversalClient(options),
}, nil
}
func (r *redisLimiter) Allow(ctx context.Context, source string) (*time.Duration, error) {
ok, delay, err := r.evaluateScript(ctx, source)
if err != nil {
return nil, fmt.Errorf("evaluating script: %w", err)
}
if !ok {
return nil, nil
}
return delay, nil
}
func (r *redisLimiter) evaluateScript(ctx context.Context, key string) (bool, *time.Duration, error) {
if r.rate == rate.Inf {
return true, nil, nil
}
params := []interface{}{
float64(r.rate / 1000000),
r.burst,
r.ttl,
time.Now().UnixMicro(),
r.maxDelay.Microseconds(),
}
v, err := AllowTokenBucketScript.Run(ctx, r.client, []string{redisPrefix + key}, params...).Result()
if err != nil {
return false, nil, fmt.Errorf("running script: %w", err)
}
values := v.([]interface{})
ok, err := strconv.ParseBool(values[0].(string))
if err != nil {
return false, nil, fmt.Errorf("parsing ok value from redis rate lua script: %w", err)
}
delay, err := strconv.ParseFloat(values[1].(string), 64)
if err != nil {
return false, nil, fmt.Errorf("parsing delay value from redis rate lua script: %w", err)
}
microDelay := time.Duration(delay * float64(time.Microsecond))
return ok, &microDelay, nil
}