Handle context canceled in ForwardAuth middleware
This commit is contained in:
parent
bf72b9768c
commit
2949995abc
2 changed files with 78 additions and 1 deletions
|
|
@ -17,6 +17,7 @@ import (
|
||||||
"github.com/traefik/traefik/v3/pkg/middlewares"
|
"github.com/traefik/traefik/v3/pkg/middlewares"
|
||||||
"github.com/traefik/traefik/v3/pkg/middlewares/accesslog"
|
"github.com/traefik/traefik/v3/pkg/middlewares/accesslog"
|
||||||
"github.com/traefik/traefik/v3/pkg/middlewares/observability"
|
"github.com/traefik/traefik/v3/pkg/middlewares/observability"
|
||||||
|
"github.com/traefik/traefik/v3/pkg/proxy/httputil"
|
||||||
"github.com/traefik/traefik/v3/pkg/tracing"
|
"github.com/traefik/traefik/v3/pkg/tracing"
|
||||||
"github.com/traefik/traefik/v3/pkg/types"
|
"github.com/traefik/traefik/v3/pkg/types"
|
||||||
"github.com/vulcand/oxy/v2/forward"
|
"github.com/vulcand/oxy/v2/forward"
|
||||||
|
|
@ -195,7 +196,12 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
logger.Debug().Err(forwardErr).Msgf("Error calling %s", fa.address)
|
logger.Debug().Err(forwardErr).Msgf("Error calling %s", fa.address)
|
||||||
observability.SetStatusErrorf(req.Context(), "Error calling %s. Cause: %s", fa.address, forwardErr)
|
observability.SetStatusErrorf(req.Context(), "Error calling %s. Cause: %s", fa.address, forwardErr)
|
||||||
|
|
||||||
rw.WriteHeader(http.StatusInternalServerError)
|
statusCode := http.StatusInternalServerError
|
||||||
|
if errors.Is(forwardErr, context.Canceled) {
|
||||||
|
statusCode = httputil.StatusClientClosedRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
rw.WriteHeader(statusCode)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer forwardResponse.Body.Close()
|
defer forwardResponse.Body.Close()
|
||||||
|
|
|
||||||
|
|
@ -11,10 +11,12 @@ import (
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
||||||
|
"github.com/traefik/traefik/v3/pkg/proxy/httputil"
|
||||||
"github.com/traefik/traefik/v3/pkg/testhelpers"
|
"github.com/traefik/traefik/v3/pkg/testhelpers"
|
||||||
"github.com/traefik/traefik/v3/pkg/tracing"
|
"github.com/traefik/traefik/v3/pkg/tracing"
|
||||||
"github.com/vulcand/oxy/v2/forward"
|
"github.com/vulcand/oxy/v2/forward"
|
||||||
|
|
@ -408,6 +410,75 @@ func TestForwardAuthFailResponseHeaders(t *testing.T) {
|
||||||
assert.Equal(t, "Forbidden\n", string(body))
|
assert.Equal(t, "Forbidden\n", string(body))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestForwardAuthClientClosedRequest(t *testing.T) {
|
||||||
|
requestStarted := make(chan struct{})
|
||||||
|
requestCancelled := make(chan struct{})
|
||||||
|
responseComplete := make(chan struct{})
|
||||||
|
|
||||||
|
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
close(requestStarted)
|
||||||
|
<-requestCancelled
|
||||||
|
}))
|
||||||
|
t.Cleanup(authTs.Close)
|
||||||
|
|
||||||
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// next should not be called.
|
||||||
|
t.Fail()
|
||||||
|
})
|
||||||
|
|
||||||
|
auth := dynamic.ForwardAuth{
|
||||||
|
Address: authTs.URL,
|
||||||
|
}
|
||||||
|
authMiddleware, err := NewForward(t.Context(), next, auth, "authTest")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(t.Context())
|
||||||
|
req := httptest.NewRequestWithContext(ctx, "GET", "http://foo", http.NoBody)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
go func() {
|
||||||
|
authMiddleware.ServeHTTP(recorder, req)
|
||||||
|
close(responseComplete)
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-requestStarted
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
close(requestCancelled)
|
||||||
|
|
||||||
|
<-responseComplete
|
||||||
|
|
||||||
|
assert.Equal(t, httputil.StatusClientClosedRequest, recorder.Result().StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardAuthForwardError(t *testing.T) {
|
||||||
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// next should not be called.
|
||||||
|
t.Fail()
|
||||||
|
})
|
||||||
|
|
||||||
|
auth := dynamic.ForwardAuth{
|
||||||
|
Address: "http://non-existing-server",
|
||||||
|
}
|
||||||
|
authMiddleware, err := NewForward(t.Context(), next, auth, "authTest")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(t.Context(), 1*time.Microsecond)
|
||||||
|
defer cancel()
|
||||||
|
req := httptest.NewRequestWithContext(ctx, http.MethodGet, "http://foo", nil)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
responseComplete := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
authMiddleware.ServeHTTP(recorder, req)
|
||||||
|
close(responseComplete)
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-responseComplete
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, recorder.Result().StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
func Test_writeHeader(t *testing.T) {
|
func Test_writeHeader(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue