Dynamic Configuration Refactoring

This commit is contained in:
Ludovic Fernandez 2018-11-14 10:18:03 +01:00 committed by Traefiker Bot
parent d3ae88f108
commit a09dfa3ce1
452 changed files with 21023 additions and 9419 deletions

View file

@ -0,0 +1,292 @@
package middleware
import (
"context"
"fmt"
"net/http"
"github.com/containous/alice"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares/addprefix"
"github.com/containous/traefik/middlewares/auth"
"github.com/containous/traefik/middlewares/buffering"
"github.com/containous/traefik/middlewares/chain"
"github.com/containous/traefik/middlewares/circuitbreaker"
"github.com/containous/traefik/middlewares/compress"
"github.com/containous/traefik/middlewares/customerrors"
"github.com/containous/traefik/middlewares/headers"
"github.com/containous/traefik/middlewares/ipwhitelist"
"github.com/containous/traefik/middlewares/maxconnection"
"github.com/containous/traefik/middlewares/passtlsclientcert"
"github.com/containous/traefik/middlewares/ratelimiter"
"github.com/containous/traefik/middlewares/redirect"
"github.com/containous/traefik/middlewares/replacepath"
"github.com/containous/traefik/middlewares/replacepathregex"
"github.com/containous/traefik/middlewares/retry"
"github.com/containous/traefik/middlewares/stripprefix"
"github.com/containous/traefik/middlewares/stripprefixregex"
"github.com/containous/traefik/middlewares/tracing"
"github.com/pkg/errors"
)
// Builder the middleware builder
type Builder struct {
configs map[string]*config.Middleware
serviceBuilder serviceBuilder
}
type serviceBuilder interface {
Build(ctx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error)
}
// NewBuilder creates a new Builder
func NewBuilder(configs map[string]*config.Middleware, serviceBuilder serviceBuilder) *Builder {
return &Builder{configs: configs, serviceBuilder: serviceBuilder}
}
// BuildChain creates a middleware chain
func (b *Builder) BuildChain(ctx context.Context, middlewares []string) (*alice.Chain, error) {
chain := alice.New()
for _, middlewareName := range middlewares {
if _, ok := b.configs[middlewareName]; !ok {
return nil, fmt.Errorf("middleware %q does not exist", middlewareName)
}
constructor, err := b.buildConstructor(ctx, middlewareName, *b.configs[middlewareName])
if err != nil {
return nil, err
}
if constructor != nil {
chain = chain.Append(constructor)
}
}
return &chain, nil
}
func (b *Builder) buildConstructor(ctx context.Context, middlewareName string, config config.Middleware) (alice.Constructor, error) {
var middleware alice.Constructor
badConf := errors.New("cannot create middleware %q: multi-types middleware not supported, consider declaring two different pieces of middleware instead")
// AddPrefix
if config.AddPrefix != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return addprefix.New(ctx, next, *config.AddPrefix, middlewareName)
}
} else {
return nil, badConf
}
}
// BasicAuth
if config.BasicAuth != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return auth.NewBasic(ctx, next, *config.BasicAuth, middlewareName)
}
} else {
return nil, badConf
}
}
// Buffering
if config.Buffering != nil && config.MaxConn.Amount != 0 {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return buffering.New(ctx, next, *config.Buffering, middlewareName)
}
} else {
return nil, badConf
}
}
// Chain
if config.Chain != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return chain.New(ctx, next, *config.Chain, b, middlewareName)
}
} else {
return nil, badConf
}
}
// CircuitBreaker
if config.CircuitBreaker != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return circuitbreaker.New(ctx, next, *config.CircuitBreaker, middlewareName)
}
} else {
return nil, badConf
}
}
// Compress
if config.Compress != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return compress.New(ctx, next, middlewareName)
}
} else {
return nil, badConf
}
}
// CustomErrors
if config.Errors != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return customerrors.New(ctx, next, *config.Errors, b.serviceBuilder, middlewareName)
}
} else {
return nil, badConf
}
}
// DigestAuth
if config.DigestAuth != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return auth.NewDigest(ctx, next, *config.DigestAuth, middlewareName)
}
} else {
return nil, badConf
}
}
// ForwardAuth
if config.ForwardAuth != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return auth.NewForward(ctx, next, *config.ForwardAuth, middlewareName)
}
} else {
return nil, badConf
}
}
// Headers
if config.Headers != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return headers.New(ctx, next, *config.Headers, middlewareName)
}
} else {
return nil, badConf
}
}
// IPWhiteList
if config.IPWhiteList != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return ipwhitelist.New(ctx, next, *config.IPWhiteList, middlewareName)
}
} else {
return nil, badConf
}
}
// MaxConn
if config.MaxConn != nil && config.MaxConn.Amount != 0 {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return maxconnection.New(ctx, next, *config.MaxConn, middlewareName)
}
} else {
return nil, badConf
}
}
// PassTLSClientCert
if config.PassTLSClientCert != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return passtlsclientcert.New(ctx, next, *config.PassTLSClientCert, middlewareName)
}
} else {
return nil, badConf
}
}
// RateLimit
if config.RateLimit != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return ratelimiter.New(ctx, next, *config.RateLimit, middlewareName)
}
} else {
return nil, badConf
}
}
// Redirect
if config.Redirect != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return redirect.New(ctx, next, *config.Redirect, middlewareName)
}
} else {
return nil, badConf
}
}
// ReplacePath
if config.ReplacePath != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return replacepath.New(ctx, next, *config.ReplacePath, middlewareName)
}
} else {
return nil, badConf
}
}
// ReplacePathRegex
if config.ReplacePathRegex != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return replacepathregex.New(ctx, next, *config.ReplacePathRegex, middlewareName)
}
} else {
return nil, badConf
}
}
// Retry
if config.Retry != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
// FIXME missing metrics / accessLog
return retry.New(ctx, next, *config.Retry, retry.Listeners{}, middlewareName)
}
} else {
return nil, badConf
}
}
// StripPrefix
if config.StripPrefix != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return stripprefix.New(ctx, next, *config.StripPrefix, middlewareName)
}
} else {
return nil, badConf
}
}
// StripPrefixRegex
if config.StripPrefixRegex != nil {
if middleware == nil {
middleware = func(next http.Handler) (http.Handler, error) {
return stripprefixregex.New(ctx, next, *config.StripPrefixRegex, middlewareName)
}
} else {
return nil, badConf
}
}
return tracing.Wrap(ctx, middleware), nil
}

View file

@ -0,0 +1,127 @@
package middleware
import (
"context"
"net/http"
"testing"
"github.com/containous/traefik/config"
"github.com/stretchr/testify/require"
)
func TestMiddlewaresRegistry_BuildMiddlewareCircuitBreaker(t *testing.T) {
testConfig := map[string]*config.Middleware{
"empty": {
CircuitBreaker: &config.CircuitBreaker{
Expression: "",
},
},
"foo": {
CircuitBreaker: &config.CircuitBreaker{
Expression: "NetworkErrorRatio() > 0.5",
},
},
}
middlewaresBuilder := NewBuilder(testConfig, nil)
emptyHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})
testCases := []struct {
desc string
middlewareID string
expectedError bool
}{
{
desc: "Should fail at creating a circuit breaker with an empty expression",
expectedError: true,
middlewareID: "empty",
}, {
desc: "Should create a circuit breaker with a valid expression",
expectedError: false,
middlewareID: "foo",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
constructor, err := middlewaresBuilder.buildConstructor(context.Background(), test.middlewareID, *testConfig[test.middlewareID])
require.NoError(t, err)
middleware, err2 := constructor(emptyHandler)
if test.expectedError {
require.Error(t, err2)
} else {
require.NoError(t, err)
require.NotNil(t, middleware)
}
})
}
}
func TestMiddlewaresRegistry_BuildChainNilConfig(t *testing.T) {
testConfig := map[string]*config.Middleware{
"empty": {},
}
middlewaresBuilder := NewBuilder(testConfig, nil)
chain, err := middlewaresBuilder.BuildChain(context.Background(), []string{"empty"})
require.NoError(t, err)
_, err = chain.Then(nil)
require.NoError(t, err)
}
func TestMiddlewaresRegistry_BuildMiddlewareAddPrefix(t *testing.T) {
testConfig := map[string]*config.Middleware{
"empty": {
AddPrefix: &config.AddPrefix{
Prefix: "",
},
},
"foo": {
AddPrefix: &config.AddPrefix{
Prefix: "foo/",
},
},
}
middlewaresBuilder := NewBuilder(testConfig, nil)
testCases := []struct {
desc string
middlewareID string
expectedError bool
}{
{
desc: "Should not create an emty AddPrefix middleware when given an empty prefix",
middlewareID: "empty",
expectedError: true,
}, {
desc: "Should create an AddPrefix middleware when given a valid configuration",
middlewareID: "foo",
expectedError: false,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
constructor, err := middlewaresBuilder.buildConstructor(context.Background(), test.middlewareID, *testConfig[test.middlewareID])
require.NoError(t, err)
middleware, err2 := constructor(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}))
if test.expectedError {
require.Error(t, err2)
} else {
require.NoError(t, err)
require.NotNil(t, middleware)
}
})
}
}

94
server/roundtripper.go Normal file
View file

@ -0,0 +1,94 @@
package server
import (
"crypto/tls"
"crypto/x509"
"net"
"net/http"
"time"
"github.com/containous/traefik/log"
"github.com/containous/traefik/old/configuration"
traefiktls "github.com/containous/traefik/tls"
"golang.org/x/net/http2"
)
type h2cTransportWrapper struct {
*http2.Transport
}
func (t *h2cTransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) {
req.URL.Scheme = "http"
return t.Transport.RoundTrip(req)
}
// createHTTPTransport creates an http.Transport configured with the GlobalConfiguration settings.
// For the settings that can't be configured in Traefik it uses the default http.Transport settings.
// An exception to this is the MaxIdleConns setting as we only provide the option MaxIdleConnsPerHost
// in Traefik at this point in time. Setting this value to the default of 100 could lead to confusing
// behavior and backwards compatibility issues.
func createHTTPTransport(globalConfiguration configuration.GlobalConfiguration) (*http.Transport, error) {
dialer := &net.Dialer{
Timeout: configuration.DefaultDialTimeout,
KeepAlive: 30 * time.Second,
DualStack: true,
}
if globalConfiguration.ForwardingTimeouts != nil {
dialer.Timeout = time.Duration(globalConfiguration.ForwardingTimeouts.DialTimeout)
}
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialer.DialContext,
MaxIdleConnsPerHost: globalConfiguration.MaxIdleConnsPerHost,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
transport.RegisterProtocol("h2c", &h2cTransportWrapper{
Transport: &http2.Transport{
DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
return net.Dial(netw, addr)
},
AllowHTTP: true,
},
})
if globalConfiguration.ForwardingTimeouts != nil {
transport.ResponseHeaderTimeout = time.Duration(globalConfiguration.ForwardingTimeouts.ResponseHeaderTimeout)
}
if globalConfiguration.InsecureSkipVerify {
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
if len(globalConfiguration.RootCAs) > 0 {
transport.TLSClientConfig = &tls.Config{
RootCAs: createRootCACertPool(globalConfiguration.RootCAs),
}
}
err := http2.ConfigureTransport(transport)
if err != nil {
return nil, err
}
return transport, nil
}
func createRootCACertPool(rootCAs traefiktls.FilesOrContents) *x509.CertPool {
roots := x509.NewCertPool()
for _, cert := range rootCAs {
certContent, err := cert.Read()
if err != nil {
log.WithoutContext().Error("Error while read RootCAs", err)
continue
}
roots.AppendCertsFromPEM(certContent)
}
return roots
}

View file

@ -0,0 +1,119 @@
package router
import (
"context"
"github.com/containous/alice"
"github.com/containous/mux"
"github.com/containous/traefik/api"
"github.com/containous/traefik/config/static"
"github.com/containous/traefik/log"
"github.com/containous/traefik/metrics"
"github.com/containous/traefik/safe"
"github.com/containous/traefik/types"
)
// chainBuilder The contract of the middleware builder
type chainBuilder interface {
BuildChain(ctx context.Context, middlewares []string) (*alice.Chain, error)
}
// NewRouteAppenderAggregator Creates a new RouteAppenderAggregator
func NewRouteAppenderAggregator(ctx context.Context, chainBuilder chainBuilder, conf static.Configuration, entryPointName string, currentConfiguration *safe.Safe) *RouteAppenderAggregator {
logger := log.FromContext(ctx)
aggregator := &RouteAppenderAggregator{}
// FIXME add REST
if conf.API != nil && conf.API.EntryPoint == entryPointName {
chain, err := chainBuilder.BuildChain(ctx, conf.API.Middlewares)
if err != nil {
logger.Error(err)
} else {
aggregator.AddAppender(&WithMiddleware{
appender: api.Handler{
EntryPoint: conf.API.EntryPoint,
Dashboard: conf.API.Dashboard,
Statistics: conf.API.Statistics,
DashboardAssets: conf.API.DashboardAssets,
CurrentConfigurations: currentConfiguration,
},
routerMiddlewares: chain,
})
}
}
if conf.Ping != nil && conf.Ping.EntryPoint == entryPointName {
chain, err := chainBuilder.BuildChain(ctx, conf.Ping.Middlewares)
if err != nil {
logger.Error(err)
} else {
aggregator.AddAppender(&WithMiddleware{
appender: conf.Ping,
routerMiddlewares: chain,
})
}
}
if conf.Metrics != nil && conf.Metrics.Prometheus != nil && conf.Metrics.Prometheus.EntryPoint == entryPointName {
chain, err := chainBuilder.BuildChain(ctx, conf.Metrics.Prometheus.Middlewares)
if err != nil {
logger.Error(err)
} else {
aggregator.AddAppender(&WithMiddleware{
appender: metrics.PrometheusHandler{},
routerMiddlewares: chain,
})
}
}
return aggregator
}
// RouteAppenderAggregator RouteAppender that aggregate other RouteAppender
type RouteAppenderAggregator struct {
appenders []types.RouteAppender
}
// Append Adds routes to the router
func (r *RouteAppenderAggregator) Append(systemRouter *mux.Router) {
for _, router := range r.appenders {
router.Append(systemRouter)
}
}
// AddAppender adds a router in the aggregator
func (r *RouteAppenderAggregator) AddAppender(router types.RouteAppender) {
r.appenders = append(r.appenders, router)
}
// WithMiddleware router with internal middleware
type WithMiddleware struct {
appender types.RouteAppender
routerMiddlewares *alice.Chain
}
// Append Adds routes to the router
func (wm *WithMiddleware) Append(systemRouter *mux.Router) {
realRouter := systemRouter.PathPrefix("/").Subrouter()
wm.appender.Append(realRouter)
if err := realRouter.Walk(wrapRoute(wm.routerMiddlewares)); err != nil {
log.WithoutContext().Error(err)
}
}
// wrapRoute with middlewares
func wrapRoute(middlewares *alice.Chain) func(*mux.Route, *mux.Router, []*mux.Route) error {
return func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
handler, err := middlewares.Then(route.GetHandler())
if err != nil {
return err
}
route.Handler(handler)
return nil
}
}

View file

@ -0,0 +1,116 @@
package router
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/alice"
"github.com/containous/mux"
"github.com/containous/traefik/config/static"
"github.com/containous/traefik/ping"
"github.com/stretchr/testify/assert"
)
type ChainBuilderMock struct {
middles map[string]alice.Constructor
}
func (c *ChainBuilderMock) BuildChain(ctx context.Context, middles []string) (*alice.Chain, error) {
chain := alice.New()
for _, mName := range middles {
if constructor, ok := c.middles[mName]; ok {
chain = chain.Append(constructor)
}
}
return &chain, nil
}
func TestNewRouteAppenderAggregator(t *testing.T) {
testCases := []struct {
desc string
staticConf static.Configuration
middles map[string]alice.Constructor
expected map[string]int
}{
{
desc: "API with auth, ping without auth",
staticConf: static.Configuration{
API: &static.API{
EntryPoint: "traefik",
Middlewares: []string{"dumb"},
},
Ping: &ping.Handler{
EntryPoint: "traefik",
},
EntryPoints: &static.EntryPoints{
EntryPointList: map[string]static.EntryPoint{
"traefik": {},
},
},
},
middles: map[string]alice.Constructor{
"dumb": func(_ http.Handler) (http.Handler, error) {
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}), nil
},
},
expected: map[string]int{
"/wrong": http.StatusBadGateway,
"/ping": http.StatusOK,
//"/.well-known/acme-challenge/token": http.StatusNotFound, // FIXME
"/api/providers": http.StatusUnauthorized,
},
},
{
desc: "Wrong entrypoint name",
staticConf: static.Configuration{
API: &static.API{
EntryPoint: "no",
},
EntryPoints: &static.EntryPoints{
EntryPointList: map[string]static.EntryPoint{
"traefik": {},
},
},
},
expected: map[string]int{
"/api/providers": http.StatusBadGateway,
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
chainBuilder := &ChainBuilderMock{middles: test.middles}
ctx := context.Background()
router := NewRouteAppenderAggregator(ctx, chainBuilder, test.staticConf, "traefik", nil)
internalMuxRouter := mux.NewRouter()
router.Append(internalMuxRouter)
internalMuxRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadGateway)
})
actual := make(map[string]int)
for calledURL := range test.expected {
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, calledURL, nil)
internalMuxRouter.ServeHTTP(recorder, request)
actual[calledURL] = recorder.Code
}
assert.Equal(t, test.expected, actual)
})
}
}

View file

@ -0,0 +1,38 @@
package router
import (
"context"
"github.com/containous/traefik/config/static"
"github.com/containous/traefik/provider/acme"
"github.com/containous/traefik/safe"
"github.com/containous/traefik/server/middleware"
"github.com/containous/traefik/types"
)
// NewRouteAppenderFactory Creates a new RouteAppenderFactory
func NewRouteAppenderFactory(staticConfiguration static.Configuration, entryPointName string, acmeProvider *acme.Provider) *RouteAppenderFactory {
return &RouteAppenderFactory{
staticConfiguration: staticConfiguration,
entryPointName: entryPointName,
acmeProvider: acmeProvider,
}
}
// RouteAppenderFactory A factory of RouteAppender
type RouteAppenderFactory struct {
staticConfiguration static.Configuration
entryPointName string
acmeProvider *acme.Provider
}
// NewAppender Creates a new RouteAppender
func (r *RouteAppenderFactory) NewAppender(ctx context.Context, middlewaresBuilder *middleware.Builder, currentConfiguration *safe.Safe) types.RouteAppender {
aggregator := NewRouteAppenderAggregator(ctx, middlewaresBuilder, r.staticConfiguration, r.entryPointName, currentConfiguration)
if r.acmeProvider != nil && r.acmeProvider.HTTPChallenge != nil && r.acmeProvider.HTTPChallenge.EntryPoint == r.entryPointName {
aggregator.AddAppender(r.acmeProvider)
}
return aggregator
}

183
server/router/router.go Normal file
View file

@ -0,0 +1,183 @@
package router
import (
"context"
"fmt"
"net/http"
"github.com/containous/alice"
"github.com/containous/mux"
"github.com/containous/traefik/config"
"github.com/containous/traefik/log"
"github.com/containous/traefik/middlewares/accesslog"
"github.com/containous/traefik/middlewares/recovery"
"github.com/containous/traefik/middlewares/tracing"
"github.com/containous/traefik/responsemodifiers"
"github.com/containous/traefik/server/middleware"
"github.com/containous/traefik/server/service"
)
const (
recoveryMiddlewareName = "traefik-internal-recovery"
)
// NewManager Creates a new Manager
func NewManager(routers map[string]*config.Router,
serviceManager *service.Manager, middlewaresBuilder *middleware.Builder, modifierBuilder *responsemodifiers.Builder,
) *Manager {
return &Manager{
routerHandlers: make(map[string]http.Handler),
configs: routers,
serviceManager: serviceManager,
middlewaresBuilder: middlewaresBuilder,
modifierBuilder: modifierBuilder,
}
}
// Manager A route/router manager
type Manager struct {
routerHandlers map[string]http.Handler
configs map[string]*config.Router
serviceManager *service.Manager
middlewaresBuilder *middleware.Builder
modifierBuilder *responsemodifiers.Builder
}
// BuildHandlers Builds handler for all entry points
func (m *Manager) BuildHandlers(rootCtx context.Context, entryPoints []string, defaultEntryPoints []string) map[string]http.Handler {
entryPointsRouters := m.filteredRouters(rootCtx, entryPoints, defaultEntryPoints)
entryPointHandlers := make(map[string]http.Handler)
for entryPointName, routers := range entryPointsRouters {
ctx := log.With(rootCtx, log.Str(log.EntryPointName, entryPointName))
handler, err := m.buildEntryPointHandler(ctx, routers)
if err != nil {
log.FromContext(ctx).Error(err)
continue
}
entryPointHandlers[entryPointName] = handler
}
m.serviceManager.LaunchHealthCheck()
return entryPointHandlers
}
func contains(entryPoints []string, entryPointName string) bool {
for _, name := range entryPoints {
if name == entryPointName {
return true
}
}
return false
}
func (m *Manager) filteredRouters(ctx context.Context, entryPoints []string, defaultEntryPoints []string) map[string]map[string]*config.Router {
entryPointsRouters := make(map[string]map[string]*config.Router)
for rtName, rt := range m.configs {
eps := rt.EntryPoints
if len(eps) == 0 {
eps = defaultEntryPoints
}
for _, entryPointName := range eps {
if !contains(entryPoints, entryPointName) {
log.FromContext(log.With(ctx, log.Str(log.EntryPointName, entryPointName))).
Errorf("entryPoint %q doesn't exist", entryPointName)
continue
}
if _, ok := entryPointsRouters[entryPointName]; !ok {
entryPointsRouters[entryPointName] = make(map[string]*config.Router)
}
entryPointsRouters[entryPointName][rtName] = rt
}
}
return entryPointsRouters
}
func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string]*config.Router) (http.Handler, error) {
router := mux.NewRouter().
SkipClean(true)
for routerName, routerConfig := range configs {
ctx = log.With(ctx, log.Str(log.RouterName, routerName))
logger := log.FromContext(ctx)
handler, err := m.buildRouterHandler(ctx, routerName)
if err != nil {
logger.Error(err)
continue
}
err = addRoute(ctx, router, routerConfig.Rule, routerConfig.Priority, handler)
if err != nil {
logger.Error(err)
continue
}
}
router.SortRoutes()
chain := alice.New()
chain = chain.Append(func(next http.Handler) (http.Handler, error) {
return recovery.New(ctx, next, recoveryMiddlewareName)
})
return chain.Then(router)
}
func (m *Manager) buildRouterHandler(ctx context.Context, routerName string) (http.Handler, error) {
if handler, ok := m.routerHandlers[routerName]; ok {
return handler, nil
}
configRouter, ok := m.configs[routerName]
if !ok {
return nil, fmt.Errorf("no configuration for %s", routerName)
}
handler, err := m.buildHandler(ctx, configRouter, routerName)
if err != nil {
return nil, err
}
handlerWithAccessLog, err := alice.New(func(next http.Handler) (http.Handler, error) {
return accesslog.NewFieldHandler(next, accesslog.RouterName, routerName, nil), nil
}).Then(handler)
if err != nil {
log.FromContext(ctx).Error(err)
m.routerHandlers[routerName] = handler
} else {
m.routerHandlers[routerName] = handlerWithAccessLog
}
return m.routerHandlers[routerName], nil
}
func (m *Manager) buildHandler(ctx context.Context, router *config.Router, routerName string) (http.Handler, error) {
rm := m.modifierBuilder.Build(ctx, router.Middlewares)
sHandler, err := m.serviceManager.Build(ctx, router.Service, rm)
if err != nil {
return nil, err
}
mHandler, err := m.middlewaresBuilder.BuildChain(ctx, router.Middlewares)
if err != nil {
return nil, err
}
alHandler := func(next http.Handler) (http.Handler, error) {
return accesslog.NewFieldHandler(next, accesslog.ServiceName, router.Service, accesslog.AddServiceFields), nil
}
tHandler := func(next http.Handler) (http.Handler, error) {
return tracing.NewForwarder(ctx, routerName, router.Service, next), nil
}
return alice.New().Append(alHandler).Extend(*mHandler).Append(tHandler).Then(sHandler)
}

View file

@ -0,0 +1,334 @@
package router
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares/accesslog"
"github.com/containous/traefik/middlewares/requestdecorator"
"github.com/containous/traefik/responsemodifiers"
"github.com/containous/traefik/server/middleware"
"github.com/containous/traefik/server/service"
"github.com/containous/traefik/testhelpers"
"github.com/containous/traefik/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRouterManager_Get(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
type ExpectedResult struct {
StatusCode int
RequestHeaders map[string]string
}
testCases := []struct {
desc string
routersConfig map[string]*config.Router
serviceConfig map[string]*config.Service
middlewaresConfig map[string]*config.Middleware
entryPoints []string
defaultEntryPoints []string
expected ExpectedResult
}{
{
desc: "no middleware",
routersConfig: map[string]*config.Router{
"foo": {
EntryPoints: []string{"web"},
Service: "foo-service",
Rule: "Host:foo.bar",
},
},
serviceConfig: map[string]*config.Service{
"foo-service": {
LoadBalancer: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: server.URL,
Weight: 1,
},
},
Method: "wrr",
},
},
},
entryPoints: []string{"web"},
expected: ExpectedResult{StatusCode: http.StatusOK},
},
{
desc: "no middleware, default entry point",
routersConfig: map[string]*config.Router{
"foo": {
Service: "foo-service",
Rule: "Host:foo.bar",
},
},
serviceConfig: map[string]*config.Service{
"foo-service": {
LoadBalancer: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: server.URL,
Weight: 1,
},
},
Method: "wrr",
},
},
},
entryPoints: []string{"web"},
defaultEntryPoints: []string{"web"},
expected: ExpectedResult{StatusCode: http.StatusOK},
},
{
desc: "no middleware, no matching",
routersConfig: map[string]*config.Router{
"foo": {
EntryPoints: []string{"web"},
Service: "foo-service",
Rule: "Host:bar.bar",
},
},
serviceConfig: map[string]*config.Service{
"foo-service": {
LoadBalancer: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: server.URL,
Weight: 1,
},
},
Method: "wrr",
},
},
},
entryPoints: []string{"web"},
expected: ExpectedResult{StatusCode: http.StatusNotFound},
},
{
desc: "middleware: headers > auth",
routersConfig: map[string]*config.Router{
"foo": {
EntryPoints: []string{"web"},
Middlewares: []string{"headers-middle", "auth-middle"},
Service: "foo-service",
Rule: "Host:foo.bar",
},
},
serviceConfig: map[string]*config.Service{
"foo-service": {
LoadBalancer: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: server.URL,
Weight: 1,
},
},
Method: "wrr",
},
},
},
middlewaresConfig: map[string]*config.Middleware{
"auth-middle": {
BasicAuth: &config.BasicAuth{
Users: []string{"toto:titi"},
},
},
"headers-middle": {
Headers: &config.Headers{
CustomRequestHeaders: map[string]string{"X-Apero": "beer"},
},
},
},
entryPoints: []string{"web"},
expected: ExpectedResult{
StatusCode: http.StatusUnauthorized,
RequestHeaders: map[string]string{
"X-Apero": "beer",
},
},
},
{
desc: "middleware: auth > header",
routersConfig: map[string]*config.Router{
"foo": {
EntryPoints: []string{"web"},
Middlewares: []string{"auth-middle", "headers-middle"},
Service: "foo-service",
Rule: "Host:foo.bar",
},
},
serviceConfig: map[string]*config.Service{
"foo-service": {
LoadBalancer: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: server.URL,
Weight: 1,
},
},
Method: "wrr",
},
},
},
middlewaresConfig: map[string]*config.Middleware{
"auth-middle": {
BasicAuth: &config.BasicAuth{
Users: []string{"toto:titi"},
},
},
"headers-middle": {
Headers: &config.Headers{
CustomRequestHeaders: map[string]string{"X-Apero": "beer"},
},
},
},
entryPoints: []string{"web"},
expected: ExpectedResult{
StatusCode: http.StatusUnauthorized,
RequestHeaders: map[string]string{
"X-Apero": "",
},
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
serviceManager := service.NewManager(test.serviceConfig, http.DefaultTransport)
middlewaresBuilder := middleware.NewBuilder(test.middlewaresConfig, serviceManager)
responseModifierFactory := responsemodifiers.NewBuilder(test.middlewaresConfig)
routerManager := NewManager(test.routersConfig, serviceManager, middlewaresBuilder, responseModifierFactory)
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints, test.defaultEntryPoints)
w := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
reqHost := requestdecorator.New(nil)
reqHost.ServeHTTP(w, req, handlers["web"].ServeHTTP)
assert.Equal(t, test.expected.StatusCode, w.Code)
for key, value := range test.expected.RequestHeaders {
assert.Equal(t, value, req.Header.Get(key))
}
})
}
}
func TestAccessLog(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
testCases := []struct {
desc string
routersConfig map[string]*config.Router
serviceConfig map[string]*config.Service
middlewaresConfig map[string]*config.Middleware
entryPoints []string
defaultEntryPoints []string
expected string
}{
{
desc: "apply routerName in accesslog (first match)",
routersConfig: map[string]*config.Router{
"foo": {
EntryPoints: []string{"web"},
Service: "foo-service",
Rule: "Host:foo.bar",
},
"bar": {
EntryPoints: []string{"web"},
Service: "foo-service",
Rule: "Host:bar.foo",
},
},
serviceConfig: map[string]*config.Service{
"foo-service": {
LoadBalancer: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: server.URL,
Weight: 1,
},
},
Method: "wrr",
},
},
},
entryPoints: []string{"web"},
expected: "foo",
},
{
desc: "apply routerName in accesslog (second match)",
routersConfig: map[string]*config.Router{
"foo": {
EntryPoints: []string{"web"},
Service: "foo-service",
Rule: "Host:bar.foo",
},
"bar": {
EntryPoints: []string{"web"},
Service: "foo-service",
Rule: "Host:foo.bar",
},
},
serviceConfig: map[string]*config.Service{
"foo-service": {
LoadBalancer: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: server.URL,
Weight: 1,
},
},
Method: "wrr",
},
},
},
entryPoints: []string{"web"},
expected: "bar",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
serviceManager := service.NewManager(test.serviceConfig, http.DefaultTransport)
middlewaresBuilder := middleware.NewBuilder(test.middlewaresConfig, serviceManager)
responseModifierFactory := responsemodifiers.NewBuilder(test.middlewaresConfig)
routerManager := NewManager(test.routersConfig, serviceManager, middlewaresBuilder, responseModifierFactory)
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints, test.defaultEntryPoints)
w := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
accesslogger, err := accesslog.NewHandler(&types.AccessLog{
Format: "json",
})
require.NoError(t, err)
reqHost := requestdecorator.New(nil)
accesslogger.ServeHTTP(w, req, http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
reqHost.ServeHTTP(w, req, handlers["web"].ServeHTTP)
data := accesslog.GetLogData(req)
require.NotNil(t, data)
assert.Equal(t, test.expected, data.Core[accesslog.RouterName])
}))
})
}
}

163
server/router/rules.go Normal file
View file

@ -0,0 +1,163 @@
package router
import (
"context"
"fmt"
"net/http"
"strings"
"github.com/containous/mux"
"github.com/containous/traefik/log"
"github.com/containous/traefik/middlewares/requestdecorator"
)
func addRoute(ctx context.Context, router *mux.Router, rule string, priority int, handler http.Handler) error {
matchers, err := parseRule(rule)
if err != nil {
return err
}
if priority == 0 {
priority = len(rule)
}
route := router.NewRoute().Handler(handler).Priority(priority)
for _, matcher := range matchers {
matcher(route)
if route.GetError() != nil {
log.FromContext(ctx).Error(route.GetError())
}
}
return nil
}
func parseRule(rule string) ([]func(*mux.Route), error) {
funcs := map[string]func(*mux.Route, ...string){
"Host": host,
"HostRegexp": hostRegexp,
"Path": path,
"PathPrefix": pathPrefix,
"Method": methods,
"Headers": headers,
"HeadersRegexp": headersRegexp,
"Query": query,
}
splitRule := func(c rune) bool {
return c == ';'
}
parsedRules := strings.FieldsFunc(rule, splitRule)
var matchers []func(*mux.Route)
for _, expression := range parsedRules {
expParts := strings.Split(expression, ":")
if len(expParts) > 1 && len(expParts[1]) > 0 {
if fn, ok := funcs[expParts[0]]; ok {
parseOr := func(c rune) bool {
return c == ','
}
exp := strings.FieldsFunc(strings.Join(expParts[1:], ":"), parseOr)
var trimmedExp []string
for _, value := range exp {
trimmedExp = append(trimmedExp, strings.TrimSpace(value))
}
// FIXME struct for onhostrule ?
matcher := func(rt *mux.Route) {
fn(rt, trimmedExp...)
}
matchers = append(matchers, matcher)
} else {
return nil, fmt.Errorf("invalid matcher: %s", expression)
}
}
}
return matchers, nil
}
func path(route *mux.Route, paths ...string) {
rt := route.Subrouter()
for _, path := range paths {
tmpRt := rt.Path(path)
if tmpRt.GetError() != nil {
log.WithoutContext().WithField("paths", strings.Join(paths, ",")).Error(tmpRt.GetError())
}
}
}
func pathPrefix(route *mux.Route, paths ...string) {
rt := route.Subrouter()
for _, path := range paths {
tmpRt := rt.PathPrefix(path)
if tmpRt.GetError() != nil {
log.WithoutContext().WithField("paths", strings.Join(paths, ",")).Error(tmpRt.GetError())
}
}
}
func host(route *mux.Route, hosts ...string) {
for i, host := range hosts {
hosts[i] = strings.ToLower(host)
}
route.MatcherFunc(func(req *http.Request, route *mux.RouteMatch) bool {
reqHost := requestdecorator.GetCanonizedHost(req.Context())
if len(reqHost) == 0 {
log.FromContext(req.Context()).Warnf("Could not retrieve CanonizedHost, rejecting %s", req.Host)
return false
}
flatH := requestdecorator.GetCNAMEFlatten(req.Context())
if len(flatH) > 0 {
for _, host := range hosts {
if strings.EqualFold(reqHost, host) || strings.EqualFold(flatH, host) {
return true
}
log.FromContext(req.Context()).Debugf("CNAMEFlattening: request %s which resolved to %s, is not matched to route %s", reqHost, flatH, host)
}
return false
}
for _, host := range hosts {
if reqHost == host {
return true
}
}
return false
})
}
func hostRegexp(route *mux.Route, hosts ...string) {
router := route.Subrouter()
for _, host := range hosts {
router.Host(host)
}
}
func methods(route *mux.Route, methods ...string) {
route.Methods(methods...)
}
func headers(route *mux.Route, headers ...string) {
route.Headers(headers...)
}
func headersRegexp(route *mux.Route, headers ...string) {
route.HeadersRegexp(headers...)
}
func query(route *mux.Route, query ...string) {
var queries []string
for _, elem := range query {
queries = append(queries, strings.Split(elem, "=")...)
}
route.Queries(queries...)
}

442
server/router/rules_test.go Normal file
View file

@ -0,0 +1,442 @@
package router
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/mux"
"github.com/containous/traefik/middlewares/requestdecorator"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_addRoute(t *testing.T) {
testCases := []struct {
desc string
rule string
headers map[string]string
expected map[string]int
}{
{
desc: "no rule",
expected: map[string]int{
"http://localhost/foo": http.StatusOK,
},
},
{
desc: "PathPrefix",
rule: "PathPrefix:/foo",
expected: map[string]int{
"http://localhost/foo": http.StatusOK,
},
},
{
desc: "wrong PathPrefix",
rule: "PathPrefix:/bar",
expected: map[string]int{
"http://localhost/foo": http.StatusNotFound,
},
},
{
desc: "Host",
rule: "Host:localhost",
expected: map[string]int{
"http://localhost/foo": http.StatusOK,
},
},
{
desc: "wrong Host",
rule: "Host:nope",
expected: map[string]int{
"http://localhost/foo": http.StatusNotFound,
},
},
{
desc: "Host and PathPrefix",
rule: "Host:localhost;PathPrefix:/foo",
expected: map[string]int{
"http://localhost/foo": http.StatusOK,
},
},
{
desc: "Host and PathPrefix: wrong PathPrefix",
rule: "Host:localhost;PathPrefix:/bar",
expected: map[string]int{
"http://localhost/foo": http.StatusNotFound,
},
},
{
desc: "Host and PathPrefix: wrong Host",
rule: "Host:nope;PathPrefix:/bar",
expected: map[string]int{
"http://localhost/foo": http.StatusNotFound,
},
},
{
desc: "Host and PathPrefix: Host OR, first host",
rule: "Host:nope,localhost;PathPrefix:/foo",
expected: map[string]int{
"http://localhost/foo": http.StatusOK,
},
},
{
desc: "Host and PathPrefix: Host OR, second host",
rule: "Host:nope,localhost;PathPrefix:/foo",
expected: map[string]int{
"http://nope/foo": http.StatusOK,
},
},
{
desc: "Host and PathPrefix: Host OR, first host and wrong PathPrefix",
rule: "Host:nope,localhost;PathPrefix:/bar",
expected: map[string]int{
"http://localhost/foo": http.StatusNotFound,
},
},
{
desc: "HostRegexp with capturing group",
rule: "HostRegexp: {subdomain:(foo\\.)?bar\\.com}",
expected: map[string]int{
"http://foo.bar.com": http.StatusOK,
"http://bar.com": http.StatusOK,
"http://fooubar.com": http.StatusNotFound,
"http://barucom": http.StatusNotFound,
"http://barcom": http.StatusNotFound,
},
},
{
desc: "HostRegexp with non capturing group",
rule: "HostRegexp: {subdomain:(?:foo\\.)?bar\\.com}",
expected: map[string]int{
"http://foo.bar.com": http.StatusOK,
"http://bar.com": http.StatusOK,
"http://fooubar.com": http.StatusNotFound,
"http://barucom": http.StatusNotFound,
"http://barcom": http.StatusNotFound,
},
},
{
desc: "Methods with GET",
rule: "Method: GET",
expected: map[string]int{
"http://localhost/foo": http.StatusOK,
},
},
{
desc: "Methods with GET and POST",
rule: "Method: GET,POST",
expected: map[string]int{
"http://localhost/foo": http.StatusOK,
},
},
{
desc: "Methods with POST",
rule: "Method: POST",
expected: map[string]int{
"http://localhost/foo": http.StatusMethodNotAllowed,
},
},
{
desc: "Header with matching header",
rule: "Headers: Content-Type,application/json",
headers: map[string]string{
"Content-Type": "application/json",
},
expected: map[string]int{
"http://localhost/foo": http.StatusOK,
},
},
{
desc: "Header without matching header",
rule: "Headers: Content-Type,application/foo",
headers: map[string]string{
"Content-Type": "application/json",
},
expected: map[string]int{
"http://localhost/foo": http.StatusNotFound,
},
},
{
desc: "HeaderRegExp with matching header",
rule: "HeadersRegexp: Content-Type, application/(text|json)",
headers: map[string]string{
"Content-Type": "application/json",
},
expected: map[string]int{
"http://localhost/foo": http.StatusOK,
},
},
{
desc: "HeaderRegExp without matching header",
rule: "HeadersRegexp: Content-Type, application/(text|json)",
headers: map[string]string{
"Content-Type": "application/foo",
},
expected: map[string]int{
"http://localhost/foo": http.StatusNotFound,
},
},
{
desc: "HeaderRegExp with matching second header",
rule: "HeadersRegexp: Content-Type, application/(text|json)",
headers: map[string]string{
"Content-Type": "application/text",
},
expected: map[string]int{
"http://localhost/foo": http.StatusOK,
},
},
{
desc: "Query with multiple params",
rule: "Query: foo=bar, bar=baz",
expected: map[string]int{
"http://localhost/foo?foo=bar&bar=baz": http.StatusOK,
"http://localhost/foo?bar=baz": http.StatusNotFound,
},
},
{
desc: "Invalid rule syntax",
rule: "Query:param_one=true, /path2;Path: /path1",
expected: map[string]int{
"http://localhost/foo?bar=baz": http.StatusNotFound,
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
router := mux.NewRouter()
router.SkipClean(true)
err := addRoute(context.Background(), router, test.rule, 0, handler)
require.NoError(t, err)
// RequestDecorator is necessary for the host rule
reqHost := requestdecorator.New(nil)
results := make(map[string]int)
for calledURL := range test.expected {
w := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, calledURL, nil)
for key, value := range test.headers {
req.Header.Set(key, value)
}
reqHost.ServeHTTP(w, req, router.ServeHTTP)
results[calledURL] = w.Code
}
assert.Equal(t, test.expected, results)
})
}
}
func Test_addRoutePriority(t *testing.T) {
type Case struct {
xFrom string
rule string
priority int
}
testCases := []struct {
desc string
path string
cases []Case
expected string
}{
{
desc: "Higher priority on second rule",
path: "/my",
cases: []Case{
{
xFrom: "header1",
rule: "PathPrefix:/my",
priority: 10,
},
{
xFrom: "header2",
rule: "PathPrefix:/my",
priority: 20,
},
},
expected: "header2",
},
{
desc: "Higher priority on first rule",
path: "/my",
cases: []Case{
{
xFrom: "header1",
rule: "PathPrefix:/my",
priority: 20,
},
{
xFrom: "header2",
rule: "PathPrefix:/my",
priority: 10,
},
},
expected: "header1",
},
{
desc: "Higher priority on second rule with different rule",
path: "/mypath",
cases: []Case{
{
xFrom: "header1",
rule: "PathPrefix:/mypath",
priority: 10,
},
{
xFrom: "header2",
rule: "PathPrefix:/my",
priority: 20,
},
},
expected: "header2",
},
{
desc: "Higher priority on longest rule (longest first)",
path: "/mypath",
cases: []Case{
{
xFrom: "header1",
rule: "PathPrefix:/mypath",
},
{
xFrom: "header2",
rule: "PathPrefix:/my",
},
},
expected: "header1",
},
{
desc: "Higher priority on longest rule (longest second)",
path: "/mypath",
cases: []Case{
{
xFrom: "header1",
rule: "PathPrefix:/my",
},
{
xFrom: "header2",
rule: "PathPrefix:/mypath",
},
},
expected: "header2",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
router := mux.NewRouter()
for _, route := range test.cases {
route := route
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-From", route.xFrom)
})
err := addRoute(context.Background(), router, route.rule, route.priority, handler)
require.NoError(t, err, route)
}
router.SortRoutes()
w := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, test.path, nil)
router.ServeHTTP(w, req)
assert.Equal(t, test.expected, w.Header().Get("X-From"))
})
}
}
func TestHostRegexp(t *testing.T) {
testCases := []struct {
desc string
hostExp string
urls map[string]bool
}{
{
desc: "capturing group",
hostExp: "{subdomain:(foo\\.)?bar\\.com}",
urls: map[string]bool{
"http://foo.bar.com": true,
"http://bar.com": true,
"http://fooubar.com": false,
"http://barucom": false,
"http://barcom": false,
},
},
{
desc: "non capturing group",
hostExp: "{subdomain:(?:foo\\.)?bar\\.com}",
urls: map[string]bool{
"http://foo.bar.com": true,
"http://bar.com": true,
"http://fooubar.com": false,
"http://barucom": false,
"http://barcom": false,
},
},
{
desc: "regex insensitive",
hostExp: "{dummy:[A-Za-z-]+\\.bar\\.com}",
urls: map[string]bool{
"http://FOO.bar.com": true,
"http://foo.bar.com": true,
"http://fooubar.com": false,
"http://barucom": false,
"http://barcom": false,
},
},
{
desc: "insensitive host",
hostExp: "{dummy:[a-z-]+\\.bar\\.com}",
urls: map[string]bool{
"http://FOO.bar.com": true,
"http://foo.bar.com": true,
"http://fooubar.com": false,
"http://barucom": false,
"http://barcom": false,
},
},
{
desc: "insensitive host simple",
hostExp: "foo.bar.com",
urls: map[string]bool{
"http://FOO.bar.com": true,
"http://foo.bar.com": true,
"http://fooubar.com": false,
"http://barucom": false,
"http://barcom": false,
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
rt := &mux.Route{}
hostRegexp(rt, test.hostExp)
for testURL, match := range test.urls {
req := testhelpers.MustNewRequest(http.MethodGet, testURL, nil)
assert.Equal(t, match, rt.Match(req, &mux.RouteMatch{}), testURL)
}
})
}
}

View file

@ -9,35 +9,36 @@ import (
stdlog "log"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"os/signal"
"sync"
"time"
"github.com/armon/go-proxyproto"
"github.com/containous/mux"
"github.com/containous/traefik/cluster"
"github.com/containous/traefik/configuration"
"github.com/containous/traefik/config"
"github.com/containous/traefik/config/static"
"github.com/containous/traefik/h2c"
"github.com/containous/traefik/ip"
"github.com/containous/traefik/log"
"github.com/containous/traefik/metrics"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/middlewares/accesslog"
"github.com/containous/traefik/middlewares/tracing"
"github.com/containous/traefik/middlewares/requestdecorator"
"github.com/containous/traefik/old/configuration"
"github.com/containous/traefik/provider"
"github.com/containous/traefik/safe"
"github.com/containous/traefik/server/middleware"
traefiktls "github.com/containous/traefik/tls"
"github.com/containous/traefik/tracing"
"github.com/containous/traefik/tracing/datadog"
"github.com/containous/traefik/tracing/jaeger"
"github.com/containous/traefik/tracing/zipkin"
"github.com/containous/traefik/types"
"github.com/sirupsen/logrus"
"github.com/urfave/negroni"
"github.com/xenolf/lego/acme"
)
var httpServerLogger = stdlog.New(log.WriterLevel(logrus.DebugLevel), "", 0)
func newHijackConnectionTracker() *hijackConnectionTracker {
return &hijackConnectionTracker{
conns: make(map[net.Conn]struct{}),
@ -85,7 +86,7 @@ func (h *hijackConnectionTracker) Shutdown(ctx context.Context) error {
func (h *hijackConnectionTracker) Close() {
for conn := range h.conns {
if err := conn.Close(); err != nil {
log.Errorf("Error while closing Hijacked conn: %v", err)
log.WithoutContext().Errorf("Error while closing Hijacked connection: %v", err)
}
delete(h.conns, conn)
}
@ -93,33 +94,38 @@ func (h *hijackConnectionTracker) Close() {
// Server is the reverse-proxy/load-balancer engine
type Server struct {
serverEntryPoints serverEntryPoints
configurationChan chan types.ConfigMessage
configurationValidatedChan chan types.ConfigMessage
signals chan os.Signal
stopChan chan bool
currentConfigurations safe.Safe
providerConfigUpdateMap map[string]chan types.ConfigMessage
globalConfiguration configuration.GlobalConfiguration
accessLoggerMiddleware *accesslog.LogHandler
tracingMiddleware *tracing.Tracing
routinesPool *safe.Pool
leadership *cluster.Leadership
defaultForwardingRoundTripper http.RoundTripper
metricsRegistry metrics.Registry
provider provider.Provider
configurationListeners []func(types.Configuration)
entryPoints map[string]EntryPoint
bufferPool httputil.BufferPool
serverEntryPoints serverEntryPoints
configurationChan chan config.Message
configurationValidatedChan chan config.Message
signals chan os.Signal
stopChan chan bool
currentConfigurations safe.Safe
providerConfigUpdateMap map[string]chan config.Message
globalConfiguration configuration.GlobalConfiguration
accessLoggerMiddleware *accesslog.Handler
tracer *tracing.Tracing
routinesPool *safe.Pool
leadership *cluster.Leadership
defaultRoundTripper http.RoundTripper
metricsRegistry metrics.Registry
provider provider.Provider
configurationListeners []func(config.Configuration)
entryPoints map[string]EntryPoint
requestDecorator *requestdecorator.RequestDecorator
}
// RouteAppenderFactory the route appender factory interface
type RouteAppenderFactory interface {
NewAppender(ctx context.Context, middlewaresBuilder *middleware.Builder, currentConfigurations *safe.Safe) types.RouteAppender
}
// EntryPoint entryPoint information (configuration + internalRouter)
type EntryPoint struct {
InternalRouter types.InternalRouter
Configuration *configuration.EntryPoint
OnDemandListener func(string) (*tls.Certificate, error)
TLSALPNGetter func(string) (*tls.Certificate, error)
CertificateStore *traefiktls.CertificateStore
RouteAppenderFactory RouteAppenderFactory
Configuration *configuration.EntryPoint
OnDemandListener func(string) (*tls.Certificate, error)
TLSALPNGetter func(string) (*tls.Certificate, error)
CertificateStore *traefiktls.CertificateStore
}
type serverEntryPoints map[string]*serverEntryPoint
@ -142,10 +148,11 @@ func (s serverEntryPoint) Shutdown(ctx context.Context) {
defer wg.Done()
if err := s.httpServer.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
log.Debugf("Wait server shutdown is over due to: %s", err)
logger := log.FromContext(ctx)
logger.Debugf("Wait server shutdown is over due to: %s", err)
err = s.httpServer.Close()
if err != nil {
log.Error(err)
logger.Error(err)
}
}
}
@ -158,7 +165,8 @@ func (s serverEntryPoint) Shutdown(ctx context.Context) {
defer wg.Done()
if err := s.hijackConnectionTracker.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
log.Debugf("Wait hijack connection is over due to: %s", err)
logger := log.FromContext(ctx)
logger.Debugf("Wait hijack connection is over due to: %s", err)
s.hijackConnectionTracker.Close()
}
}
@ -179,11 +187,32 @@ func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
if err != nil {
return nil, err
}
tc.SetKeepAlive(true)
tc.SetKeepAlivePeriod(3 * time.Minute)
if err = tc.SetKeepAlive(true); err != nil {
return nil, err
}
if err = tc.SetKeepAlivePeriod(3 * time.Minute); err != nil {
return nil, err
}
return tc, nil
}
func setupTracing(conf *static.Tracing) tracing.TrackingBackend {
switch conf.Backend {
case jaeger.Name:
return conf.Jaeger
case zipkin.Name:
return conf.Zipkin
case datadog.Name:
return conf.DataDog
default:
log.WithoutContext().Warnf("Could not initialize tracing: unknown tracer %q", conf.Backend)
return nil
}
}
// NewServer returns an initialized Server.
func NewServer(globalConfiguration configuration.GlobalConfiguration, provider provider.Provider, entrypoints map[string]EntryPoint) *Server {
server := &Server{}
@ -192,36 +221,41 @@ func NewServer(globalConfiguration configuration.GlobalConfiguration, provider p
server.provider = provider
server.globalConfiguration = globalConfiguration
server.serverEntryPoints = make(map[string]*serverEntryPoint)
server.configurationChan = make(chan types.ConfigMessage, 100)
server.configurationValidatedChan = make(chan types.ConfigMessage, 100)
server.configurationChan = make(chan config.Message, 100)
server.configurationValidatedChan = make(chan config.Message, 100)
server.signals = make(chan os.Signal, 1)
server.stopChan = make(chan bool, 1)
server.configureSignals()
currentConfigurations := make(types.Configurations)
currentConfigurations := make(config.Configurations)
server.currentConfigurations.Set(currentConfigurations)
server.providerConfigUpdateMap = make(map[string]chan types.ConfigMessage)
server.providerConfigUpdateMap = make(map[string]chan config.Message)
transport, err := createHTTPTransport(globalConfiguration)
if err != nil {
log.WithoutContext().Error(err)
server.defaultRoundTripper = http.DefaultTransport
} else {
server.defaultRoundTripper = transport
}
if server.globalConfiguration.API != nil {
server.globalConfiguration.API.CurrentConfigurations = &server.currentConfigurations
}
server.bufferPool = newBufferPool()
server.routinesPool = safe.NewPool(context.Background())
transport, err := createHTTPTransport(globalConfiguration)
if err != nil {
log.Errorf("failed to create HTTP transport: %v", err)
if globalConfiguration.Tracing != nil {
trackingBackend := setupTracing(static.ConvertTracing(globalConfiguration.Tracing))
var err error
server.tracer, err = tracing.NewTracing(globalConfiguration.Tracing.ServiceName, globalConfiguration.Tracing.SpanNameLimit, trackingBackend)
if err != nil {
log.WithoutContext().Warnf("Unable to create tracer: %v", err)
}
}
server.defaultForwardingRoundTripper = transport
server.requestDecorator = requestdecorator.New(static.ConvertHostResolverConfig(globalConfiguration.HostResolver))
server.tracingMiddleware = globalConfiguration.Tracing
if server.tracingMiddleware != nil && server.tracingMiddleware.Backend != "" {
server.tracingMiddleware.Setup()
}
server.metricsRegistry = registerMetricClients(globalConfiguration.Metrics)
server.metricsRegistry = registerMetricClients(static.ConvertMetrics(globalConfiguration.Metrics))
if globalConfiguration.Cluster != nil {
// leadership creation if cluster mode
@ -230,9 +264,9 @@ func NewServer(globalConfiguration configuration.GlobalConfiguration, provider p
if globalConfiguration.AccessLog != nil {
var err error
server.accessLoggerMiddleware, err = accesslog.NewLogHandler(globalConfiguration.AccessLog)
server.accessLoggerMiddleware, err = accesslog.NewHandler(static.ConvertAccessLog(globalConfiguration.AccessLog))
if err != nil {
log.Warnf("Unable to create log handler: %s", err)
log.WithoutContext().Warnf("Unable to create access logger : %v", err)
}
}
return server
@ -259,13 +293,16 @@ func (s *Server) StartWithContext(ctx context.Context) {
go func() {
defer s.Close()
<-ctx.Done()
log.Info("I have to go...")
logger := log.FromContext(ctx)
logger.Info("I have to go...")
reqAcceptGraceTimeOut := time.Duration(s.globalConfiguration.LifeCycle.RequestAcceptGraceTimeout)
if reqAcceptGraceTimeOut > 0 {
log.Infof("Waiting %s for incoming requests to cease", reqAcceptGraceTimeOut)
logger.Infof("Waiting %s for incoming requests to cease", reqAcceptGraceTimeOut)
time.Sleep(reqAcceptGraceTimeOut)
}
log.Info("Stopping server gracefully")
logger.Info("Stopping server gracefully")
s.Stop()
}()
s.Start()
@ -278,18 +315,23 @@ func (s *Server) Wait() {
// Stop stops the server
func (s *Server) Stop() {
defer log.Info("Server stopped")
defer log.WithoutContext().Info("Server stopped")
var wg sync.WaitGroup
for sepn, sep := range s.serverEntryPoints {
wg.Add(1)
go func(serverEntryPointName string, serverEntryPoint *serverEntryPoint) {
defer wg.Done()
logger := log.WithoutContext().WithField(log.EntryPointName, serverEntryPointName)
graceTimeOut := time.Duration(s.globalConfiguration.LifeCycle.GraceTimeOut)
ctx, cancel := context.WithTimeout(context.Background(), graceTimeOut)
log.Debugf("Waiting %s seconds before killing connections on entrypoint %s...", graceTimeOut, serverEntryPointName)
logger.Debugf("Waiting %s seconds before killing connections on entrypoint %s...", graceTimeOut, serverEntryPointName)
serverEntryPoint.Shutdown(ctx)
cancel()
log.Debugf("Entrypoint %s closed", serverEntryPointName)
logger.Debugf("Entry point %s closed", serverEntryPointName)
}(sepn, sep)
}
wg.Wait()
@ -307,6 +349,7 @@ func (s *Server) Close() {
panic("Timeout while stopping traefik, killing instance ✝")
}
}(ctx)
stopMetricsClients()
s.stopLeadership()
s.routinesPool.Cleanup()
@ -315,11 +358,17 @@ func (s *Server) Close() {
signal.Stop(s.signals)
close(s.signals)
close(s.stopChan)
if s.accessLoggerMiddleware != nil {
if err := s.accessLoggerMiddleware.Close(); err != nil {
log.Errorf("Error closing access log file: %s", err)
log.WithoutContext().Errorf("Could not close the access log file: %s", err)
}
}
if s.tracer != nil {
s.tracer.Close()
}
cancel()
}
@ -339,8 +388,9 @@ func (s *Server) startHTTPServers() {
s.serverEntryPoints = s.buildServerEntryPoints()
for newServerEntryPointName, newServerEntryPoint := range s.serverEntryPoints {
serverEntryPoint := s.setupServerEntryPoint(newServerEntryPointName, newServerEntryPoint)
go s.startServer(serverEntryPoint)
ctx := log.With(context.Background(), log.Str(log.EntryPointName, newServerEntryPointName))
serverEntryPoint := s.setupServerEntryPoint(ctx, newServerEntryPointName, newServerEntryPoint)
go s.startServer(ctx, serverEntryPoint)
}
}
@ -359,9 +409,9 @@ func (s *Server) listenProviders(stop chan bool) {
}
// AddListener adds a new listener function used when new configuration is provided
func (s *Server) AddListener(listener func(types.Configuration)) {
func (s *Server) AddListener(listener func(config.Configuration)) {
if s.configurationListeners == nil {
s.configurationListeners = make([]func(types.Configuration), 0)
s.configurationListeners = make([]func(config.Configuration), 0)
}
s.configurationListeners = append(s.configurationListeners, listener)
}
@ -395,22 +445,23 @@ func (s *serverEntryPoint) getCertificate(clientHello *tls.ClientHelloInfo) (*tl
return nil, fmt.Errorf("strict SNI enabled - No certificate found for domain: %q, closing connection", domainToCheck)
}
log.Debugf("Serving default cert for request: %q", domainToCheck)
log.WithoutContext().Debugf("Serving default certificate for request: %q", domainToCheck)
return s.certs.DefaultCertificate, nil
}
func (s *Server) startProvider() {
// start providers
jsonConf, err := json.Marshal(s.provider)
if err != nil {
log.Debugf("Unable to marshal provider conf %T with error: %v", s.provider, err)
log.WithoutContext().Debugf("Unable to marshal provider configuration %T: %v", s.provider, err)
}
log.Infof("Starting provider %T %s", s.provider, jsonConf)
log.WithoutContext().Infof("Starting provider %T %s", s.provider, jsonConf)
currentProvider := s.provider
safe.Go(func() {
err := currentProvider.Provide(s.configurationChan, s.routinesPool)
if err != nil {
log.Errorf("Error starting provider %T: %s", s.provider, err)
log.WithoutContext().Errorf("Error starting provider %T: %s", s.provider, err)
}
})
}
@ -421,7 +472,7 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL
return nil, nil
}
config, err := tlsOption.Certificates.CreateTLSConfig(entryPointName)
conf, err := tlsOption.Certificates.CreateTLSConfig(entryPointName)
if err != nil {
return nil, err
}
@ -429,7 +480,7 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL
s.serverEntryPoints[entryPointName].certs.DynamicCerts.Set(make(map[string]*tls.Certificate))
// ensure http2 enabled
config.NextProtos = []string{"h2", "http/1.1", acme.ACMETLS1Protocol}
conf.NextProtos = []string{"h2", "http/1.1", acme.ACMETLS1Protocol}
if len(tlsOption.ClientCA.Files) > 0 {
pool := x509.NewCertPool()
@ -443,55 +494,59 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL
return nil, fmt.Errorf("invalid certificate(s) in %s", caFile)
}
}
config.ClientCAs = pool
conf.ClientCAs = pool
if tlsOption.ClientCA.Optional {
config.ClientAuth = tls.VerifyClientCertIfGiven
conf.ClientAuth = tls.VerifyClientCertIfGiven
} else {
config.ClientAuth = tls.RequireAndVerifyClientCert
conf.ClientAuth = tls.RequireAndVerifyClientCert
}
}
if s.globalConfiguration.ACME != nil && entryPointName == s.globalConfiguration.ACME.EntryPoint {
checkOnDemandDomain := func(domain string) bool {
routeMatch := &mux.RouteMatch{}
match := router.GetHandler().Match(&http.Request{URL: &url.URL{}, Host: domain}, routeMatch)
if match && routeMatch.Route != nil {
return true
}
return false
}
err := s.globalConfiguration.ACME.CreateClusterConfig(s.leadership, config, s.serverEntryPoints[entryPointName].certs.DynamicCerts, checkOnDemandDomain)
if err != nil {
return nil, err
}
// FIXME onDemand
if s.globalConfiguration.ACME != nil {
// if entryPointName == s.globalConfiguration.ACME.EntryPoint {
// checkOnDemandDomain := func(domain string) bool {
// routeMatch := &mux.RouteMatch{}
// match := router.GetHandler().Match(&http.Request{URL: &url.URL{}, Host: domain}, routeMatch)
// if match && routeMatch.Route != nil {
// return true
// }
// return false
// }
//
// err := s.globalConfiguration.ACME.CreateClusterConfig(s.leadership, config, s.serverEntryPoints[entryPointName].certs.DynamicCerts, checkOnDemandDomain)
// if err != nil {
// return nil, err
// }
// }
} else {
config.GetCertificate = s.serverEntryPoints[entryPointName].getCertificate
if len(config.Certificates) != 0 {
certMap := s.buildNameOrIPToCertificate(config.Certificates)
if s.entryPoints[entryPointName].CertificateStore != nil {
s.entryPoints[entryPointName].CertificateStore.StaticCerts.Set(certMap)
}
}
// Remove certs from the TLS config object
config.Certificates = []tls.Certificate{}
conf.GetCertificate = s.serverEntryPoints[entryPointName].getCertificate
}
if len(conf.Certificates) != 0 {
certMap := s.buildNameOrIPToCertificate(conf.Certificates)
if s.entryPoints[entryPointName].CertificateStore != nil {
s.entryPoints[entryPointName].CertificateStore.StaticCerts.Set(certMap)
}
}
// Remove certs from the TLS config object
conf.Certificates = []tls.Certificate{}
// Set the minimum TLS version if set in the config TOML
if minConst, exists := traefiktls.MinVersion[s.entryPoints[entryPointName].Configuration.TLS.MinVersion]; exists {
config.PreferServerCipherSuites = true
config.MinVersion = minConst
if minConst, exists := traefiktls.MinVersion[tlsOption.MinVersion]; exists {
conf.PreferServerCipherSuites = true
conf.MinVersion = minConst
}
// Set the list of CipherSuites if set in the config TOML
if s.entryPoints[entryPointName].Configuration.TLS.CipherSuites != nil {
// if our list of CipherSuites is defined in the entrypoint config, we can re-initilize the suites list as empty
config.CipherSuites = make([]uint16, 0)
for _, cipher := range s.entryPoints[entryPointName].Configuration.TLS.CipherSuites {
if tlsOption.CipherSuites != nil {
// if our list of CipherSuites is defined in the entryPoint config, we can re-initialize the suites list as empty
conf.CipherSuites = make([]uint16, 0)
for _, cipher := range tlsOption.CipherSuites {
if cipherConst, exists := traefiktls.CipherSuites[cipher]; exists {
config.CipherSuites = append(config.CipherSuites, cipherConst)
conf.CipherSuites = append(conf.CipherSuites, cipherConst)
} else {
// CipherSuite listed in the toml does not exist in our listed
return nil, fmt.Errorf("invalid CipherSuite: %s", cipher)
@ -499,11 +554,12 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL
}
}
return config, nil
return conf, nil
}
func (s *Server) startServer(serverEntryPoint *serverEntryPoint) {
log.Infof("Starting server on %s", serverEntryPoint.httpServer.Addr)
func (s *Server) startServer(ctx context.Context, serverEntryPoint *serverEntryPoint) {
logger := log.FromContext(ctx)
logger.Infof("Starting server on %s", serverEntryPoint.httpServer.Addr)
var err error
if serverEntryPoint.httpServer.TLSConfig != nil {
@ -513,19 +569,14 @@ func (s *Server) startServer(serverEntryPoint *serverEntryPoint) {
}
if err != http.ErrServerClosed {
log.Error("Error creating server: ", err)
logger.Error("Cannot create server: %v", err)
}
}
func (s *Server) setupServerEntryPoint(newServerEntryPointName string, newServerEntryPoint *serverEntryPoint) *serverEntryPoint {
serverMiddlewares, err := s.buildServerEntryPointMiddlewares(newServerEntryPointName)
func (s *Server) setupServerEntryPoint(ctx context.Context, newServerEntryPointName string, newServerEntryPoint *serverEntryPoint) *serverEntryPoint {
newSrv, listener, err := s.prepareServer(ctx, newServerEntryPointName, s.entryPoints[newServerEntryPointName].Configuration, newServerEntryPoint.httpRouter)
if err != nil {
log.Fatal("Error preparing server: ", err)
}
newSrv, listener, err := s.prepareServer(newServerEntryPointName, s.entryPoints[newServerEntryPointName].Configuration, newServerEntryPoint.httpRouter, serverMiddlewares)
if err != nil {
log.Fatal("Error preparing server: ", err)
log.FromContext(ctx).Fatalf("Error preparing server: %v", err)
}
serverEntryPoint := s.serverEntryPoints[newServerEntryPointName]
@ -545,19 +596,15 @@ func (s *Server) setupServerEntryPoint(newServerEntryPointName string, newServer
return serverEntryPoint
}
func (s *Server) prepareServer(entryPointName string, entryPoint *configuration.EntryPoint, router *middlewares.HandlerSwitcher, middlewares []negroni.Handler) (*h2c.Server, net.Listener, error) {
func (s *Server) prepareServer(ctx context.Context, entryPointName string, entryPoint *configuration.EntryPoint, router *middlewares.HandlerSwitcher) (*h2c.Server, net.Listener, error) {
logger := log.FromContext(ctx)
readTimeout, writeTimeout, idleTimeout := buildServerTimeouts(s.globalConfiguration)
log.Infof("Preparing server %s %+v with readTimeout=%s writeTimeout=%s idleTimeout=%s", entryPointName, entryPoint, readTimeout, writeTimeout, idleTimeout)
// middlewares
n := negroni.New()
for _, middleware := range middlewares {
n.Use(middleware)
}
n.UseHandler(router)
internalMuxRouter := s.buildInternalRouter(entryPointName)
internalMuxRouter.NotFoundHandler = n
logger.
WithField("readTimeout", readTimeout).
WithField("writeTimeout", writeTimeout).
WithField("idleTimeout", idleTimeout).
Infof("Preparing server %+v", entryPoint)
tlsConfig, err := s.createTLSConfig(entryPointName, entryPoint.TLS, router)
if err != nil {
@ -572,16 +619,18 @@ func (s *Server) prepareServer(entryPointName string, entryPoint *configuration.
listener = tcpKeepAliveListener{listener.(*net.TCPListener)}
if entryPoint.ProxyProtocol != nil {
listener, err = buildProxyProtocolListener(entryPoint, listener)
listener, err = buildProxyProtocolListener(ctx, entryPoint, listener)
if err != nil {
return nil, nil, fmt.Errorf("error creating proxy protocol listener: %v", err)
}
}
httpServerLogger := stdlog.New(logger.WriterLevel(logrus.DebugLevel), "", 0)
return &h2c.Server{
Server: &http.Server{
Addr: entryPoint.Address,
Handler: internalMuxRouter,
Handler: router,
TLSConfig: tlsConfig,
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
@ -593,7 +642,7 @@ func (s *Server) prepareServer(entryPointName string, entryPoint *configuration.
nil
}
func buildProxyProtocolListener(entryPoint *configuration.EntryPoint, listener net.Listener) (net.Listener, error) {
func buildProxyProtocolListener(ctx context.Context, entryPoint *configuration.EntryPoint, listener net.Listener) (net.Listener, error) {
var sourceCheck func(addr net.Addr) (bool, error)
if entryPoint.ProxyProtocol.Insecure {
sourceCheck = func(_ net.Addr) (bool, error) {
@ -615,7 +664,7 @@ func buildProxyProtocolListener(entryPoint *configuration.EntryPoint, listener n
}
}
log.Infof("Enabling ProxyProtocol for trusted IPs %v", entryPoint.ProxyProtocol.TrustedIPs)
log.FromContext(ctx).Infof("Enabling ProxyProtocol for trusted IPs %v", entryPoint.ProxyProtocol.TrustedIPs)
return &proxyproto.Listener{
Listener: listener,
@ -623,23 +672,6 @@ func buildProxyProtocolListener(entryPoint *configuration.EntryPoint, listener n
}, nil
}
func (s *Server) buildInternalRouter(entryPointName string) *mux.Router {
internalMuxRouter := mux.NewRouter()
internalMuxRouter.StrictSlash(!s.globalConfiguration.KeepTrailingSlash)
internalMuxRouter.SkipClean(true)
if entryPoint, ok := s.entryPoints[entryPointName]; ok && entryPoint.InternalRouter != nil {
entryPoint.InternalRouter.AddRoutes(internalMuxRouter)
if s.globalConfiguration.API != nil && s.globalConfiguration.API.EntryPoint == entryPointName && s.leadership != nil {
s.leadership.AddRoutes(internalMuxRouter)
}
}
return internalMuxRouter
}
func buildServerTimeouts(globalConfig configuration.GlobalConfiguration) (readTimeout, writeTimeout, idleTimeout time.Duration) {
readTimeout = time.Duration(0)
writeTimeout = time.Duration(0)
@ -663,24 +695,35 @@ func registerMetricClients(metricsConfig *types.Metrics) metrics.Registry {
}
var registries []metrics.Registry
if metricsConfig.Prometheus != nil {
prometheusRegister := metrics.RegisterPrometheus(metricsConfig.Prometheus)
ctx := log.With(context.Background(), log.Str(log.MetricsProviderName, "prometheus"))
prometheusRegister := metrics.RegisterPrometheus(ctx, metricsConfig.Prometheus)
if prometheusRegister != nil {
registries = append(registries, prometheusRegister)
log.Debug("Configured Prometheus metrics")
log.FromContext(ctx).Debug("Configured Prometheus metrics")
}
}
if metricsConfig.Datadog != nil {
registries = append(registries, metrics.RegisterDatadog(metricsConfig.Datadog))
log.Debugf("Configured DataDog metrics pushing to %s once every %s", metricsConfig.Datadog.Address, metricsConfig.Datadog.PushInterval)
ctx := log.With(context.Background(), log.Str(log.MetricsProviderName, "datadog"))
registries = append(registries, metrics.RegisterDatadog(ctx, metricsConfig.Datadog))
log.FromContext(ctx).Debugf("Configured DataDog metrics: pushing to %s once every %s",
metricsConfig.Datadog.Address, metricsConfig.Datadog.PushInterval)
}
if metricsConfig.StatsD != nil {
registries = append(registries, metrics.RegisterStatsd(metricsConfig.StatsD))
log.Debugf("Configured StatsD metrics pushing to %s once every %s", metricsConfig.StatsD.Address, metricsConfig.StatsD.PushInterval)
ctx := log.With(context.Background(), log.Str(log.MetricsProviderName, "statsd"))
registries = append(registries, metrics.RegisterStatsd(ctx, metricsConfig.StatsD))
log.FromContext(ctx).Debugf("Configured StatsD metrics: pushing to %s once every %s",
metricsConfig.StatsD.Address, metricsConfig.StatsD.PushInterval)
}
if metricsConfig.InfluxDB != nil {
registries = append(registries, metrics.RegisterInfluxDB(metricsConfig.InfluxDB))
log.Debugf("Configured InfluxDB metrics pushing to %s once every %s", metricsConfig.InfluxDB.Address, metricsConfig.InfluxDB.PushInterval)
ctx := log.With(context.Background(), log.Str(log.MetricsProviderName, "influxdb"))
registries = append(registries, metrics.RegisterInfluxDB(ctx, metricsConfig.InfluxDB))
log.FromContext(ctx).Debugf("Configured InfluxDB metrics: pushing to %s once every %s",
metricsConfig.InfluxDB.Address, metricsConfig.InfluxDB.PushInterval)
}
return metrics.NewMultiRegistry(registries)

View file

@ -1,41 +1,42 @@
package server
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"net"
"net/http"
"reflect"
"sort"
"strings"
"time"
"github.com/containous/flaeg/parse"
"github.com/containous/alice"
"github.com/containous/mux"
"github.com/containous/traefik/configuration"
"github.com/containous/traefik/healthcheck"
"github.com/containous/traefik/hostresolver"
"github.com/containous/traefik/config"
"github.com/containous/traefik/config/static"
"github.com/containous/traefik/log"
"github.com/containous/traefik/metrics"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/middlewares/pipelining"
"github.com/containous/traefik/rules"
"github.com/containous/traefik/middlewares/accesslog"
"github.com/containous/traefik/middlewares/requestdecorator"
"github.com/containous/traefik/middlewares/tracing"
"github.com/containous/traefik/old/configuration"
"github.com/containous/traefik/responsemodifiers"
"github.com/containous/traefik/server/middleware"
"github.com/containous/traefik/server/router"
"github.com/containous/traefik/server/service"
traefiktls "github.com/containous/traefik/tls"
"github.com/containous/traefik/tls/generate"
"github.com/containous/traefik/types"
"github.com/eapache/channels"
"github.com/sirupsen/logrus"
"github.com/urfave/negroni"
"github.com/vulcand/oxy/forward"
)
// loadConfiguration manages dynamically frontends, backends and TLS configurations
func (s *Server) loadConfiguration(configMsg types.ConfigMessage) {
currentConfigurations := s.currentConfigurations.Get().(types.Configurations)
func (s *Server) loadConfiguration(configMsg config.Message) {
logger := log.FromContext(log.With(context.Background(), log.Str(log.ProviderName, configMsg.ProviderName)))
currentConfigurations := s.currentConfigurations.Get().(config.Configurations)
// Copy configurations to new map so we don't change current if LoadConfig fails
newConfigurations := make(types.Configurations)
newConfigurations := make(config.Configurations)
for k, v := range currentConfigurations {
newConfigurations[k] = v
}
@ -43,22 +44,25 @@ func (s *Server) loadConfiguration(configMsg types.ConfigMessage) {
s.metricsRegistry.ConfigReloadsCounter().Add(1)
newServerEntryPoints := s.loadConfig(newConfigurations, s.globalConfiguration)
handlers, certificates := s.loadConfig(newConfigurations, s.globalConfiguration)
s.metricsRegistry.LastConfigReloadSuccessGauge().Set(float64(time.Now().Unix()))
for newServerEntryPointName, newServerEntryPoint := range newServerEntryPoints {
s.serverEntryPoints[newServerEntryPointName].httpRouter.UpdateHandler(newServerEntryPoint.httpRouter.GetHandler())
for entryPointName, handler := range handlers {
s.serverEntryPoints[entryPointName].httpRouter.UpdateHandler(handler)
}
if s.entryPoints[newServerEntryPointName].Configuration.TLS == nil {
if newServerEntryPoint.certs.ContainsCertificates() {
log.Debugf("Certificates not added to non-TLS entryPoint %s.", newServerEntryPointName)
for entryPointName, serverEntryPoint := range s.serverEntryPoints {
eLogger := logger.WithField(log.EntryPointName, entryPointName)
if s.entryPoints[entryPointName].Configuration.TLS == nil {
if len(certificates[entryPointName]) > 0 {
eLogger.Debugf("Cannot configure certificates for the non-TLS %s entryPoint.", entryPointName)
}
} else {
s.serverEntryPoints[newServerEntryPointName].certs.DynamicCerts.Set(newServerEntryPoint.certs.DynamicCerts.Get())
s.serverEntryPoints[newServerEntryPointName].certs.ResetCache()
serverEntryPoint.certs.DynamicCerts.Set(certificates[entryPointName])
serverEntryPoint.certs.ResetCache()
}
log.Infof("Server configuration reloaded on %s", s.serverEntryPoints[newServerEntryPointName].httpServer.Addr)
eLogger.Infof("Server configuration reloaded on %s", s.serverEntryPoints[entryPointName].httpServer.Addr)
}
s.currentConfigurations.Set(newConfigurations)
@ -72,251 +76,127 @@ func (s *Server) loadConfiguration(configMsg types.ConfigMessage) {
// loadConfig returns a new gorilla.mux Route from the specified global configuration and the dynamic
// provider configurations.
func (s *Server) loadConfig(configurations types.Configurations, globalConfiguration configuration.GlobalConfiguration) map[string]*serverEntryPoint {
func (s *Server) loadConfig(configurations config.Configurations, globalConfiguration configuration.GlobalConfiguration) (map[string]http.Handler, map[string]map[string]*tls.Certificate) {
serverEntryPoints := s.buildServerEntryPoints()
ctx := context.TODO()
backendsHandlers := map[string]http.Handler{}
backendsHealthCheck := map[string]*healthcheck.BackendConfig{}
var postConfigs []handlerPostConfig
for providerName, config := range configurations {
frontendNames := sortedFrontendNamesForConfig(config)
for _, frontendName := range frontendNames {
frontendPostConfigs, err := s.loadFrontendConfig(providerName, frontendName, config,
serverEntryPoints,
backendsHandlers, backendsHealthCheck)
if err != nil {
log.Errorf("%v. Skipping frontend %s...", err, frontendName)
}
if len(frontendPostConfigs) > 0 {
postConfigs = append(postConfigs, frontendPostConfigs...)
}
// FIXME manage duplicates
conf := config.Configuration{
Routers: make(map[string]*config.Router),
Middlewares: make(map[string]*config.Middleware),
Services: make(map[string]*config.Service),
}
for _, config := range configurations {
for key, value := range config.Middlewares {
conf.Middlewares[key] = value
}
for key, value := range config.Services {
conf.Services[key] = value
}
for key, value := range config.Routers {
conf.Routers[key] = value
}
conf.TLS = append(conf.TLS, config.TLS...)
}
for _, postConfig := range postConfigs {
err := postConfig(backendsHandlers)
if err != nil {
log.Errorf("middleware post configuration error: %v", err)
}
}
handlers := s.applyConfiguration(ctx, conf)
healthcheck.GetHealthCheck(s.metricsRegistry).SetBackendsConfiguration(s.routinesPool.Ctx(), backendsHealthCheck)
// Get new certificates list sorted per entrypoints
// Get new certificates list sorted per entry points
// Update certificates
entryPointsCertificates := s.loadHTTPSConfiguration(configurations, globalConfiguration.DefaultEntryPoints)
// Sort routes and update certificates
for serverEntryPointName, serverEntryPoint := range serverEntryPoints {
serverEntryPoint.httpRouter.GetHandler().SortRoutes()
if _, exists := entryPointsCertificates[serverEntryPointName]; exists {
serverEntryPoint.certs.DynamicCerts.Set(entryPointsCertificates[serverEntryPointName])
}
}
return serverEntryPoints
return handlers, entryPointsCertificates
}
func (s *Server) loadFrontendConfig(
providerName string, frontendName string, config *types.Configuration,
serverEntryPoints map[string]*serverEntryPoint,
backendsHandlers map[string]http.Handler, backendsHealthCheck map[string]*healthcheck.BackendConfig,
) ([]handlerPostConfig, error) {
func (s *Server) applyConfiguration(ctx context.Context, configuration config.Configuration) map[string]http.Handler {
staticConfiguration := static.ConvertStaticConf(s.globalConfiguration)
frontend := config.Frontends[frontendName]
hostResolver := buildHostResolver(s.globalConfiguration)
if len(frontend.EntryPoints) == 0 {
return nil, fmt.Errorf("no entrypoint defined for frontend %s", frontendName)
var entryPoints []string
for entryPointName := range s.entryPoints {
entryPoints = append(entryPoints, entryPointName)
}
backend := config.Backends[frontend.Backend]
if backend == nil {
return nil, fmt.Errorf("undefined backend '%s' for frontend %s", frontend.Backend, frontendName)
}
serviceManager := service.NewManager(configuration.Services, s.defaultRoundTripper)
middlewaresBuilder := middleware.NewBuilder(configuration.Middlewares, serviceManager)
responseModifierFactory := responsemodifiers.NewBuilder(configuration.Middlewares)
frontendHash, err := frontend.Hash()
if err != nil {
return nil, fmt.Errorf("error calculating hash value for frontend %s: %v", frontendName, err)
}
routerManager := router.NewManager(configuration.Routers, serviceManager, middlewaresBuilder, responseModifierFactory)
var postConfigs []handlerPostConfig
handlers := routerManager.BuildHandlers(ctx, entryPoints, staticConfiguration.EntryPoints.Defaults)
for _, entryPointName := range frontend.EntryPoints {
log.Debugf("Wiring frontend %s to entryPoint %s", frontendName, entryPointName)
routerHandlers := make(map[string]http.Handler)
entryPoint := s.entryPoints[entryPointName].Configuration
for _, entryPointName := range entryPoints {
internalMuxRouter := mux.NewRouter().
SkipClean(true)
if backendsHandlers[entryPointName+providerName+frontendHash] == nil {
log.Debugf("Creating backend %s", frontend.Backend)
ctx = log.With(ctx, log.Str(log.EntryPointName, entryPointName))
handlers, responseModifier, postConfig, err := s.buildMiddlewares(frontendName, frontend, config.Backends, entryPointName, providerName)
if err != nil {
return nil, err
}
factory := s.entryPoints[entryPointName].RouteAppenderFactory
if factory != nil {
// FIXME remove currentConfigurations
appender := factory.NewAppender(ctx, middlewaresBuilder, &s.currentConfigurations)
appender.Append(internalMuxRouter)
}
if postConfig != nil {
postConfigs = append(postConfigs, postConfig)
}
fwd, err := s.buildForwarder(entryPointName, entryPoint, frontendName, frontend, responseModifier, backend)
if err != nil {
return nil, fmt.Errorf("failed to create the forwarder for frontend %s: %v", frontendName, err)
}
lb, healthCheckConfig, err := s.buildBalancerMiddlewares(frontendName, frontend, backend, fwd)
if err != nil {
return nil, err
}
// Handler used by error pages
if backendsHandlers[entryPointName+providerName+frontend.Backend] == nil {
backendsHandlers[entryPointName+providerName+frontend.Backend] = lb
}
if healthCheckConfig != nil {
backendsHealthCheck[entryPointName+providerName+frontendHash] = healthCheckConfig
}
n := negroni.New()
for _, handler := range handlers {
n.Use(handler)
}
n.UseHandler(lb)
backendsHandlers[entryPointName+providerName+frontendHash] = n
if h, ok := handlers[entryPointName]; ok {
internalMuxRouter.NotFoundHandler = h
} else {
log.Debugf("Reusing backend %s [%s - %s - %s - %s]",
frontend.Backend, entryPointName, providerName, frontendName, frontendHash)
internalMuxRouter.NotFoundHandler = s.buildDefaultHTTPRouter()
}
serverRoute, err := buildServerRoute(serverEntryPoints[entryPointName], frontendName, frontend, hostResolver)
routerHandlers[entryPointName] = internalMuxRouter
chain := alice.New()
if s.accessLoggerMiddleware != nil {
chain = chain.Append(accesslog.WrapHandler(s.accessLoggerMiddleware))
}
if s.tracer != nil {
chain = chain.Append(tracing.WrapEntryPointHandler(ctx, s.tracer, entryPointName))
}
chain = chain.Append(requestdecorator.WrapHandler(s.requestDecorator))
handler, err := chain.Then(internalMuxRouter.NotFoundHandler)
if err != nil {
return nil, err
}
handler := buildMatcherMiddlewares(serverRoute, backendsHandlers[entryPointName+providerName+frontendHash])
serverRoute.Route.Handler(handler)
err = serverRoute.Route.GetError()
if err != nil {
// FIXME error management
log.Errorf("Error building route: %s", err)
log.FromContext(ctx).Error(err)
continue
}
internalMuxRouter.NotFoundHandler = handler
}
return postConfigs, nil
return routerHandlers
}
func (s *Server) buildForwarder(entryPointName string, entryPoint *configuration.EntryPoint,
frontendName string, frontend *types.Frontend,
responseModifier modifyResponse, backend *types.Backend) (http.Handler, error) {
roundTripper, err := s.getRoundTripper(entryPointName, frontend.PassTLSCert, entryPoint.TLS)
if err != nil {
return nil, fmt.Errorf("failed to create RoundTripper for frontend %s: %v", frontendName, err)
}
var flushInterval parse.Duration
if backend.ResponseForwarding != nil {
err := flushInterval.Set(backend.ResponseForwarding.FlushInterval)
if err != nil {
return nil, fmt.Errorf("error creating flush interval for frontend %s: %v", frontendName, err)
}
}
var fwd http.Handler
fwd, err = forward.New(
forward.Stream(true),
forward.PassHostHeader(frontend.PassHostHeader),
forward.RoundTripper(roundTripper),
forward.ResponseModifier(responseModifier),
forward.BufferPool(s.bufferPool),
forward.StreamingFlushInterval(time.Duration(flushInterval)),
forward.WebsocketConnectionClosedHook(func(req *http.Request, conn net.Conn) {
server := req.Context().Value(http.ServerContextKey).(*http.Server)
if server != nil {
connState := server.ConnState
if connState != nil {
connState(conn, http.StateClosed)
}
}
}),
)
if err != nil {
return nil, fmt.Errorf("error creating forwarder for frontend %s: %v", frontendName, err)
}
if s.tracingMiddleware.IsEnabled() {
tm := s.tracingMiddleware.NewForwarderMiddleware(frontendName, frontend.Backend)
next := fwd
fwd = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tm.ServeHTTP(w, r, next.ServeHTTP)
})
}
fwd = pipelining.NewPipelining(fwd)
return fwd, nil
}
func buildServerRoute(serverEntryPoint *serverEntryPoint, frontendName string, frontend *types.Frontend, hostResolver *hostresolver.Resolver) (*types.ServerRoute, error) {
serverRoute := &types.ServerRoute{Route: serverEntryPoint.httpRouter.GetHandler().NewRoute().Name(frontendName)}
priority := 0
for routeName, route := range frontend.Routes {
rls := rules.Rules{Route: serverRoute, HostResolver: hostResolver}
newRoute, err := rls.Parse(route.Rule)
if err != nil {
return nil, fmt.Errorf("error creating route for frontend %s: %v", frontendName, err)
}
serverRoute.Route = newRoute
priority += len(route.Rule)
log.Debugf("Creating route %s %s", routeName, route.Rule)
}
if frontend.Priority > 0 {
serverRoute.Route.Priority(frontend.Priority)
} else {
serverRoute.Route.Priority(priority)
}
return serverRoute, nil
}
func (s *Server) preLoadConfiguration(configMsg types.ConfigMessage) {
func (s *Server) preLoadConfiguration(configMsg config.Message) {
providersThrottleDuration := time.Duration(s.globalConfiguration.ProvidersThrottleDuration)
s.defaultConfigurationValues(configMsg.Configuration)
currentConfigurations := s.currentConfigurations.Get().(types.Configurations)
currentConfigurations := s.currentConfigurations.Get().(config.Configurations)
logger := log.WithoutContext().WithField(log.ProviderName, configMsg.ProviderName)
if log.GetLevel() == logrus.DebugLevel {
jsonConf, _ := json.Marshal(configMsg.Configuration)
log.Debugf("Configuration received from provider %s: %s", configMsg.ProviderName, string(jsonConf))
logger.Debugf("Configuration received from provider %s: %s", configMsg.ProviderName, string(jsonConf))
}
if configMsg.Configuration == nil || configMsg.Configuration.Backends == nil && configMsg.Configuration.Frontends == nil && configMsg.Configuration.TLS == nil {
log.Infof("Skipping empty Configuration for provider %s", configMsg.ProviderName)
if configMsg.Configuration == nil || configMsg.Configuration.Routers == nil && configMsg.Configuration.Services == nil && configMsg.Configuration.Middlewares == nil && configMsg.Configuration.TLS == nil {
logger.Infof("Skipping empty Configuration for provider %s", configMsg.ProviderName)
return
}
if reflect.DeepEqual(currentConfigurations[configMsg.ProviderName], configMsg.Configuration) {
log.Infof("Skipping same configuration for provider %s", configMsg.ProviderName)
logger.Infof("Skipping same configuration for provider %s", configMsg.ProviderName)
return
}
providerConfigUpdateCh, ok := s.providerConfigUpdateMap[configMsg.ProviderName]
if !ok {
providerConfigUpdateCh = make(chan types.ConfigMessage)
providerConfigUpdateCh = make(chan config.Message)
s.providerConfigUpdateMap[configMsg.ProviderName] = providerConfigUpdateCh
s.routinesPool.Go(func(stop chan bool) {
s.throttleProviderConfigReload(providersThrottleDuration, s.configurationValidatedChan, providerConfigUpdateCh, stop)
@ -326,74 +206,8 @@ func (s *Server) preLoadConfiguration(configMsg types.ConfigMessage) {
providerConfigUpdateCh <- configMsg
}
func (s *Server) defaultConfigurationValues(configuration *types.Configuration) {
if configuration == nil || configuration.Frontends == nil {
return
}
s.configureFrontends(configuration.Frontends)
configureBackends(configuration.Backends)
}
func (s *Server) configureFrontends(frontends map[string]*types.Frontend) {
defaultEntrypoints := s.globalConfiguration.DefaultEntryPoints
for frontendName, frontend := range frontends {
// default endpoints if not defined in frontends
if len(frontend.EntryPoints) == 0 {
frontend.EntryPoints = defaultEntrypoints
}
frontendEntryPoints, undefinedEntryPoints := s.filterEntryPoints(frontend.EntryPoints)
if len(undefinedEntryPoints) > 0 {
log.Errorf("Undefined entry point(s) '%s' for frontend %s", strings.Join(undefinedEntryPoints, ","), frontendName)
}
frontend.EntryPoints = frontendEntryPoints
}
}
func (s *Server) filterEntryPoints(entryPoints []string) ([]string, []string) {
var frontendEntryPoints []string
var undefinedEntryPoints []string
for _, fepName := range entryPoints {
var exist bool
for epName := range s.entryPoints {
if epName == fepName {
exist = true
break
}
}
if exist {
frontendEntryPoints = append(frontendEntryPoints, fepName)
} else {
undefinedEntryPoints = append(undefinedEntryPoints, fepName)
}
}
return frontendEntryPoints, undefinedEntryPoints
}
func configureBackends(backends map[string]*types.Backend) {
for backendName := range backends {
backend := backends[backendName]
_, err := types.NewLoadBalancerMethod(backend.LoadBalancer)
if err != nil {
log.Debugf("Backend %s: %v", backendName, err)
var stickiness *types.Stickiness
if backend.LoadBalancer != nil {
stickiness = backend.LoadBalancer.Stickiness
}
backend.LoadBalancer = &types.LoadBalancer{
Method: "wrr",
Stickiness: stickiness,
}
}
}
func (s *Server) defaultConfigurationValues(configuration *config.Configuration) {
// FIXME create a config hook
}
func (s *Server) listenConfigurations(stop chan bool) {
@ -414,7 +228,7 @@ func (s *Server) listenConfigurations(stop chan bool) {
// It will immediately publish a new configuration and then only publish the next configuration after the throttle duration.
// Note that in the case it receives N new configs in the timeframe of the throttle duration after publishing,
// it will publish the last of the newly received configurations.
func (s *Server) throttleProviderConfigReload(throttle time.Duration, publish chan<- types.ConfigMessage, in <-chan types.ConfigMessage, stop chan bool) {
func (s *Server) throttleProviderConfigReload(throttle time.Duration, publish chan<- config.Message, in <-chan config.Message, stop chan bool) {
ring := channels.NewRingChannel(1)
defer ring.Close()
@ -424,7 +238,7 @@ func (s *Server) throttleProviderConfigReload(throttle time.Duration, publish ch
case <-stop:
return
case nextConfig := <-ring.Out():
if config, ok := nextConfig.(types.ConfigMessage); ok {
if config, ok := nextConfig.(config.Message); ok {
publish <- config
time.Sleep(throttle)
}
@ -442,95 +256,53 @@ func (s *Server) throttleProviderConfigReload(throttle time.Duration, publish ch
}
}
func buildMatcherMiddlewares(serverRoute *types.ServerRoute, handler http.Handler) http.Handler {
// path replace - This needs to always be the very last on the handler chain (first in the order in this function)
// -- Replacing Path should happen at the very end of the Modifier chain, after all the Matcher+Modifiers ran
if len(serverRoute.ReplacePath) > 0 {
handler = &middlewares.ReplacePath{
Path: serverRoute.ReplacePath,
Handler: handler,
}
}
if len(serverRoute.ReplacePathRegex) > 0 {
sp := strings.Split(serverRoute.ReplacePathRegex, " ")
if len(sp) == 2 {
handler = middlewares.NewReplacePathRegexHandler(sp[0], sp[1], handler)
} else {
log.Warnf("Invalid syntax for ReplacePathRegex: %s. Separate the regular expression and the replacement by a space.", serverRoute.ReplacePathRegex)
}
}
// add prefix - This needs to always be right before ReplacePath on the chain (second in order in this function)
// -- Adding Path Prefix should happen after all *Strip Matcher+Modifiers ran, but before Replace (in case it's configured)
if len(serverRoute.AddPrefix) > 0 {
handler = &middlewares.AddPrefix{
Prefix: serverRoute.AddPrefix,
Handler: handler,
}
}
// strip prefix
if len(serverRoute.StripPrefixes) > 0 {
handler = &middlewares.StripPrefix{
Prefixes: serverRoute.StripPrefixes,
Handler: handler,
}
}
// strip prefix with regex
if len(serverRoute.StripPrefixesRegex) > 0 {
handler = middlewares.NewStripPrefixRegex(handler, serverRoute.StripPrefixesRegex)
}
return handler
}
func (s *Server) postLoadConfiguration() {
if s.metricsRegistry.IsEnabled() {
activeConfig := s.currentConfigurations.Get().(types.Configurations)
metrics.OnConfigurationUpdate(activeConfig)
}
// FIXME metrics
// if s.metricsRegistry.IsEnabled() {
// activeConfig := s.currentConfigurations.Get().(config.Configurations)
// metrics.OnConfigurationUpdate(activeConfig)
// }
if s.globalConfiguration.ACME == nil || s.leadership == nil || !s.leadership.IsLeader() {
return
}
if s.globalConfiguration.ACME.OnHostRule {
currentConfigurations := s.currentConfigurations.Get().(types.Configurations)
for _, config := range currentConfigurations {
for _, frontend := range config.Frontends {
// check if one of the frontend entrypoints is configured with TLS
// and is configured with ACME
acmeEnabled := false
for _, entryPoint := range frontend.EntryPoints {
if s.globalConfiguration.ACME.EntryPoint == entryPoint && s.entryPoints[entryPoint].Configuration.TLS != nil {
acmeEnabled = true
break
}
}
if acmeEnabled {
for _, route := range frontend.Routes {
rls := rules.Rules{}
domains, err := rls.ParseDomains(route.Rule)
if err != nil {
log.Errorf("Error parsing domains: %v", err)
} else if len(domains) == 0 {
log.Debugf("No domain parsed in rule %q", route.Rule)
} else {
s.globalConfiguration.ACME.LoadCertificateForDomains(domains)
}
}
}
}
}
}
// FIXME acme
// if s.globalConfiguration.ACME.OnHostRule {
// currentConfigurations := s.currentConfigurations.Get().(config.Configurations)
// for _, config := range currentConfigurations {
// for _, frontend := range config.Frontends {
//
// // check if one of the frontend entrypoints is configured with TLS
// // and is configured with ACME
// acmeEnabled := false
// for _, entryPoint := range frontend.EntryPoints {
// if s.globalConfiguration.ACME.EntryPoint == entryPoint && s.entryPoints[entryPoint].Configuration.TLS != nil {
// acmeEnabled = true
// break
// }
// }
//
// if acmeEnabled {
// for _, route := range frontend.Routes {
// rls := rules.Rules{}
// domains, err := rls.ParseDomains(route.Rule)
// if err != nil {
// log.Errorf("Error parsing domains: %v", err)
// } else if len(domains) == 0 {
// log.Debugf("No domain parsed in rule %q", route.Rule)
// } else {
// s.globalConfiguration.ACME.LoadCertificateForDomains(domains)
// }
// }
// }
// }
// }
// }
}
// loadHTTPSConfiguration add/delete HTTPS certificate managed dynamically
func (s *Server) loadHTTPSConfiguration(configurations types.Configurations, defaultEntryPoints configuration.DefaultEntryPoints) map[string]map[string]*tls.Certificate {
func (s *Server) loadHTTPSConfiguration(configurations config.Configurations, defaultEntryPoints configuration.DefaultEntryPoints) map[string]map[string]*tls.Certificate {
newEPCertificates := make(map[string]map[string]*tls.Certificate)
// Get all certificates
for _, config := range configurations {
@ -543,9 +315,14 @@ func (s *Server) loadHTTPSConfiguration(configurations types.Configurations, def
func (s *Server) buildServerEntryPoints() map[string]*serverEntryPoint {
serverEntryPoints := make(map[string]*serverEntryPoint)
ctx := context.Background()
handlers := s.applyConfiguration(ctx, config.Configuration{})
for entryPointName, entryPoint := range s.entryPoints {
serverEntryPoints[entryPointName] = &serverEntryPoint{
httpRouter: middlewares.NewHandlerSwitcher(s.buildDefaultHTTPRouter()),
httpRouter: middlewares.NewHandlerSwitcher(handlers[entryPointName]),
onDemandListener: entryPoint.OnDemandListener,
tlsALPNGetter: entryPoint.TLSALPNGetter,
}
@ -557,19 +334,21 @@ func (s *Server) buildServerEntryPoints() map[string]*serverEntryPoint {
}
if entryPoint.Configuration.TLS != nil {
logger := log.FromContext(ctx).WithField(log.EntryPointName, entryPointName)
serverEntryPoints[entryPointName].certs.SniStrict = entryPoint.Configuration.TLS.SniStrict
if entryPoint.Configuration.TLS.DefaultCertificate != nil {
cert, err := buildDefaultCertificate(entryPoint.Configuration.TLS.DefaultCertificate)
if err != nil {
log.Error(err)
logger.Error(err)
continue
}
serverEntryPoints[entryPointName].certs.DefaultCertificate = cert
} else {
cert, err := generate.DefaultCertificate()
if err != nil {
log.Errorf("failed to generate default certificate: %v", err)
logger.Error(err)
continue
}
serverEntryPoints[entryPointName].certs.DefaultCertificate = cert
@ -585,6 +364,13 @@ func (s *Server) buildServerEntryPoints() map[string]*serverEntryPoint {
return serverEntryPoints
}
func (s *Server) buildDefaultHTTPRouter() *mux.Router {
rt := mux.NewRouter()
rt.NotFoundHandler = http.HandlerFunc(http.NotFound)
rt.SkipClean(true)
return rt
}
func buildDefaultCertificate(defaultCertificate *traefiktls.Certificate) (*tls.Certificate, error) {
certFile, err := defaultCertificate.CertFile.Read()
if err != nil {
@ -602,31 +388,3 @@ func buildDefaultCertificate(defaultCertificate *traefiktls.Certificate) (*tls.C
}
return &cert, nil
}
func (s *Server) buildDefaultHTTPRouter() *mux.Router {
rt := mux.NewRouter()
rt.NotFoundHandler = s.wrapHTTPHandlerWithAccessLog(http.HandlerFunc(http.NotFound), "backend not found")
rt.StrictSlash(!s.globalConfiguration.KeepTrailingSlash)
rt.SkipClean(true)
return rt
}
func sortedFrontendNamesForConfig(configuration *types.Configuration) []string {
var keys []string
for key := range configuration.Frontends {
keys = append(keys, key)
}
sort.Strings(keys)
return keys
}
func buildHostResolver(globalConfig configuration.GlobalConfiguration) *hostresolver.Resolver {
if globalConfig.HostResolver != nil {
return &hostresolver.Resolver{
CnameFlattening: globalConfig.HostResolver.CnameFlattening,
ResolvConfig: globalConfig.HostResolver.ResolvConfig,
ResolvDepth: globalConfig.HostResolver.ResolvDepth,
}
}
return nil
}

View file

@ -1,25 +1,16 @@
package server
import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/containous/flaeg/parse"
"github.com/containous/mux"
"github.com/containous/traefik/configuration"
"github.com/containous/traefik/healthcheck"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/rules"
"github.com/containous/traefik/config"
"github.com/containous/traefik/old/configuration"
th "github.com/containous/traefik/testhelpers"
"github.com/containous/traefik/tls"
"github.com/containous/traefik/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vulcand/oxy/roundrobin"
)
// LocalhostCert is a PEM-encoded TLS cert with SAN IPs
@ -60,137 +51,6 @@ f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA==
-----END RSA PRIVATE KEY-----`)
)
type testLoadBalancer struct{}
func (lb *testLoadBalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// noop
}
func (lb *testLoadBalancer) RemoveServer(u *url.URL) error {
return nil
}
func (lb *testLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error {
return nil
}
func (lb *testLoadBalancer) Servers() []*url.URL {
return []*url.URL{}
}
func TestServerLoadConfigHealthCheckOptions(t *testing.T) {
healthChecks := []*types.HealthCheck{
nil,
{
Path: "/path",
},
}
for _, lbMethod := range []string{"Wrr", "Drr"} {
for _, healthCheck := range healthChecks {
t.Run(fmt.Sprintf("%s/hc=%t", lbMethod, healthCheck != nil), func(t *testing.T) {
globalConfig := configuration.GlobalConfiguration{
HealthCheck: &configuration.HealthCheckConfig{
Interval: parse.Duration(5 * time.Second),
Timeout: parse.Duration(3 * time.Second),
},
}
entryPoints := map[string]EntryPoint{
"http": {
Configuration: &configuration.EntryPoint{
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
},
},
}
dynamicConfigs := types.Configurations{
"config": &types.Configuration{
Frontends: map[string]*types.Frontend{
"frontend": {
EntryPoints: []string{"http"},
Backend: "backend",
},
},
Backends: map[string]*types.Backend{
"backend": {
Servers: map[string]types.Server{
"server": {
URL: "http://localhost",
},
},
LoadBalancer: &types.LoadBalancer{
Method: lbMethod,
},
HealthCheck: healthCheck,
},
},
TLS: []*tls.Configuration{
{
Certificate: &tls.Certificate{
CertFile: localhostCert,
KeyFile: localhostKey,
},
EntryPoints: []string{"http"},
},
},
},
}
srv := NewServer(globalConfig, nil, entryPoints)
_ = srv.loadConfig(dynamicConfigs, globalConfig)
expectedNumHealthCheckBackends := 0
if healthCheck != nil {
expectedNumHealthCheckBackends = 1
}
assert.Len(t, healthcheck.GetHealthCheck(th.NewCollectingHealthCheckMetrics()).Backends, expectedNumHealthCheckBackends, "health check backends")
})
}
}
}
func TestServerLoadConfigEmptyBasicAuth(t *testing.T) {
globalConfig := configuration.GlobalConfiguration{
EntryPoints: configuration.EntryPoints{
"http": &configuration.EntryPoint{ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}},
},
}
dynamicConfigs := types.Configurations{
"config": &types.Configuration{
Frontends: map[string]*types.Frontend{
"frontend": {
EntryPoints: []string{"http"},
Backend: "backend",
},
},
Backends: map[string]*types.Backend{
"backend": {
Servers: map[string]types.Server{
"server": {
URL: "http://localhost",
},
},
LoadBalancer: &types.LoadBalancer{
Method: "Wrr",
},
},
},
},
}
entryPoints := map[string]EntryPoint{}
for key, value := range globalConfig.EntryPoints {
entryPoints[key] = EntryPoint{
Configuration: value,
}
}
srv := NewServer(globalConfig, nil, entryPoints)
_ = srv.loadConfig(dynamicConfigs, globalConfig)
}
func TestServerLoadCertificateWithDefaultEntryPoint(t *testing.T) {
globalConfig := configuration.GlobalConfiguration{
DefaultEntryPoints: []string{"http", "https"},
@ -200,8 +60,8 @@ func TestServerLoadCertificateWithDefaultEntryPoint(t *testing.T) {
"http": {Configuration: &configuration.EntryPoint{}},
}
dynamicConfigs := types.Configurations{
"config": &types.Configuration{
dynamicConfigs := config.Configurations{
"config": &config.Configuration{
TLS: []*tls.Configuration{
{
Certificate: &tls.Certificate{
@ -214,61 +74,58 @@ func TestServerLoadCertificateWithDefaultEntryPoint(t *testing.T) {
}
srv := NewServer(globalConfig, nil, entryPoints)
mapEntryPoints := srv.loadConfig(dynamicConfigs, globalConfig)
if !mapEntryPoints["https"].certs.ContainsCertificates() {
_, mapsCerts := srv.loadConfig(dynamicConfigs, globalConfig)
if len(mapsCerts["https"]) == 0 {
t.Fatal("got error: https entryPoint must have TLS certificates.")
}
}
func TestReuseBackend(t *testing.T) {
func TestReuseService(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
globalConfig := configuration.GlobalConfiguration{
DefaultEntryPoints: []string{"http"},
}
entryPoints := map[string]EntryPoint{
"http": {Configuration: &configuration.EntryPoint{
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
}},
}
dynamicConfigs := types.Configurations{
globalConfig := configuration.GlobalConfiguration{
DefaultEntryPoints: []string{"http"},
}
dynamicConfigs := config.Configurations{
"config": th.BuildConfiguration(
th.WithFrontends(
th.WithFrontend("backend",
th.WithFrontendName("frontend0"),
th.WithRouters(
th.WithRouter("foo",
th.WithServiceName("bar"),
th.WithRule("Path:/ok")),
th.WithRouter("foo2",
th.WithEntryPoints("http"),
th.WithRoutes(th.WithRoute("/ok", "Path: /ok"))),
th.WithFrontend("backend",
th.WithFrontendName("frontend1"),
th.WithEntryPoints("http"),
th.WithRoutes(th.WithRoute("/unauthorized", "Path: /unauthorized")),
th.WithFrontEndAuth(&types.Auth{
Basic: &types.Basic{
Users: []string{"foo:bar"},
},
})),
th.WithRule("Path:/unauthorized"),
th.WithServiceName("bar"),
th.WithRouterMiddlewares("basicauth")),
),
th.WithBackends(th.WithBackendNew("backend",
th.WithMiddlewares(th.WithMiddleware("basicauth",
th.WithBasicAuth(&config.BasicAuth{Users: []string{"foo:bar"}}),
)),
th.WithLoadBalancerServices(th.WithService("bar",
th.WithLBMethod("wrr"),
th.WithServersNew(th.WithServerNew(testServer.URL))),
th.WithServers(th.WithServer(testServer.URL))),
),
),
}
srv := NewServer(globalConfig, nil, entryPoints)
serverEntryPoints := srv.loadConfig(dynamicConfigs, globalConfig)
serverEntryPoints, _ := srv.loadConfig(dynamicConfigs, globalConfig)
// Test that the /ok path returns a status 200.
responseRecorderOk := &httptest.ResponseRecorder{}
requestOk := httptest.NewRequest(http.MethodGet, testServer.URL+"/ok", nil)
serverEntryPoints["http"].httpRouter.ServeHTTP(responseRecorderOk, requestOk)
serverEntryPoints["http"].ServeHTTP(responseRecorderOk, requestOk)
assert.Equal(t, http.StatusOK, responseRecorderOk.Result().StatusCode, "status code")
@ -276,15 +133,15 @@ func TestReuseBackend(t *testing.T) {
// the basic authentication defined on the frontend.
responseRecorderUnauthorized := &httptest.ResponseRecorder{}
requestUnauthorized := httptest.NewRequest(http.MethodGet, testServer.URL+"/unauthorized", nil)
serverEntryPoints["http"].httpRouter.ServeHTTP(responseRecorderUnauthorized, requestUnauthorized)
serverEntryPoints["http"].ServeHTTP(responseRecorderUnauthorized, requestUnauthorized)
assert.Equal(t, http.StatusUnauthorized, responseRecorderUnauthorized.Result().StatusCode, "status code")
}
func TestThrottleProviderConfigReload(t *testing.T) {
throttleDuration := 30 * time.Millisecond
publishConfig := make(chan types.ConfigMessage)
providerConfig := make(chan types.ConfigMessage)
publishConfig := make(chan config.Message)
providerConfig := make(chan config.Message)
stop := make(chan bool)
defer func() {
stop <- true
@ -312,7 +169,7 @@ func TestThrottleProviderConfigReload(t *testing.T) {
// publish 5 new configs, one new config each 10 milliseconds
for i := 0; i < 5; i++ {
providerConfig <- types.ConfigMessage{}
providerConfig <- config.Message{}
time.Sleep(10 * time.Millisecond)
}
@ -335,205 +192,3 @@ func TestThrottleProviderConfigReload(t *testing.T) {
t.Error("Last config was not published in time")
}
}
func TestServerMultipleFrontendRules(t *testing.T) {
testCases := []struct {
expression string
requestURL string
expectedURL string
}{
{
expression: "Host:foo.bar",
requestURL: "http://foo.bar",
expectedURL: "http://foo.bar",
},
{
expression: "PathPrefix:/management;ReplacePath:/health",
requestURL: "http://foo.bar/management",
expectedURL: "http://foo.bar/health",
},
{
expression: "Host:foo.bar;AddPrefix:/blah",
requestURL: "http://foo.bar/baz",
expectedURL: "http://foo.bar/blah/baz",
},
{
expression: "PathPrefixStripRegex:/one/{two}/{three:[0-9]+}",
requestURL: "http://foo.bar/one/some/12345/four",
expectedURL: "http://foo.bar/four",
},
{
expression: "PathPrefixStripRegex:/one/{two}/{three:[0-9]+};AddPrefix:/zero",
requestURL: "http://foo.bar/one/some/12345/four",
expectedURL: "http://foo.bar/zero/four",
},
{
expression: "AddPrefix:/blah;ReplacePath:/baz",
requestURL: "http://foo.bar/hello",
expectedURL: "http://foo.bar/baz",
},
{
expression: "PathPrefixStrip:/management;ReplacePath:/health",
requestURL: "http://foo.bar/management",
expectedURL: "http://foo.bar/health",
},
}
for _, test := range testCases {
test := test
t.Run(test.expression, func(t *testing.T) {
t.Parallel()
router := mux.NewRouter()
route := router.NewRoute()
serverRoute := &types.ServerRoute{Route: route}
reqHostMid := &middlewares.RequestHost{}
rls := &rules.Rules{Route: serverRoute}
expression := test.expression
routeResult, err := rls.Parse(expression)
if err != nil {
t.Fatalf("Error while building route for %s: %+v", expression, err)
}
request := th.MustNewRequest(http.MethodGet, test.requestURL, nil)
var routeMatch bool
reqHostMid.ServeHTTP(nil, request, func(w http.ResponseWriter, r *http.Request) {
routeMatch = routeResult.Match(r, &mux.RouteMatch{Route: routeResult})
})
if !routeMatch {
t.Fatalf("Rule %s doesn't match", expression)
}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, test.expectedURL, r.URL.String(), "URL")
})
hd := buildMatcherMiddlewares(serverRoute, handler)
serverRoute.Route.Handler(hd)
serverRoute.Route.GetHandler().ServeHTTP(nil, request)
})
}
}
func TestServerBuildHealthCheckOptions(t *testing.T) {
lb := &testLoadBalancer{}
globalInterval := 15 * time.Second
globalTimeout := 3 * time.Second
testCases := []struct {
desc string
hc *types.HealthCheck
expectedOpts *healthcheck.Options
}{
{
desc: "nil health check",
hc: nil,
expectedOpts: nil,
},
{
desc: "empty path",
hc: &types.HealthCheck{
Path: "",
},
expectedOpts: nil,
},
{
desc: "unparseable interval",
hc: &types.HealthCheck{
Path: "/path",
Interval: "unparseable",
},
expectedOpts: &healthcheck.Options{
Path: "/path",
Interval: globalInterval,
LB: lb,
Timeout: 3 * time.Second,
},
},
{
desc: "sub-zero interval",
hc: &types.HealthCheck{
Path: "/path",
Interval: "-42s",
},
expectedOpts: &healthcheck.Options{
Path: "/path",
Interval: globalInterval,
LB: lb,
Timeout: 3 * time.Second,
},
},
{
desc: "parseable interval",
hc: &types.HealthCheck{
Path: "/path",
Interval: "5m",
},
expectedOpts: &healthcheck.Options{
Path: "/path",
Interval: 5 * time.Minute,
LB: lb,
Timeout: 3 * time.Second,
},
},
{
desc: "unparseable timeout",
hc: &types.HealthCheck{
Path: "/path",
Interval: "15s",
Timeout: "unparseable",
},
expectedOpts: &healthcheck.Options{
Path: "/path",
Interval: globalInterval,
Timeout: globalTimeout,
LB: lb,
},
},
{
desc: "sub-zero timeout",
hc: &types.HealthCheck{
Path: "/path",
Interval: "15s",
Timeout: "-42s",
},
expectedOpts: &healthcheck.Options{
Path: "/path",
Interval: globalInterval,
Timeout: globalTimeout,
LB: lb,
},
},
{
desc: "parseable timeout",
hc: &types.HealthCheck{
Path: "/path",
Interval: "15s",
Timeout: "10s",
},
expectedOpts: &healthcheck.Options{
Path: "/path",
Interval: globalInterval,
Timeout: 10 * time.Second,
LB: lb,
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
opts := buildHealthCheckOptions(lb, "backend", test.hc, &configuration.HealthCheckConfig{
Interval: parse.Duration(globalInterval),
Timeout: parse.Duration(globalTimeout),
})
assert.Equal(t, test.expectedOpts, opts, "health check options")
})
}
}

View file

@ -1,439 +0,0 @@
package server
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"time"
"github.com/containous/traefik/configuration"
"github.com/containous/traefik/healthcheck"
"github.com/containous/traefik/log"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/middlewares/accesslog"
"github.com/containous/traefik/server/cookie"
traefiktls "github.com/containous/traefik/tls"
"github.com/containous/traefik/types"
"github.com/vulcand/oxy/buffer"
"github.com/vulcand/oxy/connlimit"
"github.com/vulcand/oxy/ratelimit"
"github.com/vulcand/oxy/roundrobin"
"github.com/vulcand/oxy/utils"
"golang.org/x/net/http2"
)
type h2cTransportWrapper struct {
*http2.Transport
}
func (t *h2cTransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) {
req.URL.Scheme = "http"
return t.Transport.RoundTrip(req)
}
func (s *Server) buildBalancerMiddlewares(frontendName string, frontend *types.Frontend, backend *types.Backend, fwd http.Handler) (http.Handler, *healthcheck.BackendConfig, error) {
balancer, err := s.buildLoadBalancer(frontendName, frontend.Backend, backend, fwd)
if err != nil {
return nil, nil, err
}
// Health Check
var backendHealthCheck *healthcheck.BackendConfig
if hcOpts := buildHealthCheckOptions(balancer, frontend.Backend, backend.HealthCheck, s.globalConfiguration.HealthCheck); hcOpts != nil {
log.Debugf("Setting up backend health check %s", *hcOpts)
hcOpts.Transport = s.defaultForwardingRoundTripper
backendHealthCheck = healthcheck.NewBackendConfig(*hcOpts, frontend.Backend)
}
// Empty (backend with no servers)
var lb http.Handler = middlewares.NewEmptyBackendHandler(balancer)
// Rate Limit
if frontend.RateLimit != nil && len(frontend.RateLimit.RateSet) > 0 {
handler, err := buildRateLimiter(lb, frontend.RateLimit)
if err != nil {
return nil, nil, fmt.Errorf("error creating rate limiter: %v", err)
}
lb = s.wrapHTTPHandlerWithAccessLog(
s.tracingMiddleware.NewHTTPHandlerWrapper("Rate limit", handler, false),
fmt.Sprintf("rate limit for %s", frontendName),
)
}
// Max Connections
if backend.MaxConn != nil && backend.MaxConn.Amount != 0 {
log.Debugf("Creating load-balancer connection limit")
handler, err := buildMaxConn(lb, backend.MaxConn)
if err != nil {
return nil, nil, err
}
lb = s.wrapHTTPHandlerWithAccessLog(handler, fmt.Sprintf("connection limit for %s", frontendName))
}
// Retry
if s.globalConfiguration.Retry != nil {
handler := s.buildRetryMiddleware(lb, s.globalConfiguration.Retry, len(backend.Servers), frontend.Backend)
lb = s.tracingMiddleware.NewHTTPHandlerWrapper("Retry", handler, false)
}
// Buffering
if backend.Buffering != nil {
handler, err := buildBufferingMiddleware(lb, backend.Buffering)
if err != nil {
return nil, nil, fmt.Errorf("error setting up buffering middleware: %s", err)
}
// TODO refactor ?
lb = handler
}
// Circuit Breaker
if backend.CircuitBreaker != nil {
log.Debugf("Creating circuit breaker %s", backend.CircuitBreaker.Expression)
expression := backend.CircuitBreaker.Expression
circuitBreaker, err := middlewares.NewCircuitBreaker(lb, expression, middlewares.NewCircuitBreakerOptions(expression))
if err != nil {
return nil, nil, fmt.Errorf("error creating circuit breaker: %v", err)
}
lb = s.tracingMiddleware.NewHTTPHandlerWrapper("Circuit breaker", circuitBreaker, false)
}
return lb, backendHealthCheck, nil
}
func (s *Server) buildLoadBalancer(frontendName string, backendName string, backend *types.Backend, fwd http.Handler) (healthcheck.BalancerHandler, error) {
var rr *roundrobin.RoundRobin
var saveFrontend http.Handler
if s.accessLoggerMiddleware != nil {
saveUsername := accesslog.NewSaveUsername(fwd)
saveBackend := accesslog.NewSaveBackend(saveUsername, backendName)
saveFrontend = accesslog.NewSaveFrontend(saveBackend, frontendName)
rr, _ = roundrobin.New(saveFrontend)
} else {
rr, _ = roundrobin.New(fwd)
}
var stickySession *roundrobin.StickySession
var cookieName string
if stickiness := backend.LoadBalancer.Stickiness; stickiness != nil {
cookieName = cookie.GetName(stickiness.CookieName, backendName)
stickySession = roundrobin.NewStickySession(cookieName)
}
lbMethod, err := types.NewLoadBalancerMethod(backend.LoadBalancer)
if err != nil {
return nil, fmt.Errorf("error loading load balancer method '%+v' for frontend %s: %v", backend.LoadBalancer, frontendName, err)
}
var lb healthcheck.BalancerHandler
switch lbMethod {
case types.Drr:
log.Debug("Creating load-balancer drr")
if stickySession != nil {
log.Debugf("Sticky session with cookie %v", cookieName)
lb, err = roundrobin.NewRebalancer(rr, roundrobin.RebalancerStickySession(stickySession))
if err != nil {
return nil, err
}
} else {
lb, err = roundrobin.NewRebalancer(rr)
if err != nil {
return nil, err
}
}
case types.Wrr:
log.Debug("Creating load-balancer wrr")
if stickySession != nil {
log.Debugf("Sticky session with cookie %v", cookieName)
if s.accessLoggerMiddleware != nil {
lb, err = roundrobin.New(saveFrontend, roundrobin.EnableStickySession(stickySession))
if err != nil {
return nil, err
}
} else {
lb, err = roundrobin.New(fwd, roundrobin.EnableStickySession(stickySession))
if err != nil {
return nil, err
}
}
} else {
lb = rr
}
default:
return nil, fmt.Errorf("invalid load-balancing method %q", lbMethod)
}
if err := s.configureLBServers(lb, backend, backendName); err != nil {
return nil, fmt.Errorf("error configuring load balancer for frontend %s: %v", frontendName, err)
}
return lb, nil
}
func (s *Server) configureLBServers(lb healthcheck.BalancerHandler, backend *types.Backend, backendName string) error {
for name, srv := range backend.Servers {
u, err := url.Parse(srv.URL)
if err != nil {
return fmt.Errorf("error parsing server URL %s: %v", srv.URL, err)
}
log.Debugf("Creating server %s at %s with weight %d", name, u, srv.Weight)
if err := lb.UpsertServer(u, roundrobin.Weight(srv.Weight)); err != nil {
return fmt.Errorf("error adding server %s to load balancer: %v", srv.URL, err)
}
s.metricsRegistry.BackendServerUpGauge().With("backend", backendName, "url", srv.URL).Set(1)
}
return nil
}
// getRoundTripper will either use server.defaultForwardingRoundTripper or create a new one
// given a custom TLS configuration is passed and the passTLSCert option is set to true.
func (s *Server) getRoundTripper(entryPointName string, passTLSCert bool, tls *traefiktls.TLS) (http.RoundTripper, error) {
if passTLSCert {
tlsConfig, err := createClientTLSConfig(entryPointName, tls)
if err != nil {
return nil, fmt.Errorf("failed to create TLSClientConfig: %v", err)
}
transport, err := createHTTPTransport(s.globalConfiguration)
if err != nil {
return nil, fmt.Errorf("failed to create HTTP transport: %v", err)
}
transport.TLSClientConfig = tlsConfig
return transport, nil
}
return s.defaultForwardingRoundTripper, nil
}
// createHTTPTransport creates an http.Transport configured with the GlobalConfiguration settings.
// For the settings that can't be configured in Traefik it uses the default http.Transport settings.
// An exception to this is the MaxIdleConns setting as we only provide the option MaxIdleConnsPerHost
// in Traefik at this point in time. Setting this value to the default of 100 could lead to confusing
// behavior and backwards compatibility issues.
func createHTTPTransport(globalConfiguration configuration.GlobalConfiguration) (*http.Transport, error) {
dialer := &net.Dialer{
Timeout: configuration.DefaultDialTimeout,
KeepAlive: 30 * time.Second,
DualStack: true,
}
if globalConfiguration.ForwardingTimeouts != nil {
dialer.Timeout = time.Duration(globalConfiguration.ForwardingTimeouts.DialTimeout)
}
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialer.DialContext,
MaxIdleConnsPerHost: globalConfiguration.MaxIdleConnsPerHost,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
transport.RegisterProtocol("h2c", &h2cTransportWrapper{
Transport: &http2.Transport{
DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
return net.Dial(netw, addr)
},
AllowHTTP: true,
},
})
if globalConfiguration.ForwardingTimeouts != nil {
transport.ResponseHeaderTimeout = time.Duration(globalConfiguration.ForwardingTimeouts.ResponseHeaderTimeout)
}
if globalConfiguration.InsecureSkipVerify {
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
if len(globalConfiguration.RootCAs) > 0 {
transport.TLSClientConfig = &tls.Config{
RootCAs: createRootCACertPool(globalConfiguration.RootCAs),
}
}
err := http2.ConfigureTransport(transport)
if err != nil {
return nil, err
}
return transport, nil
}
func createRootCACertPool(rootCAs traefiktls.FilesOrContents) *x509.CertPool {
roots := x509.NewCertPool()
for _, cert := range rootCAs {
certContent, err := cert.Read()
if err != nil {
log.Error("Error while read RootCAs", err)
continue
}
roots.AppendCertsFromPEM(certContent)
}
return roots
}
func createClientTLSConfig(entryPointName string, tlsOption *traefiktls.TLS) (*tls.Config, error) {
if tlsOption == nil {
return nil, errors.New("no TLS provided")
}
config, err := tlsOption.Certificates.CreateTLSConfig(entryPointName)
if err != nil {
return nil, err
}
if len(tlsOption.ClientCA.Files) > 0 {
pool := x509.NewCertPool()
for _, caFile := range tlsOption.ClientCA.Files {
data, err := caFile.Read()
if err != nil {
return nil, err
}
if !pool.AppendCertsFromPEM(data) {
return nil, fmt.Errorf("invalid certificate(s) in %s", caFile)
}
}
config.RootCAs = pool
}
config.BuildNameToCertificate()
return config, nil
}
func (s *Server) buildRetryMiddleware(handler http.Handler, retry *configuration.Retry, countServers int, backendName string) http.Handler {
retryListeners := middlewares.RetryListeners{}
if s.metricsRegistry.IsEnabled() {
retryListeners = append(retryListeners, middlewares.NewMetricsRetryListener(s.metricsRegistry, backendName))
}
if s.accessLoggerMiddleware != nil {
retryListeners = append(retryListeners, &accesslog.SaveRetries{})
}
retryAttempts := countServers
if retry.Attempts > 0 {
retryAttempts = retry.Attempts
}
log.Debugf("Creating retries max attempts %d", retryAttempts)
return middlewares.NewRetry(retryAttempts, handler, retryListeners)
}
func buildRateLimiter(handler http.Handler, rlConfig *types.RateLimit) (http.Handler, error) {
extractFunc, err := utils.NewExtractor(rlConfig.ExtractorFunc)
if err != nil {
return nil, err
}
log.Debugf("Creating load-balancer rate limiter")
rateSet := ratelimit.NewRateSet()
for _, rate := range rlConfig.RateSet {
if err := rateSet.Add(time.Duration(rate.Period), rate.Average, rate.Burst); err != nil {
return nil, err
}
}
return ratelimit.New(handler, extractFunc, rateSet)
}
func buildBufferingMiddleware(handler http.Handler, config *types.Buffering) (http.Handler, error) {
log.Debugf("Setting up buffering: request limits: %d (mem), %d (max), response limits: %d (mem), %d (max) with retry: '%s'",
config.MemRequestBodyBytes, config.MaxRequestBodyBytes, config.MemResponseBodyBytes,
config.MaxResponseBodyBytes, config.RetryExpression)
return buffer.New(
handler,
buffer.MemRequestBodyBytes(config.MemRequestBodyBytes),
buffer.MaxRequestBodyBytes(config.MaxRequestBodyBytes),
buffer.MemResponseBodyBytes(config.MemResponseBodyBytes),
buffer.MaxResponseBodyBytes(config.MaxResponseBodyBytes),
buffer.CondSetter(len(config.RetryExpression) > 0, buffer.Retry(config.RetryExpression)),
)
}
func buildMaxConn(lb http.Handler, maxConns *types.MaxConn) (http.Handler, error) {
extractFunc, err := utils.NewExtractor(maxConns.ExtractorFunc)
if err != nil {
return nil, fmt.Errorf("error creating connection limit: %v", err)
}
log.Debugf("Creating load-balancer connection limit")
handler, err := connlimit.New(lb, extractFunc, maxConns.Amount)
if err != nil {
return nil, fmt.Errorf("error creating connection limit: %v", err)
}
return handler, nil
}
func buildHealthCheckOptions(lb healthcheck.BalancerHandler, backend string, hc *types.HealthCheck, hcConfig *configuration.HealthCheckConfig) *healthcheck.Options {
if hc == nil || hc.Path == "" || hcConfig == nil {
return nil
}
interval := time.Duration(hcConfig.Interval)
if hc.Interval != "" {
intervalOverride, err := time.ParseDuration(hc.Interval)
if err != nil {
log.Errorf("Illegal health check interval for backend '%s': %s", backend, err)
} else if intervalOverride <= 0 {
log.Errorf("Health check interval smaller than zero for backend '%s', backend", backend)
} else {
interval = intervalOverride
}
}
timeout := time.Duration(hcConfig.Timeout)
if hc.Timeout != "" {
timeoutOverride, err := time.ParseDuration(hc.Timeout)
if err != nil {
log.Errorf("Illegal health check timeout for backend '%s': %s", backend, err)
} else if timeoutOverride <= 0 {
log.Errorf("Health check timeout smaller than zero for backend '%s', backend", backend)
} else {
timeout = timeoutOverride
}
}
if timeout >= interval {
log.Warnf("Health check timeout for backend '%s' should be lower than the health check interval. Interval set to timeout + 1 second (%s).", backend)
}
return &healthcheck.Options{
Scheme: hc.Scheme,
Path: hc.Path,
Port: hc.Port,
Interval: interval,
Timeout: timeout,
LB: lb,
Hostname: hc.Hostname,
Headers: hc.Headers,
}
}

View file

@ -1,81 +0,0 @@
package server
import (
"testing"
"github.com/containous/traefik/types"
"github.com/stretchr/testify/assert"
)
func TestConfigureBackends(t *testing.T) {
validMethod := "Drr"
defaultMethod := "wrr"
testCases := []struct {
desc string
lb *types.LoadBalancer
expectedMethod string
expectedStickiness *types.Stickiness
}{
{
desc: "valid load balancer method with sticky enabled",
lb: &types.LoadBalancer{
Method: validMethod,
Stickiness: &types.Stickiness{},
},
expectedMethod: validMethod,
expectedStickiness: &types.Stickiness{},
},
{
desc: "valid load balancer method with sticky disabled",
lb: &types.LoadBalancer{
Method: validMethod,
Stickiness: nil,
},
expectedMethod: validMethod,
},
{
desc: "invalid load balancer method with sticky enabled",
lb: &types.LoadBalancer{
Method: "Invalid",
Stickiness: &types.Stickiness{},
},
expectedMethod: defaultMethod,
expectedStickiness: &types.Stickiness{},
},
{
desc: "invalid load balancer method with sticky disabled",
lb: &types.LoadBalancer{
Method: "Invalid",
Stickiness: nil,
},
expectedMethod: defaultMethod,
},
{
desc: "missing load balancer",
lb: nil,
expectedMethod: defaultMethod,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
backend := &types.Backend{
LoadBalancer: test.lb,
}
configureBackends(map[string]*types.Backend{
"backend": backend,
})
expected := types.LoadBalancer{
Method: test.expectedMethod,
Stickiness: test.expectedStickiness,
}
assert.Equal(t, expected, *backend.LoadBalancer)
})
}
}

View file

@ -1,350 +0,0 @@
package server
import (
"fmt"
"net/http"
"github.com/containous/traefik/log"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/middlewares/accesslog"
mauth "github.com/containous/traefik/middlewares/auth"
"github.com/containous/traefik/middlewares/errorpages"
"github.com/containous/traefik/middlewares/forwardedheaders"
"github.com/containous/traefik/middlewares/redirect"
"github.com/containous/traefik/types"
thoas_stats "github.com/thoas/stats"
"github.com/unrolled/secure"
"github.com/urfave/negroni"
)
type handlerPostConfig func(backendsHandlers map[string]http.Handler) error
type modifyResponse func(*http.Response) error
func (s *Server) buildMiddlewares(frontendName string, frontend *types.Frontend,
backends map[string]*types.Backend, entryPointName string, providerName string) ([]negroni.Handler, modifyResponse, handlerPostConfig, error) {
var middle []negroni.Handler
var postConfig handlerPostConfig
// Error pages
if len(frontend.Errors) > 0 {
handlers, err := buildErrorPagesMiddleware(frontendName, frontend, backends, entryPointName, providerName)
if err != nil {
return nil, nil, nil, err
}
postConfig = errorPagesPostConfig(handlers)
for _, handler := range handlers {
middle = append(middle, handler)
}
}
// Metrics
if s.metricsRegistry.IsEnabled() {
handler := middlewares.NewBackendMetricsMiddleware(s.metricsRegistry, frontend.Backend)
middle = append(middle, handler)
}
// Whitelist
ipWhitelistMiddleware, err := buildIPWhiteLister(frontend.WhiteList, s.entryPoints[entryPointName].Configuration.ClientIPStrategy)
if err != nil {
return nil, nil, nil, fmt.Errorf("error creating IP Whitelister: %s", err)
}
if ipWhitelistMiddleware != nil {
log.Debugf("Configured IP Whitelists: %v", frontend.WhiteList.SourceRange)
handler := s.tracingMiddleware.NewNegroniHandlerWrapper(
"IP whitelist",
s.wrapNegroniHandlerWithAccessLog(ipWhitelistMiddleware, fmt.Sprintf("ipwhitelister for %s", frontendName)),
false)
middle = append(middle, handler)
}
// Redirect
if frontend.Redirect != nil && entryPointName != frontend.Redirect.EntryPoint {
rewrite, err := s.buildRedirectHandler(entryPointName, frontend.Redirect)
if err != nil {
return nil, nil, nil, fmt.Errorf("error creating Frontend Redirect: %v", err)
}
handler := s.wrapNegroniHandlerWithAccessLog(rewrite, fmt.Sprintf("frontend redirect for %s", frontendName))
middle = append(middle, handler)
log.Debugf("Frontend %s redirect created", frontendName)
}
// Header
headerMiddleware := middlewares.NewHeaderFromStruct(frontend.Headers)
if headerMiddleware != nil {
log.Debugf("Adding header middleware for frontend %s", frontendName)
handler := s.tracingMiddleware.NewNegroniHandlerWrapper("Header", headerMiddleware, false)
middle = append(middle, handler)
}
// Secure
secureMiddleware := middlewares.NewSecure(frontend.Headers)
if secureMiddleware != nil {
log.Debugf("Adding secure middleware for frontend %s", frontendName)
handler := negroni.HandlerFunc(secureMiddleware.HandlerFuncWithNextForRequestOnly)
middle = append(middle, handler)
}
// Authentication
if frontend.Auth != nil {
authMiddleware, err := mauth.NewAuthenticator(frontend.Auth, s.tracingMiddleware)
if err != nil {
return nil, nil, nil, err
}
handler := s.wrapNegroniHandlerWithAccessLog(authMiddleware, fmt.Sprintf("Auth for %s", frontendName))
middle = append(middle, handler)
}
// TLSClientHeaders
tlsClientHeadersMiddleware := middlewares.NewTLSClientHeaders(frontend)
if tlsClientHeadersMiddleware != nil {
log.Debugf("Adding TLSClientHeaders middleware for frontend %s", frontendName)
handler := s.tracingMiddleware.NewNegroniHandlerWrapper("TLSClientHeaders", tlsClientHeadersMiddleware, false)
middle = append(middle, handler)
}
return middle, buildModifyResponse(secureMiddleware, headerMiddleware), postConfig, nil
}
func (s *Server) buildServerEntryPointMiddlewares(serverEntryPointName string) ([]negroni.Handler, error) {
serverMiddlewares := []negroni.Handler{middlewares.NegroniRecoverHandler()}
if s.tracingMiddleware.IsEnabled() {
serverMiddlewares = append(serverMiddlewares, s.tracingMiddleware.NewEntryPoint(serverEntryPointName))
}
if s.accessLoggerMiddleware != nil {
serverMiddlewares = append(serverMiddlewares, s.accessLoggerMiddleware)
}
if s.metricsRegistry.IsEnabled() {
serverMiddlewares = append(serverMiddlewares, middlewares.NewEntryPointMetricsMiddleware(s.metricsRegistry, serverEntryPointName))
}
if s.globalConfiguration.API != nil {
if s.globalConfiguration.API.Stats == nil {
s.globalConfiguration.API.Stats = thoas_stats.New()
}
serverMiddlewares = append(serverMiddlewares, s.globalConfiguration.API.Stats)
if s.globalConfiguration.API.Statistics != nil {
if s.globalConfiguration.API.StatsRecorder == nil {
s.globalConfiguration.API.StatsRecorder = middlewares.NewStatsRecorder(s.globalConfiguration.API.Statistics.RecentErrors)
}
serverMiddlewares = append(serverMiddlewares, s.globalConfiguration.API.StatsRecorder)
}
}
if s.entryPoints[serverEntryPointName].Configuration.Redirect != nil {
redirectHandlers, err := s.buildEntryPointRedirect()
if err != nil {
return nil, fmt.Errorf("failed to create redirect middleware: %v", err)
}
serverMiddlewares = append(serverMiddlewares, redirectHandlers[serverEntryPointName])
}
if s.entryPoints[serverEntryPointName].Configuration.Auth != nil {
authMiddleware, err := mauth.NewAuthenticator(s.entryPoints[serverEntryPointName].Configuration.Auth, s.tracingMiddleware)
if err != nil {
return nil, fmt.Errorf("failed to create authentication middleware: %v", err)
}
serverMiddlewares = append(serverMiddlewares, s.wrapNegroniHandlerWithAccessLog(authMiddleware, fmt.Sprintf("Auth for entrypoint %s", serverEntryPointName)))
}
if s.entryPoints[serverEntryPointName].Configuration.Compress != nil {
serverMiddlewares = append(serverMiddlewares, &middlewares.Compress{})
}
if s.entryPoints[serverEntryPointName].Configuration.ForwardedHeaders != nil {
xForwardedMiddleware, err := forwardedheaders.NewXforwarded(
s.entryPoints[serverEntryPointName].Configuration.ForwardedHeaders.Insecure,
s.entryPoints[serverEntryPointName].Configuration.ForwardedHeaders.TrustedIPs,
)
if err != nil {
return nil, fmt.Errorf("failed to create xforwarded headers middleware: %v", err)
}
serverMiddlewares = append(serverMiddlewares, xForwardedMiddleware)
}
ipWhitelistMiddleware, err := buildIPWhiteLister(s.entryPoints[serverEntryPointName].Configuration.WhiteList, s.entryPoints[serverEntryPointName].Configuration.ClientIPStrategy)
if err != nil {
return nil, fmt.Errorf("failed to create ip whitelist middleware: %v", err)
}
if ipWhitelistMiddleware != nil {
serverMiddlewares = append(serverMiddlewares, s.wrapNegroniHandlerWithAccessLog(ipWhitelistMiddleware, fmt.Sprintf("ipwhitelister for entrypoint %s", serverEntryPointName)))
}
// RequestHost Cannonizer
serverMiddlewares = append(serverMiddlewares, &middlewares.RequestHost{})
return serverMiddlewares, nil
}
func errorPagesPostConfig(epHandlers []*errorpages.Handler) handlerPostConfig {
return func(backendsHandlers map[string]http.Handler) error {
for _, errorPageHandler := range epHandlers {
if handler, ok := backendsHandlers[errorPageHandler.BackendName]; ok {
err := errorPageHandler.PostLoad(handler)
if err != nil {
return fmt.Errorf("failed to configure error pages for backend %s: %v", errorPageHandler.BackendName, err)
}
} else {
err := errorPageHandler.PostLoad(nil)
if err != nil {
return fmt.Errorf("failed to configure error pages for %s: %v", errorPageHandler.FallbackURL, err)
}
}
}
return nil
}
}
func buildErrorPagesMiddleware(frontendName string, frontend *types.Frontend, backends map[string]*types.Backend, entryPointName string, providerName string) ([]*errorpages.Handler, error) {
var errorPageHandlers []*errorpages.Handler
for errorPageName, errorPage := range frontend.Errors {
if frontend.Backend == errorPage.Backend {
log.Errorf("Error when creating error page %q for frontend %q: error pages backend %q is the same as backend for the frontend (infinite call risk).",
errorPageName, frontendName, errorPage.Backend)
} else if backends[errorPage.Backend] == nil {
log.Errorf("Error when creating error page %q for frontend %q: the backend %q doesn't exist.",
errorPageName, frontendName, errorPage.Backend)
} else {
errorPagesHandler, err := errorpages.NewHandler(errorPage, entryPointName+providerName+errorPage.Backend)
if err != nil {
return nil, fmt.Errorf("error creating error pages: %v", err)
}
if errorPageServer, ok := backends[errorPage.Backend].Servers["error"]; ok {
errorPagesHandler.FallbackURL = errorPageServer.URL
}
errorPageHandlers = append(errorPageHandlers, errorPagesHandler)
}
}
return errorPageHandlers, nil
}
func (s *Server) buildBasicAuthMiddleware(authData []string) (*mauth.Authenticator, error) {
users := types.Users{}
for _, user := range authData {
users = append(users, user)
}
auth := &types.Auth{}
auth.Basic = &types.Basic{
Users: users,
}
authMiddleware, err := mauth.NewAuthenticator(auth, s.tracingMiddleware)
if err != nil {
return nil, fmt.Errorf("error creating Basic Auth: %v", err)
}
return authMiddleware, nil
}
func (s *Server) buildEntryPointRedirect() (map[string]negroni.Handler, error) {
redirectHandlers := map[string]negroni.Handler{}
for entryPointName, ep := range s.entryPoints {
entryPoint := ep.Configuration
if entryPoint.Redirect != nil && entryPointName != entryPoint.Redirect.EntryPoint {
handler, err := s.buildRedirectHandler(entryPointName, entryPoint.Redirect)
if err != nil {
return nil, fmt.Errorf("error loading configuration for entrypoint %s: %v", entryPointName, err)
}
handlerToUse := s.wrapNegroniHandlerWithAccessLog(handler, fmt.Sprintf("entrypoint redirect for %s", entryPointName))
redirectHandlers[entryPointName] = handlerToUse
}
}
return redirectHandlers, nil
}
func (s *Server) buildRedirectHandler(srcEntryPointName string, opt *types.Redirect) (negroni.Handler, error) {
// entry point redirect
if len(opt.EntryPoint) > 0 {
entryPoint := s.entryPoints[opt.EntryPoint].Configuration
if entryPoint == nil {
return nil, fmt.Errorf("unknown target entrypoint %q", srcEntryPointName)
}
log.Debugf("Creating entry point redirect %s -> %s", srcEntryPointName, opt.EntryPoint)
return redirect.NewEntryPointHandler(entryPoint, opt.Permanent)
}
// regex redirect
redirection, err := redirect.NewRegexHandler(opt.Regex, opt.Replacement, opt.Permanent)
if err != nil {
return nil, err
}
log.Debugf("Creating regex redirect %s -> %s -> %s", srcEntryPointName, opt.Regex, opt.Replacement)
return redirection, nil
}
func buildIPWhiteLister(whiteList *types.WhiteList, ipStrategy *types.IPStrategy) (*middlewares.IPWhiteLister, error) {
if whiteList == nil {
return nil, nil
}
if whiteList.IPStrategy != nil {
ipStrategy = whiteList.IPStrategy
}
strategy, err := ipStrategy.Get()
if err != nil {
return nil, err
}
return middlewares.NewIPWhiteLister(whiteList.SourceRange, strategy)
}
func (s *Server) wrapNegroniHandlerWithAccessLog(handler negroni.Handler, frontendName string) negroni.Handler {
if s.accessLoggerMiddleware != nil {
saveUsername := accesslog.NewSaveNegroniUsername(handler)
saveBackend := accesslog.NewSaveNegroniBackend(saveUsername, "Traefik")
saveFrontend := accesslog.NewSaveNegroniFrontend(saveBackend, frontendName)
return saveFrontend
}
return handler
}
func (s *Server) wrapHTTPHandlerWithAccessLog(handler http.Handler, frontendName string) http.Handler {
if s.accessLoggerMiddleware != nil {
saveUsername := accesslog.NewSaveUsername(handler)
saveBackend := accesslog.NewSaveBackend(saveUsername, "Traefik")
saveFrontend := accesslog.NewSaveFrontend(saveBackend, frontendName)
return saveFrontend
}
return handler
}
func buildModifyResponse(secure *secure.Secure, header *middlewares.HeaderStruct) func(res *http.Response) error {
return func(res *http.Response) error {
if secure != nil {
if err := secure.ModifyResponseHeaders(res); err != nil {
return err
}
}
if header != nil {
if err := header.ModifyResponseHeaders(res); err != nil {
return err
}
}
return nil
}
}

View file

@ -1,270 +0,0 @@
package server
import (
"net/http"
"net/http/httptest"
"reflect"
"testing"
"github.com/containous/mux"
"github.com/containous/traefik/configuration"
"github.com/containous/traefik/metrics"
"github.com/containous/traefik/middlewares"
th "github.com/containous/traefik/testhelpers"
"github.com/containous/traefik/tls"
"github.com/containous/traefik/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/negroni"
)
func TestServerEntryPointWhitelistConfig(t *testing.T) {
testCases := []struct {
desc string
entrypoint *configuration.EntryPoint
expectMiddleware bool
}{
{
desc: "no whitelist middleware if no config on entrypoint",
entrypoint: &configuration.EntryPoint{
Address: ":0",
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
},
expectMiddleware: false,
},
{
desc: "whitelist middleware should be added if configured on entrypoint",
entrypoint: &configuration.EntryPoint{
Address: ":0",
WhiteList: &types.WhiteList{
SourceRange: []string{
"127.0.0.1/32",
},
},
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
},
expectMiddleware: true,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
srv := Server{
globalConfiguration: configuration.GlobalConfiguration{},
metricsRegistry: metrics.NewVoidRegistry(),
entryPoints: map[string]EntryPoint{
"test": {
Configuration: test.entrypoint,
},
},
}
srv.serverEntryPoints = srv.buildServerEntryPoints()
srvEntryPoint := srv.setupServerEntryPoint("test", srv.serverEntryPoints["test"])
handler := srvEntryPoint.httpServer.Handler.(*mux.Router).NotFoundHandler.(*negroni.Negroni)
found := false
for _, handler := range handler.Handlers() {
if reflect.TypeOf(handler) == reflect.TypeOf((*middlewares.IPWhiteLister)(nil)) {
found = true
}
}
if found && !test.expectMiddleware {
t.Error("ip whitelist middleware was installed even though it should not")
}
if !found && test.expectMiddleware {
t.Error("ip whitelist middleware was not installed even though it should have")
}
})
}
}
func TestBuildIPWhiteLister(t *testing.T) {
testCases := []struct {
desc string
whitelistSourceRange []string
whiteList *types.WhiteList
middlewareConfigured bool
errMessage string
}{
{
desc: "no whitelists configured",
whitelistSourceRange: nil,
middlewareConfigured: false,
errMessage: "",
},
{
desc: "whitelists configured",
whiteList: &types.WhiteList{
SourceRange: []string{
"1.2.3.4/24",
"fe80::/16",
},
},
middlewareConfigured: true,
errMessage: "",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
middleware, err := buildIPWhiteLister(test.whiteList, nil)
if test.errMessage != "" {
require.EqualError(t, err, test.errMessage)
} else {
assert.NoError(t, err)
if test.middlewareConfigured {
require.NotNil(t, middleware, "not expected middleware to be configured")
} else {
require.Nil(t, middleware, "expected middleware to be configured")
}
}
})
}
}
func TestBuildRedirectHandler(t *testing.T) {
srv := Server{
globalConfiguration: configuration.GlobalConfiguration{},
entryPoints: map[string]EntryPoint{
"http": {Configuration: &configuration.EntryPoint{Address: ":80"}},
"https": {Configuration: &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}}},
},
}
testCases := []struct {
desc string
srcEntryPointName string
url string
entryPoint *configuration.EntryPoint
redirect *types.Redirect
expectedURL string
}{
{
desc: "redirect regex",
srcEntryPointName: "http",
url: "http://foo.com",
redirect: &types.Redirect{
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
Replacement: "https://$1{{\"bar\"}}$2",
},
entryPoint: &configuration.EntryPoint{
Address: ":80",
Redirect: &types.Redirect{
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
Replacement: "https://$1{{\"bar\"}}$2",
},
},
expectedURL: "https://foobar.com",
},
{
desc: "redirect entry point",
srcEntryPointName: "http",
url: "http://foo:80",
redirect: &types.Redirect{
EntryPoint: "https",
},
entryPoint: &configuration.EntryPoint{
Address: ":80",
Redirect: &types.Redirect{
EntryPoint: "https",
},
},
expectedURL: "https://foo:443",
},
{
desc: "redirect entry point with regex (ignored)",
srcEntryPointName: "http",
url: "http://foo.com:80",
redirect: &types.Redirect{
EntryPoint: "https",
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
Replacement: "https://$1{{\"bar\"}}$2",
},
entryPoint: &configuration.EntryPoint{
Address: ":80",
Redirect: &types.Redirect{
EntryPoint: "https",
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
Replacement: "https://$1{{\"bar\"}}$2",
},
},
expectedURL: "https://foo.com:443",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
rewrite, err := srv.buildRedirectHandler(test.srcEntryPointName, test.redirect)
require.NoError(t, err)
req := th.MustNewRequest(http.MethodGet, test.url, nil)
recorder := httptest.NewRecorder()
rewrite.ServeHTTP(recorder, req, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Location", "fail")
}))
location, err := recorder.Result().Location()
require.NoError(t, err)
assert.Equal(t, test.expectedURL, location.String())
})
}
}
func TestServerGenericFrontendAuthFail(t *testing.T) {
globalConfig := configuration.GlobalConfiguration{
EntryPoints: configuration.EntryPoints{
"http": &configuration.EntryPoint{ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}},
},
}
entryPoints := map[string]EntryPoint{
"http": {
Configuration: globalConfig.EntryPoints["http"],
},
}
dynamicConfigs := types.Configurations{
"config": &types.Configuration{
Frontends: map[string]*types.Frontend{
"frontend": {
EntryPoints: []string{"http"},
Backend: "backend",
Auth: &types.Auth{
Basic: &types.Basic{
Users: []string{""},
}},
},
},
Backends: map[string]*types.Backend{
"backend": {
Servers: map[string]types.Server{
"server": {
URL: "http://localhost",
},
},
LoadBalancer: &types.LoadBalancer{
Method: "Wrr",
},
},
},
},
}
srv := NewServer(globalConfig, nil, entryPoints)
_ = srv.loadConfig(dynamicConfigs, globalConfig)
}

View file

@ -21,16 +21,16 @@ func (s *Server) listenSignals(stop chan bool) {
case sig := <-s.signals:
switch sig {
case syscall.SIGUSR1:
log.Infof("Closing and re-opening log files for rotation: %+v", sig)
log.WithoutContext().Infof("Closing and re-opening log files for rotation: %+v", sig)
if s.accessLoggerMiddleware != nil {
if err := s.accessLoggerMiddleware.Rotate(); err != nil {
log.Errorf("Error rotating access log: %v", err)
log.WithoutContext().Errorf("Error rotating access log: %v", err)
}
}
if err := log.RotateFile(); err != nil {
log.Errorf("Error rotating traefik log: %v", err)
log.WithoutContext().Errorf("Error rotating traefik log: %v", err)
}
}
}

View file

@ -9,13 +9,12 @@ import (
"github.com/containous/flaeg/parse"
"github.com/containous/mux"
"github.com/containous/traefik/configuration"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/old/configuration"
th "github.com/containous/traefik/testhelpers"
"github.com/containous/traefik/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/unrolled/secure"
)
func TestPrepareServerTimeouts(t *testing.T) {
@ -62,7 +61,7 @@ func TestPrepareServerTimeouts(t *testing.T) {
router := middlewares.NewHandlerSwitcher(mux.NewRouter())
srv := NewServer(test.globalConfig, nil, nil)
httpServer, _, err := srv.prepareServer(entryPointName, entryPoint, router, nil)
httpServer, _, err := srv.prepareServer(context.Background(), entryPointName, entryPoint, router)
require.NoError(t, err, "Unexpected error when preparing srv")
assert.Equal(t, test.expectedIdleTimeout, httpServer.IdleTimeout, "IdleTimeout")
@ -87,7 +86,7 @@ func TestListenProvidersSkipsEmptyConfigs(t *testing.T) {
}
}()
server.configurationChan <- types.ConfigMessage{ProviderName: "kubernetes"}
server.configurationChan <- config.Message{ProviderName: "kubernetes"}
// give some time so that the configuration can be processed
time.Sleep(100 * time.Millisecond)
@ -103,12 +102,12 @@ func TestListenProvidersSkipsSameConfigurationForProvider(t *testing.T) {
select {
case <-stop:
return
case config := <-server.configurationValidatedChan:
case conf := <-server.configurationValidatedChan:
// set the current configuration
// this is usually done in the processing part of the published configuration
// so we have to emulate the behavior here
currentConfigurations := server.currentConfigurations.Get().(types.Configurations)
currentConfigurations[config.ProviderName] = config.Configuration
currentConfigurations := server.currentConfigurations.Get().(config.Configurations)
currentConfigurations[conf.ProviderName] = conf.Configuration
server.currentConfigurations.Set(currentConfigurations)
publishedConfigCount++
@ -119,19 +118,19 @@ func TestListenProvidersSkipsSameConfigurationForProvider(t *testing.T) {
}
}()
config := th.BuildConfiguration(
th.WithFrontends(th.WithFrontend("backend")),
th.WithBackends(th.WithBackendNew("backend")),
conf := th.BuildConfiguration(
th.WithRouters(th.WithRouter("foo")),
th.WithLoadBalancerServices(th.WithService("bar")),
)
// provide a configuration
server.configurationChan <- types.ConfigMessage{ProviderName: "kubernetes", Configuration: config}
server.configurationChan <- config.Message{ProviderName: "kubernetes", Configuration: conf}
// give some time so that the configuration can be processed
time.Sleep(20 * time.Millisecond)
// provide the same configuration a second time
server.configurationChan <- types.ConfigMessage{ProviderName: "kubernetes", Configuration: config}
server.configurationChan <- config.Message{ProviderName: "kubernetes", Configuration: conf}
// give some time so that the configuration can be processed
time.Sleep(100 * time.Millisecond)
@ -160,12 +159,12 @@ func TestListenProvidersPublishesConfigForEachProvider(t *testing.T) {
}
}()
config := th.BuildConfiguration(
th.WithFrontends(th.WithFrontend("backend")),
th.WithBackends(th.WithBackendNew("backend")),
conf := th.BuildConfiguration(
th.WithRouters(th.WithRouter("foo")),
th.WithLoadBalancerServices(th.WithService("bar")),
)
server.configurationChan <- types.ConfigMessage{ProviderName: "kubernetes", Configuration: config}
server.configurationChan <- types.ConfigMessage{ProviderName: "marathon", Configuration: config}
server.configurationChan <- config.Message{ProviderName: "kubernetes", Configuration: conf}
server.configurationChan <- config.Message{ProviderName: "marathon", Configuration: conf}
select {
case <-consumePublishedConfigsDone:
@ -206,20 +205,21 @@ func TestServerResponseEmptyBackend(t *testing.T) {
testCases := []struct {
desc string
config func(testServerURL string) *types.Configuration
config func(testServerURL string) *config.Configuration
expectedStatusCode int
}{
{
desc: "Ok",
config: func(testServerURL string) *types.Configuration {
config: func(testServerURL string) *config.Configuration {
return th.BuildConfiguration(
th.WithFrontends(th.WithFrontend("backend",
th.WithRouters(th.WithRouter("foo",
th.WithEntryPoints("http"),
th.WithRoutes(th.WithRoute(requestPath, routeRule))),
th.WithServiceName("bar"),
th.WithRule(routeRule)),
),
th.WithBackends(th.WithBackendNew("backend",
th.WithLoadBalancerServices(th.WithService("bar",
th.WithLBMethod("wrr"),
th.WithServersNew(th.WithServerNew(testServerURL))),
th.WithServers(th.WithServer(testServerURL))),
),
)
},
@ -227,20 +227,21 @@ func TestServerResponseEmptyBackend(t *testing.T) {
},
{
desc: "No Frontend",
config: func(testServerURL string) *types.Configuration {
config: func(testServerURL string) *config.Configuration {
return th.BuildConfiguration()
},
expectedStatusCode: http.StatusNotFound,
},
{
desc: "Empty Backend LB-Drr",
config: func(testServerURL string) *types.Configuration {
config: func(testServerURL string) *config.Configuration {
return th.BuildConfiguration(
th.WithFrontends(th.WithFrontend("backend",
th.WithRouters(th.WithRouter("foo",
th.WithEntryPoints("http"),
th.WithRoutes(th.WithRoute(requestPath, routeRule))),
th.WithServiceName("bar"),
th.WithRule(routeRule)),
),
th.WithBackends(th.WithBackendNew("backend",
th.WithLoadBalancerServices(th.WithService("bar",
th.WithLBMethod("drr")),
),
)
@ -249,14 +250,15 @@ func TestServerResponseEmptyBackend(t *testing.T) {
},
{
desc: "Empty Backend LB-Drr Sticky",
config: func(testServerURL string) *types.Configuration {
config: func(testServerURL string) *config.Configuration {
return th.BuildConfiguration(
th.WithFrontends(th.WithFrontend("backend",
th.WithRouters(th.WithRouter("foo",
th.WithEntryPoints("http"),
th.WithRoutes(th.WithRoute(requestPath, routeRule))),
th.WithServiceName("bar"),
th.WithRule(routeRule)),
),
th.WithBackends(th.WithBackendNew("backend",
th.WithLBMethod("drr"), th.WithLBSticky("test")),
th.WithLoadBalancerServices(th.WithService("bar",
th.WithLBMethod("drr"), th.WithStickiness("test")),
),
)
},
@ -264,13 +266,14 @@ func TestServerResponseEmptyBackend(t *testing.T) {
},
{
desc: "Empty Backend LB-Wrr",
config: func(testServerURL string) *types.Configuration {
config: func(testServerURL string) *config.Configuration {
return th.BuildConfiguration(
th.WithFrontends(th.WithFrontend("backend",
th.WithRouters(th.WithRouter("foo",
th.WithEntryPoints("http"),
th.WithRoutes(th.WithRoute(requestPath, routeRule))),
th.WithServiceName("bar"),
th.WithRule(routeRule)),
),
th.WithBackends(th.WithBackendNew("backend",
th.WithLoadBalancerServices(th.WithService("bar",
th.WithLBMethod("wrr")),
),
)
@ -279,14 +282,15 @@ func TestServerResponseEmptyBackend(t *testing.T) {
},
{
desc: "Empty Backend LB-Wrr Sticky",
config: func(testServerURL string) *types.Configuration {
config: func(testServerURL string) *config.Configuration {
return th.BuildConfiguration(
th.WithFrontends(th.WithFrontend("backend",
th.WithRouters(th.WithRouter("foo",
th.WithEntryPoints("http"),
th.WithRoutes(th.WithRoute(requestPath, routeRule))),
th.WithServiceName("bar"),
th.WithRule(routeRule)),
),
th.WithBackends(th.WithBackendNew("backend",
th.WithLBMethod("wrr"), th.WithLBSticky("test")),
th.WithLoadBalancerServices(th.WithService("bar",
th.WithLBMethod("wrr"), th.WithStickiness("test")),
),
)
},
@ -309,134 +313,17 @@ func TestServerResponseEmptyBackend(t *testing.T) {
entryPointsConfig := map[string]EntryPoint{
"http": {Configuration: &configuration.EntryPoint{ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}}},
}
dynamicConfigs := types.Configurations{"config": test.config(testServer.URL)}
dynamicConfigs := config.Configurations{"config": test.config(testServer.URL)}
srv := NewServer(globalConfig, nil, entryPointsConfig)
entryPoints := srv.loadConfig(dynamicConfigs, globalConfig)
entryPoints, _ := srv.loadConfig(dynamicConfigs, globalConfig)
responseRecorder := &httptest.ResponseRecorder{}
request := httptest.NewRequest(http.MethodGet, testServer.URL+requestPath, nil)
entryPoints["http"].httpRouter.ServeHTTP(responseRecorder, request)
entryPoints["http"].ServeHTTP(responseRecorder, request)
assert.Equal(t, test.expectedStatusCode, responseRecorder.Result().StatusCode, "status code")
})
}
}
type mockContext struct {
headers http.Header
}
func (c mockContext) Deadline() (deadline time.Time, ok bool) {
return deadline, ok
}
func (c mockContext) Done() <-chan struct{} {
ch := make(chan struct{})
close(ch)
return ch
}
func (c mockContext) Err() error {
return context.DeadlineExceeded
}
func (c mockContext) Value(key interface{}) interface{} {
return c.headers
}
func TestNewServerWithResponseModifiers(t *testing.T) {
testCases := []struct {
desc string
headerMiddleware *middlewares.HeaderStruct
secureMiddleware *secure.Secure
ctx context.Context
expected map[string]string
}{
{
desc: "header and secure nil",
headerMiddleware: nil,
secureMiddleware: nil,
ctx: mockContext{},
expected: map[string]string{
"X-Default": "powpow",
"Referrer-Policy": "same-origin",
},
},
{
desc: "header middleware not nil",
headerMiddleware: middlewares.NewHeaderFromStruct(&types.Headers{
CustomResponseHeaders: map[string]string{
"X-Default": "powpow",
},
}),
secureMiddleware: nil,
ctx: mockContext{},
expected: map[string]string{
"X-Default": "powpow",
"Referrer-Policy": "same-origin",
},
},
{
desc: "secure middleware not nil",
headerMiddleware: nil,
secureMiddleware: middlewares.NewSecure(&types.Headers{
ReferrerPolicy: "no-referrer",
}),
ctx: mockContext{
headers: http.Header{"Referrer-Policy": []string{"no-referrer"}},
},
expected: map[string]string{
"X-Default": "powpow",
"Referrer-Policy": "no-referrer",
},
},
{
desc: "header and secure middleware not nil",
headerMiddleware: middlewares.NewHeaderFromStruct(&types.Headers{
CustomResponseHeaders: map[string]string{
"Referrer-Policy": "powpow",
},
}),
secureMiddleware: middlewares.NewSecure(&types.Headers{
ReferrerPolicy: "no-referrer",
}),
ctx: mockContext{
headers: http.Header{"Referrer-Policy": []string{"no-referrer"}},
},
expected: map[string]string{
"X-Default": "powpow",
"Referrer-Policy": "powpow",
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
headers := make(http.Header)
headers.Add("X-Default", "powpow")
headers.Add("Referrer-Policy", "same-origin")
req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
res := &http.Response{
Request: req.WithContext(test.ctx),
Header: headers,
}
responseModifier := buildModifyResponse(test.secureMiddleware, test.headerMiddleware)
err := responseModifier(res)
assert.NoError(t, err)
assert.Equal(t, len(test.expected), len(res.Header))
for k, v := range test.expected {
assert.Equal(t, v, res.Header.Get(k))
}
})
}
}

View file

@ -1,4 +1,4 @@
package server
package service
import "sync"

280
server/service/service.go Normal file
View file

@ -0,0 +1,280 @@
package service
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httputil"
"net/url"
"time"
"github.com/containous/flaeg/parse"
"github.com/containous/traefik/config"
"github.com/containous/traefik/healthcheck"
"github.com/containous/traefik/log"
"github.com/containous/traefik/middlewares/emptybackendhandler"
"github.com/containous/traefik/old/middlewares/pipelining"
"github.com/containous/traefik/server/cookie"
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/roundrobin"
)
const (
defaultHealthCheckInterval = 30 * time.Second
defaultHealthCheckTimeout = 5 * time.Second
)
// See oxy/roundrobin/rr.go
type balancerHandler interface {
Servers() []*url.URL
ServeHTTP(w http.ResponseWriter, req *http.Request)
ServerWeight(u *url.URL) (int, bool)
RemoveServer(u *url.URL) error
UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error
NextServer() (*url.URL, error)
Next() http.Handler
}
// NewManager creates a new Manager
func NewManager(configs map[string]*config.Service, defaultRoundTripper http.RoundTripper) *Manager {
return &Manager{
bufferPool: newBufferPool(),
defaultRoundTripper: defaultRoundTripper,
balancers: make(map[string][]healthcheck.BalancerHandler),
configs: configs,
}
}
// Manager The service manager
type Manager struct {
bufferPool httputil.BufferPool
defaultRoundTripper http.RoundTripper
balancers map[string][]healthcheck.BalancerHandler
configs map[string]*config.Service
}
// Build Creates a http.Handler for a service configuration.
func (m *Manager) Build(rootCtx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) {
ctx := log.With(rootCtx, log.Str(log.ServiceName, serviceName))
// TODO refactor ?
if conf, ok := m.configs[serviceName]; ok {
// FIXME Should handle multiple service types
return m.getLoadBalancerServiceHandler(ctx, serviceName, conf.LoadBalancer, responseModifier)
}
return nil, fmt.Errorf("the service %q does not exits", serviceName)
}
func (m *Manager) getLoadBalancerServiceHandler(
ctx context.Context,
serviceName string,
service *config.LoadBalancerService,
responseModifier func(*http.Response) error,
) (http.Handler, error) {
fwd, err := m.buildForwarder(service.PassHostHeader, service.ResponseForwarding, responseModifier)
if err != nil {
return nil, err
}
fwd = pipelining.NewPipelining(fwd)
rr, err := roundrobin.New(fwd)
if err != nil {
return nil, err
}
balancer, err := m.getLoadBalancer(ctx, serviceName, service, fwd, rr)
if err != nil {
return nil, err
}
// TODO rename and checks
m.balancers[serviceName] = append(m.balancers[serviceName], balancer)
// Empty (backend with no servers)
return emptybackendhandler.New(balancer), nil
}
// LaunchHealthCheck Launches the health checks.
func (m *Manager) LaunchHealthCheck() {
backendConfigs := make(map[string]*healthcheck.BackendConfig)
for serviceName, balancers := range m.balancers {
ctx := log.With(context.Background(), log.Str(log.ServiceName, serviceName))
// FIXME aggregate
balancer := balancers[0]
// FIXME Should all the services handle healthcheck? Handle different types
service := m.configs[serviceName].LoadBalancer
// Health Check
var backendHealthCheck *healthcheck.BackendConfig
if hcOpts := buildHealthCheckOptions(ctx, balancer, serviceName, service.HealthCheck); hcOpts != nil {
log.FromContext(ctx).Debugf("Setting up healthcheck for service %s with %s", serviceName, *hcOpts)
hcOpts.Transport = m.defaultRoundTripper
backendHealthCheck = healthcheck.NewBackendConfig(*hcOpts, serviceName)
}
if backendHealthCheck != nil {
backendConfigs[serviceName] = backendHealthCheck
}
}
// FIXME metrics and context
healthcheck.GetHealthCheck().SetBackendsConfiguration(context.TODO(), backendConfigs)
}
func buildHealthCheckOptions(ctx context.Context, lb healthcheck.BalancerHandler, backend string, hc *config.HealthCheck) *healthcheck.Options {
if hc == nil || hc.Path == "" {
return nil
}
logger := log.FromContext(ctx)
interval := defaultHealthCheckInterval
if hc.Interval != "" {
intervalOverride, err := time.ParseDuration(hc.Interval)
if err != nil {
logger.Errorf("Illegal health check interval for '%s': %s", backend, err)
} else if intervalOverride <= 0 {
logger.Errorf("Health check interval smaller than zero for service '%s'", backend)
} else {
interval = intervalOverride
}
}
timeout := defaultHealthCheckTimeout
if hc.Timeout != "" {
timeoutOverride, err := time.ParseDuration(hc.Timeout)
if err != nil {
logger.Errorf("Illegal health check timeout for backend '%s': %s", backend, err)
} else if timeoutOverride <= 0 {
logger.Errorf("Health check timeout smaller than zero for backend '%s', backend", backend)
} else {
timeout = timeoutOverride
}
}
if timeout >= interval {
logger.Warnf("Health check timeout for backend '%s' should be lower than the health check interval. Interval set to timeout + 1 second (%s).", backend)
}
return &healthcheck.Options{
Scheme: hc.Scheme,
Path: hc.Path,
Port: hc.Port,
Interval: interval,
Timeout: timeout,
LB: lb,
Hostname: hc.Hostname,
Headers: hc.Headers,
}
}
func (m *Manager) getLoadBalancer(ctx context.Context, serviceName string, service *config.LoadBalancerService, fwd http.Handler, rr balancerHandler) (healthcheck.BalancerHandler, error) {
logger := log.FromContext(ctx)
var stickySession *roundrobin.StickySession
var cookieName string
if stickiness := service.Stickiness; stickiness != nil {
cookieName = cookie.GetName(stickiness.CookieName, serviceName)
stickySession = roundrobin.NewStickySession(cookieName)
}
var lb healthcheck.BalancerHandler
var err error
if service.Method == "drr" {
logger.Debug("Creating drr load-balancer")
if stickySession != nil {
logger.Debugf("Sticky session cookie name: %v", cookieName)
lb, err = roundrobin.NewRebalancer(rr, roundrobin.RebalancerStickySession(stickySession))
if err != nil {
return nil, err
}
} else {
lb, err = roundrobin.NewRebalancer(rr)
if err != nil {
return nil, err
}
}
} else {
if service.Method != "wrr" {
logger.Warnf("Invalid load-balancing method %q, fallback to 'wrr' method", service.Method)
}
logger.Debug("Creating wrr load-balancer")
if stickySession != nil {
logger.Debugf("Sticky session cookie name: %v", cookieName)
lb, err = roundrobin.New(fwd, roundrobin.EnableStickySession(stickySession))
if err != nil {
return nil, err
}
} else {
lb = rr
}
}
if err := m.upsertServers(ctx, lb, service.Servers); err != nil {
return nil, fmt.Errorf("error configuring load balancer for service %s: %v", serviceName, err)
}
return lb, nil
}
func (m *Manager) upsertServers(ctx context.Context, lb healthcheck.BalancerHandler, servers []config.Server) error {
logger := log.FromContext(ctx)
for name, srv := range servers {
u, err := url.Parse(srv.URL)
if err != nil {
return fmt.Errorf("error parsing server URL %s: %v", srv.URL, err)
}
logger.WithField(log.ServerName, name).Debugf("Creating server %d at %s with weight %d", name, u, srv.Weight)
if err := lb.UpsertServer(u, roundrobin.Weight(srv.Weight)); err != nil {
return fmt.Errorf("error adding server %s to load balancer: %v", srv.URL, err)
}
// FIXME Handle Metrics
}
return nil
}
func (m *Manager) buildForwarder(passHostHeader bool, responseForwarding *config.ResponseForwarding, responseModifier func(*http.Response) error) (http.Handler, error) {
var flushInterval parse.Duration
if responseForwarding != nil {
err := flushInterval.Set(responseForwarding.FlushInterval)
if err != nil {
return nil, fmt.Errorf("error creating flush interval: %v", err)
}
}
return forward.New(
forward.Stream(true),
forward.PassHostHeader(passHostHeader),
forward.RoundTripper(m.defaultRoundTripper),
forward.ResponseModifier(responseModifier),
forward.BufferPool(m.bufferPool),
forward.StreamingFlushInterval(time.Duration(flushInterval)),
forward.WebsocketConnectionClosedHook(func(req *http.Request, conn net.Conn) {
server := req.Context().Value(http.ServerContextKey).(*http.Server)
if server != nil {
connState := server.ConnState
if connState != nil {
connState(conn, http.StateClosed)
}
}
}),
)
}

View file

@ -0,0 +1,327 @@
package service
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/containous/traefik/config"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vulcand/oxy/roundrobin"
)
type MockRR struct {
err error
}
func (*MockRR) Servers() []*url.URL {
panic("implement me")
}
func (*MockRR) ServeHTTP(w http.ResponseWriter, req *http.Request) {
panic("implement me")
}
func (*MockRR) ServerWeight(u *url.URL) (int, bool) {
panic("implement me")
}
func (*MockRR) RemoveServer(u *url.URL) error {
panic("implement me")
}
func (m *MockRR) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error {
return m.err
}
func (*MockRR) NextServer() (*url.URL, error) {
panic("implement me")
}
func (*MockRR) Next() http.Handler {
panic("implement me")
}
type MockForwarder struct{}
func (MockForwarder) ServeHTTP(http.ResponseWriter, *http.Request) {
panic("implement me")
}
func TestGetLoadBalancer(t *testing.T) {
sm := Manager{}
testCases := []struct {
desc string
serviceName string
service *config.LoadBalancerService
fwd http.Handler
rr balancerHandler
expectError bool
}{
{
desc: "Fails when provided an invalid URL",
serviceName: "test",
service: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: ":",
Weight: 0,
},
},
},
fwd: &MockForwarder{},
rr: &MockRR{},
expectError: true,
},
{
desc: "Fails when the server upsert fails",
serviceName: "test",
service: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: "http://foo",
Weight: 0,
},
},
},
fwd: &MockForwarder{},
rr: &MockRR{err: errors.New("upsert fails")},
expectError: true,
},
{
desc: "Succeeds when there are no servers",
serviceName: "test",
service: &config.LoadBalancerService{},
fwd: &MockForwarder{},
rr: &MockRR{},
expectError: false,
},
{
desc: "Succeeds when stickiness is set",
serviceName: "test",
service: &config.LoadBalancerService{
Stickiness: &config.Stickiness{},
},
fwd: &MockForwarder{},
rr: &MockRR{},
expectError: false,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
handler, err := sm.getLoadBalancer(context.Background(), test.serviceName, test.service, test.fwd, test.rr)
if test.expectError {
require.Error(t, err)
assert.Nil(t, handler)
} else {
require.NoError(t, err)
assert.NotNil(t, handler)
}
})
}
}
func TestGetLoadBalancerServiceHandler(t *testing.T) {
sm := NewManager(nil, http.DefaultTransport)
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-From", "first")
}))
defer server1.Close()
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-From", "second")
}))
defer server2.Close()
serverPassHost := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-From", "passhost")
assert.Equal(t, "callme", r.Host)
}))
defer serverPassHost.Close()
serverPassHostFalse := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-From", "passhostfalse")
assert.NotEqual(t, "callme", r.Host)
}))
defer serverPassHostFalse.Close()
type ExpectedResult struct {
StatusCode int
XFrom string
}
testCases := []struct {
desc string
serviceName string
service *config.LoadBalancerService
responseModifier func(*http.Response) error
expected []ExpectedResult
}{
{
desc: "Load balances between the two servers",
serviceName: "test",
service: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: server1.URL,
Weight: 50,
},
{
URL: server2.URL,
Weight: 50,
},
},
Method: "wrr",
},
expected: []ExpectedResult{
{
StatusCode: http.StatusOK,
XFrom: "first",
},
{
StatusCode: http.StatusOK,
XFrom: "second",
},
},
},
{
desc: "StatusBadGateway when the server is not reachable",
serviceName: "test",
service: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: "http://foo",
Weight: 1,
},
},
Method: "wrr",
},
expected: []ExpectedResult{
{
StatusCode: http.StatusBadGateway,
},
},
},
{
desc: "ServiceUnavailable when no servers are available",
serviceName: "test",
service: &config.LoadBalancerService{
Servers: []config.Server{},
Method: "wrr",
},
expected: []ExpectedResult{
{
StatusCode: http.StatusServiceUnavailable,
},
},
},
{
desc: "Always call the same server when stickiness is true",
serviceName: "test",
service: &config.LoadBalancerService{
Stickiness: &config.Stickiness{},
Servers: []config.Server{
{
URL: server1.URL,
Weight: 1,
},
{
URL: server2.URL,
Weight: 1,
},
},
Method: "wrr",
},
expected: []ExpectedResult{
{
StatusCode: http.StatusOK,
XFrom: "first",
},
{
StatusCode: http.StatusOK,
XFrom: "first",
},
},
},
{
desc: "PassHost passes the host instead of the IP",
serviceName: "test",
service: &config.LoadBalancerService{
Stickiness: &config.Stickiness{},
PassHostHeader: true,
Servers: []config.Server{
{
URL: serverPassHost.URL,
Weight: 1,
},
},
Method: "wrr",
},
expected: []ExpectedResult{
{
StatusCode: http.StatusOK,
XFrom: "passhost",
},
},
},
{
desc: "PassHost doesn't passe the host instead of the IP",
serviceName: "test",
service: &config.LoadBalancerService{
Stickiness: &config.Stickiness{},
Servers: []config.Server{
{
URL: serverPassHostFalse.URL,
Weight: 1,
},
},
Method: "wrr",
},
expected: []ExpectedResult{
{
StatusCode: http.StatusOK,
XFrom: "passhostfalse",
},
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
handler, err := sm.getLoadBalancerServiceHandler(context.Background(), test.serviceName, test.service, test.responseModifier)
assert.NoError(t, err)
assert.NotNil(t, handler)
req := testhelpers.MustNewRequest(http.MethodGet, "http://callme", nil)
for _, expected := range test.expected {
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, expected.StatusCode, recorder.Code)
assert.Equal(t, expected.XFrom, recorder.Header().Get("X-From"))
if len(recorder.Header().Get("Set-Cookie")) > 0 {
req.Header.Set("Cookie", recorder.Header().Get("Set-Cookie"))
}
}
})
}
}
// FIXME Add healthcheck tests