From 2949995abc523a6edff6f4b695dbf5b91ef5a05e Mon Sep 17 00:00:00 2001 From: Ben <52796471+bengentree@users.noreply.github.com> Date: Wed, 4 Jun 2025 07:38:04 -0600 Subject: [PATCH] Handle context canceled in ForwardAuth middleware --- pkg/middlewares/auth/forward.go | 8 +++- pkg/middlewares/auth/forward_test.go | 71 ++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/pkg/middlewares/auth/forward.go b/pkg/middlewares/auth/forward.go index 0ade40b7d..8520e2357 100644 --- a/pkg/middlewares/auth/forward.go +++ b/pkg/middlewares/auth/forward.go @@ -17,6 +17,7 @@ import ( "github.com/traefik/traefik/v3/pkg/middlewares" "github.com/traefik/traefik/v3/pkg/middlewares/accesslog" "github.com/traefik/traefik/v3/pkg/middlewares/observability" + "github.com/traefik/traefik/v3/pkg/proxy/httputil" "github.com/traefik/traefik/v3/pkg/tracing" "github.com/traefik/traefik/v3/pkg/types" "github.com/vulcand/oxy/v2/forward" @@ -195,7 +196,12 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) { logger.Debug().Err(forwardErr).Msgf("Error calling %s", fa.address) observability.SetStatusErrorf(req.Context(), "Error calling %s. Cause: %s", fa.address, forwardErr) - rw.WriteHeader(http.StatusInternalServerError) + statusCode := http.StatusInternalServerError + if errors.Is(forwardErr, context.Canceled) { + statusCode = httputil.StatusClientClosedRequest + } + + rw.WriteHeader(statusCode) return } defer forwardResponse.Body.Close() diff --git a/pkg/middlewares/auth/forward_test.go b/pkg/middlewares/auth/forward_test.go index aa4503381..2bd062370 100644 --- a/pkg/middlewares/auth/forward_test.go +++ b/pkg/middlewares/auth/forward_test.go @@ -11,10 +11,12 @@ import ( "net/url" "strconv" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/traefik/traefik/v3/pkg/config/dynamic" + "github.com/traefik/traefik/v3/pkg/proxy/httputil" "github.com/traefik/traefik/v3/pkg/testhelpers" "github.com/traefik/traefik/v3/pkg/tracing" "github.com/vulcand/oxy/v2/forward" @@ -408,6 +410,75 @@ func TestForwardAuthFailResponseHeaders(t *testing.T) { assert.Equal(t, "Forbidden\n", string(body)) } +func TestForwardAuthClientClosedRequest(t *testing.T) { + requestStarted := make(chan struct{}) + requestCancelled := make(chan struct{}) + responseComplete := make(chan struct{}) + + authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + close(requestStarted) + <-requestCancelled + })) + t.Cleanup(authTs.Close) + + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // next should not be called. + t.Fail() + }) + + auth := dynamic.ForwardAuth{ + Address: authTs.URL, + } + authMiddleware, err := NewForward(t.Context(), next, auth, "authTest") + require.NoError(t, err) + + ctx, cancel := context.WithCancel(t.Context()) + req := httptest.NewRequestWithContext(ctx, "GET", "http://foo", http.NoBody) + + recorder := httptest.NewRecorder() + go func() { + authMiddleware.ServeHTTP(recorder, req) + close(responseComplete) + }() + + <-requestStarted + + cancel() + close(requestCancelled) + + <-responseComplete + + assert.Equal(t, httputil.StatusClientClosedRequest, recorder.Result().StatusCode) +} + +func TestForwardAuthForwardError(t *testing.T) { + next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // next should not be called. + t.Fail() + }) + + auth := dynamic.ForwardAuth{ + Address: "http://non-existing-server", + } + authMiddleware, err := NewForward(t.Context(), next, auth, "authTest") + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(t.Context(), 1*time.Microsecond) + defer cancel() + req := httptest.NewRequestWithContext(ctx, http.MethodGet, "http://foo", nil) + + recorder := httptest.NewRecorder() + responseComplete := make(chan struct{}) + go func() { + authMiddleware.ServeHTTP(recorder, req) + close(responseComplete) + }() + + <-responseComplete + + assert.Equal(t, http.StatusInternalServerError, recorder.Result().StatusCode) +} + func Test_writeHeader(t *testing.T) { testCases := []struct { name string