Introduce a fast proxy mode to improve HTTP/1.1 performances with backends
Co-authored-by: Romain <rtribotte@users.noreply.github.com> Co-authored-by: Julien Salleyron <julien.salleyron@gmail.com>
This commit is contained in:
parent
a6db1cac37
commit
f8a78b3b25
39 changed files with 3173 additions and 378 deletions
129
pkg/proxy/fast/builder.go
Normal file
129
pkg/proxy/fast/builder.go
Normal file
|
@ -0,0 +1,129 @@
|
|||
package fast
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
||||
"github.com/traefik/traefik/v3/pkg/config/static"
|
||||
)
|
||||
|
||||
// TransportManager manages transport used for backend communications.
|
||||
type TransportManager interface {
|
||||
Get(name string) (*dynamic.ServersTransport, error)
|
||||
GetTLSConfig(name string) (*tls.Config, error)
|
||||
}
|
||||
|
||||
// ProxyBuilder handles the connection pools for the FastProxy proxies.
|
||||
type ProxyBuilder struct {
|
||||
debug bool
|
||||
transportManager TransportManager
|
||||
|
||||
// lock isn't needed because ProxyBuilder is not called concurrently.
|
||||
pools map[string]map[string]*connPool
|
||||
proxy func(*http.Request) (*url.URL, error)
|
||||
|
||||
// not goroutine safe.
|
||||
configs map[string]*dynamic.ServersTransport
|
||||
}
|
||||
|
||||
// NewProxyBuilder creates a new ProxyBuilder.
|
||||
func NewProxyBuilder(transportManager TransportManager, config static.FastProxyConfig) *ProxyBuilder {
|
||||
return &ProxyBuilder{
|
||||
debug: config.Debug,
|
||||
transportManager: transportManager,
|
||||
pools: make(map[string]map[string]*connPool),
|
||||
proxy: http.ProxyFromEnvironment,
|
||||
configs: make(map[string]*dynamic.ServersTransport),
|
||||
}
|
||||
}
|
||||
|
||||
// Update updates all the round-tripper corresponding to the given configs.
|
||||
// This method must not be used concurrently.
|
||||
func (r *ProxyBuilder) Update(newConfigs map[string]*dynamic.ServersTransport) {
|
||||
for configName := range r.configs {
|
||||
if _, ok := newConfigs[configName]; !ok {
|
||||
for _, c := range r.pools[configName] {
|
||||
c.Close()
|
||||
}
|
||||
delete(r.pools, configName)
|
||||
}
|
||||
}
|
||||
|
||||
for newConfigName, newConfig := range newConfigs {
|
||||
if !reflect.DeepEqual(newConfig, r.configs[newConfigName]) {
|
||||
for _, c := range r.pools[newConfigName] {
|
||||
c.Close()
|
||||
}
|
||||
delete(r.pools, newConfigName)
|
||||
}
|
||||
}
|
||||
|
||||
r.configs = newConfigs
|
||||
}
|
||||
|
||||
// Build builds a new ReverseProxy with the given configuration.
|
||||
func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, passHostHeader bool) (http.Handler, error) {
|
||||
proxyURL, err := r.proxy(&http.Request{URL: targetURL})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting proxy: %w", err)
|
||||
}
|
||||
|
||||
cfg, err := r.transportManager.Get(cfgName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting ServersTransport: %w", err)
|
||||
}
|
||||
|
||||
var responseHeaderTimeout time.Duration
|
||||
if cfg.ForwardingTimeouts != nil {
|
||||
responseHeaderTimeout = time.Duration(cfg.ForwardingTimeouts.ResponseHeaderTimeout)
|
||||
}
|
||||
|
||||
tlsConfig, err := r.transportManager.GetTLSConfig(cfgName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting TLS config: %w", err)
|
||||
}
|
||||
|
||||
pool := r.getPool(cfgName, cfg, tlsConfig, targetURL, proxyURL)
|
||||
return NewReverseProxy(targetURL, proxyURL, r.debug, passHostHeader, responseHeaderTimeout, pool)
|
||||
}
|
||||
|
||||
func (r *ProxyBuilder) getPool(cfgName string, config *dynamic.ServersTransport, tlsConfig *tls.Config, targetURL *url.URL, proxyURL *url.URL) *connPool {
|
||||
pool, ok := r.pools[cfgName]
|
||||
if !ok {
|
||||
pool = make(map[string]*connPool)
|
||||
r.pools[cfgName] = pool
|
||||
}
|
||||
|
||||
if connPool, ok := pool[targetURL.String()]; ok {
|
||||
return connPool
|
||||
}
|
||||
|
||||
idleConnTimeout := 90 * time.Second
|
||||
dialTimeout := 30 * time.Second
|
||||
if config.ForwardingTimeouts != nil {
|
||||
idleConnTimeout = time.Duration(config.ForwardingTimeouts.IdleConnTimeout)
|
||||
dialTimeout = time.Duration(config.ForwardingTimeouts.DialTimeout)
|
||||
}
|
||||
|
||||
proxyDialer := newDialer(dialerConfig{
|
||||
DialKeepAlive: 0,
|
||||
DialTimeout: dialTimeout,
|
||||
HTTP: true,
|
||||
TLS: targetURL.Scheme == "https",
|
||||
ProxyURL: proxyURL,
|
||||
}, tlsConfig)
|
||||
|
||||
connPool := newConnPool(config.MaxIdleConnsPerHost, idleConnTimeout, func() (net.Conn, error) {
|
||||
return proxyDialer.Dial("tcp", addrFromURL(targetURL))
|
||||
})
|
||||
|
||||
r.pools[cfgName][targetURL.String()] = connPool
|
||||
|
||||
return connPool
|
||||
}
|
163
pkg/proxy/fast/connpool.go
Normal file
163
pkg/proxy/fast/connpool.go
Normal file
|
@ -0,0 +1,163 @@
|
|||
package fast
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// conn is an enriched net.Conn.
|
||||
type conn struct {
|
||||
net.Conn
|
||||
|
||||
idleAt time.Time // the last time it was marked as idle.
|
||||
idleTimeout time.Duration
|
||||
}
|
||||
|
||||
func (c *conn) isExpired() bool {
|
||||
expTime := c.idleAt.Add(c.idleTimeout)
|
||||
return c.idleTimeout > 0 && time.Now().After(expTime)
|
||||
}
|
||||
|
||||
// connPool is a net.Conn pool implementation using channels.
|
||||
type connPool struct {
|
||||
dialer func() (net.Conn, error)
|
||||
idleConns chan *conn
|
||||
idleConnTimeout time.Duration
|
||||
ticker *time.Ticker
|
||||
doneCh chan struct{}
|
||||
}
|
||||
|
||||
// newConnPool creates a new connPool.
|
||||
func newConnPool(maxIdleConn int, idleConnTimeout time.Duration, dialer func() (net.Conn, error)) *connPool {
|
||||
c := &connPool{
|
||||
dialer: dialer,
|
||||
idleConns: make(chan *conn, maxIdleConn),
|
||||
idleConnTimeout: idleConnTimeout,
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
if idleConnTimeout > 0 {
|
||||
c.ticker = time.NewTicker(c.idleConnTimeout / 2)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-c.ticker.C:
|
||||
c.cleanIdleConns()
|
||||
case <-c.doneCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// Close closes stop the cleanIdleConn goroutine.
|
||||
func (c *connPool) Close() {
|
||||
if c.idleConnTimeout > 0 {
|
||||
close(c.doneCh)
|
||||
c.ticker.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// AcquireConn returns an idle net.Conn from the pool.
|
||||
func (c *connPool) AcquireConn() (*conn, error) {
|
||||
for {
|
||||
co, err := c.acquireConn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !co.isExpired() {
|
||||
return co, nil
|
||||
}
|
||||
|
||||
// As the acquired conn is expired we can close it
|
||||
// without putting it again into the pool.
|
||||
if err := co.Close(); err != nil {
|
||||
log.Debug().
|
||||
Err(err).
|
||||
Msg("Unexpected error while releasing the connection")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReleaseConn releases the given net.Conn to the pool.
|
||||
func (c *connPool) ReleaseConn(co *conn) {
|
||||
co.idleAt = time.Now()
|
||||
c.releaseConn(co)
|
||||
}
|
||||
|
||||
// cleanIdleConns is a routine cleaning the expired connections at a regular basis.
|
||||
func (c *connPool) cleanIdleConns() {
|
||||
for {
|
||||
select {
|
||||
case co := <-c.idleConns:
|
||||
if !co.isExpired() {
|
||||
c.releaseConn(co)
|
||||
return
|
||||
}
|
||||
|
||||
if err := co.Close(); err != nil {
|
||||
log.Debug().
|
||||
Err(err).
|
||||
Msg("Unexpected error while releasing the connection")
|
||||
}
|
||||
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connPool) acquireConn() (*conn, error) {
|
||||
select {
|
||||
case co := <-c.idleConns:
|
||||
return co, nil
|
||||
|
||||
default:
|
||||
errCh := make(chan error, 1)
|
||||
go c.askForNewConn(errCh)
|
||||
|
||||
select {
|
||||
case co := <-c.idleConns:
|
||||
return co, nil
|
||||
|
||||
case err := <-errCh:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connPool) releaseConn(co *conn) {
|
||||
select {
|
||||
case c.idleConns <- co:
|
||||
|
||||
// Hitting the default case means that we have reached the maximum number of idle
|
||||
// connections, so we can close it.
|
||||
default:
|
||||
if err := co.Close(); err != nil {
|
||||
log.Debug().
|
||||
Err(err).
|
||||
Msg("Unexpected error while releasing the connection")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *connPool) askForNewConn(errCh chan<- error) {
|
||||
co, err := c.dialer()
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("create conn: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.releaseConn(&conn{
|
||||
Conn: co,
|
||||
idleAt: time.Now(),
|
||||
idleTimeout: c.idleConnTimeout,
|
||||
})
|
||||
}
|
184
pkg/proxy/fast/connpool_test.go
Normal file
184
pkg/proxy/fast/connpool_test.go
Normal file
|
@ -0,0 +1,184 @@
|
|||
package fast
|
||||
|
||||
import (
|
||||
"net"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConnPool_ConnReuse(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
poolFn func(pool *connPool)
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
desc: "One connection",
|
||||
poolFn: func(pool *connPool) {
|
||||
c1, _ := pool.AcquireConn()
|
||||
pool.ReleaseConn(c1)
|
||||
},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
desc: "Two connections with release",
|
||||
poolFn: func(pool *connPool) {
|
||||
c1, _ := pool.AcquireConn()
|
||||
pool.ReleaseConn(c1)
|
||||
|
||||
c2, _ := pool.AcquireConn()
|
||||
pool.ReleaseConn(c2)
|
||||
},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
desc: "Two concurrent connections",
|
||||
poolFn: func(pool *connPool) {
|
||||
c1, _ := pool.AcquireConn()
|
||||
c2, _ := pool.AcquireConn()
|
||||
|
||||
pool.ReleaseConn(c1)
|
||||
pool.ReleaseConn(c2)
|
||||
},
|
||||
expected: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var connAlloc int
|
||||
dialer := func() (net.Conn, error) {
|
||||
connAlloc++
|
||||
return &net.TCPConn{}, nil
|
||||
}
|
||||
|
||||
pool := newConnPool(2, 0, dialer)
|
||||
test.poolFn(pool)
|
||||
|
||||
assert.Equal(t, test.expected, connAlloc)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnPool_MaxIdleConn(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
poolFn func(pool *connPool)
|
||||
maxIdleConn int
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
desc: "One connection",
|
||||
poolFn: func(pool *connPool) {
|
||||
c1, _ := pool.AcquireConn()
|
||||
pool.ReleaseConn(c1)
|
||||
},
|
||||
maxIdleConn: 1,
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
desc: "Multiple connections with defered release",
|
||||
poolFn: func(pool *connPool) {
|
||||
for range 7 {
|
||||
c, _ := pool.AcquireConn()
|
||||
defer pool.ReleaseConn(c)
|
||||
}
|
||||
},
|
||||
maxIdleConn: 5,
|
||||
expected: 5,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var keepOpenedConn int
|
||||
dialer := func() (net.Conn, error) {
|
||||
keepOpenedConn++
|
||||
return &mockConn{closeFn: func() error {
|
||||
keepOpenedConn--
|
||||
return nil
|
||||
}}, nil
|
||||
}
|
||||
|
||||
pool := newConnPool(test.maxIdleConn, 0, dialer)
|
||||
test.poolFn(pool)
|
||||
|
||||
assert.Equal(t, test.expected, keepOpenedConn)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGC(t *testing.T) {
|
||||
var isDestroyed bool
|
||||
pools := map[string]*connPool{}
|
||||
dialer := func() (net.Conn, error) {
|
||||
c := &mockConn{closeFn: func() error {
|
||||
return nil
|
||||
}}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
pools["test"] = newConnPool(10, 1*time.Second, dialer)
|
||||
runtime.SetFinalizer(pools["test"], func(p *connPool) {
|
||||
isDestroyed = true
|
||||
})
|
||||
c, err := pools["test"].AcquireConn()
|
||||
require.NoError(t, err)
|
||||
|
||||
pools["test"].ReleaseConn(c)
|
||||
|
||||
pools["test"].Close()
|
||||
|
||||
delete(pools, "test")
|
||||
|
||||
runtime.GC()
|
||||
|
||||
require.True(t, isDestroyed)
|
||||
}
|
||||
|
||||
type mockConn struct {
|
||||
closeFn func() error
|
||||
}
|
||||
|
||||
func (m *mockConn) Read(_ []byte) (n int, err error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockConn) Write(_ []byte) (n int, err error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockConn) Close() error {
|
||||
if m.closeFn != nil {
|
||||
return m.closeFn()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConn) LocalAddr() net.Addr {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockConn) RemoteAddr() net.Addr {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockConn) SetDeadline(_ time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockConn) SetReadDeadline(_ time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockConn) SetWriteDeadline(_ time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
195
pkg/proxy/fast/dialer.go
Normal file
195
pkg/proxy/fast/dialer.go
Normal file
|
@ -0,0 +1,195 @@
|
|||
package fast
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
const (
|
||||
schemeHTTP = "http"
|
||||
schemeHTTPS = "https"
|
||||
schemeSocks5 = "socks5"
|
||||
)
|
||||
|
||||
type dialer interface {
|
||||
Dial(network, addr string) (c net.Conn, err error)
|
||||
}
|
||||
|
||||
type dialerFunc func(network, addr string) (net.Conn, error)
|
||||
|
||||
func (d dialerFunc) Dial(network, addr string) (net.Conn, error) {
|
||||
return d(network, addr)
|
||||
}
|
||||
|
||||
type dialerConfig struct {
|
||||
DialKeepAlive time.Duration
|
||||
DialTimeout time.Duration
|
||||
ProxyURL *url.URL
|
||||
HTTP bool
|
||||
TLS bool
|
||||
}
|
||||
|
||||
func newDialer(cfg dialerConfig, tlsConfig *tls.Config) dialer {
|
||||
if cfg.ProxyURL == nil {
|
||||
return buildDialer(cfg, tlsConfig, cfg.TLS)
|
||||
}
|
||||
|
||||
proxyDialer := buildDialer(cfg, tlsConfig, cfg.ProxyURL.Scheme == "https")
|
||||
proxyAddr := addrFromURL(cfg.ProxyURL)
|
||||
|
||||
switch {
|
||||
case cfg.ProxyURL.Scheme == schemeSocks5:
|
||||
var auth *proxy.Auth
|
||||
if u := cfg.ProxyURL.User; u != nil {
|
||||
auth = &proxy.Auth{User: u.Username()}
|
||||
auth.Password, _ = u.Password()
|
||||
}
|
||||
|
||||
// SOCKS5 implementation do not return errors.
|
||||
socksDialer, _ := proxy.SOCKS5("tcp", proxyAddr, auth, proxyDialer)
|
||||
return dialerFunc(func(network, targetAddr string) (net.Conn, error) {
|
||||
co, err := socksDialer.Dial("tcp", targetAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cfg.TLS {
|
||||
c := &tls.Config{}
|
||||
if tlsConfig != nil {
|
||||
c = tlsConfig.Clone()
|
||||
}
|
||||
|
||||
if c.ServerName == "" {
|
||||
host, _, _ := net.SplitHostPort(targetAddr)
|
||||
c.ServerName = host
|
||||
}
|
||||
|
||||
return tls.Client(co, c), nil
|
||||
}
|
||||
|
||||
return co, nil
|
||||
})
|
||||
case cfg.HTTP && !cfg.TLS:
|
||||
// Nothing to do the Proxy-Authorization header will be added by the ReverseProxy.
|
||||
|
||||
default:
|
||||
hdr := make(http.Header)
|
||||
if u := cfg.ProxyURL.User; u != nil {
|
||||
username := u.Username()
|
||||
password, _ := u.Password()
|
||||
auth := username + ":" + password
|
||||
hdr.Set("Proxy-Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(auth)))
|
||||
}
|
||||
|
||||
return dialerFunc(func(network, targetAddr string) (net.Conn, error) {
|
||||
conn, err := proxyDialer.Dial("tcp", proxyAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
connectReq := &http.Request{
|
||||
Method: http.MethodConnect,
|
||||
URL: &url.URL{Opaque: targetAddr},
|
||||
Host: targetAddr,
|
||||
Header: hdr,
|
||||
}
|
||||
|
||||
connectCtx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
didReadResponse := make(chan struct{}) // closed after CONNECT write+read is done or fails
|
||||
var resp *http.Response
|
||||
|
||||
// Write the CONNECT request & read the response.
|
||||
go func() {
|
||||
defer close(didReadResponse)
|
||||
err = connectReq.Write(conn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// Okay to use and discard buffered reader here, because
|
||||
// TLS server will not speak until spoken to.
|
||||
br := bufio.NewReader(conn)
|
||||
resp, err = http.ReadResponse(br, connectReq)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-connectCtx.Done():
|
||||
conn.Close()
|
||||
<-didReadResponse
|
||||
return nil, connectCtx.Err()
|
||||
case <-didReadResponse:
|
||||
// resp or err now set
|
||||
}
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
_, statusText, ok := strings.Cut(resp.Status, " ")
|
||||
conn.Close()
|
||||
if !ok {
|
||||
return nil, errors.New("unknown status code")
|
||||
}
|
||||
|
||||
return nil, errors.New(statusText)
|
||||
}
|
||||
|
||||
c := &tls.Config{}
|
||||
if tlsConfig != nil {
|
||||
c = tlsConfig.Clone()
|
||||
}
|
||||
if c.ServerName == "" {
|
||||
host, _, _ := net.SplitHostPort(targetAddr)
|
||||
c.ServerName = host
|
||||
}
|
||||
|
||||
return tls.Client(conn, c), nil
|
||||
})
|
||||
}
|
||||
return dialerFunc(func(network, addr string) (net.Conn, error) {
|
||||
return proxyDialer.Dial("tcp", proxyAddr)
|
||||
})
|
||||
}
|
||||
|
||||
func buildDialer(cfg dialerConfig, tlsConfig *tls.Config, isTLS bool) dialer {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: cfg.DialTimeout,
|
||||
KeepAlive: cfg.DialKeepAlive,
|
||||
}
|
||||
|
||||
if !isTLS {
|
||||
return dialer
|
||||
}
|
||||
|
||||
return &tls.Dialer{
|
||||
NetDialer: dialer,
|
||||
Config: tlsConfig,
|
||||
}
|
||||
}
|
||||
|
||||
func addrFromURL(u *url.URL) string {
|
||||
addr := u.Host
|
||||
|
||||
if u.Port() == "" {
|
||||
switch u.Scheme {
|
||||
case schemeHTTP:
|
||||
return addr + ":80"
|
||||
case schemeHTTPS:
|
||||
return addr + ":443"
|
||||
}
|
||||
}
|
||||
|
||||
return addr
|
||||
}
|
553
pkg/proxy/fast/proxy.go
Normal file
553
pkg/proxy/fast/proxy.go
Normal file
|
@ -0,0 +1,553 @@
|
|||
package fast
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
proxyhttputil "github.com/traefik/traefik/v3/pkg/proxy/httputil"
|
||||
"github.com/valyala/fasthttp"
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
|
||||
const (
|
||||
bufferSize = 32 * 1024
|
||||
bufioSize = 64 * 1024
|
||||
)
|
||||
|
||||
var hopHeaders = []string{
|
||||
"Connection",
|
||||
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
|
||||
"Keep-Alive",
|
||||
"Proxy-Authenticate",
|
||||
"Proxy-Authorization",
|
||||
"Te", // canonicalized version of "TE"
|
||||
"Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
}
|
||||
|
||||
type pool[T any] struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
func (p *pool[T]) Get() T {
|
||||
if tmp := p.pool.Get(); tmp != nil {
|
||||
return tmp.(T)
|
||||
}
|
||||
|
||||
var res T
|
||||
return res
|
||||
}
|
||||
|
||||
func (p *pool[T]) Put(x T) {
|
||||
p.pool.Put(x)
|
||||
}
|
||||
|
||||
type buffConn struct {
|
||||
*bufio.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (b buffConn) Read(p []byte) (int, error) {
|
||||
return b.Reader.Read(p)
|
||||
}
|
||||
|
||||
type writeDetector struct {
|
||||
net.Conn
|
||||
|
||||
written bool
|
||||
}
|
||||
|
||||
func (w *writeDetector) Write(p []byte) (int, error) {
|
||||
n, err := w.Conn.Write(p)
|
||||
if n > 0 {
|
||||
w.written = true
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
type writeFlusher struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (w *writeFlusher) Write(b []byte) (int, error) {
|
||||
n, err := w.Writer.Write(b)
|
||||
if f, ok := w.Writer.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
type timeoutError struct {
|
||||
error
|
||||
}
|
||||
|
||||
func (t timeoutError) Timeout() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (t timeoutError) Temporary() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// ReverseProxy is the FastProxy reverse proxy implementation.
|
||||
type ReverseProxy struct {
|
||||
debug bool
|
||||
|
||||
connPool *connPool
|
||||
|
||||
bufferPool pool[[]byte]
|
||||
readerPool pool[*bufio.Reader]
|
||||
writerPool pool[*bufio.Writer]
|
||||
limitReaderPool pool[*io.LimitedReader]
|
||||
|
||||
proxyAuth string
|
||||
|
||||
targetURL *url.URL
|
||||
passHostHeader bool
|
||||
responseHeaderTimeout time.Duration
|
||||
}
|
||||
|
||||
// NewReverseProxy creates a new ReverseProxy.
|
||||
func NewReverseProxy(targetURL *url.URL, proxyURL *url.URL, debug, passHostHeader bool, responseHeaderTimeout time.Duration, connPool *connPool) (*ReverseProxy, error) {
|
||||
var proxyAuth string
|
||||
if proxyURL != nil && proxyURL.User != nil && targetURL.Scheme == "http" {
|
||||
username := proxyURL.User.Username()
|
||||
password, _ := proxyURL.User.Password()
|
||||
proxyAuth = "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
||||
}
|
||||
|
||||
return &ReverseProxy{
|
||||
debug: debug,
|
||||
passHostHeader: passHostHeader,
|
||||
targetURL: targetURL,
|
||||
proxyAuth: proxyAuth,
|
||||
connPool: connPool,
|
||||
responseHeaderTimeout: responseHeaderTimeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Body != nil {
|
||||
defer req.Body.Close()
|
||||
}
|
||||
|
||||
outReq := fasthttp.AcquireRequest()
|
||||
defer fasthttp.ReleaseRequest(outReq)
|
||||
|
||||
// This is not required as the headers are already normalized by net/http.
|
||||
outReq.Header.DisableNormalizing()
|
||||
|
||||
for k, v := range req.Header {
|
||||
for _, s := range v {
|
||||
outReq.Header.Add(k, s)
|
||||
}
|
||||
}
|
||||
|
||||
removeConnectionHeaders(&outReq.Header)
|
||||
|
||||
for _, header := range hopHeaders {
|
||||
outReq.Header.Del(header)
|
||||
}
|
||||
|
||||
if p.proxyAuth != "" {
|
||||
outReq.Header.Set("Proxy-Authorization", p.proxyAuth)
|
||||
}
|
||||
|
||||
if httpguts.HeaderValuesContainsToken(req.Header["Te"], "trailers") {
|
||||
outReq.Header.Set("Te", "trailers")
|
||||
}
|
||||
|
||||
if p.debug {
|
||||
outReq.Header.Set("X-Traefik-Fast-Proxy", "enabled")
|
||||
}
|
||||
|
||||
reqUpType := upgradeType(req.Header)
|
||||
if !isGraphic(reqUpType) {
|
||||
proxyhttputil.ErrorHandler(rw, req, fmt.Errorf("client tried to switch to invalid protocol %q", reqUpType))
|
||||
return
|
||||
}
|
||||
|
||||
if reqUpType != "" {
|
||||
outReq.Header.Set("Connection", "Upgrade")
|
||||
outReq.Header.Set("Upgrade", reqUpType)
|
||||
if reqUpType == "websocket" {
|
||||
cleanWebSocketHeaders(&outReq.Header)
|
||||
}
|
||||
}
|
||||
|
||||
u2 := new(url.URL)
|
||||
*u2 = *req.URL
|
||||
u2.Scheme = p.targetURL.Scheme
|
||||
u2.Host = p.targetURL.Host
|
||||
|
||||
u := req.URL
|
||||
if req.RequestURI != "" {
|
||||
parsedURL, err := url.ParseRequestURI(req.RequestURI)
|
||||
if err == nil {
|
||||
u = parsedURL
|
||||
}
|
||||
}
|
||||
|
||||
u2.Path = u.Path
|
||||
u2.RawPath = u.RawPath
|
||||
u2.RawQuery = strings.ReplaceAll(u.RawQuery, ";", "&")
|
||||
|
||||
outReq.SetHost(u2.Host)
|
||||
outReq.Header.SetHost(u2.Host)
|
||||
|
||||
if p.passHostHeader {
|
||||
outReq.Header.SetHost(req.Host)
|
||||
}
|
||||
|
||||
outReq.SetRequestURI(u2.RequestURI())
|
||||
|
||||
outReq.SetBodyStream(req.Body, int(req.ContentLength))
|
||||
|
||||
outReq.Header.SetMethod(req.Method)
|
||||
|
||||
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||
// If we aren't the first proxy retain prior
|
||||
// X-Forwarded-For information as a comma+space
|
||||
// separated list and fold multiple headers into one.
|
||||
prior, ok := req.Header["X-Forwarded-For"]
|
||||
if len(prior) > 0 {
|
||||
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||||
}
|
||||
|
||||
omit := ok && prior == nil // Go Issue 38079: nil now means don't populate the header
|
||||
if !omit {
|
||||
outReq.Header.Set("X-Forwarded-For", clientIP)
|
||||
}
|
||||
}
|
||||
|
||||
if err := p.roundTrip(rw, req, outReq, reqUpType); err != nil {
|
||||
proxyhttputil.ErrorHandler(rw, req, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Note that unlike the net/http RoundTrip:
|
||||
// - we are not supporting "100 Continue" response to forward them as-is to the client.
|
||||
// - we are not asking for compressed response automatically. That is because this will add an extra cost when the
|
||||
// client is asking for an uncompressed response, as we will have to un-compress it, and nowadays most clients are
|
||||
// already asking for compressed response (allowing "passthrough" compression).
|
||||
func (p *ReverseProxy) roundTrip(rw http.ResponseWriter, req *http.Request, outReq *fasthttp.Request, reqUpType string) error {
|
||||
ctx := req.Context()
|
||||
trace := httptrace.ContextClientTrace(ctx)
|
||||
|
||||
var co *conn
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
|
||||
default:
|
||||
}
|
||||
|
||||
var err error
|
||||
co, err = p.connPool.AcquireConn()
|
||||
if err != nil {
|
||||
return fmt.Errorf("acquire connection: %w", err)
|
||||
}
|
||||
|
||||
wd := &writeDetector{Conn: co}
|
||||
|
||||
err = p.writeRequest(wd, outReq)
|
||||
if wd.written && trace != nil && trace.WroteRequest != nil {
|
||||
// WroteRequest hook is used by the tracing middleware to detect if the request has been written.
|
||||
trace.WroteRequest(httptrace.WroteRequestInfo{})
|
||||
}
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
log.Ctx(ctx).Debug().Err(err).Msg("Error while writing request")
|
||||
|
||||
co.Close()
|
||||
|
||||
if wd.written && !isReplayable(req) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
br := p.readerPool.Get()
|
||||
if br == nil {
|
||||
br = bufio.NewReaderSize(co, bufioSize)
|
||||
}
|
||||
defer p.readerPool.Put(br)
|
||||
|
||||
br.Reset(co)
|
||||
|
||||
res := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseResponse(res)
|
||||
|
||||
res.Header.SetNoDefaultContentType(true)
|
||||
|
||||
for {
|
||||
var timer *time.Timer
|
||||
errTimeout := atomic.Pointer[timeoutError]{}
|
||||
if p.responseHeaderTimeout > 0 {
|
||||
timer = time.AfterFunc(p.responseHeaderTimeout, func() {
|
||||
errTimeout.Store(&timeoutError{errors.New("timeout awaiting response headers")})
|
||||
co.Close()
|
||||
})
|
||||
}
|
||||
|
||||
res.Header.SetNoDefaultContentType(true)
|
||||
if err := res.Header.Read(br); err != nil {
|
||||
if p.responseHeaderTimeout > 0 {
|
||||
if errT := errTimeout.Load(); errT != nil {
|
||||
return errT
|
||||
}
|
||||
}
|
||||
co.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
|
||||
fixPragmaCacheControl(&res.Header)
|
||||
|
||||
resCode := res.StatusCode()
|
||||
is1xx := 100 <= resCode && resCode <= 199
|
||||
// treat 101 as a terminal status, see issue 26161
|
||||
is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols
|
||||
if is1xxNonTerminal {
|
||||
removeConnectionHeaders(&res.Header)
|
||||
h := rw.Header()
|
||||
|
||||
for _, header := range hopHeaders {
|
||||
res.Header.Del(header)
|
||||
}
|
||||
|
||||
res.Header.VisitAll(func(key, value []byte) {
|
||||
rw.Header().Add(string(key), string(value))
|
||||
})
|
||||
|
||||
rw.WriteHeader(res.StatusCode())
|
||||
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
|
||||
for k := range h {
|
||||
delete(h, k)
|
||||
}
|
||||
|
||||
res.Reset()
|
||||
res.Header.Reset()
|
||||
res.Header.SetNoDefaultContentType(true)
|
||||
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
announcedTrailers := res.Header.Peek("Trailer")
|
||||
|
||||
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
|
||||
if res.StatusCode() == http.StatusSwitchingProtocols {
|
||||
// As the connection has been hijacked, it cannot be added back to the pool.
|
||||
handleUpgradeResponse(rw, req, reqUpType, res, buffConn{Conn: co, Reader: br})
|
||||
return nil
|
||||
}
|
||||
|
||||
removeConnectionHeaders(&res.Header)
|
||||
|
||||
for _, header := range hopHeaders {
|
||||
res.Header.Del(header)
|
||||
}
|
||||
|
||||
if len(announcedTrailers) > 0 {
|
||||
res.Header.Add("Trailer", string(announcedTrailers))
|
||||
}
|
||||
|
||||
res.Header.VisitAll(func(key, value []byte) {
|
||||
rw.Header().Add(string(key), string(value))
|
||||
})
|
||||
|
||||
rw.WriteHeader(res.StatusCode())
|
||||
|
||||
// Chunked response, Content-Length is set to -1 by FastProxy when "Transfer-Encoding: chunked" header is received.
|
||||
if res.Header.ContentLength() == -1 {
|
||||
cbr := httputil.NewChunkedReader(br)
|
||||
|
||||
b := p.bufferPool.Get()
|
||||
if b == nil {
|
||||
b = make([]byte, bufferSize)
|
||||
}
|
||||
defer p.bufferPool.Put(b)
|
||||
|
||||
if _, err := io.CopyBuffer(&writeFlusher{rw}, cbr, b); err != nil {
|
||||
co.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
res.Header.Reset()
|
||||
res.Header.SetNoDefaultContentType(true)
|
||||
if err := res.Header.ReadTrailer(br); err != nil {
|
||||
co.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Header.Len() > 0 {
|
||||
var announcedTrailersKey []string
|
||||
if len(announcedTrailers) > 0 {
|
||||
announcedTrailersKey = strings.Split(string(announcedTrailers), ",")
|
||||
}
|
||||
|
||||
res.Header.VisitAll(func(key, value []byte) {
|
||||
for _, s := range announcedTrailersKey {
|
||||
if strings.EqualFold(s, strings.TrimSpace(string(key))) {
|
||||
rw.Header().Add(string(key), string(value))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
rw.Header().Add(http.TrailerPrefix+string(key), string(value))
|
||||
})
|
||||
}
|
||||
|
||||
p.connPool.ReleaseConn(co)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
brl := p.limitReaderPool.Get()
|
||||
if brl == nil {
|
||||
brl = &io.LimitedReader{}
|
||||
}
|
||||
defer p.limitReaderPool.Put(brl)
|
||||
|
||||
brl.R = br
|
||||
brl.N = int64(res.Header.ContentLength())
|
||||
|
||||
b := p.bufferPool.Get()
|
||||
if b == nil {
|
||||
b = make([]byte, bufferSize)
|
||||
}
|
||||
defer p.bufferPool.Put(b)
|
||||
|
||||
if _, err := io.CopyBuffer(rw, brl, b); err != nil {
|
||||
co.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
p.connPool.ReleaseConn(co)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) writeRequest(co net.Conn, outReq *fasthttp.Request) error {
|
||||
bw := p.writerPool.Get()
|
||||
if bw == nil {
|
||||
bw = bufio.NewWriterSize(co, bufioSize)
|
||||
}
|
||||
defer p.writerPool.Put(bw)
|
||||
|
||||
bw.Reset(co)
|
||||
|
||||
if err := outReq.Write(bw); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return bw.Flush()
|
||||
}
|
||||
|
||||
// isReplayable returns whether the request is replayable.
|
||||
func isReplayable(req *http.Request) bool {
|
||||
if req.Body == nil || req.Body == http.NoBody {
|
||||
switch req.Method {
|
||||
case http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace:
|
||||
return true
|
||||
}
|
||||
|
||||
// The Idempotency-Key, while non-standard, is widely used to
|
||||
// mean a POST or other request is idempotent. See
|
||||
// https://golang.org/issue/19943#issuecomment-421092421
|
||||
if _, ok := req.Header["Idempotency-Key"]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
if _, ok := req.Header["X-Idempotency-Key"]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isGraphic returns whether s is ASCII and printable according to
|
||||
// https://tools.ietf.org/html/rfc20#section-4.2.
|
||||
func isGraphic(s string) bool {
|
||||
for i := range len(s) {
|
||||
if s[i] < ' ' || s[i] > '~' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
type fasthttpHeader interface {
|
||||
Peek(key string) []byte
|
||||
Set(key string, value string)
|
||||
SetBytesV(key string, value []byte)
|
||||
DelBytes(key []byte)
|
||||
Del(key string)
|
||||
}
|
||||
|
||||
// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h.
|
||||
// See RFC 7230, section 6.1.
|
||||
func removeConnectionHeaders(h fasthttpHeader) {
|
||||
f := h.Peek(fasthttp.HeaderConnection)
|
||||
for _, sf := range bytes.Split(f, []byte{','}) {
|
||||
if sf = bytes.TrimSpace(sf); len(sf) > 0 {
|
||||
h.DelBytes(sf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RFC 7234, section 5.4: Should treat Pragma: no-cache like Cache-Control: no-cache.
|
||||
func fixPragmaCacheControl(header fasthttpHeader) {
|
||||
if pragma := header.Peek("Pragma"); bytes.Equal(pragma, []byte("no-cache")) {
|
||||
if len(header.Peek("Cache-Control")) == 0 {
|
||||
header.Set("Cache-Control", "no-cache")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanWebSocketHeaders Even if the websocket RFC says that headers should be case-insensitive,
|
||||
// some servers need Sec-WebSocket-Key, Sec-WebSocket-Extensions, Sec-WebSocket-Accept,
|
||||
// Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive.
|
||||
// https://tools.ietf.org/html/rfc6455#page-20
|
||||
func cleanWebSocketHeaders(headers fasthttpHeader) {
|
||||
headers.SetBytesV("Sec-WebSocket-Key", headers.Peek("Sec-Websocket-Key"))
|
||||
headers.Del("Sec-Websocket-Key")
|
||||
|
||||
headers.SetBytesV("Sec-WebSocket-Extensions", headers.Peek("Sec-Websocket-Extensions"))
|
||||
headers.Del("Sec-Websocket-Extensions")
|
||||
|
||||
headers.SetBytesV("Sec-WebSocket-Accept", headers.Peek("Sec-Websocket-Accept"))
|
||||
headers.Del("Sec-Websocket-Accept")
|
||||
|
||||
headers.SetBytesV("Sec-WebSocket-Protocol", headers.Peek("Sec-Websocket-Protocol"))
|
||||
headers.Del("Sec-Websocket-Protocol")
|
||||
|
||||
headers.SetBytesV("Sec-WebSocket-Version", headers.Peek("Sec-Websocket-Version"))
|
||||
headers.DelBytes([]byte("Sec-Websocket-Version"))
|
||||
}
|
311
pkg/proxy/fast/proxy_test.go
Normal file
311
pkg/proxy/fast/proxy_test.go
Normal file
|
@ -0,0 +1,311 @@
|
|||
package fast
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/armon/go-socks5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
||||
"github.com/traefik/traefik/v3/pkg/config/static"
|
||||
"github.com/traefik/traefik/v3/pkg/testhelpers"
|
||||
"github.com/traefik/traefik/v3/pkg/tls/generate"
|
||||
)
|
||||
|
||||
const (
|
||||
proxyHTTP = "http"
|
||||
proxyHTTPS = "https"
|
||||
proxySocks5 = "socks"
|
||||
)
|
||||
|
||||
type authCreds struct {
|
||||
user string
|
||||
password string
|
||||
}
|
||||
|
||||
func TestProxyFromEnvironment(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
proxyType string
|
||||
tls bool
|
||||
auth *authCreds
|
||||
}{
|
||||
{
|
||||
desc: "Proxy HTTP with HTTP Backend",
|
||||
proxyType: proxyHTTP,
|
||||
},
|
||||
{
|
||||
desc: "Proxy HTTP with HTTP backend and proxy auth",
|
||||
proxyType: proxyHTTP,
|
||||
tls: false,
|
||||
auth: &authCreds{
|
||||
user: "user",
|
||||
password: "password",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Proxy HTTP with HTTPS backend",
|
||||
proxyType: proxyHTTP,
|
||||
tls: true,
|
||||
},
|
||||
{
|
||||
desc: "Proxy HTTP with HTTPS backend and proxy auth",
|
||||
proxyType: proxyHTTP,
|
||||
tls: true,
|
||||
auth: &authCreds{
|
||||
user: "user",
|
||||
password: "password",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Proxy HTTPS with HTTP backend",
|
||||
proxyType: proxyHTTPS,
|
||||
},
|
||||
{
|
||||
desc: "Proxy HTTPS with HTTP backend and proxy auth",
|
||||
proxyType: proxyHTTPS,
|
||||
tls: false,
|
||||
auth: &authCreds{
|
||||
user: "user",
|
||||
password: "password",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Proxy HTTPS with HTTPS backend",
|
||||
proxyType: proxyHTTPS,
|
||||
tls: true,
|
||||
},
|
||||
{
|
||||
desc: "Proxy HTTPS with HTTPS backend and proxy auth",
|
||||
proxyType: proxyHTTPS,
|
||||
tls: true,
|
||||
auth: &authCreds{
|
||||
user: "user",
|
||||
password: "password",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Proxy Socks5 with HTTP backend",
|
||||
proxyType: proxySocks5,
|
||||
},
|
||||
{
|
||||
desc: "Proxy Socks5 with HTTP backend and proxy auth",
|
||||
proxyType: proxySocks5,
|
||||
auth: &authCreds{
|
||||
user: "user",
|
||||
password: "password",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Proxy Socks5 with HTTPS backend",
|
||||
proxyType: proxySocks5,
|
||||
tls: true,
|
||||
},
|
||||
{
|
||||
desc: "Proxy Socks5 with HTTPS backend and proxy auth",
|
||||
proxyType: proxySocks5,
|
||||
tls: true,
|
||||
auth: &authCreds{
|
||||
user: "user",
|
||||
password: "password",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
backendURL, backendCert := newBackendServer(t, test.tls, http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
_, _ = rw.Write([]byte("backend"))
|
||||
}))
|
||||
|
||||
var proxyCalled bool
|
||||
proxyHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
proxyCalled = true
|
||||
|
||||
if test.auth != nil {
|
||||
proxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(test.auth.user+":"+test.auth.password))
|
||||
require.Equal(t, proxyAuth, req.Header.Get("Proxy-Authorization"))
|
||||
}
|
||||
|
||||
if req.Method != http.MethodConnect {
|
||||
proxy := httputil.NewSingleHostReverseProxy(testhelpers.MustParseURL("http://" + req.Host))
|
||||
proxy.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
// CONNECT method
|
||||
conn, err := net.Dial("tcp", req.Host)
|
||||
require.NoError(t, err)
|
||||
|
||||
hj, ok := rw.(http.Hijacker)
|
||||
require.True(t, ok)
|
||||
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
connHj, _, err := hj.Hijack()
|
||||
require.NoError(t, err)
|
||||
|
||||
go func() { _, _ = io.Copy(connHj, conn) }()
|
||||
_, _ = io.Copy(conn, connHj)
|
||||
})
|
||||
|
||||
var proxyURL string
|
||||
var proxyCert *x509.Certificate
|
||||
|
||||
switch test.proxyType {
|
||||
case proxySocks5:
|
||||
ln, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyURL = fmt.Sprintf("socks5://%s", ln.Addr())
|
||||
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyCalled = true
|
||||
|
||||
conf := &socks5.Config{}
|
||||
if test.auth != nil {
|
||||
conf.Credentials = socks5.StaticCredentials{test.auth.user: test.auth.password}
|
||||
}
|
||||
|
||||
server, err := socks5.New(conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
// We are not checking the error, because ServeConn is blocked until the client or the backend
|
||||
// connection is closed which, in some cases, raises a connection reset by peer error.
|
||||
_ = server.ServeConn(conn)
|
||||
|
||||
err = ln.Close()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
case proxyHTTP:
|
||||
proxyServer := httptest.NewServer(proxyHandler)
|
||||
t.Cleanup(proxyServer.Close)
|
||||
|
||||
proxyURL = proxyServer.URL
|
||||
|
||||
case proxyHTTPS:
|
||||
proxyServer := httptest.NewServer(proxyHandler)
|
||||
t.Cleanup(proxyServer.Close)
|
||||
|
||||
proxyURL = proxyServer.URL
|
||||
proxyCert = proxyServer.Certificate()
|
||||
}
|
||||
|
||||
certPool := x509.NewCertPool()
|
||||
if proxyCert != nil {
|
||||
certPool.AddCert(proxyCert)
|
||||
}
|
||||
if backendCert != nil {
|
||||
cert, err := x509.ParseCertificate(backendCert.Certificate[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
certPool.AddCert(cert)
|
||||
}
|
||||
|
||||
builder := NewProxyBuilder(&transportManagerMock{tlsConfig: &tls.Config{RootCAs: certPool}}, static.FastProxyConfig{})
|
||||
builder.proxy = func(req *http.Request) (*url.URL, error) {
|
||||
u, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if test.auth != nil {
|
||||
u.User = url.UserPassword(test.auth.user, test.auth.password)
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
||||
|
||||
reverseProxy, err := builder.Build("foo", testhelpers.MustParseURL(backendURL), false)
|
||||
require.NoError(t, err)
|
||||
|
||||
reverseProxyServer := httptest.NewServer(reverseProxy)
|
||||
t.Cleanup(reverseProxyServer.Close)
|
||||
|
||||
client := http.Client{Timeout: 5 * time.Second}
|
||||
|
||||
resp, err := client.Get(reverseProxyServer.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "backend", string(body))
|
||||
assert.True(t, proxyCalled)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newCertificate(t *testing.T, domain string) *tls.Certificate {
|
||||
t.Helper()
|
||||
|
||||
certPEM, keyPEM, err := generate.KeyPair(domain, time.Time{})
|
||||
require.NoError(t, err)
|
||||
|
||||
certificate, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||
require.NoError(t, err)
|
||||
|
||||
return &certificate
|
||||
}
|
||||
|
||||
func newBackendServer(t *testing.T, isTLS bool, handler http.Handler) (string, *tls.Certificate) {
|
||||
t.Helper()
|
||||
|
||||
var ln net.Listener
|
||||
var err error
|
||||
var cert *tls.Certificate
|
||||
|
||||
scheme := "http"
|
||||
domain := "backend.localhost"
|
||||
if isTLS {
|
||||
scheme = "https"
|
||||
|
||||
cert = newCertificate(t, domain)
|
||||
|
||||
ln, err = tls.Listen("tcp", ":0", &tls.Config{Certificates: []tls.Certificate{*cert}})
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
ln, err = net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
srv := &http.Server{Handler: handler}
|
||||
go func() { _ = srv.Serve(ln) }()
|
||||
|
||||
t.Cleanup(func() { _ = srv.Close() })
|
||||
|
||||
_, port, err := net.SplitHostPort(ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
backendURL := fmt.Sprintf("%s://%s:%s", scheme, domain, port)
|
||||
|
||||
return backendURL, cert
|
||||
}
|
||||
|
||||
type transportManagerMock struct {
|
||||
tlsConfig *tls.Config
|
||||
}
|
||||
|
||||
func (r *transportManagerMock) GetTLSConfig(_ string) (*tls.Config, error) {
|
||||
return r.tlsConfig, nil
|
||||
}
|
||||
|
||||
func (r *transportManagerMock) Get(_ string) (*dynamic.ServersTransport, error) {
|
||||
return &dynamic.ServersTransport{}, nil
|
||||
}
|
693
pkg/proxy/fast/proxy_websocket_test.go
Normal file
693
pkg/proxy/fast/proxy_websocket_test.go
Normal file
|
@ -0,0 +1,693 @@
|
|||
package fast
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
gorillawebsocket "github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/traefik/traefik/v3/pkg/testhelpers"
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
func TestWebSocketTCPClose(t *testing.T) {
|
||||
errChan := make(chan error, 1)
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
for {
|
||||
_, _, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
_, conn, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
).open()
|
||||
require.NoError(t, err)
|
||||
|
||||
conn.Close()
|
||||
|
||||
serverErr := <-errChan
|
||||
|
||||
var wsErr *gorillawebsocket.CloseError
|
||||
require.ErrorAs(t, serverErr, &wsErr)
|
||||
assert.Equal(t, 1006, wsErr.Code)
|
||||
}
|
||||
|
||||
func TestWebSocketPingPong(t *testing.T) {
|
||||
upgrader := gorillawebsocket.Upgrader{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
CheckOrigin: func(*http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/ws", func(writer http.ResponseWriter, request *http.Request) {
|
||||
ws, err := upgrader.Upgrade(writer, request, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
ws.SetPingHandler(func(appData string) error {
|
||||
err = ws.WriteMessage(gorillawebsocket.PongMessage, []byte(appData+"Pong"))
|
||||
require.NoError(t, err)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
_, _, _ = ws.ReadMessage()
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", resp)
|
||||
defer conn.Close()
|
||||
|
||||
goodErr := fmt.Errorf("signal: %s", "Good data")
|
||||
badErr := fmt.Errorf("signal: %s", "Bad data")
|
||||
conn.SetPongHandler(func(data string) error {
|
||||
if data == "PingPong" {
|
||||
return goodErr
|
||||
}
|
||||
|
||||
return badErr
|
||||
})
|
||||
|
||||
err = conn.WriteControl(gorillawebsocket.PingMessage, []byte("Ping"), time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = conn.ReadMessage()
|
||||
|
||||
if !errors.Is(err, goodErr) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketEcho(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
msg := make([]byte, 4)
|
||||
|
||||
n, err := conn.Read(msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = conn.Write(msg[:n])
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", resp)
|
||||
|
||||
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, msg, err := conn.ReadMessage()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "OK", string(msg))
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestWebSocketPassHost(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
passHost bool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "PassHost false",
|
||||
passHost: false,
|
||||
},
|
||||
{
|
||||
desc: "PassHost true",
|
||||
passHost: true,
|
||||
expected: "example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
req := conn.Request()
|
||||
|
||||
if test.passHost {
|
||||
require.Equal(t, test.expected, req.Host)
|
||||
} else {
|
||||
require.NotEqual(t, test.expected, req.Host)
|
||||
}
|
||||
|
||||
msg := make([]byte, 4)
|
||||
|
||||
n, err := conn.Read(msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = conn.Write(msg[:n])
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
|
||||
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
headers.Add("Host", "example.com")
|
||||
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", resp)
|
||||
|
||||
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, msg, err := conn.ReadMessage()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "OK", string(msg))
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketServerWithoutCheckOrigin(t *testing.T) {
|
||||
upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
}}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
for {
|
||||
mt, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = c.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
withOrigin("http://127.0.0.2"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithOrigin(t *testing.T) {
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
for {
|
||||
mt, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = c.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
_, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("echo"),
|
||||
withOrigin("http://127.0.0.2"),
|
||||
).send()
|
||||
require.EqualError(t, err, "bad status")
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithQueryParams(t *testing.T) {
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
assert.Equal(t, "test", r.URL.Query().Get("query"))
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
err = conn.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws?query=test"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
_ = conn.Close()
|
||||
}))
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u := parseURI(t, srv.URL)
|
||||
|
||||
f, err := NewReverseProxy(u, nil, true, false, 0, newConnPool(1, 0, func() (net.Conn, error) {
|
||||
return net.Dial("tcp", u.Host)
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
req.URL = parseURI(t, srv.URL)
|
||||
w.Header().Set("HEADER-KEY", "HEADER-VALUE")
|
||||
f.ServeHTTP(w, req)
|
||||
}))
|
||||
defer proxy.Close()
|
||||
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", err, resp)
|
||||
defer conn.Close()
|
||||
|
||||
assert.Equal(t, "HEADER-VALUE", resp.Header.Get("HEADER-KEY"))
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithEncodedChar(t *testing.T) {
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
assert.Equal(t, "/%3A%2F%2F", r.URL.EscapedPath())
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = conn.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/%3A%2F%2F"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketUpgradeFailed(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/ws", func(w http.ResponseWriter, req *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
})
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
u := parseURI(t, srv.URL)
|
||||
f, err := NewReverseProxy(u, nil, true, false, 0, newConnPool(1, 0, func() (net.Conn, error) {
|
||||
return net.Dial("tcp", u.Host)
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
path := req.URL.Path // keep the original path
|
||||
|
||||
if path != "/ws" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
// Set new backend URL
|
||||
req.URL = parseURI(t, srv.URL)
|
||||
req.URL.Path = path
|
||||
f.ServeHTTP(w, req)
|
||||
}))
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
conn, err := net.DialTimeout("tcp", proxyAddr, dialTimeout)
|
||||
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "ws://127.0.0.1/ws", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Add("upgrade", "websocket")
|
||||
req.Header.Add("Connection", "upgrade")
|
||||
|
||||
err = req.Write(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First request works with 400
|
||||
br := bufio.NewReader(conn)
|
||||
resp, err := http.ReadResponse(br, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 400, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestForwardsWebsocketTraffic(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
_, err := conn.Write([]byte("ok"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("echo"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func createTLSWebsocketServer() *httptest.Server {
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
err = conn.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestWebSocketTransferTLSConfig(t *testing.T) {
|
||||
srv := createTLSWebsocketServer()
|
||||
defer srv.Close()
|
||||
|
||||
proxyWithoutTLSConfig := createProxyWithForwarder(t, srv.URL, createConnectionPool(srv.URL, nil))
|
||||
defer proxyWithoutTLSConfig.Close()
|
||||
|
||||
proxyAddr := proxyWithoutTLSConfig.Listener.Addr().String()
|
||||
|
||||
_, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.EqualError(t, err, "bad status")
|
||||
|
||||
pool := createConnectionPool(srv.URL, &tls.Config{InsecureSkipVerify: true})
|
||||
|
||||
proxyWithTLSConfig := createProxyWithForwarder(t, srv.URL, pool)
|
||||
defer proxyWithTLSConfig.Close()
|
||||
|
||||
proxyAddr = proxyWithTLSConfig.Listener.Addr().String()
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
const dialTimeout = time.Second
|
||||
|
||||
type websocketRequestOpt func(w *websocketRequest)
|
||||
|
||||
func withServer(server string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.ServerAddr = server
|
||||
}
|
||||
}
|
||||
|
||||
func withPath(path string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.Path = path
|
||||
}
|
||||
}
|
||||
|
||||
func withData(data string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.Data = data
|
||||
}
|
||||
}
|
||||
|
||||
func withOrigin(origin string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.Origin = origin
|
||||
}
|
||||
}
|
||||
|
||||
func newWebsocketRequest(opts ...websocketRequestOpt) *websocketRequest {
|
||||
wsrequest := &websocketRequest{}
|
||||
for _, opt := range opts {
|
||||
opt(wsrequest)
|
||||
}
|
||||
|
||||
if wsrequest.Origin == "" {
|
||||
wsrequest.Origin = "http://" + wsrequest.ServerAddr
|
||||
}
|
||||
|
||||
if wsrequest.Config == nil {
|
||||
wsrequest.Config, _ = websocket.NewConfig(fmt.Sprintf("ws://%s%s", wsrequest.ServerAddr, wsrequest.Path), wsrequest.Origin)
|
||||
}
|
||||
|
||||
return wsrequest
|
||||
}
|
||||
|
||||
type websocketRequest struct {
|
||||
ServerAddr string
|
||||
Path string
|
||||
Data string
|
||||
Origin string
|
||||
Config *websocket.Config
|
||||
}
|
||||
|
||||
func (w *websocketRequest) send() (string, error) {
|
||||
conn, _, err := w.open()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if _, err := conn.Write([]byte(w.Data)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
msg := make([]byte, 512)
|
||||
|
||||
var n int
|
||||
n, err = conn.Read(msg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
received := string(msg[:n])
|
||||
return received, nil
|
||||
}
|
||||
|
||||
func (w *websocketRequest) open() (*websocket.Conn, net.Conn, error) {
|
||||
client, err := net.DialTimeout("tcp", w.ServerAddr, dialTimeout)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
conn, err := websocket.NewClient(w.Config, client)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return conn, client, err
|
||||
}
|
||||
|
||||
func parseURI(t *testing.T, uri string) *url.URL {
|
||||
t.Helper()
|
||||
|
||||
out, err := url.ParseRequestURI(uri)
|
||||
require.NoError(t, err)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func createConnectionPool(target string, tlsConfig *tls.Config) *connPool {
|
||||
u := testhelpers.MustParseURL(target)
|
||||
return newConnPool(200, 0, func() (net.Conn, error) {
|
||||
if tlsConfig != nil {
|
||||
return tls.Dial("tcp", u.Host, tlsConfig)
|
||||
}
|
||||
|
||||
return net.Dial("tcp", u.Host)
|
||||
})
|
||||
}
|
||||
|
||||
func createProxyWithForwarder(t *testing.T, uri string, pool *connPool) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
u := parseURI(t, uri)
|
||||
proxy, err := NewReverseProxy(u, nil, false, true, 0, pool)
|
||||
require.NoError(t, err)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
path := req.URL.Path // keep the original path
|
||||
// Set new backend URL
|
||||
req.URL = u
|
||||
req.URL.Path = path
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
return srv
|
||||
}
|
104
pkg/proxy/fast/upgrade.go
Normal file
104
pkg/proxy/fast/upgrade.go
Normal file
|
@ -0,0 +1,104 @@
|
|||
package fast
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/traefik/traefik/v3/pkg/proxy/httputil"
|
||||
"github.com/valyala/fasthttp"
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
|
||||
// switchProtocolCopier exists so goroutines proxying data back and
|
||||
// forth have nice names in stacks.
|
||||
type switchProtocolCopier struct {
|
||||
user, backend io.ReadWriter
|
||||
}
|
||||
|
||||
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
|
||||
_, err := io.Copy(c.user, c.backend)
|
||||
errc <- err
|
||||
}
|
||||
|
||||
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
|
||||
_, err := io.Copy(c.backend, c.user)
|
||||
errc <- err
|
||||
}
|
||||
|
||||
func handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, reqUpType string, res *fasthttp.Response, backConn net.Conn) {
|
||||
defer backConn.Close()
|
||||
|
||||
resUpType := upgradeTypeFastHTTP(&res.Header)
|
||||
|
||||
if !strings.EqualFold(reqUpType, resUpType) {
|
||||
httputil.ErrorHandler(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
|
||||
return
|
||||
}
|
||||
|
||||
hj, ok := rw.(http.Hijacker)
|
||||
if !ok {
|
||||
httputil.ErrorHandler(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
|
||||
return
|
||||
}
|
||||
backConnCloseCh := make(chan bool)
|
||||
go func() {
|
||||
// Ensure that the cancellation of a request closes the backend.
|
||||
// See issue https://golang.org/issue/35559.
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
case <-backConnCloseCh:
|
||||
}
|
||||
_ = backConn.Close()
|
||||
}()
|
||||
|
||||
defer close(backConnCloseCh)
|
||||
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
httputil.ErrorHandler(rw, req, fmt.Errorf("hijack failed on protocol switch: %w", err))
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
for k, values := range rw.Header() {
|
||||
for _, v := range values {
|
||||
res.Header.Add(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
if err := res.Header.Write(brw.Writer); err != nil {
|
||||
httputil.ErrorHandler(rw, req, fmt.Errorf("response write: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := brw.Flush(); err != nil {
|
||||
httputil.ErrorHandler(rw, req, fmt.Errorf("response flush: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
errc := make(chan error, 1)
|
||||
spc := switchProtocolCopier{user: conn, backend: backConn}
|
||||
go spc.copyToBackend(errc)
|
||||
go spc.copyFromBackend(errc)
|
||||
<-errc
|
||||
}
|
||||
|
||||
func upgradeType(h http.Header) string {
|
||||
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
|
||||
return ""
|
||||
}
|
||||
|
||||
return h.Get("Upgrade")
|
||||
}
|
||||
|
||||
func upgradeTypeFastHTTP(h fasthttpHeader) string {
|
||||
if !bytes.Contains(h.Peek("Connection"), []byte("Upgrade")) {
|
||||
return ""
|
||||
}
|
||||
|
||||
return string(h.Peek("Upgrade"))
|
||||
}
|
29
pkg/proxy/httputil/bufferpool.go
Normal file
29
pkg/proxy/httputil/bufferpool.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
package httputil
|
||||
|
||||
import "sync"
|
||||
|
||||
const bufferSize = 32 * 1024
|
||||
|
||||
type bufferPool struct {
|
||||
pool sync.Pool
|
||||
}
|
||||
|
||||
func newBufferPool() *bufferPool {
|
||||
b := &bufferPool{
|
||||
pool: sync.Pool{},
|
||||
}
|
||||
|
||||
b.pool.New = func() interface{} {
|
||||
return make([]byte, bufferSize)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *bufferPool) Get() []byte {
|
||||
return b.pool.Get().([]byte)
|
||||
}
|
||||
|
||||
func (b *bufferPool) Put(bytes []byte) {
|
||||
b.pool.Put(bytes)
|
||||
}
|
54
pkg/proxy/httputil/builder.go
Normal file
54
pkg/proxy/httputil/builder.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
||||
"github.com/traefik/traefik/v3/pkg/metrics"
|
||||
)
|
||||
|
||||
// TransportManager manages transport used for backend communications.
|
||||
type TransportManager interface {
|
||||
Get(name string) (*dynamic.ServersTransport, error)
|
||||
GetRoundTripper(name string) (http.RoundTripper, error)
|
||||
GetTLSConfig(name string) (*tls.Config, error)
|
||||
}
|
||||
|
||||
// ProxyBuilder handles the http.RoundTripper for httputil reverse proxies.
|
||||
type ProxyBuilder struct {
|
||||
bufferPool *bufferPool
|
||||
transportManager TransportManager
|
||||
semConvMetricsRegistry *metrics.SemConvMetricsRegistry
|
||||
}
|
||||
|
||||
// NewProxyBuilder creates a new ProxyBuilder.
|
||||
func NewProxyBuilder(transportManager TransportManager, semConvMetricsRegistry *metrics.SemConvMetricsRegistry) *ProxyBuilder {
|
||||
return &ProxyBuilder{
|
||||
bufferPool: newBufferPool(),
|
||||
transportManager: transportManager,
|
||||
semConvMetricsRegistry: semConvMetricsRegistry,
|
||||
}
|
||||
}
|
||||
|
||||
// Update does nothing.
|
||||
func (r *ProxyBuilder) Update(_ map[string]*dynamic.ServersTransport) {}
|
||||
|
||||
// Build builds a new httputil.ReverseProxy with the given configuration.
|
||||
func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, shouldObserve, passHostHeader bool, flushInterval time.Duration) (http.Handler, error) {
|
||||
roundTripper, err := r.transportManager.GetRoundTripper(cfgName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting RoundTripper: %w", err)
|
||||
}
|
||||
|
||||
if shouldObserve {
|
||||
// Wrapping the roundTripper with the Tracing roundTripper,
|
||||
// to handle the reverseProxy client span creation.
|
||||
roundTripper = newObservabilityRoundTripper(r.semConvMetricsRegistry, roundTripper)
|
||||
}
|
||||
|
||||
return buildSingleHostProxy(targetURL, passHostHeader, flushInterval, roundTripper, r.bufferPool), nil
|
||||
}
|
56
pkg/proxy/httputil/builder_test.go
Normal file
56
pkg/proxy/httputil/builder_test.go
Normal file
|
@ -0,0 +1,56 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
||||
"github.com/traefik/traefik/v3/pkg/testhelpers"
|
||||
)
|
||||
|
||||
func TestEscapedPath(t *testing.T) {
|
||||
var gotEscapedPath string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
gotEscapedPath = req.URL.EscapedPath()
|
||||
}))
|
||||
|
||||
transportManager := &transportManagerMock{
|
||||
roundTrippers: map[string]http.RoundTripper{"default": &http.Transport{}},
|
||||
}
|
||||
|
||||
p, err := NewProxyBuilder(transportManager, nil).Build("default", testhelpers.MustParseURL(srv.URL), false, true, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(p.ServeHTTP))
|
||||
|
||||
_, err = http.Get(proxy.URL + "/%3A%2F%2F")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "/%3A%2F%2F", gotEscapedPath)
|
||||
}
|
||||
|
||||
type transportManagerMock struct {
|
||||
roundTrippers map[string]http.RoundTripper
|
||||
}
|
||||
|
||||
func (t *transportManagerMock) GetRoundTripper(name string) (http.RoundTripper, error) {
|
||||
roundTripper, ok := t.roundTrippers[name]
|
||||
if !ok {
|
||||
return nil, errors.New("no transport for " + name)
|
||||
}
|
||||
|
||||
return roundTripper, nil
|
||||
}
|
||||
|
||||
func (t *transportManagerMock) GetTLSConfig(_ string) (*tls.Config, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (t *transportManagerMock) Get(_ string) (*dynamic.ServersTransport, error) {
|
||||
panic("implement me")
|
||||
}
|
105
pkg/proxy/httputil/observability.go
Normal file
105
pkg/proxy/httputil/observability.go
Normal file
|
@ -0,0 +1,105 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/traefik/traefik/v3/pkg/metrics"
|
||||
"github.com/traefik/traefik/v3/pkg/middlewares/observability"
|
||||
"github.com/traefik/traefik/v3/pkg/tracing"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
type wrapper struct {
|
||||
semConvMetricRegistry *metrics.SemConvMetricsRegistry
|
||||
rt http.RoundTripper
|
||||
}
|
||||
|
||||
func newObservabilityRoundTripper(semConvMetricRegistry *metrics.SemConvMetricsRegistry, rt http.RoundTripper) http.RoundTripper {
|
||||
return &wrapper{
|
||||
semConvMetricRegistry: semConvMetricRegistry,
|
||||
rt: rt,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *wrapper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
start := time.Now()
|
||||
var span trace.Span
|
||||
var tracingCtx context.Context
|
||||
var tracer *tracing.Tracer
|
||||
if tracer = tracing.TracerFromContext(req.Context()); tracer != nil {
|
||||
tracingCtx, span = tracer.Start(req.Context(), "ReverseProxy", trace.WithSpanKind(trace.SpanKindClient))
|
||||
defer span.End()
|
||||
|
||||
req = req.WithContext(tracingCtx)
|
||||
|
||||
tracer.CaptureClientRequest(span, req)
|
||||
tracing.InjectContextIntoCarrier(req)
|
||||
}
|
||||
|
||||
var statusCode int
|
||||
var headers http.Header
|
||||
response, err := t.rt.RoundTrip(req)
|
||||
if err != nil {
|
||||
statusCode = ComputeStatusCode(err)
|
||||
}
|
||||
if response != nil {
|
||||
statusCode = response.StatusCode
|
||||
headers = response.Header
|
||||
}
|
||||
|
||||
if tracer != nil {
|
||||
tracer.CaptureResponse(span, headers, statusCode, trace.SpanKindClient)
|
||||
}
|
||||
|
||||
end := time.Now()
|
||||
|
||||
// Ending the span as soon as the response is handled because we want to use the same end time for the trace and the metric.
|
||||
// If any errors happen earlier, this span will be close by the defer instruction.
|
||||
if span != nil {
|
||||
span.End(trace.WithTimestamp(end))
|
||||
}
|
||||
|
||||
if t.semConvMetricRegistry != nil && t.semConvMetricRegistry.HTTPClientRequestDuration() != nil {
|
||||
var attrs []attribute.KeyValue
|
||||
|
||||
if statusCode < 100 || statusCode >= 600 {
|
||||
attrs = append(attrs, attribute.Key("error.type").String(fmt.Sprintf("Invalid HTTP status code %d", statusCode)))
|
||||
} else if statusCode >= 400 {
|
||||
attrs = append(attrs, attribute.Key("error.type").String(strconv.Itoa(statusCode)))
|
||||
}
|
||||
|
||||
attrs = append(attrs, semconv.HTTPRequestMethodKey.String(req.Method))
|
||||
attrs = append(attrs, semconv.HTTPResponseStatusCode(statusCode))
|
||||
attrs = append(attrs, semconv.NetworkProtocolName(strings.ToLower(req.Proto)))
|
||||
attrs = append(attrs, semconv.NetworkProtocolVersion(observability.Proto(req.Proto)))
|
||||
attrs = append(attrs, semconv.ServerAddress(req.URL.Host))
|
||||
|
||||
_, port, err := net.SplitHostPort(req.URL.Host)
|
||||
if err != nil {
|
||||
switch req.URL.Scheme {
|
||||
case "http":
|
||||
attrs = append(attrs, semconv.ServerPort(80))
|
||||
case "https":
|
||||
attrs = append(attrs, semconv.ServerPort(443))
|
||||
}
|
||||
} else {
|
||||
intPort, _ := strconv.Atoi(port)
|
||||
attrs = append(attrs, semconv.ServerPort(intPort))
|
||||
}
|
||||
|
||||
attrs = append(attrs, semconv.URLScheme(req.Header.Get("X-Forwarded-Proto")))
|
||||
|
||||
t.semConvMetricRegistry.HTTPClientRequestDuration().Record(req.Context(), end.Sub(start).Seconds(), metric.WithAttributes(attrs...))
|
||||
}
|
||||
|
||||
return response, err
|
||||
}
|
122
pkg/proxy/httputil/observability_test.go
Normal file
122
pkg/proxy/httputil/observability_test.go
Normal file
|
@ -0,0 +1,122 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
ptypes "github.com/traefik/paerser/types"
|
||||
"github.com/traefik/traefik/v3/pkg/metrics"
|
||||
"github.com/traefik/traefik/v3/pkg/types"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
|
||||
"go.opentelemetry.io/otel/sdk/metric/metricdata"
|
||||
"go.opentelemetry.io/otel/sdk/metric/metricdata/metricdatatest"
|
||||
)
|
||||
|
||||
func TestObservabilityRoundTripper_metrics(t *testing.T) {
|
||||
tests := []struct {
|
||||
desc string
|
||||
serverURL string
|
||||
statusCode int
|
||||
wantAttributes attribute.Set
|
||||
}{
|
||||
{
|
||||
desc: "not found status",
|
||||
serverURL: "http://www.test.com",
|
||||
statusCode: http.StatusNotFound,
|
||||
wantAttributes: attribute.NewSet(
|
||||
attribute.Key("error.type").String("404"),
|
||||
attribute.Key("http.request.method").String("GET"),
|
||||
attribute.Key("http.response.status_code").Int(404),
|
||||
attribute.Key("network.protocol.name").String("http/1.1"),
|
||||
attribute.Key("network.protocol.version").String("1.1"),
|
||||
attribute.Key("server.address").String("www.test.com"),
|
||||
attribute.Key("server.port").Int(80),
|
||||
attribute.Key("url.scheme").String("http"),
|
||||
),
|
||||
},
|
||||
{
|
||||
desc: "created status",
|
||||
serverURL: "https://www.test.com",
|
||||
statusCode: http.StatusCreated,
|
||||
wantAttributes: attribute.NewSet(
|
||||
attribute.Key("http.request.method").String("GET"),
|
||||
attribute.Key("http.response.status_code").Int(201),
|
||||
attribute.Key("network.protocol.name").String("http/1.1"),
|
||||
attribute.Key("network.protocol.version").String("1.1"),
|
||||
attribute.Key("server.address").String("www.test.com"),
|
||||
attribute.Key("server.port").Int(443),
|
||||
attribute.Key("url.scheme").String("http"),
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var cfg types.OTLP
|
||||
(&cfg).SetDefaults()
|
||||
cfg.AddRoutersLabels = true
|
||||
cfg.PushInterval = ptypes.Duration(10 * time.Millisecond)
|
||||
rdr := sdkmetric.NewManualReader()
|
||||
|
||||
meterProvider := sdkmetric.NewMeterProvider(sdkmetric.WithReader(rdr))
|
||||
// force the meter provider with manual reader to collect metrics for the test.
|
||||
metrics.SetMeterProvider(meterProvider)
|
||||
|
||||
semConvMetricRegistry, err := metrics.NewSemConvMetricRegistry(context.Background(), &cfg)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, semConvMetricRegistry)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, test.serverURL+"/search?q=Opentelemetry", nil)
|
||||
req.RemoteAddr = "10.0.0.1:1234"
|
||||
req.Header.Set("User-Agent", "rt-test")
|
||||
req.Header.Set("X-Forwarded-Proto", "http")
|
||||
|
||||
ort := newObservabilityRoundTripper(semConvMetricRegistry, mockRoundTripper{statusCode: test.statusCode})
|
||||
_, err = ort.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
got := metricdata.ResourceMetrics{}
|
||||
err = rdr.Collect(context.Background(), &got)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, got.ScopeMetrics, 1)
|
||||
|
||||
expected := metricdata.Metrics{
|
||||
Name: "http.client.request.duration",
|
||||
Description: "Duration of HTTP client requests.",
|
||||
Unit: "s",
|
||||
Data: metricdata.Histogram[float64]{
|
||||
DataPoints: []metricdata.HistogramDataPoint[float64]{
|
||||
{
|
||||
Attributes: test.wantAttributes,
|
||||
Count: 1,
|
||||
Bounds: []float64{0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.25, 0.5, 0.75, 1, 2.5, 5, 7.5, 10},
|
||||
BucketCounts: []uint64{0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
|
||||
Min: metricdata.NewExtrema[float64](1),
|
||||
Max: metricdata.NewExtrema[float64](1),
|
||||
Sum: 1,
|
||||
},
|
||||
},
|
||||
Temporality: metricdata.CumulativeTemporality,
|
||||
},
|
||||
}
|
||||
|
||||
metricdatatest.AssertEqual[metricdata.Metrics](t, expected, got.ScopeMetrics[0].Metrics[0], metricdatatest.IgnoreTimestamp(), metricdatatest.IgnoreValue())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockRoundTripper struct {
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (m mockRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
|
||||
return &http.Response{StatusCode: m.statusCode}, nil
|
||||
}
|
135
pkg/proxy/httputil/proxy.go
Normal file
135
pkg/proxy/httputil/proxy.go
Normal file
|
@ -0,0 +1,135 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
|
||||
// StatusClientClosedRequest non-standard HTTP status code for client disconnection.
|
||||
const StatusClientClosedRequest = 499
|
||||
|
||||
// StatusClientClosedRequestText non-standard HTTP status for client disconnection.
|
||||
const StatusClientClosedRequestText = "Client Closed Request"
|
||||
|
||||
func buildSingleHostProxy(target *url.URL, passHostHeader bool, flushInterval time.Duration, roundTripper http.RoundTripper, bufferPool httputil.BufferPool) http.Handler {
|
||||
return &httputil.ReverseProxy{
|
||||
Director: directorBuilder(target, passHostHeader),
|
||||
Transport: roundTripper,
|
||||
FlushInterval: flushInterval,
|
||||
BufferPool: bufferPool,
|
||||
ErrorHandler: ErrorHandler,
|
||||
}
|
||||
}
|
||||
|
||||
func directorBuilder(target *url.URL, passHostHeader bool) func(req *http.Request) {
|
||||
return func(outReq *http.Request) {
|
||||
outReq.URL.Scheme = target.Scheme
|
||||
outReq.URL.Host = target.Host
|
||||
|
||||
u := outReq.URL
|
||||
if outReq.RequestURI != "" {
|
||||
parsedURL, err := url.ParseRequestURI(outReq.RequestURI)
|
||||
if err == nil {
|
||||
u = parsedURL
|
||||
}
|
||||
}
|
||||
|
||||
outReq.URL.Path = u.Path
|
||||
outReq.URL.RawPath = u.RawPath
|
||||
// If a plugin/middleware adds semicolons in query params, they should be urlEncoded.
|
||||
outReq.URL.RawQuery = strings.ReplaceAll(u.RawQuery, ";", "&")
|
||||
outReq.RequestURI = "" // Outgoing request should not have RequestURI
|
||||
|
||||
outReq.Proto = "HTTP/1.1"
|
||||
outReq.ProtoMajor = 1
|
||||
outReq.ProtoMinor = 1
|
||||
|
||||
// Do not pass client Host header unless optsetter PassHostHeader is set.
|
||||
if !passHostHeader {
|
||||
outReq.Host = outReq.URL.Host
|
||||
}
|
||||
|
||||
cleanWebSocketHeaders(outReq)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanWebSocketHeaders Even if the websocket RFC says that headers should be case-insensitive,
|
||||
// some servers need Sec-WebSocket-Key, Sec-WebSocket-Extensions, Sec-WebSocket-Accept,
|
||||
// Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive.
|
||||
// https://tools.ietf.org/html/rfc6455#page-20
|
||||
func cleanWebSocketHeaders(req *http.Request) {
|
||||
if !isWebSocketUpgrade(req) {
|
||||
return
|
||||
}
|
||||
|
||||
req.Header["Sec-WebSocket-Key"] = req.Header["Sec-Websocket-Key"]
|
||||
delete(req.Header, "Sec-Websocket-Key")
|
||||
|
||||
req.Header["Sec-WebSocket-Extensions"] = req.Header["Sec-Websocket-Extensions"]
|
||||
delete(req.Header, "Sec-Websocket-Extensions")
|
||||
|
||||
req.Header["Sec-WebSocket-Accept"] = req.Header["Sec-Websocket-Accept"]
|
||||
delete(req.Header, "Sec-Websocket-Accept")
|
||||
|
||||
req.Header["Sec-WebSocket-Protocol"] = req.Header["Sec-Websocket-Protocol"]
|
||||
delete(req.Header, "Sec-Websocket-Protocol")
|
||||
|
||||
req.Header["Sec-WebSocket-Version"] = req.Header["Sec-Websocket-Version"]
|
||||
delete(req.Header, "Sec-Websocket-Version")
|
||||
}
|
||||
|
||||
func isWebSocketUpgrade(req *http.Request) bool {
|
||||
return httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") &&
|
||||
strings.EqualFold(req.Header.Get("Upgrade"), "websocket")
|
||||
}
|
||||
|
||||
// ErrorHandler is the http.Handler called when something goes wrong when forwarding the request.
|
||||
func ErrorHandler(w http.ResponseWriter, req *http.Request, err error) {
|
||||
statusCode := ComputeStatusCode(err)
|
||||
|
||||
logger := log.Ctx(req.Context())
|
||||
logger.Debug().Err(err).Msgf("%d %s", statusCode, statusText(statusCode))
|
||||
|
||||
w.WriteHeader(statusCode)
|
||||
if _, werr := w.Write([]byte(statusText(statusCode))); werr != nil {
|
||||
logger.Debug().Err(werr).Msg("Error while writing status code")
|
||||
}
|
||||
}
|
||||
|
||||
// ComputeStatusCode computes the HTTP status code according to the given error.
|
||||
func ComputeStatusCode(err error) int {
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
return http.StatusBadGateway
|
||||
case errors.Is(err, context.Canceled):
|
||||
return StatusClientClosedRequest
|
||||
default:
|
||||
var netErr net.Error
|
||||
if errors.As(err, &netErr) {
|
||||
if netErr.Timeout() {
|
||||
return http.StatusGatewayTimeout
|
||||
}
|
||||
|
||||
return http.StatusBadGateway
|
||||
}
|
||||
}
|
||||
|
||||
return http.StatusInternalServerError
|
||||
}
|
||||
|
||||
func statusText(statusCode int) string {
|
||||
if statusCode == StatusClientClosedRequest {
|
||||
return StatusClientClosedRequestText
|
||||
}
|
||||
return http.StatusText(statusCode)
|
||||
}
|
638
pkg/proxy/httputil/proxy_websocket_test.go
Normal file
638
pkg/proxy/httputil/proxy_websocket_test.go
Normal file
|
@ -0,0 +1,638 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
gorillawebsocket "github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/traefik/traefik/v3/pkg/testhelpers"
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
func TestWebSocketTCPClose(t *testing.T) {
|
||||
errChan := make(chan error, 1)
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
for {
|
||||
_, _, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
_, conn, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
).open()
|
||||
require.NoError(t, err)
|
||||
|
||||
conn.Close()
|
||||
|
||||
serverErr := <-errChan
|
||||
|
||||
var wsErr *gorillawebsocket.CloseError
|
||||
require.ErrorAs(t, serverErr, &wsErr)
|
||||
assert.Equal(t, 1006, wsErr.Code)
|
||||
}
|
||||
|
||||
func TestWebSocketPingPong(t *testing.T) {
|
||||
upgrader := gorillawebsocket.Upgrader{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
CheckOrigin: func(*http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/ws", func(writer http.ResponseWriter, request *http.Request) {
|
||||
ws, err := upgrader.Upgrade(writer, request, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
ws.SetPingHandler(func(appData string) error {
|
||||
err = ws.WriteMessage(gorillawebsocket.PongMessage, []byte(appData+"Pong"))
|
||||
require.NoError(t, err)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
_, _, _ = ws.ReadMessage()
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", resp)
|
||||
defer conn.Close()
|
||||
|
||||
goodErr := fmt.Errorf("signal: %s", "Good data")
|
||||
badErr := fmt.Errorf("signal: %s", "Bad data")
|
||||
conn.SetPongHandler(func(data string) error {
|
||||
if data == "PingPong" {
|
||||
return goodErr
|
||||
}
|
||||
|
||||
return badErr
|
||||
})
|
||||
|
||||
err = conn.WriteControl(gorillawebsocket.PingMessage, []byte("Ping"), time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = conn.ReadMessage()
|
||||
if !errors.Is(err, goodErr) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketEcho(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
msg := make([]byte, 4)
|
||||
n, err := conn.Read(msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = conn.Write(msg[:n])
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", resp)
|
||||
|
||||
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, msg, err := conn.ReadMessage()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "OK", string(msg))
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestWebSocketPassHost(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
passHost bool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "PassHost false",
|
||||
passHost: false,
|
||||
},
|
||||
{
|
||||
desc: "PassHost true",
|
||||
passHost: true,
|
||||
expected: "example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
req := conn.Request()
|
||||
|
||||
if test.passHost {
|
||||
require.Equal(t, test.expected, req.Host)
|
||||
} else {
|
||||
require.NotEqual(t, test.expected, req.Host)
|
||||
}
|
||||
|
||||
msg := make([]byte, 4)
|
||||
n, err := conn.Read(msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = conn.Write(msg[:n])
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
|
||||
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
headers.Add("Host", "example.com")
|
||||
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", resp)
|
||||
|
||||
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, msg, err := conn.ReadMessage()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "OK", string(msg))
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketServerWithoutCheckOrigin(t *testing.T) {
|
||||
upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }}
|
||||
srv := createServer(t, upgrader, func(*http.Request) {})
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
withOrigin("http://127.0.0.2"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithOrigin(t *testing.T) {
|
||||
srv := createServer(t, gorillawebsocket.Upgrader{}, func(*http.Request) {})
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
_, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("echo"),
|
||||
withOrigin("http://127.0.0.2"),
|
||||
).send()
|
||||
require.EqualError(t, err, "bad status")
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithQueryParams(t *testing.T) {
|
||||
srv := createServer(t, gorillawebsocket.Upgrader{}, func(r *http.Request) {
|
||||
assert.Equal(t, "test", r.URL.Query().Get("query"))
|
||||
})
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws?query=test"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
conn.Close()
|
||||
}))
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
transportManager := &transportManagerMock{
|
||||
roundTrippers: map[string]http.RoundTripper{
|
||||
"default@internal": &http.Transport{},
|
||||
},
|
||||
}
|
||||
|
||||
p, err := NewProxyBuilder(transportManager, nil).Build("default@internal", testhelpers.MustParseURL(srv.URL), false, true, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
req.URL = testhelpers.MustParseURL(srv.URL)
|
||||
w.Header().Set("HEADER-KEY", "HEADER-VALUE")
|
||||
p.ServeHTTP(w, req)
|
||||
}))
|
||||
defer proxy.Close()
|
||||
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", err, resp)
|
||||
defer conn.Close()
|
||||
|
||||
assert.Equal(t, "HEADER-VALUE", resp.Header.Get("HEADER-KEY"))
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithEncodedChar(t *testing.T) {
|
||||
srv := createServer(t, gorillawebsocket.Upgrader{}, func(r *http.Request) {
|
||||
assert.Equal(t, "/%3A%2F%2F", r.URL.EscapedPath())
|
||||
})
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/%3A%2F%2F"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketUpgradeFailed(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/ws", func(w http.ResponseWriter, req *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
transportManager := &transportManagerMock{
|
||||
roundTrippers: map[string]http.RoundTripper{
|
||||
"default@internal": &http.Transport{},
|
||||
},
|
||||
}
|
||||
|
||||
p, err := NewProxyBuilder(transportManager, nil).Build("default@internal", testhelpers.MustParseURL(srv.URL), false, true, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
path := req.URL.Path // keep the original path
|
||||
|
||||
if path == "/ws" {
|
||||
// Set new backend URL
|
||||
req.URL = testhelpers.MustParseURL(srv.URL)
|
||||
req.URL.Path = path
|
||||
p.ServeHTTP(w, req)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
}))
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
conn, err := net.DialTimeout("tcp", proxyAddr, dialTimeout)
|
||||
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "ws://127.0.0.1/ws", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Add("upgrade", "websocket")
|
||||
req.Header.Add("Connection", "upgrade")
|
||||
|
||||
err = req.Write(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First request works with 400
|
||||
br := bufio.NewReader(conn)
|
||||
resp, err := http.ReadResponse(br, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 400, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestForwardsWebsocketTraffic(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
_, err := conn.Write([]byte("ok"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("echo"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func createTLSWebsocketServer() *httptest.Server {
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = conn.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestWebSocketTransferTLSConfig(t *testing.T) {
|
||||
srv := createTLSWebsocketServer()
|
||||
defer srv.Close()
|
||||
|
||||
proxyWithoutTLSConfig := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
|
||||
defer proxyWithoutTLSConfig.Close()
|
||||
|
||||
proxyAddr := proxyWithoutTLSConfig.Listener.Addr().String()
|
||||
|
||||
_, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.EqualError(t, err, "bad status")
|
||||
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
|
||||
proxyWithTLSConfig := createProxyWithForwarder(t, srv.URL, transport)
|
||||
defer proxyWithTLSConfig.Close()
|
||||
|
||||
proxyAddr = proxyWithTLSConfig.Listener.Addr().String()
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
|
||||
// Don't alter default transport to prevent side effects on other tests.
|
||||
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
|
||||
defaultTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
|
||||
proxyWithTLSConfigFromDefaultTransport := createProxyWithForwarder(t, srv.URL, defaultTransport)
|
||||
defer proxyWithTLSConfig.Close()
|
||||
|
||||
proxyAddr = proxyWithTLSConfigFromDefaultTransport.Listener.Addr().String()
|
||||
|
||||
resp, err = newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
const dialTimeout = time.Second
|
||||
|
||||
type websocketRequestOpt func(w *websocketRequest)
|
||||
|
||||
func withServer(server string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.ServerAddr = server
|
||||
}
|
||||
}
|
||||
|
||||
func withPath(path string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.Path = path
|
||||
}
|
||||
}
|
||||
|
||||
func withData(data string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.Data = data
|
||||
}
|
||||
}
|
||||
|
||||
func withOrigin(origin string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.Origin = origin
|
||||
}
|
||||
}
|
||||
|
||||
func newWebsocketRequest(opts ...websocketRequestOpt) *websocketRequest {
|
||||
wsrequest := &websocketRequest{}
|
||||
for _, opt := range opts {
|
||||
opt(wsrequest)
|
||||
}
|
||||
if wsrequest.Origin == "" {
|
||||
wsrequest.Origin = "http://" + wsrequest.ServerAddr
|
||||
}
|
||||
if wsrequest.Config == nil {
|
||||
wsrequest.Config, _ = websocket.NewConfig(fmt.Sprintf("ws://%s%s", wsrequest.ServerAddr, wsrequest.Path), wsrequest.Origin)
|
||||
}
|
||||
return wsrequest
|
||||
}
|
||||
|
||||
type websocketRequest struct {
|
||||
ServerAddr string
|
||||
Path string
|
||||
Data string
|
||||
Origin string
|
||||
Config *websocket.Config
|
||||
}
|
||||
|
||||
func (w *websocketRequest) send() (string, error) {
|
||||
conn, _, err := w.open()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer conn.Close()
|
||||
if _, err := conn.Write([]byte(w.Data)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
msg := make([]byte, 512)
|
||||
var n int
|
||||
n, err = conn.Read(msg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
received := string(msg[:n])
|
||||
return received, nil
|
||||
}
|
||||
|
||||
func (w *websocketRequest) open() (*websocket.Conn, net.Conn, error) {
|
||||
client, err := net.DialTimeout("tcp", w.ServerAddr, dialTimeout)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
conn, err := websocket.NewClient(w.Config, client)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return conn, client, err
|
||||
}
|
||||
|
||||
func createProxyWithForwarder(t *testing.T, uri string, transport http.RoundTripper) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
u := testhelpers.MustParseURL(uri)
|
||||
|
||||
transportManager := &transportManagerMock{
|
||||
roundTrippers: map[string]http.RoundTripper{"fwd": transport},
|
||||
}
|
||||
|
||||
p, err := NewProxyBuilder(transportManager, nil).Build("fwd", u, false, true, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
// keep the original path
|
||||
path := req.URL.Path
|
||||
|
||||
// Set new backend URL
|
||||
req.URL = u
|
||||
req.URL.Path = path
|
||||
|
||||
p.ServeHTTP(w, req)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
return srv
|
||||
}
|
||||
|
||||
func createServer(t *testing.T, upgrader gorillawebsocket.Upgrader, check func(*http.Request)) *httptest.Server {
|
||||
t.Helper()
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
t.Logf("Error during upgrade: %v", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
check(r)
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
t.Logf("Error during read: %v", err)
|
||||
break
|
||||
}
|
||||
|
||||
err = conn.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
t.Logf("Error during write: %v", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
return srv
|
||||
}
|
61
pkg/proxy/smart_builder.go
Normal file
61
pkg/proxy/smart_builder.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
||||
"github.com/traefik/traefik/v3/pkg/config/static"
|
||||
"github.com/traefik/traefik/v3/pkg/proxy/fast"
|
||||
"github.com/traefik/traefik/v3/pkg/proxy/httputil"
|
||||
"github.com/traefik/traefik/v3/pkg/server/service"
|
||||
)
|
||||
|
||||
// TransportManager manages transport used for backend communications.
|
||||
type TransportManager interface {
|
||||
Get(name string) (*dynamic.ServersTransport, error)
|
||||
GetRoundTripper(name string) (http.RoundTripper, error)
|
||||
GetTLSConfig(name string) (*tls.Config, error)
|
||||
}
|
||||
|
||||
// SmartBuilder is a proxy builder which returns a fast proxy or httputil proxy corresponding
|
||||
// to the ServersTransport configuration.
|
||||
type SmartBuilder struct {
|
||||
fastProxyBuilder *fast.ProxyBuilder
|
||||
proxyBuilder service.ProxyBuilder
|
||||
|
||||
transportManager httputil.TransportManager
|
||||
}
|
||||
|
||||
// NewSmartBuilder creates and returns a new SmartBuilder instance.
|
||||
func NewSmartBuilder(transportManager TransportManager, proxyBuilder service.ProxyBuilder, fastProxyConfig static.FastProxyConfig) *SmartBuilder {
|
||||
return &SmartBuilder{
|
||||
fastProxyBuilder: fast.NewProxyBuilder(transportManager, fastProxyConfig),
|
||||
proxyBuilder: proxyBuilder,
|
||||
transportManager: transportManager,
|
||||
}
|
||||
}
|
||||
|
||||
// Update is the handler called when the dynamic configuration is updated.
|
||||
func (b *SmartBuilder) Update(newConfigs map[string]*dynamic.ServersTransport) {
|
||||
b.fastProxyBuilder.Update(newConfigs)
|
||||
}
|
||||
|
||||
// Build builds an HTTP proxy for the given URL using the ServersTransport with the given name.
|
||||
func (b *SmartBuilder) Build(configName string, targetURL *url.URL, shouldObserve, passHostHeader bool, flushInterval time.Duration) (http.Handler, error) {
|
||||
serversTransport, err := b.transportManager.Get(configName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting ServersTransport: %w", err)
|
||||
}
|
||||
|
||||
// The fast proxy implementation cannot handle HTTP/2 requests for now.
|
||||
// For the https scheme we cannot guess if the backend communication will use HTTP2,
|
||||
// thus we check if HTTP/2 is disabled to use the fast proxy implementation when this is possible.
|
||||
if targetURL.Scheme == "h2c" || (targetURL.Scheme == "https" && !serversTransport.DisableHTTP2) {
|
||||
return b.proxyBuilder.Build(configName, targetURL, shouldObserve, passHostHeader, flushInterval)
|
||||
}
|
||||
return b.fastProxyBuilder.Build(configName, targetURL, passHostHeader)
|
||||
}
|
113
pkg/proxy/smart_builder_test.go
Normal file
113
pkg/proxy/smart_builder_test.go
Normal file
|
@ -0,0 +1,113 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/pem"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
||||
"github.com/traefik/traefik/v3/pkg/config/static"
|
||||
"github.com/traefik/traefik/v3/pkg/proxy/httputil"
|
||||
"github.com/traefik/traefik/v3/pkg/server/service"
|
||||
"github.com/traefik/traefik/v3/pkg/testhelpers"
|
||||
"github.com/traefik/traefik/v3/pkg/types"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
)
|
||||
|
||||
func TestSmartBuilder_Build(t *testing.T) {
|
||||
tests := []struct {
|
||||
desc string
|
||||
serversTransport dynamic.ServersTransport
|
||||
fastProxyConfig static.FastProxyConfig
|
||||
https bool
|
||||
h2c bool
|
||||
wantFastProxy bool
|
||||
}{
|
||||
{
|
||||
desc: "fastproxy",
|
||||
fastProxyConfig: static.FastProxyConfig{Debug: true},
|
||||
wantFastProxy: true,
|
||||
},
|
||||
{
|
||||
desc: "fastproxy with https and without DisableHTTP2",
|
||||
https: true,
|
||||
fastProxyConfig: static.FastProxyConfig{Debug: true},
|
||||
wantFastProxy: false,
|
||||
},
|
||||
{
|
||||
desc: "fastproxy with https and DisableHTTP2",
|
||||
https: true,
|
||||
serversTransport: dynamic.ServersTransport{DisableHTTP2: true},
|
||||
fastProxyConfig: static.FastProxyConfig{Debug: true},
|
||||
wantFastProxy: true,
|
||||
},
|
||||
{
|
||||
desc: "fastproxy with h2c",
|
||||
h2c: true,
|
||||
fastProxyConfig: static.FastProxyConfig{Debug: true},
|
||||
wantFastProxy: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var callCount int
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
callCount++
|
||||
if test.wantFastProxy {
|
||||
assert.Contains(t, r.Header, "X-Traefik-Fast-Proxy")
|
||||
} else {
|
||||
assert.NotContains(t, r.Header, "X-Traefik-Fast-Proxy")
|
||||
}
|
||||
})
|
||||
|
||||
var server *httptest.Server
|
||||
|
||||
if test.https {
|
||||
server = httptest.NewUnstartedServer(handler)
|
||||
server.EnableHTTP2 = false
|
||||
server.StartTLS()
|
||||
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: server.TLS.Certificates[0].Certificate[0]})
|
||||
test.serversTransport.RootCAs = []types.FileOrContent{
|
||||
types.FileOrContent(certPEM),
|
||||
}
|
||||
} else {
|
||||
server = httptest.NewServer(h2c.NewHandler(handler, &http2.Server{}))
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
server.Close()
|
||||
})
|
||||
|
||||
targetURL := testhelpers.MustParseURL(server.URL)
|
||||
if test.h2c {
|
||||
targetURL.Scheme = "h2c"
|
||||
}
|
||||
|
||||
serversTransports := map[string]*dynamic.ServersTransport{
|
||||
"test": &test.serversTransport,
|
||||
}
|
||||
|
||||
transportManager := service.NewTransportManager(nil)
|
||||
transportManager.Update(serversTransports)
|
||||
|
||||
httpProxyBuilder := httputil.NewProxyBuilder(transportManager, nil)
|
||||
proxyBuilder := NewSmartBuilder(transportManager, httpProxyBuilder, test.fastProxyConfig)
|
||||
|
||||
proxyHandler, err := proxyBuilder.Build("test", targetURL, false, false, time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
proxyHandler.ServeHTTP(rw, httptest.NewRequest(http.MethodGet, "/", http.NoBody))
|
||||
|
||||
assert.Equal(t, 1, callCount)
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue