added support for tcp proxyProtocol v1&v2 to backend

This commit is contained in:
Matthias Schneider 2020-11-17 13:04:04 +01:00 committed by GitHub
parent 520fcf82ae
commit 84b125bdde
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 388 additions and 83 deletions

View file

@ -1,10 +1,13 @@
package tcp
import (
"fmt"
"io"
"net"
"time"
"github.com/pires/go-proxyproto"
"github.com/traefik/traefik/v2/pkg/config/dynamic"
"github.com/traefik/traefik/v2/pkg/log"
)
@ -12,16 +15,21 @@ import (
type Proxy struct {
target *net.TCPAddr
terminationDelay time.Duration
proxyProtocol *dynamic.ProxyProtocol
}
// NewProxy creates a new Proxy.
func NewProxy(address string, terminationDelay time.Duration) (*Proxy, error) {
func NewProxy(address string, terminationDelay time.Duration, proxyProtocol *dynamic.ProxyProtocol) (*Proxy, error) {
tcpAddr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
return nil, err
}
return &Proxy{target: tcpAddr, terminationDelay: terminationDelay}, nil
if proxyProtocol != nil && (proxyProtocol.Version < 1 || proxyProtocol.Version > 2) {
return nil, fmt.Errorf("unknown proxyProtocol version: %d", proxyProtocol.Version)
}
return &Proxy{target: tcpAddr, terminationDelay: terminationDelay, proxyProtocol: proxyProtocol}, nil
}
// ServeTCP forwards the connection to a service.
@ -39,8 +47,16 @@ func (p *Proxy) ServeTCP(conn WriteCloser) {
// maybe not needed, but just in case
defer connBackend.Close()
errChan := make(chan error)
if p.proxyProtocol != nil && p.proxyProtocol.Version > 0 && p.proxyProtocol.Version < 3 {
header := proxyproto.HeaderProxyFromAddrs(byte(p.proxyProtocol.Version), conn.RemoteAddr(), conn.LocalAddr())
if _, err := header.WriteTo(connBackend); err != nil {
log.WithoutContext().Errorf("Error while writing proxy protocol headers to backend connection: %v", err)
return
}
}
go p.connCopy(conn, connBackend, errChan)
go p.connCopy(connBackend, conn, errChan)

View file

@ -2,13 +2,17 @@ package tcp
import (
"bytes"
"errors"
"fmt"
"io"
"net"
"testing"
"time"
"github.com/pires/go-proxyproto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/traefik/traefik/v2/pkg/config/dynamic"
)
func fakeRedis(t *testing.T, listener net.Listener) {
@ -16,6 +20,7 @@ func fakeRedis(t *testing.T, listener net.Listener) {
conn, err := listener.Accept()
fmt.Println("Accept on server")
require.NoError(t, err)
for {
withErr := false
buf := make([]byte, 64)
@ -26,12 +31,13 @@ func fakeRedis(t *testing.T, listener net.Listener) {
if string(buf[:4]) == "ping" {
time.Sleep(1 * time.Millisecond)
if _, err := conn.Write([]byte("PONG")); err != nil {
conn.Close()
_ = conn.Close()
return
}
}
if withErr {
conn.Close()
_ = conn.Close()
return
}
}
@ -46,7 +52,7 @@ func TestCloseWrite(t *testing.T) {
_, port, err := net.SplitHostPort(backendListener.Addr().String())
require.NoError(t, err)
proxy, err := NewProxy(":"+port, 10*time.Millisecond)
proxy, err := NewProxy(":"+port, 10*time.Millisecond, nil)
require.NoError(t, err)
proxyListener, err := net.Listen("tcp", ":0")
@ -79,3 +85,87 @@ func TestCloseWrite(t *testing.T) {
require.Equal(t, int64(4), n)
require.Equal(t, "PONG", buffer.String())
}
func TestProxyProtocol(t *testing.T) {
testCases := []struct {
desc string
version int
}{
{
desc: "PROXY protocol v1",
version: 1,
},
{
desc: "PROXY protocol v2",
version: 2,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
backendListener, err := net.Listen("tcp", ":0")
require.NoError(t, err)
var version int
proxyBackendListener := proxyproto.Listener{
Listener: backendListener,
ValidateHeader: func(h *proxyproto.Header) error {
version = int(h.Version)
return nil
},
Policy: func(upstream net.Addr) (proxyproto.Policy, error) {
switch test.version {
case 1, 2:
return proxyproto.USE, nil
default:
return proxyproto.REQUIRE, errors.New("unsupported version")
}
},
}
defer proxyBackendListener.Close()
go fakeRedis(t, &proxyBackendListener)
_, port, err := net.SplitHostPort(proxyBackendListener.Addr().String())
require.NoError(t, err)
proxy, err := NewProxy(":"+port, 10*time.Millisecond, &dynamic.ProxyProtocol{Version: test.version})
require.NoError(t, err)
proxyListener, err := net.Listen("tcp", ":0")
require.NoError(t, err)
go func() {
for {
conn, err := proxyListener.Accept()
require.NoError(t, err)
proxy.ServeTCP(conn.(*net.TCPConn))
}
}()
_, port, err = net.SplitHostPort(proxyListener.Addr().String())
require.NoError(t, err)
conn, err := net.Dial("tcp", ":"+port)
require.NoError(t, err)
_, err = conn.Write([]byte("ping\n"))
require.NoError(t, err)
err = conn.(*net.TCPConn).CloseWrite()
require.NoError(t, err)
var buf []byte
buffer := bytes.NewBuffer(buf)
n, err := io.Copy(buffer, conn)
require.NoError(t, err)
assert.Equal(t, int64(4), n)
assert.Equal(t, "PONG", buffer.String())
assert.Equal(t, test.version, version)
})
}
}