Add support for Brotli
Co-authored-by: Mathieu Lonjaret <mathieu.lonjaret@gmail.com> Co-authored-by: Tom Moulard <tom.moulard@traefik.io> Co-authored-by: Romain <rtribotte@users.noreply.github.com> Co-authored-by: Kevin Pollet <pollet.kevin@gmail.com>
This commit is contained in:
parent
1a1cfd1adc
commit
67d9c8da0b
11 changed files with 1201 additions and 39 deletions
338
pkg/middlewares/compress/brotli/brotli.go
Normal file
338
pkg/middlewares/compress/brotli/brotli.go
Normal file
|
@ -0,0 +1,338 @@
|
|||
package brotli
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
)
|
||||
|
||||
const (
|
||||
vary = "Vary"
|
||||
acceptEncoding = "Accept-Encoding"
|
||||
contentEncoding = "Content-Encoding"
|
||||
contentLength = "Content-Length"
|
||||
contentType = "Content-Type"
|
||||
)
|
||||
|
||||
// Config is the Brotli handler configuration.
|
||||
type Config struct {
|
||||
// ExcludedContentTypes is the list of content types for which we should not compress.
|
||||
ExcludedContentTypes []string
|
||||
// MinSize is the minimum size (in bytes) required to enable compression.
|
||||
MinSize int
|
||||
}
|
||||
|
||||
// NewWrapper returns a new Brotli compressing wrapper.
|
||||
func NewWrapper(cfg Config) (func(http.Handler) http.HandlerFunc, error) {
|
||||
if cfg.MinSize < 0 {
|
||||
return nil, fmt.Errorf("minimum size must be greater than or equal to zero")
|
||||
}
|
||||
|
||||
var contentTypes []parsedContentType
|
||||
for _, v := range cfg.ExcludedContentTypes {
|
||||
mediaType, params, err := mime.ParseMediaType(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing media type: %w", err)
|
||||
}
|
||||
|
||||
contentTypes = append(contentTypes, parsedContentType{mediaType, params})
|
||||
}
|
||||
|
||||
return func(h http.Handler) http.HandlerFunc {
|
||||
return func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Add(vary, acceptEncoding)
|
||||
|
||||
brw := &responseWriter{
|
||||
rw: rw,
|
||||
bw: brotli.NewWriter(rw),
|
||||
minSize: cfg.MinSize,
|
||||
statusCode: http.StatusOK,
|
||||
excludedContentTypes: contentTypes,
|
||||
}
|
||||
defer brw.close()
|
||||
|
||||
h.ServeHTTP(brw, r)
|
||||
}
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TODO: check whether we want to implement content-type sniffing (as gzip does)
|
||||
// TODO: check whether we should support Accept-Ranges (as gzip does, see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Ranges)
|
||||
type responseWriter struct {
|
||||
rw http.ResponseWriter
|
||||
bw *brotli.Writer
|
||||
|
||||
minSize int
|
||||
excludedContentTypes []parsedContentType
|
||||
|
||||
buf []byte
|
||||
hijacked bool
|
||||
compressionStarted bool
|
||||
compressionDisabled bool
|
||||
headersSent bool
|
||||
|
||||
// Mostly needed to avoid calling bw.Flush/bw.Close when no data was
|
||||
// written in bw.
|
||||
seenData bool
|
||||
|
||||
statusCodeSet bool
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (r *responseWriter) Header() http.Header {
|
||||
return r.rw.Header()
|
||||
}
|
||||
|
||||
func (r *responseWriter) WriteHeader(statusCode int) {
|
||||
if r.statusCodeSet {
|
||||
return
|
||||
}
|
||||
|
||||
r.statusCode = statusCode
|
||||
r.statusCodeSet = true
|
||||
}
|
||||
|
||||
func (r *responseWriter) Write(p []byte) (int, error) {
|
||||
// i.e. has write ever been called at least once with non nil data.
|
||||
if !r.seenData && len(p) > 0 {
|
||||
r.seenData = true
|
||||
}
|
||||
|
||||
// We do not compress, either for contentEncoding or contentType reasons.
|
||||
if r.compressionDisabled {
|
||||
return r.rw.Write(p)
|
||||
}
|
||||
|
||||
// We have already buffered more than minSize,
|
||||
// We are now in compression cruise mode until the end of times.
|
||||
if r.compressionStarted {
|
||||
// If compressionStarted we assume we have sent headers already
|
||||
return r.bw.Write(p)
|
||||
}
|
||||
|
||||
// If we detect a contentEncoding, we know we are never going to compress.
|
||||
if r.rw.Header().Get(contentEncoding) != "" {
|
||||
r.compressionDisabled = true
|
||||
return r.rw.Write(p)
|
||||
}
|
||||
|
||||
// Disable compression according to user wishes in excludedContentTypes.
|
||||
if ct := r.rw.Header().Get(contentType); ct != "" {
|
||||
mediaType, params, err := mime.ParseMediaType(ct)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing media type: %w", err)
|
||||
}
|
||||
|
||||
for _, excludedContentType := range r.excludedContentTypes {
|
||||
if excludedContentType.equals(mediaType, params) {
|
||||
r.compressionDisabled = true
|
||||
return r.rw.Write(p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We buffer until we know whether to compress (i.e. when we reach minSize received).
|
||||
if len(r.buf)+len(p) < r.minSize {
|
||||
r.buf = append(r.buf, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// If we ever make it here, we have received at least minSize, which means we want to compress,
|
||||
// and we are going to send headers right away.
|
||||
r.compressionStarted = true
|
||||
|
||||
// Since we know we are going to compress we will never be able to know the actual length.
|
||||
r.rw.Header().Del(contentLength)
|
||||
|
||||
r.rw.Header().Set(contentEncoding, "br")
|
||||
r.rw.WriteHeader(r.statusCode)
|
||||
r.headersSent = true
|
||||
|
||||
// Start with sending what we have previously buffered, before actually writing
|
||||
// the bytes in argument.
|
||||
n, err := r.bw.Write(r.buf)
|
||||
if err != nil {
|
||||
r.buf = r.buf[n:]
|
||||
// Return zero because we haven't taken care of the bytes in argument yet.
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// If we wrote less than what we wanted, we need to reclaim the leftovers + the bytes in argument,
|
||||
// and keep them for a subsequent Write.
|
||||
if n < len(r.buf) {
|
||||
r.buf = r.buf[n:]
|
||||
r.buf = append(r.buf, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// Otherwise just reset the buffer.
|
||||
r.buf = r.buf[:0]
|
||||
|
||||
// Now that we emptied the buffer, we can actually write the given bytes.
|
||||
return r.bw.Write(p)
|
||||
}
|
||||
|
||||
// Flush flushes data to the appropriate underlying writer(s), although it does
|
||||
// not guarantee that all buffered data will be sent.
|
||||
// If not enough bytes have been written to determine whether to enable compression,
|
||||
// no flushing will take place.
|
||||
func (r *responseWriter) Flush() {
|
||||
if !r.seenData {
|
||||
// we should not flush if there never was any data, because flushing the bw
|
||||
// (just like closing) would send some extra end of compressionStarted stream bytes.
|
||||
return
|
||||
}
|
||||
|
||||
// It was already established by Write that compression is disabled, we only
|
||||
// have to flush the uncompressed writer.
|
||||
if r.compressionDisabled {
|
||||
if rw, ok := r.rw.(http.Flusher); ok {
|
||||
rw.Flush()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Here, nothing was ever written either to rw or to bw (since we're still
|
||||
// waiting to decide whether to compress), so we do not need to flush anything.
|
||||
// Note that we diverge with klauspost's gzip behavior, where they instead
|
||||
// force compression and flush whatever was in the buffer in this case.
|
||||
if !r.compressionStarted {
|
||||
return
|
||||
}
|
||||
|
||||
// Conversely, we here know that something was already written to bw (or is
|
||||
// going to be written right after anyway), so bw will have to be flushed.
|
||||
// Also, since we know that bw writes to rw, but (apparently) never flushes it,
|
||||
// we have to do it ourselves.
|
||||
defer func() {
|
||||
// because we also ignore the error returned by Write anyway
|
||||
_ = r.bw.Flush()
|
||||
|
||||
if rw, ok := r.rw.(http.Flusher); ok {
|
||||
rw.Flush()
|
||||
}
|
||||
}()
|
||||
|
||||
// We empty whatever is left of the buffer that Write never took care of.
|
||||
n, err := r.bw.Write(r.buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// And just like in Write we also handle "short writes".
|
||||
if n < len(r.buf) {
|
||||
r.buf = r.buf[n:]
|
||||
return
|
||||
}
|
||||
|
||||
r.buf = r.buf[:0]
|
||||
}
|
||||
|
||||
func (r *responseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hijacker, ok := r.rw.(http.Hijacker); ok {
|
||||
// We only make use of r.hijacked in close (and not in Write/WriteHeader)
|
||||
// because we want to let the stdlib catch the error on writes, as
|
||||
// they already do a good job of logging it.
|
||||
r.hijacked = true
|
||||
return hijacker.Hijack()
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf("%T is not a http.Hijacker", r.rw)
|
||||
}
|
||||
|
||||
// close closes the underlying writers if/when appropriate.
|
||||
// Note that the compressed writer should not be closed if we never used it,
|
||||
// as it would otherwise send some extra "end of compression" bytes.
|
||||
// Close also makes sure to flush whatever was left to write from the buffer.
|
||||
func (r *responseWriter) close() error {
|
||||
if r.hijacked {
|
||||
return nil
|
||||
}
|
||||
|
||||
// We have to take care of statusCode ourselves (in case there was never any
|
||||
// call to Write or WriteHeader before us) as it's the only header we buffer.
|
||||
if !r.headersSent {
|
||||
r.rw.WriteHeader(r.statusCode)
|
||||
r.headersSent = true
|
||||
}
|
||||
|
||||
// Nothing was ever written anywhere, nothing to flush.
|
||||
if !r.seenData {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If compression was disabled, there never was anything in the buffer to flush,
|
||||
// and nothing was ever written to bw.
|
||||
if r.compressionDisabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(r.buf) == 0 {
|
||||
// If we got here we know compression has started, so we can safely flush on bw.
|
||||
return r.bw.Close()
|
||||
}
|
||||
|
||||
// There is still data in the buffer, because we never reached minSize (to
|
||||
// determine whether to compress). We therefore flush it uncompressed.
|
||||
if !r.compressionStarted {
|
||||
n, err := r.rw.Write(r.buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n < len(r.buf) {
|
||||
return io.ErrShortWrite
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// There is still data in the buffer, simply because Write did not take care of it all.
|
||||
// We flush it to the compressed writer.
|
||||
n, err := r.bw.Write(r.buf)
|
||||
if err != nil {
|
||||
r.bw.Close()
|
||||
return err
|
||||
}
|
||||
if n < len(r.buf) {
|
||||
r.bw.Close()
|
||||
return io.ErrShortWrite
|
||||
}
|
||||
return r.bw.Close()
|
||||
}
|
||||
|
||||
// parsedContentType is the parsed representation of one of the inputs to ContentTypes.
|
||||
// From https://github.com/klauspost/compress/blob/master/gzhttp/compress.go#L401.
|
||||
type parsedContentType struct {
|
||||
mediaType string
|
||||
params map[string]string
|
||||
}
|
||||
|
||||
// equals returns whether this content type matches another content type.
|
||||
func (p parsedContentType) equals(mediaType string, params map[string]string) bool {
|
||||
if p.mediaType != mediaType {
|
||||
return false
|
||||
}
|
||||
|
||||
// if p has no params, don't care about other's params
|
||||
if len(p.params) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// if p has any params, they must be identical to other's.
|
||||
if len(p.params) != len(params) {
|
||||
return false
|
||||
}
|
||||
|
||||
for k, v := range p.params {
|
||||
if w, ok := params[k]; !ok || v != w {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
618
pkg/middlewares/compress/brotli/brotli_test.go
Normal file
618
pkg/middlewares/compress/brotli/brotli_test.go
Normal file
|
@ -0,0 +1,618 @@
|
|||
package brotli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
smallTestBody = []byte("aaabbc" + strings.Repeat("aaabbbccc", 9) + "aaabbbc")
|
||||
bigTestBody = []byte(strings.Repeat(strings.Repeat("aaabbbccc", 66)+" ", 6) + strings.Repeat("aaabbbccc", 66))
|
||||
)
|
||||
|
||||
func Test_Vary(t *testing.T) {
|
||||
h := newTestHandler(t, smallTestBody)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
h.ServeHTTP(rw, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rw.Code)
|
||||
assert.Equal(t, acceptEncoding, rw.Header().Get(vary))
|
||||
}
|
||||
|
||||
func Test_SmallBodyNoCompression(t *testing.T) {
|
||||
h := newTestHandler(t, smallTestBody)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
h.ServeHTTP(rw, req)
|
||||
|
||||
// With less than 1024 bytes the response should not be compressed.
|
||||
assert.Equal(t, http.StatusOK, rw.Code)
|
||||
assert.Empty(t, rw.Header().Get(contentEncoding))
|
||||
assert.Equal(t, smallTestBody, rw.Body.Bytes())
|
||||
}
|
||||
|
||||
func Test_AlreadyCompressed(t *testing.T) {
|
||||
h := newTestHandler(t, bigTestBody)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/compressed", nil)
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
h.ServeHTTP(rw, req)
|
||||
|
||||
assert.Equal(t, bigTestBody, rw.Body.Bytes())
|
||||
}
|
||||
|
||||
func Test_NoBody(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
statusCode int
|
||||
body []byte
|
||||
}{
|
||||
{
|
||||
desc: "status no content",
|
||||
statusCode: http.StatusNoContent,
|
||||
body: nil,
|
||||
},
|
||||
{
|
||||
desc: "status not modified",
|
||||
statusCode: http.StatusNotModified,
|
||||
body: nil,
|
||||
},
|
||||
{
|
||||
desc: "status OK with empty body",
|
||||
statusCode: http.StatusOK,
|
||||
body: []byte{},
|
||||
},
|
||||
{
|
||||
desc: "status OK with nil body",
|
||||
statusCode: http.StatusOK,
|
||||
body: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(test.statusCode)
|
||||
|
||||
_, err := rw.Write(test.body)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
h.ServeHTTP(rw, req)
|
||||
|
||||
body, err := io.ReadAll(rw.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Empty(t, rw.Header().Get(contentEncoding))
|
||||
assert.Empty(t, body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_MinSize(t *testing.T) {
|
||||
cfg := Config{
|
||||
MinSize: 128,
|
||||
}
|
||||
|
||||
var bodySize int
|
||||
h := mustNewWrapper(t, cfg)(http.HandlerFunc(
|
||||
func(rw http.ResponseWriter, req *http.Request) {
|
||||
for i := 0; i < bodySize; i++ {
|
||||
// We make sure to Write at least once less than minSize so that both
|
||||
// cases below go through the same algo: i.e. they start buffering
|
||||
// because they haven't reached minSize.
|
||||
_, err := rw.Write([]byte{'x'})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
},
|
||||
))
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", &bytes.Buffer{})
|
||||
req.Header.Add(acceptEncoding, "br")
|
||||
|
||||
// Short response is not compressed
|
||||
bodySize = cfg.MinSize - 1
|
||||
rw := httptest.NewRecorder()
|
||||
h.ServeHTTP(rw, req)
|
||||
|
||||
assert.Empty(t, rw.Result().Header.Get(contentEncoding))
|
||||
|
||||
// Long response is compressed
|
||||
bodySize = cfg.MinSize
|
||||
rw = httptest.NewRecorder()
|
||||
h.ServeHTTP(rw, req)
|
||||
|
||||
assert.Equal(t, "br", rw.Result().Header.Get(contentEncoding))
|
||||
}
|
||||
|
||||
func Test_MultipleWriteHeader(t *testing.T) {
|
||||
h := mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
// We ensure that the subsequent call to WriteHeader is a noop.
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
rw.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
h.ServeHTTP(rw, req)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, rw.Code)
|
||||
}
|
||||
|
||||
func Test_FlushBeforeWrite(t *testing.T) {
|
||||
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
rw.(http.Flusher).Flush()
|
||||
|
||||
_, err := rw.Write(bigTestBody)
|
||||
require.NoError(t, err)
|
||||
})))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Equal(t, "br", res.Header.Get(contentEncoding))
|
||||
|
||||
got, err := io.ReadAll(brotli.NewReader(res.Body))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, bigTestBody, got)
|
||||
}
|
||||
|
||||
func Test_FlushAfterWrite(t *testing.T) {
|
||||
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
|
||||
_, err := rw.Write(bigTestBody[0:1])
|
||||
require.NoError(t, err)
|
||||
|
||||
rw.(http.Flusher).Flush()
|
||||
for _, b := range bigTestBody[1:] {
|
||||
_, err := rw.Write([]byte{b})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Equal(t, "br", res.Header.Get(contentEncoding))
|
||||
|
||||
got, err := io.ReadAll(brotli.NewReader(res.Body))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, bigTestBody, got)
|
||||
}
|
||||
|
||||
func Test_FlushAfterWriteNil(t *testing.T) {
|
||||
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
|
||||
_, err := rw.Write(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
rw.(http.Flusher).Flush()
|
||||
})))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Empty(t, res.Header.Get(contentEncoding))
|
||||
|
||||
got, err := io.ReadAll(brotli.NewReader(res.Body))
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, got)
|
||||
}
|
||||
|
||||
func Test_FlushAfterAllWrites(t *testing.T) {
|
||||
srv := httptest.NewServer(mustNewWrapper(t, Config{MinSize: 1024})(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
for i := range bigTestBody {
|
||||
_, err := rw.Write(bigTestBody[i : i+1])
|
||||
require.NoError(t, err)
|
||||
}
|
||||
rw.(http.Flusher).Flush()
|
||||
})))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Equal(t, "br", res.Header.Get(contentEncoding))
|
||||
|
||||
got, err := io.ReadAll(brotli.NewReader(res.Body))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, bigTestBody, got)
|
||||
}
|
||||
|
||||
func Test_ExcludedContentTypes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
contentType string
|
||||
excludedContentTypes []string
|
||||
expCompression bool
|
||||
}{
|
||||
{
|
||||
desc: "Always compress when content types are empty",
|
||||
contentType: "",
|
||||
excludedContentTypes: []string{},
|
||||
expCompression: true,
|
||||
},
|
||||
{
|
||||
desc: "MIME match",
|
||||
contentType: "application/json",
|
||||
excludedContentTypes: []string{"application/json"},
|
||||
expCompression: false,
|
||||
},
|
||||
{
|
||||
desc: "MIME no match",
|
||||
contentType: "text/xml",
|
||||
excludedContentTypes: []string{"application/json"},
|
||||
expCompression: true,
|
||||
},
|
||||
{
|
||||
desc: "MIME match with no other directive ignores non-MIME directives",
|
||||
contentType: "application/json; charset=utf-8",
|
||||
excludedContentTypes: []string{"application/json"},
|
||||
expCompression: false,
|
||||
},
|
||||
{
|
||||
desc: "MIME match with other directives requires all directives be equal, different charset",
|
||||
contentType: "application/json; charset=ascii",
|
||||
excludedContentTypes: []string{"application/json; charset=utf-8"},
|
||||
expCompression: true,
|
||||
},
|
||||
{
|
||||
desc: "MIME match with other directives requires all directives be equal, same charset",
|
||||
contentType: "application/json; charset=utf-8",
|
||||
excludedContentTypes: []string{"application/json; charset=utf-8"},
|
||||
expCompression: false,
|
||||
},
|
||||
{
|
||||
desc: "MIME match with other directives requires all directives be equal, missing charset",
|
||||
contentType: "application/json",
|
||||
excludedContentTypes: []string{"application/json; charset=ascii"},
|
||||
expCompression: true,
|
||||
},
|
||||
{
|
||||
desc: "MIME match case insensitive",
|
||||
contentType: "Application/Json",
|
||||
excludedContentTypes: []string{"application/json"},
|
||||
expCompression: false,
|
||||
},
|
||||
{
|
||||
desc: "MIME match ignore whitespace",
|
||||
contentType: "application/json;charset=utf-8",
|
||||
excludedContentTypes: []string{"application/json; charset=utf-8"},
|
||||
expCompression: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := Config{
|
||||
MinSize: 1024,
|
||||
ExcludedContentTypes: test.excludedContentTypes,
|
||||
}
|
||||
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set(contentType, test.contentType)
|
||||
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
|
||||
_, err := rw.Write(bigTestBody)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
h.ServeHTTP(rw, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rw.Code)
|
||||
|
||||
if test.expCompression {
|
||||
assert.Equal(t, "br", rw.Header().Get(contentEncoding))
|
||||
|
||||
got, err := io.ReadAll(brotli.NewReader(rw.Body))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, bigTestBody, got)
|
||||
} else {
|
||||
assert.NotEqual(t, "br", rw.Header().Get("Content-Encoding"))
|
||||
|
||||
got, err := io.ReadAll(rw.Body)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, bigTestBody, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_FlushExcludedContentTypes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
contentType string
|
||||
excludedContentTypes []string
|
||||
expCompression bool
|
||||
}{
|
||||
{
|
||||
desc: "Always compress when content types are empty",
|
||||
contentType: "",
|
||||
excludedContentTypes: []string{},
|
||||
expCompression: true,
|
||||
},
|
||||
{
|
||||
desc: "MIME match",
|
||||
contentType: "application/json",
|
||||
excludedContentTypes: []string{"application/json"},
|
||||
expCompression: false,
|
||||
},
|
||||
{
|
||||
desc: "MIME no match",
|
||||
contentType: "text/xml",
|
||||
excludedContentTypes: []string{"application/json"},
|
||||
expCompression: true,
|
||||
},
|
||||
{
|
||||
desc: "MIME match with no other directive ignores non-MIME directives",
|
||||
contentType: "application/json; charset=utf-8",
|
||||
excludedContentTypes: []string{"application/json"},
|
||||
expCompression: false,
|
||||
},
|
||||
{
|
||||
desc: "MIME match with other directives requires all directives be equal, different charset",
|
||||
contentType: "application/json; charset=ascii",
|
||||
excludedContentTypes: []string{"application/json; charset=utf-8"},
|
||||
expCompression: true,
|
||||
},
|
||||
{
|
||||
desc: "MIME match with other directives requires all directives be equal, same charset",
|
||||
contentType: "application/json; charset=utf-8",
|
||||
excludedContentTypes: []string{"application/json; charset=utf-8"},
|
||||
expCompression: false,
|
||||
},
|
||||
{
|
||||
desc: "MIME match with other directives requires all directives be equal, missing charset",
|
||||
contentType: "application/json",
|
||||
excludedContentTypes: []string{"application/json; charset=ascii"},
|
||||
expCompression: true,
|
||||
},
|
||||
{
|
||||
desc: "MIME match case insensitive",
|
||||
contentType: "Application/Json",
|
||||
excludedContentTypes: []string{"application/json"},
|
||||
expCompression: false,
|
||||
},
|
||||
{
|
||||
desc: "MIME match ignore whitespace",
|
||||
contentType: "application/json;charset=utf-8",
|
||||
excludedContentTypes: []string{"application/json; charset=utf-8"},
|
||||
expCompression: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := Config{
|
||||
MinSize: 1024,
|
||||
ExcludedContentTypes: test.excludedContentTypes,
|
||||
}
|
||||
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set(contentType, test.contentType)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
|
||||
tb := bigTestBody
|
||||
for len(tb) > 0 {
|
||||
// Write 100 bytes per run
|
||||
// Detection should not be affected (we send 100 bytes)
|
||||
toWrite := 100
|
||||
if toWrite > len(tb) {
|
||||
toWrite = len(tb)
|
||||
}
|
||||
|
||||
_, err := rw.Write(tb[:toWrite])
|
||||
require.NoError(t, err)
|
||||
|
||||
// Flush between each write
|
||||
rw.(http.Flusher).Flush()
|
||||
tb = tb[toWrite:]
|
||||
}
|
||||
}))
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
|
||||
// This doesn't allow checking flushes, but we validate if content is correct.
|
||||
rw := httptest.NewRecorder()
|
||||
h.ServeHTTP(rw, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rw.Code)
|
||||
|
||||
if test.expCompression {
|
||||
assert.Equal(t, "br", rw.Header().Get(contentEncoding))
|
||||
|
||||
got, err := io.ReadAll(brotli.NewReader(rw.Body))
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, bigTestBody, got)
|
||||
} else {
|
||||
assert.NotEqual(t, "br", rw.Header().Get(contentEncoding))
|
||||
|
||||
got, err := io.ReadAll(rw.Body)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, bigTestBody, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustNewWrapper(t *testing.T, cfg Config) func(http.Handler) http.HandlerFunc {
|
||||
t.Helper()
|
||||
|
||||
w, err := NewWrapper(cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
return w
|
||||
}
|
||||
|
||||
func newTestHandler(t *testing.T, body []byte) http.Handler {
|
||||
t.Helper()
|
||||
|
||||
return mustNewWrapper(t, Config{MinSize: 1024})(
|
||||
http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.URL.Path == "/compressed" {
|
||||
rw.Header().Set("Content-Encoding", "br")
|
||||
}
|
||||
|
||||
_, err := rw.Write(body)
|
||||
require.NoError(t, err)
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
func TestParseContentType_equals(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
pct parsedContentType
|
||||
mediaType string
|
||||
params map[string]string
|
||||
expect assert.BoolAssertionFunc
|
||||
}{
|
||||
{
|
||||
desc: "empty parsed content type",
|
||||
expect: assert.True,
|
||||
},
|
||||
{
|
||||
desc: "simple content type",
|
||||
pct: parsedContentType{
|
||||
mediaType: "plain/text",
|
||||
},
|
||||
mediaType: "plain/text",
|
||||
expect: assert.True,
|
||||
},
|
||||
{
|
||||
desc: "content type with params",
|
||||
pct: parsedContentType{
|
||||
mediaType: "plain/text",
|
||||
params: map[string]string{
|
||||
"charset": "utf8",
|
||||
},
|
||||
},
|
||||
mediaType: "plain/text",
|
||||
params: map[string]string{
|
||||
"charset": "utf8",
|
||||
},
|
||||
expect: assert.True,
|
||||
},
|
||||
{
|
||||
desc: "different content type",
|
||||
pct: parsedContentType{
|
||||
mediaType: "plain/text",
|
||||
},
|
||||
mediaType: "application/json",
|
||||
expect: assert.False,
|
||||
},
|
||||
{
|
||||
desc: "content type with params",
|
||||
pct: parsedContentType{
|
||||
mediaType: "plain/text",
|
||||
params: map[string]string{
|
||||
"charset": "utf8",
|
||||
},
|
||||
},
|
||||
mediaType: "plain/text",
|
||||
params: map[string]string{
|
||||
"charset": "latin-1",
|
||||
},
|
||||
expect: assert.False,
|
||||
},
|
||||
{
|
||||
desc: "different number of parameters",
|
||||
pct: parsedContentType{
|
||||
mediaType: "plain/text",
|
||||
params: map[string]string{
|
||||
"charset": "utf8",
|
||||
},
|
||||
},
|
||||
mediaType: "plain/text",
|
||||
params: map[string]string{
|
||||
"charset": "utf8",
|
||||
"q": "0.8",
|
||||
},
|
||||
expect: assert.False,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
test.expect(t, test.pct.equals(test.mediaType, test.params))
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue