Improve TLS Handshake
This commit is contained in:
parent
2303301d38
commit
689f120410
20 changed files with 819 additions and 60 deletions
|
@ -15,7 +15,6 @@ import (
|
|||
"os"
|
||||
"os/signal"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -79,7 +78,7 @@ type serverEntryPoint struct {
|
|||
httpServer *h2c.Server
|
||||
listener net.Listener
|
||||
httpRouter *middlewares.HandlerSwitcher
|
||||
certs *safe.Safe
|
||||
certs *traefiktls.CertificateStore
|
||||
onDemandListener func(string) (*tls.Certificate, error)
|
||||
tlsALPNGetter func(string) (*tls.Certificate, error)
|
||||
}
|
||||
|
@ -276,19 +275,13 @@ func (s *Server) AddListener(listener func(types.Configuration)) {
|
|||
|
||||
// getCertificate allows to customize tlsConfig.GetCertificate behaviour to get the certificates inserted dynamically
|
||||
func (s *serverEntryPoint) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
domainToCheck := types.CanonicalDomain(clientHello.ServerName)
|
||||
|
||||
if s.certs.Get() != nil {
|
||||
for domains, cert := range s.certs.Get().(map[string]*tls.Certificate) {
|
||||
for _, certDomain := range strings.Split(domains, ",") {
|
||||
if types.MatchDomain(domainToCheck, certDomain) {
|
||||
return cert, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
log.Debugf("No certificate provided dynamically can check the domain %q, a per default certificate will be used.", domainToCheck)
|
||||
bestCertificate := s.certs.GetBestCertificate(clientHello)
|
||||
if bestCertificate != nil {
|
||||
return bestCertificate, nil
|
||||
}
|
||||
|
||||
domainToCheck := types.CanonicalDomain(clientHello.ServerName)
|
||||
|
||||
if s.tlsALPNGetter != nil {
|
||||
cert, err := s.tlsALPNGetter(domainToCheck)
|
||||
if err != nil {
|
||||
|
@ -300,11 +293,17 @@ func (s *serverEntryPoint) getCertificate(clientHello *tls.ClientHelloInfo) (*tl
|
|||
}
|
||||
}
|
||||
|
||||
if s.onDemandListener != nil {
|
||||
if s.onDemandListener != nil && len(domainToCheck) > 0 {
|
||||
// Only check for an onDemandCert if there is a domain name
|
||||
return s.onDemandListener(domainToCheck)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
if s.certs.SniStrict {
|
||||
return nil, fmt.Errorf("strict SNI enabled - No certificate found for domain: %q, closing connection", domainToCheck)
|
||||
}
|
||||
|
||||
log.Debugf("Serving default cert for request: %q", domainToCheck)
|
||||
return s.certs.DefaultCertificate, nil
|
||||
}
|
||||
|
||||
func (s *Server) startProvider() {
|
||||
|
@ -335,7 +334,7 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL
|
|||
return nil, err
|
||||
}
|
||||
|
||||
s.serverEntryPoints[entryPointName].certs.Set(make(map[string]*tls.Certificate))
|
||||
s.serverEntryPoints[entryPointName].certs.DynamicCerts.Set(make(map[string]*tls.Certificate))
|
||||
|
||||
// ensure http2 enabled
|
||||
config.NextProtos = []string{"h2", "http/1.1", acme.ACMETLS1Protocol}
|
||||
|
@ -345,6 +344,7 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL
|
|||
tlsOption.ClientCA.Files = tlsOption.ClientCAFiles
|
||||
tlsOption.ClientCA.Optional = false
|
||||
}
|
||||
|
||||
if len(tlsOption.ClientCA.Files) > 0 {
|
||||
pool := x509.NewCertPool()
|
||||
for _, caFile := range tlsOption.ClientCA.Files {
|
||||
|
@ -376,7 +376,7 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL
|
|||
return false
|
||||
}
|
||||
|
||||
err := s.globalConfiguration.ACME.CreateClusterConfig(s.leadership, config, s.serverEntryPoints[entryPointName].certs, checkOnDemandDomain)
|
||||
err := s.globalConfiguration.ACME.CreateClusterConfig(s.leadership, config, s.serverEntryPoints[entryPointName].certs.DynamicCerts, checkOnDemandDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -385,17 +385,16 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL
|
|||
config.GetCertificate = s.serverEntryPoints[entryPointName].getCertificate
|
||||
}
|
||||
|
||||
if len(config.Certificates) == 0 {
|
||||
return nil, fmt.Errorf("no certificates found for TLS entrypoint %s", entryPointName)
|
||||
if len(config.Certificates) != 0 {
|
||||
certMap := s.buildNameOrIPToCertificate(config.Certificates)
|
||||
|
||||
if s.entryPoints[entryPointName].CertificateStore != nil {
|
||||
s.entryPoints[entryPointName].CertificateStore.StaticCerts.Set(certMap)
|
||||
}
|
||||
}
|
||||
|
||||
// BuildNameToCertificate parses the CommonName and SubjectAlternateName fields
|
||||
// in each certificate and populates the config.NameToCertificate map.
|
||||
config.BuildNameToCertificate()
|
||||
|
||||
if s.entryPoints[entryPointName].CertificateStore != nil {
|
||||
s.entryPoints[entryPointName].CertificateStore.StaticCerts.Set(config.NameToCertificate)
|
||||
}
|
||||
// Remove certs from the TLS config object
|
||||
config.Certificates = []tls.Certificate{}
|
||||
|
||||
// Set the minimum TLS version if set in the config TOML
|
||||
if minConst, exists := traefiktls.MinVersion[s.entryPoints[entryPointName].Configuration.TLS.MinVersion]; exists {
|
||||
|
@ -593,3 +592,24 @@ func stopMetricsClients() {
|
|||
metrics.StopStatsd()
|
||||
metrics.StopInfluxDB()
|
||||
}
|
||||
|
||||
func (s *Server) buildNameOrIPToCertificate(certs []tls.Certificate) map[string]*tls.Certificate {
|
||||
certMap := make(map[string]*tls.Certificate)
|
||||
for i := range certs {
|
||||
cert := &certs[i]
|
||||
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if len(x509Cert.Subject.CommonName) > 0 {
|
||||
certMap[x509Cert.Subject.CommonName] = cert
|
||||
}
|
||||
for _, san := range x509Cert.DNSNames {
|
||||
certMap[san] = cert
|
||||
}
|
||||
for _, ipSan := range x509Cert.IPAddresses {
|
||||
certMap[ipSan.String()] = cert
|
||||
}
|
||||
}
|
||||
return certMap
|
||||
}
|
||||
|
|
|
@ -19,8 +19,8 @@ import (
|
|||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/middlewares/pipelining"
|
||||
"github.com/containous/traefik/rules"
|
||||
"github.com/containous/traefik/safe"
|
||||
traefiktls "github.com/containous/traefik/tls"
|
||||
"github.com/containous/traefik/tls/generate"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/eapache/channels"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
@ -55,11 +55,12 @@ func (s *Server) loadConfiguration(configMsg types.ConfigMessage) {
|
|||
s.serverEntryPoints[newServerEntryPointName].httpRouter.UpdateHandler(newServerEntryPoint.httpRouter.GetHandler())
|
||||
|
||||
if s.entryPoints[newServerEntryPointName].Configuration.TLS == nil {
|
||||
if newServerEntryPoint.certs.Get() != nil {
|
||||
if newServerEntryPoint.certs.ContainsCertificates() {
|
||||
log.Debugf("Certificates not added to non-TLS entryPoint %s.", newServerEntryPointName)
|
||||
}
|
||||
} else {
|
||||
s.serverEntryPoints[newServerEntryPointName].certs.Set(newServerEntryPoint.certs.Get())
|
||||
s.serverEntryPoints[newServerEntryPointName].certs.DynamicCerts.Set(newServerEntryPoint.certs.DynamicCerts.Get())
|
||||
s.serverEntryPoints[newServerEntryPointName].certs.ResetCache()
|
||||
}
|
||||
log.Infof("Server configuration reloaded on %s", s.serverEntryPoints[newServerEntryPointName].httpServer.Addr)
|
||||
}
|
||||
|
@ -123,7 +124,7 @@ func (s *Server) loadConfig(configurations types.Configurations, globalConfigura
|
|||
for serverEntryPointName, serverEntryPoint := range serverEntryPoints {
|
||||
serverEntryPoint.httpRouter.GetHandler().SortRoutes()
|
||||
if _, exists := entryPointsCertificates[serverEntryPointName]; exists {
|
||||
serverEntryPoint.certs.Set(entryPointsCertificates[serverEntryPointName])
|
||||
serverEntryPoint.certs.DynamicCerts.Set(entryPointsCertificates[serverEntryPointName])
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -559,10 +560,33 @@ func (s *Server) buildServerEntryPoints() map[string]*serverEntryPoint {
|
|||
onDemandListener: entryPoint.OnDemandListener,
|
||||
tlsALPNGetter: entryPoint.TLSALPNGetter,
|
||||
}
|
||||
|
||||
if entryPoint.CertificateStore != nil {
|
||||
serverEntryPoints[entryPointName].certs = entryPoint.CertificateStore.DynamicCerts
|
||||
serverEntryPoints[entryPointName].certs = entryPoint.CertificateStore
|
||||
} else {
|
||||
serverEntryPoints[entryPointName].certs = &safe.Safe{}
|
||||
serverEntryPoints[entryPointName].certs = traefiktls.NewCertificateStore()
|
||||
}
|
||||
|
||||
if entryPoint.Configuration.TLS != nil {
|
||||
serverEntryPoints[entryPointName].certs.SniStrict = entryPoint.Configuration.TLS.SniStrict
|
||||
|
||||
if entryPoint.Configuration.TLS.DefaultCertificate != nil {
|
||||
cert, err := tls.LoadX509KeyPair(entryPoint.Configuration.TLS.DefaultCertificate.CertFile.String(), entryPoint.Configuration.TLS.DefaultCertificate.KeyFile.String())
|
||||
if err != nil {
|
||||
}
|
||||
serverEntryPoints[entryPointName].certs.DefaultCertificate = &cert
|
||||
} else {
|
||||
cert, err := generate.DefaultCertificate()
|
||||
if err != nil {
|
||||
}
|
||||
serverEntryPoints[entryPointName].certs.DefaultCertificate = cert
|
||||
}
|
||||
if len(entryPoint.Configuration.TLS.Certificates) > 0 {
|
||||
config, _ := entryPoint.Configuration.TLS.Certificates.CreateTLSConfig(entryPointName)
|
||||
certMap := s.buildNameOrIPToCertificate(config.Certificates)
|
||||
serverEntryPoints[entryPointName].certs.StaticCerts.Set(certMap)
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
return serverEntryPoints
|
||||
|
|
|
@ -215,7 +215,7 @@ func TestServerLoadCertificateWithDefaultEntryPoint(t *testing.T) {
|
|||
srv := NewServer(globalConfig, nil, entryPoints)
|
||||
if mapEntryPoints, err := srv.loadConfig(dynamicConfigs, globalConfig); err != nil {
|
||||
t.Fatalf("got error: %s", err)
|
||||
} else if mapEntryPoints["https"].certs.Get() == nil {
|
||||
} else if !mapEntryPoints["https"].certs.ContainsCertificates() {
|
||||
t.Fatal("got error: https entryPoint must have TLS certificates.")
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue