fix: logger and context.
This commit is contained in:
parent
b4c7b90c9e
commit
8e18d37b3d
52 changed files with 231 additions and 183 deletions
|
@ -16,7 +16,6 @@ import (
|
|||
"github.com/containous/traefik/v2/pkg/middlewares"
|
||||
"github.com/containous/traefik/v2/pkg/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -66,7 +65,7 @@ type passTLSClientCert struct {
|
|||
|
||||
// New constructs a new PassTLSClientCert instance from supplied frontend header struct.
|
||||
func New(ctx context.Context, next http.Handler, config dynamic.PassTLSClientCert, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
log.FromContext(middlewares.GetLoggerCtx(ctx, name, typeName)).Debug("Creating middleware")
|
||||
|
||||
return &passTLSClientCert{
|
||||
next: next,
|
||||
|
@ -104,11 +103,13 @@ func (p *passTLSClientCert) GetTracingInformation() (string, ext.SpanKindEnum) {
|
|||
}
|
||||
|
||||
func (p *passTLSClientCert) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
logger := middlewares.GetLogger(req.Context(), p.name, typeName)
|
||||
p.modifyRequestHeaders(logger, req)
|
||||
ctx := middlewares.GetLoggerCtx(req.Context(), p.name, typeName)
|
||||
|
||||
p.modifyRequestHeaders(ctx, req)
|
||||
p.next.ServeHTTP(rw, req)
|
||||
}
|
||||
func getDNInfo(prefix string, options *DistinguishedNameOptions, cs *pkix.Name) string {
|
||||
|
||||
func getDNInfo(ctx context.Context, prefix string, options *DistinguishedNameOptions, cs *pkix.Name) string {
|
||||
if options == nil {
|
||||
return ""
|
||||
}
|
||||
|
@ -124,27 +125,27 @@ func getDNInfo(prefix string, options *DistinguishedNameOptions, cs *pkix.Name)
|
|||
}
|
||||
|
||||
if options.CountryName {
|
||||
writeParts(content, cs.Country, "C")
|
||||
writeParts(ctx, content, cs.Country, "C")
|
||||
}
|
||||
|
||||
if options.StateOrProvinceName {
|
||||
writeParts(content, cs.Province, "ST")
|
||||
writeParts(ctx, content, cs.Province, "ST")
|
||||
}
|
||||
|
||||
if options.LocalityName {
|
||||
writeParts(content, cs.Locality, "L")
|
||||
writeParts(ctx, content, cs.Locality, "L")
|
||||
}
|
||||
|
||||
if options.OrganizationName {
|
||||
writeParts(content, cs.Organization, "O")
|
||||
writeParts(ctx, content, cs.Organization, "O")
|
||||
}
|
||||
|
||||
if options.SerialNumber {
|
||||
writePart(content, cs.SerialNumber, "SN")
|
||||
writePart(ctx, content, cs.SerialNumber, "SN")
|
||||
}
|
||||
|
||||
if options.CommonName {
|
||||
writePart(content, cs.CommonName, "CN")
|
||||
writePart(ctx, content, cs.CommonName, "CN")
|
||||
}
|
||||
|
||||
if content.Len() > 0 {
|
||||
|
@ -154,24 +155,24 @@ func getDNInfo(prefix string, options *DistinguishedNameOptions, cs *pkix.Name)
|
|||
return ""
|
||||
}
|
||||
|
||||
func writeParts(content io.StringWriter, entries []string, prefix string) {
|
||||
func writeParts(ctx context.Context, content io.StringWriter, entries []string, prefix string) {
|
||||
for _, entry := range entries {
|
||||
writePart(content, entry, prefix)
|
||||
writePart(ctx, content, entry, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
func writePart(content io.StringWriter, entry string, prefix string) {
|
||||
func writePart(ctx context.Context, content io.StringWriter, entry string, prefix string) {
|
||||
if len(entry) > 0 {
|
||||
_, err := content.WriteString(fmt.Sprintf("%s=%s,", prefix, entry))
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
log.FromContext(ctx).Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getXForwardedTLSClientCertInfo Build a string with the wanted client certificates information
|
||||
// like Subject="C=%s,ST=%s,L=%s,O=%s,CN=%s",NB=%d,NA=%d,SAN=%s;
|
||||
func (p *passTLSClientCert) getXForwardedTLSClientCertInfo(certs []*x509.Certificate) string {
|
||||
func (p *passTLSClientCert) getXForwardedTLSClientCertInfo(ctx context.Context, certs []*x509.Certificate) string {
|
||||
var headerValues []string
|
||||
|
||||
for _, peerCert := range certs {
|
||||
|
@ -181,12 +182,12 @@ func (p *passTLSClientCert) getXForwardedTLSClientCertInfo(certs []*x509.Certifi
|
|||
var na string
|
||||
|
||||
if p.info != nil {
|
||||
subject := getDNInfo("Subject", p.info.subject, &peerCert.Subject)
|
||||
subject := getDNInfo(ctx, "Subject", p.info.subject, &peerCert.Subject)
|
||||
if len(subject) > 0 {
|
||||
values = append(values, subject)
|
||||
}
|
||||
|
||||
issuer := getDNInfo("Issuer", p.info.issuer, &peerCert.Issuer)
|
||||
issuer := getDNInfo(ctx, "Issuer", p.info.issuer, &peerCert.Issuer)
|
||||
if len(issuer) > 0 {
|
||||
values = append(values, issuer)
|
||||
}
|
||||
|
@ -217,10 +218,12 @@ func (p *passTLSClientCert) getXForwardedTLSClientCertInfo(certs []*x509.Certifi
|
|||
}
|
||||
|
||||
// modifyRequestHeaders set the wanted headers with the certificates information.
|
||||
func (p *passTLSClientCert) modifyRequestHeaders(logger logrus.FieldLogger, r *http.Request) {
|
||||
func (p *passTLSClientCert) modifyRequestHeaders(ctx context.Context, r *http.Request) {
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
if p.pem {
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
r.Header.Set(xForwardedTLSClientCert, getXForwardedTLSClientCert(logger, r.TLS.PeerCertificates))
|
||||
r.Header.Set(xForwardedTLSClientCert, getXForwardedTLSClientCert(ctx, r.TLS.PeerCertificates))
|
||||
} else {
|
||||
logger.Warn("Tried to extract a certificate on a request without mutual TLS")
|
||||
}
|
||||
|
@ -228,7 +231,7 @@ func (p *passTLSClientCert) modifyRequestHeaders(logger logrus.FieldLogger, r *h
|
|||
|
||||
if p.info != nil {
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
headerContent := p.getXForwardedTLSClientCertInfo(r.TLS.PeerCertificates)
|
||||
headerContent := p.getXForwardedTLSClientCertInfo(ctx, r.TLS.PeerCertificates)
|
||||
r.Header.Set(xForwardedTLSClientCertInfo, url.QueryEscape(headerContent))
|
||||
} else {
|
||||
logger.Warn("Tried to extract a certificate on a request without mutual TLS")
|
||||
|
@ -248,22 +251,22 @@ func sanitize(cert []byte) string {
|
|||
}
|
||||
|
||||
// extractCertificate extract the certificate from the request.
|
||||
func extractCertificate(logger logrus.FieldLogger, cert *x509.Certificate) string {
|
||||
func extractCertificate(ctx context.Context, cert *x509.Certificate) string {
|
||||
b := pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
|
||||
certPEM := pem.EncodeToMemory(&b)
|
||||
if certPEM == nil {
|
||||
logger.Error("Cannot extract the certificate content")
|
||||
log.FromContext(ctx).Error("Cannot extract the certificate content")
|
||||
return ""
|
||||
}
|
||||
return sanitize(certPEM)
|
||||
}
|
||||
|
||||
// getXForwardedTLSClientCert Build a string with the client certificates.
|
||||
func getXForwardedTLSClientCert(logger logrus.FieldLogger, certs []*x509.Certificate) string {
|
||||
func getXForwardedTLSClientCert(ctx context.Context, certs []*x509.Certificate) string {
|
||||
var headerValues []string
|
||||
|
||||
for _, peerCert := range certs {
|
||||
headerValues = append(headerValues, extractCertificate(logger, peerCert))
|
||||
headerValues = append(headerValues, extractCertificate(ctx, peerCert))
|
||||
}
|
||||
|
||||
return strings.Join(headerValues, ",")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue