Send proxy protocol header before TLS handshake
Co-authored-by: Kevin Pollet <pollet.kevin@gmail.com>
This commit is contained in:
parent
30b0666219
commit
f9fbcfbb42
28 changed files with 566 additions and 416 deletions
|
|
@ -9,11 +9,13 @@ import (
|
|||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pires/go-proxyproto"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spiffe/go-spiffe/v2/bundle/x509bundle"
|
||||
"github.com/spiffe/go-spiffe/v2/spiffeid"
|
||||
"github.com/spiffe/go-spiffe/v2/spiffetls/tlsconfig"
|
||||
"github.com/spiffe/go-spiffe/v2/svid/x509svid"
|
||||
ptypes "github.com/traefik/paerser/types"
|
||||
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
||||
traefiktls "github.com/traefik/traefik/v3/pkg/tls"
|
||||
"github.com/traefik/traefik/v3/pkg/types"
|
||||
|
|
@ -27,15 +29,54 @@ type Dialer interface {
|
|||
}
|
||||
|
||||
type tcpDialer struct {
|
||||
proxy.Dialer
|
||||
dialer *net.Dialer
|
||||
terminationDelay time.Duration
|
||||
proxyProtocol *dynamic.ProxyProtocol
|
||||
}
|
||||
|
||||
func (d tcpDialer) TerminationDelay() time.Duration {
|
||||
return d.terminationDelay
|
||||
}
|
||||
|
||||
// SpiffeX509Source allows to retrieve a x509 SVID and bundle.
|
||||
func (d tcpDialer) Dial(network, addr string) (net.Conn, error) {
|
||||
conn, err := d.dialer.Dial(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if d.proxyProtocol != nil && d.proxyProtocol.Version > 0 && d.proxyProtocol.Version < 3 {
|
||||
header := proxyproto.HeaderProxyFromAddrs(byte(d.proxyProtocol.Version), conn.RemoteAddr(), conn.LocalAddr())
|
||||
if _, err := header.WriteTo(conn); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("writing PROXY Protocol header: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
type tcpTLSDialer struct {
|
||||
tcpDialer
|
||||
tlsConfig *tls.Config
|
||||
}
|
||||
|
||||
func (d tcpTLSDialer) Dial(network, addr string) (net.Conn, error) {
|
||||
conn, err := d.tcpDialer.Dial(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Now perform TLS handshake on the connection
|
||||
tlsConn := tls.Client(conn, d.tlsConfig)
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("TLS handshake failed: %w", err)
|
||||
}
|
||||
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
// SpiffeX509Source allows retrieving a x509 SVID and bundle.
|
||||
type SpiffeX509Source interface {
|
||||
x509svid.Source
|
||||
x509bundle.Source
|
||||
|
|
@ -43,118 +84,106 @@ type SpiffeX509Source interface {
|
|||
|
||||
// DialerManager handles dialer for the reverse proxy.
|
||||
type DialerManager struct {
|
||||
rtLock sync.RWMutex
|
||||
dialers map[string]Dialer
|
||||
dialersTLS map[string]Dialer
|
||||
spiffeX509Source SpiffeX509Source
|
||||
serversTransportsMu sync.RWMutex
|
||||
serversTransports map[string]*dynamic.TCPServersTransport
|
||||
spiffeX509Source SpiffeX509Source
|
||||
}
|
||||
|
||||
// NewDialerManager creates a new DialerManager.
|
||||
func NewDialerManager(spiffeX509Source SpiffeX509Source) *DialerManager {
|
||||
return &DialerManager{
|
||||
dialers: make(map[string]Dialer),
|
||||
dialersTLS: make(map[string]Dialer),
|
||||
spiffeX509Source: spiffeX509Source,
|
||||
serversTransports: make(map[string]*dynamic.TCPServersTransport),
|
||||
spiffeX509Source: spiffeX509Source,
|
||||
}
|
||||
}
|
||||
|
||||
// Update updates the dialers configurations.
|
||||
// Update updates the TCP serversTransport configurations.
|
||||
func (d *DialerManager) Update(configs map[string]*dynamic.TCPServersTransport) {
|
||||
d.rtLock.Lock()
|
||||
defer d.rtLock.Unlock()
|
||||
d.serversTransportsMu.Lock()
|
||||
defer d.serversTransportsMu.Unlock()
|
||||
|
||||
d.dialers = make(map[string]Dialer)
|
||||
d.dialersTLS = make(map[string]Dialer)
|
||||
for configName, config := range configs {
|
||||
if err := d.createDialers(configName, config); err != nil {
|
||||
log.Debug().
|
||||
Str("dialer", configName).
|
||||
Err(err).
|
||||
Msg("Create TCP Dialer")
|
||||
}
|
||||
}
|
||||
d.serversTransports = configs
|
||||
}
|
||||
|
||||
// Get gets a dialer by name.
|
||||
func (d *DialerManager) Get(name string, tls bool) (Dialer, error) {
|
||||
if len(name) == 0 {
|
||||
name = "default@internal"
|
||||
// Build builds a dialer by name.
|
||||
func (d *DialerManager) Build(config *dynamic.TCPServersLoadBalancer, isTLS bool) (Dialer, error) {
|
||||
name := "default@internal"
|
||||
if config.ServersTransport != "" {
|
||||
name = config.ServersTransport
|
||||
}
|
||||
|
||||
d.rtLock.RLock()
|
||||
defer d.rtLock.RUnlock()
|
||||
|
||||
if tls {
|
||||
if rt, ok := d.dialersTLS[name]; ok {
|
||||
return rt, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("TCP dialer not found %s", name)
|
||||
var st *dynamic.TCPServersTransport
|
||||
d.serversTransportsMu.RLock()
|
||||
st, ok := d.serversTransports[name]
|
||||
d.serversTransportsMu.RUnlock()
|
||||
if !ok || st == nil {
|
||||
return nil, fmt.Errorf("no transport configuration found for %q", name)
|
||||
}
|
||||
|
||||
if rt, ok := d.dialers[name]; ok {
|
||||
return rt, nil
|
||||
// Handle TerminationDelay and ProxyProtocol deprecated options.
|
||||
var terminationDelay ptypes.Duration
|
||||
if config.TerminationDelay != nil {
|
||||
terminationDelay = ptypes.Duration(*config.TerminationDelay)
|
||||
}
|
||||
proxyProtocol := config.ProxyProtocol
|
||||
|
||||
if config.ServersTransport != "" {
|
||||
terminationDelay = st.TerminationDelay
|
||||
proxyProtocol = st.ProxyProtocol
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("TCP dialer not found %s", name)
|
||||
}
|
||||
|
||||
// createDialers creates the dialers according to the TCPServersTransport configuration.
|
||||
func (d *DialerManager) createDialers(name string, cfg *dynamic.TCPServersTransport) error {
|
||||
if cfg == nil {
|
||||
return errors.New("no transport configuration given")
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: time.Duration(cfg.DialTimeout),
|
||||
KeepAlive: time.Duration(cfg.DialKeepAlive),
|
||||
if proxyProtocol != nil && (proxyProtocol.Version < 1 || proxyProtocol.Version > 2) {
|
||||
return nil, fmt.Errorf("unknown proxyProtocol version: %d", proxyProtocol.Version)
|
||||
}
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
|
||||
if cfg.TLS != nil {
|
||||
if cfg.TLS.Spiffe != nil {
|
||||
if st.TLS != nil {
|
||||
if st.TLS.Spiffe != nil {
|
||||
if d.spiffeX509Source == nil {
|
||||
return errors.New("SPIFFE is enabled for this transport, but not configured")
|
||||
return nil, errors.New("SPIFFE is enabled for this transport, but not configured")
|
||||
}
|
||||
|
||||
authorizer, err := buildSpiffeAuthorizer(cfg.TLS.Spiffe)
|
||||
authorizer, err := buildSpiffeAuthorizer(st.TLS.Spiffe)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to build SPIFFE authorizer: %w", err)
|
||||
return nil, fmt.Errorf("unable to build SPIFFE authorizer: %w", err)
|
||||
}
|
||||
|
||||
tlsConfig = tlsconfig.MTLSClientConfig(d.spiffeX509Source, d.spiffeX509Source, authorizer)
|
||||
}
|
||||
|
||||
if cfg.TLS.InsecureSkipVerify || len(cfg.TLS.RootCAs) > 0 || len(cfg.TLS.ServerName) > 0 || len(cfg.TLS.Certificates) > 0 || cfg.TLS.PeerCertURI != "" {
|
||||
if st.TLS.InsecureSkipVerify || len(st.TLS.RootCAs) > 0 || len(st.TLS.ServerName) > 0 || len(st.TLS.Certificates) > 0 || st.TLS.PeerCertURI != "" {
|
||||
if tlsConfig != nil {
|
||||
return errors.New("TLS and SPIFFE configuration cannot be defined at the same time")
|
||||
return nil, errors.New("TLS and SPIFFE configuration cannot be defined at the same time")
|
||||
}
|
||||
|
||||
tlsConfig = &tls.Config{
|
||||
ServerName: cfg.TLS.ServerName,
|
||||
InsecureSkipVerify: cfg.TLS.InsecureSkipVerify,
|
||||
RootCAs: createRootCACertPool(cfg.TLS.RootCAs),
|
||||
Certificates: cfg.TLS.Certificates.GetCertificates(),
|
||||
ServerName: st.TLS.ServerName,
|
||||
InsecureSkipVerify: st.TLS.InsecureSkipVerify,
|
||||
RootCAs: createRootCACertPool(st.TLS.RootCAs),
|
||||
Certificates: st.TLS.Certificates.GetCertificates(),
|
||||
}
|
||||
|
||||
if cfg.TLS.PeerCertURI != "" {
|
||||
if st.TLS.PeerCertURI != "" {
|
||||
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
|
||||
return traefiktls.VerifyPeerCertificate(cfg.TLS.PeerCertURI, tlsConfig, rawCerts)
|
||||
return traefiktls.VerifyPeerCertificate(st.TLS.PeerCertURI, tlsConfig, rawCerts)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tlsDialer := &tls.Dialer{
|
||||
NetDialer: dialer,
|
||||
Config: tlsConfig,
|
||||
dialer := tcpDialer{
|
||||
dialer: &net.Dialer{
|
||||
Timeout: time.Duration(st.DialTimeout),
|
||||
KeepAlive: time.Duration(st.DialKeepAlive),
|
||||
},
|
||||
terminationDelay: time.Duration(terminationDelay),
|
||||
proxyProtocol: proxyProtocol,
|
||||
}
|
||||
|
||||
d.dialers[name] = tcpDialer{dialer, time.Duration(cfg.TerminationDelay)}
|
||||
d.dialersTLS[name] = tcpDialer{tlsDialer, time.Duration(cfg.TerminationDelay)}
|
||||
|
||||
return nil
|
||||
if !isTLS {
|
||||
return dialer, nil
|
||||
}
|
||||
return tcpTLSDialer{dialer, tlsConfig}, nil
|
||||
}
|
||||
|
||||
func createRootCACertPool(rootCAs []types.FileOrContent) *x509.CertPool {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue