Add a new protocol
Co-authored-by: Gérald Croës <gerald@containo.us>
This commit is contained in:
parent
0ca2149408
commit
4a68d29ce2
231 changed files with 6895 additions and 4395 deletions
|
@ -3,26 +3,54 @@ package server
|
|||
import (
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/server/internal"
|
||||
"github.com/containous/traefik/tls"
|
||||
)
|
||||
|
||||
func mergeConfiguration(configurations config.Configurations) config.Configuration {
|
||||
conf := config.Configuration{
|
||||
Routers: make(map[string]*config.Router),
|
||||
Middlewares: make(map[string]*config.Middleware),
|
||||
Services: make(map[string]*config.Service),
|
||||
HTTP: &config.HTTPConfiguration{
|
||||
Routers: make(map[string]*config.Router),
|
||||
Middlewares: make(map[string]*config.Middleware),
|
||||
Services: make(map[string]*config.Service),
|
||||
},
|
||||
TCP: &config.TCPConfiguration{
|
||||
Routers: make(map[string]*config.TCPRouter),
|
||||
Services: make(map[string]*config.TCPService),
|
||||
},
|
||||
TLSOptions: make(map[string]tls.TLS),
|
||||
TLSStores: make(map[string]tls.Store),
|
||||
}
|
||||
|
||||
for provider, configuration := range configurations {
|
||||
for routerName, router := range configuration.Routers {
|
||||
conf.Routers[internal.MakeQualifiedName(provider, routerName)] = router
|
||||
if configuration.HTTP != nil {
|
||||
for routerName, router := range configuration.HTTP.Routers {
|
||||
conf.HTTP.Routers[internal.MakeQualifiedName(provider, routerName)] = router
|
||||
}
|
||||
for middlewareName, middleware := range configuration.HTTP.Middlewares {
|
||||
conf.HTTP.Middlewares[internal.MakeQualifiedName(provider, middlewareName)] = middleware
|
||||
}
|
||||
for serviceName, service := range configuration.HTTP.Services {
|
||||
conf.HTTP.Services[internal.MakeQualifiedName(provider, serviceName)] = service
|
||||
}
|
||||
}
|
||||
for middlewareName, middleware := range configuration.Middlewares {
|
||||
conf.Middlewares[internal.MakeQualifiedName(provider, middlewareName)] = middleware
|
||||
}
|
||||
for serviceName, service := range configuration.Services {
|
||||
conf.Services[internal.MakeQualifiedName(provider, serviceName)] = service
|
||||
|
||||
if configuration.TCP != nil {
|
||||
for routerName, router := range configuration.TCP.Routers {
|
||||
conf.TCP.Routers[internal.MakeQualifiedName(provider, routerName)] = router
|
||||
}
|
||||
for serviceName, service := range configuration.TCP.Services {
|
||||
conf.TCP.Services[internal.MakeQualifiedName(provider, serviceName)] = service
|
||||
}
|
||||
}
|
||||
conf.TLS = append(conf.TLS, configuration.TLS...)
|
||||
|
||||
for key, store := range configuration.TLSStores {
|
||||
conf.TLSStores[key] = store
|
||||
}
|
||||
|
||||
for key, config := range configuration.TLSOptions {
|
||||
conf.TLSOptions[key] = config
|
||||
}
|
||||
}
|
||||
|
||||
return conf
|
||||
|
|
|
@ -11,12 +11,12 @@ func TestAggregator(t *testing.T) {
|
|||
testCases := []struct {
|
||||
desc string
|
||||
given config.Configurations
|
||||
expected config.Configuration
|
||||
expected *config.HTTPConfiguration
|
||||
}{
|
||||
{
|
||||
desc: "Nil returns an empty configuration",
|
||||
given: nil,
|
||||
expected: config.Configuration{
|
||||
expected: &config.HTTPConfiguration{
|
||||
Routers: make(map[string]*config.Router),
|
||||
Middlewares: make(map[string]*config.Middleware),
|
||||
Services: make(map[string]*config.Service),
|
||||
|
@ -26,18 +26,20 @@ func TestAggregator(t *testing.T) {
|
|||
desc: "Returns fully qualified elements from a mono-provider configuration map",
|
||||
given: config.Configurations{
|
||||
"provider-1": &config.Configuration{
|
||||
Routers: map[string]*config.Router{
|
||||
"router-1": {},
|
||||
},
|
||||
Middlewares: map[string]*config.Middleware{
|
||||
"middleware-1": {},
|
||||
},
|
||||
Services: map[string]*config.Service{
|
||||
"service-1": {},
|
||||
HTTP: &config.HTTPConfiguration{
|
||||
Routers: map[string]*config.Router{
|
||||
"router-1": {},
|
||||
},
|
||||
Middlewares: map[string]*config.Middleware{
|
||||
"middleware-1": {},
|
||||
},
|
||||
Services: map[string]*config.Service{
|
||||
"service-1": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: config.Configuration{
|
||||
expected: &config.HTTPConfiguration{
|
||||
Routers: map[string]*config.Router{
|
||||
"provider-1.router-1": {},
|
||||
},
|
||||
|
@ -53,29 +55,33 @@ func TestAggregator(t *testing.T) {
|
|||
desc: "Returns fully qualified elements from a multi-provider configuration map",
|
||||
given: config.Configurations{
|
||||
"provider-1": &config.Configuration{
|
||||
Routers: map[string]*config.Router{
|
||||
"router-1": {},
|
||||
},
|
||||
Middlewares: map[string]*config.Middleware{
|
||||
"middleware-1": {},
|
||||
},
|
||||
Services: map[string]*config.Service{
|
||||
"service-1": {},
|
||||
HTTP: &config.HTTPConfiguration{
|
||||
Routers: map[string]*config.Router{
|
||||
"router-1": {},
|
||||
},
|
||||
Middlewares: map[string]*config.Middleware{
|
||||
"middleware-1": {},
|
||||
},
|
||||
Services: map[string]*config.Service{
|
||||
"service-1": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider-2": &config.Configuration{
|
||||
Routers: map[string]*config.Router{
|
||||
"router-1": {},
|
||||
},
|
||||
Middlewares: map[string]*config.Middleware{
|
||||
"middleware-1": {},
|
||||
},
|
||||
Services: map[string]*config.Service{
|
||||
"service-1": {},
|
||||
HTTP: &config.HTTPConfiguration{
|
||||
Routers: map[string]*config.Router{
|
||||
"router-1": {},
|
||||
},
|
||||
Middlewares: map[string]*config.Middleware{
|
||||
"middleware-1": {},
|
||||
},
|
||||
Services: map[string]*config.Service{
|
||||
"service-1": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: config.Configuration{
|
||||
expected: &config.HTTPConfiguration{
|
||||
Routers: map[string]*config.Router{
|
||||
"provider-1.router-1": {},
|
||||
"provider-2.router-1": {},
|
||||
|
@ -96,8 +102,9 @@ func TestAggregator(t *testing.T) {
|
|||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
actual := mergeConfiguration(test.given)
|
||||
assert.Equal(t, test.expected, actual)
|
||||
assert.Equal(t, test.expected, actual.HTTP)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ type Builder struct {
|
|||
}
|
||||
|
||||
type serviceBuilder interface {
|
||||
Build(ctx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error)
|
||||
BuildHTTP(ctx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error)
|
||||
}
|
||||
|
||||
// NewBuilder creates a new Builder
|
||||
|
@ -57,9 +57,9 @@ func (b *Builder) BuildChain(ctx context.Context, middlewares []string) *alice.C
|
|||
chain := alice.New()
|
||||
for _, name := range middlewares {
|
||||
middlewareName := internal.GetQualifiedName(ctx, name)
|
||||
constructorContext := internal.AddProviderInContext(ctx, middlewareName)
|
||||
|
||||
chain = chain.Append(func(next http.Handler) (http.Handler, error) {
|
||||
constructorContext := internal.AddProviderInContext(ctx, middlewareName)
|
||||
if _, ok := b.configs[middlewareName]; !ok {
|
||||
return nil, fmt.Errorf("middleware %q does not exist", middlewareName)
|
||||
}
|
||||
|
|
|
@ -45,8 +45,8 @@ type Manager struct {
|
|||
}
|
||||
|
||||
// BuildHandlers Builds handler for all entry points
|
||||
func (m *Manager) BuildHandlers(rootCtx context.Context, entryPoints []string) map[string]http.Handler {
|
||||
entryPointsRouters := m.filteredRouters(rootCtx, entryPoints)
|
||||
func (m *Manager) BuildHandlers(rootCtx context.Context, entryPoints []string, tls bool) map[string]http.Handler {
|
||||
entryPointsRouters := m.filteredRouters(rootCtx, entryPoints, tls)
|
||||
|
||||
entryPointHandlers := make(map[string]http.Handler)
|
||||
for entryPointName, routers := range entryPointsRouters {
|
||||
|
@ -84,10 +84,14 @@ func contains(entryPoints []string, entryPointName string) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) filteredRouters(ctx context.Context, entryPoints []string) map[string]map[string]*config.Router {
|
||||
func (m *Manager) filteredRouters(ctx context.Context, entryPoints []string, tls bool) map[string]map[string]*config.Router {
|
||||
entryPointsRouters := make(map[string]map[string]*config.Router)
|
||||
|
||||
for rtName, rt := range m.configs {
|
||||
if (tls && rt.TLS == nil) || (!tls && rt.TLS != nil) {
|
||||
continue
|
||||
}
|
||||
|
||||
eps := rt.EntryPoints
|
||||
if len(eps) == 0 {
|
||||
eps = entryPoints
|
||||
|
@ -155,7 +159,7 @@ func (m *Manager) buildRouterHandler(ctx context.Context, routerName string) (ht
|
|||
return nil, fmt.Errorf("no configuration for %s", routerName)
|
||||
}
|
||||
|
||||
handler, err := m.buildHandler(ctx, configRouter, routerName)
|
||||
handler, err := m.buildHTTPHandler(ctx, configRouter, routerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -173,10 +177,10 @@ func (m *Manager) buildRouterHandler(ctx context.Context, routerName string) (ht
|
|||
return m.routerHandlers[routerName], nil
|
||||
}
|
||||
|
||||
func (m *Manager) buildHandler(ctx context.Context, router *config.Router, routerName string) (http.Handler, error) {
|
||||
func (m *Manager) buildHTTPHandler(ctx context.Context, router *config.Router, routerName string) (http.Handler, error) {
|
||||
rm := m.modifierBuilder.Build(ctx, router.Middlewares)
|
||||
|
||||
sHandler, err := m.serviceManager.Build(ctx, router.Service, rm)
|
||||
sHandler, err := m.serviceManager.BuildHTTP(ctx, router.Service, rm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -2,8 +2,10 @@ package router
|
|||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
|
@ -318,7 +320,7 @@ func TestRouterManager_Get(t *testing.T) {
|
|||
|
||||
routerManager := NewManager(test.routersConfig, serviceManager, middlewaresBuilder, responseModifierFactory)
|
||||
|
||||
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints)
|
||||
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
|
||||
|
@ -417,7 +419,7 @@ func TestAccessLog(t *testing.T) {
|
|||
|
||||
routerManager := NewManager(test.routersConfig, serviceManager, middlewaresBuilder, responseModifierFactory)
|
||||
|
||||
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints)
|
||||
handlers := routerManager.BuildHandlers(context.Background(), test.entryPoints, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
|
||||
|
@ -440,3 +442,90 @@ func TestAccessLog(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
type staticTransport struct {
|
||||
res *http.Response
|
||||
}
|
||||
|
||||
func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
return t.res, nil
|
||||
}
|
||||
|
||||
func BenchmarkRouterServe(b *testing.B) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
|
||||
res := &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("")),
|
||||
}
|
||||
routersConfig := map[string]*config.Router{
|
||||
"foo": {
|
||||
EntryPoints: []string{"web"},
|
||||
Service: "foo-service",
|
||||
Rule: "Host(`foo.bar`) && Path(`/`)",
|
||||
},
|
||||
}
|
||||
serviceConfig := map[string]*config.Service{
|
||||
"foo-service": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: server.URL,
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
}
|
||||
entryPoints := []string{"web"}
|
||||
|
||||
serviceManager := service.NewManager(serviceConfig, &staticTransport{res})
|
||||
middlewaresBuilder := middleware.NewBuilder(map[string]*config.Middleware{}, serviceManager)
|
||||
responseModifierFactory := responsemodifiers.NewBuilder(map[string]*config.Middleware{})
|
||||
|
||||
routerManager := NewManager(routersConfig, serviceManager, middlewaresBuilder, responseModifierFactory)
|
||||
|
||||
handlers := routerManager.BuildHandlers(context.Background(), entryPoints, false)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
|
||||
|
||||
reqHost := requestdecorator.New(nil)
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
reqHost.ServeHTTP(w, req, handlers["web"].ServeHTTP)
|
||||
}
|
||||
|
||||
}
|
||||
func BenchmarkService(b *testing.B) {
|
||||
res := &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("")),
|
||||
}
|
||||
|
||||
serviceConfig := map[string]*config.Service{
|
||||
"foo-service": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Servers: []config.Server{
|
||||
{
|
||||
URL: "tchouck",
|
||||
Weight: 1,
|
||||
},
|
||||
},
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
serviceManager := service.NewManager(serviceConfig, &staticTransport{res})
|
||||
w := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
|
||||
|
||||
handler, _ := serviceManager.BuildHTTP(context.Background(), "foo-service", nil)
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
handler.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
139
server/router/tcp/router.go
Normal file
139
server/router/tcp/router.go
Normal file
|
@ -0,0 +1,139 @@
|
|||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/rules"
|
||||
"github.com/containous/traefik/server/internal"
|
||||
tcpservice "github.com/containous/traefik/server/service/tcp"
|
||||
"github.com/containous/traefik/tcp"
|
||||
)
|
||||
|
||||
// NewManager Creates a new Manager
|
||||
func NewManager(routers map[string]*config.TCPRouter,
|
||||
serviceManager *tcpservice.Manager,
|
||||
httpHandlers map[string]http.Handler,
|
||||
httpsHandlers map[string]http.Handler,
|
||||
tlsConfig *tls.Config,
|
||||
) *Manager {
|
||||
return &Manager{
|
||||
configs: routers,
|
||||
serviceManager: serviceManager,
|
||||
httpHandlers: httpHandlers,
|
||||
httpsHandlers: httpsHandlers,
|
||||
tlsConfig: tlsConfig,
|
||||
}
|
||||
}
|
||||
|
||||
// Manager is a route/router manager
|
||||
type Manager struct {
|
||||
configs map[string]*config.TCPRouter
|
||||
serviceManager *tcpservice.Manager
|
||||
httpHandlers map[string]http.Handler
|
||||
httpsHandlers map[string]http.Handler
|
||||
tlsConfig *tls.Config
|
||||
}
|
||||
|
||||
// BuildHandlers builds the handlers for the given entrypoints
|
||||
func (m *Manager) BuildHandlers(rootCtx context.Context, entryPoints []string) map[string]*tcp.Router {
|
||||
entryPointsRouters := m.filteredRouters(rootCtx, entryPoints)
|
||||
|
||||
entryPointHandlers := make(map[string]*tcp.Router)
|
||||
for _, entryPointName := range entryPoints {
|
||||
entryPointName := entryPointName
|
||||
|
||||
routers := entryPointsRouters[entryPointName]
|
||||
|
||||
ctx := log.With(rootCtx, log.Str(log.EntryPointName, entryPointName))
|
||||
|
||||
handler, err := m.buildEntryPointHandler(ctx, routers, m.httpHandlers[entryPointName], m.httpsHandlers[entryPointName])
|
||||
if err != nil {
|
||||
log.FromContext(ctx).Error(err)
|
||||
continue
|
||||
}
|
||||
entryPointHandlers[entryPointName] = handler
|
||||
}
|
||||
return entryPointHandlers
|
||||
}
|
||||
|
||||
func (m *Manager) buildEntryPointHandler(ctx context.Context, configs map[string]*config.TCPRouter, handlerHTTP http.Handler, handlerHTTPS http.Handler) (*tcp.Router, error) {
|
||||
router := &tcp.Router{}
|
||||
|
||||
router.HTTPHandler(handlerHTTP)
|
||||
router.HTTPSHandler(handlerHTTPS, m.tlsConfig)
|
||||
for routerName, routerConfig := range configs {
|
||||
ctxRouter := log.With(ctx, log.Str(log.RouterName, routerName))
|
||||
logger := log.FromContext(ctxRouter)
|
||||
|
||||
ctxRouter = internal.AddProviderInContext(ctxRouter, routerName)
|
||||
|
||||
handler, err := m.serviceManager.BuildTCP(ctxRouter, routerConfig.Service)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
continue
|
||||
}
|
||||
|
||||
domains, err := rules.ParseHostSNI(routerConfig.Rule)
|
||||
if err != nil {
|
||||
log.WithoutContext().Debugf("Unknown rule %s", routerConfig.Rule)
|
||||
continue
|
||||
}
|
||||
for _, domain := range domains {
|
||||
log.WithoutContext().Debugf("Add route %s on TCP", domain)
|
||||
switch {
|
||||
case routerConfig.TLS != nil:
|
||||
if routerConfig.TLS.Passthrough {
|
||||
router.AddRoute(domain, handler)
|
||||
} else {
|
||||
router.AddRouteTLS(domain, handler, m.tlsConfig)
|
||||
|
||||
}
|
||||
case domain == "*":
|
||||
router.AddCatchAllNoTLS(handler)
|
||||
default:
|
||||
logger.Warn("TCP Router ignored, cannot specify a Host rule without TLS")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return router, nil
|
||||
}
|
||||
|
||||
func contains(entryPoints []string, entryPointName string) bool {
|
||||
for _, name := range entryPoints {
|
||||
if name == entryPointName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) filteredRouters(ctx context.Context, entryPoints []string) map[string]map[string]*config.TCPRouter {
|
||||
entryPointsRouters := make(map[string]map[string]*config.TCPRouter)
|
||||
|
||||
for rtName, rt := range m.configs {
|
||||
eps := rt.EntryPoints
|
||||
if len(eps) == 0 {
|
||||
eps = entryPoints
|
||||
}
|
||||
for _, entryPointName := range eps {
|
||||
if !contains(entryPoints, entryPointName) {
|
||||
log.FromContext(log.With(ctx, log.Str(log.EntryPointName, entryPointName))).
|
||||
Errorf("entryPoint %q doesn't exist", entryPointName)
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := entryPointsRouters[entryPointName]; !ok {
|
||||
entryPointsRouters[entryPointName] = make(map[string]*config.TCPRouter)
|
||||
}
|
||||
|
||||
entryPointsRouters[entryPointName][rtName] = rt
|
||||
}
|
||||
}
|
||||
|
||||
return entryPointsRouters
|
||||
}
|
|
@ -9,7 +9,6 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/cluster"
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/config/static"
|
||||
"github.com/containous/traefik/log"
|
||||
|
@ -19,6 +18,7 @@ import (
|
|||
"github.com/containous/traefik/provider"
|
||||
"github.com/containous/traefik/safe"
|
||||
"github.com/containous/traefik/server/middleware"
|
||||
"github.com/containous/traefik/tls"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/containous/traefik/tracing/datadog"
|
||||
"github.com/containous/traefik/tracing/instana"
|
||||
|
@ -29,7 +29,7 @@ import (
|
|||
|
||||
// Server is the reverse-proxy/load-balancer engine
|
||||
type Server struct {
|
||||
entryPoints EntryPoints
|
||||
entryPointsTCP TCPEntryPoints
|
||||
configurationChan chan config.Message
|
||||
configurationValidatedChan chan config.Message
|
||||
signals chan os.Signal
|
||||
|
@ -39,13 +39,13 @@ type Server struct {
|
|||
accessLoggerMiddleware *accesslog.Handler
|
||||
tracer *tracing.Tracing
|
||||
routinesPool *safe.Pool
|
||||
leadership *cluster.Leadership //FIXME Cluster
|
||||
defaultRoundTripper http.RoundTripper
|
||||
metricsRegistry metrics.Registry
|
||||
provider provider.Provider
|
||||
configurationListeners []func(config.Configuration)
|
||||
requestDecorator *requestdecorator.RequestDecorator
|
||||
providersThrottleDuration time.Duration
|
||||
tlsManager *tls.Manager
|
||||
}
|
||||
|
||||
// RouteAppenderFactory the route appender factory interface
|
||||
|
@ -70,11 +70,11 @@ func setupTracing(conf *static.Tracing) tracing.TrackingBackend {
|
|||
}
|
||||
|
||||
// NewServer returns an initialized Server.
|
||||
func NewServer(staticConfiguration static.Configuration, provider provider.Provider, entryPoints EntryPoints) *Server {
|
||||
func NewServer(staticConfiguration static.Configuration, provider provider.Provider, entryPoints TCPEntryPoints, tlsManager *tls.Manager) *Server {
|
||||
server := &Server{}
|
||||
|
||||
server.provider = provider
|
||||
server.entryPoints = entryPoints
|
||||
server.entryPointsTCP = entryPoints
|
||||
server.configurationChan = make(chan config.Message, 100)
|
||||
server.configurationValidatedChan = make(chan config.Message, 100)
|
||||
server.signals = make(chan os.Signal, 1)
|
||||
|
@ -83,6 +83,7 @@ func NewServer(staticConfiguration static.Configuration, provider provider.Provi
|
|||
currentConfigurations := make(config.Configurations)
|
||||
server.currentConfigurations.Set(currentConfigurations)
|
||||
server.providerConfigUpdateMap = make(map[string]chan config.Message)
|
||||
server.tlsManager = tlsManager
|
||||
|
||||
if staticConfiguration.Providers != nil {
|
||||
server.providersThrottleDuration = time.Duration(staticConfiguration.Providers.ProvidersThrottleDuration)
|
||||
|
@ -132,8 +133,7 @@ func (s *Server) Start(ctx context.Context) {
|
|||
s.Stop()
|
||||
}()
|
||||
|
||||
s.startHTTPServers()
|
||||
s.startLeadership()
|
||||
s.startTCPServers()
|
||||
s.routinesPool.Go(func(stop chan bool) {
|
||||
s.listenProviders(stop)
|
||||
})
|
||||
|
@ -156,9 +156,9 @@ func (s *Server) Stop() {
|
|||
defer log.WithoutContext().Info("Server stopped")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for epn, ep := range s.entryPoints {
|
||||
for epn, ep := range s.entryPointsTCP {
|
||||
wg.Add(1)
|
||||
go func(entryPointName string, entryPoint *EntryPoint) {
|
||||
go func(entryPointName string, entryPoint *TCPEntryPoint) {
|
||||
ctx := log.With(context.Background(), log.Str(log.EntryPointName, entryPointName))
|
||||
defer wg.Done()
|
||||
|
||||
|
@ -184,7 +184,6 @@ func (s *Server) Close() {
|
|||
}(ctx)
|
||||
|
||||
stopMetricsClients()
|
||||
s.stopLeadership()
|
||||
s.routinesPool.Cleanup()
|
||||
close(s.configurationChan)
|
||||
close(s.configurationValidatedChan)
|
||||
|
@ -205,28 +204,16 @@ func (s *Server) Close() {
|
|||
cancel()
|
||||
}
|
||||
|
||||
func (s *Server) startLeadership() {
|
||||
if s.leadership != nil {
|
||||
s.leadership.Participate(s.routinesPool)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) stopLeadership() {
|
||||
if s.leadership != nil {
|
||||
s.leadership.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) startHTTPServers() {
|
||||
func (s *Server) startTCPServers() {
|
||||
// Use an empty configuration in order to initialize the default handlers with internal routes
|
||||
handlers := s.applyConfiguration(context.Background(), config.Configuration{})
|
||||
for entryPointName, handler := range handlers {
|
||||
s.entryPoints[entryPointName].switcher.UpdateHandler(handler)
|
||||
routers := s.loadConfigurationTCP(config.Configurations{})
|
||||
for entryPointName, router := range routers {
|
||||
s.entryPointsTCP[entryPointName].switchRouter(router)
|
||||
}
|
||||
|
||||
for entryPointName, entryPoint := range s.entryPoints {
|
||||
for entryPointName, serverEntryPoint := range s.entryPointsTCP {
|
||||
ctx := log.With(context.Background(), log.Str(log.EntryPointName, entryPointName))
|
||||
go entryPoint.Start(ctx)
|
||||
go serverEntryPoint.startTCP(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"time"
|
||||
|
@ -19,16 +18,16 @@ import (
|
|||
"github.com/containous/traefik/responsemodifiers"
|
||||
"github.com/containous/traefik/server/middleware"
|
||||
"github.com/containous/traefik/server/router"
|
||||
routertcp "github.com/containous/traefik/server/router/tcp"
|
||||
"github.com/containous/traefik/server/service"
|
||||
traefiktls "github.com/containous/traefik/tls"
|
||||
"github.com/containous/traefik/server/service/tcp"
|
||||
tcpCore "github.com/containous/traefik/tcp"
|
||||
"github.com/eapache/channels"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// loadConfiguration manages dynamically routers, middlewares, servers and TLS configurations
|
||||
func (s *Server) loadConfiguration(configMsg config.Message) {
|
||||
logger := log.FromContext(log.With(context.Background(), log.Str(log.ProviderName, configMsg.ProviderName)))
|
||||
|
||||
currentConfigurations := s.currentConfigurations.Get().(config.Configurations)
|
||||
|
||||
// Copy configurations to new map so we don't change current if LoadConfig fails
|
||||
|
@ -40,27 +39,13 @@ func (s *Server) loadConfiguration(configMsg config.Message) {
|
|||
|
||||
s.metricsRegistry.ConfigReloadsCounter().Add(1)
|
||||
|
||||
handlers, certificates := s.loadConfig(newConfigurations)
|
||||
handlersTCP := s.loadConfigurationTCP(newConfigurations)
|
||||
for entryPointName, router := range handlersTCP {
|
||||
s.entryPointsTCP[entryPointName].switchRouter(router)
|
||||
}
|
||||
|
||||
s.metricsRegistry.LastConfigReloadSuccessGauge().Set(float64(time.Now().Unix()))
|
||||
|
||||
for entryPointName, handler := range handlers {
|
||||
s.entryPoints[entryPointName].switcher.UpdateHandler(handler)
|
||||
}
|
||||
|
||||
for entryPointName, entryPoint := range s.entryPoints {
|
||||
eLogger := logger.WithField(log.EntryPointName, entryPointName)
|
||||
if entryPoint.Certs == nil {
|
||||
if len(certificates[entryPointName]) > 0 {
|
||||
eLogger.Debugf("Cannot configure certificates for the non-TLS %s entryPoint.", entryPointName)
|
||||
}
|
||||
} else {
|
||||
entryPoint.Certs.DynamicCerts.Set(certificates[entryPointName])
|
||||
entryPoint.Certs.ResetCache()
|
||||
}
|
||||
eLogger.Infof("Server configuration reloaded on %s", s.entryPoints[entryPointName].httpServer.Addr)
|
||||
}
|
||||
|
||||
s.currentConfigurations.Set(newConfigurations)
|
||||
|
||||
for _, listener := range s.configurationListeners {
|
||||
|
@ -70,35 +55,48 @@ func (s *Server) loadConfiguration(configMsg config.Message) {
|
|||
s.postLoadConfiguration()
|
||||
}
|
||||
|
||||
// loadConfig returns a new gorilla.mux Route from the specified global configuration and the dynamic
|
||||
// loadConfigurationTCP returns a new gorilla.mux Route from the specified global configuration and the dynamic
|
||||
// provider configurations.
|
||||
func (s *Server) loadConfig(configurations config.Configurations) (map[string]http.Handler, map[string]map[string]*tls.Certificate) {
|
||||
|
||||
func (s *Server) loadConfigurationTCP(configurations config.Configurations) map[string]*tcpCore.Router {
|
||||
ctx := context.TODO()
|
||||
|
||||
conf := mergeConfiguration(configurations)
|
||||
handlers := s.applyConfiguration(ctx, conf)
|
||||
|
||||
// Get new certificates list sorted per entry points
|
||||
// Update certificates
|
||||
entryPointsCertificates := s.loadHTTPSConfiguration(configurations)
|
||||
|
||||
return handlers, entryPointsCertificates
|
||||
}
|
||||
|
||||
func (s *Server) applyConfiguration(ctx context.Context, configuration config.Configuration) map[string]http.Handler {
|
||||
var entryPoints []string
|
||||
for entryPointName := range s.entryPoints {
|
||||
for entryPointName := range s.entryPointsTCP {
|
||||
entryPoints = append(entryPoints, entryPointName)
|
||||
}
|
||||
|
||||
conf := mergeConfiguration(configurations)
|
||||
|
||||
s.tlsManager.UpdateConfigs(conf.TLSStores, conf.TLSOptions, conf.TLS)
|
||||
|
||||
handlersNonTLS, handlersTLS := s.createHTTPHandlers(ctx, *conf.HTTP, entryPoints)
|
||||
|
||||
routersTCP := s.createTCPRouters(ctx, conf.TCP, entryPoints, handlersNonTLS, handlersTLS, s.tlsManager.Get("default", "default"))
|
||||
|
||||
return routersTCP
|
||||
}
|
||||
|
||||
func (s *Server) createTCPRouters(ctx context.Context, configuration *config.TCPConfiguration, entryPoints []string, handlers map[string]http.Handler, handlersTLS map[string]http.Handler, tlsConfig *tls.Config) map[string]*tcpCore.Router {
|
||||
if configuration == nil {
|
||||
return make(map[string]*tcpCore.Router)
|
||||
}
|
||||
|
||||
serviceManager := tcp.NewManager(configuration.Services)
|
||||
routerManager := routertcp.NewManager(configuration.Routers, serviceManager, handlers, handlersTLS, tlsConfig)
|
||||
|
||||
return routerManager.BuildHandlers(ctx, entryPoints)
|
||||
|
||||
}
|
||||
|
||||
func (s *Server) createHTTPHandlers(ctx context.Context, configuration config.HTTPConfiguration, entryPoints []string) (map[string]http.Handler, map[string]http.Handler) {
|
||||
serviceManager := service.NewManager(configuration.Services, s.defaultRoundTripper)
|
||||
middlewaresBuilder := middleware.NewBuilder(configuration.Middlewares, serviceManager)
|
||||
responseModifierFactory := responsemodifiers.NewBuilder(configuration.Middlewares)
|
||||
|
||||
routerManager := router.NewManager(configuration.Routers, serviceManager, middlewaresBuilder, responseModifierFactory)
|
||||
|
||||
handlers := routerManager.BuildHandlers(ctx, entryPoints)
|
||||
handlersNonTLS := routerManager.BuildHandlers(ctx, entryPoints, false)
|
||||
handlersTLS := routerManager.BuildHandlers(ctx, entryPoints, true)
|
||||
|
||||
routerHandlers := make(map[string]http.Handler)
|
||||
|
||||
|
@ -108,14 +106,14 @@ func (s *Server) applyConfiguration(ctx context.Context, configuration config.Co
|
|||
|
||||
ctx = log.With(ctx, log.Str(log.EntryPointName, entryPointName))
|
||||
|
||||
factory := s.entryPoints[entryPointName].RouteAppenderFactory
|
||||
factory := s.entryPointsTCP[entryPointName].RouteAppenderFactory
|
||||
if factory != nil {
|
||||
// FIXME remove currentConfigurations
|
||||
appender := factory.NewAppender(ctx, middlewaresBuilder, &s.currentConfigurations)
|
||||
appender.Append(internalMuxRouter)
|
||||
}
|
||||
|
||||
if h, ok := handlers[entryPointName]; ok {
|
||||
if h, ok := handlersNonTLS[entryPointName]; ok {
|
||||
internalMuxRouter.NotFoundHandler = h
|
||||
} else {
|
||||
internalMuxRouter.NotFoundHandler = buildDefaultHTTPRouter()
|
||||
|
@ -141,13 +139,42 @@ func (s *Server) applyConfiguration(ctx context.Context, configuration config.Co
|
|||
continue
|
||||
}
|
||||
internalMuxRouter.NotFoundHandler = handler
|
||||
|
||||
handlerTLS, ok := handlersTLS[entryPointName]
|
||||
if ok {
|
||||
handlerTLSWithMiddlewares, err := chain.Then(handlerTLS)
|
||||
if err != nil {
|
||||
log.FromContext(ctx).Error(err)
|
||||
continue
|
||||
}
|
||||
handlersTLS[entryPointName] = handlerTLSWithMiddlewares
|
||||
}
|
||||
}
|
||||
|
||||
return routerHandlers
|
||||
return routerHandlers, handlersTLS
|
||||
}
|
||||
|
||||
func isEmptyConfiguration(conf *config.Configuration) bool {
|
||||
if conf == nil {
|
||||
return true
|
||||
}
|
||||
if conf.TCP == nil {
|
||||
conf.TCP = &config.TCPConfiguration{}
|
||||
}
|
||||
if conf.HTTP == nil {
|
||||
conf.HTTP = &config.HTTPConfiguration{}
|
||||
}
|
||||
|
||||
return conf.HTTP.Routers == nil &&
|
||||
conf.HTTP.Services == nil &&
|
||||
conf.HTTP.Middlewares == nil &&
|
||||
conf.TLS == nil &&
|
||||
conf.TCP.Routers == nil &&
|
||||
conf.TCP.Services == nil
|
||||
}
|
||||
|
||||
func (s *Server) preLoadConfiguration(configMsg config.Message) {
|
||||
s.defaultConfigurationValues(configMsg.Configuration)
|
||||
s.defaultConfigurationValues(configMsg.Configuration.HTTP)
|
||||
currentConfigurations := s.currentConfigurations.Get().(config.Configurations)
|
||||
|
||||
logger := log.WithoutContext().WithField(log.ProviderName, configMsg.ProviderName)
|
||||
|
@ -156,7 +183,7 @@ func (s *Server) preLoadConfiguration(configMsg config.Message) {
|
|||
logger.Debugf("Configuration received from provider %s: %s", configMsg.ProviderName, string(jsonConf))
|
||||
}
|
||||
|
||||
if configMsg.Configuration == nil || configMsg.Configuration.Routers == nil && configMsg.Configuration.Services == nil && configMsg.Configuration.Middlewares == nil && configMsg.Configuration.TLS == nil {
|
||||
if isEmptyConfiguration(configMsg.Configuration) {
|
||||
logger.Infof("Skipping empty Configuration for provider %s", configMsg.ProviderName)
|
||||
return
|
||||
}
|
||||
|
@ -178,7 +205,7 @@ func (s *Server) preLoadConfiguration(configMsg config.Message) {
|
|||
providerConfigUpdateCh <- configMsg
|
||||
}
|
||||
|
||||
func (s *Server) defaultConfigurationValues(configuration *config.Configuration) {
|
||||
func (s *Server) defaultConfigurationValues(configuration *config.HTTPConfiguration) {
|
||||
// FIXME create a config hook
|
||||
}
|
||||
|
||||
|
@ -235,59 +262,6 @@ func (s *Server) postLoadConfiguration() {
|
|||
// metrics.OnConfigurationUpdate(activeConfig)
|
||||
// }
|
||||
|
||||
// FIXME acme
|
||||
// if s.staticConfiguration.ACME == nil || s.leadership == nil || !s.leadership.IsLeader() {
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// if s.staticConfiguration.ACME.OnHostRule {
|
||||
// currentConfigurations := s.currentConfigurations.Get().(config.Configurations)
|
||||
// for _, config := range currentConfigurations {
|
||||
// for _, frontend := range config.Frontends {
|
||||
//
|
||||
// // check if one of the frontend entrypoints is configured with TLS
|
||||
// // and is configured with ACME
|
||||
// acmeEnabled := false
|
||||
// for _, entryPoint := range frontend.EntryPoints {
|
||||
// if s.staticConfiguration.ACME.EntryPoint == entryPoint && s.entryPoints[entryPoint].Configuration.TLS != nil {
|
||||
// acmeEnabled = true
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// if acmeEnabled {
|
||||
// for _, route := range frontend.Routes {
|
||||
// rls := rules.Rules{}
|
||||
// domains, err := rls.ParseDomains(route.Rule)
|
||||
// if err != nil {
|
||||
// log.Errorf("Error parsing domains: %v", err)
|
||||
// } else if len(domains) == 0 {
|
||||
// log.Debugf("No domain parsed in rule %q", route.Rule)
|
||||
// } else {
|
||||
// s.staticConfiguration.ACME.LoadCertificateForDomains(domains)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
// loadHTTPSConfiguration add/delete HTTPS certificate managed dynamically
|
||||
func (s *Server) loadHTTPSConfiguration(configurations config.Configurations) map[string]map[string]*tls.Certificate {
|
||||
var entryPoints []string
|
||||
for entryPointName := range s.entryPoints {
|
||||
entryPoints = append(entryPoints, entryPointName)
|
||||
}
|
||||
|
||||
newEPCertificates := make(map[string]map[string]*tls.Certificate)
|
||||
// Get all certificates
|
||||
for _, config := range configurations {
|
||||
if config.TLS != nil && len(config.TLS) > 0 {
|
||||
traefiktls.SortTLSPerEntryPoints(config.TLS, newEPCertificates, entryPoints)
|
||||
}
|
||||
}
|
||||
return newEPCertificates
|
||||
}
|
||||
|
||||
func buildDefaultHTTPRouter() *mux.Router {
|
||||
|
@ -296,21 +270,3 @@ func buildDefaultHTTPRouter() *mux.Router {
|
|||
rt.SkipClean(true)
|
||||
return rt
|
||||
}
|
||||
|
||||
func buildDefaultCertificate(defaultCertificate *traefiktls.Certificate) (*tls.Certificate, error) {
|
||||
certFile, err := defaultCertificate.CertFile.Read()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get cert file content: %v", err)
|
||||
}
|
||||
|
||||
keyFile, err := defaultCertificate.KeyFile.Read()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get key file content: %v", err)
|
||||
}
|
||||
|
||||
cert, err := tls.X509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load X509 key pair: %v", err)
|
||||
}
|
||||
return &cert, nil
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
@ -9,115 +10,44 @@ import (
|
|||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/config/static"
|
||||
th "github.com/containous/traefik/testhelpers"
|
||||
"github.com/containous/traefik/tls"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// LocalhostCert is a PEM-encoded TLS cert with SAN IPs
|
||||
// "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT.
|
||||
// generated from src/crypto/tls:
|
||||
// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
|
||||
var (
|
||||
localhostCert = tls.FileOrContent(`-----BEGIN CERTIFICATE-----
|
||||
MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS
|
||||
MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw
|
||||
MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB
|
||||
iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4
|
||||
iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul
|
||||
rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO
|
||||
BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw
|
||||
AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA
|
||||
AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9
|
||||
tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs
|
||||
h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM
|
||||
fblo6RBxUQ==
|
||||
-----END CERTIFICATE-----`)
|
||||
|
||||
// LocalhostKey is the private key for localhostCert.
|
||||
localhostKey = tls.FileOrContent(`-----BEGIN RSA PRIVATE KEY-----
|
||||
MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9
|
||||
SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB
|
||||
l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB
|
||||
AoGAGRzwwir7XvBOAy5tM/uV6e+Zf6anZzus1s1Y1ClbjbE6HXbnWWF/wbZGOpet
|
||||
3Zm4vD6MXc7jpTLryzTQIvVdfQbRc6+MUVeLKwZatTXtdZrhu+Jk7hx0nTPy8Jcb
|
||||
uJqFk541aEw+mMogY/xEcfbWd6IOkp+4xqjlFLBEDytgbIECQQDvH/E6nk+hgN4H
|
||||
qzzVtxxr397vWrjrIgPbJpQvBsafG7b0dA4AFjwVbFLmQcj2PprIMmPcQrooz8vp
|
||||
jy4SHEg1AkEA/v13/5M47K9vCxmb8QeD/asydfsgS5TeuNi8DoUBEmiSJwma7FXY
|
||||
fFUtxuvL7XvjwjN5B30pNEbc6Iuyt7y4MQJBAIt21su4b3sjXNueLKH85Q+phy2U
|
||||
fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU
|
||||
y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX
|
||||
qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo
|
||||
f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA==
|
||||
-----END RSA PRIVATE KEY-----`)
|
||||
)
|
||||
|
||||
func TestServerLoadCertificateWithTLSEntryPoints(t *testing.T) {
|
||||
staticConfig := static.Configuration{}
|
||||
|
||||
dynamicConfigs := config.Configurations{
|
||||
"config": &config.Configuration{
|
||||
TLS: []*tls.Configuration{
|
||||
{
|
||||
Certificate: &tls.Certificate{
|
||||
CertFile: localhostCert,
|
||||
KeyFile: localhostKey,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
srv := NewServer(staticConfig, nil, EntryPoints{
|
||||
"https": &EntryPoint{
|
||||
Certs: tls.NewCertificateStore(),
|
||||
},
|
||||
"https2": &EntryPoint{
|
||||
Certs: tls.NewCertificateStore(),
|
||||
},
|
||||
})
|
||||
_, mapsCerts := srv.loadConfig(dynamicConfigs)
|
||||
if len(mapsCerts["https"]) == 0 || len(mapsCerts["https2"]) == 0 {
|
||||
t.Fatal("got error: https entryPoint must have TLS certificates.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReuseService(t *testing.T) {
|
||||
testServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer testServer.Close()
|
||||
|
||||
entryPoints := EntryPoints{
|
||||
"http": &EntryPoint{},
|
||||
entryPoints := TCPEntryPoints{
|
||||
"http": &TCPEntryPoint{},
|
||||
}
|
||||
|
||||
staticConfig := static.Configuration{}
|
||||
|
||||
dynamicConfigs := config.Configurations{
|
||||
"config": th.BuildConfiguration(
|
||||
th.WithRouters(
|
||||
th.WithRouter("foo",
|
||||
th.WithServiceName("bar"),
|
||||
th.WithRule("Path(`/ok`)")),
|
||||
th.WithRouter("foo2",
|
||||
th.WithEntryPoints("http"),
|
||||
th.WithRule("Path(`/unauthorized`)"),
|
||||
th.WithServiceName("bar"),
|
||||
th.WithRouterMiddlewares("basicauth")),
|
||||
),
|
||||
th.WithMiddlewares(th.WithMiddleware("basicauth",
|
||||
th.WithBasicAuth(&config.BasicAuth{Users: []string{"foo:bar"}}),
|
||||
)),
|
||||
th.WithLoadBalancerServices(th.WithService("bar",
|
||||
th.WithLBMethod("wrr"),
|
||||
th.WithServers(th.WithServer(testServer.URL))),
|
||||
),
|
||||
dynamicConfigs := th.BuildConfiguration(
|
||||
th.WithRouters(
|
||||
th.WithRouter("foo",
|
||||
th.WithServiceName("bar"),
|
||||
th.WithRule("Path(`/ok`)")),
|
||||
th.WithRouter("foo2",
|
||||
th.WithEntryPoints("http"),
|
||||
th.WithRule("Path(`/unauthorized`)"),
|
||||
th.WithServiceName("bar"),
|
||||
th.WithRouterMiddlewares("basicauth")),
|
||||
),
|
||||
}
|
||||
th.WithMiddlewares(th.WithMiddleware("basicauth",
|
||||
th.WithBasicAuth(&config.BasicAuth{Users: []string{"foo:bar"}}),
|
||||
)),
|
||||
th.WithLoadBalancerServices(th.WithService("bar",
|
||||
th.WithLBMethod("wrr"),
|
||||
th.WithServers(th.WithServer(testServer.URL))),
|
||||
),
|
||||
)
|
||||
|
||||
srv := NewServer(staticConfig, nil, entryPoints)
|
||||
srv := NewServer(staticConfig, nil, entryPoints, nil)
|
||||
|
||||
entrypointsHandlers, _ := srv.loadConfig(dynamicConfigs)
|
||||
entrypointsHandlers, _ := srv.createHTTPHandlers(context.Background(), *dynamicConfigs, []string{"http"})
|
||||
|
||||
// Test that the /ok path returns a status 200.
|
||||
responseRecorderOk := &httptest.ResponseRecorder{}
|
||||
|
@ -145,7 +75,7 @@ func TestThrottleProviderConfigReload(t *testing.T) {
|
|||
}()
|
||||
|
||||
staticConfiguration := static.Configuration{}
|
||||
server := NewServer(staticConfiguration, nil, nil)
|
||||
server := NewServer(staticConfiguration, nil, nil, nil)
|
||||
|
||||
go server.throttleProviderConfigReload(throttleDuration, publishConfig, providerConfig, stop)
|
||||
|
||||
|
|
|
@ -1,434 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
stdlog "log"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-proxyproto"
|
||||
"github.com/containous/traefik/config/static"
|
||||
"github.com/containous/traefik/h2c"
|
||||
"github.com/containous/traefik/ip"
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/middlewares/forwardedheaders"
|
||||
"github.com/containous/traefik/old/configuration"
|
||||
traefiktls "github.com/containous/traefik/tls"
|
||||
"github.com/containous/traefik/tls/generate"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/xenolf/lego/challenge/tlsalpn01"
|
||||
)
|
||||
|
||||
// EntryPoints map of EntryPoint
|
||||
type EntryPoints map[string]*EntryPoint
|
||||
|
||||
// NewEntryPoint creates a new EntryPoint
|
||||
func NewEntryPoint(ctx context.Context, configuration *static.EntryPoint) (*EntryPoint, error) {
|
||||
var err error
|
||||
|
||||
switcher := middlewares.NewHandlerSwitcher(buildDefaultHTTPRouter())
|
||||
handler, err := forwardedheaders.NewXForwarded(
|
||||
configuration.ForwardedHeaders.Insecure,
|
||||
configuration.ForwardedHeaders.TrustedIPs,
|
||||
switcher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tracker := newHijackConnectionTracker()
|
||||
|
||||
listener, err := buildListener(ctx, configuration)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error preparing server: %v", err)
|
||||
}
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
var certificateStore *traefiktls.CertificateStore
|
||||
if configuration.TLS != nil {
|
||||
certificateStore, err = buildCertificateStore(*configuration.TLS)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating certificate store: %v", err)
|
||||
}
|
||||
|
||||
tlsConfig, err = buildTLSConfig(*configuration.TLS)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating TLS config: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
entryPoint := &EntryPoint{
|
||||
switcher: switcher,
|
||||
transportConfiguration: configuration.Transport,
|
||||
hijackConnectionTracker: tracker,
|
||||
listener: listener,
|
||||
httpServer: buildServer(ctx, configuration, tlsConfig, handler, tracker),
|
||||
Certs: certificateStore,
|
||||
}
|
||||
|
||||
if tlsConfig != nil {
|
||||
tlsConfig.GetCertificate = entryPoint.getCertificate
|
||||
}
|
||||
|
||||
return entryPoint, nil
|
||||
}
|
||||
|
||||
// EntryPoint holds everything about the entry point (httpServer, listener etc...)
|
||||
type EntryPoint struct {
|
||||
RouteAppenderFactory RouteAppenderFactory
|
||||
httpServer *h2c.Server
|
||||
listener net.Listener
|
||||
switcher *middlewares.HandlerSwitcher
|
||||
Certs *traefiktls.CertificateStore
|
||||
OnDemandListener func(string) (*tls.Certificate, error)
|
||||
TLSALPNGetter func(string) (*tls.Certificate, error)
|
||||
hijackConnectionTracker *hijackConnectionTracker
|
||||
transportConfiguration *static.EntryPointsTransport
|
||||
}
|
||||
|
||||
// Start starts listening for traffic
|
||||
func (s *EntryPoint) Start(ctx context.Context) {
|
||||
logger := log.FromContext(ctx)
|
||||
logger.Infof("Starting server on %s", s.httpServer.Addr)
|
||||
|
||||
var err error
|
||||
if s.httpServer.TLSConfig != nil {
|
||||
err = s.httpServer.ServeTLS(s.listener, "", "")
|
||||
} else {
|
||||
err = s.httpServer.Serve(s.listener)
|
||||
}
|
||||
|
||||
if err != http.ErrServerClosed {
|
||||
logger.Error("Cannot start server: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown handles the entrypoint shutdown process
|
||||
func (s EntryPoint) Shutdown(ctx context.Context) {
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
reqAcceptGraceTimeOut := time.Duration(s.transportConfiguration.LifeCycle.RequestAcceptGraceTimeout)
|
||||
if reqAcceptGraceTimeOut > 0 {
|
||||
logger.Infof("Waiting %s for incoming requests to cease", reqAcceptGraceTimeOut)
|
||||
time.Sleep(reqAcceptGraceTimeOut)
|
||||
}
|
||||
|
||||
graceTimeOut := time.Duration(s.transportConfiguration.LifeCycle.GraceTimeOut)
|
||||
ctx, cancel := context.WithTimeout(ctx, graceTimeOut)
|
||||
logger.Debugf("Waiting %s seconds before killing connections.", graceTimeOut)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
if s.httpServer != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := s.httpServer.Shutdown(ctx); err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
logger.Debugf("Wait server shutdown is overdue to: %s", err)
|
||||
err = s.httpServer.Close()
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if s.hijackConnectionTracker != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := s.hijackConnectionTracker.Shutdown(ctx); err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
logger.Debugf("Wait hijack connection is overdue to: %s", err)
|
||||
s.hijackConnectionTracker.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
cancel()
|
||||
}
|
||||
|
||||
// getCertificate allows to customize tlsConfig.GetCertificate behavior to get the certificates inserted dynamically
|
||||
func (s *EntryPoint) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
domainToCheck := types.CanonicalDomain(clientHello.ServerName)
|
||||
|
||||
if s.TLSALPNGetter != nil {
|
||||
cert, err := s.TLSALPNGetter(domainToCheck)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cert != nil {
|
||||
return cert, nil
|
||||
}
|
||||
}
|
||||
|
||||
bestCertificate := s.Certs.GetBestCertificate(clientHello)
|
||||
if bestCertificate != nil {
|
||||
return bestCertificate, nil
|
||||
}
|
||||
|
||||
if s.OnDemandListener != nil && len(domainToCheck) > 0 {
|
||||
// Only check for an onDemandCert if there is a domain name
|
||||
return s.OnDemandListener(domainToCheck)
|
||||
}
|
||||
|
||||
if s.Certs.SniStrict {
|
||||
return nil, fmt.Errorf("strict SNI enabled - No certificate found for domain: %q, closing connection", domainToCheck)
|
||||
}
|
||||
|
||||
log.WithoutContext().Debugf("Serving default certificate for request: %q", domainToCheck)
|
||||
return s.Certs.DefaultCertificate, nil
|
||||
}
|
||||
|
||||
func newHijackConnectionTracker() *hijackConnectionTracker {
|
||||
return &hijackConnectionTracker{
|
||||
conns: make(map[net.Conn]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
type hijackConnectionTracker struct {
|
||||
conns map[net.Conn]struct{}
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// AddHijackedConnection add a connection in the tracked connections list
|
||||
func (h *hijackConnectionTracker) AddHijackedConnection(conn net.Conn) {
|
||||
h.lock.Lock()
|
||||
defer h.lock.Unlock()
|
||||
h.conns[conn] = struct{}{}
|
||||
}
|
||||
|
||||
// RemoveHijackedConnection remove a connection from the tracked connections list
|
||||
func (h *hijackConnectionTracker) RemoveHijackedConnection(conn net.Conn) {
|
||||
h.lock.Lock()
|
||||
defer h.lock.Unlock()
|
||||
delete(h.conns, conn)
|
||||
}
|
||||
|
||||
// Shutdown wait for the connection closing
|
||||
func (h *hijackConnectionTracker) Shutdown(ctx context.Context) error {
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
h.lock.RLock()
|
||||
if len(h.conns) == 0 {
|
||||
return nil
|
||||
}
|
||||
h.lock.RUnlock()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close close all the connections in the tracked connections list
|
||||
func (h *hijackConnectionTracker) Close() {
|
||||
for conn := range h.conns {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.WithoutContext().Errorf("Error while closing Hijacked connection: %v", err)
|
||||
}
|
||||
delete(h.conns, conn)
|
||||
}
|
||||
}
|
||||
|
||||
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
|
||||
// connections.
|
||||
type tcpKeepAliveListener struct {
|
||||
*net.TCPListener
|
||||
}
|
||||
|
||||
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
|
||||
tc, err := ln.AcceptTCP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = tc.SetKeepAlive(true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = tc.SetKeepAlivePeriod(3 * time.Minute); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tc, nil
|
||||
}
|
||||
|
||||
func buildProxyProtocolListener(ctx context.Context, entryPoint *static.EntryPoint, listener net.Listener) (net.Listener, error) {
|
||||
var sourceCheck func(addr net.Addr) (bool, error)
|
||||
if entryPoint.ProxyProtocol.Insecure {
|
||||
sourceCheck = func(_ net.Addr) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
} else {
|
||||
checker, err := ip.NewChecker(entryPoint.ProxyProtocol.TrustedIPs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sourceCheck = func(addr net.Addr) (bool, error) {
|
||||
ipAddr, ok := addr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("type error %v", addr)
|
||||
}
|
||||
|
||||
return checker.ContainsIP(ipAddr.IP), nil
|
||||
}
|
||||
}
|
||||
|
||||
log.FromContext(ctx).Infof("Enabling ProxyProtocol for trusted IPs %v", entryPoint.ProxyProtocol.TrustedIPs)
|
||||
|
||||
return &proxyproto.Listener{
|
||||
Listener: listener,
|
||||
SourceCheck: sourceCheck,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildServerTimeouts(entryPointsTransport static.EntryPointsTransport) (readTimeout, writeTimeout, idleTimeout time.Duration) {
|
||||
readTimeout = time.Duration(0)
|
||||
writeTimeout = time.Duration(0)
|
||||
if entryPointsTransport.RespondingTimeouts != nil {
|
||||
readTimeout = time.Duration(entryPointsTransport.RespondingTimeouts.ReadTimeout)
|
||||
writeTimeout = time.Duration(entryPointsTransport.RespondingTimeouts.WriteTimeout)
|
||||
}
|
||||
|
||||
if entryPointsTransport.RespondingTimeouts != nil {
|
||||
idleTimeout = time.Duration(entryPointsTransport.RespondingTimeouts.IdleTimeout)
|
||||
} else {
|
||||
idleTimeout = configuration.DefaultIdleTimeout
|
||||
}
|
||||
|
||||
return readTimeout, writeTimeout, idleTimeout
|
||||
}
|
||||
|
||||
func buildListener(ctx context.Context, entryPoint *static.EntryPoint) (net.Listener, error) {
|
||||
listener, err := net.Listen("tcp", entryPoint.Address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening listener: %v", err)
|
||||
}
|
||||
|
||||
listener = tcpKeepAliveListener{listener.(*net.TCPListener)}
|
||||
|
||||
if entryPoint.ProxyProtocol != nil {
|
||||
listener, err = buildProxyProtocolListener(ctx, entryPoint, listener)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating proxy protocol listener: %v", err)
|
||||
}
|
||||
}
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
func buildCertificateStore(tlsOption traefiktls.TLS) (*traefiktls.CertificateStore, error) {
|
||||
certificateStore := traefiktls.NewCertificateStore()
|
||||
certificateStore.DynamicCerts.Set(make(map[string]*tls.Certificate))
|
||||
|
||||
certificateStore.SniStrict = tlsOption.SniStrict
|
||||
|
||||
if tlsOption.DefaultCertificate != nil {
|
||||
cert, err := buildDefaultCertificate(tlsOption.DefaultCertificate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
certificateStore.DefaultCertificate = cert
|
||||
} else {
|
||||
cert, err := generate.DefaultCertificate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
certificateStore.DefaultCertificate = cert
|
||||
}
|
||||
return certificateStore, nil
|
||||
}
|
||||
|
||||
func buildServer(ctx context.Context, configuration *static.EntryPoint, tlsConfig *tls.Config, router http.Handler, tracker *hijackConnectionTracker) *h2c.Server {
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
readTimeout, writeTimeout, idleTimeout := buildServerTimeouts(*configuration.Transport)
|
||||
logger.
|
||||
WithField("readTimeout", readTimeout).
|
||||
WithField("writeTimeout", writeTimeout).
|
||||
WithField("idleTimeout", idleTimeout).
|
||||
Infof("Preparing server")
|
||||
|
||||
return &h2c.Server{
|
||||
Server: &http.Server{
|
||||
Addr: configuration.Address,
|
||||
Handler: router,
|
||||
TLSConfig: tlsConfig,
|
||||
ReadTimeout: readTimeout,
|
||||
WriteTimeout: writeTimeout,
|
||||
IdleTimeout: idleTimeout,
|
||||
ErrorLog: stdlog.New(logger.WriterLevel(logrus.DebugLevel), "", 0),
|
||||
ConnState: func(conn net.Conn, state http.ConnState) {
|
||||
switch state {
|
||||
case http.StateHijacked:
|
||||
tracker.AddHijackedConnection(conn)
|
||||
case http.StateClosed:
|
||||
tracker.RemoveHijackedConnection(conn)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// creates a TLS config that allows terminating HTTPS for multiple domains using SNI
|
||||
func buildTLSConfig(tlsOption traefiktls.TLS) (*tls.Config, error) {
|
||||
conf := &tls.Config{}
|
||||
|
||||
// ensure http2 enabled
|
||||
conf.NextProtos = []string{"h2", "http/1.1", tlsalpn01.ACMETLS1Protocol}
|
||||
|
||||
if len(tlsOption.ClientCA.Files) > 0 {
|
||||
pool := x509.NewCertPool()
|
||||
for _, caFile := range tlsOption.ClientCA.Files {
|
||||
data, err := caFile.Read()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ok := pool.AppendCertsFromPEM(data)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid certificate(s) in %s", caFile)
|
||||
}
|
||||
}
|
||||
conf.ClientCAs = pool
|
||||
if tlsOption.ClientCA.Optional {
|
||||
conf.ClientAuth = tls.VerifyClientCertIfGiven
|
||||
} else {
|
||||
conf.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
}
|
||||
|
||||
// Set the minimum TLS version if set in the config TOML
|
||||
if minConst, exists := traefiktls.MinVersion[tlsOption.MinVersion]; exists {
|
||||
conf.PreferServerCipherSuites = true
|
||||
conf.MinVersion = minConst
|
||||
}
|
||||
|
||||
// Set the list of CipherSuites if set in the config TOML
|
||||
if tlsOption.CipherSuites != nil {
|
||||
// if our list of CipherSuites is defined in the entryPoint config, we can re-initialize the suites list as empty
|
||||
conf.CipherSuites = make([]uint16, 0)
|
||||
for _, cipher := range tlsOption.CipherSuites {
|
||||
if cipherConst, exists := traefiktls.CipherSuites[cipher]; exists {
|
||||
conf.CipherSuites = append(conf.CipherSuites, cipherConst)
|
||||
} else {
|
||||
// CipherSuite listed in the toml does not exist in our listed
|
||||
return nil, fmt.Errorf("invalid CipherSuite: %s", cipher)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return conf, nil
|
||||
}
|
385
server/server_entrypoint_tcp.go
Normal file
385
server/server_entrypoint_tcp.go
Normal file
|
@ -0,0 +1,385 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-proxyproto"
|
||||
"github.com/containous/traefik/config/static"
|
||||
"github.com/containous/traefik/h2c"
|
||||
"github.com/containous/traefik/ip"
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/middlewares/forwardedheaders"
|
||||
"github.com/containous/traefik/safe"
|
||||
"github.com/containous/traefik/tcp"
|
||||
)
|
||||
|
||||
type httpForwarder struct {
|
||||
net.Listener
|
||||
connChan chan net.Conn
|
||||
}
|
||||
|
||||
func newHTTPForwarder(ln net.Listener) *httpForwarder {
|
||||
return &httpForwarder{
|
||||
Listener: ln,
|
||||
connChan: make(chan net.Conn),
|
||||
}
|
||||
}
|
||||
|
||||
// ServeTCP uses the connection to serve it later in "Accept"
|
||||
func (h *httpForwarder) ServeTCP(conn net.Conn) {
|
||||
h.connChan <- conn
|
||||
}
|
||||
|
||||
// Accept retrieves a served connection in ServeTCP
|
||||
func (h *httpForwarder) Accept() (net.Conn, error) {
|
||||
conn := <-h.connChan
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// TCPEntryPoints holds a map of TCPEntryPoint (the entrypoint names being the keys)
|
||||
type TCPEntryPoints map[string]*TCPEntryPoint
|
||||
|
||||
// TCPEntryPoint is the TCP server
|
||||
type TCPEntryPoint struct {
|
||||
listener net.Listener
|
||||
switcher *tcp.HandlerSwitcher
|
||||
RouteAppenderFactory RouteAppenderFactory
|
||||
transportConfiguration *static.EntryPointsTransport
|
||||
tracker *connectionTracker
|
||||
httpServer *httpServer
|
||||
httpsServer *httpServer
|
||||
}
|
||||
|
||||
// NewTCPEntryPoint creates a new TCPEntryPoint
|
||||
func NewTCPEntryPoint(ctx context.Context, configuration *static.EntryPoint) (*TCPEntryPoint, error) {
|
||||
tracker := newConnectionTracker()
|
||||
|
||||
listener, err := buildListener(ctx, configuration)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error preparing server: %v", err)
|
||||
}
|
||||
|
||||
router := &tcp.Router{}
|
||||
|
||||
httpServer, err := createHTTPServer(listener, configuration, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error preparing httpServer: %v", err)
|
||||
}
|
||||
|
||||
router.HTTPForwarder(httpServer.Forwarder)
|
||||
|
||||
httpsServer, err := createHTTPServer(listener, configuration, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error preparing httpsServer: %v", err)
|
||||
}
|
||||
|
||||
router.HTTPSForwarder(httpsServer.Forwarder)
|
||||
|
||||
tcpSwitcher := &tcp.HandlerSwitcher{}
|
||||
tcpSwitcher.Switch(router)
|
||||
|
||||
return &TCPEntryPoint{
|
||||
listener: listener,
|
||||
switcher: tcpSwitcher,
|
||||
transportConfiguration: configuration.Transport,
|
||||
tracker: tracker,
|
||||
httpServer: httpServer,
|
||||
httpsServer: httpsServer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *TCPEntryPoint) startTCP(ctx context.Context) {
|
||||
log.FromContext(ctx).Debugf("Start TCP Server")
|
||||
|
||||
for {
|
||||
conn, err := e.listener.Accept()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
safe.Go(func() {
|
||||
e.switcher.ServeTCP(newTrackedConnection(conn, e.tracker))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown stops the TCP connections
|
||||
func (e *TCPEntryPoint) Shutdown(ctx context.Context) {
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
reqAcceptGraceTimeOut := time.Duration(e.transportConfiguration.LifeCycle.RequestAcceptGraceTimeout)
|
||||
if reqAcceptGraceTimeOut > 0 {
|
||||
logger.Infof("Waiting %s for incoming requests to cease", reqAcceptGraceTimeOut)
|
||||
time.Sleep(reqAcceptGraceTimeOut)
|
||||
}
|
||||
|
||||
graceTimeOut := time.Duration(e.transportConfiguration.LifeCycle.GraceTimeOut)
|
||||
ctx, cancel := context.WithTimeout(ctx, graceTimeOut)
|
||||
logger.Debugf("Waiting %s seconds before killing connections.", graceTimeOut)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
if e.httpServer.Server != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := e.httpServer.Server.Shutdown(ctx); err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
logger.Debugf("Wait server shutdown is overdue to: %s", err)
|
||||
err = e.httpServer.Server.Close()
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if e.httpsServer.Server != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := e.httpsServer.Server.Shutdown(ctx); err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
logger.Debugf("Wait server shutdown is overdue to: %s", err)
|
||||
err = e.httpsServer.Server.Close()
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if e.tracker != nil {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := e.tracker.Shutdown(ctx); err != nil {
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
logger.Debugf("Wait hijack connection is overdue to: %s", err)
|
||||
e.tracker.Close()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
cancel()
|
||||
}
|
||||
|
||||
func (e *TCPEntryPoint) switchRouter(router *tcp.Router) {
|
||||
router.HTTPForwarder(e.httpServer.Forwarder)
|
||||
router.HTTPSForwarder(e.httpsServer.Forwarder)
|
||||
|
||||
httpHandler := router.GetHTTPHandler()
|
||||
httpsHandler := router.GetHTTPSHandler()
|
||||
if httpHandler == nil {
|
||||
httpHandler = buildDefaultHTTPRouter()
|
||||
}
|
||||
if httpsHandler == nil {
|
||||
httpsHandler = buildDefaultHTTPRouter()
|
||||
}
|
||||
|
||||
e.httpServer.Switcher.UpdateHandler(httpHandler)
|
||||
e.httpsServer.Switcher.UpdateHandler(httpsHandler)
|
||||
e.switcher.Switch(router)
|
||||
}
|
||||
|
||||
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
|
||||
// connections.
|
||||
type tcpKeepAliveListener struct {
|
||||
*net.TCPListener
|
||||
}
|
||||
|
||||
func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
|
||||
tc, err := ln.AcceptTCP()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = tc.SetKeepAlive(true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = tc.SetKeepAlivePeriod(3 * time.Minute); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tc, nil
|
||||
}
|
||||
|
||||
func buildProxyProtocolListener(ctx context.Context, entryPoint *static.EntryPoint, listener net.Listener) (net.Listener, error) {
|
||||
var sourceCheck func(addr net.Addr) (bool, error)
|
||||
if entryPoint.ProxyProtocol.Insecure {
|
||||
sourceCheck = func(_ net.Addr) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
} else {
|
||||
checker, err := ip.NewChecker(entryPoint.ProxyProtocol.TrustedIPs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sourceCheck = func(addr net.Addr) (bool, error) {
|
||||
ipAddr, ok := addr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
return false, fmt.Errorf("type error %v", addr)
|
||||
}
|
||||
|
||||
return checker.ContainsIP(ipAddr.IP), nil
|
||||
}
|
||||
}
|
||||
|
||||
log.FromContext(ctx).Infof("Enabling ProxyProtocol for trusted IPs %v", entryPoint.ProxyProtocol.TrustedIPs)
|
||||
|
||||
return &proxyproto.Listener{
|
||||
Listener: listener,
|
||||
SourceCheck: sourceCheck,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildListener(ctx context.Context, entryPoint *static.EntryPoint) (net.Listener, error) {
|
||||
listener, err := net.Listen("tcp", entryPoint.Address)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening listener: %v", err)
|
||||
}
|
||||
|
||||
listener = tcpKeepAliveListener{listener.(*net.TCPListener)}
|
||||
|
||||
if entryPoint.ProxyProtocol != nil {
|
||||
listener, err = buildProxyProtocolListener(ctx, entryPoint, listener)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating proxy protocol listener: %v", err)
|
||||
}
|
||||
}
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
func newConnectionTracker() *connectionTracker {
|
||||
return &connectionTracker{
|
||||
conns: make(map[net.Conn]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
type connectionTracker struct {
|
||||
conns map[net.Conn]struct{}
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// AddConnection add a connection in the tracked connections list
|
||||
func (c *connectionTracker) AddConnection(conn net.Conn) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
c.conns[conn] = struct{}{}
|
||||
}
|
||||
|
||||
// RemoveConnection remove a connection from the tracked connections list
|
||||
func (c *connectionTracker) RemoveConnection(conn net.Conn) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
delete(c.conns, conn)
|
||||
}
|
||||
|
||||
// Shutdown wait for the connection closing
|
||||
func (c *connectionTracker) Shutdown(ctx context.Context) error {
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
c.lock.RLock()
|
||||
if len(c.conns) == 0 {
|
||||
return nil
|
||||
}
|
||||
c.lock.RUnlock()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close close all the connections in the tracked connections list
|
||||
func (c *connectionTracker) Close() {
|
||||
for conn := range c.conns {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.WithoutContext().Errorf("Error while closing connection: %v", err)
|
||||
}
|
||||
delete(c.conns, conn)
|
||||
}
|
||||
}
|
||||
|
||||
type stoppableServer interface {
|
||||
Shutdown(context.Context) error
|
||||
Close() error
|
||||
Serve(listener net.Listener) error
|
||||
}
|
||||
|
||||
type httpServer struct {
|
||||
Server stoppableServer
|
||||
Forwarder *httpForwarder
|
||||
Switcher *middlewares.HTTPHandlerSwitcher
|
||||
}
|
||||
|
||||
func createHTTPServer(ln net.Listener, configuration *static.EntryPoint, withH2c bool) (*httpServer, error) {
|
||||
httpSwitcher := middlewares.NewHandlerSwitcher(buildDefaultHTTPRouter())
|
||||
handler, err := forwardedheaders.NewXForwarded(
|
||||
configuration.ForwardedHeaders.Insecure,
|
||||
configuration.ForwardedHeaders.TrustedIPs,
|
||||
httpSwitcher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var serverHTTP stoppableServer
|
||||
|
||||
if withH2c {
|
||||
serverHTTP = &h2c.Server{
|
||||
Server: &http.Server{
|
||||
Handler: handler,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
serverHTTP = &http.Server{
|
||||
Handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
listener := newHTTPForwarder(ln)
|
||||
go func() {
|
||||
err := serverHTTP.Serve(listener)
|
||||
if err != nil {
|
||||
log.Errorf("Error while starting server: %v", err)
|
||||
}
|
||||
}()
|
||||
return &httpServer{
|
||||
Server: serverHTTP,
|
||||
Forwarder: listener,
|
||||
Switcher: httpSwitcher,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newTrackedConnection(conn net.Conn, tracker *connectionTracker) *trackedConnection {
|
||||
tracker.AddConnection(conn)
|
||||
return &trackedConnection{
|
||||
Conn: conn,
|
||||
tracker: tracker,
|
||||
}
|
||||
}
|
||||
|
||||
type trackedConnection struct {
|
||||
tracker *connectionTracker
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (t *trackedConnection) Close() error {
|
||||
t.tracker.RemoveConnection(t.Conn)
|
||||
return t.Conn.Close()
|
||||
}
|
141
server/server_entrypoint_tcp_test.go
Normal file
141
server/server_entrypoint_tcp_test.go
Normal file
|
@ -0,0 +1,141 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/config/static"
|
||||
"github.com/containous/traefik/tcp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestShutdownHTTP(t *testing.T) {
|
||||
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
|
||||
Address: ":0",
|
||||
Transport: &static.EntryPointsTransport{
|
||||
LifeCycle: &static.LifeCycle{
|
||||
RequestAcceptGraceTimeout: 0,
|
||||
GraceTimeOut: 5000000000,
|
||||
},
|
||||
},
|
||||
ForwardedHeaders: &static.ForwardedHeaders{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
go entryPoint.startTCP(context.Background())
|
||||
|
||||
router := &tcp.Router{}
|
||||
router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
time.Sleep(time.Second * 1)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
entryPoint.switchRouter(router)
|
||||
|
||||
conn, err := net.Dial("tcp", entryPoint.listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
go entryPoint.Shutdown(context.Background())
|
||||
|
||||
request, err := http.NewRequest("GET", "http://127.0.0.1:8082", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = request.Write(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(conn), request)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
func TestShutdownHTTPHijacked(t *testing.T) {
|
||||
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
|
||||
Address: ":0",
|
||||
Transport: &static.EntryPointsTransport{
|
||||
LifeCycle: &static.LifeCycle{
|
||||
RequestAcceptGraceTimeout: 0,
|
||||
GraceTimeOut: 5000000000,
|
||||
},
|
||||
},
|
||||
ForwardedHeaders: &static.ForwardedHeaders{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
go entryPoint.startTCP(context.Background())
|
||||
|
||||
router := &tcp.Router{}
|
||||
router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
conn, _, err := rw.(http.Hijacker).Hijack()
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second * 1)
|
||||
|
||||
resp := http.Response{StatusCode: http.StatusOK}
|
||||
err = resp.Write(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
}))
|
||||
entryPoint.switchRouter(router)
|
||||
|
||||
conn, err := net.Dial("tcp", entryPoint.listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
go entryPoint.Shutdown(context.Background())
|
||||
|
||||
request, err := http.NewRequest("GET", "http://127.0.0.1:8082", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = request.Write(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(conn), request)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, resp.StatusCode, http.StatusOK)
|
||||
}
|
||||
|
||||
func TestShutdownTCPConn(t *testing.T) {
|
||||
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
|
||||
Address: ":0",
|
||||
Transport: &static.EntryPointsTransport{
|
||||
LifeCycle: &static.LifeCycle{
|
||||
RequestAcceptGraceTimeout: 0,
|
||||
GraceTimeOut: 5000000000,
|
||||
},
|
||||
},
|
||||
ForwardedHeaders: &static.ForwardedHeaders{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
go entryPoint.startTCP(context.Background())
|
||||
|
||||
router := &tcp.Router{}
|
||||
router.AddCatchAllNoTLS(tcp.HandlerFunc(func(conn net.Conn) {
|
||||
_, err := http.ReadRequest(bufio.NewReader(conn))
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second * 1)
|
||||
|
||||
resp := http.Response{StatusCode: http.StatusOK}
|
||||
err = resp.Write(conn)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
|
||||
entryPoint.switchRouter(router)
|
||||
|
||||
conn, err := net.Dial("tcp", entryPoint.listener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
go entryPoint.Shutdown(context.Background())
|
||||
|
||||
request, err := http.NewRequest("GET", "http://127.0.0.1:8082", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = request.Write(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := http.ReadResponse(bufio.NewReader(conn), request)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, resp.StatusCode, http.StatusOK)
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
@ -59,8 +60,8 @@ func TestListenProvidersSkipsSameConfigurationForProvider(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}()
|
||||
|
||||
conf := th.BuildConfiguration(
|
||||
conf := &config.Configuration{}
|
||||
conf.HTTP = th.BuildConfiguration(
|
||||
th.WithRouters(th.WithRouter("foo")),
|
||||
th.WithLoadBalancerServices(th.WithService("bar")),
|
||||
)
|
||||
|
@ -101,7 +102,8 @@ func TestListenProvidersPublishesConfigForEachProvider(t *testing.T) {
|
|||
}
|
||||
}()
|
||||
|
||||
conf := th.BuildConfiguration(
|
||||
conf := &config.Configuration{}
|
||||
conf.HTTP = th.BuildConfiguration(
|
||||
th.WithRouters(th.WithRouter("foo")),
|
||||
th.WithLoadBalancerServices(th.WithService("bar")),
|
||||
)
|
||||
|
@ -134,7 +136,7 @@ func setupListenProvider(throttleDuration time.Duration) (server *Server, stop c
|
|||
},
|
||||
}
|
||||
|
||||
server = NewServer(staticConfiguration, nil, nil)
|
||||
server = NewServer(staticConfiguration, nil, nil, nil)
|
||||
go server.listenProviders(stop)
|
||||
|
||||
return server, stop, invokeStopChan
|
||||
|
@ -146,12 +148,12 @@ func TestServerResponseEmptyBackend(t *testing.T) {
|
|||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
config func(testServerURL string) *config.Configuration
|
||||
config func(testServerURL string) *config.HTTPConfiguration
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
desc: "Ok",
|
||||
config: func(testServerURL string) *config.Configuration {
|
||||
config: func(testServerURL string) *config.HTTPConfiguration {
|
||||
return th.BuildConfiguration(
|
||||
th.WithRouters(th.WithRouter("foo",
|
||||
th.WithEntryPoints("http"),
|
||||
|
@ -168,14 +170,14 @@ func TestServerResponseEmptyBackend(t *testing.T) {
|
|||
},
|
||||
{
|
||||
desc: "No Frontend",
|
||||
config: func(testServerURL string) *config.Configuration {
|
||||
config: func(testServerURL string) *config.HTTPConfiguration {
|
||||
return th.BuildConfiguration()
|
||||
},
|
||||
expectedStatusCode: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
desc: "Empty Backend LB-Drr",
|
||||
config: func(testServerURL string) *config.Configuration {
|
||||
config: func(testServerURL string) *config.HTTPConfiguration {
|
||||
return th.BuildConfiguration(
|
||||
th.WithRouters(th.WithRouter("foo",
|
||||
th.WithEntryPoints("http"),
|
||||
|
@ -191,7 +193,7 @@ func TestServerResponseEmptyBackend(t *testing.T) {
|
|||
},
|
||||
{
|
||||
desc: "Empty Backend LB-Drr Sticky",
|
||||
config: func(testServerURL string) *config.Configuration {
|
||||
config: func(testServerURL string) *config.HTTPConfiguration {
|
||||
return th.BuildConfiguration(
|
||||
th.WithRouters(th.WithRouter("foo",
|
||||
th.WithEntryPoints("http"),
|
||||
|
@ -207,7 +209,7 @@ func TestServerResponseEmptyBackend(t *testing.T) {
|
|||
},
|
||||
{
|
||||
desc: "Empty Backend LB-Wrr",
|
||||
config: func(testServerURL string) *config.Configuration {
|
||||
config: func(testServerURL string) *config.HTTPConfiguration {
|
||||
return th.BuildConfiguration(
|
||||
th.WithRouters(th.WithRouter("foo",
|
||||
th.WithEntryPoints("http"),
|
||||
|
@ -223,7 +225,7 @@ func TestServerResponseEmptyBackend(t *testing.T) {
|
|||
},
|
||||
{
|
||||
desc: "Empty Backend LB-Wrr Sticky",
|
||||
config: func(testServerURL string) *config.Configuration {
|
||||
config: func(testServerURL string) *config.HTTPConfiguration {
|
||||
return th.BuildConfiguration(
|
||||
th.WithRouters(th.WithRouter("foo",
|
||||
th.WithEntryPoints("http"),
|
||||
|
@ -251,13 +253,12 @@ func TestServerResponseEmptyBackend(t *testing.T) {
|
|||
defer testServer.Close()
|
||||
|
||||
globalConfig := static.Configuration{}
|
||||
entryPointsConfig := EntryPoints{
|
||||
"http": &EntryPoint{},
|
||||
entryPointsConfig := TCPEntryPoints{
|
||||
"http": &TCPEntryPoint{},
|
||||
}
|
||||
dynamicConfigs := config.Configurations{"config": test.config(testServer.URL)}
|
||||
|
||||
srv := NewServer(globalConfig, nil, entryPointsConfig)
|
||||
entryPoints, _ := srv.loadConfig(dynamicConfigs)
|
||||
srv := NewServer(globalConfig, nil, entryPointsConfig, nil)
|
||||
entryPoints, _ := srv.createHTTPHandlers(context.Background(), *test.config(testServer.URL), []string{"http"})
|
||||
|
||||
responseRecorder := &httptest.ResponseRecorder{}
|
||||
request := httptest.NewRequest(http.MethodGet, testServer.URL+requestPath, nil)
|
||||
|
|
100
server/service/proxy.go
Normal file
100
server/service/proxy.go
Normal file
|
@ -0,0 +1,100 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/containous/flaeg/parse"
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/log"
|
||||
)
|
||||
|
||||
// StatusClientClosedRequest non-standard HTTP status code for client disconnection
|
||||
const StatusClientClosedRequest = 499
|
||||
|
||||
// StatusClientClosedRequestText non-standard HTTP status for client disconnection
|
||||
const StatusClientClosedRequestText = "Client Closed Request"
|
||||
|
||||
func buildProxy(passHostHeader bool, responseForwarding *config.ResponseForwarding, defaultRoundTripper http.RoundTripper, bufferPool httputil.BufferPool, responseModifier func(*http.Response) error) (http.Handler, error) {
|
||||
var flushInterval parse.Duration
|
||||
if responseForwarding != nil {
|
||||
err := flushInterval.Set(responseForwarding.FlushInterval)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating flush interval: %v", err)
|
||||
}
|
||||
}
|
||||
if flushInterval == 0 {
|
||||
flushInterval = parse.Duration(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
proxy := &httputil.ReverseProxy{
|
||||
Director: func(outReq *http.Request) {
|
||||
u := outReq.URL
|
||||
if outReq.RequestURI != "" {
|
||||
parsedURL, err := url.ParseRequestURI(outReq.RequestURI)
|
||||
if err == nil {
|
||||
u = parsedURL
|
||||
}
|
||||
}
|
||||
|
||||
outReq.URL.Path = u.Path
|
||||
outReq.URL.RawPath = u.RawPath
|
||||
outReq.URL.RawQuery = u.RawQuery
|
||||
outReq.RequestURI = "" // Outgoing request should not have RequestURI
|
||||
|
||||
outReq.Proto = "HTTP/1.1"
|
||||
outReq.ProtoMajor = 1
|
||||
outReq.ProtoMinor = 1
|
||||
|
||||
// Do not pass client Host header unless optsetter PassHostHeader is set.
|
||||
if !passHostHeader {
|
||||
outReq.Host = outReq.URL.Host
|
||||
}
|
||||
|
||||
},
|
||||
Transport: defaultRoundTripper,
|
||||
FlushInterval: time.Duration(flushInterval),
|
||||
ModifyResponse: responseModifier,
|
||||
BufferPool: bufferPool,
|
||||
ErrorHandler: func(w http.ResponseWriter, request *http.Request, err error) {
|
||||
statusCode := http.StatusInternalServerError
|
||||
|
||||
switch {
|
||||
case err == io.EOF:
|
||||
statusCode = http.StatusBadGateway
|
||||
case err == context.Canceled:
|
||||
statusCode = StatusClientClosedRequest
|
||||
default:
|
||||
if e, ok := err.(net.Error); ok {
|
||||
if e.Timeout() {
|
||||
statusCode = http.StatusGatewayTimeout
|
||||
} else {
|
||||
statusCode = http.StatusBadGateway
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("'%d %s' caused by: %v", statusCode, statusText(statusCode), err)
|
||||
w.WriteHeader(statusCode)
|
||||
_, werr := w.Write([]byte(statusText(statusCode)))
|
||||
if werr != nil {
|
||||
log.Debugf("Error while writing status code", werr)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
func statusText(statusCode int) string {
|
||||
if statusCode == StatusClientClosedRequest {
|
||||
return StatusClientClosedRequestText
|
||||
}
|
||||
return http.StatusText(statusCode)
|
||||
}
|
37
server/service/proxy_test.go
Normal file
37
server/service/proxy_test.go
Normal file
|
@ -0,0 +1,37 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
)
|
||||
|
||||
type staticTransport struct {
|
||||
res *http.Response
|
||||
}
|
||||
|
||||
func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
return t.res, nil
|
||||
}
|
||||
|
||||
func BenchmarkProxy(b *testing.B) {
|
||||
res := &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("")),
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
|
||||
|
||||
pool := newBufferPool()
|
||||
handler, _ := buildProxy(false, nil, &staticTransport{res}, pool, nil)
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
handler.ServeHTTP(w, req)
|
||||
}
|
||||
}
|
713
server/service/proxy_websocket_test.go
Normal file
713
server/service/proxy_websocket_test.go
Normal file
|
@ -0,0 +1,713 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
gorillawebsocket "github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
func TestWebSocketTCPClose(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
for {
|
||||
_, _, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, f, srv.URL)
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
_, conn, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
).open()
|
||||
require.NoError(t, err)
|
||||
conn.Close()
|
||||
|
||||
serverErr := <-errChan
|
||||
|
||||
wsErr, ok := serverErr.(*gorillawebsocket.CloseError)
|
||||
assert.Equal(t, true, ok)
|
||||
assert.Equal(t, 1006, wsErr.Code)
|
||||
}
|
||||
|
||||
func TestWebSocketPingPong(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
var upgrader = gorillawebsocket.Upgrader{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
CheckOrigin: func(*http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/ws", func(writer http.ResponseWriter, request *http.Request) {
|
||||
ws, err := upgrader.Upgrade(writer, request, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
ws.SetPingHandler(func(appData string) error {
|
||||
err = ws.WriteMessage(gorillawebsocket.PongMessage, []byte(appData+"Pong"))
|
||||
require.NoError(t, err)
|
||||
return nil
|
||||
})
|
||||
|
||||
_, _, _ = ws.ReadMessage()
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
req.URL = parseURI(t, srv.URL)
|
||||
f.ServeHTTP(w, req)
|
||||
}))
|
||||
defer proxy.Close()
|
||||
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", resp)
|
||||
defer conn.Close()
|
||||
|
||||
goodErr := fmt.Errorf("signal: %s", "Good data")
|
||||
badErr := fmt.Errorf("signal: %s", "Bad data")
|
||||
conn.SetPongHandler(func(data string) error {
|
||||
if data == "PingPong" {
|
||||
return goodErr
|
||||
}
|
||||
return badErr
|
||||
})
|
||||
|
||||
err = conn.WriteControl(gorillawebsocket.PingMessage, []byte("Ping"), time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = conn.ReadMessage()
|
||||
|
||||
if err != goodErr {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketEcho(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
msg := make([]byte, 4)
|
||||
_, err := conn.Read(msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
fmt.Println(string(msg))
|
||||
|
||||
_, err = conn.Write(msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
}))
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
req.URL = parseURI(t, srv.URL)
|
||||
f.ServeHTTP(w, req)
|
||||
}))
|
||||
defer proxy.Close()
|
||||
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", resp)
|
||||
|
||||
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
|
||||
require.NoError(t, err)
|
||||
|
||||
fmt.Println(conn.ReadMessage())
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
}
|
||||
|
||||
func TestWebSocketPassHost(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
passHost bool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "PassHost false",
|
||||
passHost: false,
|
||||
},
|
||||
{
|
||||
desc: "PassHost true",
|
||||
passHost: true,
|
||||
expected: "example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
f, err := buildProxy(test.passHost, nil, http.DefaultTransport, nil, nil)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
req := conn.Request()
|
||||
|
||||
if test.passHost {
|
||||
require.Equal(t, test.expected, req.Host)
|
||||
} else {
|
||||
require.NotEqual(t, test.expected, req.Host)
|
||||
}
|
||||
|
||||
msg := make([]byte, 4)
|
||||
_, err = conn.Read(msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
fmt.Println(string(msg))
|
||||
_, err = conn.Write(msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
}))
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
req.URL = parseURI(t, srv.URL)
|
||||
f.ServeHTTP(w, req)
|
||||
}))
|
||||
defer proxy.Close()
|
||||
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
headers.Add("Host", "example.com")
|
||||
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", resp)
|
||||
|
||||
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
|
||||
require.NoError(t, err)
|
||||
|
||||
fmt.Println(conn.ReadMessage())
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketServerWithoutCheckOrigin(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
}}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
for {
|
||||
mt, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = c.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, f, srv.URL)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
withOrigin("http://127.0.0.2"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithOrigin(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
for {
|
||||
mt, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = c.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, f, srv.URL)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
_, err = newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("echo"),
|
||||
withOrigin("http://127.0.0.2"),
|
||||
).send()
|
||||
require.EqualError(t, err, "bad status")
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithQueryParams(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
assert.Equal(t, "test", r.URL.Query().Get("query"))
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = conn.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, f, srv.URL)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws?query=test"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
conn.Close()
|
||||
}))
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
req.URL = parseURI(t, srv.URL)
|
||||
w.Header().Set("HEADER-KEY", "HEADER-VALUE")
|
||||
f.ServeHTTP(w, req)
|
||||
}))
|
||||
defer proxy.Close()
|
||||
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", err, resp)
|
||||
defer conn.Close()
|
||||
|
||||
assert.Equal(t, "HEADER-VALUE", resp.Header.Get("HEADER-KEY"))
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithEncodedChar(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
assert.Equal(t, "/%3A%2F%2F", r.URL.EscapedPath())
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = conn.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, f, srv.URL)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/%3A%2F%2F"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketUpgradeFailed(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/ws", func(w http.ResponseWriter, req *http.Request) {
|
||||
w.WriteHeader(400)
|
||||
})
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
path := req.URL.Path // keep the original path
|
||||
|
||||
if path == "/ws" {
|
||||
// Set new backend URL
|
||||
req.URL = parseURI(t, srv.URL)
|
||||
req.URL.Path = path
|
||||
f.ServeHTTP(w, req)
|
||||
} else {
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
}))
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
conn, err := net.DialTimeout("tcp", proxyAddr, dialTimeout)
|
||||
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "ws://127.0.0.1/ws", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Add("upgrade", "websocket")
|
||||
req.Header.Add("Connection", "upgrade")
|
||||
|
||||
err = req.Write(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First request works with 400
|
||||
br := bufio.NewReader(conn)
|
||||
resp, err := http.ReadResponse(br, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 400, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestForwardsWebsocketTraffic(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
_, err := conn.Write([]byte("ok"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, f, srv.URL)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("echo"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func createTLSWebsocketServer() *httptest.Server {
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = conn.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestWebSocketTransferTLSConfig(t *testing.T) {
|
||||
srv := createTLSWebsocketServer()
|
||||
defer srv.Close()
|
||||
|
||||
forwarderWithoutTLSConfig, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyWithoutTLSConfig := createProxyWithForwarder(t, forwarderWithoutTLSConfig, srv.URL)
|
||||
defer proxyWithoutTLSConfig.Close()
|
||||
|
||||
proxyAddr := proxyWithoutTLSConfig.Listener.Addr().String()
|
||||
|
||||
_, err = newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.EqualError(t, err, "bad status")
|
||||
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
forwarderWithTLSConfig, err := buildProxy(true, nil, transport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyWithTLSConfig := createProxyWithForwarder(t, forwarderWithTLSConfig, srv.URL)
|
||||
defer proxyWithTLSConfig.Close()
|
||||
|
||||
proxyAddr = proxyWithTLSConfig.Listener.Addr().String()
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
|
||||
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
|
||||
forwarderWithTLSConfigFromDefaultTransport, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyWithTLSConfigFromDefaultTransport := createProxyWithForwarder(t, forwarderWithTLSConfigFromDefaultTransport, srv.URL)
|
||||
defer proxyWithTLSConfig.Close()
|
||||
|
||||
proxyAddr = proxyWithTLSConfigFromDefaultTransport.Listener.Addr().String()
|
||||
|
||||
resp, err = newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
const dialTimeout = time.Second
|
||||
|
||||
type websocketRequestOpt func(w *websocketRequest)
|
||||
|
||||
func withServer(server string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.ServerAddr = server
|
||||
}
|
||||
}
|
||||
|
||||
func withPath(path string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.Path = path
|
||||
}
|
||||
}
|
||||
|
||||
func withData(data string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.Data = data
|
||||
}
|
||||
}
|
||||
|
||||
func withOrigin(origin string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.Origin = origin
|
||||
}
|
||||
}
|
||||
|
||||
func newWebsocketRequest(opts ...websocketRequestOpt) *websocketRequest {
|
||||
wsrequest := &websocketRequest{}
|
||||
for _, opt := range opts {
|
||||
opt(wsrequest)
|
||||
}
|
||||
if wsrequest.Origin == "" {
|
||||
wsrequest.Origin = "http://" + wsrequest.ServerAddr
|
||||
}
|
||||
if wsrequest.Config == nil {
|
||||
wsrequest.Config, _ = websocket.NewConfig(fmt.Sprintf("ws://%s%s", wsrequest.ServerAddr, wsrequest.Path), wsrequest.Origin)
|
||||
}
|
||||
return wsrequest
|
||||
}
|
||||
|
||||
type websocketRequest struct {
|
||||
ServerAddr string
|
||||
Path string
|
||||
Data string
|
||||
Origin string
|
||||
Config *websocket.Config
|
||||
}
|
||||
|
||||
func (w *websocketRequest) send() (string, error) {
|
||||
conn, _, err := w.open()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer conn.Close()
|
||||
if _, err := conn.Write([]byte(w.Data)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
var msg = make([]byte, 512)
|
||||
var n int
|
||||
n, err = conn.Read(msg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
received := string(msg[:n])
|
||||
return received, nil
|
||||
}
|
||||
|
||||
func (w *websocketRequest) open() (*websocket.Conn, net.Conn, error) {
|
||||
client, err := net.DialTimeout("tcp", w.ServerAddr, dialTimeout)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
conn, err := websocket.NewClient(w.Config, client)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return conn, client, err
|
||||
}
|
||||
|
||||
func parseURI(t *testing.T, uri string) *url.URL {
|
||||
out, err := url.ParseRequestURI(uri)
|
||||
require.NoError(t, err)
|
||||
return out
|
||||
}
|
||||
|
||||
func createProxyWithForwarder(t *testing.T, proxy http.Handler, url string) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
path := req.URL.Path // keep the original path
|
||||
// Set new backend URL
|
||||
req.URL = parseURI(t, url)
|
||||
req.URL.Path = path
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
}))
|
||||
}
|
|
@ -3,14 +3,12 @@ package service
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/containous/alice"
|
||||
"github.com/containous/flaeg/parse"
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/healthcheck"
|
||||
"github.com/containous/traefik/log"
|
||||
|
@ -19,7 +17,6 @@ import (
|
|||
"github.com/containous/traefik/old/middlewares/pipelining"
|
||||
"github.com/containous/traefik/server/cookie"
|
||||
"github.com/containous/traefik/server/internal"
|
||||
"github.com/vulcand/oxy/forward"
|
||||
"github.com/vulcand/oxy/roundrobin"
|
||||
)
|
||||
|
||||
|
@ -46,8 +43,8 @@ type Manager struct {
|
|||
configs map[string]*config.Service
|
||||
}
|
||||
|
||||
// Build Creates a http.Handler for a service configuration.
|
||||
func (m *Manager) Build(rootCtx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) {
|
||||
// BuildHTTP Creates a http.Handler for a service configuration.
|
||||
func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) {
|
||||
ctx := log.With(rootCtx, log.Str(log.ServiceName, serviceName))
|
||||
|
||||
serviceName = internal.GetQualifiedName(ctx, serviceName)
|
||||
|
@ -55,6 +52,7 @@ func (m *Manager) Build(rootCtx context.Context, serviceName string, responseMod
|
|||
|
||||
if conf, ok := m.configs[serviceName]; ok {
|
||||
// TODO Should handle multiple service types
|
||||
// FIXME Check if the service is declared multiple times with different types
|
||||
if conf.LoadBalancer != nil {
|
||||
return m.getLoadBalancerServiceHandler(ctx, serviceName, conf.LoadBalancer, responseModifier)
|
||||
}
|
||||
|
@ -69,8 +67,7 @@ func (m *Manager) getLoadBalancerServiceHandler(
|
|||
service *config.LoadBalancerService,
|
||||
responseModifier func(*http.Response) error,
|
||||
) (http.Handler, error) {
|
||||
|
||||
fwd, err := m.buildForwarder(service.PassHostHeader, service.ResponseForwarding, responseModifier)
|
||||
fwd, err := buildProxy(service.PassHostHeader, service.ResponseForwarding, m.defaultRoundTripper, m.bufferPool, responseModifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -258,32 +255,3 @@ func (m *Manager) upsertServers(ctx context.Context, lb healthcheck.BalancerHand
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) buildForwarder(passHostHeader bool, responseForwarding *config.ResponseForwarding, responseModifier func(*http.Response) error) (http.Handler, error) {
|
||||
|
||||
var flushInterval parse.Duration
|
||||
if responseForwarding != nil {
|
||||
err := flushInterval.Set(responseForwarding.FlushInterval)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating flush interval: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return forward.New(
|
||||
forward.Stream(true),
|
||||
forward.PassHostHeader(passHostHeader),
|
||||
forward.RoundTripper(m.defaultRoundTripper),
|
||||
forward.ResponseModifier(responseModifier),
|
||||
forward.BufferPool(m.bufferPool),
|
||||
forward.StreamingFlushInterval(time.Duration(flushInterval)),
|
||||
forward.WebsocketConnectionClosedHook(func(req *http.Request, conn net.Conn) {
|
||||
server := req.Context().Value(http.ServerContextKey).(*http.Server)
|
||||
if server != nil {
|
||||
connState := server.ConnState
|
||||
if connState != nil {
|
||||
connState(conn, http.StateClosed)
|
||||
}
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
|
|
@ -320,7 +320,7 @@ func TestManager_Build(t *testing.T) {
|
|||
ctx = internal.AddProviderInContext(ctx, test.providerName+".foobar")
|
||||
}
|
||||
|
||||
_, err := manager.Build(ctx, test.serviceName, nil)
|
||||
_, err := manager.BuildHTTP(ctx, test.serviceName, nil)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
|
67
server/service/tcp/service.go
Normal file
67
server/service/tcp/service.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/server/internal"
|
||||
"github.com/containous/traefik/tcp"
|
||||
)
|
||||
|
||||
// Manager is the TCPHandlers factory
|
||||
type Manager struct {
|
||||
configs map[string]*config.TCPService
|
||||
}
|
||||
|
||||
// NewManager creates a new manager
|
||||
func NewManager(configs map[string]*config.TCPService) *Manager {
|
||||
return &Manager{
|
||||
configs: configs,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildTCP Creates a tcp.Handler for a service configuration.
|
||||
func (m *Manager) BuildTCP(rootCtx context.Context, serviceName string) (tcp.Handler, error) {
|
||||
ctx := log.With(rootCtx, log.Str(log.ServiceName, serviceName))
|
||||
|
||||
serviceName = internal.GetQualifiedName(ctx, serviceName)
|
||||
ctx = internal.AddProviderInContext(ctx, serviceName)
|
||||
|
||||
if conf, ok := m.configs[serviceName]; ok {
|
||||
// FIXME Check if the service is declared multiple times with different types
|
||||
if conf.LoadBalancer != nil {
|
||||
loadBalancer := tcp.NewRRLoadBalancer()
|
||||
|
||||
var handler tcp.Handler
|
||||
for _, server := range conf.LoadBalancer.Servers {
|
||||
_, err := parseIP(server.Address)
|
||||
if err == nil {
|
||||
handler, _ = tcp.NewProxy(server.Address)
|
||||
loadBalancer.AddServer(handler)
|
||||
} else {
|
||||
log.FromContext(ctx).Errorf("Invalid IP address for a %s server %s: %v", serviceName, server.Address, err)
|
||||
}
|
||||
}
|
||||
return loadBalancer, nil
|
||||
}
|
||||
return nil, fmt.Errorf("the service %q doesn't have any TCP load balancer", serviceName)
|
||||
}
|
||||
return nil, fmt.Errorf("the service %q does not exits", serviceName)
|
||||
}
|
||||
|
||||
func parseIP(s string) (string, error) {
|
||||
ip, _, err := net.SplitHostPort(s)
|
||||
if err == nil {
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
ipNoPort := net.ParseIP(s)
|
||||
if ipNoPort == nil {
|
||||
return "", fmt.Errorf("invalid IP Address %s", ipNoPort)
|
||||
}
|
||||
|
||||
return ipNoPort.String(), nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue