1
0
Fork 0

Merge current v2.11 into v3.1

This commit is contained in:
romain 2024-09-30 14:59:38 +02:00
commit b641d5cf2a
17 changed files with 50 additions and 128 deletions

View file

@ -13,22 +13,21 @@ const (
upgradeHeader = "Upgrade"
)
// Remover removes hop-by-hop headers listed in the "Connection" header.
// RemoveConnectionHeaders removes hop-by-hop headers listed in the "Connection" header.
// See RFC 7230, section 6.1.
func Remover(next http.Handler) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
next.ServeHTTP(rw, Remove(req))
}
}
// Remove removes hop-by-hop header on the request.
func Remove(req *http.Request) *http.Request {
func RemoveConnectionHeaders(req *http.Request) {
var reqUpType string
if httpguts.HeaderValuesContainsToken(req.Header[connectionHeader], upgradeHeader) {
reqUpType = req.Header.Get(upgradeHeader)
}
removeConnectionHeaders(req.Header)
for _, f := range req.Header[connectionHeader] {
for _, sf := range strings.Split(f, ",") {
if sf = textproto.TrimString(sf); sf != "" {
req.Header.Del(sf)
}
}
}
if reqUpType != "" {
req.Header.Set(connectionHeader, upgradeHeader)
@ -36,16 +35,4 @@ func Remove(req *http.Request) *http.Request {
} else {
req.Header.Del(connectionHeader)
}
return req
}
func removeConnectionHeaders(h http.Header) {
for _, f := range h[connectionHeader] {
for _, sf := range strings.Split(f, ",") {
if sf = textproto.TrimString(sf); sf != "" {
h.Del(sf)
}
}
}
}

View file

@ -50,19 +50,13 @@ func TestRemover(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})
h := Remover(next)
req := httptest.NewRequest(http.MethodGet, "https://localhost", nil)
for k, v := range test.reqHeaders {
req.Header.Set(k, v)
}
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
RemoveConnectionHeaders(req)
assert.Equal(t, test.expected, req.Header)
})

View file

@ -120,8 +120,6 @@ func (fa *forwardAuth) GetTracingInformation() (string, string, trace.SpanKind)
func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), fa.name, typeNameForward)
req = Remove(req)
forwardReq, err := http.NewRequestWithContext(req.Context(), http.MethodGet, fa.address, nil)
if err != nil {
logger.Debug().Msgf("Error calling %s. Cause %s", fa.address, err)
@ -262,6 +260,8 @@ func (fa *forwardAuth) buildModifier(authCookies []*http.Cookie) func(res *http.
func writeHeader(req, forwardReq *http.Request, trustForwardHeader bool, allowedHeaders []string) {
utils.CopyHeaders(forwardReq.Header, req.Header)
RemoveConnectionHeaders(forwardReq)
utils.RemoveHeaders(forwardReq.Header, hopHeaders...)
forwardReq.Header = filterForwardRequestHeaders(forwardReq.Header, allowedHeaders)

View file

@ -554,8 +554,11 @@ func (p *Provider) resolveDefaultCertificate(ctx context.Context, domains []stri
p.resolvingDomainsMutex.Lock()
sort.Strings(domains)
domainKey := strings.Join(domains, ",")
sortedDomains := make([]string, len(domains))
copy(sortedDomains, domains)
sort.Strings(sortedDomains)
domainKey := strings.Join(sortedDomains, ",")
if _, ok := p.resolvingDomains[domainKey]; ok {
p.resolvingDomainsMutex.Unlock()
@ -955,12 +958,14 @@ func (p *Provider) certExists(validDomains []string) bool {
p.certificatesMu.RLock()
defer p.certificatesMu.RUnlock()
sort.Strings(validDomains)
sortedDomains := make([]string, len(validDomains))
copy(sortedDomains, validDomains)
sort.Strings(sortedDomains)
for _, cert := range p.certificates {
domains := cert.Certificate.Domain.ToStrArray()
sort.Strings(domains)
if reflect.DeepEqual(domains, validDomains) {
if reflect.DeepEqual(domains, sortedDomains) {
return true
}
}