diff --git a/pkg/middlewares/customerrors/custom_errors.go b/pkg/middlewares/customerrors/custom_errors.go index a2c53935f..3cc0f494d 100644 --- a/pkg/middlewares/customerrors/custom_errors.go +++ b/pkg/middlewares/customerrors/custom_errors.go @@ -2,7 +2,6 @@ package customerrors import ( "bufio" - "bytes" "context" "fmt" "net" @@ -12,7 +11,6 @@ import ( "strings" "github.com/opentracing/opentracing-go/ext" - "github.com/sirupsen/logrus" "github.com/traefik/traefik/v2/pkg/config/dynamic" "github.com/traefik/traefik/v2/pkg/log" "github.com/traefik/traefik/v2/pkg/middlewares" @@ -23,7 +21,8 @@ import ( // Compile time validation that the response recorder implements http interfaces correctly. var ( - _ middlewares.Stateful = &responseRecorderWithCloseNotify{} + // TODO: maybe remove at least for codeModifierWithCloseNotify. + _ middlewares.Stateful = &codeModifierWithCloseNotify{} _ middlewares.Stateful = &codeCatcherWithCloseNotify{} ) @@ -88,44 +87,25 @@ func (c *customErrors) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // check the recorder code against the configured http status code ranges code := catcher.getCode() - for _, block := range c.httpCodeRanges { - if code < block[0] || code > block[1] { - continue - } + logger.Debugf("Caught HTTP Status Code %d, returning error page", code) - logger.Debugf("Caught HTTP Status Code %d, returning error page", code) - - var query string - if len(c.backendQuery) > 0 { - query = "/" + strings.TrimPrefix(c.backendQuery, "/") - query = strings.ReplaceAll(query, "{status}", strconv.Itoa(code)) - } - - pageReq, err := newRequest("http://" + req.Host + query) - if err != nil { - logger.Error(err) - rw.WriteHeader(code) - _, err = fmt.Fprint(rw, http.StatusText(code)) - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) - } - return - } - - recorderErrorPage := newResponseRecorder(ctx, rw) - utils.CopyHeaders(pageReq.Header, req.Header) - - c.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context())) - - utils.CopyHeaders(rw.Header(), recorderErrorPage.Header()) - rw.WriteHeader(code) - - if _, err = rw.Write(recorderErrorPage.GetBody().Bytes()); err != nil { - logger.Error(err) - } + var query string + if len(c.backendQuery) > 0 { + query = "/" + strings.TrimPrefix(c.backendQuery, "/") + query = strings.ReplaceAll(query, "{status}", strconv.Itoa(code)) + } + pageReq, err := newRequest("http://" + req.Host + query) + if err != nil { + logger.Error(err) + http.Error(rw, http.StatusText(code), code) return } + + utils.CopyHeaders(pageReq.Header, req.Header) + + c.backendHandler.ServeHTTP(newCodeModifier(rw, code), + pageReq.WithContext(req.Context())) } func newRequest(baseURL string) (*http.Request, error) { @@ -269,106 +249,85 @@ func (cc *codeCatcher) Flush() { } } -type responseRecorder interface { +// codeModifier forwards a response back to the client, +// while enforcing a given response code. +type codeModifier interface { http.ResponseWriter - http.Flusher - GetCode() int - GetBody() *bytes.Buffer - IsStreamingResponseStarted() bool } -// newResponseRecorder returns an initialized responseRecorder. -func newResponseRecorder(ctx context.Context, rw http.ResponseWriter) responseRecorder { - recorder := &responseRecorderWithoutCloseNotify{ - HeaderMap: make(http.Header), - Body: new(bytes.Buffer), - Code: http.StatusOK, +// newCodeModifier returns a codeModifier that enforces the given code. +func newCodeModifier(rw http.ResponseWriter, code int) codeModifier { + codeMod := &codeModifierWithoutCloseNotify{ + headerMap: make(http.Header), + code: code, responseWriter: rw, - logger: log.FromContext(ctx), } if _, ok := rw.(http.CloseNotifier); ok { - return &responseRecorderWithCloseNotify{recorder} + return &codeModifierWithCloseNotify{codeMod} } - return recorder + return codeMod } -// responseRecorderWithoutCloseNotify is an implementation of http.ResponseWriter that -// records its mutations for later inspection. -type responseRecorderWithoutCloseNotify struct { - Code int // the HTTP response code from WriteHeader - HeaderMap http.Header // the HTTP response headers - Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to +type codeModifierWithoutCloseNotify struct { + code int // the code enforced in the response. - responseWriter http.ResponseWriter - err error - streamingResponseStarted bool - logger logrus.FieldLogger + // headerSent is whether the headers have already been sent, + // either through Write or WriteHeader. + headerSent bool + headerMap http.Header // the HTTP response headers from the backend. + + responseWriter http.ResponseWriter } -type responseRecorderWithCloseNotify struct { - *responseRecorderWithoutCloseNotify +type codeModifierWithCloseNotify struct { + *codeModifierWithoutCloseNotify } // CloseNotify returns a channel that receives at most a // single value (true) when the client connection has gone away. -func (r *responseRecorderWithCloseNotify) CloseNotify() <-chan bool { +func (r *codeModifierWithCloseNotify) CloseNotify() <-chan bool { return r.responseWriter.(http.CloseNotifier).CloseNotify() } // Header returns the response headers. -func (r *responseRecorderWithoutCloseNotify) Header() http.Header { - if r.HeaderMap == nil { - r.HeaderMap = make(http.Header) +func (r *codeModifierWithoutCloseNotify) Header() http.Header { + if r.headerMap == nil { + r.headerMap = make(http.Header) } - return r.HeaderMap + return r.headerMap } -func (r *responseRecorderWithoutCloseNotify) GetCode() int { - return r.Code +// Write calls WriteHeader to send the enforced code, +// then writes the data directly to r.responseWriter. +func (r *codeModifierWithoutCloseNotify) Write(buf []byte) (int, error) { + r.WriteHeader(r.code) + return r.responseWriter.Write(buf) } -func (r *responseRecorderWithoutCloseNotify) GetBody() *bytes.Buffer { - return r.Body -} - -func (r *responseRecorderWithoutCloseNotify) IsStreamingResponseStarted() bool { - return r.streamingResponseStarted -} - -// Write always succeeds and writes to rw.Body, if not nil. -func (r *responseRecorderWithoutCloseNotify) Write(buf []byte) (int, error) { - if r.err != nil { - return 0, r.err +// WriteHeader sends the headers, with the enforced code (the code in argument +// is always ignored), if it hasn't already been done. +func (r *codeModifierWithoutCloseNotify) WriteHeader(_ int) { + if r.headerSent { + return } - return r.Body.Write(buf) -} -// WriteHeader sets rw.Code. -func (r *responseRecorderWithoutCloseNotify) WriteHeader(code int) { - r.Code = code + utils.CopyHeaders(r.responseWriter.Header(), r.Header()) + r.responseWriter.WriteHeader(r.code) + r.headerSent = true } // Hijack hijacks the connection. -func (r *responseRecorderWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return r.responseWriter.(http.Hijacker).Hijack() +func (r *codeModifierWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := r.responseWriter.(http.Hijacker) + if !ok { + return nil, nil, fmt.Errorf("%T is not a http.Hijacker", r.responseWriter) + } + return hijacker.Hijack() } // Flush sends any buffered data to the client. -func (r *responseRecorderWithoutCloseNotify) Flush() { - if !r.streamingResponseStarted { - utils.CopyHeaders(r.responseWriter.Header(), r.Header()) - r.responseWriter.WriteHeader(r.Code) - r.streamingResponseStarted = true - } - - _, err := r.responseWriter.Write(r.Body.Bytes()) - if err != nil { - r.logger.Errorf("Error writing response in responseRecorder: %v", err) - r.err = err - } - r.Body.Reset() - +func (r *codeModifierWithoutCloseNotify) Flush() { if flusher, ok := r.responseWriter.(http.Flusher); ok { flusher.Flush() } diff --git a/pkg/middlewares/customerrors/custom_errors_test.go b/pkg/middlewares/customerrors/custom_errors_test.go index 1e263d686..7b0a060ec 100644 --- a/pkg/middlewares/customerrors/custom_errors_test.go +++ b/pkg/middlewares/customerrors/custom_errors_test.go @@ -180,12 +180,12 @@ func TestNewResponseRecorder(t *testing.T) { { desc: "Without Close Notify", rw: httptest.NewRecorder(), - expected: &responseRecorderWithoutCloseNotify{}, + expected: &codeModifierWithoutCloseNotify{}, }, { desc: "With Close Notify", rw: &mockRWCloseNotify{}, - expected: &responseRecorderWithCloseNotify{}, + expected: &codeModifierWithCloseNotify{}, }, } @@ -194,7 +194,7 @@ func TestNewResponseRecorder(t *testing.T) { t.Run(test.desc, func(t *testing.T) { t.Parallel() - rec := newResponseRecorder(context.Background(), test.rw) + rec := newCodeModifier(test.rw, 0) assert.IsType(t, rec, test.expected) }) }