1
0
Fork 0

Add support for Zstandard to the Compression middleware

This commit is contained in:
Antoine Aflalo 2024-06-12 05:38:04 -04:00 committed by GitHub
parent 3f48e6f8ef
commit b795f128d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 576 additions and 213 deletions

View file

@ -11,6 +11,7 @@ const acceptEncodingHeader = "Accept-Encoding"
const (
brotliName = "br"
gzipName = "gzip"
zstdName = "zstd"
identityName = "identity"
wildcardName = "*"
notAcceptable = "not_acceptable"
@ -51,7 +52,7 @@ func getCompressionType(acceptEncoding []string, defaultType string) string {
return encoding.Type
}
for _, dt := range []string{brotliName, gzipName} {
for _, dt := range []string{zstdName, brotliName, gzipName} {
if slices.ContainsFunc(encodings, func(e Encoding) bool { return e.Type == dt }) {
return dt
}
@ -76,7 +77,7 @@ func parseAcceptEncoding(acceptEncoding []string) ([]Encoding, bool) {
}
switch parsed[0] {
case brotliName, gzipName, identityName, wildcardName:
case zstdName, brotliName, gzipName, identityName, wildcardName:
// supported encoding
default:
continue

View file

@ -18,6 +18,11 @@ func Test_getCompressionType(t *testing.T) {
values: []string{"gzip, br"},
expected: brotliName,
},
{
desc: "zstd > br > gzip (no weight)",
values: []string{"zstd, gzip, br"},
expected: zstdName,
},
{
desc: "known compression type (no weight)",
values: []string{"compress, gzip"},
@ -49,6 +54,11 @@ func Test_getCompressionType(t *testing.T) {
values: []string{"compress;q=1.0, gzip;q=0.5"},
expected: gzipName,
},
{
desc: "fallback on non-zero compression type",
values: []string{"compress;q=1.0, gzip, identity;q=0"},
expected: gzipName,
},
{
desc: "not acceptable (identity)",
values: []string{"compress;q=1.0, identity;q=0"},
@ -86,9 +96,10 @@ func Test_parseAcceptEncoding(t *testing.T) {
}{
{
desc: "weight",
values: []string{"br;q=1.0, gzip;q=0.8, *;q=0.1"},
values: []string{"br;q=1.0, zstd;q=0.9, gzip;q=0.8, *;q=0.1"},
expected: []Encoding{
{Type: brotliName, Weight: ptr[float64](1)},
{Type: zstdName, Weight: ptr(0.9)},
{Type: gzipName, Weight: ptr(0.8)},
{Type: wildcardName, Weight: ptr(0.1)},
},
@ -96,9 +107,10 @@ func Test_parseAcceptEncoding(t *testing.T) {
},
{
desc: "mixed",
values: []string{"gzip, br;q=1.0, *;q=0"},
values: []string{"zstd,gzip, br;q=1.0, *;q=0"},
expected: []Encoding{
{Type: brotliName, Weight: ptr[float64](1)},
{Type: zstdName},
{Type: gzipName},
{Type: wildcardName, Weight: ptr[float64](0)},
},
@ -106,8 +118,9 @@ func Test_parseAcceptEncoding(t *testing.T) {
},
{
desc: "no weight",
values: []string{"gzip, br, *"},
values: []string{"zstd, gzip, br, *"},
expected: []Encoding{
{Type: zstdName},
{Type: gzipName},
{Type: brotliName},
{Type: wildcardName},

View file

@ -11,7 +11,6 @@ import (
"github.com/klauspost/compress/gzhttp"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/middlewares"
"github.com/traefik/traefik/v3/pkg/middlewares/compress/brotli"
"go.opentelemetry.io/otel/trace"
)
@ -32,6 +31,7 @@ type compress struct {
brotliHandler http.Handler
gzipHandler http.Handler
zstdHandler http.Handler
}
// New creates a new compress middleware.
@ -77,7 +77,13 @@ func New(ctx context.Context, next http.Handler, conf dynamic.Compress, name str
}
var err error
c.brotliHandler, err = c.newBrotliHandler()
c.zstdHandler, err = c.newCompressionHandler(zstdName, name)
if err != nil {
return nil, err
}
c.brotliHandler, err = c.newCompressionHandler(brotliName, name)
if err != nil {
return nil, err
}
@ -130,6 +136,8 @@ func (c *compress) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
func (c *compress) chooseHandler(typ string, rw http.ResponseWriter, req *http.Request) {
switch typ {
case zstdName:
c.zstdHandler.ServeHTTP(rw, req)
case brotliName:
c.brotliHandler.ServeHTTP(rw, req)
case gzipName:
@ -166,18 +174,13 @@ func (c *compress) newGzipHandler() (http.Handler, error) {
return wrapper(c.next), nil
}
func (c *compress) newBrotliHandler() (http.Handler, error) {
cfg := brotli.Config{MinSize: c.minSize}
func (c *compress) newCompressionHandler(algo string, middlewareName string) (http.Handler, error) {
cfg := Config{MinSize: c.minSize, Algorithm: algo, MiddlewareName: middlewareName}
if len(c.includes) > 0 {
cfg.IncludedContentTypes = c.includes
} else {
cfg.ExcludedContentTypes = c.excludes
}
wrapper, err := brotli.NewWrapper(cfg)
if err != nil {
return nil, fmt.Errorf("new brotli wrapper: %w", err)
}
return wrapper(c.next), nil
return NewCompressionHandler(cfg, c.next)
}

View file

@ -41,32 +41,52 @@ func TestNegotiation(t *testing.T) {
{
desc: "accept any header",
acceptEncHeader: "*",
expEncoding: "br",
expEncoding: brotliName,
},
{
desc: "gzip accept header",
acceptEncHeader: "gzip",
expEncoding: "gzip",
expEncoding: gzipName,
},
{
desc: "br accept header",
acceptEncHeader: "br",
expEncoding: "br",
expEncoding: brotliName,
},
{
desc: "multi accept header, prefer br",
acceptEncHeader: "br;q=0.8, gzip;q=0.6",
expEncoding: "br",
expEncoding: brotliName,
},
{
desc: "multi accept header, prefer gzip",
acceptEncHeader: "gzip;q=1.0, br;q=0.8",
expEncoding: "gzip",
expEncoding: gzipName,
},
{
desc: "multi accept header list, prefer br",
acceptEncHeader: "gzip, br",
expEncoding: "br",
expEncoding: brotliName,
},
{
desc: "zstd accept header",
acceptEncHeader: "zstd",
expEncoding: zstdName,
},
{
desc: "multi accept header, prefer zstd",
acceptEncHeader: "zstd;q=0.9, br;q=0.8, gzip;q=0.6",
expEncoding: zstdName,
},
{
desc: "multi accept header, prefer gzip",
acceptEncHeader: "gzip;q=1.0, br;q=0.8, zstd;q=0.7",
expEncoding: gzipName,
},
{
desc: "multi accept header list, prefer zstd",
acceptEncHeader: "gzip, br, zstd",
expEncoding: zstdName,
},
}

View file

@ -1,4 +1,4 @@
package brotli
package compress
import (
"bufio"
@ -10,6 +10,9 @@ import (
"net/http"
"github.com/andybalholm/brotli"
"github.com/klauspost/compress/zstd"
"github.com/traefik/traefik/v3/pkg/middlewares"
"github.com/traefik/traefik/v3/pkg/middlewares/observability"
)
const (
@ -30,10 +33,26 @@ type Config struct {
IncludedContentTypes []string
// MinSize is the minimum size (in bytes) required to enable compression.
MinSize int
// Algorithm used for the compression (currently Brotli and Zstandard)
Algorithm string
// MiddlewareName use for logging purposes
MiddlewareName string
}
// NewWrapper returns a new Brotli compressing wrapper.
func NewWrapper(cfg Config) (func(http.Handler) http.HandlerFunc, error) {
// CompressionHandler handles Brolti and Zstd compression.
type CompressionHandler struct {
cfg Config
excludedContentTypes []parsedContentType
includedContentTypes []parsedContentType
next http.Handler
}
// NewCompressionHandler returns a new compressing handler.
func NewCompressionHandler(cfg Config, next http.Handler) (http.Handler, error) {
if cfg.Algorithm == "" {
return nil, errors.New("compression algorithm undefined")
}
if cfg.MinSize < 0 {
return nil, errors.New("minimum size must be greater than or equal to zero")
}
@ -62,30 +81,89 @@ func NewWrapper(cfg Config) (func(http.Handler) http.HandlerFunc, error) {
includedContentTypes = append(includedContentTypes, parsedContentType{mediaType, params})
}
return func(h http.Handler) http.HandlerFunc {
return func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(vary, acceptEncoding)
brw := &responseWriter{
rw: rw,
bw: brotli.NewWriter(rw),
minSize: cfg.MinSize,
statusCode: http.StatusOK,
excludedContentTypes: excludedContentTypes,
includedContentTypes: includedContentTypes,
}
defer brw.close()
h.ServeHTTP(brw, r)
}
return &CompressionHandler{
cfg: cfg,
excludedContentTypes: excludedContentTypes,
includedContentTypes: includedContentTypes,
next: next,
}, nil
}
func (c *CompressionHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(vary, acceptEncoding)
compressionWriter, err := newCompressionWriter(c.cfg.Algorithm, rw)
if err != nil {
logger := middlewares.GetLogger(r.Context(), c.cfg.MiddlewareName, typeName)
logMessage := fmt.Sprintf("create compression handler: %v", err)
logger.Debug().Msg(logMessage)
observability.SetStatusErrorf(r.Context(), logMessage)
rw.WriteHeader(http.StatusInternalServerError)
return
}
responseWriter := &responseWriter{
rw: rw,
compressionWriter: compressionWriter,
minSize: c.cfg.MinSize,
statusCode: http.StatusOK,
excludedContentTypes: c.excludedContentTypes,
includedContentTypes: c.includedContentTypes,
}
defer responseWriter.close()
c.next.ServeHTTP(responseWriter, r)
}
type compression interface {
// Write data to the encoder.
// Input data will be buffered and as the buffer fills up
// content will be compressed and written to the output.
// When done writing, use Close to flush the remaining output
// and write CRC if requested.
Write(p []byte) (n int, err error)
// Flush will send the currently written data to output
// and block until everything has been written.
// This should only be used on rare occasions where pushing the currently queued data is critical.
Flush() error
// Close closes the underlying writers if/when appropriate.
// Note that the compressed writer should not be closed if we never used it,
// as it would otherwise send some extra "end of compression" bytes.
// Close also makes sure to flush whatever was left to write from the buffer.
Close() error
}
type compressionWriter struct {
compression
alg string
}
func newCompressionWriter(algo string, in io.Writer) (*compressionWriter, error) {
switch algo {
case brotliName:
return &compressionWriter{compression: brotli.NewWriter(in), alg: algo}, nil
case zstdName:
writer, err := zstd.NewWriter(in)
if err != nil {
return nil, fmt.Errorf("creating zstd writer: %w", err)
}
return &compressionWriter{compression: writer, alg: algo}, nil
default:
return nil, fmt.Errorf("unknown compression algo: %s", algo)
}
}
func (c *compressionWriter) ContentEncoding() string {
return c.alg
}
// TODO: check whether we want to implement content-type sniffing (as gzip does)
// TODO: check whether we should support Accept-Ranges (as gzip does, see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Ranges)
type responseWriter struct {
rw http.ResponseWriter
bw *brotli.Writer
rw http.ResponseWriter
compressionWriter *compressionWriter
minSize int
excludedContentTypes []parsedContentType
@ -133,7 +211,7 @@ func (r *responseWriter) Write(p []byte) (int, error) {
// We are now in compression cruise mode until the end of times.
if r.compressionStarted {
// If compressionStarted we assume we have sent headers already
return r.bw.Write(p)
return r.compressionWriter.Write(p)
}
// If we detect a contentEncoding, we know we are never going to compress.
@ -187,13 +265,13 @@ func (r *responseWriter) Write(p []byte) (int, error) {
// Since we know we are going to compress we will never be able to know the actual length.
r.rw.Header().Del(contentLength)
r.rw.Header().Set(contentEncoding, "br")
r.rw.Header().Set(contentEncoding, r.compressionWriter.ContentEncoding())
r.rw.WriteHeader(r.statusCode)
r.headersSent = true
// Start with sending what we have previously buffered, before actually writing
// the bytes in argument.
n, err := r.bw.Write(r.buf)
n, err := r.compressionWriter.Write(r.buf)
if err != nil {
r.buf = r.buf[n:]
// Return zero because we haven't taken care of the bytes in argument yet.
@ -212,7 +290,7 @@ func (r *responseWriter) Write(p []byte) (int, error) {
r.buf = r.buf[:0]
// Now that we emptied the buffer, we can actually write the given bytes.
return r.bw.Write(p)
return r.compressionWriter.Write(p)
}
// Flush flushes data to the appropriate underlying writer(s), although it does
@ -250,7 +328,7 @@ func (r *responseWriter) Flush() {
// we have to do it ourselves.
defer func() {
// because we also ignore the error returned by Write anyway
_ = r.bw.Flush()
_ = r.compressionWriter.Flush()
if rw, ok := r.rw.(http.Flusher); ok {
rw.Flush()
@ -258,7 +336,7 @@ func (r *responseWriter) Flush() {
}()
// We empty whatever is left of the buffer that Write never took care of.
n, err := r.bw.Write(r.buf)
n, err := r.compressionWriter.Write(r.buf)
if err != nil {
return
}
@ -313,7 +391,7 @@ func (r *responseWriter) close() error {
if len(r.buf) == 0 {
// If we got here we know compression has started, so we can safely flush on bw.
return r.bw.Close()
return r.compressionWriter.Close()
}
// There is still data in the buffer, because we never reached minSize (to
@ -331,16 +409,16 @@ func (r *responseWriter) close() error {
// There is still data in the buffer, simply because Write did not take care of it all.
// We flush it to the compressed writer.
n, err := r.bw.Write(r.buf)
n, err := r.compressionWriter.Write(r.buf)
if err != nil {
r.bw.Close()
r.compressionWriter.Close()
return err
}
if n < len(r.buf) {
r.bw.Close()
r.compressionWriter.Close()
return io.ErrShortWrite
}
return r.bw.Close()
return r.compressionWriter.Close()
}
// parsedContentType is the parsed representation of one of the inputs to ContentTypes.

View file

@ -1,4 +1,4 @@
package brotli
package compress
import (
"bytes"
@ -9,6 +9,7 @@ import (
"testing"
"github.com/andybalholm/brotli"
"github.com/klauspost/compress/zstd"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -19,44 +20,107 @@ var (
)
func Test_Vary(t *testing.T) {
h := newTestHandler(t, smallTestBody)
testCases := []struct {
desc string
h http.Handler
acceptEncoding string
}{
{
desc: "brotli",
h: newTestBrotliHandler(t, smallTestBody),
acceptEncoding: "br",
},
{
desc: "zstd",
h: newTestZstandardHandler(t, smallTestBody),
acceptEncoding: "zstd",
},
}
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, "br")
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, test.acceptEncoding)
assert.Equal(t, http.StatusAccepted, rw.Code)
assert.Equal(t, acceptEncoding, rw.Header().Get(vary))
rw := httptest.NewRecorder()
test.h.ServeHTTP(rw, req)
assert.Equal(t, http.StatusAccepted, rw.Code)
assert.Equal(t, acceptEncoding, rw.Header().Get(vary))
})
}
}
func Test_SmallBodyNoCompression(t *testing.T) {
h := newTestHandler(t, smallTestBody)
testCases := []struct {
desc string
h http.Handler
acceptEncoding string
}{
{
desc: "brotli",
h: newTestBrotliHandler(t, smallTestBody),
acceptEncoding: "br",
},
{
desc: "zstd",
h: newTestZstandardHandler(t, smallTestBody),
acceptEncoding: "zstd",
},
}
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, "br")
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, test.acceptEncoding)
// With less than 1024 bytes the response should not be compressed.
assert.Equal(t, http.StatusAccepted, rw.Code)
assert.Empty(t, rw.Header().Get(contentEncoding))
assert.Equal(t, smallTestBody, rw.Body.Bytes())
rw := httptest.NewRecorder()
test.h.ServeHTTP(rw, req)
// With less than 1024 bytes the response should not be compressed.
assert.Equal(t, http.StatusAccepted, rw.Code)
assert.Empty(t, rw.Header().Get(contentEncoding))
assert.Equal(t, smallTestBody, rw.Body.Bytes())
})
}
}
func Test_AlreadyCompressed(t *testing.T) {
h := newTestHandler(t, bigTestBody)
testCases := []struct {
desc string
h http.Handler
acceptEncoding string
}{
{
desc: "brotli",
h: newTestBrotliHandler(t, bigTestBody),
acceptEncoding: "br",
},
{
desc: "zstd",
h: newTestZstandardHandler(t, bigTestBody),
acceptEncoding: "zstd",
},
}
req, _ := http.NewRequest(http.MethodGet, "/compressed", nil)
req.Header.Set(acceptEncoding, "br")
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
req, _ := http.NewRequest(http.MethodGet, "/compressed", nil)
req.Header.Set(acceptEncoding, test.acceptEncoding)
assert.Equal(t, http.StatusAccepted, rw.Code)
assert.Equal(t, bigTestBody, rw.Body.Bytes())
rw := httptest.NewRecorder()
test.h.ServeHTTP(rw, req)
assert.Equal(t, http.StatusAccepted, rw.Code)
assert.Equal(t, bigTestBody, rw.Body.Bytes())
})
}
}
func Test_NoBody(t *testing.T) {
@ -91,15 +155,17 @@ func Test_NoBody(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
h := mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(test.statusCode)
_, err := rw.Write(test.body)
require.NoError(t, err)
}))
})
h := mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName}, next)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(acceptEncoding, "br")
req.Header.Set(acceptEncoding, "zstd")
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
@ -115,24 +181,26 @@ func Test_NoBody(t *testing.T) {
func Test_MinSize(t *testing.T) {
cfg := Config{
MinSize: 128,
MinSize: 128,
Algorithm: zstdName,
}
var bodySize int
h := mustNewWrapper(t, cfg)(http.HandlerFunc(
func(rw http.ResponseWriter, req *http.Request) {
for range bodySize {
// We make sure to Write at least once less than minSize so that both
// cases below go through the same algo: i.e. they start buffering
// because they haven't reached minSize.
_, err := rw.Write([]byte{'x'})
require.NoError(t, err)
}
},
))
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
for range bodySize {
// We make sure to Write at least once less than minSize so that both
// cases below go through the same algo: i.e. they start buffering
// because they haven't reached minSize.
_, err := rw.Write([]byte{'x'})
require.NoError(t, err)
}
})
h := mustNewCompressionHandler(t, cfg, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", &bytes.Buffer{})
req.Header.Add(acceptEncoding, "br")
req.Header.Add(acceptEncoding, "zstd")
// Short response is not compressed
bodySize = cfg.MinSize - 1
@ -146,18 +214,20 @@ func Test_MinSize(t *testing.T) {
rw = httptest.NewRecorder()
h.ServeHTTP(rw, req)
assert.Equal(t, "br", rw.Result().Header.Get(contentEncoding))
assert.Equal(t, "zstd", rw.Result().Header.Get(contentEncoding))
}
func Test_MultipleWriteHeader(t *testing.T) {
h := mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
// We ensure that the subsequent call to WriteHeader is a noop.
rw.WriteHeader(http.StatusInternalServerError)
rw.WriteHeader(http.StatusNotFound)
}))
})
h := mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName}, next)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(acceptEncoding, "br")
req.Header.Set(acceptEncoding, "zstd")
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
@ -166,121 +236,255 @@ func Test_MultipleWriteHeader(t *testing.T) {
}
func Test_FlushBeforeWrite(t *testing.T) {
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
rw.(http.Flusher).Flush()
testCases := []struct {
desc string
cfg Config
readerBuilder func(io.Reader) (io.Reader, error)
acceptEncoding string
}{
{
desc: "brotli",
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return brotli.NewReader(reader), nil
},
acceptEncoding: "br",
},
{
desc: "zstd",
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return zstd.NewReader(reader)
},
acceptEncoding: "zstd",
},
}
_, err := rw.Write(bigTestBody)
require.NoError(t, err)
})))
defer srv.Close()
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
rw.(http.Flusher).Flush()
req.Header.Set(acceptEncoding, "br")
_, err := rw.Write(bigTestBody)
require.NoError(t, err)
})
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
defer srv.Close()
defer res.Body.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, "br", res.Header.Get(contentEncoding))
req.Header.Set(acceptEncoding, test.acceptEncoding)
got, err := io.ReadAll(brotli.NewReader(res.Body))
require.NoError(t, err)
assert.Equal(t, bigTestBody, got)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding))
reader, err := test.readerBuilder(res.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, bigTestBody, got)
})
}
}
func Test_FlushAfterWrite(t *testing.T) {
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
testCases := []struct {
desc string
cfg Config
readerBuilder func(io.Reader) (io.Reader, error)
acceptEncoding string
}{
{
desc: "brotli",
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return brotli.NewReader(reader), nil
},
acceptEncoding: "br",
},
{
desc: "zstd",
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return zstd.NewReader(reader)
},
acceptEncoding: "zstd",
},
}
_, err := rw.Write(bigTestBody[0:1])
require.NoError(t, err)
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
rw.(http.Flusher).Flush()
for _, b := range bigTestBody[1:] {
_, err := rw.Write([]byte{b})
_, err := rw.Write(bigTestBody[0:1])
require.NoError(t, err)
rw.(http.Flusher).Flush()
for _, b := range bigTestBody[1:] {
_, err := rw.Write([]byte{b})
require.NoError(t, err)
}
})
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)
}
})))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)
req.Header.Set(acceptEncoding, test.acceptEncoding)
req.Header.Set(acceptEncoding, "br")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer res.Body.Close()
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding))
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, "br", res.Header.Get(contentEncoding))
reader, err := test.readerBuilder(res.Body)
require.NoError(t, err)
got, err := io.ReadAll(brotli.NewReader(res.Body))
require.NoError(t, err)
assert.Equal(t, bigTestBody, got)
got, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, bigTestBody, got)
})
}
}
func Test_FlushAfterWriteNil(t *testing.T) {
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
testCases := []struct {
desc string
cfg Config
readerBuilder func(io.Reader) (io.Reader, error)
acceptEncoding string
}{
{
desc: "brotli",
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return brotli.NewReader(reader), nil
},
acceptEncoding: "br",
},
{
desc: "zstd",
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return zstd.NewReader(reader)
},
acceptEncoding: "zstd",
},
}
_, err := rw.Write(nil)
require.NoError(t, err)
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
rw.(http.Flusher).Flush()
})))
defer srv.Close()
_, err := rw.Write(nil)
require.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)
rw.(http.Flusher).Flush()
})
req.Header.Set(acceptEncoding, "br")
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
defer srv.Close()
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)
defer res.Body.Close()
req.Header.Set(acceptEncoding, test.acceptEncoding)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Empty(t, res.Header.Get(contentEncoding))
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
got, err := io.ReadAll(brotli.NewReader(res.Body))
require.NoError(t, err)
assert.Empty(t, got)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Empty(t, res.Header.Get(contentEncoding))
reader, err := test.readerBuilder(res.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Empty(t, got)
})
}
}
func Test_FlushAfterAllWrites(t *testing.T) {
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
for i := range bigTestBody {
_, err := rw.Write(bigTestBody[i : i+1])
testCases := []struct {
desc string
cfg Config
readerBuilder func(io.Reader) (io.Reader, error)
acceptEncoding string
}{
{
desc: "brotli",
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return brotli.NewReader(reader), nil
},
acceptEncoding: "br",
},
{
desc: "zstd",
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return zstd.NewReader(reader)
},
acceptEncoding: "zstd",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
for i := range bigTestBody {
_, err := rw.Write(bigTestBody[i : i+1])
require.NoError(t, err)
}
rw.(http.Flusher).Flush()
})
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)
}
rw.(http.Flusher).Flush()
})))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
require.NoError(t, err)
req.Header.Set(acceptEncoding, test.acceptEncoding)
req.Header.Set(acceptEncoding, "br")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer res.Body.Close()
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, test.acceptEncoding, res.Header.Get(contentEncoding))
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, "br", res.Header.Get(contentEncoding))
reader, err := test.readerBuilder(res.Body)
require.NoError(t, err)
got, err := io.ReadAll(brotli.NewReader(res.Body))
require.NoError(t, err)
assert.Equal(t, bigTestBody, got)
got, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, bigTestBody, got)
})
}
}
func Test_ExcludedContentTypes(t *testing.T) {
@ -352,18 +556,22 @@ func Test_ExcludedContentTypes(t *testing.T) {
cfg := Config{
MinSize: 1024,
ExcludedContentTypes: test.excludedContentTypes,
Algorithm: zstdName,
}
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set(contentType, test.contentType)
rw.WriteHeader(http.StatusAccepted)
_, err := rw.Write(bigTestBody)
require.NoError(t, err)
}))
})
h := mustNewCompressionHandler(t, cfg, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, "br")
req.Header.Set(acceptEncoding, zstdName)
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
@ -371,13 +579,16 @@ func Test_ExcludedContentTypes(t *testing.T) {
assert.Equal(t, http.StatusAccepted, rw.Code)
if test.expCompression {
assert.Equal(t, "br", rw.Header().Get(contentEncoding))
assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
got, err := io.ReadAll(brotli.NewReader(rw.Body))
reader, err := zstd.NewReader(rw.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
assert.NoError(t, err)
assert.Equal(t, bigTestBody, got)
} else {
assert.NotEqual(t, "br", rw.Header().Get("Content-Encoding"))
assert.NotEqual(t, zstdName, rw.Header().Get("Content-Encoding"))
got, err := io.ReadAll(rw.Body)
assert.NoError(t, err)
@ -456,18 +667,22 @@ func Test_IncludedContentTypes(t *testing.T) {
cfg := Config{
MinSize: 1024,
IncludedContentTypes: test.includedContentTypes,
Algorithm: zstdName,
}
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set(contentType, test.contentType)
rw.WriteHeader(http.StatusAccepted)
_, err := rw.Write(bigTestBody)
require.NoError(t, err)
}))
})
h := mustNewCompressionHandler(t, cfg, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, "br")
req.Header.Set(acceptEncoding, zstdName)
rw := httptest.NewRecorder()
h.ServeHTTP(rw, req)
@ -475,13 +690,16 @@ func Test_IncludedContentTypes(t *testing.T) {
assert.Equal(t, http.StatusAccepted, rw.Code)
if test.expCompression {
assert.Equal(t, "br", rw.Header().Get(contentEncoding))
assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
got, err := io.ReadAll(brotli.NewReader(rw.Body))
reader, err := zstd.NewReader(rw.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
assert.NoError(t, err)
assert.Equal(t, bigTestBody, got)
} else {
assert.NotEqual(t, "br", rw.Header().Get("Content-Encoding"))
assert.NotEqual(t, zstdName, rw.Header().Get("Content-Encoding"))
got, err := io.ReadAll(rw.Body)
assert.NoError(t, err)
@ -560,8 +778,10 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
cfg := Config{
MinSize: 1024,
ExcludedContentTypes: test.excludedContentTypes,
Algorithm: zstdName,
}
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set(contentType, test.contentType)
rw.WriteHeader(http.StatusOK)
@ -581,10 +801,12 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
rw.(http.Flusher).Flush()
tb = tb[toWrite:]
}
}))
})
h := mustNewCompressionHandler(t, cfg, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, "br")
req.Header.Set(acceptEncoding, zstdName)
// This doesn't allow checking flushes, but we validate if content is correct.
rw := httptest.NewRecorder()
@ -593,13 +815,16 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
assert.Equal(t, http.StatusOK, rw.Code)
if test.expCompression {
assert.Equal(t, "br", rw.Header().Get(contentEncoding))
assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
got, err := io.ReadAll(brotli.NewReader(rw.Body))
reader, err := zstd.NewReader(rw.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
assert.NoError(t, err)
assert.Equal(t, bigTestBody, got)
} else {
assert.NotEqual(t, "br", rw.Header().Get(contentEncoding))
assert.NotEqual(t, zstdName, rw.Header().Get(contentEncoding))
got, err := io.ReadAll(rw.Body)
assert.NoError(t, err)
@ -678,8 +903,10 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
cfg := Config{
MinSize: 1024,
IncludedContentTypes: test.includedContentTypes,
Algorithm: zstdName,
}
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set(contentType, test.contentType)
rw.WriteHeader(http.StatusOK)
@ -699,10 +926,12 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
rw.(http.Flusher).Flush()
tb = tb[toWrite:]
}
}))
})
h := mustNewCompressionHandler(t, cfg, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, "br")
req.Header.Set(acceptEncoding, zstdName)
// This doesn't allow checking flushes, but we validate if content is correct.
rw := httptest.NewRecorder()
@ -711,13 +940,16 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
assert.Equal(t, http.StatusOK, rw.Code)
if test.expCompression {
assert.Equal(t, "br", rw.Header().Get(contentEncoding))
assert.Equal(t, zstdName, rw.Header().Get(contentEncoding))
got, err := io.ReadAll(brotli.NewReader(rw.Body))
reader, err := zstd.NewReader(rw.Body)
require.NoError(t, err)
got, err := io.ReadAll(reader)
assert.NoError(t, err)
assert.Equal(t, bigTestBody, got)
} else {
assert.NotEqual(t, "br", rw.Header().Get(contentEncoding))
assert.NotEqual(t, zstdName, rw.Header().Get(contentEncoding))
got, err := io.ReadAll(rw.Body)
assert.NoError(t, err)
@ -727,32 +959,48 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
}
}
func mustNewWrapper(t *testing.T, cfg Config) func(http.Handler) http.HandlerFunc {
func mustNewCompressionHandler(t *testing.T, cfg Config, next http.Handler) http.Handler {
t.Helper()
w, err := NewWrapper(cfg)
w, err := NewCompressionHandler(cfg, next)
require.NoError(t, err)
return w
}
func newTestHandler(t *testing.T, body []byte) http.Handler {
func newTestBrotliHandler(t *testing.T, body []byte) http.Handler {
t.Helper()
return mustNewWrapper(t, Config{MinSize: 1024})(
http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == "/compressed" {
rw.Header().Set("Content-Encoding", "br")
}
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == "/compressed" {
rw.Header().Set("Content-Encoding", brotliName)
}
rw.WriteHeader(http.StatusAccepted)
_, err := rw.Write(body)
require.NoError(t, err)
}),
)
rw.WriteHeader(http.StatusAccepted)
_, err := rw.Write(body)
require.NoError(t, err)
})
return mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Compress"}, next)
}
func TestParseContentType_equals(t *testing.T) {
func newTestZstandardHandler(t *testing.T, body []byte) http.Handler {
t.Helper()
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.Path == "/compressed" {
rw.Header().Set("Content-Encoding", zstdName)
}
rw.WriteHeader(http.StatusAccepted)
_, err := rw.Write(body)
require.NoError(t, err)
})
return mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Compress"}, next)
}
func Test_ParseContentType_equals(t *testing.T) {
testCases := []struct {
desc string
pct parsedContentType