Retry should send headers on Write

Co-authored-by: Romain <rtribotte@users.noreply.github.com>
This commit is contained in:
Kevin Pollet 2025-02-21 10:52:04 +01:00 committed by GitHub
parent 8e5d4c6ae9
commit c2a294c872
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 40 additions and 18 deletions

View file

@ -196,6 +196,9 @@ func (r *responseWriterWithoutCloseNotify) Write(buf []byte) (int, error) {
if r.ShouldRetry() { if r.ShouldRetry() {
return len(buf), nil return len(buf), nil
} }
if !r.written {
r.WriteHeader(http.StatusOK)
}
return r.responseWriter.Write(buf) return r.responseWriter.Write(buf)
} }

View file

@ -169,7 +169,6 @@ func TestRetryListeners(t *testing.T) {
func TestMultipleRetriesShouldNotLooseHeaders(t *testing.T) { func TestMultipleRetriesShouldNotLooseHeaders(t *testing.T) {
attempt := 0 attempt := 0
expectedHeaderName := "X-Foo-Test-2"
expectedHeaderValue := "bar" expectedHeaderValue := "bar"
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
@ -180,44 +179,55 @@ func TestMultipleRetriesShouldNotLooseHeaders(t *testing.T) {
return return
} }
// Request has been successfully written to backend // Request has been successfully written to backend.
trace := httptrace.ContextClientTrace(req.Context()) trace := httptrace.ContextClientTrace(req.Context())
trace.WroteHeaders() trace.WroteHeaders()
// And we decide to answer to client // And we decide to answer to client.
rw.WriteHeader(http.StatusNoContent) rw.WriteHeader(http.StatusNoContent)
}) })
retry, err := New(context.Background(), next, dynamic.Retry{Attempts: 3}, &countingRetryListener{}, "traefikTest") retry, err := New(context.Background(), next, dynamic.Retry{Attempts: 3}, &countingRetryListener{}, "traefikTest")
require.NoError(t, err) require.NoError(t, err)
responseRecorder := httptest.NewRecorder() res := httptest.NewRecorder()
retry.ServeHTTP(responseRecorder, testhelpers.MustNewRequest(http.MethodGet, "http://test", http.NoBody)) retry.ServeHTTP(res, testhelpers.MustNewRequest(http.MethodGet, "http://test", http.NoBody))
headerValue := responseRecorder.Header().Get(expectedHeaderName) // The third header attempt is kept.
headerValue := res.Header().Get("X-Foo-Test-2")
// Validate if we have the correct header assert.Equal(t, expectedHeaderValue, headerValue)
if headerValue != expectedHeaderValue {
t.Errorf("Expected to have %s for header %s, got %s", expectedHeaderValue, expectedHeaderName, headerValue)
}
// Validate that we don't have headers from previous attempts // Validate that we don't have headers from previous attempts
for i := range attempt { for i := range attempt {
headerName := fmt.Sprintf("X-Foo-Test-%d", i) headerName := fmt.Sprintf("X-Foo-Test-%d", i)
headerValue = responseRecorder.Header().Get("headerName") headerValue = res.Header().Get(headerName)
if headerValue != "" { if headerValue != "" {
t.Errorf("Expected no value for header %s, got %s", headerName, headerValue) t.Errorf("Expected no value for header %s, got %s", headerName, headerValue)
} }
} }
} }
// countingRetryListener is a Listener implementation to count the times the Retried fn is called. func TestRetryShouldNotLooseHeadersOnWrite(t *testing.T) {
type countingRetryListener struct { next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
timesCalled int rw.Header().Add("X-Foo-Test", "bar")
}
func (l *countingRetryListener) Retried(req *http.Request, attempt int) { // Request has been successfully written to backend.
l.timesCalled++ trace := httptrace.ContextClientTrace(req.Context())
trace.WroteHeaders()
// And we decide to answer to client without calling WriteHeader.
_, err := rw.Write([]byte("bar"))
require.NoError(t, err)
})
retry, err := New(context.Background(), next, dynamic.Retry{Attempts: 3}, &countingRetryListener{}, "traefikTest")
require.NoError(t, err)
res := httptest.NewRecorder()
retry.ServeHTTP(res, testhelpers.MustNewRequest(http.MethodGet, "http://test", http.NoBody))
headerValue := res.Header().Get("X-Foo-Test")
assert.Equal(t, "bar", headerValue)
} }
func TestRetryWithFlush(t *testing.T) { func TestRetryWithFlush(t *testing.T) {
@ -387,3 +397,12 @@ func Test1xxResponses(t *testing.T) {
assert.Equal(t, 0, retryListener.timesCalled) assert.Equal(t, 0, retryListener.timesCalled)
} }
// countingRetryListener is a Listener implementation to count the times the Retried fn is called.
type countingRetryListener struct {
timesCalled int
}
func (l *countingRetryListener) Retried(req *http.Request, attempt int) {
l.timesCalled++
}