diff --git a/pkg/proxy/fast/connpool.go b/pkg/proxy/fast/connpool.go index f6357c59d..7ef17e459 100644 --- a/pkg/proxy/fast/connpool.go +++ b/pkg/proxy/fast/connpool.go @@ -215,34 +215,20 @@ func (c *conn) handleResponse(r rwWithUpgrade) error { return nil } - hasContentLength := len(res.Header.Peek("Content-Length")) > 0 - - if hasContentLength && res.Header.ContentLength() == 0 { - return nil - } - // When a body is not allowed for a given status code the body is ignored. // The connection will be marked as broken by the next Peek in the readloop. if !isBodyAllowedForStatus(res.StatusCode()) { return nil } - if !hasContentLength { - b := c.bufferPool.Get() - if b == nil { - b = make([]byte, bufferSize) - } - defer c.bufferPool.Put(b) - - if _, err := io.CopyBuffer(r.RW, c.br, b); err != nil { - return err - } + contentLength := res.Header.ContentLength() + if contentLength == 0 { return nil } // Chunked response, Content-Length is set to -1 by FastProxy when "Transfer-Encoding: chunked" header is received. - if res.Header.ContentLength() == -1 { + if contentLength == -1 { cbr := httputil.NewChunkedReader(c.br) b := c.bufferPool.Get() @@ -282,6 +268,23 @@ func (c *conn) handleResponse(r rwWithUpgrade) error { return nil } + // Response without Content-Length header. + // The message body length is determined by the number of bytes received prior to the server closing the connection. + if contentLength == -2 { + b := c.bufferPool.Get() + if b == nil { + b = make([]byte, bufferSize) + } + defer c.bufferPool.Put(b) + + if _, err := io.CopyBuffer(r.RW, c.br, b); err != nil { + return err + } + + return nil + } + + // Response with a valid Content-Length header. brl := c.limitedReaderPool.Get() if brl == nil { brl = &io.LimitedReader{} diff --git a/pkg/proxy/fast/proxy_test.go b/pkg/proxy/fast/proxy_test.go index 4c4c158b1..143039cc6 100644 --- a/pkg/proxy/fast/proxy_test.go +++ b/pkg/proxy/fast/proxy_test.go @@ -306,7 +306,7 @@ func TestHeadRequest(t *testing.T) { assert.Equal(t, http.StatusOK, res.Code) } -func TestNoContentLengthResponse(t *testing.T) { +func TestNoContentLength(t *testing.T) { backendListener, err := net.Listen("tcp", ":0") require.NoError(t, err) @@ -346,6 +346,45 @@ func TestNoContentLengthResponse(t *testing.T) { assert.Equal(t, "foo", res.Body.String()) } +func TestTransferEncodingChunked(t *testing.T) { + backendServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + flusher, ok := rw.(http.Flusher) + require.True(t, ok) + + for i := range 3 { + _, err := rw.Write([]byte(fmt.Sprintf("chunk %d\n", i))) + require.NoError(t, err) + + flusher.Flush() + } + })) + t.Cleanup(backendServer.Close) + + builder := NewProxyBuilder(&transportManagerMock{}, static.FastProxyConfig{}) + + proxyHandler, err := builder.Build("", testhelpers.MustParseURL(backendServer.URL), true, true) + require.NoError(t, err) + + proxyServer := httptest.NewServer(proxyHandler) + t.Cleanup(proxyServer.Close) + + req, err := http.NewRequest(http.MethodGet, proxyServer.URL, http.NoBody) + require.NoError(t, err) + + res, err := http.DefaultClient.Do(req) + require.NoError(t, err) + + t.Cleanup(func() { _ = res.Body.Close() }) + + assert.Equal(t, http.StatusOK, res.StatusCode) + assert.Equal(t, []string{"chunked"}, res.TransferEncoding) + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + + assert.Equal(t, "chunk 0\nchunk 1\nchunk 2\n", string(body)) +} + func newCertificate(t *testing.T, domain string) *tls.Certificate { t.Helper()