1
0
Fork 0

Add a new protocol

Co-authored-by: Gérald Croës <gerald@containo.us>
This commit is contained in:
Julien Salleyron 2019-03-14 09:30:04 +01:00 committed by Traefiker Bot
parent 0ca2149408
commit 4a68d29ce2
231 changed files with 6895 additions and 4395 deletions

View file

@ -3,26 +3,54 @@ package server
import (
"github.com/containous/traefik/config"
"github.com/containous/traefik/server/internal"
"github.com/containous/traefik/tls"
)
func mergeConfiguration(configurations config.Configurations) config.Configuration {
conf := config.Configuration{
Routers: make(map[string]*config.Router),
Middlewares: make(map[string]*config.Middleware),
Services: make(map[string]*config.Service),
HTTP: &config.HTTPConfiguration{
Routers: make(map[string]*config.Router),
Middlewares: make(map[string]*config.Middleware),
Services: make(map[string]*config.Service),
},
TCP: &config.TCPConfiguration{
Routers: make(map[string]*config.TCPRouter),
Services: make(map[string]*config.TCPService),
},
TLSOptions: make(map[string]tls.TLS),
TLSStores: make(map[string]tls.Store),
}
for provider, configuration := range configurations {
for routerName, router := range configuration.Routers {
conf.Routers[internal.MakeQualifiedName(provider, routerName)] = router
if configuration.HTTP != nil {
for routerName, router := range configuration.HTTP.Routers {
conf.HTTP.Routers[internal.MakeQualifiedName(provider, routerName)] = router
}
for middlewareName, middleware := range configuration.HTTP.Middlewares {
conf.HTTP.Middlewares[internal.MakeQualifiedName(provider, middlewareName)] = middleware
}
for serviceName, service := range configuration.HTTP.Services {
conf.HTTP.Services[internal.MakeQualifiedName(provider, serviceName)] = service
}
}
for middlewareName, middleware := range configuration.Middlewares {
conf.Middlewares[internal.MakeQualifiedName(provider, middlewareName)] = middleware
}
for serviceName, service := range configuration.Services {
conf.Services[internal.MakeQualifiedName(provider, serviceName)] = service
if configuration.TCP != nil {
for routerName, router := range configuration.TCP.Routers {
conf.TCP.Routers[internal.MakeQualifiedName(provider, routerName)] = router
}
for serviceName, service := range configuration.TCP.Services {
conf.TCP.Services[internal.MakeQualifiedName(provider, serviceName)] = service
}
}
conf.TLS = append(conf.TLS, configuration.TLS...)
for key, store := range configuration.TLSStores {
conf.TLSStores[key] = store
}
for key, config := range configuration.TLSOptions {
conf.TLSOptions[key] = config
}
}
return conf

View file

@ -11,12 +11,12 @@ func TestAggregator(t *testing.T) {
testCases := []struct {
desc string
given config.Configurations
expected config.Configuration
expected *config.HTTPConfiguration
}{
{
desc: "Nil returns an empty configuration",
given: nil,
expected: config.Configuration{
expected: &config.HTTPConfiguration{
Routers: make(map[string]*config.Router),
Middlewares: make(map[string]*config.Middleware),
Services: make(map[string]*config.Service),
@ -26,18 +26,20 @@ func TestAggregator(t *testing.T) {
desc: "Returns fully qualified elements from a mono-provider configuration map",
given: config.Configurations{
"provider-1": &config.Configuration{
Routers: map[string]*config.Router{
"router-1": {},
},
Middlewares: map[string]*config.Middleware{
"middleware-1": {},
},
Services: map[string]*config.Service{
"service-1": {},
HTTP: &config.HTTPConfiguration{
Routers: map[string]*config.Router{
"router-1": {},
},
Middlewares: map[string]*config.Middleware{
"middleware-1": {},
},
Services: map[string]*config.Service{
"service-1": {},
},
},
},
},
expected: config.Configuration{
expected: &config.HTTPConfiguration{
Routers: map[string]*config.Router{
"provider-1.router-1": {},
},
@ -53,29 +55,33 @@ func TestAggregator(t *testing.T) {
desc: "Returns fully qualified elements from a multi-provider configuration map",
given: config.Configurations{
"provider-1": &config.Configuration{
Routers: map[string]*config.Router{
"router-1": {},
},
Middlewares: map[string]*config.Middleware{
"middleware-1": {},
},
Services: map[string]*config.Service{
"service-1": {},
HTTP: &config.HTTPConfiguration{
Routers: map[string]*config.Router{
"router-1": {},
},
Middlewares: map[string]*config.Middleware{
"middleware-1": {},
},
Services: map[string]*config.Service{
"service-1": {},
},
},
},
"provider-2": &config.Configuration{
Routers: map[string]*config.Router{
"router-1": {},
},
Middlewares: map[string]*config.Middleware{
"middleware-1": {},
},
Services: map[string]*config.Service{
"service-1": {},
HTTP: &config.HTTPConfiguration{
Routers: map[string]*config.Router{
"router-1": {},
},
Middlewares: map[string]*config.Middleware{
"middleware-1": {},
},
Services: map[string]*config.Service{
"service-1": {},
},
},
},
},
expected: config.Configuration{
expected: &config.HTTPConfiguration{
Routers: map[string]*config.Router{
"provider-1.router-1": {},
"provider-2.router-1": {},
@ -96,8 +102,9 @@ func TestAggregator(t *testing.T) {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
actual := mergeConfiguration(test.given)
assert.Equal(t, test.expected, actual)
assert.Equal(t, test.expected, actual.HTTP)
})
}
}

View file

@ -44,7 +44,7 @@ type Builder struct {
}
type serviceBuilder interface {
Build(ctx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error)
BuildHTTP(ctx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error)
}
// NewBuilder creates a new Builder
@ -57,9 +57,9 @@ func (b *Builder) BuildChain(ctx context.Context, middlewares []string) *alice.C
chain := alice.New()
for _, name := range middlewares {
middlewareName := internal.GetQualifiedName(ctx, name)
constructorContext := internal.AddProviderInContext(ctx, middlewareName)
chain = chain.Append(func(next http.Handler) (http.Handler, error) {
constructorContext := internal.AddProviderInContext(ctx, middlewareName)
if _, ok := b.configs[middlewareName]; !ok {
return nil, fmt.Errorf("middleware %q does not exist", middlewareName)
}

View file

@ -45,8 +45,8 @@ type Manager struct {
}
// BuildHandlers Builds handler for all entry points
func (m *Manager) BuildHandlers(rootCtx context.Context, entryPoints []string) map[string]http.Handler {
entryPointsRouters := m.filteredRouters(rootCtx, entryPoints)
func (m *Manager) BuildHandlers(rootCtx context.Context, entryPoints []string, tls bool) map[string]http.Handler {
entryPointsRouters := m.filteredRouters(rootCtx, entryPoints, tls)
entryPointHandlers := make(map[string]http.Handler)
for entryPointName, routers := range entryPointsRouters {
@ -84,10 +84,14 @@ func contains(entryPoints []string, entryPointName string) bool {
return false
}
func (m *Manager) filteredRouters(ctx context.Context, entryPoints []string) map[string]map[string]*config.Router {
func (m *Manager) filteredRouters(ctx context.Context, entryPoints []string, tls bool) map[string]map[string]*config.Router {
entryPointsRouters := make(map[string]map[string]*config.Router)
for rtName, rt := range m.configs {
if (tls && rt.TLS == nil) || (!tls && rt.TLS != nil) {
continue
}
eps := rt.EntryPoints
if len(eps) == 0 {
eps = entryPoints
@ -155,7 +159,7 @@ func (m *Manager) buildRouterHandler(ctx context.Context, routerName string) (ht
return nil, fmt.Errorf("no configuration for %s", routerName)
}
handler, err := m.buildHandler(ctx, configRouter, routerName)
handler, err := m.buildHTTPHandler(ctx, configRouter, routerName)
if err != nil {
return nil, err
}
@ -173,10 +177,10 @@ func (m *Manager) buildRouterHandler(ctx context.Context, routerName string) (ht
return m.routerHandlers[routerName], nil
}
func (m *Manager) buildHandler(ctx context.Context, router *config.Router, routerName string) (http.Handler, error) {
func (m *Manager) buildHTTPHandler(ctx context.Context, router *config.Router, routerName string) (http.Handler, error) {
rm := m.modifierBuilder.Build(ctx, router.Middlewares)
sHandler, err := m.serviceManager.Build(ctx, router.Service, rm)
sHandler, err := m.serviceManager.BuildHTTP(ctx, router.Service, rm)
if err != nil {
return nil, err
}

View file

@ -2,8 +2,10 @@ package router
import (
"context"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/containous/traefik/config"
@ -318,7 +320,7 @@ func TestRouterManager_Get(t *testing.T) {
routerManager := NewManager(test.routersConfig, serviceManager, middlewaresBuilder, responseModifierFactory)
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints)
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints, false)
w := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
@ -417,7 +419,7 @@ func TestAccessLog(t *testing.T) {
routerManager := NewManager(test.routersConfig, serviceManager, middlewaresBuilder, responseModifierFactory)
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints)
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints, false)
w := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
@ -440,3 +442,90 @@ func TestAccessLog(t *testing.T) {
})
}
}
type staticTransport struct {
res *http.Response
}
func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
return t.res, nil
}
func BenchmarkRouterServe(b *testing.B) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
res := &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("")),
}
routersConfig := map[string]*config.Router{
"foo": {
EntryPoints: []string{"web"},
Service: "foo-service",
Rule: "Host(`foo.bar`) && Path(`/`)",
},
}
serviceConfig := map[string]*config.Service{
"foo-service": {
LoadBalancer: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: server.URL,
Weight: 1,
},
},
Method: "wrr",
},
},
}
entryPoints := []string{"web"}
serviceManager := service.NewManager(serviceConfig, &staticTransport{res})
middlewaresBuilder := middleware.NewBuilder(map[string]*config.Middleware{}, serviceManager)
responseModifierFactory := responsemodifiers.NewBuilder(map[string]*config.Middleware{})
routerManager := NewManager(routersConfig, serviceManager, middlewaresBuilder, responseModifierFactory)
handlers := routerManager.BuildHandlers(context.Background(), entryPoints, false)
w := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
reqHost := requestdecorator.New(nil)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
reqHost.ServeHTTP(w, req, handlers["web"].ServeHTTP)
}
}
func BenchmarkService(b *testing.B) {
res := &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("")),
}
serviceConfig := map[string]*config.Service{
"foo-service": {
LoadBalancer: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: "tchouck",
Weight: 1,
},
},
Method: "wrr",
},
},
}
serviceManager := service.NewManager(serviceConfig, &staticTransport{res})
w := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
handler, _ := serviceManager.BuildHTTP(context.Background(), "foo-service", nil)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
handler.ServeHTTP(w, req)
}
}

139
server/router/tcp/router.go Normal file
View file

@ -0,0 +1,139 @@
package tcp
import (
"context"
"crypto/tls"
"net/http"
"github.com/containous/traefik/config"
"github.com/containous/traefik/log"
"github.com/containous/traefik/rules"
"github.com/containous/traefik/server/internal"
tcpservice "github.com/containous/traefik/server/service/tcp"
"github.com/containous/traefik/tcp"
)
// NewManager Creates a new Manager
func NewManager(routers map[string]*config.TCPRouter,
serviceManager *tcpservice.Manager,
httpHandlers map[string]http.Handler,
httpsHandlers map[string]http.Handler,
tlsConfig *tls.Config,
) *Manager {
return &Manager{
configs: routers,
serviceManager: serviceManager,
httpHandlers: httpHandlers,
httpsHandlers: httpsHandlers,
tlsConfig: tlsConfig,
}
}
// Manager is a route/router manager
type Manager struct {
configs map[string]*config.TCPRouter
serviceManager *tcpservice.Manager
httpHandlers map[string]http.Handler
httpsHandlers map[string]http.Handler
tlsConfig *tls.Config
}
// BuildHandlers builds the handlers for the given entrypoints
func (m *Manager) BuildHandlers(rootCtx context.Context, entryPoints []string) map[string]*tcp.Router {
entryPointsRouters := m.filteredRouters(rootCtx, entryPoints)
entryPointHandlers := make(map[string]*tcp.Router)
for _, entryPointName := range entryPoints {
entryPointName := entryPointName
routers := entryPointsRouters[entryPointName]
ctx := log.With(rootCtx, log.Str(log.EntryPointName, entryPointName))
handler, err := m.buildEntryPointHandler(ctx, routers, m.httpHandlers[entryPointName], m.httpsHandlers[entryPointName])
if err != nil {
log.FromContext(ctx).Error(err)
continue
}
entryPointHandlers[entryPointName] = handler
}
return entryPointHandlers
}
func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string]*config.TCPRouter, handlerHTTP http.Handler, handlerHTTPS http.Handler) (*tcp.Router, error) {
router := &tcp.Router{}
router.HTTPHandler(handlerHTTP)
router.HTTPSHandler(handlerHTTPS, m.tlsConfig)
for routerName, routerConfig := range configs {
ctxRouter := log.With(ctx, log.Str(log.RouterName, routerName))
logger := log.FromContext(ctxRouter)
ctxRouter = internal.AddProviderInContext(ctxRouter, routerName)
handler, err := m.serviceManager.BuildTCP(ctxRouter, routerConfig.Service)
if err != nil {
logger.Error(err)
continue
}
domains, err := rules.ParseHostSNI(routerConfig.Rule)
if err != nil {
log.WithoutContext().Debugf("Unknown rule %s", routerConfig.Rule)
continue
}
for _, domain := range domains {
log.WithoutContext().Debugf("Add route %s on TCP", domain)
switch {
case routerConfig.TLS != nil:
if routerConfig.TLS.Passthrough {
router.AddRoute(domain, handler)
} else {
router.AddRouteTLS(domain, handler, m.tlsConfig)
}
case domain == "*":
router.AddCatchAllNoTLS(handler)
default:
logger.Warn("TCP Router ignored, cannot specify a Host rule without TLS")
}
}
}
return router, nil
}
func contains(entryPoints []string, entryPointName string) bool {
for _, name := range entryPoints {
if name == entryPointName {
return true
}
}
return false
}
func (m *Manager) filteredRouters(ctx context.Context, entryPoints []string) map[string]map[string]*config.TCPRouter {
entryPointsRouters := make(map[string]map[string]*config.TCPRouter)
for rtName, rt := range m.configs {
eps := rt.EntryPoints
if len(eps) == 0 {
eps = entryPoints
}
for _, entryPointName := range eps {
if !contains(entryPoints, entryPointName) {
log.FromContext(log.With(ctx, log.Str(log.EntryPointName, entryPointName))).
Errorf("entryPoint %q doesn't exist", entryPointName)
continue
}
if _, ok := entryPointsRouters[entryPointName]; !ok {
entryPointsRouters[entryPointName] = make(map[string]*config.TCPRouter)
}
entryPointsRouters[entryPointName][rtName] = rt
}
}
return entryPointsRouters
}

View file

@ -9,7 +9,6 @@ import (
"sync"
"time"
"github.com/containous/traefik/cluster"
"github.com/containous/traefik/config"
"github.com/containous/traefik/config/static"
"github.com/containous/traefik/log"
@ -19,6 +18,7 @@ import (
"github.com/containous/traefik/provider"
"github.com/containous/traefik/safe"
"github.com/containous/traefik/server/middleware"
"github.com/containous/traefik/tls"
"github.com/containous/traefik/tracing"
"github.com/containous/traefik/tracing/datadog"
"github.com/containous/traefik/tracing/instana"
@ -29,7 +29,7 @@ import (
// Server is the reverse-proxy/load-balancer engine
type Server struct {
entryPoints EntryPoints
entryPointsTCP TCPEntryPoints
configurationChan chan config.Message
configurationValidatedChan chan config.Message
signals chan os.Signal
@ -39,13 +39,13 @@ type Server struct {
accessLoggerMiddleware *accesslog.Handler
tracer *tracing.Tracing
routinesPool *safe.Pool
leadership *cluster.Leadership //FIXME Cluster
defaultRoundTripper http.RoundTripper
metricsRegistry metrics.Registry
provider provider.Provider
configurationListeners []func(config.Configuration)
requestDecorator *requestdecorator.RequestDecorator
providersThrottleDuration time.Duration
tlsManager *tls.Manager
}
// RouteAppenderFactory the route appender factory interface
@ -70,11 +70,11 @@ func setupTracing(conf *static.Tracing) tracing.TrackingBackend {
}
// NewServer returns an initialized Server.
func NewServer(staticConfiguration static.Configuration, provider provider.Provider, entryPoints EntryPoints) *Server {
func NewServer(staticConfiguration static.Configuration, provider provider.Provider, entryPoints TCPEntryPoints, tlsManager *tls.Manager) *Server {
server := &Server{}
server.provider = provider
server.entryPoints = entryPoints
server.entryPointsTCP = entryPoints
server.configurationChan = make(chan config.Message, 100)
server.configurationValidatedChan = make(chan config.Message, 100)
server.signals = make(chan os.Signal, 1)
@ -83,6 +83,7 @@ func NewServer(staticConfiguration static.Configuration, provider provider.Provi
currentConfigurations := make(config.Configurations)
server.currentConfigurations.Set(currentConfigurations)
server.providerConfigUpdateMap = make(map[string]chan config.Message)
server.tlsManager = tlsManager
if staticConfiguration.Providers != nil {
server.providersThrottleDuration = time.Duration(staticConfiguration.Providers.ProvidersThrottleDuration)
@ -132,8 +133,7 @@ func (s *Server) Start(ctx context.Context) {
s.Stop()
}()
s.startHTTPServers()
s.startLeadership()
s.startTCPServers()
s.routinesPool.Go(func(stop chan bool) {
s.listenProviders(stop)
})
@ -156,9 +156,9 @@ func (s *Server) Stop() {
defer log.WithoutContext().Info("Server stopped")
var wg sync.WaitGroup
for epn, ep := range s.entryPoints {
for epn, ep := range s.entryPointsTCP {
wg.Add(1)
go func(entryPointName string, entryPoint *EntryPoint) {
go func(entryPointName string, entryPoint *TCPEntryPoint) {
ctx := log.With(context.Background(), log.Str(log.EntryPointName, entryPointName))
defer wg.Done()
@ -184,7 +184,6 @@ func (s *Server) Close() {
}(ctx)
stopMetricsClients()
s.stopLeadership()
s.routinesPool.Cleanup()
close(s.configurationChan)
close(s.configurationValidatedChan)
@ -205,28 +204,16 @@ func (s *Server) Close() {
cancel()
}
func (s *Server) startLeadership() {
if s.leadership != nil {
s.leadership.Participate(s.routinesPool)
}
}
func (s *Server) stopLeadership() {
if s.leadership != nil {
s.leadership.Stop()
}
}
func (s *Server) startHTTPServers() {
func (s *Server) startTCPServers() {
// Use an empty configuration in order to initialize the default handlers with internal routes
handlers := s.applyConfiguration(context.Background(), config.Configuration{})
for entryPointName, handler := range handlers {
s.entryPoints[entryPointName].switcher.UpdateHandler(handler)
routers := s.loadConfigurationTCP(config.Configurations{})
for entryPointName, router := range routers {
s.entryPointsTCP[entryPointName].switchRouter(router)
}
for entryPointName, entryPoint := range s.entryPoints {
for entryPointName, serverEntryPoint := range s.entryPointsTCP {
ctx := log.With(context.Background(), log.Str(log.EntryPointName, entryPointName))
go entryPoint.Start(ctx)
go serverEntryPoint.startTCP(ctx)
}
}

View file

@ -4,7 +4,6 @@ import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"net/http"
"reflect"
"time"
@ -19,16 +18,16 @@ import (
"github.com/containous/traefik/responsemodifiers"
"github.com/containous/traefik/server/middleware"
"github.com/containous/traefik/server/router"
routertcp "github.com/containous/traefik/server/router/tcp"
"github.com/containous/traefik/server/service"
traefiktls "github.com/containous/traefik/tls"
"github.com/containous/traefik/server/service/tcp"
tcpCore "github.com/containous/traefik/tcp"
"github.com/eapache/channels"
"github.com/sirupsen/logrus"
)
// loadConfiguration manages dynamically routers, middlewares, servers and TLS configurations
func (s *Server) loadConfiguration(configMsg config.Message) {
logger := log.FromContext(log.With(context.Background(), log.Str(log.ProviderName, configMsg.ProviderName)))
currentConfigurations := s.currentConfigurations.Get().(config.Configurations)
// Copy configurations to new map so we don't change current if LoadConfig fails
@ -40,27 +39,13 @@ func (s *Server) loadConfiguration(configMsg config.Message) {
s.metricsRegistry.ConfigReloadsCounter().Add(1)
handlers, certificates := s.loadConfig(newConfigurations)
handlersTCP := s.loadConfigurationTCP(newConfigurations)
for entryPointName, router := range handlersTCP {
s.entryPointsTCP[entryPointName].switchRouter(router)
}
s.metricsRegistry.LastConfigReloadSuccessGauge().Set(float64(time.Now().Unix()))
for entryPointName, handler := range handlers {
s.entryPoints[entryPointName].switcher.UpdateHandler(handler)
}
for entryPointName, entryPoint := range s.entryPoints {
eLogger := logger.WithField(log.EntryPointName, entryPointName)
if entryPoint.Certs == nil {
if len(certificates[entryPointName]) > 0 {
eLogger.Debugf("Cannot configure certificates for the non-TLS %s entryPoint.", entryPointName)
}
} else {
entryPoint.Certs.DynamicCerts.Set(certificates[entryPointName])
entryPoint.Certs.ResetCache()
}
eLogger.Infof("Server configuration reloaded on %s", s.entryPoints[entryPointName].httpServer.Addr)
}
s.currentConfigurations.Set(newConfigurations)
for _, listener := range s.configurationListeners {
@ -70,35 +55,48 @@ func (s *Server) loadConfiguration(configMsg config.Message) {
s.postLoadConfiguration()
}
// loadConfig returns a new gorilla.mux Route from the specified global configuration and the dynamic
// loadConfigurationTCP returns a new gorilla.mux Route from the specified global configuration and the dynamic
// provider configurations.
func (s *Server) loadConfig(configurations config.Configurations) (map[string]http.Handler, map[string]map[string]*tls.Certificate) {
func (s *Server) loadConfigurationTCP(configurations config.Configurations) map[string]*tcpCore.Router {
ctx := context.TODO()
conf := mergeConfiguration(configurations)
handlers := s.applyConfiguration(ctx, conf)
// Get new certificates list sorted per entry points
// Update certificates
entryPointsCertificates := s.loadHTTPSConfiguration(configurations)
return handlers, entryPointsCertificates
}
func (s *Server) applyConfiguration(ctx context.Context, configuration config.Configuration) map[string]http.Handler {
var entryPoints []string
for entryPointName := range s.entryPoints {
for entryPointName := range s.entryPointsTCP {
entryPoints = append(entryPoints, entryPointName)
}
conf := mergeConfiguration(configurations)
s.tlsManager.UpdateConfigs(conf.TLSStores, conf.TLSOptions, conf.TLS)
handlersNonTLS, handlersTLS := s.createHTTPHandlers(ctx, *conf.HTTP, entryPoints)
routersTCP := s.createTCPRouters(ctx, conf.TCP, entryPoints, handlersNonTLS, handlersTLS, s.tlsManager.Get("default", "default"))
return routersTCP
}
func (s *Server) createTCPRouters(ctx context.Context, configuration *config.TCPConfiguration, entryPoints []string, handlers map[string]http.Handler, handlersTLS map[string]http.Handler, tlsConfig *tls.Config) map[string]*tcpCore.Router {
if configuration == nil {
return make(map[string]*tcpCore.Router)
}
serviceManager := tcp.NewManager(configuration.Services)
routerManager := routertcp.NewManager(configuration.Routers, serviceManager, handlers, handlersTLS, tlsConfig)
return routerManager.BuildHandlers(ctx, entryPoints)
}
func (s *Server) createHTTPHandlers(ctx context.Context, configuration config.HTTPConfiguration, entryPoints []string) (map[string]http.Handler, map[string]http.Handler) {
serviceManager := service.NewManager(configuration.Services, s.defaultRoundTripper)
middlewaresBuilder := middleware.NewBuilder(configuration.Middlewares, serviceManager)
responseModifierFactory := responsemodifiers.NewBuilder(configuration.Middlewares)
routerManager := router.NewManager(configuration.Routers, serviceManager, middlewaresBuilder, responseModifierFactory)
handlers := routerManager.BuildHandlers(ctx, entryPoints)
handlersNonTLS := routerManager.BuildHandlers(ctx, entryPoints, false)
handlersTLS := routerManager.BuildHandlers(ctx, entryPoints, true)
routerHandlers := make(map[string]http.Handler)
@ -108,14 +106,14 @@ func (s *Server) applyConfiguration(ctx context.Context, configuration config.Co
ctx = log.With(ctx, log.Str(log.EntryPointName, entryPointName))
factory := s.entryPoints[entryPointName].RouteAppenderFactory
factory := s.entryPointsTCP[entryPointName].RouteAppenderFactory
if factory != nil {
// FIXME remove currentConfigurations
appender := factory.NewAppender(ctx, middlewaresBuilder, &s.currentConfigurations)
appender.Append(internalMuxRouter)
}
if h, ok := handlers[entryPointName]; ok {
if h, ok := handlersNonTLS[entryPointName]; ok {
internalMuxRouter.NotFoundHandler = h
} else {
internalMuxRouter.NotFoundHandler = buildDefaultHTTPRouter()
@ -141,13 +139,42 @@ func (s *Server) applyConfiguration(ctx context.Context, configuration config.Co
continue
}
internalMuxRouter.NotFoundHandler = handler
handlerTLS, ok := handlersTLS[entryPointName]
if ok {
handlerTLSWithMiddlewares, err := chain.Then(handlerTLS)
if err != nil {
log.FromContext(ctx).Error(err)
continue
}
handlersTLS[entryPointName] = handlerTLSWithMiddlewares
}
}
return routerHandlers
return routerHandlers, handlersTLS
}
func isEmptyConfiguration(conf *config.Configuration) bool {
if conf == nil {
return true
}
if conf.TCP == nil {
conf.TCP = &config.TCPConfiguration{}
}
if conf.HTTP == nil {
conf.HTTP = &config.HTTPConfiguration{}
}
return conf.HTTP.Routers == nil &&
conf.HTTP.Services == nil &&
conf.HTTP.Middlewares == nil &&
conf.TLS == nil &&
conf.TCP.Routers == nil &&
conf.TCP.Services == nil
}
func (s *Server) preLoadConfiguration(configMsg config.Message) {
s.defaultConfigurationValues(configMsg.Configuration)
s.defaultConfigurationValues(configMsg.Configuration.HTTP)
currentConfigurations := s.currentConfigurations.Get().(config.Configurations)
logger := log.WithoutContext().WithField(log.ProviderName, configMsg.ProviderName)
@ -156,7 +183,7 @@ func (s *Server) preLoadConfiguration(configMsg config.Message) {
logger.Debugf("Configuration received from provider %s: %s", configMsg.ProviderName, string(jsonConf))
}
if configMsg.Configuration == nil || configMsg.Configuration.Routers == nil && configMsg.Configuration.Services == nil && configMsg.Configuration.Middlewares == nil && configMsg.Configuration.TLS == nil {
if isEmptyConfiguration(configMsg.Configuration) {
logger.Infof("Skipping empty Configuration for provider %s", configMsg.ProviderName)
return
}
@ -178,7 +205,7 @@ func (s *Server) preLoadConfiguration(configMsg config.Message) {
providerConfigUpdateCh <- configMsg
}
func (s *Server) defaultConfigurationValues(configuration *config.Configuration) {
func (s *Server) defaultConfigurationValues(configuration *config.HTTPConfiguration) {
// FIXME create a config hook
}
@ -235,59 +262,6 @@ func (s *Server) postLoadConfiguration() {
// metrics.OnConfigurationUpdate(activeConfig)
// }
// FIXME acme
// if s.staticConfiguration.ACME == nil || s.leadership == nil || !s.leadership.IsLeader() {
// return
// }
//
// if s.staticConfiguration.ACME.OnHostRule {
// currentConfigurations := s.currentConfigurations.Get().(config.Configurations)
// for _, config := range currentConfigurations {
// for _, frontend := range config.Frontends {
//
// // check if one of the frontend entrypoints is configured with TLS
// // and is configured with ACME
// acmeEnabled := false
// for _, entryPoint := range frontend.EntryPoints {
// if s.staticConfiguration.ACME.EntryPoint == entryPoint && s.entryPoints[entryPoint].Configuration.TLS != nil {
// acmeEnabled = true
// break
// }
// }
//
// if acmeEnabled {
// for _, route := range frontend.Routes {
// rls := rules.Rules{}
// domains, err := rls.ParseDomains(route.Rule)
// if err != nil {
// log.Errorf("Error parsing domains: %v", err)
// } else if len(domains) == 0 {
// log.Debugf("No domain parsed in rule %q", route.Rule)
// } else {
// s.staticConfiguration.ACME.LoadCertificateForDomains(domains)
// }
// }
// }
// }
// }
// }
}
// loadHTTPSConfiguration add/delete HTTPS certificate managed dynamically
func (s *Server) loadHTTPSConfiguration(configurations config.Configurations) map[string]map[string]*tls.Certificate {
var entryPoints []string
for entryPointName := range s.entryPoints {
entryPoints = append(entryPoints, entryPointName)
}
newEPCertificates := make(map[string]map[string]*tls.Certificate)
// Get all certificates
for _, config := range configurations {
if config.TLS != nil && len(config.TLS) > 0 {
traefiktls.SortTLSPerEntryPoints(config.TLS, newEPCertificates, entryPoints)
}
}
return newEPCertificates
}
func buildDefaultHTTPRouter() *mux.Router {
@ -296,21 +270,3 @@ func buildDefaultHTTPRouter() *mux.Router {
rt.SkipClean(true)
return rt
}
func buildDefaultCertificate(defaultCertificate *traefiktls.Certificate) (*tls.Certificate, error) {
certFile, err := defaultCertificate.CertFile.Read()
if err != nil {
return nil, fmt.Errorf("failed to get cert file content: %v", err)
}
keyFile, err := defaultCertificate.KeyFile.Read()
if err != nil {
return nil, fmt.Errorf("failed to get key file content: %v", err)
}
cert, err := tls.X509KeyPair(certFile, keyFile)
if err != nil {
return nil, fmt.Errorf("failed to load X509 key pair: %v", err)
}
return &cert, nil
}

View file

@ -1,6 +1,7 @@
package server
import (
"context"
"net/http"
"net/http/httptest"
"testing"
@ -9,115 +10,44 @@ import (
"github.com/containous/traefik/config"
"github.com/containous/traefik/config/static"
th "github.com/containous/traefik/testhelpers"
"github.com/containous/traefik/tls"
"github.com/stretchr/testify/assert"
)
// LocalhostCert is a PEM-encoded TLS cert with SAN IPs
// "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT.
// generated from src/crypto/tls:
// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
var (
localhostCert = tls.FileOrContent(`-----BEGIN CERTIFICATE-----
MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS
MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw
MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB
iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4
iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul
rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO
BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw
AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA
AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9
tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs
h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM
fblo6RBxUQ==
-----END CERTIFICATE-----`)
// LocalhostKey is the private key for localhostCert.
localhostKey = tls.FileOrContent(`-----BEGIN RSA PRIVATE KEY-----
MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9
SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB
l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB
AoGAGRzwwir7XvBOAy5tM/uV6e+Zf6anZzus1s1Y1ClbjbE6HXbnWWF/wbZGOpet
3Zm4vD6MXc7jpTLryzTQIvVdfQbRc6+MUVeLKwZatTXtdZrhu+Jk7hx0nTPy8Jcb
uJqFk541aEw+mMogY/xEcfbWd6IOkp+4xqjlFLBEDytgbIECQQDvH/E6nk+hgN4H
qzzVtxxr397vWrjrIgPbJpQvBsafG7b0dA4AFjwVbFLmQcj2PprIMmPcQrooz8vp
jy4SHEg1AkEA/v13/5M47K9vCxmb8QeD/asydfsgS5TeuNi8DoUBEmiSJwma7FXY
fFUtxuvL7XvjwjN5B30pNEbc6Iuyt7y4MQJBAIt21su4b3sjXNueLKH85Q+phy2U
fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU
y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX
qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo
f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA==
-----END RSA PRIVATE KEY-----`)
)
func TestServerLoadCertificateWithTLSEntryPoints(t *testing.T) {
staticConfig := static.Configuration{}
dynamicConfigs := config.Configurations{
"config": &config.Configuration{
TLS: []*tls.Configuration{
{
Certificate: &tls.Certificate{
CertFile: localhostCert,
KeyFile: localhostKey,
},
},
},
},
}
srv := NewServer(staticConfig, nil, EntryPoints{
"https": &EntryPoint{
Certs: tls.NewCertificateStore(),
},
"https2": &EntryPoint{
Certs: tls.NewCertificateStore(),
},
})
_, mapsCerts := srv.loadConfig(dynamicConfigs)
if len(mapsCerts["https"]) == 0 || len(mapsCerts["https2"]) == 0 {
t.Fatal("got error: https entryPoint must have TLS certificates.")
}
}
func TestReuseService(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
}))
defer testServer.Close()
entryPoints := EntryPoints{
"http": &EntryPoint{},
entryPoints := TCPEntryPoints{
"http": &TCPEntryPoint{},
}
staticConfig := static.Configuration{}
dynamicConfigs := config.Configurations{
"config": th.BuildConfiguration(
th.WithRouters(
th.WithRouter("foo",
th.WithServiceName("bar"),
th.WithRule("Path(`/ok`)")),
th.WithRouter("foo2",
th.WithEntryPoints("http"),
th.WithRule("Path(`/unauthorized`)"),
th.WithServiceName("bar"),
th.WithRouterMiddlewares("basicauth")),
),
th.WithMiddlewares(th.WithMiddleware("basicauth",
th.WithBasicAuth(&config.BasicAuth{Users: []string{"foo:bar"}}),
)),
th.WithLoadBalancerServices(th.WithService("bar",
th.WithLBMethod("wrr"),
th.WithServers(th.WithServer(testServer.URL))),
),
dynamicConfigs := th.BuildConfiguration(
th.WithRouters(
th.WithRouter("foo",
th.WithServiceName("bar"),
th.WithRule("Path(`/ok`)")),
th.WithRouter("foo2",
th.WithEntryPoints("http"),
th.WithRule("Path(`/unauthorized`)"),
th.WithServiceName("bar"),
th.WithRouterMiddlewares("basicauth")),
),
}
th.WithMiddlewares(th.WithMiddleware("basicauth",
th.WithBasicAuth(&config.BasicAuth{Users: []string{"foo:bar"}}),
)),
th.WithLoadBalancerServices(th.WithService("bar",
th.WithLBMethod("wrr"),
th.WithServers(th.WithServer(testServer.URL))),
),
)
srv := NewServer(staticConfig, nil, entryPoints)
srv := NewServer(staticConfig, nil, entryPoints, nil)
entrypointsHandlers, _ := srv.loadConfig(dynamicConfigs)
entrypointsHandlers, _ := srv.createHTTPHandlers(context.Background(), *dynamicConfigs, []string{"http"})
// Test that the /ok path returns a status 200.
responseRecorderOk := &httptest.ResponseRecorder{}
@ -145,7 +75,7 @@ func TestThrottleProviderConfigReload(t *testing.T) {
}()
staticConfiguration := static.Configuration{}
server := NewServer(staticConfiguration, nil, nil)
server := NewServer(staticConfiguration, nil, nil, nil)
go server.throttleProviderConfigReload(throttleDuration, publishConfig, providerConfig, stop)

View file

@ -1,434 +0,0 @@
package server
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
stdlog "log"
"net"
"net/http"
"sync"
"time"
"github.com/armon/go-proxyproto"
"github.com/containous/traefik/config/static"
"github.com/containous/traefik/h2c"
"github.com/containous/traefik/ip"
"github.com/containous/traefik/log"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/middlewares/forwardedheaders"
"github.com/containous/traefik/old/configuration"
traefiktls "github.com/containous/traefik/tls"
"github.com/containous/traefik/tls/generate"
"github.com/containous/traefik/types"
"github.com/sirupsen/logrus"
"github.com/xenolf/lego/challenge/tlsalpn01"
)
// EntryPoints map of EntryPoint
type EntryPoints map[string]*EntryPoint
// NewEntryPoint creates a new EntryPoint
func NewEntryPoint(ctx context.Context, configuration *static.EntryPoint) (*EntryPoint, error) {
var err error
switcher := middlewares.NewHandlerSwitcher(buildDefaultHTTPRouter())
handler, err := forwardedheaders.NewXForwarded(
configuration.ForwardedHeaders.Insecure,
configuration.ForwardedHeaders.TrustedIPs,
switcher)
if err != nil {
return nil, err
}
tracker := newHijackConnectionTracker()
listener, err := buildListener(ctx, configuration)
if err != nil {
return nil, fmt.Errorf("error preparing server: %v", err)
}
var tlsConfig *tls.Config
var certificateStore *traefiktls.CertificateStore
if configuration.TLS != nil {
certificateStore, err = buildCertificateStore(*configuration.TLS)
if err != nil {
return nil, fmt.Errorf("error creating certificate store: %v", err)
}
tlsConfig, err = buildTLSConfig(*configuration.TLS)
if err != nil {
return nil, fmt.Errorf("error creating TLS config: %v", err)
}
}
entryPoint := &EntryPoint{
switcher: switcher,
transportConfiguration: configuration.Transport,
hijackConnectionTracker: tracker,
listener: listener,
httpServer: buildServer(ctx, configuration, tlsConfig, handler, tracker),
Certs: certificateStore,
}
if tlsConfig != nil {
tlsConfig.GetCertificate = entryPoint.getCertificate
}
return entryPoint, nil
}
// EntryPoint holds everything about the entry point (httpServer, listener etc...)
type EntryPoint struct {
RouteAppenderFactory RouteAppenderFactory
httpServer *h2c.Server
listener net.Listener
switcher *middlewares.HandlerSwitcher
Certs *traefiktls.CertificateStore
OnDemandListener func(string) (*tls.Certificate, error)
TLSALPNGetter func(string) (*tls.Certificate, error)
hijackConnectionTracker *hijackConnectionTracker
transportConfiguration *static.EntryPointsTransport
}
// Start starts listening for traffic
func (s *EntryPoint) Start(ctx context.Context) {
logger := log.FromContext(ctx)
logger.Infof("Starting server on %s", s.httpServer.Addr)
var err error
if s.httpServer.TLSConfig != nil {
err = s.httpServer.ServeTLS(s.listener, "", "")
} else {
err = s.httpServer.Serve(s.listener)
}
if err != http.ErrServerClosed {
logger.Error("Cannot start server: %v", err)
}
}
// Shutdown handles the entrypoint shutdown process
func (s EntryPoint) Shutdown(ctx context.Context) {
logger := log.FromContext(ctx)
reqAcceptGraceTimeOut := time.Duration(s.transportConfiguration.LifeCycle.RequestAcceptGraceTimeout)
if reqAcceptGraceTimeOut > 0 {
logger.Infof("Waiting %s for incoming requests to cease", reqAcceptGraceTimeOut)
time.Sleep(reqAcceptGraceTimeOut)
}
graceTimeOut := time.Duration(s.transportConfiguration.LifeCycle.GraceTimeOut)
ctx, cancel := context.WithTimeout(ctx, graceTimeOut)
logger.Debugf("Waiting %s seconds before killing connections.", graceTimeOut)
var wg sync.WaitGroup
if s.httpServer != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := s.httpServer.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
logger.Debugf("Wait server shutdown is overdue to: %s", err)
err = s.httpServer.Close()
if err != nil {
logger.Error(err)
}
}
}
}()
}
if s.hijackConnectionTracker != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := s.hijackConnectionTracker.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
logger.Debugf("Wait hijack connection is overdue to: %s", err)
s.hijackConnectionTracker.Close()
}
}
}()
}
wg.Wait()
cancel()
}
// getCertificate allows to customize tlsConfig.GetCertificate behavior to get the certificates inserted dynamically
func (s *EntryPoint) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
domainToCheck := types.CanonicalDomain(clientHello.ServerName)
if s.TLSALPNGetter != nil {
cert, err := s.TLSALPNGetter(domainToCheck)
if err != nil {
return nil, err
}
if cert != nil {
return cert, nil
}
}
bestCertificate := s.Certs.GetBestCertificate(clientHello)
if bestCertificate != nil {
return bestCertificate, nil
}
if s.OnDemandListener != nil && len(domainToCheck) > 0 {
// Only check for an onDemandCert if there is a domain name
return s.OnDemandListener(domainToCheck)
}
if s.Certs.SniStrict {
return nil, fmt.Errorf("strict SNI enabled - No certificate found for domain: %q, closing connection", domainToCheck)
}
log.WithoutContext().Debugf("Serving default certificate for request: %q", domainToCheck)
return s.Certs.DefaultCertificate, nil
}
func newHijackConnectionTracker() *hijackConnectionTracker {
return &hijackConnectionTracker{
conns: make(map[net.Conn]struct{}),
}
}
type hijackConnectionTracker struct {
conns map[net.Conn]struct{}
lock sync.RWMutex
}
// AddHijackedConnection add a connection in the tracked connections list
func (h *hijackConnectionTracker) AddHijackedConnection(conn net.Conn) {
h.lock.Lock()
defer h.lock.Unlock()
h.conns[conn] = struct{}{}
}
// RemoveHijackedConnection remove a connection from the tracked connections list
func (h *hijackConnectionTracker) RemoveHijackedConnection(conn net.Conn) {
h.lock.Lock()
defer h.lock.Unlock()
delete(h.conns, conn)
}
// Shutdown wait for the connection closing
func (h *hijackConnectionTracker) Shutdown(ctx context.Context) error {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
h.lock.RLock()
if len(h.conns) == 0 {
return nil
}
h.lock.RUnlock()
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
}
}
// Close close all the connections in the tracked connections list
func (h *hijackConnectionTracker) Close() {
for conn := range h.conns {
if err := conn.Close(); err != nil {
log.WithoutContext().Errorf("Error while closing Hijacked connection: %v", err)
}
delete(h.conns, conn)
}
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections.
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
tc, err := ln.AcceptTCP()
if err != nil {
return nil, err
}
if err = tc.SetKeepAlive(true); err != nil {
return nil, err
}
if err = tc.SetKeepAlivePeriod(3 * time.Minute); err != nil {
return nil, err
}
return tc, nil
}
func buildProxyProtocolListener(ctx context.Context, entryPoint *static.EntryPoint, listener net.Listener) (net.Listener, error) {
var sourceCheck func(addr net.Addr) (bool, error)
if entryPoint.ProxyProtocol.Insecure {
sourceCheck = func(_ net.Addr) (bool, error) {
return true, nil
}
} else {
checker, err := ip.NewChecker(entryPoint.ProxyProtocol.TrustedIPs)
if err != nil {
return nil, err
}
sourceCheck = func(addr net.Addr) (bool, error) {
ipAddr, ok := addr.(*net.TCPAddr)
if !ok {
return false, fmt.Errorf("type error %v", addr)
}
return checker.ContainsIP(ipAddr.IP), nil
}
}
log.FromContext(ctx).Infof("Enabling ProxyProtocol for trusted IPs %v", entryPoint.ProxyProtocol.TrustedIPs)
return &proxyproto.Listener{
Listener: listener,
SourceCheck: sourceCheck,
}, nil
}
func buildServerTimeouts(entryPointsTransport static.EntryPointsTransport) (readTimeout, writeTimeout, idleTimeout time.Duration) {
readTimeout = time.Duration(0)
writeTimeout = time.Duration(0)
if entryPointsTransport.RespondingTimeouts != nil {
readTimeout = time.Duration(entryPointsTransport.RespondingTimeouts.ReadTimeout)
writeTimeout = time.Duration(entryPointsTransport.RespondingTimeouts.WriteTimeout)
}
if entryPointsTransport.RespondingTimeouts != nil {
idleTimeout = time.Duration(entryPointsTransport.RespondingTimeouts.IdleTimeout)
} else {
idleTimeout = configuration.DefaultIdleTimeout
}
return readTimeout, writeTimeout, idleTimeout
}
func buildListener(ctx context.Context, entryPoint *static.EntryPoint) (net.Listener, error) {
listener, err := net.Listen("tcp", entryPoint.Address)
if err != nil {
return nil, fmt.Errorf("error opening listener: %v", err)
}
listener = tcpKeepAliveListener{listener.(*net.TCPListener)}
if entryPoint.ProxyProtocol != nil {
listener, err = buildProxyProtocolListener(ctx, entryPoint, listener)
if err != nil {
return nil, fmt.Errorf("error creating proxy protocol listener: %v", err)
}
}
return listener, nil
}
func buildCertificateStore(tlsOption traefiktls.TLS) (*traefiktls.CertificateStore, error) {
certificateStore := traefiktls.NewCertificateStore()
certificateStore.DynamicCerts.Set(make(map[string]*tls.Certificate))
certificateStore.SniStrict = tlsOption.SniStrict
if tlsOption.DefaultCertificate != nil {
cert, err := buildDefaultCertificate(tlsOption.DefaultCertificate)
if err != nil {
return nil, err
}
certificateStore.DefaultCertificate = cert
} else {
cert, err := generate.DefaultCertificate()
if err != nil {
return nil, err
}
certificateStore.DefaultCertificate = cert
}
return certificateStore, nil
}
func buildServer(ctx context.Context, configuration *static.EntryPoint, tlsConfig *tls.Config, router http.Handler, tracker *hijackConnectionTracker) *h2c.Server {
logger := log.FromContext(ctx)
readTimeout, writeTimeout, idleTimeout := buildServerTimeouts(*configuration.Transport)
logger.
WithField("readTimeout", readTimeout).
WithField("writeTimeout", writeTimeout).
WithField("idleTimeout", idleTimeout).
Infof("Preparing server")
return &h2c.Server{
Server: &http.Server{
Addr: configuration.Address,
Handler: router,
TLSConfig: tlsConfig,
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
IdleTimeout: idleTimeout,
ErrorLog: stdlog.New(logger.WriterLevel(logrus.DebugLevel), "", 0),
ConnState: func(conn net.Conn, state http.ConnState) {
switch state {
case http.StateHijacked:
tracker.AddHijackedConnection(conn)
case http.StateClosed:
tracker.RemoveHijackedConnection(conn)
}
},
},
}
}
// creates a TLS config that allows terminating HTTPS for multiple domains using SNI
func buildTLSConfig(tlsOption traefiktls.TLS) (*tls.Config, error) {
conf := &tls.Config{}
// ensure http2 enabled
conf.NextProtos = []string{"h2", "http/1.1", tlsalpn01.ACMETLS1Protocol}
if len(tlsOption.ClientCA.Files) > 0 {
pool := x509.NewCertPool()
for _, caFile := range tlsOption.ClientCA.Files {
data, err := caFile.Read()
if err != nil {
return nil, err
}
ok := pool.AppendCertsFromPEM(data)
if !ok {
return nil, fmt.Errorf("invalid certificate(s) in %s", caFile)
}
}
conf.ClientCAs = pool
if tlsOption.ClientCA.Optional {
conf.ClientAuth = tls.VerifyClientCertIfGiven
} else {
conf.ClientAuth = tls.RequireAndVerifyClientCert
}
}
// Set the minimum TLS version if set in the config TOML
if minConst, exists := traefiktls.MinVersion[tlsOption.MinVersion]; exists {
conf.PreferServerCipherSuites = true
conf.MinVersion = minConst
}
// Set the list of CipherSuites if set in the config TOML
if tlsOption.CipherSuites != nil {
// if our list of CipherSuites is defined in the entryPoint config, we can re-initialize the suites list as empty
conf.CipherSuites = make([]uint16, 0)
for _, cipher := range tlsOption.CipherSuites {
if cipherConst, exists := traefiktls.CipherSuites[cipher]; exists {
conf.CipherSuites = append(conf.CipherSuites, cipherConst)
} else {
// CipherSuite listed in the toml does not exist in our listed
return nil, fmt.Errorf("invalid CipherSuite: %s", cipher)
}
}
}
return conf, nil
}

View file

@ -0,0 +1,385 @@
package server
import (
"context"
"fmt"
"net"
"net/http"
"sync"
"time"
"github.com/armon/go-proxyproto"
"github.com/containous/traefik/config/static"
"github.com/containous/traefik/h2c"
"github.com/containous/traefik/ip"
"github.com/containous/traefik/log"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/middlewares/forwardedheaders"
"github.com/containous/traefik/safe"
"github.com/containous/traefik/tcp"
)
type httpForwarder struct {
net.Listener
connChan chan net.Conn
}
func newHTTPForwarder(ln net.Listener) *httpForwarder {
return &httpForwarder{
Listener: ln,
connChan: make(chan net.Conn),
}
}
// ServeTCP uses the connection to serve it later in "Accept"
func (h *httpForwarder) ServeTCP(conn net.Conn) {
h.connChan <- conn
}
// Accept retrieves a served connection in ServeTCP
func (h *httpForwarder) Accept() (net.Conn, error) {
conn := <-h.connChan
return conn, nil
}
// TCPEntryPoints holds a map of TCPEntryPoint (the entrypoint names being the keys)
type TCPEntryPoints map[string]*TCPEntryPoint
// TCPEntryPoint is the TCP server
type TCPEntryPoint struct {
listener net.Listener
switcher *tcp.HandlerSwitcher
RouteAppenderFactory RouteAppenderFactory
transportConfiguration *static.EntryPointsTransport
tracker *connectionTracker
httpServer *httpServer
httpsServer *httpServer
}
// NewTCPEntryPoint creates a new TCPEntryPoint
func NewTCPEntryPoint(ctx context.Context, configuration *static.EntryPoint) (*TCPEntryPoint, error) {
tracker := newConnectionTracker()
listener, err := buildListener(ctx, configuration)
if err != nil {
return nil, fmt.Errorf("error preparing server: %v", err)
}
router := &tcp.Router{}
httpServer, err := createHTTPServer(listener, configuration, true)
if err != nil {
return nil, fmt.Errorf("error preparing httpServer: %v", err)
}
router.HTTPForwarder(httpServer.Forwarder)
httpsServer, err := createHTTPServer(listener, configuration, false)
if err != nil {
return nil, fmt.Errorf("error preparing httpsServer: %v", err)
}
router.HTTPSForwarder(httpsServer.Forwarder)
tcpSwitcher := &tcp.HandlerSwitcher{}
tcpSwitcher.Switch(router)
return &TCPEntryPoint{
listener: listener,
switcher: tcpSwitcher,
transportConfiguration: configuration.Transport,
tracker: tracker,
httpServer: httpServer,
httpsServer: httpsServer,
}, nil
}
func (e *TCPEntryPoint) startTCP(ctx context.Context) {
log.FromContext(ctx).Debugf("Start TCP Server")
for {
conn, err := e.listener.Accept()
if err != nil {
log.Error(err)
return
}
safe.Go(func() {
e.switcher.ServeTCP(newTrackedConnection(conn, e.tracker))
})
}
}
// Shutdown stops the TCP connections
func (e *TCPEntryPoint) Shutdown(ctx context.Context) {
logger := log.FromContext(ctx)
reqAcceptGraceTimeOut := time.Duration(e.transportConfiguration.LifeCycle.RequestAcceptGraceTimeout)
if reqAcceptGraceTimeOut > 0 {
logger.Infof("Waiting %s for incoming requests to cease", reqAcceptGraceTimeOut)
time.Sleep(reqAcceptGraceTimeOut)
}
graceTimeOut := time.Duration(e.transportConfiguration.LifeCycle.GraceTimeOut)
ctx, cancel := context.WithTimeout(ctx, graceTimeOut)
logger.Debugf("Waiting %s seconds before killing connections.", graceTimeOut)
var wg sync.WaitGroup
if e.httpServer.Server != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := e.httpServer.Server.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
logger.Debugf("Wait server shutdown is overdue to: %s", err)
err = e.httpServer.Server.Close()
if err != nil {
logger.Error(err)
}
}
}
}()
}
if e.httpsServer.Server != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := e.httpsServer.Server.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
logger.Debugf("Wait server shutdown is overdue to: %s", err)
err = e.httpsServer.Server.Close()
if err != nil {
logger.Error(err)
}
}
}
}()
}
if e.tracker != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := e.tracker.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
logger.Debugf("Wait hijack connection is overdue to: %s", err)
e.tracker.Close()
}
}
}()
}
wg.Wait()
cancel()
}
func (e *TCPEntryPoint) switchRouter(router *tcp.Router) {
router.HTTPForwarder(e.httpServer.Forwarder)
router.HTTPSForwarder(e.httpsServer.Forwarder)
httpHandler := router.GetHTTPHandler()
httpsHandler := router.GetHTTPSHandler()
if httpHandler == nil {
httpHandler = buildDefaultHTTPRouter()
}
if httpsHandler == nil {
httpsHandler = buildDefaultHTTPRouter()
}
e.httpServer.Switcher.UpdateHandler(httpHandler)
e.httpsServer.Switcher.UpdateHandler(httpsHandler)
e.switcher.Switch(router)
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
// connections.
type tcpKeepAliveListener struct {
*net.TCPListener
}
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
tc, err := ln.AcceptTCP()
if err != nil {
return nil, err
}
if err = tc.SetKeepAlive(true); err != nil {
return nil, err
}
if err = tc.SetKeepAlivePeriod(3 * time.Minute); err != nil {
return nil, err
}
return tc, nil
}
func buildProxyProtocolListener(ctx context.Context, entryPoint *static.EntryPoint, listener net.Listener) (net.Listener, error) {
var sourceCheck func(addr net.Addr) (bool, error)
if entryPoint.ProxyProtocol.Insecure {
sourceCheck = func(_ net.Addr) (bool, error) {
return true, nil
}
} else {
checker, err := ip.NewChecker(entryPoint.ProxyProtocol.TrustedIPs)
if err != nil {
return nil, err
}
sourceCheck = func(addr net.Addr) (bool, error) {
ipAddr, ok := addr.(*net.TCPAddr)
if !ok {
return false, fmt.Errorf("type error %v", addr)
}
return checker.ContainsIP(ipAddr.IP), nil
}
}
log.FromContext(ctx).Infof("Enabling ProxyProtocol for trusted IPs %v", entryPoint.ProxyProtocol.TrustedIPs)
return &proxyproto.Listener{
Listener: listener,
SourceCheck: sourceCheck,
}, nil
}
func buildListener(ctx context.Context, entryPoint *static.EntryPoint) (net.Listener, error) {
listener, err := net.Listen("tcp", entryPoint.Address)
if err != nil {
return nil, fmt.Errorf("error opening listener: %v", err)
}
listener = tcpKeepAliveListener{listener.(*net.TCPListener)}
if entryPoint.ProxyProtocol != nil {
listener, err = buildProxyProtocolListener(ctx, entryPoint, listener)
if err != nil {
return nil, fmt.Errorf("error creating proxy protocol listener: %v", err)
}
}
return listener, nil
}
func newConnectionTracker() *connectionTracker {
return &connectionTracker{
conns: make(map[net.Conn]struct{}),
}
}
type connectionTracker struct {
conns map[net.Conn]struct{}
lock sync.RWMutex
}
// AddConnection add a connection in the tracked connections list
func (c *connectionTracker) AddConnection(conn net.Conn) {
c.lock.Lock()
defer c.lock.Unlock()
c.conns[conn] = struct{}{}
}
// RemoveConnection remove a connection from the tracked connections list
func (c *connectionTracker) RemoveConnection(conn net.Conn) {
c.lock.Lock()
defer c.lock.Unlock()
delete(c.conns, conn)
}
// Shutdown wait for the connection closing
func (c *connectionTracker) Shutdown(ctx context.Context) error {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
c.lock.RLock()
if len(c.conns) == 0 {
return nil
}
c.lock.RUnlock()
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
}
}
// Close close all the connections in the tracked connections list
func (c *connectionTracker) Close() {
for conn := range c.conns {
if err := conn.Close(); err != nil {
log.WithoutContext().Errorf("Error while closing connection: %v", err)
}
delete(c.conns, conn)
}
}
type stoppableServer interface {
Shutdown(context.Context) error
Close() error
Serve(listener net.Listener) error
}
type httpServer struct {
Server stoppableServer
Forwarder *httpForwarder
Switcher *middlewares.HTTPHandlerSwitcher
}
func createHTTPServer(ln net.Listener, configuration *static.EntryPoint, withH2c bool) (*httpServer, error) {
httpSwitcher := middlewares.NewHandlerSwitcher(buildDefaultHTTPRouter())
handler, err := forwardedheaders.NewXForwarded(
configuration.ForwardedHeaders.Insecure,
configuration.ForwardedHeaders.TrustedIPs,
httpSwitcher)
if err != nil {
return nil, err
}
var serverHTTP stoppableServer
if withH2c {
serverHTTP = &h2c.Server{
Server: &http.Server{
Handler: handler,
},
}
} else {
serverHTTP = &http.Server{
Handler: handler,
}
}
listener := newHTTPForwarder(ln)
go func() {
err := serverHTTP.Serve(listener)
if err != nil {
log.Errorf("Error while starting server: %v", err)
}
}()
return &httpServer{
Server: serverHTTP,
Forwarder: listener,
Switcher: httpSwitcher,
}, nil
}
func newTrackedConnection(conn net.Conn, tracker *connectionTracker) *trackedConnection {
tracker.AddConnection(conn)
return &trackedConnection{
Conn: conn,
tracker: tracker,
}
}
type trackedConnection struct {
tracker *connectionTracker
net.Conn
}
func (t *trackedConnection) Close() error {
t.tracker.RemoveConnection(t.Conn)
return t.Conn.Close()
}

View file

@ -0,0 +1,141 @@
package server
import (
"bufio"
"context"
"net"
"net/http"
"testing"
"time"
"github.com/containous/traefik/config/static"
"github.com/containous/traefik/tcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestShutdownHTTP(t *testing.T) {
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
Address: ":0",
Transport: &static.EntryPointsTransport{
LifeCycle: &static.LifeCycle{
RequestAcceptGraceTimeout: 0,
GraceTimeOut: 5000000000,
},
},
ForwardedHeaders: &static.ForwardedHeaders{},
})
require.NoError(t, err)
go entryPoint.startTCP(context.Background())
router := &tcp.Router{}
router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(time.Second * 1)
rw.WriteHeader(http.StatusOK)
}))
entryPoint.switchRouter(router)
conn, err := net.Dial("tcp", entryPoint.listener.Addr().String())
require.NoError(t, err)
go entryPoint.Shutdown(context.Background())
request, err := http.NewRequest("GET", "http://127.0.0.1:8082", nil)
require.NoError(t, err)
err = request.Write(conn)
require.NoError(t, err)
resp, err := http.ReadResponse(bufio.NewReader(conn), request)
require.NoError(t, err)
assert.Equal(t, resp.StatusCode, http.StatusOK)
}
func TestShutdownHTTPHijacked(t *testing.T) {
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
Address: ":0",
Transport: &static.EntryPointsTransport{
LifeCycle: &static.LifeCycle{
RequestAcceptGraceTimeout: 0,
GraceTimeOut: 5000000000,
},
},
ForwardedHeaders: &static.ForwardedHeaders{},
})
require.NoError(t, err)
go entryPoint.startTCP(context.Background())
router := &tcp.Router{}
router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
conn, _, err := rw.(http.Hijacker).Hijack()
require.NoError(t, err)
time.Sleep(time.Second * 1)
resp := http.Response{StatusCode: http.StatusOK}
err = resp.Write(conn)
require.NoError(t, err)
}))
entryPoint.switchRouter(router)
conn, err := net.Dial("tcp", entryPoint.listener.Addr().String())
require.NoError(t, err)
go entryPoint.Shutdown(context.Background())
request, err := http.NewRequest("GET", "http://127.0.0.1:8082", nil)
require.NoError(t, err)
err = request.Write(conn)
require.NoError(t, err)
resp, err := http.ReadResponse(bufio.NewReader(conn), request)
require.NoError(t, err)
assert.Equal(t, resp.StatusCode, http.StatusOK)
}
func TestShutdownTCPConn(t *testing.T) {
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
Address: ":0",
Transport: &static.EntryPointsTransport{
LifeCycle: &static.LifeCycle{
RequestAcceptGraceTimeout: 0,
GraceTimeOut: 5000000000,
},
},
ForwardedHeaders: &static.ForwardedHeaders{},
})
require.NoError(t, err)
go entryPoint.startTCP(context.Background())
router := &tcp.Router{}
router.AddCatchAllNoTLS(tcp.HandlerFunc(func(conn net.Conn) {
_, err := http.ReadRequest(bufio.NewReader(conn))
require.NoError(t, err)
time.Sleep(time.Second * 1)
resp := http.Response{StatusCode: http.StatusOK}
err = resp.Write(conn)
require.NoError(t, err)
}))
entryPoint.switchRouter(router)
conn, err := net.Dial("tcp", entryPoint.listener.Addr().String())
require.NoError(t, err)
go entryPoint.Shutdown(context.Background())
request, err := http.NewRequest("GET", "http://127.0.0.1:8082", nil)
require.NoError(t, err)
err = request.Write(conn)
require.NoError(t, err)
resp, err := http.ReadResponse(bufio.NewReader(conn), request)
require.NoError(t, err)
assert.Equal(t, resp.StatusCode, http.StatusOK)
}

View file

@ -1,6 +1,7 @@
package server
import (
"context"
"net/http"
"net/http/httptest"
"testing"
@ -59,8 +60,8 @@ func TestListenProvidersSkipsSameConfigurationForProvider(t *testing.T) {
}
}
}()
conf := th.BuildConfiguration(
conf := &config.Configuration{}
conf.HTTP = th.BuildConfiguration(
th.WithRouters(th.WithRouter("foo")),
th.WithLoadBalancerServices(th.WithService("bar")),
)
@ -101,7 +102,8 @@ func TestListenProvidersPublishesConfigForEachProvider(t *testing.T) {
}
}()
conf := th.BuildConfiguration(
conf := &config.Configuration{}
conf.HTTP = th.BuildConfiguration(
th.WithRouters(th.WithRouter("foo")),
th.WithLoadBalancerServices(th.WithService("bar")),
)
@ -134,7 +136,7 @@ func setupListenProvider(throttleDuration time.Duration) (server *Server, stop c
},
}
server = NewServer(staticConfiguration, nil, nil)
server = NewServer(staticConfiguration, nil, nil, nil)
go server.listenProviders(stop)
return server, stop, invokeStopChan
@ -146,12 +148,12 @@ func TestServerResponseEmptyBackend(t *testing.T) {
testCases := []struct {
desc string
config func(testServerURL string) *config.Configuration
config func(testServerURL string) *config.HTTPConfiguration
expectedStatusCode int
}{
{
desc: "Ok",
config: func(testServerURL string) *config.Configuration {
config: func(testServerURL string) *config.HTTPConfiguration {
return th.BuildConfiguration(
th.WithRouters(th.WithRouter("foo",
th.WithEntryPoints("http"),
@ -168,14 +170,14 @@ func TestServerResponseEmptyBackend(t *testing.T) {
},
{
desc: "No Frontend",
config: func(testServerURL string) *config.Configuration {
config: func(testServerURL string) *config.HTTPConfiguration {
return th.BuildConfiguration()
},
expectedStatusCode: http.StatusNotFound,
},
{
desc: "Empty Backend LB-Drr",
config: func(testServerURL string) *config.Configuration {
config: func(testServerURL string) *config.HTTPConfiguration {
return th.BuildConfiguration(
th.WithRouters(th.WithRouter("foo",
th.WithEntryPoints("http"),
@ -191,7 +193,7 @@ func TestServerResponseEmptyBackend(t *testing.T) {
},
{
desc: "Empty Backend LB-Drr Sticky",
config: func(testServerURL string) *config.Configuration {
config: func(testServerURL string) *config.HTTPConfiguration {
return th.BuildConfiguration(
th.WithRouters(th.WithRouter("foo",
th.WithEntryPoints("http"),
@ -207,7 +209,7 @@ func TestServerResponseEmptyBackend(t *testing.T) {
},
{
desc: "Empty Backend LB-Wrr",
config: func(testServerURL string) *config.Configuration {
config: func(testServerURL string) *config.HTTPConfiguration {
return th.BuildConfiguration(
th.WithRouters(th.WithRouter("foo",
th.WithEntryPoints("http"),
@ -223,7 +225,7 @@ func TestServerResponseEmptyBackend(t *testing.T) {
},
{
desc: "Empty Backend LB-Wrr Sticky",
config: func(testServerURL string) *config.Configuration {
config: func(testServerURL string) *config.HTTPConfiguration {
return th.BuildConfiguration(
th.WithRouters(th.WithRouter("foo",
th.WithEntryPoints("http"),
@ -251,13 +253,12 @@ func TestServerResponseEmptyBackend(t *testing.T) {
defer testServer.Close()
globalConfig := static.Configuration{}
entryPointsConfig := EntryPoints{
"http": &EntryPoint{},
entryPointsConfig := TCPEntryPoints{
"http": &TCPEntryPoint{},
}
dynamicConfigs := config.Configurations{"config": test.config(testServer.URL)}
srv := NewServer(globalConfig, nil, entryPointsConfig)
entryPoints, _ := srv.loadConfig(dynamicConfigs)
srv := NewServer(globalConfig, nil, entryPointsConfig, nil)
entryPoints, _ := srv.createHTTPHandlers(context.Background(), *test.config(testServer.URL), []string{"http"})
responseRecorder := &httptest.ResponseRecorder{}
request := httptest.NewRequest(http.MethodGet, testServer.URL+requestPath, nil)

100
server/service/proxy.go Normal file
View file

@ -0,0 +1,100 @@
package service
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
"time"
"github.com/containous/flaeg/parse"
"github.com/containous/traefik/config"
"github.com/containous/traefik/log"
)
// StatusClientClosedRequest non-standard HTTP status code for client disconnection
const StatusClientClosedRequest = 499
// StatusClientClosedRequestText non-standard HTTP status for client disconnection
const StatusClientClosedRequestText = "Client Closed Request"
func buildProxy(passHostHeader bool, responseForwarding *config.ResponseForwarding, defaultRoundTripper http.RoundTripper, bufferPool httputil.BufferPool, responseModifier func(*http.Response) error) (http.Handler, error) {
var flushInterval parse.Duration
if responseForwarding != nil {
err := flushInterval.Set(responseForwarding.FlushInterval)
if err != nil {
return nil, fmt.Errorf("error creating flush interval: %v", err)
}
}
if flushInterval == 0 {
flushInterval = parse.Duration(100 * time.Millisecond)
}
proxy := &httputil.ReverseProxy{
Director: func(outReq *http.Request) {
u := outReq.URL
if outReq.RequestURI != "" {
parsedURL, err := url.ParseRequestURI(outReq.RequestURI)
if err == nil {
u = parsedURL
}
}
outReq.URL.Path = u.Path
outReq.URL.RawPath = u.RawPath
outReq.URL.RawQuery = u.RawQuery
outReq.RequestURI = "" // Outgoing request should not have RequestURI
outReq.Proto = "HTTP/1.1"
outReq.ProtoMajor = 1
outReq.ProtoMinor = 1
// Do not pass client Host header unless optsetter PassHostHeader is set.
if !passHostHeader {
outReq.Host = outReq.URL.Host
}
},
Transport: defaultRoundTripper,
FlushInterval: time.Duration(flushInterval),
ModifyResponse: responseModifier,
BufferPool: bufferPool,
ErrorHandler: func(w http.ResponseWriter, request *http.Request, err error) {
statusCode := http.StatusInternalServerError
switch {
case err == io.EOF:
statusCode = http.StatusBadGateway
case err == context.Canceled:
statusCode = StatusClientClosedRequest
default:
if e, ok := err.(net.Error); ok {
if e.Timeout() {
statusCode = http.StatusGatewayTimeout
} else {
statusCode = http.StatusBadGateway
}
}
}
log.Debugf("'%d %s' caused by: %v", statusCode, statusText(statusCode), err)
w.WriteHeader(statusCode)
_, werr := w.Write([]byte(statusText(statusCode)))
if werr != nil {
log.Debugf("Error while writing status code", werr)
}
},
}
return proxy, nil
}
func statusText(statusCode int) string {
if statusCode == StatusClientClosedRequest {
return StatusClientClosedRequestText
}
return http.StatusText(statusCode)
}

View file

@ -0,0 +1,37 @@
package service
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/containous/traefik/testhelpers"
)
type staticTransport struct {
res *http.Response
}
func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
return t.res, nil
}
func BenchmarkProxy(b *testing.B) {
res := &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("")),
}
w := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
pool := newBufferPool()
handler, _ := buildProxy(false, nil, &staticTransport{res}, pool, nil)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
handler.ServeHTTP(w, req)
}
}

View file

@ -0,0 +1,713 @@
package service
import (
"bufio"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
gorillawebsocket "github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/websocket"
)
func TestWebSocketTCPClose(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
errChan := make(chan error, 1)
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
for {
_, _, err := c.ReadMessage()
if err != nil {
errChan <- err
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
proxyAddr := proxy.Listener.Addr().String()
_, conn, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
).open()
require.NoError(t, err)
conn.Close()
serverErr := <-errChan
wsErr, ok := serverErr.(*gorillawebsocket.CloseError)
assert.Equal(t, true, ok)
assert.Equal(t, 1006, wsErr.Code)
}
func TestWebSocketPingPong(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
var upgrader = gorillawebsocket.Upgrader{
HandshakeTimeout: 10 * time.Second,
CheckOrigin: func(*http.Request) bool {
return true
},
}
mux := http.NewServeMux()
mux.HandleFunc("/ws", func(writer http.ResponseWriter, request *http.Request) {
ws, err := upgrader.Upgrade(writer, request, nil)
require.NoError(t, err)
ws.SetPingHandler(func(appData string) error {
err = ws.WriteMessage(gorillawebsocket.PongMessage, []byte(appData+"Pong"))
require.NoError(t, err)
return nil
})
_, _, _ = ws.ReadMessage()
})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL = parseURI(t, srv.URL)
f.ServeHTTP(w, req)
}))
defer proxy.Close()
serverAddr := proxy.Listener.Addr().String()
headers := http.Header{}
webSocketURL := "ws://" + serverAddr + "/ws"
headers.Add("Origin", webSocketURL)
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
require.NoError(t, err, "Error during Dial with response: %+v", resp)
defer conn.Close()
goodErr := fmt.Errorf("signal: %s", "Good data")
badErr := fmt.Errorf("signal: %s", "Bad data")
conn.SetPongHandler(func(data string) error {
if data == "PingPong" {
return goodErr
}
return badErr
})
err = conn.WriteControl(gorillawebsocket.PingMessage, []byte("Ping"), time.Now().Add(time.Second))
require.NoError(t, err)
_, _, err = conn.ReadMessage()
if err != goodErr {
require.NoError(t, err)
}
}
func TestWebSocketEcho(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
msg := make([]byte, 4)
_, err := conn.Read(msg)
require.NoError(t, err)
fmt.Println(string(msg))
_, err = conn.Write(msg)
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
}))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL = parseURI(t, srv.URL)
f.ServeHTTP(w, req)
}))
defer proxy.Close()
serverAddr := proxy.Listener.Addr().String()
headers := http.Header{}
webSocketURL := "ws://" + serverAddr + "/ws"
headers.Add("Origin", webSocketURL)
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
require.NoError(t, err, "Error during Dial with response: %+v", resp)
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
require.NoError(t, err)
fmt.Println(conn.ReadMessage())
err = conn.Close()
require.NoError(t, err)
}
func TestWebSocketPassHost(t *testing.T) {
testCases := []struct {
desc string
passHost bool
expected string
}{
{
desc: "PassHost false",
passHost: false,
},
{
desc: "PassHost true",
passHost: true,
expected: "example.com",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
f, err := buildProxy(test.passHost, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
req := conn.Request()
if test.passHost {
require.Equal(t, test.expected, req.Host)
} else {
require.NotEqual(t, test.expected, req.Host)
}
msg := make([]byte, 4)
_, err = conn.Read(msg)
require.NoError(t, err)
fmt.Println(string(msg))
_, err = conn.Write(msg)
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
}))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL = parseURI(t, srv.URL)
f.ServeHTTP(w, req)
}))
defer proxy.Close()
serverAddr := proxy.Listener.Addr().String()
headers := http.Header{}
webSocketURL := "ws://" + serverAddr + "/ws"
headers.Add("Origin", webSocketURL)
headers.Add("Host", "example.com")
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
require.NoError(t, err, "Error during Dial with response: %+v", resp)
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
require.NoError(t, err)
fmt.Println(conn.ReadMessage())
err = conn.Close()
require.NoError(t, err)
})
}
}
func TestWebSocketServerWithoutCheckOrigin(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool {
return true
}}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
for {
mt, message, err := c.ReadMessage()
if err != nil {
break
}
err = c.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
withOrigin("http://127.0.0.2"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func TestWebSocketRequestWithOrigin(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
for {
mt, message, err := c.ReadMessage()
if err != nil {
break
}
err = c.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
_, err = newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("echo"),
withOrigin("http://127.0.0.2"),
).send()
require.EqualError(t, err, "bad status")
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func TestWebSocketRequestWithQueryParams(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
assert.Equal(t, "test", r.URL.Query().Get("query"))
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws?query=test"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
conn.Close()
}))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL = parseURI(t, srv.URL)
w.Header().Set("HEADER-KEY", "HEADER-VALUE")
f.ServeHTTP(w, req)
}))
defer proxy.Close()
serverAddr := proxy.Listener.Addr().String()
headers := http.Header{}
webSocketURL := "ws://" + serverAddr + "/ws"
headers.Add("Origin", webSocketURL)
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
require.NoError(t, err, "Error during Dial with response: %+v", err, resp)
defer conn.Close()
assert.Equal(t, "HEADER-VALUE", resp.Header.Get("HEADER-KEY"))
}
func TestWebSocketRequestWithEncodedChar(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
assert.Equal(t, "/%3A%2F%2F", r.URL.EscapedPath())
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/%3A%2F%2F"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func TestWebSocketUpgradeFailed(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
mux := http.NewServeMux()
mux.HandleFunc("/ws", func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(400)
})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path // keep the original path
if path == "/ws" {
// Set new backend URL
req.URL = parseURI(t, srv.URL)
req.URL.Path = path
f.ServeHTTP(w, req)
} else {
w.WriteHeader(200)
}
}))
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
conn, err := net.DialTimeout("tcp", proxyAddr, dialTimeout)
require.NoError(t, err)
defer conn.Close()
req, err := http.NewRequest(http.MethodGet, "ws://127.0.0.1/ws", nil)
require.NoError(t, err)
req.Header.Add("upgrade", "websocket")
req.Header.Add("Connection", "upgrade")
err = req.Write(conn)
require.NoError(t, err)
// First request works with 400
br := bufio.NewReader(conn)
resp, err := http.ReadResponse(br, req)
require.NoError(t, err)
assert.Equal(t, 400, resp.StatusCode)
}
func TestForwardsWebsocketTraffic(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
_, err := conn.Write([]byte("ok"))
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
}))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("echo"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func createTLSWebsocketServer() *httptest.Server {
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
return srv
}
func TestWebSocketTransferTLSConfig(t *testing.T) {
srv := createTLSWebsocketServer()
defer srv.Close()
forwarderWithoutTLSConfig, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
proxyWithoutTLSConfig := createProxyWithForwarder(t, forwarderWithoutTLSConfig, srv.URL)
defer proxyWithoutTLSConfig.Close()
proxyAddr := proxyWithoutTLSConfig.Listener.Addr().String()
_, err = newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
).send()
require.EqualError(t, err, "bad status")
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
forwarderWithTLSConfig, err := buildProxy(true, nil, transport, nil, nil)
require.NoError(t, err)
proxyWithTLSConfig := createProxyWithForwarder(t, forwarderWithTLSConfig, srv.URL)
defer proxyWithTLSConfig.Close()
proxyAddr = proxyWithTLSConfig.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
forwarderWithTLSConfigFromDefaultTransport, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
proxyWithTLSConfigFromDefaultTransport := createProxyWithForwarder(t, forwarderWithTLSConfigFromDefaultTransport, srv.URL)
defer proxyWithTLSConfig.Close()
proxyAddr = proxyWithTLSConfigFromDefaultTransport.Listener.Addr().String()
resp, err = newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
const dialTimeout = time.Second
type websocketRequestOpt func(w *websocketRequest)
func withServer(server string) websocketRequestOpt {
return func(w *websocketRequest) {
w.ServerAddr = server
}
}
func withPath(path string) websocketRequestOpt {
return func(w *websocketRequest) {
w.Path = path
}
}
func withData(data string) websocketRequestOpt {
return func(w *websocketRequest) {
w.Data = data
}
}
func withOrigin(origin string) websocketRequestOpt {
return func(w *websocketRequest) {
w.Origin = origin
}
}
func newWebsocketRequest(opts ...websocketRequestOpt) *websocketRequest {
wsrequest := &websocketRequest{}
for _, opt := range opts {
opt(wsrequest)
}
if wsrequest.Origin == "" {
wsrequest.Origin = "http://" + wsrequest.ServerAddr
}
if wsrequest.Config == nil {
wsrequest.Config, _ = websocket.NewConfig(fmt.Sprintf("ws://%s%s", wsrequest.ServerAddr, wsrequest.Path), wsrequest.Origin)
}
return wsrequest
}
type websocketRequest struct {
ServerAddr string
Path string
Data string
Origin string
Config *websocket.Config
}
func (w *websocketRequest) send() (string, error) {
conn, _, err := w.open()
if err != nil {
return "", err
}
defer conn.Close()
if _, err := conn.Write([]byte(w.Data)); err != nil {
return "", err
}
var msg = make([]byte, 512)
var n int
n, err = conn.Read(msg)
if err != nil {
return "", err
}
received := string(msg[:n])
return received, nil
}
func (w *websocketRequest) open() (*websocket.Conn, net.Conn, error) {
client, err := net.DialTimeout("tcp", w.ServerAddr, dialTimeout)
if err != nil {
return nil, nil, err
}
conn, err := websocket.NewClient(w.Config, client)
if err != nil {
return nil, nil, err
}
return conn, client, err
}
func parseURI(t *testing.T, uri string) *url.URL {
out, err := url.ParseRequestURI(uri)
require.NoError(t, err)
return out
}
func createProxyWithForwarder(t *testing.T, proxy http.Handler, url string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path // keep the original path
// Set new backend URL
req.URL = parseURI(t, url)
req.URL.Path = path
proxy.ServeHTTP(w, req)
}))
}

View file

@ -3,14 +3,12 @@ package service
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httputil"
"net/url"
"time"
"github.com/containous/alice"
"github.com/containous/flaeg/parse"
"github.com/containous/traefik/config"
"github.com/containous/traefik/healthcheck"
"github.com/containous/traefik/log"
@ -19,7 +17,6 @@ import (
"github.com/containous/traefik/old/middlewares/pipelining"
"github.com/containous/traefik/server/cookie"
"github.com/containous/traefik/server/internal"
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/roundrobin"
)
@ -46,8 +43,8 @@ type Manager struct {
configs map[string]*config.Service
}
// Build Creates a http.Handler for a service configuration.
func (m *Manager) Build(rootCtx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) {
// BuildHTTP Creates a http.Handler for a service configuration.
func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) {
ctx := log.With(rootCtx, log.Str(log.ServiceName, serviceName))
serviceName = internal.GetQualifiedName(ctx, serviceName)
@ -55,6 +52,7 @@ func (m *Manager) Build(rootCtx context.Context, serviceName string, responseMod
if conf, ok := m.configs[serviceName]; ok {
// TODO Should handle multiple service types
// FIXME Check if the service is declared multiple times with different types
if conf.LoadBalancer != nil {
return m.getLoadBalancerServiceHandler(ctx, serviceName, conf.LoadBalancer, responseModifier)
}
@ -69,8 +67,7 @@ func (m *Manager) getLoadBalancerServiceHandler(
service *config.LoadBalancerService,
responseModifier func(*http.Response) error,
) (http.Handler, error) {
fwd, err := m.buildForwarder(service.PassHostHeader, service.ResponseForwarding, responseModifier)
fwd, err := buildProxy(service.PassHostHeader, service.ResponseForwarding, m.defaultRoundTripper, m.bufferPool, responseModifier)
if err != nil {
return nil, err
}
@ -258,32 +255,3 @@ func (m *Manager) upsertServers(ctx context.Context, lb healthcheck.BalancerHand
}
return nil
}
func (m *Manager) buildForwarder(passHostHeader bool, responseForwarding *config.ResponseForwarding, responseModifier func(*http.Response) error) (http.Handler, error) {
var flushInterval parse.Duration
if responseForwarding != nil {
err := flushInterval.Set(responseForwarding.FlushInterval)
if err != nil {
return nil, fmt.Errorf("error creating flush interval: %v", err)
}
}
return forward.New(
forward.Stream(true),
forward.PassHostHeader(passHostHeader),
forward.RoundTripper(m.defaultRoundTripper),
forward.ResponseModifier(responseModifier),
forward.BufferPool(m.bufferPool),
forward.StreamingFlushInterval(time.Duration(flushInterval)),
forward.WebsocketConnectionClosedHook(func(req *http.Request, conn net.Conn) {
server := req.Context().Value(http.ServerContextKey).(*http.Server)
if server != nil {
connState := server.ConnState
if connState != nil {
connState(conn, http.StateClosed)
}
}
}),
)
}

View file

@ -320,7 +320,7 @@ func TestManager_Build(t *testing.T) {
ctx = internal.AddProviderInContext(ctx, test.providerName+".foobar")
}
_, err := manager.Build(ctx, test.serviceName, nil)
_, err := manager.BuildHTTP(ctx, test.serviceName, nil)
require.NoError(t, err)
})
}

View file

@ -0,0 +1,67 @@
package tcp
import (
"context"
"fmt"
"net"
"github.com/containous/traefik/config"
"github.com/containous/traefik/log"
"github.com/containous/traefik/server/internal"
"github.com/containous/traefik/tcp"
)
// Manager is the TCPHandlers factory
type Manager struct {
configs map[string]*config.TCPService
}
// NewManager creates a new manager
func NewManager(configs map[string]*config.TCPService) *Manager {
return &Manager{
configs: configs,
}
}
// BuildTCP Creates a tcp.Handler for a service configuration.
func (m *Manager) BuildTCP(rootCtx context.Context, serviceName string) (tcp.Handler, error) {
ctx := log.With(rootCtx, log.Str(log.ServiceName, serviceName))
serviceName = internal.GetQualifiedName(ctx, serviceName)
ctx = internal.AddProviderInContext(ctx, serviceName)
if conf, ok := m.configs[serviceName]; ok {
// FIXME Check if the service is declared multiple times with different types
if conf.LoadBalancer != nil {
loadBalancer := tcp.NewRRLoadBalancer()
var handler tcp.Handler
for _, server := range conf.LoadBalancer.Servers {
_, err := parseIP(server.Address)
if err == nil {
handler, _ = tcp.NewProxy(server.Address)
loadBalancer.AddServer(handler)
} else {
log.FromContext(ctx).Errorf("Invalid IP address for a %s server %s: %v", serviceName, server.Address, err)
}
}
return loadBalancer, nil
}
return nil, fmt.Errorf("the service %q doesn't have any TCP load balancer", serviceName)
}
return nil, fmt.Errorf("the service %q does not exits", serviceName)
}
func parseIP(s string) (string, error) {
ip, _, err := net.SplitHostPort(s)
if err == nil {
return ip, nil
}
ipNoPort := net.ParseIP(s)
if ipNoPort == nil {
return "", fmt.Errorf("invalid IP Address %s", ipNoPort)
}
return ipNoPort.String(), nil
}