avoid retries when any data was written to the backend
This commit is contained in:
parent
586ba31120
commit
e31c85aace
5 changed files with 161 additions and 250 deletions
|
@ -1,40 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/middlewares"
|
||||
)
|
||||
|
||||
// RecordingErrorHandler is an error handler, implementing the vulcand/oxy
|
||||
// error handler interface, which is recording network errors by using the netErrorRecorder.
|
||||
// In addition it sets a proper HTTP status code and body, depending on the type of error occurred.
|
||||
type RecordingErrorHandler struct {
|
||||
netErrorRecorder middlewares.NetErrorRecorder
|
||||
}
|
||||
|
||||
// NewRecordingErrorHandler creates and returns a new instance of RecordingErrorHandler.
|
||||
func NewRecordingErrorHandler(recorder middlewares.NetErrorRecorder) *RecordingErrorHandler {
|
||||
return &RecordingErrorHandler{recorder}
|
||||
}
|
||||
|
||||
func (eh *RecordingErrorHandler) ServeHTTP(w http.ResponseWriter, req *http.Request, err error) {
|
||||
statusCode := http.StatusInternalServerError
|
||||
|
||||
if e, ok := err.(net.Error); ok {
|
||||
eh.netErrorRecorder.Record(req.Context())
|
||||
if e.Timeout() {
|
||||
statusCode = http.StatusGatewayTimeout
|
||||
} else {
|
||||
statusCode = http.StatusBadGateway
|
||||
}
|
||||
} else if err == io.EOF {
|
||||
eh.netErrorRecorder.Record(req.Context())
|
||||
statusCode = http.StatusBadGateway
|
||||
}
|
||||
|
||||
w.WriteHeader(statusCode)
|
||||
w.Write([]byte(http.StatusText(statusCode)))
|
||||
}
|
|
@ -1,88 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type timeoutError struct{}
|
||||
|
||||
func (e *timeoutError) Error() string { return "i/o timeout" }
|
||||
func (e *timeoutError) Timeout() bool { return true }
|
||||
func (e *timeoutError) Temporary() bool { return true }
|
||||
|
||||
func TestServeHTTP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantHTTPStatus int
|
||||
wantNetErrRecorded bool
|
||||
}{
|
||||
{
|
||||
name: "net.Error",
|
||||
err: net.UnknownNetworkError("any network error"),
|
||||
wantHTTPStatus: http.StatusBadGateway,
|
||||
wantNetErrRecorded: true,
|
||||
},
|
||||
{
|
||||
name: "net.Error with Timeout",
|
||||
err: &timeoutError{},
|
||||
wantHTTPStatus: http.StatusGatewayTimeout,
|
||||
wantNetErrRecorded: true,
|
||||
},
|
||||
{
|
||||
name: "io.EOF",
|
||||
err: io.EOF,
|
||||
wantHTTPStatus: http.StatusBadGateway,
|
||||
wantNetErrRecorded: true,
|
||||
},
|
||||
{
|
||||
name: "custom error",
|
||||
err: errors.New("any error"),
|
||||
wantHTTPStatus: http.StatusInternalServerError,
|
||||
wantNetErrRecorded: false,
|
||||
},
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
wantHTTPStatus: http.StatusInternalServerError,
|
||||
wantNetErrRecorded: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
errorRecorder := &netErrorRecorder{}
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/any", nil)
|
||||
|
||||
recordingErrorHandler := NewRecordingErrorHandler(errorRecorder)
|
||||
recordingErrorHandler.ServeHTTP(recorder, req, test.err)
|
||||
|
||||
if recorder.Code != test.wantHTTPStatus {
|
||||
t.Errorf("got HTTP status code %v, wanted %v", recorder.Code, test.wantHTTPStatus)
|
||||
}
|
||||
if errorRecorder.netErrorWasRecorded != test.wantNetErrRecorded {
|
||||
t.Errorf("net error recording wrong, got %v wanted %v", errorRecorder.netErrorWasRecorded, test.wantNetErrRecorded)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type netErrorRecorder struct {
|
||||
netErrorWasRecorded bool
|
||||
}
|
||||
|
||||
func (recorder *netErrorRecorder) Record(ctx context.Context) {
|
||||
recorder.netErrorWasRecorded = true
|
||||
}
|
|
@ -23,7 +23,6 @@ import (
|
|||
"github.com/eapache/channels"
|
||||
"github.com/urfave/negroni"
|
||||
"github.com/vulcand/oxy/forward"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
// loadConfiguration manages dynamically frontends, backends and TLS configurations
|
||||
|
@ -80,7 +79,6 @@ func (s *Server) loadConfig(configurations types.Configurations, globalConfigura
|
|||
}
|
||||
|
||||
serverEntryPoints := s.buildServerEntryPoints()
|
||||
errorHandler := NewRecordingErrorHandler(middlewares.DefaultNetErrorRecorder{})
|
||||
|
||||
backendsHandlers := map[string]http.Handler{}
|
||||
backendsHealthCheck := map[string]*healthcheck.BackendConfig{}
|
||||
|
@ -92,7 +90,7 @@ func (s *Server) loadConfig(configurations types.Configurations, globalConfigura
|
|||
|
||||
for _, frontendName := range frontendNames {
|
||||
frontendPostConfigs, err := s.loadFrontendConfig(providerName, frontendName, config,
|
||||
redirectHandlers, serverEntryPoints, errorHandler,
|
||||
redirectHandlers, serverEntryPoints,
|
||||
backendsHandlers, backendsHealthCheck)
|
||||
if err != nil {
|
||||
log.Errorf("%v. Skipping frontend %s...", err, frontendName)
|
||||
|
@ -131,7 +129,7 @@ func (s *Server) loadConfig(configurations types.Configurations, globalConfigura
|
|||
|
||||
func (s *Server) loadFrontendConfig(
|
||||
providerName string, frontendName string, config *types.Configuration,
|
||||
redirectHandlers map[string]negroni.Handler, serverEntryPoints map[string]*serverEntryPoint, errorHandler *RecordingErrorHandler,
|
||||
redirectHandlers map[string]negroni.Handler, serverEntryPoints map[string]*serverEntryPoint,
|
||||
backendsHandlers map[string]http.Handler, backendsHealthCheck map[string]*healthcheck.BackendConfig,
|
||||
) ([]handlerPostConfig, error) {
|
||||
|
||||
|
@ -170,7 +168,7 @@ func (s *Server) loadFrontendConfig(
|
|||
postConfigs = append(postConfigs, postConfig)
|
||||
}
|
||||
|
||||
fwd, err := s.buildForwarder(entryPointName, entryPoint, frontendName, frontend, errorHandler, responseModifier)
|
||||
fwd, err := s.buildForwarder(entryPointName, entryPoint, frontendName, frontend, responseModifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create the forwarder for frontend %s: %v", frontendName, err)
|
||||
}
|
||||
|
@ -222,7 +220,7 @@ func (s *Server) loadFrontendConfig(
|
|||
|
||||
func (s *Server) buildForwarder(entryPointName string, entryPoint *configuration.EntryPoint,
|
||||
frontendName string, frontend *types.Frontend,
|
||||
errorHandler utils.ErrorHandler, responseModifier modifyResponse) (http.Handler, error) {
|
||||
responseModifier modifyResponse) (http.Handler, error) {
|
||||
|
||||
roundTripper, err := s.getRoundTripper(entryPointName, frontend.PassTLSCert, entryPoint.TLS)
|
||||
if err != nil {
|
||||
|
@ -239,7 +237,6 @@ func (s *Server) buildForwarder(entryPointName string, entryPoint *configuration
|
|||
forward.Stream(true),
|
||||
forward.PassHostHeader(frontend.PassHostHeader),
|
||||
forward.RoundTripper(roundTripper),
|
||||
forward.ErrorHandler(errorHandler),
|
||||
forward.Rewriter(rewriter),
|
||||
forward.ResponseModifier(responseModifier),
|
||||
forward.BufferPool(s.bufferPool),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue