1
0
Fork 0

Clean connection headers for forward auth request only

Co-authored-by: Kevin Pollet <pollet.kevin@gmail.com>
This commit is contained in:
Romain 2024-09-27 11:18:05 +02:00 committed by GitHub
parent b00f640d72
commit a42d396ed2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 30 additions and 30 deletions

View file

@ -13,34 +13,26 @@ const (
upgradeHeader = "Upgrade"
)
// Remover removes hop-by-hop headers listed in the "Connection" header.
// RemoveConnectionHeaders removes hop-by-hop headers listed in the "Connection" header.
// See RFC 7230, section 6.1.
func Remover(next http.Handler) http.HandlerFunc {
return func(rw http.ResponseWriter, req *http.Request) {
var reqUpType string
if httpguts.HeaderValuesContainsToken(req.Header[connectionHeader], upgradeHeader) {
reqUpType = req.Header.Get(upgradeHeader)
}
removeConnectionHeaders(req.Header)
if reqUpType != "" {
req.Header.Set(connectionHeader, upgradeHeader)
req.Header.Set(upgradeHeader, reqUpType)
} else {
req.Header.Del(connectionHeader)
}
next.ServeHTTP(rw, req)
func RemoveConnectionHeaders(req *http.Request) {
var reqUpType string
if httpguts.HeaderValuesContainsToken(req.Header[connectionHeader], upgradeHeader) {
reqUpType = req.Header.Get(upgradeHeader)
}
}
func removeConnectionHeaders(h http.Header) {
for _, f := range h[connectionHeader] {
for _, f := range req.Header[connectionHeader] {
for _, sf := range strings.Split(f, ",") {
if sf = textproto.TrimString(sf); sf != "" {
h.Del(sf)
req.Header.Del(sf)
}
}
}
if reqUpType != "" {
req.Header.Set(connectionHeader, upgradeHeader)
req.Header.Set(upgradeHeader, reqUpType)
} else {
req.Header.Del(connectionHeader)
}
}

View file

@ -50,19 +50,13 @@ func TestRemover(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})
h := Remover(next)
req := httptest.NewRequest(http.MethodGet, "https://localhost", nil)
for k, v := range test.reqHeaders {
req.Header.Set(k, v)
}
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
RemoveConnectionHeaders(req)
assert.Equal(t, test.expected, req.Header)
})

View file

@ -89,7 +89,7 @@ func NewForward(ctx context.Context, next http.Handler, config dynamic.ForwardAu
fa.authResponseHeadersRegex = re
}
return Remover(fa), nil
return fa, nil
}
func (fa *forwardAuth) GetTracingInformation() (string, ext.SpanKindEnum) {
@ -195,6 +195,8 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
func writeHeader(req, forwardReq *http.Request, trustForwardHeader bool, allowedHeaders []string) {
utils.CopyHeaders(forwardReq.Header, req.Header)
RemoveConnectionHeaders(forwardReq)
utils.RemoveHeaders(forwardReq.Header, hopHeaders...)
forwardReq.Header = filterForwardRequestHeaders(forwardReq.Header, allowedHeaders)