Support Accept-Encoding header weights with Compress middleware

This commit is contained in:
Ludovic Fernandez 2024-06-06 16:42:04 +02:00 committed by GitHub
parent 359477c583
commit 778dc22e14
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 398 additions and 66 deletions

View file

@ -18,12 +18,9 @@ import (
)
const (
acceptEncodingHeader = "Accept-Encoding"
contentEncodingHeader = "Content-Encoding"
contentTypeHeader = "Content-Type"
varyHeader = "Vary"
gzipValue = "gzip"
brotliValue = "br"
)
func TestNegotiation(t *testing.T) {
@ -62,9 +59,9 @@ func TestNegotiation(t *testing.T) {
expEncoding: "br",
},
{
desc: "multi accept header, prefer br",
desc: "multi accept header, prefer gzip",
acceptEncHeader: "gzip;q=1.0, br;q=0.8",
expEncoding: "br",
expEncoding: "gzip",
},
{
desc: "multi accept header list, prefer br",
@ -98,7 +95,7 @@ func TestNegotiation(t *testing.T) {
func TestShouldCompressWhenNoContentEncodingHeader(t *testing.T) {
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
req.Header.Add(acceptEncodingHeader, gzipName)
baseBody := generateBytes(gzhttp.DefaultMinSize)
@ -112,7 +109,7 @@ func TestShouldCompressWhenNoContentEncodingHeader(t *testing.T) {
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req)
assert.Equal(t, gzipValue, rw.Header().Get(contentEncodingHeader))
assert.Equal(t, gzipName, rw.Header().Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, rw.Header().Get(varyHeader))
gr, err := gzip.NewReader(rw.Body)
@ -125,11 +122,11 @@ func TestShouldCompressWhenNoContentEncodingHeader(t *testing.T) {
func TestShouldNotCompressWhenContentEncodingHeader(t *testing.T) {
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
req.Header.Add(acceptEncodingHeader, gzipName)
fakeCompressedBody := generateBytes(gzhttp.DefaultMinSize)
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(contentEncodingHeader, gzipValue)
rw.Header().Add(contentEncodingHeader, gzipName)
rw.Header().Add(varyHeader, acceptEncodingHeader)
_, err := rw.Write(fakeCompressedBody)
if err != nil {
@ -142,7 +139,7 @@ func TestShouldNotCompressWhenContentEncodingHeader(t *testing.T) {
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req)
assert.Equal(t, gzipValue, rw.Header().Get(contentEncodingHeader))
assert.Equal(t, gzipName, rw.Header().Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, rw.Header().Get(varyHeader))
assert.EqualValues(t, rw.Body.Bytes(), fakeCompressedBody)
@ -225,7 +222,7 @@ func TestShouldNotCompressWhenEmptyAcceptEncodingHeader(t *testing.T) {
func TestShouldNotCompressHeadRequest(t *testing.T) {
req := testhelpers.MustNewRequest(http.MethodHead, "http://localhost", nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
req.Header.Add(acceptEncodingHeader, gzipName)
fakeBody := generateBytes(gzhttp.DefaultMinSize)
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
@ -301,7 +298,7 @@ func TestShouldNotCompressWhenSpecificContentType(t *testing.T) {
t.Parallel()
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
req.Header.Add(acceptEncodingHeader, gzipName)
if test.reqContentType != "" {
req.Header.Add(contentTypeHeader, test.reqContentType)
}
@ -352,7 +349,7 @@ func TestShouldCompressWhenSpecificContentType(t *testing.T) {
t.Parallel()
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
req.Header.Add(acceptEncodingHeader, gzipName)
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set(contentTypeHeader, test.respContentType)
@ -368,7 +365,7 @@ func TestShouldCompressWhenSpecificContentType(t *testing.T) {
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req)
assert.Equal(t, gzipValue, rw.Header().Get(contentEncodingHeader))
assert.Equal(t, gzipName, rw.Header().Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, rw.Header().Get(varyHeader))
assert.NotEqualValues(t, rw.Body.Bytes(), baseBody)
})
@ -386,7 +383,7 @@ func TestIntegrationShouldNotCompress(t *testing.T) {
{
name: "when content already compressed",
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(contentEncodingHeader, gzipValue)
rw.Header().Add(contentEncodingHeader, gzipName)
rw.Header().Add(varyHeader, acceptEncodingHeader)
_, err := rw.Write(fakeCompressedBody)
if err != nil {
@ -398,7 +395,7 @@ func TestIntegrationShouldNotCompress(t *testing.T) {
{
name: "when content already compressed and status code Created",
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(contentEncodingHeader, gzipValue)
rw.Header().Add(contentEncodingHeader, gzipName)
rw.Header().Add(varyHeader, acceptEncodingHeader)
rw.WriteHeader(http.StatusCreated)
_, err := rw.Write(fakeCompressedBody)
@ -419,14 +416,14 @@ func TestIntegrationShouldNotCompress(t *testing.T) {
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
req.Header.Add(acceptEncodingHeader, gzipName)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, test.expectedStatusCode, resp.StatusCode)
assert.Equal(t, gzipValue, resp.Header.Get(contentEncodingHeader))
assert.Equal(t, gzipName, resp.Header.Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, resp.Header.Get(varyHeader))
body, err := io.ReadAll(resp.Body)
@ -438,7 +435,7 @@ func TestIntegrationShouldNotCompress(t *testing.T) {
func TestShouldWriteHeaderWhenFlush(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(contentEncodingHeader, gzipValue)
rw.Header().Add(contentEncodingHeader, gzipName)
rw.Header().Add(varyHeader, acceptEncodingHeader)
rw.WriteHeader(http.StatusUnauthorized)
rw.(http.Flusher).Flush()
@ -454,14 +451,14 @@ func TestShouldWriteHeaderWhenFlush(t *testing.T) {
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
req.Header.Add(acceptEncodingHeader, gzipName)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.Equal(t, gzipValue, resp.Header.Get(contentEncodingHeader))
assert.Equal(t, gzipName, resp.Header.Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, resp.Header.Get(varyHeader))
}
@ -505,14 +502,14 @@ func TestIntegrationShouldCompress(t *testing.T) {
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
req.Header.Add(acceptEncodingHeader, gzipName)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, test.expectedStatusCode, resp.StatusCode)
assert.Equal(t, gzipValue, resp.Header.Get(contentEncodingHeader))
assert.Equal(t, gzipName, resp.Header.Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, resp.Header.Get(varyHeader))
body, err := io.ReadAll(resp.Body)
@ -547,7 +544,7 @@ func TestMinResponseBodyBytes(t *testing.T) {
t.Parallel()
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
req.Header.Add(acceptEncodingHeader, gzipName)
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
if _, err := rw.Write(fakeBody); err != nil {
@ -562,7 +559,7 @@ func TestMinResponseBodyBytes(t *testing.T) {
handler.ServeHTTP(rw, req)
if test.expectedCompression {
assert.Equal(t, gzipValue, rw.Header().Get(contentEncodingHeader))
assert.Equal(t, gzipName, rw.Header().Get(contentEncodingHeader))
assert.NotEqualValues(t, rw.Body.Bytes(), fakeBody)
return
}
@ -636,7 +633,7 @@ func Test1xxResponses(t *testing.T) {
},
}
req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), http.MethodGet, server.URL, nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
req.Header.Add(acceptEncodingHeader, gzipName)
res, err := frontendClient.Do(req)
assert.NoError(t, err)
@ -648,7 +645,7 @@ func Test1xxResponses(t *testing.T) {
}
checkLinkHeaders(t, []string{"</style.css>; rel=preload; as=style", "</script.js>; rel=preload; as=script", "</foo.js>; rel=preload; as=script"}, res.Header["Link"])
assert.Equal(t, gzipValue, res.Header.Get(contentEncodingHeader))
assert.Equal(t, gzipName, res.Header.Get(contentEncodingHeader))
body, _ := io.ReadAll(res.Body)
assert.NotEqualValues(t, body, fakeBody)
}
@ -730,7 +727,7 @@ func runBenchmark(b *testing.B, req *http.Request, handler http.Handler) {
b.Fatalf("Expected 200 but got %d", code)
}
assert.Equal(b, gzipValue, res.Header().Get(contentEncodingHeader))
assert.Equal(b, gzipName, res.Header().Get(contentEncodingHeader))
}
func generateBytes(length int) []byte {