Cleanup Connection headers before passing the middleware chain
Co-authored-by: Romain <rtribotte@users.noreply.github.com>
This commit is contained in:
parent
0cf2032c15
commit
5841441005
15 changed files with 475 additions and 31 deletions
46
pkg/middlewares/auth/connectionheader.go
Normal file
46
pkg/middlewares/auth/connectionheader.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
|
||||
const (
|
||||
connectionHeader = "Connection"
|
||||
upgradeHeader = "Upgrade"
|
||||
)
|
||||
|
||||
// Remover removes hop-by-hop headers listed in the "Connection" header.
|
||||
// See RFC 7230, section 6.1.
|
||||
func Remover(next http.Handler) http.HandlerFunc {
|
||||
return func(rw http.ResponseWriter, req *http.Request) {
|
||||
var reqUpType string
|
||||
if httpguts.HeaderValuesContainsToken(req.Header[connectionHeader], upgradeHeader) {
|
||||
reqUpType = req.Header.Get(upgradeHeader)
|
||||
}
|
||||
|
||||
removeConnectionHeaders(req.Header)
|
||||
|
||||
if reqUpType != "" {
|
||||
req.Header.Set(connectionHeader, upgradeHeader)
|
||||
req.Header.Set(upgradeHeader, reqUpType)
|
||||
} else {
|
||||
req.Header.Del(connectionHeader)
|
||||
}
|
||||
|
||||
next.ServeHTTP(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
func removeConnectionHeaders(h http.Header) {
|
||||
for _, f := range h[connectionHeader] {
|
||||
for _, sf := range strings.Split(f, ",") {
|
||||
if sf = textproto.TrimString(sf); sf != "" {
|
||||
h.Del(sf)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
70
pkg/middlewares/auth/connectionheader_test.go
Normal file
70
pkg/middlewares/auth/connectionheader_test.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRemover(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
reqHeaders map[string]string
|
||||
expected http.Header
|
||||
}{
|
||||
{
|
||||
desc: "simple remove",
|
||||
reqHeaders: map[string]string{
|
||||
"Foo": "bar",
|
||||
connectionHeader: "foo",
|
||||
},
|
||||
expected: http.Header{},
|
||||
},
|
||||
{
|
||||
desc: "remove and Upgrade",
|
||||
reqHeaders: map[string]string{
|
||||
upgradeHeader: "test",
|
||||
"Foo": "bar",
|
||||
connectionHeader: "Upgrade,foo",
|
||||
},
|
||||
expected: http.Header{
|
||||
upgradeHeader: []string{"test"},
|
||||
connectionHeader: []string{"Upgrade"},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "no remove",
|
||||
reqHeaders: map[string]string{
|
||||
"Foo": "bar",
|
||||
connectionHeader: "fii",
|
||||
},
|
||||
expected: http.Header{
|
||||
"Foo": []string{"bar"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})
|
||||
|
||||
h := Remover(next)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "https://localhost", nil)
|
||||
|
||||
for k, v := range test.reqHeaders {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
h.ServeHTTP(rw, req)
|
||||
|
||||
assert.Equal(t, test.expected, req.Header)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -15,7 +15,6 @@ import (
|
|||
"github.com/traefik/traefik/v2/pkg/config/dynamic"
|
||||
"github.com/traefik/traefik/v2/pkg/log"
|
||||
"github.com/traefik/traefik/v2/pkg/middlewares"
|
||||
"github.com/traefik/traefik/v2/pkg/middlewares/connectionheader"
|
||||
"github.com/traefik/traefik/v2/pkg/tracing"
|
||||
"github.com/vulcand/oxy/v2/forward"
|
||||
"github.com/vulcand/oxy/v2/utils"
|
||||
|
@ -90,7 +89,7 @@ func NewForward(ctx context.Context, next http.Handler, config dynamic.ForwardAu
|
|||
fa.authResponseHeadersRegex = re
|
||||
}
|
||||
|
||||
return connectionheader.Remover(fa), nil
|
||||
return Remover(fa), nil
|
||||
}
|
||||
|
||||
func (fa *forwardAuth) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue