Allow handling ACME challenges with custom routers

This commit is contained in:
Romain 2024-09-13 15:54:04 +02:00 committed by GitHub
parent d547b943df
commit 0cf2032c15
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 142 additions and 24 deletions

View file

@ -21,6 +21,8 @@ const defaultBufSize = 4096
// Router is a TCP router.
type Router struct {
acmeTLSPassthrough bool
// Contains TCP routes.
muxerTCP tcpmuxer.Muxer
// Contains TCP TLS routes.
@ -148,7 +150,7 @@ func (r *Router) ServeTCP(conn tcp.WriteCloser) {
}
// Handling ACME-TLS/1 challenges.
if slices.Contains(hello.protos, tlsalpn01.ACMETLS1Protocol) {
if !r.acmeTLSPassthrough && slices.Contains(hello.protos, tlsalpn01.ACMETLS1Protocol) {
r.acmeTLSALPNHandler().ServeTCP(r.GetConn(conn, hello.peeked))
return
}
@ -303,6 +305,10 @@ func (r *Router) SetHTTPSHandler(handler http.Handler, config *tls.Config) {
r.httpsTLSConfig = config
}
func (r *Router) EnableACMETLSPassthrough() {
r.acmeTLSPassthrough = true
}
// Conn is a connection proxy that handles Peeked bytes.
type Conn struct {
// Peeked are the bytes that have been read from Conn for the purposes of route matching,

View file

@ -209,9 +209,10 @@ func Test_Routing(t *testing.T) {
}
testCases := []struct {
desc string
routers []applyRouter
checks []checkCase
desc string
routers []applyRouter
checks []checkCase
allowACMETLSPassthrough bool
}{
{
desc: "No routers",
@ -268,6 +269,18 @@ func Test_Routing(t *testing.T) {
},
},
},
{
desc: "TCP TLS passthrough catches ACME TLS",
allowACMETLSPassthrough: true,
routers: []applyRouter{routerTCPTLSCatchAllPassthrough},
checks: []checkCase{
{
desc: "ACME TLS Challenge",
checkRouter: checkACMETLS,
expectedError: "tls: first record does not look like a TLS handshake",
},
},
},
{
desc: "Single TCP CatchAll router",
routers: []applyRouter{routerTCPCatchAll},
@ -578,6 +591,10 @@ func Test_Routing(t *testing.T) {
router, err := manager.buildEntryPointHandler(context.Background(), dynConf.TCPRouters, dynConf.Routers, nil, nil)
require.NoError(t, err)
if test.allowACMETLSPassthrough {
router.EnableACMETLSPassthrough()
}
epListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
@ -699,7 +716,7 @@ func routerTCPTLSCatchAll(conf *runtime.Configuration) {
}
}
// routerTCPTLSCatchAllPassthrough a TCP TLS CatchAll Passthrough - HostSNI(`*`) router with TLS 1.0 config.
// routerTCPTLSCatchAllPassthrough a TCP TLS CatchAll Passthrough - HostSNI(`*`) router with TLS 1.2 config.
func routerTCPTLSCatchAllPassthrough(conf *runtime.Configuration) {
conf.TCPRouters["tcp-tls-catchall-passthrough"] = &runtime.TCPRouterInfo{
TCPRouter: &dynamic.TCPRouter{

View file

@ -21,25 +21,37 @@ import (
// RouterFactory the factory of TCP/UDP routers.
type RouterFactory struct {
entryPointsTCP []string
entryPointsUDP []string
entryPointsTCP []string
entryPointsUDP []string
allowACMEByPass map[string]bool
managerFactory *service.ManagerFactory
managerFactory *service.ManagerFactory
metricsRegistry metrics.Registry
pluginBuilder middleware.PluginsBuilder
chainBuilder *middleware.ChainBuilder
tlsManager *tls.Manager
chainBuilder *middleware.ChainBuilder
tlsManager *tls.Manager
}
// NewRouterFactory creates a new RouterFactory.
func NewRouterFactory(staticConfiguration static.Configuration, managerFactory *service.ManagerFactory, tlsManager *tls.Manager,
chainBuilder *middleware.ChainBuilder, pluginBuilder middleware.PluginsBuilder, metricsRegistry metrics.Registry,
) *RouterFactory {
handlesTLSChallenge := false
for _, resolver := range staticConfiguration.CertificatesResolvers {
if resolver.ACME.TLSChallenge != nil {
handlesTLSChallenge = true
break
}
}
allowACMEByPass := map[string]bool{}
var entryPointsTCP, entryPointsUDP []string
for name, cfg := range staticConfiguration.EntryPoints {
protocol, err := cfg.GetProtocol()
for name, ep := range staticConfiguration.EntryPoints {
allowACMEByPass[name] = ep.AllowACMEByPass || !handlesTLSChallenge
protocol, err := ep.GetProtocol()
if err != nil {
// Should never happen because Traefik should not start if protocol is invalid.
log.WithoutContext().Errorf("Invalid protocol: %v", err)
@ -60,6 +72,7 @@ func NewRouterFactory(staticConfiguration static.Configuration, managerFactory *
tlsManager: tlsManager,
chainBuilder: chainBuilder,
pluginBuilder: pluginBuilder,
allowACMEByPass: allowACMEByPass,
}
}
@ -87,6 +100,12 @@ func (f *RouterFactory) CreateRouters(rtConf *runtime.Configuration) (map[string
rtTCPManager := tcprouter.NewManager(rtConf, svcTCPManager, middlewaresTCPBuilder, handlersNonTLS, handlersTLS, f.tlsManager)
routersTCP := rtTCPManager.BuildHandlers(ctx, f.entryPointsTCP)
for ep, r := range routersTCP {
if allowACMEByPass, ok := f.allowACMEByPass[ep]; ok && allowACMEByPass {
r.EnableACMETLSPassthrough()
}
}
// UDP
svcUDPManager := udp.NewManager(rtConf)
rtUDPManager := udprouter.NewManager(rtConf, svcUDPManager)

View file

@ -172,7 +172,10 @@ func NewTCPEntryPoint(ctx context.Context, configuration *static.EntryPoint, hos
return nil, fmt.Errorf("error preparing server: %w", err)
}
rt := &tcprouter.Router{}
rt, err := tcprouter.NewRouter()
if err != nil {
return nil, fmt.Errorf("error preparing tcp router: %w", err)
}
reqDecorator := requestdecorator.New(hostResolverConfig)

View file

@ -20,7 +20,9 @@ import (
)
func TestShutdownHijacked(t *testing.T) {
router := &tcprouter.Router{}
router, err := tcprouter.NewRouter()
require.NoError(t, err)
router.SetHTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
conn, _, err := rw.(http.Hijacker).Hijack()
require.NoError(t, err)
@ -34,7 +36,9 @@ func TestShutdownHijacked(t *testing.T) {
}
func TestShutdownHTTP(t *testing.T) {
router := &tcprouter.Router{}
router, err := tcprouter.NewRouter()
require.NoError(t, err)
router.SetHTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
time.Sleep(time.Second)
@ -167,7 +171,9 @@ func TestReadTimeoutWithoutFirstByte(t *testing.T) {
}, nil)
require.NoError(t, err)
router := &tcprouter.Router{}
router, err := tcprouter.NewRouter()
require.NoError(t, err)
router.SetHTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
}))
@ -204,7 +210,9 @@ func TestReadTimeoutWithFirstByte(t *testing.T) {
}, nil)
require.NoError(t, err)
router := &tcprouter.Router{}
router, err := tcprouter.NewRouter()
require.NoError(t, err)
router.SetHTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
}))
@ -244,7 +252,9 @@ func TestKeepAliveMaxRequests(t *testing.T) {
}, nil)
require.NoError(t, err)
router := &tcprouter.Router{}
router, err := tcprouter.NewRouter()
require.NoError(t, err)
router.SetHTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
}))
@ -290,7 +300,9 @@ func TestKeepAliveMaxTime(t *testing.T) {
}, nil)
require.NoError(t, err)
router := &tcprouter.Router{}
router, err := tcprouter.NewRouter()
require.NoError(t, err)
router.SetHTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
}))