Use client conn to build the proxy protocol header
Co-authored-by: Simon Delicata <simon.delicata@free.fr>
This commit is contained in:
parent
660acf3b42
commit
5df4c270a7
3 changed files with 77 additions and 19 deletions
|
|
@ -19,12 +19,20 @@ import (
|
||||||
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
||||||
traefiktls "github.com/traefik/traefik/v3/pkg/tls"
|
traefiktls "github.com/traefik/traefik/v3/pkg/tls"
|
||||||
"github.com/traefik/traefik/v3/pkg/types"
|
"github.com/traefik/traefik/v3/pkg/types"
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Dialer interface {
|
// ClientConn is the interface that provides information about the client connection.
|
||||||
proxy.Dialer
|
type ClientConn interface {
|
||||||
|
// LocalAddr returns the local network address, if known.
|
||||||
|
LocalAddr() net.Addr
|
||||||
|
|
||||||
|
// RemoteAddr returns the remote network address, if known.
|
||||||
|
RemoteAddr() net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dialer is an interface to dial a network connection, with support for PROXY protocol and termination delay.
|
||||||
|
type Dialer interface {
|
||||||
|
Dial(network, addr string, clientConn ClientConn) (c net.Conn, err error)
|
||||||
TerminationDelay() time.Duration
|
TerminationDelay() time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -34,18 +42,20 @@ type tcpDialer struct {
|
||||||
proxyProtocol *dynamic.ProxyProtocol
|
proxyProtocol *dynamic.ProxyProtocol
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TerminationDelay returns the termination delay duration.
|
||||||
func (d tcpDialer) TerminationDelay() time.Duration {
|
func (d tcpDialer) TerminationDelay() time.Duration {
|
||||||
return d.terminationDelay
|
return d.terminationDelay
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d tcpDialer) Dial(network, addr string) (net.Conn, error) {
|
// Dial dials a network connection and optionally sends a PROXY protocol header.
|
||||||
|
func (d tcpDialer) Dial(network, addr string, clientConn ClientConn) (net.Conn, error) {
|
||||||
conn, err := d.dialer.Dial(network, addr)
|
conn, err := d.dialer.Dial(network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if d.proxyProtocol != nil && d.proxyProtocol.Version > 0 && d.proxyProtocol.Version < 3 {
|
if d.proxyProtocol != nil && clientConn != nil && d.proxyProtocol.Version > 0 && d.proxyProtocol.Version < 3 {
|
||||||
header := proxyproto.HeaderProxyFromAddrs(byte(d.proxyProtocol.Version), conn.RemoteAddr(), conn.LocalAddr())
|
header := proxyproto.HeaderProxyFromAddrs(byte(d.proxyProtocol.Version), clientConn.RemoteAddr(), clientConn.LocalAddr())
|
||||||
if _, err := header.WriteTo(conn); err != nil {
|
if _, err := header.WriteTo(conn); err != nil {
|
||||||
_ = conn.Close()
|
_ = conn.Close()
|
||||||
return nil, fmt.Errorf("writing PROXY Protocol header: %w", err)
|
return nil, fmt.Errorf("writing PROXY Protocol header: %w", err)
|
||||||
|
|
@ -60,8 +70,9 @@ type tcpTLSDialer struct {
|
||||||
tlsConfig *tls.Config
|
tlsConfig *tls.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d tcpTLSDialer) Dial(network, addr string) (net.Conn, error) {
|
// Dial dials a network connection with the wrapped tcpDialer and performs a TLS handshake.
|
||||||
conn, err := d.tcpDialer.Dial(network, addr)
|
func (d tcpTLSDialer) Dial(network, addr string, clientConn ClientConn) (net.Conn, error) {
|
||||||
|
conn, err := d.tcpDialer.Dial(network, addr, clientConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -160,7 +160,7 @@ func TestNoTLS(t *testing.T) {
|
||||||
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, false)
|
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
conn, err := dialer.Dial("tcp", ":"+port)
|
conn, err := dialer.Dial("tcp", ":"+port, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = conn.Write([]byte("ping\n"))
|
_, err = conn.Write([]byte("ping\n"))
|
||||||
|
|
@ -209,7 +209,7 @@ func TestTLS(t *testing.T) {
|
||||||
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, true)
|
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
conn, err := dialer.Dial("tcp", ":"+port)
|
conn, err := dialer.Dial("tcp", ":"+port, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = conn.Write([]byte("ping\n"))
|
_, err = conn.Write([]byte("ping\n"))
|
||||||
|
|
@ -260,7 +260,7 @@ func TestTLSWithInsecureSkipVerify(t *testing.T) {
|
||||||
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, true)
|
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
conn, err := dialer.Dial("tcp", ":"+port)
|
conn, err := dialer.Dial("tcp", ":"+port, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = conn.Write([]byte("ping\n"))
|
_, err = conn.Write([]byte("ping\n"))
|
||||||
|
|
@ -329,7 +329,7 @@ func TestMTLS(t *testing.T) {
|
||||||
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, true)
|
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
conn, err := dialer.Dial("tcp", ":"+port)
|
conn, err := dialer.Dial("tcp", ":"+port, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = conn.Write([]byte("ping\n"))
|
_, err = conn.Write([]byte("ping\n"))
|
||||||
|
|
@ -463,7 +463,7 @@ func TestSpiffeMTLS(t *testing.T) {
|
||||||
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, true)
|
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
conn, err := dialer.Dial("tcp", ":"+port)
|
conn, err := dialer.Dial("tcp", ":"+port, nil)
|
||||||
|
|
||||||
if test.wantError {
|
if test.wantError {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
|
@ -510,10 +510,13 @@ func TestProxyProtocol(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var version int
|
var version int
|
||||||
|
var localAddr, remoteAddr string
|
||||||
proxyBackendListener := proxyproto.Listener{
|
proxyBackendListener := proxyproto.Listener{
|
||||||
Listener: backendListener,
|
Listener: backendListener,
|
||||||
ValidateHeader: func(h *proxyproto.Header) error {
|
ValidateHeader: func(h *proxyproto.Header) error {
|
||||||
version = int(h.Version)
|
version = int(h.Version)
|
||||||
|
localAddr = h.DestinationAddr.String()
|
||||||
|
remoteAddr = h.SourceAddr.String()
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
|
Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
|
||||||
|
|
@ -544,7 +547,18 @@ func TestProxyProtocol(t *testing.T) {
|
||||||
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, false)
|
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
conn, err := dialer.Dial("tcp", ":"+port)
|
clientConn := &fakeClientConn{
|
||||||
|
localAddr: &net.TCPAddr{
|
||||||
|
IP: net.ParseIP("2.2.2.2"),
|
||||||
|
Port: 12345,
|
||||||
|
},
|
||||||
|
remoteAddr: &net.TCPAddr{
|
||||||
|
IP: net.ParseIP("1.1.1.1"),
|
||||||
|
Port: 12345,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := dialer.Dial("tcp", ":"+port, clientConn)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
|
|
@ -558,6 +572,8 @@ func TestProxyProtocol(t *testing.T) {
|
||||||
assert.Equal(t, 4, n)
|
assert.Equal(t, 4, n)
|
||||||
assert.Equal(t, "PONG", string(buf[:4]))
|
assert.Equal(t, "PONG", string(buf[:4]))
|
||||||
assert.Equal(t, test.version, version)
|
assert.Equal(t, test.version, version)
|
||||||
|
assert.Equal(t, "2.2.2.2:12345", localAddr)
|
||||||
|
assert.Equal(t, "1.1.1.1:12345", remoteAddr)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -586,10 +602,13 @@ func TestProxyProtocolWithTLS(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
var version int
|
var version int
|
||||||
|
var localAddr, remoteAddr string
|
||||||
proxyBackendListener := proxyproto.Listener{
|
proxyBackendListener := proxyproto.Listener{
|
||||||
Listener: backendListener,
|
Listener: backendListener,
|
||||||
ValidateHeader: func(h *proxyproto.Header) error {
|
ValidateHeader: func(h *proxyproto.Header) error {
|
||||||
version = int(h.Version)
|
version = int(h.Version)
|
||||||
|
localAddr = h.DestinationAddr.String()
|
||||||
|
remoteAddr = h.SourceAddr.String()
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
|
Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
|
||||||
|
|
@ -646,7 +665,18 @@ func TestProxyProtocolWithTLS(t *testing.T) {
|
||||||
}, true)
|
}, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
conn, err := dialer.Dial("tcp", ":"+port)
|
clientConn := &fakeClientConn{
|
||||||
|
localAddr: &net.TCPAddr{
|
||||||
|
IP: net.ParseIP("2.2.2.2"),
|
||||||
|
Port: 12345,
|
||||||
|
},
|
||||||
|
remoteAddr: &net.TCPAddr{
|
||||||
|
IP: net.ParseIP("1.1.1.1"),
|
||||||
|
Port: 12345,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := dialer.Dial("tcp", ":"+port, clientConn)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
|
|
@ -660,6 +690,8 @@ func TestProxyProtocolWithTLS(t *testing.T) {
|
||||||
assert.Equal(t, 4, n)
|
assert.Equal(t, 4, n)
|
||||||
assert.Equal(t, "PONG", string(buf[:4]))
|
assert.Equal(t, "PONG", string(buf[:4]))
|
||||||
assert.Equal(t, test.version, version)
|
assert.Equal(t, test.version, version)
|
||||||
|
assert.Equal(t, "2.2.2.2:12345", localAddr)
|
||||||
|
assert.Equal(t, "1.1.1.1:12345", remoteAddr)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -695,7 +727,7 @@ func TestProxyProtocolDisabled(t *testing.T) {
|
||||||
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, false)
|
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
conn, err := dialer.Dial("tcp", ":"+port)
|
conn, err := dialer.Dial("tcp", ":"+port, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
_, err = conn.Write([]byte("ping"))
|
_, err = conn.Write([]byte("ping"))
|
||||||
|
|
@ -709,6 +741,19 @@ func TestProxyProtocolDisabled(t *testing.T) {
|
||||||
assert.Equal(t, "PONG", string(buf[:4]))
|
assert.Equal(t, "PONG", string(buf[:4]))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type fakeClientConn struct {
|
||||||
|
remoteAddr *net.TCPAddr
|
||||||
|
localAddr *net.TCPAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f fakeClientConn) LocalAddr() net.Addr {
|
||||||
|
return f.localAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f fakeClientConn) RemoteAddr() net.Addr {
|
||||||
|
return f.remoteAddr
|
||||||
|
}
|
||||||
|
|
||||||
// fakeSpiffePKI simulates a SPIFFE aware PKI and allows generating multiple valid SVIDs.
|
// fakeSpiffePKI simulates a SPIFFE aware PKI and allows generating multiple valid SVIDs.
|
||||||
type fakeSpiffePKI struct {
|
type fakeSpiffePKI struct {
|
||||||
caPrivateKey *rsa.PrivateKey
|
caPrivateKey *rsa.PrivateKey
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ func (p *Proxy) ServeTCP(conn WriteCloser) {
|
||||||
// needed because of e.g. server.trackedConnection
|
// needed because of e.g. server.trackedConnection
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
connBackend, err := p.dialBackend()
|
connBackend, err := p.dialBackend(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Error while dialing backend")
|
log.Error().Err(err).Msg("Error while dialing backend")
|
||||||
return
|
return
|
||||||
|
|
@ -62,8 +62,10 @@ func (p *Proxy) ServeTCP(conn WriteCloser) {
|
||||||
<-errChan
|
<-errChan
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) dialBackend() (WriteCloser, error) {
|
func (p *Proxy) dialBackend(clientConn net.Conn) (WriteCloser, error) {
|
||||||
conn, err := p.dialer.Dial("tcp", p.address)
|
// The clientConn is passed to the dialer so that it can use information from it if needed,
|
||||||
|
// to build a PROXY protocol header.
|
||||||
|
conn, err := p.dialer.Dial("tcp", p.address, clientConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue