Merge branch v3.2 into master
This commit is contained in:
commit
7004f0e750
65 changed files with 2393 additions and 1222 deletions
|
@ -25,7 +25,6 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
ptypes "github.com/traefik/paerser/types"
|
||||
"github.com/traefik/traefik/v3/pkg/middlewares/capture"
|
||||
"github.com/traefik/traefik/v3/pkg/middlewares/recovery"
|
||||
"github.com/traefik/traefik/v3/pkg/types"
|
||||
)
|
||||
|
||||
|
@ -954,8 +953,14 @@ func doLoggingWithAbortedStream(t *testing.T, config *types.AccessLog) {
|
|||
req = req.WithContext(reqContext)
|
||||
|
||||
chain := alice.New()
|
||||
|
||||
chain = chain.Append(func(next http.Handler) (http.Handler, error) {
|
||||
return recovery.New(context.Background(), next)
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
defer func() {
|
||||
_ = recover() // ignore the stream backend panic to avoid the test to fail.
|
||||
}()
|
||||
next.ServeHTTP(rw, req)
|
||||
}), nil
|
||||
})
|
||||
chain = chain.Append(capture.Wrap)
|
||||
chain = chain.Append(WrapHandler(logger))
|
||||
|
|
|
@ -8,7 +8,9 @@ import (
|
|||
"net/http"
|
||||
"slices"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/klauspost/compress/gzhttp"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
||||
"github.com/traefik/traefik/v3/pkg/middlewares"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
@ -94,12 +96,12 @@ func New(ctx context.Context, next http.Handler, conf dynamic.Compress, name str
|
|||
|
||||
var err error
|
||||
|
||||
c.zstdHandler, err = c.newCompressionHandler(zstdName, name)
|
||||
c.zstdHandler, err = c.newZstdHandler(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.brotliHandler, err = c.newCompressionHandler(brotliName, name)
|
||||
c.brotliHandler, err = c.newBrotliHandler(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -190,13 +192,34 @@ func (c *compress) newGzipHandler() (http.Handler, error) {
|
|||
return wrapper(c.next), nil
|
||||
}
|
||||
|
||||
func (c *compress) newCompressionHandler(algo string, middlewareName string) (http.Handler, error) {
|
||||
cfg := Config{MinSize: c.minSize, Algorithm: algo, MiddlewareName: middlewareName}
|
||||
func (c *compress) newBrotliHandler(middlewareName string) (http.Handler, error) {
|
||||
cfg := Config{MinSize: c.minSize, MiddlewareName: middlewareName}
|
||||
if len(c.includes) > 0 {
|
||||
cfg.IncludedContentTypes = c.includes
|
||||
} else {
|
||||
cfg.ExcludedContentTypes = c.excludes
|
||||
}
|
||||
|
||||
return NewCompressionHandler(cfg, c.next)
|
||||
newBrotliWriter := func(rw http.ResponseWriter) (CompressionWriter, string, error) {
|
||||
return brotli.NewWriter(rw), brotliName, nil
|
||||
}
|
||||
return NewCompressionHandler(cfg, newBrotliWriter, c.next)
|
||||
}
|
||||
|
||||
func (c *compress) newZstdHandler(middlewareName string) (http.Handler, error) {
|
||||
cfg := Config{MinSize: c.minSize, MiddlewareName: middlewareName}
|
||||
if len(c.includes) > 0 {
|
||||
cfg.IncludedContentTypes = c.includes
|
||||
} else {
|
||||
cfg.ExcludedContentTypes = c.excludes
|
||||
}
|
||||
|
||||
newZstdWriter := func(rw http.ResponseWriter) (CompressionWriter, string, error) {
|
||||
writer, err := zstd.NewWriter(rw)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("creating zstd writer: %w", err)
|
||||
}
|
||||
return writer, zstdName, nil
|
||||
}
|
||||
return NewCompressionHandler(cfg, newZstdWriter, c.next)
|
||||
}
|
||||
|
|
|
@ -10,8 +10,6 @@ import (
|
|||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/traefik/traefik/v3/pkg/middlewares"
|
||||
"github.com/traefik/traefik/v3/pkg/middlewares/observability"
|
||||
)
|
||||
|
@ -24,6 +22,30 @@ const (
|
|||
contentType = "Content-Type"
|
||||
)
|
||||
|
||||
// CompressionWriter compresses the written bytes.
|
||||
type CompressionWriter interface {
|
||||
// Write data to the encoder.
|
||||
// Input data will be buffered and as the buffer fills up
|
||||
// content will be compressed and written to the output.
|
||||
// When done writing, use Close to flush the remaining output
|
||||
// and write CRC if requested.
|
||||
Write(p []byte) (n int, err error)
|
||||
// Flush will send the currently written data to output
|
||||
// and block until everything has been written.
|
||||
// This should only be used on rare occasions where pushing the currently queued data is critical.
|
||||
Flush() error
|
||||
// Close closes the underlying writers if/when appropriate.
|
||||
// Note that the compressed writer should not be closed if we never used it,
|
||||
// as it would otherwise send some extra "end of compression" bytes.
|
||||
// Close also makes sure to flush whatever was left to write from the buffer.
|
||||
Close() error
|
||||
// Reset reinitializes the state of the encoder, allowing it to be reused.
|
||||
Reset(w io.Writer)
|
||||
}
|
||||
|
||||
// NewCompressionWriter returns a new CompressionWriter with its corresponding algorithm.
|
||||
type NewCompressionWriter func(rw http.ResponseWriter) (CompressionWriter, string, error)
|
||||
|
||||
// Config is the Brotli handler configuration.
|
||||
type Config struct {
|
||||
// ExcludedContentTypes is the list of content types for which we should not compress.
|
||||
|
@ -34,8 +56,6 @@ type Config struct {
|
|||
IncludedContentTypes []string
|
||||
// MinSize is the minimum size (in bytes) required to enable compression.
|
||||
MinSize int
|
||||
// Algorithm used for the compression (currently Brotli and Zstandard)
|
||||
Algorithm string
|
||||
// MiddlewareName use for logging purposes
|
||||
MiddlewareName string
|
||||
}
|
||||
|
@ -46,15 +66,13 @@ type CompressionHandler struct {
|
|||
excludedContentTypes []parsedContentType
|
||||
includedContentTypes []parsedContentType
|
||||
next http.Handler
|
||||
writerPool sync.Pool
|
||||
|
||||
writerPool sync.Pool
|
||||
newWriter NewCompressionWriter
|
||||
}
|
||||
|
||||
// NewCompressionHandler returns a new compressing handler.
|
||||
func NewCompressionHandler(cfg Config, next http.Handler) (http.Handler, error) {
|
||||
if cfg.Algorithm == "" {
|
||||
return nil, errors.New("compression algorithm undefined")
|
||||
}
|
||||
|
||||
func NewCompressionHandler(cfg Config, newWriter NewCompressionWriter, next http.Handler) (http.Handler, error) {
|
||||
if cfg.MinSize < 0 {
|
||||
return nil, errors.New("minimum size must be greater than or equal to zero")
|
||||
}
|
||||
|
@ -88,6 +106,7 @@ func NewCompressionHandler(cfg Config, next http.Handler) (http.Handler, error)
|
|||
excludedContentTypes: excludedContentTypes,
|
||||
includedContentTypes: includedContentTypes,
|
||||
next: next,
|
||||
newWriter: newWriter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -117,70 +136,38 @@ func (c *CompressionHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
|||
c.next.ServeHTTP(responseWriter, r)
|
||||
}
|
||||
|
||||
type compression interface {
|
||||
// Write data to the encoder.
|
||||
// Input data will be buffered and as the buffer fills up
|
||||
// content will be compressed and written to the output.
|
||||
// When done writing, use Close to flush the remaining output
|
||||
// and write CRC if requested.
|
||||
Write(p []byte) (n int, err error)
|
||||
// Flush will send the currently written data to output
|
||||
// and block until everything has been written.
|
||||
// This should only be used on rare occasions where pushing the currently queued data is critical.
|
||||
Flush() error
|
||||
// Close closes the underlying writers if/when appropriate.
|
||||
// Note that the compressed writer should not be closed if we never used it,
|
||||
// as it would otherwise send some extra "end of compression" bytes.
|
||||
// Close also makes sure to flush whatever was left to write from the buffer.
|
||||
Close() error
|
||||
// Reset reinitializes the state of the encoder, allowing it to be reused.
|
||||
Reset(w io.Writer)
|
||||
}
|
||||
|
||||
type compressionWriter struct {
|
||||
compression
|
||||
alg string
|
||||
}
|
||||
|
||||
func (c *CompressionHandler) getCompressionWriter(rw io.Writer) (*compressionWriter, error) {
|
||||
if writer, ok := c.writerPool.Get().(*compressionWriter); ok {
|
||||
writer.compression.Reset(rw)
|
||||
func (c *CompressionHandler) getCompressionWriter(rw http.ResponseWriter) (*compressionWriterWrapper, error) {
|
||||
if writer, ok := c.writerPool.Get().(*compressionWriterWrapper); ok {
|
||||
writer.Reset(rw)
|
||||
return writer, nil
|
||||
}
|
||||
return newCompressionWriter(c.cfg.Algorithm, rw)
|
||||
|
||||
writer, algo, err := c.newWriter(rw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating compression writer: %w", err)
|
||||
}
|
||||
return &compressionWriterWrapper{CompressionWriter: writer, algo: algo}, nil
|
||||
}
|
||||
|
||||
func (c *CompressionHandler) putCompressionWriter(writer *compressionWriter) {
|
||||
func (c *CompressionHandler) putCompressionWriter(writer *compressionWriterWrapper) {
|
||||
writer.Reset(nil)
|
||||
c.writerPool.Put(writer)
|
||||
}
|
||||
|
||||
func newCompressionWriter(algo string, in io.Writer) (*compressionWriter, error) {
|
||||
switch algo {
|
||||
case brotliName:
|
||||
return &compressionWriter{compression: brotli.NewWriter(in), alg: algo}, nil
|
||||
|
||||
case zstdName:
|
||||
writer, err := zstd.NewWriter(in)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating zstd writer: %w", err)
|
||||
}
|
||||
return &compressionWriter{compression: writer, alg: algo}, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown compression algo: %s", algo)
|
||||
}
|
||||
type compressionWriterWrapper struct {
|
||||
CompressionWriter
|
||||
algo string
|
||||
}
|
||||
|
||||
func (c *compressionWriter) ContentEncoding() string {
|
||||
return c.alg
|
||||
func (c *compressionWriterWrapper) ContentEncoding() string {
|
||||
return c.algo
|
||||
}
|
||||
|
||||
// TODO: check whether we want to implement content-type sniffing (as gzip does)
|
||||
// TODO: check whether we should support Accept-Ranges (as gzip does, see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Ranges)
|
||||
type responseWriter struct {
|
||||
rw http.ResponseWriter
|
||||
compressionWriter *compressionWriter
|
||||
compressionWriter *compressionWriterWrapper
|
||||
|
||||
minSize int
|
||||
excludedContentTypes []parsedContentType
|
||||
|
|
|
@ -162,7 +162,7 @@ func Test_NoBody(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
h := mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName}, next)
|
||||
h := mustNewCompressionHandler(t, Config{MinSize: 1024}, zstdName, next)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(acceptEncoding, "zstd")
|
||||
|
@ -181,8 +181,7 @@ func Test_NoBody(t *testing.T) {
|
|||
|
||||
func Test_MinSize(t *testing.T) {
|
||||
cfg := Config{
|
||||
MinSize: 128,
|
||||
Algorithm: zstdName,
|
||||
MinSize: 128,
|
||||
}
|
||||
|
||||
var bodySize int
|
||||
|
@ -197,7 +196,7 @@ func Test_MinSize(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
h := mustNewCompressionHandler(t, cfg, next)
|
||||
h := mustNewCompressionHandler(t, cfg, zstdName, next)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", &bytes.Buffer{})
|
||||
req.Header.Add(acceptEncoding, "zstd")
|
||||
|
@ -224,7 +223,7 @@ func Test_MultipleWriteHeader(t *testing.T) {
|
|||
rw.WriteHeader(http.StatusNotFound)
|
||||
})
|
||||
|
||||
h := mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName}, next)
|
||||
h := mustNewCompressionHandler(t, Config{MinSize: 1024}, zstdName, next)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(acceptEncoding, "zstd")
|
||||
|
@ -239,12 +238,14 @@ func Test_FlushBeforeWrite(t *testing.T) {
|
|||
testCases := []struct {
|
||||
desc string
|
||||
cfg Config
|
||||
algo string
|
||||
readerBuilder func(io.Reader) (io.Reader, error)
|
||||
acceptEncoding string
|
||||
}{
|
||||
{
|
||||
desc: "brotli",
|
||||
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
|
||||
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
|
||||
algo: brotliName,
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return brotli.NewReader(reader), nil
|
||||
},
|
||||
|
@ -252,7 +253,8 @@ func Test_FlushBeforeWrite(t *testing.T) {
|
|||
},
|
||||
{
|
||||
desc: "zstd",
|
||||
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
|
||||
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
|
||||
algo: zstdName,
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return zstd.NewReader(reader)
|
||||
},
|
||||
|
@ -272,7 +274,7 @@ func Test_FlushBeforeWrite(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
|
||||
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
||||
|
@ -302,12 +304,14 @@ func Test_FlushAfterWrite(t *testing.T) {
|
|||
testCases := []struct {
|
||||
desc string
|
||||
cfg Config
|
||||
algo string
|
||||
readerBuilder func(io.Reader) (io.Reader, error)
|
||||
acceptEncoding string
|
||||
}{
|
||||
{
|
||||
desc: "brotli",
|
||||
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
|
||||
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
|
||||
algo: brotliName,
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return brotli.NewReader(reader), nil
|
||||
},
|
||||
|
@ -315,7 +319,8 @@ func Test_FlushAfterWrite(t *testing.T) {
|
|||
},
|
||||
{
|
||||
desc: "zstd",
|
||||
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
|
||||
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
|
||||
algo: zstdName,
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return zstd.NewReader(reader)
|
||||
},
|
||||
|
@ -338,7 +343,7 @@ func Test_FlushAfterWrite(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
|
||||
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
||||
|
@ -368,12 +373,14 @@ func Test_FlushAfterWriteNil(t *testing.T) {
|
|||
testCases := []struct {
|
||||
desc string
|
||||
cfg Config
|
||||
algo string
|
||||
readerBuilder func(io.Reader) (io.Reader, error)
|
||||
acceptEncoding string
|
||||
}{
|
||||
{
|
||||
desc: "brotli",
|
||||
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
|
||||
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
|
||||
algo: brotliName,
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return brotli.NewReader(reader), nil
|
||||
},
|
||||
|
@ -381,7 +388,8 @@ func Test_FlushAfterWriteNil(t *testing.T) {
|
|||
},
|
||||
{
|
||||
desc: "zstd",
|
||||
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
|
||||
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
|
||||
algo: zstdName,
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return zstd.NewReader(reader)
|
||||
},
|
||||
|
@ -400,7 +408,7 @@ func Test_FlushAfterWriteNil(t *testing.T) {
|
|||
rw.(http.Flusher).Flush()
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
|
||||
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
||||
|
@ -430,12 +438,14 @@ func Test_FlushAfterAllWrites(t *testing.T) {
|
|||
testCases := []struct {
|
||||
desc string
|
||||
cfg Config
|
||||
algo string
|
||||
readerBuilder func(io.Reader) (io.Reader, error)
|
||||
acceptEncoding string
|
||||
}{
|
||||
{
|
||||
desc: "brotli",
|
||||
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
|
||||
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
|
||||
algo: brotliName,
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return brotli.NewReader(reader), nil
|
||||
},
|
||||
|
@ -443,7 +453,8 @@ func Test_FlushAfterAllWrites(t *testing.T) {
|
|||
},
|
||||
{
|
||||
desc: "zstd",
|
||||
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
|
||||
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
|
||||
algo: zstdName,
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return zstd.NewReader(reader)
|
||||
},
|
||||
|
@ -461,7 +472,7 @@ func Test_FlushAfterAllWrites(t *testing.T) {
|
|||
rw.(http.Flusher).Flush()
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
|
||||
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
||||
|
@ -556,7 +567,6 @@ func Test_ExcludedContentTypes(t *testing.T) {
|
|||
cfg := Config{
|
||||
MinSize: 1024,
|
||||
ExcludedContentTypes: test.excludedContentTypes,
|
||||
Algorithm: zstdName,
|
||||
}
|
||||
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
|
@ -568,7 +578,7 @@ func Test_ExcludedContentTypes(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
h := mustNewCompressionHandler(t, cfg, next)
|
||||
h := mustNewCompressionHandler(t, cfg, zstdName, next)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||
req.Header.Set(acceptEncoding, zstdName)
|
||||
|
@ -667,7 +677,6 @@ func Test_IncludedContentTypes(t *testing.T) {
|
|||
cfg := Config{
|
||||
MinSize: 1024,
|
||||
IncludedContentTypes: test.includedContentTypes,
|
||||
Algorithm: zstdName,
|
||||
}
|
||||
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
|
@ -679,7 +688,7 @@ func Test_IncludedContentTypes(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
h := mustNewCompressionHandler(t, cfg, next)
|
||||
h := mustNewCompressionHandler(t, cfg, zstdName, next)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||
req.Header.Set(acceptEncoding, zstdName)
|
||||
|
@ -778,7 +787,6 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
|
|||
cfg := Config{
|
||||
MinSize: 1024,
|
||||
ExcludedContentTypes: test.excludedContentTypes,
|
||||
Algorithm: zstdName,
|
||||
}
|
||||
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
|
@ -803,7 +811,7 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
h := mustNewCompressionHandler(t, cfg, next)
|
||||
h := mustNewCompressionHandler(t, cfg, zstdName, next)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||
req.Header.Set(acceptEncoding, zstdName)
|
||||
|
@ -903,7 +911,6 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
|
|||
cfg := Config{
|
||||
MinSize: 1024,
|
||||
IncludedContentTypes: test.includedContentTypes,
|
||||
Algorithm: zstdName,
|
||||
}
|
||||
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
|
@ -928,7 +935,7 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
h := mustNewCompressionHandler(t, cfg, next)
|
||||
h := mustNewCompressionHandler(t, cfg, zstdName, next)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||
req.Header.Set(acceptEncoding, zstdName)
|
||||
|
@ -959,10 +966,26 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func mustNewCompressionHandler(t *testing.T, cfg Config, next http.Handler) http.Handler {
|
||||
func mustNewCompressionHandler(t *testing.T, cfg Config, algo string, next http.Handler) http.Handler {
|
||||
t.Helper()
|
||||
|
||||
w, err := NewCompressionHandler(cfg, next)
|
||||
var writer NewCompressionWriter
|
||||
switch algo {
|
||||
case zstdName:
|
||||
writer = func(rw http.ResponseWriter) (CompressionWriter, string, error) {
|
||||
writer, err := zstd.NewWriter(rw)
|
||||
require.NoError(t, err)
|
||||
return writer, zstdName, nil
|
||||
}
|
||||
case brotliName:
|
||||
writer = func(rw http.ResponseWriter) (CompressionWriter, string, error) {
|
||||
return brotli.NewWriter(rw), brotliName, nil
|
||||
}
|
||||
default:
|
||||
assert.Failf(t, "unknown compression algorithm: %s", algo)
|
||||
}
|
||||
|
||||
w, err := NewCompressionHandler(cfg, writer, next)
|
||||
require.NoError(t, err)
|
||||
|
||||
return w
|
||||
|
@ -981,7 +1004,7 @@ func newTestBrotliHandler(t *testing.T, body []byte) http.Handler {
|
|||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
return mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Compress"}, next)
|
||||
return mustNewCompressionHandler(t, Config{MinSize: 1024, MiddlewareName: "Compress"}, brotliName, next)
|
||||
}
|
||||
|
||||
func newTestZstandardHandler(t *testing.T, body []byte) http.Handler {
|
||||
|
@ -997,7 +1020,7 @@ func newTestZstandardHandler(t *testing.T, body []byte) http.Handler {
|
|||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
return mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Compress"}, next)
|
||||
return mustNewCompressionHandler(t, Config{MinSize: 1024, MiddlewareName: "Compress"}, zstdName, next)
|
||||
}
|
||||
|
||||
func Test_ParseContentType_equals(t *testing.T) {
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
package recovery
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
|
||||
|
@ -27,12 +30,16 @@ func New(ctx context.Context, next http.Handler) (http.Handler, error) {
|
|||
}
|
||||
|
||||
func (re *recovery) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
defer recoverFunc(rw, req)
|
||||
re.next.ServeHTTP(rw, req)
|
||||
recoveryRW := newRecoveryResponseWriter(rw)
|
||||
defer recoverFunc(recoveryRW, req)
|
||||
|
||||
re.next.ServeHTTP(recoveryRW, req)
|
||||
}
|
||||
|
||||
func recoverFunc(rw http.ResponseWriter, req *http.Request) {
|
||||
func recoverFunc(rw recoveryResponseWriter, req *http.Request) {
|
||||
if err := recover(); err != nil {
|
||||
defer rw.finalizeResponse()
|
||||
|
||||
logger := middlewares.GetLogger(req.Context(), middlewareName, typeName)
|
||||
if !shouldLogPanic(err) {
|
||||
logger.Debug().Msgf("Request has been aborted [%s - %s]: %v", req.RemoteAddr, req.URL, err)
|
||||
|
@ -44,8 +51,6 @@ func recoverFunc(rw http.ResponseWriter, req *http.Request) {
|
|||
buf := make([]byte, size)
|
||||
buf = buf[:runtime.Stack(buf, false)]
|
||||
logger.Error().Msgf("Stack: %s", buf)
|
||||
|
||||
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -55,3 +60,81 @@ func shouldLogPanic(panicValue interface{}) bool {
|
|||
//nolint:errorlint // false-positive because panicValue is an interface.
|
||||
return panicValue != nil && panicValue != http.ErrAbortHandler
|
||||
}
|
||||
|
||||
type recoveryResponseWriter interface {
|
||||
http.ResponseWriter
|
||||
|
||||
finalizeResponse()
|
||||
}
|
||||
|
||||
func newRecoveryResponseWriter(rw http.ResponseWriter) recoveryResponseWriter {
|
||||
wrapper := &responseWriterWrapper{rw: rw}
|
||||
if _, ok := rw.(http.CloseNotifier); !ok {
|
||||
return wrapper
|
||||
}
|
||||
|
||||
return &responseWriterWrapperWithCloseNotify{wrapper}
|
||||
}
|
||||
|
||||
type responseWriterWrapper struct {
|
||||
rw http.ResponseWriter
|
||||
headersSent bool
|
||||
}
|
||||
|
||||
func (r *responseWriterWrapper) Header() http.Header {
|
||||
return r.rw.Header()
|
||||
}
|
||||
|
||||
func (r *responseWriterWrapper) Write(bytes []byte) (int, error) {
|
||||
r.headersSent = true
|
||||
return r.rw.Write(bytes)
|
||||
}
|
||||
|
||||
func (r *responseWriterWrapper) WriteHeader(code int) {
|
||||
if r.headersSent {
|
||||
return
|
||||
}
|
||||
|
||||
// Handling informational headers.
|
||||
if code >= 100 && code <= 199 {
|
||||
r.rw.WriteHeader(code)
|
||||
return
|
||||
}
|
||||
|
||||
r.headersSent = true
|
||||
r.rw.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (r *responseWriterWrapper) Flush() {
|
||||
if f, ok := r.rw.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *responseWriterWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if h, ok := r.rw.(http.Hijacker); ok {
|
||||
return h.Hijack()
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf("not a hijacker: %T", r.rw)
|
||||
}
|
||||
|
||||
func (r *responseWriterWrapper) finalizeResponse() {
|
||||
// If headers have been sent this is not possible to respond with an HTTP error,
|
||||
// and we let the server abort the response silently thanks to the http.ErrAbortHandler sentinel panic value.
|
||||
if r.headersSent {
|
||||
panic(http.ErrAbortHandler)
|
||||
}
|
||||
|
||||
// The response has not yet started to be written,
|
||||
// we can safely return a fresh new error response.
|
||||
http.Error(r.rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
type responseWriterWrapperWithCloseNotify struct {
|
||||
*responseWriterWrapper
|
||||
}
|
||||
|
||||
func (r *responseWriterWrapperWithCloseNotify) CloseNotify() <-chan bool {
|
||||
return r.rw.(http.CloseNotifier).CloseNotify()
|
||||
}
|
||||
|
|
|
@ -2,6 +2,8 @@ package recovery
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
@ -11,17 +13,54 @@ import (
|
|||
)
|
||||
|
||||
func TestRecoverHandler(t *testing.T) {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
panic("I love panicking!")
|
||||
tests := []struct {
|
||||
desc string
|
||||
panicErr error
|
||||
headersSent bool
|
||||
}{
|
||||
{
|
||||
desc: "headers sent and custom panic error",
|
||||
panicErr: errors.New("foo"),
|
||||
headersSent: true,
|
||||
},
|
||||
{
|
||||
desc: "headers sent and error abort handler",
|
||||
panicErr: http.ErrAbortHandler,
|
||||
headersSent: true,
|
||||
},
|
||||
{
|
||||
desc: "custom panic error",
|
||||
panicErr: errors.New("foo"),
|
||||
},
|
||||
{
|
||||
desc: "error abort handler",
|
||||
panicErr: http.ErrAbortHandler,
|
||||
},
|
||||
}
|
||||
recovery, err := New(context.Background(), http.HandlerFunc(fn))
|
||||
require.NoError(t, err)
|
||||
for _, test := range tests {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(recovery)
|
||||
defer server.Close()
|
||||
fn := func(rw http.ResponseWriter, req *http.Request) {
|
||||
if test.headersSent {
|
||||
rw.WriteHeader(http.StatusTeapot)
|
||||
}
|
||||
panic(test.panicErr)
|
||||
}
|
||||
recovery, err := New(context.Background(), http.HandlerFunc(fn))
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
require.NoError(t, err)
|
||||
server := httptest.NewServer(recovery)
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
|
||||
res, err := http.Get(server.URL)
|
||||
if test.headersSent {
|
||||
require.Nil(t, res)
|
||||
assert.ErrorIs(t, err, io.EOF)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue