diff --git a/pkg/middlewares/accesslog/logger.go b/pkg/middlewares/accesslog/logger.go index ed85a8888..b5a533b8c 100644 --- a/pkg/middlewares/accesslog/logger.go +++ b/pkg/middlewares/accesslog/logger.go @@ -196,16 +196,6 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http }, } - defer func() { - if h.config.BufferingSize > 0 { - h.logHandlerChan <- handlerParams{ - logDataTable: logDataTable, - } - return - } - h.logTheRoundTrip(logDataTable) - }() - reqWithDataTable := req.WithContext(context.WithValue(req.Context(), DataTableKey, logDataTable)) core[RequestCount] = nextRequestCount() @@ -249,19 +239,30 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http return } + defer func() { + logDataTable.DownstreamResponse = downstreamResponse{ + headers: rw.Header().Clone(), + } + + logDataTable.DownstreamResponse.status = capt.StatusCode() + logDataTable.DownstreamResponse.size = capt.ResponseSize() + logDataTable.Request.size = capt.RequestSize() + + if _, ok := core[ClientUsername]; !ok { + core[ClientUsername] = usernameIfPresent(reqWithDataTable.URL) + } + + if h.config.BufferingSize > 0 { + h.logHandlerChan <- handlerParams{ + logDataTable: logDataTable, + } + return + } + + h.logTheRoundTrip(logDataTable) + }() + next.ServeHTTP(rw, reqWithDataTable) - - if _, ok := core[ClientUsername]; !ok { - core[ClientUsername] = usernameIfPresent(reqWithDataTable.URL) - } - - logDataTable.DownstreamResponse = downstreamResponse{ - headers: rw.Header().Clone(), - } - - logDataTable.DownstreamResponse.status = capt.StatusCode() - logDataTable.DownstreamResponse.size = capt.ResponseSize() - logDataTable.Request.size = capt.RequestSize() } // Close closes the Logger (i.e. the file, drain logHandlerChan, etc). diff --git a/pkg/middlewares/accesslog/logger_test.go b/pkg/middlewares/accesslog/logger_test.go index d205bb9d7..01433c5f6 100644 --- a/pkg/middlewares/accesslog/logger_test.go +++ b/pkg/middlewares/accesslog/logger_test.go @@ -2,6 +2,7 @@ package accesslog import ( "bytes" + "context" "crypto/tls" "encoding/json" "fmt" @@ -22,6 +23,7 @@ import ( "github.com/stretchr/testify/require" ptypes "github.com/traefik/paerser/types" "github.com/traefik/traefik/v2/pkg/middlewares/capture" + "github.com/traefik/traefik/v2/pkg/middlewares/recovery" "github.com/traefik/traefik/v2/pkg/types" ) @@ -519,6 +521,64 @@ func TestLoggerJSON(t *testing.T) { } } +func TestLogger_AbortedRequest(t *testing.T) { + expected := map[string]func(t *testing.T, value interface{}){ + RequestContentSize: assertFloat64(0), + RequestHost: assertString(testHostname), + RequestAddr: assertString(testHostname), + RequestMethod: assertString(testMethod), + RequestPath: assertString(""), + RequestProtocol: assertString(testProto), + RequestScheme: assertString(testScheme), + RequestPort: assertString("-"), + DownstreamStatus: assertFloat64(float64(200)), + DownstreamContentSize: assertFloat64(float64(40)), + RequestRefererHeader: assertString(testReferer), + RequestUserAgentHeader: assertString(testUserAgent), + ServiceURL: assertString("http://stream"), + ServiceAddr: assertString("127.0.0.1"), + ServiceName: assertString("stream"), + ClientUsername: assertString(testUsername), + ClientHost: assertString(testHostname), + ClientPort: assertString(strconv.Itoa(testPort)), + ClientAddr: assertString(fmt.Sprintf("%s:%d", testHostname, testPort)), + "level": assertString("info"), + "msg": assertString(""), + RequestCount: assertFloat64NotZero(), + Duration: assertFloat64NotZero(), + Overhead: assertFloat64NotZero(), + RetryAttempts: assertFloat64(float64(0)), + "time": assertNotEmpty(), + StartLocal: assertNotEmpty(), + StartUTC: assertNotEmpty(), + "downstream_Content-Type": assertString("text/plain"), + "downstream_Transfer-Encoding": assertString("chunked"), + "downstream_Cache-Control": assertString("no-cache"), + } + + config := &types.AccessLog{ + FilePath: filepath.Join(t.TempDir(), logFileNameSuffix), + Format: JSONFormat, + } + doLoggingWithAbortedStream(t, config) + + logData, err := os.ReadFile(config.FilePath) + require.NoError(t, err) + + jsonData := make(map[string]interface{}) + err = json.Unmarshal(logData, &jsonData) + require.NoError(t, err) + + assert.Equal(t, len(expected), len(jsonData)) + + for field, assertion := range expected { + assertion(t, jsonData[field]) + if t.Failed() { + return + } + } +} + func TestNewLogHandlerOutputStdout(t *testing.T) { testCases := []struct { desc string @@ -852,3 +912,89 @@ func logWriterTestHandlerFunc(rw http.ResponseWriter, r *http.Request) { rw.WriteHeader(testStatus) } + +func doLoggingWithAbortedStream(t *testing.T, config *types.AccessLog) { + t.Helper() + + logger, err := NewHandler(config) + require.NoError(t, err) + t.Cleanup(func() { + err := logger.Close() + require.NoError(t, err) + }) + + if config.FilePath != "" { + _, err = os.Stat(config.FilePath) + require.NoError(t, err, "logger should create "+config.FilePath) + } + + reqContext, cancelRequest := context.WithCancel(context.Background()) + + req := &http.Request{ + Header: map[string][]string{ + "User-Agent": {testUserAgent}, + "Referer": {testReferer}, + }, + Proto: testProto, + Host: testHostname, + Method: testMethod, + RemoteAddr: fmt.Sprintf("%s:%d", testHostname, testPort), + URL: &url.URL{ + User: url.UserPassword(testUsername, ""), + }, + Body: nil, + } + + req = req.WithContext(reqContext) + + chain := alice.New() + chain = chain.Append(func(next http.Handler) (http.Handler, error) { + return recovery.New(context.Background(), next) + }) + chain = chain.Append(capture.Wrap) + chain = chain.Append(WrapHandler(logger)) + + service := NewFieldHandler(http.HandlerFunc(streamBackend), ServiceURL, "http://stream", nil) + service = NewFieldHandler(service, ServiceAddr, "127.0.0.1", nil) + service = NewFieldHandler(service, ServiceName, "stream", AddServiceFields) + + handler, err := chain.Then(service) + require.NoError(t, err) + + go func() { + time.Sleep(499 * time.Millisecond) + cancelRequest() + }() + + handler.ServeHTTP(httptest.NewRecorder(), req) +} + +func streamBackend(rw http.ResponseWriter, r *http.Request) { + // Get the Flusher to flush the response to the client + flusher, ok := rw.(http.Flusher) + if !ok { + http.Error(rw, "Streaming unsupported!", http.StatusInternalServerError) + return + } + + // Set the headers for streaming + rw.Header().Set("Content-Type", "text/plain") + rw.Header().Set("Transfer-Encoding", "chunked") + rw.Header().Set("Cache-Control", "no-cache") + + for { + time.Sleep(100 * time.Millisecond) + + select { + case <-r.Context().Done(): + panic(http.ErrAbortHandler) + + default: + if _, err := fmt.Fprint(rw, "FOOBAR!!!!"); err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + flusher.Flush() + } + } +}