package ratelimiter import ( "context" "github.com/redis/go-redis/v9" ) type Rediser interface { Eval(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd ScriptExists(ctx context.Context, hashes ...string) *redis.BoolSliceCmd ScriptLoad(ctx context.Context, script string) *redis.StringCmd Del(ctx context.Context, keys ...string) *redis.IntCmd EvalRO(ctx context.Context, script string, keys []string, args ...interface{}) *redis.Cmd EvalShaRO(ctx context.Context, sha1 string, keys []string, args ...interface{}) *redis.Cmd } //nolint:dupword var AllowTokenBucketRaw = ` local key = KEYS[1] local limit, burst, ttl, t, max_delay = tonumber(ARGV[1]), tonumber(ARGV[2]), tonumber(ARGV[3]), tonumber(ARGV[4]), tonumber(ARGV[5]) local bucket = { limit = limit, burst = burst, tokens = 0, last = 0 } local rl_source = redis.call('hgetall', key) if table.maxn(rl_source) == 4 then -- Get bucket state from redis bucket.last = tonumber(rl_source[2]) bucket.tokens = tonumber(rl_source[4]) end local last = bucket.last if t < last then last = t end local elapsed = t - last local delta = bucket.limit * elapsed local tokens = bucket.tokens + delta tokens = math.min(tokens, bucket.burst) tokens = tokens - 1 local wait_duration = 0 if tokens < 0 then wait_duration = (tokens * -1) / bucket.limit if wait_duration > max_delay then tokens = tokens + 1 tokens = math.min(tokens, burst) end end redis.call('hset', key, 'last', t, 'tokens', tokens) redis.call('expire', key, ttl) return {tostring(true), tostring(wait_duration),tostring(tokens)}` var AllowTokenBucketScript = redis.NewScript(AllowTokenBucketRaw)