Improve TLS Handshake

This commit is contained in:
Daniel Tomcej 2018-07-06 02:30:03 -06:00 committed by Traefiker Bot
parent 2303301d38
commit 689f120410
20 changed files with 819 additions and 60 deletions

View file

@ -15,7 +15,6 @@ import (
"os"
"os/signal"
"reflect"
"strings"
"sync"
"time"
@ -79,7 +78,7 @@ type serverEntryPoint struct {
httpServer *h2c.Server
listener net.Listener
httpRouter *middlewares.HandlerSwitcher
certs *safe.Safe
certs *traefiktls.CertificateStore
onDemandListener func(string) (*tls.Certificate, error)
tlsALPNGetter func(string) (*tls.Certificate, error)
}
@ -276,19 +275,13 @@ func (s *Server) AddListener(listener func(types.Configuration)) {
// getCertificate allows to customize tlsConfig.GetCertificate behaviour to get the certificates inserted dynamically
func (s *serverEntryPoint) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
domainToCheck := types.CanonicalDomain(clientHello.ServerName)
if s.certs.Get() != nil {
for domains, cert := range s.certs.Get().(map[string]*tls.Certificate) {
for _, certDomain := range strings.Split(domains, ",") {
if types.MatchDomain(domainToCheck, certDomain) {
return cert, nil
}
}
}
log.Debugf("No certificate provided dynamically can check the domain %q, a per default certificate will be used.", domainToCheck)
bestCertificate := s.certs.GetBestCertificate(clientHello)
if bestCertificate != nil {
return bestCertificate, nil
}
domainToCheck := types.CanonicalDomain(clientHello.ServerName)
if s.tlsALPNGetter != nil {
cert, err := s.tlsALPNGetter(domainToCheck)
if err != nil {
@ -300,11 +293,17 @@ func (s *serverEntryPoint) getCertificate(clientHello *tls.ClientHelloInfo) (*tl
}
}
if s.onDemandListener != nil {
if s.onDemandListener != nil && len(domainToCheck) > 0 {
// Only check for an onDemandCert if there is a domain name
return s.onDemandListener(domainToCheck)
}
return nil, nil
if s.certs.SniStrict {
return nil, fmt.Errorf("strict SNI enabled - No certificate found for domain: %q, closing connection", domainToCheck)
}
log.Debugf("Serving default cert for request: %q", domainToCheck)
return s.certs.DefaultCertificate, nil
}
func (s *Server) startProvider() {
@ -335,7 +334,7 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL
return nil, err
}
s.serverEntryPoints[entryPointName].certs.Set(make(map[string]*tls.Certificate))
s.serverEntryPoints[entryPointName].certs.DynamicCerts.Set(make(map[string]*tls.Certificate))
// ensure http2 enabled
config.NextProtos = []string{"h2", "http/1.1", acme.ACMETLS1Protocol}
@ -345,6 +344,7 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL
tlsOption.ClientCA.Files = tlsOption.ClientCAFiles
tlsOption.ClientCA.Optional = false
}
if len(tlsOption.ClientCA.Files) > 0 {
pool := x509.NewCertPool()
for _, caFile := range tlsOption.ClientCA.Files {
@ -376,7 +376,7 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL
return false
}
err := s.globalConfiguration.ACME.CreateClusterConfig(s.leadership, config, s.serverEntryPoints[entryPointName].certs, checkOnDemandDomain)
err := s.globalConfiguration.ACME.CreateClusterConfig(s.leadership, config, s.serverEntryPoints[entryPointName].certs.DynamicCerts, checkOnDemandDomain)
if err != nil {
return nil, err
}
@ -385,17 +385,16 @@ func (s *Server) createTLSConfig(entryPointName string, tlsOption *traefiktls.TL
config.GetCertificate = s.serverEntryPoints[entryPointName].getCertificate
}
if len(config.Certificates) == 0 {
return nil, fmt.Errorf("no certificates found for TLS entrypoint %s", entryPointName)
if len(config.Certificates) != 0 {
certMap := s.buildNameOrIPToCertificate(config.Certificates)
if s.entryPoints[entryPointName].CertificateStore != nil {
s.entryPoints[entryPointName].CertificateStore.StaticCerts.Set(certMap)
}
}
// BuildNameToCertificate parses the CommonName and SubjectAlternateName fields
// in each certificate and populates the config.NameToCertificate map.
config.BuildNameToCertificate()
if s.entryPoints[entryPointName].CertificateStore != nil {
s.entryPoints[entryPointName].CertificateStore.StaticCerts.Set(config.NameToCertificate)
}
// Remove certs from the TLS config object
config.Certificates = []tls.Certificate{}
// Set the minimum TLS version if set in the config TOML
if minConst, exists := traefiktls.MinVersion[s.entryPoints[entryPointName].Configuration.TLS.MinVersion]; exists {
@ -593,3 +592,24 @@ func stopMetricsClients() {
metrics.StopStatsd()
metrics.StopInfluxDB()
}
func (s *Server) buildNameOrIPToCertificate(certs []tls.Certificate) map[string]*tls.Certificate {
certMap := make(map[string]*tls.Certificate)
for i := range certs {
cert := &certs[i]
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
continue
}
if len(x509Cert.Subject.CommonName) > 0 {
certMap[x509Cert.Subject.CommonName] = cert
}
for _, san := range x509Cert.DNSNames {
certMap[san] = cert
}
for _, ipSan := range x509Cert.IPAddresses {
certMap[ipSan.String()] = cert
}
}
return certMap
}