fix: PassClientTLSCert middleware separators and formatting
This commit is contained in:
parent
89db08eb93
commit
39a3cefc21
3 changed files with 303 additions and 293 deletions
|
@ -18,10 +18,17 @@ import (
|
|||
"github.com/opentracing/opentracing-go/ext"
|
||||
)
|
||||
|
||||
const typeName = "PassClientTLSCert"
|
||||
|
||||
const (
|
||||
xForwardedTLSClientCert = "X-Forwarded-Tls-Client-Cert"
|
||||
xForwardedTLSClientCertInfo = "X-Forwarded-Tls-Client-Cert-Info"
|
||||
typeName = "PassClientTLSCert"
|
||||
)
|
||||
|
||||
const (
|
||||
certSeparator = ","
|
||||
fieldSeparator = ";"
|
||||
subFieldSeparator = ","
|
||||
)
|
||||
|
||||
var attributeTypeNames = map[string]string{
|
||||
|
@ -55,6 +62,29 @@ func newDistinguishedNameOptions(info *dynamic.TLSCLientCertificateDNInfo) *Dist
|
|||
}
|
||||
}
|
||||
|
||||
// tlsClientCertificateInfo is a struct for specifying the configuration for the passTLSClientCert middleware.
|
||||
type tlsClientCertificateInfo struct {
|
||||
notAfter bool
|
||||
notBefore bool
|
||||
sans bool
|
||||
subject *DistinguishedNameOptions
|
||||
issuer *DistinguishedNameOptions
|
||||
}
|
||||
|
||||
func newTLSClientCertificateInfo(info *dynamic.TLSClientCertificateInfo) *tlsClientCertificateInfo {
|
||||
if info == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &tlsClientCertificateInfo{
|
||||
issuer: newDistinguishedNameOptions(info.Issuer),
|
||||
notAfter: info.NotAfter,
|
||||
notBefore: info.NotBefore,
|
||||
subject: newDistinguishedNameOptions(info.Subject),
|
||||
sans: info.Sans,
|
||||
}
|
||||
}
|
||||
|
||||
// passTLSClientCert is a middleware that helps setup a few tls info features.
|
||||
type passTLSClientCert struct {
|
||||
next http.Handler
|
||||
|
@ -71,45 +101,84 @@ func New(ctx context.Context, next http.Handler, config dynamic.PassTLSClientCer
|
|||
next: next,
|
||||
name: name,
|
||||
pem: config.PEM,
|
||||
info: newTLSClientInfo(config.Info),
|
||||
info: newTLSClientCertificateInfo(config.Info),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// tlsClientCertificateInfo is a struct for specifying the configuration for the passTLSClientCert middleware.
|
||||
type tlsClientCertificateInfo struct {
|
||||
notAfter bool
|
||||
notBefore bool
|
||||
sans bool
|
||||
subject *DistinguishedNameOptions
|
||||
issuer *DistinguishedNameOptions
|
||||
}
|
||||
|
||||
func newTLSClientInfo(info *dynamic.TLSClientCertificateInfo) *tlsClientCertificateInfo {
|
||||
if info == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &tlsClientCertificateInfo{
|
||||
issuer: newDistinguishedNameOptions(info.Issuer),
|
||||
notAfter: info.NotAfter,
|
||||
notBefore: info.NotBefore,
|
||||
subject: newDistinguishedNameOptions(info.Subject),
|
||||
sans: info.Sans,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *passTLSClientCert) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return p.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (p *passTLSClientCert) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
ctx := middlewares.GetLoggerCtx(req.Context(), p.name, typeName)
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
if p.pem {
|
||||
if req.TLS != nil && len(req.TLS.PeerCertificates) > 0 {
|
||||
req.Header.Set(xForwardedTLSClientCert, getCertificates(ctx, req.TLS.PeerCertificates))
|
||||
} else {
|
||||
logger.Warn("Tried to extract a certificate on a request without mutual TLS")
|
||||
}
|
||||
}
|
||||
|
||||
if p.info != nil {
|
||||
if req.TLS != nil && len(req.TLS.PeerCertificates) > 0 {
|
||||
headerContent := p.getCertInfo(ctx, req.TLS.PeerCertificates)
|
||||
req.Header.Set(xForwardedTLSClientCertInfo, url.QueryEscape(headerContent))
|
||||
} else {
|
||||
logger.Warn("Tried to extract a certificate on a request without mutual TLS")
|
||||
}
|
||||
}
|
||||
|
||||
p.modifyRequestHeaders(ctx, req)
|
||||
p.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
func getDNInfo(ctx context.Context, prefix string, options *DistinguishedNameOptions, cs *pkix.Name) string {
|
||||
// getCertInfo Build a string with the wanted client certificates information
|
||||
// - the `,` is used to separate certificates
|
||||
// - the `;` is used to separate root fields
|
||||
// - the value of root fields is always wrapped by double quote
|
||||
// - if a field is empty, the field is ignored
|
||||
func (p *passTLSClientCert) getCertInfo(ctx context.Context, certs []*x509.Certificate) string {
|
||||
var headerValues []string
|
||||
|
||||
for _, peerCert := range certs {
|
||||
var values []string
|
||||
|
||||
if p.info != nil {
|
||||
subject := getDNInfo(ctx, p.info.subject, &peerCert.Subject)
|
||||
if subject != "" {
|
||||
values = append(values, fmt.Sprintf(`Subject="%s"`, strings.TrimSuffix(subject, subFieldSeparator)))
|
||||
}
|
||||
|
||||
issuer := getDNInfo(ctx, p.info.issuer, &peerCert.Issuer)
|
||||
if issuer != "" {
|
||||
values = append(values, fmt.Sprintf(`Issuer="%s"`, strings.TrimSuffix(issuer, subFieldSeparator)))
|
||||
}
|
||||
|
||||
if p.info.notBefore {
|
||||
values = append(values, fmt.Sprintf(`NB="%d"`, uint64(peerCert.NotBefore.Unix())))
|
||||
}
|
||||
|
||||
if p.info.notAfter {
|
||||
values = append(values, fmt.Sprintf(`NA="%d"`, uint64(peerCert.NotAfter.Unix())))
|
||||
}
|
||||
|
||||
if p.info.sans {
|
||||
sans := getSANs(peerCert)
|
||||
if len(sans) > 0 {
|
||||
values = append(values, fmt.Sprintf(`SAN="%s"`, strings.Join(sans, subFieldSeparator)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
value := strings.Join(values, fieldSeparator)
|
||||
headerValues = append(headerValues, value)
|
||||
}
|
||||
|
||||
return strings.Join(headerValues, certSeparator)
|
||||
}
|
||||
|
||||
func getDNInfo(ctx context.Context, options *DistinguishedNameOptions, cs *pkix.Name) string {
|
||||
if options == nil {
|
||||
return ""
|
||||
}
|
||||
|
@ -120,7 +189,7 @@ func getDNInfo(ctx context.Context, prefix string, options *DistinguishedNameOpt
|
|||
for _, name := range cs.Names {
|
||||
// Domain Component - RFC 2247
|
||||
if options.DomainComponent && attributeTypeNames[name.Type.String()] == "DC" {
|
||||
content.WriteString(fmt.Sprintf("DC=%s,", name.Value))
|
||||
content.WriteString(fmt.Sprintf("DC=%s%s", name.Value, subFieldSeparator))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -148,11 +217,7 @@ func getDNInfo(ctx context.Context, prefix string, options *DistinguishedNameOpt
|
|||
writePart(ctx, content, cs.CommonName, "CN")
|
||||
}
|
||||
|
||||
if content.Len() > 0 {
|
||||
return prefix + `="` + strings.TrimSuffix(content.String(), ",") + `"`
|
||||
}
|
||||
|
||||
return ""
|
||||
return content.String()
|
||||
}
|
||||
|
||||
func writeParts(ctx context.Context, content io.StringWriter, entries []string, prefix string) {
|
||||
|
@ -163,135 +228,63 @@ func writeParts(ctx context.Context, content io.StringWriter, entries []string,
|
|||
|
||||
func writePart(ctx context.Context, content io.StringWriter, entry string, prefix string) {
|
||||
if len(entry) > 0 {
|
||||
_, err := content.WriteString(fmt.Sprintf("%s=%s,", prefix, entry))
|
||||
_, err := content.WriteString(fmt.Sprintf("%s=%s%s", prefix, entry, subFieldSeparator))
|
||||
if err != nil {
|
||||
log.FromContext(ctx).Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getXForwardedTLSClientCertInfo Build a string with the wanted client certificates information
|
||||
// like Subject="C=%s,ST=%s,L=%s,O=%s,CN=%s",NB=%d,NA=%d,SAN=%s;
|
||||
func (p *passTLSClientCert) getXForwardedTLSClientCertInfo(ctx context.Context, certs []*x509.Certificate) string {
|
||||
var headerValues []string
|
||||
|
||||
for _, peerCert := range certs {
|
||||
var values []string
|
||||
var sans string
|
||||
var nb string
|
||||
var na string
|
||||
|
||||
if p.info != nil {
|
||||
subject := getDNInfo(ctx, "Subject", p.info.subject, &peerCert.Subject)
|
||||
if len(subject) > 0 {
|
||||
values = append(values, subject)
|
||||
}
|
||||
|
||||
issuer := getDNInfo(ctx, "Issuer", p.info.issuer, &peerCert.Issuer)
|
||||
if len(issuer) > 0 {
|
||||
values = append(values, issuer)
|
||||
}
|
||||
}
|
||||
|
||||
ci := p.info
|
||||
if ci != nil {
|
||||
if ci.notBefore {
|
||||
nb = fmt.Sprintf("NB=%d", uint64(peerCert.NotBefore.Unix()))
|
||||
values = append(values, nb)
|
||||
}
|
||||
if ci.notAfter {
|
||||
na = fmt.Sprintf("NA=%d", uint64(peerCert.NotAfter.Unix()))
|
||||
values = append(values, na)
|
||||
}
|
||||
|
||||
if ci.sans {
|
||||
sans = fmt.Sprintf("SAN=%s", strings.Join(getSANs(peerCert), ","))
|
||||
values = append(values, sans)
|
||||
}
|
||||
}
|
||||
|
||||
value := strings.Join(values, ",")
|
||||
headerValues = append(headerValues, value)
|
||||
}
|
||||
|
||||
return strings.Join(headerValues, ";")
|
||||
}
|
||||
|
||||
// modifyRequestHeaders set the wanted headers with the certificates information.
|
||||
func (p *passTLSClientCert) modifyRequestHeaders(ctx context.Context, r *http.Request) {
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
if p.pem {
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
r.Header.Set(xForwardedTLSClientCert, getXForwardedTLSClientCert(ctx, r.TLS.PeerCertificates))
|
||||
} else {
|
||||
logger.Warn("Tried to extract a certificate on a request without mutual TLS")
|
||||
}
|
||||
}
|
||||
|
||||
if p.info != nil {
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
headerContent := p.getXForwardedTLSClientCertInfo(ctx, r.TLS.PeerCertificates)
|
||||
r.Header.Set(xForwardedTLSClientCertInfo, url.QueryEscape(headerContent))
|
||||
} else {
|
||||
logger.Warn("Tried to extract a certificate on a request without mutual TLS")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sanitize As we pass the raw certificates, remove the useless data and make it http request compliant.
|
||||
func sanitize(cert []byte) string {
|
||||
s := string(cert)
|
||||
r := strings.NewReplacer("-----BEGIN CERTIFICATE-----", "",
|
||||
cleaned := strings.NewReplacer(
|
||||
"-----BEGIN CERTIFICATE-----", "",
|
||||
"-----END CERTIFICATE-----", "",
|
||||
"\n", "")
|
||||
cleaned := r.Replace(s)
|
||||
"\n", "",
|
||||
).Replace(string(cert))
|
||||
|
||||
return url.QueryEscape(cleaned)
|
||||
}
|
||||
|
||||
// extractCertificate extract the certificate from the request.
|
||||
func extractCertificate(ctx context.Context, cert *x509.Certificate) string {
|
||||
b := pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
|
||||
certPEM := pem.EncodeToMemory(&b)
|
||||
if certPEM == nil {
|
||||
log.FromContext(ctx).Error("Cannot extract the certificate content")
|
||||
return ""
|
||||
}
|
||||
return sanitize(certPEM)
|
||||
}
|
||||
|
||||
// getXForwardedTLSClientCert Build a string with the client certificates.
|
||||
func getXForwardedTLSClientCert(ctx context.Context, certs []*x509.Certificate) string {
|
||||
// getCertificates Build a string with the client certificates.
|
||||
func getCertificates(ctx context.Context, certs []*x509.Certificate) string {
|
||||
var headerValues []string
|
||||
|
||||
for _, peerCert := range certs {
|
||||
headerValues = append(headerValues, extractCertificate(ctx, peerCert))
|
||||
}
|
||||
|
||||
return strings.Join(headerValues, ",")
|
||||
return strings.Join(headerValues, certSeparator)
|
||||
}
|
||||
|
||||
// extractCertificate extract the certificate from the request.
|
||||
func extractCertificate(ctx context.Context, cert *x509.Certificate) string {
|
||||
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
|
||||
if certPEM == nil {
|
||||
log.FromContext(ctx).Error("Cannot extract the certificate content")
|
||||
return ""
|
||||
}
|
||||
|
||||
return sanitize(certPEM)
|
||||
}
|
||||
|
||||
// getSANs get the Subject Alternate Name values.
|
||||
func getSANs(cert *x509.Certificate) []string {
|
||||
var sans []string
|
||||
if cert == nil {
|
||||
return sans
|
||||
return nil
|
||||
}
|
||||
|
||||
var sans []string
|
||||
sans = append(sans, cert.DNSNames...)
|
||||
sans = append(sans, cert.EmailAddresses...)
|
||||
|
||||
var ips []string
|
||||
for _, ip := range cert.IPAddresses {
|
||||
ips = append(ips, ip.String())
|
||||
sans = append(sans, ip.String())
|
||||
}
|
||||
sans = append(sans, ips...)
|
||||
|
||||
var uris []string
|
||||
for _, uri := range cert.URIs {
|
||||
uris = append(uris, uri.String())
|
||||
sans = append(sans, uri.String())
|
||||
}
|
||||
|
||||
return append(sans, uris...)
|
||||
return sans
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue