1
0
Fork 0

Send proxy protocol header before TLS handshake

Co-authored-by: Kevin Pollet <pollet.kevin@gmail.com>
This commit is contained in:
Romain 2025-08-29 12:30:04 +02:00 committed by GitHub
parent 30b0666219
commit f9fbcfbb42
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 566 additions and 416 deletions

View file

@ -7,6 +7,7 @@ import (
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"io"
"math/big"
"net"
@ -14,6 +15,7 @@ import (
"testing"
"time"
"github.com/pires/go-proxyproto"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/go-spiffe/v2/spiffetls/tlsconfig"
@ -131,7 +133,7 @@ func TestConflictingConfig(t *testing.T) {
dialerManager.Update(dynamicConf)
_, err := dialerManager.Get("test", false)
_, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, false)
require.Error(t, err)
}
@ -140,7 +142,7 @@ func TestNoTLS(t *testing.T) {
require.NoError(t, err)
defer backendListener.Close()
go fakeRedis(t, backendListener)
go fakeServer(t, backendListener)
_, port, err := net.SplitHostPort(backendListener.Addr().String())
require.NoError(t, err)
@ -155,7 +157,7 @@ func TestNoTLS(t *testing.T) {
dialerManager.Update(dynamicConf)
dialer, err := dialerManager.Get("test", false)
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, false)
require.NoError(t, err)
conn, err := dialer.Dial("tcp", ":"+port)
@ -186,7 +188,7 @@ func TestTLS(t *testing.T) {
tlsListener := tls.NewListener(backendListener, &tls.Config{Certificates: []tls.Certificate{cert}})
defer tlsListener.Close()
go fakeRedis(t, tlsListener)
go fakeServer(t, tlsListener)
_, port, err := net.SplitHostPort(tlsListener.Addr().String())
require.NoError(t, err)
@ -204,7 +206,7 @@ func TestTLS(t *testing.T) {
dialerManager.Update(dynamicConf)
dialer, err := dialerManager.Get("test", true)
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, true)
require.NoError(t, err)
conn, err := dialer.Dial("tcp", ":"+port)
@ -236,7 +238,7 @@ func TestTLSWithInsecureSkipVerify(t *testing.T) {
tlsListener := tls.NewListener(backendListener, &tls.Config{Certificates: []tls.Certificate{cert}})
defer tlsListener.Close()
go fakeRedis(t, tlsListener)
go fakeServer(t, tlsListener)
_, port, err := net.SplitHostPort(tlsListener.Addr().String())
require.NoError(t, err)
@ -255,7 +257,7 @@ func TestTLSWithInsecureSkipVerify(t *testing.T) {
dialerManager.Update(dynamicConf)
dialer, err := dialerManager.Get("test", true)
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, true)
require.NoError(t, err)
conn, err := dialer.Dial("tcp", ":"+port)
@ -297,7 +299,7 @@ func TestMTLS(t *testing.T) {
})
defer tlsListener.Close()
go fakeRedis(t, tlsListener)
go fakeServer(t, tlsListener)
_, port, err := net.SplitHostPort(tlsListener.Addr().String())
require.NoError(t, err)
@ -324,7 +326,7 @@ func TestMTLS(t *testing.T) {
dialerManager.Update(dynamicConf)
dialer, err := dialerManager.Get("test", true)
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, true)
require.NoError(t, err)
conn, err := dialer.Dial("tcp", ":"+port)
@ -444,7 +446,7 @@ func TestSpiffeMTLS(t *testing.T) {
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
go fakeRedis(t, tlsListener)
go fakeServer(t, tlsListener)
dialerManager := NewDialerManager(test.clientSource)
@ -458,7 +460,7 @@ func TestSpiffeMTLS(t *testing.T) {
dialerManager.Update(dynamicConf)
dialer, err := dialerManager.Get("test", true)
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, true)
require.NoError(t, err)
conn, err := dialer.Dial("tcp", ":"+port)
@ -487,6 +489,226 @@ func TestSpiffeMTLS(t *testing.T) {
}
}
func TestProxyProtocol(t *testing.T) {
testCases := []struct {
desc string
version int
}{
{
desc: "proxy protocol v1",
version: 1,
},
{
desc: "proxy protocol v2",
version: 2,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
backendListener, err := net.Listen("tcp", ":0")
require.NoError(t, err)
var version int
proxyBackendListener := proxyproto.Listener{
Listener: backendListener,
ValidateHeader: func(h *proxyproto.Header) error {
version = int(h.Version)
return nil
},
Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
switch test.version {
case 1, 2:
return proxyproto.USE, nil
default:
return proxyproto.REQUIRE, errors.New("unsupported version")
}
},
}
defer proxyBackendListener.Close()
go fakeServer(t, &proxyBackendListener)
_, port, err := net.SplitHostPort(backendListener.Addr().String())
require.NoError(t, err)
dialerManager := NewDialerManager(nil)
dialerManager.Update(map[string]*dynamic.TCPServersTransport{
"test": {
ProxyProtocol: &dynamic.ProxyProtocol{
Version: test.version,
},
},
})
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, false)
require.NoError(t, err)
conn, err := dialer.Dial("tcp", ":"+port)
require.NoError(t, err)
defer conn.Close()
_, err = conn.Write([]byte("ping"))
require.NoError(t, err)
buf := make([]byte, 64)
n, err := conn.Read(buf)
require.NoError(t, err)
assert.Equal(t, 4, n)
assert.Equal(t, "PONG", string(buf[:4]))
assert.Equal(t, test.version, version)
})
}
}
func TestProxyProtocolWithTLS(t *testing.T) {
testCases := []struct {
desc string
version int
}{
{
desc: "proxy protocol v1",
version: 1,
},
{
desc: "proxy protocol v2",
version: 2,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
cert, err := tls.X509KeyPair(LocalhostCert, LocalhostKey)
require.NoError(t, err)
backendListener, err := net.Listen("tcp", ":0")
require.NoError(t, err)
var version int
proxyBackendListener := proxyproto.Listener{
Listener: backendListener,
ValidateHeader: func(h *proxyproto.Header) error {
version = int(h.Version)
return nil
},
Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
switch test.version {
case 1, 2:
return proxyproto.USE, nil
default:
return proxyproto.REQUIRE, errors.New("unsupported version")
}
},
}
defer proxyBackendListener.Close()
go func() {
conn, err := proxyBackendListener.Accept()
require.NoError(t, err)
defer conn.Close()
// Now wrap with TLS and perform handshake
tlsConn := tls.Server(conn, &tls.Config{Certificates: []tls.Certificate{cert}})
defer tlsConn.Close()
err = tlsConn.Handshake()
require.NoError(t, err)
buf := make([]byte, 64)
n, err := tlsConn.Read(buf)
require.NoError(t, err)
if bytes.Equal(buf[:n], []byte("ping")) {
_, _ = tlsConn.Write([]byte("PONG"))
}
}()
_, port, err := net.SplitHostPort(backendListener.Addr().String())
require.NoError(t, err)
dialerManager := NewDialerManager(nil)
dialerManager.Update(map[string]*dynamic.TCPServersTransport{
"test": {
TLS: &dynamic.TLSClientConfig{
ServerName: "example.com",
RootCAs: []types.FileOrContent{types.FileOrContent(LocalhostCert)},
InsecureSkipVerify: true,
},
ProxyProtocol: &dynamic.ProxyProtocol{
Version: test.version,
},
},
})
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{
ServersTransport: "test",
}, true)
require.NoError(t, err)
conn, err := dialer.Dial("tcp", ":"+port)
require.NoError(t, err)
defer conn.Close()
_, err = conn.Write([]byte("ping"))
require.NoError(t, err)
buf := make([]byte, 64)
n, err := conn.Read(buf)
require.NoError(t, err)
assert.Equal(t, 4, n)
assert.Equal(t, "PONG", string(buf[:4]))
assert.Equal(t, test.version, version)
})
}
}
func TestProxyProtocolDisabled(t *testing.T) {
backendListener, err := net.Listen("tcp", ":0")
require.NoError(t, err)
defer backendListener.Close()
go func() {
conn, err := backendListener.Accept()
require.NoError(t, err)
defer conn.Close()
buf := make([]byte, 64)
n, err := conn.Read(buf)
require.NoError(t, err)
if bytes.Equal(buf[:n], []byte("ping")) {
_, _ = conn.Write([]byte("PONG"))
}
}()
_, port, err := net.SplitHostPort(backendListener.Addr().String())
require.NoError(t, err)
// No proxy protocol configuration.
dialerManager := NewDialerManager(nil)
dialerManager.Update(map[string]*dynamic.TCPServersTransport{
"test": {},
})
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, false)
require.NoError(t, err)
conn, err := dialer.Dial("tcp", ":"+port)
require.NoError(t, err)
_, err = conn.Write([]byte("ping"))
require.NoError(t, err)
buf := make([]byte, 64)
n, err := conn.Read(buf)
require.NoError(t, err)
assert.Equal(t, 4, n)
assert.Equal(t, "PONG", string(buf[:4]))
}
// fakeSpiffePKI simulates a SPIFFE aware PKI and allows generating multiple valid SVIDs.
type fakeSpiffePKI struct {
caPrivateKey *rsa.PrivateKey