diff --git a/middlewares/auth/forward.go b/middlewares/auth/forward.go index 12e885fda..86c365e8c 100644 --- a/middlewares/auth/forward.go +++ b/middlewares/auth/forward.go @@ -73,6 +73,7 @@ func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next log.Debugf("Remote error %s. StatusCode: %d", config.Address, forwardResponse.StatusCode) utils.CopyHeaders(w.Header(), forwardResponse.Header) + utils.RemoveHeaders(w.Header(), forward.HopHeaders...) // Grab the location header, if any. redirectURL, err := forwardResponse.Location() diff --git a/middlewares/auth/forward_test.go b/middlewares/auth/forward_test.go index 72ce246b7..bf0ada6fb 100644 --- a/middlewares/auth/forward_test.go +++ b/middlewares/auth/forward_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/urfave/negroni" + "github.com/vulcand/oxy/forward" ) func TestForwardAuthFail(t *testing.T) { @@ -122,6 +123,59 @@ func TestForwardAuthRedirect(t *testing.T) { assert.NotEmpty(t, string(body), "there should be something in the body") } +func TestForwardAuthRemoveHopByHopHeaders(t *testing.T) { + authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + headers := w.Header() + for _, header := range forward.HopHeaders { + if header == forward.TransferEncoding { + headers.Add(header, "identity") + } else { + headers.Add(header, "test") + } + } + + http.Redirect(w, r, "http://example.com/redirect-test", http.StatusFound) + })) + defer authTs.Close() + + authMiddleware, err := NewAuthenticator(&types.Auth{ + Forward: &types.Forward{ + Address: authTs.URL, + }, + }, &tracing.Tracing{}) + assert.NoError(t, err, "there should be no error") + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "traefik") + }) + n := negroni.New(authMiddleware) + n.UseHandler(handler) + ts := httptest.NewServer(n) + defer ts.Close() + + client := &http.Client{ + CheckRedirect: func(r *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil) + res, err := client.Do(req) + assert.NoError(t, err, "there should be no error") + assert.Equal(t, http.StatusFound, res.StatusCode, "they should be equal") + + for _, header := range forward.HopHeaders { + assert.Equal(t, "", res.Header.Get(header), "hop-by-hop header '%s' mustn't be set", header) + } + + location, err := res.Location() + assert.NoError(t, err, "there should be no error") + assert.Equal(t, "http://example.com/redirect-test", location.String(), "they should be equal") + + body, err := ioutil.ReadAll(res.Body) + assert.NoError(t, err, "there should be no error") + assert.NotEmpty(t, string(body), "there should be something in the body") +} + func TestForwardAuthFailResponseHeaders(t *testing.T) { authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { cookie := &http.Cookie{Name: "example", Value: "testing", Path: "/"}