Make the TLS certificates management dynamic.

This commit is contained in:
NicoMen 2017-11-09 12:16:03 +01:00 committed by Traefiker
parent f6aa147c78
commit c469e669fd
36 changed files with 1257 additions and 513 deletions

View file

@ -33,6 +33,7 @@ import (
"github.com/containous/traefik/provider"
"github.com/containous/traefik/safe"
"github.com/containous/traefik/server/cookie"
traefikTls "github.com/containous/traefik/tls"
"github.com/containous/traefik/types"
"github.com/containous/traefik/whitelist"
"github.com/streamrail/concurrent-map"
@ -66,6 +67,8 @@ type Server struct {
leadership *cluster.Leadership
defaultForwardingRoundTripper http.RoundTripper
metricsRegistry metrics.Registry
lastReceivedConfiguration *safe.Safe
lastConfigs cmap.ConcurrentMap
}
type serverEntryPoints map[string]*serverEntryPoint
@ -74,6 +77,7 @@ type serverEntryPoint struct {
httpServer *http.Server
listener net.Listener
httpRouter *middlewares.HandlerSwitcher
certs safe.Safe
}
type serverRoute struct {
@ -101,6 +105,8 @@ func NewServer(globalConfiguration configuration.GlobalConfiguration) *Server {
server.globalConfiguration = globalConfiguration
server.routinesPool = safe.NewPool(context.Background())
server.defaultForwardingRoundTripper = createHTTPTransport(globalConfiguration)
server.lastReceivedConfiguration = safe.New(time.Unix(0, 0))
server.lastConfigs = cmap.New()
server.metricsRegistry = metrics.NewVoidRegistry()
if globalConfiguration.Web != nil && globalConfiguration.Web.Metrics != nil {
@ -165,7 +171,7 @@ func createHTTPTransport(globalConfiguration configuration.GlobalConfiguration)
return transport
}
func createRootCACertPool(rootCAs configuration.RootCAs) *x509.CertPool {
func createRootCACertPool(rootCAs traefikTls.RootCAs) *x509.CertPool {
roots := x509.NewCertPool()
for _, cert := range rootCAs {
@ -317,51 +323,54 @@ func (server *Server) setupServerEntryPoint(newServerEntryPointName string, newS
}
func (server *Server) listenProviders(stop chan bool) {
lastReceivedConfiguration := safe.New(time.Unix(0, 0))
lastConfigs := cmap.New()
for {
select {
case <-stop:
return
case configMsg, ok := <-server.configurationChan:
if !ok {
if !ok || configMsg.Configuration == nil {
return
}
server.defaultConfigurationValues(configMsg.Configuration)
currentConfigurations := server.currentConfigurations.Get().(types.Configurations)
jsonConf, _ := json.Marshal(configMsg.Configuration)
log.Debugf("Configuration received from provider %s: %s", configMsg.ProviderName, string(jsonConf))
if configMsg.Configuration == nil || configMsg.Configuration.Backends == nil && configMsg.Configuration.Frontends == nil {
log.Infof("Skipping empty Configuration for provider %s", configMsg.ProviderName)
} else if reflect.DeepEqual(currentConfigurations[configMsg.ProviderName], configMsg.Configuration) {
log.Infof("Skipping same configuration for provider %s", configMsg.ProviderName)
} else {
lastConfigs.Set(configMsg.ProviderName, &configMsg)
lastReceivedConfigurationValue := lastReceivedConfiguration.Get().(time.Time)
providersThrottleDuration := time.Duration(server.globalConfiguration.ProvidersThrottleDuration)
if time.Now().After(lastReceivedConfigurationValue.Add(providersThrottleDuration)) {
log.Debugf("Last %s config received more than %s, OK", configMsg.ProviderName, server.globalConfiguration.ProvidersThrottleDuration.String())
// last config received more than n s ago
server.configurationValidatedChan <- configMsg
} else {
log.Debugf("Last %s config received less than %s, waiting...", configMsg.ProviderName, server.globalConfiguration.ProvidersThrottleDuration)
safe.Go(func() {
<-time.After(providersThrottleDuration)
lastReceivedConfigurationValue := lastReceivedConfiguration.Get().(time.Time)
if time.Now().After(lastReceivedConfigurationValue.Add(time.Duration(providersThrottleDuration))) {
log.Debugf("Waited for %s config, OK", configMsg.ProviderName)
if lastConfig, ok := lastConfigs.Get(configMsg.ProviderName); ok {
server.configurationValidatedChan <- *lastConfig.(*types.ConfigMessage)
}
}
})
}
lastReceivedConfiguration.Set(time.Now())
}
server.preLoadConfiguration(configMsg)
}
}
}
func (server *Server) preLoadConfiguration(configMsg types.ConfigMessage) {
server.defaultConfigurationValues(configMsg.Configuration)
currentConfigurations := server.currentConfigurations.Get().(types.Configurations)
jsonConf, _ := json.Marshal(configMsg.Configuration)
log.Debugf("Configuration received from provider %s: %s", configMsg.ProviderName, string(jsonConf))
if configMsg.Configuration == nil || configMsg.Configuration.Backends == nil && configMsg.Configuration.Frontends == nil && configMsg.Configuration.TLSConfiguration == nil {
log.Infof("Skipping empty Configuration for provider %s", configMsg.ProviderName)
} else if reflect.DeepEqual(currentConfigurations[configMsg.ProviderName], configMsg.Configuration) {
log.Infof("Skipping same configuration for provider %s", configMsg.ProviderName)
} else {
server.lastConfigs.Set(configMsg.ProviderName, &configMsg)
lastReceivedConfigurationValue := server.lastReceivedConfiguration.Get().(time.Time)
providersThrottleDuration := time.Duration(server.globalConfiguration.ProvidersThrottleDuration)
if time.Now().After(lastReceivedConfigurationValue.Add(providersThrottleDuration)) {
log.Debugf("Last %s configuration received more than %s, OK", configMsg.ProviderName, server.globalConfiguration.ProvidersThrottleDuration.String())
// last config received more than n server ago
server.configurationValidatedChan <- configMsg
} else {
log.Debugf("Last %s configuration received less than %s, waiting...", configMsg.ProviderName, server.globalConfiguration.ProvidersThrottleDuration.String())
safe.Go(func() {
<-time.After(providersThrottleDuration)
lastReceivedConfigurationValue := server.lastReceivedConfiguration.Get().(time.Time)
if time.Now().After(lastReceivedConfigurationValue.Add(time.Duration(providersThrottleDuration))) {
log.Debugf("Waited for %s configuration, OK", configMsg.ProviderName)
if lastConfig, ok := server.lastConfigs.Get(configMsg.ProviderName); ok {
server.configurationValidatedChan <- *lastConfig.(*types.ConfigMessage)
}
}
})
}
// Update the last configuration loading time
server.lastReceivedConfiguration.Set(time.Now())
}
}
func (server *Server) defaultConfigurationValues(configuration *types.Configuration) {
if configuration == nil || configuration.Frontends == nil {
return
@ -376,34 +385,73 @@ func (server *Server) listenConfigurations(stop chan bool) {
case <-stop:
return
case configMsg, ok := <-server.configurationValidatedChan:
if !ok {
if !ok || configMsg.Configuration == nil {
return
}
currentConfigurations := server.currentConfigurations.Get().(types.Configurations)
// Copy configurations to new map so we don't change current if LoadConfig fails
newConfigurations := make(types.Configurations)
for k, v := range currentConfigurations {
newConfigurations[k] = v
}
newConfigurations[configMsg.ProviderName] = configMsg.Configuration
newServerEntryPoints, err := server.loadConfig(newConfigurations, server.globalConfiguration)
if err == nil {
for newServerEntryPointName, newServerEntryPoint := range newServerEntryPoints {
server.serverEntryPoints[newServerEntryPointName].httpRouter.UpdateHandler(newServerEntryPoint.httpRouter.GetHandler())
log.Infof("Server configuration reloaded on %s", server.serverEntryPoints[newServerEntryPointName].httpServer.Addr)
}
server.currentConfigurations.Set(newConfigurations)
server.postLoadConfig()
} else {
log.Error("Error loading new configuration, aborted ", err)
}
server.loadConfiguration(configMsg)
}
}
}
func (server *Server) postLoadConfig() {
// loadConfiguration manages dynamically frontends, backends and TLS configurations
func (server *Server) loadConfiguration(configMsg types.ConfigMessage) {
currentConfigurations := server.currentConfigurations.Get().(types.Configurations)
// Copy configurations to new map so we don't change current if LoadConfig fails
newConfigurations := make(types.Configurations)
for k, v := range currentConfigurations {
newConfigurations[k] = v
}
newConfigurations[configMsg.ProviderName] = configMsg.Configuration
newServerEntryPoints, err := server.loadConfig(newConfigurations, server.globalConfiguration)
if err == nil {
for newServerEntryPointName, newServerEntryPoint := range newServerEntryPoints {
server.serverEntryPoints[newServerEntryPointName].httpRouter.UpdateHandler(newServerEntryPoint.httpRouter.GetHandler())
if &newServerEntryPoint.certs != nil {
server.serverEntryPoints[newServerEntryPointName].certs.Set(newServerEntryPoint.certs.Get())
}
log.Infof("Server configuration reloaded on %s", server.serverEntryPoints[newServerEntryPointName].httpServer.Addr)
}
server.currentConfigurations.Set(newConfigurations)
server.postLoadConfiguration()
} else {
log.Error("Error loading new configuration, aborted ", err)
}
}
// loadHTTPSConfiguration add/delete HTTPS certificate managed dynamically
func (server *Server) loadHTTPSConfiguration(configurations types.Configurations) (map[string]*traefikTls.DomainsCertificates, error) {
newEPCertificates := make(map[string]*traefikTls.DomainsCertificates)
// Get all certificates
for _, configuration := range configurations {
if configuration.TLSConfiguration != nil && len(configuration.TLSConfiguration) > 0 {
if err := traefikTls.SortTLSConfigurationPerEntryPoints(configuration.TLSConfiguration, newEPCertificates); err != nil {
return nil, err
}
}
}
return newEPCertificates, nil
}
// getCertificate allows to customize tlsConfig.Getcertificate behaviour to get the certificates inserted dynamically
func (s *serverEntryPoint) getCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
if s.certs.Get() != nil {
domainToCheck := types.CanonicalDomain(clientHello.ServerName)
for domains, cert := range *s.certs.Get().(*traefikTls.DomainsCertificates) {
for _, domain := range strings.Split(domains, ",") {
selector := "^" + strings.Replace(domain, "*.", "[^\\.]*\\.?", -1) + "$"
domainCheck, _ := regexp.MatchString(selector, domainToCheck)
if domainCheck {
return cert, nil
}
}
}
}
return nil, nil
}
func (server *Server) postLoadConfiguration() {
if server.globalConfiguration.ACME == nil {
return
}
@ -418,8 +466,8 @@ func (server *Server) postLoadConfig() {
// check if one of the frontend entrypoints is configured with TLS
// and is configured with ACME
ACMEEnabled := false
for _, entrypoint := range frontend.EntryPoints {
if server.globalConfiguration.ACME.EntryPoint == entrypoint && server.globalConfiguration.EntryPoints[entrypoint].TLS != nil {
for _, entryPoint := range frontend.EntryPoints {
if server.globalConfiguration.ACME.EntryPoint == entryPoint && server.globalConfiguration.EntryPoints[entryPoint].TLS != nil {
ACMEEnabled = true
break
}
@ -508,12 +556,12 @@ func (server *Server) startProviders() {
}
}
func createClientTLSConfig(tlsOption *configuration.TLS) (*tls.Config, error) {
func createClientTLSConfig(entryPointName string, tlsOption *traefikTls.TLS) (*tls.Config, error) {
if tlsOption == nil {
return nil, errors.New("no TLS provided")
}
config, err := tlsOption.Certificates.CreateTLSConfig()
config, _, err := tlsOption.Certificates.CreateTLSConfig(entryPointName)
if err != nil {
return nil, err
}
@ -536,16 +584,22 @@ func createClientTLSConfig(tlsOption *configuration.TLS) (*tls.Config, error) {
}
// creates a TLS config that allows terminating HTTPS for multiple domains using SNI
func (server *Server) createTLSConfig(entryPointName string, tlsOption *configuration.TLS, router *middlewares.HandlerSwitcher) (*tls.Config, error) {
func (server *Server) createTLSConfig(entryPointName string, tlsOption *traefikTls.TLS, router *middlewares.HandlerSwitcher) (*tls.Config, error) {
if tlsOption == nil {
return nil, nil
}
config, err := tlsOption.Certificates.CreateTLSConfig()
config, epDomainsCertificates, err := tlsOption.Certificates.CreateTLSConfig(entryPointName)
if err != nil {
return nil, err
}
epDomainsCertificatesTmp := new(traefikTls.DomainsCertificates)
if epDomainsCertificates[entryPointName] != nil {
epDomainsCertificatesTmp = epDomainsCertificates[entryPointName]
} else {
*epDomainsCertificatesTmp = make(map[string]*tls.Certificate)
}
server.serverEntryPoints[entryPointName].certs.Set(epDomainsCertificatesTmp)
// ensure http2 enabled
config.NextProtos = []string{"h2", "http/1.1"}
@ -578,12 +632,12 @@ func (server *Server) createTLSConfig(entryPointName string, tlsOption *configur
return false
}
if server.leadership == nil {
err := server.globalConfiguration.ACME.CreateLocalConfig(config, checkOnDemandDomain)
err := server.globalConfiguration.ACME.CreateLocalConfig(config, &server.serverEntryPoints[entryPointName].certs, checkOnDemandDomain)
if err != nil {
return nil, err
}
} else {
err := server.globalConfiguration.ACME.CreateClusterConfig(server.leadership, config, checkOnDemandDomain)
err := server.globalConfiguration.ACME.CreateClusterConfig(server.leadership, config, &server.serverEntryPoints[entryPointName].certs, checkOnDemandDomain)
if err != nil {
return nil, err
}
@ -592,6 +646,8 @@ func (server *Server) createTLSConfig(entryPointName string, tlsOption *configur
} else {
return nil, errors.New("Unknown entrypoint " + server.globalConfiguration.ACME.EntryPoint + " for ACME configuration")
}
} else {
config.GetCertificate = server.serverEntryPoints[entryPointName].getCertificate
}
if len(config.Certificates) == 0 {
return nil, errors.New("No certificates found for TLS entrypoint " + entryPointName)
@ -600,7 +656,7 @@ func (server *Server) createTLSConfig(entryPointName string, tlsOption *configur
// in each certificate and populates the config.NameToCertificate map.
config.BuildNameToCertificate()
//Set the minimum TLS version if set in the config TOML
if minConst, exists := configuration.MinVersion[server.globalConfiguration.EntryPoints[entryPointName].TLS.MinVersion]; exists {
if minConst, exists := traefikTls.MinVersion[server.globalConfiguration.EntryPoints[entryPointName].TLS.MinVersion]; exists {
config.PreferServerCipherSuites = true
config.MinVersion = minConst
}
@ -609,7 +665,7 @@ func (server *Server) createTLSConfig(entryPointName string, tlsOption *configur
//if our list of CipherSuites is defined in the entrypoint config, we can re-initilize the suites list as empty
config.CipherSuites = make([]uint16, 0)
for _, cipher := range server.globalConfiguration.EntryPoints[entryPointName].TLS.CipherSuites {
if cipherConst, exists := configuration.CipherSuites[cipher]; exists {
if cipherConst, exists := traefikTls.CipherSuites[cipher]; exists {
config.CipherSuites = append(config.CipherSuites, cipherConst)
} else {
//CipherSuite listed in the toml does not exist in our listed
@ -617,7 +673,6 @@ func (server *Server) createTLSConfig(entryPointName string, tlsOption *configur
}
}
}
return config, nil
}
@ -721,9 +776,9 @@ func (server *Server) buildEntryPoints(globalConfiguration configuration.GlobalC
// getRoundTripper will either use server.defaultForwardingRoundTripper or create a new one
// given a custom TLS configuration is passed and the passTLSCert option is set to true.
func (server *Server) getRoundTripper(globalConfiguration configuration.GlobalConfiguration, passTLSCert bool, tls *configuration.TLS) (http.RoundTripper, error) {
func (server *Server) getRoundTripper(entryPointName string, globalConfiguration configuration.GlobalConfiguration, passTLSCert bool, tls *traefikTls.TLS) (http.RoundTripper, error) {
if passTLSCert {
tlsConfig, err := createClientTLSConfig(tls)
tlsConfig, err := createClientTLSConfig(entryPointName, tls)
if err != nil {
log.Errorf("Failed to create TLSClientConfig: %s", err)
return nil, err
@ -802,7 +857,7 @@ func (server *Server) loadConfig(configurations types.Configurations, globalConf
if backends[entryPointName+frontend.Backend] == nil {
log.Debugf("Creating backend %s", frontend.Backend)
roundTripper, err := server.getRoundTripper(globalConfiguration, frontend.PassTLSCert, entryPoint.TLS)
roundTripper, err := server.getRoundTripper(entryPointName, globalConfiguration, frontend.PassTLSCert, entryPoint.TLS)
if err != nil {
log.Errorf("Failed to create RoundTripper for frontend %s: %v", frontendName, err)
log.Errorf("Skipping frontend %s...", frontendName)
@ -1019,11 +1074,19 @@ func (server *Server) loadConfig(configurations types.Configurations, globalConf
}
}
healthcheck.GetHealthCheck().SetBackendsConfiguration(server.routinesPool.Ctx(), backendsHealthCheck)
//sort routes
for _, serverEntryPoint := range serverEntryPoints {
// Get new certificates list sorted per entrypoints
// Update certificates
entryPointsCertificates, err := server.loadHTTPSConfiguration(configurations)
//sort routes and update certificates
for serverEntryPointName, serverEntryPoint := range serverEntryPoints {
serverEntryPoint.httpRouter.GetHandler().SortRoutes()
_, exists := entryPointsCertificates[serverEntryPointName]
if exists {
serverEntryPoint.certs.Set(entryPointsCertificates[serverEntryPointName])
}
}
return serverEntryPoints, nil
return serverEntryPoints, err
}
func configureLBServers(lb healthcheck.LoadBalancer, config *types.Configuration, frontend *types.Frontend) error {