Move code to pkg
This commit is contained in:
parent
bd4c822670
commit
f1b085fa36
465 changed files with 656 additions and 680 deletions
57
pkg/server/aggregator.go
Normal file
57
pkg/server/aggregator.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/server/internal"
|
||||
"github.com/containous/traefik/pkg/tls"
|
||||
)
|
||||
|
||||
func mergeConfiguration(configurations config.Configurations) config.Configuration {
|
||||
conf := config.Configuration{
|
||||
HTTP: &config.HTTPConfiguration{
|
||||
Routers: make(map[string]*config.Router),
|
||||
Middlewares: make(map[string]*config.Middleware),
|
||||
Services: make(map[string]*config.Service),
|
||||
},
|
||||
TCP: &config.TCPConfiguration{
|
||||
Routers: make(map[string]*config.TCPRouter),
|
||||
Services: make(map[string]*config.TCPService),
|
||||
},
|
||||
TLSOptions: make(map[string]tls.TLS),
|
||||
TLSStores: make(map[string]tls.Store),
|
||||
}
|
||||
|
||||
for provider, configuration := range configurations {
|
||||
if configuration.HTTP != nil {
|
||||
for routerName, router := range configuration.HTTP.Routers {
|
||||
conf.HTTP.Routers[internal.MakeQualifiedName(provider, routerName)] = router
|
||||
}
|
||||
for middlewareName, middleware := range configuration.HTTP.Middlewares {
|
||||
conf.HTTP.Middlewares[internal.MakeQualifiedName(provider, middlewareName)] = middleware
|
||||
}
|
||||
for serviceName, service := range configuration.HTTP.Services {
|
||||
conf.HTTP.Services[internal.MakeQualifiedName(provider, serviceName)] = service
|
||||
}
|
||||
}
|
||||
|
||||
if configuration.TCP != nil {
|
||||
for routerName, router := range configuration.TCP.Routers {
|
||||
conf.TCP.Routers[internal.MakeQualifiedName(provider, routerName)] = router
|
||||
}
|
||||
for serviceName, service := range configuration.TCP.Services {
|
||||
conf.TCP.Services[internal.MakeQualifiedName(provider, serviceName)] = service
|
||||
}
|
||||
}
|
||||
conf.TLS = append(conf.TLS, configuration.TLS...)
|
||||
|
||||
for key, store := range configuration.TLSStores {
|
||||
conf.TLSStores[key] = store
|
||||
}
|
||||
|
||||
for key, config := range configuration.TLSOptions {
|
||||
conf.TLSOptions[key] = config
|
||||
}
|
||||
}
|
||||
|
||||
return conf
|
||||
}
|
110
pkg/server/aggregator_test.go
Normal file
110
pkg/server/aggregator_test.go
Normal file
|
@ -0,0 +1,110 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAggregator(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
given config.Configurations
|
||||
expected *config.HTTPConfiguration
|
||||
}{
|
||||
{
|
||||
desc: "Nil returns an empty configuration",
|
||||
given: nil,
|
||||
expected: &config.HTTPConfiguration{
|
||||
Routers: make(map[string]*config.Router),
|
||||
Middlewares: make(map[string]*config.Middleware),
|
||||
Services: make(map[string]*config.Service),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Returns fully qualified elements from a mono-provider configuration map",
|
||||
given: config.Configurations{
|
||||
"provider-1": &config.Configuration{
|
||||
HTTP: &config.HTTPConfiguration{
|
||||
Routers: map[string]*config.Router{
|
||||
"router-1": {},
|
||||
},
|
||||
Middlewares: map[string]*config.Middleware{
|
||||
"middleware-1": {},
|
||||
},
|
||||
Services: map[string]*config.Service{
|
||||
"service-1": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &config.HTTPConfiguration{
|
||||
Routers: map[string]*config.Router{
|
||||
"provider-1.router-1": {},
|
||||
},
|
||||
Middlewares: map[string]*config.Middleware{
|
||||
"provider-1.middleware-1": {},
|
||||
},
|
||||
Services: map[string]*config.Service{
|
||||
"provider-1.service-1": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Returns fully qualified elements from a multi-provider configuration map",
|
||||
given: config.Configurations{
|
||||
"provider-1": &config.Configuration{
|
||||
HTTP: &config.HTTPConfiguration{
|
||||
Routers: map[string]*config.Router{
|
||||
"router-1": {},
|
||||
},
|
||||
Middlewares: map[string]*config.Middleware{
|
||||
"middleware-1": {},
|
||||
},
|
||||
Services: map[string]*config.Service{
|
||||
"service-1": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider-2": &config.Configuration{
|
||||
HTTP: &config.HTTPConfiguration{
|
||||
Routers: map[string]*config.Router{
|
||||
"router-1": {},
|
||||
},
|
||||
Middlewares: map[string]*config.Middleware{
|
||||
"middleware-1": {},
|
||||
},
|
||||
Services: map[string]*config.Service{
|
||||
"service-1": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &config.HTTPConfiguration{
|
||||
Routers: map[string]*config.Router{
|
||||
"provider-1.router-1": {},
|
||||
"provider-2.router-1": {},
|
||||
},
|
||||
Middlewares: map[string]*config.Middleware{
|
||||
"provider-1.middleware-1": {},
|
||||
"provider-2.middleware-1": {},
|
||||
},
|
||||
Services: map[string]*config.Service{
|
||||
"provider-1.service-1": {},
|
||||
"provider-2.service-1": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
actual := mergeConfiguration(test.given)
|
||||
assert.Equal(t, test.expected, actual.HTTP)
|
||||
})
|
||||
}
|
||||
}
|
57
pkg/server/cookie/cookie.go
Normal file
57
pkg/server/cookie/cookie.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package cookie
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
)
|
||||
|
||||
const cookieNameLength = 6
|
||||
|
||||
// GetName of a cookie
|
||||
func GetName(cookieName string, backendName string) string {
|
||||
if len(cookieName) != 0 {
|
||||
return sanitizeName(cookieName)
|
||||
}
|
||||
|
||||
return GenerateName(backendName)
|
||||
}
|
||||
|
||||
// GenerateName Generate a hashed name
|
||||
func GenerateName(backendName string) string {
|
||||
data := []byte("_TRAEFIK_BACKEND_" + backendName)
|
||||
|
||||
hash := sha1.New()
|
||||
_, err := hash.Write(data)
|
||||
if err != nil {
|
||||
// Impossible case
|
||||
log.Errorf("Fail to create cookie name: %v", err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("_%x", hash.Sum(nil))[:cookieNameLength]
|
||||
}
|
||||
|
||||
// sanitizeName According to [RFC 2616](https://www.ietf.org/rfc/rfc2616.txt) section 2.2
|
||||
func sanitizeName(backend string) string {
|
||||
sanitizer := func(r rune) rune {
|
||||
switch r {
|
||||
case '!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '`', '|', '~':
|
||||
return r
|
||||
}
|
||||
|
||||
switch {
|
||||
case 'a' <= r && r <= 'z':
|
||||
fallthrough
|
||||
case 'A' <= r && r <= 'Z':
|
||||
fallthrough
|
||||
case '0' <= r && r <= '9':
|
||||
return r
|
||||
default:
|
||||
return '_'
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Map(sanitizer, backend)
|
||||
}
|
83
pkg/server/cookie/cookie_test.go
Normal file
83
pkg/server/cookie/cookie_test.go
Normal file
|
@ -0,0 +1,83 @@
|
|||
package cookie
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetName(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
cookieName string
|
||||
backendName string
|
||||
expectedCookieName string
|
||||
}{
|
||||
{
|
||||
desc: "with backend name, without cookie name",
|
||||
cookieName: "",
|
||||
backendName: "/my/BACKEND-v1.0~rc1",
|
||||
expectedCookieName: "_5f7bc",
|
||||
},
|
||||
{
|
||||
desc: "without backend name, with cookie name",
|
||||
cookieName: "/my/BACKEND-v1.0~rc1",
|
||||
backendName: "",
|
||||
expectedCookieName: "_my_BACKEND-v1.0~rc1",
|
||||
},
|
||||
{
|
||||
desc: "with backend name, with cookie name",
|
||||
cookieName: "containous",
|
||||
backendName: "treafik",
|
||||
expectedCookieName: "containous",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cookieName := GetName(test.cookieName, test.backendName)
|
||||
|
||||
assert.Equal(t, test.expectedCookieName, cookieName)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sanitizeName(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
srcName string
|
||||
expectedName string
|
||||
}{
|
||||
{
|
||||
desc: "with /",
|
||||
srcName: "/my/BACKEND-v1.0~rc1",
|
||||
expectedName: "_my_BACKEND-v1.0~rc1",
|
||||
},
|
||||
{
|
||||
desc: "some chars",
|
||||
srcName: "!#$%&'()*+-./:<=>?@[]^_`{|}~",
|
||||
expectedName: "!#$%&'__*+-._________^_`_|_~",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cookieName := sanitizeName(test.srcName)
|
||||
|
||||
assert.Equal(t, test.expectedName, cookieName, "Cookie name")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateName(t *testing.T) {
|
||||
cookieName := GenerateName("containous")
|
||||
|
||||
assert.Len(t, "_8a7bc", 6)
|
||||
assert.Equal(t, "_8a7bc", cookieName)
|
||||
}
|
45
pkg/server/internal/provider.go
Normal file
45
pkg/server/internal/provider.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
)
|
||||
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
providerKey contextKey = iota
|
||||
)
|
||||
|
||||
// AddProviderInContext Adds the provider name in the context
|
||||
func AddProviderInContext(ctx context.Context, elementName string) context.Context {
|
||||
parts := strings.Split(elementName, ".")
|
||||
if len(parts) == 1 {
|
||||
log.FromContext(ctx).Debugf("Could not find a provider for %s.", elementName)
|
||||
return ctx
|
||||
}
|
||||
|
||||
if name, ok := ctx.Value(providerKey).(string); ok && name == parts[0] {
|
||||
return ctx
|
||||
}
|
||||
|
||||
return context.WithValue(ctx, providerKey, parts[0])
|
||||
}
|
||||
|
||||
// GetQualifiedName Gets the fully qualified name.
|
||||
func GetQualifiedName(ctx context.Context, elementName string) string {
|
||||
parts := strings.Split(elementName, ".")
|
||||
if len(parts) == 1 {
|
||||
if providerName, ok := ctx.Value(providerKey).(string); ok {
|
||||
return MakeQualifiedName(providerName, parts[0])
|
||||
}
|
||||
}
|
||||
return elementName
|
||||
}
|
||||
|
||||
// MakeQualifiedName Creates a qualified name for an element
|
||||
func MakeQualifiedName(providerName string, elementName string) string {
|
||||
return providerName + "." + elementName
|
||||
}
|
343
pkg/server/middleware/middlewares.go
Normal file
343
pkg/server/middleware/middlewares.go
Normal file
|
@ -0,0 +1,343 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/alice"
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/middlewares/addprefix"
|
||||
"github.com/containous/traefik/pkg/middlewares/auth"
|
||||
"github.com/containous/traefik/pkg/middlewares/buffering"
|
||||
"github.com/containous/traefik/pkg/middlewares/chain"
|
||||
"github.com/containous/traefik/pkg/middlewares/circuitbreaker"
|
||||
"github.com/containous/traefik/pkg/middlewares/compress"
|
||||
"github.com/containous/traefik/pkg/middlewares/customerrors"
|
||||
"github.com/containous/traefik/pkg/middlewares/headers"
|
||||
"github.com/containous/traefik/pkg/middlewares/ipwhitelist"
|
||||
"github.com/containous/traefik/pkg/middlewares/maxconnection"
|
||||
"github.com/containous/traefik/pkg/middlewares/passtlsclientcert"
|
||||
"github.com/containous/traefik/pkg/middlewares/ratelimiter"
|
||||
"github.com/containous/traefik/pkg/middlewares/redirect"
|
||||
"github.com/containous/traefik/pkg/middlewares/replacepath"
|
||||
"github.com/containous/traefik/pkg/middlewares/replacepathregex"
|
||||
"github.com/containous/traefik/pkg/middlewares/retry"
|
||||
"github.com/containous/traefik/pkg/middlewares/stripprefix"
|
||||
"github.com/containous/traefik/pkg/middlewares/stripprefixregex"
|
||||
"github.com/containous/traefik/pkg/middlewares/tracing"
|
||||
"github.com/containous/traefik/pkg/server/internal"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type middlewareStackType int
|
||||
|
||||
const (
|
||||
middlewareStackKey middlewareStackType = iota
|
||||
)
|
||||
|
||||
// Builder the middleware builder
|
||||
type Builder struct {
|
||||
configs map[string]*config.Middleware
|
||||
serviceBuilder serviceBuilder
|
||||
}
|
||||
|
||||
type serviceBuilder interface {
|
||||
BuildHTTP(ctx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error)
|
||||
}
|
||||
|
||||
// NewBuilder creates a new Builder
|
||||
func NewBuilder(configs map[string]*config.Middleware, serviceBuilder serviceBuilder) *Builder {
|
||||
return &Builder{configs: configs, serviceBuilder: serviceBuilder}
|
||||
}
|
||||
|
||||
// BuildChain creates a middleware chain
|
||||
func (b *Builder) BuildChain(ctx context.Context, middlewares []string) *alice.Chain {
|
||||
chain := alice.New()
|
||||
for _, name := range middlewares {
|
||||
middlewareName := internal.GetQualifiedName(ctx, name)
|
||||
|
||||
chain = chain.Append(func(next http.Handler) (http.Handler, error) {
|
||||
constructorContext := internal.AddProviderInContext(ctx, middlewareName)
|
||||
if _, ok := b.configs[middlewareName]; !ok {
|
||||
return nil, fmt.Errorf("middleware %q does not exist", middlewareName)
|
||||
}
|
||||
|
||||
var err error
|
||||
if constructorContext, err = checkRecursivity(constructorContext, middlewareName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
constructor, err := b.buildConstructor(constructorContext, middlewareName, *b.configs[middlewareName])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error during instanciation of %s: %v", middlewareName, err)
|
||||
}
|
||||
return constructor(next)
|
||||
})
|
||||
}
|
||||
return &chain
|
||||
}
|
||||
|
||||
func checkRecursivity(ctx context.Context, middlewareName string) (context.Context, error) {
|
||||
currentStack, ok := ctx.Value(middlewareStackKey).([]string)
|
||||
if !ok {
|
||||
currentStack = []string{}
|
||||
}
|
||||
if inSlice(middlewareName, currentStack) {
|
||||
return ctx, fmt.Errorf("could not instantiate middleware %s: recursion detected in %s", middlewareName, strings.Join(append(currentStack, middlewareName), "->"))
|
||||
}
|
||||
return context.WithValue(ctx, middlewareStackKey, append(currentStack, middlewareName)), nil
|
||||
}
|
||||
|
||||
func (b *Builder) buildConstructor(ctx context.Context, middlewareName string, config config.Middleware) (alice.Constructor, error) {
|
||||
var middleware alice.Constructor
|
||||
badConf := errors.New("cannot create middleware %q: multi-types middleware not supported, consider declaring two different pieces of middleware instead")
|
||||
|
||||
// AddPrefix
|
||||
if config.AddPrefix != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return addprefix.New(ctx, next, *config.AddPrefix, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// BasicAuth
|
||||
if config.BasicAuth != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return auth.NewBasic(ctx, next, *config.BasicAuth, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// Buffering
|
||||
if config.Buffering != nil && config.MaxConn.Amount != 0 {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return buffering.New(ctx, next, *config.Buffering, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// Chain
|
||||
if config.Chain != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return chain.New(ctx, next, *config.Chain, b, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreaker
|
||||
if config.CircuitBreaker != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return circuitbreaker.New(ctx, next, *config.CircuitBreaker, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// Compress
|
||||
if config.Compress != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return compress.New(ctx, next, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// CustomErrors
|
||||
if config.Errors != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return customerrors.New(ctx, next, *config.Errors, b.serviceBuilder, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// DigestAuth
|
||||
if config.DigestAuth != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return auth.NewDigest(ctx, next, *config.DigestAuth, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// ForwardAuth
|
||||
if config.ForwardAuth != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return auth.NewForward(ctx, next, *config.ForwardAuth, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// Headers
|
||||
if config.Headers != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return headers.New(ctx, next, *config.Headers, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// IPWhiteList
|
||||
if config.IPWhiteList != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return ipwhitelist.New(ctx, next, *config.IPWhiteList, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// MaxConn
|
||||
if config.MaxConn != nil && config.MaxConn.Amount != 0 {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return maxconnection.New(ctx, next, *config.MaxConn, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// PassTLSClientCert
|
||||
if config.PassTLSClientCert != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return passtlsclientcert.New(ctx, next, *config.PassTLSClientCert, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimit
|
||||
if config.RateLimit != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return ratelimiter.New(ctx, next, *config.RateLimit, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// RedirectRegex
|
||||
if config.RedirectRegex != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return redirect.NewRedirectRegex(ctx, next, *config.RedirectRegex, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// RedirectScheme
|
||||
if config.RedirectScheme != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return redirect.NewRedirectScheme(ctx, next, *config.RedirectScheme, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// ReplacePath
|
||||
if config.ReplacePath != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return replacepath.New(ctx, next, *config.ReplacePath, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// ReplacePathRegex
|
||||
if config.ReplacePathRegex != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return replacepathregex.New(ctx, next, *config.ReplacePathRegex, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// Retry
|
||||
if config.Retry != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
// FIXME missing metrics / accessLog
|
||||
return retry.New(ctx, next, *config.Retry, retry.Listeners{}, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// StripPrefix
|
||||
if config.StripPrefix != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return stripprefix.New(ctx, next, *config.StripPrefix, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// StripPrefixRegex
|
||||
if config.StripPrefixRegex != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return stripprefixregex.New(ctx, next, *config.StripPrefixRegex, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
if middleware == nil {
|
||||
return nil, fmt.Errorf("middleware %q does not exist", middlewareName)
|
||||
}
|
||||
|
||||
return tracing.Wrap(ctx, middleware), nil
|
||||
}
|
||||
|
||||
func inSlice(element string, stack []string) bool {
|
||||
for _, value := range stack {
|
||||
if value == element {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
391
pkg/server/middleware/middlewares_test.go
Normal file
391
pkg/server/middleware/middlewares_test.go
Normal file
|
@ -0,0 +1,391 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/server/internal"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuilder_buildConstructorCircuitBreaker(t *testing.T) {
|
||||
testConfig := map[string]*config.Middleware{
|
||||
"empty": {
|
||||
CircuitBreaker: &config.CircuitBreaker{
|
||||
Expression: "",
|
||||
},
|
||||
},
|
||||
"foo": {
|
||||
CircuitBreaker: &config.CircuitBreaker{
|
||||
Expression: "NetworkErrorRatio() > 0.5",
|
||||
},
|
||||
},
|
||||
}
|
||||
middlewaresBuilder := NewBuilder(testConfig, nil)
|
||||
|
||||
emptyHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
middlewareID string
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
desc: "Should fail at creating a circuit breaker with an empty expression",
|
||||
expectedError: true,
|
||||
middlewareID: "empty",
|
||||
}, {
|
||||
desc: "Should create a circuit breaker with a valid expression",
|
||||
expectedError: false,
|
||||
middlewareID: "foo",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
constructor, err := middlewaresBuilder.buildConstructor(context.Background(), test.middlewareID, *testConfig[test.middlewareID])
|
||||
require.NoError(t, err)
|
||||
|
||||
middleware, err2 := constructor(emptyHandler)
|
||||
|
||||
if test.expectedError {
|
||||
require.Error(t, err2)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, middleware)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuilder_BuildChainNilConfig(t *testing.T) {
|
||||
testConfig := map[string]*config.Middleware{
|
||||
"empty": {},
|
||||
}
|
||||
middlewaresBuilder := NewBuilder(testConfig, nil)
|
||||
|
||||
chain := middlewaresBuilder.BuildChain(context.Background(), []string{"empty"})
|
||||
_, err := chain.Then(nil)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestBuilder_BuildChainNonExistentChain(t *testing.T) {
|
||||
testConfig := map[string]*config.Middleware{
|
||||
"foobar": {},
|
||||
}
|
||||
middlewaresBuilder := NewBuilder(testConfig, nil)
|
||||
|
||||
chain := middlewaresBuilder.BuildChain(context.Background(), []string{"empty"})
|
||||
_, err := chain.Then(nil)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestBuilder_buildConstructorAddPrefix(t *testing.T) {
|
||||
testConfig := map[string]*config.Middleware{
|
||||
"empty": {
|
||||
AddPrefix: &config.AddPrefix{
|
||||
Prefix: "",
|
||||
},
|
||||
},
|
||||
"foo": {
|
||||
AddPrefix: &config.AddPrefix{
|
||||
Prefix: "foo/",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
middlewaresBuilder := NewBuilder(testConfig, nil)
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
middlewareID string
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
desc: "Should not create an empty AddPrefix middleware when given an empty prefix",
|
||||
middlewareID: "empty",
|
||||
expectedError: true,
|
||||
}, {
|
||||
desc: "Should create an AddPrefix middleware when given a valid configuration",
|
||||
middlewareID: "foo",
|
||||
expectedError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
constructor, err := middlewaresBuilder.buildConstructor(context.Background(), test.middlewareID, *testConfig[test.middlewareID])
|
||||
require.NoError(t, err)
|
||||
|
||||
middleware, err2 := constructor(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}))
|
||||
|
||||
if test.expectedError {
|
||||
require.Error(t, err2)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, middleware)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuild_BuildChainWithContext(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
buildChain []string
|
||||
configuration map[string]*config.Middleware
|
||||
expected map[string]string
|
||||
contextProvider string
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
desc: "Simple middleware",
|
||||
buildChain: []string{"middleware-1"},
|
||||
configuration: map[string]*config.Middleware{
|
||||
"middleware-1": {
|
||||
Headers: &config.Headers{
|
||||
CustomRequestHeaders: map[string]string{"middleware-1": "value-middleware-1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{"middleware-1": "value-middleware-1"},
|
||||
},
|
||||
{
|
||||
desc: "Middleware that references a chain",
|
||||
buildChain: []string{"middleware-chain-1"},
|
||||
configuration: map[string]*config.Middleware{
|
||||
"middleware-1": {
|
||||
Headers: &config.Headers{
|
||||
CustomRequestHeaders: map[string]string{"middleware-1": "value-middleware-1"},
|
||||
},
|
||||
},
|
||||
"middleware-chain-1": {
|
||||
Chain: &config.Chain{
|
||||
Middlewares: []string{"middleware-1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{"middleware-1": "value-middleware-1"},
|
||||
},
|
||||
{
|
||||
desc: "Should prefix the middlewareName with the provider in the context",
|
||||
buildChain: []string{"middleware-1"},
|
||||
configuration: map[string]*config.Middleware{
|
||||
"provider-1.middleware-1": {
|
||||
Headers: &config.Headers{
|
||||
CustomRequestHeaders: map[string]string{"provider-1.middleware-1": "value-middleware-1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{"provider-1.middleware-1": "value-middleware-1"},
|
||||
contextProvider: "provider-1",
|
||||
},
|
||||
{
|
||||
desc: "Should not prefix a qualified middlewareName with the provider in the context",
|
||||
buildChain: []string{"provider-1.middleware-1"},
|
||||
configuration: map[string]*config.Middleware{
|
||||
"provider-1.middleware-1": {
|
||||
Headers: &config.Headers{
|
||||
CustomRequestHeaders: map[string]string{"provider-1.middleware-1": "value-middleware-1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{"provider-1.middleware-1": "value-middleware-1"},
|
||||
contextProvider: "provider-1",
|
||||
},
|
||||
{
|
||||
desc: "Should be context aware if a chain references another middleware",
|
||||
buildChain: []string{"provider-1.middleware-chain-1"},
|
||||
configuration: map[string]*config.Middleware{
|
||||
"provider-1.middleware-1": {
|
||||
Headers: &config.Headers{
|
||||
CustomRequestHeaders: map[string]string{"middleware-1": "value-middleware-1"},
|
||||
},
|
||||
},
|
||||
"provider-1.middleware-chain-1": {
|
||||
Chain: &config.Chain{
|
||||
Middlewares: []string{"middleware-1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{"middleware-1": "value-middleware-1"},
|
||||
},
|
||||
{
|
||||
desc: "Should handle nested chains with different context",
|
||||
buildChain: []string{"provider-1.middleware-chain-1", "middleware-chain-1"},
|
||||
configuration: map[string]*config.Middleware{
|
||||
"provider-1.middleware-1": {
|
||||
Headers: &config.Headers{
|
||||
CustomRequestHeaders: map[string]string{"middleware-1": "value-middleware-1"},
|
||||
},
|
||||
},
|
||||
"provider-1.middleware-2": {
|
||||
Headers: &config.Headers{
|
||||
CustomRequestHeaders: map[string]string{"middleware-2": "value-middleware-2"},
|
||||
},
|
||||
},
|
||||
"provider-1.middleware-chain-1": {
|
||||
Chain: &config.Chain{
|
||||
Middlewares: []string{"middleware-1"},
|
||||
},
|
||||
},
|
||||
"provider-1.middleware-chain-2": {
|
||||
Chain: &config.Chain{
|
||||
Middlewares: []string{"middleware-2"},
|
||||
},
|
||||
},
|
||||
"provider-2.middleware-chain-1": {
|
||||
Chain: &config.Chain{
|
||||
Middlewares: []string{"provider-1.middleware-2", "provider-1.middleware-chain-2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]string{"middleware-1": "value-middleware-1", "middleware-2": "value-middleware-2"},
|
||||
contextProvider: "provider-2",
|
||||
},
|
||||
{
|
||||
desc: "Detects recursion in Middleware chain",
|
||||
buildChain: []string{"m1"},
|
||||
configuration: map[string]*config.Middleware{
|
||||
"ok": {
|
||||
Retry: &config.Retry{},
|
||||
},
|
||||
"m1": {
|
||||
Chain: &config.Chain{
|
||||
Middlewares: []string{"m2"},
|
||||
},
|
||||
},
|
||||
"m2": {
|
||||
Chain: &config.Chain{
|
||||
Middlewares: []string{"ok", "m3"},
|
||||
},
|
||||
},
|
||||
"m3": {
|
||||
Chain: &config.Chain{
|
||||
Middlewares: []string{"m1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: errors.New("could not instantiate middleware m1: recursion detected in m1->m2->m3->m1"),
|
||||
},
|
||||
{
|
||||
desc: "Detects recursion in Middleware chain",
|
||||
buildChain: []string{"provider.m1"},
|
||||
configuration: map[string]*config.Middleware{
|
||||
"provider2.ok": {
|
||||
Retry: &config.Retry{},
|
||||
},
|
||||
"provider.m1": {
|
||||
Chain: &config.Chain{
|
||||
Middlewares: []string{"provider2.m2"},
|
||||
},
|
||||
},
|
||||
"provider2.m2": {
|
||||
Chain: &config.Chain{
|
||||
Middlewares: []string{"ok", "provider.m3"},
|
||||
},
|
||||
},
|
||||
"provider.m3": {
|
||||
Chain: &config.Chain{
|
||||
Middlewares: []string{"m1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: errors.New("could not instantiate middleware provider.m1: recursion detected in provider.m1->provider2.m2->provider.m3->provider.m1"),
|
||||
},
|
||||
{
|
||||
buildChain: []string{"ok", "m0"},
|
||||
configuration: map[string]*config.Middleware{
|
||||
"ok": {
|
||||
Retry: &config.Retry{},
|
||||
},
|
||||
"m0": {
|
||||
Chain: &config.Chain{
|
||||
Middlewares: []string{"m0"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: errors.New("could not instantiate middleware m0: recursion detected in m0->m0"),
|
||||
},
|
||||
{
|
||||
desc: "Detects MiddlewareChain that references a Chain that references a Chain with a missing middleware",
|
||||
buildChain: []string{"m0"},
|
||||
configuration: map[string]*config.Middleware{
|
||||
"m0": {
|
||||
Chain: &config.Chain{
|
||||
Middlewares: []string{"m1"},
|
||||
},
|
||||
},
|
||||
"m1": {
|
||||
Chain: &config.Chain{
|
||||
Middlewares: []string{"m2"},
|
||||
},
|
||||
},
|
||||
"m2": {
|
||||
Chain: &config.Chain{
|
||||
Middlewares: []string{"m3"},
|
||||
},
|
||||
},
|
||||
"m3": {
|
||||
Chain: &config.Chain{
|
||||
Middlewares: []string{"m2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: errors.New("could not instantiate middleware m2: recursion detected in m0->m1->m2->m3->m2"),
|
||||
},
|
||||
{
|
||||
desc: "--",
|
||||
buildChain: []string{"m0"},
|
||||
configuration: map[string]*config.Middleware{
|
||||
"m0": {
|
||||
Chain: &config.Chain{
|
||||
Middlewares: []string{"m0"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedError: errors.New("could not instantiate middleware m0: recursion detected in m0->m0"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
if len(test.contextProvider) > 0 {
|
||||
ctx = internal.AddProviderInContext(ctx, test.contextProvider+".foobar")
|
||||
}
|
||||
|
||||
builder := NewBuilder(test.configuration, nil)
|
||||
|
||||
result := builder.BuildChain(ctx, test.buildChain)
|
||||
|
||||
handlers, err := result.Then(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }))
|
||||
if test.expectedError != nil {
|
||||
require.NotNil(t, err)
|
||||
require.Equal(t, test.expectedError.Error(), err.Error())
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
recorder := httptest.NewRecorder()
|
||||
request, _ := http.NewRequest(http.MethodGet, "http://foo/", nil)
|
||||
handlers.ServeHTTP(recorder, request)
|
||||
|
||||
for key, value := range test.expected {
|
||||
assert.Equal(t, value, request.Header.Get(key))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
100
pkg/server/roundtripper.go
Normal file
100
pkg/server/roundtripper.go
Normal file
|
@ -0,0 +1,100 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/old/configuration"
|
||||
"github.com/containous/traefik/pkg/config/static"
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
traefiktls "github.com/containous/traefik/pkg/tls"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
type h2cTransportWrapper struct {
|
||||
*http2.Transport
|
||||
}
|
||||
|
||||
func (t *h2cTransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
req.URL.Scheme = "http"
|
||||
return t.Transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// createHTTPTransport creates an http.Transport configured with the Transport configuration settings.
|
||||
// For the settings that can't be configured in Traefik it uses the default http.Transport settings.
|
||||
// An exception to this is the MaxIdleConns setting as we only provide the option MaxIdleConnsPerHost
|
||||
// in Traefik at this point in time. Setting this value to the default of 100 could lead to confusing
|
||||
// behavior and backwards compatibility issues.
|
||||
func createHTTPTransport(transportConfiguration *static.ServersTransport) (*http.Transport, error) {
|
||||
if transportConfiguration == nil {
|
||||
return nil, errors.New("no transport configuration given")
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: configuration.DefaultDialTimeout,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}
|
||||
|
||||
if transportConfiguration.ForwardingTimeouts != nil {
|
||||
dialer.Timeout = time.Duration(transportConfiguration.ForwardingTimeouts.DialTimeout)
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: dialer.DialContext,
|
||||
MaxIdleConnsPerHost: transportConfiguration.MaxIdleConnsPerHost,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
transport.RegisterProtocol("h2c", &h2cTransportWrapper{
|
||||
Transport: &http2.Transport{
|
||||
DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
|
||||
return net.Dial(netw, addr)
|
||||
},
|
||||
AllowHTTP: true,
|
||||
},
|
||||
})
|
||||
|
||||
if transportConfiguration.ForwardingTimeouts != nil {
|
||||
transport.ResponseHeaderTimeout = time.Duration(transportConfiguration.ForwardingTimeouts.ResponseHeaderTimeout)
|
||||
}
|
||||
|
||||
if transportConfiguration.InsecureSkipVerify {
|
||||
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
}
|
||||
|
||||
if len(transportConfiguration.RootCAs) > 0 {
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
RootCAs: createRootCACertPool(transportConfiguration.RootCAs),
|
||||
}
|
||||
}
|
||||
|
||||
err := http2.ConfigureTransport(transport)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
func createRootCACertPool(rootCAs traefiktls.FilesOrContents) *x509.CertPool {
|
||||
roots := x509.NewCertPool()
|
||||
|
||||
for _, cert := range rootCAs {
|
||||
certContent, err := cert.Read()
|
||||
if err != nil {
|
||||
log.WithoutContext().Error("Error while read RootCAs", err)
|
||||
continue
|
||||
}
|
||||
roots.AppendCertsFromPEM(certContent)
|
||||
}
|
||||
|
||||
return roots
|
||||
}
|
109
pkg/server/router/route_appender_aggregator.go
Normal file
109
pkg/server/router/route_appender_aggregator.go
Normal file
|
@ -0,0 +1,109 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/containous/alice"
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/pkg/api"
|
||||
"github.com/containous/traefik/pkg/config/static"
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
"github.com/containous/traefik/pkg/metrics"
|
||||
"github.com/containous/traefik/pkg/safe"
|
||||
"github.com/containous/traefik/pkg/types"
|
||||
)
|
||||
|
||||
// chainBuilder The contract of the middleware builder
|
||||
type chainBuilder interface {
|
||||
BuildChain(ctx context.Context, middlewares []string) *alice.Chain
|
||||
}
|
||||
|
||||
// NewRouteAppenderAggregator Creates a new RouteAppenderAggregator
|
||||
func NewRouteAppenderAggregator(ctx context.Context, chainBuilder chainBuilder, conf static.Configuration, entryPointName string, currentConfiguration *safe.Safe) *RouteAppenderAggregator {
|
||||
aggregator := &RouteAppenderAggregator{}
|
||||
|
||||
if conf.Providers != nil && conf.Providers.Rest != nil {
|
||||
aggregator.AddAppender(conf.Providers.Rest)
|
||||
}
|
||||
|
||||
if conf.API != nil && conf.API.EntryPoint == entryPointName {
|
||||
chain := chainBuilder.BuildChain(ctx, conf.API.Middlewares)
|
||||
aggregator.AddAppender(&WithMiddleware{
|
||||
appender: api.Handler{
|
||||
EntryPoint: conf.API.EntryPoint,
|
||||
Dashboard: conf.API.Dashboard,
|
||||
Statistics: conf.API.Statistics,
|
||||
DashboardAssets: conf.API.DashboardAssets,
|
||||
CurrentConfigurations: currentConfiguration,
|
||||
Debug: conf.Global.Debug,
|
||||
},
|
||||
routerMiddlewares: chain,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
if conf.Ping != nil && conf.Ping.EntryPoint == entryPointName {
|
||||
chain := chainBuilder.BuildChain(ctx, conf.Ping.Middlewares)
|
||||
aggregator.AddAppender(&WithMiddleware{
|
||||
appender: conf.Ping,
|
||||
routerMiddlewares: chain,
|
||||
})
|
||||
}
|
||||
|
||||
if conf.Metrics != nil && conf.Metrics.Prometheus != nil && conf.Metrics.Prometheus.EntryPoint == entryPointName {
|
||||
chain := chainBuilder.BuildChain(ctx, conf.Metrics.Prometheus.Middlewares)
|
||||
aggregator.AddAppender(&WithMiddleware{
|
||||
appender: metrics.PrometheusHandler{},
|
||||
routerMiddlewares: chain,
|
||||
})
|
||||
}
|
||||
|
||||
return aggregator
|
||||
}
|
||||
|
||||
// RouteAppenderAggregator RouteAppender that aggregate other RouteAppender
|
||||
type RouteAppenderAggregator struct {
|
||||
appenders []types.RouteAppender
|
||||
}
|
||||
|
||||
// Append Adds routes to the router
|
||||
func (r *RouteAppenderAggregator) Append(systemRouter *mux.Router) {
|
||||
for _, router := range r.appenders {
|
||||
router.Append(systemRouter)
|
||||
}
|
||||
}
|
||||
|
||||
// AddAppender adds a router in the aggregator
|
||||
func (r *RouteAppenderAggregator) AddAppender(router types.RouteAppender) {
|
||||
r.appenders = append(r.appenders, router)
|
||||
}
|
||||
|
||||
// WithMiddleware router with internal middleware
|
||||
type WithMiddleware struct {
|
||||
appender types.RouteAppender
|
||||
routerMiddlewares *alice.Chain
|
||||
}
|
||||
|
||||
// Append Adds routes to the router
|
||||
func (wm *WithMiddleware) Append(systemRouter *mux.Router) {
|
||||
realRouter := systemRouter.PathPrefix("/").Subrouter()
|
||||
|
||||
wm.appender.Append(realRouter)
|
||||
|
||||
if err := realRouter.Walk(wrapRoute(wm.routerMiddlewares)); err != nil {
|
||||
log.WithoutContext().Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// wrapRoute with middlewares
|
||||
func wrapRoute(middlewares *alice.Chain) func(*mux.Route, *mux.Router, []*mux.Route) error {
|
||||
return func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
|
||||
handler, err := middlewares.Then(route.GetHandler())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
route.Handler(handler)
|
||||
return nil
|
||||
}
|
||||
}
|
114
pkg/server/router/route_appender_aggregator_test.go
Normal file
114
pkg/server/router/route_appender_aggregator_test.go
Normal file
|
@ -0,0 +1,114 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/alice"
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/pkg/config/static"
|
||||
"github.com/containous/traefik/pkg/ping"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type ChainBuilderMock struct {
|
||||
middles map[string]alice.Constructor
|
||||
}
|
||||
|
||||
func (c *ChainBuilderMock) BuildChain(ctx context.Context, middles []string) *alice.Chain {
|
||||
chain := alice.New()
|
||||
|
||||
for _, mName := range middles {
|
||||
if constructor, ok := c.middles[mName]; ok {
|
||||
chain = chain.Append(constructor)
|
||||
}
|
||||
}
|
||||
|
||||
return &chain
|
||||
}
|
||||
|
||||
func TestNewRouteAppenderAggregator(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
staticConf static.Configuration
|
||||
middles map[string]alice.Constructor
|
||||
expected map[string]int
|
||||
}{
|
||||
{
|
||||
desc: "API with auth, ping without auth",
|
||||
staticConf: static.Configuration{
|
||||
Global: &static.Global{},
|
||||
API: &static.API{
|
||||
EntryPoint: "traefik",
|
||||
Middlewares: []string{"dumb"},
|
||||
},
|
||||
Ping: &ping.Handler{
|
||||
EntryPoint: "traefik",
|
||||
},
|
||||
EntryPoints: static.EntryPoints{
|
||||
"traefik": {},
|
||||
},
|
||||
},
|
||||
middles: map[string]alice.Constructor{
|
||||
"dumb": func(_ http.Handler) (http.Handler, error) {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}), nil
|
||||
},
|
||||
},
|
||||
expected: map[string]int{
|
||||
"/wrong": http.StatusBadGateway,
|
||||
"/ping": http.StatusOK,
|
||||
// "/.well-known/acme-challenge/token": http.StatusNotFound, // FIXME
|
||||
"/api/providers": http.StatusUnauthorized,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Wrong entrypoint name",
|
||||
staticConf: static.Configuration{
|
||||
Global: &static.Global{},
|
||||
API: &static.API{
|
||||
EntryPoint: "no",
|
||||
},
|
||||
EntryPoints: static.EntryPoints{
|
||||
"traefik": {},
|
||||
},
|
||||
},
|
||||
expected: map[string]int{
|
||||
"/api/providers": http.StatusBadGateway,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chainBuilder := &ChainBuilderMock{middles: test.middles}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
router := NewRouteAppenderAggregator(ctx, chainBuilder, test.staticConf, "traefik", nil)
|
||||
|
||||
internalMuxRouter := mux.NewRouter()
|
||||
router.Append(internalMuxRouter)
|
||||
|
||||
internalMuxRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
})
|
||||
|
||||
actual := make(map[string]int)
|
||||
for calledURL := range test.expected {
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodGet, calledURL, nil)
|
||||
internalMuxRouter.ServeHTTP(recorder, request)
|
||||
actual[calledURL] = recorder.Code
|
||||
}
|
||||
|
||||
assert.Equal(t, test.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
38
pkg/server/router/route_appender_factory.go
Normal file
38
pkg/server/router/route_appender_factory.go
Normal file
|
@ -0,0 +1,38 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/containous/traefik/pkg/config/static"
|
||||
"github.com/containous/traefik/pkg/provider/acme"
|
||||
"github.com/containous/traefik/pkg/safe"
|
||||
"github.com/containous/traefik/pkg/server/middleware"
|
||||
"github.com/containous/traefik/pkg/types"
|
||||
)
|
||||
|
||||
// NewRouteAppenderFactory Creates a new RouteAppenderFactory
|
||||
func NewRouteAppenderFactory(staticConfiguration static.Configuration, entryPointName string, acmeProvider *acme.Provider) *RouteAppenderFactory {
|
||||
return &RouteAppenderFactory{
|
||||
staticConfiguration: staticConfiguration,
|
||||
entryPointName: entryPointName,
|
||||
acmeProvider: acmeProvider,
|
||||
}
|
||||
}
|
||||
|
||||
// RouteAppenderFactory A factory of RouteAppender
|
||||
type RouteAppenderFactory struct {
|
||||
staticConfiguration static.Configuration
|
||||
entryPointName string
|
||||
acmeProvider *acme.Provider
|
||||
}
|
||||
|
||||
// NewAppender Creates a new RouteAppender
|
||||
func (r *RouteAppenderFactory) NewAppender(ctx context.Context, middlewaresBuilder *middleware.Builder, currentConfiguration *safe.Safe) types.RouteAppender {
|
||||
aggregator := NewRouteAppenderAggregator(ctx, middlewaresBuilder, r.staticConfiguration, r.entryPointName, currentConfiguration)
|
||||
|
||||
if r.acmeProvider != nil && r.acmeProvider.HTTPChallenge != nil && r.acmeProvider.HTTPChallenge.EntryPoint == r.entryPointName {
|
||||
aggregator.AddAppender(r.acmeProvider)
|
||||
}
|
||||
|
||||
return aggregator
|
||||
}
|
195
pkg/server/router/router.go
Normal file
195
pkg/server/router/router.go
Normal file
|
@ -0,0 +1,195 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/alice"
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
"github.com/containous/traefik/pkg/middlewares/accesslog"
|
||||
"github.com/containous/traefik/pkg/middlewares/recovery"
|
||||
"github.com/containous/traefik/pkg/middlewares/tracing"
|
||||
"github.com/containous/traefik/pkg/responsemodifiers"
|
||||
"github.com/containous/traefik/pkg/rules"
|
||||
"github.com/containous/traefik/pkg/server/internal"
|
||||
"github.com/containous/traefik/pkg/server/middleware"
|
||||
"github.com/containous/traefik/pkg/server/service"
|
||||
)
|
||||
|
||||
const (
|
||||
recoveryMiddlewareName = "traefik-internal-recovery"
|
||||
)
|
||||
|
||||
// NewManager Creates a new Manager
|
||||
func NewManager(routers map[string]*config.Router,
|
||||
serviceManager *service.Manager, middlewaresBuilder *middleware.Builder, modifierBuilder *responsemodifiers.Builder,
|
||||
) *Manager {
|
||||
return &Manager{
|
||||
routerHandlers: make(map[string]http.Handler),
|
||||
configs: routers,
|
||||
serviceManager: serviceManager,
|
||||
middlewaresBuilder: middlewaresBuilder,
|
||||
modifierBuilder: modifierBuilder,
|
||||
}
|
||||
}
|
||||
|
||||
// Manager A route/router manager
|
||||
type Manager struct {
|
||||
routerHandlers map[string]http.Handler
|
||||
configs map[string]*config.Router
|
||||
serviceManager *service.Manager
|
||||
middlewaresBuilder *middleware.Builder
|
||||
modifierBuilder *responsemodifiers.Builder
|
||||
}
|
||||
|
||||
// BuildHandlers Builds handler for all entry points
|
||||
func (m *Manager) BuildHandlers(rootCtx context.Context, entryPoints []string, tls bool) map[string]http.Handler {
|
||||
entryPointsRouters := m.filteredRouters(rootCtx, entryPoints, tls)
|
||||
|
||||
entryPointHandlers := make(map[string]http.Handler)
|
||||
for entryPointName, routers := range entryPointsRouters {
|
||||
entryPointName := entryPointName
|
||||
ctx := log.With(rootCtx, log.Str(log.EntryPointName, entryPointName))
|
||||
|
||||
handler, err := m.buildEntryPointHandler(ctx, routers)
|
||||
if err != nil {
|
||||
log.FromContext(ctx).Error(err)
|
||||
continue
|
||||
}
|
||||
|
||||
handlerWithAccessLog, err := alice.New(func(next http.Handler) (http.Handler, error) {
|
||||
return accesslog.NewFieldHandler(next, log.EntryPointName, entryPointName, accesslog.AddOriginFields), nil
|
||||
}).Then(handler)
|
||||
if err != nil {
|
||||
log.FromContext(ctx).Error(err)
|
||||
entryPointHandlers[entryPointName] = handler
|
||||
} else {
|
||||
entryPointHandlers[entryPointName] = handlerWithAccessLog
|
||||
}
|
||||
}
|
||||
|
||||
m.serviceManager.LaunchHealthCheck()
|
||||
|
||||
return entryPointHandlers
|
||||
}
|
||||
|
||||
func contains(entryPoints []string, entryPointName string) bool {
|
||||
for _, name := range entryPoints {
|
||||
if name == entryPointName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) filteredRouters(ctx context.Context, entryPoints []string, tls bool) map[string]map[string]*config.Router {
|
||||
entryPointsRouters := make(map[string]map[string]*config.Router)
|
||||
|
||||
for rtName, rt := range m.configs {
|
||||
if (tls && rt.TLS == nil) || (!tls && rt.TLS != nil) {
|
||||
continue
|
||||
}
|
||||
|
||||
eps := rt.EntryPoints
|
||||
if len(eps) == 0 {
|
||||
eps = entryPoints
|
||||
}
|
||||
for _, entryPointName := range eps {
|
||||
if !contains(entryPoints, entryPointName) {
|
||||
log.FromContext(log.With(ctx, log.Str(log.EntryPointName, entryPointName))).
|
||||
Errorf("entryPoint %q doesn't exist", entryPointName)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := entryPointsRouters[entryPointName]; !ok {
|
||||
entryPointsRouters[entryPointName] = make(map[string]*config.Router)
|
||||
}
|
||||
|
||||
entryPointsRouters[entryPointName][rtName] = rt
|
||||
}
|
||||
}
|
||||
|
||||
return entryPointsRouters
|
||||
}
|
||||
|
||||
func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string]*config.Router) (http.Handler, error) {
|
||||
router, err := rules.NewRouter()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for routerName, routerConfig := range configs {
|
||||
ctxRouter := log.With(ctx, log.Str(log.RouterName, routerName))
|
||||
logger := log.FromContext(ctxRouter)
|
||||
|
||||
ctxRouter = internal.AddProviderInContext(ctxRouter, routerName)
|
||||
|
||||
handler, err := m.buildRouterHandler(ctxRouter, routerName)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
continue
|
||||
}
|
||||
|
||||
err = router.AddRoute(routerConfig.Rule, routerConfig.Priority, handler)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
router.SortRoutes()
|
||||
|
||||
chain := alice.New()
|
||||
chain = chain.Append(func(next http.Handler) (http.Handler, error) {
|
||||
return recovery.New(ctx, next, recoveryMiddlewareName)
|
||||
})
|
||||
|
||||
return chain.Then(router)
|
||||
}
|
||||
|
||||
func (m *Manager) buildRouterHandler(ctx context.Context, routerName string) (http.Handler, error) {
|
||||
if handler, ok := m.routerHandlers[routerName]; ok {
|
||||
return handler, nil
|
||||
}
|
||||
|
||||
configRouter, ok := m.configs[routerName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no configuration for %s", routerName)
|
||||
}
|
||||
|
||||
handler, err := m.buildHTTPHandler(ctx, configRouter, routerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
handlerWithAccessLog, err := alice.New(func(next http.Handler) (http.Handler, error) {
|
||||
return accesslog.NewFieldHandler(next, accesslog.RouterName, routerName, nil), nil
|
||||
}).Then(handler)
|
||||
if err != nil {
|
||||
log.FromContext(ctx).Error(err)
|
||||
m.routerHandlers[routerName] = handler
|
||||
} else {
|
||||
m.routerHandlers[routerName] = handlerWithAccessLog
|
||||
}
|
||||
|
||||
return m.routerHandlers[routerName], nil
|
||||
}
|
||||
|
||||
func (m *Manager) buildHTTPHandler(ctx context.Context, router *config.Router, routerName string) (http.Handler, error) {
|
||||
rm := m.modifierBuilder.Build(ctx, router.Middlewares)
|
||||
|
||||
sHandler, err := m.serviceManager.BuildHTTP(ctx, router.Service, rm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mHandler := m.middlewaresBuilder.BuildChain(ctx, router.Middlewares)
|
||||
|
||||
tHandler := func(next http.Handler) (http.Handler, error) {
|
||||
return tracing.NewForwarder(ctx, routerName, router.Service, next), nil
|
||||
}
|
||||
|
||||
return alice.New().Extend(*mHandler).Append(tHandler).Then(sHandler)
|
||||
}
|
531
pkg/server/router/router_test.go
Normal file
531
pkg/server/router/router_test.go
Normal file
|
@ -0,0 +1,531 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/middlewares/accesslog"
|
||||
"github.com/containous/traefik/pkg/middlewares/requestdecorator"
|
||||
"github.com/containous/traefik/pkg/responsemodifiers"
|
||||
"github.com/containous/traefik/pkg/server/middleware"
|
||||
"github.com/containous/traefik/pkg/server/service"
|
||||
"github.com/containous/traefik/pkg/testhelpers"
|
||||
"github.com/containous/traefik/pkg/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRouterManager_Get(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
|
||||
type ExpectedResult struct {
|
||||
StatusCode int
|
||||
RequestHeaders map[string]string
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
routersConfig map[string]*config.Router
|
||||
serviceConfig map[string]*config.Service
|
||||
middlewaresConfig map[string]*config.Middleware
|
||||
entryPoints []string
|
||||
expected ExpectedResult
|
||||
}{
|
||||
{
|
||||
desc: "no middleware",
|
||||
routersConfig: map[string]*config.Router{
|
||||
"foo": {
|
||||
EntryPoints: []string{"web"},
|
||||
Service: "foo-service",
|
||||
Rule: "Host(`foo.bar`)",
|
||||
},
|
||||
},
|
||||
serviceConfig: map[string]*config.Service{
|
||||
"foo-service": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
entryPoints: []string{"web"},
|
||||
expected: ExpectedResult{StatusCode: http.StatusOK},
|
||||
},
|
||||
{
|
||||
desc: "no load balancer",
|
||||
routersConfig: map[string]*config.Router{
|
||||
"foo": {
|
||||
EntryPoints: []string{"web"},
|
||||
Service: "foo-service",
|
||||
Rule: "Host(`foo.bar`)",
|
||||
},
|
||||
},
|
||||
serviceConfig: map[string]*config.Service{
|
||||
"foo-service": {},
|
||||
},
|
||||
entryPoints: []string{"web"},
|
||||
expected: ExpectedResult{StatusCode: http.StatusNotFound},
|
||||
},
|
||||
{
|
||||
desc: "no middleware, default entry point",
|
||||
routersConfig: map[string]*config.Router{
|
||||
"foo": {
|
||||
Service: "foo-service",
|
||||
Rule: "Host(`foo.bar`)",
|
||||
},
|
||||
},
|
||||
serviceConfig: map[string]*config.Service{
|
||||
"foo-service": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
entryPoints: []string{"web"},
|
||||
expected: ExpectedResult{StatusCode: http.StatusOK},
|
||||
},
|
||||
{
|
||||
desc: "no middleware, no matching",
|
||||
routersConfig: map[string]*config.Router{
|
||||
"foo": {
|
||||
EntryPoints: []string{"web"},
|
||||
Service: "foo-service",
|
||||
Rule: "Host(`bar.bar`)",
|
||||
},
|
||||
},
|
||||
serviceConfig: map[string]*config.Service{
|
||||
"foo-service": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
entryPoints: []string{"web"},
|
||||
expected: ExpectedResult{StatusCode: http.StatusNotFound},
|
||||
},
|
||||
{
|
||||
desc: "middleware: headers > auth",
|
||||
routersConfig: map[string]*config.Router{
|
||||
"foo": {
|
||||
EntryPoints: []string{"web"},
|
||||
Middlewares: []string{"headers-middle", "auth-middle"},
|
||||
Service: "foo-service",
|
||||
Rule: "Host(`foo.bar`)",
|
||||
},
|
||||
},
|
||||
serviceConfig: map[string]*config.Service{
|
||||
"foo-service": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
middlewaresConfig: map[string]*config.Middleware{
|
||||
"auth-middle": {
|
||||
BasicAuth: &config.BasicAuth{
|
||||
Users: []string{"toto:titi"},
|
||||
},
|
||||
},
|
||||
"headers-middle": {
|
||||
Headers: &config.Headers{
|
||||
CustomRequestHeaders: map[string]string{"X-Apero": "beer"},
|
||||
},
|
||||
},
|
||||
},
|
||||
entryPoints: []string{"web"},
|
||||
expected: ExpectedResult{
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
RequestHeaders: map[string]string{
|
||||
"X-Apero": "beer",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "middleware: auth > header",
|
||||
routersConfig: map[string]*config.Router{
|
||||
"foo": {
|
||||
EntryPoints: []string{"web"},
|
||||
Middlewares: []string{"auth-middle", "headers-middle"},
|
||||
Service: "foo-service",
|
||||
Rule: "Host(`foo.bar`)",
|
||||
},
|
||||
},
|
||||
serviceConfig: map[string]*config.Service{
|
||||
"foo-service": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
middlewaresConfig: map[string]*config.Middleware{
|
||||
"auth-middle": {
|
||||
BasicAuth: &config.BasicAuth{
|
||||
Users: []string{"toto:titi"},
|
||||
},
|
||||
},
|
||||
"headers-middle": {
|
||||
Headers: &config.Headers{
|
||||
CustomRequestHeaders: map[string]string{"X-Apero": "beer"},
|
||||
},
|
||||
},
|
||||
},
|
||||
entryPoints: []string{"web"},
|
||||
expected: ExpectedResult{
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
RequestHeaders: map[string]string{
|
||||
"X-Apero": "",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "no middleware with provider name",
|
||||
routersConfig: map[string]*config.Router{
|
||||
"provider-1.foo": {
|
||||
EntryPoints: []string{"web"},
|
||||
Service: "foo-service",
|
||||
Rule: "Host(`foo.bar`)",
|
||||
},
|
||||
},
|
||||
serviceConfig: map[string]*config.Service{
|
||||
"provider-1.foo-service": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
entryPoints: []string{"web"},
|
||||
expected: ExpectedResult{StatusCode: http.StatusOK},
|
||||
},
|
||||
{
|
||||
desc: "no middleware with specified provider name",
|
||||
routersConfig: map[string]*config.Router{
|
||||
"provider-1.foo": {
|
||||
EntryPoints: []string{"web"},
|
||||
Service: "provider-2.foo-service",
|
||||
Rule: "Host(`foo.bar`)",
|
||||
},
|
||||
},
|
||||
serviceConfig: map[string]*config.Service{
|
||||
"provider-2.foo-service": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
entryPoints: []string{"web"},
|
||||
expected: ExpectedResult{StatusCode: http.StatusOK},
|
||||
},
|
||||
{
|
||||
desc: "middleware: chain with provider name",
|
||||
routersConfig: map[string]*config.Router{
|
||||
"provider-1.foo": {
|
||||
EntryPoints: []string{"web"},
|
||||
Middlewares: []string{"provider-2.chain-middle", "headers-middle"},
|
||||
Service: "foo-service",
|
||||
Rule: "Host(`foo.bar`)",
|
||||
},
|
||||
},
|
||||
serviceConfig: map[string]*config.Service{
|
||||
"provider-1.foo-service": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
middlewaresConfig: map[string]*config.Middleware{
|
||||
"provider-2.chain-middle": {
|
||||
Chain: &config.Chain{Middlewares: []string{"auth-middle"}},
|
||||
},
|
||||
"provider-2.auth-middle": {
|
||||
BasicAuth: &config.BasicAuth{
|
||||
Users: []string{"toto:titi"},
|
||||
},
|
||||
},
|
||||
"provider-1.headers-middle": {
|
||||
Headers: &config.Headers{
|
||||
CustomRequestHeaders: map[string]string{"X-Apero": "beer"},
|
||||
},
|
||||
},
|
||||
},
|
||||
entryPoints: []string{"web"},
|
||||
expected: ExpectedResult{
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
RequestHeaders: map[string]string{
|
||||
"X-Apero": "",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serviceManager := service.NewManager(test.serviceConfig, http.DefaultTransport)
|
||||
middlewaresBuilder := middleware.NewBuilder(test.middlewaresConfig, serviceManager)
|
||||
responseModifierFactory := responsemodifiers.NewBuilder(test.middlewaresConfig)
|
||||
|
||||
routerManager := NewManager(test.routersConfig, serviceManager, middlewaresBuilder, responseModifierFactory)
|
||||
|
||||
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
|
||||
|
||||
reqHost := requestdecorator.New(nil)
|
||||
reqHost.ServeHTTP(w, req, handlers["web"].ServeHTTP)
|
||||
|
||||
assert.Equal(t, test.expected.StatusCode, w.Code)
|
||||
|
||||
for key, value := range test.expected.RequestHeaders {
|
||||
assert.Equal(t, value, req.Header.Get(key))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLog(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
routersConfig map[string]*config.Router
|
||||
serviceConfig map[string]*config.Service
|
||||
middlewaresConfig map[string]*config.Middleware
|
||||
entryPoints []string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "apply routerName in accesslog (first match)",
|
||||
routersConfig: map[string]*config.Router{
|
||||
"foo": {
|
||||
EntryPoints: []string{"web"},
|
||||
Service: "foo-service",
|
||||
Rule: "Host(`foo.bar`)",
|
||||
},
|
||||
"bar": {
|
||||
EntryPoints: []string{"web"},
|
||||
Service: "foo-service",
|
||||
Rule: "Host(`bar.foo`)",
|
||||
},
|
||||
},
|
||||
serviceConfig: map[string]*config.Service{
|
||||
"foo-service": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
entryPoints: []string{"web"},
|
||||
expected: "foo",
|
||||
},
|
||||
{
|
||||
desc: "apply routerName in accesslog (second match)",
|
||||
routersConfig: map[string]*config.Router{
|
||||
"foo": {
|
||||
EntryPoints: []string{"web"},
|
||||
Service: "foo-service",
|
||||
Rule: "Host(`bar.foo`)",
|
||||
},
|
||||
"bar": {
|
||||
EntryPoints: []string{"web"},
|
||||
Service: "foo-service",
|
||||
Rule: "Host(`foo.bar`)",
|
||||
},
|
||||
},
|
||||
serviceConfig: map[string]*config.Service{
|
||||
"foo-service": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
entryPoints: []string{"web"},
|
||||
expected: "bar",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
|
||||
serviceManager := service.NewManager(test.serviceConfig, http.DefaultTransport)
|
||||
middlewaresBuilder := middleware.NewBuilder(test.middlewaresConfig, serviceManager)
|
||||
responseModifierFactory := responsemodifiers.NewBuilder(test.middlewaresConfig)
|
||||
|
||||
routerManager := NewManager(test.routersConfig, serviceManager, middlewaresBuilder, responseModifierFactory)
|
||||
|
||||
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
|
||||
|
||||
accesslogger, err := accesslog.NewHandler(&types.AccessLog{
|
||||
Format: "json",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
reqHost := requestdecorator.New(nil)
|
||||
|
||||
accesslogger.ServeHTTP(w, req, http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
reqHost.ServeHTTP(w, req, handlers["web"].ServeHTTP)
|
||||
|
||||
data := accesslog.GetLogData(req)
|
||||
require.NotNil(t, data)
|
||||
|
||||
assert.Equal(t, test.expected, data.Core[accesslog.RouterName])
|
||||
}))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type staticTransport struct {
|
||||
res *http.Response
|
||||
}
|
||||
|
||||
func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
return t.res, nil
|
||||
}
|
||||
|
||||
func BenchmarkRouterServe(b *testing.B) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
|
||||
res := &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("")),
|
||||
}
|
||||
routersConfig := map[string]*config.Router{
|
||||
"foo": {
|
||||
EntryPoints: []string{"web"},
|
||||
Service: "foo-service",
|
||||
Rule: "Host(`foo.bar`) && Path(`/`)",
|
||||
},
|
||||
}
|
||||
serviceConfig := map[string]*config.Service{
|
||||
"foo-service": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
}
|
||||
entryPoints := []string{"web"}
|
||||
|
||||
serviceManager := service.NewManager(serviceConfig, &staticTransport{res})
|
||||
middlewaresBuilder := middleware.NewBuilder(map[string]*config.Middleware{}, serviceManager)
|
||||
responseModifierFactory := responsemodifiers.NewBuilder(map[string]*config.Middleware{})
|
||||
|
||||
routerManager := NewManager(routersConfig, serviceManager, middlewaresBuilder, responseModifierFactory)
|
||||
|
||||
handlers := routerManager.BuildHandlers(context.Background(), entryPoints, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
|
||||
|
||||
reqHost := requestdecorator.New(nil)
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
reqHost.ServeHTTP(w, req, handlers["web"].ServeHTTP)
|
||||
}
|
||||
|
||||
}
|
||||
func BenchmarkService(b *testing.B) {
|
||||
res := &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("")),
|
||||
}
|
||||
|
||||
serviceConfig := map[string]*config.Service{
|
||||
"foo-service": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: "tchouck",
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
serviceManager := service.NewManager(serviceConfig, &staticTransport{res})
|
||||
w := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
|
||||
|
||||
handler, _ := serviceManager.BuildHTTP(context.Background(), "foo-service", nil)
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
handler.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
}
|
139
pkg/server/router/tcp/router.go
Normal file
139
pkg/server/router/tcp/router.go
Normal file
|
@ -0,0 +1,139 @@
|
|||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
"github.com/containous/traefik/pkg/rules"
|
||||
"github.com/containous/traefik/pkg/server/internal"
|
||||
tcpservice "github.com/containous/traefik/pkg/server/service/tcp"
|
||||
"github.com/containous/traefik/pkg/tcp"
|
||||
)
|
||||
|
||||
// NewManager Creates a new Manager
|
||||
func NewManager(routers map[string]*config.TCPRouter,
|
||||
serviceManager *tcpservice.Manager,
|
||||
httpHandlers map[string]http.Handler,
|
||||
httpsHandlers map[string]http.Handler,
|
||||
tlsConfig *tls.Config,
|
||||
) *Manager {
|
||||
return &Manager{
|
||||
configs: routers,
|
||||
serviceManager: serviceManager,
|
||||
httpHandlers: httpHandlers,
|
||||
httpsHandlers: httpsHandlers,
|
||||
tlsConfig: tlsConfig,
|
||||
}
|
||||
}
|
||||
|
||||
// Manager is a route/router manager
|
||||
type Manager struct {
|
||||
configs map[string]*config.TCPRouter
|
||||
serviceManager *tcpservice.Manager
|
||||
httpHandlers map[string]http.Handler
|
||||
httpsHandlers map[string]http.Handler
|
||||
tlsConfig *tls.Config
|
||||
}
|
||||
|
||||
// BuildHandlers builds the handlers for the given entrypoints
|
||||
func (m *Manager) BuildHandlers(rootCtx context.Context, entryPoints []string) map[string]*tcp.Router {
|
||||
entryPointsRouters := m.filteredRouters(rootCtx, entryPoints)
|
||||
|
||||
entryPointHandlers := make(map[string]*tcp.Router)
|
||||
for _, entryPointName := range entryPoints {
|
||||
entryPointName := entryPointName
|
||||
|
||||
routers := entryPointsRouters[entryPointName]
|
||||
|
||||
ctx := log.With(rootCtx, log.Str(log.EntryPointName, entryPointName))
|
||||
|
||||
handler, err := m.buildEntryPointHandler(ctx, routers, m.httpHandlers[entryPointName], m.httpsHandlers[entryPointName])
|
||||
if err != nil {
|
||||
log.FromContext(ctx).Error(err)
|
||||
continue
|
||||
}
|
||||
entryPointHandlers[entryPointName] = handler
|
||||
}
|
||||
return entryPointHandlers
|
||||
}
|
||||
|
||||
func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string]*config.TCPRouter, handlerHTTP http.Handler, handlerHTTPS http.Handler) (*tcp.Router, error) {
|
||||
router := &tcp.Router{}
|
||||
|
||||
router.HTTPHandler(handlerHTTP)
|
||||
router.HTTPSHandler(handlerHTTPS, m.tlsConfig)
|
||||
for routerName, routerConfig := range configs {
|
||||
ctxRouter := log.With(ctx, log.Str(log.RouterName, routerName))
|
||||
logger := log.FromContext(ctxRouter)
|
||||
|
||||
ctxRouter = internal.AddProviderInContext(ctxRouter, routerName)
|
||||
|
||||
handler, err := m.serviceManager.BuildTCP(ctxRouter, routerConfig.Service)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
continue
|
||||
}
|
||||
|
||||
domains, err := rules.ParseHostSNI(routerConfig.Rule)
|
||||
if err != nil {
|
||||
log.WithoutContext().Debugf("Unknown rule %s", routerConfig.Rule)
|
||||
continue
|
||||
}
|
||||
for _, domain := range domains {
|
||||
log.WithoutContext().Debugf("Add route %s on TCP", domain)
|
||||
switch {
|
||||
case routerConfig.TLS != nil:
|
||||
if routerConfig.TLS.Passthrough {
|
||||
router.AddRoute(domain, handler)
|
||||
} else {
|
||||
router.AddRouteTLS(domain, handler, m.tlsConfig)
|
||||
|
||||
}
|
||||
case domain == "*":
|
||||
router.AddCatchAllNoTLS(handler)
|
||||
default:
|
||||
logger.Warn("TCP Router ignored, cannot specify a Host rule without TLS")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return router, nil
|
||||
}
|
||||
|
||||
func contains(entryPoints []string, entryPointName string) bool {
|
||||
for _, name := range entryPoints {
|
||||
if name == entryPointName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) filteredRouters(ctx context.Context, entryPoints []string) map[string]map[string]*config.TCPRouter {
|
||||
entryPointsRouters := make(map[string]map[string]*config.TCPRouter)
|
||||
|
||||
for rtName, rt := range m.configs {
|
||||
eps := rt.EntryPoints
|
||||
if len(eps) == 0 {
|
||||
eps = entryPoints
|
||||
}
|
||||
for _, entryPointName := range eps {
|
||||
if !contains(entryPoints, entryPointName) {
|
||||
log.FromContext(log.With(ctx, log.Str(log.EntryPointName, entryPointName))).
|
||||
Errorf("entryPoint %q doesn't exist", entryPointName)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := entryPointsRouters[entryPointName]; !ok {
|
||||
entryPointsRouters[entryPointName] = make(map[string]*config.TCPRouter)
|
||||
}
|
||||
|
||||
entryPointsRouters[entryPointName][rtName] = rt
|
||||
}
|
||||
}
|
||||
|
||||
return entryPointsRouters
|
||||
}
|
307
pkg/server/server.go
Normal file
307
pkg/server/server.go
Normal file
|
@ -0,0 +1,307 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/config/static"
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
"github.com/containous/traefik/pkg/metrics"
|
||||
"github.com/containous/traefik/pkg/middlewares/accesslog"
|
||||
"github.com/containous/traefik/pkg/middlewares/requestdecorator"
|
||||
"github.com/containous/traefik/pkg/provider"
|
||||
"github.com/containous/traefik/pkg/safe"
|
||||
"github.com/containous/traefik/pkg/server/middleware"
|
||||
"github.com/containous/traefik/pkg/tls"
|
||||
"github.com/containous/traefik/pkg/tracing"
|
||||
"github.com/containous/traefik/pkg/tracing/datadog"
|
||||
"github.com/containous/traefik/pkg/tracing/instana"
|
||||
"github.com/containous/traefik/pkg/tracing/jaeger"
|
||||
"github.com/containous/traefik/pkg/tracing/zipkin"
|
||||
"github.com/containous/traefik/pkg/types"
|
||||
)
|
||||
|
||||
// Server is the reverse-proxy/load-balancer engine
|
||||
type Server struct {
|
||||
entryPointsTCP TCPEntryPoints
|
||||
configurationChan chan config.Message
|
||||
configurationValidatedChan chan config.Message
|
||||
signals chan os.Signal
|
||||
stopChan chan bool
|
||||
currentConfigurations safe.Safe
|
||||
providerConfigUpdateMap map[string]chan config.Message
|
||||
accessLoggerMiddleware *accesslog.Handler
|
||||
tracer *tracing.Tracing
|
||||
routinesPool *safe.Pool
|
||||
defaultRoundTripper http.RoundTripper
|
||||
metricsRegistry metrics.Registry
|
||||
provider provider.Provider
|
||||
configurationListeners []func(config.Configuration)
|
||||
requestDecorator *requestdecorator.RequestDecorator
|
||||
providersThrottleDuration time.Duration
|
||||
tlsManager *tls.Manager
|
||||
}
|
||||
|
||||
// RouteAppenderFactory the route appender factory interface
|
||||
type RouteAppenderFactory interface {
|
||||
NewAppender(ctx context.Context, middlewaresBuilder *middleware.Builder, currentConfigurations *safe.Safe) types.RouteAppender
|
||||
}
|
||||
|
||||
func setupTracing(conf *static.Tracing) tracing.TrackingBackend {
|
||||
switch conf.Backend {
|
||||
case jaeger.Name:
|
||||
return conf.Jaeger
|
||||
case zipkin.Name:
|
||||
return conf.Zipkin
|
||||
case datadog.Name:
|
||||
return conf.DataDog
|
||||
case instana.Name:
|
||||
return conf.Instana
|
||||
default:
|
||||
log.WithoutContext().Warnf("Could not initialize tracing: unknown tracer %q", conf.Backend)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewServer returns an initialized Server.
|
||||
func NewServer(staticConfiguration static.Configuration, provider provider.Provider, entryPoints TCPEntryPoints, tlsManager *tls.Manager) *Server {
|
||||
server := &Server{}
|
||||
|
||||
server.provider = provider
|
||||
server.entryPointsTCP = entryPoints
|
||||
server.configurationChan = make(chan config.Message, 100)
|
||||
server.configurationValidatedChan = make(chan config.Message, 100)
|
||||
server.signals = make(chan os.Signal, 1)
|
||||
server.stopChan = make(chan bool, 1)
|
||||
server.configureSignals()
|
||||
currentConfigurations := make(config.Configurations)
|
||||
server.currentConfigurations.Set(currentConfigurations)
|
||||
server.providerConfigUpdateMap = make(map[string]chan config.Message)
|
||||
server.tlsManager = tlsManager
|
||||
|
||||
if staticConfiguration.Providers != nil {
|
||||
server.providersThrottleDuration = time.Duration(staticConfiguration.Providers.ProvidersThrottleDuration)
|
||||
}
|
||||
|
||||
transport, err := createHTTPTransport(staticConfiguration.ServersTransport)
|
||||
if err != nil {
|
||||
log.WithoutContext().Errorf("Could not configure HTTP Transport, fallbacking on default transport: %v", err)
|
||||
server.defaultRoundTripper = http.DefaultTransport
|
||||
} else {
|
||||
server.defaultRoundTripper = transport
|
||||
}
|
||||
|
||||
server.routinesPool = safe.NewPool(context.Background())
|
||||
|
||||
if staticConfiguration.Tracing != nil {
|
||||
trackingBackend := setupTracing(staticConfiguration.Tracing)
|
||||
var err error
|
||||
server.tracer, err = tracing.NewTracing(staticConfiguration.Tracing.ServiceName, staticConfiguration.Tracing.SpanNameLimit, trackingBackend)
|
||||
if err != nil {
|
||||
log.WithoutContext().Warnf("Unable to create tracer: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
server.requestDecorator = requestdecorator.New(staticConfiguration.HostResolver)
|
||||
|
||||
server.metricsRegistry = registerMetricClients(staticConfiguration.Metrics)
|
||||
|
||||
if staticConfiguration.AccessLog != nil {
|
||||
var err error
|
||||
server.accessLoggerMiddleware, err = accesslog.NewHandler(staticConfiguration.AccessLog)
|
||||
if err != nil {
|
||||
log.WithoutContext().Warnf("Unable to create access logger : %v", err)
|
||||
}
|
||||
}
|
||||
return server
|
||||
}
|
||||
|
||||
// Start starts the server and Stop/Close it when context is Done
|
||||
func (s *Server) Start(ctx context.Context) {
|
||||
go func() {
|
||||
defer s.Close()
|
||||
<-ctx.Done()
|
||||
logger := log.FromContext(ctx)
|
||||
logger.Info("I have to go...")
|
||||
logger.Info("Stopping server gracefully")
|
||||
s.Stop()
|
||||
}()
|
||||
|
||||
s.startTCPServers()
|
||||
s.routinesPool.Go(func(stop chan bool) {
|
||||
s.listenProviders(stop)
|
||||
})
|
||||
s.routinesPool.Go(func(stop chan bool) {
|
||||
s.listenConfigurations(stop)
|
||||
})
|
||||
s.startProvider()
|
||||
s.routinesPool.Go(func(stop chan bool) {
|
||||
s.listenSignals(stop)
|
||||
})
|
||||
}
|
||||
|
||||
// Wait blocks until server is shutted down.
|
||||
func (s *Server) Wait() {
|
||||
<-s.stopChan
|
||||
}
|
||||
|
||||
// Stop stops the server
|
||||
func (s *Server) Stop() {
|
||||
defer log.WithoutContext().Info("Server stopped")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for epn, ep := range s.entryPointsTCP {
|
||||
wg.Add(1)
|
||||
go func(entryPointName string, entryPoint *TCPEntryPoint) {
|
||||
ctx := log.With(context.Background(), log.Str(log.EntryPointName, entryPointName))
|
||||
defer wg.Done()
|
||||
|
||||
entryPoint.Shutdown(ctx)
|
||||
|
||||
log.FromContext(ctx).Debugf("Entry point %s closed", entryPointName)
|
||||
}(epn, ep)
|
||||
}
|
||||
wg.Wait()
|
||||
s.stopChan <- true
|
||||
}
|
||||
|
||||
// Close destroys the server
|
||||
func (s *Server) Close() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
go func(ctx context.Context) {
|
||||
<-ctx.Done()
|
||||
if ctx.Err() == context.Canceled {
|
||||
return
|
||||
} else if ctx.Err() == context.DeadlineExceeded {
|
||||
panic("Timeout while stopping traefik, killing instance ✝")
|
||||
}
|
||||
}(ctx)
|
||||
|
||||
stopMetricsClients()
|
||||
s.routinesPool.Cleanup()
|
||||
close(s.configurationChan)
|
||||
close(s.configurationValidatedChan)
|
||||
signal.Stop(s.signals)
|
||||
close(s.signals)
|
||||
close(s.stopChan)
|
||||
|
||||
if s.accessLoggerMiddleware != nil {
|
||||
if err := s.accessLoggerMiddleware.Close(); err != nil {
|
||||
log.WithoutContext().Errorf("Could not close the access log file: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if s.tracer != nil {
|
||||
s.tracer.Close()
|
||||
}
|
||||
|
||||
cancel()
|
||||
}
|
||||
|
||||
func (s *Server) startTCPServers() {
|
||||
// Use an empty configuration in order to initialize the default handlers with internal routes
|
||||
routers := s.loadConfigurationTCP(config.Configurations{})
|
||||
for entryPointName, router := range routers {
|
||||
s.entryPointsTCP[entryPointName].switchRouter(router)
|
||||
}
|
||||
|
||||
for entryPointName, serverEntryPoint := range s.entryPointsTCP {
|
||||
ctx := log.With(context.Background(), log.Str(log.EntryPointName, entryPointName))
|
||||
go serverEntryPoint.startTCP(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) listenProviders(stop chan bool) {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case configMsg, ok := <-s.configurationChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if configMsg.Configuration != nil {
|
||||
s.preLoadConfiguration(configMsg)
|
||||
} else {
|
||||
log.Debugf("Received nil configuration from provider %q, skipping.", configMsg.ProviderName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AddListener adds a new listener function used when new configuration is provided
|
||||
func (s *Server) AddListener(listener func(config.Configuration)) {
|
||||
if s.configurationListeners == nil {
|
||||
s.configurationListeners = make([]func(config.Configuration), 0)
|
||||
}
|
||||
s.configurationListeners = append(s.configurationListeners, listener)
|
||||
}
|
||||
|
||||
func (s *Server) startProvider() {
|
||||
jsonConf, err := json.Marshal(s.provider)
|
||||
if err != nil {
|
||||
log.WithoutContext().Debugf("Unable to marshal provider configuration %T: %v", s.provider, err)
|
||||
}
|
||||
|
||||
log.WithoutContext().Infof("Starting provider %T %s", s.provider, jsonConf)
|
||||
currentProvider := s.provider
|
||||
|
||||
safe.Go(func() {
|
||||
err := currentProvider.Provide(s.configurationChan, s.routinesPool)
|
||||
if err != nil {
|
||||
log.WithoutContext().Errorf("Error starting provider %T: %s", s.provider, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func registerMetricClients(metricsConfig *types.Metrics) metrics.Registry {
|
||||
if metricsConfig == nil {
|
||||
return metrics.NewVoidRegistry()
|
||||
}
|
||||
|
||||
var registries []metrics.Registry
|
||||
|
||||
if metricsConfig.Prometheus != nil {
|
||||
ctx := log.With(context.Background(), log.Str(log.MetricsProviderName, "prometheus"))
|
||||
prometheusRegister := metrics.RegisterPrometheus(ctx, metricsConfig.Prometheus)
|
||||
if prometheusRegister != nil {
|
||||
registries = append(registries, prometheusRegister)
|
||||
log.FromContext(ctx).Debug("Configured Prometheus metrics")
|
||||
}
|
||||
}
|
||||
|
||||
if metricsConfig.Datadog != nil {
|
||||
ctx := log.With(context.Background(), log.Str(log.MetricsProviderName, "datadog"))
|
||||
registries = append(registries, metrics.RegisterDatadog(ctx, metricsConfig.Datadog))
|
||||
log.FromContext(ctx).Debugf("Configured DataDog metrics: pushing to %s once every %s",
|
||||
metricsConfig.Datadog.Address, metricsConfig.Datadog.PushInterval)
|
||||
}
|
||||
|
||||
if metricsConfig.StatsD != nil {
|
||||
ctx := log.With(context.Background(), log.Str(log.MetricsProviderName, "statsd"))
|
||||
registries = append(registries, metrics.RegisterStatsd(ctx, metricsConfig.StatsD))
|
||||
log.FromContext(ctx).Debugf("Configured StatsD metrics: pushing to %s once every %s",
|
||||
metricsConfig.StatsD.Address, metricsConfig.StatsD.PushInterval)
|
||||
}
|
||||
|
||||
if metricsConfig.InfluxDB != nil {
|
||||
ctx := log.With(context.Background(), log.Str(log.MetricsProviderName, "influxdb"))
|
||||
registries = append(registries, metrics.RegisterInfluxDB(ctx, metricsConfig.InfluxDB))
|
||||
log.FromContext(ctx).Debugf("Configured InfluxDB metrics: pushing to %s once every %s",
|
||||
metricsConfig.InfluxDB.Address, metricsConfig.InfluxDB.PushInterval)
|
||||
}
|
||||
|
||||
return metrics.NewMultiRegistry(registries)
|
||||
}
|
||||
|
||||
func stopMetricsClients() {
|
||||
metrics.StopDatadog()
|
||||
metrics.StopStatsd()
|
||||
metrics.StopInfluxDB()
|
||||
}
|
272
pkg/server/server_configuration.go
Normal file
272
pkg/server/server_configuration.go
Normal file
|
@ -0,0 +1,272 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/containous/alice"
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
"github.com/containous/traefik/pkg/middlewares/accesslog"
|
||||
"github.com/containous/traefik/pkg/middlewares/requestdecorator"
|
||||
"github.com/containous/traefik/pkg/middlewares/tracing"
|
||||
"github.com/containous/traefik/pkg/responsemodifiers"
|
||||
"github.com/containous/traefik/pkg/server/middleware"
|
||||
"github.com/containous/traefik/pkg/server/router"
|
||||
routertcp "github.com/containous/traefik/pkg/server/router/tcp"
|
||||
"github.com/containous/traefik/pkg/server/service"
|
||||
"github.com/containous/traefik/pkg/server/service/tcp"
|
||||
tcpCore "github.com/containous/traefik/pkg/tcp"
|
||||
"github.com/eapache/channels"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// loadConfiguration manages dynamically routers, middlewares, servers and TLS configurations
|
||||
func (s *Server) loadConfiguration(configMsg config.Message) {
|
||||
currentConfigurations := s.currentConfigurations.Get().(config.Configurations)
|
||||
|
||||
// Copy configurations to new map so we don't change current if LoadConfig fails
|
||||
newConfigurations := make(config.Configurations)
|
||||
for k, v := range currentConfigurations {
|
||||
newConfigurations[k] = v
|
||||
}
|
||||
newConfigurations[configMsg.ProviderName] = configMsg.Configuration
|
||||
|
||||
s.metricsRegistry.ConfigReloadsCounter().Add(1)
|
||||
|
||||
handlersTCP := s.loadConfigurationTCP(newConfigurations)
|
||||
for entryPointName, router := range handlersTCP {
|
||||
s.entryPointsTCP[entryPointName].switchRouter(router)
|
||||
}
|
||||
|
||||
s.metricsRegistry.LastConfigReloadSuccessGauge().Set(float64(time.Now().Unix()))
|
||||
|
||||
s.currentConfigurations.Set(newConfigurations)
|
||||
|
||||
for _, listener := range s.configurationListeners {
|
||||
listener(*configMsg.Configuration)
|
||||
}
|
||||
|
||||
s.postLoadConfiguration()
|
||||
}
|
||||
|
||||
// loadConfigurationTCP returns a new gorilla.mux Route from the specified global configuration and the dynamic
|
||||
// provider configurations.
|
||||
func (s *Server) loadConfigurationTCP(configurations config.Configurations) map[string]*tcpCore.Router {
|
||||
ctx := context.TODO()
|
||||
|
||||
var entryPoints []string
|
||||
for entryPointName := range s.entryPointsTCP {
|
||||
entryPoints = append(entryPoints, entryPointName)
|
||||
}
|
||||
|
||||
conf := mergeConfiguration(configurations)
|
||||
|
||||
s.tlsManager.UpdateConfigs(conf.TLSStores, conf.TLSOptions, conf.TLS)
|
||||
|
||||
handlersNonTLS, handlersTLS := s.createHTTPHandlers(ctx, *conf.HTTP, entryPoints)
|
||||
|
||||
routersTCP := s.createTCPRouters(ctx, conf.TCP, entryPoints, handlersNonTLS, handlersTLS, s.tlsManager.Get("default", "default"))
|
||||
|
||||
return routersTCP
|
||||
}
|
||||
|
||||
func (s *Server) createTCPRouters(ctx context.Context, configuration *config.TCPConfiguration, entryPoints []string, handlers map[string]http.Handler, handlersTLS map[string]http.Handler, tlsConfig *tls.Config) map[string]*tcpCore.Router {
|
||||
if configuration == nil {
|
||||
return make(map[string]*tcpCore.Router)
|
||||
}
|
||||
|
||||
serviceManager := tcp.NewManager(configuration.Services)
|
||||
routerManager := routertcp.NewManager(configuration.Routers, serviceManager, handlers, handlersTLS, tlsConfig)
|
||||
|
||||
return routerManager.BuildHandlers(ctx, entryPoints)
|
||||
|
||||
}
|
||||
|
||||
func (s *Server) createHTTPHandlers(ctx context.Context, configuration config.HTTPConfiguration, entryPoints []string) (map[string]http.Handler, map[string]http.Handler) {
|
||||
serviceManager := service.NewManager(configuration.Services, s.defaultRoundTripper)
|
||||
middlewaresBuilder := middleware.NewBuilder(configuration.Middlewares, serviceManager)
|
||||
responseModifierFactory := responsemodifiers.NewBuilder(configuration.Middlewares)
|
||||
|
||||
routerManager := router.NewManager(configuration.Routers, serviceManager, middlewaresBuilder, responseModifierFactory)
|
||||
|
||||
handlersNonTLS := routerManager.BuildHandlers(ctx, entryPoints, false)
|
||||
handlersTLS := routerManager.BuildHandlers(ctx, entryPoints, true)
|
||||
|
||||
routerHandlers := make(map[string]http.Handler)
|
||||
|
||||
for _, entryPointName := range entryPoints {
|
||||
internalMuxRouter := mux.NewRouter().
|
||||
SkipClean(true)
|
||||
|
||||
ctx = log.With(ctx, log.Str(log.EntryPointName, entryPointName))
|
||||
|
||||
factory := s.entryPointsTCP[entryPointName].RouteAppenderFactory
|
||||
if factory != nil {
|
||||
// FIXME remove currentConfigurations
|
||||
appender := factory.NewAppender(ctx, middlewaresBuilder, &s.currentConfigurations)
|
||||
appender.Append(internalMuxRouter)
|
||||
}
|
||||
|
||||
if h, ok := handlersNonTLS[entryPointName]; ok {
|
||||
internalMuxRouter.NotFoundHandler = h
|
||||
} else {
|
||||
internalMuxRouter.NotFoundHandler = buildDefaultHTTPRouter()
|
||||
}
|
||||
|
||||
routerHandlers[entryPointName] = internalMuxRouter
|
||||
|
||||
chain := alice.New()
|
||||
|
||||
if s.accessLoggerMiddleware != nil {
|
||||
chain = chain.Append(accesslog.WrapHandler(s.accessLoggerMiddleware))
|
||||
}
|
||||
|
||||
if s.tracer != nil {
|
||||
chain = chain.Append(tracing.WrapEntryPointHandler(ctx, s.tracer, entryPointName))
|
||||
}
|
||||
|
||||
chain = chain.Append(requestdecorator.WrapHandler(s.requestDecorator))
|
||||
|
||||
handler, err := chain.Then(internalMuxRouter.NotFoundHandler)
|
||||
if err != nil {
|
||||
log.FromContext(ctx).Error(err)
|
||||
continue
|
||||
}
|
||||
internalMuxRouter.NotFoundHandler = handler
|
||||
|
||||
handlerTLS, ok := handlersTLS[entryPointName]
|
||||
if ok {
|
||||
handlerTLSWithMiddlewares, err := chain.Then(handlerTLS)
|
||||
if err != nil {
|
||||
log.FromContext(ctx).Error(err)
|
||||
continue
|
||||
}
|
||||
handlersTLS[entryPointName] = handlerTLSWithMiddlewares
|
||||
}
|
||||
}
|
||||
|
||||
return routerHandlers, handlersTLS
|
||||
}
|
||||
|
||||
func isEmptyConfiguration(conf *config.Configuration) bool {
|
||||
if conf == nil {
|
||||
return true
|
||||
}
|
||||
if conf.TCP == nil {
|
||||
conf.TCP = &config.TCPConfiguration{}
|
||||
}
|
||||
if conf.HTTP == nil {
|
||||
conf.HTTP = &config.HTTPConfiguration{}
|
||||
}
|
||||
|
||||
return conf.HTTP.Routers == nil &&
|
||||
conf.HTTP.Services == nil &&
|
||||
conf.HTTP.Middlewares == nil &&
|
||||
conf.TLS == nil &&
|
||||
conf.TCP.Routers == nil &&
|
||||
conf.TCP.Services == nil
|
||||
}
|
||||
|
||||
func (s *Server) preLoadConfiguration(configMsg config.Message) {
|
||||
s.defaultConfigurationValues(configMsg.Configuration.HTTP)
|
||||
currentConfigurations := s.currentConfigurations.Get().(config.Configurations)
|
||||
|
||||
logger := log.WithoutContext().WithField(log.ProviderName, configMsg.ProviderName)
|
||||
if log.GetLevel() == logrus.DebugLevel {
|
||||
jsonConf, _ := json.Marshal(configMsg.Configuration)
|
||||
logger.Debugf("Configuration received from provider %s: %s", configMsg.ProviderName, string(jsonConf))
|
||||
}
|
||||
|
||||
if isEmptyConfiguration(configMsg.Configuration) {
|
||||
logger.Infof("Skipping empty Configuration for provider %s", configMsg.ProviderName)
|
||||
return
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(currentConfigurations[configMsg.ProviderName], configMsg.Configuration) {
|
||||
logger.Infof("Skipping same configuration for provider %s", configMsg.ProviderName)
|
||||
return
|
||||
}
|
||||
|
||||
providerConfigUpdateCh, ok := s.providerConfigUpdateMap[configMsg.ProviderName]
|
||||
if !ok {
|
||||
providerConfigUpdateCh = make(chan config.Message)
|
||||
s.providerConfigUpdateMap[configMsg.ProviderName] = providerConfigUpdateCh
|
||||
s.routinesPool.Go(func(stop chan bool) {
|
||||
s.throttleProviderConfigReload(s.providersThrottleDuration, s.configurationValidatedChan, providerConfigUpdateCh, stop)
|
||||
})
|
||||
}
|
||||
|
||||
providerConfigUpdateCh <- configMsg
|
||||
}
|
||||
|
||||
func (s *Server) defaultConfigurationValues(configuration *config.HTTPConfiguration) {
|
||||
// FIXME create a config hook
|
||||
}
|
||||
|
||||
func (s *Server) listenConfigurations(stop chan bool) {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case configMsg, ok := <-s.configurationValidatedChan:
|
||||
if !ok || configMsg.Configuration == nil {
|
||||
return
|
||||
}
|
||||
s.loadConfiguration(configMsg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// throttleProviderConfigReload throttles the configuration reload speed for a single provider.
|
||||
// It will immediately publish a new configuration and then only publish the next configuration after the throttle duration.
|
||||
// Note that in the case it receives N new configs in the timeframe of the throttle duration after publishing,
|
||||
// it will publish the last of the newly received configurations.
|
||||
func (s *Server) throttleProviderConfigReload(throttle time.Duration, publish chan<- config.Message, in <-chan config.Message, stop chan bool) {
|
||||
ring := channels.NewRingChannel(1)
|
||||
defer ring.Close()
|
||||
|
||||
s.routinesPool.Go(func(stop chan bool) {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case nextConfig := <-ring.Out():
|
||||
if config, ok := nextConfig.(config.Message); ok {
|
||||
publish <- config
|
||||
time.Sleep(throttle)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case nextConfig := <-in:
|
||||
ring.In() <- nextConfig
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) postLoadConfiguration() {
|
||||
// FIXME metrics
|
||||
// if s.metricsRegistry.IsEnabled() {
|
||||
// activeConfig := s.currentConfigurations.Get().(config.Configurations)
|
||||
// metrics.OnConfigurationUpdate(activeConfig)
|
||||
// }
|
||||
|
||||
}
|
||||
|
||||
func buildDefaultHTTPRouter() *mux.Router {
|
||||
rt := mux.NewRouter()
|
||||
rt.NotFoundHandler = http.HandlerFunc(http.NotFound)
|
||||
rt.SkipClean(true)
|
||||
return rt
|
||||
}
|
121
pkg/server/server_configuration_test.go
Normal file
121
pkg/server/server_configuration_test.go
Normal file
|
@ -0,0 +1,121 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/config/static"
|
||||
th "github.com/containous/traefik/pkg/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestReuseService(t *testing.T) {
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
entryPoints := TCPEntryPoints{
|
||||
"http": &TCPEntryPoint{},
|
||||
}
|
||||
|
||||
staticConfig := static.Configuration{}
|
||||
|
||||
dynamicConfigs := th.BuildConfiguration(
|
||||
th.WithRouters(
|
||||
th.WithRouter("foo",
|
||||
th.WithServiceName("bar"),
|
||||
th.WithRule("Path(`/ok`)")),
|
||||
th.WithRouter("foo2",
|
||||
th.WithEntryPoints("http"),
|
||||
th.WithRule("Path(`/unauthorized`)"),
|
||||
th.WithServiceName("bar"),
|
||||
th.WithRouterMiddlewares("basicauth")),
|
||||
),
|
||||
th.WithMiddlewares(th.WithMiddleware("basicauth",
|
||||
th.WithBasicAuth(&config.BasicAuth{Users: []string{"foo:bar"}}),
|
||||
)),
|
||||
th.WithLoadBalancerServices(th.WithService("bar",
|
||||
th.WithLBMethod("wrr"),
|
||||
th.WithServers(th.WithServer(testServer.URL))),
|
||||
),
|
||||
)
|
||||
|
||||
srv := NewServer(staticConfig, nil, entryPoints, nil)
|
||||
|
||||
entrypointsHandlers, _ := srv.createHTTPHandlers(context.Background(), *dynamicConfigs, []string{"http"})
|
||||
|
||||
// Test that the /ok path returns a status 200.
|
||||
responseRecorderOk := &httptest.ResponseRecorder{}
|
||||
requestOk := httptest.NewRequest(http.MethodGet, testServer.URL+"/ok", nil)
|
||||
entrypointsHandlers["http"].ServeHTTP(responseRecorderOk, requestOk)
|
||||
|
||||
assert.Equal(t, http.StatusOK, responseRecorderOk.Result().StatusCode, "status code")
|
||||
|
||||
// Test that the /unauthorized path returns a 401 because of
|
||||
// the basic authentication defined on the frontend.
|
||||
responseRecorderUnauthorized := &httptest.ResponseRecorder{}
|
||||
requestUnauthorized := httptest.NewRequest(http.MethodGet, testServer.URL+"/unauthorized", nil)
|
||||
entrypointsHandlers["http"].ServeHTTP(responseRecorderUnauthorized, requestUnauthorized)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, responseRecorderUnauthorized.Result().StatusCode, "status code")
|
||||
}
|
||||
|
||||
func TestThrottleProviderConfigReload(t *testing.T) {
|
||||
throttleDuration := 30 * time.Millisecond
|
||||
publishConfig := make(chan config.Message)
|
||||
providerConfig := make(chan config.Message)
|
||||
stop := make(chan bool)
|
||||
defer func() {
|
||||
stop <- true
|
||||
}()
|
||||
|
||||
staticConfiguration := static.Configuration{}
|
||||
server := NewServer(staticConfiguration, nil, nil, nil)
|
||||
|
||||
go server.throttleProviderConfigReload(throttleDuration, publishConfig, providerConfig, stop)
|
||||
|
||||
publishedConfigCount := 0
|
||||
stopConsumeConfigs := make(chan bool)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case <-stopConsumeConfigs:
|
||||
return
|
||||
case <-publishConfig:
|
||||
publishedConfigCount++
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// publish 5 new configs, one new config each 10 milliseconds
|
||||
for i := 0; i < 5; i++ {
|
||||
providerConfig <- config.Message{}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
// after 50 milliseconds 5 new configs were published
|
||||
// with a throttle duration of 30 milliseconds this means, we should have received 2 new configs
|
||||
assert.Equal(t, 2, publishedConfigCount, "times configs were published")
|
||||
|
||||
stopConsumeConfigs <- true
|
||||
|
||||
select {
|
||||
case <-publishConfig:
|
||||
// There should be exactly one more message that we receive after ~60 milliseconds since the start of the test.
|
||||
select {
|
||||
case <-publishConfig:
|
||||
t.Error("extra config publication found")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
return
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Last config was not published in time")
|
||||
}
|
||||
}
|
385
pkg/server/server_entrypoint_tcp.go
Normal file
385
pkg/server/server_entrypoint_tcp.go
Normal file
|
@ -0,0 +1,385 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-proxyproto"
|
||||
"github.com/containous/traefik/pkg/config/static"
|
||||
"github.com/containous/traefik/pkg/h2c"
|
||||
"github.com/containous/traefik/pkg/ip"
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
"github.com/containous/traefik/pkg/middlewares"
|
||||
"github.com/containous/traefik/pkg/middlewares/forwardedheaders"
|
||||
"github.com/containous/traefik/pkg/safe"
|
||||
"github.com/containous/traefik/pkg/tcp"
|
||||
)
|
||||
|
||||
type httpForwarder struct {
|
||||
net.Listener
|
||||
connChan chan net.Conn
|
||||
}
|
||||
|
||||
func newHTTPForwarder(ln net.Listener) *httpForwarder {
|
||||
return &httpForwarder{
|
||||
Listener: ln,
|
||||
connChan: make(chan net.Conn),
|
||||
}
|
||||
}
|
||||
|
||||
// ServeTCP uses the connection to serve it later in "Accept"
|
||||
func (h *httpForwarder) ServeTCP(conn net.Conn) {
|
||||
h.connChan <- conn
|
||||
}
|
||||
|
||||
// Accept retrieves a served connection in ServeTCP
|
||||
func (h *httpForwarder) Accept() (net.Conn, error) {
|
||||
conn := <-h.connChan
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// TCPEntryPoints holds a map of TCPEntryPoint (the entrypoint names being the keys)
|
||||
type TCPEntryPoints map[string]*TCPEntryPoint
|
||||
|
||||
// TCPEntryPoint is the TCP server
|
||||
type TCPEntryPoint struct {
|
||||
listener net.Listener
|
||||
switcher *tcp.HandlerSwitcher
|
||||
RouteAppenderFactory RouteAppenderFactory
|
||||
transportConfiguration *static.EntryPointsTransport
|
||||
tracker *connectionTracker
|
||||
httpServer *httpServer
|
||||
httpsServer *httpServer
|
||||
}
|
||||
|
||||
// NewTCPEntryPoint creates a new TCPEntryPoint
|
||||
func NewTCPEntryPoint(ctx context.Context, configuration *static.EntryPoint) (*TCPEntryPoint, error) {
|
||||
tracker := newConnectionTracker()
|
||||
|
||||
listener, err := buildListener(ctx, configuration)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error preparing server: %v", err)
|
||||
}
|
||||
|
||||
router := &tcp.Router{}
|
||||
|
||||
httpServer, err := createHTTPServer(listener, configuration, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error preparing httpServer: %v", err)
|
||||
}
|
||||
|
||||
router.HTTPForwarder(httpServer.Forwarder)
|
||||
|
||||
httpsServer, err := createHTTPServer(listener, configuration, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error preparing httpsServer: %v", err)
|
||||
}
|
||||
|
||||
router.HTTPSForwarder(httpsServer.Forwarder)
|
||||
|
||||
tcpSwitcher := &tcp.HandlerSwitcher{}
|
||||
tcpSwitcher.Switch(router)
|
||||
|
||||
return &TCPEntryPoint{
|
||||
listener: listener,
|
||||
switcher: tcpSwitcher,
|
||||
transportConfiguration: configuration.Transport,
|
||||
tracker: tracker,
|
||||
httpServer: httpServer,
|
||||
httpsServer: httpsServer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *TCPEntryPoint) startTCP(ctx context.Context) {
|
||||
log.FromContext(ctx).Debugf("Start TCP Server")
|
||||
|
||||
for {
|
||||
conn, err := e.listener.Accept()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
safe.Go(func() {
|
||||
e.switcher.ServeTCP(newTrackedConnection(conn, e.tracker))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown stops the TCP connections
|
||||
func (e *TCPEntryPoint) Shutdown(ctx context.Context) {
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
reqAcceptGraceTimeOut := time.Duration(e.transportConfiguration.LifeCycle.RequestAcceptGraceTimeout)
|
||||
if reqAcceptGraceTimeOut > 0 {
|
||||
logger.Infof("Waiting %s for incoming requests to cease", reqAcceptGraceTimeOut)
|
||||
time.Sleep(reqAcceptGraceTimeOut)
|
||||
}
|
||||
|
||||
graceTimeOut := time.Duration(e.transportConfiguration.LifeCycle.GraceTimeOut)
|
||||
ctx, cancel := context.WithTimeout(ctx, graceTimeOut)
|
||||
logger.Debugf("Waiting %s seconds before killing connections.", graceTimeOut)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
if e.httpServer.Server != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := e.httpServer.Server.Shutdown(ctx); err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
logger.Debugf("Wait server shutdown is overdue to: %s", err)
|
||||
err = e.httpServer.Server.Close()
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if e.httpsServer.Server != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := e.httpsServer.Server.Shutdown(ctx); err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
logger.Debugf("Wait server shutdown is overdue to: %s", err)
|
||||
err = e.httpsServer.Server.Close()
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if e.tracker != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := e.tracker.Shutdown(ctx); err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
logger.Debugf("Wait hijack connection is overdue to: %s", err)
|
||||
e.tracker.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
cancel()
|
||||
}
|
||||
|
||||
func (e *TCPEntryPoint) switchRouter(router *tcp.Router) {
|
||||
router.HTTPForwarder(e.httpServer.Forwarder)
|
||||
router.HTTPSForwarder(e.httpsServer.Forwarder)
|
||||
|
||||
httpHandler := router.GetHTTPHandler()
|
||||
httpsHandler := router.GetHTTPSHandler()
|
||||
if httpHandler == nil {
|
||||
httpHandler = buildDefaultHTTPRouter()
|
||||
}
|
||||
if httpsHandler == nil {
|
||||
httpsHandler = buildDefaultHTTPRouter()
|
||||
}
|
||||
|
||||
e.httpServer.Switcher.UpdateHandler(httpHandler)
|
||||
e.httpsServer.Switcher.UpdateHandler(httpsHandler)
|
||||
e.switcher.Switch(router)
|
||||
}
|
||||
|
||||
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
|
||||
// connections.
|
||||
type tcpKeepAliveListener struct {
|
||||
*net.TCPListener
|
||||
}
|
||||
|
||||
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
|
||||
tc, err := ln.AcceptTCP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = tc.SetKeepAlive(true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = tc.SetKeepAlivePeriod(3 * time.Minute); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tc, nil
|
||||
}
|
||||
|
||||
func buildProxyProtocolListener(ctx context.Context, entryPoint *static.EntryPoint, listener net.Listener) (net.Listener, error) {
|
||||
var sourceCheck func(addr net.Addr) (bool, error)
|
||||
if entryPoint.ProxyProtocol.Insecure {
|
||||
sourceCheck = func(_ net.Addr) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
} else {
|
||||
checker, err := ip.NewChecker(entryPoint.ProxyProtocol.TrustedIPs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sourceCheck = func(addr net.Addr) (bool, error) {
|
||||
ipAddr, ok := addr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("type error %v", addr)
|
||||
}
|
||||
|
||||
return checker.ContainsIP(ipAddr.IP), nil
|
||||
}
|
||||
}
|
||||
|
||||
log.FromContext(ctx).Infof("Enabling ProxyProtocol for trusted IPs %v", entryPoint.ProxyProtocol.TrustedIPs)
|
||||
|
||||
return &proxyproto.Listener{
|
||||
Listener: listener,
|
||||
SourceCheck: sourceCheck,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildListener(ctx context.Context, entryPoint *static.EntryPoint) (net.Listener, error) {
|
||||
listener, err := net.Listen("tcp", entryPoint.Address)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening listener: %v", err)
|
||||
}
|
||||
|
||||
listener = tcpKeepAliveListener{listener.(*net.TCPListener)}
|
||||
|
||||
if entryPoint.ProxyProtocol != nil {
|
||||
listener, err = buildProxyProtocolListener(ctx, entryPoint, listener)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating proxy protocol listener: %v", err)
|
||||
}
|
||||
}
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
func newConnectionTracker() *connectionTracker {
|
||||
return &connectionTracker{
|
||||
conns: make(map[net.Conn]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
type connectionTracker struct {
|
||||
conns map[net.Conn]struct{}
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// AddConnection add a connection in the tracked connections list
|
||||
func (c *connectionTracker) AddConnection(conn net.Conn) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
c.conns[conn] = struct{}{}
|
||||
}
|
||||
|
||||
// RemoveConnection remove a connection from the tracked connections list
|
||||
func (c *connectionTracker) RemoveConnection(conn net.Conn) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
delete(c.conns, conn)
|
||||
}
|
||||
|
||||
// Shutdown wait for the connection closing
|
||||
func (c *connectionTracker) Shutdown(ctx context.Context) error {
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
c.lock.RLock()
|
||||
if len(c.conns) == 0 {
|
||||
return nil
|
||||
}
|
||||
c.lock.RUnlock()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close close all the connections in the tracked connections list
|
||||
func (c *connectionTracker) Close() {
|
||||
for conn := range c.conns {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.WithoutContext().Errorf("Error while closing connection: %v", err)
|
||||
}
|
||||
c.RemoveConnection(conn)
|
||||
}
|
||||
}
|
||||
|
||||
type stoppableServer interface {
|
||||
Shutdown(context.Context) error
|
||||
Close() error
|
||||
Serve(listener net.Listener) error
|
||||
}
|
||||
|
||||
type httpServer struct {
|
||||
Server stoppableServer
|
||||
Forwarder *httpForwarder
|
||||
Switcher *middlewares.HTTPHandlerSwitcher
|
||||
}
|
||||
|
||||
func createHTTPServer(ln net.Listener, configuration *static.EntryPoint, withH2c bool) (*httpServer, error) {
|
||||
httpSwitcher := middlewares.NewHandlerSwitcher(buildDefaultHTTPRouter())
|
||||
handler, err := forwardedheaders.NewXForwarded(
|
||||
configuration.ForwardedHeaders.Insecure,
|
||||
configuration.ForwardedHeaders.TrustedIPs,
|
||||
httpSwitcher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var serverHTTP stoppableServer
|
||||
|
||||
if withH2c {
|
||||
serverHTTP = &h2c.Server{
|
||||
Server: &http.Server{
|
||||
Handler: handler,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
serverHTTP = &http.Server{
|
||||
Handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
listener := newHTTPForwarder(ln)
|
||||
go func() {
|
||||
err := serverHTTP.Serve(listener)
|
||||
if err != nil {
|
||||
log.Errorf("Error while starting server: %v", err)
|
||||
}
|
||||
}()
|
||||
return &httpServer{
|
||||
Server: serverHTTP,
|
||||
Forwarder: listener,
|
||||
Switcher: httpSwitcher,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newTrackedConnection(conn net.Conn, tracker *connectionTracker) *trackedConnection {
|
||||
tracker.AddConnection(conn)
|
||||
return &trackedConnection{
|
||||
Conn: conn,
|
||||
tracker: tracker,
|
||||
}
|
||||
}
|
||||
|
||||
type trackedConnection struct {
|
||||
tracker *connectionTracker
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (t *trackedConnection) Close() error {
|
||||
t.tracker.RemoveConnection(t.Conn)
|
||||
return t.Conn.Close()
|
||||
}
|
142
pkg/server/server_entrypoint_tcp_test.go
Normal file
142
pkg/server/server_entrypoint_tcp_test.go
Normal file
|
@ -0,0 +1,142 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/containous/flaeg/parse"
|
||||
"github.com/containous/traefik/pkg/config/static"
|
||||
"github.com/containous/traefik/pkg/tcp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestShutdownHTTP(t *testing.T) {
|
||||
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
|
||||
Address: ":0",
|
||||
Transport: &static.EntryPointsTransport{
|
||||
LifeCycle: &static.LifeCycle{
|
||||
RequestAcceptGraceTimeout: 0,
|
||||
GraceTimeOut: parse.Duration(5 * time.Second),
|
||||
},
|
||||
},
|
||||
ForwardedHeaders: &static.ForwardedHeaders{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
go entryPoint.startTCP(context.Background())
|
||||
|
||||
router := &tcp.Router{}
|
||||
router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
time.Sleep(1 * time.Second)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
entryPoint.switchRouter(router)
|
||||
|
||||
conn, err := net.Dial("tcp", entryPoint.listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
go entryPoint.Shutdown(context.Background())
|
||||
|
||||
request, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8082", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = request.Write(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(conn), request)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
func TestShutdownHTTPHijacked(t *testing.T) {
|
||||
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
|
||||
Address: ":0",
|
||||
Transport: &static.EntryPointsTransport{
|
||||
LifeCycle: &static.LifeCycle{
|
||||
RequestAcceptGraceTimeout: 0,
|
||||
GraceTimeOut: parse.Duration(5 * time.Second),
|
||||
},
|
||||
},
|
||||
ForwardedHeaders: &static.ForwardedHeaders{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
go entryPoint.startTCP(context.Background())
|
||||
|
||||
router := &tcp.Router{}
|
||||
router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
conn, _, err := rw.(http.Hijacker).Hijack()
|
||||
require.NoError(t, err)
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
resp := http.Response{StatusCode: http.StatusOK}
|
||||
err = resp.Write(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
}))
|
||||
entryPoint.switchRouter(router)
|
||||
|
||||
conn, err := net.Dial("tcp", entryPoint.listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
go entryPoint.Shutdown(context.Background())
|
||||
|
||||
request, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8082", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = request.Write(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(conn), request)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
func TestShutdownTCPConn(t *testing.T) {
|
||||
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
|
||||
Address: ":0",
|
||||
Transport: &static.EntryPointsTransport{
|
||||
LifeCycle: &static.LifeCycle{
|
||||
RequestAcceptGraceTimeout: 0,
|
||||
GraceTimeOut: parse.Duration(5 * time.Second),
|
||||
},
|
||||
},
|
||||
ForwardedHeaders: &static.ForwardedHeaders{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
go entryPoint.startTCP(context.Background())
|
||||
|
||||
router := &tcp.Router{}
|
||||
router.AddCatchAllNoTLS(tcp.HandlerFunc(func(conn net.Conn) {
|
||||
_, err := http.ReadRequest(bufio.NewReader(conn))
|
||||
require.NoError(t, err)
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
resp := http.Response{StatusCode: http.StatusOK}
|
||||
err = resp.Write(conn)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
|
||||
entryPoint.switchRouter(router)
|
||||
|
||||
conn, err := net.Dial("tcp", entryPoint.listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
go entryPoint.Shutdown(context.Background())
|
||||
|
||||
request, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8082", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = request.Write(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(conn), request)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, resp.StatusCode, http.StatusOK)
|
||||
}
|
37
pkg/server/server_signals.go
Normal file
37
pkg/server/server_signals.go
Normal file
|
@ -0,0 +1,37 @@
|
|||
// +build !windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
)
|
||||
|
||||
func (s *Server) configureSignals() {
|
||||
signal.Notify(s.signals, syscall.SIGUSR1)
|
||||
}
|
||||
|
||||
func (s *Server) listenSignals(stop chan bool) {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case sig := <-s.signals:
|
||||
if sig == syscall.SIGUSR1 {
|
||||
log.WithoutContext().Infof("Closing and re-opening log files for rotation: %+v", sig)
|
||||
|
||||
if s.accessLoggerMiddleware != nil {
|
||||
if err := s.accessLoggerMiddleware.Rotate(); err != nil {
|
||||
log.WithoutContext().Errorf("Error rotating access log: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := log.RotateFile(); err != nil {
|
||||
log.WithoutContext().Errorf("Error rotating traefik log: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
7
pkg/server/server_signals_windows.go
Normal file
7
pkg/server/server_signals_windows.go
Normal file
|
@ -0,0 +1,7 @@
|
|||
// +build windows
|
||||
|
||||
package server
|
||||
|
||||
func (s *Server) configureSignals() {}
|
||||
|
||||
func (s *Server) listenSignals(stop chan bool) {}
|
271
pkg/server/server_test.go
Normal file
271
pkg/server/server_test.go
Normal file
|
@ -0,0 +1,271 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/containous/flaeg/parse"
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/config/static"
|
||||
th "github.com/containous/traefik/pkg/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestListenProvidersSkipsEmptyConfigs(t *testing.T) {
|
||||
server, stop, invokeStopChan := setupListenProvider(10 * time.Millisecond)
|
||||
defer invokeStopChan()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case <-server.configurationValidatedChan:
|
||||
t.Error("An empty configuration was published but it should not")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
server.configurationChan <- config.Message{ProviderName: "kubernetes"}
|
||||
|
||||
// give some time so that the configuration can be processed
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
func TestListenProvidersSkipsSameConfigurationForProvider(t *testing.T) {
|
||||
server, stop, invokeStopChan := setupListenProvider(10 * time.Millisecond)
|
||||
defer invokeStopChan()
|
||||
|
||||
publishedConfigCount := 0
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case conf := <-server.configurationValidatedChan:
|
||||
// set the current configuration
|
||||
// this is usually done in the processing part of the published configuration
|
||||
// so we have to emulate the behavior here
|
||||
currentConfigurations := server.currentConfigurations.Get().(config.Configurations)
|
||||
currentConfigurations[conf.ProviderName] = conf.Configuration
|
||||
server.currentConfigurations.Set(currentConfigurations)
|
||||
|
||||
publishedConfigCount++
|
||||
if publishedConfigCount > 1 {
|
||||
t.Error("Same configuration should not be published multiple times")
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
conf := &config.Configuration{}
|
||||
conf.HTTP = th.BuildConfiguration(
|
||||
th.WithRouters(th.WithRouter("foo")),
|
||||
th.WithLoadBalancerServices(th.WithService("bar")),
|
||||
)
|
||||
|
||||
// provide a configuration
|
||||
server.configurationChan <- config.Message{ProviderName: "kubernetes", Configuration: conf}
|
||||
|
||||
// give some time so that the configuration can be processed
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// provide the same configuration a second time
|
||||
server.configurationChan <- config.Message{ProviderName: "kubernetes", Configuration: conf}
|
||||
|
||||
// give some time so that the configuration can be processed
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
func TestListenProvidersPublishesConfigForEachProvider(t *testing.T) {
|
||||
server, stop, invokeStopChan := setupListenProvider(10 * time.Millisecond)
|
||||
defer invokeStopChan()
|
||||
|
||||
publishedProviderConfigCount := map[string]int{}
|
||||
publishedConfigCount := 0
|
||||
consumePublishedConfigsDone := make(chan bool)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case newConfig := <-server.configurationValidatedChan:
|
||||
publishedProviderConfigCount[newConfig.ProviderName]++
|
||||
publishedConfigCount++
|
||||
if publishedConfigCount == 2 {
|
||||
consumePublishedConfigsDone <- true
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
conf := &config.Configuration{}
|
||||
conf.HTTP = th.BuildConfiguration(
|
||||
th.WithRouters(th.WithRouter("foo")),
|
||||
th.WithLoadBalancerServices(th.WithService("bar")),
|
||||
)
|
||||
server.configurationChan <- config.Message{ProviderName: "kubernetes", Configuration: conf}
|
||||
server.configurationChan <- config.Message{ProviderName: "marathon", Configuration: conf}
|
||||
|
||||
select {
|
||||
case <-consumePublishedConfigsDone:
|
||||
if val := publishedProviderConfigCount["kubernetes"]; val != 1 {
|
||||
t.Errorf("Got %d configuration publication(s) for provider %q, want 1", val, "kubernetes")
|
||||
}
|
||||
if val := publishedProviderConfigCount["marathon"]; val != 1 {
|
||||
t.Errorf("Got %d configuration publication(s) for provider %q, want 1", val, "marathon")
|
||||
}
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Errorf("Published configurations were not consumed in time")
|
||||
}
|
||||
}
|
||||
|
||||
// setupListenProvider configures the Server and starts listenProviders
|
||||
func setupListenProvider(throttleDuration time.Duration) (server *Server, stop chan bool, invokeStopChan func()) {
|
||||
stop = make(chan bool)
|
||||
invokeStopChan = func() {
|
||||
stop <- true
|
||||
}
|
||||
|
||||
staticConfiguration := static.Configuration{
|
||||
Providers: &static.Providers{
|
||||
ProvidersThrottleDuration: parse.Duration(throttleDuration),
|
||||
},
|
||||
}
|
||||
|
||||
server = NewServer(staticConfiguration, nil, nil, nil)
|
||||
go server.listenProviders(stop)
|
||||
|
||||
return server, stop, invokeStopChan
|
||||
}
|
||||
|
||||
func TestServerResponseEmptyBackend(t *testing.T) {
|
||||
const requestPath = "/path"
|
||||
const routeRule = "Path(`" + requestPath + "`)"
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
config func(testServerURL string) *config.HTTPConfiguration
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
desc: "Ok",
|
||||
config: func(testServerURL string) *config.HTTPConfiguration {
|
||||
return th.BuildConfiguration(
|
||||
th.WithRouters(th.WithRouter("foo",
|
||||
th.WithEntryPoints("http"),
|
||||
th.WithServiceName("bar"),
|
||||
th.WithRule(routeRule)),
|
||||
),
|
||||
th.WithLoadBalancerServices(th.WithService("bar",
|
||||
th.WithLBMethod("wrr"),
|
||||
th.WithServers(th.WithServer(testServerURL))),
|
||||
),
|
||||
)
|
||||
},
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
desc: "No Frontend",
|
||||
config: func(testServerURL string) *config.HTTPConfiguration {
|
||||
return th.BuildConfiguration()
|
||||
},
|
||||
expectedStatusCode: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
desc: "Empty Backend LB-Drr",
|
||||
config: func(testServerURL string) *config.HTTPConfiguration {
|
||||
return th.BuildConfiguration(
|
||||
th.WithRouters(th.WithRouter("foo",
|
||||
th.WithEntryPoints("http"),
|
||||
th.WithServiceName("bar"),
|
||||
th.WithRule(routeRule)),
|
||||
),
|
||||
th.WithLoadBalancerServices(th.WithService("bar",
|
||||
th.WithLBMethod("drr")),
|
||||
),
|
||||
)
|
||||
},
|
||||
expectedStatusCode: http.StatusServiceUnavailable,
|
||||
},
|
||||
{
|
||||
desc: "Empty Backend LB-Drr Sticky",
|
||||
config: func(testServerURL string) *config.HTTPConfiguration {
|
||||
return th.BuildConfiguration(
|
||||
th.WithRouters(th.WithRouter("foo",
|
||||
th.WithEntryPoints("http"),
|
||||
th.WithServiceName("bar"),
|
||||
th.WithRule(routeRule)),
|
||||
),
|
||||
th.WithLoadBalancerServices(th.WithService("bar",
|
||||
th.WithLBMethod("drr"), th.WithStickiness("test")),
|
||||
),
|
||||
)
|
||||
},
|
||||
expectedStatusCode: http.StatusServiceUnavailable,
|
||||
},
|
||||
{
|
||||
desc: "Empty Backend LB-Wrr",
|
||||
config: func(testServerURL string) *config.HTTPConfiguration {
|
||||
return th.BuildConfiguration(
|
||||
th.WithRouters(th.WithRouter("foo",
|
||||
th.WithEntryPoints("http"),
|
||||
th.WithServiceName("bar"),
|
||||
th.WithRule(routeRule)),
|
||||
),
|
||||
th.WithLoadBalancerServices(th.WithService("bar",
|
||||
th.WithLBMethod("wrr")),
|
||||
),
|
||||
)
|
||||
},
|
||||
expectedStatusCode: http.StatusServiceUnavailable,
|
||||
},
|
||||
{
|
||||
desc: "Empty Backend LB-Wrr Sticky",
|
||||
config: func(testServerURL string) *config.HTTPConfiguration {
|
||||
return th.BuildConfiguration(
|
||||
th.WithRouters(th.WithRouter("foo",
|
||||
th.WithEntryPoints("http"),
|
||||
th.WithServiceName("bar"),
|
||||
th.WithRule(routeRule)),
|
||||
),
|
||||
th.WithLoadBalancerServices(th.WithService("bar",
|
||||
th.WithLBMethod("wrr"), th.WithStickiness("test")),
|
||||
),
|
||||
)
|
||||
},
|
||||
expectedStatusCode: http.StatusServiceUnavailable,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
globalConfig := static.Configuration{}
|
||||
entryPointsConfig := TCPEntryPoints{
|
||||
"http": &TCPEntryPoint{},
|
||||
}
|
||||
|
||||
srv := NewServer(globalConfig, nil, entryPointsConfig, nil)
|
||||
entryPoints, _ := srv.createHTTPHandlers(context.Background(), *test.config(testServer.URL), []string{"http"})
|
||||
|
||||
responseRecorder := &httptest.ResponseRecorder{}
|
||||
request := httptest.NewRequest(http.MethodGet, testServer.URL+requestPath, nil)
|
||||
|
||||
entryPoints["http"].ServeHTTP(responseRecorder, request)
|
||||
|
||||
assert.Equal(t, test.expectedStatusCode, responseRecorder.Result().StatusCode, "status code")
|
||||
})
|
||||
}
|
||||
}
|
27
pkg/server/service/bufferpool.go
Normal file
27
pkg/server/service/bufferpool.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
package service
|
||||
|
||||
import "sync"
|
||||
|
||||
const bufferPoolSize = 32 * 1024
|
||||
|
||||
func newBufferPool() *bufferPool {
|
||||
return &bufferPool{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, bufferPoolSize)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type bufferPool struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
func (b *bufferPool) Get() []byte {
|
||||
return b.pool.Get().([]byte)
|
||||
}
|
||||
|
||||
func (b *bufferPool) Put(bytes []byte) {
|
||||
b.pool.Put(bytes)
|
||||
}
|
100
pkg/server/service/proxy.go
Normal file
100
pkg/server/service/proxy.go
Normal file
|
@ -0,0 +1,100 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/containous/flaeg/parse"
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
)
|
||||
|
||||
// StatusClientClosedRequest non-standard HTTP status code for client disconnection
|
||||
const StatusClientClosedRequest = 499
|
||||
|
||||
// StatusClientClosedRequestText non-standard HTTP status for client disconnection
|
||||
const StatusClientClosedRequestText = "Client Closed Request"
|
||||
|
||||
func buildProxy(passHostHeader bool, responseForwarding *config.ResponseForwarding, defaultRoundTripper http.RoundTripper, bufferPool httputil.BufferPool, responseModifier func(*http.Response) error) (http.Handler, error) {
|
||||
var flushInterval parse.Duration
|
||||
if responseForwarding != nil {
|
||||
err := flushInterval.Set(responseForwarding.FlushInterval)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating flush interval: %v", err)
|
||||
}
|
||||
}
|
||||
if flushInterval == 0 {
|
||||
flushInterval = parse.Duration(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
proxy := &httputil.ReverseProxy{
|
||||
Director: func(outReq *http.Request) {
|
||||
u := outReq.URL
|
||||
if outReq.RequestURI != "" {
|
||||
parsedURL, err := url.ParseRequestURI(outReq.RequestURI)
|
||||
if err == nil {
|
||||
u = parsedURL
|
||||
}
|
||||
}
|
||||
|
||||
outReq.URL.Path = u.Path
|
||||
outReq.URL.RawPath = u.RawPath
|
||||
outReq.URL.RawQuery = u.RawQuery
|
||||
outReq.RequestURI = "" // Outgoing request should not have RequestURI
|
||||
|
||||
outReq.Proto = "HTTP/1.1"
|
||||
outReq.ProtoMajor = 1
|
||||
outReq.ProtoMinor = 1
|
||||
|
||||
// Do not pass client Host header unless optsetter PassHostHeader is set.
|
||||
if !passHostHeader {
|
||||
outReq.Host = outReq.URL.Host
|
||||
}
|
||||
|
||||
},
|
||||
Transport: defaultRoundTripper,
|
||||
FlushInterval: time.Duration(flushInterval),
|
||||
ModifyResponse: responseModifier,
|
||||
BufferPool: bufferPool,
|
||||
ErrorHandler: func(w http.ResponseWriter, request *http.Request, err error) {
|
||||
statusCode := http.StatusInternalServerError
|
||||
|
||||
switch {
|
||||
case err == io.EOF:
|
||||
statusCode = http.StatusBadGateway
|
||||
case err == context.Canceled:
|
||||
statusCode = StatusClientClosedRequest
|
||||
default:
|
||||
if e, ok := err.(net.Error); ok {
|
||||
if e.Timeout() {
|
||||
statusCode = http.StatusGatewayTimeout
|
||||
} else {
|
||||
statusCode = http.StatusBadGateway
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("'%d %s' caused by: %v", statusCode, statusText(statusCode), err)
|
||||
w.WriteHeader(statusCode)
|
||||
_, werr := w.Write([]byte(statusText(statusCode)))
|
||||
if werr != nil {
|
||||
log.Debugf("Error while writing status code", werr)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
func statusText(statusCode int) string {
|
||||
if statusCode == StatusClientClosedRequest {
|
||||
return StatusClientClosedRequestText
|
||||
}
|
||||
return http.StatusText(statusCode)
|
||||
}
|
37
pkg/server/service/proxy_test.go
Normal file
37
pkg/server/service/proxy_test.go
Normal file
|
@ -0,0 +1,37 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/pkg/testhelpers"
|
||||
)
|
||||
|
||||
type staticTransport struct {
|
||||
res *http.Response
|
||||
}
|
||||
|
||||
func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
return t.res, nil
|
||||
}
|
||||
|
||||
func BenchmarkProxy(b *testing.B) {
|
||||
res := &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("")),
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
|
||||
|
||||
pool := newBufferPool()
|
||||
handler, _ := buildProxy(false, nil, &staticTransport{res}, pool, nil)
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
handler.ServeHTTP(w, req)
|
||||
}
|
||||
}
|
713
pkg/server/service/proxy_websocket_test.go
Normal file
713
pkg/server/service/proxy_websocket_test.go
Normal file
|
@ -0,0 +1,713 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
gorillawebsocket "github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
func TestWebSocketTCPClose(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
for {
|
||||
_, _, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, f, srv.URL)
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
_, conn, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
).open()
|
||||
require.NoError(t, err)
|
||||
conn.Close()
|
||||
|
||||
serverErr := <-errChan
|
||||
|
||||
wsErr, ok := serverErr.(*gorillawebsocket.CloseError)
|
||||
assert.Equal(t, true, ok)
|
||||
assert.Equal(t, 1006, wsErr.Code)
|
||||
}
|
||||
|
||||
func TestWebSocketPingPong(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
var upgrader = gorillawebsocket.Upgrader{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
CheckOrigin: func(*http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/ws", func(writer http.ResponseWriter, request *http.Request) {
|
||||
ws, err := upgrader.Upgrade(writer, request, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
ws.SetPingHandler(func(appData string) error {
|
||||
err = ws.WriteMessage(gorillawebsocket.PongMessage, []byte(appData+"Pong"))
|
||||
require.NoError(t, err)
|
||||
return nil
|
||||
})
|
||||
|
||||
_, _, _ = ws.ReadMessage()
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
req.URL = parseURI(t, srv.URL)
|
||||
f.ServeHTTP(w, req)
|
||||
}))
|
||||
defer proxy.Close()
|
||||
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", resp)
|
||||
defer conn.Close()
|
||||
|
||||
goodErr := fmt.Errorf("signal: %s", "Good data")
|
||||
badErr := fmt.Errorf("signal: %s", "Bad data")
|
||||
conn.SetPongHandler(func(data string) error {
|
||||
if data == "PingPong" {
|
||||
return goodErr
|
||||
}
|
||||
return badErr
|
||||
})
|
||||
|
||||
err = conn.WriteControl(gorillawebsocket.PingMessage, []byte("Ping"), time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = conn.ReadMessage()
|
||||
|
||||
if err != goodErr {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketEcho(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
msg := make([]byte, 4)
|
||||
_, err := conn.Read(msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
fmt.Println(string(msg))
|
||||
|
||||
_, err = conn.Write(msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
}))
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
req.URL = parseURI(t, srv.URL)
|
||||
f.ServeHTTP(w, req)
|
||||
}))
|
||||
defer proxy.Close()
|
||||
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", resp)
|
||||
|
||||
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
|
||||
require.NoError(t, err)
|
||||
|
||||
fmt.Println(conn.ReadMessage())
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
}
|
||||
|
||||
func TestWebSocketPassHost(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
passHost bool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "PassHost false",
|
||||
passHost: false,
|
||||
},
|
||||
{
|
||||
desc: "PassHost true",
|
||||
passHost: true,
|
||||
expected: "example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
f, err := buildProxy(test.passHost, nil, http.DefaultTransport, nil, nil)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
req := conn.Request()
|
||||
|
||||
if test.passHost {
|
||||
require.Equal(t, test.expected, req.Host)
|
||||
} else {
|
||||
require.NotEqual(t, test.expected, req.Host)
|
||||
}
|
||||
|
||||
msg := make([]byte, 4)
|
||||
_, err = conn.Read(msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
fmt.Println(string(msg))
|
||||
_, err = conn.Write(msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
}))
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
req.URL = parseURI(t, srv.URL)
|
||||
f.ServeHTTP(w, req)
|
||||
}))
|
||||
defer proxy.Close()
|
||||
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
headers.Add("Host", "example.com")
|
||||
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", resp)
|
||||
|
||||
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
|
||||
require.NoError(t, err)
|
||||
|
||||
fmt.Println(conn.ReadMessage())
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketServerWithoutCheckOrigin(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
}}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
for {
|
||||
mt, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = c.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, f, srv.URL)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
withOrigin("http://127.0.0.2"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithOrigin(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
for {
|
||||
mt, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = c.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, f, srv.URL)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
_, err = newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("echo"),
|
||||
withOrigin("http://127.0.0.2"),
|
||||
).send()
|
||||
require.EqualError(t, err, "bad status")
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithQueryParams(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
assert.Equal(t, "test", r.URL.Query().Get("query"))
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = conn.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, f, srv.URL)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws?query=test"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
conn.Close()
|
||||
}))
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
req.URL = parseURI(t, srv.URL)
|
||||
w.Header().Set("HEADER-KEY", "HEADER-VALUE")
|
||||
f.ServeHTTP(w, req)
|
||||
}))
|
||||
defer proxy.Close()
|
||||
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", err, resp)
|
||||
defer conn.Close()
|
||||
|
||||
assert.Equal(t, "HEADER-VALUE", resp.Header.Get("HEADER-KEY"))
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithEncodedChar(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
assert.Equal(t, "/%3A%2F%2F", r.URL.EscapedPath())
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = conn.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, f, srv.URL)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/%3A%2F%2F"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketUpgradeFailed(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/ws", func(w http.ResponseWriter, req *http.Request) {
|
||||
w.WriteHeader(400)
|
||||
})
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
path := req.URL.Path // keep the original path
|
||||
|
||||
if path == "/ws" {
|
||||
// Set new backend URL
|
||||
req.URL = parseURI(t, srv.URL)
|
||||
req.URL.Path = path
|
||||
f.ServeHTTP(w, req)
|
||||
} else {
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
}))
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
conn, err := net.DialTimeout("tcp", proxyAddr, dialTimeout)
|
||||
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "ws://127.0.0.1/ws", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Add("upgrade", "websocket")
|
||||
req.Header.Add("Connection", "upgrade")
|
||||
|
||||
err = req.Write(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First request works with 400
|
||||
br := bufio.NewReader(conn)
|
||||
resp, err := http.ReadResponse(br, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 400, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestForwardsWebsocketTraffic(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
_, err := conn.Write([]byte("ok"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, f, srv.URL)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("echo"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func createTLSWebsocketServer() *httptest.Server {
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = conn.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestWebSocketTransferTLSConfig(t *testing.T) {
|
||||
srv := createTLSWebsocketServer()
|
||||
defer srv.Close()
|
||||
|
||||
forwarderWithoutTLSConfig, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyWithoutTLSConfig := createProxyWithForwarder(t, forwarderWithoutTLSConfig, srv.URL)
|
||||
defer proxyWithoutTLSConfig.Close()
|
||||
|
||||
proxyAddr := proxyWithoutTLSConfig.Listener.Addr().String()
|
||||
|
||||
_, err = newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.EqualError(t, err, "bad status")
|
||||
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
forwarderWithTLSConfig, err := buildProxy(true, nil, transport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyWithTLSConfig := createProxyWithForwarder(t, forwarderWithTLSConfig, srv.URL)
|
||||
defer proxyWithTLSConfig.Close()
|
||||
|
||||
proxyAddr = proxyWithTLSConfig.Listener.Addr().String()
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
|
||||
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
|
||||
forwarderWithTLSConfigFromDefaultTransport, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyWithTLSConfigFromDefaultTransport := createProxyWithForwarder(t, forwarderWithTLSConfigFromDefaultTransport, srv.URL)
|
||||
defer proxyWithTLSConfig.Close()
|
||||
|
||||
proxyAddr = proxyWithTLSConfigFromDefaultTransport.Listener.Addr().String()
|
||||
|
||||
resp, err = newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
const dialTimeout = time.Second
|
||||
|
||||
type websocketRequestOpt func(w *websocketRequest)
|
||||
|
||||
func withServer(server string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.ServerAddr = server
|
||||
}
|
||||
}
|
||||
|
||||
func withPath(path string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.Path = path
|
||||
}
|
||||
}
|
||||
|
||||
func withData(data string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.Data = data
|
||||
}
|
||||
}
|
||||
|
||||
func withOrigin(origin string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.Origin = origin
|
||||
}
|
||||
}
|
||||
|
||||
func newWebsocketRequest(opts ...websocketRequestOpt) *websocketRequest {
|
||||
wsrequest := &websocketRequest{}
|
||||
for _, opt := range opts {
|
||||
opt(wsrequest)
|
||||
}
|
||||
if wsrequest.Origin == "" {
|
||||
wsrequest.Origin = "http://" + wsrequest.ServerAddr
|
||||
}
|
||||
if wsrequest.Config == nil {
|
||||
wsrequest.Config, _ = websocket.NewConfig(fmt.Sprintf("ws://%s%s", wsrequest.ServerAddr, wsrequest.Path), wsrequest.Origin)
|
||||
}
|
||||
return wsrequest
|
||||
}
|
||||
|
||||
type websocketRequest struct {
|
||||
ServerAddr string
|
||||
Path string
|
||||
Data string
|
||||
Origin string
|
||||
Config *websocket.Config
|
||||
}
|
||||
|
||||
func (w *websocketRequest) send() (string, error) {
|
||||
conn, _, err := w.open()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer conn.Close()
|
||||
if _, err := conn.Write([]byte(w.Data)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
var msg = make([]byte, 512)
|
||||
var n int
|
||||
n, err = conn.Read(msg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
received := string(msg[:n])
|
||||
return received, nil
|
||||
}
|
||||
|
||||
func (w *websocketRequest) open() (*websocket.Conn, net.Conn, error) {
|
||||
client, err := net.DialTimeout("tcp", w.ServerAddr, dialTimeout)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
conn, err := websocket.NewClient(w.Config, client)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return conn, client, err
|
||||
}
|
||||
|
||||
func parseURI(t *testing.T, uri string) *url.URL {
|
||||
out, err := url.ParseRequestURI(uri)
|
||||
require.NoError(t, err)
|
||||
return out
|
||||
}
|
||||
|
||||
func createProxyWithForwarder(t *testing.T, proxy http.Handler, url string) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
path := req.URL.Path // keep the original path
|
||||
// Set new backend URL
|
||||
req.URL = parseURI(t, url)
|
||||
req.URL.Path = path
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
}))
|
||||
}
|
257
pkg/server/service/service.go
Normal file
257
pkg/server/service/service.go
Normal file
|
@ -0,0 +1,257 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/containous/alice"
|
||||
"github.com/containous/traefik/old/middlewares/pipelining"
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/healthcheck"
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
"github.com/containous/traefik/pkg/middlewares/accesslog"
|
||||
"github.com/containous/traefik/pkg/middlewares/emptybackendhandler"
|
||||
"github.com/containous/traefik/pkg/server/cookie"
|
||||
"github.com/containous/traefik/pkg/server/internal"
|
||||
"github.com/vulcand/oxy/roundrobin"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultHealthCheckInterval = 30 * time.Second
|
||||
defaultHealthCheckTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// NewManager creates a new Manager
|
||||
func NewManager(configs map[string]*config.Service, defaultRoundTripper http.RoundTripper) *Manager {
|
||||
return &Manager{
|
||||
bufferPool: newBufferPool(),
|
||||
defaultRoundTripper: defaultRoundTripper,
|
||||
balancers: make(map[string][]healthcheck.BalancerHandler),
|
||||
configs: configs,
|
||||
}
|
||||
}
|
||||
|
||||
// Manager The service manager
|
||||
type Manager struct {
|
||||
bufferPool httputil.BufferPool
|
||||
defaultRoundTripper http.RoundTripper
|
||||
balancers map[string][]healthcheck.BalancerHandler
|
||||
configs map[string]*config.Service
|
||||
}
|
||||
|
||||
// BuildHTTP Creates a http.Handler for a service configuration.
|
||||
func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) {
|
||||
ctx := log.With(rootCtx, log.Str(log.ServiceName, serviceName))
|
||||
|
||||
serviceName = internal.GetQualifiedName(ctx, serviceName)
|
||||
ctx = internal.AddProviderInContext(ctx, serviceName)
|
||||
|
||||
if conf, ok := m.configs[serviceName]; ok {
|
||||
// TODO Should handle multiple service types
|
||||
// FIXME Check if the service is declared multiple times with different types
|
||||
if conf.LoadBalancer != nil {
|
||||
return m.getLoadBalancerServiceHandler(ctx, serviceName, conf.LoadBalancer, responseModifier)
|
||||
}
|
||||
return nil, fmt.Errorf("the service %q doesn't have any load balancer", serviceName)
|
||||
}
|
||||
return nil, fmt.Errorf("the service %q does not exits", serviceName)
|
||||
}
|
||||
|
||||
func (m *Manager) getLoadBalancerServiceHandler(
|
||||
ctx context.Context,
|
||||
serviceName string,
|
||||
service *config.LoadBalancerService,
|
||||
responseModifier func(*http.Response) error,
|
||||
) (http.Handler, error) {
|
||||
fwd, err := buildProxy(service.PassHostHeader, service.ResponseForwarding, m.defaultRoundTripper, m.bufferPool, responseModifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
alHandler := func(next http.Handler) (http.Handler, error) {
|
||||
return accesslog.NewFieldHandler(next, accesslog.ServiceName, serviceName, accesslog.AddServiceFields), nil
|
||||
}
|
||||
|
||||
handler, err := alice.New().Append(alHandler).Then(pipelining.NewPipelining(fwd))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
balancer, err := m.getLoadBalancer(ctx, serviceName, service, handler)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO rename and checks
|
||||
m.balancers[serviceName] = append(m.balancers[serviceName], balancer)
|
||||
|
||||
// Empty (backend with no servers)
|
||||
return emptybackendhandler.New(balancer), nil
|
||||
}
|
||||
|
||||
// LaunchHealthCheck Launches the health checks.
|
||||
func (m *Manager) LaunchHealthCheck() {
|
||||
backendConfigs := make(map[string]*healthcheck.BackendConfig)
|
||||
|
||||
for serviceName, balancers := range m.balancers {
|
||||
ctx := log.With(context.Background(), log.Str(log.ServiceName, serviceName))
|
||||
|
||||
// FIXME aggregate
|
||||
balancer := balancers[0]
|
||||
|
||||
// FIXME Should all the services handle healthcheck? Handle different types
|
||||
service := m.configs[serviceName].LoadBalancer
|
||||
|
||||
// Health Check
|
||||
var backendHealthCheck *healthcheck.BackendConfig
|
||||
if hcOpts := buildHealthCheckOptions(ctx, balancer, serviceName, service.HealthCheck); hcOpts != nil {
|
||||
log.FromContext(ctx).Debugf("Setting up healthcheck for service %s with %s", serviceName, *hcOpts)
|
||||
|
||||
hcOpts.Transport = m.defaultRoundTripper
|
||||
backendHealthCheck = healthcheck.NewBackendConfig(*hcOpts, serviceName)
|
||||
}
|
||||
|
||||
if backendHealthCheck != nil {
|
||||
backendConfigs[serviceName] = backendHealthCheck
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME metrics and context
|
||||
healthcheck.GetHealthCheck().SetBackendsConfiguration(context.TODO(), backendConfigs)
|
||||
}
|
||||
|
||||
func buildHealthCheckOptions(ctx context.Context, lb healthcheck.BalancerHandler, backend string, hc *config.HealthCheck) *healthcheck.Options {
|
||||
if hc == nil || hc.Path == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
interval := defaultHealthCheckInterval
|
||||
if hc.Interval != "" {
|
||||
intervalOverride, err := time.ParseDuration(hc.Interval)
|
||||
switch {
|
||||
case err != nil:
|
||||
logger.Errorf("Illegal health check interval for '%s': %s", backend, err)
|
||||
case intervalOverride <= 0:
|
||||
logger.Errorf("Health check interval smaller than zero for service '%s'", backend)
|
||||
default:
|
||||
interval = intervalOverride
|
||||
}
|
||||
}
|
||||
|
||||
timeout := defaultHealthCheckTimeout
|
||||
if hc.Timeout != "" {
|
||||
timeoutOverride, err := time.ParseDuration(hc.Timeout)
|
||||
switch {
|
||||
case err != nil:
|
||||
logger.Errorf("Illegal health check timeout for backend '%s': %s", backend, err)
|
||||
case timeoutOverride <= 0:
|
||||
logger.Errorf("Health check timeout smaller than zero for backend '%s', backend", backend)
|
||||
default:
|
||||
timeout = timeoutOverride
|
||||
}
|
||||
}
|
||||
|
||||
if timeout >= interval {
|
||||
logger.Warnf("Health check timeout for backend '%s' should be lower than the health check interval. Interval set to timeout + 1 second (%s).", backend)
|
||||
}
|
||||
|
||||
return &healthcheck.Options{
|
||||
Scheme: hc.Scheme,
|
||||
Path: hc.Path,
|
||||
Port: hc.Port,
|
||||
Interval: interval,
|
||||
Timeout: timeout,
|
||||
LB: lb,
|
||||
Hostname: hc.Hostname,
|
||||
Headers: hc.Headers,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) getLoadBalancer(ctx context.Context, serviceName string, service *config.LoadBalancerService, fwd http.Handler) (healthcheck.BalancerHandler, error) {
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
var stickySession *roundrobin.StickySession
|
||||
var cookieName string
|
||||
if stickiness := service.Stickiness; stickiness != nil {
|
||||
cookieName = cookie.GetName(stickiness.CookieName, serviceName)
|
||||
stickySession = roundrobin.NewStickySession(cookieName)
|
||||
}
|
||||
|
||||
var lb healthcheck.BalancerHandler
|
||||
|
||||
if service.Method == "drr" {
|
||||
logger.Debug("Creating drr load-balancer")
|
||||
rr, err := roundrobin.New(fwd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if stickySession != nil {
|
||||
logger.Debugf("Sticky session cookie name: %v", cookieName)
|
||||
|
||||
lb, err = roundrobin.NewRebalancer(rr, roundrobin.RebalancerStickySession(stickySession))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
lb, err = roundrobin.NewRebalancer(rr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if service.Method != "wrr" {
|
||||
logger.Warnf("Invalid load-balancing method %q, fallback to 'wrr' method", service.Method)
|
||||
}
|
||||
|
||||
logger.Debug("Creating wrr load-balancer")
|
||||
|
||||
if stickySession != nil {
|
||||
logger.Debugf("Sticky session cookie name: %v", cookieName)
|
||||
|
||||
var err error
|
||||
lb, err = roundrobin.New(fwd, roundrobin.EnableStickySession(stickySession))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
var err error
|
||||
lb, err = roundrobin.New(fwd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.upsertServers(ctx, lb, service.Servers); err != nil {
|
||||
return nil, fmt.Errorf("error configuring load balancer for service %s: %v", serviceName, err)
|
||||
}
|
||||
|
||||
return lb, nil
|
||||
}
|
||||
|
||||
func (m *Manager) upsertServers(ctx context.Context, lb healthcheck.BalancerHandler, servers []config.Server) error {
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
for name, srv := range servers {
|
||||
u, err := url.Parse(srv.URL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing server URL %s: %v", srv.URL, err)
|
||||
}
|
||||
|
||||
logger.WithField(log.ServerName, name).Debugf("Creating server %d at %s with weight %d", name, u, srv.Weight)
|
||||
|
||||
if err := lb.UpsertServer(u, roundrobin.Weight(srv.Weight)); err != nil {
|
||||
return fmt.Errorf("error adding server %s to load balancer: %v", srv.URL, err)
|
||||
}
|
||||
|
||||
// FIXME Handle Metrics
|
||||
}
|
||||
return nil
|
||||
}
|
329
pkg/server/service/service_test.go
Normal file
329
pkg/server/service/service_test.go
Normal file
|
@ -0,0 +1,329 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/server/internal"
|
||||
"github.com/containous/traefik/pkg/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type MockForwarder struct{}
|
||||
|
||||
func (MockForwarder) ServeHTTP(http.ResponseWriter, *http.Request) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func TestGetLoadBalancer(t *testing.T) {
|
||||
sm := Manager{}
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
serviceName string
|
||||
service *config.LoadBalancerService
|
||||
fwd http.Handler
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
desc: "Fails when provided an invalid URL",
|
||||
serviceName: "test",
|
||||
service: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: ":",
|
||||
Weight: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
fwd: &MockForwarder{},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
desc: "Succeeds when there are no servers",
|
||||
serviceName: "test",
|
||||
service: &config.LoadBalancerService{},
|
||||
fwd: &MockForwarder{},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
desc: "Succeeds when stickiness is set",
|
||||
serviceName: "test",
|
||||
service: &config.LoadBalancerService{
|
||||
Stickiness: &config.Stickiness{},
|
||||
},
|
||||
fwd: &MockForwarder{},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := sm.getLoadBalancer(context.Background(), test.serviceName, test.service, test.fwd)
|
||||
if test.expectError {
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, handler)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, handler)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLoadBalancerServiceHandler(t *testing.T) {
|
||||
sm := NewManager(nil, http.DefaultTransport)
|
||||
|
||||
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-From", "first")
|
||||
}))
|
||||
defer server1.Close()
|
||||
|
||||
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-From", "second")
|
||||
}))
|
||||
defer server2.Close()
|
||||
|
||||
serverPassHost := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-From", "passhost")
|
||||
assert.Equal(t, "callme", r.Host)
|
||||
}))
|
||||
defer serverPassHost.Close()
|
||||
|
||||
serverPassHostFalse := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-From", "passhostfalse")
|
||||
assert.NotEqual(t, "callme", r.Host)
|
||||
}))
|
||||
defer serverPassHostFalse.Close()
|
||||
|
||||
type ExpectedResult struct {
|
||||
StatusCode int
|
||||
XFrom string
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
serviceName string
|
||||
service *config.LoadBalancerService
|
||||
responseModifier func(*http.Response) error
|
||||
|
||||
expected []ExpectedResult
|
||||
}{
|
||||
{
|
||||
desc: "Load balances between the two servers",
|
||||
serviceName: "test",
|
||||
service: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server1.URL,
|
||||
Weight: 50,
|
||||
},
|
||||
{
|
||||
URL: server2.URL,
|
||||
Weight: 50,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
expected: []ExpectedResult{
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
XFrom: "first",
|
||||
},
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
XFrom: "second",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "StatusBadGateway when the server is not reachable",
|
||||
serviceName: "test",
|
||||
service: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: "http://foo",
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
expected: []ExpectedResult{
|
||||
{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "ServiceUnavailable when no servers are available",
|
||||
serviceName: "test",
|
||||
service: &config.LoadBalancerService{
|
||||
Servers: []config.Server{},
|
||||
Method: "wrr",
|
||||
},
|
||||
expected: []ExpectedResult{
|
||||
{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Always call the same server when stickiness is true",
|
||||
serviceName: "test",
|
||||
service: &config.LoadBalancerService{
|
||||
Stickiness: &config.Stickiness{},
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server1.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
{
|
||||
URL: server2.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
expected: []ExpectedResult{
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
XFrom: "first",
|
||||
},
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
XFrom: "first",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "PassHost passes the host instead of the IP",
|
||||
serviceName: "test",
|
||||
service: &config.LoadBalancerService{
|
||||
Stickiness: &config.Stickiness{},
|
||||
PassHostHeader: true,
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: serverPassHost.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
expected: []ExpectedResult{
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
XFrom: "passhost",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "PassHost doesn't passe the host instead of the IP",
|
||||
serviceName: "test",
|
||||
service: &config.LoadBalancerService{
|
||||
Stickiness: &config.Stickiness{},
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: serverPassHostFalse.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
expected: []ExpectedResult{
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
XFrom: "passhostfalse",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
|
||||
handler, err := sm.getLoadBalancerServiceHandler(context.Background(), test.serviceName, test.service, test.responseModifier)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, handler)
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://callme", nil)
|
||||
for _, expected := range test.expected {
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, expected.StatusCode, recorder.Code)
|
||||
assert.Equal(t, expected.XFrom, recorder.Header().Get("X-From"))
|
||||
|
||||
if len(recorder.Header().Get("Set-Cookie")) > 0 {
|
||||
req.Header.Set("Cookie", recorder.Header().Get("Set-Cookie"))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Build(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
serviceName string
|
||||
configs map[string]*config.Service
|
||||
providerName string
|
||||
}{
|
||||
{
|
||||
desc: "Simple service name",
|
||||
serviceName: "serviceName",
|
||||
configs: map[string]*config.Service{
|
||||
"serviceName": {
|
||||
LoadBalancer: &config.LoadBalancerService{Method: "wrr"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Service name with provider",
|
||||
serviceName: "provider-1.serviceName",
|
||||
configs: map[string]*config.Service{
|
||||
"provider-1.serviceName": {
|
||||
LoadBalancer: &config.LoadBalancerService{Method: "wrr"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Service name with provider in context",
|
||||
serviceName: "serviceName",
|
||||
configs: map[string]*config.Service{
|
||||
"provider-1.serviceName": {
|
||||
LoadBalancer: &config.LoadBalancerService{Method: "wrr"},
|
||||
},
|
||||
},
|
||||
providerName: "provider-1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := NewManager(test.configs, http.DefaultTransport)
|
||||
|
||||
ctx := context.Background()
|
||||
if len(test.providerName) > 0 {
|
||||
ctx = internal.AddProviderInContext(ctx, test.providerName+".foobar")
|
||||
}
|
||||
|
||||
_, err := manager.BuildHTTP(ctx, test.serviceName, nil)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME Add healthcheck tests
|
67
pkg/server/service/tcp/service.go
Normal file
67
pkg/server/service/tcp/service.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
"github.com/containous/traefik/pkg/server/internal"
|
||||
"github.com/containous/traefik/pkg/tcp"
|
||||
)
|
||||
|
||||
// Manager is the TCPHandlers factory
|
||||
type Manager struct {
|
||||
configs map[string]*config.TCPService
|
||||
}
|
||||
|
||||
// NewManager creates a new manager
|
||||
func NewManager(configs map[string]*config.TCPService) *Manager {
|
||||
return &Manager{
|
||||
configs: configs,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildTCP Creates a tcp.Handler for a service configuration.
|
||||
func (m *Manager) BuildTCP(rootCtx context.Context, serviceName string) (tcp.Handler, error) {
|
||||
ctx := log.With(rootCtx, log.Str(log.ServiceName, serviceName))
|
||||
|
||||
serviceName = internal.GetQualifiedName(ctx, serviceName)
|
||||
ctx = internal.AddProviderInContext(ctx, serviceName)
|
||||
|
||||
if conf, ok := m.configs[serviceName]; ok {
|
||||
// FIXME Check if the service is declared multiple times with different types
|
||||
if conf.LoadBalancer != nil {
|
||||
loadBalancer := tcp.NewRRLoadBalancer()
|
||||
|
||||
var handler tcp.Handler
|
||||
for _, server := range conf.LoadBalancer.Servers {
|
||||
_, err := parseIP(server.Address)
|
||||
if err == nil {
|
||||
handler, _ = tcp.NewProxy(server.Address)
|
||||
loadBalancer.AddServer(handler)
|
||||
} else {
|
||||
log.FromContext(ctx).Errorf("Invalid IP address for a %s server %s: %v", serviceName, server.Address, err)
|
||||
}
|
||||
}
|
||||
return loadBalancer, nil
|
||||
}
|
||||
return nil, fmt.Errorf("the service %q doesn't have any TCP load balancer", serviceName)
|
||||
}
|
||||
return nil, fmt.Errorf("the service %q does not exits", serviceName)
|
||||
}
|
||||
|
||||
func parseIP(s string) (string, error) {
|
||||
ip, _, err := net.SplitHostPort(s)
|
||||
if err == nil {
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
ipNoPort := net.ParseIP(s)
|
||||
if ipNoPort == nil {
|
||||
return "", fmt.Errorf("invalid IP Address %s", ipNoPort)
|
||||
}
|
||||
|
||||
return ipNoPort.String(), nil
|
||||
}
|
14
pkg/server/uuid/uuid.go
Normal file
14
pkg/server/uuid/uuid.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
package uuid
|
||||
|
||||
import guuid "github.com/satori/go.uuid"
|
||||
|
||||
var uuid string
|
||||
|
||||
func init() {
|
||||
uuid = guuid.NewV4().String()
|
||||
}
|
||||
|
||||
// Get the instance UUID
|
||||
func Get() string {
|
||||
return uuid
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue