Merge branch v3.2 into master
This commit is contained in:
commit
7004f0e750
65 changed files with 2393 additions and 1222 deletions
|
@ -244,10 +244,11 @@ func (r *ResponseForwarding) SetDefaults() {
|
|||
|
||||
// Server holds the server configuration.
|
||||
type Server struct {
|
||||
URL string `json:"url,omitempty" toml:"url,omitempty" yaml:"url,omitempty" label:"-"`
|
||||
Weight *int `json:"weight,omitempty" toml:"weight,omitempty" yaml:"weight,omitempty" label:"weight"`
|
||||
Scheme string `json:"-" toml:"-" yaml:"-" file:"-"`
|
||||
Port string `json:"-" toml:"-" yaml:"-" file:"-"`
|
||||
URL string `json:"url,omitempty" toml:"url,omitempty" yaml:"url,omitempty" label:"-"`
|
||||
Weight *int `json:"weight,omitempty" toml:"weight,omitempty" yaml:"weight,omitempty" label:"weight" export:"true"`
|
||||
PreservePath bool `json:"preservePath,omitempty" toml:"preservePath,omitempty" yaml:"preservePath,omitempty" label:"-" export:"true"`
|
||||
Scheme string `json:"-" toml:"-" yaml:"-" file:"-"`
|
||||
Port string `json:"-" toml:"-" yaml:"-" file:"-"`
|
||||
}
|
||||
|
||||
// SetDefaults Default values for a Server.
|
||||
|
|
|
@ -25,7 +25,6 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
ptypes "github.com/traefik/paerser/types"
|
||||
"github.com/traefik/traefik/v3/pkg/middlewares/capture"
|
||||
"github.com/traefik/traefik/v3/pkg/middlewares/recovery"
|
||||
"github.com/traefik/traefik/v3/pkg/types"
|
||||
)
|
||||
|
||||
|
@ -954,8 +953,14 @@ func doLoggingWithAbortedStream(t *testing.T, config *types.AccessLog) {
|
|||
req = req.WithContext(reqContext)
|
||||
|
||||
chain := alice.New()
|
||||
|
||||
chain = chain.Append(func(next http.Handler) (http.Handler, error) {
|
||||
return recovery.New(context.Background(), next)
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
defer func() {
|
||||
_ = recover() // ignore the stream backend panic to avoid the test to fail.
|
||||
}()
|
||||
next.ServeHTTP(rw, req)
|
||||
}), nil
|
||||
})
|
||||
chain = chain.Append(capture.Wrap)
|
||||
chain = chain.Append(WrapHandler(logger))
|
||||
|
|
|
@ -8,7 +8,9 @@ import (
|
|||
"net/http"
|
||||
"slices"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/klauspost/compress/gzhttp"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/traefik/traefik/v3/pkg/config/dynamic"
|
||||
"github.com/traefik/traefik/v3/pkg/middlewares"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
@ -94,12 +96,12 @@ func New(ctx context.Context, next http.Handler, conf dynamic.Compress, name str
|
|||
|
||||
var err error
|
||||
|
||||
c.zstdHandler, err = c.newCompressionHandler(zstdName, name)
|
||||
c.zstdHandler, err = c.newZstdHandler(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.brotliHandler, err = c.newCompressionHandler(brotliName, name)
|
||||
c.brotliHandler, err = c.newBrotliHandler(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -190,13 +192,34 @@ func (c *compress) newGzipHandler() (http.Handler, error) {
|
|||
return wrapper(c.next), nil
|
||||
}
|
||||
|
||||
func (c *compress) newCompressionHandler(algo string, middlewareName string) (http.Handler, error) {
|
||||
cfg := Config{MinSize: c.minSize, Algorithm: algo, MiddlewareName: middlewareName}
|
||||
func (c *compress) newBrotliHandler(middlewareName string) (http.Handler, error) {
|
||||
cfg := Config{MinSize: c.minSize, MiddlewareName: middlewareName}
|
||||
if len(c.includes) > 0 {
|
||||
cfg.IncludedContentTypes = c.includes
|
||||
} else {
|
||||
cfg.ExcludedContentTypes = c.excludes
|
||||
}
|
||||
|
||||
return NewCompressionHandler(cfg, c.next)
|
||||
newBrotliWriter := func(rw http.ResponseWriter) (CompressionWriter, string, error) {
|
||||
return brotli.NewWriter(rw), brotliName, nil
|
||||
}
|
||||
return NewCompressionHandler(cfg, newBrotliWriter, c.next)
|
||||
}
|
||||
|
||||
func (c *compress) newZstdHandler(middlewareName string) (http.Handler, error) {
|
||||
cfg := Config{MinSize: c.minSize, MiddlewareName: middlewareName}
|
||||
if len(c.includes) > 0 {
|
||||
cfg.IncludedContentTypes = c.includes
|
||||
} else {
|
||||
cfg.ExcludedContentTypes = c.excludes
|
||||
}
|
||||
|
||||
newZstdWriter := func(rw http.ResponseWriter) (CompressionWriter, string, error) {
|
||||
writer, err := zstd.NewWriter(rw)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("creating zstd writer: %w", err)
|
||||
}
|
||||
return writer, zstdName, nil
|
||||
}
|
||||
return NewCompressionHandler(cfg, newZstdWriter, c.next)
|
||||
}
|
||||
|
|
|
@ -10,8 +10,6 @@ import (
|
|||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/traefik/traefik/v3/pkg/middlewares"
|
||||
"github.com/traefik/traefik/v3/pkg/middlewares/observability"
|
||||
)
|
||||
|
@ -24,6 +22,30 @@ const (
|
|||
contentType = "Content-Type"
|
||||
)
|
||||
|
||||
// CompressionWriter compresses the written bytes.
|
||||
type CompressionWriter interface {
|
||||
// Write data to the encoder.
|
||||
// Input data will be buffered and as the buffer fills up
|
||||
// content will be compressed and written to the output.
|
||||
// When done writing, use Close to flush the remaining output
|
||||
// and write CRC if requested.
|
||||
Write(p []byte) (n int, err error)
|
||||
// Flush will send the currently written data to output
|
||||
// and block until everything has been written.
|
||||
// This should only be used on rare occasions where pushing the currently queued data is critical.
|
||||
Flush() error
|
||||
// Close closes the underlying writers if/when appropriate.
|
||||
// Note that the compressed writer should not be closed if we never used it,
|
||||
// as it would otherwise send some extra "end of compression" bytes.
|
||||
// Close also makes sure to flush whatever was left to write from the buffer.
|
||||
Close() error
|
||||
// Reset reinitializes the state of the encoder, allowing it to be reused.
|
||||
Reset(w io.Writer)
|
||||
}
|
||||
|
||||
// NewCompressionWriter returns a new CompressionWriter with its corresponding algorithm.
|
||||
type NewCompressionWriter func(rw http.ResponseWriter) (CompressionWriter, string, error)
|
||||
|
||||
// Config is the Brotli handler configuration.
|
||||
type Config struct {
|
||||
// ExcludedContentTypes is the list of content types for which we should not compress.
|
||||
|
@ -34,8 +56,6 @@ type Config struct {
|
|||
IncludedContentTypes []string
|
||||
// MinSize is the minimum size (in bytes) required to enable compression.
|
||||
MinSize int
|
||||
// Algorithm used for the compression (currently Brotli and Zstandard)
|
||||
Algorithm string
|
||||
// MiddlewareName use for logging purposes
|
||||
MiddlewareName string
|
||||
}
|
||||
|
@ -46,15 +66,13 @@ type CompressionHandler struct {
|
|||
excludedContentTypes []parsedContentType
|
||||
includedContentTypes []parsedContentType
|
||||
next http.Handler
|
||||
writerPool sync.Pool
|
||||
|
||||
writerPool sync.Pool
|
||||
newWriter NewCompressionWriter
|
||||
}
|
||||
|
||||
// NewCompressionHandler returns a new compressing handler.
|
||||
func NewCompressionHandler(cfg Config, next http.Handler) (http.Handler, error) {
|
||||
if cfg.Algorithm == "" {
|
||||
return nil, errors.New("compression algorithm undefined")
|
||||
}
|
||||
|
||||
func NewCompressionHandler(cfg Config, newWriter NewCompressionWriter, next http.Handler) (http.Handler, error) {
|
||||
if cfg.MinSize < 0 {
|
||||
return nil, errors.New("minimum size must be greater than or equal to zero")
|
||||
}
|
||||
|
@ -88,6 +106,7 @@ func NewCompressionHandler(cfg Config, next http.Handler) (http.Handler, error)
|
|||
excludedContentTypes: excludedContentTypes,
|
||||
includedContentTypes: includedContentTypes,
|
||||
next: next,
|
||||
newWriter: newWriter,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -117,70 +136,38 @@ func (c *CompressionHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request)
|
|||
c.next.ServeHTTP(responseWriter, r)
|
||||
}
|
||||
|
||||
type compression interface {
|
||||
// Write data to the encoder.
|
||||
// Input data will be buffered and as the buffer fills up
|
||||
// content will be compressed and written to the output.
|
||||
// When done writing, use Close to flush the remaining output
|
||||
// and write CRC if requested.
|
||||
Write(p []byte) (n int, err error)
|
||||
// Flush will send the currently written data to output
|
||||
// and block until everything has been written.
|
||||
// This should only be used on rare occasions where pushing the currently queued data is critical.
|
||||
Flush() error
|
||||
// Close closes the underlying writers if/when appropriate.
|
||||
// Note that the compressed writer should not be closed if we never used it,
|
||||
// as it would otherwise send some extra "end of compression" bytes.
|
||||
// Close also makes sure to flush whatever was left to write from the buffer.
|
||||
Close() error
|
||||
// Reset reinitializes the state of the encoder, allowing it to be reused.
|
||||
Reset(w io.Writer)
|
||||
}
|
||||
|
||||
type compressionWriter struct {
|
||||
compression
|
||||
alg string
|
||||
}
|
||||
|
||||
func (c *CompressionHandler) getCompressionWriter(rw io.Writer) (*compressionWriter, error) {
|
||||
if writer, ok := c.writerPool.Get().(*compressionWriter); ok {
|
||||
writer.compression.Reset(rw)
|
||||
func (c *CompressionHandler) getCompressionWriter(rw http.ResponseWriter) (*compressionWriterWrapper, error) {
|
||||
if writer, ok := c.writerPool.Get().(*compressionWriterWrapper); ok {
|
||||
writer.Reset(rw)
|
||||
return writer, nil
|
||||
}
|
||||
return newCompressionWriter(c.cfg.Algorithm, rw)
|
||||
|
||||
writer, algo, err := c.newWriter(rw)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating compression writer: %w", err)
|
||||
}
|
||||
return &compressionWriterWrapper{CompressionWriter: writer, algo: algo}, nil
|
||||
}
|
||||
|
||||
func (c *CompressionHandler) putCompressionWriter(writer *compressionWriter) {
|
||||
func (c *CompressionHandler) putCompressionWriter(writer *compressionWriterWrapper) {
|
||||
writer.Reset(nil)
|
||||
c.writerPool.Put(writer)
|
||||
}
|
||||
|
||||
func newCompressionWriter(algo string, in io.Writer) (*compressionWriter, error) {
|
||||
switch algo {
|
||||
case brotliName:
|
||||
return &compressionWriter{compression: brotli.NewWriter(in), alg: algo}, nil
|
||||
|
||||
case zstdName:
|
||||
writer, err := zstd.NewWriter(in)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating zstd writer: %w", err)
|
||||
}
|
||||
return &compressionWriter{compression: writer, alg: algo}, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown compression algo: %s", algo)
|
||||
}
|
||||
type compressionWriterWrapper struct {
|
||||
CompressionWriter
|
||||
algo string
|
||||
}
|
||||
|
||||
func (c *compressionWriter) ContentEncoding() string {
|
||||
return c.alg
|
||||
func (c *compressionWriterWrapper) ContentEncoding() string {
|
||||
return c.algo
|
||||
}
|
||||
|
||||
// TODO: check whether we want to implement content-type sniffing (as gzip does)
|
||||
// TODO: check whether we should support Accept-Ranges (as gzip does, see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Ranges)
|
||||
type responseWriter struct {
|
||||
rw http.ResponseWriter
|
||||
compressionWriter *compressionWriter
|
||||
compressionWriter *compressionWriterWrapper
|
||||
|
||||
minSize int
|
||||
excludedContentTypes []parsedContentType
|
||||
|
|
|
@ -162,7 +162,7 @@ func Test_NoBody(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
h := mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName}, next)
|
||||
h := mustNewCompressionHandler(t, Config{MinSize: 1024}, zstdName, next)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(acceptEncoding, "zstd")
|
||||
|
@ -181,8 +181,7 @@ func Test_NoBody(t *testing.T) {
|
|||
|
||||
func Test_MinSize(t *testing.T) {
|
||||
cfg := Config{
|
||||
MinSize: 128,
|
||||
Algorithm: zstdName,
|
||||
MinSize: 128,
|
||||
}
|
||||
|
||||
var bodySize int
|
||||
|
@ -197,7 +196,7 @@ func Test_MinSize(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
h := mustNewCompressionHandler(t, cfg, next)
|
||||
h := mustNewCompressionHandler(t, cfg, zstdName, next)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", &bytes.Buffer{})
|
||||
req.Header.Add(acceptEncoding, "zstd")
|
||||
|
@ -224,7 +223,7 @@ func Test_MultipleWriteHeader(t *testing.T) {
|
|||
rw.WriteHeader(http.StatusNotFound)
|
||||
})
|
||||
|
||||
h := mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName}, next)
|
||||
h := mustNewCompressionHandler(t, Config{MinSize: 1024}, zstdName, next)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.Header.Set(acceptEncoding, "zstd")
|
||||
|
@ -239,12 +238,14 @@ func Test_FlushBeforeWrite(t *testing.T) {
|
|||
testCases := []struct {
|
||||
desc string
|
||||
cfg Config
|
||||
algo string
|
||||
readerBuilder func(io.Reader) (io.Reader, error)
|
||||
acceptEncoding string
|
||||
}{
|
||||
{
|
||||
desc: "brotli",
|
||||
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
|
||||
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
|
||||
algo: brotliName,
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return brotli.NewReader(reader), nil
|
||||
},
|
||||
|
@ -252,7 +253,8 @@ func Test_FlushBeforeWrite(t *testing.T) {
|
|||
},
|
||||
{
|
||||
desc: "zstd",
|
||||
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
|
||||
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
|
||||
algo: zstdName,
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return zstd.NewReader(reader)
|
||||
},
|
||||
|
@ -272,7 +274,7 @@ func Test_FlushBeforeWrite(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
|
||||
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
||||
|
@ -302,12 +304,14 @@ func Test_FlushAfterWrite(t *testing.T) {
|
|||
testCases := []struct {
|
||||
desc string
|
||||
cfg Config
|
||||
algo string
|
||||
readerBuilder func(io.Reader) (io.Reader, error)
|
||||
acceptEncoding string
|
||||
}{
|
||||
{
|
||||
desc: "brotli",
|
||||
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
|
||||
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
|
||||
algo: brotliName,
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return brotli.NewReader(reader), nil
|
||||
},
|
||||
|
@ -315,7 +319,8 @@ func Test_FlushAfterWrite(t *testing.T) {
|
|||
},
|
||||
{
|
||||
desc: "zstd",
|
||||
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
|
||||
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
|
||||
algo: zstdName,
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return zstd.NewReader(reader)
|
||||
},
|
||||
|
@ -338,7 +343,7 @@ func Test_FlushAfterWrite(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
|
||||
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
||||
|
@ -368,12 +373,14 @@ func Test_FlushAfterWriteNil(t *testing.T) {
|
|||
testCases := []struct {
|
||||
desc string
|
||||
cfg Config
|
||||
algo string
|
||||
readerBuilder func(io.Reader) (io.Reader, error)
|
||||
acceptEncoding string
|
||||
}{
|
||||
{
|
||||
desc: "brotli",
|
||||
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
|
||||
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
|
||||
algo: brotliName,
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return brotli.NewReader(reader), nil
|
||||
},
|
||||
|
@ -381,7 +388,8 @@ func Test_FlushAfterWriteNil(t *testing.T) {
|
|||
},
|
||||
{
|
||||
desc: "zstd",
|
||||
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
|
||||
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
|
||||
algo: zstdName,
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return zstd.NewReader(reader)
|
||||
},
|
||||
|
@ -400,7 +408,7 @@ func Test_FlushAfterWriteNil(t *testing.T) {
|
|||
rw.(http.Flusher).Flush()
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
|
||||
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
||||
|
@ -430,12 +438,14 @@ func Test_FlushAfterAllWrites(t *testing.T) {
|
|||
testCases := []struct {
|
||||
desc string
|
||||
cfg Config
|
||||
algo string
|
||||
readerBuilder func(io.Reader) (io.Reader, error)
|
||||
acceptEncoding string
|
||||
}{
|
||||
{
|
||||
desc: "brotli",
|
||||
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
|
||||
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
|
||||
algo: brotliName,
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return brotli.NewReader(reader), nil
|
||||
},
|
||||
|
@ -443,7 +453,8 @@ func Test_FlushAfterAllWrites(t *testing.T) {
|
|||
},
|
||||
{
|
||||
desc: "zstd",
|
||||
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
|
||||
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
|
||||
algo: zstdName,
|
||||
readerBuilder: func(reader io.Reader) (io.Reader, error) {
|
||||
return zstd.NewReader(reader)
|
||||
},
|
||||
|
@ -461,7 +472,7 @@ func Test_FlushAfterAllWrites(t *testing.T) {
|
|||
rw.(http.Flusher).Flush()
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
|
||||
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next))
|
||||
defer srv.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
|
||||
|
@ -556,7 +567,6 @@ func Test_ExcludedContentTypes(t *testing.T) {
|
|||
cfg := Config{
|
||||
MinSize: 1024,
|
||||
ExcludedContentTypes: test.excludedContentTypes,
|
||||
Algorithm: zstdName,
|
||||
}
|
||||
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
|
@ -568,7 +578,7 @@ func Test_ExcludedContentTypes(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
h := mustNewCompressionHandler(t, cfg, next)
|
||||
h := mustNewCompressionHandler(t, cfg, zstdName, next)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||
req.Header.Set(acceptEncoding, zstdName)
|
||||
|
@ -667,7 +677,6 @@ func Test_IncludedContentTypes(t *testing.T) {
|
|||
cfg := Config{
|
||||
MinSize: 1024,
|
||||
IncludedContentTypes: test.includedContentTypes,
|
||||
Algorithm: zstdName,
|
||||
}
|
||||
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
|
@ -679,7 +688,7 @@ func Test_IncludedContentTypes(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
h := mustNewCompressionHandler(t, cfg, next)
|
||||
h := mustNewCompressionHandler(t, cfg, zstdName, next)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||
req.Header.Set(acceptEncoding, zstdName)
|
||||
|
@ -778,7 +787,6 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
|
|||
cfg := Config{
|
||||
MinSize: 1024,
|
||||
ExcludedContentTypes: test.excludedContentTypes,
|
||||
Algorithm: zstdName,
|
||||
}
|
||||
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
|
@ -803,7 +811,7 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
h := mustNewCompressionHandler(t, cfg, next)
|
||||
h := mustNewCompressionHandler(t, cfg, zstdName, next)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||
req.Header.Set(acceptEncoding, zstdName)
|
||||
|
@ -903,7 +911,6 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
|
|||
cfg := Config{
|
||||
MinSize: 1024,
|
||||
IncludedContentTypes: test.includedContentTypes,
|
||||
Algorithm: zstdName,
|
||||
}
|
||||
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
|
@ -928,7 +935,7 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
h := mustNewCompressionHandler(t, cfg, next)
|
||||
h := mustNewCompressionHandler(t, cfg, zstdName, next)
|
||||
|
||||
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
|
||||
req.Header.Set(acceptEncoding, zstdName)
|
||||
|
@ -959,10 +966,26 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func mustNewCompressionHandler(t *testing.T, cfg Config, next http.Handler) http.Handler {
|
||||
func mustNewCompressionHandler(t *testing.T, cfg Config, algo string, next http.Handler) http.Handler {
|
||||
t.Helper()
|
||||
|
||||
w, err := NewCompressionHandler(cfg, next)
|
||||
var writer NewCompressionWriter
|
||||
switch algo {
|
||||
case zstdName:
|
||||
writer = func(rw http.ResponseWriter) (CompressionWriter, string, error) {
|
||||
writer, err := zstd.NewWriter(rw)
|
||||
require.NoError(t, err)
|
||||
return writer, zstdName, nil
|
||||
}
|
||||
case brotliName:
|
||||
writer = func(rw http.ResponseWriter) (CompressionWriter, string, error) {
|
||||
return brotli.NewWriter(rw), brotliName, nil
|
||||
}
|
||||
default:
|
||||
assert.Failf(t, "unknown compression algorithm: %s", algo)
|
||||
}
|
||||
|
||||
w, err := NewCompressionHandler(cfg, writer, next)
|
||||
require.NoError(t, err)
|
||||
|
||||
return w
|
||||
|
@ -981,7 +1004,7 @@ func newTestBrotliHandler(t *testing.T, body []byte) http.Handler {
|
|||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
return mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Compress"}, next)
|
||||
return mustNewCompressionHandler(t, Config{MinSize: 1024, MiddlewareName: "Compress"}, brotliName, next)
|
||||
}
|
||||
|
||||
func newTestZstandardHandler(t *testing.T, body []byte) http.Handler {
|
||||
|
@ -997,7 +1020,7 @@ func newTestZstandardHandler(t *testing.T, body []byte) http.Handler {
|
|||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
return mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Compress"}, next)
|
||||
return mustNewCompressionHandler(t, Config{MinSize: 1024, MiddlewareName: "Compress"}, zstdName, next)
|
||||
}
|
||||
|
||||
func Test_ParseContentType_equals(t *testing.T) {
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
package recovery
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"runtime"
|
||||
|
||||
|
@ -27,12 +30,16 @@ func New(ctx context.Context, next http.Handler) (http.Handler, error) {
|
|||
}
|
||||
|
||||
func (re *recovery) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
defer recoverFunc(rw, req)
|
||||
re.next.ServeHTTP(rw, req)
|
||||
recoveryRW := newRecoveryResponseWriter(rw)
|
||||
defer recoverFunc(recoveryRW, req)
|
||||
|
||||
re.next.ServeHTTP(recoveryRW, req)
|
||||
}
|
||||
|
||||
func recoverFunc(rw http.ResponseWriter, req *http.Request) {
|
||||
func recoverFunc(rw recoveryResponseWriter, req *http.Request) {
|
||||
if err := recover(); err != nil {
|
||||
defer rw.finalizeResponse()
|
||||
|
||||
logger := middlewares.GetLogger(req.Context(), middlewareName, typeName)
|
||||
if !shouldLogPanic(err) {
|
||||
logger.Debug().Msgf("Request has been aborted [%s - %s]: %v", req.RemoteAddr, req.URL, err)
|
||||
|
@ -44,8 +51,6 @@ func recoverFunc(rw http.ResponseWriter, req *http.Request) {
|
|||
buf := make([]byte, size)
|
||||
buf = buf[:runtime.Stack(buf, false)]
|
||||
logger.Error().Msgf("Stack: %s", buf)
|
||||
|
||||
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -55,3 +60,81 @@ func shouldLogPanic(panicValue interface{}) bool {
|
|||
//nolint:errorlint // false-positive because panicValue is an interface.
|
||||
return panicValue != nil && panicValue != http.ErrAbortHandler
|
||||
}
|
||||
|
||||
type recoveryResponseWriter interface {
|
||||
http.ResponseWriter
|
||||
|
||||
finalizeResponse()
|
||||
}
|
||||
|
||||
func newRecoveryResponseWriter(rw http.ResponseWriter) recoveryResponseWriter {
|
||||
wrapper := &responseWriterWrapper{rw: rw}
|
||||
if _, ok := rw.(http.CloseNotifier); !ok {
|
||||
return wrapper
|
||||
}
|
||||
|
||||
return &responseWriterWrapperWithCloseNotify{wrapper}
|
||||
}
|
||||
|
||||
type responseWriterWrapper struct {
|
||||
rw http.ResponseWriter
|
||||
headersSent bool
|
||||
}
|
||||
|
||||
func (r *responseWriterWrapper) Header() http.Header {
|
||||
return r.rw.Header()
|
||||
}
|
||||
|
||||
func (r *responseWriterWrapper) Write(bytes []byte) (int, error) {
|
||||
r.headersSent = true
|
||||
return r.rw.Write(bytes)
|
||||
}
|
||||
|
||||
func (r *responseWriterWrapper) WriteHeader(code int) {
|
||||
if r.headersSent {
|
||||
return
|
||||
}
|
||||
|
||||
// Handling informational headers.
|
||||
if code >= 100 && code <= 199 {
|
||||
r.rw.WriteHeader(code)
|
||||
return
|
||||
}
|
||||
|
||||
r.headersSent = true
|
||||
r.rw.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (r *responseWriterWrapper) Flush() {
|
||||
if f, ok := r.rw.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *responseWriterWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if h, ok := r.rw.(http.Hijacker); ok {
|
||||
return h.Hijack()
|
||||
}
|
||||
|
||||
return nil, nil, fmt.Errorf("not a hijacker: %T", r.rw)
|
||||
}
|
||||
|
||||
func (r *responseWriterWrapper) finalizeResponse() {
|
||||
// If headers have been sent this is not possible to respond with an HTTP error,
|
||||
// and we let the server abort the response silently thanks to the http.ErrAbortHandler sentinel panic value.
|
||||
if r.headersSent {
|
||||
panic(http.ErrAbortHandler)
|
||||
}
|
||||
|
||||
// The response has not yet started to be written,
|
||||
// we can safely return a fresh new error response.
|
||||
http.Error(r.rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
type responseWriterWrapperWithCloseNotify struct {
|
||||
*responseWriterWrapper
|
||||
}
|
||||
|
||||
func (r *responseWriterWrapperWithCloseNotify) CloseNotify() <-chan bool {
|
||||
return r.rw.(http.CloseNotifier).CloseNotify()
|
||||
}
|
||||
|
|
|
@ -2,6 +2,8 @@ package recovery
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
@ -11,17 +13,54 @@ import (
|
|||
)
|
||||
|
||||
func TestRecoverHandler(t *testing.T) {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
panic("I love panicking!")
|
||||
tests := []struct {
|
||||
desc string
|
||||
panicErr error
|
||||
headersSent bool
|
||||
}{
|
||||
{
|
||||
desc: "headers sent and custom panic error",
|
||||
panicErr: errors.New("foo"),
|
||||
headersSent: true,
|
||||
},
|
||||
{
|
||||
desc: "headers sent and error abort handler",
|
||||
panicErr: http.ErrAbortHandler,
|
||||
headersSent: true,
|
||||
},
|
||||
{
|
||||
desc: "custom panic error",
|
||||
panicErr: errors.New("foo"),
|
||||
},
|
||||
{
|
||||
desc: "error abort handler",
|
||||
panicErr: http.ErrAbortHandler,
|
||||
},
|
||||
}
|
||||
recovery, err := New(context.Background(), http.HandlerFunc(fn))
|
||||
require.NoError(t, err)
|
||||
for _, test := range tests {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(recovery)
|
||||
defer server.Close()
|
||||
fn := func(rw http.ResponseWriter, req *http.Request) {
|
||||
if test.headersSent {
|
||||
rw.WriteHeader(http.StatusTeapot)
|
||||
}
|
||||
panic(test.panicErr)
|
||||
}
|
||||
recovery, err := New(context.Background(), http.HandlerFunc(fn))
|
||||
require.NoError(t, err)
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
require.NoError(t, err)
|
||||
server := httptest.NewServer(recovery)
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
|
||||
res, err := http.Get(server.URL)
|
||||
if test.headersSent {
|
||||
require.Nil(t, res)
|
||||
assert.ErrorIs(t, err, io.EOF)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -312,19 +312,9 @@ func (p *Provider) getClient() (*lego.Client, error) {
|
|||
}
|
||||
|
||||
err = client.Challenge.SetDNS01Provider(provider,
|
||||
dns01.CondOption(len(p.DNSChallenge.Resolvers) > 0, dns01.AddRecursiveNameservers(p.DNSChallenge.Resolvers)),
|
||||
dns01.WrapPreCheck(func(domain, fqdn, value string, check dns01.PreCheckFunc) (bool, error) {
|
||||
if p.DNSChallenge.DelayBeforeCheck > 0 {
|
||||
logger.Debug().Msgf("Delaying %d rather than validating DNS propagation now.", p.DNSChallenge.DelayBeforeCheck)
|
||||
time.Sleep(time.Duration(p.DNSChallenge.DelayBeforeCheck))
|
||||
}
|
||||
|
||||
if p.DNSChallenge.DisablePropagationCheck {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return check(fqdn, value)
|
||||
}),
|
||||
dns01.CondOption(len(p.DNSChallenge.Resolvers) > 0,
|
||||
dns01.AddRecursiveNameservers(p.DNSChallenge.Resolvers)),
|
||||
dns01.PropagationWait(time.Duration(p.DNSChallenge.DelayBeforeCheck), p.DNSChallenge.DisablePropagationCheck),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
---
|
||||
kind: GatewayClass
|
||||
apiVersion: gateway.networking.k8s.io/v1
|
||||
metadata:
|
||||
name: my-gateway-class
|
||||
spec:
|
||||
controllerName: traefik.io/gateway-controller
|
||||
|
||||
---
|
||||
kind: Gateway
|
||||
apiVersion: gateway.networking.k8s.io/v1
|
||||
metadata:
|
||||
name: my-gateway
|
||||
namespace: default
|
||||
spec:
|
||||
gatewayClassName: my-gateway-class
|
||||
listeners: # Use GatewayClass defaults for listener definition.
|
||||
- name: http
|
||||
protocol: HTTP
|
||||
port: 80
|
||||
allowedRoutes:
|
||||
kinds:
|
||||
- kind: GRPCRoute
|
||||
group: gateway.networking.k8s.io
|
||||
namespaces:
|
||||
from: Same
|
||||
|
||||
---
|
||||
kind: GRPCRoute
|
||||
apiVersion: gateway.networking.k8s.io/v1
|
||||
metadata:
|
||||
name: grpc-app-1
|
||||
namespace: default
|
||||
spec:
|
||||
parentRefs:
|
||||
- name: my-gateway
|
||||
kind: Gateway
|
||||
group: gateway.networking.k8s.io
|
||||
hostnames:
|
||||
- foo.com
|
||||
rules:
|
||||
- backendRefs:
|
||||
- name: whoami
|
||||
port: 80
|
||||
weight: 1
|
||||
filters:
|
||||
- type: ExtensionRef
|
||||
extensionRef:
|
||||
group: traefik.io
|
||||
kind: Middleware
|
||||
name: my-first-middleware
|
||||
- type: ExtensionRef
|
||||
extensionRef:
|
||||
group: traefik.io
|
||||
kind: Middleware
|
||||
name: my-second-middleware
|
|
@ -54,4 +54,9 @@ spec:
|
|||
extensionRef:
|
||||
group: traefik.io
|
||||
kind: Middleware
|
||||
name: my-middleware
|
||||
name: my-first-middleware
|
||||
- type: ExtensionRef
|
||||
extensionRef:
|
||||
group: traefik.io
|
||||
kind: Middleware
|
||||
name: my-second-middleware
|
||||
|
|
|
@ -120,7 +120,7 @@ func (p *Provider) loadGRPCRoute(ctx context.Context, listener gatewayListener,
|
|||
|
||||
for ri, routeRule := range route.Spec.Rules {
|
||||
// Adding the gateway desc and the entryPoint desc prevents overlapping of routers build from the same routes.
|
||||
routeKey := provider.Normalize(fmt.Sprintf("%s-%s-%s-%s-%d", route.Namespace, route.Name, listener.GWName, listener.EPName, ri))
|
||||
routeKey := provider.Normalize(fmt.Sprintf("%s-%s-%s-gw-%s-%s-ep-%s-%d", strings.ToLower(kindGRPCRoute), route.Namespace, route.Name, listener.GWNamespace, listener.GWName, listener.EPName, ri))
|
||||
|
||||
matches := routeRule.Matches
|
||||
if len(matches) == 0 {
|
||||
|
@ -166,7 +166,7 @@ func (p *Provider) loadGRPCRoute(ctx context.Context, listener gatewayListener,
|
|||
|
||||
default:
|
||||
var serviceCondition *metav1.Condition
|
||||
router.Service, serviceCondition = p.loadGRPCService(conf, routeKey, routeRule, route)
|
||||
router.Service, serviceCondition = p.loadGRPCService(conf, routerName, routeRule, route)
|
||||
if serviceCondition != nil {
|
||||
condition = *serviceCondition
|
||||
}
|
||||
|
@ -267,7 +267,7 @@ func (p *Provider) loadGRPCBackendRef(route *gatev1.GRPCRoute, backendRef gatev1
|
|||
}
|
||||
|
||||
portStr := strconv.FormatInt(int64(port), 10)
|
||||
serviceName = provider.Normalize(serviceName + "-" + portStr)
|
||||
serviceName = provider.Normalize(serviceName + "-" + portStr + "-grpc")
|
||||
|
||||
lb, errCondition := p.loadGRPCServers(namespace, route, backendRef)
|
||||
if errCondition != nil {
|
||||
|
@ -278,20 +278,30 @@ func (p *Provider) loadGRPCBackendRef(route *gatev1.GRPCRoute, backendRef gatev1
|
|||
}
|
||||
|
||||
func (p *Provider) loadGRPCMiddlewares(conf *dynamic.Configuration, namespace, routerName string, filters []gatev1.GRPCRouteFilter) ([]string, error) {
|
||||
middlewares := make(map[string]*dynamic.Middleware)
|
||||
type namedMiddleware struct {
|
||||
Name string
|
||||
Config *dynamic.Middleware
|
||||
}
|
||||
|
||||
var middlewares []namedMiddleware
|
||||
for i, filter := range filters {
|
||||
name := fmt.Sprintf("%s-%s-%d", routerName, strings.ToLower(string(filter.Type)), i)
|
||||
switch filter.Type {
|
||||
case gatev1.GRPCRouteFilterRequestHeaderModifier:
|
||||
middlewares[name] = createRequestHeaderModifier(filter.RequestHeaderModifier)
|
||||
middlewares = append(middlewares, namedMiddleware{
|
||||
name,
|
||||
createRequestHeaderModifier(filter.RequestHeaderModifier),
|
||||
})
|
||||
|
||||
case gatev1.GRPCRouteFilterExtensionRef:
|
||||
name, middleware, err := p.loadHTTPRouteFilterExtensionRef(namespace, filter.ExtensionRef)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading ExtensionRef filter %s: %w", filter.Type, err)
|
||||
}
|
||||
|
||||
middlewares[name] = middleware
|
||||
middlewares = append(middlewares, namedMiddleware{
|
||||
name,
|
||||
middleware,
|
||||
})
|
||||
|
||||
default:
|
||||
// As per the spec: https://gateway-api.sigs.k8s.io/api-types/httproute/#filters-optional
|
||||
|
@ -303,12 +313,11 @@ func (p *Provider) loadGRPCMiddlewares(conf *dynamic.Configuration, namespace, r
|
|||
}
|
||||
|
||||
var middlewareNames []string
|
||||
for name, middleware := range middlewares {
|
||||
if middleware != nil {
|
||||
conf.HTTP.Middlewares[name] = middleware
|
||||
for _, m := range middlewares {
|
||||
if m.Config != nil {
|
||||
conf.HTTP.Middlewares[m.Name] = m.Config
|
||||
}
|
||||
|
||||
middlewareNames = append(middlewareNames, name)
|
||||
middlewareNames = append(middlewareNames, m.Name)
|
||||
}
|
||||
|
||||
return middlewareNames, nil
|
||||
|
|
|
@ -123,7 +123,7 @@ func (p *Provider) loadHTTPRoute(ctx context.Context, listener gatewayListener,
|
|||
|
||||
for ri, routeRule := range route.Spec.Rules {
|
||||
// Adding the gateway desc and the entryPoint desc prevents overlapping of routers build from the same routes.
|
||||
routeKey := provider.Normalize(fmt.Sprintf("%s-%s-%s-%s-%d", route.Namespace, route.Name, listener.GWName, listener.EPName, ri))
|
||||
routeKey := provider.Normalize(fmt.Sprintf("%s-%s-%s-gw-%s-%s-ep-%s-%d", strings.ToLower(kindHTTPRoute), route.Namespace, route.Name, listener.GWNamespace, listener.GWName, listener.EPName, ri))
|
||||
|
||||
for _, match := range routeRule.Matches {
|
||||
rule, priority := buildMatchRule(hostnames, match)
|
||||
|
@ -224,7 +224,7 @@ func (p *Provider) loadService(ctx context.Context, listener gatewayListener, co
|
|||
namespace = string(*backendRef.Namespace)
|
||||
}
|
||||
|
||||
serviceName := provider.Normalize(namespace + "-" + string(backendRef.Name))
|
||||
serviceName := provider.Normalize(namespace + "-" + string(backendRef.Name) + "-http")
|
||||
|
||||
if err := p.isReferenceGranted(kindHTTPRoute, route.Namespace, group, string(kind), string(backendRef.Name), namespace); err != nil {
|
||||
return serviceName, &metav1.Condition{
|
||||
|
@ -384,39 +384,58 @@ func (p *Provider) loadHTTPBackendRef(namespace string, backendRef gatev1.HTTPBa
|
|||
}
|
||||
|
||||
func (p *Provider) loadMiddlewares(conf *dynamic.Configuration, namespace, routerName string, filters []gatev1.HTTPRouteFilter, pathMatch *gatev1.HTTPPathMatch) ([]string, error) {
|
||||
type namedMiddleware struct {
|
||||
Name string
|
||||
Config *dynamic.Middleware
|
||||
}
|
||||
|
||||
pm := ptr.Deref(pathMatch, gatev1.HTTPPathMatch{
|
||||
Type: ptr.To(gatev1.PathMatchPathPrefix),
|
||||
Value: ptr.To("/"),
|
||||
})
|
||||
|
||||
middlewares := make(map[string]*dynamic.Middleware)
|
||||
var middlewares []namedMiddleware
|
||||
for i, filter := range filters {
|
||||
name := fmt.Sprintf("%s-%s-%d", routerName, strings.ToLower(string(filter.Type)), i)
|
||||
|
||||
switch filter.Type {
|
||||
case gatev1.HTTPRouteFilterRequestRedirect:
|
||||
middlewares[name] = createRequestRedirect(filter.RequestRedirect, pm)
|
||||
middlewares = append(middlewares, namedMiddleware{
|
||||
name,
|
||||
createRequestRedirect(filter.RequestRedirect, pm),
|
||||
})
|
||||
|
||||
case gatev1.HTTPRouteFilterRequestHeaderModifier:
|
||||
middlewares[name] = createRequestHeaderModifier(filter.RequestHeaderModifier)
|
||||
middlewares = append(middlewares, namedMiddleware{
|
||||
name,
|
||||
createRequestHeaderModifier(filter.RequestHeaderModifier),
|
||||
})
|
||||
|
||||
case gatev1.HTTPRouteFilterResponseHeaderModifier:
|
||||
middlewares[name] = createResponseHeaderModifier(filter.ResponseHeaderModifier)
|
||||
middlewares = append(middlewares, namedMiddleware{
|
||||
name,
|
||||
createResponseHeaderModifier(filter.ResponseHeaderModifier),
|
||||
})
|
||||
|
||||
case gatev1.HTTPRouteFilterExtensionRef:
|
||||
name, middleware, err := p.loadHTTPRouteFilterExtensionRef(namespace, filter.ExtensionRef)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading ExtensionRef filter %s: %w", filter.Type, err)
|
||||
}
|
||||
|
||||
middlewares[name] = middleware
|
||||
middlewares = append(middlewares, namedMiddleware{
|
||||
name,
|
||||
middleware,
|
||||
})
|
||||
|
||||
case gatev1.HTTPRouteFilterURLRewrite:
|
||||
var err error
|
||||
middleware, err := createURLRewrite(filter.URLRewrite, pm)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid filter %s: %w", filter.Type, err)
|
||||
}
|
||||
middlewares[name] = middleware
|
||||
middlewares = append(middlewares, namedMiddleware{
|
||||
name,
|
||||
middleware,
|
||||
})
|
||||
|
||||
default:
|
||||
// As per the spec: https://gateway-api.sigs.k8s.io/api-types/httproute/#filters-optional
|
||||
|
@ -428,12 +447,11 @@ func (p *Provider) loadMiddlewares(conf *dynamic.Configuration, namespace, route
|
|||
}
|
||||
|
||||
var middlewareNames []string
|
||||
for name, middleware := range middlewares {
|
||||
if middleware != nil {
|
||||
conf.HTTP.Middlewares[name] = middleware
|
||||
for _, m := range middlewares {
|
||||
if m.Config != nil {
|
||||
conf.HTTP.Middlewares[m.Name] = m.Config
|
||||
}
|
||||
|
||||
middlewareNames = append(middlewareNames, name)
|
||||
middlewareNames = append(middlewareNames, m.Name)
|
||||
}
|
||||
|
||||
return middlewareNames, nil
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -130,7 +130,7 @@ func (p *Provider) loadTCPRoute(listener gatewayListener, route *gatev1alpha2.TC
|
|||
}
|
||||
|
||||
// Adding the gateway desc and the entryPoint desc prevents overlapping of routers build from the same routes.
|
||||
routeKey := provider.Normalize(fmt.Sprintf("%s-%s-%s-%s-%d", route.Namespace, route.Name, listener.GWName, listener.EPName, ri))
|
||||
routeKey := provider.Normalize(fmt.Sprintf("%s-%s-%s-gw-%s-%s-ep-%s-%d", strings.ToLower(kindTCPRoute), route.Namespace, route.Name, listener.GWNamespace, listener.GWName, listener.EPName, ri))
|
||||
// Routing criteria should be introduced at some point.
|
||||
routerName := makeRouterName("", routeKey)
|
||||
|
||||
|
|
|
@ -132,7 +132,7 @@ func (p *Provider) loadTLSRoute(listener gatewayListener, route *gatev1alpha2.TL
|
|||
}
|
||||
|
||||
// Adding the gateway desc and the entryPoint desc prevents overlapping of routers build from the same routes.
|
||||
routeKey := provider.Normalize(fmt.Sprintf("%s-%s-%s-%s-%d", route.Namespace, route.Name, listener.GWName, listener.EPName, ri))
|
||||
routeKey := provider.Normalize(fmt.Sprintf("%s-%s-%s-gw-%s-%s-ep-%s-%d", strings.ToLower(kindTLSRoute), route.Namespace, route.Name, listener.GWNamespace, listener.GWName, listener.EPName, ri))
|
||||
// Routing criteria should be introduced at some point.
|
||||
routerName := makeRouterName("", routeKey)
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
|
||||
// MustParseYaml parses a YAML to objects.
|
||||
func MustParseYaml(content []byte) []runtime.Object {
|
||||
acceptedK8sTypes := regexp.MustCompile(`^(Namespace|Deployment|EndpointSlice|Node|Service|ConfigMap|Ingress|IngressRoute|IngressRouteTCP|IngressRouteUDP|Middleware|MiddlewareTCP|Secret|TLSOption|TLSStore|TraefikService|IngressClass|ServersTransport|ServersTransportTCP|GatewayClass|Gateway|HTTPRoute|TCPRoute|TLSRoute|ReferenceGrant|BackendTLSPolicy)$`)
|
||||
acceptedK8sTypes := regexp.MustCompile(`^(Namespace|Deployment|EndpointSlice|Node|Service|ConfigMap|Ingress|IngressRoute|IngressRouteTCP|IngressRouteUDP|Middleware|MiddlewareTCP|Secret|TLSOption|TLSStore|TraefikService|IngressClass|ServersTransport|ServersTransportTCP|GatewayClass|Gateway|GRPCRoute|HTTPRoute|TCPRoute|TLSRoute|ReferenceGrant|BackendTLSPolicy)$`)
|
||||
|
||||
files := strings.Split(string(content), "---\n")
|
||||
retVal := make([]runtime.Object, 0, len(files))
|
||||
|
|
|
@ -68,7 +68,7 @@ func (r *ProxyBuilder) Update(newConfigs map[string]*dynamic.ServersTransport) {
|
|||
}
|
||||
|
||||
// Build builds a new ReverseProxy with the given configuration.
|
||||
func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, passHostHeader bool) (http.Handler, error) {
|
||||
func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, passHostHeader, preservePath bool) (http.Handler, error) {
|
||||
proxyURL, err := r.proxy(&http.Request{URL: targetURL})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting proxy: %w", err)
|
||||
|
@ -79,18 +79,13 @@ func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, passHostHeader
|
|||
return nil, fmt.Errorf("getting ServersTransport: %w", err)
|
||||
}
|
||||
|
||||
var responseHeaderTimeout time.Duration
|
||||
if cfg.ForwardingTimeouts != nil {
|
||||
responseHeaderTimeout = time.Duration(cfg.ForwardingTimeouts.ResponseHeaderTimeout)
|
||||
}
|
||||
|
||||
tlsConfig, err := r.transportManager.GetTLSConfig(cfgName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting TLS config: %w", err)
|
||||
}
|
||||
|
||||
pool := r.getPool(cfgName, cfg, tlsConfig, targetURL, proxyURL)
|
||||
return NewReverseProxy(targetURL, proxyURL, r.debug, passHostHeader, responseHeaderTimeout, pool)
|
||||
return NewReverseProxy(targetURL, proxyURL, r.debug, passHostHeader, preservePath, pool)
|
||||
}
|
||||
|
||||
func (r *ProxyBuilder) getPool(cfgName string, config *dynamic.ServersTransport, tlsConfig *tls.Config, targetURL *url.URL, proxyURL *url.URL) *connPool {
|
||||
|
@ -106,9 +101,11 @@ func (r *ProxyBuilder) getPool(cfgName string, config *dynamic.ServersTransport,
|
|||
|
||||
idleConnTimeout := 90 * time.Second
|
||||
dialTimeout := 30 * time.Second
|
||||
var responseHeaderTimeout time.Duration
|
||||
if config.ForwardingTimeouts != nil {
|
||||
idleConnTimeout = time.Duration(config.ForwardingTimeouts.IdleConnTimeout)
|
||||
dialTimeout = time.Duration(config.ForwardingTimeouts.DialTimeout)
|
||||
responseHeaderTimeout = time.Duration(config.ForwardingTimeouts.ResponseHeaderTimeout)
|
||||
}
|
||||
|
||||
proxyDialer := newDialer(dialerConfig{
|
||||
|
@ -119,7 +116,7 @@ func (r *ProxyBuilder) getPool(cfgName string, config *dynamic.ServersTransport,
|
|||
ProxyURL: proxyURL,
|
||||
}, tlsConfig)
|
||||
|
||||
connPool := newConnPool(config.MaxIdleConnsPerHost, idleConnTimeout, func() (net.Conn, error) {
|
||||
connPool := newConnPool(config.MaxIdleConnsPerHost, idleConnTimeout, responseHeaderTimeout, func() (net.Conn, error) {
|
||||
return proxyDialer.Dial("tcp", addrFromURL(targetURL))
|
||||
})
|
||||
|
||||
|
|
|
@ -1,42 +1,309 @@
|
|||
package fast
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
|
||||
// rwWithUpgrade contains a ResponseWriter and an upgradeHandler,
|
||||
// used to upgrade the connection (e.g. Websockets).
|
||||
type rwWithUpgrade struct {
|
||||
RW http.ResponseWriter
|
||||
Upgrade upgradeHandler
|
||||
}
|
||||
|
||||
// conn is an enriched net.Conn.
|
||||
type conn struct {
|
||||
net.Conn
|
||||
|
||||
RWCh chan rwWithUpgrade
|
||||
ErrCh chan error
|
||||
|
||||
br *bufio.Reader
|
||||
|
||||
idleAt time.Time // the last time it was marked as idle.
|
||||
idleTimeout time.Duration
|
||||
|
||||
responseHeaderTimeout time.Duration
|
||||
|
||||
expectedResponse atomic.Bool
|
||||
broken atomic.Bool
|
||||
upgraded atomic.Bool
|
||||
|
||||
closeMu sync.Mutex
|
||||
closed bool
|
||||
closeErr error
|
||||
|
||||
bufferPool *pool[[]byte]
|
||||
limitedReaderPool *pool[*io.LimitedReader]
|
||||
}
|
||||
|
||||
func (c *conn) isExpired() bool {
|
||||
// Read reads data from the connection.
|
||||
// Overrides conn Read to use the buffered reader.
|
||||
func (c *conn) Read(b []byte) (n int, err error) {
|
||||
return c.br.Read(b)
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
// Ensures that connection is closed only once,
|
||||
// to avoid duplicate close error.
|
||||
func (c *conn) Close() error {
|
||||
c.closeMu.Lock()
|
||||
defer c.closeMu.Unlock()
|
||||
|
||||
if c.closed {
|
||||
return c.closeErr
|
||||
}
|
||||
|
||||
c.closed = true
|
||||
c.closeErr = c.Conn.Close()
|
||||
|
||||
return c.closeErr
|
||||
}
|
||||
|
||||
// isStale returns whether the connection is in an invalid state (i.e. expired/broken).
|
||||
func (c *conn) isStale() bool {
|
||||
expTime := c.idleAt.Add(c.idleTimeout)
|
||||
return c.idleTimeout > 0 && time.Now().After(expTime)
|
||||
return c.idleTimeout > 0 && time.Now().After(expTime) || c.broken.Load()
|
||||
}
|
||||
|
||||
// isUpgraded returns whether this connection has been upgraded (e.g. Websocket).
|
||||
// An upgraded connection should not be reused and putted back in the connection pool.
|
||||
func (c *conn) isUpgraded() bool {
|
||||
return c.upgraded.Load()
|
||||
}
|
||||
|
||||
// readLoop handles the successive HTTP response read operations on the connection,
|
||||
// and watches for unsolicited bytes or connection errors when idle.
|
||||
func (c *conn) readLoop() {
|
||||
defer c.Close()
|
||||
|
||||
for {
|
||||
_, err := c.br.Peek(1)
|
||||
if err != nil {
|
||||
select {
|
||||
// An error occurred while a response was expected to be handled.
|
||||
case <-c.RWCh:
|
||||
c.ErrCh <- err
|
||||
// An error occurred on an idle connection.
|
||||
default:
|
||||
c.broken.Store(true)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Unsolicited response received on an idle connection.
|
||||
if !c.expectedResponse.Load() {
|
||||
c.broken.Store(true)
|
||||
return
|
||||
}
|
||||
|
||||
r := <-c.RWCh
|
||||
if err = c.handleResponse(r); err != nil {
|
||||
c.ErrCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
c.expectedResponse.Store(false)
|
||||
c.ErrCh <- nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *conn) handleResponse(r rwWithUpgrade) error {
|
||||
res := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseResponse(res)
|
||||
|
||||
res.Header.SetNoDefaultContentType(true)
|
||||
|
||||
for {
|
||||
var (
|
||||
timer *time.Timer
|
||||
errTimeout atomic.Pointer[timeoutError]
|
||||
)
|
||||
if c.responseHeaderTimeout > 0 {
|
||||
timer = time.AfterFunc(c.responseHeaderTimeout, func() {
|
||||
errTimeout.Store(&timeoutError{errors.New("timeout awaiting response headers")})
|
||||
c.Close() // This close call is needed to interrupt the read operation below when the timeout is over.
|
||||
})
|
||||
}
|
||||
|
||||
res.Header.SetNoDefaultContentType(true)
|
||||
if err := res.Header.Read(c.br); err != nil {
|
||||
if c.responseHeaderTimeout > 0 {
|
||||
if errT := errTimeout.Load(); errT != nil {
|
||||
return errT
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
|
||||
fixPragmaCacheControl(&res.Header)
|
||||
|
||||
resCode := res.StatusCode()
|
||||
is1xx := 100 <= resCode && resCode <= 199
|
||||
// treat 101 as a terminal status, see issue 26161
|
||||
is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols
|
||||
if is1xxNonTerminal {
|
||||
removeConnectionHeaders(&res.Header)
|
||||
h := r.RW.Header()
|
||||
|
||||
for _, header := range hopHeaders {
|
||||
res.Header.Del(header)
|
||||
}
|
||||
|
||||
res.Header.VisitAll(func(key, value []byte) {
|
||||
r.RW.Header().Add(string(key), string(value))
|
||||
})
|
||||
|
||||
r.RW.WriteHeader(res.StatusCode())
|
||||
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
|
||||
for k := range h {
|
||||
delete(h, k)
|
||||
}
|
||||
|
||||
res.Reset()
|
||||
res.Header.Reset()
|
||||
res.Header.SetNoDefaultContentType(true)
|
||||
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
announcedTrailers := res.Header.Peek("Trailer")
|
||||
|
||||
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
|
||||
if res.StatusCode() == http.StatusSwitchingProtocols {
|
||||
r.Upgrade(r.RW, res, c)
|
||||
c.upgraded.Store(true) // As the connection has been upgraded, it cannot be added back to the pool.
|
||||
return nil
|
||||
}
|
||||
|
||||
removeConnectionHeaders(&res.Header)
|
||||
|
||||
for _, header := range hopHeaders {
|
||||
res.Header.Del(header)
|
||||
}
|
||||
|
||||
if len(announcedTrailers) > 0 {
|
||||
res.Header.Add("Trailer", string(announcedTrailers))
|
||||
}
|
||||
|
||||
res.Header.VisitAll(func(key, value []byte) {
|
||||
r.RW.Header().Add(string(key), string(value))
|
||||
})
|
||||
|
||||
r.RW.WriteHeader(res.StatusCode())
|
||||
|
||||
if res.Header.ContentLength() == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// When a body is not allowed for a given status code the body is ignored.
|
||||
// The connection will be marked as broken by the next Peek in the readloop.
|
||||
if !isBodyAllowedForStatus(res.StatusCode()) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Chunked response, Content-Length is set to -1 by FastProxy when "Transfer-Encoding: chunked" header is received.
|
||||
if res.Header.ContentLength() == -1 {
|
||||
cbr := httputil.NewChunkedReader(c.br)
|
||||
|
||||
b := c.bufferPool.Get()
|
||||
if b == nil {
|
||||
b = make([]byte, bufferSize)
|
||||
}
|
||||
defer c.bufferPool.Put(b)
|
||||
|
||||
if _, err := io.CopyBuffer(&writeFlusher{r.RW}, cbr, b); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res.Header.Reset()
|
||||
res.Header.SetNoDefaultContentType(true)
|
||||
if err := res.Header.ReadTrailer(c.br); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Header.Len() > 0 {
|
||||
var announcedTrailersKey []string
|
||||
if len(announcedTrailers) > 0 {
|
||||
announcedTrailersKey = strings.Split(string(announcedTrailers), ",")
|
||||
}
|
||||
|
||||
res.Header.VisitAll(func(key, value []byte) {
|
||||
for _, s := range announcedTrailersKey {
|
||||
if strings.EqualFold(s, strings.TrimSpace(string(key))) {
|
||||
r.RW.Header().Add(string(key), string(value))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
r.RW.Header().Add(http.TrailerPrefix+string(key), string(value))
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
brl := c.limitedReaderPool.Get()
|
||||
if brl == nil {
|
||||
brl = &io.LimitedReader{}
|
||||
}
|
||||
defer c.limitedReaderPool.Put(brl)
|
||||
|
||||
brl.R = c.br
|
||||
brl.N = int64(res.Header.ContentLength())
|
||||
|
||||
b := c.bufferPool.Get()
|
||||
if b == nil {
|
||||
b = make([]byte, bufferSize)
|
||||
}
|
||||
defer c.bufferPool.Put(b)
|
||||
|
||||
if _, err := io.CopyBuffer(r.RW, brl, b); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// connPool is a net.Conn pool implementation using channels.
|
||||
type connPool struct {
|
||||
dialer func() (net.Conn, error)
|
||||
idleConns chan *conn
|
||||
idleConnTimeout time.Duration
|
||||
ticker *time.Ticker
|
||||
doneCh chan struct{}
|
||||
dialer func() (net.Conn, error)
|
||||
idleConns chan *conn
|
||||
idleConnTimeout time.Duration
|
||||
responseHeaderTimeout time.Duration
|
||||
ticker *time.Ticker
|
||||
bufferPool pool[[]byte]
|
||||
limitedReaderPool pool[*io.LimitedReader]
|
||||
doneCh chan struct{}
|
||||
}
|
||||
|
||||
// newConnPool creates a new connPool.
|
||||
func newConnPool(maxIdleConn int, idleConnTimeout time.Duration, dialer func() (net.Conn, error)) *connPool {
|
||||
func newConnPool(maxIdleConn int, idleConnTimeout, responseHeaderTimeout time.Duration, dialer func() (net.Conn, error)) *connPool {
|
||||
c := &connPool{
|
||||
dialer: dialer,
|
||||
idleConns: make(chan *conn, maxIdleConn),
|
||||
idleConnTimeout: idleConnTimeout,
|
||||
doneCh: make(chan struct{}),
|
||||
dialer: dialer,
|
||||
idleConns: make(chan *conn, maxIdleConn),
|
||||
idleConnTimeout: idleConnTimeout,
|
||||
responseHeaderTimeout: responseHeaderTimeout,
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
if idleConnTimeout > 0 {
|
||||
|
@ -72,22 +339,28 @@ func (c *connPool) AcquireConn() (*conn, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if !co.isExpired() {
|
||||
if !co.isStale() {
|
||||
return co, nil
|
||||
}
|
||||
|
||||
// As the acquired conn is expired we can close it
|
||||
// As the acquired conn is stale we can close it
|
||||
// without putting it again into the pool.
|
||||
if err := co.Close(); err != nil {
|
||||
log.Debug().
|
||||
Err(err).
|
||||
Msg("Unexpected error while releasing the connection")
|
||||
Msg("Unexpected error while closing the connection")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ReleaseConn releases the given net.Conn to the pool.
|
||||
func (c *connPool) ReleaseConn(co *conn) {
|
||||
// An upgraded connection cannot be safely reused for another roundTrip,
|
||||
// thus we are not putting it back to the pool.
|
||||
if co.isUpgraded() {
|
||||
return
|
||||
}
|
||||
|
||||
co.idleAt = time.Now()
|
||||
c.releaseConn(co)
|
||||
}
|
||||
|
@ -97,7 +370,7 @@ func (c *connPool) cleanIdleConns() {
|
|||
for {
|
||||
select {
|
||||
case co := <-c.idleConns:
|
||||
if !co.isExpired() {
|
||||
if !co.isStale() {
|
||||
c.releaseConn(co)
|
||||
return
|
||||
}
|
||||
|
@ -105,7 +378,7 @@ func (c *connPool) cleanIdleConns() {
|
|||
if err := co.Close(); err != nil {
|
||||
log.Debug().
|
||||
Err(err).
|
||||
Msg("Unexpected error while releasing the connection")
|
||||
Msg("Unexpected error while closing the connection")
|
||||
}
|
||||
|
||||
default:
|
||||
|
@ -155,9 +428,33 @@ func (c *connPool) askForNewConn(errCh chan<- error) {
|
|||
return
|
||||
}
|
||||
|
||||
c.releaseConn(&conn{
|
||||
Conn: co,
|
||||
idleAt: time.Now(),
|
||||
idleTimeout: c.idleConnTimeout,
|
||||
})
|
||||
newConn := &conn{
|
||||
Conn: co,
|
||||
br: bufio.NewReaderSize(co, bufioSize),
|
||||
idleAt: time.Now(),
|
||||
idleTimeout: c.idleConnTimeout,
|
||||
responseHeaderTimeout: c.responseHeaderTimeout,
|
||||
RWCh: make(chan rwWithUpgrade),
|
||||
ErrCh: make(chan error),
|
||||
bufferPool: &c.bufferPool,
|
||||
limitedReaderPool: &c.limitedReaderPool,
|
||||
}
|
||||
go newConn.readLoop()
|
||||
|
||||
c.releaseConn(newConn)
|
||||
}
|
||||
|
||||
// isBodyAllowedForStatus reports whether a given response status code
|
||||
// permits a body. See RFC 7230, section 3.3.
|
||||
// From https://github.com/golang/go/blame/master/src/net/http/transfer.go#L459
|
||||
func isBodyAllowedForStatus(status int) bool {
|
||||
switch {
|
||||
case status >= 100 && status <= 199:
|
||||
return false
|
||||
case status == 204:
|
||||
return false
|
||||
case status == 304:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -58,7 +58,7 @@ func TestConnPool_ConnReuse(t *testing.T) {
|
|||
return &net.TCPConn{}, nil
|
||||
}
|
||||
|
||||
pool := newConnPool(2, 0, dialer)
|
||||
pool := newConnPool(2, 0, 0, dialer)
|
||||
test.poolFn(pool)
|
||||
|
||||
assert.Equal(t, test.expected, connAlloc)
|
||||
|
@ -102,13 +102,16 @@ func TestConnPool_MaxIdleConn(t *testing.T) {
|
|||
var keepOpenedConn int
|
||||
dialer := func() (net.Conn, error) {
|
||||
keepOpenedConn++
|
||||
return &mockConn{closeFn: func() error {
|
||||
keepOpenedConn--
|
||||
return nil
|
||||
}}, nil
|
||||
return &mockConn{
|
||||
doneCh: make(chan struct{}),
|
||||
closeFn: func() error {
|
||||
keepOpenedConn--
|
||||
return nil
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
pool := newConnPool(test.maxIdleConn, 0, dialer)
|
||||
pool := newConnPool(test.maxIdleConn, 0, 0, dialer)
|
||||
test.poolFn(pool)
|
||||
|
||||
assert.Equal(t, test.expected, keepOpenedConn)
|
||||
|
@ -129,7 +132,7 @@ func TestGC(t *testing.T) {
|
|||
return c, nil
|
||||
}
|
||||
|
||||
pools["test"] = newConnPool(10, 1*time.Second, dialer)
|
||||
pools["test"] = newConnPool(10, 1*time.Second, 0, dialer)
|
||||
runtime.SetFinalizer(pools["test"], func(p *connPool) {
|
||||
isDestroyed = true
|
||||
})
|
||||
|
@ -149,10 +152,12 @@ func TestGC(t *testing.T) {
|
|||
|
||||
type mockConn struct {
|
||||
closeFn func() error
|
||||
doneCh chan struct{} // makes sure that the readLoop is blocking avoiding close.
|
||||
}
|
||||
|
||||
func (m *mockConn) Read(_ []byte) (n int, err error) {
|
||||
panic("implement me")
|
||||
<-m.doneCh
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockConn) Write(_ []byte) (n int, err error) {
|
||||
|
@ -160,6 +165,7 @@ func (m *mockConn) Write(_ []byte) (n int, err error) {
|
|||
}
|
||||
|
||||
func (m *mockConn) Close() error {
|
||||
defer close(m.doneCh)
|
||||
if m.closeFn != nil {
|
||||
return m.closeFn()
|
||||
}
|
||||
|
|
|
@ -4,18 +4,14 @@ import (
|
|||
"bufio"
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
proxyhttputil "github.com/traefik/traefik/v3/pkg/proxy/httputil"
|
||||
|
@ -57,15 +53,6 @@ func (p *pool[T]) Put(x T) {
|
|||
p.pool.Put(x)
|
||||
}
|
||||
|
||||
type buffConn struct {
|
||||
*bufio.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (b buffConn) Read(p []byte) (int, error) {
|
||||
return b.Reader.Read(p)
|
||||
}
|
||||
|
||||
type writeDetector struct {
|
||||
net.Conn
|
||||
|
||||
|
@ -112,20 +99,17 @@ type ReverseProxy struct {
|
|||
|
||||
connPool *connPool
|
||||
|
||||
bufferPool pool[[]byte]
|
||||
readerPool pool[*bufio.Reader]
|
||||
writerPool pool[*bufio.Writer]
|
||||
limitReaderPool pool[*io.LimitedReader]
|
||||
writerPool pool[*bufio.Writer]
|
||||
|
||||
proxyAuth string
|
||||
|
||||
targetURL *url.URL
|
||||
passHostHeader bool
|
||||
responseHeaderTimeout time.Duration
|
||||
targetURL *url.URL
|
||||
passHostHeader bool
|
||||
preservePath bool
|
||||
}
|
||||
|
||||
// NewReverseProxy creates a new ReverseProxy.
|
||||
func NewReverseProxy(targetURL *url.URL, proxyURL *url.URL, debug, passHostHeader bool, responseHeaderTimeout time.Duration, connPool *connPool) (*ReverseProxy, error) {
|
||||
func NewReverseProxy(targetURL, proxyURL *url.URL, debug, passHostHeader, preservePath bool, connPool *connPool) (*ReverseProxy, error) {
|
||||
var proxyAuth string
|
||||
if proxyURL != nil && proxyURL.User != nil && targetURL.Scheme == "http" {
|
||||
username := proxyURL.User.Username()
|
||||
|
@ -134,12 +118,12 @@ func NewReverseProxy(targetURL *url.URL, proxyURL *url.URL, debug, passHostHeade
|
|||
}
|
||||
|
||||
return &ReverseProxy{
|
||||
debug: debug,
|
||||
passHostHeader: passHostHeader,
|
||||
targetURL: targetURL,
|
||||
proxyAuth: proxyAuth,
|
||||
connPool: connPool,
|
||||
responseHeaderTimeout: responseHeaderTimeout,
|
||||
debug: debug,
|
||||
passHostHeader: passHostHeader,
|
||||
preservePath: preservePath,
|
||||
targetURL: targetURL,
|
||||
proxyAuth: proxyAuth,
|
||||
connPool: connPool,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -207,6 +191,11 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||
|
||||
u2.Path = u.Path
|
||||
u2.RawPath = u.RawPath
|
||||
|
||||
if p.preservePath {
|
||||
u2.Path, u2.RawPath = proxyhttputil.JoinURLPath(p.targetURL, u)
|
||||
}
|
||||
|
||||
u2.RawQuery = strings.ReplaceAll(u.RawQuery, ";", "&")
|
||||
|
||||
outReq.SetHost(u2.Host)
|
||||
|
@ -266,8 +255,15 @@ func (p *ReverseProxy) roundTrip(rw http.ResponseWriter, req *http.Request, outR
|
|||
return fmt.Errorf("acquire connection: %w", err)
|
||||
}
|
||||
|
||||
// Before writing the request,
|
||||
// we mark the conn as expecting to handle a response.
|
||||
co.expectedResponse.Store(true)
|
||||
|
||||
wd := &writeDetector{Conn: co}
|
||||
|
||||
// TODO: do not wait to write the full request before reading the response (to handle "100 Continue").
|
||||
// TODO: this is currently impossible with fasthttp to write the request partially (headers only).
|
||||
// Currently, writing the request fully is a mandatory step before handling the response.
|
||||
err = p.writeRequest(wd, outReq)
|
||||
if wd.written && trace != nil && trace.WroteRequest != nil {
|
||||
// WroteRequest hook is used by the tracing middleware to detect if the request has been written.
|
||||
|
@ -286,169 +282,17 @@ func (p *ReverseProxy) roundTrip(rw http.ResponseWriter, req *http.Request, outR
|
|||
}
|
||||
}
|
||||
|
||||
br := p.readerPool.Get()
|
||||
if br == nil {
|
||||
br = bufio.NewReaderSize(co, bufioSize)
|
||||
}
|
||||
defer p.readerPool.Put(br)
|
||||
|
||||
br.Reset(co)
|
||||
|
||||
res := fasthttp.AcquireResponse()
|
||||
defer fasthttp.ReleaseResponse(res)
|
||||
|
||||
res.Header.SetNoDefaultContentType(true)
|
||||
|
||||
for {
|
||||
var timer *time.Timer
|
||||
errTimeout := atomic.Pointer[timeoutError]{}
|
||||
if p.responseHeaderTimeout > 0 {
|
||||
timer = time.AfterFunc(p.responseHeaderTimeout, func() {
|
||||
errTimeout.Store(&timeoutError{errors.New("timeout awaiting response headers")})
|
||||
co.Close()
|
||||
})
|
||||
}
|
||||
|
||||
res.Header.SetNoDefaultContentType(true)
|
||||
if err := res.Header.Read(br); err != nil {
|
||||
if p.responseHeaderTimeout > 0 {
|
||||
if errT := errTimeout.Load(); errT != nil {
|
||||
return errT
|
||||
}
|
||||
}
|
||||
co.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
|
||||
fixPragmaCacheControl(&res.Header)
|
||||
|
||||
resCode := res.StatusCode()
|
||||
is1xx := 100 <= resCode && resCode <= 199
|
||||
// treat 101 as a terminal status, see issue 26161
|
||||
is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols
|
||||
if is1xxNonTerminal {
|
||||
removeConnectionHeaders(&res.Header)
|
||||
h := rw.Header()
|
||||
|
||||
for _, header := range hopHeaders {
|
||||
res.Header.Del(header)
|
||||
}
|
||||
|
||||
res.Header.VisitAll(func(key, value []byte) {
|
||||
rw.Header().Add(string(key), string(value))
|
||||
})
|
||||
|
||||
rw.WriteHeader(res.StatusCode())
|
||||
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
|
||||
for k := range h {
|
||||
delete(h, k)
|
||||
}
|
||||
|
||||
res.Reset()
|
||||
res.Header.Reset()
|
||||
res.Header.SetNoDefaultContentType(true)
|
||||
|
||||
continue
|
||||
}
|
||||
break
|
||||
// Sending the responseWriter unlocks the connection readLoop, to handle the response.
|
||||
co.RWCh <- rwWithUpgrade{
|
||||
RW: rw,
|
||||
Upgrade: upgradeResponseHandler(req.Context(), reqUpType),
|
||||
}
|
||||
|
||||
announcedTrailers := res.Header.Peek("Trailer")
|
||||
|
||||
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
|
||||
if res.StatusCode() == http.StatusSwitchingProtocols {
|
||||
// As the connection has been hijacked, it cannot be added back to the pool.
|
||||
handleUpgradeResponse(rw, req, reqUpType, res, buffConn{Conn: co, Reader: br})
|
||||
return nil
|
||||
}
|
||||
|
||||
removeConnectionHeaders(&res.Header)
|
||||
|
||||
for _, header := range hopHeaders {
|
||||
res.Header.Del(header)
|
||||
}
|
||||
|
||||
if len(announcedTrailers) > 0 {
|
||||
res.Header.Add("Trailer", string(announcedTrailers))
|
||||
}
|
||||
|
||||
res.Header.VisitAll(func(key, value []byte) {
|
||||
rw.Header().Add(string(key), string(value))
|
||||
})
|
||||
|
||||
rw.WriteHeader(res.StatusCode())
|
||||
|
||||
// Chunked response, Content-Length is set to -1 by FastProxy when "Transfer-Encoding: chunked" header is received.
|
||||
if res.Header.ContentLength() == -1 {
|
||||
cbr := httputil.NewChunkedReader(br)
|
||||
|
||||
b := p.bufferPool.Get()
|
||||
if b == nil {
|
||||
b = make([]byte, bufferSize)
|
||||
}
|
||||
defer p.bufferPool.Put(b)
|
||||
|
||||
if _, err := io.CopyBuffer(&writeFlusher{rw}, cbr, b); err != nil {
|
||||
co.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
res.Header.Reset()
|
||||
res.Header.SetNoDefaultContentType(true)
|
||||
if err := res.Header.ReadTrailer(br); err != nil {
|
||||
co.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if res.Header.Len() > 0 {
|
||||
var announcedTrailersKey []string
|
||||
if len(announcedTrailers) > 0 {
|
||||
announcedTrailersKey = strings.Split(string(announcedTrailers), ",")
|
||||
}
|
||||
|
||||
res.Header.VisitAll(func(key, value []byte) {
|
||||
for _, s := range announcedTrailersKey {
|
||||
if strings.EqualFold(s, strings.TrimSpace(string(key))) {
|
||||
rw.Header().Add(string(key), string(value))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
rw.Header().Add(http.TrailerPrefix+string(key), string(value))
|
||||
})
|
||||
}
|
||||
|
||||
p.connPool.ReleaseConn(co)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
brl := p.limitReaderPool.Get()
|
||||
if brl == nil {
|
||||
brl = &io.LimitedReader{}
|
||||
}
|
||||
defer p.limitReaderPool.Put(brl)
|
||||
|
||||
brl.R = br
|
||||
brl.N = int64(res.Header.ContentLength())
|
||||
|
||||
b := p.bufferPool.Get()
|
||||
if b == nil {
|
||||
b = make([]byte, bufferSize)
|
||||
}
|
||||
defer p.bufferPool.Put(b)
|
||||
|
||||
if _, err := io.CopyBuffer(rw, brl, b); err != nil {
|
||||
co.Close()
|
||||
if err := <-co.ErrCh; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.connPool.ReleaseConn(co)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -230,7 +230,7 @@ func TestProxyFromEnvironment(t *testing.T) {
|
|||
return u, nil
|
||||
}
|
||||
|
||||
reverseProxy, err := builder.Build("foo", testhelpers.MustParseURL(backendURL), false)
|
||||
reverseProxy, err := builder.Build("foo", testhelpers.MustParseURL(backendURL), false, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
reverseProxyServer := httptest.NewServer(reverseProxy)
|
||||
|
@ -252,6 +252,32 @@ func TestProxyFromEnvironment(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestPreservePath(t *testing.T) {
|
||||
var callCount int
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
callCount++
|
||||
assert.Equal(t, "/base/foo/bar", req.URL.Path)
|
||||
assert.Equal(t, "/base/foo%2Fbar", req.URL.RawPath)
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
builder := NewProxyBuilder(&transportManagerMock{}, static.FastProxyConfig{})
|
||||
|
||||
serverURL, err := url.JoinPath(server.URL, "base")
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyHandler, err := builder.Build("", testhelpers.MustParseURL(serverURL), true, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/foo%2Fbar", http.NoBody)
|
||||
res := httptest.NewRecorder()
|
||||
|
||||
proxyHandler.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, 1, callCount)
|
||||
assert.Equal(t, http.StatusOK, res.Code)
|
||||
}
|
||||
|
||||
func newCertificate(t *testing.T, domain string) *tls.Certificate {
|
||||
t.Helper()
|
||||
|
||||
|
|
|
@ -362,7 +362,7 @@ func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
|
|||
|
||||
u := parseURI(t, srv.URL)
|
||||
|
||||
f, err := NewReverseProxy(u, nil, true, false, 0, newConnPool(1, 0, func() (net.Conn, error) {
|
||||
f, err := NewReverseProxy(u, nil, true, false, false, newConnPool(1, 0, 0, func() (net.Conn, error) {
|
||||
return net.Dial("tcp", u.Host)
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
|
@ -434,7 +434,7 @@ func TestWebSocketUpgradeFailed(t *testing.T) {
|
|||
defer srv.Close()
|
||||
|
||||
u := parseURI(t, srv.URL)
|
||||
f, err := NewReverseProxy(u, nil, true, false, 0, newConnPool(1, 0, func() (net.Conn, error) {
|
||||
f, err := NewReverseProxy(u, nil, true, false, false, newConnPool(1, 0, 0, func() (net.Conn, error) {
|
||||
return net.Dial("tcp", u.Host)
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
|
@ -663,7 +663,7 @@ func parseURI(t *testing.T, uri string) *url.URL {
|
|||
|
||||
func createConnectionPool(target string, tlsConfig *tls.Config) *connPool {
|
||||
u := testhelpers.MustParseURL(target)
|
||||
return newConnPool(200, 0, func() (net.Conn, error) {
|
||||
return newConnPool(200, 0, 0, func() (net.Conn, error) {
|
||||
if tlsConfig != nil {
|
||||
return tls.Dial("tcp", u.Host, tlsConfig)
|
||||
}
|
||||
|
@ -676,7 +676,7 @@ func createProxyWithForwarder(t *testing.T, uri string, pool *connPool) *httptes
|
|||
t.Helper()
|
||||
|
||||
u := parseURI(t, uri)
|
||||
proxy, err := NewReverseProxy(u, nil, false, true, 0, pool)
|
||||
proxy, err := NewReverseProxy(u, nil, false, true, false, pool)
|
||||
require.NoError(t, err)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
|
|
|
@ -2,6 +2,7 @@ package fast
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
@ -19,72 +20,75 @@ type switchProtocolCopier struct {
|
|||
user, backend io.ReadWriter
|
||||
}
|
||||
|
||||
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
|
||||
func (c switchProtocolCopier) copyFromBackend(errCh chan<- error) {
|
||||
_, err := io.Copy(c.user, c.backend)
|
||||
errc <- err
|
||||
errCh <- err
|
||||
}
|
||||
|
||||
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
|
||||
func (c switchProtocolCopier) copyToBackend(errCh chan<- error) {
|
||||
_, err := io.Copy(c.backend, c.user)
|
||||
errc <- err
|
||||
errCh <- err
|
||||
}
|
||||
|
||||
func handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, reqUpType string, res *fasthttp.Response, backConn net.Conn) {
|
||||
defer backConn.Close()
|
||||
type upgradeHandler func(rw http.ResponseWriter, res *fasthttp.Response, backConn net.Conn)
|
||||
|
||||
resUpType := upgradeTypeFastHTTP(&res.Header)
|
||||
func upgradeResponseHandler(ctx context.Context, reqUpType string) upgradeHandler {
|
||||
return func(rw http.ResponseWriter, res *fasthttp.Response, backConn net.Conn) {
|
||||
resUpType := upgradeTypeFastHTTP(&res.Header)
|
||||
|
||||
if !strings.EqualFold(reqUpType, resUpType) {
|
||||
httputil.ErrorHandler(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
|
||||
return
|
||||
}
|
||||
|
||||
hj, ok := rw.(http.Hijacker)
|
||||
if !ok {
|
||||
httputil.ErrorHandler(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
|
||||
return
|
||||
}
|
||||
backConnCloseCh := make(chan bool)
|
||||
go func() {
|
||||
// Ensure that the cancellation of a request closes the backend.
|
||||
// See issue https://golang.org/issue/35559.
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
case <-backConnCloseCh:
|
||||
if !strings.EqualFold(reqUpType, resUpType) {
|
||||
httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
|
||||
backConn.Close()
|
||||
return
|
||||
}
|
||||
_ = backConn.Close()
|
||||
}()
|
||||
|
||||
defer close(backConnCloseCh)
|
||||
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
httputil.ErrorHandler(rw, req, fmt.Errorf("hijack failed on protocol switch: %w", err))
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
for k, values := range rw.Header() {
|
||||
for _, v := range values {
|
||||
res.Header.Add(k, v)
|
||||
hj, ok := rw.(http.Hijacker)
|
||||
if !ok {
|
||||
httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
|
||||
backConn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
backConnCloseCh := make(chan bool)
|
||||
go func() {
|
||||
// Ensure that the cancellation of a request closes the backend.
|
||||
// See issue https://golang.org/issue/35559.
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-backConnCloseCh:
|
||||
}
|
||||
_ = backConn.Close()
|
||||
}()
|
||||
defer close(backConnCloseCh)
|
||||
|
||||
if err := res.Header.Write(brw.Writer); err != nil {
|
||||
httputil.ErrorHandler(rw, req, fmt.Errorf("response write: %w", err))
|
||||
return
|
||||
}
|
||||
conn, brw, err := hj.Hijack()
|
||||
if err != nil {
|
||||
httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("hijack failed on protocol switch: %w", err))
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := brw.Flush(); err != nil {
|
||||
httputil.ErrorHandler(rw, req, fmt.Errorf("response flush: %w", err))
|
||||
return
|
||||
}
|
||||
for k, values := range rw.Header() {
|
||||
for _, v := range values {
|
||||
res.Header.Add(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
errc := make(chan error, 1)
|
||||
spc := switchProtocolCopier{user: conn, backend: backConn}
|
||||
go spc.copyToBackend(errc)
|
||||
go spc.copyFromBackend(errc)
|
||||
<-errc
|
||||
if err := res.Header.Write(brw.Writer); err != nil {
|
||||
httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("response write: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := brw.Flush(); err != nil {
|
||||
httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("response flush: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
spc := switchProtocolCopier{user: conn, backend: backConn}
|
||||
go spc.copyToBackend(errCh)
|
||||
go spc.copyFromBackend(errCh)
|
||||
<-errCh
|
||||
}
|
||||
}
|
||||
|
||||
func upgradeType(h http.Header) string {
|
||||
|
|
|
@ -38,7 +38,7 @@ func NewProxyBuilder(transportManager TransportManager, semConvMetricsRegistry *
|
|||
func (r *ProxyBuilder) Update(_ map[string]*dynamic.ServersTransport) {}
|
||||
|
||||
// Build builds a new httputil.ReverseProxy with the given configuration.
|
||||
func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, shouldObserve, passHostHeader bool, flushInterval time.Duration) (http.Handler, error) {
|
||||
func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, shouldObserve, passHostHeader, preservePath bool, flushInterval time.Duration) (http.Handler, error) {
|
||||
roundTripper, err := r.transportManager.GetRoundTripper(cfgName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting RoundTripper: %w", err)
|
||||
|
@ -50,5 +50,5 @@ func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, shouldObserve,
|
|||
roundTripper = newObservabilityRoundTripper(r.semConvMetricsRegistry, roundTripper)
|
||||
}
|
||||
|
||||
return buildSingleHostProxy(targetURL, passHostHeader, flushInterval, roundTripper, r.bufferPool), nil
|
||||
return buildSingleHostProxy(targetURL, passHostHeader, preservePath, flushInterval, roundTripper, r.bufferPool), nil
|
||||
}
|
||||
|
|
|
@ -23,7 +23,7 @@ func TestEscapedPath(t *testing.T) {
|
|||
roundTrippers: map[string]http.RoundTripper{"default": &http.Transport{}},
|
||||
}
|
||||
|
||||
p, err := NewProxyBuilder(transportManager, nil).Build("default", testhelpers.MustParseURL(srv.URL), false, true, 0)
|
||||
p, err := NewProxyBuilder(transportManager, nil).Build("default", testhelpers.MustParseURL(srv.URL), false, true, false, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(p.ServeHTTP))
|
||||
|
|
|
@ -15,15 +15,17 @@ import (
|
|||
"golang.org/x/net/http/httpguts"
|
||||
)
|
||||
|
||||
// StatusClientClosedRequest non-standard HTTP status code for client disconnection.
|
||||
const StatusClientClosedRequest = 499
|
||||
const (
|
||||
// StatusClientClosedRequest non-standard HTTP status code for client disconnection.
|
||||
StatusClientClosedRequest = 499
|
||||
|
||||
// StatusClientClosedRequestText non-standard HTTP status for client disconnection.
|
||||
const StatusClientClosedRequestText = "Client Closed Request"
|
||||
// StatusClientClosedRequestText non-standard HTTP status for client disconnection.
|
||||
StatusClientClosedRequestText = "Client Closed Request"
|
||||
)
|
||||
|
||||
func buildSingleHostProxy(target *url.URL, passHostHeader bool, flushInterval time.Duration, roundTripper http.RoundTripper, bufferPool httputil.BufferPool) http.Handler {
|
||||
func buildSingleHostProxy(target *url.URL, passHostHeader bool, preservePath bool, flushInterval time.Duration, roundTripper http.RoundTripper, bufferPool httputil.BufferPool) http.Handler {
|
||||
return &httputil.ReverseProxy{
|
||||
Director: directorBuilder(target, passHostHeader),
|
||||
Director: directorBuilder(target, passHostHeader, preservePath),
|
||||
Transport: roundTripper,
|
||||
FlushInterval: flushInterval,
|
||||
BufferPool: bufferPool,
|
||||
|
@ -31,7 +33,7 @@ func buildSingleHostProxy(target *url.URL, passHostHeader bool, flushInterval ti
|
|||
}
|
||||
}
|
||||
|
||||
func directorBuilder(target *url.URL, passHostHeader bool) func(req *http.Request) {
|
||||
func directorBuilder(target *url.URL, passHostHeader bool, preservePath bool) func(req *http.Request) {
|
||||
return func(outReq *http.Request) {
|
||||
outReq.URL.Scheme = target.Scheme
|
||||
outReq.URL.Host = target.Host
|
||||
|
@ -46,6 +48,11 @@ func directorBuilder(target *url.URL, passHostHeader bool) func(req *http.Reques
|
|||
|
||||
outReq.URL.Path = u.Path
|
||||
outReq.URL.RawPath = u.RawPath
|
||||
|
||||
if preservePath {
|
||||
outReq.URL.Path, outReq.URL.RawPath = JoinURLPath(target, u)
|
||||
}
|
||||
|
||||
// If a plugin/middleware adds semicolons in query params, they should be urlEncoded.
|
||||
outReq.URL.RawQuery = strings.ReplaceAll(u.RawQuery, ";", "&")
|
||||
outReq.RequestURI = "" // Outgoing request should not have RequestURI
|
||||
|
@ -54,7 +61,7 @@ func directorBuilder(target *url.URL, passHostHeader bool) func(req *http.Reques
|
|||
outReq.ProtoMajor = 1
|
||||
outReq.ProtoMinor = 1
|
||||
|
||||
// Do not pass client Host header unless optsetter PassHostHeader is set.
|
||||
// Do not pass client Host header unless option PassHostHeader is set.
|
||||
if !passHostHeader {
|
||||
outReq.Host = outReq.URL.Host
|
||||
}
|
||||
|
@ -95,9 +102,14 @@ func isWebSocketUpgrade(req *http.Request) bool {
|
|||
|
||||
// ErrorHandler is the http.Handler called when something goes wrong when forwarding the request.
|
||||
func ErrorHandler(w http.ResponseWriter, req *http.Request, err error) {
|
||||
ErrorHandlerWithContext(req.Context(), w, err)
|
||||
}
|
||||
|
||||
// ErrorHandlerWithContext is the http.Handler called when something goes wrong when forwarding the request.
|
||||
func ErrorHandlerWithContext(ctx context.Context, w http.ResponseWriter, err error) {
|
||||
statusCode := ComputeStatusCode(err)
|
||||
|
||||
logger := log.Ctx(req.Context())
|
||||
logger := log.Ctx(ctx)
|
||||
logger.Debug().Err(err).Msgf("%d %s", statusCode, statusText(statusCode))
|
||||
|
||||
w.WriteHeader(statusCode)
|
||||
|
@ -106,6 +118,13 @@ func ErrorHandler(w http.ResponseWriter, req *http.Request, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
func statusText(statusCode int) string {
|
||||
if statusCode == StatusClientClosedRequest {
|
||||
return StatusClientClosedRequestText
|
||||
}
|
||||
return http.StatusText(statusCode)
|
||||
}
|
||||
|
||||
// ComputeStatusCode computes the HTTP status code according to the given error.
|
||||
func ComputeStatusCode(err error) int {
|
||||
switch {
|
||||
|
@ -127,9 +146,38 @@ func ComputeStatusCode(err error) int {
|
|||
return http.StatusInternalServerError
|
||||
}
|
||||
|
||||
func statusText(statusCode int) string {
|
||||
if statusCode == StatusClientClosedRequest {
|
||||
return StatusClientClosedRequestText
|
||||
// JoinURLPath computes the joined path and raw path of the given URLs.
|
||||
// From https://github.com/golang/go/blob/b521ebb55a9b26c8824b219376c7f91f7cda6ec2/src/net/http/httputil/reverseproxy.go#L221
|
||||
func JoinURLPath(a, b *url.URL) (path, rawpath string) {
|
||||
if a.RawPath == "" && b.RawPath == "" {
|
||||
return singleJoiningSlash(a.Path, b.Path), ""
|
||||
}
|
||||
return http.StatusText(statusCode)
|
||||
|
||||
// Same as singleJoiningSlash, but uses EscapedPath to determine
|
||||
// whether a slash should be added
|
||||
apath := a.EscapedPath()
|
||||
bpath := b.EscapedPath()
|
||||
|
||||
aslash := strings.HasSuffix(apath, "/")
|
||||
bslash := strings.HasPrefix(bpath, "/")
|
||||
|
||||
switch {
|
||||
case aslash && bslash:
|
||||
return a.Path + b.Path[1:], apath + bpath[1:]
|
||||
case !aslash && !bslash:
|
||||
return a.Path + "/" + b.Path, apath + "/" + bpath
|
||||
}
|
||||
return a.Path + b.Path, apath + bpath
|
||||
}
|
||||
|
||||
func singleJoiningSlash(a, b string) string {
|
||||
aslash := strings.HasSuffix(a, "/")
|
||||
bslash := strings.HasPrefix(b, "/")
|
||||
switch {
|
||||
case aslash && bslash:
|
||||
return a + b[1:]
|
||||
case !aslash && !bslash:
|
||||
return a + "/" + b
|
||||
}
|
||||
return a + b
|
||||
}
|
||||
|
|
102
pkg/proxy/httputil/proxy_test.go
Normal file
102
pkg/proxy/httputil/proxy_test.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/traefik/traefik/v3/pkg/testhelpers"
|
||||
)
|
||||
|
||||
func Test_directorBuilder(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
target *url.URL
|
||||
passHostHeader bool
|
||||
preservePath bool
|
||||
incomingURL string
|
||||
expectedScheme string
|
||||
expectedHost string
|
||||
expectedPath string
|
||||
expectedRawPath string
|
||||
expectedQuery string
|
||||
}{
|
||||
{
|
||||
name: "Basic proxy",
|
||||
target: testhelpers.MustParseURL("http://example.com"),
|
||||
passHostHeader: false,
|
||||
preservePath: false,
|
||||
incomingURL: "http://localhost/test?param=value",
|
||||
expectedScheme: "http",
|
||||
expectedHost: "example.com",
|
||||
expectedPath: "/test",
|
||||
expectedQuery: "param=value",
|
||||
},
|
||||
{
|
||||
name: "HTTPS target",
|
||||
target: testhelpers.MustParseURL("https://secure.example.com"),
|
||||
passHostHeader: false,
|
||||
preservePath: false,
|
||||
incomingURL: "http://localhost/secure",
|
||||
expectedScheme: "https",
|
||||
expectedHost: "secure.example.com",
|
||||
expectedPath: "/secure",
|
||||
},
|
||||
{
|
||||
name: "PassHostHeader",
|
||||
target: testhelpers.MustParseURL("http://example.com"),
|
||||
passHostHeader: true,
|
||||
preservePath: false,
|
||||
incomingURL: "http://original.host/test",
|
||||
expectedScheme: "http",
|
||||
expectedHost: "original.host",
|
||||
expectedPath: "/test",
|
||||
},
|
||||
{
|
||||
name: "Preserve path",
|
||||
target: testhelpers.MustParseURL("http://example.com/base"),
|
||||
passHostHeader: false,
|
||||
preservePath: true,
|
||||
incomingURL: "http://localhost/foo%2Fbar",
|
||||
expectedScheme: "http",
|
||||
expectedHost: "example.com",
|
||||
expectedPath: "/base/foo/bar",
|
||||
expectedRawPath: "/base/foo%2Fbar",
|
||||
},
|
||||
{
|
||||
name: "Handle semicolons in query",
|
||||
target: testhelpers.MustParseURL("http://example.com"),
|
||||
passHostHeader: false,
|
||||
preservePath: false,
|
||||
incomingURL: "http://localhost/test?param1=value1;param2=value2",
|
||||
expectedScheme: "http",
|
||||
expectedHost: "example.com",
|
||||
expectedPath: "/test",
|
||||
expectedQuery: "param1=value1¶m2=value2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
director := directorBuilder(test.target, test.passHostHeader, test.preservePath)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, test.incomingURL, http.NoBody)
|
||||
director(req)
|
||||
|
||||
assert.Equal(t, test.expectedScheme, req.URL.Scheme)
|
||||
assert.Equal(t, test.expectedHost, req.Host)
|
||||
assert.Equal(t, test.expectedPath, req.URL.Path)
|
||||
assert.Equal(t, test.expectedRawPath, req.URL.RawPath)
|
||||
assert.Equal(t, test.expectedQuery, req.URL.RawQuery)
|
||||
assert.Empty(t, req.RequestURI)
|
||||
assert.Equal(t, "HTTP/1.1", req.Proto)
|
||||
assert.Equal(t, 1, req.ProtoMajor)
|
||||
assert.Equal(t, 1, req.ProtoMinor)
|
||||
assert.False(t, !test.passHostHeader && req.Host != req.URL.Host)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -298,9 +298,8 @@ func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
p, err := NewProxyBuilder(transportManager, nil).Build("default@internal", testhelpers.MustParseURL(srv.URL), false, true, 0)
|
||||
p, err := NewProxyBuilder(transportManager, nil).Build("default@internal", testhelpers.MustParseURL(srv.URL), false, true, false, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
req.URL = testhelpers.MustParseURL(srv.URL)
|
||||
w.Header().Set("HEADER-KEY", "HEADER-VALUE")
|
||||
|
@ -355,9 +354,8 @@ func TestWebSocketUpgradeFailed(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
p, err := NewProxyBuilder(transportManager, nil).Build("default@internal", testhelpers.MustParseURL(srv.URL), false, true, 0)
|
||||
p, err := NewProxyBuilder(transportManager, nil).Build("default@internal", testhelpers.MustParseURL(srv.URL), false, true, false, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
path := req.URL.Path // keep the original path
|
||||
|
||||
|
@ -588,7 +586,7 @@ func createProxyWithForwarder(t *testing.T, uri string, transport http.RoundTrip
|
|||
roundTrippers: map[string]http.RoundTripper{"fwd": transport},
|
||||
}
|
||||
|
||||
p, err := NewProxyBuilder(transportManager, nil).Build("fwd", u, false, true, 0)
|
||||
p, err := NewProxyBuilder(transportManager, nil).Build("fwd", u, false, true, false, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
|
|
|
@ -45,7 +45,7 @@ func (b *SmartBuilder) Update(newConfigs map[string]*dynamic.ServersTransport) {
|
|||
}
|
||||
|
||||
// Build builds an HTTP proxy for the given URL using the ServersTransport with the given name.
|
||||
func (b *SmartBuilder) Build(configName string, targetURL *url.URL, shouldObserve, passHostHeader bool, flushInterval time.Duration) (http.Handler, error) {
|
||||
func (b *SmartBuilder) Build(configName string, targetURL *url.URL, shouldObserve, passHostHeader, preservePath bool, flushInterval time.Duration) (http.Handler, error) {
|
||||
serversTransport, err := b.transportManager.Get(configName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting ServersTransport: %w", err)
|
||||
|
@ -55,7 +55,7 @@ func (b *SmartBuilder) Build(configName string, targetURL *url.URL, shouldObserv
|
|||
// For the https scheme we cannot guess if the backend communication will use HTTP2,
|
||||
// thus we check if HTTP/2 is disabled to use the fast proxy implementation when this is possible.
|
||||
if targetURL.Scheme == "h2c" || (targetURL.Scheme == "https" && !serversTransport.DisableHTTP2) {
|
||||
return b.proxyBuilder.Build(configName, targetURL, shouldObserve, passHostHeader, flushInterval)
|
||||
return b.proxyBuilder.Build(configName, targetURL, shouldObserve, passHostHeader, preservePath, flushInterval)
|
||||
}
|
||||
return b.fastProxyBuilder.Build(configName, targetURL, passHostHeader)
|
||||
return b.fastProxyBuilder.Build(configName, targetURL, passHostHeader, preservePath)
|
||||
}
|
||||
|
|
|
@ -101,7 +101,7 @@ func TestSmartBuilder_Build(t *testing.T) {
|
|||
httpProxyBuilder := httputil.NewProxyBuilder(transportManager, nil)
|
||||
proxyBuilder := NewSmartBuilder(transportManager, httpProxyBuilder, test.fastProxyConfig)
|
||||
|
||||
proxyHandler, err := proxyBuilder.Build("test", targetURL, false, false, time.Second)
|
||||
proxyHandler, err := proxyBuilder.Build("test", targetURL, false, false, false, time.Second)
|
||||
require.NoError(t, err)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
|
|
@ -897,7 +897,7 @@ func BenchmarkService(b *testing.B) {
|
|||
|
||||
type proxyBuilderMock struct{}
|
||||
|
||||
func (p proxyBuilderMock) Build(_ string, _ *url.URL, _, _ bool, _ time.Duration) (http.Handler, error) {
|
||||
func (p proxyBuilderMock) Build(_ string, _ *url.URL, _, _, _ bool, _ time.Duration) (http.Handler, error) {
|
||||
return http.HandlerFunc(func(responseWriter http.ResponseWriter, req *http.Request) {}), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -254,7 +254,7 @@ func TestInternalServices(t *testing.T) {
|
|||
|
||||
type proxyBuilderMock struct{}
|
||||
|
||||
func (p proxyBuilderMock) Build(_ string, _ *url.URL, _, _ bool, _ time.Duration) (http.Handler, error) {
|
||||
func (p proxyBuilderMock) Build(_ string, _ *url.URL, _, _, _ bool, _ time.Duration) (http.Handler, error) {
|
||||
return http.HandlerFunc(func(responseWriter http.ResponseWriter, req *http.Request) {}), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ const (
|
|||
|
||||
// ProxyBuilder builds reverse proxy handlers.
|
||||
type ProxyBuilder interface {
|
||||
Build(cfgName string, targetURL *url.URL, shouldObserve, passHostHeader bool, flushInterval time.Duration) (http.Handler, error)
|
||||
Build(cfgName string, targetURL *url.URL, shouldObserve, passHostHeader, preservePath bool, flushInterval time.Duration) (http.Handler, error)
|
||||
Update(configs map[string]*dynamic.ServersTransport)
|
||||
}
|
||||
|
||||
|
@ -338,7 +338,7 @@ func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName
|
|||
qualifiedSvcName := provider.GetQualifiedName(ctx, serviceName)
|
||||
|
||||
shouldObserve := m.observabilityMgr.ShouldAddTracing(qualifiedSvcName) || m.observabilityMgr.ShouldAddMetrics(qualifiedSvcName)
|
||||
proxy, err := m.proxyBuilder.Build(service.ServersTransport, target, shouldObserve, passHostHeader, flushInterval)
|
||||
proxy, err := m.proxyBuilder.Build(service.ServersTransport, target, shouldObserve, passHostHeader, server.PreservePath, flushInterval)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error building proxy for server URL %s: %w", server.URL, err)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue