1
0
Fork 0

Merge branch v3.2 into master

This commit is contained in:
kevinpollet 2024-10-29 09:29:27 +01:00
commit 7004f0e750
No known key found for this signature in database
GPG key ID: 0C9A5DDD1B292453
65 changed files with 2393 additions and 1222 deletions

View file

@ -244,10 +244,11 @@ func (r *ResponseForwarding) SetDefaults() {
// Server holds the server configuration.
type Server struct {
URL string `json:"url,omitempty" toml:"url,omitempty" yaml:"url,omitempty" label:"-"`
Weight *int `json:"weight,omitempty" toml:"weight,omitempty" yaml:"weight,omitempty" label:"weight"`
Scheme string `json:"-" toml:"-" yaml:"-" file:"-"`
Port string `json:"-" toml:"-" yaml:"-" file:"-"`
URL string `json:"url,omitempty" toml:"url,omitempty" yaml:"url,omitempty" label:"-"`
Weight *int `json:"weight,omitempty" toml:"weight,omitempty" yaml:"weight,omitempty" label:"weight" export:"true"`
PreservePath bool `json:"preservePath,omitempty" toml:"preservePath,omitempty" yaml:"preservePath,omitempty" label:"-" export:"true"`
Scheme string `json:"-" toml:"-" yaml:"-" file:"-"`
Port string `json:"-" toml:"-" yaml:"-" file:"-"`
}
// SetDefaults Default values for a Server.

View file

@ -25,7 +25,6 @@ import (
"github.com/stretchr/testify/require"
ptypes "github.com/traefik/paerser/types"
"github.com/traefik/traefik/v3/pkg/middlewares/capture"
"github.com/traefik/traefik/v3/pkg/middlewares/recovery"
"github.com/traefik/traefik/v3/pkg/types"
)
@ -954,8 +953,14 @@ func doLoggingWithAbortedStream(t *testing.T, config *types.AccessLog) {
req = req.WithContext(reqContext)
chain := alice.New()
chain = chain.Append(func(next http.Handler) (http.Handler, error) {
return recovery.New(context.Background(), next)
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
defer func() {
_ = recover() // ignore the stream backend panic to avoid the test to fail.
}()
next.ServeHTTP(rw, req)
}), nil
})
chain = chain.Append(capture.Wrap)
chain = chain.Append(WrapHandler(logger))

View file

@ -8,7 +8,9 @@ import (
"net/http"
"slices"
"github.com/andybalholm/brotli"
"github.com/klauspost/compress/gzhttp"
"github.com/klauspost/compress/zstd"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/middlewares"
"go.opentelemetry.io/otel/trace"
@ -94,12 +96,12 @@ func New(ctx context.Context, next http.Handler, conf dynamic.Compress, name str
var err error
c.zstdHandler, err = c.newCompressionHandler(zstdName, name)
c.zstdHandler, err = c.newZstdHandler(name)
if err != nil {
return nil, err
}
c.brotliHandler, err = c.newCompressionHandler(brotliName, name)
c.brotliHandler, err = c.newBrotliHandler(name)
if err != nil {
return nil, err
}
@ -190,13 +192,34 @@ func (c *compress) newGzipHandler() (http.Handler, error) {
return wrapper(c.next), nil
}
func (c *compress) newCompressionHandler(algo string, middlewareName string) (http.Handler, error) {
cfg := Config{MinSize: c.minSize, Algorithm: algo, MiddlewareName: middlewareName}
func (c *compress) newBrotliHandler(middlewareName string) (http.Handler, error) {
cfg := Config{MinSize: c.minSize, MiddlewareName: middlewareName}
if len(c.includes) > 0 {
cfg.IncludedContentTypes = c.includes
} else {
cfg.ExcludedContentTypes = c.excludes
}
return NewCompressionHandler(cfg, c.next)
newBrotliWriter := func(rw http.ResponseWriter) (CompressionWriter, string, error) {
return brotli.NewWriter(rw), brotliName, nil
}
return NewCompressionHandler(cfg, newBrotliWriter, c.next)
}
func (c *compress) newZstdHandler(middlewareName string) (http.Handler, error) {
cfg := Config{MinSize: c.minSize, MiddlewareName: middlewareName}
if len(c.includes) > 0 {
cfg.IncludedContentTypes = c.includes
} else {
cfg.ExcludedContentTypes = c.excludes
}
newZstdWriter := func(rw http.ResponseWriter) (CompressionWriter, string, error) {
writer, err := zstd.NewWriter(rw)
if err != nil {
return nil, "", fmt.Errorf("creating zstd writer: %w", err)
}
return writer, zstdName, nil
}
return NewCompressionHandler(cfg, newZstdWriter, c.next)
}

View file

@ -10,8 +10,6 @@ import (
"net/http"
"sync"
"github.com/andybalholm/brotli"
"github.com/klauspost/compress/zstd"
"github.com/traefik/traefik/v3/pkg/middlewares"
"github.com/traefik/traefik/v3/pkg/middlewares/observability"
)
@ -24,6 +22,30 @@ const (
contentType = "Content-Type"
)
// CompressionWriter compresses the written bytes.
type CompressionWriter interface {
// Write data to the encoder.
// Input data will be buffered and as the buffer fills up
// content will be compressed and written to the output.
// When done writing, use Close to flush the remaining output
// and write CRC if requested.
Write(p []byte) (n int, err error)
// Flush will send the currently written data to output
// and block until everything has been written.
// This should only be used on rare occasions where pushing the currently queued data is critical.
Flush() error
// Close closes the underlying writers if/when appropriate.
// Note that the compressed writer should not be closed if we never used it,
// as it would otherwise send some extra "end of compression" bytes.
// Close also makes sure to flush whatever was left to write from the buffer.
Close() error
// Reset reinitializes the state of the encoder, allowing it to be reused.
Reset(w io.Writer)
}
// NewCompressionWriter returns a new CompressionWriter with its corresponding algorithm.
type NewCompressionWriter func(rw http.ResponseWriter) (CompressionWriter, string, error)
// Config is the Brotli handler configuration.
type Config struct {
// ExcludedContentTypes is the list of content types for which we should not compress.
@ -34,8 +56,6 @@ type Config struct {
IncludedContentTypes []string
// MinSize is the minimum size (in bytes) required to enable compression.
MinSize int
// Algorithm used for the compression (currently Brotli and Zstandard)
Algorithm string
// MiddlewareName use for logging purposes
MiddlewareName string
}
@ -46,15 +66,13 @@ type CompressionHandler struct {
excludedContentTypes []parsedContentType
includedContentTypes []parsedContentType
next http.Handler
writerPool sync.Pool
writerPool sync.Pool
newWriter NewCompressionWriter
}
// NewCompressionHandler returns a new compressing handler.
func NewCompressionHandler(cfg Config, next http.Handler) (http.Handler, error) {
if cfg.Algorithm == "" {
return nil, errors.New("compression algorithm undefined")
}
func NewCompressionHandler(cfg Config, newWriter NewCompressionWriter, next http.Handler) (http.Handler, error) {
if cfg.MinSize < 0 {
return nil, errors.New("minimum size must be greater than or equal to zero")
}
@ -88,6 +106,7 @@ func NewCompressionHandler(cfg Config, next http.Handler) (http.Handler, error)
excludedContentTypes: excludedContentTypes,
includedContentTypes: includedContentTypes,
next: next,
newWriter: newWriter,
}, nil
}
@ -117,70 +136,38 @@ func (c *CompressionHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request)
c.next.ServeHTTP(responseWriter, r)
}
type compression interface {
// Write data to the encoder.
// Input data will be buffered and as the buffer fills up
// content will be compressed and written to the output.
// When done writing, use Close to flush the remaining output
// and write CRC if requested.
Write(p []byte) (n int, err error)
// Flush will send the currently written data to output
// and block until everything has been written.
// This should only be used on rare occasions where pushing the currently queued data is critical.
Flush() error
// Close closes the underlying writers if/when appropriate.
// Note that the compressed writer should not be closed if we never used it,
// as it would otherwise send some extra "end of compression" bytes.
// Close also makes sure to flush whatever was left to write from the buffer.
Close() error
// Reset reinitializes the state of the encoder, allowing it to be reused.
Reset(w io.Writer)
}
type compressionWriter struct {
compression
alg string
}
func (c *CompressionHandler) getCompressionWriter(rw io.Writer) (*compressionWriter, error) {
if writer, ok := c.writerPool.Get().(*compressionWriter); ok {
writer.compression.Reset(rw)
func (c *CompressionHandler) getCompressionWriter(rw http.ResponseWriter) (*compressionWriterWrapper, error) {
if writer, ok := c.writerPool.Get().(*compressionWriterWrapper); ok {
writer.Reset(rw)
return writer, nil
}
return newCompressionWriter(c.cfg.Algorithm, rw)
writer, algo, err := c.newWriter(rw)
if err != nil {
return nil, fmt.Errorf("creating compression writer: %w", err)
}
return &compressionWriterWrapper{CompressionWriter: writer, algo: algo}, nil
}
func (c *CompressionHandler) putCompressionWriter(writer *compressionWriter) {
func (c *CompressionHandler) putCompressionWriter(writer *compressionWriterWrapper) {
writer.Reset(nil)
c.writerPool.Put(writer)
}
func newCompressionWriter(algo string, in io.Writer) (*compressionWriter, error) {
switch algo {
case brotliName:
return &compressionWriter{compression: brotli.NewWriter(in), alg: algo}, nil
case zstdName:
writer, err := zstd.NewWriter(in)
if err != nil {
return nil, fmt.Errorf("creating zstd writer: %w", err)
}
return &compressionWriter{compression: writer, alg: algo}, nil
default:
return nil, fmt.Errorf("unknown compression algo: %s", algo)
}
type compressionWriterWrapper struct {
CompressionWriter
algo string
}
func (c *compressionWriter) ContentEncoding() string {
return c.alg
func (c *compressionWriterWrapper) ContentEncoding() string {
return c.algo
}
// TODO: check whether we want to implement content-type sniffing (as gzip does)
// TODO: check whether we should support Accept-Ranges (as gzip does, see https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Ranges)
type responseWriter struct {
rw http.ResponseWriter
compressionWriter *compressionWriter
compressionWriter *compressionWriterWrapper
minSize int
excludedContentTypes []parsedContentType

View file

@ -162,7 +162,7 @@ func Test_NoBody(t *testing.T) {
require.NoError(t, err)
})
h := mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName}, next)
h := mustNewCompressionHandler(t, Config{MinSize: 1024}, zstdName, next)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(acceptEncoding, "zstd")
@ -181,8 +181,7 @@ func Test_NoBody(t *testing.T) {
func Test_MinSize(t *testing.T) {
cfg := Config{
MinSize: 128,
Algorithm: zstdName,
MinSize: 128,
}
var bodySize int
@ -197,7 +196,7 @@ func Test_MinSize(t *testing.T) {
}
})
h := mustNewCompressionHandler(t, cfg, next)
h := mustNewCompressionHandler(t, cfg, zstdName, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", &bytes.Buffer{})
req.Header.Add(acceptEncoding, "zstd")
@ -224,7 +223,7 @@ func Test_MultipleWriteHeader(t *testing.T) {
rw.WriteHeader(http.StatusNotFound)
})
h := mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName}, next)
h := mustNewCompressionHandler(t, Config{MinSize: 1024}, zstdName, next)
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(acceptEncoding, "zstd")
@ -239,12 +238,14 @@ func Test_FlushBeforeWrite(t *testing.T) {
testCases := []struct {
desc string
cfg Config
algo string
readerBuilder func(io.Reader) (io.Reader, error)
acceptEncoding string
}{
{
desc: "brotli",
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
algo: brotliName,
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return brotli.NewReader(reader), nil
},
@ -252,7 +253,8 @@ func Test_FlushBeforeWrite(t *testing.T) {
},
{
desc: "zstd",
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
algo: zstdName,
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return zstd.NewReader(reader)
},
@ -272,7 +274,7 @@ func Test_FlushBeforeWrite(t *testing.T) {
require.NoError(t, err)
})
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
@ -302,12 +304,14 @@ func Test_FlushAfterWrite(t *testing.T) {
testCases := []struct {
desc string
cfg Config
algo string
readerBuilder func(io.Reader) (io.Reader, error)
acceptEncoding string
}{
{
desc: "brotli",
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
algo: brotliName,
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return brotli.NewReader(reader), nil
},
@ -315,7 +319,8 @@ func Test_FlushAfterWrite(t *testing.T) {
},
{
desc: "zstd",
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
algo: zstdName,
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return zstd.NewReader(reader)
},
@ -338,7 +343,7 @@ func Test_FlushAfterWrite(t *testing.T) {
}
})
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
@ -368,12 +373,14 @@ func Test_FlushAfterWriteNil(t *testing.T) {
testCases := []struct {
desc string
cfg Config
algo string
readerBuilder func(io.Reader) (io.Reader, error)
acceptEncoding string
}{
{
desc: "brotli",
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
algo: brotliName,
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return brotli.NewReader(reader), nil
},
@ -381,7 +388,8 @@ func Test_FlushAfterWriteNil(t *testing.T) {
},
{
desc: "zstd",
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
algo: zstdName,
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return zstd.NewReader(reader)
},
@ -400,7 +408,7 @@ func Test_FlushAfterWriteNil(t *testing.T) {
rw.(http.Flusher).Flush()
})
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
@ -430,12 +438,14 @@ func Test_FlushAfterAllWrites(t *testing.T) {
testCases := []struct {
desc string
cfg Config
algo string
readerBuilder func(io.Reader) (io.Reader, error)
acceptEncoding string
}{
{
desc: "brotli",
cfg: Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Test"},
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
algo: brotliName,
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return brotli.NewReader(reader), nil
},
@ -443,7 +453,8 @@ func Test_FlushAfterAllWrites(t *testing.T) {
},
{
desc: "zstd",
cfg: Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Test"},
cfg: Config{MinSize: 1024, MiddlewareName: "Test"},
algo: zstdName,
readerBuilder: func(reader io.Reader) (io.Reader, error) {
return zstd.NewReader(reader)
},
@ -461,7 +472,7 @@ func Test_FlushAfterAllWrites(t *testing.T) {
rw.(http.Flusher).Flush()
})
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, next))
srv := httptest.NewServer(mustNewCompressionHandler(t, test.cfg, test.algo, next))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL, http.NoBody)
@ -556,7 +567,6 @@ func Test_ExcludedContentTypes(t *testing.T) {
cfg := Config{
MinSize: 1024,
ExcludedContentTypes: test.excludedContentTypes,
Algorithm: zstdName,
}
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
@ -568,7 +578,7 @@ func Test_ExcludedContentTypes(t *testing.T) {
require.NoError(t, err)
})
h := mustNewCompressionHandler(t, cfg, next)
h := mustNewCompressionHandler(t, cfg, zstdName, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, zstdName)
@ -667,7 +677,6 @@ func Test_IncludedContentTypes(t *testing.T) {
cfg := Config{
MinSize: 1024,
IncludedContentTypes: test.includedContentTypes,
Algorithm: zstdName,
}
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
@ -679,7 +688,7 @@ func Test_IncludedContentTypes(t *testing.T) {
require.NoError(t, err)
})
h := mustNewCompressionHandler(t, cfg, next)
h := mustNewCompressionHandler(t, cfg, zstdName, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, zstdName)
@ -778,7 +787,6 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
cfg := Config{
MinSize: 1024,
ExcludedContentTypes: test.excludedContentTypes,
Algorithm: zstdName,
}
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
@ -803,7 +811,7 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
}
})
h := mustNewCompressionHandler(t, cfg, next)
h := mustNewCompressionHandler(t, cfg, zstdName, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, zstdName)
@ -903,7 +911,6 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
cfg := Config{
MinSize: 1024,
IncludedContentTypes: test.includedContentTypes,
Algorithm: zstdName,
}
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
@ -928,7 +935,7 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
}
})
h := mustNewCompressionHandler(t, cfg, next)
h := mustNewCompressionHandler(t, cfg, zstdName, next)
req, _ := http.NewRequest(http.MethodGet, "/whatever", nil)
req.Header.Set(acceptEncoding, zstdName)
@ -959,10 +966,26 @@ func Test_FlushIncludedContentTypes(t *testing.T) {
}
}
func mustNewCompressionHandler(t *testing.T, cfg Config, next http.Handler) http.Handler {
func mustNewCompressionHandler(t *testing.T, cfg Config, algo string, next http.Handler) http.Handler {
t.Helper()
w, err := NewCompressionHandler(cfg, next)
var writer NewCompressionWriter
switch algo {
case zstdName:
writer = func(rw http.ResponseWriter) (CompressionWriter, string, error) {
writer, err := zstd.NewWriter(rw)
require.NoError(t, err)
return writer, zstdName, nil
}
case brotliName:
writer = func(rw http.ResponseWriter) (CompressionWriter, string, error) {
return brotli.NewWriter(rw), brotliName, nil
}
default:
assert.Failf(t, "unknown compression algorithm: %s", algo)
}
w, err := NewCompressionHandler(cfg, writer, next)
require.NoError(t, err)
return w
@ -981,7 +1004,7 @@ func newTestBrotliHandler(t *testing.T, body []byte) http.Handler {
require.NoError(t, err)
})
return mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: brotliName, MiddlewareName: "Compress"}, next)
return mustNewCompressionHandler(t, Config{MinSize: 1024, MiddlewareName: "Compress"}, brotliName, next)
}
func newTestZstandardHandler(t *testing.T, body []byte) http.Handler {
@ -997,7 +1020,7 @@ func newTestZstandardHandler(t *testing.T, body []byte) http.Handler {
require.NoError(t, err)
})
return mustNewCompressionHandler(t, Config{MinSize: 1024, Algorithm: zstdName, MiddlewareName: "Compress"}, next)
return mustNewCompressionHandler(t, Config{MinSize: 1024, MiddlewareName: "Compress"}, zstdName, next)
}
func Test_ParseContentType_equals(t *testing.T) {

View file

@ -1,7 +1,10 @@
package recovery
import (
"bufio"
"context"
"fmt"
"net"
"net/http"
"runtime"
@ -27,12 +30,16 @@ func New(ctx context.Context, next http.Handler) (http.Handler, error) {
}
func (re *recovery) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
defer recoverFunc(rw, req)
re.next.ServeHTTP(rw, req)
recoveryRW := newRecoveryResponseWriter(rw)
defer recoverFunc(recoveryRW, req)
re.next.ServeHTTP(recoveryRW, req)
}
func recoverFunc(rw http.ResponseWriter, req *http.Request) {
func recoverFunc(rw recoveryResponseWriter, req *http.Request) {
if err := recover(); err != nil {
defer rw.finalizeResponse()
logger := middlewares.GetLogger(req.Context(), middlewareName, typeName)
if !shouldLogPanic(err) {
logger.Debug().Msgf("Request has been aborted [%s - %s]: %v", req.RemoteAddr, req.URL, err)
@ -44,8 +51,6 @@ func recoverFunc(rw http.ResponseWriter, req *http.Request) {
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
logger.Error().Msgf("Stack: %s", buf)
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
}
@ -55,3 +60,81 @@ func shouldLogPanic(panicValue interface{}) bool {
//nolint:errorlint // false-positive because panicValue is an interface.
return panicValue != nil && panicValue != http.ErrAbortHandler
}
type recoveryResponseWriter interface {
http.ResponseWriter
finalizeResponse()
}
func newRecoveryResponseWriter(rw http.ResponseWriter) recoveryResponseWriter {
wrapper := &responseWriterWrapper{rw: rw}
if _, ok := rw.(http.CloseNotifier); !ok {
return wrapper
}
return &responseWriterWrapperWithCloseNotify{wrapper}
}
type responseWriterWrapper struct {
rw http.ResponseWriter
headersSent bool
}
func (r *responseWriterWrapper) Header() http.Header {
return r.rw.Header()
}
func (r *responseWriterWrapper) Write(bytes []byte) (int, error) {
r.headersSent = true
return r.rw.Write(bytes)
}
func (r *responseWriterWrapper) WriteHeader(code int) {
if r.headersSent {
return
}
// Handling informational headers.
if code >= 100 && code <= 199 {
r.rw.WriteHeader(code)
return
}
r.headersSent = true
r.rw.WriteHeader(code)
}
func (r *responseWriterWrapper) Flush() {
if f, ok := r.rw.(http.Flusher); ok {
f.Flush()
}
}
func (r *responseWriterWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if h, ok := r.rw.(http.Hijacker); ok {
return h.Hijack()
}
return nil, nil, fmt.Errorf("not a hijacker: %T", r.rw)
}
func (r *responseWriterWrapper) finalizeResponse() {
// If headers have been sent this is not possible to respond with an HTTP error,
// and we let the server abort the response silently thanks to the http.ErrAbortHandler sentinel panic value.
if r.headersSent {
panic(http.ErrAbortHandler)
}
// The response has not yet started to be written,
// we can safely return a fresh new error response.
http.Error(r.rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
type responseWriterWrapperWithCloseNotify struct {
*responseWriterWrapper
}
func (r *responseWriterWrapperWithCloseNotify) CloseNotify() <-chan bool {
return r.rw.(http.CloseNotifier).CloseNotify()
}

View file

@ -2,6 +2,8 @@ package recovery
import (
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"testing"
@ -11,17 +13,54 @@ import (
)
func TestRecoverHandler(t *testing.T) {
fn := func(w http.ResponseWriter, r *http.Request) {
panic("I love panicking!")
tests := []struct {
desc string
panicErr error
headersSent bool
}{
{
desc: "headers sent and custom panic error",
panicErr: errors.New("foo"),
headersSent: true,
},
{
desc: "headers sent and error abort handler",
panicErr: http.ErrAbortHandler,
headersSent: true,
},
{
desc: "custom panic error",
panicErr: errors.New("foo"),
},
{
desc: "error abort handler",
panicErr: http.ErrAbortHandler,
},
}
recovery, err := New(context.Background(), http.HandlerFunc(fn))
require.NoError(t, err)
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(recovery)
defer server.Close()
fn := func(rw http.ResponseWriter, req *http.Request) {
if test.headersSent {
rw.WriteHeader(http.StatusTeapot)
}
panic(test.panicErr)
}
recovery, err := New(context.Background(), http.HandlerFunc(fn))
require.NoError(t, err)
resp, err := http.Get(server.URL)
require.NoError(t, err)
server := httptest.NewServer(recovery)
t.Cleanup(server.Close)
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
res, err := http.Get(server.URL)
if test.headersSent {
require.Nil(t, res)
assert.ErrorIs(t, err, io.EOF)
} else {
require.NoError(t, err)
assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
}
})
}
}

View file

@ -312,19 +312,9 @@ func (p *Provider) getClient() (*lego.Client, error) {
}
err = client.Challenge.SetDNS01Provider(provider,
dns01.CondOption(len(p.DNSChallenge.Resolvers) > 0, dns01.AddRecursiveNameservers(p.DNSChallenge.Resolvers)),
dns01.WrapPreCheck(func(domain, fqdn, value string, check dns01.PreCheckFunc) (bool, error) {
if p.DNSChallenge.DelayBeforeCheck > 0 {
logger.Debug().Msgf("Delaying %d rather than validating DNS propagation now.", p.DNSChallenge.DelayBeforeCheck)
time.Sleep(time.Duration(p.DNSChallenge.DelayBeforeCheck))
}
if p.DNSChallenge.DisablePropagationCheck {
return true, nil
}
return check(fqdn, value)
}),
dns01.CondOption(len(p.DNSChallenge.Resolvers) > 0,
dns01.AddRecursiveNameservers(p.DNSChallenge.Resolvers)),
dns01.PropagationWait(time.Duration(p.DNSChallenge.DelayBeforeCheck), p.DNSChallenge.DisablePropagationCheck),
)
if err != nil {
return nil, err

View file

@ -0,0 +1,56 @@
---
kind: GatewayClass
apiVersion: gateway.networking.k8s.io/v1
metadata:
name: my-gateway-class
spec:
controllerName: traefik.io/gateway-controller
---
kind: Gateway
apiVersion: gateway.networking.k8s.io/v1
metadata:
name: my-gateway
namespace: default
spec:
gatewayClassName: my-gateway-class
listeners: # Use GatewayClass defaults for listener definition.
- name: http
protocol: HTTP
port: 80
allowedRoutes:
kinds:
- kind: GRPCRoute
group: gateway.networking.k8s.io
namespaces:
from: Same
---
kind: GRPCRoute
apiVersion: gateway.networking.k8s.io/v1
metadata:
name: grpc-app-1
namespace: default
spec:
parentRefs:
- name: my-gateway
kind: Gateway
group: gateway.networking.k8s.io
hostnames:
- foo.com
rules:
- backendRefs:
- name: whoami
port: 80
weight: 1
filters:
- type: ExtensionRef
extensionRef:
group: traefik.io
kind: Middleware
name: my-first-middleware
- type: ExtensionRef
extensionRef:
group: traefik.io
kind: Middleware
name: my-second-middleware

View file

@ -54,4 +54,9 @@ spec:
extensionRef:
group: traefik.io
kind: Middleware
name: my-middleware
name: my-first-middleware
- type: ExtensionRef
extensionRef:
group: traefik.io
kind: Middleware
name: my-second-middleware

View file

@ -120,7 +120,7 @@ func (p *Provider) loadGRPCRoute(ctx context.Context, listener gatewayListener,
for ri, routeRule := range route.Spec.Rules {
// Adding the gateway desc and the entryPoint desc prevents overlapping of routers build from the same routes.
routeKey := provider.Normalize(fmt.Sprintf("%s-%s-%s-%s-%d", route.Namespace, route.Name, listener.GWName, listener.EPName, ri))
routeKey := provider.Normalize(fmt.Sprintf("%s-%s-%s-gw-%s-%s-ep-%s-%d", strings.ToLower(kindGRPCRoute), route.Namespace, route.Name, listener.GWNamespace, listener.GWName, listener.EPName, ri))
matches := routeRule.Matches
if len(matches) == 0 {
@ -166,7 +166,7 @@ func (p *Provider) loadGRPCRoute(ctx context.Context, listener gatewayListener,
default:
var serviceCondition *metav1.Condition
router.Service, serviceCondition = p.loadGRPCService(conf, routeKey, routeRule, route)
router.Service, serviceCondition = p.loadGRPCService(conf, routerName, routeRule, route)
if serviceCondition != nil {
condition = *serviceCondition
}
@ -267,7 +267,7 @@ func (p *Provider) loadGRPCBackendRef(route *gatev1.GRPCRoute, backendRef gatev1
}
portStr := strconv.FormatInt(int64(port), 10)
serviceName = provider.Normalize(serviceName + "-" + portStr)
serviceName = provider.Normalize(serviceName + "-" + portStr + "-grpc")
lb, errCondition := p.loadGRPCServers(namespace, route, backendRef)
if errCondition != nil {
@ -278,20 +278,30 @@ func (p *Provider) loadGRPCBackendRef(route *gatev1.GRPCRoute, backendRef gatev1
}
func (p *Provider) loadGRPCMiddlewares(conf *dynamic.Configuration, namespace, routerName string, filters []gatev1.GRPCRouteFilter) ([]string, error) {
middlewares := make(map[string]*dynamic.Middleware)
type namedMiddleware struct {
Name string
Config *dynamic.Middleware
}
var middlewares []namedMiddleware
for i, filter := range filters {
name := fmt.Sprintf("%s-%s-%d", routerName, strings.ToLower(string(filter.Type)), i)
switch filter.Type {
case gatev1.GRPCRouteFilterRequestHeaderModifier:
middlewares[name] = createRequestHeaderModifier(filter.RequestHeaderModifier)
middlewares = append(middlewares, namedMiddleware{
name,
createRequestHeaderModifier(filter.RequestHeaderModifier),
})
case gatev1.GRPCRouteFilterExtensionRef:
name, middleware, err := p.loadHTTPRouteFilterExtensionRef(namespace, filter.ExtensionRef)
if err != nil {
return nil, fmt.Errorf("loading ExtensionRef filter %s: %w", filter.Type, err)
}
middlewares[name] = middleware
middlewares = append(middlewares, namedMiddleware{
name,
middleware,
})
default:
// As per the spec: https://gateway-api.sigs.k8s.io/api-types/httproute/#filters-optional
@ -303,12 +313,11 @@ func (p *Provider) loadGRPCMiddlewares(conf *dynamic.Configuration, namespace, r
}
var middlewareNames []string
for name, middleware := range middlewares {
if middleware != nil {
conf.HTTP.Middlewares[name] = middleware
for _, m := range middlewares {
if m.Config != nil {
conf.HTTP.Middlewares[m.Name] = m.Config
}
middlewareNames = append(middlewareNames, name)
middlewareNames = append(middlewareNames, m.Name)
}
return middlewareNames, nil

View file

@ -123,7 +123,7 @@ func (p *Provider) loadHTTPRoute(ctx context.Context, listener gatewayListener,
for ri, routeRule := range route.Spec.Rules {
// Adding the gateway desc and the entryPoint desc prevents overlapping of routers build from the same routes.
routeKey := provider.Normalize(fmt.Sprintf("%s-%s-%s-%s-%d", route.Namespace, route.Name, listener.GWName, listener.EPName, ri))
routeKey := provider.Normalize(fmt.Sprintf("%s-%s-%s-gw-%s-%s-ep-%s-%d", strings.ToLower(kindHTTPRoute), route.Namespace, route.Name, listener.GWNamespace, listener.GWName, listener.EPName, ri))
for _, match := range routeRule.Matches {
rule, priority := buildMatchRule(hostnames, match)
@ -224,7 +224,7 @@ func (p *Provider) loadService(ctx context.Context, listener gatewayListener, co
namespace = string(*backendRef.Namespace)
}
serviceName := provider.Normalize(namespace + "-" + string(backendRef.Name))
serviceName := provider.Normalize(namespace + "-" + string(backendRef.Name) + "-http")
if err := p.isReferenceGranted(kindHTTPRoute, route.Namespace, group, string(kind), string(backendRef.Name), namespace); err != nil {
return serviceName, &metav1.Condition{
@ -384,39 +384,58 @@ func (p *Provider) loadHTTPBackendRef(namespace string, backendRef gatev1.HTTPBa
}
func (p *Provider) loadMiddlewares(conf *dynamic.Configuration, namespace, routerName string, filters []gatev1.HTTPRouteFilter, pathMatch *gatev1.HTTPPathMatch) ([]string, error) {
type namedMiddleware struct {
Name string
Config *dynamic.Middleware
}
pm := ptr.Deref(pathMatch, gatev1.HTTPPathMatch{
Type: ptr.To(gatev1.PathMatchPathPrefix),
Value: ptr.To("/"),
})
middlewares := make(map[string]*dynamic.Middleware)
var middlewares []namedMiddleware
for i, filter := range filters {
name := fmt.Sprintf("%s-%s-%d", routerName, strings.ToLower(string(filter.Type)), i)
switch filter.Type {
case gatev1.HTTPRouteFilterRequestRedirect:
middlewares[name] = createRequestRedirect(filter.RequestRedirect, pm)
middlewares = append(middlewares, namedMiddleware{
name,
createRequestRedirect(filter.RequestRedirect, pm),
})
case gatev1.HTTPRouteFilterRequestHeaderModifier:
middlewares[name] = createRequestHeaderModifier(filter.RequestHeaderModifier)
middlewares = append(middlewares, namedMiddleware{
name,
createRequestHeaderModifier(filter.RequestHeaderModifier),
})
case gatev1.HTTPRouteFilterResponseHeaderModifier:
middlewares[name] = createResponseHeaderModifier(filter.ResponseHeaderModifier)
middlewares = append(middlewares, namedMiddleware{
name,
createResponseHeaderModifier(filter.ResponseHeaderModifier),
})
case gatev1.HTTPRouteFilterExtensionRef:
name, middleware, err := p.loadHTTPRouteFilterExtensionRef(namespace, filter.ExtensionRef)
if err != nil {
return nil, fmt.Errorf("loading ExtensionRef filter %s: %w", filter.Type, err)
}
middlewares[name] = middleware
middlewares = append(middlewares, namedMiddleware{
name,
middleware,
})
case gatev1.HTTPRouteFilterURLRewrite:
var err error
middleware, err := createURLRewrite(filter.URLRewrite, pm)
if err != nil {
return nil, fmt.Errorf("invalid filter %s: %w", filter.Type, err)
}
middlewares[name] = middleware
middlewares = append(middlewares, namedMiddleware{
name,
middleware,
})
default:
// As per the spec: https://gateway-api.sigs.k8s.io/api-types/httproute/#filters-optional
@ -428,12 +447,11 @@ func (p *Provider) loadMiddlewares(conf *dynamic.Configuration, namespace, route
}
var middlewareNames []string
for name, middleware := range middlewares {
if middleware != nil {
conf.HTTP.Middlewares[name] = middleware
for _, m := range middlewares {
if m.Config != nil {
conf.HTTP.Middlewares[m.Name] = m.Config
}
middlewareNames = append(middlewareNames, name)
middlewareNames = append(middlewareNames, m.Name)
}
return middlewareNames, nil

File diff suppressed because it is too large Load diff

View file

@ -130,7 +130,7 @@ func (p *Provider) loadTCPRoute(listener gatewayListener, route *gatev1alpha2.TC
}
// Adding the gateway desc and the entryPoint desc prevents overlapping of routers build from the same routes.
routeKey := provider.Normalize(fmt.Sprintf("%s-%s-%s-%s-%d", route.Namespace, route.Name, listener.GWName, listener.EPName, ri))
routeKey := provider.Normalize(fmt.Sprintf("%s-%s-%s-gw-%s-%s-ep-%s-%d", strings.ToLower(kindTCPRoute), route.Namespace, route.Name, listener.GWNamespace, listener.GWName, listener.EPName, ri))
// Routing criteria should be introduced at some point.
routerName := makeRouterName("", routeKey)

View file

@ -132,7 +132,7 @@ func (p *Provider) loadTLSRoute(listener gatewayListener, route *gatev1alpha2.TL
}
// Adding the gateway desc and the entryPoint desc prevents overlapping of routers build from the same routes.
routeKey := provider.Normalize(fmt.Sprintf("%s-%s-%s-%s-%d", route.Namespace, route.Name, listener.GWName, listener.EPName, ri))
routeKey := provider.Normalize(fmt.Sprintf("%s-%s-%s-gw-%s-%s-ep-%s-%d", strings.ToLower(kindTLSRoute), route.Namespace, route.Name, listener.GWNamespace, listener.GWName, listener.EPName, ri))
// Routing criteria should be introduced at some point.
routerName := makeRouterName("", routeKey)

View file

@ -12,7 +12,7 @@ import (
// MustParseYaml parses a YAML to objects.
func MustParseYaml(content []byte) []runtime.Object {
acceptedK8sTypes := regexp.MustCompile(`^(Namespace|Deployment|EndpointSlice|Node|Service|ConfigMap|Ingress|IngressRoute|IngressRouteTCP|IngressRouteUDP|Middleware|MiddlewareTCP|Secret|TLSOption|TLSStore|TraefikService|IngressClass|ServersTransport|ServersTransportTCP|GatewayClass|Gateway|HTTPRoute|TCPRoute|TLSRoute|ReferenceGrant|BackendTLSPolicy)$`)
acceptedK8sTypes := regexp.MustCompile(`^(Namespace|Deployment|EndpointSlice|Node|Service|ConfigMap|Ingress|IngressRoute|IngressRouteTCP|IngressRouteUDP|Middleware|MiddlewareTCP|Secret|TLSOption|TLSStore|TraefikService|IngressClass|ServersTransport|ServersTransportTCP|GatewayClass|Gateway|GRPCRoute|HTTPRoute|TCPRoute|TLSRoute|ReferenceGrant|BackendTLSPolicy)$`)
files := strings.Split(string(content), "---\n")
retVal := make([]runtime.Object, 0, len(files))

View file

@ -68,7 +68,7 @@ func (r *ProxyBuilder) Update(newConfigs map[string]*dynamic.ServersTransport) {
}
// Build builds a new ReverseProxy with the given configuration.
func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, passHostHeader bool) (http.Handler, error) {
func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, passHostHeader, preservePath bool) (http.Handler, error) {
proxyURL, err := r.proxy(&http.Request{URL: targetURL})
if err != nil {
return nil, fmt.Errorf("getting proxy: %w", err)
@ -79,18 +79,13 @@ func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, passHostHeader
return nil, fmt.Errorf("getting ServersTransport: %w", err)
}
var responseHeaderTimeout time.Duration
if cfg.ForwardingTimeouts != nil {
responseHeaderTimeout = time.Duration(cfg.ForwardingTimeouts.ResponseHeaderTimeout)
}
tlsConfig, err := r.transportManager.GetTLSConfig(cfgName)
if err != nil {
return nil, fmt.Errorf("getting TLS config: %w", err)
}
pool := r.getPool(cfgName, cfg, tlsConfig, targetURL, proxyURL)
return NewReverseProxy(targetURL, proxyURL, r.debug, passHostHeader, responseHeaderTimeout, pool)
return NewReverseProxy(targetURL, proxyURL, r.debug, passHostHeader, preservePath, pool)
}
func (r *ProxyBuilder) getPool(cfgName string, config *dynamic.ServersTransport, tlsConfig *tls.Config, targetURL *url.URL, proxyURL *url.URL) *connPool {
@ -106,9 +101,11 @@ func (r *ProxyBuilder) getPool(cfgName string, config *dynamic.ServersTransport,
idleConnTimeout := 90 * time.Second
dialTimeout := 30 * time.Second
var responseHeaderTimeout time.Duration
if config.ForwardingTimeouts != nil {
idleConnTimeout = time.Duration(config.ForwardingTimeouts.IdleConnTimeout)
dialTimeout = time.Duration(config.ForwardingTimeouts.DialTimeout)
responseHeaderTimeout = time.Duration(config.ForwardingTimeouts.ResponseHeaderTimeout)
}
proxyDialer := newDialer(dialerConfig{
@ -119,7 +116,7 @@ func (r *ProxyBuilder) getPool(cfgName string, config *dynamic.ServersTransport,
ProxyURL: proxyURL,
}, tlsConfig)
connPool := newConnPool(config.MaxIdleConnsPerHost, idleConnTimeout, func() (net.Conn, error) {
connPool := newConnPool(config.MaxIdleConnsPerHost, idleConnTimeout, responseHeaderTimeout, func() (net.Conn, error) {
return proxyDialer.Dial("tcp", addrFromURL(targetURL))
})

View file

@ -1,42 +1,309 @@
package fast
import (
"bufio"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httputil"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp"
)
// rwWithUpgrade contains a ResponseWriter and an upgradeHandler,
// used to upgrade the connection (e.g. Websockets).
type rwWithUpgrade struct {
RW http.ResponseWriter
Upgrade upgradeHandler
}
// conn is an enriched net.Conn.
type conn struct {
net.Conn
RWCh chan rwWithUpgrade
ErrCh chan error
br *bufio.Reader
idleAt time.Time // the last time it was marked as idle.
idleTimeout time.Duration
responseHeaderTimeout time.Duration
expectedResponse atomic.Bool
broken atomic.Bool
upgraded atomic.Bool
closeMu sync.Mutex
closed bool
closeErr error
bufferPool *pool[[]byte]
limitedReaderPool *pool[*io.LimitedReader]
}
func (c *conn) isExpired() bool {
// Read reads data from the connection.
// Overrides conn Read to use the buffered reader.
func (c *conn) Read(b []byte) (n int, err error) {
return c.br.Read(b)
}
// Close closes the connection.
// Ensures that connection is closed only once,
// to avoid duplicate close error.
func (c *conn) Close() error {
c.closeMu.Lock()
defer c.closeMu.Unlock()
if c.closed {
return c.closeErr
}
c.closed = true
c.closeErr = c.Conn.Close()
return c.closeErr
}
// isStale returns whether the connection is in an invalid state (i.e. expired/broken).
func (c *conn) isStale() bool {
expTime := c.idleAt.Add(c.idleTimeout)
return c.idleTimeout > 0 && time.Now().After(expTime)
return c.idleTimeout > 0 && time.Now().After(expTime) || c.broken.Load()
}
// isUpgraded returns whether this connection has been upgraded (e.g. Websocket).
// An upgraded connection should not be reused and putted back in the connection pool.
func (c *conn) isUpgraded() bool {
return c.upgraded.Load()
}
// readLoop handles the successive HTTP response read operations on the connection,
// and watches for unsolicited bytes or connection errors when idle.
func (c *conn) readLoop() {
defer c.Close()
for {
_, err := c.br.Peek(1)
if err != nil {
select {
// An error occurred while a response was expected to be handled.
case <-c.RWCh:
c.ErrCh <- err
// An error occurred on an idle connection.
default:
c.broken.Store(true)
}
return
}
// Unsolicited response received on an idle connection.
if !c.expectedResponse.Load() {
c.broken.Store(true)
return
}
r := <-c.RWCh
if err = c.handleResponse(r); err != nil {
c.ErrCh <- err
return
}
c.expectedResponse.Store(false)
c.ErrCh <- nil
}
}
func (c *conn) handleResponse(r rwWithUpgrade) error {
res := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(res)
res.Header.SetNoDefaultContentType(true)
for {
var (
timer *time.Timer
errTimeout atomic.Pointer[timeoutError]
)
if c.responseHeaderTimeout > 0 {
timer = time.AfterFunc(c.responseHeaderTimeout, func() {
errTimeout.Store(&timeoutError{errors.New("timeout awaiting response headers")})
c.Close() // This close call is needed to interrupt the read operation below when the timeout is over.
})
}
res.Header.SetNoDefaultContentType(true)
if err := res.Header.Read(c.br); err != nil {
if c.responseHeaderTimeout > 0 {
if errT := errTimeout.Load(); errT != nil {
return errT
}
}
return err
}
if timer != nil {
timer.Stop()
}
fixPragmaCacheControl(&res.Header)
resCode := res.StatusCode()
is1xx := 100 <= resCode && resCode <= 199
// treat 101 as a terminal status, see issue 26161
is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols
if is1xxNonTerminal {
removeConnectionHeaders(&res.Header)
h := r.RW.Header()
for _, header := range hopHeaders {
res.Header.Del(header)
}
res.Header.VisitAll(func(key, value []byte) {
r.RW.Header().Add(string(key), string(value))
})
r.RW.WriteHeader(res.StatusCode())
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
for k := range h {
delete(h, k)
}
res.Reset()
res.Header.Reset()
res.Header.SetNoDefaultContentType(true)
continue
}
break
}
announcedTrailers := res.Header.Peek("Trailer")
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
if res.StatusCode() == http.StatusSwitchingProtocols {
r.Upgrade(r.RW, res, c)
c.upgraded.Store(true) // As the connection has been upgraded, it cannot be added back to the pool.
return nil
}
removeConnectionHeaders(&res.Header)
for _, header := range hopHeaders {
res.Header.Del(header)
}
if len(announcedTrailers) > 0 {
res.Header.Add("Trailer", string(announcedTrailers))
}
res.Header.VisitAll(func(key, value []byte) {
r.RW.Header().Add(string(key), string(value))
})
r.RW.WriteHeader(res.StatusCode())
if res.Header.ContentLength() == 0 {
return nil
}
// When a body is not allowed for a given status code the body is ignored.
// The connection will be marked as broken by the next Peek in the readloop.
if !isBodyAllowedForStatus(res.StatusCode()) {
return nil
}
// Chunked response, Content-Length is set to -1 by FastProxy when "Transfer-Encoding: chunked" header is received.
if res.Header.ContentLength() == -1 {
cbr := httputil.NewChunkedReader(c.br)
b := c.bufferPool.Get()
if b == nil {
b = make([]byte, bufferSize)
}
defer c.bufferPool.Put(b)
if _, err := io.CopyBuffer(&writeFlusher{r.RW}, cbr, b); err != nil {
return err
}
res.Header.Reset()
res.Header.SetNoDefaultContentType(true)
if err := res.Header.ReadTrailer(c.br); err != nil {
return err
}
if res.Header.Len() > 0 {
var announcedTrailersKey []string
if len(announcedTrailers) > 0 {
announcedTrailersKey = strings.Split(string(announcedTrailers), ",")
}
res.Header.VisitAll(func(key, value []byte) {
for _, s := range announcedTrailersKey {
if strings.EqualFold(s, strings.TrimSpace(string(key))) {
r.RW.Header().Add(string(key), string(value))
return
}
}
r.RW.Header().Add(http.TrailerPrefix+string(key), string(value))
})
}
return nil
}
brl := c.limitedReaderPool.Get()
if brl == nil {
brl = &io.LimitedReader{}
}
defer c.limitedReaderPool.Put(brl)
brl.R = c.br
brl.N = int64(res.Header.ContentLength())
b := c.bufferPool.Get()
if b == nil {
b = make([]byte, bufferSize)
}
defer c.bufferPool.Put(b)
if _, err := io.CopyBuffer(r.RW, brl, b); err != nil {
return err
}
return nil
}
// connPool is a net.Conn pool implementation using channels.
type connPool struct {
dialer func() (net.Conn, error)
idleConns chan *conn
idleConnTimeout time.Duration
ticker *time.Ticker
doneCh chan struct{}
dialer func() (net.Conn, error)
idleConns chan *conn
idleConnTimeout time.Duration
responseHeaderTimeout time.Duration
ticker *time.Ticker
bufferPool pool[[]byte]
limitedReaderPool pool[*io.LimitedReader]
doneCh chan struct{}
}
// newConnPool creates a new connPool.
func newConnPool(maxIdleConn int, idleConnTimeout time.Duration, dialer func() (net.Conn, error)) *connPool {
func newConnPool(maxIdleConn int, idleConnTimeout, responseHeaderTimeout time.Duration, dialer func() (net.Conn, error)) *connPool {
c := &connPool{
dialer: dialer,
idleConns: make(chan *conn, maxIdleConn),
idleConnTimeout: idleConnTimeout,
doneCh: make(chan struct{}),
dialer: dialer,
idleConns: make(chan *conn, maxIdleConn),
idleConnTimeout: idleConnTimeout,
responseHeaderTimeout: responseHeaderTimeout,
doneCh: make(chan struct{}),
}
if idleConnTimeout > 0 {
@ -72,22 +339,28 @@ func (c *connPool) AcquireConn() (*conn, error) {
return nil, err
}
if !co.isExpired() {
if !co.isStale() {
return co, nil
}
// As the acquired conn is expired we can close it
// As the acquired conn is stale we can close it
// without putting it again into the pool.
if err := co.Close(); err != nil {
log.Debug().
Err(err).
Msg("Unexpected error while releasing the connection")
Msg("Unexpected error while closing the connection")
}
}
}
// ReleaseConn releases the given net.Conn to the pool.
func (c *connPool) ReleaseConn(co *conn) {
// An upgraded connection cannot be safely reused for another roundTrip,
// thus we are not putting it back to the pool.
if co.isUpgraded() {
return
}
co.idleAt = time.Now()
c.releaseConn(co)
}
@ -97,7 +370,7 @@ func (c *connPool) cleanIdleConns() {
for {
select {
case co := <-c.idleConns:
if !co.isExpired() {
if !co.isStale() {
c.releaseConn(co)
return
}
@ -105,7 +378,7 @@ func (c *connPool) cleanIdleConns() {
if err := co.Close(); err != nil {
log.Debug().
Err(err).
Msg("Unexpected error while releasing the connection")
Msg("Unexpected error while closing the connection")
}
default:
@ -155,9 +428,33 @@ func (c *connPool) askForNewConn(errCh chan<- error) {
return
}
c.releaseConn(&conn{
Conn: co,
idleAt: time.Now(),
idleTimeout: c.idleConnTimeout,
})
newConn := &conn{
Conn: co,
br: bufio.NewReaderSize(co, bufioSize),
idleAt: time.Now(),
idleTimeout: c.idleConnTimeout,
responseHeaderTimeout: c.responseHeaderTimeout,
RWCh: make(chan rwWithUpgrade),
ErrCh: make(chan error),
bufferPool: &c.bufferPool,
limitedReaderPool: &c.limitedReaderPool,
}
go newConn.readLoop()
c.releaseConn(newConn)
}
// isBodyAllowedForStatus reports whether a given response status code
// permits a body. See RFC 7230, section 3.3.
// From https://github.com/golang/go/blame/master/src/net/http/transfer.go#L459
func isBodyAllowedForStatus(status int) bool {
switch {
case status >= 100 && status <= 199:
return false
case status == 204:
return false
case status == 304:
return false
}
return true
}

View file

@ -58,7 +58,7 @@ func TestConnPool_ConnReuse(t *testing.T) {
return &net.TCPConn{}, nil
}
pool := newConnPool(2, 0, dialer)
pool := newConnPool(2, 0, 0, dialer)
test.poolFn(pool)
assert.Equal(t, test.expected, connAlloc)
@ -102,13 +102,16 @@ func TestConnPool_MaxIdleConn(t *testing.T) {
var keepOpenedConn int
dialer := func() (net.Conn, error) {
keepOpenedConn++
return &mockConn{closeFn: func() error {
keepOpenedConn--
return nil
}}, nil
return &mockConn{
doneCh: make(chan struct{}),
closeFn: func() error {
keepOpenedConn--
return nil
},
}, nil
}
pool := newConnPool(test.maxIdleConn, 0, dialer)
pool := newConnPool(test.maxIdleConn, 0, 0, dialer)
test.poolFn(pool)
assert.Equal(t, test.expected, keepOpenedConn)
@ -129,7 +132,7 @@ func TestGC(t *testing.T) {
return c, nil
}
pools["test"] = newConnPool(10, 1*time.Second, dialer)
pools["test"] = newConnPool(10, 1*time.Second, 0, dialer)
runtime.SetFinalizer(pools["test"], func(p *connPool) {
isDestroyed = true
})
@ -149,10 +152,12 @@ func TestGC(t *testing.T) {
type mockConn struct {
closeFn func() error
doneCh chan struct{} // makes sure that the readLoop is blocking avoiding close.
}
func (m *mockConn) Read(_ []byte) (n int, err error) {
panic("implement me")
<-m.doneCh
return 0, nil
}
func (m *mockConn) Write(_ []byte) (n int, err error) {
@ -160,6 +165,7 @@ func (m *mockConn) Write(_ []byte) (n int, err error) {
}
func (m *mockConn) Close() error {
defer close(m.doneCh)
if m.closeFn != nil {
return m.closeFn()
}

View file

@ -4,18 +4,14 @@ import (
"bufio"
"bytes"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptrace"
"net/http/httputil"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/rs/zerolog/log"
proxyhttputil "github.com/traefik/traefik/v3/pkg/proxy/httputil"
@ -57,15 +53,6 @@ func (p *pool[T]) Put(x T) {
p.pool.Put(x)
}
type buffConn struct {
*bufio.Reader
net.Conn
}
func (b buffConn) Read(p []byte) (int, error) {
return b.Reader.Read(p)
}
type writeDetector struct {
net.Conn
@ -112,20 +99,17 @@ type ReverseProxy struct {
connPool *connPool
bufferPool pool[[]byte]
readerPool pool[*bufio.Reader]
writerPool pool[*bufio.Writer]
limitReaderPool pool[*io.LimitedReader]
writerPool pool[*bufio.Writer]
proxyAuth string
targetURL *url.URL
passHostHeader bool
responseHeaderTimeout time.Duration
targetURL *url.URL
passHostHeader bool
preservePath bool
}
// NewReverseProxy creates a new ReverseProxy.
func NewReverseProxy(targetURL *url.URL, proxyURL *url.URL, debug, passHostHeader bool, responseHeaderTimeout time.Duration, connPool *connPool) (*ReverseProxy, error) {
func NewReverseProxy(targetURL, proxyURL *url.URL, debug, passHostHeader, preservePath bool, connPool *connPool) (*ReverseProxy, error) {
var proxyAuth string
if proxyURL != nil && proxyURL.User != nil && targetURL.Scheme == "http" {
username := proxyURL.User.Username()
@ -134,12 +118,12 @@ func NewReverseProxy(targetURL *url.URL, proxyURL *url.URL, debug, passHostHeade
}
return &ReverseProxy{
debug: debug,
passHostHeader: passHostHeader,
targetURL: targetURL,
proxyAuth: proxyAuth,
connPool: connPool,
responseHeaderTimeout: responseHeaderTimeout,
debug: debug,
passHostHeader: passHostHeader,
preservePath: preservePath,
targetURL: targetURL,
proxyAuth: proxyAuth,
connPool: connPool,
}, nil
}
@ -207,6 +191,11 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
u2.Path = u.Path
u2.RawPath = u.RawPath
if p.preservePath {
u2.Path, u2.RawPath = proxyhttputil.JoinURLPath(p.targetURL, u)
}
u2.RawQuery = strings.ReplaceAll(u.RawQuery, ";", "&")
outReq.SetHost(u2.Host)
@ -266,8 +255,15 @@ func (p *ReverseProxy) roundTrip(rw http.ResponseWriter, req *http.Request, outR
return fmt.Errorf("acquire connection: %w", err)
}
// Before writing the request,
// we mark the conn as expecting to handle a response.
co.expectedResponse.Store(true)
wd := &writeDetector{Conn: co}
// TODO: do not wait to write the full request before reading the response (to handle "100 Continue").
// TODO: this is currently impossible with fasthttp to write the request partially (headers only).
// Currently, writing the request fully is a mandatory step before handling the response.
err = p.writeRequest(wd, outReq)
if wd.written && trace != nil && trace.WroteRequest != nil {
// WroteRequest hook is used by the tracing middleware to detect if the request has been written.
@ -286,169 +282,17 @@ func (p *ReverseProxy) roundTrip(rw http.ResponseWriter, req *http.Request, outR
}
}
br := p.readerPool.Get()
if br == nil {
br = bufio.NewReaderSize(co, bufioSize)
}
defer p.readerPool.Put(br)
br.Reset(co)
res := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(res)
res.Header.SetNoDefaultContentType(true)
for {
var timer *time.Timer
errTimeout := atomic.Pointer[timeoutError]{}
if p.responseHeaderTimeout > 0 {
timer = time.AfterFunc(p.responseHeaderTimeout, func() {
errTimeout.Store(&timeoutError{errors.New("timeout awaiting response headers")})
co.Close()
})
}
res.Header.SetNoDefaultContentType(true)
if err := res.Header.Read(br); err != nil {
if p.responseHeaderTimeout > 0 {
if errT := errTimeout.Load(); errT != nil {
return errT
}
}
co.Close()
return err
}
if timer != nil {
timer.Stop()
}
fixPragmaCacheControl(&res.Header)
resCode := res.StatusCode()
is1xx := 100 <= resCode && resCode <= 199
// treat 101 as a terminal status, see issue 26161
is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols
if is1xxNonTerminal {
removeConnectionHeaders(&res.Header)
h := rw.Header()
for _, header := range hopHeaders {
res.Header.Del(header)
}
res.Header.VisitAll(func(key, value []byte) {
rw.Header().Add(string(key), string(value))
})
rw.WriteHeader(res.StatusCode())
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
for k := range h {
delete(h, k)
}
res.Reset()
res.Header.Reset()
res.Header.SetNoDefaultContentType(true)
continue
}
break
// Sending the responseWriter unlocks the connection readLoop, to handle the response.
co.RWCh <- rwWithUpgrade{
RW: rw,
Upgrade: upgradeResponseHandler(req.Context(), reqUpType),
}
announcedTrailers := res.Header.Peek("Trailer")
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
if res.StatusCode() == http.StatusSwitchingProtocols {
// As the connection has been hijacked, it cannot be added back to the pool.
handleUpgradeResponse(rw, req, reqUpType, res, buffConn{Conn: co, Reader: br})
return nil
}
removeConnectionHeaders(&res.Header)
for _, header := range hopHeaders {
res.Header.Del(header)
}
if len(announcedTrailers) > 0 {
res.Header.Add("Trailer", string(announcedTrailers))
}
res.Header.VisitAll(func(key, value []byte) {
rw.Header().Add(string(key), string(value))
})
rw.WriteHeader(res.StatusCode())
// Chunked response, Content-Length is set to -1 by FastProxy when "Transfer-Encoding: chunked" header is received.
if res.Header.ContentLength() == -1 {
cbr := httputil.NewChunkedReader(br)
b := p.bufferPool.Get()
if b == nil {
b = make([]byte, bufferSize)
}
defer p.bufferPool.Put(b)
if _, err := io.CopyBuffer(&writeFlusher{rw}, cbr, b); err != nil {
co.Close()
return err
}
res.Header.Reset()
res.Header.SetNoDefaultContentType(true)
if err := res.Header.ReadTrailer(br); err != nil {
co.Close()
return err
}
if res.Header.Len() > 0 {
var announcedTrailersKey []string
if len(announcedTrailers) > 0 {
announcedTrailersKey = strings.Split(string(announcedTrailers), ",")
}
res.Header.VisitAll(func(key, value []byte) {
for _, s := range announcedTrailersKey {
if strings.EqualFold(s, strings.TrimSpace(string(key))) {
rw.Header().Add(string(key), string(value))
return
}
}
rw.Header().Add(http.TrailerPrefix+string(key), string(value))
})
}
p.connPool.ReleaseConn(co)
return nil
}
brl := p.limitReaderPool.Get()
if brl == nil {
brl = &io.LimitedReader{}
}
defer p.limitReaderPool.Put(brl)
brl.R = br
brl.N = int64(res.Header.ContentLength())
b := p.bufferPool.Get()
if b == nil {
b = make([]byte, bufferSize)
}
defer p.bufferPool.Put(b)
if _, err := io.CopyBuffer(rw, brl, b); err != nil {
co.Close()
if err := <-co.ErrCh; err != nil {
return err
}
p.connPool.ReleaseConn(co)
return nil
}

View file

@ -230,7 +230,7 @@ func TestProxyFromEnvironment(t *testing.T) {
return u, nil
}
reverseProxy, err := builder.Build("foo", testhelpers.MustParseURL(backendURL), false)
reverseProxy, err := builder.Build("foo", testhelpers.MustParseURL(backendURL), false, false)
require.NoError(t, err)
reverseProxyServer := httptest.NewServer(reverseProxy)
@ -252,6 +252,32 @@ func TestProxyFromEnvironment(t *testing.T) {
}
}
func TestPreservePath(t *testing.T) {
var callCount int
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
callCount++
assert.Equal(t, "/base/foo/bar", req.URL.Path)
assert.Equal(t, "/base/foo%2Fbar", req.URL.RawPath)
}))
t.Cleanup(server.Close)
builder := NewProxyBuilder(&transportManagerMock{}, static.FastProxyConfig{})
serverURL, err := url.JoinPath(server.URL, "base")
require.NoError(t, err)
proxyHandler, err := builder.Build("", testhelpers.MustParseURL(serverURL), true, true)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "/foo%2Fbar", http.NoBody)
res := httptest.NewRecorder()
proxyHandler.ServeHTTP(res, req)
assert.Equal(t, 1, callCount)
assert.Equal(t, http.StatusOK, res.Code)
}
func newCertificate(t *testing.T, domain string) *tls.Certificate {
t.Helper()

View file

@ -362,7 +362,7 @@ func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
u := parseURI(t, srv.URL)
f, err := NewReverseProxy(u, nil, true, false, 0, newConnPool(1, 0, func() (net.Conn, error) {
f, err := NewReverseProxy(u, nil, true, false, false, newConnPool(1, 0, 0, func() (net.Conn, error) {
return net.Dial("tcp", u.Host)
}))
require.NoError(t, err)
@ -434,7 +434,7 @@ func TestWebSocketUpgradeFailed(t *testing.T) {
defer srv.Close()
u := parseURI(t, srv.URL)
f, err := NewReverseProxy(u, nil, true, false, 0, newConnPool(1, 0, func() (net.Conn, error) {
f, err := NewReverseProxy(u, nil, true, false, false, newConnPool(1, 0, 0, func() (net.Conn, error) {
return net.Dial("tcp", u.Host)
}))
require.NoError(t, err)
@ -663,7 +663,7 @@ func parseURI(t *testing.T, uri string) *url.URL {
func createConnectionPool(target string, tlsConfig *tls.Config) *connPool {
u := testhelpers.MustParseURL(target)
return newConnPool(200, 0, func() (net.Conn, error) {
return newConnPool(200, 0, 0, func() (net.Conn, error) {
if tlsConfig != nil {
return tls.Dial("tcp", u.Host, tlsConfig)
}
@ -676,7 +676,7 @@ func createProxyWithForwarder(t *testing.T, uri string, pool *connPool) *httptes
t.Helper()
u := parseURI(t, uri)
proxy, err := NewReverseProxy(u, nil, false, true, 0, pool)
proxy, err := NewReverseProxy(u, nil, false, true, false, pool)
require.NoError(t, err)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {

View file

@ -2,6 +2,7 @@ package fast
import (
"bytes"
"context"
"fmt"
"io"
"net"
@ -19,72 +20,75 @@ type switchProtocolCopier struct {
user, backend io.ReadWriter
}
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
func (c switchProtocolCopier) copyFromBackend(errCh chan<- error) {
_, err := io.Copy(c.user, c.backend)
errc <- err
errCh <- err
}
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
func (c switchProtocolCopier) copyToBackend(errCh chan<- error) {
_, err := io.Copy(c.backend, c.user)
errc <- err
errCh <- err
}
func handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, reqUpType string, res *fasthttp.Response, backConn net.Conn) {
defer backConn.Close()
type upgradeHandler func(rw http.ResponseWriter, res *fasthttp.Response, backConn net.Conn)
resUpType := upgradeTypeFastHTTP(&res.Header)
func upgradeResponseHandler(ctx context.Context, reqUpType string) upgradeHandler {
return func(rw http.ResponseWriter, res *fasthttp.Response, backConn net.Conn) {
resUpType := upgradeTypeFastHTTP(&res.Header)
if !strings.EqualFold(reqUpType, resUpType) {
httputil.ErrorHandler(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
return
}
hj, ok := rw.(http.Hijacker)
if !ok {
httputil.ErrorHandler(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
return
}
backConnCloseCh := make(chan bool)
go func() {
// Ensure that the cancellation of a request closes the backend.
// See issue https://golang.org/issue/35559.
select {
case <-req.Context().Done():
case <-backConnCloseCh:
if !strings.EqualFold(reqUpType, resUpType) {
httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
backConn.Close()
return
}
_ = backConn.Close()
}()
defer close(backConnCloseCh)
conn, brw, err := hj.Hijack()
if err != nil {
httputil.ErrorHandler(rw, req, fmt.Errorf("hijack failed on protocol switch: %w", err))
return
}
defer conn.Close()
for k, values := range rw.Header() {
for _, v := range values {
res.Header.Add(k, v)
hj, ok := rw.(http.Hijacker)
if !ok {
httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
backConn.Close()
return
}
}
backConnCloseCh := make(chan bool)
go func() {
// Ensure that the cancellation of a request closes the backend.
// See issue https://golang.org/issue/35559.
select {
case <-ctx.Done():
case <-backConnCloseCh:
}
_ = backConn.Close()
}()
defer close(backConnCloseCh)
if err := res.Header.Write(brw.Writer); err != nil {
httputil.ErrorHandler(rw, req, fmt.Errorf("response write: %w", err))
return
}
conn, brw, err := hj.Hijack()
if err != nil {
httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("hijack failed on protocol switch: %w", err))
return
}
defer conn.Close()
if err := brw.Flush(); err != nil {
httputil.ErrorHandler(rw, req, fmt.Errorf("response flush: %w", err))
return
}
for k, values := range rw.Header() {
for _, v := range values {
res.Header.Add(k, v)
}
}
errc := make(chan error, 1)
spc := switchProtocolCopier{user: conn, backend: backConn}
go spc.copyToBackend(errc)
go spc.copyFromBackend(errc)
<-errc
if err := res.Header.Write(brw.Writer); err != nil {
httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("response write: %w", err))
return
}
if err := brw.Flush(); err != nil {
httputil.ErrorHandlerWithContext(ctx, rw, fmt.Errorf("response flush: %w", err))
return
}
errCh := make(chan error, 1)
spc := switchProtocolCopier{user: conn, backend: backConn}
go spc.copyToBackend(errCh)
go spc.copyFromBackend(errCh)
<-errCh
}
}
func upgradeType(h http.Header) string {

View file

@ -38,7 +38,7 @@ func NewProxyBuilder(transportManager TransportManager, semConvMetricsRegistry *
func (r *ProxyBuilder) Update(_ map[string]*dynamic.ServersTransport) {}
// Build builds a new httputil.ReverseProxy with the given configuration.
func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, shouldObserve, passHostHeader bool, flushInterval time.Duration) (http.Handler, error) {
func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, shouldObserve, passHostHeader, preservePath bool, flushInterval time.Duration) (http.Handler, error) {
roundTripper, err := r.transportManager.GetRoundTripper(cfgName)
if err != nil {
return nil, fmt.Errorf("getting RoundTripper: %w", err)
@ -50,5 +50,5 @@ func (r *ProxyBuilder) Build(cfgName string, targetURL *url.URL, shouldObserve,
roundTripper = newObservabilityRoundTripper(r.semConvMetricsRegistry, roundTripper)
}
return buildSingleHostProxy(targetURL, passHostHeader, flushInterval, roundTripper, r.bufferPool), nil
return buildSingleHostProxy(targetURL, passHostHeader, preservePath, flushInterval, roundTripper, r.bufferPool), nil
}

View file

@ -23,7 +23,7 @@ func TestEscapedPath(t *testing.T) {
roundTrippers: map[string]http.RoundTripper{"default": &http.Transport{}},
}
p, err := NewProxyBuilder(transportManager, nil).Build("default", testhelpers.MustParseURL(srv.URL), false, true, 0)
p, err := NewProxyBuilder(transportManager, nil).Build("default", testhelpers.MustParseURL(srv.URL), false, true, false, 0)
require.NoError(t, err)
proxy := httptest.NewServer(http.HandlerFunc(p.ServeHTTP))

View file

@ -15,15 +15,17 @@ import (
"golang.org/x/net/http/httpguts"
)
// StatusClientClosedRequest non-standard HTTP status code for client disconnection.
const StatusClientClosedRequest = 499
const (
// StatusClientClosedRequest non-standard HTTP status code for client disconnection.
StatusClientClosedRequest = 499
// StatusClientClosedRequestText non-standard HTTP status for client disconnection.
const StatusClientClosedRequestText = "Client Closed Request"
// StatusClientClosedRequestText non-standard HTTP status for client disconnection.
StatusClientClosedRequestText = "Client Closed Request"
)
func buildSingleHostProxy(target *url.URL, passHostHeader bool, flushInterval time.Duration, roundTripper http.RoundTripper, bufferPool httputil.BufferPool) http.Handler {
func buildSingleHostProxy(target *url.URL, passHostHeader bool, preservePath bool, flushInterval time.Duration, roundTripper http.RoundTripper, bufferPool httputil.BufferPool) http.Handler {
return &httputil.ReverseProxy{
Director: directorBuilder(target, passHostHeader),
Director: directorBuilder(target, passHostHeader, preservePath),
Transport: roundTripper,
FlushInterval: flushInterval,
BufferPool: bufferPool,
@ -31,7 +33,7 @@ func buildSingleHostProxy(target *url.URL, passHostHeader bool, flushInterval ti
}
}
func directorBuilder(target *url.URL, passHostHeader bool) func(req *http.Request) {
func directorBuilder(target *url.URL, passHostHeader bool, preservePath bool) func(req *http.Request) {
return func(outReq *http.Request) {
outReq.URL.Scheme = target.Scheme
outReq.URL.Host = target.Host
@ -46,6 +48,11 @@ func directorBuilder(target *url.URL, passHostHeader bool) func(req *http.Reques
outReq.URL.Path = u.Path
outReq.URL.RawPath = u.RawPath
if preservePath {
outReq.URL.Path, outReq.URL.RawPath = JoinURLPath(target, u)
}
// If a plugin/middleware adds semicolons in query params, they should be urlEncoded.
outReq.URL.RawQuery = strings.ReplaceAll(u.RawQuery, ";", "&")
outReq.RequestURI = "" // Outgoing request should not have RequestURI
@ -54,7 +61,7 @@ func directorBuilder(target *url.URL, passHostHeader bool) func(req *http.Reques
outReq.ProtoMajor = 1
outReq.ProtoMinor = 1
// Do not pass client Host header unless optsetter PassHostHeader is set.
// Do not pass client Host header unless option PassHostHeader is set.
if !passHostHeader {
outReq.Host = outReq.URL.Host
}
@ -95,9 +102,14 @@ func isWebSocketUpgrade(req *http.Request) bool {
// ErrorHandler is the http.Handler called when something goes wrong when forwarding the request.
func ErrorHandler(w http.ResponseWriter, req *http.Request, err error) {
ErrorHandlerWithContext(req.Context(), w, err)
}
// ErrorHandlerWithContext is the http.Handler called when something goes wrong when forwarding the request.
func ErrorHandlerWithContext(ctx context.Context, w http.ResponseWriter, err error) {
statusCode := ComputeStatusCode(err)
logger := log.Ctx(req.Context())
logger := log.Ctx(ctx)
logger.Debug().Err(err).Msgf("%d %s", statusCode, statusText(statusCode))
w.WriteHeader(statusCode)
@ -106,6 +118,13 @@ func ErrorHandler(w http.ResponseWriter, req *http.Request, err error) {
}
}
func statusText(statusCode int) string {
if statusCode == StatusClientClosedRequest {
return StatusClientClosedRequestText
}
return http.StatusText(statusCode)
}
// ComputeStatusCode computes the HTTP status code according to the given error.
func ComputeStatusCode(err error) int {
switch {
@ -127,9 +146,38 @@ func ComputeStatusCode(err error) int {
return http.StatusInternalServerError
}
func statusText(statusCode int) string {
if statusCode == StatusClientClosedRequest {
return StatusClientClosedRequestText
// JoinURLPath computes the joined path and raw path of the given URLs.
// From https://github.com/golang/go/blob/b521ebb55a9b26c8824b219376c7f91f7cda6ec2/src/net/http/httputil/reverseproxy.go#L221
func JoinURLPath(a, b *url.URL) (path, rawpath string) {
if a.RawPath == "" && b.RawPath == "" {
return singleJoiningSlash(a.Path, b.Path), ""
}
return http.StatusText(statusCode)
// Same as singleJoiningSlash, but uses EscapedPath to determine
// whether a slash should be added
apath := a.EscapedPath()
bpath := b.EscapedPath()
aslash := strings.HasSuffix(apath, "/")
bslash := strings.HasPrefix(bpath, "/")
switch {
case aslash && bslash:
return a.Path + b.Path[1:], apath + bpath[1:]
case !aslash && !bslash:
return a.Path + "/" + b.Path, apath + "/" + bpath
}
return a.Path + b.Path, apath + bpath
}
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}

View file

@ -0,0 +1,102 @@
package httputil
import (
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/traefik/traefik/v3/pkg/testhelpers"
)
func Test_directorBuilder(t *testing.T) {
tests := []struct {
name string
target *url.URL
passHostHeader bool
preservePath bool
incomingURL string
expectedScheme string
expectedHost string
expectedPath string
expectedRawPath string
expectedQuery string
}{
{
name: "Basic proxy",
target: testhelpers.MustParseURL("http://example.com"),
passHostHeader: false,
preservePath: false,
incomingURL: "http://localhost/test?param=value",
expectedScheme: "http",
expectedHost: "example.com",
expectedPath: "/test",
expectedQuery: "param=value",
},
{
name: "HTTPS target",
target: testhelpers.MustParseURL("https://secure.example.com"),
passHostHeader: false,
preservePath: false,
incomingURL: "http://localhost/secure",
expectedScheme: "https",
expectedHost: "secure.example.com",
expectedPath: "/secure",
},
{
name: "PassHostHeader",
target: testhelpers.MustParseURL("http://example.com"),
passHostHeader: true,
preservePath: false,
incomingURL: "http://original.host/test",
expectedScheme: "http",
expectedHost: "original.host",
expectedPath: "/test",
},
{
name: "Preserve path",
target: testhelpers.MustParseURL("http://example.com/base"),
passHostHeader: false,
preservePath: true,
incomingURL: "http://localhost/foo%2Fbar",
expectedScheme: "http",
expectedHost: "example.com",
expectedPath: "/base/foo/bar",
expectedRawPath: "/base/foo%2Fbar",
},
{
name: "Handle semicolons in query",
target: testhelpers.MustParseURL("http://example.com"),
passHostHeader: false,
preservePath: false,
incomingURL: "http://localhost/test?param1=value1;param2=value2",
expectedScheme: "http",
expectedHost: "example.com",
expectedPath: "/test",
expectedQuery: "param1=value1&param2=value2",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
director := directorBuilder(test.target, test.passHostHeader, test.preservePath)
req := httptest.NewRequest(http.MethodGet, test.incomingURL, http.NoBody)
director(req)
assert.Equal(t, test.expectedScheme, req.URL.Scheme)
assert.Equal(t, test.expectedHost, req.Host)
assert.Equal(t, test.expectedPath, req.URL.Path)
assert.Equal(t, test.expectedRawPath, req.URL.RawPath)
assert.Equal(t, test.expectedQuery, req.URL.RawQuery)
assert.Empty(t, req.RequestURI)
assert.Equal(t, "HTTP/1.1", req.Proto)
assert.Equal(t, 1, req.ProtoMajor)
assert.Equal(t, 1, req.ProtoMinor)
assert.False(t, !test.passHostHeader && req.Host != req.URL.Host)
})
}
}

View file

@ -298,9 +298,8 @@ func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
},
}
p, err := NewProxyBuilder(transportManager, nil).Build("default@internal", testhelpers.MustParseURL(srv.URL), false, true, 0)
p, err := NewProxyBuilder(transportManager, nil).Build("default@internal", testhelpers.MustParseURL(srv.URL), false, true, false, 0)
require.NoError(t, err)
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL = testhelpers.MustParseURL(srv.URL)
w.Header().Set("HEADER-KEY", "HEADER-VALUE")
@ -355,9 +354,8 @@ func TestWebSocketUpgradeFailed(t *testing.T) {
},
}
p, err := NewProxyBuilder(transportManager, nil).Build("default@internal", testhelpers.MustParseURL(srv.URL), false, true, 0)
p, err := NewProxyBuilder(transportManager, nil).Build("default@internal", testhelpers.MustParseURL(srv.URL), false, true, false, 0)
require.NoError(t, err)
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path // keep the original path
@ -588,7 +586,7 @@ func createProxyWithForwarder(t *testing.T, uri string, transport http.RoundTrip
roundTrippers: map[string]http.RoundTripper{"fwd": transport},
}
p, err := NewProxyBuilder(transportManager, nil).Build("fwd", u, false, true, 0)
p, err := NewProxyBuilder(transportManager, nil).Build("fwd", u, false, true, false, 0)
require.NoError(t, err)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {

View file

@ -45,7 +45,7 @@ func (b *SmartBuilder) Update(newConfigs map[string]*dynamic.ServersTransport) {
}
// Build builds an HTTP proxy for the given URL using the ServersTransport with the given name.
func (b *SmartBuilder) Build(configName string, targetURL *url.URL, shouldObserve, passHostHeader bool, flushInterval time.Duration) (http.Handler, error) {
func (b *SmartBuilder) Build(configName string, targetURL *url.URL, shouldObserve, passHostHeader, preservePath bool, flushInterval time.Duration) (http.Handler, error) {
serversTransport, err := b.transportManager.Get(configName)
if err != nil {
return nil, fmt.Errorf("getting ServersTransport: %w", err)
@ -55,7 +55,7 @@ func (b *SmartBuilder) Build(configName string, targetURL *url.URL, shouldObserv
// For the https scheme we cannot guess if the backend communication will use HTTP2,
// thus we check if HTTP/2 is disabled to use the fast proxy implementation when this is possible.
if targetURL.Scheme == "h2c" || (targetURL.Scheme == "https" && !serversTransport.DisableHTTP2) {
return b.proxyBuilder.Build(configName, targetURL, shouldObserve, passHostHeader, flushInterval)
return b.proxyBuilder.Build(configName, targetURL, shouldObserve, passHostHeader, preservePath, flushInterval)
}
return b.fastProxyBuilder.Build(configName, targetURL, passHostHeader)
return b.fastProxyBuilder.Build(configName, targetURL, passHostHeader, preservePath)
}

View file

@ -101,7 +101,7 @@ func TestSmartBuilder_Build(t *testing.T) {
httpProxyBuilder := httputil.NewProxyBuilder(transportManager, nil)
proxyBuilder := NewSmartBuilder(transportManager, httpProxyBuilder, test.fastProxyConfig)
proxyHandler, err := proxyBuilder.Build("test", targetURL, false, false, time.Second)
proxyHandler, err := proxyBuilder.Build("test", targetURL, false, false, false, time.Second)
require.NoError(t, err)
rw := httptest.NewRecorder()

View file

@ -897,7 +897,7 @@ func BenchmarkService(b *testing.B) {
type proxyBuilderMock struct{}
func (p proxyBuilderMock) Build(_ string, _ *url.URL, _, _ bool, _ time.Duration) (http.Handler, error) {
func (p proxyBuilderMock) Build(_ string, _ *url.URL, _, _, _ bool, _ time.Duration) (http.Handler, error) {
return http.HandlerFunc(func(responseWriter http.ResponseWriter, req *http.Request) {}), nil
}

View file

@ -254,7 +254,7 @@ func TestInternalServices(t *testing.T) {
type proxyBuilderMock struct{}
func (p proxyBuilderMock) Build(_ string, _ *url.URL, _, _ bool, _ time.Duration) (http.Handler, error) {
func (p proxyBuilderMock) Build(_ string, _ *url.URL, _, _, _ bool, _ time.Duration) (http.Handler, error) {
return http.HandlerFunc(func(responseWriter http.ResponseWriter, req *http.Request) {}), nil
}

View file

@ -42,7 +42,7 @@ const (
// ProxyBuilder builds reverse proxy handlers.
type ProxyBuilder interface {
Build(cfgName string, targetURL *url.URL, shouldObserve, passHostHeader bool, flushInterval time.Duration) (http.Handler, error)
Build(cfgName string, targetURL *url.URL, shouldObserve, passHostHeader, preservePath bool, flushInterval time.Duration) (http.Handler, error)
Update(configs map[string]*dynamic.ServersTransport)
}
@ -338,7 +338,7 @@ func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName
qualifiedSvcName := provider.GetQualifiedName(ctx, serviceName)
shouldObserve := m.observabilityMgr.ShouldAddTracing(qualifiedSvcName) || m.observabilityMgr.ShouldAddMetrics(qualifiedSvcName)
proxy, err := m.proxyBuilder.Build(service.ServersTransport, target, shouldObserve, passHostHeader, flushInterval)
proxy, err := m.proxyBuilder.Build(service.ServersTransport, target, shouldObserve, passHostHeader, server.PreservePath, flushInterval)
if err != nil {
return nil, fmt.Errorf("error building proxy for server URL %s: %w", server.URL, err)
}