1
0
Fork 0

Introduce a fast proxy mode to improve HTTP/1.1 performances with backends

Co-authored-by: Romain <rtribotte@users.noreply.github.com>
Co-authored-by: Julien Salleyron <julien.salleyron@gmail.com>
This commit is contained in:
Kevin Pollet 2024-09-26 11:00:05 +02:00 committed by GitHub
parent a6db1cac37
commit f8a78b3b25
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
39 changed files with 3173 additions and 378 deletions

View file

@ -0,0 +1,29 @@
package httputil
import "sync"
const bufferSize = 32 * 1024
type bufferPool struct {
pool sync.Pool
}
func newBufferPool() *bufferPool {
b := &bufferPool{
pool: sync.Pool{},
}
b.pool.New = func() interface{} {
return make([]byte, bufferSize)
}
return b
}
func (b *bufferPool) Get() []byte {
return b.pool.Get().([]byte)
}
func (b *bufferPool) Put(bytes []byte) {
b.pool.Put(bytes)
}

View file

@ -0,0 +1,54 @@
package httputil
import (
"crypto/tls"
"fmt"
"net/http"
"net/url"
"time"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/metrics"
)
// TransportManager manages transport used for backend communications.
type TransportManager interface {
Get(name string) (*dynamic.ServersTransport, error)
GetRoundTripper(name string) (http.RoundTripper, error)
GetTLSConfig(name string) (*tls.Config, error)
}
// ProxyBuilder handles the http.RoundTripper for httputil reverse proxies.
type ProxyBuilder struct {
bufferPool *bufferPool
transportManager TransportManager
semConvMetricsRegistry *metrics.SemConvMetricsRegistry
}
// NewProxyBuilder creates a new ProxyBuilder.
func NewProxyBuilder(transportManager TransportManager, semConvMetricsRegistry *metrics.SemConvMetricsRegistry) *ProxyBuilder {
return &ProxyBuilder{
bufferPool: newBufferPool(),
transportManager: transportManager,
semConvMetricsRegistry: semConvMetricsRegistry,
}
}
// Update does nothing.
func (r *ProxyBuilder) Update(_ map[string]*dynamic.ServersTransport) {}
// Build builds a new httputil.ReverseProxy with the given configuration.
func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, shouldObserve, passHostHeader bool, flushInterval time.Duration) (http.Handler, error) {
roundTripper, err := r.transportManager.GetRoundTripper(cfgName)
if err != nil {
return nil, fmt.Errorf("getting RoundTripper: %w", err)
}
if shouldObserve {
// Wrapping the roundTripper with the Tracing roundTripper,
// to handle the reverseProxy client span creation.
roundTripper = newObservabilityRoundTripper(r.semConvMetricsRegistry, roundTripper)
}
return buildSingleHostProxy(targetURL, passHostHeader, flushInterval, roundTripper, r.bufferPool), nil
}

View file

@ -0,0 +1,56 @@
package httputil
import (
"crypto/tls"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/testhelpers"
)
func TestEscapedPath(t *testing.T) {
var gotEscapedPath string
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
gotEscapedPath = req.URL.EscapedPath()
}))
transportManager := &transportManagerMock{
roundTrippers: map[string]http.RoundTripper{"default": &http.Transport{}},
}
p, err := NewProxyBuilder(transportManager, nil).Build("default", testhelpers.MustParseURL(srv.URL), false, true, 0)
require.NoError(t, err)
proxy := httptest.NewServer(http.HandlerFunc(p.ServeHTTP))
_, err = http.Get(proxy.URL + "/%3A%2F%2F")
require.NoError(t, err)
assert.Equal(t, "/%3A%2F%2F", gotEscapedPath)
}
type transportManagerMock struct {
roundTrippers map[string]http.RoundTripper
}
func (t *transportManagerMock) GetRoundTripper(name string) (http.RoundTripper, error) {
roundTripper, ok := t.roundTrippers[name]
if !ok {
return nil, errors.New("no transport for " + name)
}
return roundTripper, nil
}
func (t *transportManagerMock) GetTLSConfig(_ string) (*tls.Config, error) {
panic("implement me")
}
func (t *transportManagerMock) Get(_ string) (*dynamic.ServersTransport, error) {
panic("implement me")
}

View file

@ -0,0 +1,105 @@
package httputil
import (
"context"
"fmt"
"net"
"net/http"
"strconv"
"strings"
"time"
"github.com/traefik/traefik/v3/pkg/metrics"
"github.com/traefik/traefik/v3/pkg/middlewares/observability"
"github.com/traefik/traefik/v3/pkg/tracing"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
"go.opentelemetry.io/otel/trace"
)
type wrapper struct {
semConvMetricRegistry *metrics.SemConvMetricsRegistry
rt http.RoundTripper
}
func newObservabilityRoundTripper(semConvMetricRegistry *metrics.SemConvMetricsRegistry, rt http.RoundTripper) http.RoundTripper {
return &wrapper{
semConvMetricRegistry: semConvMetricRegistry,
rt: rt,
}
}
func (t *wrapper) RoundTrip(req *http.Request) (*http.Response, error) {
start := time.Now()
var span trace.Span
var tracingCtx context.Context
var tracer *tracing.Tracer
if tracer = tracing.TracerFromContext(req.Context()); tracer != nil {
tracingCtx, span = tracer.Start(req.Context(), "ReverseProxy", trace.WithSpanKind(trace.SpanKindClient))
defer span.End()
req = req.WithContext(tracingCtx)
tracer.CaptureClientRequest(span, req)
tracing.InjectContextIntoCarrier(req)
}
var statusCode int
var headers http.Header
response, err := t.rt.RoundTrip(req)
if err != nil {
statusCode = ComputeStatusCode(err)
}
if response != nil {
statusCode = response.StatusCode
headers = response.Header
}
if tracer != nil {
tracer.CaptureResponse(span, headers, statusCode, trace.SpanKindClient)
}
end := time.Now()
// Ending the span as soon as the response is handled because we want to use the same end time for the trace and the metric.
// If any errors happen earlier, this span will be close by the defer instruction.
if span != nil {
span.End(trace.WithTimestamp(end))
}
if t.semConvMetricRegistry != nil && t.semConvMetricRegistry.HTTPClientRequestDuration() != nil {
var attrs []attribute.KeyValue
if statusCode < 100 || statusCode >= 600 {
attrs = append(attrs, attribute.Key("error.type").String(fmt.Sprintf("Invalid HTTP status code %d", statusCode)))
} else if statusCode >= 400 {
attrs = append(attrs, attribute.Key("error.type").String(strconv.Itoa(statusCode)))
}
attrs = append(attrs, semconv.HTTPRequestMethodKey.String(req.Method))
attrs = append(attrs, semconv.HTTPResponseStatusCode(statusCode))
attrs = append(attrs, semconv.NetworkProtocolName(strings.ToLower(req.Proto)))
attrs = append(attrs, semconv.NetworkProtocolVersion(observability.Proto(req.Proto)))
attrs = append(attrs, semconv.ServerAddress(req.URL.Host))
_, port, err := net.SplitHostPort(req.URL.Host)
if err != nil {
switch req.URL.Scheme {
case "http":
attrs = append(attrs, semconv.ServerPort(80))
case "https":
attrs = append(attrs, semconv.ServerPort(443))
}
} else {
intPort, _ := strconv.Atoi(port)
attrs = append(attrs, semconv.ServerPort(intPort))
}
attrs = append(attrs, semconv.URLScheme(req.Header.Get("X-Forwarded-Proto")))
t.semConvMetricRegistry.HTTPClientRequestDuration().Record(req.Context(), end.Sub(start).Seconds(), metric.WithAttributes(attrs...))
}
return response, err
}

View file

@ -0,0 +1,122 @@
package httputil
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/require"
ptypes "github.com/traefik/paerser/types"
"github.com/traefik/traefik/v3/pkg/metrics"
"github.com/traefik/traefik/v3/pkg/types"
"go.opentelemetry.io/otel/attribute"
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/metric/metricdata"
"go.opentelemetry.io/otel/sdk/metric/metricdata/metricdatatest"
)
func TestObservabilityRoundTripper_metrics(t *testing.T) {
tests := []struct {
desc string
serverURL string
statusCode int
wantAttributes attribute.Set
}{
{
desc: "not found status",
serverURL: "http://www.test.com",
statusCode: http.StatusNotFound,
wantAttributes: attribute.NewSet(
attribute.Key("error.type").String("404"),
attribute.Key("http.request.method").String("GET"),
attribute.Key("http.response.status_code").Int(404),
attribute.Key("network.protocol.name").String("http/1.1"),
attribute.Key("network.protocol.version").String("1.1"),
attribute.Key("server.address").String("www.test.com"),
attribute.Key("server.port").Int(80),
attribute.Key("url.scheme").String("http"),
),
},
{
desc: "created status",
serverURL: "https://www.test.com",
statusCode: http.StatusCreated,
wantAttributes: attribute.NewSet(
attribute.Key("http.request.method").String("GET"),
attribute.Key("http.response.status_code").Int(201),
attribute.Key("network.protocol.name").String("http/1.1"),
attribute.Key("network.protocol.version").String("1.1"),
attribute.Key("server.address").String("www.test.com"),
attribute.Key("server.port").Int(443),
attribute.Key("url.scheme").String("http"),
),
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
var cfg types.OTLP
(&cfg).SetDefaults()
cfg.AddRoutersLabels = true
cfg.PushInterval = ptypes.Duration(10 * time.Millisecond)
rdr := sdkmetric.NewManualReader()
meterProvider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(rdr))
// force the meter provider with manual reader to collect metrics for the test.
metrics.SetMeterProvider(meterProvider)
semConvMetricRegistry, err := metrics.NewSemConvMetricRegistry(context.Background(), &cfg)
require.NoError(t, err)
require.NotNil(t, semConvMetricRegistry)
req := httptest.NewRequest(http.MethodGet, test.serverURL+"/search?q=Opentelemetry", nil)
req.RemoteAddr = "10.0.0.1:1234"
req.Header.Set("User-Agent", "rt-test")
req.Header.Set("X-Forwarded-Proto", "http")
ort := newObservabilityRoundTripper(semConvMetricRegistry, mockRoundTripper{statusCode: test.statusCode})
_, err = ort.RoundTrip(req)
require.NoError(t, err)
got := metricdata.ResourceMetrics{}
err = rdr.Collect(context.Background(), &got)
require.NoError(t, err)
require.Len(t, got.ScopeMetrics, 1)
expected := metricdata.Metrics{
Name: "http.client.request.duration",
Description: "Duration of HTTP client requests.",
Unit: "s",
Data: metricdata.Histogram[float64]{
DataPoints: []metricdata.HistogramDataPoint[float64]{
{
Attributes: test.wantAttributes,
Count: 1,
Bounds: []float64{0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.25, 0.5, 0.75, 1, 2.5, 5, 7.5, 10},
BucketCounts: []uint64{0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
Min: metricdata.NewExtrema[float64](1),
Max: metricdata.NewExtrema[float64](1),
Sum: 1,
},
},
Temporality: metricdata.CumulativeTemporality,
},
}
metricdatatest.AssertEqual[metricdata.Metrics](t, expected, got.ScopeMetrics[0].Metrics[0], metricdatatest.IgnoreTimestamp(), metricdatatest.IgnoreValue())
})
}
}
type mockRoundTripper struct {
statusCode int
}
func (m mockRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
return &http.Response{StatusCode: m.statusCode}, nil
}

135
pkg/proxy/httputil/proxy.go Normal file
View file

@ -0,0 +1,135 @@
package httputil
import (
"context"
"errors"
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"time"
"github.com/rs/zerolog/log"
"golang.org/x/net/http/httpguts"
)
// StatusClientClosedRequest non-standard HTTP status code for client disconnection.
const StatusClientClosedRequest = 499
// StatusClientClosedRequestText non-standard HTTP status for client disconnection.
const StatusClientClosedRequestText = "Client Closed Request"
func buildSingleHostProxy(target *url.URL, passHostHeader bool, flushInterval time.Duration, roundTripper http.RoundTripper, bufferPool httputil.BufferPool) http.Handler {
return &httputil.ReverseProxy{
Director: directorBuilder(target, passHostHeader),
Transport: roundTripper,
FlushInterval: flushInterval,
BufferPool: bufferPool,
ErrorHandler: ErrorHandler,
}
}
func directorBuilder(target *url.URL, passHostHeader bool) func(req *http.Request) {
return func(outReq *http.Request) {
outReq.URL.Scheme = target.Scheme
outReq.URL.Host = target.Host
u := outReq.URL
if outReq.RequestURI != "" {
parsedURL, err := url.ParseRequestURI(outReq.RequestURI)
if err == nil {
u = parsedURL
}
}
outReq.URL.Path = u.Path
outReq.URL.RawPath = u.RawPath
// If a plugin/middleware adds semicolons in query params, they should be urlEncoded.
outReq.URL.RawQuery = strings.ReplaceAll(u.RawQuery, ";", "&")
outReq.RequestURI = "" // Outgoing request should not have RequestURI
outReq.Proto = "HTTP/1.1"
outReq.ProtoMajor = 1
outReq.ProtoMinor = 1
// Do not pass client Host header unless optsetter PassHostHeader is set.
if !passHostHeader {
outReq.Host = outReq.URL.Host
}
cleanWebSocketHeaders(outReq)
}
}
// cleanWebSocketHeaders Even if the websocket RFC says that headers should be case-insensitive,
// some servers need Sec-WebSocket-Key, Sec-WebSocket-Extensions, Sec-WebSocket-Accept,
// Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive.
// https://tools.ietf.org/html/rfc6455#page-20
func cleanWebSocketHeaders(req *http.Request) {
if !isWebSocketUpgrade(req) {
return
}
req.Header["Sec-WebSocket-Key"] = req.Header["Sec-Websocket-Key"]
delete(req.Header, "Sec-Websocket-Key")
req.Header["Sec-WebSocket-Extensions"] = req.Header["Sec-Websocket-Extensions"]
delete(req.Header, "Sec-Websocket-Extensions")
req.Header["Sec-WebSocket-Accept"] = req.Header["Sec-Websocket-Accept"]
delete(req.Header, "Sec-Websocket-Accept")
req.Header["Sec-WebSocket-Protocol"] = req.Header["Sec-Websocket-Protocol"]
delete(req.Header, "Sec-Websocket-Protocol")
req.Header["Sec-WebSocket-Version"] = req.Header["Sec-Websocket-Version"]
delete(req.Header, "Sec-Websocket-Version")
}
func isWebSocketUpgrade(req *http.Request) bool {
return httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") &&
strings.EqualFold(req.Header.Get("Upgrade"), "websocket")
}
// ErrorHandler is the http.Handler called when something goes wrong when forwarding the request.
func ErrorHandler(w http.ResponseWriter, req *http.Request, err error) {
statusCode := ComputeStatusCode(err)
logger := log.Ctx(req.Context())
logger.Debug().Err(err).Msgf("%d %s", statusCode, statusText(statusCode))
w.WriteHeader(statusCode)
if _, werr := w.Write([]byte(statusText(statusCode))); werr != nil {
logger.Debug().Err(werr).Msg("Error while writing status code")
}
}
// ComputeStatusCode computes the HTTP status code according to the given error.
func ComputeStatusCode(err error) int {
switch {
case errors.Is(err, io.EOF):
return http.StatusBadGateway
case errors.Is(err, context.Canceled):
return StatusClientClosedRequest
default:
var netErr net.Error
if errors.As(err, &netErr) {
if netErr.Timeout() {
return http.StatusGatewayTimeout
}
return http.StatusBadGateway
}
}
return http.StatusInternalServerError
}
func statusText(statusCode int) string {
if statusCode == StatusClientClosedRequest {
return StatusClientClosedRequestText
}
return http.StatusText(statusCode)
}

View file

@ -0,0 +1,638 @@
package httputil
import (
"bufio"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
gorillawebsocket "github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/traefik/traefik/v3/pkg/testhelpers"
"golang.org/x/net/websocket"
)
func TestWebSocketTCPClose(t *testing.T) {
errChan := make(chan error, 1)
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
for {
_, _, err := c.ReadMessage()
if err != nil {
errChan <- err
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
proxyAddr := proxy.Listener.Addr().String()
_, conn, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
).open()
require.NoError(t, err)
conn.Close()
serverErr := <-errChan
var wsErr *gorillawebsocket.CloseError
require.ErrorAs(t, serverErr, &wsErr)
assert.Equal(t, 1006, wsErr.Code)
}
func TestWebSocketPingPong(t *testing.T) {
upgrader := gorillawebsocket.Upgrader{
HandshakeTimeout: 10 * time.Second,
CheckOrigin: func(*http.Request) bool {
return true
},
}
mux := http.NewServeMux()
mux.HandleFunc("/ws", func(writer http.ResponseWriter, request *http.Request) {
ws, err := upgrader.Upgrade(writer, request, nil)
require.NoError(t, err)
ws.SetPingHandler(func(appData string) error {
err = ws.WriteMessage(gorillawebsocket.PongMessage, []byte(appData+"Pong"))
require.NoError(t, err)
return nil
})
_, _, _ = ws.ReadMessage()
})
srv := httptest.NewServer(mux)
defer srv.Close()
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
serverAddr := proxy.Listener.Addr().String()
headers := http.Header{}
webSocketURL := "ws://" + serverAddr + "/ws"
headers.Add("Origin", webSocketURL)
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
require.NoError(t, err, "Error during Dial with response: %+v", resp)
defer conn.Close()
goodErr := fmt.Errorf("signal: %s", "Good data")
badErr := fmt.Errorf("signal: %s", "Bad data")
conn.SetPongHandler(func(data string) error {
if data == "PingPong" {
return goodErr
}
return badErr
})
err = conn.WriteControl(gorillawebsocket.PingMessage, []byte("Ping"), time.Now().Add(time.Second))
require.NoError(t, err)
_, _, err = conn.ReadMessage()
if !errors.Is(err, goodErr) {
require.NoError(t, err)
}
}
func TestWebSocketEcho(t *testing.T) {
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
msg := make([]byte, 4)
n, err := conn.Read(msg)
require.NoError(t, err)
_, err = conn.Write(msg[:n])
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
}))
srv := httptest.NewServer(mux)
defer srv.Close()
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
serverAddr := proxy.Listener.Addr().String()
headers := http.Header{}
webSocketURL := "ws://" + serverAddr + "/ws"
headers.Add("Origin", webSocketURL)
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
require.NoError(t, err, "Error during Dial with response: %+v", resp)
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
require.NoError(t, err)
_, msg, err := conn.ReadMessage()
require.NoError(t, err)
assert.Equal(t, "OK", string(msg))
err = conn.Close()
require.NoError(t, err)
}
func TestWebSocketPassHost(t *testing.T) {
testCases := []struct {
desc string
passHost bool
expected string
}{
{
desc: "PassHost false",
passHost: false,
},
{
desc: "PassHost true",
passHost: true,
expected: "example.com",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
req := conn.Request()
if test.passHost {
require.Equal(t, test.expected, req.Host)
} else {
require.NotEqual(t, test.expected, req.Host)
}
msg := make([]byte, 4)
n, err := conn.Read(msg)
require.NoError(t, err)
_, err = conn.Write(msg[:n])
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
}))
srv := httptest.NewServer(mux)
defer srv.Close()
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
serverAddr := proxy.Listener.Addr().String()
headers := http.Header{}
webSocketURL := "ws://" + serverAddr + "/ws"
headers.Add("Origin", webSocketURL)
headers.Add("Host", "example.com")
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
require.NoError(t, err, "Error during Dial with response: %+v", resp)
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
require.NoError(t, err)
_, msg, err := conn.ReadMessage()
require.NoError(t, err)
assert.Equal(t, "OK", string(msg))
err = conn.Close()
require.NoError(t, err)
})
}
}
func TestWebSocketServerWithoutCheckOrigin(t *testing.T) {
upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
srv := createServer(t, upgrader, func(*http.Request) {})
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
withOrigin("http://127.0.0.2"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func TestWebSocketRequestWithOrigin(t *testing.T) {
srv := createServer(t, gorillawebsocket.Upgrader{}, func(*http.Request) {})
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
_, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("echo"),
withOrigin("http://127.0.0.2"),
).send()
require.EqualError(t, err, "bad status")
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func TestWebSocketRequestWithQueryParams(t *testing.T) {
srv := createServer(t, gorillawebsocket.Upgrader{}, func(r *http.Request) {
assert.Equal(t, "test", r.URL.Query().Get("query"))
})
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws?query=test"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
conn.Close()
}))
srv := httptest.NewServer(mux)
defer srv.Close()
transportManager := &transportManagerMock{
roundTrippers: map[string]http.RoundTripper{
"default@internal": &http.Transport{},
},
}
p, err := NewProxyBuilder(transportManager, nil).Build("default@internal", testhelpers.MustParseURL(srv.URL), false, true, 0)
require.NoError(t, err)
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL = testhelpers.MustParseURL(srv.URL)
w.Header().Set("HEADER-KEY", "HEADER-VALUE")
p.ServeHTTP(w, req)
}))
defer proxy.Close()
serverAddr := proxy.Listener.Addr().String()
headers := http.Header{}
webSocketURL := "ws://" + serverAddr + "/ws"
headers.Add("Origin", webSocketURL)
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
require.NoError(t, err, "Error during Dial with response: %+v", err, resp)
defer conn.Close()
assert.Equal(t, "HEADER-VALUE", resp.Header.Get("HEADER-KEY"))
}
func TestWebSocketRequestWithEncodedChar(t *testing.T) {
srv := createServer(t, gorillawebsocket.Upgrader{}, func(r *http.Request) {
assert.Equal(t, "/%3A%2F%2F", r.URL.EscapedPath())
})
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/%3A%2F%2F"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func TestWebSocketUpgradeFailed(t *testing.T) {
mux := http.NewServeMux()
mux.HandleFunc("/ws", func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(http.StatusBadRequest)
})
srv := httptest.NewServer(mux)
defer srv.Close()
transportManager := &transportManagerMock{
roundTrippers: map[string]http.RoundTripper{
"default@internal": &http.Transport{},
},
}
p, err := NewProxyBuilder(transportManager, nil).Build("default@internal", testhelpers.MustParseURL(srv.URL), false, true, 0)
require.NoError(t, err)
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path // keep the original path
if path == "/ws" {
// Set new backend URL
req.URL = testhelpers.MustParseURL(srv.URL)
req.URL.Path = path
p.ServeHTTP(w, req)
} else {
w.WriteHeader(http.StatusOK)
}
}))
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
conn, err := net.DialTimeout("tcp", proxyAddr, dialTimeout)
require.NoError(t, err)
defer conn.Close()
req, err := http.NewRequest(http.MethodGet, "ws://127.0.0.1/ws", nil)
require.NoError(t, err)
req.Header.Add("upgrade", "websocket")
req.Header.Add("Connection", "upgrade")
err = req.Write(conn)
require.NoError(t, err)
// First request works with 400
br := bufio.NewReader(conn)
resp, err := http.ReadResponse(br, req)
require.NoError(t, err)
assert.Equal(t, 400, resp.StatusCode)
}
func TestForwardsWebsocketTraffic(t *testing.T) {
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
_, err := conn.Write([]byte("ok"))
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
}))
srv := httptest.NewServer(mux)
defer srv.Close()
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("echo"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func createTLSWebsocketServer() *httptest.Server {
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
return srv
}
func TestWebSocketTransferTLSConfig(t *testing.T) {
srv := createTLSWebsocketServer()
defer srv.Close()
proxyWithoutTLSConfig := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
defer proxyWithoutTLSConfig.Close()
proxyAddr := proxyWithoutTLSConfig.Listener.Addr().String()
_, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
).send()
require.EqualError(t, err, "bad status")
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
proxyWithTLSConfig := createProxyWithForwarder(t, srv.URL, transport)
defer proxyWithTLSConfig.Close()
proxyAddr = proxyWithTLSConfig.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
// Don't alter default transport to prevent side effects on other tests.
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
defaultTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
proxyWithTLSConfigFromDefaultTransport := createProxyWithForwarder(t, srv.URL, defaultTransport)
defer proxyWithTLSConfig.Close()
proxyAddr = proxyWithTLSConfigFromDefaultTransport.Listener.Addr().String()
resp, err = newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
const dialTimeout = time.Second
type websocketRequestOpt func(w *websocketRequest)
func withServer(server string) websocketRequestOpt {
return func(w *websocketRequest) {
w.ServerAddr = server
}
}
func withPath(path string) websocketRequestOpt {
return func(w *websocketRequest) {
w.Path = path
}
}
func withData(data string) websocketRequestOpt {
return func(w *websocketRequest) {
w.Data = data
}
}
func withOrigin(origin string) websocketRequestOpt {
return func(w *websocketRequest) {
w.Origin = origin
}
}
func newWebsocketRequest(opts ...websocketRequestOpt) *websocketRequest {
wsrequest := &websocketRequest{}
for _, opt := range opts {
opt(wsrequest)
}
if wsrequest.Origin == "" {
wsrequest.Origin = "http://" + wsrequest.ServerAddr
}
if wsrequest.Config == nil {
wsrequest.Config, _ = websocket.NewConfig(fmt.Sprintf("ws://%s%s", wsrequest.ServerAddr, wsrequest.Path), wsrequest.Origin)
}
return wsrequest
}
type websocketRequest struct {
ServerAddr string
Path string
Data string
Origin string
Config *websocket.Config
}
func (w *websocketRequest) send() (string, error) {
conn, _, err := w.open()
if err != nil {
return "", err
}
defer conn.Close()
if _, err := conn.Write([]byte(w.Data)); err != nil {
return "", err
}
msg := make([]byte, 512)
var n int
n, err = conn.Read(msg)
if err != nil {
return "", err
}
received := string(msg[:n])
return received, nil
}
func (w *websocketRequest) open() (*websocket.Conn, net.Conn, error) {
client, err := net.DialTimeout("tcp", w.ServerAddr, dialTimeout)
if err != nil {
return nil, nil, err
}
conn, err := websocket.NewClient(w.Config, client)
if err != nil {
return nil, nil, err
}
return conn, client, err
}
func createProxyWithForwarder(t *testing.T, uri string, transport http.RoundTripper) *httptest.Server {
t.Helper()
u := testhelpers.MustParseURL(uri)
transportManager := &transportManagerMock{
roundTrippers: map[string]http.RoundTripper{"fwd": transport},
}
p, err := NewProxyBuilder(transportManager, nil).Build("fwd", u, false, true, 0)
require.NoError(t, err)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
// keep the original path
path := req.URL.Path
// Set new backend URL
req.URL = u
req.URL.Path = path
p.ServeHTTP(w, req)
}))
t.Cleanup(srv.Close)
return srv
}
func createServer(t *testing.T, upgrader gorillawebsocket.Upgrader, check func(*http.Request)) *httptest.Server {
t.Helper()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
t.Logf("Error during upgrade: %v", err)
return
}
defer conn.Close()
check(r)
for {
mt, message, err := conn.ReadMessage()
if err != nil {
t.Logf("Error during read: %v", err)
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
t.Logf("Error during write: %v", err)
break
}
}
}))
t.Cleanup(srv.Close)
return srv
}