Remove old global config and use new static config

This commit is contained in:
SALLEYRON Julien 2018-11-27 17:42:04 +01:00 committed by Traefiker Bot
parent c39d21c178
commit 5d91c7e15c
114 changed files with 2485 additions and 3646 deletions

View file

@ -7,9 +7,11 @@ import (
"net/http"
"time"
"github.com/containous/traefik/config/static"
"github.com/containous/traefik/log"
"github.com/containous/traefik/old/configuration"
traefiktls "github.com/containous/traefik/tls"
"github.com/pkg/errors"
"golang.org/x/net/http2"
)
@ -22,26 +24,30 @@ func (t *h2cTransportWrapper) RoundTrip(req *http.Request) (*http.Response, erro
return t.Transport.RoundTrip(req)
}
// createHTTPTransport creates an http.Transport configured with the GlobalConfiguration settings.
// createHTTPTransport creates an http.Transport configured with the Transport configuration settings.
// For the settings that can't be configured in Traefik it uses the default http.Transport settings.
// An exception to this is the MaxIdleConns setting as we only provide the option MaxIdleConnsPerHost
// in Traefik at this point in time. Setting this value to the default of 100 could lead to confusing
// behavior and backwards compatibility issues.
func createHTTPTransport(globalConfiguration configuration.GlobalConfiguration) (*http.Transport, error) {
func createHTTPTransport(transportConfiguration *static.ServersTransport) (*http.Transport, error) {
if transportConfiguration == nil {
return nil, errors.New("no transport configuration given")
}
dialer := &net.Dialer{
Timeout: configuration.DefaultDialTimeout,
KeepAlive: 30 * time.Second,
DualStack: true,
}
if globalConfiguration.ForwardingTimeouts != nil {
dialer.Timeout = time.Duration(globalConfiguration.ForwardingTimeouts.DialTimeout)
if transportConfiguration.ForwardingTimeouts != nil {
dialer.Timeout = time.Duration(transportConfiguration.ForwardingTimeouts.DialTimeout)
}
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialer.DialContext,
MaxIdleConnsPerHost: globalConfiguration.MaxIdleConnsPerHost,
MaxIdleConnsPerHost: transportConfiguration.MaxIdleConnsPerHost,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
@ -56,17 +62,17 @@ func createHTTPTransport(globalConfiguration configuration.GlobalConfiguration)
},
})
if globalConfiguration.ForwardingTimeouts != nil {
transport.ResponseHeaderTimeout = time.Duration(globalConfiguration.ForwardingTimeouts.ResponseHeaderTimeout)
if transportConfiguration.ForwardingTimeouts != nil {
transport.ResponseHeaderTimeout = time.Duration(transportConfiguration.ForwardingTimeouts.ResponseHeaderTimeout)
}
if globalConfiguration.InsecureSkipVerify {
if transportConfiguration.InsecureSkipVerify {
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
if len(globalConfiguration.RootCAs) > 0 {
if len(transportConfiguration.RootCAs) > 0 {
transport.TLSClientConfig = &tls.Config{
RootCAs: createRootCACertPool(globalConfiguration.RootCAs),
RootCAs: createRootCACertPool(transportConfiguration.RootCAs),
}
}

View file

@ -38,6 +38,7 @@ func NewRouteAppenderAggregator(ctx context.Context, chainBuilder chainBuilder,
Statistics: conf.API.Statistics,
DashboardAssets: conf.API.DashboardAssets,
CurrentConfigurations: currentConfiguration,
Debug: conf.Global.Debug,
},
routerMiddlewares: chain,
})

View file

@ -39,6 +39,7 @@ func TestNewRouteAppenderAggregator(t *testing.T) {
{
desc: "API with auth, ping without auth",
staticConf: static.Configuration{
Global: &static.Global{},
API: &static.API{
EntryPoint: "traefik",
Middlewares: []string{"dumb"},
@ -46,10 +47,8 @@ func TestNewRouteAppenderAggregator(t *testing.T) {
Ping: &ping.Handler{
EntryPoint: "traefik",
},
EntryPoints: &static.EntryPoints{
EntryPointList: map[string]static.EntryPoint{
"traefik": {},
},
EntryPoints: static.EntryPoints{
"traefik": {},
},
},
middles: map[string]alice.Constructor{
@ -69,13 +68,12 @@ func TestNewRouteAppenderAggregator(t *testing.T) {
{
desc: "Wrong entrypoint name",
staticConf: static.Configuration{
Global: &static.Global{},
API: &static.API{
EntryPoint: "no",
},
EntryPoints: &static.EntryPoints{
EntryPointList: map[string]static.EntryPoint{
"traefik": {},
},
EntryPoints: static.EntryPoints{
"traefik": {},
},
},
expected: map[string]int{

View file

@ -44,8 +44,8 @@ type Manager struct {
}
// BuildHandlers Builds handler for all entry points
func (m *Manager) BuildHandlers(rootCtx context.Context, entryPoints []string, defaultEntryPoints []string) map[string]http.Handler {
entryPointsRouters := m.filteredRouters(rootCtx, entryPoints, defaultEntryPoints)
func (m *Manager) BuildHandlers(rootCtx context.Context, entryPoints []string) map[string]http.Handler {
entryPointsRouters := m.filteredRouters(rootCtx, entryPoints)
entryPointHandlers := make(map[string]http.Handler)
for entryPointName, routers := range entryPointsRouters {
@ -73,13 +73,13 @@ func contains(entryPoints []string, entryPointName string) bool {
return false
}
func (m *Manager) filteredRouters(ctx context.Context, entryPoints []string, defaultEntryPoints []string) map[string]map[string]*config.Router {
func (m *Manager) filteredRouters(ctx context.Context, entryPoints []string) map[string]map[string]*config.Router {
entryPointsRouters := make(map[string]map[string]*config.Router)
for rtName, rt := range m.configs {
eps := rt.EntryPoints
if len(eps) == 0 {
eps = defaultEntryPoints
eps = entryPoints
}
for _, entryPointName := range eps {
if !contains(entryPoints, entryPointName) {

View file

@ -27,13 +27,12 @@ func TestRouterManager_Get(t *testing.T) {
}
testCases := []struct {
desc string
routersConfig map[string]*config.Router
serviceConfig map[string]*config.Service
middlewaresConfig map[string]*config.Middleware
entryPoints []string
defaultEntryPoints []string
expected ExpectedResult
desc string
routersConfig map[string]*config.Router
serviceConfig map[string]*config.Service
middlewaresConfig map[string]*config.Middleware
entryPoints []string
expected ExpectedResult
}{
{
desc: "no middleware",
@ -81,9 +80,8 @@ func TestRouterManager_Get(t *testing.T) {
},
},
},
entryPoints: []string{"web"},
defaultEntryPoints: []string{"web"},
expected: ExpectedResult{StatusCode: http.StatusOK},
entryPoints: []string{"web"},
expected: ExpectedResult{StatusCode: http.StatusOK},
},
{
desc: "no middleware, no matching",
@ -209,7 +207,7 @@ func TestRouterManager_Get(t *testing.T) {
routerManager := NewManager(test.routersConfig, serviceManager, middlewaresBuilder, responseModifierFactory)
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints, test.defaultEntryPoints)
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints)
w := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
@ -309,7 +307,7 @@ func TestAccessLog(t *testing.T) {
routerManager := NewManager(test.routersConfig, serviceManager, middlewaresBuilder, responseModifierFactory)
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints, test.defaultEntryPoints)
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints)
w := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)

View file

@ -2,116 +2,49 @@ package server
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
stdlog "log"
"net"
"net/http"
"os"
"os/signal"
"sync"
"time"
"github.com/armon/go-proxyproto"
"github.com/containous/traefik/cluster"
"github.com/containous/traefik/config"
"github.com/containous/traefik/config/static"
"github.com/containous/traefik/h2c"
"github.com/containous/traefik/ip"
"github.com/containous/traefik/log"
"github.com/containous/traefik/metrics"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/middlewares/accesslog"
"github.com/containous/traefik/middlewares/requestdecorator"
"github.com/containous/traefik/old/configuration"
"github.com/containous/traefik/provider"
"github.com/containous/traefik/safe"
"github.com/containous/traefik/server/middleware"
traefiktls "github.com/containous/traefik/tls"
"github.com/containous/traefik/tracing"
"github.com/containous/traefik/tracing/datadog"
"github.com/containous/traefik/tracing/jaeger"
"github.com/containous/traefik/tracing/zipkin"
"github.com/containous/traefik/types"
"github.com/sirupsen/logrus"
"github.com/xenolf/lego/acme"
)
func newHijackConnectionTracker() *hijackConnectionTracker {
return &hijackConnectionTracker{
conns: make(map[net.Conn]struct{}),
}
}
type hijackConnectionTracker struct {
conns map[net.Conn]struct{}
lock sync.RWMutex
}
// AddHijackedConnection add a connection in the tracked connections list
func (h *hijackConnectionTracker) AddHijackedConnection(conn net.Conn) {
h.lock.Lock()
defer h.lock.Unlock()
h.conns[conn] = struct{}{}
}
// RemoveHijackedConnection remove a connection from the tracked connections list
func (h *hijackConnectionTracker) RemoveHijackedConnection(conn net.Conn) {
h.lock.Lock()
defer h.lock.Unlock()
delete(h.conns, conn)
}
// Shutdown wait for the connection closing
func (h *hijackConnectionTracker) Shutdown(ctx context.Context) error {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
h.lock.RLock()
if len(h.conns) == 0 {
return nil
}
h.lock.RUnlock()
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
}
}
// Close close all the connections in the tracked connections list
func (h *hijackConnectionTracker) Close() {
for conn := range h.conns {
if err := conn.Close(); err != nil {
log.WithoutContext().Errorf("Error while closing Hijacked connection: %v", err)
}
delete(h.conns, conn)
}
}
// Server is the reverse-proxy/load-balancer engine
type Server struct {
serverEntryPoints serverEntryPoints
entryPoints EntryPoints
configurationChan chan config.Message
configurationValidatedChan chan config.Message
signals chan os.Signal
stopChan chan bool
currentConfigurations safe.Safe
providerConfigUpdateMap map[string]chan config.Message
globalConfiguration configuration.GlobalConfiguration
accessLoggerMiddleware *accesslog.Handler
tracer *tracing.Tracing
routinesPool *safe.Pool
leadership *cluster.Leadership
leadership *cluster.Leadership //FIXME Cluster
defaultRoundTripper http.RoundTripper
metricsRegistry metrics.Registry
provider provider.Provider
configurationListeners []func(config.Configuration)
entryPoints map[string]EntryPoint
requestDecorator *requestdecorator.RequestDecorator
providersThrottleDuration time.Duration
}
// RouteAppenderFactory the route appender factory interface
@ -119,86 +52,6 @@ type RouteAppenderFactory interface {
NewAppender(ctx context.Context, middlewaresBuilder *middleware.Builder, currentConfigurations *safe.Safe) types.RouteAppender
}
// EntryPoint entryPoint information (configuration + internalRouter)
type EntryPoint struct {
RouteAppenderFactory RouteAppenderFactory
Configuration *configuration.EntryPoint
OnDemandListener func(string) (*tls.Certificate, error)
TLSALPNGetter func(string) (*tls.Certificate, error)
CertificateStore *traefiktls.CertificateStore
}
type serverEntryPoints map[string]*serverEntryPoint
type serverEntryPoint struct {
httpServer *h2c.Server
listener net.Listener
httpRouter *middlewares.HandlerSwitcher
certs *traefiktls.CertificateStore
onDemandListener func(string) (*tls.Certificate, error)
tlsALPNGetter func(string) (*tls.Certificate, error)
hijackConnectionTracker *hijackConnectionTracker
}
func (s serverEntryPoint) Shutdown(ctx context.Context) {
var wg sync.WaitGroup
if s.httpServer != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := s.httpServer.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
logger := log.FromContext(ctx)
logger.Debugf("Wait server shutdown is over due to: %s", err)
err = s.httpServer.Close()
if err != nil {
logger.Error(err)
}
}
}
}()
}
if s.hijackConnectionTracker != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := s.hijackConnectionTracker.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
logger := log.FromContext(ctx)
logger.Debugf("Wait hijack connection is over due to: %s", err)
s.hijackConnectionTracker.Close()
}
}
}()
}
wg.Wait()
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections.
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
tc, err := ln.AcceptTCP()
if err != nil {
return nil, err
}
if err = tc.SetKeepAlive(true); err != nil {
return nil, err
}
if err = tc.SetKeepAlivePeriod(3 * time.Minute); err != nil {
return nil, err
}
return tc, nil
}
func setupTracing(conf *static.Tracing) tracing.TrackingBackend {
switch conf.Backend {
case jaeger.Name:
@ -214,13 +67,11 @@ func setupTracing(conf *static.Tracing) tracing.TrackingBackend {
}
// NewServer returns an initialized Server.
func NewServer(globalConfiguration configuration.GlobalConfiguration, provider provider.Provider, entrypoints map[string]EntryPoint) *Server {
func NewServer(staticConfiguration static.Configuration, provider provider.Provider, entryPoints EntryPoints) *Server {
server := &Server{}
server.entryPoints = entrypoints
server.provider = provider
server.globalConfiguration = globalConfiguration
server.serverEntryPoints = make(map[string]*serverEntryPoint)
server.entryPoints = entryPoints
server.configurationChan = make(chan config.Message, 100)
server.configurationValidatedChan = make(chan config.Message, 100)
server.signals = make(chan os.Signal, 1)
@ -230,41 +81,36 @@ func NewServer(globalConfiguration configuration.GlobalConfiguration, provider p
server.currentConfigurations.Set(currentConfigurations)
server.providerConfigUpdateMap = make(map[string]chan config.Message)
transport, err := createHTTPTransport(globalConfiguration)
if staticConfiguration.Providers != nil {
server.providersThrottleDuration = time.Duration(staticConfiguration.Providers.ProvidersThrottleDuration)
}
transport, err := createHTTPTransport(staticConfiguration.ServersTransport)
if err != nil {
log.WithoutContext().Error(err)
log.WithoutContext().Errorf("Could not configure HTTP Transport, fallbacking on default transport: %v", err)
server.defaultRoundTripper = http.DefaultTransport
} else {
server.defaultRoundTripper = transport
}
if server.globalConfiguration.API != nil {
server.globalConfiguration.API.CurrentConfigurations = &server.currentConfigurations
}
server.routinesPool = safe.NewPool(context.Background())
if globalConfiguration.Tracing != nil {
trackingBackend := setupTracing(static.ConvertTracing(globalConfiguration.Tracing))
if staticConfiguration.Tracing != nil {
trackingBackend := setupTracing(staticConfiguration.Tracing)
var err error
server.tracer, err = tracing.NewTracing(globalConfiguration.Tracing.ServiceName, globalConfiguration.Tracing.SpanNameLimit, trackingBackend)
server.tracer, err = tracing.NewTracing(staticConfiguration.Tracing.ServiceName, staticConfiguration.Tracing.SpanNameLimit, trackingBackend)
if err != nil {
log.WithoutContext().Warnf("Unable to create tracer: %v", err)
}
}
server.requestDecorator = requestdecorator.New(static.ConvertHostResolverConfig(globalConfiguration.HostResolver))
server.requestDecorator = requestdecorator.New(staticConfiguration.HostResolver)
server.metricsRegistry = registerMetricClients(static.ConvertMetrics(globalConfiguration.Metrics))
server.metricsRegistry = registerMetricClients(staticConfiguration.Metrics)
if globalConfiguration.Cluster != nil {
// leadership creation if cluster mode
server.leadership = cluster.NewLeadership(server.routinesPool.Ctx(), globalConfiguration.Cluster)
}
if globalConfiguration.AccessLog != nil {
if staticConfiguration.AccessLog != nil {
var err error
server.accessLoggerMiddleware, err = accesslog.NewHandler(static.ConvertAccessLog(globalConfiguration.AccessLog))
server.accessLoggerMiddleware, err = accesslog.NewHandler(staticConfiguration.AccessLog)
if err != nil {
log.WithoutContext().Warnf("Unable to create access logger : %v", err)
}
@ -272,8 +118,17 @@ func NewServer(globalConfiguration configuration.GlobalConfiguration, provider p
return server
}
// Start starts the server.
func (s *Server) Start() {
// Start starts the server and Stop/Close it when context is Done
func (s *Server) Start(ctx context.Context) {
go func() {
defer s.Close()
<-ctx.Done()
logger := log.FromContext(ctx)
logger.Info("I have to go...")
logger.Info("Stopping server gracefully")
s.Stop()
}()
s.startHTTPServers()
s.startLeadership()
s.routinesPool.Go(func(stop chan bool) {
@ -288,26 +143,6 @@ func (s *Server) Start() {
})
}
// StartWithContext starts the server and Stop/Close it when context is Done
func (s *Server) StartWithContext(ctx context.Context) {
go func() {
defer s.Close()
<-ctx.Done()
logger := log.FromContext(ctx)
logger.Info("I have to go...")
reqAcceptGraceTimeOut := time.Duration(s.globalConfiguration.LifeCycle.RequestAcceptGraceTimeout)
if reqAcceptGraceTimeOut > 0 {
logger.Infof("Waiting %s for incoming requests to cease", reqAcceptGraceTimeOut)
time.Sleep(reqAcceptGraceTimeOut)
}
logger.Info("Stopping server gracefully")
s.Stop()
}()
s.Start()
}
// Wait blocks until server is shutted down.
func (s *Server) Wait() {
<-s.stopChan
@ -318,21 +153,16 @@ func (s *Server) Stop() {
defer log.WithoutContext().Info("Server stopped")
var wg sync.WaitGroup
for sepn, sep := range s.serverEntryPoints {
for epn, ep := range s.entryPoints {
wg.Add(1)
go func(serverEntryPointName string, serverEntryPoint *serverEntryPoint) {
go func(entryPointName string, entryPoint *EntryPoint) {
ctx := log.With(context.Background(), log.Str(log.EntryPointName, entryPointName))
defer wg.Done()
logger := log.WithoutContext().WithField(log.EntryPointName, serverEntryPointName)
graceTimeOut := time.Duration(s.globalConfiguration.LifeCycle.GraceTimeOut)
ctx, cancel := context.WithTimeout(context.Background(), graceTimeOut)
logger.Debugf("Waiting %s seconds before killing connections on entrypoint %s...", graceTimeOut, serverEntryPointName)
entryPoint.Shutdown(ctx)
serverEntryPoint.Shutdown(ctx)
cancel()
logger.Debugf("Entry point %s closed", serverEntryPointName)
}(sepn, sep)
log.FromContext(ctx).Debugf("Entry point %s closed", entryPointName)
}(epn, ep)
}
wg.Wait()
s.stopChan <- true
@ -385,12 +215,15 @@ func (s *Server) stopLeadership() {
}
func (s *Server) startHTTPServers() {
s.serverEntryPoints = s.buildServerEntryPoints()
// Use an empty configuration in order to initialize the default handlers with internal routes
handlers := s.applyConfiguration(context.Background(), config.Configuration{})
for entryPointName, handler := range handlers {
s.entryPoints[entryPointName].httpRouter.UpdateHandler(handler)
}
for newServerEntryPointName, newServerEntryPoint := range s.serverEntryPoints {
ctx := log.With(context.Background(), log.Str(log.EntryPointName, newServerEntryPointName))
serverEntryPoint := s.setupServerEntryPoint(ctx, newServerEntryPointName, newServerEntryPoint)
go s.startServer(ctx, serverEntryPoint)
for entryPointName, entryPoint := range s.entryPoints {
ctx := log.With(context.Background(), log.Str(log.EntryPointName, entryPointName))
go entryPoint.Start(ctx)
}
}
@ -416,39 +249,6 @@ func (s *Server) AddListener(listener func(config.Configuration)) {
s.configurationListeners = append(s.configurationListeners, listener)
}
// getCertificate allows to customize tlsConfig.GetCertificate behavior to get the certificates inserted dynamically
func (s *serverEntryPoint) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
domainToCheck := types.CanonicalDomain(clientHello.ServerName)
if s.tlsALPNGetter != nil {
cert, err := s.tlsALPNGetter(domainToCheck)
if err != nil {
return nil, err
}
if cert != nil {
return cert, nil
}
}
bestCertificate := s.certs.GetBestCertificate(clientHello)
if bestCertificate != nil {
return bestCertificate, nil
}
if s.onDemandListener != nil && len(domainToCheck) > 0 {
// Only check for an onDemandCert if there is a domain name
return s.onDemandListener(domainToCheck)
}
if s.certs.SniStrict {
return nil, fmt.Errorf("strict SNI enabled - No certificate found for domain: %q, closing connection", domainToCheck)
}
log.WithoutContext().Debugf("Serving default certificate for request: %q", domainToCheck)
return s.certs.DefaultCertificate, nil
}
func (s *Server) startProvider() {
jsonConf, err := json.Marshal(s.provider)
if err != nil {
@ -466,229 +266,6 @@ func (s *Server) startProvider() {
})
}
// creates a TLS config that allows terminating HTTPS for multiple domains using SNI
func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TLS, router *middlewares.HandlerSwitcher) (*tls.Config, error) {
if tlsOption == nil {
return nil, nil
}
conf, err := tlsOption.Certificates.CreateTLSConfig(entryPointName)
if err != nil {
return nil, err
}
s.serverEntryPoints[entryPointName].certs.DynamicCerts.Set(make(map[string]*tls.Certificate))
// ensure http2 enabled
conf.NextProtos = []string{"h2", "http/1.1", acme.ACMETLS1Protocol}
if len(tlsOption.ClientCA.Files) > 0 {
pool := x509.NewCertPool()
for _, caFile := range tlsOption.ClientCA.Files {
data, err := caFile.Read()
if err != nil {
return nil, err
}
ok := pool.AppendCertsFromPEM(data)
if !ok {
return nil, fmt.Errorf("invalid certificate(s) in %s", caFile)
}
}
conf.ClientCAs = pool
if tlsOption.ClientCA.Optional {
conf.ClientAuth = tls.VerifyClientCertIfGiven
} else {
conf.ClientAuth = tls.RequireAndVerifyClientCert
}
}
// FIXME onDemand
if s.globalConfiguration.ACME != nil {
// if entryPointName == s.globalConfiguration.ACME.EntryPoint {
// checkOnDemandDomain := func(domain string) bool {
// routeMatch := &mux.RouteMatch{}
// match := router.GetHandler().Match(&http.Request{URL: &url.URL{}, Host: domain}, routeMatch)
// if match && routeMatch.Route != nil {
// return true
// }
// return false
// }
//
// err := s.globalConfiguration.ACME.CreateClusterConfig(s.leadership, config, s.serverEntryPoints[entryPointName].certs.DynamicCerts, checkOnDemandDomain)
// if err != nil {
// return nil, err
// }
// }
} else {
conf.GetCertificate = s.serverEntryPoints[entryPointName].getCertificate
}
if len(conf.Certificates) != 0 {
certMap := s.buildNameOrIPToCertificate(conf.Certificates)
if s.entryPoints[entryPointName].CertificateStore != nil {
s.entryPoints[entryPointName].CertificateStore.StaticCerts.Set(certMap)
}
}
// Remove certs from the TLS config object
conf.Certificates = []tls.Certificate{}
// Set the minimum TLS version if set in the config TOML
if minConst, exists := traefiktls.MinVersion[tlsOption.MinVersion]; exists {
conf.PreferServerCipherSuites = true
conf.MinVersion = minConst
}
// Set the list of CipherSuites if set in the config TOML
if tlsOption.CipherSuites != nil {
// if our list of CipherSuites is defined in the entryPoint config, we can re-initialize the suites list as empty
conf.CipherSuites = make([]uint16, 0)
for _, cipher := range tlsOption.CipherSuites {
if cipherConst, exists := traefiktls.CipherSuites[cipher]; exists {
conf.CipherSuites = append(conf.CipherSuites, cipherConst)
} else {
// CipherSuite listed in the toml does not exist in our listed
return nil, fmt.Errorf("invalid CipherSuite: %s", cipher)
}
}
}
return conf, nil
}
func (s *Server) startServer(ctx context.Context, serverEntryPoint *serverEntryPoint) {
logger := log.FromContext(ctx)
logger.Infof("Starting server on %s", serverEntryPoint.httpServer.Addr)
var err error
if serverEntryPoint.httpServer.TLSConfig != nil {
err = serverEntryPoint.httpServer.ServeTLS(serverEntryPoint.listener, "", "")
} else {
err = serverEntryPoint.httpServer.Serve(serverEntryPoint.listener)
}
if err != http.ErrServerClosed {
logger.Error("Cannot create server: %v", err)
}
}
func (s *Server) setupServerEntryPoint(ctx context.Context, newServerEntryPointName string, newServerEntryPoint *serverEntryPoint) *serverEntryPoint {
newSrv, listener, err := s.prepareServer(ctx, newServerEntryPointName, s.entryPoints[newServerEntryPointName].Configuration, newServerEntryPoint.httpRouter)
if err != nil {
log.FromContext(ctx).Fatalf("Error preparing server: %v", err)
}
serverEntryPoint := s.serverEntryPoints[newServerEntryPointName]
serverEntryPoint.httpServer = newSrv
serverEntryPoint.listener = listener
serverEntryPoint.hijackConnectionTracker = newHijackConnectionTracker()
serverEntryPoint.httpServer.ConnState = func(conn net.Conn, state http.ConnState) {
switch state {
case http.StateHijacked:
serverEntryPoint.hijackConnectionTracker.AddHijackedConnection(conn)
case http.StateClosed:
serverEntryPoint.hijackConnectionTracker.RemoveHijackedConnection(conn)
}
}
return serverEntryPoint
}
func (s *Server) prepareServer(ctx context.Context, entryPointName string, entryPoint *configuration.EntryPoint, router *middlewares.HandlerSwitcher) (*h2c.Server, net.Listener, error) {
logger := log.FromContext(ctx)
readTimeout, writeTimeout, idleTimeout := buildServerTimeouts(s.globalConfiguration)
logger.
WithField("readTimeout", readTimeout).
WithField("writeTimeout", writeTimeout).
WithField("idleTimeout", idleTimeout).
Infof("Preparing server %+v", entryPoint)
tlsConfig, err := s.createTLSConfig(entryPointName, entryPoint.TLS, router)
if err != nil {
return nil, nil, fmt.Errorf("error creating TLS config: %v", err)
}
listener, err := net.Listen("tcp", entryPoint.Address)
if err != nil {
return nil, nil, fmt.Errorf("error opening listener: %v", err)
}
listener = tcpKeepAliveListener{listener.(*net.TCPListener)}
if entryPoint.ProxyProtocol != nil {
listener, err = buildProxyProtocolListener(ctx, entryPoint, listener)
if err != nil {
return nil, nil, fmt.Errorf("error creating proxy protocol listener: %v", err)
}
}
httpServerLogger := stdlog.New(logger.WriterLevel(logrus.DebugLevel), "", 0)
return &h2c.Server{
Server: &http.Server{
Addr: entryPoint.Address,
Handler: router,
TLSConfig: tlsConfig,
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
IdleTimeout: idleTimeout,
ErrorLog: httpServerLogger,
},
},
listener,
nil
}
func buildProxyProtocolListener(ctx context.Context, entryPoint *configuration.EntryPoint, listener net.Listener) (net.Listener, error) {
var sourceCheck func(addr net.Addr) (bool, error)
if entryPoint.ProxyProtocol.Insecure {
sourceCheck = func(_ net.Addr) (bool, error) {
return true, nil
}
} else {
checker, err := ip.NewChecker(entryPoint.ProxyProtocol.TrustedIPs)
if err != nil {
return nil, err
}
sourceCheck = func(addr net.Addr) (bool, error) {
ipAddr, ok := addr.(*net.TCPAddr)
if !ok {
return false, fmt.Errorf("type error %v", addr)
}
return checker.ContainsIP(ipAddr.IP), nil
}
}
log.FromContext(ctx).Infof("Enabling ProxyProtocol for trusted IPs %v", entryPoint.ProxyProtocol.TrustedIPs)
return &proxyproto.Listener{
Listener: listener,
SourceCheck: sourceCheck,
}, nil
}
func buildServerTimeouts(globalConfig configuration.GlobalConfiguration) (readTimeout, writeTimeout, idleTimeout time.Duration) {
readTimeout = time.Duration(0)
writeTimeout = time.Duration(0)
if globalConfig.RespondingTimeouts != nil {
readTimeout = time.Duration(globalConfig.RespondingTimeouts.ReadTimeout)
writeTimeout = time.Duration(globalConfig.RespondingTimeouts.WriteTimeout)
}
if globalConfig.RespondingTimeouts != nil {
idleTimeout = time.Duration(globalConfig.RespondingTimeouts.IdleTimeout)
} else {
idleTimeout = configuration.DefaultIdleTimeout
}
return readTimeout, writeTimeout, idleTimeout
}
func registerMetricClients(metricsConfig *types.Metrics) metrics.Registry {
if metricsConfig == nil {
return metrics.NewVoidRegistry()
@ -734,24 +311,3 @@ func stopMetricsClients() {
metrics.StopStatsd()
metrics.StopInfluxDB()
}
func (s *Server) buildNameOrIPToCertificate(certs []tls.Certificate) map[string]*tls.Certificate {
certMap := make(map[string]*tls.Certificate)
for i := range certs {
cert := &certs[i]
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
continue
}
if len(x509Cert.Subject.CommonName) > 0 {
certMap[x509Cert.Subject.CommonName] = cert
}
for _, san := range x509Cert.DNSNames {
certMap[san] = cert
}
for _, ipSan := range x509Cert.IPAddresses {
certMap[ipSan.String()] = cert
}
}
return certMap
}

View file

@ -12,19 +12,15 @@ import (
"github.com/containous/alice"
"github.com/containous/mux"
"github.com/containous/traefik/config"
"github.com/containous/traefik/config/static"
"github.com/containous/traefik/log"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/middlewares/accesslog"
"github.com/containous/traefik/middlewares/requestdecorator"
"github.com/containous/traefik/middlewares/tracing"
"github.com/containous/traefik/old/configuration"
"github.com/containous/traefik/responsemodifiers"
"github.com/containous/traefik/server/middleware"
"github.com/containous/traefik/server/router"
"github.com/containous/traefik/server/service"
traefiktls "github.com/containous/traefik/tls"
"github.com/containous/traefik/tls/generate"
"github.com/eapache/channels"
"github.com/sirupsen/logrus"
)
@ -44,25 +40,25 @@ func (s *Server) loadConfiguration(configMsg config.Message) {
s.metricsRegistry.ConfigReloadsCounter().Add(1)
handlers, certificates := s.loadConfig(newConfigurations, s.globalConfiguration)
handlers, certificates := s.loadConfig(newConfigurations)
s.metricsRegistry.LastConfigReloadSuccessGauge().Set(float64(time.Now().Unix()))
for entryPointName, handler := range handlers {
s.serverEntryPoints[entryPointName].httpRouter.UpdateHandler(handler)
s.entryPoints[entryPointName].httpRouter.UpdateHandler(handler)
}
for entryPointName, serverEntryPoint := range s.serverEntryPoints {
for entryPointName, entryPoint := range s.entryPoints {
eLogger := logger.WithField(log.EntryPointName, entryPointName)
if s.entryPoints[entryPointName].Configuration.TLS == nil {
if entryPoint.Certs == nil {
if len(certificates[entryPointName]) > 0 {
eLogger.Debugf("Cannot configure certificates for the non-TLS %s entryPoint.", entryPointName)
}
} else {
serverEntryPoint.certs.DynamicCerts.Set(certificates[entryPointName])
serverEntryPoint.certs.ResetCache()
entryPoint.Certs.DynamicCerts.Set(certificates[entryPointName])
entryPoint.Certs.ResetCache()
}
eLogger.Infof("Server configuration reloaded on %s", s.serverEntryPoints[entryPointName].httpServer.Addr)
eLogger.Infof("Server configuration reloaded on %s", s.entryPoints[entryPointName].httpServer.Addr)
}
s.currentConfigurations.Set(newConfigurations)
@ -76,7 +72,7 @@ func (s *Server) loadConfiguration(configMsg config.Message) {
// loadConfig returns a new gorilla.mux Route from the specified global configuration and the dynamic
// provider configurations.
func (s *Server) loadConfig(configurations config.Configurations, globalConfiguration configuration.GlobalConfiguration) (map[string]http.Handler, map[string]map[string]*tls.Certificate) {
func (s *Server) loadConfig(configurations config.Configurations) (map[string]http.Handler, map[string]map[string]*tls.Certificate) {
ctx := context.TODO()
@ -106,14 +102,12 @@ func (s *Server) loadConfig(configurations config.Configurations, globalConfigur
// Get new certificates list sorted per entry points
// Update certificates
entryPointsCertificates := s.loadHTTPSConfiguration(configurations, globalConfiguration.DefaultEntryPoints)
entryPointsCertificates := s.loadHTTPSConfiguration(configurations)
return handlers, entryPointsCertificates
}
func (s *Server) applyConfiguration(ctx context.Context, configuration config.Configuration) map[string]http.Handler {
staticConfiguration := static.ConvertStaticConf(s.globalConfiguration)
var entryPoints []string
for entryPointName := range s.entryPoints {
entryPoints = append(entryPoints, entryPointName)
@ -125,7 +119,7 @@ func (s *Server) applyConfiguration(ctx context.Context, configuration config.Co
routerManager := router.NewManager(configuration.Routers, serviceManager, middlewaresBuilder, responseModifierFactory)
handlers := routerManager.BuildHandlers(ctx, entryPoints, staticConfiguration.EntryPoints.Defaults)
handlers := routerManager.BuildHandlers(ctx, entryPoints)
routerHandlers := make(map[string]http.Handler)
@ -145,7 +139,7 @@ func (s *Server) applyConfiguration(ctx context.Context, configuration config.Co
if h, ok := handlers[entryPointName]; ok {
internalMuxRouter.NotFoundHandler = h
} else {
internalMuxRouter.NotFoundHandler = s.buildDefaultHTTPRouter()
internalMuxRouter.NotFoundHandler = buildDefaultHTTPRouter()
}
routerHandlers[entryPointName] = internalMuxRouter
@ -174,7 +168,6 @@ func (s *Server) applyConfiguration(ctx context.Context, configuration config.Co
}
func (s *Server) preLoadConfiguration(configMsg config.Message) {
providersThrottleDuration := time.Duration(s.globalConfiguration.ProvidersThrottleDuration)
s.defaultConfigurationValues(configMsg.Configuration)
currentConfigurations := s.currentConfigurations.Get().(config.Configurations)
@ -199,7 +192,7 @@ func (s *Server) preLoadConfiguration(configMsg config.Message) {
providerConfigUpdateCh = make(chan config.Message)
s.providerConfigUpdateMap[configMsg.ProviderName] = providerConfigUpdateCh
s.routinesPool.Go(func(stop chan bool) {
s.throttleProviderConfigReload(providersThrottleDuration, s.configurationValidatedChan, providerConfigUpdateCh, stop)
s.throttleProviderConfigReload(s.providersThrottleDuration, s.configurationValidatedChan, providerConfigUpdateCh, stop)
})
}
@ -263,12 +256,12 @@ func (s *Server) postLoadConfiguration() {
// metrics.OnConfigurationUpdate(activeConfig)
// }
if s.globalConfiguration.ACME == nil || s.leadership == nil || !s.leadership.IsLeader() {
return
}
// FIXME acme
// if s.globalConfiguration.ACME.OnHostRule {
// if s.staticConfiguration.ACME == nil || s.leadership == nil || !s.leadership.IsLeader() {
// return
// }
//
// if s.staticConfiguration.ACME.OnHostRule {
// currentConfigurations := s.currentConfigurations.Get().(config.Configurations)
// for _, config := range currentConfigurations {
// for _, frontend := range config.Frontends {
@ -277,7 +270,7 @@ func (s *Server) postLoadConfiguration() {
// // and is configured with ACME
// acmeEnabled := false
// for _, entryPoint := range frontend.EntryPoints {
// if s.globalConfiguration.ACME.EntryPoint == entryPoint && s.entryPoints[entryPoint].Configuration.TLS != nil {
// if s.staticConfiguration.ACME.EntryPoint == entryPoint && s.entryPoints[entryPoint].Configuration.TLS != nil {
// acmeEnabled = true
// break
// }
@ -292,7 +285,7 @@ func (s *Server) postLoadConfiguration() {
// } else if len(domains) == 0 {
// log.Debugf("No domain parsed in rule %q", route.Rule)
// } else {
// s.globalConfiguration.ACME.LoadCertificateForDomains(domains)
// s.staticConfiguration.ACME.LoadCertificateForDomains(domains)
// }
// }
// }
@ -302,69 +295,23 @@ func (s *Server) postLoadConfiguration() {
}
// loadHTTPSConfiguration add/delete HTTPS certificate managed dynamically
func (s *Server) loadHTTPSConfiguration(configurations config.Configurations, defaultEntryPoints configuration.DefaultEntryPoints) map[string]map[string]*tls.Certificate {
func (s *Server) loadHTTPSConfiguration(configurations config.Configurations) map[string]map[string]*tls.Certificate {
var entryPoints []string
for entryPointName := range s.entryPoints {
entryPoints = append(entryPoints, entryPointName)
}
newEPCertificates := make(map[string]map[string]*tls.Certificate)
// Get all certificates
for _, config := range configurations {
if config.TLS != nil && len(config.TLS) > 0 {
traefiktls.SortTLSPerEntryPoints(config.TLS, newEPCertificates, defaultEntryPoints)
traefiktls.SortTLSPerEntryPoints(config.TLS, newEPCertificates, entryPoints)
}
}
return newEPCertificates
}
func (s *Server) buildServerEntryPoints() map[string]*serverEntryPoint {
serverEntryPoints := make(map[string]*serverEntryPoint)
ctx := context.Background()
handlers := s.applyConfiguration(ctx, config.Configuration{})
for entryPointName, entryPoint := range s.entryPoints {
serverEntryPoints[entryPointName] = &serverEntryPoint{
httpRouter: middlewares.NewHandlerSwitcher(handlers[entryPointName]),
onDemandListener: entryPoint.OnDemandListener,
tlsALPNGetter: entryPoint.TLSALPNGetter,
}
if entryPoint.CertificateStore != nil {
serverEntryPoints[entryPointName].certs = entryPoint.CertificateStore
} else {
serverEntryPoints[entryPointName].certs = traefiktls.NewCertificateStore()
}
if entryPoint.Configuration.TLS != nil {
logger := log.FromContext(ctx).WithField(log.EntryPointName, entryPointName)
serverEntryPoints[entryPointName].certs.SniStrict = entryPoint.Configuration.TLS.SniStrict
if entryPoint.Configuration.TLS.DefaultCertificate != nil {
cert, err := buildDefaultCertificate(entryPoint.Configuration.TLS.DefaultCertificate)
if err != nil {
logger.Error(err)
continue
}
serverEntryPoints[entryPointName].certs.DefaultCertificate = cert
} else {
cert, err := generate.DefaultCertificate()
if err != nil {
logger.Error(err)
continue
}
serverEntryPoints[entryPointName].certs.DefaultCertificate = cert
}
if len(entryPoint.Configuration.TLS.Certificates) > 0 {
config, _ := entryPoint.Configuration.TLS.Certificates.CreateTLSConfig(entryPointName)
certMap := s.buildNameOrIPToCertificate(config.Certificates)
serverEntryPoints[entryPointName].certs.StaticCerts.Set(certMap)
}
}
}
return serverEntryPoints
}
func (s *Server) buildDefaultHTTPRouter() *mux.Router {
func buildDefaultHTTPRouter() *mux.Router {
rt := mux.NewRouter()
rt.NotFoundHandler = http.HandlerFunc(http.NotFound)
rt.SkipClean(true)

View file

@ -7,7 +7,7 @@ import (
"time"
"github.com/containous/traefik/config"
"github.com/containous/traefik/old/configuration"
"github.com/containous/traefik/config/static"
th "github.com/containous/traefik/testhelpers"
"github.com/containous/traefik/tls"
"github.com/stretchr/testify/assert"
@ -51,14 +51,8 @@ f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA==
-----END RSA PRIVATE KEY-----`)
)
func TestServerLoadCertificateWithDefaultEntryPoint(t *testing.T) {
globalConfig := configuration.GlobalConfiguration{
DefaultEntryPoints: []string{"http", "https"},
}
entryPoints := map[string]EntryPoint{
"https": {Configuration: &configuration.EntryPoint{TLS: &tls.TLS{}}},
"http": {Configuration: &configuration.EntryPoint{}},
}
func TestServerLoadCertificateWithTLSEntryPoints(t *testing.T) {
staticConfig := static.Configuration{}
dynamicConfigs := config.Configurations{
"config": &config.Configuration{
@ -73,9 +67,16 @@ func TestServerLoadCertificateWithDefaultEntryPoint(t *testing.T) {
},
}
srv := NewServer(globalConfig, nil, entryPoints)
_, mapsCerts := srv.loadConfig(dynamicConfigs, globalConfig)
if len(mapsCerts["https"]) == 0 {
srv := NewServer(staticConfig, nil, EntryPoints{
"https": &EntryPoint{
Certs: tls.NewCertificateStore(),
},
"https2": &EntryPoint{
Certs: tls.NewCertificateStore(),
},
})
_, mapsCerts := srv.loadConfig(dynamicConfigs)
if len(mapsCerts["https"]) == 0 || len(mapsCerts["https2"]) == 0 {
t.Fatal("got error: https entryPoint must have TLS certificates.")
}
}
@ -86,15 +87,11 @@ func TestReuseService(t *testing.T) {
}))
defer testServer.Close()
entryPoints := map[string]EntryPoint{
"http": {Configuration: &configuration.EntryPoint{
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
}},
entryPoints := EntryPoints{
"http": &EntryPoint{},
}
globalConfig := configuration.GlobalConfiguration{
DefaultEntryPoints: []string{"http"},
}
staticConfig := static.Configuration{}
dynamicConfigs := config.Configurations{
"config": th.BuildConfiguration(
@ -118,14 +115,14 @@ func TestReuseService(t *testing.T) {
),
}
srv := NewServer(globalConfig, nil, entryPoints)
srv := NewServer(staticConfig, nil, entryPoints)
serverEntryPoints, _ := srv.loadConfig(dynamicConfigs, globalConfig)
entrypointsHandlers, _ := srv.loadConfig(dynamicConfigs)
// Test that the /ok path returns a status 200.
responseRecorderOk := &httptest.ResponseRecorder{}
requestOk := httptest.NewRequest(http.MethodGet, testServer.URL+"/ok", nil)
serverEntryPoints["http"].ServeHTTP(responseRecorderOk, requestOk)
entrypointsHandlers["http"].ServeHTTP(responseRecorderOk, requestOk)
assert.Equal(t, http.StatusOK, responseRecorderOk.Result().StatusCode, "status code")
@ -133,7 +130,7 @@ func TestReuseService(t *testing.T) {
// the basic authentication defined on the frontend.
responseRecorderUnauthorized := &httptest.ResponseRecorder{}
requestUnauthorized := httptest.NewRequest(http.MethodGet, testServer.URL+"/unauthorized", nil)
serverEntryPoints["http"].ServeHTTP(responseRecorderUnauthorized, requestUnauthorized)
entrypointsHandlers["http"].ServeHTTP(responseRecorderUnauthorized, requestUnauthorized)
assert.Equal(t, http.StatusUnauthorized, responseRecorderUnauthorized.Result().StatusCode, "status code")
}
@ -147,8 +144,8 @@ func TestThrottleProviderConfigReload(t *testing.T) {
stop <- true
}()
globalConfig := configuration.GlobalConfiguration{}
server := NewServer(globalConfig, nil, nil)
staticConfiguration := static.Configuration{}
server := NewServer(staticConfiguration, nil, nil)
go server.throttleProviderConfigReload(throttleDuration, publishConfig, providerConfig, stop)

426
server/server_entrypoint.go Normal file
View file

@ -0,0 +1,426 @@
package server
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
stdlog "log"
"net"
"net/http"
"sync"
"time"
"github.com/armon/go-proxyproto"
"github.com/containous/traefik/config/static"
"github.com/containous/traefik/h2c"
"github.com/containous/traefik/ip"
"github.com/containous/traefik/log"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/old/configuration"
traefiktls "github.com/containous/traefik/tls"
"github.com/containous/traefik/tls/generate"
"github.com/containous/traefik/types"
"github.com/sirupsen/logrus"
"github.com/xenolf/lego/acme"
)
// EntryPoints map of EntryPoint
type EntryPoints map[string]*EntryPoint
// NewEntryPoint creates a new EntryPoint
func NewEntryPoint(ctx context.Context, configuration *static.EntryPoint) (*EntryPoint, error) {
logger := log.FromContext(ctx)
var err error
router := middlewares.NewHandlerSwitcher(buildDefaultHTTPRouter())
tracker := newHijackConnectionTracker()
listener, err := buildListener(ctx, configuration)
if err != nil {
logger.Fatalf("Error preparing server: %v", err)
}
var tlsConfig *tls.Config
var certificateStore *traefiktls.CertificateStore
if configuration.TLS != nil {
certificateStore, err = buildCertificateStore(*configuration.TLS)
if err != nil {
return nil, fmt.Errorf("error creating certificate store: %v", err)
}
tlsConfig, err = buildTLSConfig(*configuration.TLS)
if err != nil {
return nil, fmt.Errorf("error creating TLS config: %v", err)
}
}
entryPoint := &EntryPoint{
httpRouter: router,
transportConfiguration: configuration.Transport,
hijackConnectionTracker: tracker,
listener: listener,
httpServer: buildServer(ctx, configuration, tlsConfig, router, tracker),
Certs: certificateStore,
}
if tlsConfig != nil {
tlsConfig.GetCertificate = entryPoint.getCertificate
}
return entryPoint, nil
}
// EntryPoint holds everything about the entry point (httpServer, listener etc...)
type EntryPoint struct {
RouteAppenderFactory RouteAppenderFactory
httpServer *h2c.Server
listener net.Listener
httpRouter *middlewares.HandlerSwitcher
Certs *traefiktls.CertificateStore
OnDemandListener func(string) (*tls.Certificate, error)
TLSALPNGetter func(string) (*tls.Certificate, error)
hijackConnectionTracker *hijackConnectionTracker
transportConfiguration *static.EntryPointsTransport
}
// Start starts listening for traffic
func (s *EntryPoint) Start(ctx context.Context) {
logger := log.FromContext(ctx)
logger.Infof("Starting server on %s", s.httpServer.Addr)
var err error
if s.httpServer.TLSConfig != nil {
err = s.httpServer.ServeTLS(s.listener, "", "")
} else {
err = s.httpServer.Serve(s.listener)
}
if err != http.ErrServerClosed {
logger.Error("Cannot start server: %v", err)
}
}
// Shutdown handles the entrypoint shutdown process
func (s EntryPoint) Shutdown(ctx context.Context) {
logger := log.FromContext(ctx)
reqAcceptGraceTimeOut := time.Duration(s.transportConfiguration.LifeCycle.RequestAcceptGraceTimeout)
if reqAcceptGraceTimeOut > 0 {
logger.Infof("Waiting %s for incoming requests to cease", reqAcceptGraceTimeOut)
time.Sleep(reqAcceptGraceTimeOut)
}
graceTimeOut := time.Duration(s.transportConfiguration.LifeCycle.GraceTimeOut)
ctx, cancel := context.WithTimeout(ctx, graceTimeOut)
logger.Debugf("Waiting %s seconds before killing connections.", graceTimeOut)
var wg sync.WaitGroup
if s.httpServer != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := s.httpServer.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
logger.Debugf("Wait server shutdown is overdue to: %s", err)
err = s.httpServer.Close()
if err != nil {
logger.Error(err)
}
}
}
}()
}
if s.hijackConnectionTracker != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := s.hijackConnectionTracker.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
logger.Debugf("Wait hijack connection is overdue to: %s", err)
s.hijackConnectionTracker.Close()
}
}
}()
}
wg.Wait()
cancel()
}
// getCertificate allows to customize tlsConfig.GetCertificate behavior to get the certificates inserted dynamically
func (s *EntryPoint) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
domainToCheck := types.CanonicalDomain(clientHello.ServerName)
if s.TLSALPNGetter != nil {
cert, err := s.TLSALPNGetter(domainToCheck)
if err != nil {
return nil, err
}
if cert != nil {
return cert, nil
}
}
bestCertificate := s.Certs.GetBestCertificate(clientHello)
if bestCertificate != nil {
return bestCertificate, nil
}
if s.OnDemandListener != nil && len(domainToCheck) > 0 {
// Only check for an onDemandCert if there is a domain name
return s.OnDemandListener(domainToCheck)
}
if s.Certs.SniStrict {
return nil, fmt.Errorf("strict SNI enabled - No certificate found for domain: %q, closing connection", domainToCheck)
}
log.WithoutContext().Debugf("Serving default certificate for request: %q", domainToCheck)
return s.Certs.DefaultCertificate, nil
}
func newHijackConnectionTracker() *hijackConnectionTracker {
return &hijackConnectionTracker{
conns: make(map[net.Conn]struct{}),
}
}
type hijackConnectionTracker struct {
conns map[net.Conn]struct{}
lock sync.RWMutex
}
// AddHijackedConnection add a connection in the tracked connections list
func (h *hijackConnectionTracker) AddHijackedConnection(conn net.Conn) {
h.lock.Lock()
defer h.lock.Unlock()
h.conns[conn] = struct{}{}
}
// RemoveHijackedConnection remove a connection from the tracked connections list
func (h *hijackConnectionTracker) RemoveHijackedConnection(conn net.Conn) {
h.lock.Lock()
defer h.lock.Unlock()
delete(h.conns, conn)
}
// Shutdown wait for the connection closing
func (h *hijackConnectionTracker) Shutdown(ctx context.Context) error {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
h.lock.RLock()
if len(h.conns) == 0 {
return nil
}
h.lock.RUnlock()
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
}
}
// Close close all the connections in the tracked connections list
func (h *hijackConnectionTracker) Close() {
for conn := range h.conns {
if err := conn.Close(); err != nil {
log.WithoutContext().Errorf("Error while closing Hijacked connection: %v", err)
}
delete(h.conns, conn)
}
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections.
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
tc, err := ln.AcceptTCP()
if err != nil {
return nil, err
}
if err = tc.SetKeepAlive(true); err != nil {
return nil, err
}
if err = tc.SetKeepAlivePeriod(3 * time.Minute); err != nil {
return nil, err
}
return tc, nil
}
func buildProxyProtocolListener(ctx context.Context, entryPoint *static.EntryPoint, listener net.Listener) (net.Listener, error) {
var sourceCheck func(addr net.Addr) (bool, error)
if entryPoint.ProxyProtocol.Insecure {
sourceCheck = func(_ net.Addr) (bool, error) {
return true, nil
}
} else {
checker, err := ip.NewChecker(entryPoint.ProxyProtocol.TrustedIPs)
if err != nil {
return nil, err
}
sourceCheck = func(addr net.Addr) (bool, error) {
ipAddr, ok := addr.(*net.TCPAddr)
if !ok {
return false, fmt.Errorf("type error %v", addr)
}
return checker.ContainsIP(ipAddr.IP), nil
}
}
log.FromContext(ctx).Infof("Enabling ProxyProtocol for trusted IPs %v", entryPoint.ProxyProtocol.TrustedIPs)
return &proxyproto.Listener{
Listener: listener,
SourceCheck: sourceCheck,
}, nil
}
func buildServerTimeouts(entryPointsTransport static.EntryPointsTransport) (readTimeout, writeTimeout, idleTimeout time.Duration) {
readTimeout = time.Duration(0)
writeTimeout = time.Duration(0)
if entryPointsTransport.RespondingTimeouts != nil {
readTimeout = time.Duration(entryPointsTransport.RespondingTimeouts.ReadTimeout)
writeTimeout = time.Duration(entryPointsTransport.RespondingTimeouts.WriteTimeout)
}
if entryPointsTransport.RespondingTimeouts != nil {
idleTimeout = time.Duration(entryPointsTransport.RespondingTimeouts.IdleTimeout)
} else {
idleTimeout = configuration.DefaultIdleTimeout
}
return readTimeout, writeTimeout, idleTimeout
}
func buildListener(ctx context.Context, entryPoint *static.EntryPoint) (net.Listener, error) {
listener, err := net.Listen("tcp", entryPoint.Address)
if err != nil {
return nil, fmt.Errorf("error opening listener: %v", err)
}
listener = tcpKeepAliveListener{listener.(*net.TCPListener)}
if entryPoint.ProxyProtocol != nil {
listener, err = buildProxyProtocolListener(ctx, entryPoint, listener)
if err != nil {
return nil, fmt.Errorf("error creating proxy protocol listener: %v", err)
}
}
return listener, nil
}
func buildCertificateStore(tlsOption traefiktls.TLS) (*traefiktls.CertificateStore, error) {
certificateStore := traefiktls.NewCertificateStore()
certificateStore.DynamicCerts.Set(make(map[string]*tls.Certificate))
certificateStore.SniStrict = tlsOption.SniStrict
if tlsOption.DefaultCertificate != nil {
cert, err := buildDefaultCertificate(tlsOption.DefaultCertificate)
if err != nil {
return nil, err
}
certificateStore.DefaultCertificate = cert
} else {
cert, err := generate.DefaultCertificate()
if err != nil {
return nil, err
}
certificateStore.DefaultCertificate = cert
}
return certificateStore, nil
}
func buildServer(ctx context.Context, configuration *static.EntryPoint, tlsConfig *tls.Config, router http.Handler, tracker *hijackConnectionTracker) *h2c.Server {
logger := log.FromContext(ctx)
readTimeout, writeTimeout, idleTimeout := buildServerTimeouts(*configuration.Transport)
logger.
WithField("readTimeout", readTimeout).
WithField("writeTimeout", writeTimeout).
WithField("idleTimeout", idleTimeout).
Infof("Preparing server")
return &h2c.Server{
Server: &http.Server{
Addr: configuration.Address,
Handler: router,
TLSConfig: tlsConfig,
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
IdleTimeout: idleTimeout,
ErrorLog: stdlog.New(logger.WriterLevel(logrus.DebugLevel), "", 0),
ConnState: func(conn net.Conn, state http.ConnState) {
switch state {
case http.StateHijacked:
tracker.AddHijackedConnection(conn)
case http.StateClosed:
tracker.RemoveHijackedConnection(conn)
}
},
},
}
}
// creates a TLS config that allows terminating HTTPS for multiple domains using SNI
func buildTLSConfig(tlsOption traefiktls.TLS) (*tls.Config, error) {
conf := &tls.Config{}
// ensure http2 enabled
conf.NextProtos = []string{"h2", "http/1.1", acme.ACMETLS1Protocol}
if len(tlsOption.ClientCA.Files) > 0 {
pool := x509.NewCertPool()
for _, caFile := range tlsOption.ClientCA.Files {
data, err := caFile.Read()
if err != nil {
return nil, err
}
ok := pool.AppendCertsFromPEM(data)
if !ok {
return nil, fmt.Errorf("invalid certificate(s) in %s", caFile)
}
}
conf.ClientCAs = pool
if tlsOption.ClientCA.Optional {
conf.ClientAuth = tls.VerifyClientCertIfGiven
} else {
conf.ClientAuth = tls.RequireAndVerifyClientCert
}
}
// Set the minimum TLS version if set in the config TOML
if minConst, exists := traefiktls.MinVersion[tlsOption.MinVersion]; exists {
conf.PreferServerCipherSuites = true
conf.MinVersion = minConst
}
// Set the list of CipherSuites if set in the config TOML
if tlsOption.CipherSuites != nil {
// if our list of CipherSuites is defined in the entryPoint config, we can re-initialize the suites list as empty
conf.CipherSuites = make([]uint16, 0)
for _, cipher := range tlsOption.CipherSuites {
if cipherConst, exists := traefiktls.CipherSuites[cipher]; exists {
conf.CipherSuites = append(conf.CipherSuites, cipherConst)
} else {
// CipherSuite listed in the toml does not exist in our listed
return nil, fmt.Errorf("invalid CipherSuite: %s", cipher)
}
}
}
return conf, nil
}

View file

@ -1,76 +1,18 @@
package server
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/containous/flaeg/parse"
"github.com/containous/mux"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/old/configuration"
"github.com/containous/traefik/config/static"
th "github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPrepareServerTimeouts(t *testing.T) {
testCases := []struct {
desc string
globalConfig configuration.GlobalConfiguration
expectedIdleTimeout time.Duration
expectedReadTimeout time.Duration
expectedWriteTimeout time.Duration
}{
{
desc: "full configuration",
globalConfig: configuration.GlobalConfiguration{
RespondingTimeouts: &configuration.RespondingTimeouts{
IdleTimeout: parse.Duration(10 * time.Second),
ReadTimeout: parse.Duration(12 * time.Second),
WriteTimeout: parse.Duration(14 * time.Second),
},
},
expectedIdleTimeout: 10 * time.Second,
expectedReadTimeout: 12 * time.Second,
expectedWriteTimeout: 14 * time.Second,
},
{
desc: "using defaults",
globalConfig: configuration.GlobalConfiguration{},
expectedIdleTimeout: 180 * time.Second,
expectedReadTimeout: 0 * time.Second,
expectedWriteTimeout: 0 * time.Second,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
entryPointName := "http"
entryPoint := &configuration.EntryPoint{
Address: "localhost:0",
ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true},
}
router := middlewares.NewHandlerSwitcher(mux.NewRouter())
srv := NewServer(test.globalConfig, nil, nil)
httpServer, _, err := srv.prepareServer(context.Background(), entryPointName, entryPoint, router)
require.NoError(t, err, "Unexpected error when preparing srv")
assert.Equal(t, test.expectedIdleTimeout, httpServer.IdleTimeout, "IdleTimeout")
assert.Equal(t, test.expectedReadTimeout, httpServer.ReadTimeout, "ReadTimeout")
assert.Equal(t, test.expectedWriteTimeout, httpServer.WriteTimeout, "WriteTimeout")
})
}
}
func TestListenProvidersSkipsEmptyConfigs(t *testing.T) {
server, stop, invokeStopChan := setupListenProvider(10 * time.Millisecond)
defer invokeStopChan()
@ -186,14 +128,13 @@ func setupListenProvider(throttleDuration time.Duration) (server *Server, stop c
stop <- true
}
globalConfig := configuration.GlobalConfiguration{
EntryPoints: configuration.EntryPoints{
"http": &configuration.EntryPoint{},
staticConfiguration := static.Configuration{
Providers: &static.Providers{
ProvidersThrottleDuration: parse.Duration(throttleDuration),
},
ProvidersThrottleDuration: parse.Duration(throttleDuration),
}
server = NewServer(globalConfig, nil, nil)
server = NewServer(staticConfiguration, nil, nil)
go server.listenProviders(stop)
return server, stop, invokeStopChan
@ -309,14 +250,14 @@ func TestServerResponseEmptyBackend(t *testing.T) {
}))
defer testServer.Close()
globalConfig := configuration.GlobalConfiguration{}
entryPointsConfig := map[string]EntryPoint{
"http": {Configuration: &configuration.EntryPoint{ForwardedHeaders: &configuration.ForwardedHeaders{Insecure: true}}},
globalConfig := static.Configuration{}
entryPointsConfig := EntryPoints{
"http": &EntryPoint{},
}
dynamicConfigs := config.Configurations{"config": test.config(testServer.URL)}
srv := NewServer(globalConfig, nil, entryPointsConfig)
entryPoints, _ := srv.loadConfig(dynamicConfigs, globalConfig)
entryPoints, _ := srv.loadConfig(dynamicConfigs)
responseRecorder := &httptest.ResponseRecorder{}
request := httptest.NewRequest(http.MethodGet, testServer.URL+requestPath, nil)