traefik/pkg/middlewares/ratelimiter/lua.go
2025-03-10 11:02:05 +01:00

66 lines
1.8 KiB
Go

package ratelimiter
import (
"context"
"github.com/redis/go-redis/v9"
)
type Rediser interface {
Eval(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd
EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd
ScriptExists(ctx context.Context, hashes ...string) *redis.BoolSliceCmd
ScriptLoad(ctx context.Context, script string) *redis.StringCmd
Del(ctx context.Context, keys ...string) *redis.IntCmd
EvalRO(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd
EvalShaRO(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd
}
//nolint:dupword
var AllowTokenBucketRaw = `
local key = KEYS[1]
local limit, burst, ttl, t, max_delay = tonumber(ARGV[1]), tonumber(ARGV[2]), tonumber(ARGV[3]), tonumber(ARGV[4]),
tonumber(ARGV[5])
local bucket = {
limit = limit,
burst = burst,
tokens = 0,
last = 0
}
local rl_source = redis.call('hgetall', key)
if table.maxn(rl_source) == 4 then
-- Get bucket state from redis
bucket.last = tonumber(rl_source[2])
bucket.tokens = tonumber(rl_source[4])
end
local last = bucket.last
if t < last then
last = t
end
local elapsed = t - last
local delta = bucket.limit * elapsed
local tokens = bucket.tokens + delta
tokens = math.min(tokens, bucket.burst)
tokens = tokens - 1
local wait_duration = 0
if tokens < 0 then
wait_duration = (tokens * -1) / bucket.limit
if wait_duration > max_delay then
tokens = tokens + 1
tokens = math.min(tokens, burst)
end
end
redis.call('hset', key, 'last', t, 'tokens', tokens)
redis.call('expire', key, ttl)
return {tostring(true), tostring(wait_duration),tostring(tokens)}`
var AllowTokenBucketScript = redis.NewScript(AllowTokenBucketRaw)