1
0
Fork 0

Send proxy protocol header before TLS handshake

Co-authored-by: Kevin Pollet <pollet.kevin@gmail.com>
This commit is contained in:
Romain 2025-08-29 12:30:04 +02:00 committed by GitHub
parent 30b0666219
commit f9fbcfbb42
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
28 changed files with 566 additions and 416 deletions

View file

@ -9,11 +9,13 @@ import (
"sync"
"time"
"github.com/pires/go-proxyproto"
"github.com/rs/zerolog/log"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/go-spiffe/v2/spiffetls/tlsconfig"
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
ptypes "github.com/traefik/paerser/types"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
traefiktls "github.com/traefik/traefik/v3/pkg/tls"
"github.com/traefik/traefik/v3/pkg/types"
@ -27,15 +29,54 @@ type Dialer interface {
}
type tcpDialer struct {
proxy.Dialer
dialer *net.Dialer
terminationDelay time.Duration
proxyProtocol *dynamic.ProxyProtocol
}
func (d tcpDialer) TerminationDelay() time.Duration {
return d.terminationDelay
}
// SpiffeX509Source allows to retrieve a x509 SVID and bundle.
func (d tcpDialer) Dial(network, addr string) (net.Conn, error) {
conn, err := d.dialer.Dial(network, addr)
if err != nil {
return nil, err
}
if d.proxyProtocol != nil && d.proxyProtocol.Version > 0 && d.proxyProtocol.Version < 3 {
header := proxyproto.HeaderProxyFromAddrs(byte(d.proxyProtocol.Version), conn.RemoteAddr(), conn.LocalAddr())
if _, err := header.WriteTo(conn); err != nil {
_ = conn.Close()
return nil, fmt.Errorf("writing PROXY Protocol header: %w", err)
}
}
return conn, nil
}
type tcpTLSDialer struct {
tcpDialer
tlsConfig *tls.Config
}
func (d tcpTLSDialer) Dial(network, addr string) (net.Conn, error) {
conn, err := d.tcpDialer.Dial(network, addr)
if err != nil {
return nil, err
}
// Now perform TLS handshake on the connection
tlsConn := tls.Client(conn, d.tlsConfig)
if err := tlsConn.Handshake(); err != nil {
_ = conn.Close()
return nil, fmt.Errorf("TLS handshake failed: %w", err)
}
return tlsConn, nil
}
// SpiffeX509Source allows retrieving a x509 SVID and bundle.
type SpiffeX509Source interface {
x509svid.Source
x509bundle.Source
@ -43,118 +84,106 @@ type SpiffeX509Source interface {
// DialerManager handles dialer for the reverse proxy.
type DialerManager struct {
rtLock sync.RWMutex
dialers map[string]Dialer
dialersTLS map[string]Dialer
spiffeX509Source SpiffeX509Source
serversTransportsMu sync.RWMutex
serversTransports map[string]*dynamic.TCPServersTransport
spiffeX509Source SpiffeX509Source
}
// NewDialerManager creates a new DialerManager.
func NewDialerManager(spiffeX509Source SpiffeX509Source) *DialerManager {
return &DialerManager{
dialers: make(map[string]Dialer),
dialersTLS: make(map[string]Dialer),
spiffeX509Source: spiffeX509Source,
serversTransports: make(map[string]*dynamic.TCPServersTransport),
spiffeX509Source: spiffeX509Source,
}
}
// Update updates the dialers configurations.
// Update updates the TCP serversTransport configurations.
func (d *DialerManager) Update(configs map[string]*dynamic.TCPServersTransport) {
d.rtLock.Lock()
defer d.rtLock.Unlock()
d.serversTransportsMu.Lock()
defer d.serversTransportsMu.Unlock()
d.dialers = make(map[string]Dialer)
d.dialersTLS = make(map[string]Dialer)
for configName, config := range configs {
if err := d.createDialers(configName, config); err != nil {
log.Debug().
Str("dialer", configName).
Err(err).
Msg("Create TCP Dialer")
}
}
d.serversTransports = configs
}
// Get gets a dialer by name.
func (d *DialerManager) Get(name string, tls bool) (Dialer, error) {
if len(name) == 0 {
name = "default@internal"
// Build builds a dialer by name.
func (d *DialerManager) Build(config *dynamic.TCPServersLoadBalancer, isTLS bool) (Dialer, error) {
name := "default@internal"
if config.ServersTransport != "" {
name = config.ServersTransport
}
d.rtLock.RLock()
defer d.rtLock.RUnlock()
if tls {
if rt, ok := d.dialersTLS[name]; ok {
return rt, nil
}
return nil, fmt.Errorf("TCP dialer not found %s", name)
var st *dynamic.TCPServersTransport
d.serversTransportsMu.RLock()
st, ok := d.serversTransports[name]
d.serversTransportsMu.RUnlock()
if !ok || st == nil {
return nil, fmt.Errorf("no transport configuration found for %q", name)
}
if rt, ok := d.dialers[name]; ok {
return rt, nil
// Handle TerminationDelay and ProxyProtocol deprecated options.
var terminationDelay ptypes.Duration
if config.TerminationDelay != nil {
terminationDelay = ptypes.Duration(*config.TerminationDelay)
}
proxyProtocol := config.ProxyProtocol
if config.ServersTransport != "" {
terminationDelay = st.TerminationDelay
proxyProtocol = st.ProxyProtocol
}
return nil, fmt.Errorf("TCP dialer not found %s", name)
}
// createDialers creates the dialers according to the TCPServersTransport configuration.
func (d *DialerManager) createDialers(name string, cfg *dynamic.TCPServersTransport) error {
if cfg == nil {
return errors.New("no transport configuration given")
}
dialer := &net.Dialer{
Timeout: time.Duration(cfg.DialTimeout),
KeepAlive: time.Duration(cfg.DialKeepAlive),
if proxyProtocol != nil && (proxyProtocol.Version < 1 || proxyProtocol.Version > 2) {
return nil, fmt.Errorf("unknown proxyProtocol version: %d", proxyProtocol.Version)
}
var tlsConfig *tls.Config
if cfg.TLS != nil {
if cfg.TLS.Spiffe != nil {
if st.TLS != nil {
if st.TLS.Spiffe != nil {
if d.spiffeX509Source == nil {
return errors.New("SPIFFE is enabled for this transport, but not configured")
return nil, errors.New("SPIFFE is enabled for this transport, but not configured")
}
authorizer, err := buildSpiffeAuthorizer(cfg.TLS.Spiffe)
authorizer, err := buildSpiffeAuthorizer(st.TLS.Spiffe)
if err != nil {
return fmt.Errorf("unable to build SPIFFE authorizer: %w", err)
return nil, fmt.Errorf("unable to build SPIFFE authorizer: %w", err)
}
tlsConfig = tlsconfig.MTLSClientConfig(d.spiffeX509Source, d.spiffeX509Source, authorizer)
}
if cfg.TLS.InsecureSkipVerify || len(cfg.TLS.RootCAs) > 0 || len(cfg.TLS.ServerName) > 0 || len(cfg.TLS.Certificates) > 0 || cfg.TLS.PeerCertURI != "" {
if st.TLS.InsecureSkipVerify || len(st.TLS.RootCAs) > 0 || len(st.TLS.ServerName) > 0 || len(st.TLS.Certificates) > 0 || st.TLS.PeerCertURI != "" {
if tlsConfig != nil {
return errors.New("TLS and SPIFFE configuration cannot be defined at the same time")
return nil, errors.New("TLS and SPIFFE configuration cannot be defined at the same time")
}
tlsConfig = &tls.Config{
ServerName: cfg.TLS.ServerName,
InsecureSkipVerify: cfg.TLS.InsecureSkipVerify,
RootCAs: createRootCACertPool(cfg.TLS.RootCAs),
Certificates: cfg.TLS.Certificates.GetCertificates(),
ServerName: st.TLS.ServerName,
InsecureSkipVerify: st.TLS.InsecureSkipVerify,
RootCAs: createRootCACertPool(st.TLS.RootCAs),
Certificates: st.TLS.Certificates.GetCertificates(),
}
if cfg.TLS.PeerCertURI != "" {
if st.TLS.PeerCertURI != "" {
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
return traefiktls.VerifyPeerCertificate(cfg.TLS.PeerCertURI, tlsConfig, rawCerts)
return traefiktls.VerifyPeerCertificate(st.TLS.PeerCertURI, tlsConfig, rawCerts)
}
}
}
}
tlsDialer := &tls.Dialer{
NetDialer: dialer,
Config: tlsConfig,
dialer := tcpDialer{
dialer: &net.Dialer{
Timeout: time.Duration(st.DialTimeout),
KeepAlive: time.Duration(st.DialKeepAlive),
},
terminationDelay: time.Duration(terminationDelay),
proxyProtocol: proxyProtocol,
}
d.dialers[name] = tcpDialer{dialer, time.Duration(cfg.TerminationDelay)}
d.dialersTLS[name] = tcpDialer{tlsDialer, time.Duration(cfg.TerminationDelay)}
return nil
if !isTLS {
return dialer, nil
}
return tcpTLSDialer{dialer, tlsConfig}, nil
}
func createRootCACertPool(rootCAs []types.FileOrContent) *x509.CertPool {

View file

@ -7,6 +7,7 @@ import (
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors"
"io"
"math/big"
"net"
@ -14,6 +15,7 @@ import (
"testing"
"time"
"github.com/pires/go-proxyproto"
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
"github.com/spiffe/go-spiffe/v2/spiffeid"
"github.com/spiffe/go-spiffe/v2/spiffetls/tlsconfig"
@ -131,7 +133,7 @@ func TestConflictingConfig(t *testing.T) {
dialerManager.Update(dynamicConf)
_, err := dialerManager.Get("test", false)
_, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, false)
require.Error(t, err)
}
@ -140,7 +142,7 @@ func TestNoTLS(t *testing.T) {
require.NoError(t, err)
defer backendListener.Close()
go fakeRedis(t, backendListener)
go fakeServer(t, backendListener)
_, port, err := net.SplitHostPort(backendListener.Addr().String())
require.NoError(t, err)
@ -155,7 +157,7 @@ func TestNoTLS(t *testing.T) {
dialerManager.Update(dynamicConf)
dialer, err := dialerManager.Get("test", false)
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, false)
require.NoError(t, err)
conn, err := dialer.Dial("tcp", ":"+port)
@ -186,7 +188,7 @@ func TestTLS(t *testing.T) {
tlsListener := tls.NewListener(backendListener, &tls.Config{Certificates: []tls.Certificate{cert}})
defer tlsListener.Close()
go fakeRedis(t, tlsListener)
go fakeServer(t, tlsListener)
_, port, err := net.SplitHostPort(tlsListener.Addr().String())
require.NoError(t, err)
@ -204,7 +206,7 @@ func TestTLS(t *testing.T) {
dialerManager.Update(dynamicConf)
dialer, err := dialerManager.Get("test", true)
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, true)
require.NoError(t, err)
conn, err := dialer.Dial("tcp", ":"+port)
@ -236,7 +238,7 @@ func TestTLSWithInsecureSkipVerify(t *testing.T) {
tlsListener := tls.NewListener(backendListener, &tls.Config{Certificates: []tls.Certificate{cert}})
defer tlsListener.Close()
go fakeRedis(t, tlsListener)
go fakeServer(t, tlsListener)
_, port, err := net.SplitHostPort(tlsListener.Addr().String())
require.NoError(t, err)
@ -255,7 +257,7 @@ func TestTLSWithInsecureSkipVerify(t *testing.T) {
dialerManager.Update(dynamicConf)
dialer, err := dialerManager.Get("test", true)
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, true)
require.NoError(t, err)
conn, err := dialer.Dial("tcp", ":"+port)
@ -297,7 +299,7 @@ func TestMTLS(t *testing.T) {
})
defer tlsListener.Close()
go fakeRedis(t, tlsListener)
go fakeServer(t, tlsListener)
_, port, err := net.SplitHostPort(tlsListener.Addr().String())
require.NoError(t, err)
@ -324,7 +326,7 @@ func TestMTLS(t *testing.T) {
dialerManager.Update(dynamicConf)
dialer, err := dialerManager.Get("test", true)
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, true)
require.NoError(t, err)
conn, err := dialer.Dial("tcp", ":"+port)
@ -444,7 +446,7 @@ func TestSpiffeMTLS(t *testing.T) {
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
go fakeRedis(t, tlsListener)
go fakeServer(t, tlsListener)
dialerManager := NewDialerManager(test.clientSource)
@ -458,7 +460,7 @@ func TestSpiffeMTLS(t *testing.T) {
dialerManager.Update(dynamicConf)
dialer, err := dialerManager.Get("test", true)
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, true)
require.NoError(t, err)
conn, err := dialer.Dial("tcp", ":"+port)
@ -487,6 +489,226 @@ func TestSpiffeMTLS(t *testing.T) {
}
}
func TestProxyProtocol(t *testing.T) {
testCases := []struct {
desc string
version int
}{
{
desc: "proxy protocol v1",
version: 1,
},
{
desc: "proxy protocol v2",
version: 2,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
backendListener, err := net.Listen("tcp", ":0")
require.NoError(t, err)
var version int
proxyBackendListener := proxyproto.Listener{
Listener: backendListener,
ValidateHeader: func(h *proxyproto.Header) error {
version = int(h.Version)
return nil
},
Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
switch test.version {
case 1, 2:
return proxyproto.USE, nil
default:
return proxyproto.REQUIRE, errors.New("unsupported version")
}
},
}
defer proxyBackendListener.Close()
go fakeServer(t, &proxyBackendListener)
_, port, err := net.SplitHostPort(backendListener.Addr().String())
require.NoError(t, err)
dialerManager := NewDialerManager(nil)
dialerManager.Update(map[string]*dynamic.TCPServersTransport{
"test": {
ProxyProtocol: &dynamic.ProxyProtocol{
Version: test.version,
},
},
})
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, false)
require.NoError(t, err)
conn, err := dialer.Dial("tcp", ":"+port)
require.NoError(t, err)
defer conn.Close()
_, err = conn.Write([]byte("ping"))
require.NoError(t, err)
buf := make([]byte, 64)
n, err := conn.Read(buf)
require.NoError(t, err)
assert.Equal(t, 4, n)
assert.Equal(t, "PONG", string(buf[:4]))
assert.Equal(t, test.version, version)
})
}
}
func TestProxyProtocolWithTLS(t *testing.T) {
testCases := []struct {
desc string
version int
}{
{
desc: "proxy protocol v1",
version: 1,
},
{
desc: "proxy protocol v2",
version: 2,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
cert, err := tls.X509KeyPair(LocalhostCert, LocalhostKey)
require.NoError(t, err)
backendListener, err := net.Listen("tcp", ":0")
require.NoError(t, err)
var version int
proxyBackendListener := proxyproto.Listener{
Listener: backendListener,
ValidateHeader: func(h *proxyproto.Header) error {
version = int(h.Version)
return nil
},
Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
switch test.version {
case 1, 2:
return proxyproto.USE, nil
default:
return proxyproto.REQUIRE, errors.New("unsupported version")
}
},
}
defer proxyBackendListener.Close()
go func() {
conn, err := proxyBackendListener.Accept()
require.NoError(t, err)
defer conn.Close()
// Now wrap with TLS and perform handshake
tlsConn := tls.Server(conn, &tls.Config{Certificates: []tls.Certificate{cert}})
defer tlsConn.Close()
err = tlsConn.Handshake()
require.NoError(t, err)
buf := make([]byte, 64)
n, err := tlsConn.Read(buf)
require.NoError(t, err)
if bytes.Equal(buf[:n], []byte("ping")) {
_, _ = tlsConn.Write([]byte("PONG"))
}
}()
_, port, err := net.SplitHostPort(backendListener.Addr().String())
require.NoError(t, err)
dialerManager := NewDialerManager(nil)
dialerManager.Update(map[string]*dynamic.TCPServersTransport{
"test": {
TLS: &dynamic.TLSClientConfig{
ServerName: "example.com",
RootCAs: []types.FileOrContent{types.FileOrContent(LocalhostCert)},
InsecureSkipVerify: true,
},
ProxyProtocol: &dynamic.ProxyProtocol{
Version: test.version,
},
},
})
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{
ServersTransport: "test",
}, true)
require.NoError(t, err)
conn, err := dialer.Dial("tcp", ":"+port)
require.NoError(t, err)
defer conn.Close()
_, err = conn.Write([]byte("ping"))
require.NoError(t, err)
buf := make([]byte, 64)
n, err := conn.Read(buf)
require.NoError(t, err)
assert.Equal(t, 4, n)
assert.Equal(t, "PONG", string(buf[:4]))
assert.Equal(t, test.version, version)
})
}
}
func TestProxyProtocolDisabled(t *testing.T) {
backendListener, err := net.Listen("tcp", ":0")
require.NoError(t, err)
defer backendListener.Close()
go func() {
conn, err := backendListener.Accept()
require.NoError(t, err)
defer conn.Close()
buf := make([]byte, 64)
n, err := conn.Read(buf)
require.NoError(t, err)
if bytes.Equal(buf[:n], []byte("ping")) {
_, _ = conn.Write([]byte("PONG"))
}
}()
_, port, err := net.SplitHostPort(backendListener.Addr().String())
require.NoError(t, err)
// No proxy protocol configuration.
dialerManager := NewDialerManager(nil)
dialerManager.Update(map[string]*dynamic.TCPServersTransport{
"test": {},
})
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{ServersTransport: "test"}, false)
require.NoError(t, err)
conn, err := dialer.Dial("tcp", ":"+port)
require.NoError(t, err)
_, err = conn.Write([]byte("ping"))
require.NoError(t, err)
buf := make([]byte, 64)
n, err := conn.Read(buf)
require.NoError(t, err)
assert.Equal(t, 4, n)
assert.Equal(t, "PONG", string(buf[:4]))
}
// fakeSpiffePKI simulates a SPIFFE aware PKI and allows generating multiple valid SVIDs.
type fakeSpiffePKI struct {
caPrivateKey *rsa.PrivateKey

View file

@ -2,34 +2,25 @@ package tcp
import (
"errors"
"fmt"
"io"
"net"
"syscall"
"time"
"github.com/pires/go-proxyproto"
"github.com/rs/zerolog/log"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
)
// Proxy forwards a TCP request to a TCP service.
type Proxy struct {
address string
proxyProtocol *dynamic.ProxyProtocol
dialer Dialer
address string
dialer Dialer
}
// NewProxy creates a new Proxy.
func NewProxy(address string, proxyProtocol *dynamic.ProxyProtocol, dialer Dialer) (*Proxy, error) {
if proxyProtocol != nil && (proxyProtocol.Version < 1 || proxyProtocol.Version > 2) {
return nil, fmt.Errorf("unknown proxyProtocol version: %d", proxyProtocol.Version)
}
func NewProxy(address string, dialer Dialer) (*Proxy, error) {
return &Proxy{
address: address,
proxyProtocol: proxyProtocol,
dialer: dialer,
address: address,
dialer: dialer,
}, nil
}
@ -53,14 +44,6 @@ func (p *Proxy) ServeTCP(conn WriteCloser) {
defer connBackend.Close()
errChan := make(chan error)
if p.proxyProtocol != nil && p.proxyProtocol.Version > 0 && p.proxyProtocol.Version < 3 {
header := proxyproto.HeaderProxyFromAddrs(byte(p.proxyProtocol.Version), conn.RemoteAddr(), conn.LocalAddr())
if _, err := header.WriteTo(connBackend); err != nil {
log.Error().Err(err).Msg("Error while writing TCP proxy protocol headers to backend connection")
return
}
}
go p.connCopy(conn, connBackend, errChan)
go p.connCopy(connBackend, conn, errChan)

View file

@ -2,59 +2,25 @@ package tcp
import (
"bytes"
"errors"
"io"
"net"
"testing"
"time"
"github.com/pires/go-proxyproto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
)
func fakeRedis(t *testing.T, listener net.Listener) {
t.Helper()
for {
conn, err := listener.Accept()
require.NoError(t, err)
for {
withErr := false
buf := make([]byte, 64)
if _, err := conn.Read(buf); err != nil {
withErr = true
}
if string(buf[:4]) == "ping" {
time.Sleep(1 * time.Millisecond)
if _, err := conn.Write([]byte("PONG")); err != nil {
_ = conn.Close()
return
}
}
if withErr {
_ = conn.Close()
return
}
}
}
}
func TestCloseWrite(t *testing.T) {
backendListener, err := net.Listen("tcp", ":0")
require.NoError(t, err)
go fakeRedis(t, backendListener)
go fakeServer(t, backendListener)
_, port, err := net.SplitHostPort(backendListener.Addr().String())
require.NoError(t, err)
dialer := tcpDialer{&net.Dialer{}, 10 * time.Millisecond}
dialer := tcpDialer{&net.Dialer{}, 10 * time.Millisecond, nil}
proxy, err := NewProxy(":"+port, nil, dialer)
proxy, err := NewProxy(":"+port, dialer)
require.NoError(t, err)
proxyListener, err := net.Listen("tcp", ":0")
@ -84,90 +50,37 @@ func TestCloseWrite(t *testing.T) {
buffer := bytes.NewBuffer(buf)
n, err := io.Copy(buffer, conn)
require.NoError(t, err)
require.Equal(t, int64(4), n)
require.Equal(t, "PONG", buffer.String())
}
func TestProxyProtocol(t *testing.T) {
testCases := []struct {
desc string
version int
}{
{
desc: "PROXY protocol v1",
version: 1,
},
{
desc: "PROXY protocol v2",
version: 2,
},
}
func fakeServer(t *testing.T, listener net.Listener) {
t.Helper()
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
backendListener, err := net.Listen("tcp", ":0")
require.NoError(t, err)
for {
conn, err := listener.Accept()
require.NoError(t, err)
var version int
proxyBackendListener := proxyproto.Listener{
Listener: backendListener,
ValidateHeader: func(h *proxyproto.Header) error {
version = int(h.Version)
return nil
},
Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
switch test.version {
case 1, 2:
return proxyproto.USE, nil
default:
return proxyproto.REQUIRE, errors.New("unsupported version")
}
},
for {
withErr := false
buf := make([]byte, 64)
if _, err := conn.Read(buf); err != nil {
withErr = true
}
defer proxyBackendListener.Close()
go fakeRedis(t, &proxyBackendListener)
_, port, err := net.SplitHostPort(proxyBackendListener.Addr().String())
require.NoError(t, err)
dialer := tcpDialer{&net.Dialer{}, 10 * time.Millisecond}
proxy, err := NewProxy(":"+port, &dynamic.ProxyProtocol{Version: test.version}, dialer)
require.NoError(t, err)
proxyListener, err := net.Listen("tcp", ":0")
require.NoError(t, err)
go func() {
for {
conn, err := proxyListener.Accept()
require.NoError(t, err)
proxy.ServeTCP(conn.(*net.TCPConn))
if string(buf[:4]) == "ping" {
time.Sleep(1 * time.Millisecond)
if _, err := conn.Write([]byte("PONG")); err != nil {
_ = conn.Close()
return
}
}()
}
_, port, err = net.SplitHostPort(proxyListener.Addr().String())
require.NoError(t, err)
conn, err := net.Dial("tcp", ":"+port)
require.NoError(t, err)
_, err = conn.Write([]byte("ping\n"))
require.NoError(t, err)
err = conn.(*net.TCPConn).CloseWrite()
require.NoError(t, err)
var buf []byte
buffer := bytes.NewBuffer(buf)
n, err := io.Copy(buffer, conn)
require.NoError(t, err)
assert.Equal(t, int64(4), n)
assert.Equal(t, "PONG", buffer.String())
assert.Equal(t, test.version, version)
})
if withErr {
_ = conn.Close()
return
}
}
}
}