Dynamic Configuration Refactoring
This commit is contained in:
parent
d3ae88f108
commit
a09dfa3ce1
452 changed files with 21023 additions and 9419 deletions
292
server/middleware/middlewares.go
Normal file
292
server/middleware/middlewares.go
Normal file
|
@ -0,0 +1,292 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/alice"
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares/addprefix"
|
||||
"github.com/containous/traefik/middlewares/auth"
|
||||
"github.com/containous/traefik/middlewares/buffering"
|
||||
"github.com/containous/traefik/middlewares/chain"
|
||||
"github.com/containous/traefik/middlewares/circuitbreaker"
|
||||
"github.com/containous/traefik/middlewares/compress"
|
||||
"github.com/containous/traefik/middlewares/customerrors"
|
||||
"github.com/containous/traefik/middlewares/headers"
|
||||
"github.com/containous/traefik/middlewares/ipwhitelist"
|
||||
"github.com/containous/traefik/middlewares/maxconnection"
|
||||
"github.com/containous/traefik/middlewares/passtlsclientcert"
|
||||
"github.com/containous/traefik/middlewares/ratelimiter"
|
||||
"github.com/containous/traefik/middlewares/redirect"
|
||||
"github.com/containous/traefik/middlewares/replacepath"
|
||||
"github.com/containous/traefik/middlewares/replacepathregex"
|
||||
"github.com/containous/traefik/middlewares/retry"
|
||||
"github.com/containous/traefik/middlewares/stripprefix"
|
||||
"github.com/containous/traefik/middlewares/stripprefixregex"
|
||||
"github.com/containous/traefik/middlewares/tracing"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Builder the middleware builder
|
||||
type Builder struct {
|
||||
configs map[string]*config.Middleware
|
||||
serviceBuilder serviceBuilder
|
||||
}
|
||||
|
||||
type serviceBuilder interface {
|
||||
Build(ctx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error)
|
||||
}
|
||||
|
||||
// NewBuilder creates a new Builder
|
||||
func NewBuilder(configs map[string]*config.Middleware, serviceBuilder serviceBuilder) *Builder {
|
||||
return &Builder{configs: configs, serviceBuilder: serviceBuilder}
|
||||
}
|
||||
|
||||
// BuildChain creates a middleware chain
|
||||
func (b *Builder) BuildChain(ctx context.Context, middlewares []string) (*alice.Chain, error) {
|
||||
chain := alice.New()
|
||||
for _, middlewareName := range middlewares {
|
||||
if _, ok := b.configs[middlewareName]; !ok {
|
||||
return nil, fmt.Errorf("middleware %q does not exist", middlewareName)
|
||||
}
|
||||
|
||||
constructor, err := b.buildConstructor(ctx, middlewareName, *b.configs[middlewareName])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if constructor != nil {
|
||||
chain = chain.Append(constructor)
|
||||
}
|
||||
}
|
||||
return &chain, nil
|
||||
}
|
||||
|
||||
func (b *Builder) buildConstructor(ctx context.Context, middlewareName string, config config.Middleware) (alice.Constructor, error) {
|
||||
var middleware alice.Constructor
|
||||
badConf := errors.New("cannot create middleware %q: multi-types middleware not supported, consider declaring two different pieces of middleware instead")
|
||||
|
||||
// AddPrefix
|
||||
if config.AddPrefix != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return addprefix.New(ctx, next, *config.AddPrefix, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// BasicAuth
|
||||
if config.BasicAuth != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return auth.NewBasic(ctx, next, *config.BasicAuth, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// Buffering
|
||||
if config.Buffering != nil && config.MaxConn.Amount != 0 {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return buffering.New(ctx, next, *config.Buffering, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// Chain
|
||||
if config.Chain != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return chain.New(ctx, next, *config.Chain, b, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreaker
|
||||
if config.CircuitBreaker != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return circuitbreaker.New(ctx, next, *config.CircuitBreaker, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// Compress
|
||||
if config.Compress != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return compress.New(ctx, next, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// CustomErrors
|
||||
if config.Errors != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return customerrors.New(ctx, next, *config.Errors, b.serviceBuilder, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// DigestAuth
|
||||
if config.DigestAuth != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return auth.NewDigest(ctx, next, *config.DigestAuth, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// ForwardAuth
|
||||
if config.ForwardAuth != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return auth.NewForward(ctx, next, *config.ForwardAuth, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// Headers
|
||||
if config.Headers != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return headers.New(ctx, next, *config.Headers, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// IPWhiteList
|
||||
if config.IPWhiteList != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return ipwhitelist.New(ctx, next, *config.IPWhiteList, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// MaxConn
|
||||
if config.MaxConn != nil && config.MaxConn.Amount != 0 {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return maxconnection.New(ctx, next, *config.MaxConn, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// PassTLSClientCert
|
||||
if config.PassTLSClientCert != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return passtlsclientcert.New(ctx, next, *config.PassTLSClientCert, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimit
|
||||
if config.RateLimit != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return ratelimiter.New(ctx, next, *config.RateLimit, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// Redirect
|
||||
if config.Redirect != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return redirect.New(ctx, next, *config.Redirect, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// ReplacePath
|
||||
if config.ReplacePath != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return replacepath.New(ctx, next, *config.ReplacePath, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// ReplacePathRegex
|
||||
if config.ReplacePathRegex != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return replacepathregex.New(ctx, next, *config.ReplacePathRegex, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// Retry
|
||||
if config.Retry != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
// FIXME missing metrics / accessLog
|
||||
return retry.New(ctx, next, *config.Retry, retry.Listeners{}, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// StripPrefix
|
||||
if config.StripPrefix != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return stripprefix.New(ctx, next, *config.StripPrefix, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
// StripPrefixRegex
|
||||
if config.StripPrefixRegex != nil {
|
||||
if middleware == nil {
|
||||
middleware = func(next http.Handler) (http.Handler, error) {
|
||||
return stripprefixregex.New(ctx, next, *config.StripPrefixRegex, middlewareName)
|
||||
}
|
||||
} else {
|
||||
return nil, badConf
|
||||
}
|
||||
}
|
||||
|
||||
return tracing.Wrap(ctx, middleware), nil
|
||||
}
|
127
server/middleware/middlewares_test.go
Normal file
127
server/middleware/middlewares_test.go
Normal file
|
@ -0,0 +1,127 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMiddlewaresRegistry_BuildMiddlewareCircuitBreaker(t *testing.T) {
|
||||
testConfig := map[string]*config.Middleware{
|
||||
"empty": {
|
||||
CircuitBreaker: &config.CircuitBreaker{
|
||||
Expression: "",
|
||||
},
|
||||
},
|
||||
"foo": {
|
||||
CircuitBreaker: &config.CircuitBreaker{
|
||||
Expression: "NetworkErrorRatio() > 0.5",
|
||||
},
|
||||
},
|
||||
}
|
||||
middlewaresBuilder := NewBuilder(testConfig, nil)
|
||||
|
||||
emptyHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
middlewareID string
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
desc: "Should fail at creating a circuit breaker with an empty expression",
|
||||
expectedError: true,
|
||||
middlewareID: "empty",
|
||||
}, {
|
||||
desc: "Should create a circuit breaker with a valid expression",
|
||||
expectedError: false,
|
||||
middlewareID: "foo",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
constructor, err := middlewaresBuilder.buildConstructor(context.Background(), test.middlewareID, *testConfig[test.middlewareID])
|
||||
require.NoError(t, err)
|
||||
|
||||
middleware, err2 := constructor(emptyHandler)
|
||||
|
||||
if test.expectedError {
|
||||
require.Error(t, err2)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, middleware)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewaresRegistry_BuildChainNilConfig(t *testing.T) {
|
||||
testConfig := map[string]*config.Middleware{
|
||||
"empty": {},
|
||||
}
|
||||
middlewaresBuilder := NewBuilder(testConfig, nil)
|
||||
|
||||
chain, err := middlewaresBuilder.BuildChain(context.Background(), []string{"empty"})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = chain.Then(nil)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestMiddlewaresRegistry_BuildMiddlewareAddPrefix(t *testing.T) {
|
||||
testConfig := map[string]*config.Middleware{
|
||||
"empty": {
|
||||
AddPrefix: &config.AddPrefix{
|
||||
Prefix: "",
|
||||
},
|
||||
},
|
||||
"foo": {
|
||||
AddPrefix: &config.AddPrefix{
|
||||
Prefix: "foo/",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
middlewaresBuilder := NewBuilder(testConfig, nil)
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
middlewareID string
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
desc: "Should not create an emty AddPrefix middleware when given an empty prefix",
|
||||
middlewareID: "empty",
|
||||
expectedError: true,
|
||||
}, {
|
||||
desc: "Should create an AddPrefix middleware when given a valid configuration",
|
||||
middlewareID: "foo",
|
||||
expectedError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
constructor, err := middlewaresBuilder.buildConstructor(context.Background(), test.middlewareID, *testConfig[test.middlewareID])
|
||||
require.NoError(t, err)
|
||||
|
||||
middleware, err2 := constructor(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}))
|
||||
|
||||
if test.expectedError {
|
||||
require.Error(t, err2)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, middleware)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
94
server/roundtripper.go
Normal file
94
server/roundtripper.go
Normal file
|
@ -0,0 +1,94 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/old/configuration"
|
||||
traefiktls "github.com/containous/traefik/tls"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
type h2cTransportWrapper struct {
|
||||
*http2.Transport
|
||||
}
|
||||
|
||||
func (t *h2cTransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
req.URL.Scheme = "http"
|
||||
return t.Transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// createHTTPTransport creates an http.Transport configured with the GlobalConfiguration settings.
|
||||
// For the settings that can't be configured in Traefik it uses the default http.Transport settings.
|
||||
// An exception to this is the MaxIdleConns setting as we only provide the option MaxIdleConnsPerHost
|
||||
// in Traefik at this point in time. Setting this value to the default of 100 could lead to confusing
|
||||
// behavior and backwards compatibility issues.
|
||||
func createHTTPTransport(globalConfiguration configuration.GlobalConfiguration) (*http.Transport, error) {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: configuration.DefaultDialTimeout,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}
|
||||
|
||||
if globalConfiguration.ForwardingTimeouts != nil {
|
||||
dialer.Timeout = time.Duration(globalConfiguration.ForwardingTimeouts.DialTimeout)
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: dialer.DialContext,
|
||||
MaxIdleConnsPerHost: globalConfiguration.MaxIdleConnsPerHost,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
transport.RegisterProtocol("h2c", &h2cTransportWrapper{
|
||||
Transport: &http2.Transport{
|
||||
DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
|
||||
return net.Dial(netw, addr)
|
||||
},
|
||||
AllowHTTP: true,
|
||||
},
|
||||
})
|
||||
|
||||
if globalConfiguration.ForwardingTimeouts != nil {
|
||||
transport.ResponseHeaderTimeout = time.Duration(globalConfiguration.ForwardingTimeouts.ResponseHeaderTimeout)
|
||||
}
|
||||
|
||||
if globalConfiguration.InsecureSkipVerify {
|
||||
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
}
|
||||
|
||||
if len(globalConfiguration.RootCAs) > 0 {
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
RootCAs: createRootCACertPool(globalConfiguration.RootCAs),
|
||||
}
|
||||
}
|
||||
|
||||
err := http2.ConfigureTransport(transport)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
func createRootCACertPool(rootCAs traefiktls.FilesOrContents) *x509.CertPool {
|
||||
roots := x509.NewCertPool()
|
||||
|
||||
for _, cert := range rootCAs {
|
||||
certContent, err := cert.Read()
|
||||
if err != nil {
|
||||
log.WithoutContext().Error("Error while read RootCAs", err)
|
||||
continue
|
||||
}
|
||||
roots.AppendCertsFromPEM(certContent)
|
||||
}
|
||||
|
||||
return roots
|
||||
}
|
119
server/router/route_appender_aggregator.go
Normal file
119
server/router/route_appender_aggregator.go
Normal file
|
@ -0,0 +1,119 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/containous/alice"
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/api"
|
||||
"github.com/containous/traefik/config/static"
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/metrics"
|
||||
"github.com/containous/traefik/safe"
|
||||
"github.com/containous/traefik/types"
|
||||
)
|
||||
|
||||
// chainBuilder The contract of the middleware builder
|
||||
type chainBuilder interface {
|
||||
BuildChain(ctx context.Context, middlewares []string) (*alice.Chain, error)
|
||||
}
|
||||
|
||||
// NewRouteAppenderAggregator Creates a new RouteAppenderAggregator
|
||||
func NewRouteAppenderAggregator(ctx context.Context, chainBuilder chainBuilder, conf static.Configuration, entryPointName string, currentConfiguration *safe.Safe) *RouteAppenderAggregator {
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
aggregator := &RouteAppenderAggregator{}
|
||||
|
||||
// FIXME add REST
|
||||
|
||||
if conf.API != nil && conf.API.EntryPoint == entryPointName {
|
||||
chain, err := chainBuilder.BuildChain(ctx, conf.API.Middlewares)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
} else {
|
||||
aggregator.AddAppender(&WithMiddleware{
|
||||
appender: api.Handler{
|
||||
EntryPoint: conf.API.EntryPoint,
|
||||
Dashboard: conf.API.Dashboard,
|
||||
Statistics: conf.API.Statistics,
|
||||
DashboardAssets: conf.API.DashboardAssets,
|
||||
CurrentConfigurations: currentConfiguration,
|
||||
},
|
||||
routerMiddlewares: chain,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if conf.Ping != nil && conf.Ping.EntryPoint == entryPointName {
|
||||
chain, err := chainBuilder.BuildChain(ctx, conf.Ping.Middlewares)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
} else {
|
||||
aggregator.AddAppender(&WithMiddleware{
|
||||
appender: conf.Ping,
|
||||
routerMiddlewares: chain,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if conf.Metrics != nil && conf.Metrics.Prometheus != nil && conf.Metrics.Prometheus.EntryPoint == entryPointName {
|
||||
chain, err := chainBuilder.BuildChain(ctx, conf.Metrics.Prometheus.Middlewares)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
} else {
|
||||
aggregator.AddAppender(&WithMiddleware{
|
||||
appender: metrics.PrometheusHandler{},
|
||||
routerMiddlewares: chain,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return aggregator
|
||||
}
|
||||
|
||||
// RouteAppenderAggregator RouteAppender that aggregate other RouteAppender
|
||||
type RouteAppenderAggregator struct {
|
||||
appenders []types.RouteAppender
|
||||
}
|
||||
|
||||
// Append Adds routes to the router
|
||||
func (r *RouteAppenderAggregator) Append(systemRouter *mux.Router) {
|
||||
for _, router := range r.appenders {
|
||||
router.Append(systemRouter)
|
||||
}
|
||||
}
|
||||
|
||||
// AddAppender adds a router in the aggregator
|
||||
func (r *RouteAppenderAggregator) AddAppender(router types.RouteAppender) {
|
||||
r.appenders = append(r.appenders, router)
|
||||
}
|
||||
|
||||
// WithMiddleware router with internal middleware
|
||||
type WithMiddleware struct {
|
||||
appender types.RouteAppender
|
||||
routerMiddlewares *alice.Chain
|
||||
}
|
||||
|
||||
// Append Adds routes to the router
|
||||
func (wm *WithMiddleware) Append(systemRouter *mux.Router) {
|
||||
realRouter := systemRouter.PathPrefix("/").Subrouter()
|
||||
|
||||
wm.appender.Append(realRouter)
|
||||
|
||||
if err := realRouter.Walk(wrapRoute(wm.routerMiddlewares)); err != nil {
|
||||
log.WithoutContext().Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// wrapRoute with middlewares
|
||||
func wrapRoute(middlewares *alice.Chain) func(*mux.Route, *mux.Router, []*mux.Route) error {
|
||||
return func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
|
||||
handler, err := middlewares.Then(route.GetHandler())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
route.Handler(handler)
|
||||
return nil
|
||||
}
|
||||
}
|
116
server/router/route_appender_aggregator_test.go
Normal file
116
server/router/route_appender_aggregator_test.go
Normal file
|
@ -0,0 +1,116 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/alice"
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/config/static"
|
||||
"github.com/containous/traefik/ping"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type ChainBuilderMock struct {
|
||||
middles map[string]alice.Constructor
|
||||
}
|
||||
|
||||
func (c *ChainBuilderMock) BuildChain(ctx context.Context, middles []string) (*alice.Chain, error) {
|
||||
chain := alice.New()
|
||||
|
||||
for _, mName := range middles {
|
||||
if constructor, ok := c.middles[mName]; ok {
|
||||
chain = chain.Append(constructor)
|
||||
}
|
||||
}
|
||||
|
||||
return &chain, nil
|
||||
}
|
||||
|
||||
func TestNewRouteAppenderAggregator(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
staticConf static.Configuration
|
||||
middles map[string]alice.Constructor
|
||||
expected map[string]int
|
||||
}{
|
||||
{
|
||||
desc: "API with auth, ping without auth",
|
||||
staticConf: static.Configuration{
|
||||
API: &static.API{
|
||||
EntryPoint: "traefik",
|
||||
Middlewares: []string{"dumb"},
|
||||
},
|
||||
Ping: &ping.Handler{
|
||||
EntryPoint: "traefik",
|
||||
},
|
||||
EntryPoints: &static.EntryPoints{
|
||||
EntryPointList: map[string]static.EntryPoint{
|
||||
"traefik": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
middles: map[string]alice.Constructor{
|
||||
"dumb": func(_ http.Handler) (http.Handler, error) {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}), nil
|
||||
},
|
||||
},
|
||||
expected: map[string]int{
|
||||
"/wrong": http.StatusBadGateway,
|
||||
"/ping": http.StatusOK,
|
||||
//"/.well-known/acme-challenge/token": http.StatusNotFound, // FIXME
|
||||
"/api/providers": http.StatusUnauthorized,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Wrong entrypoint name",
|
||||
staticConf: static.Configuration{
|
||||
API: &static.API{
|
||||
EntryPoint: "no",
|
||||
},
|
||||
EntryPoints: &static.EntryPoints{
|
||||
EntryPointList: map[string]static.EntryPoint{
|
||||
"traefik": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]int{
|
||||
"/api/providers": http.StatusBadGateway,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
chainBuilder := &ChainBuilderMock{middles: test.middles}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
router := NewRouteAppenderAggregator(ctx, chainBuilder, test.staticConf, "traefik", nil)
|
||||
|
||||
internalMuxRouter := mux.NewRouter()
|
||||
router.Append(internalMuxRouter)
|
||||
|
||||
internalMuxRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
})
|
||||
|
||||
actual := make(map[string]int)
|
||||
for calledURL := range test.expected {
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodGet, calledURL, nil)
|
||||
internalMuxRouter.ServeHTTP(recorder, request)
|
||||
actual[calledURL] = recorder.Code
|
||||
}
|
||||
|
||||
assert.Equal(t, test.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
38
server/router/route_appender_factory.go
Normal file
38
server/router/route_appender_factory.go
Normal file
|
@ -0,0 +1,38 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/containous/traefik/config/static"
|
||||
"github.com/containous/traefik/provider/acme"
|
||||
"github.com/containous/traefik/safe"
|
||||
"github.com/containous/traefik/server/middleware"
|
||||
"github.com/containous/traefik/types"
|
||||
)
|
||||
|
||||
// NewRouteAppenderFactory Creates a new RouteAppenderFactory
|
||||
func NewRouteAppenderFactory(staticConfiguration static.Configuration, entryPointName string, acmeProvider *acme.Provider) *RouteAppenderFactory {
|
||||
return &RouteAppenderFactory{
|
||||
staticConfiguration: staticConfiguration,
|
||||
entryPointName: entryPointName,
|
||||
acmeProvider: acmeProvider,
|
||||
}
|
||||
}
|
||||
|
||||
// RouteAppenderFactory A factory of RouteAppender
|
||||
type RouteAppenderFactory struct {
|
||||
staticConfiguration static.Configuration
|
||||
entryPointName string
|
||||
acmeProvider *acme.Provider
|
||||
}
|
||||
|
||||
// NewAppender Creates a new RouteAppender
|
||||
func (r *RouteAppenderFactory) NewAppender(ctx context.Context, middlewaresBuilder *middleware.Builder, currentConfiguration *safe.Safe) types.RouteAppender {
|
||||
aggregator := NewRouteAppenderAggregator(ctx, middlewaresBuilder, r.staticConfiguration, r.entryPointName, currentConfiguration)
|
||||
|
||||
if r.acmeProvider != nil && r.acmeProvider.HTTPChallenge != nil && r.acmeProvider.HTTPChallenge.EntryPoint == r.entryPointName {
|
||||
aggregator.AddAppender(r.acmeProvider)
|
||||
}
|
||||
|
||||
return aggregator
|
||||
}
|
183
server/router/router.go
Normal file
183
server/router/router.go
Normal file
|
@ -0,0 +1,183 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/alice"
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/middlewares/accesslog"
|
||||
"github.com/containous/traefik/middlewares/recovery"
|
||||
"github.com/containous/traefik/middlewares/tracing"
|
||||
"github.com/containous/traefik/responsemodifiers"
|
||||
"github.com/containous/traefik/server/middleware"
|
||||
"github.com/containous/traefik/server/service"
|
||||
)
|
||||
|
||||
const (
|
||||
recoveryMiddlewareName = "traefik-internal-recovery"
|
||||
)
|
||||
|
||||
// NewManager Creates a new Manager
|
||||
func NewManager(routers map[string]*config.Router,
|
||||
serviceManager *service.Manager, middlewaresBuilder *middleware.Builder, modifierBuilder *responsemodifiers.Builder,
|
||||
) *Manager {
|
||||
return &Manager{
|
||||
routerHandlers: make(map[string]http.Handler),
|
||||
configs: routers,
|
||||
serviceManager: serviceManager,
|
||||
middlewaresBuilder: middlewaresBuilder,
|
||||
modifierBuilder: modifierBuilder,
|
||||
}
|
||||
}
|
||||
|
||||
// Manager A route/router manager
|
||||
type Manager struct {
|
||||
routerHandlers map[string]http.Handler
|
||||
configs map[string]*config.Router
|
||||
serviceManager *service.Manager
|
||||
middlewaresBuilder *middleware.Builder
|
||||
modifierBuilder *responsemodifiers.Builder
|
||||
}
|
||||
|
||||
// BuildHandlers Builds handler for all entry points
|
||||
func (m *Manager) BuildHandlers(rootCtx context.Context, entryPoints []string, defaultEntryPoints []string) map[string]http.Handler {
|
||||
entryPointsRouters := m.filteredRouters(rootCtx, entryPoints, defaultEntryPoints)
|
||||
|
||||
entryPointHandlers := make(map[string]http.Handler)
|
||||
for entryPointName, routers := range entryPointsRouters {
|
||||
ctx := log.With(rootCtx, log.Str(log.EntryPointName, entryPointName))
|
||||
|
||||
handler, err := m.buildEntryPointHandler(ctx, routers)
|
||||
if err != nil {
|
||||
log.FromContext(ctx).Error(err)
|
||||
continue
|
||||
}
|
||||
entryPointHandlers[entryPointName] = handler
|
||||
}
|
||||
|
||||
m.serviceManager.LaunchHealthCheck()
|
||||
|
||||
return entryPointHandlers
|
||||
}
|
||||
|
||||
func contains(entryPoints []string, entryPointName string) bool {
|
||||
for _, name := range entryPoints {
|
||||
if name == entryPointName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) filteredRouters(ctx context.Context, entryPoints []string, defaultEntryPoints []string) map[string]map[string]*config.Router {
|
||||
entryPointsRouters := make(map[string]map[string]*config.Router)
|
||||
|
||||
for rtName, rt := range m.configs {
|
||||
eps := rt.EntryPoints
|
||||
if len(eps) == 0 {
|
||||
eps = defaultEntryPoints
|
||||
}
|
||||
for _, entryPointName := range eps {
|
||||
if !contains(entryPoints, entryPointName) {
|
||||
log.FromContext(log.With(ctx, log.Str(log.EntryPointName, entryPointName))).
|
||||
Errorf("entryPoint %q doesn't exist", entryPointName)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := entryPointsRouters[entryPointName]; !ok {
|
||||
entryPointsRouters[entryPointName] = make(map[string]*config.Router)
|
||||
}
|
||||
|
||||
entryPointsRouters[entryPointName][rtName] = rt
|
||||
}
|
||||
}
|
||||
|
||||
return entryPointsRouters
|
||||
}
|
||||
|
||||
func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string]*config.Router) (http.Handler, error) {
|
||||
router := mux.NewRouter().
|
||||
SkipClean(true)
|
||||
|
||||
for routerName, routerConfig := range configs {
|
||||
ctx = log.With(ctx, log.Str(log.RouterName, routerName))
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
handler, err := m.buildRouterHandler(ctx, routerName)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
continue
|
||||
}
|
||||
|
||||
err = addRoute(ctx, router, routerConfig.Rule, routerConfig.Priority, handler)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
router.SortRoutes()
|
||||
|
||||
chain := alice.New()
|
||||
chain = chain.Append(func(next http.Handler) (http.Handler, error) {
|
||||
return recovery.New(ctx, next, recoveryMiddlewareName)
|
||||
})
|
||||
|
||||
return chain.Then(router)
|
||||
}
|
||||
|
||||
func (m *Manager) buildRouterHandler(ctx context.Context, routerName string) (http.Handler, error) {
|
||||
if handler, ok := m.routerHandlers[routerName]; ok {
|
||||
return handler, nil
|
||||
}
|
||||
|
||||
configRouter, ok := m.configs[routerName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no configuration for %s", routerName)
|
||||
}
|
||||
|
||||
handler, err := m.buildHandler(ctx, configRouter, routerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
handlerWithAccessLog, err := alice.New(func(next http.Handler) (http.Handler, error) {
|
||||
return accesslog.NewFieldHandler(next, accesslog.RouterName, routerName, nil), nil
|
||||
}).Then(handler)
|
||||
if err != nil {
|
||||
log.FromContext(ctx).Error(err)
|
||||
m.routerHandlers[routerName] = handler
|
||||
} else {
|
||||
m.routerHandlers[routerName] = handlerWithAccessLog
|
||||
}
|
||||
|
||||
return m.routerHandlers[routerName], nil
|
||||
}
|
||||
|
||||
func (m *Manager) buildHandler(ctx context.Context, router *config.Router, routerName string) (http.Handler, error) {
|
||||
rm := m.modifierBuilder.Build(ctx, router.Middlewares)
|
||||
|
||||
sHandler, err := m.serviceManager.Build(ctx, router.Service, rm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mHandler, err := m.middlewaresBuilder.BuildChain(ctx, router.Middlewares)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
alHandler := func(next http.Handler) (http.Handler, error) {
|
||||
return accesslog.NewFieldHandler(next, accesslog.ServiceName, router.Service, accesslog.AddServiceFields), nil
|
||||
}
|
||||
|
||||
tHandler := func(next http.Handler) (http.Handler, error) {
|
||||
return tracing.NewForwarder(ctx, routerName, router.Service, next), nil
|
||||
}
|
||||
|
||||
return alice.New().Append(alHandler).Extend(*mHandler).Append(tHandler).Then(sHandler)
|
||||
}
|
334
server/router/router_test.go
Normal file
334
server/router/router_test.go
Normal file
|
@ -0,0 +1,334 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares/accesslog"
|
||||
"github.com/containous/traefik/middlewares/requestdecorator"
|
||||
"github.com/containous/traefik/responsemodifiers"
|
||||
"github.com/containous/traefik/server/middleware"
|
||||
"github.com/containous/traefik/server/service"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRouterManager_Get(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
|
||||
type ExpectedResult struct {
|
||||
StatusCode int
|
||||
RequestHeaders map[string]string
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
routersConfig map[string]*config.Router
|
||||
serviceConfig map[string]*config.Service
|
||||
middlewaresConfig map[string]*config.Middleware
|
||||
entryPoints []string
|
||||
defaultEntryPoints []string
|
||||
expected ExpectedResult
|
||||
}{
|
||||
{
|
||||
desc: "no middleware",
|
||||
routersConfig: map[string]*config.Router{
|
||||
"foo": {
|
||||
EntryPoints: []string{"web"},
|
||||
Service: "foo-service",
|
||||
Rule: "Host:foo.bar",
|
||||
},
|
||||
},
|
||||
serviceConfig: map[string]*config.Service{
|
||||
"foo-service": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
entryPoints: []string{"web"},
|
||||
expected: ExpectedResult{StatusCode: http.StatusOK},
|
||||
},
|
||||
{
|
||||
desc: "no middleware, default entry point",
|
||||
routersConfig: map[string]*config.Router{
|
||||
"foo": {
|
||||
Service: "foo-service",
|
||||
Rule: "Host:foo.bar",
|
||||
},
|
||||
},
|
||||
serviceConfig: map[string]*config.Service{
|
||||
"foo-service": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
entryPoints: []string{"web"},
|
||||
defaultEntryPoints: []string{"web"},
|
||||
expected: ExpectedResult{StatusCode: http.StatusOK},
|
||||
},
|
||||
{
|
||||
desc: "no middleware, no matching",
|
||||
routersConfig: map[string]*config.Router{
|
||||
"foo": {
|
||||
EntryPoints: []string{"web"},
|
||||
Service: "foo-service",
|
||||
Rule: "Host:bar.bar",
|
||||
},
|
||||
},
|
||||
serviceConfig: map[string]*config.Service{
|
||||
"foo-service": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
entryPoints: []string{"web"},
|
||||
expected: ExpectedResult{StatusCode: http.StatusNotFound},
|
||||
},
|
||||
{
|
||||
desc: "middleware: headers > auth",
|
||||
routersConfig: map[string]*config.Router{
|
||||
"foo": {
|
||||
EntryPoints: []string{"web"},
|
||||
Middlewares: []string{"headers-middle", "auth-middle"},
|
||||
Service: "foo-service",
|
||||
Rule: "Host:foo.bar",
|
||||
},
|
||||
},
|
||||
serviceConfig: map[string]*config.Service{
|
||||
"foo-service": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
middlewaresConfig: map[string]*config.Middleware{
|
||||
"auth-middle": {
|
||||
BasicAuth: &config.BasicAuth{
|
||||
Users: []string{"toto:titi"},
|
||||
},
|
||||
},
|
||||
"headers-middle": {
|
||||
Headers: &config.Headers{
|
||||
CustomRequestHeaders: map[string]string{"X-Apero": "beer"},
|
||||
},
|
||||
},
|
||||
},
|
||||
entryPoints: []string{"web"},
|
||||
expected: ExpectedResult{
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
RequestHeaders: map[string]string{
|
||||
"X-Apero": "beer",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "middleware: auth > header",
|
||||
routersConfig: map[string]*config.Router{
|
||||
"foo": {
|
||||
EntryPoints: []string{"web"},
|
||||
Middlewares: []string{"auth-middle", "headers-middle"},
|
||||
Service: "foo-service",
|
||||
Rule: "Host:foo.bar",
|
||||
},
|
||||
},
|
||||
serviceConfig: map[string]*config.Service{
|
||||
"foo-service": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
middlewaresConfig: map[string]*config.Middleware{
|
||||
"auth-middle": {
|
||||
BasicAuth: &config.BasicAuth{
|
||||
Users: []string{"toto:titi"},
|
||||
},
|
||||
},
|
||||
"headers-middle": {
|
||||
Headers: &config.Headers{
|
||||
CustomRequestHeaders: map[string]string{"X-Apero": "beer"},
|
||||
},
|
||||
},
|
||||
},
|
||||
entryPoints: []string{"web"},
|
||||
expected: ExpectedResult{
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
RequestHeaders: map[string]string{
|
||||
"X-Apero": "",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serviceManager := service.NewManager(test.serviceConfig, http.DefaultTransport)
|
||||
middlewaresBuilder := middleware.NewBuilder(test.middlewaresConfig, serviceManager)
|
||||
responseModifierFactory := responsemodifiers.NewBuilder(test.middlewaresConfig)
|
||||
|
||||
routerManager := NewManager(test.routersConfig, serviceManager, middlewaresBuilder, responseModifierFactory)
|
||||
|
||||
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints, test.defaultEntryPoints)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
|
||||
|
||||
reqHost := requestdecorator.New(nil)
|
||||
reqHost.ServeHTTP(w, req, handlers["web"].ServeHTTP)
|
||||
|
||||
assert.Equal(t, test.expected.StatusCode, w.Code)
|
||||
|
||||
for key, value := range test.expected.RequestHeaders {
|
||||
assert.Equal(t, value, req.Header.Get(key))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccessLog(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
routersConfig map[string]*config.Router
|
||||
serviceConfig map[string]*config.Service
|
||||
middlewaresConfig map[string]*config.Middleware
|
||||
entryPoints []string
|
||||
defaultEntryPoints []string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "apply routerName in accesslog (first match)",
|
||||
routersConfig: map[string]*config.Router{
|
||||
"foo": {
|
||||
EntryPoints: []string{"web"},
|
||||
Service: "foo-service",
|
||||
Rule: "Host:foo.bar",
|
||||
},
|
||||
"bar": {
|
||||
EntryPoints: []string{"web"},
|
||||
Service: "foo-service",
|
||||
Rule: "Host:bar.foo",
|
||||
},
|
||||
},
|
||||
serviceConfig: map[string]*config.Service{
|
||||
"foo-service": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
entryPoints: []string{"web"},
|
||||
expected: "foo",
|
||||
},
|
||||
{
|
||||
desc: "apply routerName in accesslog (second match)",
|
||||
routersConfig: map[string]*config.Router{
|
||||
"foo": {
|
||||
EntryPoints: []string{"web"},
|
||||
Service: "foo-service",
|
||||
Rule: "Host:bar.foo",
|
||||
},
|
||||
"bar": {
|
||||
EntryPoints: []string{"web"},
|
||||
Service: "foo-service",
|
||||
Rule: "Host:foo.bar",
|
||||
},
|
||||
},
|
||||
serviceConfig: map[string]*config.Service{
|
||||
"foo-service": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
entryPoints: []string{"web"},
|
||||
expected: "bar",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
|
||||
serviceManager := service.NewManager(test.serviceConfig, http.DefaultTransport)
|
||||
middlewaresBuilder := middleware.NewBuilder(test.middlewaresConfig, serviceManager)
|
||||
responseModifierFactory := responsemodifiers.NewBuilder(test.middlewaresConfig)
|
||||
|
||||
routerManager := NewManager(test.routersConfig, serviceManager, middlewaresBuilder, responseModifierFactory)
|
||||
|
||||
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints, test.defaultEntryPoints)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
|
||||
|
||||
accesslogger, err := accesslog.NewHandler(&types.AccessLog{
|
||||
Format: "json",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
reqHost := requestdecorator.New(nil)
|
||||
|
||||
accesslogger.ServeHTTP(w, req, http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
reqHost.ServeHTTP(w, req, handlers["web"].ServeHTTP)
|
||||
|
||||
data := accesslog.GetLogData(req)
|
||||
require.NotNil(t, data)
|
||||
|
||||
assert.Equal(t, test.expected, data.Core[accesslog.RouterName])
|
||||
}))
|
||||
})
|
||||
}
|
||||
}
|
163
server/router/rules.go
Normal file
163
server/router/rules.go
Normal file
|
@ -0,0 +1,163 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/middlewares/requestdecorator"
|
||||
)
|
||||
|
||||
func addRoute(ctx context.Context, router *mux.Router, rule string, priority int, handler http.Handler) error {
|
||||
matchers, err := parseRule(rule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if priority == 0 {
|
||||
priority = len(rule)
|
||||
}
|
||||
|
||||
route := router.NewRoute().Handler(handler).Priority(priority)
|
||||
for _, matcher := range matchers {
|
||||
matcher(route)
|
||||
if route.GetError() != nil {
|
||||
log.FromContext(ctx).Error(route.GetError())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseRule(rule string) ([]func(*mux.Route), error) {
|
||||
funcs := map[string]func(*mux.Route, ...string){
|
||||
"Host": host,
|
||||
"HostRegexp": hostRegexp,
|
||||
"Path": path,
|
||||
"PathPrefix": pathPrefix,
|
||||
"Method": methods,
|
||||
"Headers": headers,
|
||||
"HeadersRegexp": headersRegexp,
|
||||
"Query": query,
|
||||
}
|
||||
|
||||
splitRule := func(c rune) bool {
|
||||
return c == ';'
|
||||
}
|
||||
parsedRules := strings.FieldsFunc(rule, splitRule)
|
||||
|
||||
var matchers []func(*mux.Route)
|
||||
|
||||
for _, expression := range parsedRules {
|
||||
expParts := strings.Split(expression, ":")
|
||||
if len(expParts) > 1 && len(expParts[1]) > 0 {
|
||||
if fn, ok := funcs[expParts[0]]; ok {
|
||||
|
||||
parseOr := func(c rune) bool {
|
||||
return c == ','
|
||||
}
|
||||
|
||||
exp := strings.FieldsFunc(strings.Join(expParts[1:], ":"), parseOr)
|
||||
|
||||
var trimmedExp []string
|
||||
for _, value := range exp {
|
||||
trimmedExp = append(trimmedExp, strings.TrimSpace(value))
|
||||
}
|
||||
|
||||
// FIXME struct for onhostrule ?
|
||||
matcher := func(rt *mux.Route) {
|
||||
fn(rt, trimmedExp...)
|
||||
}
|
||||
|
||||
matchers = append(matchers, matcher)
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid matcher: %s", expression)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return matchers, nil
|
||||
}
|
||||
|
||||
func path(route *mux.Route, paths ...string) {
|
||||
rt := route.Subrouter()
|
||||
for _, path := range paths {
|
||||
tmpRt := rt.Path(path)
|
||||
if tmpRt.GetError() != nil {
|
||||
log.WithoutContext().WithField("paths", strings.Join(paths, ",")).Error(tmpRt.GetError())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func pathPrefix(route *mux.Route, paths ...string) {
|
||||
rt := route.Subrouter()
|
||||
for _, path := range paths {
|
||||
tmpRt := rt.PathPrefix(path)
|
||||
if tmpRt.GetError() != nil {
|
||||
log.WithoutContext().WithField("paths", strings.Join(paths, ",")).Error(tmpRt.GetError())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func host(route *mux.Route, hosts ...string) {
|
||||
for i, host := range hosts {
|
||||
hosts[i] = strings.ToLower(host)
|
||||
}
|
||||
|
||||
route.MatcherFunc(func(req *http.Request, route *mux.RouteMatch) bool {
|
||||
reqHost := requestdecorator.GetCanonizedHost(req.Context())
|
||||
if len(reqHost) == 0 {
|
||||
log.FromContext(req.Context()).Warnf("Could not retrieve CanonizedHost, rejecting %s", req.Host)
|
||||
return false
|
||||
}
|
||||
|
||||
flatH := requestdecorator.GetCNAMEFlatten(req.Context())
|
||||
if len(flatH) > 0 {
|
||||
for _, host := range hosts {
|
||||
if strings.EqualFold(reqHost, host) || strings.EqualFold(flatH, host) {
|
||||
return true
|
||||
}
|
||||
log.FromContext(req.Context()).Debugf("CNAMEFlattening: request %s which resolved to %s, is not matched to route %s", reqHost, flatH, host)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
for _, host := range hosts {
|
||||
if reqHost == host {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
func hostRegexp(route *mux.Route, hosts ...string) {
|
||||
router := route.Subrouter()
|
||||
for _, host := range hosts {
|
||||
router.Host(host)
|
||||
}
|
||||
}
|
||||
|
||||
func methods(route *mux.Route, methods ...string) {
|
||||
route.Methods(methods...)
|
||||
}
|
||||
|
||||
func headers(route *mux.Route, headers ...string) {
|
||||
route.Headers(headers...)
|
||||
}
|
||||
|
||||
func headersRegexp(route *mux.Route, headers ...string) {
|
||||
route.HeadersRegexp(headers...)
|
||||
}
|
||||
|
||||
func query(route *mux.Route, query ...string) {
|
||||
var queries []string
|
||||
for _, elem := range query {
|
||||
queries = append(queries, strings.Split(elem, "=")...)
|
||||
}
|
||||
|
||||
route.Queries(queries...)
|
||||
}
|
442
server/router/rules_test.go
Normal file
442
server/router/rules_test.go
Normal file
|
@ -0,0 +1,442 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/middlewares/requestdecorator"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_addRoute(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
rule string
|
||||
headers map[string]string
|
||||
expected map[string]int
|
||||
}{
|
||||
{
|
||||
desc: "no rule",
|
||||
expected: map[string]int{
|
||||
"http://localhost/foo": http.StatusOK,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "PathPrefix",
|
||||
rule: "PathPrefix:/foo",
|
||||
expected: map[string]int{
|
||||
"http://localhost/foo": http.StatusOK,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "wrong PathPrefix",
|
||||
rule: "PathPrefix:/bar",
|
||||
expected: map[string]int{
|
||||
"http://localhost/foo": http.StatusNotFound,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Host",
|
||||
rule: "Host:localhost",
|
||||
expected: map[string]int{
|
||||
"http://localhost/foo": http.StatusOK,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "wrong Host",
|
||||
rule: "Host:nope",
|
||||
expected: map[string]int{
|
||||
"http://localhost/foo": http.StatusNotFound,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Host and PathPrefix",
|
||||
rule: "Host:localhost;PathPrefix:/foo",
|
||||
expected: map[string]int{
|
||||
"http://localhost/foo": http.StatusOK,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Host and PathPrefix: wrong PathPrefix",
|
||||
rule: "Host:localhost;PathPrefix:/bar",
|
||||
expected: map[string]int{
|
||||
"http://localhost/foo": http.StatusNotFound,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Host and PathPrefix: wrong Host",
|
||||
rule: "Host:nope;PathPrefix:/bar",
|
||||
expected: map[string]int{
|
||||
"http://localhost/foo": http.StatusNotFound,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Host and PathPrefix: Host OR, first host",
|
||||
rule: "Host:nope,localhost;PathPrefix:/foo",
|
||||
expected: map[string]int{
|
||||
"http://localhost/foo": http.StatusOK,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Host and PathPrefix: Host OR, second host",
|
||||
rule: "Host:nope,localhost;PathPrefix:/foo",
|
||||
expected: map[string]int{
|
||||
"http://nope/foo": http.StatusOK,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Host and PathPrefix: Host OR, first host and wrong PathPrefix",
|
||||
rule: "Host:nope,localhost;PathPrefix:/bar",
|
||||
expected: map[string]int{
|
||||
"http://localhost/foo": http.StatusNotFound,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "HostRegexp with capturing group",
|
||||
rule: "HostRegexp: {subdomain:(foo\\.)?bar\\.com}",
|
||||
expected: map[string]int{
|
||||
"http://foo.bar.com": http.StatusOK,
|
||||
"http://bar.com": http.StatusOK,
|
||||
"http://fooubar.com": http.StatusNotFound,
|
||||
"http://barucom": http.StatusNotFound,
|
||||
"http://barcom": http.StatusNotFound,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "HostRegexp with non capturing group",
|
||||
rule: "HostRegexp: {subdomain:(?:foo\\.)?bar\\.com}",
|
||||
expected: map[string]int{
|
||||
"http://foo.bar.com": http.StatusOK,
|
||||
"http://bar.com": http.StatusOK,
|
||||
"http://fooubar.com": http.StatusNotFound,
|
||||
"http://barucom": http.StatusNotFound,
|
||||
"http://barcom": http.StatusNotFound,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Methods with GET",
|
||||
rule: "Method: GET",
|
||||
expected: map[string]int{
|
||||
"http://localhost/foo": http.StatusOK,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Methods with GET and POST",
|
||||
rule: "Method: GET,POST",
|
||||
expected: map[string]int{
|
||||
"http://localhost/foo": http.StatusOK,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Methods with POST",
|
||||
rule: "Method: POST",
|
||||
expected: map[string]int{
|
||||
"http://localhost/foo": http.StatusMethodNotAllowed,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Header with matching header",
|
||||
rule: "Headers: Content-Type,application/json",
|
||||
headers: map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
expected: map[string]int{
|
||||
"http://localhost/foo": http.StatusOK,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Header without matching header",
|
||||
rule: "Headers: Content-Type,application/foo",
|
||||
headers: map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
expected: map[string]int{
|
||||
"http://localhost/foo": http.StatusNotFound,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "HeaderRegExp with matching header",
|
||||
rule: "HeadersRegexp: Content-Type, application/(text|json)",
|
||||
headers: map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
expected: map[string]int{
|
||||
"http://localhost/foo": http.StatusOK,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "HeaderRegExp without matching header",
|
||||
rule: "HeadersRegexp: Content-Type, application/(text|json)",
|
||||
headers: map[string]string{
|
||||
"Content-Type": "application/foo",
|
||||
},
|
||||
expected: map[string]int{
|
||||
"http://localhost/foo": http.StatusNotFound,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "HeaderRegExp with matching second header",
|
||||
rule: "HeadersRegexp: Content-Type, application/(text|json)",
|
||||
headers: map[string]string{
|
||||
"Content-Type": "application/text",
|
||||
},
|
||||
expected: map[string]int{
|
||||
"http://localhost/foo": http.StatusOK,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Query with multiple params",
|
||||
rule: "Query: foo=bar, bar=baz",
|
||||
expected: map[string]int{
|
||||
"http://localhost/foo?foo=bar&bar=baz": http.StatusOK,
|
||||
"http://localhost/foo?bar=baz": http.StatusNotFound,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Invalid rule syntax",
|
||||
rule: "Query:param_one=true, /path2;Path: /path1",
|
||||
expected: map[string]int{
|
||||
"http://localhost/foo?bar=baz": http.StatusNotFound,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.SkipClean(true)
|
||||
|
||||
err := addRoute(context.Background(), router, test.rule, 0, handler)
|
||||
require.NoError(t, err)
|
||||
|
||||
// RequestDecorator is necessary for the host rule
|
||||
reqHost := requestdecorator.New(nil)
|
||||
|
||||
results := make(map[string]int)
|
||||
for calledURL := range test.expected {
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, calledURL, nil)
|
||||
for key, value := range test.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
reqHost.ServeHTTP(w, req, router.ServeHTTP)
|
||||
results[calledURL] = w.Code
|
||||
}
|
||||
assert.Equal(t, test.expected, results)
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_addRoutePriority(t *testing.T) {
|
||||
type Case struct {
|
||||
xFrom string
|
||||
rule string
|
||||
priority int
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
path string
|
||||
cases []Case
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "Higher priority on second rule",
|
||||
path: "/my",
|
||||
cases: []Case{
|
||||
{
|
||||
xFrom: "header1",
|
||||
rule: "PathPrefix:/my",
|
||||
priority: 10,
|
||||
},
|
||||
{
|
||||
xFrom: "header2",
|
||||
rule: "PathPrefix:/my",
|
||||
priority: 20,
|
||||
},
|
||||
},
|
||||
expected: "header2",
|
||||
},
|
||||
{
|
||||
desc: "Higher priority on first rule",
|
||||
path: "/my",
|
||||
cases: []Case{
|
||||
{
|
||||
xFrom: "header1",
|
||||
rule: "PathPrefix:/my",
|
||||
priority: 20,
|
||||
},
|
||||
{
|
||||
xFrom: "header2",
|
||||
rule: "PathPrefix:/my",
|
||||
priority: 10,
|
||||
},
|
||||
},
|
||||
expected: "header1",
|
||||
},
|
||||
{
|
||||
desc: "Higher priority on second rule with different rule",
|
||||
path: "/mypath",
|
||||
cases: []Case{
|
||||
{
|
||||
xFrom: "header1",
|
||||
rule: "PathPrefix:/mypath",
|
||||
priority: 10,
|
||||
},
|
||||
{
|
||||
xFrom: "header2",
|
||||
rule: "PathPrefix:/my",
|
||||
priority: 20,
|
||||
},
|
||||
},
|
||||
expected: "header2",
|
||||
},
|
||||
{
|
||||
desc: "Higher priority on longest rule (longest first)",
|
||||
path: "/mypath",
|
||||
cases: []Case{
|
||||
{
|
||||
xFrom: "header1",
|
||||
rule: "PathPrefix:/mypath",
|
||||
},
|
||||
{
|
||||
xFrom: "header2",
|
||||
rule: "PathPrefix:/my",
|
||||
},
|
||||
},
|
||||
expected: "header1",
|
||||
},
|
||||
{
|
||||
desc: "Higher priority on longest rule (longest second)",
|
||||
path: "/mypath",
|
||||
cases: []Case{
|
||||
{
|
||||
xFrom: "header1",
|
||||
rule: "PathPrefix:/my",
|
||||
},
|
||||
{
|
||||
xFrom: "header2",
|
||||
rule: "PathPrefix:/mypath",
|
||||
},
|
||||
},
|
||||
expected: "header2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
router := mux.NewRouter()
|
||||
|
||||
for _, route := range test.cases {
|
||||
route := route
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-From", route.xFrom)
|
||||
})
|
||||
err := addRoute(context.Background(), router, route.rule, route.priority, handler)
|
||||
require.NoError(t, err, route)
|
||||
}
|
||||
|
||||
router.SortRoutes()
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, test.path, nil)
|
||||
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, test.expected, w.Header().Get("X-From"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostRegexp(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
hostExp string
|
||||
urls map[string]bool
|
||||
}{
|
||||
{
|
||||
desc: "capturing group",
|
||||
hostExp: "{subdomain:(foo\\.)?bar\\.com}",
|
||||
urls: map[string]bool{
|
||||
"http://foo.bar.com": true,
|
||||
"http://bar.com": true,
|
||||
"http://fooubar.com": false,
|
||||
"http://barucom": false,
|
||||
"http://barcom": false,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "non capturing group",
|
||||
hostExp: "{subdomain:(?:foo\\.)?bar\\.com}",
|
||||
urls: map[string]bool{
|
||||
"http://foo.bar.com": true,
|
||||
"http://bar.com": true,
|
||||
"http://fooubar.com": false,
|
||||
"http://barucom": false,
|
||||
"http://barcom": false,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "regex insensitive",
|
||||
hostExp: "{dummy:[A-Za-z-]+\\.bar\\.com}",
|
||||
urls: map[string]bool{
|
||||
"http://FOO.bar.com": true,
|
||||
"http://foo.bar.com": true,
|
||||
"http://fooubar.com": false,
|
||||
"http://barucom": false,
|
||||
"http://barcom": false,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "insensitive host",
|
||||
hostExp: "{dummy:[a-z-]+\\.bar\\.com}",
|
||||
urls: map[string]bool{
|
||||
"http://FOO.bar.com": true,
|
||||
"http://foo.bar.com": true,
|
||||
"http://fooubar.com": false,
|
||||
"http://barucom": false,
|
||||
"http://barcom": false,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "insensitive host simple",
|
||||
hostExp: "foo.bar.com",
|
||||
urls: map[string]bool{
|
||||
"http://FOO.bar.com": true,
|
||||
"http://foo.bar.com": true,
|
||||
"http://fooubar.com": false,
|
||||
"http://barucom": false,
|
||||
"http://barcom": false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rt := &mux.Route{}
|
||||
hostRegexp(rt, test.hostExp)
|
||||
|
||||
for testURL, match := range test.urls {
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, testURL, nil)
|
||||
assert.Equal(t, match, rt.Match(req, &mux.RouteMatch{}), testURL)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
367
server/server.go
367
server/server.go
|
@ -9,35 +9,36 @@ import (
|
|||
stdlog "log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-proxyproto"
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/cluster"
|
||||
"github.com/containous/traefik/configuration"
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/config/static"
|
||||
"github.com/containous/traefik/h2c"
|
||||
"github.com/containous/traefik/ip"
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/metrics"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/middlewares/accesslog"
|
||||
"github.com/containous/traefik/middlewares/tracing"
|
||||
"github.com/containous/traefik/middlewares/requestdecorator"
|
||||
"github.com/containous/traefik/old/configuration"
|
||||
"github.com/containous/traefik/provider"
|
||||
"github.com/containous/traefik/safe"
|
||||
"github.com/containous/traefik/server/middleware"
|
||||
traefiktls "github.com/containous/traefik/tls"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/containous/traefik/tracing/datadog"
|
||||
"github.com/containous/traefik/tracing/jaeger"
|
||||
"github.com/containous/traefik/tracing/zipkin"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/urfave/negroni"
|
||||
"github.com/xenolf/lego/acme"
|
||||
)
|
||||
|
||||
var httpServerLogger = stdlog.New(log.WriterLevel(logrus.DebugLevel), "", 0)
|
||||
|
||||
func newHijackConnectionTracker() *hijackConnectionTracker {
|
||||
return &hijackConnectionTracker{
|
||||
conns: make(map[net.Conn]struct{}),
|
||||
|
@ -85,7 +86,7 @@ func (h *hijackConnectionTracker) Shutdown(ctx context.Context) error {
|
|||
func (h *hijackConnectionTracker) Close() {
|
||||
for conn := range h.conns {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("Error while closing Hijacked conn: %v", err)
|
||||
log.WithoutContext().Errorf("Error while closing Hijacked connection: %v", err)
|
||||
}
|
||||
delete(h.conns, conn)
|
||||
}
|
||||
|
@ -93,33 +94,38 @@ func (h *hijackConnectionTracker) Close() {
|
|||
|
||||
// Server is the reverse-proxy/load-balancer engine
|
||||
type Server struct {
|
||||
serverEntryPoints serverEntryPoints
|
||||
configurationChan chan types.ConfigMessage
|
||||
configurationValidatedChan chan types.ConfigMessage
|
||||
signals chan os.Signal
|
||||
stopChan chan bool
|
||||
currentConfigurations safe.Safe
|
||||
providerConfigUpdateMap map[string]chan types.ConfigMessage
|
||||
globalConfiguration configuration.GlobalConfiguration
|
||||
accessLoggerMiddleware *accesslog.LogHandler
|
||||
tracingMiddleware *tracing.Tracing
|
||||
routinesPool *safe.Pool
|
||||
leadership *cluster.Leadership
|
||||
defaultForwardingRoundTripper http.RoundTripper
|
||||
metricsRegistry metrics.Registry
|
||||
provider provider.Provider
|
||||
configurationListeners []func(types.Configuration)
|
||||
entryPoints map[string]EntryPoint
|
||||
bufferPool httputil.BufferPool
|
||||
serverEntryPoints serverEntryPoints
|
||||
configurationChan chan config.Message
|
||||
configurationValidatedChan chan config.Message
|
||||
signals chan os.Signal
|
||||
stopChan chan bool
|
||||
currentConfigurations safe.Safe
|
||||
providerConfigUpdateMap map[string]chan config.Message
|
||||
globalConfiguration configuration.GlobalConfiguration
|
||||
accessLoggerMiddleware *accesslog.Handler
|
||||
tracer *tracing.Tracing
|
||||
routinesPool *safe.Pool
|
||||
leadership *cluster.Leadership
|
||||
defaultRoundTripper http.RoundTripper
|
||||
metricsRegistry metrics.Registry
|
||||
provider provider.Provider
|
||||
configurationListeners []func(config.Configuration)
|
||||
entryPoints map[string]EntryPoint
|
||||
requestDecorator *requestdecorator.RequestDecorator
|
||||
}
|
||||
|
||||
// RouteAppenderFactory the route appender factory interface
|
||||
type RouteAppenderFactory interface {
|
||||
NewAppender(ctx context.Context, middlewaresBuilder *middleware.Builder, currentConfigurations *safe.Safe) types.RouteAppender
|
||||
}
|
||||
|
||||
// EntryPoint entryPoint information (configuration + internalRouter)
|
||||
type EntryPoint struct {
|
||||
InternalRouter types.InternalRouter
|
||||
Configuration *configuration.EntryPoint
|
||||
OnDemandListener func(string) (*tls.Certificate, error)
|
||||
TLSALPNGetter func(string) (*tls.Certificate, error)
|
||||
CertificateStore *traefiktls.CertificateStore
|
||||
RouteAppenderFactory RouteAppenderFactory
|
||||
Configuration *configuration.EntryPoint
|
||||
OnDemandListener func(string) (*tls.Certificate, error)
|
||||
TLSALPNGetter func(string) (*tls.Certificate, error)
|
||||
CertificateStore *traefiktls.CertificateStore
|
||||
}
|
||||
|
||||
type serverEntryPoints map[string]*serverEntryPoint
|
||||
|
@ -142,10 +148,11 @@ func (s serverEntryPoint) Shutdown(ctx context.Context) {
|
|||
defer wg.Done()
|
||||
if err := s.httpServer.Shutdown(ctx); err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
log.Debugf("Wait server shutdown is over due to: %s", err)
|
||||
logger := log.FromContext(ctx)
|
||||
logger.Debugf("Wait server shutdown is over due to: %s", err)
|
||||
err = s.httpServer.Close()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
logger.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -158,7 +165,8 @@ func (s serverEntryPoint) Shutdown(ctx context.Context) {
|
|||
defer wg.Done()
|
||||
if err := s.hijackConnectionTracker.Shutdown(ctx); err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
log.Debugf("Wait hijack connection is over due to: %s", err)
|
||||
logger := log.FromContext(ctx)
|
||||
logger.Debugf("Wait hijack connection is over due to: %s", err)
|
||||
s.hijackConnectionTracker.Close()
|
||||
}
|
||||
}
|
||||
|
@ -179,11 +187,32 @@ func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tc.SetKeepAlive(true)
|
||||
tc.SetKeepAlivePeriod(3 * time.Minute)
|
||||
|
||||
if err = tc.SetKeepAlive(true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = tc.SetKeepAlivePeriod(3 * time.Minute); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tc, nil
|
||||
}
|
||||
|
||||
func setupTracing(conf *static.Tracing) tracing.TrackingBackend {
|
||||
switch conf.Backend {
|
||||
case jaeger.Name:
|
||||
return conf.Jaeger
|
||||
case zipkin.Name:
|
||||
return conf.Zipkin
|
||||
case datadog.Name:
|
||||
return conf.DataDog
|
||||
default:
|
||||
log.WithoutContext().Warnf("Could not initialize tracing: unknown tracer %q", conf.Backend)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewServer returns an initialized Server.
|
||||
func NewServer(globalConfiguration configuration.GlobalConfiguration, provider provider.Provider, entrypoints map[string]EntryPoint) *Server {
|
||||
server := &Server{}
|
||||
|
@ -192,36 +221,41 @@ func NewServer(globalConfiguration configuration.GlobalConfiguration, provider p
|
|||
server.provider = provider
|
||||
server.globalConfiguration = globalConfiguration
|
||||
server.serverEntryPoints = make(map[string]*serverEntryPoint)
|
||||
server.configurationChan = make(chan types.ConfigMessage, 100)
|
||||
server.configurationValidatedChan = make(chan types.ConfigMessage, 100)
|
||||
server.configurationChan = make(chan config.Message, 100)
|
||||
server.configurationValidatedChan = make(chan config.Message, 100)
|
||||
server.signals = make(chan os.Signal, 1)
|
||||
server.stopChan = make(chan bool, 1)
|
||||
server.configureSignals()
|
||||
currentConfigurations := make(types.Configurations)
|
||||
currentConfigurations := make(config.Configurations)
|
||||
server.currentConfigurations.Set(currentConfigurations)
|
||||
server.providerConfigUpdateMap = make(map[string]chan types.ConfigMessage)
|
||||
server.providerConfigUpdateMap = make(map[string]chan config.Message)
|
||||
|
||||
transport, err := createHTTPTransport(globalConfiguration)
|
||||
if err != nil {
|
||||
log.WithoutContext().Error(err)
|
||||
server.defaultRoundTripper = http.DefaultTransport
|
||||
} else {
|
||||
server.defaultRoundTripper = transport
|
||||
}
|
||||
|
||||
if server.globalConfiguration.API != nil {
|
||||
server.globalConfiguration.API.CurrentConfigurations = &server.currentConfigurations
|
||||
}
|
||||
|
||||
server.bufferPool = newBufferPool()
|
||||
|
||||
server.routinesPool = safe.NewPool(context.Background())
|
||||
|
||||
transport, err := createHTTPTransport(globalConfiguration)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create HTTP transport: %v", err)
|
||||
if globalConfiguration.Tracing != nil {
|
||||
trackingBackend := setupTracing(static.ConvertTracing(globalConfiguration.Tracing))
|
||||
var err error
|
||||
server.tracer, err = tracing.NewTracing(globalConfiguration.Tracing.ServiceName, globalConfiguration.Tracing.SpanNameLimit, trackingBackend)
|
||||
if err != nil {
|
||||
log.WithoutContext().Warnf("Unable to create tracer: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
server.defaultForwardingRoundTripper = transport
|
||||
server.requestDecorator = requestdecorator.New(static.ConvertHostResolverConfig(globalConfiguration.HostResolver))
|
||||
|
||||
server.tracingMiddleware = globalConfiguration.Tracing
|
||||
if server.tracingMiddleware != nil && server.tracingMiddleware.Backend != "" {
|
||||
server.tracingMiddleware.Setup()
|
||||
}
|
||||
|
||||
server.metricsRegistry = registerMetricClients(globalConfiguration.Metrics)
|
||||
server.metricsRegistry = registerMetricClients(static.ConvertMetrics(globalConfiguration.Metrics))
|
||||
|
||||
if globalConfiguration.Cluster != nil {
|
||||
// leadership creation if cluster mode
|
||||
|
@ -230,9 +264,9 @@ func NewServer(globalConfiguration configuration.GlobalConfiguration, provider p
|
|||
|
||||
if globalConfiguration.AccessLog != nil {
|
||||
var err error
|
||||
server.accessLoggerMiddleware, err = accesslog.NewLogHandler(globalConfiguration.AccessLog)
|
||||
server.accessLoggerMiddleware, err = accesslog.NewHandler(static.ConvertAccessLog(globalConfiguration.AccessLog))
|
||||
if err != nil {
|
||||
log.Warnf("Unable to create log handler: %s", err)
|
||||
log.WithoutContext().Warnf("Unable to create access logger : %v", err)
|
||||
}
|
||||
}
|
||||
return server
|
||||
|
@ -259,13 +293,16 @@ func (s *Server) StartWithContext(ctx context.Context) {
|
|||
go func() {
|
||||
defer s.Close()
|
||||
<-ctx.Done()
|
||||
log.Info("I have to go...")
|
||||
logger := log.FromContext(ctx)
|
||||
logger.Info("I have to go...")
|
||||
|
||||
reqAcceptGraceTimeOut := time.Duration(s.globalConfiguration.LifeCycle.RequestAcceptGraceTimeout)
|
||||
if reqAcceptGraceTimeOut > 0 {
|
||||
log.Infof("Waiting %s for incoming requests to cease", reqAcceptGraceTimeOut)
|
||||
logger.Infof("Waiting %s for incoming requests to cease", reqAcceptGraceTimeOut)
|
||||
time.Sleep(reqAcceptGraceTimeOut)
|
||||
}
|
||||
log.Info("Stopping server gracefully")
|
||||
|
||||
logger.Info("Stopping server gracefully")
|
||||
s.Stop()
|
||||
}()
|
||||
s.Start()
|
||||
|
@ -278,18 +315,23 @@ func (s *Server) Wait() {
|
|||
|
||||
// Stop stops the server
|
||||
func (s *Server) Stop() {
|
||||
defer log.Info("Server stopped")
|
||||
defer log.WithoutContext().Info("Server stopped")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for sepn, sep := range s.serverEntryPoints {
|
||||
wg.Add(1)
|
||||
go func(serverEntryPointName string, serverEntryPoint *serverEntryPoint) {
|
||||
defer wg.Done()
|
||||
logger := log.WithoutContext().WithField(log.EntryPointName, serverEntryPointName)
|
||||
|
||||
graceTimeOut := time.Duration(s.globalConfiguration.LifeCycle.GraceTimeOut)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), graceTimeOut)
|
||||
log.Debugf("Waiting %s seconds before killing connections on entrypoint %s...", graceTimeOut, serverEntryPointName)
|
||||
logger.Debugf("Waiting %s seconds before killing connections on entrypoint %s...", graceTimeOut, serverEntryPointName)
|
||||
|
||||
serverEntryPoint.Shutdown(ctx)
|
||||
cancel()
|
||||
log.Debugf("Entrypoint %s closed", serverEntryPointName)
|
||||
|
||||
logger.Debugf("Entry point %s closed", serverEntryPointName)
|
||||
}(sepn, sep)
|
||||
}
|
||||
wg.Wait()
|
||||
|
@ -307,6 +349,7 @@ func (s *Server) Close() {
|
|||
panic("Timeout while stopping traefik, killing instance ✝")
|
||||
}
|
||||
}(ctx)
|
||||
|
||||
stopMetricsClients()
|
||||
s.stopLeadership()
|
||||
s.routinesPool.Cleanup()
|
||||
|
@ -315,11 +358,17 @@ func (s *Server) Close() {
|
|||
signal.Stop(s.signals)
|
||||
close(s.signals)
|
||||
close(s.stopChan)
|
||||
|
||||
if s.accessLoggerMiddleware != nil {
|
||||
if err := s.accessLoggerMiddleware.Close(); err != nil {
|
||||
log.Errorf("Error closing access log file: %s", err)
|
||||
log.WithoutContext().Errorf("Could not close the access log file: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if s.tracer != nil {
|
||||
s.tracer.Close()
|
||||
}
|
||||
|
||||
cancel()
|
||||
}
|
||||
|
||||
|
@ -339,8 +388,9 @@ func (s *Server) startHTTPServers() {
|
|||
s.serverEntryPoints = s.buildServerEntryPoints()
|
||||
|
||||
for newServerEntryPointName, newServerEntryPoint := range s.serverEntryPoints {
|
||||
serverEntryPoint := s.setupServerEntryPoint(newServerEntryPointName, newServerEntryPoint)
|
||||
go s.startServer(serverEntryPoint)
|
||||
ctx := log.With(context.Background(), log.Str(log.EntryPointName, newServerEntryPointName))
|
||||
serverEntryPoint := s.setupServerEntryPoint(ctx, newServerEntryPointName, newServerEntryPoint)
|
||||
go s.startServer(ctx, serverEntryPoint)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -359,9 +409,9 @@ func (s *Server) listenProviders(stop chan bool) {
|
|||
}
|
||||
|
||||
// AddListener adds a new listener function used when new configuration is provided
|
||||
func (s *Server) AddListener(listener func(types.Configuration)) {
|
||||
func (s *Server) AddListener(listener func(config.Configuration)) {
|
||||
if s.configurationListeners == nil {
|
||||
s.configurationListeners = make([]func(types.Configuration), 0)
|
||||
s.configurationListeners = make([]func(config.Configuration), 0)
|
||||
}
|
||||
s.configurationListeners = append(s.configurationListeners, listener)
|
||||
}
|
||||
|
@ -395,22 +445,23 @@ func (s *serverEntryPoint) getCertificate(clientHello *tls.ClientHelloInfo) (*tl
|
|||
return nil, fmt.Errorf("strict SNI enabled - No certificate found for domain: %q, closing connection", domainToCheck)
|
||||
}
|
||||
|
||||
log.Debugf("Serving default cert for request: %q", domainToCheck)
|
||||
log.WithoutContext().Debugf("Serving default certificate for request: %q", domainToCheck)
|
||||
return s.certs.DefaultCertificate, nil
|
||||
}
|
||||
|
||||
func (s *Server) startProvider() {
|
||||
// start providers
|
||||
jsonConf, err := json.Marshal(s.provider)
|
||||
if err != nil {
|
||||
log.Debugf("Unable to marshal provider conf %T with error: %v", s.provider, err)
|
||||
log.WithoutContext().Debugf("Unable to marshal provider configuration %T: %v", s.provider, err)
|
||||
}
|
||||
log.Infof("Starting provider %T %s", s.provider, jsonConf)
|
||||
|
||||
log.WithoutContext().Infof("Starting provider %T %s", s.provider, jsonConf)
|
||||
currentProvider := s.provider
|
||||
|
||||
safe.Go(func() {
|
||||
err := currentProvider.Provide(s.configurationChan, s.routinesPool)
|
||||
if err != nil {
|
||||
log.Errorf("Error starting provider %T: %s", s.provider, err)
|
||||
log.WithoutContext().Errorf("Error starting provider %T: %s", s.provider, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -421,7 +472,7 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
config, err := tlsOption.Certificates.CreateTLSConfig(entryPointName)
|
||||
conf, err := tlsOption.Certificates.CreateTLSConfig(entryPointName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -429,7 +480,7 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL
|
|||
s.serverEntryPoints[entryPointName].certs.DynamicCerts.Set(make(map[string]*tls.Certificate))
|
||||
|
||||
// ensure http2 enabled
|
||||
config.NextProtos = []string{"h2", "http/1.1", acme.ACMETLS1Protocol}
|
||||
conf.NextProtos = []string{"h2", "http/1.1", acme.ACMETLS1Protocol}
|
||||
|
||||
if len(tlsOption.ClientCA.Files) > 0 {
|
||||
pool := x509.NewCertPool()
|
||||
|
@ -443,55 +494,59 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL
|
|||
return nil, fmt.Errorf("invalid certificate(s) in %s", caFile)
|
||||
}
|
||||
}
|
||||
config.ClientCAs = pool
|
||||
conf.ClientCAs = pool
|
||||
if tlsOption.ClientCA.Optional {
|
||||
config.ClientAuth = tls.VerifyClientCertIfGiven
|
||||
conf.ClientAuth = tls.VerifyClientCertIfGiven
|
||||
} else {
|
||||
config.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
conf.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
}
|
||||
|
||||
if s.globalConfiguration.ACME != nil && entryPointName == s.globalConfiguration.ACME.EntryPoint {
|
||||
checkOnDemandDomain := func(domain string) bool {
|
||||
routeMatch := &mux.RouteMatch{}
|
||||
match := router.GetHandler().Match(&http.Request{URL: &url.URL{}, Host: domain}, routeMatch)
|
||||
if match && routeMatch.Route != nil {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
err := s.globalConfiguration.ACME.CreateClusterConfig(s.leadership, config, s.serverEntryPoints[entryPointName].certs.DynamicCerts, checkOnDemandDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// FIXME onDemand
|
||||
if s.globalConfiguration.ACME != nil {
|
||||
// if entryPointName == s.globalConfiguration.ACME.EntryPoint {
|
||||
// checkOnDemandDomain := func(domain string) bool {
|
||||
// routeMatch := &mux.RouteMatch{}
|
||||
// match := router.GetHandler().Match(&http.Request{URL: &url.URL{}, Host: domain}, routeMatch)
|
||||
// if match && routeMatch.Route != nil {
|
||||
// return true
|
||||
// }
|
||||
// return false
|
||||
// }
|
||||
//
|
||||
// err := s.globalConfiguration.ACME.CreateClusterConfig(s.leadership, config, s.serverEntryPoints[entryPointName].certs.DynamicCerts, checkOnDemandDomain)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// }
|
||||
} else {
|
||||
config.GetCertificate = s.serverEntryPoints[entryPointName].getCertificate
|
||||
if len(config.Certificates) != 0 {
|
||||
certMap := s.buildNameOrIPToCertificate(config.Certificates)
|
||||
|
||||
if s.entryPoints[entryPointName].CertificateStore != nil {
|
||||
s.entryPoints[entryPointName].CertificateStore.StaticCerts.Set(certMap)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove certs from the TLS config object
|
||||
config.Certificates = []tls.Certificate{}
|
||||
conf.GetCertificate = s.serverEntryPoints[entryPointName].getCertificate
|
||||
}
|
||||
|
||||
if len(conf.Certificates) != 0 {
|
||||
certMap := s.buildNameOrIPToCertificate(conf.Certificates)
|
||||
|
||||
if s.entryPoints[entryPointName].CertificateStore != nil {
|
||||
s.entryPoints[entryPointName].CertificateStore.StaticCerts.Set(certMap)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove certs from the TLS config object
|
||||
conf.Certificates = []tls.Certificate{}
|
||||
|
||||
// Set the minimum TLS version if set in the config TOML
|
||||
if minConst, exists := traefiktls.MinVersion[s.entryPoints[entryPointName].Configuration.TLS.MinVersion]; exists {
|
||||
config.PreferServerCipherSuites = true
|
||||
config.MinVersion = minConst
|
||||
if minConst, exists := traefiktls.MinVersion[tlsOption.MinVersion]; exists {
|
||||
conf.PreferServerCipherSuites = true
|
||||
conf.MinVersion = minConst
|
||||
}
|
||||
|
||||
// Set the list of CipherSuites if set in the config TOML
|
||||
if s.entryPoints[entryPointName].Configuration.TLS.CipherSuites != nil {
|
||||
// if our list of CipherSuites is defined in the entrypoint config, we can re-initilize the suites list as empty
|
||||
config.CipherSuites = make([]uint16, 0)
|
||||
for _, cipher := range s.entryPoints[entryPointName].Configuration.TLS.CipherSuites {
|
||||
if tlsOption.CipherSuites != nil {
|
||||
// if our list of CipherSuites is defined in the entryPoint config, we can re-initialize the suites list as empty
|
||||
conf.CipherSuites = make([]uint16, 0)
|
||||
for _, cipher := range tlsOption.CipherSuites {
|
||||
if cipherConst, exists := traefiktls.CipherSuites[cipher]; exists {
|
||||
config.CipherSuites = append(config.CipherSuites, cipherConst)
|
||||
conf.CipherSuites = append(conf.CipherSuites, cipherConst)
|
||||
} else {
|
||||
// CipherSuite listed in the toml does not exist in our listed
|
||||
return nil, fmt.Errorf("invalid CipherSuite: %s", cipher)
|
||||
|
@ -499,11 +554,12 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL
|
|||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
func (s *Server) startServer(serverEntryPoint *serverEntryPoint) {
|
||||
log.Infof("Starting server on %s", serverEntryPoint.httpServer.Addr)
|
||||
func (s *Server) startServer(ctx context.Context, serverEntryPoint *serverEntryPoint) {
|
||||
logger := log.FromContext(ctx)
|
||||
logger.Infof("Starting server on %s", serverEntryPoint.httpServer.Addr)
|
||||
|
||||
var err error
|
||||
if serverEntryPoint.httpServer.TLSConfig != nil {
|
||||
|
@ -513,19 +569,14 @@ func (s *Server) startServer(serverEntryPoint *serverEntryPoint) {
|
|||
}
|
||||
|
||||
if err != http.ErrServerClosed {
|
||||
log.Error("Error creating server: ", err)
|
||||
logger.Error("Cannot create server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) setupServerEntryPoint(newServerEntryPointName string, newServerEntryPoint *serverEntryPoint) *serverEntryPoint {
|
||||
serverMiddlewares, err := s.buildServerEntryPointMiddlewares(newServerEntryPointName)
|
||||
func (s *Server) setupServerEntryPoint(ctx context.Context, newServerEntryPointName string, newServerEntryPoint *serverEntryPoint) *serverEntryPoint {
|
||||
newSrv, listener, err := s.prepareServer(ctx, newServerEntryPointName, s.entryPoints[newServerEntryPointName].Configuration, newServerEntryPoint.httpRouter)
|
||||
if err != nil {
|
||||
log.Fatal("Error preparing server: ", err)
|
||||
}
|
||||
|
||||
newSrv, listener, err := s.prepareServer(newServerEntryPointName, s.entryPoints[newServerEntryPointName].Configuration, newServerEntryPoint.httpRouter, serverMiddlewares)
|
||||
if err != nil {
|
||||
log.Fatal("Error preparing server: ", err)
|
||||
log.FromContext(ctx).Fatalf("Error preparing server: %v", err)
|
||||
}
|
||||
|
||||
serverEntryPoint := s.serverEntryPoints[newServerEntryPointName]
|
||||
|
@ -545,19 +596,15 @@ func (s *Server) setupServerEntryPoint(newServerEntryPointName string, newServer
|
|||
return serverEntryPoint
|
||||
}
|
||||
|
||||
func (s *Server) prepareServer(entryPointName string, entryPoint *configuration.EntryPoint, router *middlewares.HandlerSwitcher, middlewares []negroni.Handler) (*h2c.Server, net.Listener, error) {
|
||||
func (s *Server) prepareServer(ctx context.Context, entryPointName string, entryPoint *configuration.EntryPoint, router *middlewares.HandlerSwitcher) (*h2c.Server, net.Listener, error) {
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
readTimeout, writeTimeout, idleTimeout := buildServerTimeouts(s.globalConfiguration)
|
||||
log.Infof("Preparing server %s %+v with readTimeout=%s writeTimeout=%s idleTimeout=%s", entryPointName, entryPoint, readTimeout, writeTimeout, idleTimeout)
|
||||
|
||||
// middlewares
|
||||
n := negroni.New()
|
||||
for _, middleware := range middlewares {
|
||||
n.Use(middleware)
|
||||
}
|
||||
n.UseHandler(router)
|
||||
|
||||
internalMuxRouter := s.buildInternalRouter(entryPointName)
|
||||
internalMuxRouter.NotFoundHandler = n
|
||||
logger.
|
||||
WithField("readTimeout", readTimeout).
|
||||
WithField("writeTimeout", writeTimeout).
|
||||
WithField("idleTimeout", idleTimeout).
|
||||
Infof("Preparing server %+v", entryPoint)
|
||||
|
||||
tlsConfig, err := s.createTLSConfig(entryPointName, entryPoint.TLS, router)
|
||||
if err != nil {
|
||||
|
@ -572,16 +619,18 @@ func (s *Server) prepareServer(entryPointName string, entryPoint *configuration.
|
|||
listener = tcpKeepAliveListener{listener.(*net.TCPListener)}
|
||||
|
||||
if entryPoint.ProxyProtocol != nil {
|
||||
listener, err = buildProxyProtocolListener(entryPoint, listener)
|
||||
listener, err = buildProxyProtocolListener(ctx, entryPoint, listener)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error creating proxy protocol listener: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
httpServerLogger := stdlog.New(logger.WriterLevel(logrus.DebugLevel), "", 0)
|
||||
|
||||
return &h2c.Server{
|
||||
Server: &http.Server{
|
||||
Addr: entryPoint.Address,
|
||||
Handler: internalMuxRouter,
|
||||
Handler: router,
|
||||
TLSConfig: tlsConfig,
|
||||
ReadTimeout: readTimeout,
|
||||
WriteTimeout: writeTimeout,
|
||||
|
@ -593,7 +642,7 @@ func (s *Server) prepareServer(entryPointName string, entryPoint *configuration.
|
|||
nil
|
||||
}
|
||||
|
||||
func buildProxyProtocolListener(entryPoint *configuration.EntryPoint, listener net.Listener) (net.Listener, error) {
|
||||
func buildProxyProtocolListener(ctx context.Context, entryPoint *configuration.EntryPoint, listener net.Listener) (net.Listener, error) {
|
||||
var sourceCheck func(addr net.Addr) (bool, error)
|
||||
if entryPoint.ProxyProtocol.Insecure {
|
||||
sourceCheck = func(_ net.Addr) (bool, error) {
|
||||
|
@ -615,7 +664,7 @@ func buildProxyProtocolListener(entryPoint *configuration.EntryPoint, listener n
|
|||
}
|
||||
}
|
||||
|
||||
log.Infof("Enabling ProxyProtocol for trusted IPs %v", entryPoint.ProxyProtocol.TrustedIPs)
|
||||
log.FromContext(ctx).Infof("Enabling ProxyProtocol for trusted IPs %v", entryPoint.ProxyProtocol.TrustedIPs)
|
||||
|
||||
return &proxyproto.Listener{
|
||||
Listener: listener,
|
||||
|
@ -623,23 +672,6 @@ func buildProxyProtocolListener(entryPoint *configuration.EntryPoint, listener n
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) buildInternalRouter(entryPointName string) *mux.Router {
|
||||
internalMuxRouter := mux.NewRouter()
|
||||
internalMuxRouter.StrictSlash(!s.globalConfiguration.KeepTrailingSlash)
|
||||
internalMuxRouter.SkipClean(true)
|
||||
|
||||
if entryPoint, ok := s.entryPoints[entryPointName]; ok && entryPoint.InternalRouter != nil {
|
||||
entryPoint.InternalRouter.AddRoutes(internalMuxRouter)
|
||||
|
||||
if s.globalConfiguration.API != nil && s.globalConfiguration.API.EntryPoint == entryPointName && s.leadership != nil {
|
||||
s.leadership.AddRoutes(internalMuxRouter)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return internalMuxRouter
|
||||
}
|
||||
|
||||
func buildServerTimeouts(globalConfig configuration.GlobalConfiguration) (readTimeout, writeTimeout, idleTimeout time.Duration) {
|
||||
readTimeout = time.Duration(0)
|
||||
writeTimeout = time.Duration(0)
|
||||
|
@ -663,24 +695,35 @@ func registerMetricClients(metricsConfig *types.Metrics) metrics.Registry {
|
|||
}
|
||||
|
||||
var registries []metrics.Registry
|
||||
|
||||
if metricsConfig.Prometheus != nil {
|
||||
prometheusRegister := metrics.RegisterPrometheus(metricsConfig.Prometheus)
|
||||
ctx := log.With(context.Background(), log.Str(log.MetricsProviderName, "prometheus"))
|
||||
prometheusRegister := metrics.RegisterPrometheus(ctx, metricsConfig.Prometheus)
|
||||
if prometheusRegister != nil {
|
||||
registries = append(registries, prometheusRegister)
|
||||
log.Debug("Configured Prometheus metrics")
|
||||
log.FromContext(ctx).Debug("Configured Prometheus metrics")
|
||||
}
|
||||
}
|
||||
|
||||
if metricsConfig.Datadog != nil {
|
||||
registries = append(registries, metrics.RegisterDatadog(metricsConfig.Datadog))
|
||||
log.Debugf("Configured DataDog metrics pushing to %s once every %s", metricsConfig.Datadog.Address, metricsConfig.Datadog.PushInterval)
|
||||
ctx := log.With(context.Background(), log.Str(log.MetricsProviderName, "datadog"))
|
||||
registries = append(registries, metrics.RegisterDatadog(ctx, metricsConfig.Datadog))
|
||||
log.FromContext(ctx).Debugf("Configured DataDog metrics: pushing to %s once every %s",
|
||||
metricsConfig.Datadog.Address, metricsConfig.Datadog.PushInterval)
|
||||
}
|
||||
|
||||
if metricsConfig.StatsD != nil {
|
||||
registries = append(registries, metrics.RegisterStatsd(metricsConfig.StatsD))
|
||||
log.Debugf("Configured StatsD metrics pushing to %s once every %s", metricsConfig.StatsD.Address, metricsConfig.StatsD.PushInterval)
|
||||
ctx := log.With(context.Background(), log.Str(log.MetricsProviderName, "statsd"))
|
||||
registries = append(registries, metrics.RegisterStatsd(ctx, metricsConfig.StatsD))
|
||||
log.FromContext(ctx).Debugf("Configured StatsD metrics: pushing to %s once every %s",
|
||||
metricsConfig.StatsD.Address, metricsConfig.StatsD.PushInterval)
|
||||
}
|
||||
|
||||
if metricsConfig.InfluxDB != nil {
|
||||
registries = append(registries, metrics.RegisterInfluxDB(metricsConfig.InfluxDB))
|
||||
log.Debugf("Configured InfluxDB metrics pushing to %s once every %s", metricsConfig.InfluxDB.Address, metricsConfig.InfluxDB.PushInterval)
|
||||
ctx := log.With(context.Background(), log.Str(log.MetricsProviderName, "influxdb"))
|
||||
registries = append(registries, metrics.RegisterInfluxDB(ctx, metricsConfig.InfluxDB))
|
||||
log.FromContext(ctx).Debugf("Configured InfluxDB metrics: pushing to %s once every %s",
|
||||
metricsConfig.InfluxDB.Address, metricsConfig.InfluxDB.PushInterval)
|
||||
}
|
||||
|
||||
return metrics.NewMultiRegistry(registries)
|
||||
|
|
|
@ -1,41 +1,42 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/containous/flaeg/parse"
|
||||
"github.com/containous/alice"
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/configuration"
|
||||
"github.com/containous/traefik/healthcheck"
|
||||
"github.com/containous/traefik/hostresolver"
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/config/static"
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/metrics"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/middlewares/pipelining"
|
||||
"github.com/containous/traefik/rules"
|
||||
"github.com/containous/traefik/middlewares/accesslog"
|
||||
"github.com/containous/traefik/middlewares/requestdecorator"
|
||||
"github.com/containous/traefik/middlewares/tracing"
|
||||
"github.com/containous/traefik/old/configuration"
|
||||
"github.com/containous/traefik/responsemodifiers"
|
||||
"github.com/containous/traefik/server/middleware"
|
||||
"github.com/containous/traefik/server/router"
|
||||
"github.com/containous/traefik/server/service"
|
||||
traefiktls "github.com/containous/traefik/tls"
|
||||
"github.com/containous/traefik/tls/generate"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/eapache/channels"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/urfave/negroni"
|
||||
"github.com/vulcand/oxy/forward"
|
||||
)
|
||||
|
||||
// loadConfiguration manages dynamically frontends, backends and TLS configurations
|
||||
func (s *Server) loadConfiguration(configMsg types.ConfigMessage) {
|
||||
currentConfigurations := s.currentConfigurations.Get().(types.Configurations)
|
||||
func (s *Server) loadConfiguration(configMsg config.Message) {
|
||||
logger := log.FromContext(log.With(context.Background(), log.Str(log.ProviderName, configMsg.ProviderName)))
|
||||
|
||||
currentConfigurations := s.currentConfigurations.Get().(config.Configurations)
|
||||
|
||||
// Copy configurations to new map so we don't change current if LoadConfig fails
|
||||
newConfigurations := make(types.Configurations)
|
||||
newConfigurations := make(config.Configurations)
|
||||
for k, v := range currentConfigurations {
|
||||
newConfigurations[k] = v
|
||||
}
|
||||
|
@ -43,22 +44,25 @@ func (s *Server) loadConfiguration(configMsg types.ConfigMessage) {
|
|||
|
||||
s.metricsRegistry.ConfigReloadsCounter().Add(1)
|
||||
|
||||
newServerEntryPoints := s.loadConfig(newConfigurations, s.globalConfiguration)
|
||||
handlers, certificates := s.loadConfig(newConfigurations, s.globalConfiguration)
|
||||
|
||||
s.metricsRegistry.LastConfigReloadSuccessGauge().Set(float64(time.Now().Unix()))
|
||||
|
||||
for newServerEntryPointName, newServerEntryPoint := range newServerEntryPoints {
|
||||
s.serverEntryPoints[newServerEntryPointName].httpRouter.UpdateHandler(newServerEntryPoint.httpRouter.GetHandler())
|
||||
for entryPointName, handler := range handlers {
|
||||
s.serverEntryPoints[entryPointName].httpRouter.UpdateHandler(handler)
|
||||
}
|
||||
|
||||
if s.entryPoints[newServerEntryPointName].Configuration.TLS == nil {
|
||||
if newServerEntryPoint.certs.ContainsCertificates() {
|
||||
log.Debugf("Certificates not added to non-TLS entryPoint %s.", newServerEntryPointName)
|
||||
for entryPointName, serverEntryPoint := range s.serverEntryPoints {
|
||||
eLogger := logger.WithField(log.EntryPointName, entryPointName)
|
||||
if s.entryPoints[entryPointName].Configuration.TLS == nil {
|
||||
if len(certificates[entryPointName]) > 0 {
|
||||
eLogger.Debugf("Cannot configure certificates for the non-TLS %s entryPoint.", entryPointName)
|
||||
}
|
||||
} else {
|
||||
s.serverEntryPoints[newServerEntryPointName].certs.DynamicCerts.Set(newServerEntryPoint.certs.DynamicCerts.Get())
|
||||
s.serverEntryPoints[newServerEntryPointName].certs.ResetCache()
|
||||
serverEntryPoint.certs.DynamicCerts.Set(certificates[entryPointName])
|
||||
serverEntryPoint.certs.ResetCache()
|
||||
}
|
||||
log.Infof("Server configuration reloaded on %s", s.serverEntryPoints[newServerEntryPointName].httpServer.Addr)
|
||||
eLogger.Infof("Server configuration reloaded on %s", s.serverEntryPoints[entryPointName].httpServer.Addr)
|
||||
}
|
||||
|
||||
s.currentConfigurations.Set(newConfigurations)
|
||||
|
@ -72,251 +76,127 @@ func (s *Server) loadConfiguration(configMsg types.ConfigMessage) {
|
|||
|
||||
// loadConfig returns a new gorilla.mux Route from the specified global configuration and the dynamic
|
||||
// provider configurations.
|
||||
func (s *Server) loadConfig(configurations types.Configurations, globalConfiguration configuration.GlobalConfiguration) map[string]*serverEntryPoint {
|
||||
func (s *Server) loadConfig(configurations config.Configurations, globalConfiguration configuration.GlobalConfiguration) (map[string]http.Handler, map[string]map[string]*tls.Certificate) {
|
||||
|
||||
serverEntryPoints := s.buildServerEntryPoints()
|
||||
ctx := context.TODO()
|
||||
|
||||
backendsHandlers := map[string]http.Handler{}
|
||||
backendsHealthCheck := map[string]*healthcheck.BackendConfig{}
|
||||
|
||||
var postConfigs []handlerPostConfig
|
||||
|
||||
for providerName, config := range configurations {
|
||||
frontendNames := sortedFrontendNamesForConfig(config)
|
||||
|
||||
for _, frontendName := range frontendNames {
|
||||
frontendPostConfigs, err := s.loadFrontendConfig(providerName, frontendName, config,
|
||||
serverEntryPoints,
|
||||
backendsHandlers, backendsHealthCheck)
|
||||
if err != nil {
|
||||
log.Errorf("%v. Skipping frontend %s...", err, frontendName)
|
||||
}
|
||||
|
||||
if len(frontendPostConfigs) > 0 {
|
||||
postConfigs = append(postConfigs, frontendPostConfigs...)
|
||||
}
|
||||
// FIXME manage duplicates
|
||||
conf := config.Configuration{
|
||||
Routers: make(map[string]*config.Router),
|
||||
Middlewares: make(map[string]*config.Middleware),
|
||||
Services: make(map[string]*config.Service),
|
||||
}
|
||||
for _, config := range configurations {
|
||||
for key, value := range config.Middlewares {
|
||||
conf.Middlewares[key] = value
|
||||
}
|
||||
|
||||
for key, value := range config.Services {
|
||||
conf.Services[key] = value
|
||||
}
|
||||
|
||||
for key, value := range config.Routers {
|
||||
conf.Routers[key] = value
|
||||
}
|
||||
|
||||
conf.TLS = append(conf.TLS, config.TLS...)
|
||||
}
|
||||
|
||||
for _, postConfig := range postConfigs {
|
||||
err := postConfig(backendsHandlers)
|
||||
if err != nil {
|
||||
log.Errorf("middleware post configuration error: %v", err)
|
||||
}
|
||||
}
|
||||
handlers := s.applyConfiguration(ctx, conf)
|
||||
|
||||
healthcheck.GetHealthCheck(s.metricsRegistry).SetBackendsConfiguration(s.routinesPool.Ctx(), backendsHealthCheck)
|
||||
|
||||
// Get new certificates list sorted per entrypoints
|
||||
// Get new certificates list sorted per entry points
|
||||
// Update certificates
|
||||
entryPointsCertificates := s.loadHTTPSConfiguration(configurations, globalConfiguration.DefaultEntryPoints)
|
||||
|
||||
// Sort routes and update certificates
|
||||
for serverEntryPointName, serverEntryPoint := range serverEntryPoints {
|
||||
serverEntryPoint.httpRouter.GetHandler().SortRoutes()
|
||||
if _, exists := entryPointsCertificates[serverEntryPointName]; exists {
|
||||
serverEntryPoint.certs.DynamicCerts.Set(entryPointsCertificates[serverEntryPointName])
|
||||
}
|
||||
}
|
||||
|
||||
return serverEntryPoints
|
||||
return handlers, entryPointsCertificates
|
||||
}
|
||||
|
||||
func (s *Server) loadFrontendConfig(
|
||||
providerName string, frontendName string, config *types.Configuration,
|
||||
serverEntryPoints map[string]*serverEntryPoint,
|
||||
backendsHandlers map[string]http.Handler, backendsHealthCheck map[string]*healthcheck.BackendConfig,
|
||||
) ([]handlerPostConfig, error) {
|
||||
func (s *Server) applyConfiguration(ctx context.Context, configuration config.Configuration) map[string]http.Handler {
|
||||
staticConfiguration := static.ConvertStaticConf(s.globalConfiguration)
|
||||
|
||||
frontend := config.Frontends[frontendName]
|
||||
hostResolver := buildHostResolver(s.globalConfiguration)
|
||||
|
||||
if len(frontend.EntryPoints) == 0 {
|
||||
return nil, fmt.Errorf("no entrypoint defined for frontend %s", frontendName)
|
||||
var entryPoints []string
|
||||
for entryPointName := range s.entryPoints {
|
||||
entryPoints = append(entryPoints, entryPointName)
|
||||
}
|
||||
|
||||
backend := config.Backends[frontend.Backend]
|
||||
if backend == nil {
|
||||
return nil, fmt.Errorf("undefined backend '%s' for frontend %s", frontend.Backend, frontendName)
|
||||
}
|
||||
serviceManager := service.NewManager(configuration.Services, s.defaultRoundTripper)
|
||||
middlewaresBuilder := middleware.NewBuilder(configuration.Middlewares, serviceManager)
|
||||
responseModifierFactory := responsemodifiers.NewBuilder(configuration.Middlewares)
|
||||
|
||||
frontendHash, err := frontend.Hash()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calculating hash value for frontend %s: %v", frontendName, err)
|
||||
}
|
||||
routerManager := router.NewManager(configuration.Routers, serviceManager, middlewaresBuilder, responseModifierFactory)
|
||||
|
||||
var postConfigs []handlerPostConfig
|
||||
handlers := routerManager.BuildHandlers(ctx, entryPoints, staticConfiguration.EntryPoints.Defaults)
|
||||
|
||||
for _, entryPointName := range frontend.EntryPoints {
|
||||
log.Debugf("Wiring frontend %s to entryPoint %s", frontendName, entryPointName)
|
||||
routerHandlers := make(map[string]http.Handler)
|
||||
|
||||
entryPoint := s.entryPoints[entryPointName].Configuration
|
||||
for _, entryPointName := range entryPoints {
|
||||
internalMuxRouter := mux.NewRouter().
|
||||
SkipClean(true)
|
||||
|
||||
if backendsHandlers[entryPointName+providerName+frontendHash] == nil {
|
||||
log.Debugf("Creating backend %s", frontend.Backend)
|
||||
ctx = log.With(ctx, log.Str(log.EntryPointName, entryPointName))
|
||||
|
||||
handlers, responseModifier, postConfig, err := s.buildMiddlewares(frontendName, frontend, config.Backends, entryPointName, providerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
factory := s.entryPoints[entryPointName].RouteAppenderFactory
|
||||
if factory != nil {
|
||||
// FIXME remove currentConfigurations
|
||||
appender := factory.NewAppender(ctx, middlewaresBuilder, &s.currentConfigurations)
|
||||
appender.Append(internalMuxRouter)
|
||||
}
|
||||
|
||||
if postConfig != nil {
|
||||
postConfigs = append(postConfigs, postConfig)
|
||||
}
|
||||
|
||||
fwd, err := s.buildForwarder(entryPointName, entryPoint, frontendName, frontend, responseModifier, backend)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create the forwarder for frontend %s: %v", frontendName, err)
|
||||
}
|
||||
|
||||
lb, healthCheckConfig, err := s.buildBalancerMiddlewares(frontendName, frontend, backend, fwd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Handler used by error pages
|
||||
if backendsHandlers[entryPointName+providerName+frontend.Backend] == nil {
|
||||
backendsHandlers[entryPointName+providerName+frontend.Backend] = lb
|
||||
}
|
||||
|
||||
if healthCheckConfig != nil {
|
||||
backendsHealthCheck[entryPointName+providerName+frontendHash] = healthCheckConfig
|
||||
}
|
||||
|
||||
n := negroni.New()
|
||||
|
||||
for _, handler := range handlers {
|
||||
n.Use(handler)
|
||||
}
|
||||
|
||||
n.UseHandler(lb)
|
||||
|
||||
backendsHandlers[entryPointName+providerName+frontendHash] = n
|
||||
if h, ok := handlers[entryPointName]; ok {
|
||||
internalMuxRouter.NotFoundHandler = h
|
||||
} else {
|
||||
log.Debugf("Reusing backend %s [%s - %s - %s - %s]",
|
||||
frontend.Backend, entryPointName, providerName, frontendName, frontendHash)
|
||||
internalMuxRouter.NotFoundHandler = s.buildDefaultHTTPRouter()
|
||||
}
|
||||
|
||||
serverRoute, err := buildServerRoute(serverEntryPoints[entryPointName], frontendName, frontend, hostResolver)
|
||||
routerHandlers[entryPointName] = internalMuxRouter
|
||||
|
||||
chain := alice.New()
|
||||
|
||||
if s.accessLoggerMiddleware != nil {
|
||||
chain = chain.Append(accesslog.WrapHandler(s.accessLoggerMiddleware))
|
||||
}
|
||||
|
||||
if s.tracer != nil {
|
||||
chain = chain.Append(tracing.WrapEntryPointHandler(ctx, s.tracer, entryPointName))
|
||||
}
|
||||
|
||||
chain = chain.Append(requestdecorator.WrapHandler(s.requestDecorator))
|
||||
|
||||
handler, err := chain.Then(internalMuxRouter.NotFoundHandler)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
handler := buildMatcherMiddlewares(serverRoute, backendsHandlers[entryPointName+providerName+frontendHash])
|
||||
serverRoute.Route.Handler(handler)
|
||||
|
||||
err = serverRoute.Route.GetError()
|
||||
if err != nil {
|
||||
// FIXME error management
|
||||
log.Errorf("Error building route: %s", err)
|
||||
log.FromContext(ctx).Error(err)
|
||||
continue
|
||||
}
|
||||
internalMuxRouter.NotFoundHandler = handler
|
||||
}
|
||||
|
||||
return postConfigs, nil
|
||||
return routerHandlers
|
||||
}
|
||||
|
||||
func (s *Server) buildForwarder(entryPointName string, entryPoint *configuration.EntryPoint,
|
||||
frontendName string, frontend *types.Frontend,
|
||||
responseModifier modifyResponse, backend *types.Backend) (http.Handler, error) {
|
||||
|
||||
roundTripper, err := s.getRoundTripper(entryPointName, frontend.PassTLSCert, entryPoint.TLS)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create RoundTripper for frontend %s: %v", frontendName, err)
|
||||
}
|
||||
|
||||
var flushInterval parse.Duration
|
||||
if backend.ResponseForwarding != nil {
|
||||
err := flushInterval.Set(backend.ResponseForwarding.FlushInterval)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating flush interval for frontend %s: %v", frontendName, err)
|
||||
}
|
||||
}
|
||||
|
||||
var fwd http.Handler
|
||||
fwd, err = forward.New(
|
||||
forward.Stream(true),
|
||||
forward.PassHostHeader(frontend.PassHostHeader),
|
||||
forward.RoundTripper(roundTripper),
|
||||
forward.ResponseModifier(responseModifier),
|
||||
forward.BufferPool(s.bufferPool),
|
||||
forward.StreamingFlushInterval(time.Duration(flushInterval)),
|
||||
forward.WebsocketConnectionClosedHook(func(req *http.Request, conn net.Conn) {
|
||||
server := req.Context().Value(http.ServerContextKey).(*http.Server)
|
||||
if server != nil {
|
||||
connState := server.ConnState
|
||||
if connState != nil {
|
||||
connState(conn, http.StateClosed)
|
||||
}
|
||||
}
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating forwarder for frontend %s: %v", frontendName, err)
|
||||
}
|
||||
|
||||
if s.tracingMiddleware.IsEnabled() {
|
||||
tm := s.tracingMiddleware.NewForwarderMiddleware(frontendName, frontend.Backend)
|
||||
|
||||
next := fwd
|
||||
fwd = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tm.ServeHTTP(w, r, next.ServeHTTP)
|
||||
})
|
||||
}
|
||||
|
||||
fwd = pipelining.NewPipelining(fwd)
|
||||
|
||||
return fwd, nil
|
||||
}
|
||||
|
||||
func buildServerRoute(serverEntryPoint *serverEntryPoint, frontendName string, frontend *types.Frontend, hostResolver *hostresolver.Resolver) (*types.ServerRoute, error) {
|
||||
serverRoute := &types.ServerRoute{Route: serverEntryPoint.httpRouter.GetHandler().NewRoute().Name(frontendName)}
|
||||
|
||||
priority := 0
|
||||
for routeName, route := range frontend.Routes {
|
||||
rls := rules.Rules{Route: serverRoute, HostResolver: hostResolver}
|
||||
newRoute, err := rls.Parse(route.Rule)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating route for frontend %s: %v", frontendName, err)
|
||||
}
|
||||
|
||||
serverRoute.Route = newRoute
|
||||
|
||||
priority += len(route.Rule)
|
||||
log.Debugf("Creating route %s %s", routeName, route.Rule)
|
||||
}
|
||||
|
||||
if frontend.Priority > 0 {
|
||||
serverRoute.Route.Priority(frontend.Priority)
|
||||
} else {
|
||||
serverRoute.Route.Priority(priority)
|
||||
}
|
||||
|
||||
return serverRoute, nil
|
||||
}
|
||||
|
||||
func (s *Server) preLoadConfiguration(configMsg types.ConfigMessage) {
|
||||
func (s *Server) preLoadConfiguration(configMsg config.Message) {
|
||||
providersThrottleDuration := time.Duration(s.globalConfiguration.ProvidersThrottleDuration)
|
||||
s.defaultConfigurationValues(configMsg.Configuration)
|
||||
currentConfigurations := s.currentConfigurations.Get().(types.Configurations)
|
||||
currentConfigurations := s.currentConfigurations.Get().(config.Configurations)
|
||||
|
||||
logger := log.WithoutContext().WithField(log.ProviderName, configMsg.ProviderName)
|
||||
if log.GetLevel() == logrus.DebugLevel {
|
||||
jsonConf, _ := json.Marshal(configMsg.Configuration)
|
||||
log.Debugf("Configuration received from provider %s: %s", configMsg.ProviderName, string(jsonConf))
|
||||
logger.Debugf("Configuration received from provider %s: %s", configMsg.ProviderName, string(jsonConf))
|
||||
}
|
||||
|
||||
if configMsg.Configuration == nil || configMsg.Configuration.Backends == nil && configMsg.Configuration.Frontends == nil && configMsg.Configuration.TLS == nil {
|
||||
log.Infof("Skipping empty Configuration for provider %s", configMsg.ProviderName)
|
||||
if configMsg.Configuration == nil || configMsg.Configuration.Routers == nil && configMsg.Configuration.Services == nil && configMsg.Configuration.Middlewares == nil && configMsg.Configuration.TLS == nil {
|
||||
logger.Infof("Skipping empty Configuration for provider %s", configMsg.ProviderName)
|
||||
return
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(currentConfigurations[configMsg.ProviderName], configMsg.Configuration) {
|
||||
log.Infof("Skipping same configuration for provider %s", configMsg.ProviderName)
|
||||
logger.Infof("Skipping same configuration for provider %s", configMsg.ProviderName)
|
||||
return
|
||||
}
|
||||
|
||||
providerConfigUpdateCh, ok := s.providerConfigUpdateMap[configMsg.ProviderName]
|
||||
if !ok {
|
||||
providerConfigUpdateCh = make(chan types.ConfigMessage)
|
||||
providerConfigUpdateCh = make(chan config.Message)
|
||||
s.providerConfigUpdateMap[configMsg.ProviderName] = providerConfigUpdateCh
|
||||
s.routinesPool.Go(func(stop chan bool) {
|
||||
s.throttleProviderConfigReload(providersThrottleDuration, s.configurationValidatedChan, providerConfigUpdateCh, stop)
|
||||
|
@ -326,74 +206,8 @@ func (s *Server) preLoadConfiguration(configMsg types.ConfigMessage) {
|
|||
providerConfigUpdateCh <- configMsg
|
||||
}
|
||||
|
||||
func (s *Server) defaultConfigurationValues(configuration *types.Configuration) {
|
||||
if configuration == nil || configuration.Frontends == nil {
|
||||
return
|
||||
}
|
||||
s.configureFrontends(configuration.Frontends)
|
||||
configureBackends(configuration.Backends)
|
||||
}
|
||||
|
||||
func (s *Server) configureFrontends(frontends map[string]*types.Frontend) {
|
||||
defaultEntrypoints := s.globalConfiguration.DefaultEntryPoints
|
||||
|
||||
for frontendName, frontend := range frontends {
|
||||
// default endpoints if not defined in frontends
|
||||
if len(frontend.EntryPoints) == 0 {
|
||||
frontend.EntryPoints = defaultEntrypoints
|
||||
}
|
||||
|
||||
frontendEntryPoints, undefinedEntryPoints := s.filterEntryPoints(frontend.EntryPoints)
|
||||
if len(undefinedEntryPoints) > 0 {
|
||||
log.Errorf("Undefined entry point(s) '%s' for frontend %s", strings.Join(undefinedEntryPoints, ","), frontendName)
|
||||
}
|
||||
|
||||
frontend.EntryPoints = frontendEntryPoints
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) filterEntryPoints(entryPoints []string) ([]string, []string) {
|
||||
var frontendEntryPoints []string
|
||||
var undefinedEntryPoints []string
|
||||
|
||||
for _, fepName := range entryPoints {
|
||||
var exist bool
|
||||
|
||||
for epName := range s.entryPoints {
|
||||
if epName == fepName {
|
||||
exist = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if exist {
|
||||
frontendEntryPoints = append(frontendEntryPoints, fepName)
|
||||
} else {
|
||||
undefinedEntryPoints = append(undefinedEntryPoints, fepName)
|
||||
}
|
||||
}
|
||||
|
||||
return frontendEntryPoints, undefinedEntryPoints
|
||||
}
|
||||
|
||||
func configureBackends(backends map[string]*types.Backend) {
|
||||
for backendName := range backends {
|
||||
backend := backends[backendName]
|
||||
|
||||
_, err := types.NewLoadBalancerMethod(backend.LoadBalancer)
|
||||
if err != nil {
|
||||
log.Debugf("Backend %s: %v", backendName, err)
|
||||
|
||||
var stickiness *types.Stickiness
|
||||
if backend.LoadBalancer != nil {
|
||||
stickiness = backend.LoadBalancer.Stickiness
|
||||
}
|
||||
backend.LoadBalancer = &types.LoadBalancer{
|
||||
Method: "wrr",
|
||||
Stickiness: stickiness,
|
||||
}
|
||||
}
|
||||
}
|
||||
func (s *Server) defaultConfigurationValues(configuration *config.Configuration) {
|
||||
// FIXME create a config hook
|
||||
}
|
||||
|
||||
func (s *Server) listenConfigurations(stop chan bool) {
|
||||
|
@ -414,7 +228,7 @@ func (s *Server) listenConfigurations(stop chan bool) {
|
|||
// It will immediately publish a new configuration and then only publish the next configuration after the throttle duration.
|
||||
// Note that in the case it receives N new configs in the timeframe of the throttle duration after publishing,
|
||||
// it will publish the last of the newly received configurations.
|
||||
func (s *Server) throttleProviderConfigReload(throttle time.Duration, publish chan<- types.ConfigMessage, in <-chan types.ConfigMessage, stop chan bool) {
|
||||
func (s *Server) throttleProviderConfigReload(throttle time.Duration, publish chan<- config.Message, in <-chan config.Message, stop chan bool) {
|
||||
ring := channels.NewRingChannel(1)
|
||||
defer ring.Close()
|
||||
|
||||
|
@ -424,7 +238,7 @@ func (s *Server) throttleProviderConfigReload(throttle time.Duration, publish ch
|
|||
case <-stop:
|
||||
return
|
||||
case nextConfig := <-ring.Out():
|
||||
if config, ok := nextConfig.(types.ConfigMessage); ok {
|
||||
if config, ok := nextConfig.(config.Message); ok {
|
||||
publish <- config
|
||||
time.Sleep(throttle)
|
||||
}
|
||||
|
@ -442,95 +256,53 @@ func (s *Server) throttleProviderConfigReload(throttle time.Duration, publish ch
|
|||
}
|
||||
}
|
||||
|
||||
func buildMatcherMiddlewares(serverRoute *types.ServerRoute, handler http.Handler) http.Handler {
|
||||
// path replace - This needs to always be the very last on the handler chain (first in the order in this function)
|
||||
// -- Replacing Path should happen at the very end of the Modifier chain, after all the Matcher+Modifiers ran
|
||||
if len(serverRoute.ReplacePath) > 0 {
|
||||
handler = &middlewares.ReplacePath{
|
||||
Path: serverRoute.ReplacePath,
|
||||
Handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
if len(serverRoute.ReplacePathRegex) > 0 {
|
||||
sp := strings.Split(serverRoute.ReplacePathRegex, " ")
|
||||
if len(sp) == 2 {
|
||||
handler = middlewares.NewReplacePathRegexHandler(sp[0], sp[1], handler)
|
||||
} else {
|
||||
log.Warnf("Invalid syntax for ReplacePathRegex: %s. Separate the regular expression and the replacement by a space.", serverRoute.ReplacePathRegex)
|
||||
}
|
||||
}
|
||||
|
||||
// add prefix - This needs to always be right before ReplacePath on the chain (second in order in this function)
|
||||
// -- Adding Path Prefix should happen after all *Strip Matcher+Modifiers ran, but before Replace (in case it's configured)
|
||||
if len(serverRoute.AddPrefix) > 0 {
|
||||
handler = &middlewares.AddPrefix{
|
||||
Prefix: serverRoute.AddPrefix,
|
||||
Handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
// strip prefix
|
||||
if len(serverRoute.StripPrefixes) > 0 {
|
||||
handler = &middlewares.StripPrefix{
|
||||
Prefixes: serverRoute.StripPrefixes,
|
||||
Handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
// strip prefix with regex
|
||||
if len(serverRoute.StripPrefixesRegex) > 0 {
|
||||
handler = middlewares.NewStripPrefixRegex(handler, serverRoute.StripPrefixesRegex)
|
||||
}
|
||||
|
||||
return handler
|
||||
}
|
||||
|
||||
func (s *Server) postLoadConfiguration() {
|
||||
if s.metricsRegistry.IsEnabled() {
|
||||
activeConfig := s.currentConfigurations.Get().(types.Configurations)
|
||||
metrics.OnConfigurationUpdate(activeConfig)
|
||||
}
|
||||
// FIXME metrics
|
||||
// if s.metricsRegistry.IsEnabled() {
|
||||
// activeConfig := s.currentConfigurations.Get().(config.Configurations)
|
||||
// metrics.OnConfigurationUpdate(activeConfig)
|
||||
// }
|
||||
|
||||
if s.globalConfiguration.ACME == nil || s.leadership == nil || !s.leadership.IsLeader() {
|
||||
return
|
||||
}
|
||||
|
||||
if s.globalConfiguration.ACME.OnHostRule {
|
||||
currentConfigurations := s.currentConfigurations.Get().(types.Configurations)
|
||||
for _, config := range currentConfigurations {
|
||||
for _, frontend := range config.Frontends {
|
||||
|
||||
// check if one of the frontend entrypoints is configured with TLS
|
||||
// and is configured with ACME
|
||||
acmeEnabled := false
|
||||
for _, entryPoint := range frontend.EntryPoints {
|
||||
if s.globalConfiguration.ACME.EntryPoint == entryPoint && s.entryPoints[entryPoint].Configuration.TLS != nil {
|
||||
acmeEnabled = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if acmeEnabled {
|
||||
for _, route := range frontend.Routes {
|
||||
rls := rules.Rules{}
|
||||
domains, err := rls.ParseDomains(route.Rule)
|
||||
if err != nil {
|
||||
log.Errorf("Error parsing domains: %v", err)
|
||||
} else if len(domains) == 0 {
|
||||
log.Debugf("No domain parsed in rule %q", route.Rule)
|
||||
} else {
|
||||
s.globalConfiguration.ACME.LoadCertificateForDomains(domains)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// FIXME acme
|
||||
// if s.globalConfiguration.ACME.OnHostRule {
|
||||
// currentConfigurations := s.currentConfigurations.Get().(config.Configurations)
|
||||
// for _, config := range currentConfigurations {
|
||||
// for _, frontend := range config.Frontends {
|
||||
//
|
||||
// // check if one of the frontend entrypoints is configured with TLS
|
||||
// // and is configured with ACME
|
||||
// acmeEnabled := false
|
||||
// for _, entryPoint := range frontend.EntryPoints {
|
||||
// if s.globalConfiguration.ACME.EntryPoint == entryPoint && s.entryPoints[entryPoint].Configuration.TLS != nil {
|
||||
// acmeEnabled = true
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// if acmeEnabled {
|
||||
// for _, route := range frontend.Routes {
|
||||
// rls := rules.Rules{}
|
||||
// domains, err := rls.ParseDomains(route.Rule)
|
||||
// if err != nil {
|
||||
// log.Errorf("Error parsing domains: %v", err)
|
||||
// } else if len(domains) == 0 {
|
||||
// log.Debugf("No domain parsed in rule %q", route.Rule)
|
||||
// } else {
|
||||
// s.globalConfiguration.ACME.LoadCertificateForDomains(domains)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
// loadHTTPSConfiguration add/delete HTTPS certificate managed dynamically
|
||||
func (s *Server) loadHTTPSConfiguration(configurations types.Configurations, defaultEntryPoints configuration.DefaultEntryPoints) map[string]map[string]*tls.Certificate {
|
||||
func (s *Server) loadHTTPSConfiguration(configurations config.Configurations, defaultEntryPoints configuration.DefaultEntryPoints) map[string]map[string]*tls.Certificate {
|
||||
newEPCertificates := make(map[string]map[string]*tls.Certificate)
|
||||
// Get all certificates
|
||||
for _, config := range configurations {
|
||||
|
@ -543,9 +315,14 @@ func (s *Server) loadHTTPSConfiguration(configurations types.Configurations, def
|
|||
|
||||
func (s *Server) buildServerEntryPoints() map[string]*serverEntryPoint {
|
||||
serverEntryPoints := make(map[string]*serverEntryPoint)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
handlers := s.applyConfiguration(ctx, config.Configuration{})
|
||||
|
||||
for entryPointName, entryPoint := range s.entryPoints {
|
||||
serverEntryPoints[entryPointName] = &serverEntryPoint{
|
||||
httpRouter: middlewares.NewHandlerSwitcher(s.buildDefaultHTTPRouter()),
|
||||
httpRouter: middlewares.NewHandlerSwitcher(handlers[entryPointName]),
|
||||
onDemandListener: entryPoint.OnDemandListener,
|
||||
tlsALPNGetter: entryPoint.TLSALPNGetter,
|
||||
}
|
||||
|
@ -557,19 +334,21 @@ func (s *Server) buildServerEntryPoints() map[string]*serverEntryPoint {
|
|||
}
|
||||
|
||||
if entryPoint.Configuration.TLS != nil {
|
||||
logger := log.FromContext(ctx).WithField(log.EntryPointName, entryPointName)
|
||||
|
||||
serverEntryPoints[entryPointName].certs.SniStrict = entryPoint.Configuration.TLS.SniStrict
|
||||
|
||||
if entryPoint.Configuration.TLS.DefaultCertificate != nil {
|
||||
cert, err := buildDefaultCertificate(entryPoint.Configuration.TLS.DefaultCertificate)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
logger.Error(err)
|
||||
continue
|
||||
}
|
||||
serverEntryPoints[entryPointName].certs.DefaultCertificate = cert
|
||||
} else {
|
||||
cert, err := generate.DefaultCertificate()
|
||||
if err != nil {
|
||||
log.Errorf("failed to generate default certificate: %v", err)
|
||||
logger.Error(err)
|
||||
continue
|
||||
}
|
||||
serverEntryPoints[entryPointName].certs.DefaultCertificate = cert
|
||||
|
@ -585,6 +364,13 @@ func (s *Server) buildServerEntryPoints() map[string]*serverEntryPoint {
|
|||
return serverEntryPoints
|
||||
}
|
||||
|
||||
func (s *Server) buildDefaultHTTPRouter() *mux.Router {
|
||||
rt := mux.NewRouter()
|
||||
rt.NotFoundHandler = http.HandlerFunc(http.NotFound)
|
||||
rt.SkipClean(true)
|
||||
return rt
|
||||
}
|
||||
|
||||
func buildDefaultCertificate(defaultCertificate *traefiktls.Certificate) (*tls.Certificate, error) {
|
||||
certFile, err := defaultCertificate.CertFile.Read()
|
||||
if err != nil {
|
||||
|
@ -602,31 +388,3 @@ func buildDefaultCertificate(defaultCertificate *traefiktls.Certificate) (*tls.C
|
|||
}
|
||||
return &cert, nil
|
||||
}
|
||||
|
||||
func (s *Server) buildDefaultHTTPRouter() *mux.Router {
|
||||
rt := mux.NewRouter()
|
||||
rt.NotFoundHandler = s.wrapHTTPHandlerWithAccessLog(http.HandlerFunc(http.NotFound), "backend not found")
|
||||
rt.StrictSlash(!s.globalConfiguration.KeepTrailingSlash)
|
||||
rt.SkipClean(true)
|
||||
return rt
|
||||
}
|
||||
|
||||
func sortedFrontendNamesForConfig(configuration *types.Configuration) []string {
|
||||
var keys []string
|
||||
for key := range configuration.Frontends {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func buildHostResolver(globalConfig configuration.GlobalConfiguration) *hostresolver.Resolver {
|
||||
if globalConfig.HostResolver != nil {
|
||||
return &hostresolver.Resolver{
|
||||
CnameFlattening: globalConfig.HostResolver.CnameFlattening,
|
||||
ResolvConfig: globalConfig.HostResolver.ResolvConfig,
|
||||
ResolvDepth: globalConfig.HostResolver.ResolvDepth,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,25 +1,16 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/containous/flaeg/parse"
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/configuration"
|
||||
"github.com/containous/traefik/healthcheck"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/rules"
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/old/configuration"
|
||||
th "github.com/containous/traefik/testhelpers"
|
||||
"github.com/containous/traefik/tls"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vulcand/oxy/roundrobin"
|
||||
)
|
||||
|
||||
// LocalhostCert is a PEM-encoded TLS cert with SAN IPs
|
||||
|
@ -60,137 +51,6 @@ f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA==
|
|||
-----END RSA PRIVATE KEY-----`)
|
||||
)
|
||||
|
||||
type testLoadBalancer struct{}
|
||||
|
||||
func (lb *testLoadBalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
// noop
|
||||
}
|
||||
|
||||
func (lb *testLoadBalancer) RemoveServer(u *url.URL) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lb *testLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lb *testLoadBalancer) Servers() []*url.URL {
|
||||
return []*url.URL{}
|
||||
}
|
||||
|
||||
func TestServerLoadConfigHealthCheckOptions(t *testing.T) {
|
||||
healthChecks := []*types.HealthCheck{
|
||||
nil,
|
||||
{
|
||||
Path: "/path",
|
||||
},
|
||||
}
|
||||
|
||||
for _, lbMethod := range []string{"Wrr", "Drr"} {
|
||||
for _, healthCheck := range healthChecks {
|
||||
t.Run(fmt.Sprintf("%s/hc=%t", lbMethod, healthCheck != nil), func(t *testing.T) {
|
||||
globalConfig := configuration.GlobalConfiguration{
|
||||
HealthCheck: &configuration.HealthCheckConfig{
|
||||
Interval: parse.Duration(5 * time.Second),
|
||||
Timeout: parse.Duration(3 * time.Second),
|
||||
},
|
||||
}
|
||||
entryPoints := map[string]EntryPoint{
|
||||
"http": {
|
||||
Configuration: &configuration.EntryPoint{
|
||||
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
dynamicConfigs := types.Configurations{
|
||||
"config": &types.Configuration{
|
||||
Frontends: map[string]*types.Frontend{
|
||||
"frontend": {
|
||||
EntryPoints: []string{"http"},
|
||||
Backend: "backend",
|
||||
},
|
||||
},
|
||||
Backends: map[string]*types.Backend{
|
||||
"backend": {
|
||||
Servers: map[string]types.Server{
|
||||
"server": {
|
||||
URL: "http://localhost",
|
||||
},
|
||||
},
|
||||
LoadBalancer: &types.LoadBalancer{
|
||||
Method: lbMethod,
|
||||
},
|
||||
HealthCheck: healthCheck,
|
||||
},
|
||||
},
|
||||
TLS: []*tls.Configuration{
|
||||
{
|
||||
Certificate: &tls.Certificate{
|
||||
CertFile: localhostCert,
|
||||
KeyFile: localhostKey,
|
||||
},
|
||||
EntryPoints: []string{"http"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
srv := NewServer(globalConfig, nil, entryPoints)
|
||||
|
||||
_ = srv.loadConfig(dynamicConfigs, globalConfig)
|
||||
|
||||
expectedNumHealthCheckBackends := 0
|
||||
if healthCheck != nil {
|
||||
expectedNumHealthCheckBackends = 1
|
||||
}
|
||||
assert.Len(t, healthcheck.GetHealthCheck(th.NewCollectingHealthCheckMetrics()).Backends, expectedNumHealthCheckBackends, "health check backends")
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerLoadConfigEmptyBasicAuth(t *testing.T) {
|
||||
globalConfig := configuration.GlobalConfiguration{
|
||||
EntryPoints: configuration.EntryPoints{
|
||||
"http": &configuration.EntryPoint{ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}},
|
||||
},
|
||||
}
|
||||
|
||||
dynamicConfigs := types.Configurations{
|
||||
"config": &types.Configuration{
|
||||
Frontends: map[string]*types.Frontend{
|
||||
"frontend": {
|
||||
EntryPoints: []string{"http"},
|
||||
Backend: "backend",
|
||||
},
|
||||
},
|
||||
Backends: map[string]*types.Backend{
|
||||
"backend": {
|
||||
Servers: map[string]types.Server{
|
||||
"server": {
|
||||
URL: "http://localhost",
|
||||
},
|
||||
},
|
||||
LoadBalancer: &types.LoadBalancer{
|
||||
Method: "Wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
entryPoints := map[string]EntryPoint{}
|
||||
for key, value := range globalConfig.EntryPoints {
|
||||
entryPoints[key] = EntryPoint{
|
||||
Configuration: value,
|
||||
}
|
||||
}
|
||||
|
||||
srv := NewServer(globalConfig, nil, entryPoints)
|
||||
_ = srv.loadConfig(dynamicConfigs, globalConfig)
|
||||
}
|
||||
|
||||
func TestServerLoadCertificateWithDefaultEntryPoint(t *testing.T) {
|
||||
globalConfig := configuration.GlobalConfiguration{
|
||||
DefaultEntryPoints: []string{"http", "https"},
|
||||
|
@ -200,8 +60,8 @@ func TestServerLoadCertificateWithDefaultEntryPoint(t *testing.T) {
|
|||
"http": {Configuration: &configuration.EntryPoint{}},
|
||||
}
|
||||
|
||||
dynamicConfigs := types.Configurations{
|
||||
"config": &types.Configuration{
|
||||
dynamicConfigs := config.Configurations{
|
||||
"config": &config.Configuration{
|
||||
TLS: []*tls.Configuration{
|
||||
{
|
||||
Certificate: &tls.Certificate{
|
||||
|
@ -214,61 +74,58 @@ func TestServerLoadCertificateWithDefaultEntryPoint(t *testing.T) {
|
|||
}
|
||||
|
||||
srv := NewServer(globalConfig, nil, entryPoints)
|
||||
|
||||
mapEntryPoints := srv.loadConfig(dynamicConfigs, globalConfig)
|
||||
if !mapEntryPoints["https"].certs.ContainsCertificates() {
|
||||
_, mapsCerts := srv.loadConfig(dynamicConfigs, globalConfig)
|
||||
if len(mapsCerts["https"]) == 0 {
|
||||
t.Fatal("got error: https entryPoint must have TLS certificates.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReuseBackend(t *testing.T) {
|
||||
func TestReuseService(t *testing.T) {
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
globalConfig := configuration.GlobalConfiguration{
|
||||
DefaultEntryPoints: []string{"http"},
|
||||
}
|
||||
|
||||
entryPoints := map[string]EntryPoint{
|
||||
"http": {Configuration: &configuration.EntryPoint{
|
||||
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
|
||||
}},
|
||||
}
|
||||
|
||||
dynamicConfigs := types.Configurations{
|
||||
globalConfig := configuration.GlobalConfiguration{
|
||||
DefaultEntryPoints: []string{"http"},
|
||||
}
|
||||
|
||||
dynamicConfigs := config.Configurations{
|
||||
"config": th.BuildConfiguration(
|
||||
th.WithFrontends(
|
||||
th.WithFrontend("backend",
|
||||
th.WithFrontendName("frontend0"),
|
||||
th.WithRouters(
|
||||
th.WithRouter("foo",
|
||||
th.WithServiceName("bar"),
|
||||
th.WithRule("Path:/ok")),
|
||||
th.WithRouter("foo2",
|
||||
th.WithEntryPoints("http"),
|
||||
th.WithRoutes(th.WithRoute("/ok", "Path: /ok"))),
|
||||
th.WithFrontend("backend",
|
||||
th.WithFrontendName("frontend1"),
|
||||
th.WithEntryPoints("http"),
|
||||
th.WithRoutes(th.WithRoute("/unauthorized", "Path: /unauthorized")),
|
||||
th.WithFrontEndAuth(&types.Auth{
|
||||
Basic: &types.Basic{
|
||||
Users: []string{"foo:bar"},
|
||||
},
|
||||
})),
|
||||
th.WithRule("Path:/unauthorized"),
|
||||
th.WithServiceName("bar"),
|
||||
th.WithRouterMiddlewares("basicauth")),
|
||||
),
|
||||
th.WithBackends(th.WithBackendNew("backend",
|
||||
th.WithMiddlewares(th.WithMiddleware("basicauth",
|
||||
th.WithBasicAuth(&config.BasicAuth{Users: []string{"foo:bar"}}),
|
||||
)),
|
||||
th.WithLoadBalancerServices(th.WithService("bar",
|
||||
th.WithLBMethod("wrr"),
|
||||
th.WithServersNew(th.WithServerNew(testServer.URL))),
|
||||
th.WithServers(th.WithServer(testServer.URL))),
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
srv := NewServer(globalConfig, nil, entryPoints)
|
||||
|
||||
serverEntryPoints := srv.loadConfig(dynamicConfigs, globalConfig)
|
||||
serverEntryPoints, _ := srv.loadConfig(dynamicConfigs, globalConfig)
|
||||
|
||||
// Test that the /ok path returns a status 200.
|
||||
responseRecorderOk := &httptest.ResponseRecorder{}
|
||||
requestOk := httptest.NewRequest(http.MethodGet, testServer.URL+"/ok", nil)
|
||||
serverEntryPoints["http"].httpRouter.ServeHTTP(responseRecorderOk, requestOk)
|
||||
serverEntryPoints["http"].ServeHTTP(responseRecorderOk, requestOk)
|
||||
|
||||
assert.Equal(t, http.StatusOK, responseRecorderOk.Result().StatusCode, "status code")
|
||||
|
||||
|
@ -276,15 +133,15 @@ func TestReuseBackend(t *testing.T) {
|
|||
// the basic authentication defined on the frontend.
|
||||
responseRecorderUnauthorized := &httptest.ResponseRecorder{}
|
||||
requestUnauthorized := httptest.NewRequest(http.MethodGet, testServer.URL+"/unauthorized", nil)
|
||||
serverEntryPoints["http"].httpRouter.ServeHTTP(responseRecorderUnauthorized, requestUnauthorized)
|
||||
serverEntryPoints["http"].ServeHTTP(responseRecorderUnauthorized, requestUnauthorized)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, responseRecorderUnauthorized.Result().StatusCode, "status code")
|
||||
}
|
||||
|
||||
func TestThrottleProviderConfigReload(t *testing.T) {
|
||||
throttleDuration := 30 * time.Millisecond
|
||||
publishConfig := make(chan types.ConfigMessage)
|
||||
providerConfig := make(chan types.ConfigMessage)
|
||||
publishConfig := make(chan config.Message)
|
||||
providerConfig := make(chan config.Message)
|
||||
stop := make(chan bool)
|
||||
defer func() {
|
||||
stop <- true
|
||||
|
@ -312,7 +169,7 @@ func TestThrottleProviderConfigReload(t *testing.T) {
|
|||
|
||||
// publish 5 new configs, one new config each 10 milliseconds
|
||||
for i := 0; i < 5; i++ {
|
||||
providerConfig <- types.ConfigMessage{}
|
||||
providerConfig <- config.Message{}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
|
@ -335,205 +192,3 @@ func TestThrottleProviderConfigReload(t *testing.T) {
|
|||
t.Error("Last config was not published in time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerMultipleFrontendRules(t *testing.T) {
|
||||
testCases := []struct {
|
||||
expression string
|
||||
requestURL string
|
||||
expectedURL string
|
||||
}{
|
||||
{
|
||||
expression: "Host:foo.bar",
|
||||
requestURL: "http://foo.bar",
|
||||
expectedURL: "http://foo.bar",
|
||||
},
|
||||
{
|
||||
expression: "PathPrefix:/management;ReplacePath:/health",
|
||||
requestURL: "http://foo.bar/management",
|
||||
expectedURL: "http://foo.bar/health",
|
||||
},
|
||||
{
|
||||
expression: "Host:foo.bar;AddPrefix:/blah",
|
||||
requestURL: "http://foo.bar/baz",
|
||||
expectedURL: "http://foo.bar/blah/baz",
|
||||
},
|
||||
{
|
||||
expression: "PathPrefixStripRegex:/one/{two}/{three:[0-9]+}",
|
||||
requestURL: "http://foo.bar/one/some/12345/four",
|
||||
expectedURL: "http://foo.bar/four",
|
||||
},
|
||||
{
|
||||
expression: "PathPrefixStripRegex:/one/{two}/{three:[0-9]+};AddPrefix:/zero",
|
||||
requestURL: "http://foo.bar/one/some/12345/four",
|
||||
expectedURL: "http://foo.bar/zero/four",
|
||||
},
|
||||
{
|
||||
expression: "AddPrefix:/blah;ReplacePath:/baz",
|
||||
requestURL: "http://foo.bar/hello",
|
||||
expectedURL: "http://foo.bar/baz",
|
||||
},
|
||||
{
|
||||
expression: "PathPrefixStrip:/management;ReplacePath:/health",
|
||||
requestURL: "http://foo.bar/management",
|
||||
expectedURL: "http://foo.bar/health",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.expression, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
router := mux.NewRouter()
|
||||
route := router.NewRoute()
|
||||
serverRoute := &types.ServerRoute{Route: route}
|
||||
reqHostMid := &middlewares.RequestHost{}
|
||||
rls := &rules.Rules{Route: serverRoute}
|
||||
|
||||
expression := test.expression
|
||||
routeResult, err := rls.Parse(expression)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Error while building route for %s: %+v", expression, err)
|
||||
}
|
||||
|
||||
request := th.MustNewRequest(http.MethodGet, test.requestURL, nil)
|
||||
var routeMatch bool
|
||||
reqHostMid.ServeHTTP(nil, request, func(w http.ResponseWriter, r *http.Request) {
|
||||
routeMatch = routeResult.Match(r, &mux.RouteMatch{Route: routeResult})
|
||||
})
|
||||
|
||||
if !routeMatch {
|
||||
t.Fatalf("Rule %s doesn't match", expression)
|
||||
}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, test.expectedURL, r.URL.String(), "URL")
|
||||
})
|
||||
|
||||
hd := buildMatcherMiddlewares(serverRoute, handler)
|
||||
serverRoute.Route.Handler(hd)
|
||||
|
||||
serverRoute.Route.GetHandler().ServeHTTP(nil, request)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerBuildHealthCheckOptions(t *testing.T) {
|
||||
lb := &testLoadBalancer{}
|
||||
globalInterval := 15 * time.Second
|
||||
globalTimeout := 3 * time.Second
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
hc *types.HealthCheck
|
||||
expectedOpts *healthcheck.Options
|
||||
}{
|
||||
{
|
||||
desc: "nil health check",
|
||||
hc: nil,
|
||||
expectedOpts: nil,
|
||||
},
|
||||
{
|
||||
desc: "empty path",
|
||||
hc: &types.HealthCheck{
|
||||
Path: "",
|
||||
},
|
||||
expectedOpts: nil,
|
||||
},
|
||||
{
|
||||
desc: "unparseable interval",
|
||||
hc: &types.HealthCheck{
|
||||
Path: "/path",
|
||||
Interval: "unparseable",
|
||||
},
|
||||
expectedOpts: &healthcheck.Options{
|
||||
Path: "/path",
|
||||
Interval: globalInterval,
|
||||
LB: lb,
|
||||
Timeout: 3 * time.Second,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "sub-zero interval",
|
||||
hc: &types.HealthCheck{
|
||||
Path: "/path",
|
||||
Interval: "-42s",
|
||||
},
|
||||
expectedOpts: &healthcheck.Options{
|
||||
Path: "/path",
|
||||
Interval: globalInterval,
|
||||
LB: lb,
|
||||
Timeout: 3 * time.Second,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "parseable interval",
|
||||
hc: &types.HealthCheck{
|
||||
Path: "/path",
|
||||
Interval: "5m",
|
||||
},
|
||||
expectedOpts: &healthcheck.Options{
|
||||
Path: "/path",
|
||||
Interval: 5 * time.Minute,
|
||||
LB: lb,
|
||||
Timeout: 3 * time.Second,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "unparseable timeout",
|
||||
hc: &types.HealthCheck{
|
||||
Path: "/path",
|
||||
Interval: "15s",
|
||||
Timeout: "unparseable",
|
||||
},
|
||||
expectedOpts: &healthcheck.Options{
|
||||
Path: "/path",
|
||||
Interval: globalInterval,
|
||||
Timeout: globalTimeout,
|
||||
LB: lb,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "sub-zero timeout",
|
||||
hc: &types.HealthCheck{
|
||||
Path: "/path",
|
||||
Interval: "15s",
|
||||
Timeout: "-42s",
|
||||
},
|
||||
expectedOpts: &healthcheck.Options{
|
||||
Path: "/path",
|
||||
Interval: globalInterval,
|
||||
Timeout: globalTimeout,
|
||||
LB: lb,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "parseable timeout",
|
||||
hc: &types.HealthCheck{
|
||||
Path: "/path",
|
||||
Interval: "15s",
|
||||
Timeout: "10s",
|
||||
},
|
||||
expectedOpts: &healthcheck.Options{
|
||||
Path: "/path",
|
||||
Interval: globalInterval,
|
||||
Timeout: 10 * time.Second,
|
||||
LB: lb,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
opts := buildHealthCheckOptions(lb, "backend", test.hc, &configuration.HealthCheckConfig{
|
||||
Interval: parse.Duration(globalInterval),
|
||||
Timeout: parse.Duration(globalTimeout),
|
||||
})
|
||||
assert.Equal(t, test.expectedOpts, opts, "health check options")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,439 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/configuration"
|
||||
"github.com/containous/traefik/healthcheck"
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/middlewares/accesslog"
|
||||
"github.com/containous/traefik/server/cookie"
|
||||
traefiktls "github.com/containous/traefik/tls"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/vulcand/oxy/buffer"
|
||||
"github.com/vulcand/oxy/connlimit"
|
||||
"github.com/vulcand/oxy/ratelimit"
|
||||
"github.com/vulcand/oxy/roundrobin"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
type h2cTransportWrapper struct {
|
||||
*http2.Transport
|
||||
}
|
||||
|
||||
func (t *h2cTransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
req.URL.Scheme = "http"
|
||||
return t.Transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func (s *Server) buildBalancerMiddlewares(frontendName string, frontend *types.Frontend, backend *types.Backend, fwd http.Handler) (http.Handler, *healthcheck.BackendConfig, error) {
|
||||
balancer, err := s.buildLoadBalancer(frontendName, frontend.Backend, backend, fwd)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Health Check
|
||||
var backendHealthCheck *healthcheck.BackendConfig
|
||||
if hcOpts := buildHealthCheckOptions(balancer, frontend.Backend, backend.HealthCheck, s.globalConfiguration.HealthCheck); hcOpts != nil {
|
||||
log.Debugf("Setting up backend health check %s", *hcOpts)
|
||||
|
||||
hcOpts.Transport = s.defaultForwardingRoundTripper
|
||||
backendHealthCheck = healthcheck.NewBackendConfig(*hcOpts, frontend.Backend)
|
||||
}
|
||||
|
||||
// Empty (backend with no servers)
|
||||
var lb http.Handler = middlewares.NewEmptyBackendHandler(balancer)
|
||||
|
||||
// Rate Limit
|
||||
if frontend.RateLimit != nil && len(frontend.RateLimit.RateSet) > 0 {
|
||||
handler, err := buildRateLimiter(lb, frontend.RateLimit)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error creating rate limiter: %v", err)
|
||||
}
|
||||
|
||||
lb = s.wrapHTTPHandlerWithAccessLog(
|
||||
s.tracingMiddleware.NewHTTPHandlerWrapper("Rate limit", handler, false),
|
||||
fmt.Sprintf("rate limit for %s", frontendName),
|
||||
)
|
||||
}
|
||||
|
||||
// Max Connections
|
||||
if backend.MaxConn != nil && backend.MaxConn.Amount != 0 {
|
||||
log.Debugf("Creating load-balancer connection limit")
|
||||
|
||||
handler, err := buildMaxConn(lb, backend.MaxConn)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
lb = s.wrapHTTPHandlerWithAccessLog(handler, fmt.Sprintf("connection limit for %s", frontendName))
|
||||
}
|
||||
|
||||
// Retry
|
||||
if s.globalConfiguration.Retry != nil {
|
||||
handler := s.buildRetryMiddleware(lb, s.globalConfiguration.Retry, len(backend.Servers), frontend.Backend)
|
||||
lb = s.tracingMiddleware.NewHTTPHandlerWrapper("Retry", handler, false)
|
||||
}
|
||||
|
||||
// Buffering
|
||||
if backend.Buffering != nil {
|
||||
handler, err := buildBufferingMiddleware(lb, backend.Buffering)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error setting up buffering middleware: %s", err)
|
||||
}
|
||||
|
||||
// TODO refactor ?
|
||||
lb = handler
|
||||
}
|
||||
|
||||
// Circuit Breaker
|
||||
if backend.CircuitBreaker != nil {
|
||||
log.Debugf("Creating circuit breaker %s", backend.CircuitBreaker.Expression)
|
||||
|
||||
expression := backend.CircuitBreaker.Expression
|
||||
circuitBreaker, err := middlewares.NewCircuitBreaker(lb, expression, middlewares.NewCircuitBreakerOptions(expression))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error creating circuit breaker: %v", err)
|
||||
}
|
||||
|
||||
lb = s.tracingMiddleware.NewHTTPHandlerWrapper("Circuit breaker", circuitBreaker, false)
|
||||
}
|
||||
|
||||
return lb, backendHealthCheck, nil
|
||||
}
|
||||
|
||||
func (s *Server) buildLoadBalancer(frontendName string, backendName string, backend *types.Backend, fwd http.Handler) (healthcheck.BalancerHandler, error) {
|
||||
var rr *roundrobin.RoundRobin
|
||||
var saveFrontend http.Handler
|
||||
|
||||
if s.accessLoggerMiddleware != nil {
|
||||
saveUsername := accesslog.NewSaveUsername(fwd)
|
||||
saveBackend := accesslog.NewSaveBackend(saveUsername, backendName)
|
||||
saveFrontend = accesslog.NewSaveFrontend(saveBackend, frontendName)
|
||||
rr, _ = roundrobin.New(saveFrontend)
|
||||
} else {
|
||||
rr, _ = roundrobin.New(fwd)
|
||||
}
|
||||
|
||||
var stickySession *roundrobin.StickySession
|
||||
var cookieName string
|
||||
if stickiness := backend.LoadBalancer.Stickiness; stickiness != nil {
|
||||
cookieName = cookie.GetName(stickiness.CookieName, backendName)
|
||||
stickySession = roundrobin.NewStickySession(cookieName)
|
||||
}
|
||||
|
||||
lbMethod, err := types.NewLoadBalancerMethod(backend.LoadBalancer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error loading load balancer method '%+v' for frontend %s: %v", backend.LoadBalancer, frontendName, err)
|
||||
}
|
||||
|
||||
var lb healthcheck.BalancerHandler
|
||||
|
||||
switch lbMethod {
|
||||
case types.Drr:
|
||||
log.Debug("Creating load-balancer drr")
|
||||
|
||||
if stickySession != nil {
|
||||
log.Debugf("Sticky session with cookie %v", cookieName)
|
||||
|
||||
lb, err = roundrobin.NewRebalancer(rr, roundrobin.RebalancerStickySession(stickySession))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
lb, err = roundrobin.NewRebalancer(rr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
case types.Wrr:
|
||||
log.Debug("Creating load-balancer wrr")
|
||||
|
||||
if stickySession != nil {
|
||||
log.Debugf("Sticky session with cookie %v", cookieName)
|
||||
|
||||
if s.accessLoggerMiddleware != nil {
|
||||
lb, err = roundrobin.New(saveFrontend, roundrobin.EnableStickySession(stickySession))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
lb, err = roundrobin.New(fwd, roundrobin.EnableStickySession(stickySession))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
lb = rr
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("invalid load-balancing method %q", lbMethod)
|
||||
}
|
||||
|
||||
if err := s.configureLBServers(lb, backend, backendName); err != nil {
|
||||
return nil, fmt.Errorf("error configuring load balancer for frontend %s: %v", frontendName, err)
|
||||
}
|
||||
|
||||
return lb, nil
|
||||
}
|
||||
|
||||
func (s *Server) configureLBServers(lb healthcheck.BalancerHandler, backend *types.Backend, backendName string) error {
|
||||
for name, srv := range backend.Servers {
|
||||
u, err := url.Parse(srv.URL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing server URL %s: %v", srv.URL, err)
|
||||
}
|
||||
|
||||
log.Debugf("Creating server %s at %s with weight %d", name, u, srv.Weight)
|
||||
|
||||
if err := lb.UpsertServer(u, roundrobin.Weight(srv.Weight)); err != nil {
|
||||
return fmt.Errorf("error adding server %s to load balancer: %v", srv.URL, err)
|
||||
}
|
||||
|
||||
s.metricsRegistry.BackendServerUpGauge().With("backend", backendName, "url", srv.URL).Set(1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// getRoundTripper will either use server.defaultForwardingRoundTripper or create a new one
|
||||
// given a custom TLS configuration is passed and the passTLSCert option is set to true.
|
||||
func (s *Server) getRoundTripper(entryPointName string, passTLSCert bool, tls *traefiktls.TLS) (http.RoundTripper, error) {
|
||||
if passTLSCert {
|
||||
tlsConfig, err := createClientTLSConfig(entryPointName, tls)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create TLSClientConfig: %v", err)
|
||||
}
|
||||
|
||||
transport, err := createHTTPTransport(s.globalConfiguration)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create HTTP transport: %v", err)
|
||||
}
|
||||
|
||||
transport.TLSClientConfig = tlsConfig
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
return s.defaultForwardingRoundTripper, nil
|
||||
}
|
||||
|
||||
// createHTTPTransport creates an http.Transport configured with the GlobalConfiguration settings.
|
||||
// For the settings that can't be configured in Traefik it uses the default http.Transport settings.
|
||||
// An exception to this is the MaxIdleConns setting as we only provide the option MaxIdleConnsPerHost
|
||||
// in Traefik at this point in time. Setting this value to the default of 100 could lead to confusing
|
||||
// behavior and backwards compatibility issues.
|
||||
func createHTTPTransport(globalConfiguration configuration.GlobalConfiguration) (*http.Transport, error) {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: configuration.DefaultDialTimeout,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}
|
||||
|
||||
if globalConfiguration.ForwardingTimeouts != nil {
|
||||
dialer.Timeout = time.Duration(globalConfiguration.ForwardingTimeouts.DialTimeout)
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: dialer.DialContext,
|
||||
MaxIdleConnsPerHost: globalConfiguration.MaxIdleConnsPerHost,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
transport.RegisterProtocol("h2c", &h2cTransportWrapper{
|
||||
Transport: &http2.Transport{
|
||||
DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
|
||||
return net.Dial(netw, addr)
|
||||
},
|
||||
AllowHTTP: true,
|
||||
},
|
||||
})
|
||||
|
||||
if globalConfiguration.ForwardingTimeouts != nil {
|
||||
transport.ResponseHeaderTimeout = time.Duration(globalConfiguration.ForwardingTimeouts.ResponseHeaderTimeout)
|
||||
}
|
||||
|
||||
if globalConfiguration.InsecureSkipVerify {
|
||||
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
}
|
||||
|
||||
if len(globalConfiguration.RootCAs) > 0 {
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
RootCAs: createRootCACertPool(globalConfiguration.RootCAs),
|
||||
}
|
||||
}
|
||||
|
||||
err := http2.ConfigureTransport(transport)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
func createRootCACertPool(rootCAs traefiktls.FilesOrContents) *x509.CertPool {
|
||||
roots := x509.NewCertPool()
|
||||
|
||||
for _, cert := range rootCAs {
|
||||
certContent, err := cert.Read()
|
||||
if err != nil {
|
||||
log.Error("Error while read RootCAs", err)
|
||||
continue
|
||||
}
|
||||
roots.AppendCertsFromPEM(certContent)
|
||||
}
|
||||
|
||||
return roots
|
||||
}
|
||||
|
||||
func createClientTLSConfig(entryPointName string, tlsOption *traefiktls.TLS) (*tls.Config, error) {
|
||||
if tlsOption == nil {
|
||||
return nil, errors.New("no TLS provided")
|
||||
}
|
||||
|
||||
config, err := tlsOption.Certificates.CreateTLSConfig(entryPointName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(tlsOption.ClientCA.Files) > 0 {
|
||||
pool := x509.NewCertPool()
|
||||
for _, caFile := range tlsOption.ClientCA.Files {
|
||||
data, err := caFile.Read()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !pool.AppendCertsFromPEM(data) {
|
||||
return nil, fmt.Errorf("invalid certificate(s) in %s", caFile)
|
||||
}
|
||||
}
|
||||
config.RootCAs = pool
|
||||
}
|
||||
|
||||
config.BuildNameToCertificate()
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (s *Server) buildRetryMiddleware(handler http.Handler, retry *configuration.Retry, countServers int, backendName string) http.Handler {
|
||||
retryListeners := middlewares.RetryListeners{}
|
||||
if s.metricsRegistry.IsEnabled() {
|
||||
retryListeners = append(retryListeners, middlewares.NewMetricsRetryListener(s.metricsRegistry, backendName))
|
||||
}
|
||||
if s.accessLoggerMiddleware != nil {
|
||||
retryListeners = append(retryListeners, &accesslog.SaveRetries{})
|
||||
}
|
||||
|
||||
retryAttempts := countServers
|
||||
if retry.Attempts > 0 {
|
||||
retryAttempts = retry.Attempts
|
||||
}
|
||||
|
||||
log.Debugf("Creating retries max attempts %d", retryAttempts)
|
||||
|
||||
return middlewares.NewRetry(retryAttempts, handler, retryListeners)
|
||||
}
|
||||
|
||||
func buildRateLimiter(handler http.Handler, rlConfig *types.RateLimit) (http.Handler, error) {
|
||||
extractFunc, err := utils.NewExtractor(rlConfig.ExtractorFunc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debugf("Creating load-balancer rate limiter")
|
||||
|
||||
rateSet := ratelimit.NewRateSet()
|
||||
for _, rate := range rlConfig.RateSet {
|
||||
if err := rateSet.Add(time.Duration(rate.Period), rate.Average, rate.Burst); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return ratelimit.New(handler, extractFunc, rateSet)
|
||||
}
|
||||
|
||||
func buildBufferingMiddleware(handler http.Handler, config *types.Buffering) (http.Handler, error) {
|
||||
log.Debugf("Setting up buffering: request limits: %d (mem), %d (max), response limits: %d (mem), %d (max) with retry: '%s'",
|
||||
config.MemRequestBodyBytes, config.MaxRequestBodyBytes, config.MemResponseBodyBytes,
|
||||
config.MaxResponseBodyBytes, config.RetryExpression)
|
||||
|
||||
return buffer.New(
|
||||
handler,
|
||||
buffer.MemRequestBodyBytes(config.MemRequestBodyBytes),
|
||||
buffer.MaxRequestBodyBytes(config.MaxRequestBodyBytes),
|
||||
buffer.MemResponseBodyBytes(config.MemResponseBodyBytes),
|
||||
buffer.MaxResponseBodyBytes(config.MaxResponseBodyBytes),
|
||||
buffer.CondSetter(len(config.RetryExpression) > 0, buffer.Retry(config.RetryExpression)),
|
||||
)
|
||||
}
|
||||
|
||||
func buildMaxConn(lb http.Handler, maxConns *types.MaxConn) (http.Handler, error) {
|
||||
extractFunc, err := utils.NewExtractor(maxConns.ExtractorFunc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating connection limit: %v", err)
|
||||
}
|
||||
|
||||
log.Debugf("Creating load-balancer connection limit")
|
||||
|
||||
handler, err := connlimit.New(lb, extractFunc, maxConns.Amount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating connection limit: %v", err)
|
||||
}
|
||||
|
||||
return handler, nil
|
||||
}
|
||||
|
||||
func buildHealthCheckOptions(lb healthcheck.BalancerHandler, backend string, hc *types.HealthCheck, hcConfig *configuration.HealthCheckConfig) *healthcheck.Options {
|
||||
if hc == nil || hc.Path == "" || hcConfig == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
interval := time.Duration(hcConfig.Interval)
|
||||
if hc.Interval != "" {
|
||||
intervalOverride, err := time.ParseDuration(hc.Interval)
|
||||
if err != nil {
|
||||
log.Errorf("Illegal health check interval for backend '%s': %s", backend, err)
|
||||
} else if intervalOverride <= 0 {
|
||||
log.Errorf("Health check interval smaller than zero for backend '%s', backend", backend)
|
||||
} else {
|
||||
interval = intervalOverride
|
||||
}
|
||||
}
|
||||
|
||||
timeout := time.Duration(hcConfig.Timeout)
|
||||
if hc.Timeout != "" {
|
||||
timeoutOverride, err := time.ParseDuration(hc.Timeout)
|
||||
if err != nil {
|
||||
log.Errorf("Illegal health check timeout for backend '%s': %s", backend, err)
|
||||
} else if timeoutOverride <= 0 {
|
||||
log.Errorf("Health check timeout smaller than zero for backend '%s', backend", backend)
|
||||
} else {
|
||||
timeout = timeoutOverride
|
||||
}
|
||||
}
|
||||
|
||||
if timeout >= interval {
|
||||
log.Warnf("Health check timeout for backend '%s' should be lower than the health check interval. Interval set to timeout + 1 second (%s).", backend)
|
||||
}
|
||||
|
||||
return &healthcheck.Options{
|
||||
Scheme: hc.Scheme,
|
||||
Path: hc.Path,
|
||||
Port: hc.Port,
|
||||
Interval: interval,
|
||||
Timeout: timeout,
|
||||
LB: lb,
|
||||
Hostname: hc.Hostname,
|
||||
Headers: hc.Headers,
|
||||
}
|
||||
}
|
|
@ -1,81 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConfigureBackends(t *testing.T) {
|
||||
validMethod := "Drr"
|
||||
defaultMethod := "wrr"
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
lb *types.LoadBalancer
|
||||
expectedMethod string
|
||||
expectedStickiness *types.Stickiness
|
||||
}{
|
||||
{
|
||||
desc: "valid load balancer method with sticky enabled",
|
||||
lb: &types.LoadBalancer{
|
||||
Method: validMethod,
|
||||
Stickiness: &types.Stickiness{},
|
||||
},
|
||||
expectedMethod: validMethod,
|
||||
expectedStickiness: &types.Stickiness{},
|
||||
},
|
||||
{
|
||||
desc: "valid load balancer method with sticky disabled",
|
||||
lb: &types.LoadBalancer{
|
||||
Method: validMethod,
|
||||
Stickiness: nil,
|
||||
},
|
||||
expectedMethod: validMethod,
|
||||
},
|
||||
{
|
||||
desc: "invalid load balancer method with sticky enabled",
|
||||
lb: &types.LoadBalancer{
|
||||
Method: "Invalid",
|
||||
Stickiness: &types.Stickiness{},
|
||||
},
|
||||
expectedMethod: defaultMethod,
|
||||
expectedStickiness: &types.Stickiness{},
|
||||
},
|
||||
{
|
||||
desc: "invalid load balancer method with sticky disabled",
|
||||
lb: &types.LoadBalancer{
|
||||
Method: "Invalid",
|
||||
Stickiness: nil,
|
||||
},
|
||||
expectedMethod: defaultMethod,
|
||||
},
|
||||
{
|
||||
desc: "missing load balancer",
|
||||
lb: nil,
|
||||
expectedMethod: defaultMethod,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
backend := &types.Backend{
|
||||
LoadBalancer: test.lb,
|
||||
}
|
||||
|
||||
configureBackends(map[string]*types.Backend{
|
||||
"backend": backend,
|
||||
})
|
||||
|
||||
expected := types.LoadBalancer{
|
||||
Method: test.expectedMethod,
|
||||
Stickiness: test.expectedStickiness,
|
||||
}
|
||||
|
||||
assert.Equal(t, expected, *backend.LoadBalancer)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,350 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/middlewares/accesslog"
|
||||
mauth "github.com/containous/traefik/middlewares/auth"
|
||||
"github.com/containous/traefik/middlewares/errorpages"
|
||||
"github.com/containous/traefik/middlewares/forwardedheaders"
|
||||
"github.com/containous/traefik/middlewares/redirect"
|
||||
"github.com/containous/traefik/types"
|
||||
thoas_stats "github.com/thoas/stats"
|
||||
"github.com/unrolled/secure"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
type handlerPostConfig func(backendsHandlers map[string]http.Handler) error
|
||||
|
||||
type modifyResponse func(*http.Response) error
|
||||
|
||||
func (s *Server) buildMiddlewares(frontendName string, frontend *types.Frontend,
|
||||
backends map[string]*types.Backend, entryPointName string, providerName string) ([]negroni.Handler, modifyResponse, handlerPostConfig, error) {
|
||||
|
||||
var middle []negroni.Handler
|
||||
var postConfig handlerPostConfig
|
||||
|
||||
// Error pages
|
||||
if len(frontend.Errors) > 0 {
|
||||
handlers, err := buildErrorPagesMiddleware(frontendName, frontend, backends, entryPointName, providerName)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
postConfig = errorPagesPostConfig(handlers)
|
||||
|
||||
for _, handler := range handlers {
|
||||
middle = append(middle, handler)
|
||||
}
|
||||
}
|
||||
|
||||
// Metrics
|
||||
if s.metricsRegistry.IsEnabled() {
|
||||
handler := middlewares.NewBackendMetricsMiddleware(s.metricsRegistry, frontend.Backend)
|
||||
middle = append(middle, handler)
|
||||
}
|
||||
|
||||
// Whitelist
|
||||
ipWhitelistMiddleware, err := buildIPWhiteLister(frontend.WhiteList, s.entryPoints[entryPointName].Configuration.ClientIPStrategy)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error creating IP Whitelister: %s", err)
|
||||
}
|
||||
if ipWhitelistMiddleware != nil {
|
||||
log.Debugf("Configured IP Whitelists: %v", frontend.WhiteList.SourceRange)
|
||||
|
||||
handler := s.tracingMiddleware.NewNegroniHandlerWrapper(
|
||||
"IP whitelist",
|
||||
s.wrapNegroniHandlerWithAccessLog(ipWhitelistMiddleware, fmt.Sprintf("ipwhitelister for %s", frontendName)),
|
||||
false)
|
||||
middle = append(middle, handler)
|
||||
}
|
||||
|
||||
// Redirect
|
||||
if frontend.Redirect != nil && entryPointName != frontend.Redirect.EntryPoint {
|
||||
rewrite, err := s.buildRedirectHandler(entryPointName, frontend.Redirect)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error creating Frontend Redirect: %v", err)
|
||||
}
|
||||
|
||||
handler := s.wrapNegroniHandlerWithAccessLog(rewrite, fmt.Sprintf("frontend redirect for %s", frontendName))
|
||||
middle = append(middle, handler)
|
||||
|
||||
log.Debugf("Frontend %s redirect created", frontendName)
|
||||
}
|
||||
|
||||
// Header
|
||||
headerMiddleware := middlewares.NewHeaderFromStruct(frontend.Headers)
|
||||
if headerMiddleware != nil {
|
||||
log.Debugf("Adding header middleware for frontend %s", frontendName)
|
||||
|
||||
handler := s.tracingMiddleware.NewNegroniHandlerWrapper("Header", headerMiddleware, false)
|
||||
middle = append(middle, handler)
|
||||
}
|
||||
|
||||
// Secure
|
||||
secureMiddleware := middlewares.NewSecure(frontend.Headers)
|
||||
if secureMiddleware != nil {
|
||||
log.Debugf("Adding secure middleware for frontend %s", frontendName)
|
||||
|
||||
handler := negroni.HandlerFunc(secureMiddleware.HandlerFuncWithNextForRequestOnly)
|
||||
middle = append(middle, handler)
|
||||
}
|
||||
|
||||
// Authentication
|
||||
if frontend.Auth != nil {
|
||||
authMiddleware, err := mauth.NewAuthenticator(frontend.Auth, s.tracingMiddleware)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
handler := s.wrapNegroniHandlerWithAccessLog(authMiddleware, fmt.Sprintf("Auth for %s", frontendName))
|
||||
middle = append(middle, handler)
|
||||
}
|
||||
|
||||
// TLSClientHeaders
|
||||
tlsClientHeadersMiddleware := middlewares.NewTLSClientHeaders(frontend)
|
||||
if tlsClientHeadersMiddleware != nil {
|
||||
log.Debugf("Adding TLSClientHeaders middleware for frontend %s", frontendName)
|
||||
|
||||
handler := s.tracingMiddleware.NewNegroniHandlerWrapper("TLSClientHeaders", tlsClientHeadersMiddleware, false)
|
||||
middle = append(middle, handler)
|
||||
}
|
||||
|
||||
return middle, buildModifyResponse(secureMiddleware, headerMiddleware), postConfig, nil
|
||||
}
|
||||
|
||||
func (s *Server) buildServerEntryPointMiddlewares(serverEntryPointName string) ([]negroni.Handler, error) {
|
||||
serverMiddlewares := []negroni.Handler{middlewares.NegroniRecoverHandler()}
|
||||
|
||||
if s.tracingMiddleware.IsEnabled() {
|
||||
serverMiddlewares = append(serverMiddlewares, s.tracingMiddleware.NewEntryPoint(serverEntryPointName))
|
||||
}
|
||||
|
||||
if s.accessLoggerMiddleware != nil {
|
||||
serverMiddlewares = append(serverMiddlewares, s.accessLoggerMiddleware)
|
||||
}
|
||||
|
||||
if s.metricsRegistry.IsEnabled() {
|
||||
serverMiddlewares = append(serverMiddlewares, middlewares.NewEntryPointMetricsMiddleware(s.metricsRegistry, serverEntryPointName))
|
||||
}
|
||||
|
||||
if s.globalConfiguration.API != nil {
|
||||
if s.globalConfiguration.API.Stats == nil {
|
||||
s.globalConfiguration.API.Stats = thoas_stats.New()
|
||||
}
|
||||
serverMiddlewares = append(serverMiddlewares, s.globalConfiguration.API.Stats)
|
||||
if s.globalConfiguration.API.Statistics != nil {
|
||||
if s.globalConfiguration.API.StatsRecorder == nil {
|
||||
s.globalConfiguration.API.StatsRecorder = middlewares.NewStatsRecorder(s.globalConfiguration.API.Statistics.RecentErrors)
|
||||
}
|
||||
serverMiddlewares = append(serverMiddlewares, s.globalConfiguration.API.StatsRecorder)
|
||||
}
|
||||
}
|
||||
|
||||
if s.entryPoints[serverEntryPointName].Configuration.Redirect != nil {
|
||||
redirectHandlers, err := s.buildEntryPointRedirect()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create redirect middleware: %v", err)
|
||||
}
|
||||
serverMiddlewares = append(serverMiddlewares, redirectHandlers[serverEntryPointName])
|
||||
}
|
||||
|
||||
if s.entryPoints[serverEntryPointName].Configuration.Auth != nil {
|
||||
authMiddleware, err := mauth.NewAuthenticator(s.entryPoints[serverEntryPointName].Configuration.Auth, s.tracingMiddleware)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create authentication middleware: %v", err)
|
||||
}
|
||||
serverMiddlewares = append(serverMiddlewares, s.wrapNegroniHandlerWithAccessLog(authMiddleware, fmt.Sprintf("Auth for entrypoint %s", serverEntryPointName)))
|
||||
}
|
||||
|
||||
if s.entryPoints[serverEntryPointName].Configuration.Compress != nil {
|
||||
serverMiddlewares = append(serverMiddlewares, &middlewares.Compress{})
|
||||
}
|
||||
|
||||
if s.entryPoints[serverEntryPointName].Configuration.ForwardedHeaders != nil {
|
||||
xForwardedMiddleware, err := forwardedheaders.NewXforwarded(
|
||||
s.entryPoints[serverEntryPointName].Configuration.ForwardedHeaders.Insecure,
|
||||
s.entryPoints[serverEntryPointName].Configuration.ForwardedHeaders.TrustedIPs,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create xforwarded headers middleware: %v", err)
|
||||
}
|
||||
serverMiddlewares = append(serverMiddlewares, xForwardedMiddleware)
|
||||
}
|
||||
|
||||
ipWhitelistMiddleware, err := buildIPWhiteLister(s.entryPoints[serverEntryPointName].Configuration.WhiteList, s.entryPoints[serverEntryPointName].Configuration.ClientIPStrategy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create ip whitelist middleware: %v", err)
|
||||
}
|
||||
if ipWhitelistMiddleware != nil {
|
||||
serverMiddlewares = append(serverMiddlewares, s.wrapNegroniHandlerWithAccessLog(ipWhitelistMiddleware, fmt.Sprintf("ipwhitelister for entrypoint %s", serverEntryPointName)))
|
||||
}
|
||||
|
||||
// RequestHost Cannonizer
|
||||
serverMiddlewares = append(serverMiddlewares, &middlewares.RequestHost{})
|
||||
|
||||
return serverMiddlewares, nil
|
||||
}
|
||||
|
||||
func errorPagesPostConfig(epHandlers []*errorpages.Handler) handlerPostConfig {
|
||||
return func(backendsHandlers map[string]http.Handler) error {
|
||||
for _, errorPageHandler := range epHandlers {
|
||||
if handler, ok := backendsHandlers[errorPageHandler.BackendName]; ok {
|
||||
err := errorPageHandler.PostLoad(handler)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure error pages for backend %s: %v", errorPageHandler.BackendName, err)
|
||||
}
|
||||
} else {
|
||||
err := errorPageHandler.PostLoad(nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to configure error pages for %s: %v", errorPageHandler.FallbackURL, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func buildErrorPagesMiddleware(frontendName string, frontend *types.Frontend, backends map[string]*types.Backend, entryPointName string, providerName string) ([]*errorpages.Handler, error) {
|
||||
var errorPageHandlers []*errorpages.Handler
|
||||
|
||||
for errorPageName, errorPage := range frontend.Errors {
|
||||
if frontend.Backend == errorPage.Backend {
|
||||
log.Errorf("Error when creating error page %q for frontend %q: error pages backend %q is the same as backend for the frontend (infinite call risk).",
|
||||
errorPageName, frontendName, errorPage.Backend)
|
||||
} else if backends[errorPage.Backend] == nil {
|
||||
log.Errorf("Error when creating error page %q for frontend %q: the backend %q doesn't exist.",
|
||||
errorPageName, frontendName, errorPage.Backend)
|
||||
} else {
|
||||
errorPagesHandler, err := errorpages.NewHandler(errorPage, entryPointName+providerName+errorPage.Backend)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating error pages: %v", err)
|
||||
}
|
||||
|
||||
if errorPageServer, ok := backends[errorPage.Backend].Servers["error"]; ok {
|
||||
errorPagesHandler.FallbackURL = errorPageServer.URL
|
||||
}
|
||||
|
||||
errorPageHandlers = append(errorPageHandlers, errorPagesHandler)
|
||||
}
|
||||
}
|
||||
|
||||
return errorPageHandlers, nil
|
||||
}
|
||||
|
||||
func (s *Server) buildBasicAuthMiddleware(authData []string) (*mauth.Authenticator, error) {
|
||||
users := types.Users{}
|
||||
for _, user := range authData {
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
auth := &types.Auth{}
|
||||
auth.Basic = &types.Basic{
|
||||
Users: users,
|
||||
}
|
||||
|
||||
authMiddleware, err := mauth.NewAuthenticator(auth, s.tracingMiddleware)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating Basic Auth: %v", err)
|
||||
}
|
||||
|
||||
return authMiddleware, nil
|
||||
}
|
||||
|
||||
func (s *Server) buildEntryPointRedirect() (map[string]negroni.Handler, error) {
|
||||
redirectHandlers := map[string]negroni.Handler{}
|
||||
|
||||
for entryPointName, ep := range s.entryPoints {
|
||||
entryPoint := ep.Configuration
|
||||
|
||||
if entryPoint.Redirect != nil && entryPointName != entryPoint.Redirect.EntryPoint {
|
||||
handler, err := s.buildRedirectHandler(entryPointName, entryPoint.Redirect)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error loading configuration for entrypoint %s: %v", entryPointName, err)
|
||||
}
|
||||
|
||||
handlerToUse := s.wrapNegroniHandlerWithAccessLog(handler, fmt.Sprintf("entrypoint redirect for %s", entryPointName))
|
||||
redirectHandlers[entryPointName] = handlerToUse
|
||||
}
|
||||
}
|
||||
|
||||
return redirectHandlers, nil
|
||||
}
|
||||
|
||||
func (s *Server) buildRedirectHandler(srcEntryPointName string, opt *types.Redirect) (negroni.Handler, error) {
|
||||
// entry point redirect
|
||||
if len(opt.EntryPoint) > 0 {
|
||||
entryPoint := s.entryPoints[opt.EntryPoint].Configuration
|
||||
if entryPoint == nil {
|
||||
return nil, fmt.Errorf("unknown target entrypoint %q", srcEntryPointName)
|
||||
}
|
||||
log.Debugf("Creating entry point redirect %s -> %s", srcEntryPointName, opt.EntryPoint)
|
||||
return redirect.NewEntryPointHandler(entryPoint, opt.Permanent)
|
||||
}
|
||||
|
||||
// regex redirect
|
||||
redirection, err := redirect.NewRegexHandler(opt.Regex, opt.Replacement, opt.Permanent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log.Debugf("Creating regex redirect %s -> %s -> %s", srcEntryPointName, opt.Regex, opt.Replacement)
|
||||
|
||||
return redirection, nil
|
||||
}
|
||||
|
||||
func buildIPWhiteLister(whiteList *types.WhiteList, ipStrategy *types.IPStrategy) (*middlewares.IPWhiteLister, error) {
|
||||
if whiteList == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if whiteList.IPStrategy != nil {
|
||||
ipStrategy = whiteList.IPStrategy
|
||||
}
|
||||
|
||||
strategy, err := ipStrategy.Get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return middlewares.NewIPWhiteLister(whiteList.SourceRange, strategy)
|
||||
}
|
||||
|
||||
func (s *Server) wrapNegroniHandlerWithAccessLog(handler negroni.Handler, frontendName string) negroni.Handler {
|
||||
if s.accessLoggerMiddleware != nil {
|
||||
saveUsername := accesslog.NewSaveNegroniUsername(handler)
|
||||
saveBackend := accesslog.NewSaveNegroniBackend(saveUsername, "Traefik")
|
||||
saveFrontend := accesslog.NewSaveNegroniFrontend(saveBackend, frontendName)
|
||||
return saveFrontend
|
||||
}
|
||||
return handler
|
||||
}
|
||||
|
||||
func (s *Server) wrapHTTPHandlerWithAccessLog(handler http.Handler, frontendName string) http.Handler {
|
||||
if s.accessLoggerMiddleware != nil {
|
||||
saveUsername := accesslog.NewSaveUsername(handler)
|
||||
saveBackend := accesslog.NewSaveBackend(saveUsername, "Traefik")
|
||||
saveFrontend := accesslog.NewSaveFrontend(saveBackend, frontendName)
|
||||
return saveFrontend
|
||||
}
|
||||
return handler
|
||||
}
|
||||
|
||||
func buildModifyResponse(secure *secure.Secure, header *middlewares.HeaderStruct) func(res *http.Response) error {
|
||||
return func(res *http.Response) error {
|
||||
if secure != nil {
|
||||
if err := secure.ModifyResponseHeaders(res); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if header != nil {
|
||||
if err := header.ModifyResponseHeaders(res); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
|
@ -1,270 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/configuration"
|
||||
"github.com/containous/traefik/metrics"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
th "github.com/containous/traefik/testhelpers"
|
||||
"github.com/containous/traefik/tls"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
func TestServerEntryPointWhitelistConfig(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
entrypoint *configuration.EntryPoint
|
||||
expectMiddleware bool
|
||||
}{
|
||||
{
|
||||
desc: "no whitelist middleware if no config on entrypoint",
|
||||
entrypoint: &configuration.EntryPoint{
|
||||
Address: ":0",
|
||||
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
|
||||
},
|
||||
expectMiddleware: false,
|
||||
},
|
||||
{
|
||||
desc: "whitelist middleware should be added if configured on entrypoint",
|
||||
entrypoint: &configuration.EntryPoint{
|
||||
Address: ":0",
|
||||
WhiteList: &types.WhiteList{
|
||||
SourceRange: []string{
|
||||
"127.0.0.1/32",
|
||||
},
|
||||
},
|
||||
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
|
||||
},
|
||||
expectMiddleware: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srv := Server{
|
||||
globalConfiguration: configuration.GlobalConfiguration{},
|
||||
metricsRegistry: metrics.NewVoidRegistry(),
|
||||
entryPoints: map[string]EntryPoint{
|
||||
"test": {
|
||||
Configuration: test.entrypoint,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
srv.serverEntryPoints = srv.buildServerEntryPoints()
|
||||
srvEntryPoint := srv.setupServerEntryPoint("test", srv.serverEntryPoints["test"])
|
||||
handler := srvEntryPoint.httpServer.Handler.(*mux.Router).NotFoundHandler.(*negroni.Negroni)
|
||||
|
||||
found := false
|
||||
for _, handler := range handler.Handlers() {
|
||||
if reflect.TypeOf(handler) == reflect.TypeOf((*middlewares.IPWhiteLister)(nil)) {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
|
||||
if found && !test.expectMiddleware {
|
||||
t.Error("ip whitelist middleware was installed even though it should not")
|
||||
}
|
||||
|
||||
if !found && test.expectMiddleware {
|
||||
t.Error("ip whitelist middleware was not installed even though it should have")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildIPWhiteLister(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
whitelistSourceRange []string
|
||||
whiteList *types.WhiteList
|
||||
middlewareConfigured bool
|
||||
errMessage string
|
||||
}{
|
||||
{
|
||||
desc: "no whitelists configured",
|
||||
whitelistSourceRange: nil,
|
||||
middlewareConfigured: false,
|
||||
errMessage: "",
|
||||
},
|
||||
{
|
||||
desc: "whitelists configured",
|
||||
whiteList: &types.WhiteList{
|
||||
SourceRange: []string{
|
||||
"1.2.3.4/24",
|
||||
"fe80::/16",
|
||||
},
|
||||
},
|
||||
middlewareConfigured: true,
|
||||
errMessage: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
middleware, err := buildIPWhiteLister(test.whiteList, nil)
|
||||
|
||||
if test.errMessage != "" {
|
||||
require.EqualError(t, err, test.errMessage)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
|
||||
if test.middlewareConfigured {
|
||||
require.NotNil(t, middleware, "not expected middleware to be configured")
|
||||
} else {
|
||||
require.Nil(t, middleware, "expected middleware to be configured")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRedirectHandler(t *testing.T) {
|
||||
srv := Server{
|
||||
globalConfiguration: configuration.GlobalConfiguration{},
|
||||
entryPoints: map[string]EntryPoint{
|
||||
"http": {Configuration: &configuration.EntryPoint{Address: ":80"}},
|
||||
"https": {Configuration: &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}}},
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
srcEntryPointName string
|
||||
url string
|
||||
entryPoint *configuration.EntryPoint
|
||||
redirect *types.Redirect
|
||||
expectedURL string
|
||||
}{
|
||||
{
|
||||
desc: "redirect regex",
|
||||
srcEntryPointName: "http",
|
||||
url: "http://foo.com",
|
||||
redirect: &types.Redirect{
|
||||
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
|
||||
Replacement: "https://$1{{\"bar\"}}$2",
|
||||
},
|
||||
entryPoint: &configuration.EntryPoint{
|
||||
Address: ":80",
|
||||
Redirect: &types.Redirect{
|
||||
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
|
||||
Replacement: "https://$1{{\"bar\"}}$2",
|
||||
},
|
||||
},
|
||||
expectedURL: "https://foobar.com",
|
||||
},
|
||||
{
|
||||
desc: "redirect entry point",
|
||||
srcEntryPointName: "http",
|
||||
url: "http://foo:80",
|
||||
redirect: &types.Redirect{
|
||||
EntryPoint: "https",
|
||||
},
|
||||
entryPoint: &configuration.EntryPoint{
|
||||
Address: ":80",
|
||||
Redirect: &types.Redirect{
|
||||
EntryPoint: "https",
|
||||
},
|
||||
},
|
||||
expectedURL: "https://foo:443",
|
||||
},
|
||||
{
|
||||
desc: "redirect entry point with regex (ignored)",
|
||||
srcEntryPointName: "http",
|
||||
url: "http://foo.com:80",
|
||||
redirect: &types.Redirect{
|
||||
EntryPoint: "https",
|
||||
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
|
||||
Replacement: "https://$1{{\"bar\"}}$2",
|
||||
},
|
||||
entryPoint: &configuration.EntryPoint{
|
||||
Address: ":80",
|
||||
Redirect: &types.Redirect{
|
||||
EntryPoint: "https",
|
||||
Regex: `^(?:http?:\/\/)(foo)(\.com)$`,
|
||||
Replacement: "https://$1{{\"bar\"}}$2",
|
||||
},
|
||||
},
|
||||
expectedURL: "https://foo.com:443",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rewrite, err := srv.buildRedirectHandler(test.srcEntryPointName, test.redirect)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := th.MustNewRequest(http.MethodGet, test.url, nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
rewrite.ServeHTTP(recorder, req, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Add("Location", "fail")
|
||||
}))
|
||||
|
||||
location, err := recorder.Result().Location()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expectedURL, location.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerGenericFrontendAuthFail(t *testing.T) {
|
||||
globalConfig := configuration.GlobalConfiguration{
|
||||
EntryPoints: configuration.EntryPoints{
|
||||
"http": &configuration.EntryPoint{ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}},
|
||||
},
|
||||
}
|
||||
entryPoints := map[string]EntryPoint{
|
||||
"http": {
|
||||
Configuration: globalConfig.EntryPoints["http"],
|
||||
},
|
||||
}
|
||||
|
||||
dynamicConfigs := types.Configurations{
|
||||
"config": &types.Configuration{
|
||||
Frontends: map[string]*types.Frontend{
|
||||
"frontend": {
|
||||
EntryPoints: []string{"http"},
|
||||
Backend: "backend",
|
||||
Auth: &types.Auth{
|
||||
Basic: &types.Basic{
|
||||
Users: []string{""},
|
||||
}},
|
||||
},
|
||||
},
|
||||
Backends: map[string]*types.Backend{
|
||||
"backend": {
|
||||
Servers: map[string]types.Server{
|
||||
"server": {
|
||||
URL: "http://localhost",
|
||||
},
|
||||
},
|
||||
LoadBalancer: &types.LoadBalancer{
|
||||
Method: "Wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
srv := NewServer(globalConfig, nil, entryPoints)
|
||||
|
||||
_ = srv.loadConfig(dynamicConfigs, globalConfig)
|
||||
}
|
|
@ -21,16 +21,16 @@ func (s *Server) listenSignals(stop chan bool) {
|
|||
case sig := <-s.signals:
|
||||
switch sig {
|
||||
case syscall.SIGUSR1:
|
||||
log.Infof("Closing and re-opening log files for rotation: %+v", sig)
|
||||
log.WithoutContext().Infof("Closing and re-opening log files for rotation: %+v", sig)
|
||||
|
||||
if s.accessLoggerMiddleware != nil {
|
||||
if err := s.accessLoggerMiddleware.Rotate(); err != nil {
|
||||
log.Errorf("Error rotating access log: %v", err)
|
||||
log.WithoutContext().Errorf("Error rotating access log: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := log.RotateFile(); err != nil {
|
||||
log.Errorf("Error rotating traefik log: %v", err)
|
||||
log.WithoutContext().Errorf("Error rotating traefik log: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,13 +9,12 @@ import (
|
|||
|
||||
"github.com/containous/flaeg/parse"
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/configuration"
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/old/configuration"
|
||||
th "github.com/containous/traefik/testhelpers"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/unrolled/secure"
|
||||
)
|
||||
|
||||
func TestPrepareServerTimeouts(t *testing.T) {
|
||||
|
@ -62,7 +61,7 @@ func TestPrepareServerTimeouts(t *testing.T) {
|
|||
router := middlewares.NewHandlerSwitcher(mux.NewRouter())
|
||||
|
||||
srv := NewServer(test.globalConfig, nil, nil)
|
||||
httpServer, _, err := srv.prepareServer(entryPointName, entryPoint, router, nil)
|
||||
httpServer, _, err := srv.prepareServer(context.Background(), entryPointName, entryPoint, router)
|
||||
require.NoError(t, err, "Unexpected error when preparing srv")
|
||||
|
||||
assert.Equal(t, test.expectedIdleTimeout, httpServer.IdleTimeout, "IdleTimeout")
|
||||
|
@ -87,7 +86,7 @@ func TestListenProvidersSkipsEmptyConfigs(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
|
||||
server.configurationChan <- types.ConfigMessage{ProviderName: "kubernetes"}
|
||||
server.configurationChan <- config.Message{ProviderName: "kubernetes"}
|
||||
|
||||
// give some time so that the configuration can be processed
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
@ -103,12 +102,12 @@ func TestListenProvidersSkipsSameConfigurationForProvider(t *testing.T) {
|
|||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case config := <-server.configurationValidatedChan:
|
||||
case conf := <-server.configurationValidatedChan:
|
||||
// set the current configuration
|
||||
// this is usually done in the processing part of the published configuration
|
||||
// so we have to emulate the behavior here
|
||||
currentConfigurations := server.currentConfigurations.Get().(types.Configurations)
|
||||
currentConfigurations[config.ProviderName] = config.Configuration
|
||||
currentConfigurations := server.currentConfigurations.Get().(config.Configurations)
|
||||
currentConfigurations[conf.ProviderName] = conf.Configuration
|
||||
server.currentConfigurations.Set(currentConfigurations)
|
||||
|
||||
publishedConfigCount++
|
||||
|
@ -119,19 +118,19 @@ func TestListenProvidersSkipsSameConfigurationForProvider(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
|
||||
config := th.BuildConfiguration(
|
||||
th.WithFrontends(th.WithFrontend("backend")),
|
||||
th.WithBackends(th.WithBackendNew("backend")),
|
||||
conf := th.BuildConfiguration(
|
||||
th.WithRouters(th.WithRouter("foo")),
|
||||
th.WithLoadBalancerServices(th.WithService("bar")),
|
||||
)
|
||||
|
||||
// provide a configuration
|
||||
server.configurationChan <- types.ConfigMessage{ProviderName: "kubernetes", Configuration: config}
|
||||
server.configurationChan <- config.Message{ProviderName: "kubernetes", Configuration: conf}
|
||||
|
||||
// give some time so that the configuration can be processed
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
|
||||
// provide the same configuration a second time
|
||||
server.configurationChan <- types.ConfigMessage{ProviderName: "kubernetes", Configuration: config}
|
||||
server.configurationChan <- config.Message{ProviderName: "kubernetes", Configuration: conf}
|
||||
|
||||
// give some time so that the configuration can be processed
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
@ -160,12 +159,12 @@ func TestListenProvidersPublishesConfigForEachProvider(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
|
||||
config := th.BuildConfiguration(
|
||||
th.WithFrontends(th.WithFrontend("backend")),
|
||||
th.WithBackends(th.WithBackendNew("backend")),
|
||||
conf := th.BuildConfiguration(
|
||||
th.WithRouters(th.WithRouter("foo")),
|
||||
th.WithLoadBalancerServices(th.WithService("bar")),
|
||||
)
|
||||
server.configurationChan <- types.ConfigMessage{ProviderName: "kubernetes", Configuration: config}
|
||||
server.configurationChan <- types.ConfigMessage{ProviderName: "marathon", Configuration: config}
|
||||
server.configurationChan <- config.Message{ProviderName: "kubernetes", Configuration: conf}
|
||||
server.configurationChan <- config.Message{ProviderName: "marathon", Configuration: conf}
|
||||
|
||||
select {
|
||||
case <-consumePublishedConfigsDone:
|
||||
|
@ -206,20 +205,21 @@ func TestServerResponseEmptyBackend(t *testing.T) {
|
|||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
config func(testServerURL string) *types.Configuration
|
||||
config func(testServerURL string) *config.Configuration
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
desc: "Ok",
|
||||
config: func(testServerURL string) *types.Configuration {
|
||||
config: func(testServerURL string) *config.Configuration {
|
||||
return th.BuildConfiguration(
|
||||
th.WithFrontends(th.WithFrontend("backend",
|
||||
th.WithRouters(th.WithRouter("foo",
|
||||
th.WithEntryPoints("http"),
|
||||
th.WithRoutes(th.WithRoute(requestPath, routeRule))),
|
||||
th.WithServiceName("bar"),
|
||||
th.WithRule(routeRule)),
|
||||
),
|
||||
th.WithBackends(th.WithBackendNew("backend",
|
||||
th.WithLoadBalancerServices(th.WithService("bar",
|
||||
th.WithLBMethod("wrr"),
|
||||
th.WithServersNew(th.WithServerNew(testServerURL))),
|
||||
th.WithServers(th.WithServer(testServerURL))),
|
||||
),
|
||||
)
|
||||
},
|
||||
|
@ -227,20 +227,21 @@ func TestServerResponseEmptyBackend(t *testing.T) {
|
|||
},
|
||||
{
|
||||
desc: "No Frontend",
|
||||
config: func(testServerURL string) *types.Configuration {
|
||||
config: func(testServerURL string) *config.Configuration {
|
||||
return th.BuildConfiguration()
|
||||
},
|
||||
expectedStatusCode: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
desc: "Empty Backend LB-Drr",
|
||||
config: func(testServerURL string) *types.Configuration {
|
||||
config: func(testServerURL string) *config.Configuration {
|
||||
return th.BuildConfiguration(
|
||||
th.WithFrontends(th.WithFrontend("backend",
|
||||
th.WithRouters(th.WithRouter("foo",
|
||||
th.WithEntryPoints("http"),
|
||||
th.WithRoutes(th.WithRoute(requestPath, routeRule))),
|
||||
th.WithServiceName("bar"),
|
||||
th.WithRule(routeRule)),
|
||||
),
|
||||
th.WithBackends(th.WithBackendNew("backend",
|
||||
th.WithLoadBalancerServices(th.WithService("bar",
|
||||
th.WithLBMethod("drr")),
|
||||
),
|
||||
)
|
||||
|
@ -249,14 +250,15 @@ func TestServerResponseEmptyBackend(t *testing.T) {
|
|||
},
|
||||
{
|
||||
desc: "Empty Backend LB-Drr Sticky",
|
||||
config: func(testServerURL string) *types.Configuration {
|
||||
config: func(testServerURL string) *config.Configuration {
|
||||
return th.BuildConfiguration(
|
||||
th.WithFrontends(th.WithFrontend("backend",
|
||||
th.WithRouters(th.WithRouter("foo",
|
||||
th.WithEntryPoints("http"),
|
||||
th.WithRoutes(th.WithRoute(requestPath, routeRule))),
|
||||
th.WithServiceName("bar"),
|
||||
th.WithRule(routeRule)),
|
||||
),
|
||||
th.WithBackends(th.WithBackendNew("backend",
|
||||
th.WithLBMethod("drr"), th.WithLBSticky("test")),
|
||||
th.WithLoadBalancerServices(th.WithService("bar",
|
||||
th.WithLBMethod("drr"), th.WithStickiness("test")),
|
||||
),
|
||||
)
|
||||
},
|
||||
|
@ -264,13 +266,14 @@ func TestServerResponseEmptyBackend(t *testing.T) {
|
|||
},
|
||||
{
|
||||
desc: "Empty Backend LB-Wrr",
|
||||
config: func(testServerURL string) *types.Configuration {
|
||||
config: func(testServerURL string) *config.Configuration {
|
||||
return th.BuildConfiguration(
|
||||
th.WithFrontends(th.WithFrontend("backend",
|
||||
th.WithRouters(th.WithRouter("foo",
|
||||
th.WithEntryPoints("http"),
|
||||
th.WithRoutes(th.WithRoute(requestPath, routeRule))),
|
||||
th.WithServiceName("bar"),
|
||||
th.WithRule(routeRule)),
|
||||
),
|
||||
th.WithBackends(th.WithBackendNew("backend",
|
||||
th.WithLoadBalancerServices(th.WithService("bar",
|
||||
th.WithLBMethod("wrr")),
|
||||
),
|
||||
)
|
||||
|
@ -279,14 +282,15 @@ func TestServerResponseEmptyBackend(t *testing.T) {
|
|||
},
|
||||
{
|
||||
desc: "Empty Backend LB-Wrr Sticky",
|
||||
config: func(testServerURL string) *types.Configuration {
|
||||
config: func(testServerURL string) *config.Configuration {
|
||||
return th.BuildConfiguration(
|
||||
th.WithFrontends(th.WithFrontend("backend",
|
||||
th.WithRouters(th.WithRouter("foo",
|
||||
th.WithEntryPoints("http"),
|
||||
th.WithRoutes(th.WithRoute(requestPath, routeRule))),
|
||||
th.WithServiceName("bar"),
|
||||
th.WithRule(routeRule)),
|
||||
),
|
||||
th.WithBackends(th.WithBackendNew("backend",
|
||||
th.WithLBMethod("wrr"), th.WithLBSticky("test")),
|
||||
th.WithLoadBalancerServices(th.WithService("bar",
|
||||
th.WithLBMethod("wrr"), th.WithStickiness("test")),
|
||||
),
|
||||
)
|
||||
},
|
||||
|
@ -309,134 +313,17 @@ func TestServerResponseEmptyBackend(t *testing.T) {
|
|||
entryPointsConfig := map[string]EntryPoint{
|
||||
"http": {Configuration: &configuration.EntryPoint{ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}}},
|
||||
}
|
||||
dynamicConfigs := types.Configurations{"config": test.config(testServer.URL)}
|
||||
dynamicConfigs := config.Configurations{"config": test.config(testServer.URL)}
|
||||
|
||||
srv := NewServer(globalConfig, nil, entryPointsConfig)
|
||||
entryPoints := srv.loadConfig(dynamicConfigs, globalConfig)
|
||||
entryPoints, _ := srv.loadConfig(dynamicConfigs, globalConfig)
|
||||
|
||||
responseRecorder := &httptest.ResponseRecorder{}
|
||||
request := httptest.NewRequest(http.MethodGet, testServer.URL+requestPath, nil)
|
||||
|
||||
entryPoints["http"].httpRouter.ServeHTTP(responseRecorder, request)
|
||||
entryPoints["http"].ServeHTTP(responseRecorder, request)
|
||||
|
||||
assert.Equal(t, test.expectedStatusCode, responseRecorder.Result().StatusCode, "status code")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockContext struct {
|
||||
headers http.Header
|
||||
}
|
||||
|
||||
func (c mockContext) Deadline() (deadline time.Time, ok bool) {
|
||||
return deadline, ok
|
||||
}
|
||||
|
||||
func (c mockContext) Done() <-chan struct{} {
|
||||
ch := make(chan struct{})
|
||||
close(ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
func (c mockContext) Err() error {
|
||||
return context.DeadlineExceeded
|
||||
}
|
||||
|
||||
func (c mockContext) Value(key interface{}) interface{} {
|
||||
return c.headers
|
||||
}
|
||||
|
||||
func TestNewServerWithResponseModifiers(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
headerMiddleware *middlewares.HeaderStruct
|
||||
secureMiddleware *secure.Secure
|
||||
ctx context.Context
|
||||
expected map[string]string
|
||||
}{
|
||||
{
|
||||
desc: "header and secure nil",
|
||||
headerMiddleware: nil,
|
||||
secureMiddleware: nil,
|
||||
ctx: mockContext{},
|
||||
expected: map[string]string{
|
||||
"X-Default": "powpow",
|
||||
"Referrer-Policy": "same-origin",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "header middleware not nil",
|
||||
headerMiddleware: middlewares.NewHeaderFromStruct(&types.Headers{
|
||||
CustomResponseHeaders: map[string]string{
|
||||
"X-Default": "powpow",
|
||||
},
|
||||
}),
|
||||
secureMiddleware: nil,
|
||||
ctx: mockContext{},
|
||||
expected: map[string]string{
|
||||
"X-Default": "powpow",
|
||||
"Referrer-Policy": "same-origin",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "secure middleware not nil",
|
||||
headerMiddleware: nil,
|
||||
secureMiddleware: middlewares.NewSecure(&types.Headers{
|
||||
ReferrerPolicy: "no-referrer",
|
||||
}),
|
||||
ctx: mockContext{
|
||||
headers: http.Header{"Referrer-Policy": []string{"no-referrer"}},
|
||||
},
|
||||
expected: map[string]string{
|
||||
"X-Default": "powpow",
|
||||
"Referrer-Policy": "no-referrer",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "header and secure middleware not nil",
|
||||
headerMiddleware: middlewares.NewHeaderFromStruct(&types.Headers{
|
||||
CustomResponseHeaders: map[string]string{
|
||||
"Referrer-Policy": "powpow",
|
||||
},
|
||||
}),
|
||||
secureMiddleware: middlewares.NewSecure(&types.Headers{
|
||||
ReferrerPolicy: "no-referrer",
|
||||
}),
|
||||
ctx: mockContext{
|
||||
headers: http.Header{"Referrer-Policy": []string{"no-referrer"}},
|
||||
},
|
||||
expected: map[string]string{
|
||||
"X-Default": "powpow",
|
||||
"Referrer-Policy": "powpow",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Add("X-Default", "powpow")
|
||||
headers.Add("Referrer-Policy", "same-origin")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
||||
|
||||
res := &http.Response{
|
||||
Request: req.WithContext(test.ctx),
|
||||
Header: headers,
|
||||
}
|
||||
|
||||
responseModifier := buildModifyResponse(test.secureMiddleware, test.headerMiddleware)
|
||||
err := responseModifier(res)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(test.expected), len(res.Header))
|
||||
|
||||
for k, v := range test.expected {
|
||||
assert.Equal(t, v, res.Header.Get(k))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package server
|
||||
package service
|
||||
|
||||
import "sync"
|
||||
|
280
server/service/service.go
Normal file
280
server/service/service.go
Normal file
|
@ -0,0 +1,280 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/containous/flaeg/parse"
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/healthcheck"
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/middlewares/emptybackendhandler"
|
||||
"github.com/containous/traefik/old/middlewares/pipelining"
|
||||
"github.com/containous/traefik/server/cookie"
|
||||
"github.com/vulcand/oxy/forward"
|
||||
"github.com/vulcand/oxy/roundrobin"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultHealthCheckInterval = 30 * time.Second
|
||||
defaultHealthCheckTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// See oxy/roundrobin/rr.go
|
||||
type balancerHandler interface {
|
||||
Servers() []*url.URL
|
||||
ServeHTTP(w http.ResponseWriter, req *http.Request)
|
||||
ServerWeight(u *url.URL) (int, bool)
|
||||
RemoveServer(u *url.URL) error
|
||||
UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error
|
||||
NextServer() (*url.URL, error)
|
||||
Next() http.Handler
|
||||
}
|
||||
|
||||
// NewManager creates a new Manager
|
||||
func NewManager(configs map[string]*config.Service, defaultRoundTripper http.RoundTripper) *Manager {
|
||||
return &Manager{
|
||||
bufferPool: newBufferPool(),
|
||||
defaultRoundTripper: defaultRoundTripper,
|
||||
balancers: make(map[string][]healthcheck.BalancerHandler),
|
||||
configs: configs,
|
||||
}
|
||||
}
|
||||
|
||||
// Manager The service manager
|
||||
type Manager struct {
|
||||
bufferPool httputil.BufferPool
|
||||
defaultRoundTripper http.RoundTripper
|
||||
balancers map[string][]healthcheck.BalancerHandler
|
||||
configs map[string]*config.Service
|
||||
}
|
||||
|
||||
// Build Creates a http.Handler for a service configuration.
|
||||
func (m *Manager) Build(rootCtx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) {
|
||||
ctx := log.With(rootCtx, log.Str(log.ServiceName, serviceName))
|
||||
|
||||
// TODO refactor ?
|
||||
if conf, ok := m.configs[serviceName]; ok {
|
||||
// FIXME Should handle multiple service types
|
||||
return m.getLoadBalancerServiceHandler(ctx, serviceName, conf.LoadBalancer, responseModifier)
|
||||
}
|
||||
return nil, fmt.Errorf("the service %q does not exits", serviceName)
|
||||
}
|
||||
|
||||
func (m *Manager) getLoadBalancerServiceHandler(
|
||||
ctx context.Context,
|
||||
serviceName string,
|
||||
service *config.LoadBalancerService,
|
||||
responseModifier func(*http.Response) error,
|
||||
) (http.Handler, error) {
|
||||
|
||||
fwd, err := m.buildForwarder(service.PassHostHeader, service.ResponseForwarding, responseModifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fwd = pipelining.NewPipelining(fwd)
|
||||
|
||||
rr, err := roundrobin.New(fwd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
balancer, err := m.getLoadBalancer(ctx, serviceName, service, fwd, rr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO rename and checks
|
||||
m.balancers[serviceName] = append(m.balancers[serviceName], balancer)
|
||||
|
||||
// Empty (backend with no servers)
|
||||
return emptybackendhandler.New(balancer), nil
|
||||
}
|
||||
|
||||
// LaunchHealthCheck Launches the health checks.
|
||||
func (m *Manager) LaunchHealthCheck() {
|
||||
backendConfigs := make(map[string]*healthcheck.BackendConfig)
|
||||
|
||||
for serviceName, balancers := range m.balancers {
|
||||
ctx := log.With(context.Background(), log.Str(log.ServiceName, serviceName))
|
||||
|
||||
// FIXME aggregate
|
||||
balancer := balancers[0]
|
||||
|
||||
// FIXME Should all the services handle healthcheck? Handle different types
|
||||
service := m.configs[serviceName].LoadBalancer
|
||||
|
||||
// Health Check
|
||||
var backendHealthCheck *healthcheck.BackendConfig
|
||||
if hcOpts := buildHealthCheckOptions(ctx, balancer, serviceName, service.HealthCheck); hcOpts != nil {
|
||||
log.FromContext(ctx).Debugf("Setting up healthcheck for service %s with %s", serviceName, *hcOpts)
|
||||
|
||||
hcOpts.Transport = m.defaultRoundTripper
|
||||
backendHealthCheck = healthcheck.NewBackendConfig(*hcOpts, serviceName)
|
||||
}
|
||||
|
||||
if backendHealthCheck != nil {
|
||||
backendConfigs[serviceName] = backendHealthCheck
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME metrics and context
|
||||
healthcheck.GetHealthCheck().SetBackendsConfiguration(context.TODO(), backendConfigs)
|
||||
}
|
||||
|
||||
func buildHealthCheckOptions(ctx context.Context, lb healthcheck.BalancerHandler, backend string, hc *config.HealthCheck) *healthcheck.Options {
|
||||
if hc == nil || hc.Path == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
interval := defaultHealthCheckInterval
|
||||
if hc.Interval != "" {
|
||||
intervalOverride, err := time.ParseDuration(hc.Interval)
|
||||
if err != nil {
|
||||
logger.Errorf("Illegal health check interval for '%s': %s", backend, err)
|
||||
} else if intervalOverride <= 0 {
|
||||
logger.Errorf("Health check interval smaller than zero for service '%s'", backend)
|
||||
} else {
|
||||
interval = intervalOverride
|
||||
}
|
||||
}
|
||||
|
||||
timeout := defaultHealthCheckTimeout
|
||||
if hc.Timeout != "" {
|
||||
timeoutOverride, err := time.ParseDuration(hc.Timeout)
|
||||
if err != nil {
|
||||
logger.Errorf("Illegal health check timeout for backend '%s': %s", backend, err)
|
||||
} else if timeoutOverride <= 0 {
|
||||
logger.Errorf("Health check timeout smaller than zero for backend '%s', backend", backend)
|
||||
} else {
|
||||
timeout = timeoutOverride
|
||||
}
|
||||
}
|
||||
|
||||
if timeout >= interval {
|
||||
logger.Warnf("Health check timeout for backend '%s' should be lower than the health check interval. Interval set to timeout + 1 second (%s).", backend)
|
||||
}
|
||||
|
||||
return &healthcheck.Options{
|
||||
Scheme: hc.Scheme,
|
||||
Path: hc.Path,
|
||||
Port: hc.Port,
|
||||
Interval: interval,
|
||||
Timeout: timeout,
|
||||
LB: lb,
|
||||
Hostname: hc.Hostname,
|
||||
Headers: hc.Headers,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) getLoadBalancer(ctx context.Context, serviceName string, service *config.LoadBalancerService, fwd http.Handler, rr balancerHandler) (healthcheck.BalancerHandler, error) {
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
var stickySession *roundrobin.StickySession
|
||||
var cookieName string
|
||||
if stickiness := service.Stickiness; stickiness != nil {
|
||||
cookieName = cookie.GetName(stickiness.CookieName, serviceName)
|
||||
stickySession = roundrobin.NewStickySession(cookieName)
|
||||
}
|
||||
|
||||
var lb healthcheck.BalancerHandler
|
||||
var err error
|
||||
|
||||
if service.Method == "drr" {
|
||||
logger.Debug("Creating drr load-balancer")
|
||||
|
||||
if stickySession != nil {
|
||||
logger.Debugf("Sticky session cookie name: %v", cookieName)
|
||||
|
||||
lb, err = roundrobin.NewRebalancer(rr, roundrobin.RebalancerStickySession(stickySession))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
lb, err = roundrobin.NewRebalancer(rr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if service.Method != "wrr" {
|
||||
logger.Warnf("Invalid load-balancing method %q, fallback to 'wrr' method", service.Method)
|
||||
}
|
||||
|
||||
logger.Debug("Creating wrr load-balancer")
|
||||
|
||||
if stickySession != nil {
|
||||
logger.Debugf("Sticky session cookie name: %v", cookieName)
|
||||
|
||||
lb, err = roundrobin.New(fwd, roundrobin.EnableStickySession(stickySession))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
lb = rr
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.upsertServers(ctx, lb, service.Servers); err != nil {
|
||||
return nil, fmt.Errorf("error configuring load balancer for service %s: %v", serviceName, err)
|
||||
}
|
||||
|
||||
return lb, nil
|
||||
}
|
||||
|
||||
func (m *Manager) upsertServers(ctx context.Context, lb healthcheck.BalancerHandler, servers []config.Server) error {
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
for name, srv := range servers {
|
||||
u, err := url.Parse(srv.URL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error parsing server URL %s: %v", srv.URL, err)
|
||||
}
|
||||
|
||||
logger.WithField(log.ServerName, name).Debugf("Creating server %d at %s with weight %d", name, u, srv.Weight)
|
||||
|
||||
if err := lb.UpsertServer(u, roundrobin.Weight(srv.Weight)); err != nil {
|
||||
return fmt.Errorf("error adding server %s to load balancer: %v", srv.URL, err)
|
||||
}
|
||||
|
||||
// FIXME Handle Metrics
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) buildForwarder(passHostHeader bool, responseForwarding *config.ResponseForwarding, responseModifier func(*http.Response) error) (http.Handler, error) {
|
||||
|
||||
var flushInterval parse.Duration
|
||||
if responseForwarding != nil {
|
||||
err := flushInterval.Set(responseForwarding.FlushInterval)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating flush interval: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return forward.New(
|
||||
forward.Stream(true),
|
||||
forward.PassHostHeader(passHostHeader),
|
||||
forward.RoundTripper(m.defaultRoundTripper),
|
||||
forward.ResponseModifier(responseModifier),
|
||||
forward.BufferPool(m.bufferPool),
|
||||
forward.StreamingFlushInterval(time.Duration(flushInterval)),
|
||||
forward.WebsocketConnectionClosedHook(func(req *http.Request, conn net.Conn) {
|
||||
server := req.Context().Value(http.ServerContextKey).(*http.Server)
|
||||
if server != nil {
|
||||
connState := server.ConnState
|
||||
if connState != nil {
|
||||
connState(conn, http.StateClosed)
|
||||
}
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
327
server/service/service_test.go
Normal file
327
server/service/service_test.go
Normal file
|
@ -0,0 +1,327 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vulcand/oxy/roundrobin"
|
||||
)
|
||||
|
||||
type MockRR struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (*MockRR) Servers() []*url.URL {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (*MockRR) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (*MockRR) ServerWeight(u *url.URL) (int, bool) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (*MockRR) RemoveServer(u *url.URL) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *MockRR) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error {
|
||||
return m.err
|
||||
}
|
||||
|
||||
func (*MockRR) NextServer() (*url.URL, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (*MockRR) Next() http.Handler {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
type MockForwarder struct{}
|
||||
|
||||
func (MockForwarder) ServeHTTP(http.ResponseWriter, *http.Request) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func TestGetLoadBalancer(t *testing.T) {
|
||||
sm := Manager{}
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
serviceName string
|
||||
service *config.LoadBalancerService
|
||||
fwd http.Handler
|
||||
rr balancerHandler
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
desc: "Fails when provided an invalid URL",
|
||||
serviceName: "test",
|
||||
service: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: ":",
|
||||
Weight: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
fwd: &MockForwarder{},
|
||||
rr: &MockRR{},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
desc: "Fails when the server upsert fails",
|
||||
serviceName: "test",
|
||||
service: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: "http://foo",
|
||||
Weight: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
fwd: &MockForwarder{},
|
||||
rr: &MockRR{err: errors.New("upsert fails")},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
desc: "Succeeds when there are no servers",
|
||||
serviceName: "test",
|
||||
service: &config.LoadBalancerService{},
|
||||
fwd: &MockForwarder{},
|
||||
rr: &MockRR{},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
desc: "Succeeds when stickiness is set",
|
||||
serviceName: "test",
|
||||
service: &config.LoadBalancerService{
|
||||
Stickiness: &config.Stickiness{},
|
||||
},
|
||||
fwd: &MockForwarder{},
|
||||
rr: &MockRR{},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := sm.getLoadBalancer(context.Background(), test.serviceName, test.service, test.fwd, test.rr)
|
||||
if test.expectError {
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, handler)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, handler)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLoadBalancerServiceHandler(t *testing.T) {
|
||||
sm := NewManager(nil, http.DefaultTransport)
|
||||
|
||||
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-From", "first")
|
||||
}))
|
||||
defer server1.Close()
|
||||
|
||||
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-From", "second")
|
||||
}))
|
||||
defer server2.Close()
|
||||
|
||||
serverPassHost := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-From", "passhost")
|
||||
assert.Equal(t, "callme", r.Host)
|
||||
}))
|
||||
defer serverPassHost.Close()
|
||||
|
||||
serverPassHostFalse := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-From", "passhostfalse")
|
||||
assert.NotEqual(t, "callme", r.Host)
|
||||
}))
|
||||
defer serverPassHostFalse.Close()
|
||||
|
||||
type ExpectedResult struct {
|
||||
StatusCode int
|
||||
XFrom string
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
serviceName string
|
||||
service *config.LoadBalancerService
|
||||
responseModifier func(*http.Response) error
|
||||
|
||||
expected []ExpectedResult
|
||||
}{
|
||||
{
|
||||
desc: "Load balances between the two servers",
|
||||
serviceName: "test",
|
||||
service: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server1.URL,
|
||||
Weight: 50,
|
||||
},
|
||||
{
|
||||
URL: server2.URL,
|
||||
Weight: 50,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
expected: []ExpectedResult{
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
XFrom: "first",
|
||||
},
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
XFrom: "second",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "StatusBadGateway when the server is not reachable",
|
||||
serviceName: "test",
|
||||
service: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: "http://foo",
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
expected: []ExpectedResult{
|
||||
{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "ServiceUnavailable when no servers are available",
|
||||
serviceName: "test",
|
||||
service: &config.LoadBalancerService{
|
||||
Servers: []config.Server{},
|
||||
Method: "wrr",
|
||||
},
|
||||
expected: []ExpectedResult{
|
||||
{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Always call the same server when stickiness is true",
|
||||
serviceName: "test",
|
||||
service: &config.LoadBalancerService{
|
||||
Stickiness: &config.Stickiness{},
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server1.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
{
|
||||
URL: server2.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
expected: []ExpectedResult{
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
XFrom: "first",
|
||||
},
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
XFrom: "first",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "PassHost passes the host instead of the IP",
|
||||
serviceName: "test",
|
||||
service: &config.LoadBalancerService{
|
||||
Stickiness: &config.Stickiness{},
|
||||
PassHostHeader: true,
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: serverPassHost.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
expected: []ExpectedResult{
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
XFrom: "passhost",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "PassHost doesn't passe the host instead of the IP",
|
||||
serviceName: "test",
|
||||
service: &config.LoadBalancerService{
|
||||
Stickiness: &config.Stickiness{},
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: serverPassHostFalse.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
expected: []ExpectedResult{
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
XFrom: "passhostfalse",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
|
||||
handler, err := sm.getLoadBalancerServiceHandler(context.Background(), test.serviceName, test.service, test.responseModifier)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, handler)
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://callme", nil)
|
||||
for _, expected := range test.expected {
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, expected.StatusCode, recorder.Code)
|
||||
assert.Equal(t, expected.XFrom, recorder.Header().Get("X-From"))
|
||||
|
||||
if len(recorder.Header().Get("Set-Cookie")) > 0 {
|
||||
req.Header.Set("Cookie", recorder.Header().Get("Set-Cookie"))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME Add healthcheck tests
|
Loading…
Add table
Add a link
Reference in a new issue