diff --git a/pkg/tcp/dialer.go b/pkg/tcp/dialer.go index 2a754d3b4..bc4855b5b 100644 --- a/pkg/tcp/dialer.go +++ b/pkg/tcp/dialer.go @@ -19,12 +19,20 @@ import ( "github.com/traefik/traefik/v3/pkg/config/dynamic" traefiktls "github.com/traefik/traefik/v3/pkg/tls" "github.com/traefik/traefik/v3/pkg/types" - "golang.org/x/net/proxy" ) -type Dialer interface { - proxy.Dialer +// ClientConn is the interface that provides information about the client connection. +type ClientConn interface { + // LocalAddr returns the local network address, if known. + LocalAddr() net.Addr + // RemoteAddr returns the remote network address, if known. + RemoteAddr() net.Addr +} + +// Dialer is an interface to dial a network connection, with support for PROXY protocol and termination delay. +type Dialer interface { + Dial(network, addr string, clientConn ClientConn) (c net.Conn, err error) TerminationDelay() time.Duration } @@ -34,18 +42,20 @@ type tcpDialer struct { proxyProtocol *dynamic.ProxyProtocol } +// TerminationDelay returns the termination delay duration. func (d tcpDialer) TerminationDelay() time.Duration { return d.terminationDelay } -func (d tcpDialer) Dial(network, addr string) (net.Conn, error) { +// Dial dials a network connection and optionally sends a PROXY protocol header. +func (d tcpDialer) Dial(network, addr string, clientConn ClientConn) (net.Conn, error) { conn, err := d.dialer.Dial(network, addr) if err != nil { return nil, err } - if d.proxyProtocol != nil && d.proxyProtocol.Version > 0 && d.proxyProtocol.Version < 3 { - header := proxyproto.HeaderProxyFromAddrs(byte(d.proxyProtocol.Version), conn.RemoteAddr(), conn.LocalAddr()) + if d.proxyProtocol != nil && clientConn != nil && d.proxyProtocol.Version > 0 && d.proxyProtocol.Version < 3 { + header := proxyproto.HeaderProxyFromAddrs(byte(d.proxyProtocol.Version), clientConn.RemoteAddr(), clientConn.LocalAddr()) if _, err := header.WriteTo(conn); err != nil { _ = conn.Close() return nil, fmt.Errorf("writing PROXY Protocol header: %w", err) @@ -60,8 +70,9 @@ type tcpTLSDialer struct { tlsConfig *tls.Config } -func (d tcpTLSDialer) Dial(network, addr string) (net.Conn, error) { - conn, err := d.tcpDialer.Dial(network, addr) +// Dial dials a network connection with the wrapped tcpDialer and performs a TLS handshake. +func (d tcpTLSDialer) Dial(network, addr string, clientConn ClientConn) (net.Conn, error) { + conn, err := d.tcpDialer.Dial(network, addr, clientConn) if err != nil { return nil, err } diff --git a/pkg/tcp/dialer_test.go b/pkg/tcp/dialer_test.go index 62af030db..07d8a7b12 100644 --- a/pkg/tcp/dialer_test.go +++ b/pkg/tcp/dialer_test.go @@ -160,7 +160,7 @@ func TestNoTLS(t *testing.T) { dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, false) require.NoError(t, err) - conn, err := dialer.Dial("tcp", ":"+port) + conn, err := dialer.Dial("tcp", ":"+port, nil) require.NoError(t, err) _, err = conn.Write([]byte("ping\n")) @@ -209,7 +209,7 @@ func TestTLS(t *testing.T) { dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, true) require.NoError(t, err) - conn, err := dialer.Dial("tcp", ":"+port) + conn, err := dialer.Dial("tcp", ":"+port, nil) require.NoError(t, err) _, err = conn.Write([]byte("ping\n")) @@ -260,7 +260,7 @@ func TestTLSWithInsecureSkipVerify(t *testing.T) { dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, true) require.NoError(t, err) - conn, err := dialer.Dial("tcp", ":"+port) + conn, err := dialer.Dial("tcp", ":"+port, nil) require.NoError(t, err) _, err = conn.Write([]byte("ping\n")) @@ -329,7 +329,7 @@ func TestMTLS(t *testing.T) { dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, true) require.NoError(t, err) - conn, err := dialer.Dial("tcp", ":"+port) + conn, err := dialer.Dial("tcp", ":"+port, nil) require.NoError(t, err) _, err = conn.Write([]byte("ping\n")) @@ -463,7 +463,7 @@ func TestSpiffeMTLS(t *testing.T) { dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, true) require.NoError(t, err) - conn, err := dialer.Dial("tcp", ":"+port) + conn, err := dialer.Dial("tcp", ":"+port, nil) if test.wantError { require.Error(t, err) @@ -510,10 +510,13 @@ func TestProxyProtocol(t *testing.T) { require.NoError(t, err) var version int + var localAddr, remoteAddr string proxyBackendListener := proxyproto.Listener{ Listener: backendListener, ValidateHeader: func(h *proxyproto.Header) error { version = int(h.Version) + localAddr = h.DestinationAddr.String() + remoteAddr = h.SourceAddr.String() return nil }, Policy: func(upstream net.Addr) (proxyproto.Policy, error) { @@ -544,7 +547,18 @@ func TestProxyProtocol(t *testing.T) { dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, false) require.NoError(t, err) - conn, err := dialer.Dial("tcp", ":"+port) + clientConn := &fakeClientConn{ + localAddr: &net.TCPAddr{ + IP: net.ParseIP("2.2.2.2"), + Port: 12345, + }, + remoteAddr: &net.TCPAddr{ + IP: net.ParseIP("1.1.1.1"), + Port: 12345, + }, + } + + conn, err := dialer.Dial("tcp", ":"+port, clientConn) require.NoError(t, err) defer conn.Close() @@ -558,6 +572,8 @@ func TestProxyProtocol(t *testing.T) { assert.Equal(t, 4, n) assert.Equal(t, "PONG", string(buf[:4])) assert.Equal(t, test.version, version) + assert.Equal(t, "2.2.2.2:12345", localAddr) + assert.Equal(t, "1.1.1.1:12345", remoteAddr) }) } } @@ -586,10 +602,13 @@ func TestProxyProtocolWithTLS(t *testing.T) { require.NoError(t, err) var version int + var localAddr, remoteAddr string proxyBackendListener := proxyproto.Listener{ Listener: backendListener, ValidateHeader: func(h *proxyproto.Header) error { version = int(h.Version) + localAddr = h.DestinationAddr.String() + remoteAddr = h.SourceAddr.String() return nil }, Policy: func(upstream net.Addr) (proxyproto.Policy, error) { @@ -646,7 +665,18 @@ func TestProxyProtocolWithTLS(t *testing.T) { }, true) require.NoError(t, err) - conn, err := dialer.Dial("tcp", ":"+port) + clientConn := &fakeClientConn{ + localAddr: &net.TCPAddr{ + IP: net.ParseIP("2.2.2.2"), + Port: 12345, + }, + remoteAddr: &net.TCPAddr{ + IP: net.ParseIP("1.1.1.1"), + Port: 12345, + }, + } + + conn, err := dialer.Dial("tcp", ":"+port, clientConn) require.NoError(t, err) defer conn.Close() @@ -660,6 +690,8 @@ func TestProxyProtocolWithTLS(t *testing.T) { assert.Equal(t, 4, n) assert.Equal(t, "PONG", string(buf[:4])) assert.Equal(t, test.version, version) + assert.Equal(t, "2.2.2.2:12345", localAddr) + assert.Equal(t, "1.1.1.1:12345", remoteAddr) }) } } @@ -695,7 +727,7 @@ func TestProxyProtocolDisabled(t *testing.T) { dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, false) require.NoError(t, err) - conn, err := dialer.Dial("tcp", ":"+port) + conn, err := dialer.Dial("tcp", ":"+port, nil) require.NoError(t, err) _, err = conn.Write([]byte("ping")) @@ -709,6 +741,19 @@ func TestProxyProtocolDisabled(t *testing.T) { assert.Equal(t, "PONG", string(buf[:4])) } +type fakeClientConn struct { + remoteAddr *net.TCPAddr + localAddr *net.TCPAddr +} + +func (f fakeClientConn) LocalAddr() net.Addr { + return f.localAddr +} + +func (f fakeClientConn) RemoteAddr() net.Addr { + return f.remoteAddr +} + // fakeSpiffePKI simulates a SPIFFE aware PKI and allows generating multiple valid SVIDs. type fakeSpiffePKI struct { caPrivateKey *rsa.PrivateKey diff --git a/pkg/tcp/proxy.go b/pkg/tcp/proxy.go index 75bdfcd4d..aa979d303 100644 --- a/pkg/tcp/proxy.go +++ b/pkg/tcp/proxy.go @@ -34,7 +34,7 @@ func (p *Proxy) ServeTCP(conn WriteCloser) { // needed because of e.g. server.trackedConnection defer conn.Close() - connBackend, err := p.dialBackend() + connBackend, err := p.dialBackend(conn) if err != nil { log.Error().Err(err).Msg("Error while dialing backend") return @@ -62,8 +62,10 @@ func (p *Proxy) ServeTCP(conn WriteCloser) { <-errChan } -func (p *Proxy) dialBackend() (WriteCloser, error) { - conn, err := p.dialer.Dial("tcp", p.address) +func (p *Proxy) dialBackend(clientConn net.Conn) (WriteCloser, error) { + // The clientConn is passed to the dialer so that it can use information from it if needed, + // to build a PROXY protocol header. + conn, err := p.dialer.Dial("tcp", p.address, clientConn) if err != nil { return nil, err }