Introduce a fast proxy mode to improve HTTP/1.1 performances with backends
Co-authored-by: Romain <rtribotte@users.noreply.github.com> Co-authored-by: Julien Salleyron <julien.salleyron@gmail.com>
This commit is contained in:
parent
a6db1cac37
commit
f8a78b3b25
39 changed files with 3173 additions and 378 deletions
|
@ -2,10 +2,12 @@ package router
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -18,7 +20,7 @@ import (
|
|||
"github.com/traefik/traefik/v3/pkg/server/middleware"
|
||||
"github.com/traefik/traefik/v3/pkg/server/service"
|
||||
"github.com/traefik/traefik/v3/pkg/testhelpers"
|
||||
"github.com/traefik/traefik/v3/pkg/tls"
|
||||
traefiktls "github.com/traefik/traefik/v3/pkg/tls"
|
||||
)
|
||||
|
||||
func TestRouterManager_Get(t *testing.T) {
|
||||
|
@ -309,11 +311,12 @@ func TestRouterManager_Get(t *testing.T) {
|
|||
},
|
||||
})
|
||||
|
||||
roundTripperManager := service.NewRoundTripperManager(nil)
|
||||
roundTripperManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}})
|
||||
serviceManager := service.NewManager(rtConf.Services, nil, nil, roundTripperManager)
|
||||
transportManager := service.NewTransportManager(nil)
|
||||
transportManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}})
|
||||
|
||||
serviceManager := service.NewManager(rtConf.Services, nil, nil, transportManager, proxyBuilderMock{})
|
||||
middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager, nil)
|
||||
tlsManager := tls.NewManager()
|
||||
tlsManager := traefiktls.NewManager()
|
||||
|
||||
routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, nil, tlsManager)
|
||||
|
||||
|
@ -340,7 +343,7 @@ func TestRuntimeConfiguration(t *testing.T) {
|
|||
serviceConfig map[string]*dynamic.Service
|
||||
routerConfig map[string]*dynamic.Router
|
||||
middlewareConfig map[string]*dynamic.Middleware
|
||||
tlsOptions map[string]tls.Options
|
||||
tlsOptions map[string]traefiktls.Options
|
||||
expectedError int
|
||||
}{
|
||||
{
|
||||
|
@ -597,7 +600,7 @@ func TestRuntimeConfiguration(t *testing.T) {
|
|||
TLS: &dynamic.RouterTLSConfig{},
|
||||
},
|
||||
},
|
||||
tlsOptions: map[string]tls.Options{},
|
||||
tlsOptions: map[string]traefiktls.Options{},
|
||||
expectedError: 1,
|
||||
},
|
||||
{
|
||||
|
@ -624,9 +627,9 @@ func TestRuntimeConfiguration(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
tlsOptions: map[string]tls.Options{
|
||||
tlsOptions: map[string]traefiktls.Options{
|
||||
"broken-tlsOption": {
|
||||
ClientAuth: tls.ClientAuth{
|
||||
ClientAuth: traefiktls.ClientAuth{
|
||||
ClientAuthType: "foobar",
|
||||
},
|
||||
},
|
||||
|
@ -655,9 +658,9 @@ func TestRuntimeConfiguration(t *testing.T) {
|
|||
TLS: &dynamic.RouterTLSConfig{},
|
||||
},
|
||||
},
|
||||
tlsOptions: map[string]tls.Options{
|
||||
tlsOptions: map[string]traefiktls.Options{
|
||||
"default": {
|
||||
ClientAuth: tls.ClientAuth{
|
||||
ClientAuth: traefiktls.ClientAuth{
|
||||
ClientAuthType: "foobar",
|
||||
},
|
||||
},
|
||||
|
@ -682,11 +685,12 @@ func TestRuntimeConfiguration(t *testing.T) {
|
|||
},
|
||||
})
|
||||
|
||||
roundTripperManager := service.NewRoundTripperManager(nil)
|
||||
roundTripperManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}})
|
||||
serviceManager := service.NewManager(rtConf.Services, nil, nil, roundTripperManager)
|
||||
transportManager := service.NewTransportManager(nil)
|
||||
transportManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}})
|
||||
|
||||
serviceManager := service.NewManager(rtConf.Services, nil, nil, transportManager, proxyBuilderMock{})
|
||||
middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager, nil)
|
||||
tlsManager := tls.NewManager()
|
||||
tlsManager := traefiktls.NewManager()
|
||||
tlsManager.UpdateConfigs(context.Background(), nil, test.tlsOptions, nil)
|
||||
|
||||
routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, nil, tlsManager)
|
||||
|
@ -759,11 +763,12 @@ func TestProviderOnMiddlewares(t *testing.T) {
|
|||
},
|
||||
})
|
||||
|
||||
roundTripperManager := service.NewRoundTripperManager(nil)
|
||||
roundTripperManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}})
|
||||
serviceManager := service.NewManager(rtConf.Services, nil, nil, roundTripperManager)
|
||||
transportManager := service.NewTransportManager(nil)
|
||||
transportManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}})
|
||||
|
||||
serviceManager := service.NewManager(rtConf.Services, nil, nil, transportManager, nil)
|
||||
middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager, nil)
|
||||
tlsManager := tls.NewManager()
|
||||
tlsManager := traefiktls.NewManager()
|
||||
|
||||
routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, nil, tlsManager)
|
||||
|
||||
|
@ -775,14 +780,22 @@ func TestProviderOnMiddlewares(t *testing.T) {
|
|||
assert.Equal(t, []string{"m1@docker", "m2@docker", "m1@file"}, rtConf.Middlewares["chain@docker"].Chain.Middlewares)
|
||||
}
|
||||
|
||||
type staticRoundTripperGetter struct {
|
||||
type staticTransportManager struct {
|
||||
res *http.Response
|
||||
}
|
||||
|
||||
func (s staticRoundTripperGetter) Get(name string) (http.RoundTripper, error) {
|
||||
func (s staticTransportManager) GetRoundTripper(_ string) (http.RoundTripper, error) {
|
||||
return &staticTransport{res: s.res}, nil
|
||||
}
|
||||
|
||||
func (s staticTransportManager) GetTLSConfig(_ string) (*tls.Config, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (s staticTransportManager) Get(_ string) (*dynamic.ServersTransport, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
type staticTransport struct {
|
||||
res *http.Response
|
||||
}
|
||||
|
@ -829,9 +842,9 @@ func BenchmarkRouterServe(b *testing.B) {
|
|||
},
|
||||
})
|
||||
|
||||
serviceManager := service.NewManager(rtConf.Services, nil, nil, staticRoundTripperGetter{res})
|
||||
serviceManager := service.NewManager(rtConf.Services, nil, nil, staticTransportManager{res}, nil)
|
||||
middlewaresBuilder := middleware.NewBuilder(rtConf.Middlewares, serviceManager, nil)
|
||||
tlsManager := tls.NewManager()
|
||||
tlsManager := traefiktls.NewManager()
|
||||
|
||||
routerManager := NewManager(rtConf, serviceManager, middlewaresBuilder, nil, tlsManager)
|
||||
|
||||
|
@ -871,7 +884,7 @@ func BenchmarkService(b *testing.B) {
|
|||
},
|
||||
})
|
||||
|
||||
serviceManager := service.NewManager(rtConf.Services, nil, nil, staticRoundTripperGetter{res})
|
||||
serviceManager := service.NewManager(rtConf.Services, nil, nil, staticTransportManager{res}, nil)
|
||||
w := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
|
||||
|
||||
|
@ -881,3 +894,13 @@ func BenchmarkService(b *testing.B) {
|
|||
handler.ServeHTTP(w, req)
|
||||
}
|
||||
}
|
||||
|
||||
type proxyBuilderMock struct{}
|
||||
|
||||
func (p proxyBuilderMock) Build(_ string, _ *url.URL, _, _ bool, _ time.Duration) (http.Handler, error) {
|
||||
return http.HandlerFunc(func(responseWriter http.ResponseWriter, req *http.Request) {}), nil
|
||||
}
|
||||
|
||||
func (p proxyBuilderMock) Update(_ map[string]*dynamic.ServersTransport) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
|
|
@ -3,7 +3,9 @@ package server
|
|||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
||||
|
@ -48,9 +50,10 @@ func TestReuseService(t *testing.T) {
|
|||
),
|
||||
)
|
||||
|
||||
roundTripperManager := service.NewRoundTripperManager(nil)
|
||||
roundTripperManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}})
|
||||
managerFactory := service.NewManagerFactory(staticConfig, nil, nil, roundTripperManager, nil)
|
||||
transportManager := service.NewTransportManager(nil)
|
||||
transportManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}})
|
||||
|
||||
managerFactory := service.NewManagerFactory(staticConfig, nil, nil, transportManager, proxyBuilderMock{}, nil)
|
||||
tlsManager := tls.NewManager()
|
||||
|
||||
dialerManager := tcp.NewDialerManager(nil)
|
||||
|
@ -184,9 +187,10 @@ func TestServerResponseEmptyBackend(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
roundTripperManager := service.NewRoundTripperManager(nil)
|
||||
roundTripperManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}})
|
||||
managerFactory := service.NewManagerFactory(staticConfig, nil, nil, roundTripperManager, nil)
|
||||
transportManager := service.NewTransportManager(nil)
|
||||
transportManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}})
|
||||
|
||||
managerFactory := service.NewManagerFactory(staticConfig, nil, nil, transportManager, proxyBuilderMock{}, nil)
|
||||
tlsManager := tls.NewManager()
|
||||
|
||||
dialerManager := tcp.NewDialerManager(nil)
|
||||
|
@ -228,9 +232,10 @@ func TestInternalServices(t *testing.T) {
|
|||
),
|
||||
)
|
||||
|
||||
roundTripperManager := service.NewRoundTripperManager(nil)
|
||||
roundTripperManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}})
|
||||
managerFactory := service.NewManagerFactory(staticConfig, nil, nil, roundTripperManager, nil)
|
||||
transportManager := service.NewTransportManager(nil)
|
||||
transportManager.Update(map[string]*dynamic.ServersTransport{"default@internal": {}})
|
||||
|
||||
managerFactory := service.NewManagerFactory(staticConfig, nil, nil, transportManager, nil, nil)
|
||||
tlsManager := tls.NewManager()
|
||||
|
||||
dialerManager := tcp.NewDialerManager(nil)
|
||||
|
@ -246,3 +251,13 @@ func TestInternalServices(t *testing.T) {
|
|||
|
||||
assert.Equal(t, http.StatusOK, responseRecorderOk.Result().StatusCode, "status code")
|
||||
}
|
||||
|
||||
type proxyBuilderMock struct{}
|
||||
|
||||
func (p proxyBuilderMock) Build(_ string, _ *url.URL, _, _ bool, _ time.Duration) (http.Handler, error) {
|
||||
return http.HandlerFunc(func(responseWriter http.ResponseWriter, req *http.Request) {}), nil
|
||||
}
|
||||
|
||||
func (p proxyBuilderMock) Update(_ map[string]*dynamic.ServersTransport) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
|
|
@ -1,27 +0,0 @@
|
|||
package service
|
||||
|
||||
import "sync"
|
||||
|
||||
const bufferPoolSize = 32 * 1024
|
||||
|
||||
func newBufferPool() *bufferPool {
|
||||
return &bufferPool{
|
||||
pool: sync.Pool{
|
||||
New: func() interface{} {
|
||||
return make([]byte, bufferPoolSize)
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type bufferPool struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
func (b *bufferPool) Get() []byte {
|
||||
return b.pool.Get().([]byte)
|
||||
}
|
||||
|
||||
func (b *bufferPool) Put(bytes []byte) {
|
||||
b.pool.Put(bytes)
|
||||
}
|
|
@ -17,7 +17,8 @@ import (
|
|||
type ManagerFactory struct {
|
||||
observabilityMgr *middleware.ObservabilityMgr
|
||||
|
||||
roundTripperManager *RoundTripperManager
|
||||
transportManager *TransportManager
|
||||
proxyBuilder ProxyBuilder
|
||||
|
||||
api func(configuration *runtime.Configuration) http.Handler
|
||||
restHandler http.Handler
|
||||
|
@ -30,12 +31,13 @@ type ManagerFactory struct {
|
|||
}
|
||||
|
||||
// NewManagerFactory creates a new ManagerFactory.
|
||||
func NewManagerFactory(staticConfiguration static.Configuration, routinesPool *safe.Pool, observabilityMgr *middleware.ObservabilityMgr, roundTripperManager *RoundTripperManager, acmeHTTPHandler http.Handler) *ManagerFactory {
|
||||
func NewManagerFactory(staticConfiguration static.Configuration, routinesPool *safe.Pool, observabilityMgr *middleware.ObservabilityMgr, transportManager *TransportManager, proxyBuilder ProxyBuilder, acmeHTTPHandler http.Handler) *ManagerFactory {
|
||||
factory := &ManagerFactory{
|
||||
observabilityMgr: observabilityMgr,
|
||||
routinesPool: routinesPool,
|
||||
roundTripperManager: roundTripperManager,
|
||||
acmeHTTPHandler: acmeHTTPHandler,
|
||||
observabilityMgr: observabilityMgr,
|
||||
routinesPool: routinesPool,
|
||||
transportManager: transportManager,
|
||||
proxyBuilder: proxyBuilder,
|
||||
acmeHTTPHandler: acmeHTTPHandler,
|
||||
}
|
||||
|
||||
if staticConfiguration.API != nil {
|
||||
|
@ -73,7 +75,7 @@ func NewManagerFactory(staticConfiguration static.Configuration, routinesPool *s
|
|||
|
||||
// Build creates a service manager.
|
||||
func (f *ManagerFactory) Build(configuration *runtime.Configuration) *InternalHandlers {
|
||||
svcManager := NewManager(configuration.Services, f.observabilityMgr, f.routinesPool, f.roundTripperManager)
|
||||
svcManager := NewManager(configuration.Services, f.observabilityMgr, f.routinesPool, f.transportManager, f.proxyBuilder)
|
||||
|
||||
var apiHandler http.Handler
|
||||
if f.api != nil {
|
||||
|
|
|
@ -1,105 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/traefik/traefik/v3/pkg/metrics"
|
||||
"github.com/traefik/traefik/v3/pkg/middlewares/observability"
|
||||
"github.com/traefik/traefik/v3/pkg/tracing"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
type wrapper struct {
|
||||
semConvMetricRegistry *metrics.SemConvMetricsRegistry
|
||||
rt http.RoundTripper
|
||||
}
|
||||
|
||||
func (t *wrapper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
start := time.Now()
|
||||
var span trace.Span
|
||||
var tracingCtx context.Context
|
||||
var tracer *tracing.Tracer
|
||||
if tracer = tracing.TracerFromContext(req.Context()); tracer != nil {
|
||||
tracingCtx, span = tracer.Start(req.Context(), "ReverseProxy", trace.WithSpanKind(trace.SpanKindClient))
|
||||
defer span.End()
|
||||
|
||||
req = req.WithContext(tracingCtx)
|
||||
|
||||
tracer.CaptureClientRequest(span, req)
|
||||
tracing.InjectContextIntoCarrier(req)
|
||||
}
|
||||
|
||||
var statusCode int
|
||||
var headers http.Header
|
||||
response, err := t.rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
statusCode = computeStatusCode(err)
|
||||
}
|
||||
if response != nil {
|
||||
statusCode = response.StatusCode
|
||||
headers = response.Header
|
||||
}
|
||||
|
||||
if tracer != nil {
|
||||
tracer.CaptureResponse(span, headers, statusCode, trace.SpanKindClient)
|
||||
}
|
||||
|
||||
end := time.Now()
|
||||
|
||||
// Ending the span as soon as the response is handled because we want to use the same end time for the trace and the metric.
|
||||
// If any errors happen earlier, this span will be close by the defer instruction.
|
||||
if span != nil {
|
||||
span.End(trace.WithTimestamp(end))
|
||||
}
|
||||
|
||||
if t.semConvMetricRegistry != nil && t.semConvMetricRegistry.HTTPClientRequestDuration() != nil {
|
||||
var attrs []attribute.KeyValue
|
||||
|
||||
if statusCode < 100 || statusCode >= 600 {
|
||||
attrs = append(attrs, attribute.Key("error.type").String(fmt.Sprintf("Invalid HTTP status code %d", statusCode)))
|
||||
} else if statusCode >= 400 {
|
||||
attrs = append(attrs, attribute.Key("error.type").String(strconv.Itoa(statusCode)))
|
||||
}
|
||||
|
||||
attrs = append(attrs, semconv.HTTPRequestMethodKey.String(req.Method))
|
||||
attrs = append(attrs, semconv.HTTPResponseStatusCode(statusCode))
|
||||
attrs = append(attrs, semconv.NetworkProtocolName(strings.ToLower(req.Proto)))
|
||||
attrs = append(attrs, semconv.NetworkProtocolVersion(observability.Proto(req.Proto)))
|
||||
attrs = append(attrs, semconv.ServerAddress(req.URL.Host))
|
||||
|
||||
_, port, err := net.SplitHostPort(req.URL.Host)
|
||||
if err != nil {
|
||||
switch req.URL.Scheme {
|
||||
case "http":
|
||||
attrs = append(attrs, semconv.ServerPort(80))
|
||||
case "https":
|
||||
attrs = append(attrs, semconv.ServerPort(443))
|
||||
}
|
||||
} else {
|
||||
intPort, _ := strconv.Atoi(port)
|
||||
attrs = append(attrs, semconv.ServerPort(intPort))
|
||||
}
|
||||
|
||||
attrs = append(attrs, semconv.URLScheme(req.Header.Get("X-Forwarded-Proto")))
|
||||
|
||||
t.semConvMetricRegistry.HTTPClientRequestDuration().Record(req.Context(), end.Sub(start).Seconds(), metric.WithAttributes(attrs...))
|
||||
}
|
||||
|
||||
return response, err
|
||||
}
|
||||
|
||||
func newObservabilityRoundTripper(semConvMetricRegistry *metrics.SemConvMetricsRegistry, rt http.RoundTripper) http.RoundTripper {
|
||||
return &wrapper{
|
||||
semConvMetricRegistry: semConvMetricRegistry,
|
||||
rt: rt,
|
||||
}
|
||||
}
|
|
@ -1,122 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
ptypes "github.com/traefik/paerser/types"
|
||||
"github.com/traefik/traefik/v3/pkg/metrics"
|
||||
"github.com/traefik/traefik/v3/pkg/types"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
|
||||
"go.opentelemetry.io/otel/sdk/metric/metricdata"
|
||||
"go.opentelemetry.io/otel/sdk/metric/metricdata/metricdatatest"
|
||||
)
|
||||
|
||||
func TestObservabilityRoundTripper_metrics(t *testing.T) {
|
||||
tests := []struct {
|
||||
desc string
|
||||
serverURL string
|
||||
statusCode int
|
||||
wantAttributes attribute.Set
|
||||
}{
|
||||
{
|
||||
desc: "not found status",
|
||||
serverURL: "http://www.test.com",
|
||||
statusCode: http.StatusNotFound,
|
||||
wantAttributes: attribute.NewSet(
|
||||
attribute.Key("error.type").String("404"),
|
||||
attribute.Key("http.request.method").String("GET"),
|
||||
attribute.Key("http.response.status_code").Int(404),
|
||||
attribute.Key("network.protocol.name").String("http/1.1"),
|
||||
attribute.Key("network.protocol.version").String("1.1"),
|
||||
attribute.Key("server.address").String("www.test.com"),
|
||||
attribute.Key("server.port").Int(80),
|
||||
attribute.Key("url.scheme").String("http"),
|
||||
),
|
||||
},
|
||||
{
|
||||
desc: "created status",
|
||||
serverURL: "https://www.test.com",
|
||||
statusCode: http.StatusCreated,
|
||||
wantAttributes: attribute.NewSet(
|
||||
attribute.Key("http.request.method").String("GET"),
|
||||
attribute.Key("http.response.status_code").Int(201),
|
||||
attribute.Key("network.protocol.name").String("http/1.1"),
|
||||
attribute.Key("network.protocol.version").String("1.1"),
|
||||
attribute.Key("server.address").String("www.test.com"),
|
||||
attribute.Key("server.port").Int(443),
|
||||
attribute.Key("url.scheme").String("http"),
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var cfg types.OTLP
|
||||
(&cfg).SetDefaults()
|
||||
cfg.AddRoutersLabels = true
|
||||
cfg.PushInterval = ptypes.Duration(10 * time.Millisecond)
|
||||
rdr := sdkmetric.NewManualReader()
|
||||
|
||||
meterProvider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(rdr))
|
||||
// force the meter provider with manual reader to collect metrics for the test.
|
||||
metrics.SetMeterProvider(meterProvider)
|
||||
|
||||
semConvMetricRegistry, err := metrics.NewSemConvMetricRegistry(context.Background(), &cfg)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, semConvMetricRegistry)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, test.serverURL+"/search?q=Opentelemetry", nil)
|
||||
req.RemoteAddr = "10.0.0.1:1234"
|
||||
req.Header.Set("User-Agent", "rt-test")
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
|
||||
ort := newObservabilityRoundTripper(semConvMetricRegistry, mockRoundTripper{statusCode: test.statusCode})
|
||||
_, err = ort.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
got := metricdata.ResourceMetrics{}
|
||||
err = rdr.Collect(context.Background(), &got)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, got.ScopeMetrics, 1)
|
||||
|
||||
expected := metricdata.Metrics{
|
||||
Name: "http.client.request.duration",
|
||||
Description: "Duration of HTTP client requests.",
|
||||
Unit: "s",
|
||||
Data: metricdata.Histogram[float64]{
|
||||
DataPoints: []metricdata.HistogramDataPoint[float64]{
|
||||
{
|
||||
Attributes: test.wantAttributes,
|
||||
Count: 1,
|
||||
Bounds: []float64{0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.25, 0.5, 0.75, 1, 2.5, 5, 7.5, 10},
|
||||
BucketCounts: []uint64{0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
||||
Min: metricdata.NewExtrema[float64](1),
|
||||
Max: metricdata.NewExtrema[float64](1),
|
||||
Sum: 1,
|
||||
},
|
||||
},
|
||||
Temporality: metricdata.CumulativeTemporality,
|
||||
},
|
||||
}
|
||||
|
||||
metricdatatest.AssertEqual[metricdata.Metrics](t, expected, got.ScopeMetrics[0].Metrics[0], metricdatatest.IgnoreTimestamp(), metricdatatest.IgnoreValue())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockRoundTripper struct {
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (m mockRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||
return &http.Response{StatusCode: m.statusCode}, nil
|
||||
}
|
|
@ -1,133 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
|
||||
// StatusClientClosedRequest non-standard HTTP status code for client disconnection.
|
||||
const StatusClientClosedRequest = 499
|
||||
|
||||
// StatusClientClosedRequestText non-standard HTTP status for client disconnection.
|
||||
const StatusClientClosedRequestText = "Client Closed Request"
|
||||
|
||||
func buildSingleHostProxy(target *url.URL, passHostHeader bool, flushInterval time.Duration, roundTripper http.RoundTripper, bufferPool httputil.BufferPool) http.Handler {
|
||||
return &httputil.ReverseProxy{
|
||||
Director: directorBuilder(target, passHostHeader),
|
||||
Transport: roundTripper,
|
||||
FlushInterval: flushInterval,
|
||||
BufferPool: bufferPool,
|
||||
ErrorHandler: errorHandler,
|
||||
}
|
||||
}
|
||||
|
||||
func directorBuilder(target *url.URL, passHostHeader bool) func(req *http.Request) {
|
||||
return func(outReq *http.Request) {
|
||||
outReq.URL.Scheme = target.Scheme
|
||||
outReq.URL.Host = target.Host
|
||||
|
||||
u := outReq.URL
|
||||
if outReq.RequestURI != "" {
|
||||
parsedURL, err := url.ParseRequestURI(outReq.RequestURI)
|
||||
if err == nil {
|
||||
u = parsedURL
|
||||
}
|
||||
}
|
||||
|
||||
outReq.URL.Path = u.Path
|
||||
outReq.URL.RawPath = u.RawPath
|
||||
// If a plugin/middleware adds semicolons in query params, they should be urlEncoded.
|
||||
outReq.URL.RawQuery = strings.ReplaceAll(u.RawQuery, ";", "&")
|
||||
outReq.RequestURI = "" // Outgoing request should not have RequestURI
|
||||
|
||||
outReq.Proto = "HTTP/1.1"
|
||||
outReq.ProtoMajor = 1
|
||||
outReq.ProtoMinor = 1
|
||||
|
||||
// Do not pass client Host header unless optsetter PassHostHeader is set.
|
||||
if !passHostHeader {
|
||||
outReq.Host = outReq.URL.Host
|
||||
}
|
||||
|
||||
cleanWebSocketHeaders(outReq)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanWebSocketHeaders Even if the websocket RFC says that headers should be case-insensitive,
|
||||
// some servers need Sec-WebSocket-Key, Sec-WebSocket-Extensions, Sec-WebSocket-Accept,
|
||||
// Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive.
|
||||
// https://tools.ietf.org/html/rfc6455#page-20
|
||||
func cleanWebSocketHeaders(req *http.Request) {
|
||||
if !isWebSocketUpgrade(req) {
|
||||
return
|
||||
}
|
||||
|
||||
req.Header["Sec-WebSocket-Key"] = req.Header["Sec-Websocket-Key"]
|
||||
delete(req.Header, "Sec-Websocket-Key")
|
||||
|
||||
req.Header["Sec-WebSocket-Extensions"] = req.Header["Sec-Websocket-Extensions"]
|
||||
delete(req.Header, "Sec-Websocket-Extensions")
|
||||
|
||||
req.Header["Sec-WebSocket-Accept"] = req.Header["Sec-Websocket-Accept"]
|
||||
delete(req.Header, "Sec-Websocket-Accept")
|
||||
|
||||
req.Header["Sec-WebSocket-Protocol"] = req.Header["Sec-Websocket-Protocol"]
|
||||
delete(req.Header, "Sec-Websocket-Protocol")
|
||||
|
||||
req.Header["Sec-WebSocket-Version"] = req.Header["Sec-Websocket-Version"]
|
||||
delete(req.Header, "Sec-Websocket-Version")
|
||||
}
|
||||
|
||||
func isWebSocketUpgrade(req *http.Request) bool {
|
||||
return httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") &&
|
||||
strings.EqualFold(req.Header.Get("Upgrade"), "websocket")
|
||||
}
|
||||
|
||||
func errorHandler(w http.ResponseWriter, req *http.Request, err error) {
|
||||
statusCode := computeStatusCode(err)
|
||||
|
||||
logger := log.Ctx(req.Context())
|
||||
logger.Debug().Err(err).Msgf("%d %s", statusCode, statusText(statusCode))
|
||||
|
||||
w.WriteHeader(statusCode)
|
||||
if _, werr := w.Write([]byte(statusText(statusCode))); werr != nil {
|
||||
logger.Debug().Err(werr).Msg("Error while writing status code")
|
||||
}
|
||||
}
|
||||
|
||||
func computeStatusCode(err error) int {
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
return http.StatusBadGateway
|
||||
case errors.Is(err, context.Canceled):
|
||||
return StatusClientClosedRequest
|
||||
default:
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) {
|
||||
if netErr.Timeout() {
|
||||
return http.StatusGatewayTimeout
|
||||
}
|
||||
|
||||
return http.StatusBadGateway
|
||||
}
|
||||
}
|
||||
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
|
||||
func statusText(statusCode int) string {
|
||||
if statusCode == StatusClientClosedRequest {
|
||||
return StatusClientClosedRequestText
|
||||
}
|
||||
return http.StatusText(statusCode)
|
||||
}
|
|
@ -1,37 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/traefik/traefik/v3/pkg/testhelpers"
|
||||
)
|
||||
|
||||
type staticTransport struct {
|
||||
res *http.Response
|
||||
}
|
||||
|
||||
func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
return t.res, nil
|
||||
}
|
||||
|
||||
func BenchmarkProxy(b *testing.B) {
|
||||
res := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
|
||||
|
||||
pool := newBufferPool()
|
||||
handler := buildSingleHostProxy(req.URL, false, 0, &staticTransport{res}, pool)
|
||||
|
||||
b.ReportAllocs()
|
||||
for range b.N {
|
||||
handler.ServeHTTP(w, req)
|
||||
}
|
||||
}
|
|
@ -1,655 +0,0 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
gorillawebsocket "github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
func TestWebSocketTCPClose(t *testing.T) {
|
||||
errChan := make(chan error, 1)
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
for {
|
||||
_, _, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
_, conn, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
).open()
|
||||
require.NoError(t, err)
|
||||
|
||||
conn.Close()
|
||||
|
||||
serverErr := <-errChan
|
||||
|
||||
var wsErr *gorillawebsocket.CloseError
|
||||
require.ErrorAs(t, serverErr, &wsErr)
|
||||
assert.Equal(t, 1006, wsErr.Code)
|
||||
}
|
||||
|
||||
func TestWebSocketPingPong(t *testing.T) {
|
||||
upgrader := gorillawebsocket.Upgrader{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
CheckOrigin: func(*http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/ws", func(writer http.ResponseWriter, request *http.Request) {
|
||||
ws, err := upgrader.Upgrade(writer, request, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
ws.SetPingHandler(func(appData string) error {
|
||||
err = ws.WriteMessage(gorillawebsocket.PongMessage, []byte(appData+"Pong"))
|
||||
require.NoError(t, err)
|
||||
return nil
|
||||
})
|
||||
|
||||
_, _, _ = ws.ReadMessage()
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", resp)
|
||||
defer conn.Close()
|
||||
|
||||
goodErr := fmt.Errorf("signal: %s", "Good data")
|
||||
badErr := fmt.Errorf("signal: %s", "Bad data")
|
||||
conn.SetPongHandler(func(data string) error {
|
||||
if data == "PingPong" {
|
||||
return goodErr
|
||||
}
|
||||
return badErr
|
||||
})
|
||||
|
||||
err = conn.WriteControl(gorillawebsocket.PingMessage, []byte("Ping"), time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = conn.ReadMessage()
|
||||
|
||||
if !errors.Is(err, goodErr) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketEcho(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
msg := make([]byte, 4)
|
||||
_, err := conn.Read(msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
fmt.Println(string(msg))
|
||||
|
||||
_, err = conn.Write(msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", resp)
|
||||
|
||||
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
|
||||
require.NoError(t, err)
|
||||
|
||||
fmt.Println(conn.ReadMessage())
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestWebSocketPassHost(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
passHost bool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "PassHost false",
|
||||
passHost: false,
|
||||
},
|
||||
{
|
||||
desc: "PassHost true",
|
||||
passHost: true,
|
||||
expected: "example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
req := conn.Request()
|
||||
|
||||
if test.passHost {
|
||||
require.Equal(t, test.expected, req.Host)
|
||||
} else {
|
||||
require.NotEqual(t, test.expected, req.Host)
|
||||
}
|
||||
|
||||
msg := make([]byte, 4)
|
||||
_, err := conn.Read(msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
fmt.Println(string(msg))
|
||||
_, err = conn.Write(msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
|
||||
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
headers.Add("Host", "example.com")
|
||||
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", resp)
|
||||
|
||||
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
|
||||
require.NoError(t, err)
|
||||
|
||||
fmt.Println(conn.ReadMessage())
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketServerWithoutCheckOrigin(t *testing.T) {
|
||||
upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
}}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
for {
|
||||
mt, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = c.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
withOrigin("http://127.0.0.2"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithOrigin(t *testing.T) {
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
for {
|
||||
mt, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = c.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
_, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("echo"),
|
||||
withOrigin("http://127.0.0.2"),
|
||||
).send()
|
||||
require.EqualError(t, err, "bad status")
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithQueryParams(t *testing.T) {
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
assert.Equal(t, "test", r.URL.Query().Get("query"))
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = conn.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws?query=test"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
conn.Close()
|
||||
}))
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
f := buildSingleHostProxy(parseURI(t, srv.URL), true, 0, http.DefaultTransport, nil)
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
req.URL = parseURI(t, srv.URL)
|
||||
w.Header().Set("HEADER-KEY", "HEADER-VALUE")
|
||||
f.ServeHTTP(w, req)
|
||||
}))
|
||||
defer proxy.Close()
|
||||
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", err, resp)
|
||||
defer conn.Close()
|
||||
|
||||
assert.Equal(t, "HEADER-VALUE", resp.Header.Get("HEADER-KEY"))
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithEncodedChar(t *testing.T) {
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
assert.Equal(t, "/%3A%2F%2F", r.URL.EscapedPath())
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = conn.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/%3A%2F%2F"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketUpgradeFailed(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/ws", func(w http.ResponseWriter, req *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
f := buildSingleHostProxy(parseURI(t, srv.URL), true, 0, http.DefaultTransport, nil)
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
path := req.URL.Path // keep the original path
|
||||
|
||||
if path == "/ws" {
|
||||
// Set new backend URL
|
||||
req.URL = parseURI(t, srv.URL)
|
||||
req.URL.Path = path
|
||||
f.ServeHTTP(w, req)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
}))
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
conn, err := net.DialTimeout("tcp", proxyAddr, dialTimeout)
|
||||
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "ws://127.0.0.1/ws", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Add("upgrade", "websocket")
|
||||
req.Header.Add("Connection", "upgrade")
|
||||
|
||||
err = req.Write(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First request works with 400
|
||||
br := bufio.NewReader(conn)
|
||||
resp, err := http.ReadResponse(br, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 400, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestForwardsWebsocketTraffic(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
_, err := conn.Write([]byte("ok"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("echo"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func createTLSWebsocketServer() *httptest.Server {
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = conn.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestWebSocketTransferTLSConfig(t *testing.T) {
|
||||
srv := createTLSWebsocketServer()
|
||||
defer srv.Close()
|
||||
|
||||
proxyWithoutTLSConfig := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
|
||||
defer proxyWithoutTLSConfig.Close()
|
||||
|
||||
proxyAddr := proxyWithoutTLSConfig.Listener.Addr().String()
|
||||
|
||||
_, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.EqualError(t, err, "bad status")
|
||||
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
|
||||
proxyWithTLSConfig := createProxyWithForwarder(t, srv.URL, transport)
|
||||
defer proxyWithTLSConfig.Close()
|
||||
|
||||
proxyAddr = proxyWithTLSConfig.Listener.Addr().String()
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
|
||||
// Don't alter default transport to prevent side effects on other tests.
|
||||
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
defaultTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
|
||||
proxyWithTLSConfigFromDefaultTransport := createProxyWithForwarder(t, srv.URL, defaultTransport)
|
||||
defer proxyWithTLSConfig.Close()
|
||||
|
||||
proxyAddr = proxyWithTLSConfigFromDefaultTransport.Listener.Addr().String()
|
||||
|
||||
resp, err = newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
const dialTimeout = time.Second
|
||||
|
||||
type websocketRequestOpt func(w *websocketRequest)
|
||||
|
||||
func withServer(server string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.ServerAddr = server
|
||||
}
|
||||
}
|
||||
|
||||
func withPath(path string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.Path = path
|
||||
}
|
||||
}
|
||||
|
||||
func withData(data string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.Data = data
|
||||
}
|
||||
}
|
||||
|
||||
func withOrigin(origin string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.Origin = origin
|
||||
}
|
||||
}
|
||||
|
||||
func newWebsocketRequest(opts ...websocketRequestOpt) *websocketRequest {
|
||||
wsrequest := &websocketRequest{}
|
||||
for _, opt := range opts {
|
||||
opt(wsrequest)
|
||||
}
|
||||
if wsrequest.Origin == "" {
|
||||
wsrequest.Origin = "http://" + wsrequest.ServerAddr
|
||||
}
|
||||
if wsrequest.Config == nil {
|
||||
wsrequest.Config, _ = websocket.NewConfig(fmt.Sprintf("ws://%s%s", wsrequest.ServerAddr, wsrequest.Path), wsrequest.Origin)
|
||||
}
|
||||
return wsrequest
|
||||
}
|
||||
|
||||
type websocketRequest struct {
|
||||
ServerAddr string
|
||||
Path string
|
||||
Data string
|
||||
Origin string
|
||||
Config *websocket.Config
|
||||
}
|
||||
|
||||
func (w *websocketRequest) send() (string, error) {
|
||||
conn, _, err := w.open()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer conn.Close()
|
||||
if _, err := conn.Write([]byte(w.Data)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
msg := make([]byte, 512)
|
||||
var n int
|
||||
n, err = conn.Read(msg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
received := string(msg[:n])
|
||||
return received, nil
|
||||
}
|
||||
|
||||
func (w *websocketRequest) open() (*websocket.Conn, net.Conn, error) {
|
||||
client, err := net.DialTimeout("tcp", w.ServerAddr, dialTimeout)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
conn, err := websocket.NewClient(w.Config, client)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return conn, client, err
|
||||
}
|
||||
|
||||
func parseURI(t *testing.T, uri string) *url.URL {
|
||||
t.Helper()
|
||||
|
||||
out, err := url.ParseRequestURI(uri)
|
||||
require.NoError(t, err)
|
||||
return out
|
||||
}
|
||||
|
||||
func createProxyWithForwarder(t *testing.T, uri string, transport http.RoundTripper) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
u := parseURI(t, uri)
|
||||
proxy := buildSingleHostProxy(u, true, 0, transport, nil)
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
path := req.URL.Path // keep the original path
|
||||
// Set new backend URL
|
||||
req.URL = u
|
||||
req.URL.Path = path
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
return srv
|
||||
}
|
|
@ -9,7 +9,6 @@ import (
|
|||
"hash/fnv"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
@ -25,6 +24,7 @@ import (
|
|||
"github.com/traefik/traefik/v3/pkg/middlewares/capture"
|
||||
metricsMiddle "github.com/traefik/traefik/v3/pkg/middlewares/metrics"
|
||||
"github.com/traefik/traefik/v3/pkg/middlewares/observability"
|
||||
"github.com/traefik/traefik/v3/pkg/proxy/httputil"
|
||||
"github.com/traefik/traefik/v3/pkg/safe"
|
||||
"github.com/traefik/traefik/v3/pkg/server/cookie"
|
||||
"github.com/traefik/traefik/v3/pkg/server/middleware"
|
||||
|
@ -40,17 +40,18 @@ const (
|
|||
defaultMaxBodySize int64 = -1
|
||||
)
|
||||
|
||||
// RoundTripperGetter is a roundtripper getter interface.
|
||||
type RoundTripperGetter interface {
|
||||
Get(name string) (http.RoundTripper, error)
|
||||
// ProxyBuilder builds reverse proxy handlers.
|
||||
type ProxyBuilder interface {
|
||||
Build(cfgName string, targetURL *url.URL, shouldObserve, passHostHeader bool, flushInterval time.Duration) (http.Handler, error)
|
||||
Update(configs map[string]*dynamic.ServersTransport)
|
||||
}
|
||||
|
||||
// Manager The service manager.
|
||||
type Manager struct {
|
||||
routinePool *safe.Pool
|
||||
observabilityMgr *middleware.ObservabilityMgr
|
||||
bufferPool httputil.BufferPool
|
||||
roundTripperManager RoundTripperGetter
|
||||
routinePool *safe.Pool
|
||||
observabilityMgr *middleware.ObservabilityMgr
|
||||
transportManager httputil.TransportManager
|
||||
proxyBuilder ProxyBuilder
|
||||
|
||||
services map[string]http.Handler
|
||||
configs map[string]*runtime.ServiceInfo
|
||||
|
@ -59,16 +60,16 @@ type Manager struct {
|
|||
}
|
||||
|
||||
// NewManager creates a new Manager.
|
||||
func NewManager(configs map[string]*runtime.ServiceInfo, observabilityMgr *middleware.ObservabilityMgr, routinePool *safe.Pool, roundTripperManager RoundTripperGetter) *Manager {
|
||||
func NewManager(configs map[string]*runtime.ServiceInfo, observabilityMgr *middleware.ObservabilityMgr, routinePool *safe.Pool, transportManager httputil.TransportManager, proxyBuilder ProxyBuilder) *Manager {
|
||||
return &Manager{
|
||||
routinePool: routinePool,
|
||||
observabilityMgr: observabilityMgr,
|
||||
bufferPool: newBufferPool(),
|
||||
roundTripperManager: roundTripperManager,
|
||||
services: make(map[string]http.Handler),
|
||||
configs: configs,
|
||||
healthCheckers: make(map[string]*healthcheck.ServiceHealthChecker),
|
||||
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
routinePool: routinePool,
|
||||
observabilityMgr: observabilityMgr,
|
||||
transportManager: transportManager,
|
||||
proxyBuilder: proxyBuilder,
|
||||
services: make(map[string]http.Handler),
|
||||
configs: configs,
|
||||
healthCheckers: make(map[string]*healthcheck.ServiceHealthChecker),
|
||||
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -298,9 +299,9 @@ func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName
|
|||
logger.Debug().Msg("Creating load-balancer")
|
||||
|
||||
// TODO: should we keep this config value as Go is now handling stream response correctly?
|
||||
flushInterval := dynamic.DefaultFlushInterval
|
||||
flushInterval := time.Duration(dynamic.DefaultFlushInterval)
|
||||
if service.ResponseForwarding != nil {
|
||||
flushInterval = service.ResponseForwarding.FlushInterval
|
||||
flushInterval = time.Duration(service.ResponseForwarding.FlushInterval)
|
||||
}
|
||||
|
||||
if len(service.ServersTransport) > 0 {
|
||||
|
@ -317,11 +318,6 @@ func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName
|
|||
passHostHeader = *service.PassHostHeader
|
||||
}
|
||||
|
||||
roundTripper, err := m.roundTripperManager.Get(service.ServersTransport)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lb := wrr.New(service.Sticky, service.HealthCheck != nil)
|
||||
healthCheckTargets := make(map[string]*url.URL)
|
||||
|
||||
|
@ -341,14 +337,12 @@ func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName
|
|||
|
||||
qualifiedSvcName := provider.GetQualifiedName(ctx, serviceName)
|
||||
|
||||
if m.observabilityMgr.ShouldAddTracing(qualifiedSvcName) || m.observabilityMgr.ShouldAddMetrics(qualifiedSvcName) {
|
||||
// Wrapping the roundTripper with the Tracing roundTripper,
|
||||
// to handle the reverseProxy client span creation.
|
||||
roundTripper = newObservabilityRoundTripper(m.observabilityMgr.SemConvMetricsRegistry(), roundTripper)
|
||||
shouldObserve := m.observabilityMgr.ShouldAddTracing(qualifiedSvcName) || m.observabilityMgr.ShouldAddMetrics(qualifiedSvcName)
|
||||
proxy, err := m.proxyBuilder.Build(service.ServersTransport, target, shouldObserve, passHostHeader, flushInterval)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error building proxy for server URL %s: %w", server.URL, err)
|
||||
}
|
||||
|
||||
proxy := buildSingleHostProxy(target, passHostHeader, time.Duration(flushInterval), roundTripper, m.bufferPool)
|
||||
|
||||
// Prevents from enabling observability for internal resources.
|
||||
|
||||
if m.observabilityMgr.ShouldAddAccessLogs(qualifiedSvcName) {
|
||||
|
@ -393,6 +387,11 @@ func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName
|
|||
}
|
||||
|
||||
if service.HealthCheck != nil {
|
||||
roundTripper, err := m.transportManager.GetRoundTripper(service.ServersTransport)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting RoundTripper: %w", err)
|
||||
}
|
||||
|
||||
m.healthCheckers[serviceName] = healthcheck.NewServiceHealthChecker(
|
||||
ctx,
|
||||
m.observabilityMgr.MetricsRegistry(),
|
||||
|
|
|
@ -2,6 +2,7 @@ package service
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
@ -14,13 +15,14 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
||||
"github.com/traefik/traefik/v3/pkg/config/runtime"
|
||||
"github.com/traefik/traefik/v3/pkg/proxy/httputil"
|
||||
"github.com/traefik/traefik/v3/pkg/server/provider"
|
||||
"github.com/traefik/traefik/v3/pkg/testhelpers"
|
||||
)
|
||||
|
||||
func TestGetLoadBalancer(t *testing.T) {
|
||||
sm := Manager{
|
||||
roundTripperManager: newRtMock(),
|
||||
transportManager: &transportManagerMock{},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
|
@ -40,14 +42,14 @@ func TestGetLoadBalancer(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
fwd: &MockForwarder{},
|
||||
fwd: &forwarderMock{},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
desc: "Succeeds when there are no servers",
|
||||
serviceName: "test",
|
||||
service: &dynamic.ServersLoadBalancer{},
|
||||
fwd: &MockForwarder{},
|
||||
fwd: &forwarderMock{},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
|
@ -56,7 +58,7 @@ func TestGetLoadBalancer(t *testing.T) {
|
|||
service: &dynamic.ServersLoadBalancer{
|
||||
Sticky: &dynamic.Sticky{Cookie: &dynamic.Cookie{}},
|
||||
},
|
||||
fwd: &MockForwarder{},
|
||||
fwd: &forwarderMock{},
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
@ -79,11 +81,8 @@ func TestGetLoadBalancer(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestGetLoadBalancerServiceHandler(t *testing.T) {
|
||||
sm := NewManager(nil, nil, nil, &RoundTripperManager{
|
||||
roundTrippers: map[string]http.RoundTripper{
|
||||
"default@internal": http.DefaultTransport,
|
||||
},
|
||||
})
|
||||
pb := httputil.NewProxyBuilder(&transportManagerMock{}, nil)
|
||||
sm := NewManager(nil, nil, nil, transportManagerMock{}, pb)
|
||||
|
||||
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-From", "first")
|
||||
|
@ -139,7 +138,7 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) {
|
|||
desc: "Load balances between the two servers",
|
||||
serviceName: "test",
|
||||
service: &dynamic.ServersLoadBalancer{
|
||||
PassHostHeader: Bool(true),
|
||||
PassHostHeader: boolPtr(true),
|
||||
Servers: []dynamic.Server{
|
||||
{
|
||||
URL: server1.URL,
|
||||
|
@ -254,7 +253,7 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) {
|
|||
desc: "PassHost doesn't pass the host instead of the IP",
|
||||
serviceName: "test",
|
||||
service: &dynamic.ServersLoadBalancer{
|
||||
PassHostHeader: Bool(false),
|
||||
PassHostHeader: boolPtr(false),
|
||||
Sticky: &dynamic.Sticky{Cookie: &dynamic.Cookie{}},
|
||||
Servers: []dynamic.Server{
|
||||
{
|
||||
|
@ -359,11 +358,8 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) {
|
|||
|
||||
// This test is an adapted version of net/http/httputil.Test1xxResponses test.
|
||||
func Test1xxResponses(t *testing.T) {
|
||||
sm := NewManager(nil, nil, nil, &RoundTripperManager{
|
||||
roundTrippers: map[string]http.RoundTripper{
|
||||
"default@internal": http.DefaultTransport,
|
||||
},
|
||||
})
|
||||
pb := httputil.NewProxyBuilder(&transportManagerMock{}, nil)
|
||||
sm := NewManager(nil, nil, nil, &transportManagerMock{}, pb)
|
||||
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
h := w.Header()
|
||||
|
@ -499,11 +495,7 @@ func TestManager_Build(t *testing.T) {
|
|||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := NewManager(test.configs, nil, nil, &RoundTripperManager{
|
||||
roundTrippers: map[string]http.RoundTripper{
|
||||
"default@internal": http.DefaultTransport,
|
||||
},
|
||||
})
|
||||
manager := NewManager(test.configs, nil, nil, &transportManagerMock{}, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
if len(test.providerName) > 0 {
|
||||
|
@ -526,30 +518,30 @@ func TestMultipleTypeOnBuildHTTP(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
manager := NewManager(services, nil, nil, &RoundTripperManager{
|
||||
roundTrippers: map[string]http.RoundTripper{
|
||||
"default@internal": http.DefaultTransport,
|
||||
},
|
||||
})
|
||||
manager := NewManager(services, nil, nil, &transportManagerMock{}, nil)
|
||||
|
||||
_, err := manager.BuildHTTP(context.Background(), "test@file")
|
||||
assert.Error(t, err, "cannot create service: multi-types service not supported, consider declaring two different pieces of service instead")
|
||||
}
|
||||
|
||||
func Bool(v bool) *bool { return &v }
|
||||
func boolPtr(v bool) *bool { return &v }
|
||||
|
||||
type MockForwarder struct{}
|
||||
type forwarderMock struct{}
|
||||
|
||||
func (MockForwarder) ServeHTTP(http.ResponseWriter, *http.Request) {
|
||||
func (forwarderMock) ServeHTTP(http.ResponseWriter, *http.Request) {
|
||||
panic("not available")
|
||||
}
|
||||
|
||||
type rtMock struct{}
|
||||
type transportManagerMock struct{}
|
||||
|
||||
func newRtMock() RoundTripperGetter {
|
||||
return &rtMock{}
|
||||
func (t transportManagerMock) GetRoundTripper(_ string) (http.RoundTripper, error) {
|
||||
return &http.Transport{}, nil
|
||||
}
|
||||
|
||||
func (r *rtMock) Get(_ string) (http.RoundTripper, error) {
|
||||
return http.DefaultTransport, nil
|
||||
func (t transportManagerMock) GetTLSConfig(_ string) (*tls.Config, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (t transportManagerMock) Get(_ string) (*dynamic.ServersTransport, error) {
|
||||
return &dynamic.ServersTransport{}, nil
|
||||
}
|
||||
|
|
|
@ -11,6 +11,15 @@ import (
|
|||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
type h2cTransportWrapper struct {
|
||||
*http2.Transport
|
||||
}
|
||||
|
||||
func (t *h2cTransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
req.URL.Scheme = "http"
|
||||
return t.Transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
func newSmartRoundTripper(transport *http.Transport, forwardingTimeouts *dynamic.ForwardingTimeouts) (*smartRoundTripper, error) {
|
||||
transportHTTP1 := transport.Clone()
|
||||
|
||||
|
|
|
@ -22,52 +22,45 @@ import (
|
|||
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
||||
traefiktls "github.com/traefik/traefik/v3/pkg/tls"
|
||||
"github.com/traefik/traefik/v3/pkg/types"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
type h2cTransportWrapper struct {
|
||||
*http2.Transport
|
||||
}
|
||||
|
||||
func (t *h2cTransportWrapper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
req.URL.Scheme = "http"
|
||||
return t.Transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// SpiffeX509Source allows to retrieve a x509 SVID and bundle.
|
||||
type SpiffeX509Source interface {
|
||||
x509svid.Source
|
||||
x509bundle.Source
|
||||
}
|
||||
|
||||
// NewRoundTripperManager creates a new RoundTripperManager.
|
||||
func NewRoundTripperManager(spiffeX509Source SpiffeX509Source) *RoundTripperManager {
|
||||
return &RoundTripperManager{
|
||||
roundTrippers: make(map[string]http.RoundTripper),
|
||||
configs: make(map[string]*dynamic.ServersTransport),
|
||||
spiffeX509Source: spiffeX509Source,
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTripperManager handles roundtripper for the reverse proxy.
|
||||
type RoundTripperManager struct {
|
||||
// TransportManager handles transports for backend communication.
|
||||
type TransportManager struct {
|
||||
rtLock sync.RWMutex
|
||||
roundTrippers map[string]http.RoundTripper
|
||||
configs map[string]*dynamic.ServersTransport
|
||||
tlsConfigs map[string]*tls.Config
|
||||
|
||||
spiffeX509Source SpiffeX509Source
|
||||
}
|
||||
|
||||
// Update updates the roundtrippers configurations.
|
||||
func (r *RoundTripperManager) Update(newConfigs map[string]*dynamic.ServersTransport) {
|
||||
r.rtLock.Lock()
|
||||
defer r.rtLock.Unlock()
|
||||
// NewTransportManager creates a new TransportManager.
|
||||
func NewTransportManager(spiffeX509Source SpiffeX509Source) *TransportManager {
|
||||
return &TransportManager{
|
||||
roundTrippers: make(map[string]http.RoundTripper),
|
||||
configs: make(map[string]*dynamic.ServersTransport),
|
||||
tlsConfigs: make(map[string]*tls.Config),
|
||||
spiffeX509Source: spiffeX509Source,
|
||||
}
|
||||
}
|
||||
|
||||
for configName, config := range r.configs {
|
||||
// Update updates the transport configurations.
|
||||
func (t *TransportManager) Update(newConfigs map[string]*dynamic.ServersTransport) {
|
||||
t.rtLock.Lock()
|
||||
defer t.rtLock.Unlock()
|
||||
|
||||
for configName, config := range t.configs {
|
||||
newConfig, ok := newConfigs[configName]
|
||||
if !ok {
|
||||
delete(r.configs, configName)
|
||||
delete(r.roundTrippers, configName)
|
||||
delete(t.configs, configName)
|
||||
delete(t.roundTrippers, configName)
|
||||
delete(t.tlsConfigs, configName)
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -76,50 +69,133 @@ func (r *RoundTripperManager) Update(newConfigs map[string]*dynamic.ServersTrans
|
|||
}
|
||||
|
||||
var err error
|
||||
r.roundTrippers[configName], err = r.createRoundTripper(newConfig)
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
if tlsConfig, err = t.createTLSConfig(newConfig); err != nil {
|
||||
log.Error().Err(err).Msgf("Could not configure HTTP Transport %s TLS configuration, fallback on default TLS config", configName)
|
||||
}
|
||||
t.tlsConfigs[configName] = tlsConfig
|
||||
|
||||
t.roundTrippers[configName], err = t.createRoundTripper(newConfig, tlsConfig)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Could not configure HTTP Transport %s, fallback on default transport", configName)
|
||||
r.roundTrippers[configName] = http.DefaultTransport
|
||||
t.roundTrippers[configName] = http.DefaultTransport
|
||||
}
|
||||
}
|
||||
|
||||
for newConfigName, newConfig := range newConfigs {
|
||||
if _, ok := r.configs[newConfigName]; ok {
|
||||
if _, ok := t.configs[newConfigName]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
var err error
|
||||
r.roundTrippers[newConfigName], err = r.createRoundTripper(newConfig)
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
if tlsConfig, err = t.createTLSConfig(newConfig); err != nil {
|
||||
log.Error().Err(err).Msgf("Could not configure HTTP Transport %s TLS configuration, fallback on default TLS config", newConfigName)
|
||||
}
|
||||
t.tlsConfigs[newConfigName] = tlsConfig
|
||||
|
||||
t.roundTrippers[newConfigName], err = t.createRoundTripper(newConfig, tlsConfig)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Could not configure HTTP Transport %s, fallback on default transport", newConfigName)
|
||||
r.roundTrippers[newConfigName] = http.DefaultTransport
|
||||
t.roundTrippers[newConfigName] = http.DefaultTransport
|
||||
}
|
||||
}
|
||||
|
||||
r.configs = newConfigs
|
||||
t.configs = newConfigs
|
||||
}
|
||||
|
||||
// Get gets a roundtripper by name.
|
||||
func (r *RoundTripperManager) Get(name string) (http.RoundTripper, error) {
|
||||
// GetRoundTripper gets a roundtripper corresponding to the given transport name.
|
||||
func (t *TransportManager) GetRoundTripper(name string) (http.RoundTripper, error) {
|
||||
if len(name) == 0 {
|
||||
name = "default@internal"
|
||||
}
|
||||
|
||||
r.rtLock.RLock()
|
||||
defer r.rtLock.RUnlock()
|
||||
t.rtLock.RLock()
|
||||
defer t.rtLock.RUnlock()
|
||||
|
||||
if rt, ok := r.roundTrippers[name]; ok {
|
||||
if rt, ok := t.roundTrippers[name]; ok {
|
||||
return rt, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("servers transport not found %s", name)
|
||||
}
|
||||
|
||||
// Get gets transport by name.
|
||||
func (t *TransportManager) Get(name string) (*dynamic.ServersTransport, error) {
|
||||
if len(name) == 0 {
|
||||
name = "default@internal"
|
||||
}
|
||||
|
||||
t.rtLock.RLock()
|
||||
defer t.rtLock.RUnlock()
|
||||
|
||||
if rt, ok := t.configs[name]; ok {
|
||||
return rt, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("servers transport not found %s", name)
|
||||
}
|
||||
|
||||
// GetTLSConfig gets a TLS config corresponding to the given transport name.
|
||||
func (t *TransportManager) GetTLSConfig(name string) (*tls.Config, error) {
|
||||
if len(name) == 0 {
|
||||
name = "default@internal"
|
||||
}
|
||||
|
||||
t.rtLock.RLock()
|
||||
defer t.rtLock.RUnlock()
|
||||
|
||||
if rt, ok := t.tlsConfigs[name]; ok {
|
||||
return rt, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("tls config not found %s", name)
|
||||
}
|
||||
|
||||
func (t *TransportManager) createTLSConfig(cfg *dynamic.ServersTransport) (*tls.Config, error) {
|
||||
var config *tls.Config
|
||||
if cfg.Spiffe != nil {
|
||||
if t.spiffeX509Source == nil {
|
||||
return nil, errors.New("SPIFFE is enabled for this transport, but not configured")
|
||||
}
|
||||
|
||||
spiffeAuthorizer, err := buildSpiffeAuthorizer(cfg.Spiffe)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to build SPIFFE authorizer: %w", err)
|
||||
}
|
||||
|
||||
config = tlsconfig.MTLSClientConfig(t.spiffeX509Source, t.spiffeX509Source, spiffeAuthorizer)
|
||||
}
|
||||
|
||||
if cfg.InsecureSkipVerify || len(cfg.RootCAs) > 0 || len(cfg.ServerName) > 0 || len(cfg.Certificates) > 0 || cfg.PeerCertURI != "" {
|
||||
if config != nil {
|
||||
return nil, errors.New("TLS and SPIFFE configuration cannot be defined at the same time")
|
||||
}
|
||||
|
||||
config = &tls.Config{
|
||||
ServerName: cfg.ServerName,
|
||||
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||
RootCAs: createRootCACertPool(cfg.RootCAs),
|
||||
Certificates: cfg.Certificates.GetCertificates(),
|
||||
}
|
||||
|
||||
if cfg.PeerCertURI != "" {
|
||||
config.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
|
||||
return traefiktls.VerifyPeerCertificate(cfg.PeerCertURI, config, rawCerts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// createRoundTripper creates an http.RoundTripper configured with the Transport configuration settings.
|
||||
// For the settings that can't be configured in Traefik it uses the default http.Transport settings.
|
||||
// An exception to this is the MaxIdleConns setting as we only provide the option MaxIdleConnsPerHost in Traefik at this point in time.
|
||||
// Setting this value to the default of 100 could lead to confusing behavior and backwards compatibility issues.
|
||||
func (r *RoundTripperManager) createRoundTripper(cfg *dynamic.ServersTransport) (http.RoundTripper, error) {
|
||||
func (t *TransportManager) createRoundTripper(cfg *dynamic.ServersTransport, tlsConfig *tls.Config) (http.RoundTripper, error) {
|
||||
if cfg == nil {
|
||||
return nil, errors.New("no transport configuration given")
|
||||
}
|
||||
|
@ -142,6 +218,7 @@ func (r *RoundTripperManager) createRoundTripper(cfg *dynamic.ServersTransport)
|
|||
ExpectContinueTimeout: 1 * time.Second,
|
||||
ReadBufferSize: 64 * 1024,
|
||||
WriteBufferSize: 64 * 1024,
|
||||
TLSClientConfig: tlsConfig,
|
||||
}
|
||||
|
||||
if cfg.ForwardingTimeouts != nil {
|
||||
|
@ -149,41 +226,9 @@ func (r *RoundTripperManager) createRoundTripper(cfg *dynamic.ServersTransport)
|
|||
transport.IdleConnTimeout = time.Duration(cfg.ForwardingTimeouts.IdleConnTimeout)
|
||||
}
|
||||
|
||||
if cfg.Spiffe != nil {
|
||||
if r.spiffeX509Source == nil {
|
||||
return nil, errors.New("SPIFFE is enabled for this transport, but not configured")
|
||||
}
|
||||
|
||||
spiffeAuthorizer, err := buildSpiffeAuthorizer(cfg.Spiffe)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to build SPIFFE authorizer: %w", err)
|
||||
}
|
||||
|
||||
transport.TLSClientConfig = tlsconfig.MTLSClientConfig(r.spiffeX509Source, r.spiffeX509Source, spiffeAuthorizer)
|
||||
}
|
||||
|
||||
if cfg.InsecureSkipVerify || len(cfg.RootCAs) > 0 || len(cfg.ServerName) > 0 || len(cfg.Certificates) > 0 || cfg.PeerCertURI != "" {
|
||||
if transport.TLSClientConfig != nil {
|
||||
return nil, errors.New("TLS and SPIFFE configuration cannot be defined at the same time")
|
||||
}
|
||||
|
||||
transport.TLSClientConfig = &tls.Config{
|
||||
ServerName: cfg.ServerName,
|
||||
InsecureSkipVerify: cfg.InsecureSkipVerify,
|
||||
RootCAs: createRootCACertPool(cfg.RootCAs),
|
||||
Certificates: cfg.Certificates.GetCertificates(),
|
||||
}
|
||||
|
||||
if cfg.PeerCertURI != "" {
|
||||
transport.TLSClientConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
|
||||
return traefiktls.VerifyPeerCertificate(cfg.PeerCertURI, transport.TLSClientConfig, rawCerts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return directly HTTP/1.1 transport when HTTP/2 is disabled
|
||||
if cfg.DisableHTTP2 {
|
||||
return &KerberosRoundTripper{
|
||||
return &kerberosRoundTripper{
|
||||
OriginalRoundTripper: transport,
|
||||
new: func() http.RoundTripper {
|
||||
return transport.Clone()
|
||||
|
@ -195,7 +240,7 @@ func (r *RoundTripperManager) createRoundTripper(cfg *dynamic.ServersTransport)
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &KerberosRoundTripper{
|
||||
return &kerberosRoundTripper{
|
||||
OriginalRoundTripper: rt,
|
||||
new: func() http.RoundTripper {
|
||||
return rt.Clone()
|
||||
|
@ -203,11 +248,6 @@ func (r *RoundTripperManager) createRoundTripper(cfg *dynamic.ServersTransport)
|
|||
}, nil
|
||||
}
|
||||
|
||||
type KerberosRoundTripper struct {
|
||||
new func() http.RoundTripper
|
||||
OriginalRoundTripper http.RoundTripper
|
||||
}
|
||||
|
||||
type stickyRoundTripper struct {
|
||||
RoundTripper http.RoundTripper
|
||||
}
|
||||
|
@ -220,7 +260,12 @@ func AddTransportOnContext(ctx context.Context) context.Context {
|
|||
return context.WithValue(ctx, transportKey, &stickyRoundTripper{})
|
||||
}
|
||||
|
||||
func (k *KerberosRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||
type kerberosRoundTripper struct {
|
||||
new func() http.RoundTripper
|
||||
OriginalRoundTripper http.RoundTripper
|
||||
}
|
||||
|
||||
func (k *kerberosRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||
value, ok := request.Context().Value(transportKey).(*stickyRoundTripper)
|
||||
if !ok {
|
||||
return k.OriginalRoundTripper.RoundTrip(request)
|
|
@ -141,7 +141,7 @@ func TestKeepConnectionWhenSameConfiguration(t *testing.T) {
|
|||
srv.TLS = &tls.Config{Certificates: []tls.Certificate{cert}}
|
||||
srv.StartTLS()
|
||||
|
||||
rtManager := NewRoundTripperManager(nil)
|
||||
transportManager := NewTransportManager(nil)
|
||||
|
||||
dynamicConf := map[string]*dynamic.ServersTransport{
|
||||
"test": {
|
||||
|
@ -151,9 +151,9 @@ func TestKeepConnectionWhenSameConfiguration(t *testing.T) {
|
|||
}
|
||||
|
||||
for range 10 {
|
||||
rtManager.Update(dynamicConf)
|
||||
transportManager.Update(dynamicConf)
|
||||
|
||||
tr, err := rtManager.Get("test")
|
||||
tr, err := transportManager.GetRoundTripper("test")
|
||||
require.NoError(t, err)
|
||||
|
||||
client := http.Client{Transport: tr}
|
||||
|
@ -173,9 +173,9 @@ func TestKeepConnectionWhenSameConfiguration(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
rtManager.Update(dynamicConf)
|
||||
transportManager.Update(dynamicConf)
|
||||
|
||||
tr, err := rtManager.Get("test")
|
||||
tr, err := transportManager.GetRoundTripper("test")
|
||||
require.NoError(t, err)
|
||||
|
||||
client := http.Client{Transport: tr}
|
||||
|
@ -209,7 +209,7 @@ func TestMTLS(t *testing.T) {
|
|||
}
|
||||
srv.StartTLS()
|
||||
|
||||
rtManager := NewRoundTripperManager(nil)
|
||||
transportManager := NewTransportManager(nil)
|
||||
|
||||
dynamicConf := map[string]*dynamic.ServersTransport{
|
||||
"test": {
|
||||
|
@ -227,9 +227,9 @@ func TestMTLS(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
rtManager.Update(dynamicConf)
|
||||
transportManager.Update(dynamicConf)
|
||||
|
||||
tr, err := rtManager.Get("test")
|
||||
tr, err := transportManager.GetRoundTripper("test")
|
||||
require.NoError(t, err)
|
||||
|
||||
client := http.Client{Transport: tr}
|
||||
|
@ -348,7 +348,7 @@ func TestSpiffeMTLS(t *testing.T) {
|
|||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
rtManager := NewRoundTripperManager(test.clientSource)
|
||||
transportManager := NewTransportManager(test.clientSource)
|
||||
|
||||
dynamicConf := map[string]*dynamic.ServersTransport{
|
||||
"test": {
|
||||
|
@ -356,9 +356,9 @@ func TestSpiffeMTLS(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
rtManager.Update(dynamicConf)
|
||||
transportManager.Update(dynamicConf)
|
||||
|
||||
tr, err := rtManager.Get("test")
|
||||
tr, err := transportManager.GetRoundTripper("test")
|
||||
require.NoError(t, err)
|
||||
|
||||
client := http.Client{Transport: tr}
|
||||
|
@ -415,7 +415,7 @@ func TestDisableHTTP2(t *testing.T) {
|
|||
srv.EnableHTTP2 = test.serverHTTP2
|
||||
srv.StartTLS()
|
||||
|
||||
rtManager := NewRoundTripperManager(nil)
|
||||
transportManager := NewTransportManager(nil)
|
||||
|
||||
dynamicConf := map[string]*dynamic.ServersTransport{
|
||||
"test": {
|
||||
|
@ -424,9 +424,9 @@ func TestDisableHTTP2(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
rtManager.Update(dynamicConf)
|
||||
transportManager.Update(dynamicConf)
|
||||
|
||||
tr, err := rtManager.Get("test")
|
||||
tr, err := transportManager.GetRoundTripper("test")
|
||||
require.NoError(t, err)
|
||||
|
||||
client := http.Client{Transport: tr}
|
||||
|
@ -593,7 +593,7 @@ func TestKerberosRoundTripper(t *testing.T) {
|
|||
|
||||
origCount := 0
|
||||
dedicatedCount := 0
|
||||
rt := KerberosRoundTripper{
|
||||
rt := kerberosRoundTripper{
|
||||
new: func() http.RoundTripper {
|
||||
return roundTripperFn(func(req *http.Request) (*http.Response, error) {
|
||||
dedicatedCount++
|
Loading…
Add table
Add a link
Reference in a new issue