Rate limiting for frontends
This commit is contained in:
parent
9fba37b409
commit
d54417acfe
16 changed files with 1410 additions and 3 deletions
125
vendor/github.com/vulcand/oxy/ratelimit/bucket.go
generated
vendored
Normal file
125
vendor/github.com/vulcand/oxy/ratelimit/bucket.go
generated
vendored
Normal file
|
@ -0,0 +1,125 @@
|
|||
package ratelimit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/mailgun/timetools"
|
||||
)
|
||||
|
||||
const UndefinedDelay = -1
|
||||
|
||||
// rate defines token bucket parameters.
|
||||
type rate struct {
|
||||
period time.Duration
|
||||
average int64
|
||||
burst int64
|
||||
}
|
||||
|
||||
func (r *rate) String() string {
|
||||
return fmt.Sprintf("rate(%v/%v, burst=%v)", r.average, r.period, r.burst)
|
||||
}
|
||||
|
||||
// Implements token bucket algorithm (http://en.wikipedia.org/wiki/Token_bucket)
|
||||
type tokenBucket struct {
|
||||
// The time period controlled by the bucket in nanoseconds.
|
||||
period time.Duration
|
||||
// The number of nanoseconds that takes to add one more token to the total
|
||||
// number of available tokens. It effectively caches the value that could
|
||||
// have been otherwise deduced from refillRate.
|
||||
timePerToken time.Duration
|
||||
// The maximum number of tokens that can be accumulate in the bucket.
|
||||
burst int64
|
||||
// The number of tokens available for consumption at the moment. It can
|
||||
// nether be larger then capacity.
|
||||
availableTokens int64
|
||||
// Interface that gives current time (so tests can override)
|
||||
clock timetools.TimeProvider
|
||||
// Tells when tokensAvailable was updated the last time.
|
||||
lastRefresh time.Time
|
||||
// The number of tokens consumed the last time.
|
||||
lastConsumed int64
|
||||
}
|
||||
|
||||
// newTokenBucket crates a `tokenBucket` instance for the specified `Rate`.
|
||||
func newTokenBucket(rate *rate, clock timetools.TimeProvider) *tokenBucket {
|
||||
return &tokenBucket{
|
||||
period: rate.period,
|
||||
timePerToken: time.Duration(int64(rate.period) / rate.average),
|
||||
burst: rate.burst,
|
||||
clock: clock,
|
||||
lastRefresh: clock.UtcNow(),
|
||||
availableTokens: rate.burst,
|
||||
}
|
||||
}
|
||||
|
||||
// consume makes an attempt to consume the specified number of tokens from the
|
||||
// bucket. If there are enough tokens available then `0, nil` is returned; if
|
||||
// tokens to consume is larger than the burst size, then an error is returned
|
||||
// and the delay is not defined; otherwise returned a none zero delay that tells
|
||||
// how much time the caller needs to wait until the desired number of tokens
|
||||
// will become available for consumption.
|
||||
func (tb *tokenBucket) consume(tokens int64) (time.Duration, error) {
|
||||
tb.updateAvailableTokens()
|
||||
tb.lastConsumed = 0
|
||||
if tokens > tb.burst {
|
||||
return UndefinedDelay, fmt.Errorf("Requested tokens larger than max tokens")
|
||||
}
|
||||
if tb.availableTokens < tokens {
|
||||
return tb.timeTillAvailable(tokens), nil
|
||||
}
|
||||
tb.availableTokens -= tokens
|
||||
tb.lastConsumed = tokens
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// rollback reverts effect of the most recent consumption. If the most recent
|
||||
// `consume` resulted in an error or a burst overflow, and therefore did not
|
||||
// modify the number of available tokens, then `rollback` won't do that either.
|
||||
// It is safe to call this method multiple times, for the second and all
|
||||
// following calls have no effect.
|
||||
func (tb *tokenBucket) rollback() {
|
||||
tb.availableTokens += tb.lastConsumed
|
||||
tb.lastConsumed = 0
|
||||
}
|
||||
|
||||
// Update modifies `average` and `burst` fields of the token bucket according
|
||||
// to the provided `Rate`
|
||||
func (tb *tokenBucket) update(rate *rate) error {
|
||||
if rate.period != tb.period {
|
||||
return fmt.Errorf("Period mismatch: %v != %v", tb.period, rate.period)
|
||||
}
|
||||
tb.timePerToken = time.Duration(int64(tb.period) / rate.average)
|
||||
tb.burst = rate.burst
|
||||
if tb.availableTokens > rate.burst {
|
||||
tb.availableTokens = rate.burst
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// timeTillAvailable returns the number of nanoseconds that we need to
|
||||
// wait until the specified number of tokens becomes available for consumption.
|
||||
func (tb *tokenBucket) timeTillAvailable(tokens int64) time.Duration {
|
||||
missingTokens := tokens - tb.availableTokens
|
||||
return time.Duration(missingTokens) * tb.timePerToken
|
||||
}
|
||||
|
||||
// updateAvailableTokens updates the number of tokens available for consumption.
|
||||
// It is calculated based on the refill rate, the time passed since last refresh,
|
||||
// and is limited by the bucket capacity.
|
||||
func (tb *tokenBucket) updateAvailableTokens() {
|
||||
now := tb.clock.UtcNow()
|
||||
timePassed := now.Sub(tb.lastRefresh)
|
||||
|
||||
tokens := tb.availableTokens + int64(timePassed/tb.timePerToken)
|
||||
// If we haven't added any tokens that means that not enough time has passed,
|
||||
// in this case do not adjust last refill checkpoint, otherwise it will be
|
||||
// always moving in time in case of frequent requests that exceed the rate
|
||||
if tokens != tb.availableTokens {
|
||||
tb.lastRefresh = now
|
||||
tb.availableTokens = tokens
|
||||
}
|
||||
if tb.availableTokens > tb.burst {
|
||||
tb.availableTokens = tb.burst
|
||||
}
|
||||
}
|
108
vendor/github.com/vulcand/oxy/ratelimit/bucketset.go
generated
vendored
Normal file
108
vendor/github.com/vulcand/oxy/ratelimit/bucketset.go
generated
vendored
Normal file
|
@ -0,0 +1,108 @@
|
|||
package ratelimit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mailgun/timetools"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// TokenBucketSet represents a set of TokenBucket covering different time periods.
|
||||
type TokenBucketSet struct {
|
||||
buckets map[time.Duration]*tokenBucket
|
||||
maxPeriod time.Duration
|
||||
clock timetools.TimeProvider
|
||||
}
|
||||
|
||||
// newTokenBucketSet creates a `TokenBucketSet` from the specified `rates`.
|
||||
func NewTokenBucketSet(rates *RateSet, clock timetools.TimeProvider) *TokenBucketSet {
|
||||
tbs := new(TokenBucketSet)
|
||||
tbs.clock = clock
|
||||
// In the majority of cases we will have only one bucket.
|
||||
tbs.buckets = make(map[time.Duration]*tokenBucket, len(rates.m))
|
||||
for _, rate := range rates.m {
|
||||
newBucket := newTokenBucket(rate, clock)
|
||||
tbs.buckets[rate.period] = newBucket
|
||||
tbs.maxPeriod = maxDuration(tbs.maxPeriod, rate.period)
|
||||
}
|
||||
return tbs
|
||||
}
|
||||
|
||||
// Update brings the buckets in the set in accordance with the provided `rates`.
|
||||
func (tbs *TokenBucketSet) Update(rates *RateSet) {
|
||||
// Update existing buckets and delete those that have no corresponding spec.
|
||||
for _, bucket := range tbs.buckets {
|
||||
if rate, ok := rates.m[bucket.period]; ok {
|
||||
bucket.update(rate)
|
||||
} else {
|
||||
delete(tbs.buckets, bucket.period)
|
||||
}
|
||||
}
|
||||
// Add missing buckets.
|
||||
for _, rate := range rates.m {
|
||||
if _, ok := tbs.buckets[rate.period]; !ok {
|
||||
newBucket := newTokenBucket(rate, tbs.clock)
|
||||
tbs.buckets[rate.period] = newBucket
|
||||
}
|
||||
}
|
||||
// Identify the maximum period in the set
|
||||
tbs.maxPeriod = 0
|
||||
for _, bucket := range tbs.buckets {
|
||||
tbs.maxPeriod = maxDuration(tbs.maxPeriod, bucket.period)
|
||||
}
|
||||
}
|
||||
|
||||
func (tbs *TokenBucketSet) Consume(tokens int64) (time.Duration, error) {
|
||||
var maxDelay time.Duration = UndefinedDelay
|
||||
var firstErr error = nil
|
||||
for _, tokenBucket := range tbs.buckets {
|
||||
// We keep calling `Consume` even after a error is returned for one of
|
||||
// buckets because that allows us to simplify the rollback procedure,
|
||||
// that is to just call `Rollback` for all buckets.
|
||||
delay, err := tokenBucket.consume(tokens)
|
||||
if firstErr == nil {
|
||||
if err != nil {
|
||||
firstErr = err
|
||||
} else {
|
||||
maxDelay = maxDuration(maxDelay, delay)
|
||||
}
|
||||
}
|
||||
}
|
||||
// If we could not make ALL buckets consume tokens for whatever reason,
|
||||
// then rollback consumption for all of them.
|
||||
if firstErr != nil || maxDelay > 0 {
|
||||
for _, tokenBucket := range tbs.buckets {
|
||||
tokenBucket.rollback()
|
||||
}
|
||||
}
|
||||
return maxDelay, firstErr
|
||||
}
|
||||
|
||||
func (tbs *TokenBucketSet) GetMaxPeriod() time.Duration {
|
||||
return tbs.maxPeriod
|
||||
}
|
||||
|
||||
// debugState returns string that reflects the current state of all buckets in
|
||||
// this set. It is intended to be used for debugging and testing only.
|
||||
func (tbs *TokenBucketSet) debugState() string {
|
||||
periods := sort.IntSlice(make([]int, 0, len(tbs.buckets)))
|
||||
for period := range tbs.buckets {
|
||||
periods = append(periods, int(period))
|
||||
}
|
||||
sort.Sort(periods)
|
||||
bucketRepr := make([]string, 0, len(tbs.buckets))
|
||||
for _, period := range periods {
|
||||
bucket := tbs.buckets[time.Duration(period)]
|
||||
bucketRepr = append(bucketRepr, fmt.Sprintf("{%v: %v}", bucket.period, bucket.availableTokens))
|
||||
}
|
||||
return strings.Join(bucketRepr, ", ")
|
||||
}
|
||||
|
||||
func maxDuration(x time.Duration, y time.Duration) time.Duration {
|
||||
if x > y {
|
||||
return x
|
||||
}
|
||||
return y
|
||||
}
|
248
vendor/github.com/vulcand/oxy/ratelimit/tokenlimiter.go
generated
vendored
Normal file
248
vendor/github.com/vulcand/oxy/ratelimit/tokenlimiter.go
generated
vendored
Normal file
|
@ -0,0 +1,248 @@
|
|||
// Tokenbucket based request rate limiter
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mailgun/timetools"
|
||||
"github.com/mailgun/ttlmap"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
const DefaultCapacity = 65536
|
||||
|
||||
// RateSet maintains a set of rates. It can contain only one rate per period at a time.
|
||||
type RateSet struct {
|
||||
m map[time.Duration]*rate
|
||||
}
|
||||
|
||||
// NewRateSet crates an empty `RateSet` instance.
|
||||
func NewRateSet() *RateSet {
|
||||
rs := new(RateSet)
|
||||
rs.m = make(map[time.Duration]*rate)
|
||||
return rs
|
||||
}
|
||||
|
||||
// Add adds a rate to the set. If there is a rate with the same period in the
|
||||
// set then the new rate overrides the old one.
|
||||
func (rs *RateSet) Add(period time.Duration, average int64, burst int64) error {
|
||||
if period <= 0 {
|
||||
return fmt.Errorf("Invalid period: %v", period)
|
||||
}
|
||||
if average <= 0 {
|
||||
return fmt.Errorf("Invalid average: %v", average)
|
||||
}
|
||||
if burst <= 0 {
|
||||
return fmt.Errorf("Invalid burst: %v", burst)
|
||||
}
|
||||
rs.m[period] = &rate{period, average, burst}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rs *RateSet) String() string {
|
||||
return fmt.Sprint(rs.m)
|
||||
}
|
||||
|
||||
type RateExtractor interface {
|
||||
Extract(r *http.Request) (*RateSet, error)
|
||||
}
|
||||
|
||||
type RateExtractorFunc func(r *http.Request) (*RateSet, error)
|
||||
|
||||
func (e RateExtractorFunc) Extract(r *http.Request) (*RateSet, error) {
|
||||
return e(r)
|
||||
}
|
||||
|
||||
// TokenLimiter implements rate limiting middleware.
|
||||
type TokenLimiter struct {
|
||||
defaultRates *RateSet
|
||||
extract utils.SourceExtractor
|
||||
extractRates RateExtractor
|
||||
clock timetools.TimeProvider
|
||||
mutex sync.Mutex
|
||||
bucketSets *ttlmap.TtlMap
|
||||
errHandler utils.ErrorHandler
|
||||
log utils.Logger
|
||||
capacity int
|
||||
next http.Handler
|
||||
}
|
||||
|
||||
// New constructs a `TokenLimiter` middleware instance.
|
||||
func New(next http.Handler, extract utils.SourceExtractor, defaultRates *RateSet, opts ...TokenLimiterOption) (*TokenLimiter, error) {
|
||||
if defaultRates == nil || len(defaultRates.m) == 0 {
|
||||
return nil, fmt.Errorf("Provide default rates")
|
||||
}
|
||||
if extract == nil {
|
||||
return nil, fmt.Errorf("Provide extract function")
|
||||
}
|
||||
tl := &TokenLimiter{
|
||||
next: next,
|
||||
defaultRates: defaultRates,
|
||||
extract: extract,
|
||||
}
|
||||
|
||||
for _, o := range opts {
|
||||
if err := o(tl); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
setDefaults(tl)
|
||||
bucketSets, err := ttlmap.NewMapWithProvider(tl.capacity, tl.clock)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tl.bucketSets = bucketSets
|
||||
return tl, nil
|
||||
}
|
||||
|
||||
func (tl *TokenLimiter) Wrap(next http.Handler) {
|
||||
tl.next = next
|
||||
}
|
||||
|
||||
func (tl *TokenLimiter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
source, amount, err := tl.extract.Extract(req)
|
||||
if err != nil {
|
||||
tl.errHandler.ServeHTTP(w, req, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := tl.consumeRates(req, source, amount); err != nil {
|
||||
tl.log.Infof("limiting request %v %v, limit: %v", req.Method, req.URL, err)
|
||||
tl.errHandler.ServeHTTP(w, req, err)
|
||||
return
|
||||
}
|
||||
|
||||
tl.next.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
func (tl *TokenLimiter) consumeRates(req *http.Request, source string, amount int64) error {
|
||||
tl.mutex.Lock()
|
||||
defer tl.mutex.Unlock()
|
||||
|
||||
effectiveRates := tl.resolveRates(req)
|
||||
bucketSetI, exists := tl.bucketSets.Get(source)
|
||||
var bucketSet *TokenBucketSet
|
||||
|
||||
if exists {
|
||||
bucketSet = bucketSetI.(*TokenBucketSet)
|
||||
bucketSet.Update(effectiveRates)
|
||||
} else {
|
||||
bucketSet = NewTokenBucketSet(effectiveRates, tl.clock)
|
||||
// We set ttl as 10 times rate period. E.g. if rate is 100 requests/second per client ip
|
||||
// the counters for this ip will expire after 10 seconds of inactivity
|
||||
tl.bucketSets.Set(source, bucketSet, int(bucketSet.maxPeriod/time.Second)*10+1)
|
||||
}
|
||||
delay, err := bucketSet.Consume(amount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if delay > 0 {
|
||||
return &MaxRateError{delay: delay}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// effectiveRates retrieves rates to be applied to the request.
|
||||
func (tl *TokenLimiter) resolveRates(req *http.Request) *RateSet {
|
||||
// If configuration mapper is not specified for this instance, then return
|
||||
// the default bucket specs.
|
||||
if tl.extractRates == nil {
|
||||
return tl.defaultRates
|
||||
}
|
||||
|
||||
rates, err := tl.extractRates.Extract(req)
|
||||
if err != nil {
|
||||
tl.log.Errorf("Failed to retrieve rates: %v", err)
|
||||
return tl.defaultRates
|
||||
}
|
||||
|
||||
// If the returned rate set is empty then used the default one.
|
||||
if len(rates.m) == 0 {
|
||||
return tl.defaultRates
|
||||
}
|
||||
|
||||
return rates
|
||||
}
|
||||
|
||||
type MaxRateError struct {
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
func (m *MaxRateError) Error() string {
|
||||
return fmt.Sprintf("max rate reached: retry-in %v", m.delay)
|
||||
}
|
||||
|
||||
type RateErrHandler struct {
|
||||
}
|
||||
|
||||
func (e *RateErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) {
|
||||
if rerr, ok := err.(*MaxRateError); ok {
|
||||
w.Header().Set("X-Retry-In", rerr.delay.String())
|
||||
w.WriteHeader(429)
|
||||
w.Write([]byte(err.Error()))
|
||||
return
|
||||
}
|
||||
utils.DefaultHandler.ServeHTTP(w, req, err)
|
||||
}
|
||||
|
||||
type TokenLimiterOption func(l *TokenLimiter) error
|
||||
|
||||
// Logger sets the logger that will be used by this middleware.
|
||||
func Logger(l utils.Logger) TokenLimiterOption {
|
||||
return func(cl *TokenLimiter) error {
|
||||
cl.log = l
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorHandler sets error handler of the server
|
||||
func ErrorHandler(h utils.ErrorHandler) TokenLimiterOption {
|
||||
return func(cl *TokenLimiter) error {
|
||||
cl.errHandler = h
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func ExtractRates(e RateExtractor) TokenLimiterOption {
|
||||
return func(cl *TokenLimiter) error {
|
||||
cl.extractRates = e
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func Clock(clock timetools.TimeProvider) TokenLimiterOption {
|
||||
return func(cl *TokenLimiter) error {
|
||||
cl.clock = clock
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func Capacity(cap int) TokenLimiterOption {
|
||||
return func(cl *TokenLimiter) error {
|
||||
if cap <= 0 {
|
||||
return fmt.Errorf("bad capacity: %v", cap)
|
||||
}
|
||||
cl.capacity = cap
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
var defaultErrHandler = &RateErrHandler{}
|
||||
|
||||
func setDefaults(tl *TokenLimiter) {
|
||||
if tl.log == nil {
|
||||
tl.log = utils.NullLogger
|
||||
}
|
||||
if tl.capacity <= 0 {
|
||||
tl.capacity = DefaultCapacity
|
||||
}
|
||||
if tl.clock == nil {
|
||||
tl.clock = &timetools.RealTime{}
|
||||
}
|
||||
if tl.errHandler == nil {
|
||||
tl.errHandler = defaultErrHandler
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue