1
0
Fork 0

Merge v2.10 into v3.0

This commit is contained in:
romain 2023-11-29 12:20:57 +01:00
commit e29a142f6a
118 changed files with 1324 additions and 1290 deletions

View file

@ -14,6 +14,7 @@ import (
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"testing"
"time"
@ -323,7 +324,7 @@ func TestLoggerJSON(t *testing.T) {
ServiceURL: assertString(testServiceName),
ClientUsername: assertString(testUsername),
ClientHost: assertString(testHostname),
ClientPort: assertString(fmt.Sprintf("%d", testPort)),
ClientPort: assertString(strconv.Itoa(testPort)),
ClientAddr: assertString(fmt.Sprintf("%s:%d", testHostname, testPort)),
"level": assertString("info"),
"msg": assertString(""),
@ -363,7 +364,7 @@ func TestLoggerJSON(t *testing.T) {
ServiceURL: assertString(testServiceName),
ClientUsername: assertString(testUsername),
ClientHost: assertString(testHostname),
ClientPort: assertString(fmt.Sprintf("%d", testPort)),
ClientPort: assertString(strconv.Itoa(testPort)),
ClientAddr: assertString(fmt.Sprintf("%s:%d", testHostname, testPort)),
"level": assertString("info"),
"msg": assertString(""),

View file

@ -137,5 +137,5 @@ func (r *digestRequest) makeAuthorization(req *http.Request, parts map[string]st
func generateRandom(n int) string {
b := make([]byte, 8)
_, _ = io.ReadFull(rand.Reader, b)
return fmt.Sprintf("%x", b)[:n]
return hex.EncodeToString(b)[:n]
}

View file

@ -165,7 +165,7 @@ func runBenchmark(b *testing.B, size int, req *http.Request, handler http.Handle
b.Fatalf("Expected 200 but got %d", code)
}
assert.Equal(b, size, len(recorder.Body.String()))
assert.Len(b, recorder.Body.String(), size)
}
func generateBytes(length int) []byte {

View file

@ -376,13 +376,13 @@ func Test_ExcludedContentTypes(t *testing.T) {
assert.Equal(t, "br", rw.Header().Get(contentEncoding))
got, err := io.ReadAll(brotli.NewReader(rw.Body))
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, bigTestBody, got)
} else {
assert.NotEqual(t, "br", rw.Header().Get("Content-Encoding"))
got, err := io.ReadAll(rw.Body)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, bigTestBody, got)
}
})
@ -496,13 +496,13 @@ func Test_FlushExcludedContentTypes(t *testing.T) {
assert.Equal(t, "br", rw.Header().Get(contentEncoding))
got, err := io.ReadAll(brotli.NewReader(rw.Body))
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, bigTestBody, got)
} else {
assert.NotEqual(t, "br", rw.Header().Get(contentEncoding))
got, err := io.ReadAll(rw.Body)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, bigTestBody, got)
}
})

View file

@ -582,7 +582,7 @@ func Test1xxResponses(t *testing.T) {
req.Header.Add(acceptEncodingHeader, gzipValue)
res, err := frontendClient.Do(req)
assert.Nil(t, err)
assert.NoError(t, err)
defer res.Body.Close()

View file

@ -253,7 +253,7 @@ func Test1xxResponses(t *testing.T) {
req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), http.MethodGet, server.URL, nil)
res, err := frontendClient.Do(req)
assert.Nil(t, err)
assert.NoError(t, err)
defer res.Body.Close()

View file

@ -0,0 +1,64 @@
package denyrouterrecursion
import (
"errors"
"hash/fnv"
"net/http"
"strconv"
"github.com/containous/alice"
"github.com/rs/zerolog/log"
"github.com/traefik/traefik/v3/pkg/logs"
)
const xTraefikRouter = "X-Traefik-Router"
type DenyRouterRecursion struct {
routerName string
routerNameHash string
next http.Handler
}
// WrapHandler Wraps router to alice.Constructor.
func WrapHandler(routerName string) alice.Constructor {
return func(next http.Handler) (http.Handler, error) {
return New(routerName, next)
}
}
// New creates a new DenyRouterRecursion.
// DenyRouterRecursion middleware is an internal middleware used to avoid infinite requests loop on the same router.
func New(routerName string, next http.Handler) (*DenyRouterRecursion, error) {
if routerName == "" {
return nil, errors.New("routerName cannot be empty")
}
return &DenyRouterRecursion{
routerName: routerName,
routerNameHash: makeHash(routerName),
next: next,
}, nil
}
// ServeHTTP implements http.Handler.
func (l *DenyRouterRecursion) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if req.Header.Get(xTraefikRouter) == l.routerNameHash {
logger := log.With().Str(logs.MiddlewareType, "DenyRouterRecursion").Logger()
logger.Debug().Msgf("Rejecting request in provenance of the same router (%q) to stop potential infinite loop.", l.routerName)
rw.WriteHeader(http.StatusBadRequest)
return
}
req.Header.Set(xTraefikRouter, l.routerNameHash)
l.next.ServeHTTP(rw, req)
}
func makeHash(routerName string) string {
hasher := fnv.New64()
// purposely ignoring the error, as no error can be returned from the implementation.
_, _ = hasher.Write([]byte(routerName))
return strconv.FormatUint(hasher.Sum64(), 16)
}

View file

@ -0,0 +1,38 @@
package denyrouterrecursion
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestServeHTTP(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "", nil)
require.NoError(t, err)
_, err = New("", http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
require.Error(t, err)
next := 0
m, err := New("myRouter", http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
next++
}))
require.NoError(t, err)
recorder := httptest.NewRecorder()
m.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusOK, recorder.Code)
assert.Equal(t, "995d26092d19a224", m.routerNameHash)
assert.Equal(t, m.routerNameHash, req.Header.Get(xTraefikRouter))
assert.Equal(t, 1, next)
recorder = httptest.NewRecorder()
m.ServeHTTP(recorder, req)
assert.Equal(t, 1, next)
assert.Equal(t, http.StatusBadRequest, recorder.Code)
}

View file

@ -182,7 +182,7 @@ func Test1xxResponses(t *testing.T) {
req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), http.MethodGet, server.URL, nil)
res, err := frontendClient.Do(req)
assert.Nil(t, err)
assert.NoError(t, err)
defer res.Body.Close()

View file

@ -18,6 +18,8 @@ import (
"golang.org/x/time/rate"
)
const delta float64 = 1e-10
func TestNewRateLimiter(t *testing.T) {
testCases := []struct {
desc string
@ -131,7 +133,7 @@ func TestNewRateLimiter(t *testing.T) {
assert.Equal(t, test.requestHeader, hd)
}
if test.expectedRTL != 0 {
assert.Equal(t, test.expectedRTL, rtl.rate)
assert.InDelta(t, float64(test.expectedRTL), float64(rtl.rate), delta)
}
})
}

View file

@ -94,7 +94,7 @@ func (r *retry) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
newCtx := httptrace.WithClientTrace(req.Context(), trace)
r.next.ServeHTTP(retryResponseWriter, req.WithContext(newCtx))
r.next.ServeHTTP(retryResponseWriter, req.Clone(newCtx))
if !retryResponseWriter.ShouldRetry() {
return nil

View file

@ -373,7 +373,7 @@ func Test1xxResponses(t *testing.T) {
req, _ := http.NewRequestWithContext(httptrace.WithClientTrace(context.Background(), trace), http.MethodGet, server.URL, nil)
res, err := frontendClient.Do(req)
assert.Nil(t, err)
assert.NoError(t, err)
defer res.Body.Close()

View file

@ -123,7 +123,7 @@ func TestNewForwarder(t *testing.T) {
tags := span.Tags
assert.Equal(t, test.expected.Tags, tags)
assert.True(t, len(test.expected.OperationName) <= test.spanNameLimit,
assert.LessOrEqual(t, len(test.expected.OperationName), test.spanNameLimit,
"the len of the operation name %q [len: %d] doesn't respect limit %d",
test.expected.OperationName, len(test.expected.OperationName), test.spanNameLimit)
assert.Equal(t, test.expected.OperationName, span.OpName)