Send 'Retry-After' to comply with RFC6585.
This commit is contained in:
parent
027093a5a5
commit
8d75aba7eb
29 changed files with 435 additions and 172 deletions
9
vendor/github.com/vulcand/oxy/ratelimit/bucket.go
generated
vendored
9
vendor/github.com/vulcand/oxy/ratelimit/bucket.go
generated
vendored
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/mailgun/timetools"
|
||||
)
|
||||
|
||||
// UndefinedDelay default delay
|
||||
const UndefinedDelay = -1
|
||||
|
||||
// rate defines token bucket parameters.
|
||||
|
@ -20,7 +21,7 @@ func (r *rate) String() string {
|
|||
return fmt.Sprintf("rate(%v/%v, burst=%v)", r.average, r.period, r.burst)
|
||||
}
|
||||
|
||||
// Implements token bucket algorithm (http://en.wikipedia.org/wiki/Token_bucket)
|
||||
// tokenBucket Implements token bucket algorithm (http://en.wikipedia.org/wiki/Token_bucket)
|
||||
type tokenBucket struct {
|
||||
// The time period controlled by the bucket in nanoseconds.
|
||||
period time.Duration
|
||||
|
@ -63,7 +64,7 @@ func (tb *tokenBucket) consume(tokens int64) (time.Duration, error) {
|
|||
tb.updateAvailableTokens()
|
||||
tb.lastConsumed = 0
|
||||
if tokens > tb.burst {
|
||||
return UndefinedDelay, fmt.Errorf("Requested tokens larger than max tokens")
|
||||
return UndefinedDelay, fmt.Errorf("requested tokens larger than max tokens")
|
||||
}
|
||||
if tb.availableTokens < tokens {
|
||||
return tb.timeTillAvailable(tokens), nil
|
||||
|
@ -83,11 +84,11 @@ func (tb *tokenBucket) rollback() {
|
|||
tb.lastConsumed = 0
|
||||
}
|
||||
|
||||
// Update modifies `average` and `burst` fields of the token bucket according
|
||||
// update modifies `average` and `burst` fields of the token bucket according
|
||||
// to the provided `Rate`
|
||||
func (tb *tokenBucket) update(rate *rate) error {
|
||||
if rate.period != tb.period {
|
||||
return fmt.Errorf("Period mismatch: %v != %v", tb.period, rate.period)
|
||||
return fmt.Errorf("period mismatch: %v != %v", tb.period, rate.period)
|
||||
}
|
||||
tb.timePerToken = time.Duration(int64(tb.period) / rate.average)
|
||||
tb.burst = rate.burst
|
||||
|
|
8
vendor/github.com/vulcand/oxy/ratelimit/bucketset.go
generated
vendored
8
vendor/github.com/vulcand/oxy/ratelimit/bucketset.go
generated
vendored
|
@ -2,11 +2,11 @@ package ratelimit
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/mailgun/timetools"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// TokenBucketSet represents a set of TokenBucket covering different time periods.
|
||||
|
@ -16,7 +16,7 @@ type TokenBucketSet struct {
|
|||
clock timetools.TimeProvider
|
||||
}
|
||||
|
||||
// newTokenBucketSet creates a `TokenBucketSet` from the specified `rates`.
|
||||
// NewTokenBucketSet creates a `TokenBucketSet` from the specified `rates`.
|
||||
func NewTokenBucketSet(rates *RateSet, clock timetools.TimeProvider) *TokenBucketSet {
|
||||
tbs := new(TokenBucketSet)
|
||||
tbs.clock = clock
|
||||
|
@ -54,9 +54,10 @@ func (tbs *TokenBucketSet) Update(rates *RateSet) {
|
|||
}
|
||||
}
|
||||
|
||||
// Consume consume tokens
|
||||
func (tbs *TokenBucketSet) Consume(tokens int64) (time.Duration, error) {
|
||||
var maxDelay time.Duration = UndefinedDelay
|
||||
var firstErr error = nil
|
||||
var firstErr error
|
||||
for _, tokenBucket := range tbs.buckets {
|
||||
// We keep calling `Consume` even after a error is returned for one of
|
||||
// buckets because that allows us to simplify the rollback procedure,
|
||||
|
@ -80,6 +81,7 @@ func (tbs *TokenBucketSet) Consume(tokens int64) (time.Duration, error) {
|
|||
return maxDelay, firstErr
|
||||
}
|
||||
|
||||
// GetMaxPeriod returns the max period
|
||||
func (tbs *TokenBucketSet) GetMaxPeriod() time.Duration {
|
||||
return tbs.maxPeriod
|
||||
}
|
||||
|
|
49
vendor/github.com/vulcand/oxy/ratelimit/tokenlimiter.go
generated
vendored
49
vendor/github.com/vulcand/oxy/ratelimit/tokenlimiter.go
generated
vendored
|
@ -1,4 +1,4 @@
|
|||
// Tokenbucket based request rate limiter
|
||||
// Package ratelimit Tokenbucket based request rate limiter
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
// DefaultCapacity default capacity
|
||||
const DefaultCapacity = 65536
|
||||
|
||||
// RateSet maintains a set of rates. It can contain only one rate per period at a time.
|
||||
|
@ -31,15 +32,15 @@ func NewRateSet() *RateSet {
|
|||
// set then the new rate overrides the old one.
|
||||
func (rs *RateSet) Add(period time.Duration, average int64, burst int64) error {
|
||||
if period <= 0 {
|
||||
return fmt.Errorf("Invalid period: %v", period)
|
||||
return fmt.Errorf("invalid period: %v", period)
|
||||
}
|
||||
if average <= 0 {
|
||||
return fmt.Errorf("Invalid average: %v", average)
|
||||
return fmt.Errorf("invalid average: %v", average)
|
||||
}
|
||||
if burst <= 0 {
|
||||
return fmt.Errorf("Invalid burst: %v", burst)
|
||||
return fmt.Errorf("invalid burst: %v", burst)
|
||||
}
|
||||
rs.m[period] = &rate{period, average, burst}
|
||||
rs.m[period] = &rate{period: period, average: average, burst: burst}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -47,12 +48,15 @@ func (rs *RateSet) String() string {
|
|||
return fmt.Sprint(rs.m)
|
||||
}
|
||||
|
||||
// RateExtractor rate extractor
|
||||
type RateExtractor interface {
|
||||
Extract(r *http.Request) (*RateSet, error)
|
||||
}
|
||||
|
||||
// RateExtractorFunc rate extractor function type
|
||||
type RateExtractorFunc func(r *http.Request) (*RateSet, error)
|
||||
|
||||
// Extract extract from request
|
||||
func (e RateExtractorFunc) Extract(r *http.Request) (*RateSet, error) {
|
||||
return e(r)
|
||||
}
|
||||
|
@ -68,20 +72,24 @@ type TokenLimiter struct {
|
|||
errHandler utils.ErrorHandler
|
||||
capacity int
|
||||
next http.Handler
|
||||
|
||||
log *log.Logger
|
||||
}
|
||||
|
||||
// New constructs a `TokenLimiter` middleware instance.
|
||||
func New(next http.Handler, extract utils.SourceExtractor, defaultRates *RateSet, opts ...TokenLimiterOption) (*TokenLimiter, error) {
|
||||
if defaultRates == nil || len(defaultRates.m) == 0 {
|
||||
return nil, fmt.Errorf("Provide default rates")
|
||||
return nil, fmt.Errorf("provide default rates")
|
||||
}
|
||||
if extract == nil {
|
||||
return nil, fmt.Errorf("Provide extract function")
|
||||
return nil, fmt.Errorf("provide extract function")
|
||||
}
|
||||
tl := &TokenLimiter{
|
||||
next: next,
|
||||
defaultRates: defaultRates,
|
||||
extract: extract,
|
||||
|
||||
log: log.StandardLogger(),
|
||||
}
|
||||
|
||||
for _, o := range opts {
|
||||
|
@ -98,6 +106,17 @@ func New(next http.Handler, extract utils.SourceExtractor, defaultRates *RateSet
|
|||
return tl, nil
|
||||
}
|
||||
|
||||
// Logger defines the logger the token limiter will use.
|
||||
//
|
||||
// It defaults to logrus.StandardLogger(), the global logger used by logrus.
|
||||
func Logger(l *log.Logger) TokenLimiterOption {
|
||||
return func(tl *TokenLimiter) error {
|
||||
tl.log = l
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Wrap sets the next handler to be called by token limiter handler.
|
||||
func (tl *TokenLimiter) Wrap(next http.Handler) {
|
||||
tl.next = next
|
||||
}
|
||||
|
@ -110,7 +129,7 @@ func (tl *TokenLimiter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|||
}
|
||||
|
||||
if err := tl.consumeRates(req, source, amount); err != nil {
|
||||
log.Warnf("limiting request %v %v, limit: %v", req.Method, req.URL, err)
|
||||
tl.log.Warnf("limiting request %v %v, limit: %v", req.Method, req.URL, err)
|
||||
tl.errHandler.ServeHTTP(w, req, err)
|
||||
return
|
||||
}
|
||||
|
@ -155,7 +174,7 @@ func (tl *TokenLimiter) resolveRates(req *http.Request) *RateSet {
|
|||
|
||||
rates, err := tl.extractRates.Extract(req)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to retrieve rates: %v", err)
|
||||
tl.log.Errorf("Failed to retrieve rates: %v", err)
|
||||
return tl.defaultRates
|
||||
}
|
||||
|
||||
|
@ -167,6 +186,7 @@ func (tl *TokenLimiter) resolveRates(req *http.Request) *RateSet {
|
|||
return rates
|
||||
}
|
||||
|
||||
// MaxRateError max rate error
|
||||
type MaxRateError struct {
|
||||
delay time.Duration
|
||||
}
|
||||
|
@ -175,19 +195,21 @@ func (m *MaxRateError) Error() string {
|
|||
return fmt.Sprintf("max rate reached: retry-in %v", m.delay)
|
||||
}
|
||||
|
||||
type RateErrHandler struct {
|
||||
}
|
||||
// RateErrHandler error handler
|
||||
type RateErrHandler struct{}
|
||||
|
||||
func (e *RateErrHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) {
|
||||
if rerr, ok := err.(*MaxRateError); ok {
|
||||
w.Header().Set("Retry-After", fmt.Sprintf("%.0f", rerr.delay.Seconds()))
|
||||
w.Header().Set("X-Retry-In", rerr.delay.String())
|
||||
w.WriteHeader(429)
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
w.Write([]byte(err.Error()))
|
||||
return
|
||||
}
|
||||
utils.DefaultHandler.ServeHTTP(w, req, err)
|
||||
}
|
||||
|
||||
// TokenLimiterOption token limiter option type
|
||||
type TokenLimiterOption func(l *TokenLimiter) error
|
||||
|
||||
// ErrorHandler sets error handler of the server
|
||||
|
@ -198,6 +220,7 @@ func ErrorHandler(h utils.ErrorHandler) TokenLimiterOption {
|
|||
}
|
||||
}
|
||||
|
||||
// ExtractRates sets the rate extractor
|
||||
func ExtractRates(e RateExtractor) TokenLimiterOption {
|
||||
return func(cl *TokenLimiter) error {
|
||||
cl.extractRates = e
|
||||
|
@ -205,6 +228,7 @@ func ExtractRates(e RateExtractor) TokenLimiterOption {
|
|||
}
|
||||
}
|
||||
|
||||
// Clock sets the clock
|
||||
func Clock(clock timetools.TimeProvider) TokenLimiterOption {
|
||||
return func(cl *TokenLimiter) error {
|
||||
cl.clock = clock
|
||||
|
@ -212,6 +236,7 @@ func Clock(clock timetools.TimeProvider) TokenLimiterOption {
|
|||
}
|
||||
}
|
||||
|
||||
// Capacity sets the capacity
|
||||
func Capacity(cap int) TokenLimiterOption {
|
||||
return func(cl *TokenLimiter) error {
|
||||
if cap <= 0 {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue