Implements the includedContentTypes option for the compress middleware
This commit is contained in:
parent
319517adef
commit
4e0a05406b
15 changed files with 469 additions and 24 deletions
|
@ -22,7 +22,11 @@ const (
|
|||
// Config is the Brotli handler configuration.
|
||||
type Config struct {
|
||||
// ExcludedContentTypes is the list of content types for which we should not compress.
|
||||
// Mutually exclusive with the IncludedContentTypes option.
|
||||
ExcludedContentTypes []string
|
||||
// IncludedContentTypes is the list of content types for which compression should be exclusively enabled.
|
||||
// Mutually exclusive with the ExcludedContentTypes option.
|
||||
IncludedContentTypes []string
|
||||
// MinSize is the minimum size (in bytes) required to enable compression.
|
||||
MinSize int
|
||||
}
|
||||
|
@ -33,14 +37,28 @@ func NewWrapper(cfg Config) (func(http.Handler) http.HandlerFunc, error) {
|
|||
return nil, fmt.Errorf("minimum size must be greater than or equal to zero")
|
||||
}
|
||||
|
||||
var contentTypes []parsedContentType
|
||||
if len(cfg.ExcludedContentTypes) > 0 && len(cfg.IncludedContentTypes) > 0 {
|
||||
return nil, fmt.Errorf("excludedContentTypes and includedContentTypes options are mutually exclusive")
|
||||
}
|
||||
|
||||
var excludedContentTypes []parsedContentType
|
||||
for _, v := range cfg.ExcludedContentTypes {
|
||||
mediaType, params, err := mime.ParseMediaType(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing media type: %w", err)
|
||||
return nil, fmt.Errorf("parsing excluded media type: %w", err)
|
||||
}
|
||||
|
||||
contentTypes = append(contentTypes, parsedContentType{mediaType, params})
|
||||
excludedContentTypes = append(excludedContentTypes, parsedContentType{mediaType, params})
|
||||
}
|
||||
|
||||
var includedContentTypes []parsedContentType
|
||||
for _, v := range cfg.IncludedContentTypes {
|
||||
mediaType, params, err := mime.ParseMediaType(v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing included media type: %w", err)
|
||||
}
|
||||
|
||||
includedContentTypes = append(includedContentTypes, parsedContentType{mediaType, params})
|
||||
}
|
||||
|
||||
return func(h http.Handler) http.HandlerFunc {
|
||||
|
@ -52,7 +70,8 @@ func NewWrapper(cfg Config) (func(http.Handler) http.HandlerFunc, error) {
|
|||
bw: brotli.NewWriter(rw),
|
||||
minSize: cfg.MinSize,
|
||||
statusCode: http.StatusOK,
|
||||
excludedContentTypes: contentTypes,
|
||||
excludedContentTypes: excludedContentTypes,
|
||||
includedContentTypes: includedContentTypes,
|
||||
}
|
||||
defer brw.close()
|
||||
|
||||
|
@ -69,6 +88,7 @@ type responseWriter struct {
|
|||
|
||||
minSize int
|
||||
excludedContentTypes []parsedContentType
|
||||
includedContentTypes []parsedContentType
|
||||
|
||||
buf []byte
|
||||
hijacked bool
|
||||
|
@ -121,11 +141,25 @@ func (r *responseWriter) Write(p []byte) (int, error) {
|
|||
return r.rw.Write(p)
|
||||
}
|
||||
|
||||
// Disable compression according to user wishes in excludedContentTypes.
|
||||
// Disable compression according to user wishes in excludedContentTypes or includedContentTypes.
|
||||
if ct := r.rw.Header().Get(contentType); ct != "" {
|
||||
mediaType, params, err := mime.ParseMediaType(ct)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("parsing media type: %w", err)
|
||||
return 0, fmt.Errorf("parsing content-type media type: %w", err)
|
||||
}
|
||||
|
||||
if len(r.includedContentTypes) > 0 {
|
||||
var found bool
|
||||
for _, includedContentType := range r.includedContentTypes {
|
||||
if includedContentType.equals(mediaType, params) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
r.compressionDisabled = true
|
||||
return r.rw.Write(p)
|
||||
}
|
||||
}
|
||||
|
||||
for _, excludedContentType := range r.excludedContentTypes {
|
||||
|
|
|
@ -291,10 +291,9 @@ func Test_ExcludedContentTypes(t *testing.T) {
|
|||
expCompression bool
|
||||
}{
|
||||
{
|
||||
desc: "Always compress when content types are empty",
|
||||
contentType: "",
|
||||
excludedContentTypes: []string{},
|
||||
expCompression: true,
|
||||
desc: "Always compress when content types are empty",
|
||||
contentType: "",
|
||||
expCompression: true,
|
||||
},
|
||||
{
|
||||
desc: "MIME match",
|
||||
|
@ -389,6 +388,111 @@ func Test_ExcludedContentTypes(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_IncludedContentTypes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
contentType string
|
||||
includedContentTypes []string
|
||||
expCompression bool
|
||||
}{
|
||||
{
|
||||
desc: "Always compress when content types are empty",
|
||||
contentType: "",
|
||||
expCompression: true,
|
||||
},
|
||||
{
|
||||
desc: "MIME match",
|
||||
contentType: "application/json",
|
||||
includedContentTypes: []string{"application/json"},
|
||||
expCompression: true,
|
||||
},
|
||||
{
|
||||
desc: "MIME no match",
|
||||
contentType: "text/xml",
|
||||
includedContentTypes: []string{"application/json"},
|
||||
expCompression: false,
|
||||
},
|
||||
{
|
||||
desc: "MIME match with no other directive ignores non-MIME directives",
|
||||
contentType: "application/json; charset=utf-8",
|
||||
includedContentTypes: []string{"application/json"},
|
||||
expCompression: true,
|
||||
},
|
||||
{
|
||||
desc: "MIME match with other directives requires all directives be equal, different charset",
|
||||
contentType: "application/json; charset=ascii",
|
||||
includedContentTypes: []string{"application/json; charset=utf-8"},
|
||||
expCompression: false,
|
||||
},
|
||||
{
|
||||
desc: "MIME match with other directives requires all directives be equal, same charset",
|
||||
contentType: "application/json; charset=utf-8",
|
||||
includedContentTypes: []string{"application/json; charset=utf-8"},
|
||||
expCompression: true,
|
||||
},
|
||||
{
|
||||
desc: "MIME match with other directives requires all directives be equal, missing charset",
|
||||
contentType: "application/json",
|
||||
includedContentTypes: []string{"application/json; charset=ascii"},
|
||||
expCompression: false,
|
||||
},
|
||||
{
|
||||
desc: "MIME match case insensitive",
|
||||
contentType: "Application/Json",
|
||||
includedContentTypes: []string{"application/json"},
|
||||
expCompression: true,
|
||||
},
|
||||
{
|
||||
desc: "MIME match ignore whitespace",
|
||||
contentType: "application/json;charset=utf-8",
|
||||
includedContentTypes: []string{"application/json; charset=utf-8"},
|
||||
expCompression: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := Config{
|
||||
MinSize: 1024,
|
||||
IncludedContentTypes: test.includedContentTypes,
|
||||
}
|
||||
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set(contentType, test.contentType)
|
||||
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
|
||||
_, err := rw.Write(bigTestBody)
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
h.ServeHTTP(rw, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rw.Code)
|
||||
|
||||
if test.expCompression {
|
||||
assert.Equal(t, "br", rw.Header().Get(contentEncoding))
|
||||
|
||||
got, err := io.ReadAll(brotli.NewReader(rw.Body))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, bigTestBody, got)
|
||||
} else {
|
||||
assert.NotEqual(t, "br", rw.Header().Get("Content-Encoding"))
|
||||
|
||||
got, err := io.ReadAll(rw.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, bigTestBody, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_FlushExcludedContentTypes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
|
@ -397,10 +501,9 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
|
|||
expCompression bool
|
||||
}{
|
||||
{
|
||||
desc: "Always compress when content types are empty",
|
||||
contentType: "",
|
||||
excludedContentTypes: []string{},
|
||||
expCompression: true,
|
||||
desc: "Always compress when content types are empty",
|
||||
contentType: "",
|
||||
expCompression: true,
|
||||
},
|
||||
{
|
||||
desc: "MIME match",
|
||||
|
@ -509,6 +612,125 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_FlushIncludedContentTypes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
contentType string
|
||||
includedContentTypes []string
|
||||
expCompression bool
|
||||
}{
|
||||
{
|
||||
desc: "Always compress when content types are empty",
|
||||
contentType: "",
|
||||
expCompression: true,
|
||||
},
|
||||
{
|
||||
desc: "MIME match",
|
||||
contentType: "application/json",
|
||||
includedContentTypes: []string{"application/json"},
|
||||
expCompression: true,
|
||||
},
|
||||
{
|
||||
desc: "MIME no match",
|
||||
contentType: "text/xml",
|
||||
includedContentTypes: []string{"application/json"},
|
||||
expCompression: false,
|
||||
},
|
||||
{
|
||||
desc: "MIME match with no other directive ignores non-MIME directives",
|
||||
contentType: "application/json; charset=utf-8",
|
||||
includedContentTypes: []string{"application/json"},
|
||||
expCompression: true,
|
||||
},
|
||||
{
|
||||
desc: "MIME match with other directives requires all directives be equal, different charset",
|
||||
contentType: "application/json; charset=ascii",
|
||||
includedContentTypes: []string{"application/json; charset=utf-8"},
|
||||
expCompression: false,
|
||||
},
|
||||
{
|
||||
desc: "MIME match with other directives requires all directives be equal, same charset",
|
||||
contentType: "application/json; charset=utf-8",
|
||||
includedContentTypes: []string{"application/json; charset=utf-8"},
|
||||
expCompression: true,
|
||||
},
|
||||
{
|
||||
desc: "MIME match with other directives requires all directives be equal, missing charset",
|
||||
contentType: "application/json",
|
||||
includedContentTypes: []string{"application/json; charset=ascii"},
|
||||
expCompression: false,
|
||||
},
|
||||
{
|
||||
desc: "MIME match case insensitive",
|
||||
contentType: "Application/Json",
|
||||
includedContentTypes: []string{"application/json"},
|
||||
expCompression: true,
|
||||
},
|
||||
{
|
||||
desc: "MIME match ignore whitespace",
|
||||
contentType: "application/json;charset=utf-8",
|
||||
includedContentTypes: []string{"application/json; charset=utf-8"},
|
||||
expCompression: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg := Config{
|
||||
MinSize: 1024,
|
||||
IncludedContentTypes: test.includedContentTypes,
|
||||
}
|
||||
h := mustNewWrapper(t, cfg)(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set(contentType, test.contentType)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
|
||||
tb := bigTestBody
|
||||
for len(tb) > 0 {
|
||||
// Write 100 bytes per run
|
||||
// Detection should not be affected (we send 100 bytes)
|
||||
toWrite := 100
|
||||
if toWrite > len(tb) {
|
||||
toWrite = len(tb)
|
||||
}
|
||||
|
||||
_, err := rw.Write(tb[:toWrite])
|
||||
require.NoError(t, err)
|
||||
|
||||
// Flush between each write
|
||||
rw.(http.Flusher).Flush()
|
||||
tb = tb[toWrite:]
|
||||
}
|
||||
}))
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||
req.Header.Set(acceptEncoding, "br")
|
||||
|
||||
// This doesn't allow checking flushes, but we validate if content is correct.
|
||||
rw := httptest.NewRecorder()
|
||||
h.ServeHTTP(rw, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rw.Code)
|
||||
|
||||
if test.expCompression {
|
||||
assert.Equal(t, "br", rw.Header().Get(contentEncoding))
|
||||
|
||||
got, err := io.ReadAll(brotli.NewReader(rw.Body))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, bigTestBody, got)
|
||||
} else {
|
||||
assert.NotEqual(t, "br", rw.Header().Get(contentEncoding))
|
||||
|
||||
got, err := io.ReadAll(rw.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, bigTestBody, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func mustNewWrapper(t *testing.T, cfg Config) func(http.Handler) http.HandlerFunc {
|
||||
t.Helper()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue