Improve TLS Handshake
This commit is contained in:
parent
2303301d38
commit
689f120410
20 changed files with 819 additions and 60 deletions
|
@ -152,16 +152,28 @@ func (c *Certificate) AppendCertificates(certs map[string]map[string]*tls.Certif
|
|||
|
||||
parsedCert, _ := x509.ParseCertificate(tlsCert.Certificate[0])
|
||||
|
||||
certKey := parsedCert.Subject.CommonName
|
||||
var SANs []string
|
||||
if parsedCert.Subject.CommonName != "" {
|
||||
SANs = append(SANs, parsedCert.Subject.CommonName)
|
||||
}
|
||||
if parsedCert.DNSNames != nil {
|
||||
sort.Strings(parsedCert.DNSNames)
|
||||
for _, dnsName := range parsedCert.DNSNames {
|
||||
if dnsName != parsedCert.Subject.CommonName {
|
||||
certKey += fmt.Sprintf(",%s", dnsName)
|
||||
SANs = append(SANs, dnsName)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
if parsedCert.IPAddresses != nil {
|
||||
for _, ip := range parsedCert.IPAddresses {
|
||||
if ip.String() != parsedCert.Subject.CommonName {
|
||||
SANs = append(SANs, ip.String())
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
certKey := strings.Join(SANs, ",")
|
||||
|
||||
certExists := false
|
||||
if certs[ep] == nil {
|
||||
|
|
|
@ -2,14 +2,32 @@ package tls
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/safe"
|
||||
"github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
// CertificateStore store for dynamic and static certificates
|
||||
type CertificateStore struct {
|
||||
DynamicCerts *safe.Safe
|
||||
StaticCerts *safe.Safe
|
||||
DynamicCerts *safe.Safe
|
||||
StaticCerts *safe.Safe
|
||||
DefaultCertificate *tls.Certificate
|
||||
CertCache *cache.Cache
|
||||
SniStrict bool
|
||||
}
|
||||
|
||||
// NewCertificateStore create a store for dynamic and static certificates
|
||||
func NewCertificateStore() *CertificateStore {
|
||||
return &CertificateStore{
|
||||
StaticCerts: &safe.Safe{},
|
||||
DynamicCerts: &safe.Safe{},
|
||||
CertCache: cache.New(1*time.Hour, 10*time.Minute),
|
||||
}
|
||||
}
|
||||
|
||||
// GetAllDomains return a slice with all the certificate domain
|
||||
|
@ -31,3 +49,89 @@ func (c CertificateStore) GetAllDomains() []string {
|
|||
}
|
||||
return allCerts
|
||||
}
|
||||
|
||||
// GetBestCertificate returns the best match certificate, and caches the response
|
||||
func (c CertificateStore) GetBestCertificate(clientHello *tls.ClientHelloInfo) *tls.Certificate {
|
||||
domainToCheck := strings.ToLower(strings.TrimSpace(clientHello.ServerName))
|
||||
if len(domainToCheck) == 0 {
|
||||
// If no ServerName is provided, Check for local IP address matches
|
||||
host, _, err := net.SplitHostPort(clientHello.Conn.LocalAddr().String())
|
||||
if err != nil {
|
||||
log.Debugf("Could not split host/port: %v", err)
|
||||
}
|
||||
domainToCheck = strings.TrimSpace(host)
|
||||
}
|
||||
|
||||
if cert, ok := c.CertCache.Get(domainToCheck); ok {
|
||||
return cert.(*tls.Certificate)
|
||||
}
|
||||
|
||||
matchedCerts := map[string]*tls.Certificate{}
|
||||
if c.DynamicCerts != nil && c.DynamicCerts.Get() != nil {
|
||||
for domains, cert := range c.DynamicCerts.Get().(map[string]*tls.Certificate) {
|
||||
for _, certDomain := range strings.Split(domains, ",") {
|
||||
if MatchDomain(domainToCheck, certDomain) {
|
||||
matchedCerts[certDomain] = cert
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if c.StaticCerts != nil && c.StaticCerts.Get() != nil {
|
||||
for domains, cert := range c.StaticCerts.Get().(map[string]*tls.Certificate) {
|
||||
for _, certDomain := range strings.Split(domains, ",") {
|
||||
if MatchDomain(domainToCheck, certDomain) {
|
||||
matchedCerts[certDomain] = cert
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(matchedCerts) > 0 {
|
||||
// sort map by keys
|
||||
keys := make([]string, 0, len(matchedCerts))
|
||||
for k := range matchedCerts {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
// cache best match
|
||||
c.CertCache.SetDefault(domainToCheck, matchedCerts[keys[len(keys)-1]])
|
||||
return matchedCerts[keys[len(keys)-1]]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ContainsCertificates checks if there are any certs in the store
|
||||
func (c CertificateStore) ContainsCertificates() bool {
|
||||
return c.StaticCerts.Get() != nil || c.DynamicCerts.Get() != nil
|
||||
}
|
||||
|
||||
// ResetCache clears the cache in the store
|
||||
func (c CertificateStore) ResetCache() {
|
||||
if c.CertCache != nil {
|
||||
c.CertCache.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// MatchDomain return true if a domain match the cert domain
|
||||
func MatchDomain(domain string, certDomain string) bool {
|
||||
if domain == certDomain {
|
||||
return true
|
||||
}
|
||||
|
||||
for len(certDomain) > 0 && certDomain[len(certDomain)-1] == '.' {
|
||||
certDomain = certDomain[:len(certDomain)-1]
|
||||
}
|
||||
|
||||
labels := strings.Split(domain, ".")
|
||||
for i := range labels {
|
||||
labels[i] = "*"
|
||||
candidate := strings.Join(labels, ".")
|
||||
if certDomain == candidate {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
|
134
tls/certificate_store_test.go
Normal file
134
tls/certificate_store_test.go
Normal file
|
@ -0,0 +1,134 @@
|
|||
package tls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/safe"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetBestCertificate(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
domainToCheck string
|
||||
staticCert string
|
||||
dynamicCert string
|
||||
expectedCert string
|
||||
}{
|
||||
{
|
||||
desc: "Empty Store, returns no certs",
|
||||
domainToCheck: "snitest.com",
|
||||
staticCert: "",
|
||||
dynamicCert: "",
|
||||
expectedCert: "",
|
||||
},
|
||||
{
|
||||
desc: "Empty static cert store",
|
||||
domainToCheck: "snitest.com",
|
||||
staticCert: "",
|
||||
dynamicCert: "snitest.com",
|
||||
expectedCert: "snitest.com",
|
||||
},
|
||||
{
|
||||
desc: "Empty dynamic cert store",
|
||||
domainToCheck: "snitest.com",
|
||||
staticCert: "snitest.com",
|
||||
dynamicCert: "",
|
||||
expectedCert: "snitest.com",
|
||||
},
|
||||
{
|
||||
desc: "Best Match",
|
||||
domainToCheck: "snitest.com",
|
||||
staticCert: "snitest.com",
|
||||
dynamicCert: "snitest.org",
|
||||
expectedCert: "snitest.com",
|
||||
},
|
||||
{
|
||||
desc: "Best Match with wildcard dynamic and exact static",
|
||||
domainToCheck: "www.snitest.com",
|
||||
staticCert: "www.snitest.com",
|
||||
dynamicCert: "*.snitest.com",
|
||||
expectedCert: "www.snitest.com",
|
||||
},
|
||||
{
|
||||
desc: "Best Match with wildcard static and exact dynamic",
|
||||
domainToCheck: "www.snitest.com",
|
||||
staticCert: "*.snitest.com",
|
||||
dynamicCert: "www.snitest.com",
|
||||
expectedCert: "www.snitest.com",
|
||||
},
|
||||
{
|
||||
desc: "Best Match with static wildcard only",
|
||||
domainToCheck: "www.snitest.com",
|
||||
staticCert: "*.snitest.com",
|
||||
dynamicCert: "",
|
||||
expectedCert: "*.snitest.com",
|
||||
},
|
||||
{
|
||||
desc: "Best Match with dynamic wildcard only",
|
||||
domainToCheck: "www.snitest.com",
|
||||
staticCert: "",
|
||||
dynamicCert: "*.snitest.com",
|
||||
expectedCert: "*.snitest.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
staticMap := map[string]*tls.Certificate{}
|
||||
dynamicMap := map[string]*tls.Certificate{}
|
||||
|
||||
if test.staticCert != "" {
|
||||
cert, err := loadTestCert(test.staticCert)
|
||||
require.NoError(t, err)
|
||||
staticMap[test.staticCert] = cert
|
||||
}
|
||||
|
||||
if test.dynamicCert != "" {
|
||||
cert, err := loadTestCert(test.dynamicCert)
|
||||
require.NoError(t, err)
|
||||
dynamicMap[test.dynamicCert] = cert
|
||||
}
|
||||
|
||||
store := &CertificateStore{
|
||||
DynamicCerts: safe.New(dynamicMap),
|
||||
StaticCerts: safe.New(staticMap),
|
||||
CertCache: cache.New(1*time.Hour, 10*time.Minute),
|
||||
}
|
||||
|
||||
var expected *tls.Certificate
|
||||
if test.expectedCert != "" {
|
||||
cert, err := loadTestCert(test.expectedCert)
|
||||
require.NoError(t, err)
|
||||
expected = cert
|
||||
}
|
||||
|
||||
clientHello := &tls.ClientHelloInfo{
|
||||
ServerName: test.domainToCheck,
|
||||
}
|
||||
|
||||
actual := store.GetBestCertificate(clientHello)
|
||||
assert.Equal(t, expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func loadTestCert(certName string) (*tls.Certificate, error) {
|
||||
staticCert, err := tls.LoadX509KeyPair(
|
||||
fmt.Sprintf("../integration/fixtures/https/%s.cert", strings.Replace(certName, "*", "wildcard", -1)),
|
||||
fmt.Sprintf("../integration/fixtures/https/%s.key", strings.Replace(certName, "*", "wildcard", -1)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &staticCert, nil
|
||||
}
|
12
tls/tls.go
12
tls/tls.go
|
@ -22,11 +22,13 @@ type ClientCA struct {
|
|||
|
||||
// TLS configures TLS for an entry point
|
||||
type TLS struct {
|
||||
MinVersion string `export:"true"`
|
||||
CipherSuites []string
|
||||
Certificates Certificates
|
||||
ClientCAFiles []string // Deprecated
|
||||
ClientCA ClientCA
|
||||
MinVersion string `export:"true"`
|
||||
CipherSuites []string
|
||||
Certificates Certificates
|
||||
ClientCAFiles []string // Deprecated
|
||||
ClientCA ClientCA
|
||||
DefaultCertificate *Certificate
|
||||
SniStrict bool `export:"true"`
|
||||
}
|
||||
|
||||
// RootCAs hold the CA we want to have in root
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue