Move code to pkg

This commit is contained in:
Ludovic Fernandez 2019-03-15 09:42:03 +01:00 committed by Traefiker Bot
parent bd4c822670
commit f1b085fa36
465 changed files with 656 additions and 680 deletions

57
pkg/server/aggregator.go Normal file
View file

@ -0,0 +1,57 @@
package server
import (
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/server/internal"
"github.com/containous/traefik/pkg/tls"
)
func mergeConfiguration(configurations config.Configurations) config.Configuration {
conf := config.Configuration{
HTTP: &config.HTTPConfiguration{
Routers: make(map[string]*config.Router),
Middlewares: make(map[string]*config.Middleware),
Services: make(map[string]*config.Service),
},
TCP: &config.TCPConfiguration{
Routers: make(map[string]*config.TCPRouter),
Services: make(map[string]*config.TCPService),
},
TLSOptions: make(map[string]tls.TLS),
TLSStores: make(map[string]tls.Store),
}
for provider, configuration := range configurations {
if configuration.HTTP != nil {
for routerName, router := range configuration.HTTP.Routers {
conf.HTTP.Routers[internal.MakeQualifiedName(provider, routerName)] = router
}
for middlewareName, middleware := range configuration.HTTP.Middlewares {
conf.HTTP.Middlewares[internal.MakeQualifiedName(provider, middlewareName)] = middleware
}
for serviceName, service := range configuration.HTTP.Services {
conf.HTTP.Services[internal.MakeQualifiedName(provider, serviceName)] = service
}
}
if configuration.TCP != nil {
for routerName, router := range configuration.TCP.Routers {
conf.TCP.Routers[internal.MakeQualifiedName(provider, routerName)] = router
}
for serviceName, service := range configuration.TCP.Services {
conf.TCP.Services[internal.MakeQualifiedName(provider, serviceName)] = service
}
}
conf.TLS = append(conf.TLS, configuration.TLS...)
for key, store := range configuration.TLSStores {
conf.TLSStores[key] = store
}
for key, config := range configuration.TLSOptions {
conf.TLSOptions[key] = config
}
}
return conf
}

View file

@ -0,0 +1,110 @@
package server
import (
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/stretchr/testify/assert"
)
func TestAggregator(t *testing.T) {
testCases := []struct {
desc string
given config.Configurations
expected *config.HTTPConfiguration
}{
{
desc: "Nil returns an empty configuration",
given: nil,
expected: &config.HTTPConfiguration{
Routers: make(map[string]*config.Router),
Middlewares: make(map[string]*config.Middleware),
Services: make(map[string]*config.Service),
},
},
{
desc: "Returns fully qualified elements from a mono-provider configuration map",
given: config.Configurations{
"provider-1": &config.Configuration{
HTTP: &config.HTTPConfiguration{
Routers: map[string]*config.Router{
"router-1": {},
},
Middlewares: map[string]*config.Middleware{
"middleware-1": {},
},
Services: map[string]*config.Service{
"service-1": {},
},
},
},
},
expected: &config.HTTPConfiguration{
Routers: map[string]*config.Router{
"provider-1.router-1": {},
},
Middlewares: map[string]*config.Middleware{
"provider-1.middleware-1": {},
},
Services: map[string]*config.Service{
"provider-1.service-1": {},
},
},
},
{
desc: "Returns fully qualified elements from a multi-provider configuration map",
given: config.Configurations{
"provider-1": &config.Configuration{
HTTP: &config.HTTPConfiguration{
Routers: map[string]*config.Router{
"router-1": {},
},
Middlewares: map[string]*config.Middleware{
"middleware-1": {},
},
Services: map[string]*config.Service{
"service-1": {},
},
},
},
"provider-2": &config.Configuration{
HTTP: &config.HTTPConfiguration{
Routers: map[string]*config.Router{
"router-1": {},
},
Middlewares: map[string]*config.Middleware{
"middleware-1": {},
},
Services: map[string]*config.Service{
"service-1": {},
},
},
},
},
expected: &config.HTTPConfiguration{
Routers: map[string]*config.Router{
"provider-1.router-1": {},
"provider-2.router-1": {},
},
Middlewares: map[string]*config.Middleware{
"provider-1.middleware-1": {},
"provider-2.middleware-1": {},
},
Services: map[string]*config.Service{
"provider-1.service-1": {},
"provider-2.service-1": {},
},
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
actual := mergeConfiguration(test.given)
assert.Equal(t, test.expected, actual.HTTP)
})
}
}

View file

@ -0,0 +1,57 @@
package cookie
import (
"crypto/sha1"
"fmt"
"strings"
"github.com/containous/traefik/pkg/log"
)
const cookieNameLength = 6
// GetName of a cookie
func GetName(cookieName string, backendName string) string {
if len(cookieName) != 0 {
return sanitizeName(cookieName)
}
return GenerateName(backendName)
}
// GenerateName Generate a hashed name
func GenerateName(backendName string) string {
data := []byte("_TRAEFIK_BACKEND_" + backendName)
hash := sha1.New()
_, err := hash.Write(data)
if err != nil {
// Impossible case
log.Errorf("Fail to create cookie name: %v", err)
}
return fmt.Sprintf("_%x", hash.Sum(nil))[:cookieNameLength]
}
// sanitizeName According to [RFC 2616](https://www.ietf.org/rfc/rfc2616.txt) section 2.2
func sanitizeName(backend string) string {
sanitizer := func(r rune) rune {
switch r {
case '!', '#', '$', '%', '&', '\'', '*', '+', '-', '.', '^', '`', '|', '~':
return r
}
switch {
case 'a' <= r && r <= 'z':
fallthrough
case 'A' <= r && r <= 'Z':
fallthrough
case '0' <= r && r <= '9':
return r
default:
return '_'
}
}
return strings.Map(sanitizer, backend)
}

View file

@ -0,0 +1,83 @@
package cookie
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestGetName(t *testing.T) {
testCases := []struct {
desc string
cookieName string
backendName string
expectedCookieName string
}{
{
desc: "with backend name, without cookie name",
cookieName: "",
backendName: "/my/BACKEND-v1.0~rc1",
expectedCookieName: "_5f7bc",
},
{
desc: "without backend name, with cookie name",
cookieName: "/my/BACKEND-v1.0~rc1",
backendName: "",
expectedCookieName: "_my_BACKEND-v1.0~rc1",
},
{
desc: "with backend name, with cookie name",
cookieName: "containous",
backendName: "treafik",
expectedCookieName: "containous",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
cookieName := GetName(test.cookieName, test.backendName)
assert.Equal(t, test.expectedCookieName, cookieName)
})
}
}
func Test_sanitizeName(t *testing.T) {
testCases := []struct {
desc string
srcName string
expectedName string
}{
{
desc: "with /",
srcName: "/my/BACKEND-v1.0~rc1",
expectedName: "_my_BACKEND-v1.0~rc1",
},
{
desc: "some chars",
srcName: "!#$%&'()*+-./:<=>?@[]^_`{|}~",
expectedName: "!#$%&'__*+-._________^_`_|_~",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
cookieName := sanitizeName(test.srcName)
assert.Equal(t, test.expectedName, cookieName, "Cookie name")
})
}
}
func TestGenerateName(t *testing.T) {
cookieName := GenerateName("containous")
assert.Len(t, "_8a7bc", 6)
assert.Equal(t, "_8a7bc", cookieName)
}

View file

@ -0,0 +1,45 @@
package internal
import (
"context"
"strings"
"github.com/containous/traefik/pkg/log"
)
type contextKey int
const (
providerKey contextKey = iota
)
// AddProviderInContext Adds the provider name in the context
func AddProviderInContext(ctx context.Context, elementName string) context.Context {
parts := strings.Split(elementName, ".")
if len(parts) == 1 {
log.FromContext(ctx).Debugf("Could not find a provider for %s.", elementName)
return ctx
}
if name, ok := ctx.Value(providerKey).(string); ok && name == parts[0] {
return ctx
}
return context.WithValue(ctx, providerKey, parts[0])
}
// GetQualifiedName Gets the fully qualified name.
func GetQualifiedName(ctx context.Context, elementName string) string {
parts := strings.Split(elementName, ".")
if len(parts) == 1 {
if providerName, ok := ctx.Value(providerKey).(string); ok {
return MakeQualifiedName(providerName, parts[0])
}
}
return elementName
}
// MakeQualifiedName Creates a qualified name for an element
func MakeQualifiedName(providerName string, elementName string) string {
return providerName + "." + elementName
}

View file

@ -0,0 +1,343 @@
package middleware
import (
"context"
"fmt"
"net/http"
"strings"
"github.com/containous/alice"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares/addprefix"
"github.com/containous/traefik/pkg/middlewares/auth"
"github.com/containous/traefik/pkg/middlewares/buffering"
"github.com/containous/traefik/pkg/middlewares/chain"
"github.com/containous/traefik/pkg/middlewares/circuitbreaker"
"github.com/containous/traefik/pkg/middlewares/compress"
"github.com/containous/traefik/pkg/middlewares/customerrors"
"github.com/containous/traefik/pkg/middlewares/headers"
"github.com/containous/traefik/pkg/middlewares/ipwhitelist"
"github.com/containous/traefik/pkg/middlewares/maxconnection"
"github.com/containous/traefik/pkg/middlewares/passtlsclientcert"
"github.com/containous/traefik/pkg/middlewares/ratelimiter"
"github.com/containous/traefik/pkg/middlewares/redirect"
"github.com/containous/traefik/pkg/middlewares/replacepath"
"github.com/containous/traefik/pkg/middlewares/replacepathregex"
"github.com/containous/traefik/pkg/middlewares/retry"
"github.com/containous/traefik/pkg/middlewares/stripprefix"
"github.com/containous/traefik/pkg/middlewares/stripprefixregex"
"github.com/containous/traefik/pkg/middlewares/tracing"
"github.com/containous/traefik/pkg/server/internal"
"github.com/pkg/errors"
)
type middlewareStackType int
const (
middlewareStackKey middlewareStackType = iota
)
// Builder the middleware builder
type Builder struct {
configs map[string]*config.Middleware
serviceBuilder serviceBuilder
}
type serviceBuilder interface {
BuildHTTP(ctx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error)
}
// NewBuilder creates a new Builder
func NewBuilder(configs map[string]*config.Middleware, serviceBuilder serviceBuilder) *Builder {
return &Builder{configs: configs, serviceBuilder: serviceBuilder}
}
// BuildChain creates a middleware chain
func (b *Builder) BuildChain(ctx context.Context, middlewares []string) *alice.Chain {
chain := alice.New()
for _, name := range middlewares {
middlewareName := internal.GetQualifiedName(ctx, name)
chain = chain.Append(func(next http.Handler) (http.Handler, error) {
constructorContext := internal.AddProviderInContext(ctx, middlewareName)
if _, ok := b.configs[middlewareName]; !ok {
return nil, fmt.Errorf("middleware %q does not exist", middlewareName)
}
var err error
if constructorContext, err = checkRecursivity(constructorContext, middlewareName); err != nil {
return nil, err
}
constructor, err := b.buildConstructor(constructorContext, middlewareName, *b.configs[middlewareName])
if err != nil {
return nil, fmt.Errorf("error during instanciation of %s: %v", middlewareName, err)
}
return constructor(next)
})
}
return &chain
}
func checkRecursivity(ctx context.Context, middlewareName string) (context.Context, error) {
currentStack, ok := ctx.Value(middlewareStackKey).([]string)
if !ok {
currentStack = []string{}
}
if inSlice(middlewareName, currentStack) {
return ctx, fmt.Errorf("could not instantiate middleware %s: recursion detected in %s", middlewareName, strings.Join(append(currentStack, middlewareName), "->"))
}
return context.WithValue(ctx, middlewareStackKey, append(currentStack, middlewareName)), nil
}
func (b *Builder) buildConstructor(ctx context.Context, middlewareName string, config config.Middleware) (alice.Constructor, error) {
var middleware alice.Constructor
badConf := errors.New("cannot create middleware %q: multi-types middleware not supported, consider declaring two different pieces of middleware instead")
// AddPrefix
if config.AddPrefix != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return addprefix.New(ctx, next, *config.AddPrefix, middlewareName)
}
} else {
return nil, badConf
}
}
// BasicAuth
if config.BasicAuth != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return auth.NewBasic(ctx, next, *config.BasicAuth, middlewareName)
}
} else {
return nil, badConf
}
}
// Buffering
if config.Buffering != nil && config.MaxConn.Amount != 0 {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return buffering.New(ctx, next, *config.Buffering, middlewareName)
}
} else {
return nil, badConf
}
}
// Chain
if config.Chain != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return chain.New(ctx, next, *config.Chain, b, middlewareName)
}
} else {
return nil, badConf
}
}
// CircuitBreaker
if config.CircuitBreaker != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return circuitbreaker.New(ctx, next, *config.CircuitBreaker, middlewareName)
}
} else {
return nil, badConf
}
}
// Compress
if config.Compress != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return compress.New(ctx, next, middlewareName)
}
} else {
return nil, badConf
}
}
// CustomErrors
if config.Errors != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return customerrors.New(ctx, next, *config.Errors, b.serviceBuilder, middlewareName)
}
} else {
return nil, badConf
}
}
// DigestAuth
if config.DigestAuth != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return auth.NewDigest(ctx, next, *config.DigestAuth, middlewareName)
}
} else {
return nil, badConf
}
}
// ForwardAuth
if config.ForwardAuth != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return auth.NewForward(ctx, next, *config.ForwardAuth, middlewareName)
}
} else {
return nil, badConf
}
}
// Headers
if config.Headers != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return headers.New(ctx, next, *config.Headers, middlewareName)
}
} else {
return nil, badConf
}
}
// IPWhiteList
if config.IPWhiteList != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return ipwhitelist.New(ctx, next, *config.IPWhiteList, middlewareName)
}
} else {
return nil, badConf
}
}
// MaxConn
if config.MaxConn != nil && config.MaxConn.Amount != 0 {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return maxconnection.New(ctx, next, *config.MaxConn, middlewareName)
}
} else {
return nil, badConf
}
}
// PassTLSClientCert
if config.PassTLSClientCert != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return passtlsclientcert.New(ctx, next, *config.PassTLSClientCert, middlewareName)
}
} else {
return nil, badConf
}
}
// RateLimit
if config.RateLimit != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return ratelimiter.New(ctx, next, *config.RateLimit, middlewareName)
}
} else {
return nil, badConf
}
}
// RedirectRegex
if config.RedirectRegex != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return redirect.NewRedirectRegex(ctx, next, *config.RedirectRegex, middlewareName)
}
} else {
return nil, badConf
}
}
// RedirectScheme
if config.RedirectScheme != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return redirect.NewRedirectScheme(ctx, next, *config.RedirectScheme, middlewareName)
}
} else {
return nil, badConf
}
}
// ReplacePath
if config.ReplacePath != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return replacepath.New(ctx, next, *config.ReplacePath, middlewareName)
}
} else {
return nil, badConf
}
}
// ReplacePathRegex
if config.ReplacePathRegex != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return replacepathregex.New(ctx, next, *config.ReplacePathRegex, middlewareName)
}
} else {
return nil, badConf
}
}
// Retry
if config.Retry != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
// FIXME missing metrics / accessLog
return retry.New(ctx, next, *config.Retry, retry.Listeners{}, middlewareName)
}
} else {
return nil, badConf
}
}
// StripPrefix
if config.StripPrefix != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return stripprefix.New(ctx, next, *config.StripPrefix, middlewareName)
}
} else {
return nil, badConf
}
}
// StripPrefixRegex
if config.StripPrefixRegex != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return stripprefixregex.New(ctx, next, *config.StripPrefixRegex, middlewareName)
}
} else {
return nil, badConf
}
}
if middleware == nil {
return nil, fmt.Errorf("middleware %q does not exist", middlewareName)
}
return tracing.Wrap(ctx, middleware), nil
}
func inSlice(element string, stack []string) bool {
for _, value := range stack {
if value == element {
return true
}
}
return false
}

View file

@ -0,0 +1,391 @@
package middleware
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/server/internal"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBuilder_buildConstructorCircuitBreaker(t *testing.T) {
testConfig := map[string]*config.Middleware{
"empty": {
CircuitBreaker: &config.CircuitBreaker{
Expression: "",
},
},
"foo": {
CircuitBreaker: &config.CircuitBreaker{
Expression: "NetworkErrorRatio() > 0.5",
},
},
}
middlewaresBuilder := NewBuilder(testConfig, nil)
emptyHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})
testCases := []struct {
desc string
middlewareID string
expectedError bool
}{
{
desc: "Should fail at creating a circuit breaker with an empty expression",
expectedError: true,
middlewareID: "empty",
}, {
desc: "Should create a circuit breaker with a valid expression",
expectedError: false,
middlewareID: "foo",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
constructor, err := middlewaresBuilder.buildConstructor(context.Background(), test.middlewareID, *testConfig[test.middlewareID])
require.NoError(t, err)
middleware, err2 := constructor(emptyHandler)
if test.expectedError {
require.Error(t, err2)
} else {
require.NoError(t, err)
require.NotNil(t, middleware)
}
})
}
}
func TestBuilder_BuildChainNilConfig(t *testing.T) {
testConfig := map[string]*config.Middleware{
"empty": {},
}
middlewaresBuilder := NewBuilder(testConfig, nil)
chain := middlewaresBuilder.BuildChain(context.Background(), []string{"empty"})
_, err := chain.Then(nil)
require.Error(t, err)
}
func TestBuilder_BuildChainNonExistentChain(t *testing.T) {
testConfig := map[string]*config.Middleware{
"foobar": {},
}
middlewaresBuilder := NewBuilder(testConfig, nil)
chain := middlewaresBuilder.BuildChain(context.Background(), []string{"empty"})
_, err := chain.Then(nil)
require.Error(t, err)
}
func TestBuilder_buildConstructorAddPrefix(t *testing.T) {
testConfig := map[string]*config.Middleware{
"empty": {
AddPrefix: &config.AddPrefix{
Prefix: "",
},
},
"foo": {
AddPrefix: &config.AddPrefix{
Prefix: "foo/",
},
},
}
middlewaresBuilder := NewBuilder(testConfig, nil)
testCases := []struct {
desc string
middlewareID string
expectedError bool
}{
{
desc: "Should not create an empty AddPrefix middleware when given an empty prefix",
middlewareID: "empty",
expectedError: true,
}, {
desc: "Should create an AddPrefix middleware when given a valid configuration",
middlewareID: "foo",
expectedError: false,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
constructor, err := middlewaresBuilder.buildConstructor(context.Background(), test.middlewareID, *testConfig[test.middlewareID])
require.NoError(t, err)
middleware, err2 := constructor(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}))
if test.expectedError {
require.Error(t, err2)
} else {
require.NoError(t, err)
require.NotNil(t, middleware)
}
})
}
}
func TestBuild_BuildChainWithContext(t *testing.T) {
testCases := []struct {
desc string
buildChain []string
configuration map[string]*config.Middleware
expected map[string]string
contextProvider string
expectedError error
}{
{
desc: "Simple middleware",
buildChain: []string{"middleware-1"},
configuration: map[string]*config.Middleware{
"middleware-1": {
Headers: &config.Headers{
CustomRequestHeaders: map[string]string{"middleware-1": "value-middleware-1"},
},
},
},
expected: map[string]string{"middleware-1": "value-middleware-1"},
},
{
desc: "Middleware that references a chain",
buildChain: []string{"middleware-chain-1"},
configuration: map[string]*config.Middleware{
"middleware-1": {
Headers: &config.Headers{
CustomRequestHeaders: map[string]string{"middleware-1": "value-middleware-1"},
},
},
"middleware-chain-1": {
Chain: &config.Chain{
Middlewares: []string{"middleware-1"},
},
},
},
expected: map[string]string{"middleware-1": "value-middleware-1"},
},
{
desc: "Should prefix the middlewareName with the provider in the context",
buildChain: []string{"middleware-1"},
configuration: map[string]*config.Middleware{
"provider-1.middleware-1": {
Headers: &config.Headers{
CustomRequestHeaders: map[string]string{"provider-1.middleware-1": "value-middleware-1"},
},
},
},
expected: map[string]string{"provider-1.middleware-1": "value-middleware-1"},
contextProvider: "provider-1",
},
{
desc: "Should not prefix a qualified middlewareName with the provider in the context",
buildChain: []string{"provider-1.middleware-1"},
configuration: map[string]*config.Middleware{
"provider-1.middleware-1": {
Headers: &config.Headers{
CustomRequestHeaders: map[string]string{"provider-1.middleware-1": "value-middleware-1"},
},
},
},
expected: map[string]string{"provider-1.middleware-1": "value-middleware-1"},
contextProvider: "provider-1",
},
{
desc: "Should be context aware if a chain references another middleware",
buildChain: []string{"provider-1.middleware-chain-1"},
configuration: map[string]*config.Middleware{
"provider-1.middleware-1": {
Headers: &config.Headers{
CustomRequestHeaders: map[string]string{"middleware-1": "value-middleware-1"},
},
},
"provider-1.middleware-chain-1": {
Chain: &config.Chain{
Middlewares: []string{"middleware-1"},
},
},
},
expected: map[string]string{"middleware-1": "value-middleware-1"},
},
{
desc: "Should handle nested chains with different context",
buildChain: []string{"provider-1.middleware-chain-1", "middleware-chain-1"},
configuration: map[string]*config.Middleware{
"provider-1.middleware-1": {
Headers: &config.Headers{
CustomRequestHeaders: map[string]string{"middleware-1": "value-middleware-1"},
},
},
"provider-1.middleware-2": {
Headers: &config.Headers{
CustomRequestHeaders: map[string]string{"middleware-2": "value-middleware-2"},
},
},
"provider-1.middleware-chain-1": {
Chain: &config.Chain{
Middlewares: []string{"middleware-1"},
},
},
"provider-1.middleware-chain-2": {
Chain: &config.Chain{
Middlewares: []string{"middleware-2"},
},
},
"provider-2.middleware-chain-1": {
Chain: &config.Chain{
Middlewares: []string{"provider-1.middleware-2", "provider-1.middleware-chain-2"},
},
},
},
expected: map[string]string{"middleware-1": "value-middleware-1", "middleware-2": "value-middleware-2"},
contextProvider: "provider-2",
},
{
desc: "Detects recursion in Middleware chain",
buildChain: []string{"m1"},
configuration: map[string]*config.Middleware{
"ok": {
Retry: &config.Retry{},
},
"m1": {
Chain: &config.Chain{
Middlewares: []string{"m2"},
},
},
"m2": {
Chain: &config.Chain{
Middlewares: []string{"ok", "m3"},
},
},
"m3": {
Chain: &config.Chain{
Middlewares: []string{"m1"},
},
},
},
expectedError: errors.New("could not instantiate middleware m1: recursion detected in m1->m2->m3->m1"),
},
{
desc: "Detects recursion in Middleware chain",
buildChain: []string{"provider.m1"},
configuration: map[string]*config.Middleware{
"provider2.ok": {
Retry: &config.Retry{},
},
"provider.m1": {
Chain: &config.Chain{
Middlewares: []string{"provider2.m2"},
},
},
"provider2.m2": {
Chain: &config.Chain{
Middlewares: []string{"ok", "provider.m3"},
},
},
"provider.m3": {
Chain: &config.Chain{
Middlewares: []string{"m1"},
},
},
},
expectedError: errors.New("could not instantiate middleware provider.m1: recursion detected in provider.m1->provider2.m2->provider.m3->provider.m1"),
},
{
buildChain: []string{"ok", "m0"},
configuration: map[string]*config.Middleware{
"ok": {
Retry: &config.Retry{},
},
"m0": {
Chain: &config.Chain{
Middlewares: []string{"m0"},
},
},
},
expectedError: errors.New("could not instantiate middleware m0: recursion detected in m0->m0"),
},
{
desc: "Detects MiddlewareChain that references a Chain that references a Chain with a missing middleware",
buildChain: []string{"m0"},
configuration: map[string]*config.Middleware{
"m0": {
Chain: &config.Chain{
Middlewares: []string{"m1"},
},
},
"m1": {
Chain: &config.Chain{
Middlewares: []string{"m2"},
},
},
"m2": {
Chain: &config.Chain{
Middlewares: []string{"m3"},
},
},
"m3": {
Chain: &config.Chain{
Middlewares: []string{"m2"},
},
},
},
expectedError: errors.New("could not instantiate middleware m2: recursion detected in m0->m1->m2->m3->m2"),
},
{
desc: "--",
buildChain: []string{"m0"},
configuration: map[string]*config.Middleware{
"m0": {
Chain: &config.Chain{
Middlewares: []string{"m0"},
},
},
},
expectedError: errors.New("could not instantiate middleware m0: recursion detected in m0->m0"),
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
ctx := context.Background()
if len(test.contextProvider) > 0 {
ctx = internal.AddProviderInContext(ctx, test.contextProvider+".foobar")
}
builder := NewBuilder(test.configuration, nil)
result := builder.BuildChain(ctx, test.buildChain)
handlers, err := result.Then(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }))
if test.expectedError != nil {
require.NotNil(t, err)
require.Equal(t, test.expectedError.Error(), err.Error())
} else {
require.NoError(t, err)
recorder := httptest.NewRecorder()
request, _ := http.NewRequest(http.MethodGet, "http://foo/", nil)
handlers.ServeHTTP(recorder, request)
for key, value := range test.expected {
assert.Equal(t, value, request.Header.Get(key))
}
}
})
}
}

100
pkg/server/roundtripper.go Normal file
View file

@ -0,0 +1,100 @@
package server
import (
"crypto/tls"
"crypto/x509"
"net"
"net/http"
"time"
"github.com/containous/traefik/old/configuration"
"github.com/containous/traefik/pkg/config/static"
"github.com/containous/traefik/pkg/log"
traefiktls "github.com/containous/traefik/pkg/tls"
"github.com/pkg/errors"
"golang.org/x/net/http2"
)
type h2cTransportWrapper struct {
*http2.Transport
}
func (t *h2cTransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) {
req.URL.Scheme = "http"
return t.Transport.RoundTrip(req)
}
// createHTTPTransport creates an http.Transport configured with the Transport configuration settings.
// For the settings that can't be configured in Traefik it uses the default http.Transport settings.
// An exception to this is the MaxIdleConns setting as we only provide the option MaxIdleConnsPerHost
// in Traefik at this point in time. Setting this value to the default of 100 could lead to confusing
// behavior and backwards compatibility issues.
func createHTTPTransport(transportConfiguration *static.ServersTransport) (*http.Transport, error) {
if transportConfiguration == nil {
return nil, errors.New("no transport configuration given")
}
dialer := &net.Dialer{
Timeout: configuration.DefaultDialTimeout,
KeepAlive: 30 * time.Second,
DualStack: true,
}
if transportConfiguration.ForwardingTimeouts != nil {
dialer.Timeout = time.Duration(transportConfiguration.ForwardingTimeouts.DialTimeout)
}
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialer.DialContext,
MaxIdleConnsPerHost: transportConfiguration.MaxIdleConnsPerHost,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
transport.RegisterProtocol("h2c", &h2cTransportWrapper{
Transport: &http2.Transport{
DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
return net.Dial(netw, addr)
},
AllowHTTP: true,
},
})
if transportConfiguration.ForwardingTimeouts != nil {
transport.ResponseHeaderTimeout = time.Duration(transportConfiguration.ForwardingTimeouts.ResponseHeaderTimeout)
}
if transportConfiguration.InsecureSkipVerify {
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
if len(transportConfiguration.RootCAs) > 0 {
transport.TLSClientConfig = &tls.Config{
RootCAs: createRootCACertPool(transportConfiguration.RootCAs),
}
}
err := http2.ConfigureTransport(transport)
if err != nil {
return nil, err
}
return transport, nil
}
func createRootCACertPool(rootCAs traefiktls.FilesOrContents) *x509.CertPool {
roots := x509.NewCertPool()
for _, cert := range rootCAs {
certContent, err := cert.Read()
if err != nil {
log.WithoutContext().Error("Error while read RootCAs", err)
continue
}
roots.AppendCertsFromPEM(certContent)
}
return roots
}

View file

@ -0,0 +1,109 @@
package router
import (
"context"
"github.com/containous/alice"
"github.com/containous/mux"
"github.com/containous/traefik/pkg/api"
"github.com/containous/traefik/pkg/config/static"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/metrics"
"github.com/containous/traefik/pkg/safe"
"github.com/containous/traefik/pkg/types"
)
// chainBuilder The contract of the middleware builder
type chainBuilder interface {
BuildChain(ctx context.Context, middlewares []string) *alice.Chain
}
// NewRouteAppenderAggregator Creates a new RouteAppenderAggregator
func NewRouteAppenderAggregator(ctx context.Context, chainBuilder chainBuilder, conf static.Configuration, entryPointName string, currentConfiguration *safe.Safe) *RouteAppenderAggregator {
aggregator := &RouteAppenderAggregator{}
if conf.Providers != nil && conf.Providers.Rest != nil {
aggregator.AddAppender(conf.Providers.Rest)
}
if conf.API != nil && conf.API.EntryPoint == entryPointName {
chain := chainBuilder.BuildChain(ctx, conf.API.Middlewares)
aggregator.AddAppender(&WithMiddleware{
appender: api.Handler{
EntryPoint: conf.API.EntryPoint,
Dashboard: conf.API.Dashboard,
Statistics: conf.API.Statistics,
DashboardAssets: conf.API.DashboardAssets,
CurrentConfigurations: currentConfiguration,
Debug: conf.Global.Debug,
},
routerMiddlewares: chain,
})
}
if conf.Ping != nil && conf.Ping.EntryPoint == entryPointName {
chain := chainBuilder.BuildChain(ctx, conf.Ping.Middlewares)
aggregator.AddAppender(&WithMiddleware{
appender: conf.Ping,
routerMiddlewares: chain,
})
}
if conf.Metrics != nil && conf.Metrics.Prometheus != nil && conf.Metrics.Prometheus.EntryPoint == entryPointName {
chain := chainBuilder.BuildChain(ctx, conf.Metrics.Prometheus.Middlewares)
aggregator.AddAppender(&WithMiddleware{
appender: metrics.PrometheusHandler{},
routerMiddlewares: chain,
})
}
return aggregator
}
// RouteAppenderAggregator RouteAppender that aggregate other RouteAppender
type RouteAppenderAggregator struct {
appenders []types.RouteAppender
}
// Append Adds routes to the router
func (r *RouteAppenderAggregator) Append(systemRouter *mux.Router) {
for _, router := range r.appenders {
router.Append(systemRouter)
}
}
// AddAppender adds a router in the aggregator
func (r *RouteAppenderAggregator) AddAppender(router types.RouteAppender) {
r.appenders = append(r.appenders, router)
}
// WithMiddleware router with internal middleware
type WithMiddleware struct {
appender types.RouteAppender
routerMiddlewares *alice.Chain
}
// Append Adds routes to the router
func (wm *WithMiddleware) Append(systemRouter *mux.Router) {
realRouter := systemRouter.PathPrefix("/").Subrouter()
wm.appender.Append(realRouter)
if err := realRouter.Walk(wrapRoute(wm.routerMiddlewares)); err != nil {
log.WithoutContext().Error(err)
}
}
// wrapRoute with middlewares
func wrapRoute(middlewares *alice.Chain) func(*mux.Route, *mux.Router, []*mux.Route) error {
return func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
handler, err := middlewares.Then(route.GetHandler())
if err != nil {
return err
}
route.Handler(handler)
return nil
}
}

View file

@ -0,0 +1,114 @@
package router
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/alice"
"github.com/containous/mux"
"github.com/containous/traefik/pkg/config/static"
"github.com/containous/traefik/pkg/ping"
"github.com/stretchr/testify/assert"
)
type ChainBuilderMock struct {
middles map[string]alice.Constructor
}
func (c *ChainBuilderMock) BuildChain(ctx context.Context, middles []string) *alice.Chain {
chain := alice.New()
for _, mName := range middles {
if constructor, ok := c.middles[mName]; ok {
chain = chain.Append(constructor)
}
}
return &chain
}
func TestNewRouteAppenderAggregator(t *testing.T) {
testCases := []struct {
desc string
staticConf static.Configuration
middles map[string]alice.Constructor
expected map[string]int
}{
{
desc: "API with auth, ping without auth",
staticConf: static.Configuration{
Global: &static.Global{},
API: &static.API{
EntryPoint: "traefik",
Middlewares: []string{"dumb"},
},
Ping: &ping.Handler{
EntryPoint: "traefik",
},
EntryPoints: static.EntryPoints{
"traefik": {},
},
},
middles: map[string]alice.Constructor{
"dumb": func(_ http.Handler) (http.Handler, error) {
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}), nil
},
},
expected: map[string]int{
"/wrong": http.StatusBadGateway,
"/ping": http.StatusOK,
// "/.well-known/acme-challenge/token": http.StatusNotFound, // FIXME
"/api/providers": http.StatusUnauthorized,
},
},
{
desc: "Wrong entrypoint name",
staticConf: static.Configuration{
Global: &static.Global{},
API: &static.API{
EntryPoint: "no",
},
EntryPoints: static.EntryPoints{
"traefik": {},
},
},
expected: map[string]int{
"/api/providers": http.StatusBadGateway,
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
chainBuilder := &ChainBuilderMock{middles: test.middles}
ctx := context.Background()
router := NewRouteAppenderAggregator(ctx, chainBuilder, test.staticConf, "traefik", nil)
internalMuxRouter := mux.NewRouter()
router.Append(internalMuxRouter)
internalMuxRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadGateway)
})
actual := make(map[string]int)
for calledURL := range test.expected {
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, calledURL, nil)
internalMuxRouter.ServeHTTP(recorder, request)
actual[calledURL] = recorder.Code
}
assert.Equal(t, test.expected, actual)
})
}
}

View file

@ -0,0 +1,38 @@
package router
import (
"context"
"github.com/containous/traefik/pkg/config/static"
"github.com/containous/traefik/pkg/provider/acme"
"github.com/containous/traefik/pkg/safe"
"github.com/containous/traefik/pkg/server/middleware"
"github.com/containous/traefik/pkg/types"
)
// NewRouteAppenderFactory Creates a new RouteAppenderFactory
func NewRouteAppenderFactory(staticConfiguration static.Configuration, entryPointName string, acmeProvider *acme.Provider) *RouteAppenderFactory {
return &RouteAppenderFactory{
staticConfiguration: staticConfiguration,
entryPointName: entryPointName,
acmeProvider: acmeProvider,
}
}
// RouteAppenderFactory A factory of RouteAppender
type RouteAppenderFactory struct {
staticConfiguration static.Configuration
entryPointName string
acmeProvider *acme.Provider
}
// NewAppender Creates a new RouteAppender
func (r *RouteAppenderFactory) NewAppender(ctx context.Context, middlewaresBuilder *middleware.Builder, currentConfiguration *safe.Safe) types.RouteAppender {
aggregator := NewRouteAppenderAggregator(ctx, middlewaresBuilder, r.staticConfiguration, r.entryPointName, currentConfiguration)
if r.acmeProvider != nil && r.acmeProvider.HTTPChallenge != nil && r.acmeProvider.HTTPChallenge.EntryPoint == r.entryPointName {
aggregator.AddAppender(r.acmeProvider)
}
return aggregator
}

195
pkg/server/router/router.go Normal file
View file

@ -0,0 +1,195 @@
package router
import (
"context"
"fmt"
"net/http"
"github.com/containous/alice"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/middlewares/accesslog"
"github.com/containous/traefik/pkg/middlewares/recovery"
"github.com/containous/traefik/pkg/middlewares/tracing"
"github.com/containous/traefik/pkg/responsemodifiers"
"github.com/containous/traefik/pkg/rules"
"github.com/containous/traefik/pkg/server/internal"
"github.com/containous/traefik/pkg/server/middleware"
"github.com/containous/traefik/pkg/server/service"
)
const (
recoveryMiddlewareName = "traefik-internal-recovery"
)
// NewManager Creates a new Manager
func NewManager(routers map[string]*config.Router,
serviceManager *service.Manager, middlewaresBuilder *middleware.Builder, modifierBuilder *responsemodifiers.Builder,
) *Manager {
return &Manager{
routerHandlers: make(map[string]http.Handler),
configs: routers,
serviceManager: serviceManager,
middlewaresBuilder: middlewaresBuilder,
modifierBuilder: modifierBuilder,
}
}
// Manager A route/router manager
type Manager struct {
routerHandlers map[string]http.Handler
configs map[string]*config.Router
serviceManager *service.Manager
middlewaresBuilder *middleware.Builder
modifierBuilder *responsemodifiers.Builder
}
// BuildHandlers Builds handler for all entry points
func (m *Manager) BuildHandlers(rootCtx context.Context, entryPoints []string, tls bool) map[string]http.Handler {
entryPointsRouters := m.filteredRouters(rootCtx, entryPoints, tls)
entryPointHandlers := make(map[string]http.Handler)
for entryPointName, routers := range entryPointsRouters {
entryPointName := entryPointName
ctx := log.With(rootCtx, log.Str(log.EntryPointName, entryPointName))
handler, err := m.buildEntryPointHandler(ctx, routers)
if err != nil {
log.FromContext(ctx).Error(err)
continue
}
handlerWithAccessLog, err := alice.New(func(next http.Handler) (http.Handler, error) {
return accesslog.NewFieldHandler(next, log.EntryPointName, entryPointName, accesslog.AddOriginFields), nil
}).Then(handler)
if err != nil {
log.FromContext(ctx).Error(err)
entryPointHandlers[entryPointName] = handler
} else {
entryPointHandlers[entryPointName] = handlerWithAccessLog
}
}
m.serviceManager.LaunchHealthCheck()
return entryPointHandlers
}
func contains(entryPoints []string, entryPointName string) bool {
for _, name := range entryPoints {
if name == entryPointName {
return true
}
}
return false
}
func (m *Manager) filteredRouters(ctx context.Context, entryPoints []string, tls bool) map[string]map[string]*config.Router {
entryPointsRouters := make(map[string]map[string]*config.Router)
for rtName, rt := range m.configs {
if (tls && rt.TLS == nil) || (!tls && rt.TLS != nil) {
continue
}
eps := rt.EntryPoints
if len(eps) == 0 {
eps = entryPoints
}
for _, entryPointName := range eps {
if !contains(entryPoints, entryPointName) {
log.FromContext(log.With(ctx, log.Str(log.EntryPointName, entryPointName))).
Errorf("entryPoint %q doesn't exist", entryPointName)
continue
}
if _, ok := entryPointsRouters[entryPointName]; !ok {
entryPointsRouters[entryPointName] = make(map[string]*config.Router)
}
entryPointsRouters[entryPointName][rtName] = rt
}
}
return entryPointsRouters
}
func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string]*config.Router) (http.Handler, error) {
router, err := rules.NewRouter()
if err != nil {
return nil, err
}
for routerName, routerConfig := range configs {
ctxRouter := log.With(ctx, log.Str(log.RouterName, routerName))
logger := log.FromContext(ctxRouter)
ctxRouter = internal.AddProviderInContext(ctxRouter, routerName)
handler, err := m.buildRouterHandler(ctxRouter, routerName)
if err != nil {
logger.Error(err)
continue
}
err = router.AddRoute(routerConfig.Rule, routerConfig.Priority, handler)
if err != nil {
logger.Error(err)
continue
}
}
router.SortRoutes()
chain := alice.New()
chain = chain.Append(func(next http.Handler) (http.Handler, error) {
return recovery.New(ctx, next, recoveryMiddlewareName)
})
return chain.Then(router)
}
func (m *Manager) buildRouterHandler(ctx context.Context, routerName string) (http.Handler, error) {
if handler, ok := m.routerHandlers[routerName]; ok {
return handler, nil
}
configRouter, ok := m.configs[routerName]
if !ok {
return nil, fmt.Errorf("no configuration for %s", routerName)
}
handler, err := m.buildHTTPHandler(ctx, configRouter, routerName)
if err != nil {
return nil, err
}
handlerWithAccessLog, err := alice.New(func(next http.Handler) (http.Handler, error) {
return accesslog.NewFieldHandler(next, accesslog.RouterName, routerName, nil), nil
}).Then(handler)
if err != nil {
log.FromContext(ctx).Error(err)
m.routerHandlers[routerName] = handler
} else {
m.routerHandlers[routerName] = handlerWithAccessLog
}
return m.routerHandlers[routerName], nil
}
func (m *Manager) buildHTTPHandler(ctx context.Context, router *config.Router, routerName string) (http.Handler, error) {
rm := m.modifierBuilder.Build(ctx, router.Middlewares)
sHandler, err := m.serviceManager.BuildHTTP(ctx, router.Service, rm)
if err != nil {
return nil, err
}
mHandler := m.middlewaresBuilder.BuildChain(ctx, router.Middlewares)
tHandler := func(next http.Handler) (http.Handler, error) {
return tracing.NewForwarder(ctx, routerName, router.Service, next), nil
}
return alice.New().Extend(*mHandler).Append(tHandler).Then(sHandler)
}

View file

@ -0,0 +1,531 @@
package router
import (
"context"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares/accesslog"
"github.com/containous/traefik/pkg/middlewares/requestdecorator"
"github.com/containous/traefik/pkg/responsemodifiers"
"github.com/containous/traefik/pkg/server/middleware"
"github.com/containous/traefik/pkg/server/service"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/containous/traefik/pkg/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRouterManager_Get(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
type ExpectedResult struct {
StatusCode int
RequestHeaders map[string]string
}
testCases := []struct {
desc string
routersConfig map[string]*config.Router
serviceConfig map[string]*config.Service
middlewaresConfig map[string]*config.Middleware
entryPoints []string
expected ExpectedResult
}{
{
desc: "no middleware",
routersConfig: map[string]*config.Router{
"foo": {
EntryPoints: []string{"web"},
Service: "foo-service",
Rule: "Host(`foo.bar`)",
},
},
serviceConfig: map[string]*config.Service{
"foo-service": {
LoadBalancer: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: server.URL,
Weight: 1,
},
},
Method: "wrr",
},
},
},
entryPoints: []string{"web"},
expected: ExpectedResult{StatusCode: http.StatusOK},
},
{
desc: "no load balancer",
routersConfig: map[string]*config.Router{
"foo": {
EntryPoints: []string{"web"},
Service: "foo-service",
Rule: "Host(`foo.bar`)",
},
},
serviceConfig: map[string]*config.Service{
"foo-service": {},
},
entryPoints: []string{"web"},
expected: ExpectedResult{StatusCode: http.StatusNotFound},
},
{
desc: "no middleware, default entry point",
routersConfig: map[string]*config.Router{
"foo": {
Service: "foo-service",
Rule: "Host(`foo.bar`)",
},
},
serviceConfig: map[string]*config.Service{
"foo-service": {
LoadBalancer: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: server.URL,
Weight: 1,
},
},
Method: "wrr",
},
},
},
entryPoints: []string{"web"},
expected: ExpectedResult{StatusCode: http.StatusOK},
},
{
desc: "no middleware, no matching",
routersConfig: map[string]*config.Router{
"foo": {
EntryPoints: []string{"web"},
Service: "foo-service",
Rule: "Host(`bar.bar`)",
},
},
serviceConfig: map[string]*config.Service{
"foo-service": {
LoadBalancer: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: server.URL,
Weight: 1,
},
},
Method: "wrr",
},
},
},
entryPoints: []string{"web"},
expected: ExpectedResult{StatusCode: http.StatusNotFound},
},
{
desc: "middleware: headers > auth",
routersConfig: map[string]*config.Router{
"foo": {
EntryPoints: []string{"web"},
Middlewares: []string{"headers-middle", "auth-middle"},
Service: "foo-service",
Rule: "Host(`foo.bar`)",
},
},
serviceConfig: map[string]*config.Service{
"foo-service": {
LoadBalancer: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: server.URL,
Weight: 1,
},
},
Method: "wrr",
},
},
},
middlewaresConfig: map[string]*config.Middleware{
"auth-middle": {
BasicAuth: &config.BasicAuth{
Users: []string{"toto:titi"},
},
},
"headers-middle": {
Headers: &config.Headers{
CustomRequestHeaders: map[string]string{"X-Apero": "beer"},
},
},
},
entryPoints: []string{"web"},
expected: ExpectedResult{
StatusCode: http.StatusUnauthorized,
RequestHeaders: map[string]string{
"X-Apero": "beer",
},
},
},
{
desc: "middleware: auth > header",
routersConfig: map[string]*config.Router{
"foo": {
EntryPoints: []string{"web"},
Middlewares: []string{"auth-middle", "headers-middle"},
Service: "foo-service",
Rule: "Host(`foo.bar`)",
},
},
serviceConfig: map[string]*config.Service{
"foo-service": {
LoadBalancer: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: server.URL,
Weight: 1,
},
},
Method: "wrr",
},
},
},
middlewaresConfig: map[string]*config.Middleware{
"auth-middle": {
BasicAuth: &config.BasicAuth{
Users: []string{"toto:titi"},
},
},
"headers-middle": {
Headers: &config.Headers{
CustomRequestHeaders: map[string]string{"X-Apero": "beer"},
},
},
},
entryPoints: []string{"web"},
expected: ExpectedResult{
StatusCode: http.StatusUnauthorized,
RequestHeaders: map[string]string{
"X-Apero": "",
},
},
},
{
desc: "no middleware with provider name",
routersConfig: map[string]*config.Router{
"provider-1.foo": {
EntryPoints: []string{"web"},
Service: "foo-service",
Rule: "Host(`foo.bar`)",
},
},
serviceConfig: map[string]*config.Service{
"provider-1.foo-service": {
LoadBalancer: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: server.URL,
Weight: 1,
},
},
Method: "wrr",
},
},
},
entryPoints: []string{"web"},
expected: ExpectedResult{StatusCode: http.StatusOK},
},
{
desc: "no middleware with specified provider name",
routersConfig: map[string]*config.Router{
"provider-1.foo": {
EntryPoints: []string{"web"},
Service: "provider-2.foo-service",
Rule: "Host(`foo.bar`)",
},
},
serviceConfig: map[string]*config.Service{
"provider-2.foo-service": {
LoadBalancer: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: server.URL,
Weight: 1,
},
},
Method: "wrr",
},
},
},
entryPoints: []string{"web"},
expected: ExpectedResult{StatusCode: http.StatusOK},
},
{
desc: "middleware: chain with provider name",
routersConfig: map[string]*config.Router{
"provider-1.foo": {
EntryPoints: []string{"web"},
Middlewares: []string{"provider-2.chain-middle", "headers-middle"},
Service: "foo-service",
Rule: "Host(`foo.bar`)",
},
},
serviceConfig: map[string]*config.Service{
"provider-1.foo-service": {
LoadBalancer: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: server.URL,
Weight: 1,
},
},
Method: "wrr",
},
},
},
middlewaresConfig: map[string]*config.Middleware{
"provider-2.chain-middle": {
Chain: &config.Chain{Middlewares: []string{"auth-middle"}},
},
"provider-2.auth-middle": {
BasicAuth: &config.BasicAuth{
Users: []string{"toto:titi"},
},
},
"provider-1.headers-middle": {
Headers: &config.Headers{
CustomRequestHeaders: map[string]string{"X-Apero": "beer"},
},
},
},
entryPoints: []string{"web"},
expected: ExpectedResult{
StatusCode: http.StatusUnauthorized,
RequestHeaders: map[string]string{
"X-Apero": "",
},
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
serviceManager := service.NewManager(test.serviceConfig, http.DefaultTransport)
middlewaresBuilder := middleware.NewBuilder(test.middlewaresConfig, serviceManager)
responseModifierFactory := responsemodifiers.NewBuilder(test.middlewaresConfig)
routerManager := NewManager(test.routersConfig, serviceManager, middlewaresBuilder, responseModifierFactory)
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints, false)
w := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
reqHost := requestdecorator.New(nil)
reqHost.ServeHTTP(w, req, handlers["web"].ServeHTTP)
assert.Equal(t, test.expected.StatusCode, w.Code)
for key, value := range test.expected.RequestHeaders {
assert.Equal(t, value, req.Header.Get(key))
}
})
}
}
func TestAccessLog(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
testCases := []struct {
desc string
routersConfig map[string]*config.Router
serviceConfig map[string]*config.Service
middlewaresConfig map[string]*config.Middleware
entryPoints []string
expected string
}{
{
desc: "apply routerName in accesslog (first match)",
routersConfig: map[string]*config.Router{
"foo": {
EntryPoints: []string{"web"},
Service: "foo-service",
Rule: "Host(`foo.bar`)",
},
"bar": {
EntryPoints: []string{"web"},
Service: "foo-service",
Rule: "Host(`bar.foo`)",
},
},
serviceConfig: map[string]*config.Service{
"foo-service": {
LoadBalancer: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: server.URL,
Weight: 1,
},
},
Method: "wrr",
},
},
},
entryPoints: []string{"web"},
expected: "foo",
},
{
desc: "apply routerName in accesslog (second match)",
routersConfig: map[string]*config.Router{
"foo": {
EntryPoints: []string{"web"},
Service: "foo-service",
Rule: "Host(`bar.foo`)",
},
"bar": {
EntryPoints: []string{"web"},
Service: "foo-service",
Rule: "Host(`foo.bar`)",
},
},
serviceConfig: map[string]*config.Service{
"foo-service": {
LoadBalancer: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: server.URL,
Weight: 1,
},
},
Method: "wrr",
},
},
},
entryPoints: []string{"web"},
expected: "bar",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
serviceManager := service.NewManager(test.serviceConfig, http.DefaultTransport)
middlewaresBuilder := middleware.NewBuilder(test.middlewaresConfig, serviceManager)
responseModifierFactory := responsemodifiers.NewBuilder(test.middlewaresConfig)
routerManager := NewManager(test.routersConfig, serviceManager, middlewaresBuilder, responseModifierFactory)
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints, false)
w := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
accesslogger, err := accesslog.NewHandler(&types.AccessLog{
Format: "json",
})
require.NoError(t, err)
reqHost := requestdecorator.New(nil)
accesslogger.ServeHTTP(w, req, http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
reqHost.ServeHTTP(w, req, handlers["web"].ServeHTTP)
data := accesslog.GetLogData(req)
require.NotNil(t, data)
assert.Equal(t, test.expected, data.Core[accesslog.RouterName])
}))
})
}
}
type staticTransport struct {
res *http.Response
}
func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
return t.res, nil
}
func BenchmarkRouterServe(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
res := &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("")),
}
routersConfig := map[string]*config.Router{
"foo": {
EntryPoints: []string{"web"},
Service: "foo-service",
Rule: "Host(`foo.bar`) && Path(`/`)",
},
}
serviceConfig := map[string]*config.Service{
"foo-service": {
LoadBalancer: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: server.URL,
Weight: 1,
},
},
Method: "wrr",
},
},
}
entryPoints := []string{"web"}
serviceManager := service.NewManager(serviceConfig, &staticTransport{res})
middlewaresBuilder := middleware.NewBuilder(map[string]*config.Middleware{}, serviceManager)
responseModifierFactory := responsemodifiers.NewBuilder(map[string]*config.Middleware{})
routerManager := NewManager(routersConfig, serviceManager, middlewaresBuilder, responseModifierFactory)
handlers := routerManager.BuildHandlers(context.Background(), entryPoints, false)
w := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
reqHost := requestdecorator.New(nil)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
reqHost.ServeHTTP(w, req, handlers["web"].ServeHTTP)
}
}
func BenchmarkService(b *testing.B) {
res := &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("")),
}
serviceConfig := map[string]*config.Service{
"foo-service": {
LoadBalancer: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: "tchouck",
Weight: 1,
},
},
Method: "wrr",
},
},
}
serviceManager := service.NewManager(serviceConfig, &staticTransport{res})
w := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
handler, _ := serviceManager.BuildHTTP(context.Background(), "foo-service", nil)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
handler.ServeHTTP(w, req)
}
}

View file

@ -0,0 +1,139 @@
package tcp
import (
"context"
"crypto/tls"
"net/http"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/rules"
"github.com/containous/traefik/pkg/server/internal"
tcpservice "github.com/containous/traefik/pkg/server/service/tcp"
"github.com/containous/traefik/pkg/tcp"
)
// NewManager Creates a new Manager
func NewManager(routers map[string]*config.TCPRouter,
serviceManager *tcpservice.Manager,
httpHandlers map[string]http.Handler,
httpsHandlers map[string]http.Handler,
tlsConfig *tls.Config,
) *Manager {
return &Manager{
configs: routers,
serviceManager: serviceManager,
httpHandlers: httpHandlers,
httpsHandlers: httpsHandlers,
tlsConfig: tlsConfig,
}
}
// Manager is a route/router manager
type Manager struct {
configs map[string]*config.TCPRouter
serviceManager *tcpservice.Manager
httpHandlers map[string]http.Handler
httpsHandlers map[string]http.Handler
tlsConfig *tls.Config
}
// BuildHandlers builds the handlers for the given entrypoints
func (m *Manager) BuildHandlers(rootCtx context.Context, entryPoints []string) map[string]*tcp.Router {
entryPointsRouters := m.filteredRouters(rootCtx, entryPoints)
entryPointHandlers := make(map[string]*tcp.Router)
for _, entryPointName := range entryPoints {
entryPointName := entryPointName
routers := entryPointsRouters[entryPointName]
ctx := log.With(rootCtx, log.Str(log.EntryPointName, entryPointName))
handler, err := m.buildEntryPointHandler(ctx, routers, m.httpHandlers[entryPointName], m.httpsHandlers[entryPointName])
if err != nil {
log.FromContext(ctx).Error(err)
continue
}
entryPointHandlers[entryPointName] = handler
}
return entryPointHandlers
}
func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string]*config.TCPRouter, handlerHTTP http.Handler, handlerHTTPS http.Handler) (*tcp.Router, error) {
router := &tcp.Router{}
router.HTTPHandler(handlerHTTP)
router.HTTPSHandler(handlerHTTPS, m.tlsConfig)
for routerName, routerConfig := range configs {
ctxRouter := log.With(ctx, log.Str(log.RouterName, routerName))
logger := log.FromContext(ctxRouter)
ctxRouter = internal.AddProviderInContext(ctxRouter, routerName)
handler, err := m.serviceManager.BuildTCP(ctxRouter, routerConfig.Service)
if err != nil {
logger.Error(err)
continue
}
domains, err := rules.ParseHostSNI(routerConfig.Rule)
if err != nil {
log.WithoutContext().Debugf("Unknown rule %s", routerConfig.Rule)
continue
}
for _, domain := range domains {
log.WithoutContext().Debugf("Add route %s on TCP", domain)
switch {
case routerConfig.TLS != nil:
if routerConfig.TLS.Passthrough {
router.AddRoute(domain, handler)
} else {
router.AddRouteTLS(domain, handler, m.tlsConfig)
}
case domain == "*":
router.AddCatchAllNoTLS(handler)
default:
logger.Warn("TCP Router ignored, cannot specify a Host rule without TLS")
}
}
}
return router, nil
}
func contains(entryPoints []string, entryPointName string) bool {
for _, name := range entryPoints {
if name == entryPointName {
return true
}
}
return false
}
func (m *Manager) filteredRouters(ctx context.Context, entryPoints []string) map[string]map[string]*config.TCPRouter {
entryPointsRouters := make(map[string]map[string]*config.TCPRouter)
for rtName, rt := range m.configs {
eps := rt.EntryPoints
if len(eps) == 0 {
eps = entryPoints
}
for _, entryPointName := range eps {
if !contains(entryPoints, entryPointName) {
log.FromContext(log.With(ctx, log.Str(log.EntryPointName, entryPointName))).
Errorf("entryPoint %q doesn't exist", entryPointName)
continue
}
if _, ok := entryPointsRouters[entryPointName]; !ok {
entryPointsRouters[entryPointName] = make(map[string]*config.TCPRouter)
}
entryPointsRouters[entryPointName][rtName] = rt
}
}
return entryPointsRouters
}

307
pkg/server/server.go Normal file
View file

@ -0,0 +1,307 @@
package server
import (
"context"
"encoding/json"
"net/http"
"os"
"os/signal"
"sync"
"time"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/config/static"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/metrics"
"github.com/containous/traefik/pkg/middlewares/accesslog"
"github.com/containous/traefik/pkg/middlewares/requestdecorator"
"github.com/containous/traefik/pkg/provider"
"github.com/containous/traefik/pkg/safe"
"github.com/containous/traefik/pkg/server/middleware"
"github.com/containous/traefik/pkg/tls"
"github.com/containous/traefik/pkg/tracing"
"github.com/containous/traefik/pkg/tracing/datadog"
"github.com/containous/traefik/pkg/tracing/instana"
"github.com/containous/traefik/pkg/tracing/jaeger"
"github.com/containous/traefik/pkg/tracing/zipkin"
"github.com/containous/traefik/pkg/types"
)
// Server is the reverse-proxy/load-balancer engine
type Server struct {
entryPointsTCP TCPEntryPoints
configurationChan chan config.Message
configurationValidatedChan chan config.Message
signals chan os.Signal
stopChan chan bool
currentConfigurations safe.Safe
providerConfigUpdateMap map[string]chan config.Message
accessLoggerMiddleware *accesslog.Handler
tracer *tracing.Tracing
routinesPool *safe.Pool
defaultRoundTripper http.RoundTripper
metricsRegistry metrics.Registry
provider provider.Provider
configurationListeners []func(config.Configuration)
requestDecorator *requestdecorator.RequestDecorator
providersThrottleDuration time.Duration
tlsManager *tls.Manager
}
// RouteAppenderFactory the route appender factory interface
type RouteAppenderFactory interface {
NewAppender(ctx context.Context, middlewaresBuilder *middleware.Builder, currentConfigurations *safe.Safe) types.RouteAppender
}
func setupTracing(conf *static.Tracing) tracing.TrackingBackend {
switch conf.Backend {
case jaeger.Name:
return conf.Jaeger
case zipkin.Name:
return conf.Zipkin
case datadog.Name:
return conf.DataDog
case instana.Name:
return conf.Instana
default:
log.WithoutContext().Warnf("Could not initialize tracing: unknown tracer %q", conf.Backend)
return nil
}
}
// NewServer returns an initialized Server.
func NewServer(staticConfiguration static.Configuration, provider provider.Provider, entryPoints TCPEntryPoints, tlsManager *tls.Manager) *Server {
server := &Server{}
server.provider = provider
server.entryPointsTCP = entryPoints
server.configurationChan = make(chan config.Message, 100)
server.configurationValidatedChan = make(chan config.Message, 100)
server.signals = make(chan os.Signal, 1)
server.stopChan = make(chan bool, 1)
server.configureSignals()
currentConfigurations := make(config.Configurations)
server.currentConfigurations.Set(currentConfigurations)
server.providerConfigUpdateMap = make(map[string]chan config.Message)
server.tlsManager = tlsManager
if staticConfiguration.Providers != nil {
server.providersThrottleDuration = time.Duration(staticConfiguration.Providers.ProvidersThrottleDuration)
}
transport, err := createHTTPTransport(staticConfiguration.ServersTransport)
if err != nil {
log.WithoutContext().Errorf("Could not configure HTTP Transport, fallbacking on default transport: %v", err)
server.defaultRoundTripper = http.DefaultTransport
} else {
server.defaultRoundTripper = transport
}
server.routinesPool = safe.NewPool(context.Background())
if staticConfiguration.Tracing != nil {
trackingBackend := setupTracing(staticConfiguration.Tracing)
var err error
server.tracer, err = tracing.NewTracing(staticConfiguration.Tracing.ServiceName, staticConfiguration.Tracing.SpanNameLimit, trackingBackend)
if err != nil {
log.WithoutContext().Warnf("Unable to create tracer: %v", err)
}
}
server.requestDecorator = requestdecorator.New(staticConfiguration.HostResolver)
server.metricsRegistry = registerMetricClients(staticConfiguration.Metrics)
if staticConfiguration.AccessLog != nil {
var err error
server.accessLoggerMiddleware, err = accesslog.NewHandler(staticConfiguration.AccessLog)
if err != nil {
log.WithoutContext().Warnf("Unable to create access logger : %v", err)
}
}
return server
}
// Start starts the server and Stop/Close it when context is Done
func (s *Server) Start(ctx context.Context) {
go func() {
defer s.Close()
<-ctx.Done()
logger := log.FromContext(ctx)
logger.Info("I have to go...")
logger.Info("Stopping server gracefully")
s.Stop()
}()
s.startTCPServers()
s.routinesPool.Go(func(stop chan bool) {
s.listenProviders(stop)
})
s.routinesPool.Go(func(stop chan bool) {
s.listenConfigurations(stop)
})
s.startProvider()
s.routinesPool.Go(func(stop chan bool) {
s.listenSignals(stop)
})
}
// Wait blocks until server is shutted down.
func (s *Server) Wait() {
<-s.stopChan
}
// Stop stops the server
func (s *Server) Stop() {
defer log.WithoutContext().Info("Server stopped")
var wg sync.WaitGroup
for epn, ep := range s.entryPointsTCP {
wg.Add(1)
go func(entryPointName string, entryPoint *TCPEntryPoint) {
ctx := log.With(context.Background(), log.Str(log.EntryPointName, entryPointName))
defer wg.Done()
entryPoint.Shutdown(ctx)
log.FromContext(ctx).Debugf("Entry point %s closed", entryPointName)
}(epn, ep)
}
wg.Wait()
s.stopChan <- true
}
// Close destroys the server
func (s *Server) Close() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
go func(ctx context.Context) {
<-ctx.Done()
if ctx.Err() == context.Canceled {
return
} else if ctx.Err() == context.DeadlineExceeded {
panic("Timeout while stopping traefik, killing instance ✝")
}
}(ctx)
stopMetricsClients()
s.routinesPool.Cleanup()
close(s.configurationChan)
close(s.configurationValidatedChan)
signal.Stop(s.signals)
close(s.signals)
close(s.stopChan)
if s.accessLoggerMiddleware != nil {
if err := s.accessLoggerMiddleware.Close(); err != nil {
log.WithoutContext().Errorf("Could not close the access log file: %s", err)
}
}
if s.tracer != nil {
s.tracer.Close()
}
cancel()
}
func (s *Server) startTCPServers() {
// Use an empty configuration in order to initialize the default handlers with internal routes
routers := s.loadConfigurationTCP(config.Configurations{})
for entryPointName, router := range routers {
s.entryPointsTCP[entryPointName].switchRouter(router)
}
for entryPointName, serverEntryPoint := range s.entryPointsTCP {
ctx := log.With(context.Background(), log.Str(log.EntryPointName, entryPointName))
go serverEntryPoint.startTCP(ctx)
}
}
func (s *Server) listenProviders(stop chan bool) {
for {
select {
case <-stop:
return
case configMsg, ok := <-s.configurationChan:
if !ok {
return
}
if configMsg.Configuration != nil {
s.preLoadConfiguration(configMsg)
} else {
log.Debugf("Received nil configuration from provider %q, skipping.", configMsg.ProviderName)
}
}
}
}
// AddListener adds a new listener function used when new configuration is provided
func (s *Server) AddListener(listener func(config.Configuration)) {
if s.configurationListeners == nil {
s.configurationListeners = make([]func(config.Configuration), 0)
}
s.configurationListeners = append(s.configurationListeners, listener)
}
func (s *Server) startProvider() {
jsonConf, err := json.Marshal(s.provider)
if err != nil {
log.WithoutContext().Debugf("Unable to marshal provider configuration %T: %v", s.provider, err)
}
log.WithoutContext().Infof("Starting provider %T %s", s.provider, jsonConf)
currentProvider := s.provider
safe.Go(func() {
err := currentProvider.Provide(s.configurationChan, s.routinesPool)
if err != nil {
log.WithoutContext().Errorf("Error starting provider %T: %s", s.provider, err)
}
})
}
func registerMetricClients(metricsConfig *types.Metrics) metrics.Registry {
if metricsConfig == nil {
return metrics.NewVoidRegistry()
}
var registries []metrics.Registry
if metricsConfig.Prometheus != nil {
ctx := log.With(context.Background(), log.Str(log.MetricsProviderName, "prometheus"))
prometheusRegister := metrics.RegisterPrometheus(ctx, metricsConfig.Prometheus)
if prometheusRegister != nil {
registries = append(registries, prometheusRegister)
log.FromContext(ctx).Debug("Configured Prometheus metrics")
}
}
if metricsConfig.Datadog != nil {
ctx := log.With(context.Background(), log.Str(log.MetricsProviderName, "datadog"))
registries = append(registries, metrics.RegisterDatadog(ctx, metricsConfig.Datadog))
log.FromContext(ctx).Debugf("Configured DataDog metrics: pushing to %s once every %s",
metricsConfig.Datadog.Address, metricsConfig.Datadog.PushInterval)
}
if metricsConfig.StatsD != nil {
ctx := log.With(context.Background(), log.Str(log.MetricsProviderName, "statsd"))
registries = append(registries, metrics.RegisterStatsd(ctx, metricsConfig.StatsD))
log.FromContext(ctx).Debugf("Configured StatsD metrics: pushing to %s once every %s",
metricsConfig.StatsD.Address, metricsConfig.StatsD.PushInterval)
}
if metricsConfig.InfluxDB != nil {
ctx := log.With(context.Background(), log.Str(log.MetricsProviderName, "influxdb"))
registries = append(registries, metrics.RegisterInfluxDB(ctx, metricsConfig.InfluxDB))
log.FromContext(ctx).Debugf("Configured InfluxDB metrics: pushing to %s once every %s",
metricsConfig.InfluxDB.Address, metricsConfig.InfluxDB.PushInterval)
}
return metrics.NewMultiRegistry(registries)
}
func stopMetricsClients() {
metrics.StopDatadog()
metrics.StopStatsd()
metrics.StopInfluxDB()
}

View file

@ -0,0 +1,272 @@
package server
import (
"context"
"crypto/tls"
"encoding/json"
"net/http"
"reflect"
"time"
"github.com/containous/alice"
"github.com/containous/mux"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/middlewares/accesslog"
"github.com/containous/traefik/pkg/middlewares/requestdecorator"
"github.com/containous/traefik/pkg/middlewares/tracing"
"github.com/containous/traefik/pkg/responsemodifiers"
"github.com/containous/traefik/pkg/server/middleware"
"github.com/containous/traefik/pkg/server/router"
routertcp "github.com/containous/traefik/pkg/server/router/tcp"
"github.com/containous/traefik/pkg/server/service"
"github.com/containous/traefik/pkg/server/service/tcp"
tcpCore "github.com/containous/traefik/pkg/tcp"
"github.com/eapache/channels"
"github.com/sirupsen/logrus"
)
// loadConfiguration manages dynamically routers, middlewares, servers and TLS configurations
func (s *Server) loadConfiguration(configMsg config.Message) {
currentConfigurations := s.currentConfigurations.Get().(config.Configurations)
// Copy configurations to new map so we don't change current if LoadConfig fails
newConfigurations := make(config.Configurations)
for k, v := range currentConfigurations {
newConfigurations[k] = v
}
newConfigurations[configMsg.ProviderName] = configMsg.Configuration
s.metricsRegistry.ConfigReloadsCounter().Add(1)
handlersTCP := s.loadConfigurationTCP(newConfigurations)
for entryPointName, router := range handlersTCP {
s.entryPointsTCP[entryPointName].switchRouter(router)
}
s.metricsRegistry.LastConfigReloadSuccessGauge().Set(float64(time.Now().Unix()))
s.currentConfigurations.Set(newConfigurations)
for _, listener := range s.configurationListeners {
listener(*configMsg.Configuration)
}
s.postLoadConfiguration()
}
// loadConfigurationTCP returns a new gorilla.mux Route from the specified global configuration and the dynamic
// provider configurations.
func (s *Server) loadConfigurationTCP(configurations config.Configurations) map[string]*tcpCore.Router {
ctx := context.TODO()
var entryPoints []string
for entryPointName := range s.entryPointsTCP {
entryPoints = append(entryPoints, entryPointName)
}
conf := mergeConfiguration(configurations)
s.tlsManager.UpdateConfigs(conf.TLSStores, conf.TLSOptions, conf.TLS)
handlersNonTLS, handlersTLS := s.createHTTPHandlers(ctx, *conf.HTTP, entryPoints)
routersTCP := s.createTCPRouters(ctx, conf.TCP, entryPoints, handlersNonTLS, handlersTLS, s.tlsManager.Get("default", "default"))
return routersTCP
}
func (s *Server) createTCPRouters(ctx context.Context, configuration *config.TCPConfiguration, entryPoints []string, handlers map[string]http.Handler, handlersTLS map[string]http.Handler, tlsConfig *tls.Config) map[string]*tcpCore.Router {
if configuration == nil {
return make(map[string]*tcpCore.Router)
}
serviceManager := tcp.NewManager(configuration.Services)
routerManager := routertcp.NewManager(configuration.Routers, serviceManager, handlers, handlersTLS, tlsConfig)
return routerManager.BuildHandlers(ctx, entryPoints)
}
func (s *Server) createHTTPHandlers(ctx context.Context, configuration config.HTTPConfiguration, entryPoints []string) (map[string]http.Handler, map[string]http.Handler) {
serviceManager := service.NewManager(configuration.Services, s.defaultRoundTripper)
middlewaresBuilder := middleware.NewBuilder(configuration.Middlewares, serviceManager)
responseModifierFactory := responsemodifiers.NewBuilder(configuration.Middlewares)
routerManager := router.NewManager(configuration.Routers, serviceManager, middlewaresBuilder, responseModifierFactory)
handlersNonTLS := routerManager.BuildHandlers(ctx, entryPoints, false)
handlersTLS := routerManager.BuildHandlers(ctx, entryPoints, true)
routerHandlers := make(map[string]http.Handler)
for _, entryPointName := range entryPoints {
internalMuxRouter := mux.NewRouter().
SkipClean(true)
ctx = log.With(ctx, log.Str(log.EntryPointName, entryPointName))
factory := s.entryPointsTCP[entryPointName].RouteAppenderFactory
if factory != nil {
// FIXME remove currentConfigurations
appender := factory.NewAppender(ctx, middlewaresBuilder, &s.currentConfigurations)
appender.Append(internalMuxRouter)
}
if h, ok := handlersNonTLS[entryPointName]; ok {
internalMuxRouter.NotFoundHandler = h
} else {
internalMuxRouter.NotFoundHandler = buildDefaultHTTPRouter()
}
routerHandlers[entryPointName] = internalMuxRouter
chain := alice.New()
if s.accessLoggerMiddleware != nil {
chain = chain.Append(accesslog.WrapHandler(s.accessLoggerMiddleware))
}
if s.tracer != nil {
chain = chain.Append(tracing.WrapEntryPointHandler(ctx, s.tracer, entryPointName))
}
chain = chain.Append(requestdecorator.WrapHandler(s.requestDecorator))
handler, err := chain.Then(internalMuxRouter.NotFoundHandler)
if err != nil {
log.FromContext(ctx).Error(err)
continue
}
internalMuxRouter.NotFoundHandler = handler
handlerTLS, ok := handlersTLS[entryPointName]
if ok {
handlerTLSWithMiddlewares, err := chain.Then(handlerTLS)
if err != nil {
log.FromContext(ctx).Error(err)
continue
}
handlersTLS[entryPointName] = handlerTLSWithMiddlewares
}
}
return routerHandlers, handlersTLS
}
func isEmptyConfiguration(conf *config.Configuration) bool {
if conf == nil {
return true
}
if conf.TCP == nil {
conf.TCP = &config.TCPConfiguration{}
}
if conf.HTTP == nil {
conf.HTTP = &config.HTTPConfiguration{}
}
return conf.HTTP.Routers == nil &&
conf.HTTP.Services == nil &&
conf.HTTP.Middlewares == nil &&
conf.TLS == nil &&
conf.TCP.Routers == nil &&
conf.TCP.Services == nil
}
func (s *Server) preLoadConfiguration(configMsg config.Message) {
s.defaultConfigurationValues(configMsg.Configuration.HTTP)
currentConfigurations := s.currentConfigurations.Get().(config.Configurations)
logger := log.WithoutContext().WithField(log.ProviderName, configMsg.ProviderName)
if log.GetLevel() == logrus.DebugLevel {
jsonConf, _ := json.Marshal(configMsg.Configuration)
logger.Debugf("Configuration received from provider %s: %s", configMsg.ProviderName, string(jsonConf))
}
if isEmptyConfiguration(configMsg.Configuration) {
logger.Infof("Skipping empty Configuration for provider %s", configMsg.ProviderName)
return
}
if reflect.DeepEqual(currentConfigurations[configMsg.ProviderName], configMsg.Configuration) {
logger.Infof("Skipping same configuration for provider %s", configMsg.ProviderName)
return
}
providerConfigUpdateCh, ok := s.providerConfigUpdateMap[configMsg.ProviderName]
if !ok {
providerConfigUpdateCh = make(chan config.Message)
s.providerConfigUpdateMap[configMsg.ProviderName] = providerConfigUpdateCh
s.routinesPool.Go(func(stop chan bool) {
s.throttleProviderConfigReload(s.providersThrottleDuration, s.configurationValidatedChan, providerConfigUpdateCh, stop)
})
}
providerConfigUpdateCh <- configMsg
}
func (s *Server) defaultConfigurationValues(configuration *config.HTTPConfiguration) {
// FIXME create a config hook
}
func (s *Server) listenConfigurations(stop chan bool) {
for {
select {
case <-stop:
return
case configMsg, ok := <-s.configurationValidatedChan:
if !ok || configMsg.Configuration == nil {
return
}
s.loadConfiguration(configMsg)
}
}
}
// throttleProviderConfigReload throttles the configuration reload speed for a single provider.
// It will immediately publish a new configuration and then only publish the next configuration after the throttle duration.
// Note that in the case it receives N new configs in the timeframe of the throttle duration after publishing,
// it will publish the last of the newly received configurations.
func (s *Server) throttleProviderConfigReload(throttle time.Duration, publish chan<- config.Message, in <-chan config.Message, stop chan bool) {
ring := channels.NewRingChannel(1)
defer ring.Close()
s.routinesPool.Go(func(stop chan bool) {
for {
select {
case <-stop:
return
case nextConfig := <-ring.Out():
if config, ok := nextConfig.(config.Message); ok {
publish <- config
time.Sleep(throttle)
}
}
}
})
for {
select {
case <-stop:
return
case nextConfig := <-in:
ring.In() <- nextConfig
}
}
}
func (s *Server) postLoadConfiguration() {
// FIXME metrics
// if s.metricsRegistry.IsEnabled() {
// activeConfig := s.currentConfigurations.Get().(config.Configurations)
// metrics.OnConfigurationUpdate(activeConfig)
// }
}
func buildDefaultHTTPRouter() *mux.Router {
rt := mux.NewRouter()
rt.NotFoundHandler = http.HandlerFunc(http.NotFound)
rt.SkipClean(true)
return rt
}

View file

@ -0,0 +1,121 @@
package server
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/config/static"
th "github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
)
func TestReuseService(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
entryPoints := TCPEntryPoints{
"http": &TCPEntryPoint{},
}
staticConfig := static.Configuration{}
dynamicConfigs := th.BuildConfiguration(
th.WithRouters(
th.WithRouter("foo",
th.WithServiceName("bar"),
th.WithRule("Path(`/ok`)")),
th.WithRouter("foo2",
th.WithEntryPoints("http"),
th.WithRule("Path(`/unauthorized`)"),
th.WithServiceName("bar"),
th.WithRouterMiddlewares("basicauth")),
),
th.WithMiddlewares(th.WithMiddleware("basicauth",
th.WithBasicAuth(&config.BasicAuth{Users: []string{"foo:bar"}}),
)),
th.WithLoadBalancerServices(th.WithService("bar",
th.WithLBMethod("wrr"),
th.WithServers(th.WithServer(testServer.URL))),
),
)
srv := NewServer(staticConfig, nil, entryPoints, nil)
entrypointsHandlers, _ := srv.createHTTPHandlers(context.Background(), *dynamicConfigs, []string{"http"})
// Test that the /ok path returns a status 200.
responseRecorderOk := &httptest.ResponseRecorder{}
requestOk := httptest.NewRequest(http.MethodGet, testServer.URL+"/ok", nil)
entrypointsHandlers["http"].ServeHTTP(responseRecorderOk, requestOk)
assert.Equal(t, http.StatusOK, responseRecorderOk.Result().StatusCode, "status code")
// Test that the /unauthorized path returns a 401 because of
// the basic authentication defined on the frontend.
responseRecorderUnauthorized := &httptest.ResponseRecorder{}
requestUnauthorized := httptest.NewRequest(http.MethodGet, testServer.URL+"/unauthorized", nil)
entrypointsHandlers["http"].ServeHTTP(responseRecorderUnauthorized, requestUnauthorized)
assert.Equal(t, http.StatusUnauthorized, responseRecorderUnauthorized.Result().StatusCode, "status code")
}
func TestThrottleProviderConfigReload(t *testing.T) {
throttleDuration := 30 * time.Millisecond
publishConfig := make(chan config.Message)
providerConfig := make(chan config.Message)
stop := make(chan bool)
defer func() {
stop <- true
}()
staticConfiguration := static.Configuration{}
server := NewServer(staticConfiguration, nil, nil, nil)
go server.throttleProviderConfigReload(throttleDuration, publishConfig, providerConfig, stop)
publishedConfigCount := 0
stopConsumeConfigs := make(chan bool)
go func() {
for {
select {
case <-stop:
return
case <-stopConsumeConfigs:
return
case <-publishConfig:
publishedConfigCount++
}
}
}()
// publish 5 new configs, one new config each 10 milliseconds
for i := 0; i < 5; i++ {
providerConfig <- config.Message{}
time.Sleep(10 * time.Millisecond)
}
// after 50 milliseconds 5 new configs were published
// with a throttle duration of 30 milliseconds this means, we should have received 2 new configs
assert.Equal(t, 2, publishedConfigCount, "times configs were published")
stopConsumeConfigs <- true
select {
case <-publishConfig:
// There should be exactly one more message that we receive after ~60 milliseconds since the start of the test.
select {
case <-publishConfig:
t.Error("extra config publication found")
case <-time.After(100 * time.Millisecond):
return
}
case <-time.After(100 * time.Millisecond):
t.Error("Last config was not published in time")
}
}

View file

@ -0,0 +1,385 @@
package server
import (
"context"
"fmt"
"net"
"net/http"
"sync"
"time"
"github.com/armon/go-proxyproto"
"github.com/containous/traefik/pkg/config/static"
"github.com/containous/traefik/pkg/h2c"
"github.com/containous/traefik/pkg/ip"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/middlewares/forwardedheaders"
"github.com/containous/traefik/pkg/safe"
"github.com/containous/traefik/pkg/tcp"
)
type httpForwarder struct {
net.Listener
connChan chan net.Conn
}
func newHTTPForwarder(ln net.Listener) *httpForwarder {
return &httpForwarder{
Listener: ln,
connChan: make(chan net.Conn),
}
}
// ServeTCP uses the connection to serve it later in "Accept"
func (h *httpForwarder) ServeTCP(conn net.Conn) {
h.connChan <- conn
}
// Accept retrieves a served connection in ServeTCP
func (h *httpForwarder) Accept() (net.Conn, error) {
conn := <-h.connChan
return conn, nil
}
// TCPEntryPoints holds a map of TCPEntryPoint (the entrypoint names being the keys)
type TCPEntryPoints map[string]*TCPEntryPoint
// TCPEntryPoint is the TCP server
type TCPEntryPoint struct {
listener net.Listener
switcher *tcp.HandlerSwitcher
RouteAppenderFactory RouteAppenderFactory
transportConfiguration *static.EntryPointsTransport
tracker *connectionTracker
httpServer *httpServer
httpsServer *httpServer
}
// NewTCPEntryPoint creates a new TCPEntryPoint
func NewTCPEntryPoint(ctx context.Context, configuration *static.EntryPoint) (*TCPEntryPoint, error) {
tracker := newConnectionTracker()
listener, err := buildListener(ctx, configuration)
if err != nil {
return nil, fmt.Errorf("error preparing server: %v", err)
}
router := &tcp.Router{}
httpServer, err := createHTTPServer(listener, configuration, true)
if err != nil {
return nil, fmt.Errorf("error preparing httpServer: %v", err)
}
router.HTTPForwarder(httpServer.Forwarder)
httpsServer, err := createHTTPServer(listener, configuration, false)
if err != nil {
return nil, fmt.Errorf("error preparing httpsServer: %v", err)
}
router.HTTPSForwarder(httpsServer.Forwarder)
tcpSwitcher := &tcp.HandlerSwitcher{}
tcpSwitcher.Switch(router)
return &TCPEntryPoint{
listener: listener,
switcher: tcpSwitcher,
transportConfiguration: configuration.Transport,
tracker: tracker,
httpServer: httpServer,
httpsServer: httpsServer,
}, nil
}
func (e *TCPEntryPoint) startTCP(ctx context.Context) {
log.FromContext(ctx).Debugf("Start TCP Server")
for {
conn, err := e.listener.Accept()
if err != nil {
log.Error(err)
return
}
safe.Go(func() {
e.switcher.ServeTCP(newTrackedConnection(conn, e.tracker))
})
}
}
// Shutdown stops the TCP connections
func (e *TCPEntryPoint) Shutdown(ctx context.Context) {
logger := log.FromContext(ctx)
reqAcceptGraceTimeOut := time.Duration(e.transportConfiguration.LifeCycle.RequestAcceptGraceTimeout)
if reqAcceptGraceTimeOut > 0 {
logger.Infof("Waiting %s for incoming requests to cease", reqAcceptGraceTimeOut)
time.Sleep(reqAcceptGraceTimeOut)
}
graceTimeOut := time.Duration(e.transportConfiguration.LifeCycle.GraceTimeOut)
ctx, cancel := context.WithTimeout(ctx, graceTimeOut)
logger.Debugf("Waiting %s seconds before killing connections.", graceTimeOut)
var wg sync.WaitGroup
if e.httpServer.Server != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := e.httpServer.Server.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
logger.Debugf("Wait server shutdown is overdue to: %s", err)
err = e.httpServer.Server.Close()
if err != nil {
logger.Error(err)
}
}
}
}()
}
if e.httpsServer.Server != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := e.httpsServer.Server.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
logger.Debugf("Wait server shutdown is overdue to: %s", err)
err = e.httpsServer.Server.Close()
if err != nil {
logger.Error(err)
}
}
}
}()
}
if e.tracker != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := e.tracker.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
logger.Debugf("Wait hijack connection is overdue to: %s", err)
e.tracker.Close()
}
}
}()
}
wg.Wait()
cancel()
}
func (e *TCPEntryPoint) switchRouter(router *tcp.Router) {
router.HTTPForwarder(e.httpServer.Forwarder)
router.HTTPSForwarder(e.httpsServer.Forwarder)
httpHandler := router.GetHTTPHandler()
httpsHandler := router.GetHTTPSHandler()
if httpHandler == nil {
httpHandler = buildDefaultHTTPRouter()
}
if httpsHandler == nil {
httpsHandler = buildDefaultHTTPRouter()
}
e.httpServer.Switcher.UpdateHandler(httpHandler)
e.httpsServer.Switcher.UpdateHandler(httpsHandler)
e.switcher.Switch(router)
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections.
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
tc, err := ln.AcceptTCP()
if err != nil {
return nil, err
}
if err = tc.SetKeepAlive(true); err != nil {
return nil, err
}
if err = tc.SetKeepAlivePeriod(3 * time.Minute); err != nil {
return nil, err
}
return tc, nil
}
func buildProxyProtocolListener(ctx context.Context, entryPoint *static.EntryPoint, listener net.Listener) (net.Listener, error) {
var sourceCheck func(addr net.Addr) (bool, error)
if entryPoint.ProxyProtocol.Insecure {
sourceCheck = func(_ net.Addr) (bool, error) {
return true, nil
}
} else {
checker, err := ip.NewChecker(entryPoint.ProxyProtocol.TrustedIPs)
if err != nil {
return nil, err
}
sourceCheck = func(addr net.Addr) (bool, error) {
ipAddr, ok := addr.(*net.TCPAddr)
if !ok {
return false, fmt.Errorf("type error %v", addr)
}
return checker.ContainsIP(ipAddr.IP), nil
}
}
log.FromContext(ctx).Infof("Enabling ProxyProtocol for trusted IPs %v", entryPoint.ProxyProtocol.TrustedIPs)
return &proxyproto.Listener{
Listener: listener,
SourceCheck: sourceCheck,
}, nil
}
func buildListener(ctx context.Context, entryPoint *static.EntryPoint) (net.Listener, error) {
listener, err := net.Listen("tcp", entryPoint.Address)
if err != nil {
return nil, fmt.Errorf("error opening listener: %v", err)
}
listener = tcpKeepAliveListener{listener.(*net.TCPListener)}
if entryPoint.ProxyProtocol != nil {
listener, err = buildProxyProtocolListener(ctx, entryPoint, listener)
if err != nil {
return nil, fmt.Errorf("error creating proxy protocol listener: %v", err)
}
}
return listener, nil
}
func newConnectionTracker() *connectionTracker {
return &connectionTracker{
conns: make(map[net.Conn]struct{}),
}
}
type connectionTracker struct {
conns map[net.Conn]struct{}
lock sync.RWMutex
}
// AddConnection add a connection in the tracked connections list
func (c *connectionTracker) AddConnection(conn net.Conn) {
c.lock.Lock()
defer c.lock.Unlock()
c.conns[conn] = struct{}{}
}
// RemoveConnection remove a connection from the tracked connections list
func (c *connectionTracker) RemoveConnection(conn net.Conn) {
c.lock.Lock()
defer c.lock.Unlock()
delete(c.conns, conn)
}
// Shutdown wait for the connection closing
func (c *connectionTracker) Shutdown(ctx context.Context) error {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
c.lock.RLock()
if len(c.conns) == 0 {
return nil
}
c.lock.RUnlock()
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
}
}
// Close close all the connections in the tracked connections list
func (c *connectionTracker) Close() {
for conn := range c.conns {
if err := conn.Close(); err != nil {
log.WithoutContext().Errorf("Error while closing connection: %v", err)
}
c.RemoveConnection(conn)
}
}
type stoppableServer interface {
Shutdown(context.Context) error
Close() error
Serve(listener net.Listener) error
}
type httpServer struct {
Server stoppableServer
Forwarder *httpForwarder
Switcher *middlewares.HTTPHandlerSwitcher
}
func createHTTPServer(ln net.Listener, configuration *static.EntryPoint, withH2c bool) (*httpServer, error) {
httpSwitcher := middlewares.NewHandlerSwitcher(buildDefaultHTTPRouter())
handler, err := forwardedheaders.NewXForwarded(
configuration.ForwardedHeaders.Insecure,
configuration.ForwardedHeaders.TrustedIPs,
httpSwitcher)
if err != nil {
return nil, err
}
var serverHTTP stoppableServer
if withH2c {
serverHTTP = &h2c.Server{
Server: &http.Server{
Handler: handler,
},
}
} else {
serverHTTP = &http.Server{
Handler: handler,
}
}
listener := newHTTPForwarder(ln)
go func() {
err := serverHTTP.Serve(listener)
if err != nil {
log.Errorf("Error while starting server: %v", err)
}
}()
return &httpServer{
Server: serverHTTP,
Forwarder: listener,
Switcher: httpSwitcher,
}, nil
}
func newTrackedConnection(conn net.Conn, tracker *connectionTracker) *trackedConnection {
tracker.AddConnection(conn)
return &trackedConnection{
Conn: conn,
tracker: tracker,
}
}
type trackedConnection struct {
tracker *connectionTracker
net.Conn
}
func (t *trackedConnection) Close() error {
t.tracker.RemoveConnection(t.Conn)
return t.Conn.Close()
}

View file

@ -0,0 +1,142 @@
package server
import (
"bufio"
"context"
"net"
"net/http"
"testing"
"time"
"github.com/containous/flaeg/parse"
"github.com/containous/traefik/pkg/config/static"
"github.com/containous/traefik/pkg/tcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestShutdownHTTP(t *testing.T) {
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
Address: ":0",
Transport: &static.EntryPointsTransport{
LifeCycle: &static.LifeCycle{
RequestAcceptGraceTimeout: 0,
GraceTimeOut: parse.Duration(5 * time.Second),
},
},
ForwardedHeaders: &static.ForwardedHeaders{},
})
require.NoError(t, err)
go entryPoint.startTCP(context.Background())
router := &tcp.Router{}
router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(1 * time.Second)
rw.WriteHeader(http.StatusOK)
}))
entryPoint.switchRouter(router)
conn, err := net.Dial("tcp", entryPoint.listener.Addr().String())
require.NoError(t, err)
go entryPoint.Shutdown(context.Background())
request, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8082", nil)
require.NoError(t, err)
err = request.Write(conn)
require.NoError(t, err)
resp, err := http.ReadResponse(bufio.NewReader(conn), request)
require.NoError(t, err)
assert.Equal(t, resp.StatusCode, http.StatusOK)
}
func TestShutdownHTTPHijacked(t *testing.T) {
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
Address: ":0",
Transport: &static.EntryPointsTransport{
LifeCycle: &static.LifeCycle{
RequestAcceptGraceTimeout: 0,
GraceTimeOut: parse.Duration(5 * time.Second),
},
},
ForwardedHeaders: &static.ForwardedHeaders{},
})
require.NoError(t, err)
go entryPoint.startTCP(context.Background())
router := &tcp.Router{}
router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
conn, _, err := rw.(http.Hijacker).Hijack()
require.NoError(t, err)
time.Sleep(1 * time.Second)
resp := http.Response{StatusCode: http.StatusOK}
err = resp.Write(conn)
require.NoError(t, err)
}))
entryPoint.switchRouter(router)
conn, err := net.Dial("tcp", entryPoint.listener.Addr().String())
require.NoError(t, err)
go entryPoint.Shutdown(context.Background())
request, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8082", nil)
require.NoError(t, err)
err = request.Write(conn)
require.NoError(t, err)
resp, err := http.ReadResponse(bufio.NewReader(conn), request)
require.NoError(t, err)
assert.Equal(t, resp.StatusCode, http.StatusOK)
}
func TestShutdownTCPConn(t *testing.T) {
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
Address: ":0",
Transport: &static.EntryPointsTransport{
LifeCycle: &static.LifeCycle{
RequestAcceptGraceTimeout: 0,
GraceTimeOut: parse.Duration(5 * time.Second),
},
},
ForwardedHeaders: &static.ForwardedHeaders{},
})
require.NoError(t, err)
go entryPoint.startTCP(context.Background())
router := &tcp.Router{}
router.AddCatchAllNoTLS(tcp.HandlerFunc(func(conn net.Conn) {
_, err := http.ReadRequest(bufio.NewReader(conn))
require.NoError(t, err)
time.Sleep(1 * time.Second)
resp := http.Response{StatusCode: http.StatusOK}
err = resp.Write(conn)
require.NoError(t, err)
}))
entryPoint.switchRouter(router)
conn, err := net.Dial("tcp", entryPoint.listener.Addr().String())
require.NoError(t, err)
go entryPoint.Shutdown(context.Background())
request, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8082", nil)
require.NoError(t, err)
err = request.Write(conn)
require.NoError(t, err)
resp, err := http.ReadResponse(bufio.NewReader(conn), request)
require.NoError(t, err)
assert.Equal(t, resp.StatusCode, http.StatusOK)
}

View file

@ -0,0 +1,37 @@
// +build !windows
package server
import (
"os/signal"
"syscall"
"github.com/containous/traefik/pkg/log"
)
func (s *Server) configureSignals() {
signal.Notify(s.signals, syscall.SIGUSR1)
}
func (s *Server) listenSignals(stop chan bool) {
for {
select {
case <-stop:
return
case sig := <-s.signals:
if sig == syscall.SIGUSR1 {
log.WithoutContext().Infof("Closing and re-opening log files for rotation: %+v", sig)
if s.accessLoggerMiddleware != nil {
if err := s.accessLoggerMiddleware.Rotate(); err != nil {
log.WithoutContext().Errorf("Error rotating access log: %v", err)
}
}
if err := log.RotateFile(); err != nil {
log.WithoutContext().Errorf("Error rotating traefik log: %v", err)
}
}
}
}
}

View file

@ -0,0 +1,7 @@
// +build windows
package server
func (s *Server) configureSignals() {}
func (s *Server) listenSignals(stop chan bool) {}

271
pkg/server/server_test.go Normal file
View file

@ -0,0 +1,271 @@
package server
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/containous/flaeg/parse"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/config/static"
th "github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
)
func TestListenProvidersSkipsEmptyConfigs(t *testing.T) {
server, stop, invokeStopChan := setupListenProvider(10 * time.Millisecond)
defer invokeStopChan()
go func() {
for {
select {
case <-stop:
return
case <-server.configurationValidatedChan:
t.Error("An empty configuration was published but it should not")
}
}
}()
server.configurationChan <- config.Message{ProviderName: "kubernetes"}
// give some time so that the configuration can be processed
time.Sleep(100 * time.Millisecond)
}
func TestListenProvidersSkipsSameConfigurationForProvider(t *testing.T) {
server, stop, invokeStopChan := setupListenProvider(10 * time.Millisecond)
defer invokeStopChan()
publishedConfigCount := 0
go func() {
for {
select {
case <-stop:
return
case conf := <-server.configurationValidatedChan:
// set the current configuration
// this is usually done in the processing part of the published configuration
// so we have to emulate the behavior here
currentConfigurations := server.currentConfigurations.Get().(config.Configurations)
currentConfigurations[conf.ProviderName] = conf.Configuration
server.currentConfigurations.Set(currentConfigurations)
publishedConfigCount++
if publishedConfigCount > 1 {
t.Error("Same configuration should not be published multiple times")
}
}
}
}()
conf := &config.Configuration{}
conf.HTTP = th.BuildConfiguration(
th.WithRouters(th.WithRouter("foo")),
th.WithLoadBalancerServices(th.WithService("bar")),
)
// provide a configuration
server.configurationChan <- config.Message{ProviderName: "kubernetes", Configuration: conf}
// give some time so that the configuration can be processed
time.Sleep(20 * time.Millisecond)
// provide the same configuration a second time
server.configurationChan <- config.Message{ProviderName: "kubernetes", Configuration: conf}
// give some time so that the configuration can be processed
time.Sleep(100 * time.Millisecond)
}
func TestListenProvidersPublishesConfigForEachProvider(t *testing.T) {
server, stop, invokeStopChan := setupListenProvider(10 * time.Millisecond)
defer invokeStopChan()
publishedProviderConfigCount := map[string]int{}
publishedConfigCount := 0
consumePublishedConfigsDone := make(chan bool)
go func() {
for {
select {
case <-stop:
return
case newConfig := <-server.configurationValidatedChan:
publishedProviderConfigCount[newConfig.ProviderName]++
publishedConfigCount++
if publishedConfigCount == 2 {
consumePublishedConfigsDone <- true
return
}
}
}
}()
conf := &config.Configuration{}
conf.HTTP = th.BuildConfiguration(
th.WithRouters(th.WithRouter("foo")),
th.WithLoadBalancerServices(th.WithService("bar")),
)
server.configurationChan <- config.Message{ProviderName: "kubernetes", Configuration: conf}
server.configurationChan <- config.Message{ProviderName: "marathon", Configuration: conf}
select {
case <-consumePublishedConfigsDone:
if val := publishedProviderConfigCount["kubernetes"]; val != 1 {
t.Errorf("Got %d configuration publication(s) for provider %q, want 1", val, "kubernetes")
}
if val := publishedProviderConfigCount["marathon"]; val != 1 {
t.Errorf("Got %d configuration publication(s) for provider %q, want 1", val, "marathon")
}
case <-time.After(100 * time.Millisecond):
t.Errorf("Published configurations were not consumed in time")
}
}
// setupListenProvider configures the Server and starts listenProviders
func setupListenProvider(throttleDuration time.Duration) (server *Server, stop chan bool, invokeStopChan func()) {
stop = make(chan bool)
invokeStopChan = func() {
stop <- true
}
staticConfiguration := static.Configuration{
Providers: &static.Providers{
ProvidersThrottleDuration: parse.Duration(throttleDuration),
},
}
server = NewServer(staticConfiguration, nil, nil, nil)
go server.listenProviders(stop)
return server, stop, invokeStopChan
}
func TestServerResponseEmptyBackend(t *testing.T) {
const requestPath = "/path"
const routeRule = "Path(`" + requestPath + "`)"
testCases := []struct {
desc string
config func(testServerURL string) *config.HTTPConfiguration
expectedStatusCode int
}{
{
desc: "Ok",
config: func(testServerURL string) *config.HTTPConfiguration {
return th.BuildConfiguration(
th.WithRouters(th.WithRouter("foo",
th.WithEntryPoints("http"),
th.WithServiceName("bar"),
th.WithRule(routeRule)),
),
th.WithLoadBalancerServices(th.WithService("bar",
th.WithLBMethod("wrr"),
th.WithServers(th.WithServer(testServerURL))),
),
)
},
expectedStatusCode: http.StatusOK,
},
{
desc: "No Frontend",
config: func(testServerURL string) *config.HTTPConfiguration {
return th.BuildConfiguration()
},
expectedStatusCode: http.StatusNotFound,
},
{
desc: "Empty Backend LB-Drr",
config: func(testServerURL string) *config.HTTPConfiguration {
return th.BuildConfiguration(
th.WithRouters(th.WithRouter("foo",
th.WithEntryPoints("http"),
th.WithServiceName("bar"),
th.WithRule(routeRule)),
),
th.WithLoadBalancerServices(th.WithService("bar",
th.WithLBMethod("drr")),
),
)
},
expectedStatusCode: http.StatusServiceUnavailable,
},
{
desc: "Empty Backend LB-Drr Sticky",
config: func(testServerURL string) *config.HTTPConfiguration {
return th.BuildConfiguration(
th.WithRouters(th.WithRouter("foo",
th.WithEntryPoints("http"),
th.WithServiceName("bar"),
th.WithRule(routeRule)),
),
th.WithLoadBalancerServices(th.WithService("bar",
th.WithLBMethod("drr"), th.WithStickiness("test")),
),
)
},
expectedStatusCode: http.StatusServiceUnavailable,
},
{
desc: "Empty Backend LB-Wrr",
config: func(testServerURL string) *config.HTTPConfiguration {
return th.BuildConfiguration(
th.WithRouters(th.WithRouter("foo",
th.WithEntryPoints("http"),
th.WithServiceName("bar"),
th.WithRule(routeRule)),
),
th.WithLoadBalancerServices(th.WithService("bar",
th.WithLBMethod("wrr")),
),
)
},
expectedStatusCode: http.StatusServiceUnavailable,
},
{
desc: "Empty Backend LB-Wrr Sticky",
config: func(testServerURL string) *config.HTTPConfiguration {
return th.BuildConfiguration(
th.WithRouters(th.WithRouter("foo",
th.WithEntryPoints("http"),
th.WithServiceName("bar"),
th.WithRule(routeRule)),
),
th.WithLoadBalancerServices(th.WithService("bar",
th.WithLBMethod("wrr"), th.WithStickiness("test")),
),
)
},
expectedStatusCode: http.StatusServiceUnavailable,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
testServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
globalConfig := static.Configuration{}
entryPointsConfig := TCPEntryPoints{
"http": &TCPEntryPoint{},
}
srv := NewServer(globalConfig, nil, entryPointsConfig, nil)
entryPoints, _ := srv.createHTTPHandlers(context.Background(), *test.config(testServer.URL), []string{"http"})
responseRecorder := &httptest.ResponseRecorder{}
request := httptest.NewRequest(http.MethodGet, testServer.URL+requestPath, nil)
entryPoints["http"].ServeHTTP(responseRecorder, request)
assert.Equal(t, test.expectedStatusCode, responseRecorder.Result().StatusCode, "status code")
})
}
}

View file

@ -0,0 +1,27 @@
package service
import "sync"
const bufferPoolSize = 32 * 1024
func newBufferPool() *bufferPool {
return &bufferPool{
pool: sync.Pool{
New: func() interface{} {
return make([]byte, bufferPoolSize)
},
},
}
}
type bufferPool struct {
pool sync.Pool
}
func (b *bufferPool) Get() []byte {
return b.pool.Get().([]byte)
}
func (b *bufferPool) Put(bytes []byte) {
b.pool.Put(bytes)
}

100
pkg/server/service/proxy.go Normal file
View file

@ -0,0 +1,100 @@
package service
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
"time"
"github.com/containous/flaeg/parse"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/log"
)
// StatusClientClosedRequest non-standard HTTP status code for client disconnection
const StatusClientClosedRequest = 499
// StatusClientClosedRequestText non-standard HTTP status for client disconnection
const StatusClientClosedRequestText = "Client Closed Request"
func buildProxy(passHostHeader bool, responseForwarding *config.ResponseForwarding, defaultRoundTripper http.RoundTripper, bufferPool httputil.BufferPool, responseModifier func(*http.Response) error) (http.Handler, error) {
var flushInterval parse.Duration
if responseForwarding != nil {
err := flushInterval.Set(responseForwarding.FlushInterval)
if err != nil {
return nil, fmt.Errorf("error creating flush interval: %v", err)
}
}
if flushInterval == 0 {
flushInterval = parse.Duration(100 * time.Millisecond)
}
proxy := &httputil.ReverseProxy{
Director: func(outReq *http.Request) {
u := outReq.URL
if outReq.RequestURI != "" {
parsedURL, err := url.ParseRequestURI(outReq.RequestURI)
if err == nil {
u = parsedURL
}
}
outReq.URL.Path = u.Path
outReq.URL.RawPath = u.RawPath
outReq.URL.RawQuery = u.RawQuery
outReq.RequestURI = "" // Outgoing request should not have RequestURI
outReq.Proto = "HTTP/1.1"
outReq.ProtoMajor = 1
outReq.ProtoMinor = 1
// Do not pass client Host header unless optsetter PassHostHeader is set.
if !passHostHeader {
outReq.Host = outReq.URL.Host
}
},
Transport: defaultRoundTripper,
FlushInterval: time.Duration(flushInterval),
ModifyResponse: responseModifier,
BufferPool: bufferPool,
ErrorHandler: func(w http.ResponseWriter, request *http.Request, err error) {
statusCode := http.StatusInternalServerError
switch {
case err == io.EOF:
statusCode = http.StatusBadGateway
case err == context.Canceled:
statusCode = StatusClientClosedRequest
default:
if e, ok := err.(net.Error); ok {
if e.Timeout() {
statusCode = http.StatusGatewayTimeout
} else {
statusCode = http.StatusBadGateway
}
}
}
log.Debugf("'%d %s' caused by: %v", statusCode, statusText(statusCode), err)
w.WriteHeader(statusCode)
_, werr := w.Write([]byte(statusText(statusCode)))
if werr != nil {
log.Debugf("Error while writing status code", werr)
}
},
}
return proxy, nil
}
func statusText(statusCode int) string {
if statusCode == StatusClientClosedRequest {
return StatusClientClosedRequestText
}
return http.StatusText(statusCode)
}

View file

@ -0,0 +1,37 @@
package service
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/containous/traefik/pkg/testhelpers"
)
type staticTransport struct {
res *http.Response
}
func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
return t.res, nil
}
func BenchmarkProxy(b *testing.B) {
res := &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("")),
}
w := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
pool := newBufferPool()
handler, _ := buildProxy(false, nil, &staticTransport{res}, pool, nil)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
handler.ServeHTTP(w, req)
}
}

View file

@ -0,0 +1,713 @@
package service
import (
"bufio"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
gorillawebsocket "github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/websocket"
)
func TestWebSocketTCPClose(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
errChan := make(chan error, 1)
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
for {
_, _, err := c.ReadMessage()
if err != nil {
errChan <- err
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
proxyAddr := proxy.Listener.Addr().String()
_, conn, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
).open()
require.NoError(t, err)
conn.Close()
serverErr := <-errChan
wsErr, ok := serverErr.(*gorillawebsocket.CloseError)
assert.Equal(t, true, ok)
assert.Equal(t, 1006, wsErr.Code)
}
func TestWebSocketPingPong(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
var upgrader = gorillawebsocket.Upgrader{
HandshakeTimeout: 10 * time.Second,
CheckOrigin: func(*http.Request) bool {
return true
},
}
mux := http.NewServeMux()
mux.HandleFunc("/ws", func(writer http.ResponseWriter, request *http.Request) {
ws, err := upgrader.Upgrade(writer, request, nil)
require.NoError(t, err)
ws.SetPingHandler(func(appData string) error {
err = ws.WriteMessage(gorillawebsocket.PongMessage, []byte(appData+"Pong"))
require.NoError(t, err)
return nil
})
_, _, _ = ws.ReadMessage()
})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL = parseURI(t, srv.URL)
f.ServeHTTP(w, req)
}))
defer proxy.Close()
serverAddr := proxy.Listener.Addr().String()
headers := http.Header{}
webSocketURL := "ws://" + serverAddr + "/ws"
headers.Add("Origin", webSocketURL)
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
require.NoError(t, err, "Error during Dial with response: %+v", resp)
defer conn.Close()
goodErr := fmt.Errorf("signal: %s", "Good data")
badErr := fmt.Errorf("signal: %s", "Bad data")
conn.SetPongHandler(func(data string) error {
if data == "PingPong" {
return goodErr
}
return badErr
})
err = conn.WriteControl(gorillawebsocket.PingMessage, []byte("Ping"), time.Now().Add(time.Second))
require.NoError(t, err)
_, _, err = conn.ReadMessage()
if err != goodErr {
require.NoError(t, err)
}
}
func TestWebSocketEcho(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
msg := make([]byte, 4)
_, err := conn.Read(msg)
require.NoError(t, err)
fmt.Println(string(msg))
_, err = conn.Write(msg)
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
}))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL = parseURI(t, srv.URL)
f.ServeHTTP(w, req)
}))
defer proxy.Close()
serverAddr := proxy.Listener.Addr().String()
headers := http.Header{}
webSocketURL := "ws://" + serverAddr + "/ws"
headers.Add("Origin", webSocketURL)
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
require.NoError(t, err, "Error during Dial with response: %+v", resp)
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
require.NoError(t, err)
fmt.Println(conn.ReadMessage())
err = conn.Close()
require.NoError(t, err)
}
func TestWebSocketPassHost(t *testing.T) {
testCases := []struct {
desc string
passHost bool
expected string
}{
{
desc: "PassHost false",
passHost: false,
},
{
desc: "PassHost true",
passHost: true,
expected: "example.com",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
f, err := buildProxy(test.passHost, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
req := conn.Request()
if test.passHost {
require.Equal(t, test.expected, req.Host)
} else {
require.NotEqual(t, test.expected, req.Host)
}
msg := make([]byte, 4)
_, err = conn.Read(msg)
require.NoError(t, err)
fmt.Println(string(msg))
_, err = conn.Write(msg)
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
}))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL = parseURI(t, srv.URL)
f.ServeHTTP(w, req)
}))
defer proxy.Close()
serverAddr := proxy.Listener.Addr().String()
headers := http.Header{}
webSocketURL := "ws://" + serverAddr + "/ws"
headers.Add("Origin", webSocketURL)
headers.Add("Host", "example.com")
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
require.NoError(t, err, "Error during Dial with response: %+v", resp)
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
require.NoError(t, err)
fmt.Println(conn.ReadMessage())
err = conn.Close()
require.NoError(t, err)
})
}
}
func TestWebSocketServerWithoutCheckOrigin(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool {
return true
}}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
for {
mt, message, err := c.ReadMessage()
if err != nil {
break
}
err = c.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
withOrigin("http://127.0.0.2"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func TestWebSocketRequestWithOrigin(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
for {
mt, message, err := c.ReadMessage()
if err != nil {
break
}
err = c.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
_, err = newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("echo"),
withOrigin("http://127.0.0.2"),
).send()
require.EqualError(t, err, "bad status")
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func TestWebSocketRequestWithQueryParams(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
assert.Equal(t, "test", r.URL.Query().Get("query"))
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws?query=test"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
conn.Close()
}))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL = parseURI(t, srv.URL)
w.Header().Set("HEADER-KEY", "HEADER-VALUE")
f.ServeHTTP(w, req)
}))
defer proxy.Close()
serverAddr := proxy.Listener.Addr().String()
headers := http.Header{}
webSocketURL := "ws://" + serverAddr + "/ws"
headers.Add("Origin", webSocketURL)
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
require.NoError(t, err, "Error during Dial with response: %+v", err, resp)
defer conn.Close()
assert.Equal(t, "HEADER-VALUE", resp.Header.Get("HEADER-KEY"))
}
func TestWebSocketRequestWithEncodedChar(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
assert.Equal(t, "/%3A%2F%2F", r.URL.EscapedPath())
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/%3A%2F%2F"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func TestWebSocketUpgradeFailed(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
mux := http.NewServeMux()
mux.HandleFunc("/ws", func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(400)
})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path // keep the original path
if path == "/ws" {
// Set new backend URL
req.URL = parseURI(t, srv.URL)
req.URL.Path = path
f.ServeHTTP(w, req)
} else {
w.WriteHeader(200)
}
}))
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
conn, err := net.DialTimeout("tcp", proxyAddr, dialTimeout)
require.NoError(t, err)
defer conn.Close()
req, err := http.NewRequest(http.MethodGet, "ws://127.0.0.1/ws", nil)
require.NoError(t, err)
req.Header.Add("upgrade", "websocket")
req.Header.Add("Connection", "upgrade")
err = req.Write(conn)
require.NoError(t, err)
// First request works with 400
br := bufio.NewReader(conn)
resp, err := http.ReadResponse(br, req)
require.NoError(t, err)
assert.Equal(t, 400, resp.StatusCode)
}
func TestForwardsWebsocketTraffic(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
_, err := conn.Write([]byte("ok"))
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
}))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("echo"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func createTLSWebsocketServer() *httptest.Server {
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
return srv
}
func TestWebSocketTransferTLSConfig(t *testing.T) {
srv := createTLSWebsocketServer()
defer srv.Close()
forwarderWithoutTLSConfig, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
proxyWithoutTLSConfig := createProxyWithForwarder(t, forwarderWithoutTLSConfig, srv.URL)
defer proxyWithoutTLSConfig.Close()
proxyAddr := proxyWithoutTLSConfig.Listener.Addr().String()
_, err = newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
).send()
require.EqualError(t, err, "bad status")
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
forwarderWithTLSConfig, err := buildProxy(true, nil, transport, nil, nil)
require.NoError(t, err)
proxyWithTLSConfig := createProxyWithForwarder(t, forwarderWithTLSConfig, srv.URL)
defer proxyWithTLSConfig.Close()
proxyAddr = proxyWithTLSConfig.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
forwarderWithTLSConfigFromDefaultTransport, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
proxyWithTLSConfigFromDefaultTransport := createProxyWithForwarder(t, forwarderWithTLSConfigFromDefaultTransport, srv.URL)
defer proxyWithTLSConfig.Close()
proxyAddr = proxyWithTLSConfigFromDefaultTransport.Listener.Addr().String()
resp, err = newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
const dialTimeout = time.Second
type websocketRequestOpt func(w *websocketRequest)
func withServer(server string) websocketRequestOpt {
return func(w *websocketRequest) {
w.ServerAddr = server
}
}
func withPath(path string) websocketRequestOpt {
return func(w *websocketRequest) {
w.Path = path
}
}
func withData(data string) websocketRequestOpt {
return func(w *websocketRequest) {
w.Data = data
}
}
func withOrigin(origin string) websocketRequestOpt {
return func(w *websocketRequest) {
w.Origin = origin
}
}
func newWebsocketRequest(opts ...websocketRequestOpt) *websocketRequest {
wsrequest := &websocketRequest{}
for _, opt := range opts {
opt(wsrequest)
}
if wsrequest.Origin == "" {
wsrequest.Origin = "http://" + wsrequest.ServerAddr
}
if wsrequest.Config == nil {
wsrequest.Config, _ = websocket.NewConfig(fmt.Sprintf("ws://%s%s", wsrequest.ServerAddr, wsrequest.Path), wsrequest.Origin)
}
return wsrequest
}
type websocketRequest struct {
ServerAddr string
Path string
Data string
Origin string
Config *websocket.Config
}
func (w *websocketRequest) send() (string, error) {
conn, _, err := w.open()
if err != nil {
return "", err
}
defer conn.Close()
if _, err := conn.Write([]byte(w.Data)); err != nil {
return "", err
}
var msg = make([]byte, 512)
var n int
n, err = conn.Read(msg)
if err != nil {
return "", err
}
received := string(msg[:n])
return received, nil
}
func (w *websocketRequest) open() (*websocket.Conn, net.Conn, error) {
client, err := net.DialTimeout("tcp", w.ServerAddr, dialTimeout)
if err != nil {
return nil, nil, err
}
conn, err := websocket.NewClient(w.Config, client)
if err != nil {
return nil, nil, err
}
return conn, client, err
}
func parseURI(t *testing.T, uri string) *url.URL {
out, err := url.ParseRequestURI(uri)
require.NoError(t, err)
return out
}
func createProxyWithForwarder(t *testing.T, proxy http.Handler, url string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path // keep the original path
// Set new backend URL
req.URL = parseURI(t, url)
req.URL.Path = path
proxy.ServeHTTP(w, req)
}))
}

View file

@ -0,0 +1,257 @@
package service
import (
"context"
"fmt"
"net/http"
"net/http/httputil"
"net/url"
"time"
"github.com/containous/alice"
"github.com/containous/traefik/old/middlewares/pipelining"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/healthcheck"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/middlewares/accesslog"
"github.com/containous/traefik/pkg/middlewares/emptybackendhandler"
"github.com/containous/traefik/pkg/server/cookie"
"github.com/containous/traefik/pkg/server/internal"
"github.com/vulcand/oxy/roundrobin"
)
const (
defaultHealthCheckInterval = 30 * time.Second
defaultHealthCheckTimeout = 5 * time.Second
)
// NewManager creates a new Manager
func NewManager(configs map[string]*config.Service, defaultRoundTripper http.RoundTripper) *Manager {
return &Manager{
bufferPool: newBufferPool(),
defaultRoundTripper: defaultRoundTripper,
balancers: make(map[string][]healthcheck.BalancerHandler),
configs: configs,
}
}
// Manager The service manager
type Manager struct {
bufferPool httputil.BufferPool
defaultRoundTripper http.RoundTripper
balancers map[string][]healthcheck.BalancerHandler
configs map[string]*config.Service
}
// BuildHTTP Creates a http.Handler for a service configuration.
func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) {
ctx := log.With(rootCtx, log.Str(log.ServiceName, serviceName))
serviceName = internal.GetQualifiedName(ctx, serviceName)
ctx = internal.AddProviderInContext(ctx, serviceName)
if conf, ok := m.configs[serviceName]; ok {
// TODO Should handle multiple service types
// FIXME Check if the service is declared multiple times with different types
if conf.LoadBalancer != nil {
return m.getLoadBalancerServiceHandler(ctx, serviceName, conf.LoadBalancer, responseModifier)
}
return nil, fmt.Errorf("the service %q doesn't have any load balancer", serviceName)
}
return nil, fmt.Errorf("the service %q does not exits", serviceName)
}
func (m *Manager) getLoadBalancerServiceHandler(
ctx context.Context,
serviceName string,
service *config.LoadBalancerService,
responseModifier func(*http.Response) error,
) (http.Handler, error) {
fwd, err := buildProxy(service.PassHostHeader, service.ResponseForwarding, m.defaultRoundTripper, m.bufferPool, responseModifier)
if err != nil {
return nil, err
}
alHandler := func(next http.Handler) (http.Handler, error) {
return accesslog.NewFieldHandler(next, accesslog.ServiceName, serviceName, accesslog.AddServiceFields), nil
}
handler, err := alice.New().Append(alHandler).Then(pipelining.NewPipelining(fwd))
if err != nil {
return nil, err
}
balancer, err := m.getLoadBalancer(ctx, serviceName, service, handler)
if err != nil {
return nil, err
}
// TODO rename and checks
m.balancers[serviceName] = append(m.balancers[serviceName], balancer)
// Empty (backend with no servers)
return emptybackendhandler.New(balancer), nil
}
// LaunchHealthCheck Launches the health checks.
func (m *Manager) LaunchHealthCheck() {
backendConfigs := make(map[string]*healthcheck.BackendConfig)
for serviceName, balancers := range m.balancers {
ctx := log.With(context.Background(), log.Str(log.ServiceName, serviceName))
// FIXME aggregate
balancer := balancers[0]
// FIXME Should all the services handle healthcheck? Handle different types
service := m.configs[serviceName].LoadBalancer
// Health Check
var backendHealthCheck *healthcheck.BackendConfig
if hcOpts := buildHealthCheckOptions(ctx, balancer, serviceName, service.HealthCheck); hcOpts != nil {
log.FromContext(ctx).Debugf("Setting up healthcheck for service %s with %s", serviceName, *hcOpts)
hcOpts.Transport = m.defaultRoundTripper
backendHealthCheck = healthcheck.NewBackendConfig(*hcOpts, serviceName)
}
if backendHealthCheck != nil {
backendConfigs[serviceName] = backendHealthCheck
}
}
// FIXME metrics and context
healthcheck.GetHealthCheck().SetBackendsConfiguration(context.TODO(), backendConfigs)
}
func buildHealthCheckOptions(ctx context.Context, lb healthcheck.BalancerHandler, backend string, hc *config.HealthCheck) *healthcheck.Options {
if hc == nil || hc.Path == "" {
return nil
}
logger := log.FromContext(ctx)
interval := defaultHealthCheckInterval
if hc.Interval != "" {
intervalOverride, err := time.ParseDuration(hc.Interval)
switch {
case err != nil:
logger.Errorf("Illegal health check interval for '%s': %s", backend, err)
case intervalOverride <= 0:
logger.Errorf("Health check interval smaller than zero for service '%s'", backend)
default:
interval = intervalOverride
}
}
timeout := defaultHealthCheckTimeout
if hc.Timeout != "" {
timeoutOverride, err := time.ParseDuration(hc.Timeout)
switch {
case err != nil:
logger.Errorf("Illegal health check timeout for backend '%s': %s", backend, err)
case timeoutOverride <= 0:
logger.Errorf("Health check timeout smaller than zero for backend '%s', backend", backend)
default:
timeout = timeoutOverride
}
}
if timeout >= interval {
logger.Warnf("Health check timeout for backend '%s' should be lower than the health check interval. Interval set to timeout + 1 second (%s).", backend)
}
return &healthcheck.Options{
Scheme: hc.Scheme,
Path: hc.Path,
Port: hc.Port,
Interval: interval,
Timeout: timeout,
LB: lb,
Hostname: hc.Hostname,
Headers: hc.Headers,
}
}
func (m *Manager) getLoadBalancer(ctx context.Context, serviceName string, service *config.LoadBalancerService, fwd http.Handler) (healthcheck.BalancerHandler, error) {
logger := log.FromContext(ctx)
var stickySession *roundrobin.StickySession
var cookieName string
if stickiness := service.Stickiness; stickiness != nil {
cookieName = cookie.GetName(stickiness.CookieName, serviceName)
stickySession = roundrobin.NewStickySession(cookieName)
}
var lb healthcheck.BalancerHandler
if service.Method == "drr" {
logger.Debug("Creating drr load-balancer")
rr, err := roundrobin.New(fwd)
if err != nil {
return nil, err
}
if stickySession != nil {
logger.Debugf("Sticky session cookie name: %v", cookieName)
lb, err = roundrobin.NewRebalancer(rr, roundrobin.RebalancerStickySession(stickySession))
if err != nil {
return nil, err
}
} else {
lb, err = roundrobin.NewRebalancer(rr)
if err != nil {
return nil, err
}
}
} else {
if service.Method != "wrr" {
logger.Warnf("Invalid load-balancing method %q, fallback to 'wrr' method", service.Method)
}
logger.Debug("Creating wrr load-balancer")
if stickySession != nil {
logger.Debugf("Sticky session cookie name: %v", cookieName)
var err error
lb, err = roundrobin.New(fwd, roundrobin.EnableStickySession(stickySession))
if err != nil {
return nil, err
}
} else {
var err error
lb, err = roundrobin.New(fwd)
if err != nil {
return nil, err
}
}
}
if err := m.upsertServers(ctx, lb, service.Servers); err != nil {
return nil, fmt.Errorf("error configuring load balancer for service %s: %v", serviceName, err)
}
return lb, nil
}
func (m *Manager) upsertServers(ctx context.Context, lb healthcheck.BalancerHandler, servers []config.Server) error {
logger := log.FromContext(ctx)
for name, srv := range servers {
u, err := url.Parse(srv.URL)
if err != nil {
return fmt.Errorf("error parsing server URL %s: %v", srv.URL, err)
}
logger.WithField(log.ServerName, name).Debugf("Creating server %d at %s with weight %d", name, u, srv.Weight)
if err := lb.UpsertServer(u, roundrobin.Weight(srv.Weight)); err != nil {
return fmt.Errorf("error adding server %s to load balancer: %v", srv.URL, err)
}
// FIXME Handle Metrics
}
return nil
}

View file

@ -0,0 +1,329 @@
package service
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/server/internal"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type MockForwarder struct{}
func (MockForwarder) ServeHTTP(http.ResponseWriter, *http.Request) {
panic("implement me")
}
func TestGetLoadBalancer(t *testing.T) {
sm := Manager{}
testCases := []struct {
desc string
serviceName string
service *config.LoadBalancerService
fwd http.Handler
expectError bool
}{
{
desc: "Fails when provided an invalid URL",
serviceName: "test",
service: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: ":",
Weight: 0,
},
},
},
fwd: &MockForwarder{},
expectError: true,
},
{
desc: "Succeeds when there are no servers",
serviceName: "test",
service: &config.LoadBalancerService{},
fwd: &MockForwarder{},
expectError: false,
},
{
desc: "Succeeds when stickiness is set",
serviceName: "test",
service: &config.LoadBalancerService{
Stickiness: &config.Stickiness{},
},
fwd: &MockForwarder{},
expectError: false,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
handler, err := sm.getLoadBalancer(context.Background(), test.serviceName, test.service, test.fwd)
if test.expectError {
require.Error(t, err)
assert.Nil(t, handler)
} else {
require.NoError(t, err)
assert.NotNil(t, handler)
}
})
}
}
func TestGetLoadBalancerServiceHandler(t *testing.T) {
sm := NewManager(nil, http.DefaultTransport)
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-From", "first")
}))
defer server1.Close()
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-From", "second")
}))
defer server2.Close()
serverPassHost := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-From", "passhost")
assert.Equal(t, "callme", r.Host)
}))
defer serverPassHost.Close()
serverPassHostFalse := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-From", "passhostfalse")
assert.NotEqual(t, "callme", r.Host)
}))
defer serverPassHostFalse.Close()
type ExpectedResult struct {
StatusCode int
XFrom string
}
testCases := []struct {
desc string
serviceName string
service *config.LoadBalancerService
responseModifier func(*http.Response) error
expected []ExpectedResult
}{
{
desc: "Load balances between the two servers",
serviceName: "test",
service: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: server1.URL,
Weight: 50,
},
{
URL: server2.URL,
Weight: 50,
},
},
Method: "wrr",
},
expected: []ExpectedResult{
{
StatusCode: http.StatusOK,
XFrom: "first",
},
{
StatusCode: http.StatusOK,
XFrom: "second",
},
},
},
{
desc: "StatusBadGateway when the server is not reachable",
serviceName: "test",
service: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: "http://foo",
Weight: 1,
},
},
Method: "wrr",
},
expected: []ExpectedResult{
{
StatusCode: http.StatusBadGateway,
},
},
},
{
desc: "ServiceUnavailable when no servers are available",
serviceName: "test",
service: &config.LoadBalancerService{
Servers: []config.Server{},
Method: "wrr",
},
expected: []ExpectedResult{
{
StatusCode: http.StatusServiceUnavailable,
},
},
},
{
desc: "Always call the same server when stickiness is true",
serviceName: "test",
service: &config.LoadBalancerService{
Stickiness: &config.Stickiness{},
Servers: []config.Server{
{
URL: server1.URL,
Weight: 1,
},
{
URL: server2.URL,
Weight: 1,
},
},
Method: "wrr",
},
expected: []ExpectedResult{
{
StatusCode: http.StatusOK,
XFrom: "first",
},
{
StatusCode: http.StatusOK,
XFrom: "first",
},
},
},
{
desc: "PassHost passes the host instead of the IP",
serviceName: "test",
service: &config.LoadBalancerService{
Stickiness: &config.Stickiness{},
PassHostHeader: true,
Servers: []config.Server{
{
URL: serverPassHost.URL,
Weight: 1,
},
},
Method: "wrr",
},
expected: []ExpectedResult{
{
StatusCode: http.StatusOK,
XFrom: "passhost",
},
},
},
{
desc: "PassHost doesn't passe the host instead of the IP",
serviceName: "test",
service: &config.LoadBalancerService{
Stickiness: &config.Stickiness{},
Servers: []config.Server{
{
URL: serverPassHostFalse.URL,
Weight: 1,
},
},
Method: "wrr",
},
expected: []ExpectedResult{
{
StatusCode: http.StatusOK,
XFrom: "passhostfalse",
},
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
handler, err := sm.getLoadBalancerServiceHandler(context.Background(), test.serviceName, test.service, test.responseModifier)
assert.NoError(t, err)
assert.NotNil(t, handler)
req := testhelpers.MustNewRequest(http.MethodGet, "http://callme", nil)
for _, expected := range test.expected {
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, expected.StatusCode, recorder.Code)
assert.Equal(t, expected.XFrom, recorder.Header().Get("X-From"))
if len(recorder.Header().Get("Set-Cookie")) > 0 {
req.Header.Set("Cookie", recorder.Header().Get("Set-Cookie"))
}
}
})
}
}
func TestManager_Build(t *testing.T) {
testCases := []struct {
desc string
serviceName string
configs map[string]*config.Service
providerName string
}{
{
desc: "Simple service name",
serviceName: "serviceName",
configs: map[string]*config.Service{
"serviceName": {
LoadBalancer: &config.LoadBalancerService{Method: "wrr"},
},
},
},
{
desc: "Service name with provider",
serviceName: "provider-1.serviceName",
configs: map[string]*config.Service{
"provider-1.serviceName": {
LoadBalancer: &config.LoadBalancerService{Method: "wrr"},
},
},
},
{
desc: "Service name with provider in context",
serviceName: "serviceName",
configs: map[string]*config.Service{
"provider-1.serviceName": {
LoadBalancer: &config.LoadBalancerService{Method: "wrr"},
},
},
providerName: "provider-1",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
manager := NewManager(test.configs, http.DefaultTransport)
ctx := context.Background()
if len(test.providerName) > 0 {
ctx = internal.AddProviderInContext(ctx, test.providerName+".foobar")
}
_, err := manager.BuildHTTP(ctx, test.serviceName, nil)
require.NoError(t, err)
})
}
}
// FIXME Add healthcheck tests

View file

@ -0,0 +1,67 @@
package tcp
import (
"context"
"fmt"
"net"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/server/internal"
"github.com/containous/traefik/pkg/tcp"
)
// Manager is the TCPHandlers factory
type Manager struct {
configs map[string]*config.TCPService
}
// NewManager creates a new manager
func NewManager(configs map[string]*config.TCPService) *Manager {
return &Manager{
configs: configs,
}
}
// BuildTCP Creates a tcp.Handler for a service configuration.
func (m *Manager) BuildTCP(rootCtx context.Context, serviceName string) (tcp.Handler, error) {
ctx := log.With(rootCtx, log.Str(log.ServiceName, serviceName))
serviceName = internal.GetQualifiedName(ctx, serviceName)
ctx = internal.AddProviderInContext(ctx, serviceName)
if conf, ok := m.configs[serviceName]; ok {
// FIXME Check if the service is declared multiple times with different types
if conf.LoadBalancer != nil {
loadBalancer := tcp.NewRRLoadBalancer()
var handler tcp.Handler
for _, server := range conf.LoadBalancer.Servers {
_, err := parseIP(server.Address)
if err == nil {
handler, _ = tcp.NewProxy(server.Address)
loadBalancer.AddServer(handler)
} else {
log.FromContext(ctx).Errorf("Invalid IP address for a %s server %s: %v", serviceName, server.Address, err)
}
}
return loadBalancer, nil
}
return nil, fmt.Errorf("the service %q doesn't have any TCP load balancer", serviceName)
}
return nil, fmt.Errorf("the service %q does not exits", serviceName)
}
func parseIP(s string) (string, error) {
ip, _, err := net.SplitHostPort(s)
if err == nil {
return ip, nil
}
ipNoPort := net.ParseIP(s)
if ipNoPort == nil {
return "", fmt.Errorf("invalid IP Address %s", ipNoPort)
}
return ipNoPort.String(), nil
}

14
pkg/server/uuid/uuid.go Normal file
View file

@ -0,0 +1,14 @@
package uuid
import guuid "github.com/satori/go.uuid"
var uuid string
func init() {
uuid = guuid.NewV4().String()
}
// Get the instance UUID
func Get() string {
return uuid
}