1
0
Fork 0

Merge branch v2.11 into v3.1

This commit is contained in:
kevinpollet 2024-09-16 16:24:08 +02:00
commit 093989fc14
No known key found for this signature in database
GPG key ID: 0C9A5DDD1B292453
66 changed files with 904 additions and 136 deletions

View file

@ -106,15 +106,28 @@ func NewHandler(config *types.AccessLog) (*Handler, error) {
Level: logrus.InfoLevel,
}
// Transform headers names in config to a canonical form, to be used as is without further transformations.
if config.Fields != nil && config.Fields.Headers != nil && len(config.Fields.Headers.Names) > 0 {
fields := map[string]string{}
// Transform header names to a canonical form, to be used as is without further transformations,
// and transform field names to lower case, to enable case-insensitive lookup.
if config.Fields != nil {
if len(config.Fields.Names) > 0 {
fields := map[string]string{}
for h, v := range config.Fields.Headers.Names {
fields[textproto.CanonicalMIMEHeaderKey(h)] = v
for h, v := range config.Fields.Names {
fields[strings.ToLower(h)] = v
}
config.Fields.Names = fields
}
config.Fields.Headers.Names = fields
if config.Fields.Headers != nil && len(config.Fields.Headers.Names) > 0 {
fields := map[string]string{}
for h, v := range config.Fields.Headers.Names {
fields[textproto.CanonicalMIMEHeaderKey(h)] = v
}
config.Fields.Headers.Names = fields
}
}
logHandler := &Handler{
@ -184,16 +197,6 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http
},
}
defer func() {
if h.config.BufferingSize > 0 {
h.logHandlerChan <- handlerParams{
logDataTable: logDataTable,
}
return
}
h.logTheRoundTrip(logDataTable)
}()
reqWithDataTable := req.WithContext(context.WithValue(req.Context(), DataTableKey, logDataTable))
core[RequestCount] = nextRequestCount()
@ -238,19 +241,30 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http
return
}
defer func() {
logDataTable.DownstreamResponse = downstreamResponse{
headers: rw.Header().Clone(),
}
logDataTable.DownstreamResponse.status = capt.StatusCode()
logDataTable.DownstreamResponse.size = capt.ResponseSize()
logDataTable.Request.size = capt.RequestSize()
if _, ok := core[ClientUsername]; !ok {
core[ClientUsername] = usernameIfPresent(reqWithDataTable.URL)
}
if h.config.BufferingSize > 0 {
h.logHandlerChan <- handlerParams{
logDataTable: logDataTable,
}
return
}
h.logTheRoundTrip(logDataTable)
}()
next.ServeHTTP(rw, reqWithDataTable)
if _, ok := core[ClientUsername]; !ok {
core[ClientUsername] = usernameIfPresent(reqWithDataTable.URL)
}
logDataTable.DownstreamResponse = downstreamResponse{
headers: rw.Header().Clone(),
}
logDataTable.DownstreamResponse.status = capt.StatusCode()
logDataTable.DownstreamResponse.size = capt.ResponseSize()
logDataTable.Request.size = capt.RequestSize()
}
// Close closes the Logger (i.e. the file, drain logHandlerChan, etc).
@ -334,7 +348,7 @@ func (h *Handler) logTheRoundTrip(logDataTable *LogData) {
fields := logrus.Fields{}
for k, v := range logDataTable.Core {
if h.config.Fields.Keep(k) {
if h.config.Fields.Keep(strings.ToLower(k)) {
fields[k] = v
}
}

View file

@ -2,6 +2,7 @@ package accesslog
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
@ -24,6 +25,7 @@ import (
"github.com/stretchr/testify/require"
ptypes "github.com/traefik/paerser/types"
"github.com/traefik/traefik/v3/pkg/middlewares/capture"
"github.com/traefik/traefik/v3/pkg/middlewares/recovery"
"github.com/traefik/traefik/v3/pkg/types"
)
@ -164,7 +166,7 @@ func TestLoggerHeaderFields(t *testing.T) {
},
},
{
desc: "with case insensitive match on header name",
desc: "with case-insensitive match on header name",
header: "User-Agent",
expected: types.AccessLogKeep,
accessLogFields: types.AccessLogFields{
@ -465,6 +467,32 @@ func TestLoggerJSON(t *testing.T) {
RequestRefererHeader: assertString(testReferer),
},
},
{
desc: "fields and headers with unconventional letter case",
config: &types.AccessLog{
FilePath: "",
Format: JSONFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Names: map[string]string{
"rEqUeStHoSt": "keep",
},
Headers: &types.FieldHeaders{
DefaultMode: "drop",
Names: map[string]string{
"ReFeReR": "keep",
},
},
},
},
expected: map[string]func(t *testing.T, value interface{}){
RequestHost: assertString(testHostname),
"level": assertString("info"),
"msg": assertString(""),
"time": assertNotEmpty(),
RequestRefererHeader: assertString(testReferer),
},
},
}
for _, test := range testCases {
@ -496,6 +524,64 @@ func TestLoggerJSON(t *testing.T) {
}
}
func TestLogger_AbortedRequest(t *testing.T) {
expected := map[string]func(t *testing.T, value interface{}){
RequestContentSize: assertFloat64(0),
RequestHost: assertString(testHostname),
RequestAddr: assertString(testHostname),
RequestMethod: assertString(testMethod),
RequestPath: assertString(""),
RequestProtocol: assertString(testProto),
RequestScheme: assertString(testScheme),
RequestPort: assertString("-"),
DownstreamStatus: assertFloat64(float64(200)),
DownstreamContentSize: assertFloat64(float64(40)),
RequestRefererHeader: assertString(testReferer),
RequestUserAgentHeader: assertString(testUserAgent),
ServiceURL: assertString("http://stream"),
ServiceAddr: assertString("127.0.0.1"),
ServiceName: assertString("stream"),
ClientUsername: assertString(testUsername),
ClientHost: assertString(testHostname),
ClientPort: assertString(strconv.Itoa(testPort)),
ClientAddr: assertString(fmt.Sprintf("%s:%d", testHostname, testPort)),
"level": assertString("info"),
"msg": assertString(""),
RequestCount: assertFloat64NotZero(),
Duration: assertFloat64NotZero(),
Overhead: assertFloat64NotZero(),
RetryAttempts: assertFloat64(float64(0)),
"time": assertNotEmpty(),
StartLocal: assertNotEmpty(),
StartUTC: assertNotEmpty(),
"downstream_Content-Type": assertString("text/plain"),
"downstream_Transfer-Encoding": assertString("chunked"),
"downstream_Cache-Control": assertString("no-cache"),
}
config := &types.AccessLog{
FilePath: filepath.Join(t.TempDir(), logFileNameSuffix),
Format: JSONFormat,
}
doLoggingWithAbortedStream(t, config)
logData, err := os.ReadFile(config.FilePath)
require.NoError(t, err)
jsonData := make(map[string]interface{})
err = json.Unmarshal(logData, &jsonData)
require.NoError(t, err)
assert.Equal(t, len(expected), len(jsonData))
for field, assertion := range expected {
assertion(t, jsonData[field])
if t.Failed() {
return
}
}
}
func TestNewLogHandlerOutputStdout(t *testing.T) {
testCases := []struct {
desc string
@ -832,3 +918,89 @@ func logWriterTestHandlerFunc(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(testStatus)
}
func doLoggingWithAbortedStream(t *testing.T, config *types.AccessLog) {
t.Helper()
logger, err := NewHandler(config)
require.NoError(t, err)
t.Cleanup(func() {
err := logger.Close()
require.NoError(t, err)
})
if config.FilePath != "" {
_, err = os.Stat(config.FilePath)
require.NoError(t, err, "logger should create "+config.FilePath)
}
reqContext, cancelRequest := context.WithCancel(context.Background())
req := &http.Request{
Header: map[string][]string{
"User-Agent": {testUserAgent},
"Referer": {testReferer},
},
Proto: testProto,
Host: testHostname,
Method: testMethod,
RemoteAddr: fmt.Sprintf("%s:%d", testHostname, testPort),
URL: &url.URL{
User: url.UserPassword(testUsername, ""),
},
Body: nil,
}
req = req.WithContext(reqContext)
chain := alice.New()
chain = chain.Append(func(next http.Handler) (http.Handler, error) {
return recovery.New(context.Background(), next)
})
chain = chain.Append(capture.Wrap)
chain = chain.Append(WrapHandler(logger))
service := NewFieldHandler(http.HandlerFunc(streamBackend), ServiceURL, "http://stream", nil)
service = NewFieldHandler(service, ServiceAddr, "127.0.0.1", nil)
service = NewFieldHandler(service, ServiceName, "stream", AddServiceFields)
handler, err := chain.Then(service)
require.NoError(t, err)
go func() {
time.Sleep(499 * time.Millisecond)
cancelRequest()
}()
handler.ServeHTTP(httptest.NewRecorder(), req)
}
func streamBackend(rw http.ResponseWriter, r *http.Request) {
// Get the Flusher to flush the response to the client
flusher, ok := rw.(http.Flusher)
if !ok {
http.Error(rw, "Streaming unsupported!", http.StatusInternalServerError)
return
}
// Set the headers for streaming
rw.Header().Set("Content-Type", "text/plain")
rw.Header().Set("Transfer-Encoding", "chunked")
rw.Header().Set("Cache-Control", "no-cache")
for {
time.Sleep(100 * time.Millisecond)
select {
case <-r.Context().Done():
panic(http.ErrAbortHandler)
default:
if _, err := fmt.Fprint(rw, "FOOBAR!!!!"); err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
flusher.Flush()
}
}
}

View file

@ -1,4 +1,4 @@
package connectionheader
package auth
import (
"net/http"

View file

@ -1,4 +1,4 @@
package connectionheader
package auth
import (
"net/http"

View file

@ -13,7 +13,6 @@ import (
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/middlewares"
"github.com/traefik/traefik/v3/pkg/middlewares/connectionheader"
"github.com/traefik/traefik/v3/pkg/middlewares/observability"
"github.com/traefik/traefik/v3/pkg/tracing"
"github.com/traefik/traefik/v3/pkg/types"
@ -121,7 +120,7 @@ func (fa *forwardAuth) GetTracingInformation() (string, string, trace.SpanKind)
func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), fa.name, typeNameForward)
req = connectionheader.Remove(req)
req = Remove(req)
forwardReq, err := http.NewRequestWithContext(req.Context(), http.MethodGet, fa.address, nil)
if err != nil {

View file

@ -235,7 +235,7 @@ func (cc *codeCatcher) Flush() {
// since we want to serve the ones from the error page,
// so we just don't flush.
// (e.g., To prevent superfluous WriteHeader on request with a
// `Transfert-Encoding: chunked` header).
// `Transfer-Encoding: chunked` header).
if cc.caughtFilteredCode {
return
}

View file

@ -3,10 +3,13 @@ package forwardedheaders
import (
"net"
"net/http"
"net/textproto"
"os"
"slices"
"strings"
"github.com/traefik/traefik/v3/pkg/ip"
"golang.org/x/net/http/httpguts"
)
const (
@ -42,19 +45,20 @@ var xHeaders = []string{
// Unless insecure is set,
// it first removes all the existing values for those headers if the remote address is not one of the trusted ones.
type XForwarded struct {
insecure bool
trustedIps []string
ipChecker *ip.Checker
next http.Handler
hostname string
insecure bool
trustedIPs []string
connectionHeaders []string
ipChecker *ip.Checker
next http.Handler
hostname string
}
// NewXForwarded creates a new XForwarded.
func NewXForwarded(insecure bool, trustedIps []string, next http.Handler) (*XForwarded, error) {
func NewXForwarded(insecure bool, trustedIPs []string, connectionHeaders []string, next http.Handler) (*XForwarded, error) {
var ipChecker *ip.Checker
if len(trustedIps) > 0 {
if len(trustedIPs) > 0 {
var err error
ipChecker, err = ip.NewChecker(trustedIps)
ipChecker, err = ip.NewChecker(trustedIPs)
if err != nil {
return nil, err
}
@ -66,11 +70,12 @@ func NewXForwarded(insecure bool, trustedIps []string, next http.Handler) (*XFor
}
return &XForwarded{
insecure: insecure,
trustedIps: trustedIps,
ipChecker: ipChecker,
next: next,
hostname: hostname,
insecure: insecure,
trustedIPs: trustedIPs,
connectionHeaders: connectionHeaders,
ipChecker: ipChecker,
next: next,
hostname: hostname,
}, nil
}
@ -189,9 +194,53 @@ func (x *XForwarded) ServeHTTP(w http.ResponseWriter, r *http.Request) {
x.rewrite(r)
x.removeConnectionHeaders(r)
x.next.ServeHTTP(w, r)
}
func (x *XForwarded) removeConnectionHeaders(req *http.Request) {
var reqUpType string
if httpguts.HeaderValuesContainsToken(req.Header[connection], upgrade) {
reqUpType = unsafeHeader(req.Header).Get(upgrade)
}
var connectionHopByHopHeaders []string
for _, f := range req.Header[connection] {
for _, sf := range strings.Split(f, ",") {
if sf = textproto.TrimString(sf); sf != "" {
// Connection header cannot dictate to remove X- headers managed by Traefik,
// as per rfc7230 https://datatracker.ietf.org/doc/html/rfc7230#section-6.1,
// A proxy or gateway MUST ... and then remove the Connection header field itself
// (or replace it with the intermediary's own connection options for the forwarded message).
if slices.Contains(xHeaders, sf) {
continue
}
// Keep headers allowed through the middleware chain.
if slices.Contains(x.connectionHeaders, sf) {
connectionHopByHopHeaders = append(connectionHopByHopHeaders, sf)
continue
}
// Apply Connection header option.
req.Header.Del(sf)
}
}
}
if reqUpType != "" {
connectionHopByHopHeaders = append(connectionHopByHopHeaders, upgrade)
unsafeHeader(req.Header).Set(upgrade, reqUpType)
}
if len(connectionHopByHopHeaders) > 0 {
unsafeHeader(req.Header).Set(connection, strings.Join(connectionHopByHopHeaders, ","))
return
}
unsafeHeader(req.Header).Del(connection)
}
// unsafeHeader allows to manage Header values.
// Must be used only when the header name is already a canonical key.
type unsafeHeader map[string][]string

View file

@ -12,15 +12,16 @@ import (
func TestServeHTTP(t *testing.T) {
testCases := []struct {
desc string
insecure bool
trustedIps []string
incomingHeaders map[string][]string
remoteAddr string
expectedHeaders map[string]string
tls bool
websocket bool
host string
desc string
insecure bool
trustedIps []string
connectionHeaders []string
incomingHeaders map[string][]string
remoteAddr string
expectedHeaders map[string]string
tls bool
websocket bool
host string
}{
{
desc: "all Empty",
@ -269,6 +270,196 @@ func TestServeHTTP(t *testing.T) {
xForwardedServer: "foo.com:8080",
},
},
{
desc: "Untrusted: Connection header has no effect on X- forwarded headers",
insecure: false,
incomingHeaders: map[string][]string{
connection: {
xForwardedProto,
xForwardedFor,
xForwardedURI,
xForwardedMethod,
xForwardedHost,
xForwardedPort,
xForwardedTLSClientCert,
xForwardedTLSClientCertInfo,
xRealIP,
},
xForwardedProto: {"foo"},
xForwardedFor: {"foo"},
xForwardedURI: {"foo"},
xForwardedMethod: {"foo"},
xForwardedHost: {"foo"},
xForwardedPort: {"foo"},
xForwardedTLSClientCert: {"foo"},
xForwardedTLSClientCertInfo: {"foo"},
xRealIP: {"foo"},
},
expectedHeaders: map[string]string{
xForwardedProto: "http",
xForwardedFor: "",
xForwardedURI: "",
xForwardedMethod: "",
xForwardedHost: "",
xForwardedPort: "80",
xForwardedTLSClientCert: "",
xForwardedTLSClientCertInfo: "",
xRealIP: "",
connection: "",
},
},
{
desc: "Trusted (insecure): Connection header has no effect on X- forwarded headers",
insecure: true,
incomingHeaders: map[string][]string{
connection: {
xForwardedProto,
xForwardedFor,
xForwardedURI,
xForwardedMethod,
xForwardedHost,
xForwardedPort,
xForwardedTLSClientCert,
xForwardedTLSClientCertInfo,
xRealIP,
},
xForwardedProto: {"foo"},
xForwardedFor: {"foo"},
xForwardedURI: {"foo"},
xForwardedMethod: {"foo"},
xForwardedHost: {"foo"},
xForwardedPort: {"foo"},
xForwardedTLSClientCert: {"foo"},
xForwardedTLSClientCertInfo: {"foo"},
xRealIP: {"foo"},
},
expectedHeaders: map[string]string{
xForwardedProto: "foo",
xForwardedFor: "foo",
xForwardedURI: "foo",
xForwardedMethod: "foo",
xForwardedHost: "foo",
xForwardedPort: "foo",
xForwardedTLSClientCert: "foo",
xForwardedTLSClientCertInfo: "foo",
xRealIP: "foo",
connection: "",
},
},
{
desc: "Untrusted and Connection: Connection header has no effect on X- forwarded headers",
insecure: false,
connectionHeaders: []string{
xForwardedProto,
xForwardedFor,
xForwardedURI,
xForwardedMethod,
xForwardedHost,
xForwardedPort,
xForwardedTLSClientCert,
xForwardedTLSClientCertInfo,
xRealIP,
},
incomingHeaders: map[string][]string{
connection: {
xForwardedProto,
xForwardedFor,
xForwardedURI,
xForwardedMethod,
xForwardedHost,
xForwardedPort,
xForwardedTLSClientCert,
xForwardedTLSClientCertInfo,
xRealIP,
},
xForwardedProto: {"foo"},
xForwardedFor: {"foo"},
xForwardedURI: {"foo"},
xForwardedMethod: {"foo"},
xForwardedHost: {"foo"},
xForwardedPort: {"foo"},
xForwardedTLSClientCert: {"foo"},
xForwardedTLSClientCertInfo: {"foo"},
xRealIP: {"foo"},
},
expectedHeaders: map[string]string{
xForwardedProto: "http",
xForwardedFor: "",
xForwardedURI: "",
xForwardedMethod: "",
xForwardedHost: "",
xForwardedPort: "80",
xForwardedTLSClientCert: "",
xForwardedTLSClientCertInfo: "",
xRealIP: "",
connection: "",
},
},
{
desc: "Trusted (insecure) and Connection: Connection header has no effect on X- forwarded headers",
insecure: true,
connectionHeaders: []string{
xForwardedProto,
xForwardedFor,
xForwardedURI,
xForwardedMethod,
xForwardedHost,
xForwardedPort,
xForwardedTLSClientCert,
xForwardedTLSClientCertInfo,
xRealIP,
},
incomingHeaders: map[string][]string{
connection: {
xForwardedProto,
xForwardedFor,
xForwardedURI,
xForwardedMethod,
xForwardedHost,
xForwardedPort,
xForwardedTLSClientCert,
xForwardedTLSClientCertInfo,
xRealIP,
},
xForwardedProto: {"foo"},
xForwardedFor: {"foo"},
xForwardedURI: {"foo"},
xForwardedMethod: {"foo"},
xForwardedHost: {"foo"},
xForwardedPort: {"foo"},
xForwardedTLSClientCert: {"foo"},
xForwardedTLSClientCertInfo: {"foo"},
xRealIP: {"foo"},
},
expectedHeaders: map[string]string{
xForwardedProto: "foo",
xForwardedFor: "foo",
xForwardedURI: "foo",
xForwardedMethod: "foo",
xForwardedHost: "foo",
xForwardedPort: "foo",
xForwardedTLSClientCert: "foo",
xForwardedTLSClientCertInfo: "foo",
xRealIP: "foo",
connection: "",
},
},
{
desc: "Connection: one remove, and one passthrough header",
connectionHeaders: []string{
"foo",
},
incomingHeaders: map[string][]string{
connection: {
"foo",
},
"Foo": {"bar"},
"Bar": {"foo"},
},
expectedHeaders: map[string]string{
"Bar": "foo",
},
},
}
for _, test := range testCases {
@ -299,7 +490,7 @@ func TestServeHTTP(t *testing.T) {
}
}
m, err := NewXForwarded(test.insecure, test.trustedIps,
m, err := NewXForwarded(test.insecure, test.trustedIps, test.connectionHeaders,
http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
require.NoError(t, err)
@ -382,3 +573,74 @@ func Test_isWebsocketRequest(t *testing.T) {
})
}
}
func TestConnection(t *testing.T) {
testCases := []struct {
desc string
reqHeaders map[string]string
connectionHeaders []string
expected http.Header
}{
{
desc: "simple remove",
reqHeaders: map[string]string{
"Foo": "bar",
connection: "foo",
},
expected: http.Header{},
},
{
desc: "remove and upgrade",
reqHeaders: map[string]string{
upgrade: "test",
"Foo": "bar",
connection: "upgrade,foo",
},
expected: http.Header{
upgrade: []string{"test"},
connection: []string{"Upgrade"},
},
},
{
desc: "no remove",
reqHeaders: map[string]string{
"Foo": "bar",
connection: "fii",
},
expected: http.Header{
"Foo": []string{"bar"},
},
},
{
desc: "no remove because connection header pass through",
reqHeaders: map[string]string{
"Foo": "bar",
connection: "Foo",
},
connectionHeaders: []string{"Foo"},
expected: http.Header{
"Foo": []string{"bar"},
connection: []string{"Foo"},
},
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
forwarded, err := NewXForwarded(true, nil, test.connectionHeaders, nil)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "https://localhost", nil)
for k, v := range test.reqHeaders {
req.Header.Set(k, v)
}
forwarded.removeConnectionHeaders(req)
assert.Equal(t, test.expected, req.Header)
})
}
}

View file

@ -8,7 +8,6 @@ import (
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/middlewares"
"github.com/traefik/traefik/v3/pkg/middlewares/connectionheader"
"go.opentelemetry.io/otel/trace"
)
@ -46,12 +45,11 @@ func New(ctx context.Context, next http.Handler, cfg dynamic.Headers, name strin
if hasCustomHeaders || hasCorsHeaders {
logger.Debug().Msgf("Setting up customHeaders/Cors from %v", cfg)
h, err := NewHeader(nextHandler, cfg)
var err error
handler, err = NewHeader(nextHandler, cfg)
if err != nil {
return nil, err
}
handler = connectionheader.Remover(h)
}
return &headers{

View file

@ -149,7 +149,7 @@ func (rl *rateLimiter) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
// We Set even in the case where the source already exists,
// because we want to update the expiryTime everytime we get the source,
// because we want to update the expiryTime every time we get the source,
// as the expiryTime is supposed to reflect the activity (or lack thereof) on that source.
if err := rl.buckets.Set(source, bucket, rl.ttl); err != nil {
logger.Error().Err(err).Msg("Could not insert/update bucket")