1
0
Fork 0

Move code to pkg

This commit is contained in:
Ludovic Fernandez 2019-03-15 09:42:03 +01:00 committed by Traefiker Bot
parent bd4c822670
commit f1b085fa36
465 changed files with 656 additions and 680 deletions

View file

@ -0,0 +1,18 @@
package accesslog
import "io"
type captureRequestReader struct {
source io.ReadCloser
count int64
}
func (r *captureRequestReader) Read(p []byte) (int, error) {
n, err := r.source.Read(p)
r.count += int64(n)
return n, err
}
func (r *captureRequestReader) Close() error {
return r.source.Close()
}

View file

@ -0,0 +1,68 @@
package accesslog
import (
"bufio"
"fmt"
"net"
"net/http"
"github.com/containous/traefik/old/middlewares"
)
var (
_ middlewares.Stateful = &captureResponseWriter{}
)
// captureResponseWriter is a wrapper of type http.ResponseWriter
// that tracks request status and size
type captureResponseWriter struct {
rw http.ResponseWriter
status int
size int64
}
func (crw *captureResponseWriter) Header() http.Header {
return crw.rw.Header()
}
func (crw *captureResponseWriter) Write(b []byte) (int, error) {
if crw.status == 0 {
crw.status = http.StatusOK
}
size, err := crw.rw.Write(b)
crw.size += int64(size)
return size, err
}
func (crw *captureResponseWriter) WriteHeader(s int) {
crw.rw.WriteHeader(s)
crw.status = s
}
func (crw *captureResponseWriter) Flush() {
if f, ok := crw.rw.(http.Flusher); ok {
f.Flush()
}
}
func (crw *captureResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if h, ok := crw.rw.(http.Hijacker); ok {
return h.Hijack()
}
return nil, nil, fmt.Errorf("not a hijacker: %T", crw.rw)
}
func (crw *captureResponseWriter) CloseNotify() <-chan bool {
if c, ok := crw.rw.(http.CloseNotifier); ok {
return c.CloseNotify()
}
return nil
}
func (crw *captureResponseWriter) Status() int {
return crw.status
}
func (crw *captureResponseWriter) Size() int64 {
return crw.size
}

View file

@ -0,0 +1,65 @@
package accesslog
import (
"net/http"
"time"
"github.com/vulcand/oxy/utils"
)
// FieldApply function hook to add data in accesslog
type FieldApply func(rw http.ResponseWriter, r *http.Request, next http.Handler, data *LogData)
// FieldHandler sends a new field to the logger.
type FieldHandler struct {
next http.Handler
name string
value string
applyFn FieldApply
}
// NewFieldHandler creates a Field handler.
func NewFieldHandler(next http.Handler, name string, value string, applyFn FieldApply) http.Handler {
return &FieldHandler{next: next, name: name, value: value, applyFn: applyFn}
}
func (f *FieldHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
table := GetLogData(req)
if table == nil {
f.next.ServeHTTP(rw, req)
return
}
table.Core[f.name] = f.value
if f.applyFn != nil {
f.applyFn(rw, req, f.next, table)
} else {
f.next.ServeHTTP(rw, req)
}
}
// AddServiceFields add service fields
func AddServiceFields(rw http.ResponseWriter, req *http.Request, next http.Handler, data *LogData) {
data.Core[ServiceURL] = req.URL // note that this is *not* the original incoming URL
data.Core[ServiceAddr] = req.URL.Host
next.ServeHTTP(rw, req)
}
// AddOriginFields add origin fields
func AddOriginFields(rw http.ResponseWriter, req *http.Request, next http.Handler, data *LogData) {
crw := &captureResponseWriter{rw: rw}
start := time.Now().UTC()
next.ServeHTTP(crw, req)
// use UTC to handle switchover of daylight saving correctly
data.Core[OriginDuration] = time.Now().UTC().Sub(start)
data.Core[OriginStatus] = crw.Status()
// make copy of headers so we can ensure there is no subsequent mutation during response processing
data.OriginResponse = make(http.Header)
utils.CopyHeaders(data.OriginResponse, crw.Header())
data.Core[OriginContentSize] = crw.Size()
}

View file

@ -0,0 +1,122 @@
package accesslog
import (
"net/http"
)
const (
// StartUTC is the map key used for the time at which request processing started.
StartUTC = "StartUTC"
// StartLocal is the map key used for the local time at which request processing started.
StartLocal = "StartLocal"
// Duration is the map key used for the total time taken by processing the response, including the origin server's time but
// not the log writing time.
Duration = "Duration"
// RouterName is the map key used for the name of the Traefik router.
RouterName = "RouterName"
// ServiceName is the map key used for the name of the Traefik backend.
ServiceName = "ServiceName"
// ServiceURL is the map key used for the URL of the Traefik backend.
ServiceURL = "ServiceURL"
// ServiceAddr is the map key used for the IP:port of the Traefik backend (extracted from BackendURL)
ServiceAddr = "ServiceAddr"
// ClientAddr is the map key used for the remote address in its original form (usually IP:port).
ClientAddr = "ClientAddr"
// ClientHost is the map key used for the remote IP address from which the client request was received.
ClientHost = "ClientHost"
// ClientPort is the map key used for the remote TCP port from which the client request was received.
ClientPort = "ClientPort"
// ClientUsername is the map key used for the username provided in the URL, if present.
ClientUsername = "ClientUsername"
// RequestAddr is the map key used for the HTTP Host header (usually IP:port). This is treated as not a header by the Go API.
RequestAddr = "RequestAddr"
// RequestHost is the map key used for the HTTP Host server name (not including port).
RequestHost = "RequestHost"
// RequestPort is the map key used for the TCP port from the HTTP Host.
RequestPort = "RequestPort"
// RequestMethod is the map key used for the HTTP method.
RequestMethod = "RequestMethod"
// RequestPath is the map key used for the HTTP request URI, not including the scheme, host or port.
RequestPath = "RequestPath"
// RequestProtocol is the map key used for the version of HTTP requested.
RequestProtocol = "RequestProtocol"
// RequestContentSize is the map key used for the number of bytes in the request entity (a.k.a. body) sent by the client.
RequestContentSize = "RequestContentSize"
// RequestRefererHeader is the Referer header in the request
RequestRefererHeader = "request_Referer"
// RequestUserAgentHeader is the User-Agent header in the request
RequestUserAgentHeader = "request_User-Agent"
// OriginDuration is the map key used for the time taken by the origin server ('upstream') to return its response.
OriginDuration = "OriginDuration"
// OriginContentSize is the map key used for the content length specified by the origin server, or 0 if unspecified.
OriginContentSize = "OriginContentSize"
// OriginStatus is the map key used for the HTTP status code returned by the origin server.
// If the request was handled by this Traefik instance (e.g. with a redirect), then this value will be absent.
OriginStatus = "OriginStatus"
// DownstreamStatus is the map key used for the HTTP status code returned to the client.
DownstreamStatus = "DownstreamStatus"
// DownstreamContentSize is the map key used for the number of bytes in the response entity returned to the client.
// This is in addition to the "Content-Length" header, which may be present in the origin response.
DownstreamContentSize = "DownstreamContentSize"
// RequestCount is the map key used for the number of requests received since the Traefik instance started.
RequestCount = "RequestCount"
// GzipRatio is the map key used for the response body compression ratio achieved.
GzipRatio = "GzipRatio"
// Overhead is the map key used for the processing time overhead caused by Traefik.
Overhead = "Overhead"
// RetryAttempts is the map key used for the amount of attempts the request was retried.
RetryAttempts = "RetryAttempts"
)
// These are written out in the default case when no config is provided to specify keys of interest.
var defaultCoreKeys = [...]string{
StartUTC,
Duration,
RouterName,
ServiceName,
ServiceURL,
ClientHost,
ClientPort,
ClientUsername,
RequestHost,
RequestPort,
RequestMethod,
RequestPath,
RequestProtocol,
RequestContentSize,
OriginDuration,
OriginContentSize,
OriginStatus,
DownstreamStatus,
DownstreamContentSize,
RequestCount,
}
// This contains the set of all keys, i.e. all the default keys plus all non-default keys.
var allCoreKeys = make(map[string]struct{})
func init() {
for _, k := range defaultCoreKeys {
allCoreKeys[k] = struct{}{}
}
allCoreKeys[ServiceAddr] = struct{}{}
allCoreKeys[ClientAddr] = struct{}{}
allCoreKeys[RequestAddr] = struct{}{}
allCoreKeys[GzipRatio] = struct{}{}
allCoreKeys[StartLocal] = struct{}{}
allCoreKeys[Overhead] = struct{}{}
allCoreKeys[RetryAttempts] = struct{}{}
}
// CoreLogData holds the fields computed from the request/response.
type CoreLogData map[string]interface{}
// LogData is the data captured by the middleware so that it can be logged.
type LogData struct {
Core CoreLogData
Request http.Header
OriginResponse http.Header
DownstreamResponse http.Header
}

View file

@ -0,0 +1,344 @@
package accesslog
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"sync"
"sync/atomic"
"time"
"github.com/containous/alice"
"github.com/containous/flaeg/parse"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/types"
"github.com/sirupsen/logrus"
)
type key string
const (
// DataTableKey is the key within the request context used to store the Log Data Table.
DataTableKey key = "LogDataTable"
// CommonFormat is the common logging format (CLF).
CommonFormat string = "common"
// JSONFormat is the JSON logging format.
JSONFormat string = "json"
)
type handlerParams struct {
logDataTable *LogData
crr *captureRequestReader
crw *captureResponseWriter
}
// Handler will write each request and its response to the access log.
type Handler struct {
config *types.AccessLog
logger *logrus.Logger
file *os.File
mu sync.Mutex
httpCodeRanges types.HTTPCodeRanges
logHandlerChan chan handlerParams
wg sync.WaitGroup
}
// WrapHandler Wraps access log handler into an Alice Constructor.
func WrapHandler(handler *Handler) alice.Constructor {
return func(next http.Handler) (http.Handler, error) {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
handler.ServeHTTP(rw, req, next.ServeHTTP)
}), nil
}
}
// NewHandler creates a new Handler.
func NewHandler(config *types.AccessLog) (*Handler, error) {
file := os.Stdout
if len(config.FilePath) > 0 {
f, err := openAccessLogFile(config.FilePath)
if err != nil {
return nil, fmt.Errorf("error opening access log file: %s", err)
}
file = f
}
logHandlerChan := make(chan handlerParams, config.BufferingSize)
var formatter logrus.Formatter
switch config.Format {
case CommonFormat:
formatter = new(CommonLogFormatter)
case JSONFormat:
formatter = new(logrus.JSONFormatter)
default:
return nil, fmt.Errorf("unsupported access log format: %s", config.Format)
}
logger := &logrus.Logger{
Out: file,
Formatter: formatter,
Hooks: make(logrus.LevelHooks),
Level: logrus.InfoLevel,
}
logHandler := &Handler{
config: config,
logger: logger,
file: file,
logHandlerChan: logHandlerChan,
}
if config.Filters != nil {
if httpCodeRanges, err := types.NewHTTPCodeRanges(config.Filters.StatusCodes); err != nil {
log.WithoutContext().Errorf("Failed to create new HTTP code ranges: %s", err)
} else {
logHandler.httpCodeRanges = httpCodeRanges
}
}
if config.BufferingSize > 0 {
logHandler.wg.Add(1)
go func() {
defer logHandler.wg.Done()
for handlerParams := range logHandler.logHandlerChan {
logHandler.logTheRoundTrip(handlerParams.logDataTable, handlerParams.crr, handlerParams.crw)
}
}()
}
return logHandler, nil
}
func openAccessLogFile(filePath string) (*os.File, error) {
dir := filepath.Dir(filePath)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create log path %s: %s", dir, err)
}
file, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0664)
if err != nil {
return nil, fmt.Errorf("error opening file %s: %s", filePath, err)
}
return file, nil
}
// GetLogData gets the request context object that contains logging data.
// This creates data as the request passes through the middleware chain.
func GetLogData(req *http.Request) *LogData {
if ld, ok := req.Context().Value(DataTableKey).(*LogData); ok {
return ld
}
return nil
}
func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
now := time.Now().UTC()
core := CoreLogData{
StartUTC: now,
StartLocal: now.Local(),
}
logDataTable := &LogData{Core: core, Request: req.Header}
reqWithDataTable := req.WithContext(context.WithValue(req.Context(), DataTableKey, logDataTable))
var crr *captureRequestReader
if req.Body != nil {
crr = &captureRequestReader{source: req.Body, count: 0}
reqWithDataTable.Body = crr
}
core[RequestCount] = nextRequestCount()
if req.Host != "" {
core[RequestAddr] = req.Host
core[RequestHost], core[RequestPort] = silentSplitHostPort(req.Host)
}
// copy the URL without the scheme, hostname etc
urlCopy := &url.URL{
Path: req.URL.Path,
RawPath: req.URL.RawPath,
RawQuery: req.URL.RawQuery,
ForceQuery: req.URL.ForceQuery,
Fragment: req.URL.Fragment,
}
urlCopyString := urlCopy.String()
core[RequestMethod] = req.Method
core[RequestPath] = urlCopyString
core[RequestProtocol] = req.Proto
core[ClientAddr] = req.RemoteAddr
core[ClientHost], core[ClientPort] = silentSplitHostPort(req.RemoteAddr)
if forwardedFor := req.Header.Get("X-Forwarded-For"); forwardedFor != "" {
core[ClientHost] = forwardedFor
}
crw := &captureResponseWriter{rw: rw}
next.ServeHTTP(crw, reqWithDataTable)
if _, ok := core[ClientUsername]; !ok {
core[ClientUsername] = usernameIfPresent(reqWithDataTable.URL)
}
logDataTable.DownstreamResponse = crw.Header()
if h.config.BufferingSize > 0 {
h.logHandlerChan <- handlerParams{
logDataTable: logDataTable,
crr: crr,
crw: crw,
}
} else {
h.logTheRoundTrip(logDataTable, crr, crw)
}
}
// Close closes the Logger (i.e. the file, drain logHandlerChan, etc).
func (h *Handler) Close() error {
close(h.logHandlerChan)
h.wg.Wait()
return h.file.Close()
}
// Rotate closes and reopens the log file to allow for rotation by an external source.
func (h *Handler) Rotate() error {
var err error
if h.file != nil {
defer func(f *os.File) {
f.Close()
}(h.file)
}
h.file, err = os.OpenFile(h.config.FilePath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0664)
if err != nil {
return err
}
h.mu.Lock()
defer h.mu.Unlock()
h.logger.Out = h.file
return nil
}
func silentSplitHostPort(value string) (host string, port string) {
host, port, err := net.SplitHostPort(value)
if err != nil {
return value, "-"
}
return host, port
}
func usernameIfPresent(theURL *url.URL) string {
if theURL.User != nil {
if name := theURL.User.Username(); name != "" {
return name
}
}
return "-"
}
// Logging handler to log frontend name, backend name, and elapsed time.
func (h *Handler) logTheRoundTrip(logDataTable *LogData, crr *captureRequestReader, crw *captureResponseWriter) {
core := logDataTable.Core
retryAttempts, ok := core[RetryAttempts].(int)
if !ok {
retryAttempts = 0
}
core[RetryAttempts] = retryAttempts
if crr != nil {
core[RequestContentSize] = crr.count
}
core[DownstreamStatus] = crw.Status()
// n.b. take care to perform time arithmetic using UTC to avoid errors at DST boundaries.
totalDuration := time.Now().UTC().Sub(core[StartUTC].(time.Time))
core[Duration] = totalDuration
if h.keepAccessLog(crw.Status(), retryAttempts, totalDuration) {
core[DownstreamContentSize] = crw.Size()
if original, ok := core[OriginContentSize]; ok {
o64 := original.(int64)
if crw.Size() != o64 && crw.Size() != 0 {
core[GzipRatio] = float64(o64) / float64(crw.Size())
}
}
core[Overhead] = totalDuration
if origin, ok := core[OriginDuration]; ok {
core[Overhead] = totalDuration - origin.(time.Duration)
}
fields := logrus.Fields{}
for k, v := range logDataTable.Core {
if h.config.Fields.Keep(k) {
fields[k] = v
}
}
h.redactHeaders(logDataTable.Request, fields, "request_")
h.redactHeaders(logDataTable.OriginResponse, fields, "origin_")
h.redactHeaders(logDataTable.DownstreamResponse, fields, "downstream_")
h.mu.Lock()
defer h.mu.Unlock()
h.logger.WithFields(fields).Println()
}
}
func (h *Handler) redactHeaders(headers http.Header, fields logrus.Fields, prefix string) {
for k := range headers {
v := h.config.Fields.KeepHeader(k)
if v == types.AccessLogKeep {
fields[prefix+k] = headers.Get(k)
} else if v == types.AccessLogRedact {
fields[prefix+k] = "REDACTED"
}
}
}
func (h *Handler) keepAccessLog(statusCode, retryAttempts int, duration time.Duration) bool {
if h.config.Filters == nil {
// no filters were specified
return true
}
if len(h.httpCodeRanges) == 0 && !h.config.Filters.RetryAttempts && h.config.Filters.MinDuration == 0 {
// empty filters were specified, e.g. by passing --accessLog.filters only (without other filter options)
return true
}
if h.httpCodeRanges.Contains(statusCode) {
return true
}
if h.config.Filters.RetryAttempts && retryAttempts > 0 {
return true
}
if h.config.Filters.MinDuration > 0 && (parse.Duration(duration) > h.config.Filters.MinDuration) {
return true
}
return false
}
var requestCounter uint64 // Request ID
func nextRequestCount() uint64 {
return atomic.AddUint64(&requestCounter, 1)
}

View file

@ -0,0 +1,83 @@
package accesslog
import (
"bytes"
"fmt"
"time"
"github.com/sirupsen/logrus"
)
// default format for time presentation.
const (
commonLogTimeFormat = "02/Jan/2006:15:04:05 -0700"
defaultValue = "-"
)
// CommonLogFormatter provides formatting in the Traefik common log format.
type CommonLogFormatter struct{}
// Format formats the log entry in the Traefik common log format.
func (f *CommonLogFormatter) Format(entry *logrus.Entry) ([]byte, error) {
b := &bytes.Buffer{}
var timestamp = defaultValue
if v, ok := entry.Data[StartUTC]; ok {
timestamp = v.(time.Time).Format(commonLogTimeFormat)
}
var elapsedMillis int64
if v, ok := entry.Data[Duration]; ok {
elapsedMillis = v.(time.Duration).Nanoseconds() / 1000000
}
_, err := fmt.Fprintf(b, "%s - %s [%s] \"%s %s %s\" %v %v %s %s %v %s %s %dms\n",
toLog(entry.Data, ClientHost, defaultValue, false),
toLog(entry.Data, ClientUsername, defaultValue, false),
timestamp,
toLog(entry.Data, RequestMethod, defaultValue, false),
toLog(entry.Data, RequestPath, defaultValue, false),
toLog(entry.Data, RequestProtocol, defaultValue, false),
toLog(entry.Data, OriginStatus, defaultValue, true),
toLog(entry.Data, OriginContentSize, defaultValue, true),
toLog(entry.Data, "request_Referer", `"-"`, true),
toLog(entry.Data, "request_User-Agent", `"-"`, true),
toLog(entry.Data, RequestCount, defaultValue, true),
toLog(entry.Data, RouterName, defaultValue, true),
toLog(entry.Data, ServiceURL, defaultValue, true),
elapsedMillis)
return b.Bytes(), err
}
func toLog(fields logrus.Fields, key string, defaultValue string, quoted bool) interface{} {
if v, ok := fields[key]; ok {
if v == nil {
return defaultValue
}
switch s := v.(type) {
case string:
return toLogEntry(s, defaultValue, quoted)
case fmt.Stringer:
return toLogEntry(s.String(), defaultValue, quoted)
default:
return v
}
}
return defaultValue
}
func toLogEntry(s string, defaultValue string, quote bool) string {
if len(s) == 0 {
return defaultValue
}
if quote {
return `"` + s + `"`
}
return s
}

View file

@ -0,0 +1,140 @@
package accesslog
import (
"net/http"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)
func TestCommonLogFormatter_Format(t *testing.T) {
clf := CommonLogFormatter{}
testCases := []struct {
name string
data map[string]interface{}
expectedLog string
}{
{
name: "OriginStatus & OriginContentSize are nil",
data: map[string]interface{}{
StartUTC: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC),
Duration: 123 * time.Second,
ClientHost: "10.0.0.1",
ClientUsername: "Client",
RequestMethod: http.MethodGet,
RequestPath: "/foo",
RequestProtocol: "http",
OriginStatus: nil,
OriginContentSize: nil,
RequestRefererHeader: "",
RequestUserAgentHeader: "",
RequestCount: 0,
RouterName: "",
ServiceURL: "",
},
expectedLog: `10.0.0.1 - Client [10/Nov/2009:23:00:00 +0000] "GET /foo http" - - "-" "-" 0 - - 123000ms
`,
},
{
name: "all data",
data: map[string]interface{}{
StartUTC: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC),
Duration: 123 * time.Second,
ClientHost: "10.0.0.1",
ClientUsername: "Client",
RequestMethod: http.MethodGet,
RequestPath: "/foo",
RequestProtocol: "http",
OriginStatus: 123,
OriginContentSize: 132,
RequestRefererHeader: "referer",
RequestUserAgentHeader: "agent",
RequestCount: nil,
RouterName: "foo",
ServiceURL: "http://10.0.0.2/toto",
},
expectedLog: `10.0.0.1 - Client [10/Nov/2009:23:00:00 +0000] "GET /foo http" 123 132 "referer" "agent" - "foo" "http://10.0.0.2/toto" 123000ms
`,
},
}
for _, test := range testCases {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
entry := &logrus.Entry{Data: test.data}
raw, err := clf.Format(entry)
assert.NoError(t, err)
assert.Equal(t, test.expectedLog, string(raw))
})
}
}
func Test_toLog(t *testing.T) {
testCases := []struct {
desc string
fields logrus.Fields
fieldName string
defaultValue string
quoted bool
expectedLog interface{}
}{
{
desc: "Should return int 1",
fields: logrus.Fields{
"Powpow": 1,
},
fieldName: "Powpow",
defaultValue: defaultValue,
quoted: false,
expectedLog: 1,
},
{
desc: "Should return string foo",
fields: logrus.Fields{
"Powpow": "foo",
},
fieldName: "Powpow",
defaultValue: defaultValue,
quoted: true,
expectedLog: `"foo"`,
},
{
desc: "Should return defaultValue if fieldName does not exist",
fields: logrus.Fields{
"Powpow": "foo",
},
fieldName: "",
defaultValue: defaultValue,
quoted: false,
expectedLog: "-",
},
{
desc: "Should return defaultValue if fields is nil",
fields: nil,
fieldName: "",
defaultValue: defaultValue,
quoted: false,
expectedLog: "-",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
lg := toLog(test.fields, test.fieldName, defaultValue, test.quoted)
assert.Equal(t, test.expectedLog, lg)
})
}
}

View file

@ -0,0 +1,649 @@
package accesslog
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
"testing"
"time"
"github.com/containous/flaeg/parse"
"github.com/containous/traefik/pkg/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
logFileNameSuffix = "/traefik/logger/test.log"
testContent = "Hello, World"
testServiceName = "http://127.0.0.1/testService"
testRouterName = "testRouter"
testStatus = 123
testContentSize int64 = 12
testHostname = "TestHost"
testUsername = "TestUser"
testPath = "testpath"
testPort = 8181
testProto = "HTTP/0.0"
testMethod = http.MethodPost
testReferer = "testReferer"
testUserAgent = "testUserAgent"
testRetryAttempts = 2
testStart = time.Now()
)
func TestLogRotation(t *testing.T) {
tempDir, err := ioutil.TempDir("", "traefik_")
if err != nil {
t.Fatalf("Error setting up temporary directory: %s", err)
}
fileName := tempDir + "traefik.log"
rotatedFileName := fileName + ".rotated"
config := &types.AccessLog{FilePath: fileName, Format: CommonFormat}
logHandler, err := NewHandler(config)
if err != nil {
t.Fatalf("Error creating new log handler: %s", err)
}
defer logHandler.Close()
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
next := func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
}
iterations := 20
halfDone := make(chan bool)
writeDone := make(chan bool)
go func() {
for i := 0; i < iterations; i++ {
logHandler.ServeHTTP(recorder, req, next)
if i == iterations/2 {
halfDone <- true
}
}
writeDone <- true
}()
<-halfDone
err = os.Rename(fileName, rotatedFileName)
if err != nil {
t.Fatalf("Error renaming file: %s", err)
}
err = logHandler.Rotate()
if err != nil {
t.Fatalf("Error rotating file: %s", err)
}
select {
case <-writeDone:
gotLineCount := lineCount(t, fileName) + lineCount(t, rotatedFileName)
if iterations != gotLineCount {
t.Errorf("Wanted %d written log lines, got %d", iterations, gotLineCount)
}
case <-time.After(500 * time.Millisecond):
t.Fatalf("test timed out")
}
close(halfDone)
close(writeDone)
}
func lineCount(t *testing.T, fileName string) int {
t.Helper()
fileContents, err := ioutil.ReadFile(fileName)
if err != nil {
t.Fatalf("Error reading from file %s: %s", fileName, err)
}
count := 0
for _, line := range strings.Split(string(fileContents), "\n") {
if strings.TrimSpace(line) == "" {
continue
}
count++
}
return count
}
func TestLoggerCLF(t *testing.T) {
tmpDir := createTempDir(t, CommonFormat)
defer os.RemoveAll(tmpDir)
logFilePath := filepath.Join(tmpDir, logFileNameSuffix)
config := &types.AccessLog{FilePath: logFilePath, Format: CommonFormat}
doLogging(t, config)
logData, err := ioutil.ReadFile(logFilePath)
require.NoError(t, err)
expectedLog := ` TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testRouter" "http://127.0.0.1/testService" 1ms`
assertValidLogData(t, expectedLog, logData)
}
func TestAsyncLoggerCLF(t *testing.T) {
tmpDir := createTempDir(t, CommonFormat)
defer os.RemoveAll(tmpDir)
logFilePath := filepath.Join(tmpDir, logFileNameSuffix)
config := &types.AccessLog{FilePath: logFilePath, Format: CommonFormat, BufferingSize: 1024}
doLogging(t, config)
logData, err := ioutil.ReadFile(logFilePath)
require.NoError(t, err)
expectedLog := ` TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testRouter" "http://127.0.0.1/testService" 1ms`
assertValidLogData(t, expectedLog, logData)
}
func assertString(exp string) func(t *testing.T, actual interface{}) {
return func(t *testing.T, actual interface{}) {
t.Helper()
assert.Equal(t, exp, actual)
}
}
func assertNotEmpty() func(t *testing.T, actual interface{}) {
return func(t *testing.T, actual interface{}) {
t.Helper()
assert.NotEqual(t, "", actual)
}
}
func assertFloat64(exp float64) func(t *testing.T, actual interface{}) {
return func(t *testing.T, actual interface{}) {
t.Helper()
assert.Equal(t, exp, actual)
}
}
func assertFloat64NotZero() func(t *testing.T, actual interface{}) {
return func(t *testing.T, actual interface{}) {
t.Helper()
assert.NotZero(t, actual)
}
}
func TestLoggerJSON(t *testing.T) {
testCases := []struct {
desc string
config *types.AccessLog
expected map[string]func(t *testing.T, value interface{})
}{
{
desc: "default config",
config: &types.AccessLog{
FilePath: "",
Format: JSONFormat,
},
expected: map[string]func(t *testing.T, value interface{}){
RequestHost: assertString(testHostname),
RequestAddr: assertString(testHostname),
RequestMethod: assertString(testMethod),
RequestPath: assertString(testPath),
RequestProtocol: assertString(testProto),
RequestPort: assertString("-"),
DownstreamStatus: assertFloat64(float64(testStatus)),
DownstreamContentSize: assertFloat64(float64(len(testContent))),
OriginContentSize: assertFloat64(float64(len(testContent))),
OriginStatus: assertFloat64(float64(testStatus)),
RequestRefererHeader: assertString(testReferer),
RequestUserAgentHeader: assertString(testUserAgent),
RouterName: assertString(testRouterName),
ServiceURL: assertString(testServiceName),
ClientUsername: assertString(testUsername),
ClientHost: assertString(testHostname),
ClientPort: assertString(fmt.Sprintf("%d", testPort)),
ClientAddr: assertString(fmt.Sprintf("%s:%d", testHostname, testPort)),
"level": assertString("info"),
"msg": assertString(""),
"downstream_Content-Type": assertString("text/plain; charset=utf-8"),
RequestCount: assertFloat64NotZero(),
Duration: assertFloat64NotZero(),
Overhead: assertFloat64NotZero(),
RetryAttempts: assertFloat64(float64(testRetryAttempts)),
"time": assertNotEmpty(),
"StartLocal": assertNotEmpty(),
"StartUTC": assertNotEmpty(),
},
},
{
desc: "default config drop all fields",
config: &types.AccessLog{
FilePath: "",
Format: JSONFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
},
},
expected: map[string]func(t *testing.T, value interface{}){
"level": assertString("info"),
"msg": assertString(""),
"time": assertNotEmpty(),
"downstream_Content-Type": assertString("text/plain; charset=utf-8"),
RequestRefererHeader: assertString(testReferer),
RequestUserAgentHeader: assertString(testUserAgent),
},
},
{
desc: "default config drop all fields and headers",
config: &types.AccessLog{
FilePath: "",
Format: JSONFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Headers: &types.FieldHeaders{
DefaultMode: "drop",
},
},
},
expected: map[string]func(t *testing.T, value interface{}){
"level": assertString("info"),
"msg": assertString(""),
"time": assertNotEmpty(),
},
},
{
desc: "default config drop all fields and redact headers",
config: &types.AccessLog{
FilePath: "",
Format: JSONFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Headers: &types.FieldHeaders{
DefaultMode: "redact",
},
},
},
expected: map[string]func(t *testing.T, value interface{}){
"level": assertString("info"),
"msg": assertString(""),
"time": assertNotEmpty(),
"downstream_Content-Type": assertString("REDACTED"),
RequestRefererHeader: assertString("REDACTED"),
RequestUserAgentHeader: assertString("REDACTED"),
},
},
{
desc: "default config drop all fields and headers but kept someone",
config: &types.AccessLog{
FilePath: "",
Format: JSONFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Names: types.FieldNames{
RequestHost: "keep",
},
Headers: &types.FieldHeaders{
DefaultMode: "drop",
Names: types.FieldHeaderNames{
"Referer": "keep",
},
},
},
},
expected: map[string]func(t *testing.T, value interface{}){
RequestHost: assertString(testHostname),
"level": assertString("info"),
"msg": assertString(""),
"time": assertNotEmpty(),
RequestRefererHeader: assertString(testReferer),
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
tmpDir := createTempDir(t, JSONFormat)
defer os.RemoveAll(tmpDir)
logFilePath := filepath.Join(tmpDir, logFileNameSuffix)
test.config.FilePath = logFilePath
doLogging(t, test.config)
logData, err := ioutil.ReadFile(logFilePath)
require.NoError(t, err)
jsonData := make(map[string]interface{})
err = json.Unmarshal(logData, &jsonData)
require.NoError(t, err)
assert.Equal(t, len(test.expected), len(jsonData))
for field, assertion := range test.expected {
assertion(t, jsonData[field])
}
})
}
}
func TestNewLogHandlerOutputStdout(t *testing.T) {
testCases := []struct {
desc string
config *types.AccessLog
expectedLog string
}{
{
desc: "default config",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
},
{
desc: "default config with empty filters",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Filters: &types.AccessLogFilters{},
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
},
{
desc: "Status code filter not matching",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Filters: &types.AccessLogFilters{
StatusCodes: []string{"200"},
},
},
expectedLog: ``,
},
{
desc: "Status code filter matching",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Filters: &types.AccessLogFilters{
StatusCodes: []string{"123"},
},
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
},
{
desc: "Duration filter not matching",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Filters: &types.AccessLogFilters{
MinDuration: parse.Duration(1 * time.Hour),
},
},
expectedLog: ``,
},
{
desc: "Duration filter matching",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Filters: &types.AccessLogFilters{
MinDuration: parse.Duration(1 * time.Millisecond),
},
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
},
{
desc: "Retry attempts filter matching",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Filters: &types.AccessLogFilters{
RetryAttempts: true,
},
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
},
{
desc: "Default mode keep",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Fields: &types.AccessLogFields{
DefaultMode: "keep",
},
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
},
{
desc: "Default mode keep with override",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Fields: &types.AccessLogFields{
DefaultMode: "keep",
Names: types.FieldNames{
ClientHost: "drop",
},
},
},
expectedLog: `- - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
},
{
desc: "Default mode drop",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
},
},
expectedLog: `- - - [-] "- - -" - - "testReferer" "testUserAgent" - - - 0ms`,
},
{
desc: "Default mode drop with override",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Names: types.FieldNames{
ClientHost: "drop",
ClientUsername: "keep",
},
},
},
expectedLog: `- - TestUser [-] "- - -" - - "testReferer" "testUserAgent" - - - 0ms`,
},
{
desc: "Default mode drop with header dropped",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Names: types.FieldNames{
ClientHost: "drop",
ClientUsername: "keep",
},
Headers: &types.FieldHeaders{
DefaultMode: "drop",
},
},
},
expectedLog: `- - TestUser [-] "- - -" - - "-" "-" - - - 0ms`,
},
{
desc: "Default mode drop with header redacted",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Names: types.FieldNames{
ClientHost: "drop",
ClientUsername: "keep",
},
Headers: &types.FieldHeaders{
DefaultMode: "redact",
},
},
},
expectedLog: `- - TestUser [-] "- - -" - - "REDACTED" "REDACTED" - - - 0ms`,
},
{
desc: "Default mode drop with header redacted",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Names: types.FieldNames{
ClientHost: "drop",
ClientUsername: "keep",
},
Headers: &types.FieldHeaders{
DefaultMode: "keep",
Names: types.FieldHeaderNames{
"Referer": "redact",
},
},
},
},
expectedLog: `- - TestUser [-] "- - -" - - "REDACTED" "testUserAgent" - - - 0ms`,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
// NOTE: It is not possible to run these cases in parallel because we capture Stdout
file, restoreStdout := captureStdout(t)
defer restoreStdout()
doLogging(t, test.config)
written, err := ioutil.ReadFile(file.Name())
require.NoError(t, err, "unable to read captured stdout from file")
assertValidLogData(t, test.expectedLog, written)
})
}
}
func assertValidLogData(t *testing.T, expected string, logData []byte) {
if len(expected) == 0 {
assert.Zero(t, len(logData))
t.Log(string(logData))
return
}
result, err := ParseAccessLog(string(logData))
require.NoError(t, err)
resultExpected, err := ParseAccessLog(expected)
require.NoError(t, err)
formatErrMessage := fmt.Sprintf(`
Expected: %s
Actual: %s`, expected, string(logData))
require.Equal(t, len(resultExpected), len(result), formatErrMessage)
assert.Equal(t, resultExpected[ClientHost], result[ClientHost], formatErrMessage)
assert.Equal(t, resultExpected[ClientUsername], result[ClientUsername], formatErrMessage)
assert.Equal(t, resultExpected[RequestMethod], result[RequestMethod], formatErrMessage)
assert.Equal(t, resultExpected[RequestPath], result[RequestPath], formatErrMessage)
assert.Equal(t, resultExpected[RequestProtocol], result[RequestProtocol], formatErrMessage)
assert.Equal(t, resultExpected[OriginStatus], result[OriginStatus], formatErrMessage)
assert.Equal(t, resultExpected[OriginContentSize], result[OriginContentSize], formatErrMessage)
assert.Equal(t, resultExpected[RequestRefererHeader], result[RequestRefererHeader], formatErrMessage)
assert.Equal(t, resultExpected[RequestUserAgentHeader], result[RequestUserAgentHeader], formatErrMessage)
assert.Regexp(t, regexp.MustCompile("[0-9]*"), result[RequestCount], formatErrMessage)
assert.Equal(t, resultExpected[RouterName], result[RouterName], formatErrMessage)
assert.Equal(t, resultExpected[ServiceURL], result[ServiceURL], formatErrMessage)
assert.Regexp(t, regexp.MustCompile("[0-9]*ms"), result[Duration], formatErrMessage)
}
func captureStdout(t *testing.T) (out *os.File, restoreStdout func()) {
file, err := ioutil.TempFile("", "testlogger")
require.NoError(t, err, "failed to create temp file")
original := os.Stdout
os.Stdout = file
restoreStdout = func() {
os.Stdout = original
}
return file, restoreStdout
}
func createTempDir(t *testing.T, prefix string) string {
tmpDir, err := ioutil.TempDir("", prefix)
require.NoError(t, err, "failed to create temp dir")
return tmpDir
}
func doLogging(t *testing.T, config *types.AccessLog) {
logger, err := NewHandler(config)
require.NoError(t, err)
defer logger.Close()
if config.FilePath != "" {
_, err = os.Stat(config.FilePath)
require.NoError(t, err, fmt.Sprintf("logger should create %s", config.FilePath))
}
req := &http.Request{
Header: map[string][]string{
"User-Agent": {testUserAgent},
"Referer": {testReferer},
},
Proto: testProto,
Host: testHostname,
Method: testMethod,
RemoteAddr: fmt.Sprintf("%s:%d", testHostname, testPort),
URL: &url.URL{
User: url.UserPassword(testUsername, ""),
Path: testPath,
},
}
logger.ServeHTTP(httptest.NewRecorder(), req, logWriterTestHandlerFunc)
}
func logWriterTestHandlerFunc(rw http.ResponseWriter, r *http.Request) {
if _, err := rw.Write([]byte(testContent)); err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
logData := GetLogData(r)
if logData != nil {
logData.Core[RouterName] = testRouterName
logData.Core[ServiceURL] = testServiceName
logData.Core[OriginStatus] = testStatus
logData.Core[OriginContentSize] = testContentSize
logData.Core[RetryAttempts] = testRetryAttempts
logData.Core[StartUTC] = testStart.UTC()
logData.Core[StartLocal] = testStart.Local()
} else {
http.Error(rw, "LogData is nil", http.StatusInternalServerError)
return
}
rw.WriteHeader(testStatus)
}

View file

@ -0,0 +1,54 @@
package accesslog
import (
"bytes"
"regexp"
)
// ParseAccessLog parse line of access log and return a map with each fields
func ParseAccessLog(data string) (map[string]string, error) {
var buffer bytes.Buffer
buffer.WriteString(`(\S+)`) // 1 - ClientHost
buffer.WriteString(`\s-\s`) // - - Spaces
buffer.WriteString(`(\S+)\s`) // 2 - ClientUsername
buffer.WriteString(`\[([^]]+)\]\s`) // 3 - StartUTC
buffer.WriteString(`"(\S*)\s?`) // 4 - RequestMethod
buffer.WriteString(`((?:[^"]*(?:\\")?)*)\s`) // 5 - RequestPath
buffer.WriteString(`([^"]*)"\s`) // 6 - RequestProtocol
buffer.WriteString(`(\S+)\s`) // 7 - OriginStatus
buffer.WriteString(`(\S+)\s`) // 8 - OriginContentSize
buffer.WriteString(`("?\S+"?)\s`) // 9 - Referrer
buffer.WriteString(`("\S+")\s`) // 10 - User-Agent
buffer.WriteString(`(\S+)\s`) // 11 - RequestCount
buffer.WriteString(`("[^"]*"|-)\s`) // 12 - FrontendName
buffer.WriteString(`("[^"]*"|-)\s`) // 13 - BackendURL
buffer.WriteString(`(\S+)`) // 14 - Duration
regex, err := regexp.Compile(buffer.String())
if err != nil {
return nil, err
}
submatch := regex.FindStringSubmatch(data)
result := make(map[string]string)
// Need to be > 13 to match CLF format
if len(submatch) > 13 {
result[ClientHost] = submatch[1]
result[ClientUsername] = submatch[2]
result[StartUTC] = submatch[3]
result[RequestMethod] = submatch[4]
result[RequestPath] = submatch[5]
result[RequestProtocol] = submatch[6]
result[OriginStatus] = submatch[7]
result[OriginContentSize] = submatch[8]
result[RequestRefererHeader] = submatch[9]
result[RequestUserAgentHeader] = submatch[10]
result[RequestCount] = submatch[11]
result[RouterName] = submatch[12]
result[ServiceURL] = submatch[13]
result[Duration] = submatch[14]
}
return result, nil
}

View file

@ -0,0 +1,75 @@
package accesslog
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestParseAccessLog(t *testing.T) {
testCases := []struct {
desc string
value string
expected map[string]string
}{
{
desc: "full log",
value: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testRouter" "http://127.0.0.1/testService" 1ms`,
expected: map[string]string{
ClientHost: "TestHost",
ClientUsername: "TestUser",
StartUTC: "13/Apr/2016:07:14:19 -0700",
RequestMethod: "POST",
RequestPath: "testpath",
RequestProtocol: "HTTP/0.0",
OriginStatus: "123",
OriginContentSize: "12",
RequestRefererHeader: `"testReferer"`,
RequestUserAgentHeader: `"testUserAgent"`,
RequestCount: "1",
RouterName: `"testRouter"`,
ServiceURL: `"http://127.0.0.1/testService"`,
Duration: "1ms",
},
},
{
desc: "log with space",
value: `127.0.0.1 - - [09/Mar/2018:10:51:32 +0000] "GET / HTTP/1.1" 401 17 "-" "Go-http-client/1.1" 1 "testRouter with space" - 0ms`,
expected: map[string]string{
ClientHost: "127.0.0.1",
ClientUsername: "-",
StartUTC: "09/Mar/2018:10:51:32 +0000",
RequestMethod: "GET",
RequestPath: "/",
RequestProtocol: "HTTP/1.1",
OriginStatus: "401",
OriginContentSize: "17",
RequestRefererHeader: `"-"`,
RequestUserAgentHeader: `"Go-http-client/1.1"`,
RequestCount: "1",
RouterName: `"testRouter with space"`,
ServiceURL: `-`,
Duration: "0ms",
},
},
{
desc: "bad log",
value: `bad`,
expected: map[string]string{},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
result, err := ParseAccessLog(test.value)
assert.NoError(t, err)
assert.Equal(t, len(test.expected), len(result))
for key, value := range test.expected {
assert.Equal(t, value, result[key])
}
})
}
}

View file

@ -0,0 +1,21 @@
package accesslog
import (
"net/http"
)
// SaveRetries is an implementation of RetryListener that stores RetryAttempts in the LogDataTable.
type SaveRetries struct{}
// Retried implements the RetryListener interface and will be called for each retry that happens.
func (s *SaveRetries) Retried(req *http.Request, attempt int) {
// it is the request attempt x, but the retry attempt is x-1
if attempt > 0 {
attempt--
}
table := GetLogData(req)
if table != nil {
table.Core[RetryAttempts] = attempt
}
}

View file

@ -0,0 +1,48 @@
package accesslog
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
)
func TestSaveRetries(t *testing.T) {
tests := []struct {
requestAttempt int
wantRetryAttemptsInLog int
}{
{
requestAttempt: 0,
wantRetryAttemptsInLog: 0,
},
{
requestAttempt: 1,
wantRetryAttemptsInLog: 0,
},
{
requestAttempt: 3,
wantRetryAttemptsInLog: 2,
},
}
for _, test := range tests {
test := test
t.Run(fmt.Sprintf("%d retries", test.requestAttempt), func(t *testing.T) {
t.Parallel()
saveRetries := &SaveRetries{}
logDataTable := &LogData{Core: make(CoreLogData)}
req := httptest.NewRequest(http.MethodGet, "/some/path", nil)
reqWithDataTable := req.WithContext(context.WithValue(req.Context(), DataTableKey, logDataTable))
saveRetries.Retried(reqWithDataTable, test.requestAttempt)
if logDataTable.Core[RetryAttempts] != test.wantRetryAttemptsInLog {
t.Errorf("got %v in logDataTable, want %v", logDataTable.Core[RetryAttempts], test.wantRetryAttemptsInLog)
}
})
}
}

View file

@ -0,0 +1,62 @@
package addprefix
import (
"context"
"fmt"
"net/http"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
)
const (
typeName = "AddPrefix"
)
// AddPrefix is a middleware used to add prefix to an URL request.
type addPrefix struct {
next http.Handler
prefix string
name string
}
// New creates a new handler.
func New(ctx context.Context, next http.Handler, config config.AddPrefix, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
var result *addPrefix
if len(config.Prefix) > 0 {
result = &addPrefix{
prefix: config.Prefix,
next: next,
name: name,
}
} else {
return nil, fmt.Errorf("prefix cannot be empty")
}
return result, nil
}
func (ap *addPrefix) GetTracingInformation() (string, ext.SpanKindEnum) {
return ap.name, tracing.SpanKindNoneEnum
}
func (ap *addPrefix) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), ap.name, typeName)
oldURLPath := req.URL.Path
req.URL.Path = ap.prefix + req.URL.Path
logger.Debugf("URL.Path is now %s (was %s).", req.URL.Path, oldURLPath)
if req.URL.RawPath != "" {
oldURLRawPath := req.URL.RawPath
req.URL.RawPath = ap.prefix + req.URL.RawPath
logger.Debugf("URL.RawPath is now %s (was %s).", req.URL.RawPath, oldURLRawPath)
}
req.RequestURI = req.URL.RequestURI()
ap.next.ServeHTTP(rw, req)
}

View file

@ -0,0 +1,104 @@
package addprefix
import (
"context"
"net/http"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewAddPrefix(t *testing.T) {
testCases := []struct {
desc string
prefix config.AddPrefix
expectsError bool
}{
{
desc: "Works with a non empty prefix",
prefix: config.AddPrefix{Prefix: "/a"},
},
{
desc: "Fails if prefix is empty",
prefix: config.AddPrefix{Prefix: ""},
expectsError: true,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
_, err := New(context.Background(), next, test.prefix, "foo-add-prefix")
if test.expectsError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestAddPrefix(t *testing.T) {
logrus.SetLevel(logrus.DebugLevel)
testCases := []struct {
desc string
prefix config.AddPrefix
path string
expectedPath string
expectedRawPath string
}{
{
desc: "Works with a regular path",
prefix: config.AddPrefix{Prefix: "/a"},
path: "/b",
expectedPath: "/a/b",
},
{
desc: "Works with a raw path",
prefix: config.AddPrefix{Prefix: "/a"},
path: "/b%2Fc",
expectedPath: "/a/b/c",
expectedRawPath: "/a/b%2Fc",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
var actualPath, actualRawPath, requestURI string
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actualPath = r.URL.Path
actualRawPath = r.URL.RawPath
requestURI = r.RequestURI
})
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
handler, err := New(context.Background(), next, test.prefix, "foo-add-prefix")
require.NoError(t, err)
handler.ServeHTTP(nil, req)
assert.Equal(t, test.expectedPath, actualPath)
assert.Equal(t, test.expectedRawPath, actualRawPath)
expectedURI := test.expectedPath
if test.expectedRawPath != "" {
// go HTTP uses the raw path when existent in the RequestURI
expectedURI = test.expectedRawPath
}
assert.Equal(t, expectedURI, requestURI)
})
}
}

View file

@ -0,0 +1,65 @@
package auth
import (
"io/ioutil"
"strings"
)
// UserParser Parses a string and return a userName/userHash. An error if the format of the string is incorrect.
type UserParser func(user string) (string, string, error)
const (
defaultRealm = "traefik"
authorizationHeader = "Authorization"
)
func getUsers(fileName string, appendUsers []string, parser UserParser) (map[string]string, error) {
users, err := loadUsers(fileName, appendUsers)
if err != nil {
return nil, err
}
userMap := make(map[string]string)
for _, user := range users {
userName, userHash, err := parser(user)
if err != nil {
return nil, err
}
userMap[userName] = userHash
}
return userMap, nil
}
func loadUsers(fileName string, appendUsers []string) ([]string, error) {
var users []string
var err error
if fileName != "" {
users, err = getLinesFromFile(fileName)
if err != nil {
return nil, err
}
}
return append(users, appendUsers...), nil
}
func getLinesFromFile(filename string) ([]string, error) {
dat, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
// Trim lines and filter out blanks
rawLines := strings.Split(string(dat), "\n")
var filteredLines []string
for _, rawLine := range rawLines {
line := strings.TrimSpace(rawLine)
if line != "" && !strings.HasPrefix(line, "#") {
filteredLines = append(filteredLines, line)
}
}
return filteredLines, nil
}

View file

@ -0,0 +1,102 @@
package auth
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
goauth "github.com/abbot/go-http-auth"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/middlewares/accesslog"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
)
const (
basicTypeName = "BasicAuth"
)
type basicAuth struct {
next http.Handler
auth *goauth.BasicAuth
users map[string]string
headerField string
removeHeader bool
name string
}
// NewBasic creates a basicAuth middleware.
func NewBasic(ctx context.Context, next http.Handler, authConfig config.BasicAuth, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, basicTypeName).Debug("Creating middleware")
users, err := getUsers(authConfig.UsersFile, authConfig.Users, basicUserParser)
if err != nil {
return nil, err
}
ba := &basicAuth{
next: next,
users: users,
headerField: authConfig.HeaderField,
removeHeader: authConfig.RemoveHeader,
name: name,
}
realm := defaultRealm
if len(authConfig.Realm) > 0 {
realm = authConfig.Realm
}
ba.auth = goauth.NewBasicAuthenticator(realm, ba.secretBasic)
return ba, nil
}
func (b *basicAuth) GetTracingInformation() (string, ext.SpanKindEnum) {
return b.name, tracing.SpanKindNoneEnum
}
func (b *basicAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), b.name, basicTypeName)
if username := b.auth.CheckAuth(req); username == "" {
logger.Debug("Authentication failed")
tracing.SetErrorWithEvent(req, "Authentication failed")
b.auth.RequireAuth(rw, req)
} else {
logger.Debug("Authentication succeeded")
req.URL.User = url.User(username)
logData := accesslog.GetLogData(req)
if logData != nil {
logData.Core[accesslog.ClientUsername] = username
}
if b.headerField != "" {
req.Header[b.headerField] = []string{username}
}
if b.removeHeader {
logger.Debug("Removing authorization header")
req.Header.Del(authorizationHeader)
}
b.next.ServeHTTP(rw, req)
}
}
func (b *basicAuth) secretBasic(user, realm string) string {
if secret, ok := b.users[user]; ok {
return secret
}
return ""
}
func basicUserParser(user string) (string, string, error) {
split := strings.Split(user, ":")
if len(split) != 2 {
return "", "", fmt.Errorf("error parsing BasicUser: %v", user)
}
return split[0], split[1], nil
}

View file

@ -0,0 +1,284 @@
package auth
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBasicAuthFail(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.BasicAuth{
Users: []string{"test"},
}
_, err := NewBasic(context.Background(), next, auth, "authName")
require.Error(t, err)
auth2 := config.BasicAuth{
Users: []string{"test:test"},
}
authMiddleware, err := NewBasic(context.Background(), next, auth2, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, res.StatusCode, "they should be equal")
}
func TestBasicAuthSuccess(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.BasicAuth{
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
}
authMiddleware, err := NewBasic(context.Background(), next, auth, "authName")
require.NoError(t, err)
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, "traefik\n", string(body), "they should be equal")
}
func TestBasicAuthUserHeader(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "test", r.Header["X-Webauth-User"][0], "auth user should be set")
fmt.Fprintln(w, "traefik")
})
auth := config.BasicAuth{
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
HeaderField: "X-Webauth-User",
}
middleware, err := NewBasic(context.Background(), next, auth, "authName")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, "traefik\n", string(body))
}
func TestBasicAuthHeaderRemoved(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get(authorizationHeader))
fmt.Fprintln(w, "traefik")
})
auth := config.BasicAuth{
RemoveHeader: true,
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
}
middleware, err := NewBasic(context.Background(), next, auth, "authName")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.Equal(t, "traefik\n", string(body))
}
func TestBasicAuthHeaderPresent(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.NotEmpty(t, r.Header.Get(authorizationHeader))
fmt.Fprintln(w, "traefik")
})
auth := config.BasicAuth{
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
}
middleware, err := NewBasic(context.Background(), next, auth, "authName")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.Equal(t, "traefik\n", string(body))
}
func TestBasicAuthUsersFromFile(t *testing.T) {
testCases := []struct {
desc string
userFileContent string
expectedUsers map[string]string
givenUsers []string
realm string
}{
{
desc: "Finds the users in the file",
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
givenUsers: []string{},
expectedUsers: map[string]string{"test": "test", "test2": "test2"},
},
{
desc: "Merges given users with users from the file",
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\n",
givenUsers: []string{"test2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0", "test3:$apr1$3rJbDP0q$RfzJiorTk78jQ1EcKqWso0"},
expectedUsers: map[string]string{"test": "test", "test2": "test2", "test3": "test3"},
},
{
desc: "Given users have priority over users in the file",
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
givenUsers: []string{"test2:$apr1$mK.GtItK$ncnLYvNLek0weXdxo68690"},
expectedUsers: map[string]string{"test": "test", "test2": "overridden"},
},
{
desc: "Should authenticate the correct user based on the realm",
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
givenUsers: []string{"test2:$apr1$mK.GtItK$ncnLYvNLek0weXdxo68690"},
expectedUsers: map[string]string{"test": "test", "test2": "overridden"},
realm: "traefik",
},
{
desc: "Should skip comments",
userFileContent: "#Comment\ntest:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
givenUsers: []string{},
expectedUsers: map[string]string{"test": "test", "test2": "test2"},
realm: "traefiker",
},
}
for _, test := range testCases {
if test.desc != "Should skip comments" {
continue
}
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
// Creates the temporary configuration file with the users
usersFile, err := ioutil.TempFile("", "auth-users")
require.NoError(t, err)
defer os.Remove(usersFile.Name())
_, err = usersFile.Write([]byte(test.userFileContent))
require.NoError(t, err)
// Creates the configuration for our Authenticator
authenticatorConfiguration := config.BasicAuth{
Users: test.givenUsers,
UsersFile: usersFile.Name(),
Realm: test.realm,
}
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
authenticator, err := NewBasic(context.Background(), next, authenticatorConfiguration, "authName")
require.NoError(t, err)
ts := httptest.NewServer(authenticator)
defer ts.Close()
for userName, userPwd := range test.expectedUsers {
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth(userName, userPwd)
var res *http.Response
res, err = http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode, "Cannot authenticate user "+userName)
var body []byte
body, err = ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
require.Equal(t, "traefik\n", string(body))
}
// Checks that user foo doesn't work
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("foo", "foo")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
if len(test.realm) > 0 {
require.Equal(t, `Basic realm="`+test.realm+`"`, res.Header.Get("WWW-Authenticate"))
}
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
require.NotContains(t, "traefik", string(body))
})
}
}

View file

@ -0,0 +1,102 @@
package auth
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
goauth "github.com/abbot/go-http-auth"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/middlewares/accesslog"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
)
const (
digestTypeName = "digestAuth"
)
type digestAuth struct {
next http.Handler
auth *goauth.DigestAuth
users map[string]string
headerField string
removeHeader bool
name string
}
// NewDigest creates a digest auth middleware.
func NewDigest(ctx context.Context, next http.Handler, authConfig config.DigestAuth, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, digestTypeName).Debug("Creating middleware")
users, err := getUsers(authConfig.UsersFile, authConfig.Users, digestUserParser)
if err != nil {
return nil, err
}
da := &digestAuth{
next: next,
users: users,
headerField: authConfig.HeaderField,
removeHeader: authConfig.RemoveHeader,
name: name,
}
realm := defaultRealm
if len(authConfig.Realm) > 0 {
realm = authConfig.Realm
}
da.auth = goauth.NewDigestAuthenticator(realm, da.secretDigest)
return da, nil
}
func (d *digestAuth) GetTracingInformation() (string, ext.SpanKindEnum) {
return d.name, tracing.SpanKindNoneEnum
}
func (d *digestAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), d.name, digestTypeName)
if username, _ := d.auth.CheckAuth(req); username == "" {
logger.Debug("Digest authentication failed")
tracing.SetErrorWithEvent(req, "Digest authentication failed")
d.auth.RequireAuth(rw, req)
} else {
logger.Debug("Digest authentication succeeded")
req.URL.User = url.User(username)
logData := accesslog.GetLogData(req)
if logData != nil {
logData.Core[accesslog.ClientUsername] = username
}
if d.headerField != "" {
req.Header[d.headerField] = []string{username}
}
if d.removeHeader {
logger.Debug("Removing the Authorization header")
req.Header.Del(authorizationHeader)
}
d.next.ServeHTTP(rw, req)
}
}
func (d *digestAuth) secretDigest(user, realm string) string {
if secret, ok := d.users[user+":"+realm]; ok {
return secret
}
return ""
}
func digestUserParser(user string) (string, string, error) {
split := strings.Split(user, ":")
if len(split) != 3 {
return "", "", fmt.Errorf("error parsing DigestUser: %v", user)
}
return split[0] + ":" + split[1], split[2], nil
}

View file

@ -0,0 +1,141 @@
package auth
import (
"crypto/md5"
"crypto/rand"
"encoding/hex"
"fmt"
"io"
"net/http"
"strings"
)
const (
algorithm = "algorithm"
authorization = "Authorization"
nonce = "nonce"
opaque = "opaque"
qop = "qop"
realm = "realm"
wwwAuthenticate = "Www-Authenticate"
)
// DigestRequest is a client for digest authentication requests
type digestRequest struct {
client *http.Client
username, password string
nonceCount nonceCount
}
type nonceCount int
func (nc nonceCount) String() string {
return fmt.Sprintf("%08x", int(nc))
}
var wanted = []string{algorithm, nonce, opaque, qop, realm}
// New makes a DigestRequest instance
func newDigestRequest(username, password string, client *http.Client) *digestRequest {
return &digestRequest{
client: client,
username: username,
password: password,
}
}
// Do does requests as http.Do does
func (r *digestRequest) Do(req *http.Request) (*http.Response, error) {
parts, err := r.makeParts(req)
if err != nil {
return nil, err
}
if parts != nil {
req.Header.Set(authorization, r.makeAuthorization(req, parts))
}
return r.client.Do(req)
}
func (r *digestRequest) makeParts(req *http.Request) (map[string]string, error) {
authReq, err := http.NewRequest(req.Method, req.URL.String(), nil)
if err != nil {
return nil, err
}
resp, err := r.client.Do(authReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
return nil, nil
}
if len(resp.Header[wwwAuthenticate]) == 0 {
return nil, fmt.Errorf("headers do not have %s", wwwAuthenticate)
}
headers := strings.Split(resp.Header[wwwAuthenticate][0], ",")
parts := make(map[string]string, len(wanted))
for _, r := range headers {
for _, w := range wanted {
if strings.Contains(r, w) {
parts[w] = strings.Split(r, `"`)[1]
}
}
}
if len(parts) != len(wanted) {
return nil, fmt.Errorf("header is invalid: %+v", parts)
}
return parts, nil
}
func getMD5(texts []string) string {
h := md5.New()
_, _ = io.WriteString(h, strings.Join(texts, ":"))
return hex.EncodeToString(h.Sum(nil))
}
func (r *digestRequest) getNonceCount() string {
r.nonceCount++
return r.nonceCount.String()
}
func (r *digestRequest) makeAuthorization(req *http.Request, parts map[string]string) string {
ha1 := getMD5([]string{r.username, parts[realm], r.password})
ha2 := getMD5([]string{req.Method, req.URL.String()})
cnonce := generateRandom(16)
nc := r.getNonceCount()
response := getMD5([]string{
ha1,
parts[nonce],
nc,
cnonce,
parts[qop],
ha2,
})
return fmt.Sprintf(
`Digest username="%s", realm="%s", nonce="%s", uri="%s", algorithm=%s, qop=%s, nc=%s, cnonce="%s", response="%s", opaque="%s"`,
r.username,
parts[realm],
parts[nonce],
req.URL.String(),
parts[algorithm],
parts[qop],
nc,
cnonce,
response,
parts[opaque],
)
}
// GenerateRandom generates random string
func generateRandom(n int) string {
b := make([]byte, 8)
_, _ = io.ReadFull(rand.Reader, b)
return fmt.Sprintf("%x", b)[:n]
}

View file

@ -0,0 +1,156 @@
package auth
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDigestAuthError(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.DigestAuth{
Users: []string{"test"},
}
_, err := NewDigest(context.Background(), next, auth, "authName")
assert.Error(t, err)
}
func TestDigestAuthFail(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.DigestAuth{
Users: []string{"test:traefik:a2688e031edb4be6a3797f3882655c05"},
}
authMiddleware, err := NewDigest(context.Background(), next, auth, "authName")
require.NoError(t, err)
assert.NotNil(t, authMiddleware, "this should not be nil")
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
client := http.DefaultClient
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
}
func TestDigestAuthUsersFromFile(t *testing.T) {
testCases := []struct {
desc string
userFileContent string
expectedUsers map[string]string
givenUsers []string
realm string
}{
{
desc: "Finds the users in the file",
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\ntest2:traefik:518845800f9e2bfb1f1f740ec24f074e\n",
givenUsers: []string{},
expectedUsers: map[string]string{"test": "test", "test2": "test2"},
},
{
desc: "Merges given users with users from the file",
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\n",
givenUsers: []string{"test2:traefik:518845800f9e2bfb1f1f740ec24f074e", "test3:traefik:c8e9f57ce58ecb4424407f665a91646c"},
expectedUsers: map[string]string{"test": "test", "test2": "test2", "test3": "test3"},
},
{
desc: "Given users have priority over users in the file",
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\ntest2:traefik:518845800f9e2bfb1f1f740ec24f074e\n",
givenUsers: []string{"test2:traefik:8de60a1c52da68ccf41f0c0ffb7c51a0"},
expectedUsers: map[string]string{"test": "test", "test2": "overridden"},
},
{
desc: "Should authenticate the correct user based on the realm",
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\ntest:traefiker:a3d334dff2645b914918de78bec50bf4\n",
givenUsers: []string{},
expectedUsers: map[string]string{"test": "test2"},
realm: "traefiker",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
// Creates the temporary configuration file with the users
usersFile, err := ioutil.TempFile("", "auth-users")
require.NoError(t, err)
defer os.Remove(usersFile.Name())
_, err = usersFile.Write([]byte(test.userFileContent))
require.NoError(t, err)
// Creates the configuration for our Authenticator
authenticatorConfiguration := config.DigestAuth{
Users: test.givenUsers,
UsersFile: usersFile.Name(),
Realm: test.realm,
}
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
authenticator, err := NewDigest(context.Background(), next, authenticatorConfiguration, "authName")
require.NoError(t, err)
ts := httptest.NewServer(authenticator)
defer ts.Close()
for userName, userPwd := range test.expectedUsers {
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
digestRequest := newDigestRequest(userName, userPwd, http.DefaultClient)
var res *http.Response
res, err = digestRequest.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode, "Cannot authenticate user "+userName)
var body []byte
body, err = ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
require.Equal(t, "traefik\n", string(body))
}
// Checks that user foo doesn't work
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
digestRequest := newDigestRequest("foo", "foo", http.DefaultClient)
var res *http.Response
res, err = digestRequest.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
var body []byte
body, err = ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
require.NotContains(t, "traefik", string(body))
})
}
}

View file

@ -0,0 +1,213 @@
package auth
import (
"context"
"crypto/tls"
"fmt"
"io/ioutil"
"net"
"net/http"
"strings"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/utils"
)
const (
xForwardedURI = "X-Forwarded-Uri"
xForwardedMethod = "X-Forwarded-Method"
forwardedTypeName = "ForwardedAuthType"
)
type forwardAuth struct {
address string
authResponseHeaders []string
next http.Handler
name string
tlsConfig *tls.Config
trustForwardHeader bool
}
// NewForward creates a forward auth middleware.
func NewForward(ctx context.Context, next http.Handler, config config.ForwardAuth, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, forwardedTypeName).Debug("Creating middleware")
fa := &forwardAuth{
address: config.Address,
authResponseHeaders: config.AuthResponseHeaders,
next: next,
name: name,
trustForwardHeader: config.TrustForwardHeader,
}
if config.TLS != nil {
tlsConfig, err := config.TLS.CreateTLSConfig()
if err != nil {
return nil, err
}
fa.tlsConfig = tlsConfig
}
return fa, nil
}
func (fa *forwardAuth) GetTracingInformation() (string, ext.SpanKindEnum) {
return fa.name, ext.SpanKindRPCClientEnum
}
func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), fa.name, forwardedTypeName)
// Ensure our request client does not follow redirects
httpClient := http.Client{
CheckRedirect: func(r *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
if fa.tlsConfig != nil {
httpClient.Transport = &http.Transport{
TLSClientConfig: fa.tlsConfig,
}
}
forwardReq, err := http.NewRequest(http.MethodGet, fa.address, nil)
tracing.LogRequest(tracing.GetSpan(req), forwardReq)
if err != nil {
logMessage := fmt.Sprintf("Error calling %s. Cause %s", fa.address, err)
logger.Debug(logMessage)
tracing.SetErrorWithEvent(req, logMessage)
rw.WriteHeader(http.StatusInternalServerError)
return
}
writeHeader(req, forwardReq, fa.trustForwardHeader)
tracing.InjectRequestHeaders(forwardReq)
forwardResponse, forwardErr := httpClient.Do(forwardReq)
if forwardErr != nil {
logMessage := fmt.Sprintf("Error calling %s. Cause: %s", fa.address, forwardErr)
logger.Debug(logMessage)
tracing.SetErrorWithEvent(req, logMessage)
rw.WriteHeader(http.StatusInternalServerError)
return
}
body, readError := ioutil.ReadAll(forwardResponse.Body)
if readError != nil {
logMessage := fmt.Sprintf("Error reading body %s. Cause: %s", fa.address, readError)
logger.Debug(logMessage)
tracing.SetErrorWithEvent(req, logMessage)
rw.WriteHeader(http.StatusInternalServerError)
return
}
defer forwardResponse.Body.Close()
// Pass the forward response's body and selected headers if it
// didn't return a response within the range of [200, 300).
if forwardResponse.StatusCode < http.StatusOK || forwardResponse.StatusCode >= http.StatusMultipleChoices {
logger.Debugf("Remote error %s. StatusCode: %d", fa.address, forwardResponse.StatusCode)
utils.CopyHeaders(rw.Header(), forwardResponse.Header)
utils.RemoveHeaders(rw.Header(), forward.HopHeaders...)
// Grab the location header, if any.
redirectURL, err := forwardResponse.Location()
if err != nil {
if err != http.ErrNoLocation {
logMessage := fmt.Sprintf("Error reading response location header %s. Cause: %s", fa.address, err)
logger.Debug(logMessage)
tracing.SetErrorWithEvent(req, logMessage)
rw.WriteHeader(http.StatusInternalServerError)
return
}
} else if redirectURL.String() != "" {
// Set the location in our response if one was sent back.
rw.Header().Set("Location", redirectURL.String())
}
tracing.LogResponseCode(tracing.GetSpan(req), forwardResponse.StatusCode)
rw.WriteHeader(forwardResponse.StatusCode)
if _, err = rw.Write(body); err != nil {
logger.Error(err)
}
return
}
for _, headerName := range fa.authResponseHeaders {
req.Header.Set(headerName, forwardResponse.Header.Get(headerName))
}
req.RequestURI = req.URL.RequestURI()
fa.next.ServeHTTP(rw, req)
}
func writeHeader(req *http.Request, forwardReq *http.Request, trustForwardHeader bool) {
utils.CopyHeaders(forwardReq.Header, req.Header)
utils.RemoveHeaders(forwardReq.Header, forward.HopHeaders...)
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if trustForwardHeader {
if prior, ok := req.Header[forward.XForwardedFor]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
}
forwardReq.Header.Set(forward.XForwardedFor, clientIP)
}
xMethod := req.Header.Get(xForwardedMethod)
switch {
case xMethod != "" && trustForwardHeader:
forwardReq.Header.Set(xForwardedMethod, xMethod)
case req.Method != "":
forwardReq.Header.Set(xForwardedMethod, req.Method)
default:
forwardReq.Header.Del(xForwardedMethod)
}
xfp := req.Header.Get(forward.XForwardedProto)
switch {
case xfp != "" && trustForwardHeader:
forwardReq.Header.Set(forward.XForwardedProto, xfp)
case req.TLS != nil:
forwardReq.Header.Set(forward.XForwardedProto, "https")
default:
forwardReq.Header.Set(forward.XForwardedProto, "http")
}
if xfp := req.Header.Get(forward.XForwardedPort); xfp != "" && trustForwardHeader {
forwardReq.Header.Set(forward.XForwardedPort, xfp)
}
xfh := req.Header.Get(forward.XForwardedHost)
switch {
case xfh != "" && trustForwardHeader:
forwardReq.Header.Set(forward.XForwardedHost, xfh)
case req.Host != "":
forwardReq.Header.Set(forward.XForwardedHost, req.Host)
default:
forwardReq.Header.Del(forward.XForwardedHost)
}
xfURI := req.Header.Get(xForwardedURI)
switch {
case xfURI != "" && trustForwardHeader:
forwardReq.Header.Set(xForwardedURI, xfURI)
case req.URL.RequestURI() != "":
forwardReq.Header.Set(xForwardedURI, req.URL.RequestURI())
default:
forwardReq.Header.Del(xForwardedURI)
}
}

View file

@ -0,0 +1,393 @@
package auth
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vulcand/oxy/forward"
)
func TestForwardAuthFail(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Forbidden", http.StatusForbidden)
}))
defer server.Close()
middleware, err := NewForward(context.Background(), next, config.ForwardAuth{
Address: server.URL,
}, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusForbidden, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.Equal(t, "Forbidden\n", string(body))
}
func TestForwardAuthSuccess(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Auth-User", "user@example.com")
w.Header().Set("X-Auth-Secret", "secret")
fmt.Fprintln(w, "Success")
}))
defer server.Close()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "user@example.com", r.Header.Get("X-Auth-User"))
assert.Empty(t, r.Header.Get("X-Auth-Secret"))
fmt.Fprintln(w, "traefik")
})
auth := config.ForwardAuth{
Address: server.URL,
AuthResponseHeaders: []string{"X-Auth-User"},
}
middleware, err := NewForward(context.Background(), next, auth, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.Equal(t, "traefik\n", string(body))
}
func TestForwardAuthRedirect(t *testing.T) {
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "http://example.com/redirect-test", http.StatusFound)
}))
defer authTs.Close()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.ForwardAuth{
Address: authTs.URL,
}
authMiddleware, err := NewForward(context.Background(), next, auth, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
client := &http.Client{
CheckRedirect: func(r *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusFound, res.StatusCode)
location, err := res.Location()
require.NoError(t, err)
assert.Equal(t, "http://example.com/redirect-test", location.String())
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.NotEmpty(t, string(body))
}
func TestForwardAuthRemoveHopByHopHeaders(t *testing.T) {
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
headers := w.Header()
for _, header := range forward.HopHeaders {
if header == forward.TransferEncoding {
headers.Add(header, "identity")
} else {
headers.Add(header, "test")
}
}
http.Redirect(w, r, "http://example.com/redirect-test", http.StatusFound)
}))
defer authTs.Close()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.ForwardAuth{
Address: authTs.URL,
}
authMiddleware, err := NewForward(context.Background(), next, auth, "authTest")
assert.NoError(t, err, "there should be no error")
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
client := &http.Client{
CheckRedirect: func(r *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := client.Do(req)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, http.StatusFound, res.StatusCode, "they should be equal")
for _, header := range forward.HopHeaders {
assert.Equal(t, "", res.Header.Get(header), "hop-by-hop header '%s' mustn't be set", header)
}
location, err := res.Location()
assert.NoError(t, err, "there should be no error")
assert.Equal(t, "http://example.com/redirect-test", location.String(), "they should be equal")
body, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err, "there should be no error")
assert.NotEmpty(t, string(body), "there should be something in the body")
}
func TestForwardAuthFailResponseHeaders(t *testing.T) {
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie := &http.Cookie{Name: "example", Value: "testing", Path: "/"}
http.SetCookie(w, cookie)
w.Header().Add("X-Foo", "bar")
http.Error(w, "Forbidden", http.StatusForbidden)
}))
defer authTs.Close()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.ForwardAuth{
Address: authTs.URL,
}
authMiddleware, err := NewForward(context.Background(), next, auth, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusForbidden, res.StatusCode)
require.Len(t, res.Cookies(), 1)
for _, cookie := range res.Cookies() {
assert.Equal(t, "testing", cookie.Value)
}
expectedHeaders := http.Header{
"Content-Length": []string{"10"},
"Content-Type": []string{"text/plain; charset=utf-8"},
"X-Foo": []string{"bar"},
"Set-Cookie": []string{"example=testing; Path=/"},
"X-Content-Type-Options": []string{"nosniff"},
}
assert.Len(t, res.Header, 6)
for key, value := range expectedHeaders {
assert.Equal(t, value, res.Header[key])
}
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.Equal(t, "Forbidden\n", string(body))
}
func Test_writeHeader(t *testing.T) {
testCases := []struct {
name string
headers map[string]string
trustForwardHeader bool
emptyHost bool
expectedHeaders map[string]string
checkForUnexpectedHeaders bool
}{
{
name: "trust Forward Header",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
trustForwardHeader: true,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
},
{
name: "not trust Forward Header",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
trustForwardHeader: false,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "foo.bar",
},
},
{
name: "trust Forward Header with empty Host",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
trustForwardHeader: true,
emptyHost: true,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
},
{
name: "not trust Forward Header with empty Host",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
trustForwardHeader: false,
emptyHost: true,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "",
},
},
{
name: "trust Forward Header with forwarded URI",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
"X-Forwarded-Uri": "/forward?q=1",
},
trustForwardHeader: true,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
"X-Forwarded-Uri": "/forward?q=1",
},
},
{
name: "not trust Forward Header with forward requested URI",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
"X-Forwarded-Uri": "/forward?q=1",
},
trustForwardHeader: false,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "foo.bar",
"X-Forwarded-Uri": "/path?q=1",
},
}, {
name: "trust Forward Header with forwarded request Method",
headers: map[string]string{
"X-Forwarded-Method": "OPTIONS",
},
trustForwardHeader: true,
expectedHeaders: map[string]string{
"X-Forwarded-Method": "OPTIONS",
},
},
{
name: "not trust Forward Header with forward request Method",
headers: map[string]string{
"X-Forwarded-Method": "OPTIONS",
},
trustForwardHeader: false,
expectedHeaders: map[string]string{
"X-Forwarded-Method": "GET",
},
},
{
name: "remove hop-by-hop headers",
headers: map[string]string{
forward.Connection: "Connection",
forward.KeepAlive: "KeepAlive",
forward.ProxyAuthenticate: "ProxyAuthenticate",
forward.ProxyAuthorization: "ProxyAuthorization",
forward.Te: "Te",
forward.Trailers: "Trailers",
forward.TransferEncoding: "TransferEncoding",
forward.Upgrade: "Upgrade",
"X-CustomHeader": "CustomHeader",
},
trustForwardHeader: false,
expectedHeaders: map[string]string{
"X-CustomHeader": "CustomHeader",
"X-Forwarded-Proto": "http",
"X-Forwarded-Host": "foo.bar",
"X-Forwarded-Uri": "/path?q=1",
"X-Forwarded-Method": "GET",
},
checkForUnexpectedHeaders: true,
},
}
for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/path?q=1", nil)
for key, value := range test.headers {
req.Header.Set(key, value)
}
if test.emptyHost {
req.Host = ""
}
forwardReq := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/path?q=1", nil)
writeHeader(req, forwardReq, test.trustForwardHeader)
actualHeaders := forwardReq.Header
expectedHeaders := test.expectedHeaders
for key, value := range expectedHeaders {
assert.Equal(t, value, actualHeaders.Get(key))
actualHeaders.Del(key)
}
if test.checkForUnexpectedHeaders {
for key := range actualHeaders {
assert.Fail(t, "Unexpected header found", key)
}
}
})
}
}

View file

@ -0,0 +1,54 @@
package buffering
import (
"context"
"net/http"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
oxybuffer "github.com/vulcand/oxy/buffer"
)
const (
typeName = "Buffer"
)
type buffer struct {
name string
buffer *oxybuffer.Buffer
}
// New creates a buffering middleware.
func New(ctx context.Context, next http.Handler, config config.Buffering, name string) (http.Handler, error) {
logger := middlewares.GetLogger(ctx, name, typeName)
logger.Debug("Creating middleware")
logger.Debug("Setting up buffering: request limits: %d (mem), %d (max), response limits: %d (mem), %d (max) with retry: '%s'",
config.MemRequestBodyBytes, config.MaxRequestBodyBytes, config.MemResponseBodyBytes, config.MaxResponseBodyBytes, config.RetryExpression)
oxyBuffer, err := oxybuffer.New(
next,
oxybuffer.MemRequestBodyBytes(config.MemRequestBodyBytes),
oxybuffer.MaxRequestBodyBytes(config.MaxRequestBodyBytes),
oxybuffer.MemResponseBodyBytes(config.MemResponseBodyBytes),
oxybuffer.MaxResponseBodyBytes(config.MaxResponseBodyBytes),
oxybuffer.CondSetter(len(config.RetryExpression) > 0, oxybuffer.Retry(config.RetryExpression)),
)
if err != nil {
return nil, err
}
return &buffer{
name: name,
buffer: oxyBuffer,
}, nil
}
func (b *buffer) GetTracingInformation() (string, ext.SpanKindEnum) {
return b.name, tracing.SpanKindNoneEnum
}
func (b *buffer) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
b.buffer.ServeHTTP(rw, req)
}

View file

@ -0,0 +1,26 @@
package chain
import (
"context"
"net/http"
"github.com/containous/alice"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
)
const (
typeName = "Chain"
)
type chainBuilder interface {
BuildChain(ctx context.Context, middlewares []string) *alice.Chain
}
// New creates a chain middleware
func New(ctx context.Context, next http.Handler, config config.Chain, builder chainBuilder, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
middlewareChain := builder.BuildChain(ctx, config.Middlewares)
return middlewareChain.Then(next)
}

View file

@ -0,0 +1,61 @@
package circuitbreaker
import (
"context"
"net/http"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/vulcand/oxy/cbreaker"
)
const (
typeName = "CircuitBreaker"
)
type circuitBreaker struct {
circuitBreaker *cbreaker.CircuitBreaker
name string
}
// New creates a new circuit breaker middleware.
func New(ctx context.Context, next http.Handler, confCircuitBreaker config.CircuitBreaker, name string) (http.Handler, error) {
expression := confCircuitBreaker.Expression
logger := middlewares.GetLogger(ctx, name, typeName)
logger.Debug("Creating middleware")
logger.Debug("Setting up with expression: %s", expression)
oxyCircuitBreaker, err := cbreaker.New(next, expression, createCircuitBreakerOptions(expression))
if err != nil {
return nil, err
}
return &circuitBreaker{
circuitBreaker: oxyCircuitBreaker,
name: name,
}, nil
}
// NewCircuitBreakerOptions returns a new CircuitBreakerOption
func createCircuitBreakerOptions(expression string) cbreaker.CircuitBreakerOption {
return cbreaker.Fallback(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
tracing.SetErrorWithEvent(req, "blocked by circuit-breaker (%q)", expression)
rw.WriteHeader(http.StatusServiceUnavailable)
if _, err := rw.Write([]byte(http.StatusText(http.StatusServiceUnavailable))); err != nil {
log.FromContext(req.Context()).Error(err)
}
}))
}
func (c *circuitBreaker) GetTracingInformation() (string, ext.SpanKindEnum) {
return c.name, tracing.SpanKindNoneEnum
}
func (c *circuitBreaker) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
middlewares.GetLogger(req.Context(), c.name, typeName).Debug("Entering middleware")
c.circuitBreaker.ServeHTTP(rw, req)
}

View file

@ -0,0 +1,58 @@
package compress
import (
"compress/gzip"
"context"
"net/http"
"strings"
"github.com/NYTimes/gziphandler"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/sirupsen/logrus"
)
const (
typeName = "Compress"
)
// Compress is a middleware that allows to compress the response.
type compress struct {
next http.Handler
name string
}
// New creates a new compress middleware.
func New(ctx context.Context, next http.Handler, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
return &compress{
next: next,
name: name,
}, nil
}
func (c *compress) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
contentType := req.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/grpc") {
c.next.ServeHTTP(rw, req)
} else {
gzipHandler(c.next, middlewares.GetLogger(req.Context(), c.name, typeName)).ServeHTTP(rw, req)
}
}
func (c *compress) GetTracingInformation() (string, ext.SpanKindEnum) {
return c.name, tracing.SpanKindNoneEnum
}
func gzipHandler(h http.Handler, logger logrus.FieldLogger) http.Handler {
wrapper, err := gziphandler.GzipHandlerWithOpts(
gziphandler.CompressionLevel(gzip.DefaultCompression),
gziphandler.MinSize(gziphandler.DefaultMinSize))
if err != nil {
logger.Error(err)
}
return wrapper(h)
}

View file

@ -0,0 +1,260 @@
package compress
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"github.com/NYTimes/gziphandler"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
acceptEncodingHeader = "Accept-Encoding"
contentEncodingHeader = "Content-Encoding"
contentTypeHeader = "Content-Type"
varyHeader = "Vary"
gzipValue = "gzip"
)
func TestShouldCompressWhenNoContentEncodingHeader(t *testing.T) {
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
baseBody := generateBytes(gziphandler.DefaultMinSize)
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
_, err := rw.Write(baseBody)
assert.NoError(t, err)
})
handler := &compress{next: next}
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req)
assert.Equal(t, gzipValue, rw.Header().Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, rw.Header().Get(varyHeader))
if assert.ObjectsAreEqualValues(rw.Body.Bytes(), baseBody) {
assert.Fail(t, "expected a compressed body", "got %v", rw.Body.Bytes())
}
}
func TestShouldNotCompressWhenContentEncodingHeader(t *testing.T) {
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
fakeCompressedBody := generateBytes(gziphandler.DefaultMinSize)
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(contentEncodingHeader, gzipValue)
rw.Header().Add(varyHeader, acceptEncodingHeader)
_, err := rw.Write(fakeCompressedBody)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
})
handler := &compress{next: next}
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req)
assert.Equal(t, gzipValue, rw.Header().Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, rw.Header().Get(varyHeader))
assert.EqualValues(t, rw.Body.Bytes(), fakeCompressedBody)
}
func TestShouldNotCompressWhenNoAcceptEncodingHeader(t *testing.T) {
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
fakeBody := generateBytes(gziphandler.DefaultMinSize)
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
_, err := rw.Write(fakeBody)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
})
handler := &compress{next: next}
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req)
assert.Empty(t, rw.Header().Get(contentEncodingHeader))
assert.EqualValues(t, rw.Body.Bytes(), fakeBody)
}
func TestShouldNotCompressWhenGRPC(t *testing.T) {
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
req.Header.Add(contentTypeHeader, "application/grpc")
baseBody := generateBytes(gziphandler.DefaultMinSize)
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
_, err := rw.Write(baseBody)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
})
handler := &compress{next: next}
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req)
assert.Empty(t, rw.Header().Get(acceptEncodingHeader))
assert.Empty(t, rw.Header().Get(contentEncodingHeader))
assert.EqualValues(t, rw.Body.Bytes(), baseBody)
}
func TestIntegrationShouldNotCompress(t *testing.T) {
fakeCompressedBody := generateBytes(100000)
testCases := []struct {
name string
handler http.Handler
expectedStatusCode int
}{
{
name: "when content already compressed",
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(contentEncodingHeader, gzipValue)
rw.Header().Add(varyHeader, acceptEncodingHeader)
_, err := rw.Write(fakeCompressedBody)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}),
expectedStatusCode: http.StatusOK,
},
{
name: "when content already compressed and status code Created",
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(contentEncodingHeader, gzipValue)
rw.Header().Add(varyHeader, acceptEncodingHeader)
rw.WriteHeader(http.StatusCreated)
_, err := rw.Write(fakeCompressedBody)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}),
expectedStatusCode: http.StatusCreated,
},
}
for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {
compress := &compress{next: test.handler}
ts := httptest.NewServer(compress)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, test.expectedStatusCode, resp.StatusCode)
assert.Equal(t, gzipValue, resp.Header.Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, resp.Header.Get(varyHeader))
body, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
assert.EqualValues(t, fakeCompressedBody, body)
})
}
}
func TestShouldWriteHeaderWhenFlush(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(contentEncodingHeader, gzipValue)
rw.Header().Add(varyHeader, acceptEncodingHeader)
rw.WriteHeader(http.StatusUnauthorized)
rw.(http.Flusher).Flush()
_, err := rw.Write([]byte("short"))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
})
handler := &compress{next: next}
ts := httptest.NewServer(handler)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.Equal(t, gzipValue, resp.Header.Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, resp.Header.Get(varyHeader))
}
func TestIntegrationShouldCompress(t *testing.T) {
fakeBody := generateBytes(100000)
testCases := []struct {
name string
handler http.Handler
expectedStatusCode int
}{
{
name: "when AcceptEncoding header is present",
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
_, err := rw.Write(fakeBody)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}),
expectedStatusCode: http.StatusOK,
},
{
name: "when AcceptEncoding header is present and status code Created",
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusCreated)
_, err := rw.Write(fakeBody)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}),
expectedStatusCode: http.StatusCreated,
},
}
for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {
compress := &compress{next: test.handler}
ts := httptest.NewServer(compress)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, test.expectedStatusCode, resp.StatusCode)
assert.Equal(t, gzipValue, resp.Header.Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, resp.Header.Get(varyHeader))
body, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
if assert.ObjectsAreEqualValues(body, fakeBody) {
assert.Fail(t, "expected a compressed body", "got %v", body)
}
})
}
}
func generateBytes(len int) []byte {
var value []byte
for i := 0; i < len; i++ {
value = append(value, 0x61+byte(i))
}
return value
}

View file

@ -0,0 +1,248 @@
package customerrors
import (
"bufio"
"bytes"
"context"
"fmt"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/containous/traefik/old/types"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/sirupsen/logrus"
"github.com/vulcand/oxy/utils"
)
// Compile time validation that the response recorder implements http interfaces correctly.
var _ middlewares.Stateful = &responseRecorderWithCloseNotify{}
const (
typeName = "customError"
backendURL = "http://0.0.0.0"
)
type serviceBuilder interface {
BuildHTTP(ctx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error)
}
// customErrors is a middleware that provides the custom error pages..
type customErrors struct {
name string
next http.Handler
backendHandler http.Handler
httpCodeRanges types.HTTPCodeRanges
backendQuery string
}
// New creates a new custom error pages middleware.
func New(ctx context.Context, next http.Handler, config config.ErrorPage, serviceBuilder serviceBuilder, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
httpCodeRanges, err := types.NewHTTPCodeRanges(config.Status)
if err != nil {
return nil, err
}
backend, err := serviceBuilder.BuildHTTP(ctx, config.Service, nil)
if err != nil {
return nil, err
}
return &customErrors{
name: name,
next: next,
backendHandler: backend,
httpCodeRanges: httpCodeRanges,
backendQuery: config.Query,
}, nil
}
func (c *customErrors) GetTracingInformation() (string, ext.SpanKindEnum) {
return c.name, tracing.SpanKindNoneEnum
}
func (c *customErrors) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), c.name, typeName)
if c.backendHandler == nil {
logger.Error("Error pages: no backend handler.")
tracing.SetErrorWithEvent(req, "Error pages: no backend handler.")
c.next.ServeHTTP(rw, req)
return
}
recorder := newResponseRecorder(rw, middlewares.GetLogger(context.Background(), "test", typeName))
c.next.ServeHTTP(recorder, req)
// check the recorder code against the configured http status code ranges
for _, block := range c.httpCodeRanges {
if recorder.GetCode() >= block[0] && recorder.GetCode() <= block[1] {
logger.Errorf("Caught HTTP Status Code %d, returning error page", recorder.GetCode())
var query string
if len(c.backendQuery) > 0 {
query = "/" + strings.TrimPrefix(c.backendQuery, "/")
query = strings.Replace(query, "{status}", strconv.Itoa(recorder.GetCode()), -1)
}
pageReq, err := newRequest(backendURL + query)
if err != nil {
logger.Error(err)
rw.WriteHeader(recorder.GetCode())
_, err = fmt.Fprint(rw, http.StatusText(recorder.GetCode()))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
return
}
recorderErrorPage := newResponseRecorder(rw, middlewares.GetLogger(context.Background(), "test", typeName))
utils.CopyHeaders(pageReq.Header, req.Header)
c.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context()))
utils.CopyHeaders(rw.Header(), recorderErrorPage.Header())
rw.WriteHeader(recorder.GetCode())
if _, err = rw.Write(recorderErrorPage.GetBody().Bytes()); err != nil {
logger.Error(err)
}
return
}
}
// did not catch a configured status code so proceed with the request
utils.CopyHeaders(rw.Header(), recorder.Header())
rw.WriteHeader(recorder.GetCode())
_, err := rw.Write(recorder.GetBody().Bytes())
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}
func newRequest(baseURL string) (*http.Request, error) {
u, err := url.Parse(baseURL)
if err != nil {
return nil, fmt.Errorf("error pages: error when parse URL: %v", err)
}
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
if err != nil {
return nil, fmt.Errorf("error pages: error when create query: %v", err)
}
req.RequestURI = u.RequestURI()
return req, nil
}
type responseRecorder interface {
http.ResponseWriter
http.Flusher
GetCode() int
GetBody() *bytes.Buffer
IsStreamingResponseStarted() bool
}
// newResponseRecorder returns an initialized responseRecorder.
func newResponseRecorder(rw http.ResponseWriter, logger logrus.FieldLogger) responseRecorder {
recorder := &responseRecorderWithoutCloseNotify{
HeaderMap: make(http.Header),
Body: new(bytes.Buffer),
Code: http.StatusOK,
responseWriter: rw,
logger: logger,
}
if _, ok := rw.(http.CloseNotifier); ok {
return &responseRecorderWithCloseNotify{recorder}
}
return recorder
}
// responseRecorderWithoutCloseNotify is an implementation of http.ResponseWriter that
// records its mutations for later inspection.
type responseRecorderWithoutCloseNotify struct {
Code int // the HTTP response code from WriteHeader
HeaderMap http.Header // the HTTP response headers
Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
responseWriter http.ResponseWriter
err error
streamingResponseStarted bool
logger logrus.FieldLogger
}
type responseRecorderWithCloseNotify struct {
*responseRecorderWithoutCloseNotify
}
// CloseNotify returns a channel that receives at most a
// single value (true) when the client connection has gone away.
func (r *responseRecorderWithCloseNotify) CloseNotify() <-chan bool {
return r.responseWriter.(http.CloseNotifier).CloseNotify()
}
// Header returns the response headers.
func (r *responseRecorderWithoutCloseNotify) Header() http.Header {
if r.HeaderMap == nil {
r.HeaderMap = make(http.Header)
}
return r.HeaderMap
}
func (r *responseRecorderWithoutCloseNotify) GetCode() int {
return r.Code
}
func (r *responseRecorderWithoutCloseNotify) GetBody() *bytes.Buffer {
return r.Body
}
func (r *responseRecorderWithoutCloseNotify) IsStreamingResponseStarted() bool {
return r.streamingResponseStarted
}
// Write always succeeds and writes to rw.Body, if not nil.
func (r *responseRecorderWithoutCloseNotify) Write(buf []byte) (int, error) {
if r.err != nil {
return 0, r.err
}
return r.Body.Write(buf)
}
// WriteHeader sets rw.Code.
func (r *responseRecorderWithoutCloseNotify) WriteHeader(code int) {
r.Code = code
}
// Hijack hijacks the connection
func (r *responseRecorderWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return r.responseWriter.(http.Hijacker).Hijack()
}
// Flush sends any buffered data to the client.
func (r *responseRecorderWithoutCloseNotify) Flush() {
if !r.streamingResponseStarted {
utils.CopyHeaders(r.responseWriter.Header(), r.Header())
r.responseWriter.WriteHeader(r.Code)
r.streamingResponseStarted = true
}
_, err := r.responseWriter.Write(r.Body.Bytes())
if err != nil {
r.logger.Errorf("Error writing response in responseRecorder: %v", err)
r.err = err
}
r.Body.Reset()
if flusher, ok := r.responseWriter.(http.Flusher); ok {
flusher.Flush()
}
}

View file

@ -0,0 +1,176 @@
package customerrors
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHandler(t *testing.T) {
testCases := []struct {
desc string
errorPage *config.ErrorPage
backendCode int
backendErrorHandler http.HandlerFunc
validate func(t *testing.T, recorder *httptest.ResponseRecorder)
}{
{
desc: "no error",
errorPage: &config.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusOK,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusOK))
},
},
{
desc: "in the range",
errorPage: &config.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusInternalServerError,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusInternalServerError, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My error page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
{
desc: "not in the range",
errorPage: &config.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusBadGateway,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusBadGateway, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusBadGateway))
assert.NotContains(t, recorder.Body.String(), "Test Server", "Should return the oops page since we have not configured the 502 code")
},
},
{
desc: "query replacement",
errorPage: &config.ErrorPage{Service: "error", Query: "/{status}", Status: []string{"503-503"}},
backendCode: http.StatusServiceUnavailable,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI == "/503" {
fmt.Fprintln(w, "My 503 page.")
} else {
fmt.Fprintln(w, "Failed")
}
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My 503 page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
{
desc: "Single code",
errorPage: &config.ErrorPage{Service: "error", Query: "/{status}", Status: []string{"503"}},
backendCode: http.StatusServiceUnavailable,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI == "/503" {
fmt.Fprintln(w, "My 503 page.")
} else {
fmt.Fprintln(w, "Failed")
}
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My 503 page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
serviceBuilderMock := &mockServiceBuilder{handler: test.backendErrorHandler}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(test.backendCode)
fmt.Fprintln(w, http.StatusText(test.backendCode))
})
errorPageHandler, err := New(context.Background(), handler, *test.errorPage, serviceBuilderMock, "test")
require.NoError(t, err)
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost/test", nil)
recorder := httptest.NewRecorder()
errorPageHandler.ServeHTTP(recorder, req)
test.validate(t, recorder)
})
}
}
type mockServiceBuilder struct {
handler http.Handler
}
func (m *mockServiceBuilder) BuildHTTP(_ context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) {
return m.handler, nil
}
func TestNewResponseRecorder(t *testing.T) {
testCases := []struct {
desc string
rw http.ResponseWriter
expected http.ResponseWriter
}{
{
desc: "Without Close Notify",
rw: httptest.NewRecorder(),
expected: &responseRecorderWithoutCloseNotify{},
},
{
desc: "With Close Notify",
rw: &mockRWCloseNotify{},
expected: &responseRecorderWithCloseNotify{},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
rec := newResponseRecorder(test.rw, middlewares.GetLogger(context.Background(), "test", typeName))
assert.IsType(t, rec, test.expected)
})
}
}
type mockRWCloseNotify struct{}
func (m *mockRWCloseNotify) CloseNotify() <-chan bool {
panic("implement me")
}
func (m *mockRWCloseNotify) Header() http.Header {
panic("implement me")
}
func (m *mockRWCloseNotify) Write([]byte) (int, error) {
panic("implement me")
}
func (m *mockRWCloseNotify) WriteHeader(int) {
panic("implement me")
}

View file

@ -0,0 +1,33 @@
package emptybackendhandler
import (
"net/http"
"github.com/containous/traefik/pkg/healthcheck"
)
// EmptyBackend is a middleware that checks whether the current Backend
// has at least one active Server in respect to the healthchecks and if this
// is not the case, it will stop the middleware chain and respond with 503.
type emptyBackend struct {
next healthcheck.BalancerHandler
}
// New creates a new EmptyBackend middleware.
func New(lb healthcheck.BalancerHandler) http.Handler {
return &emptyBackend{next: lb}
}
// ServeHTTP responds with 503 when there is no active Server and otherwise
// invokes the next handler in the middleware chain.
func (e *emptyBackend) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if len(e.next.Servers()) == 0 {
rw.WriteHeader(http.StatusServiceUnavailable)
_, err := rw.Write([]byte(http.StatusText(http.StatusServiceUnavailable)))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
} else {
e.next.ServeHTTP(rw, req)
}
}

View file

@ -0,0 +1,81 @@
package emptybackendhandler
import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/vulcand/oxy/roundrobin"
)
func TestEmptyBackendHandler(t *testing.T) {
testCases := []struct {
amountServer int
expectedStatusCode int
}{
{
amountServer: 0,
expectedStatusCode: http.StatusServiceUnavailable,
},
{
amountServer: 1,
expectedStatusCode: http.StatusOK,
},
}
for _, test := range testCases {
test := test
t.Run(fmt.Sprintf("amount servers %d", test.amountServer), func(t *testing.T) {
t.Parallel()
handler := New(&healthCheckLoadBalancer{amountServer: test.amountServer})
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
handler.ServeHTTP(recorder, req)
assert.Equal(t, test.expectedStatusCode, recorder.Result().StatusCode)
})
}
}
type healthCheckLoadBalancer struct {
amountServer int
}
func (lb *healthCheckLoadBalancer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
func (lb *healthCheckLoadBalancer) Servers() []*url.URL {
servers := make([]*url.URL, lb.amountServer)
for i := 0; i < lb.amountServer; i++ {
servers = append(servers, testhelpers.MustParseURL("http://localhost"))
}
return servers
}
func (lb *healthCheckLoadBalancer) RemoveServer(u *url.URL) error {
return nil
}
func (lb *healthCheckLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error {
return nil
}
func (lb *healthCheckLoadBalancer) ServerWeight(u *url.URL) (int, bool) {
return 0, false
}
func (lb *healthCheckLoadBalancer) NextServer() (*url.URL, error) {
return nil, nil
}
func (lb *healthCheckLoadBalancer) Next() http.Handler {
return nil
}

View file

@ -0,0 +1,51 @@
package forwardedheaders
import (
"net/http"
"github.com/containous/traefik/pkg/ip"
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/utils"
)
// XForwarded filter for XForwarded headers.
type XForwarded struct {
insecure bool
trustedIps []string
ipChecker *ip.Checker
next http.Handler
}
// NewXForwarded creates a new XForwarded.
func NewXForwarded(insecure bool, trustedIps []string, next http.Handler) (*XForwarded, error) {
var ipChecker *ip.Checker
if len(trustedIps) > 0 {
var err error
ipChecker, err = ip.NewChecker(trustedIps)
if err != nil {
return nil, err
}
}
return &XForwarded{
insecure: insecure,
trustedIps: trustedIps,
ipChecker: ipChecker,
next: next,
}, nil
}
func (x *XForwarded) isTrustedIP(ip string) bool {
if x.ipChecker == nil {
return false
}
return x.ipChecker.IsAuthorized(ip) == nil
}
func (x *XForwarded) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !x.insecure && !x.isTrustedIP(r.RemoteAddr) {
utils.RemoveHeaders(r.Header, forward.XHeaders...)
}
x.next.ServeHTTP(w, r)
}

View file

@ -0,0 +1,128 @@
package forwardedheaders
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestServeHTTP(t *testing.T) {
testCases := []struct {
desc string
insecure bool
trustedIps []string
incomingHeaders map[string]string
remoteAddr string
expectedHeaders map[string]string
}{
{
desc: "all Empty",
insecure: true,
trustedIps: nil,
remoteAddr: "",
incomingHeaders: map[string]string{},
expectedHeaders: map[string]string{
"X-Forwarded-for": "",
},
},
{
desc: "insecure true with incoming X-Forwarded-For",
insecure: true,
trustedIps: nil,
remoteAddr: "",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
},
{
desc: "insecure false with incoming X-Forwarded-For",
insecure: false,
trustedIps: nil,
remoteAddr: "",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "",
},
},
{
desc: "insecure false with incoming X-Forwarded-For and valid Trusted Ips",
insecure: false,
trustedIps: []string{"10.0.1.100"},
remoteAddr: "10.0.1.100:80",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
},
{
desc: "insecure false with incoming X-Forwarded-For and invalid Trusted Ips",
insecure: false,
trustedIps: []string{"10.0.1.100"},
remoteAddr: "10.0.1.101:80",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "",
},
},
{
desc: "insecure false with incoming X-Forwarded-For and valid Trusted Ips CIDR",
insecure: false,
trustedIps: []string{"1.2.3.4/24"},
remoteAddr: "1.2.3.156:80",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
},
{
desc: "insecure false with incoming X-Forwarded-For and invalid Trusted Ips CIDR",
insecure: false,
trustedIps: []string{"1.2.3.4/24"},
remoteAddr: "10.0.1.101:80",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "",
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
req, err := http.NewRequest(http.MethodGet, "", nil)
require.NoError(t, err)
req.RemoteAddr = test.remoteAddr
for k, v := range test.incomingHeaders {
req.Header.Set(k, v)
}
m, err := NewXForwarded(test.insecure, test.trustedIps, http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
require.NoError(t, err)
m.ServeHTTP(nil, req)
for k, v := range test.expectedHeaders {
assert.Equal(t, v, req.Header.Get(k))
}
})
}
}

View file

@ -0,0 +1,35 @@
package middlewares
import (
"net/http"
"github.com/containous/traefik/pkg/safe"
)
// HTTPHandlerSwitcher allows hot switching of http.ServeMux
type HTTPHandlerSwitcher struct {
handler *safe.Safe
}
// NewHandlerSwitcher builds a new instance of HTTPHandlerSwitcher
func NewHandlerSwitcher(newHandler http.Handler) (hs *HTTPHandlerSwitcher) {
return &HTTPHandlerSwitcher{
handler: safe.New(newHandler),
}
}
func (h *HTTPHandlerSwitcher) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
handlerBackup := h.handler.Get().(http.Handler)
handlerBackup.ServeHTTP(rw, req)
}
// GetHandler returns the current http.ServeMux
func (h *HTTPHandlerSwitcher) GetHandler() (newHandler http.Handler) {
handler := h.handler.Get().(http.Handler)
return handler
}
// UpdateHandler safely updates the current http.ServeMux with a new one
func (h *HTTPHandlerSwitcher) UpdateHandler(newHandler http.Handler) {
h.handler.Set(newHandler)
}

View file

@ -0,0 +1,134 @@
// Package headers Middleware based on https://github.com/unrolled/secure.
package headers
import (
"context"
"errors"
"net/http"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/unrolled/secure"
)
const (
typeName = "Headers"
)
type headers struct {
name string
handler http.Handler
}
// New creates a Headers middleware.
func New(ctx context.Context, next http.Handler, config config.Headers, name string) (http.Handler, error) {
// HeaderMiddleware -> SecureMiddleWare -> next
logger := middlewares.GetLogger(ctx, name, typeName)
logger.Debug("Creating middleware")
if !config.HasSecureHeadersDefined() && !config.HasCustomHeadersDefined() {
return nil, errors.New("headers configuration not valid")
}
var handler http.Handler
nextHandler := next
if config.HasSecureHeadersDefined() {
logger.Debug("Setting up secureHeaders from %v", config)
handler = newSecure(next, config)
nextHandler = handler
}
if config.HasCustomHeadersDefined() {
logger.Debug("Setting up customHeaders from %v", config)
handler = newHeader(nextHandler, config)
}
return &headers{
handler: handler,
name: name,
}, nil
}
func (h *headers) GetTracingInformation() (string, ext.SpanKindEnum) {
return h.name, tracing.SpanKindNoneEnum
}
func (h *headers) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
h.handler.ServeHTTP(rw, req)
}
type secureHeader struct {
next http.Handler
secure *secure.Secure
}
// newSecure constructs a new secure instance with supplied options.
func newSecure(next http.Handler, headers config.Headers) *secureHeader {
opt := secure.Options{
BrowserXssFilter: headers.BrowserXSSFilter,
ContentTypeNosniff: headers.ContentTypeNosniff,
ForceSTSHeader: headers.ForceSTSHeader,
FrameDeny: headers.FrameDeny,
IsDevelopment: headers.IsDevelopment,
SSLRedirect: headers.SSLRedirect,
SSLForceHost: headers.SSLForceHost,
SSLTemporaryRedirect: headers.SSLTemporaryRedirect,
STSIncludeSubdomains: headers.STSIncludeSubdomains,
STSPreload: headers.STSPreload,
ContentSecurityPolicy: headers.ContentSecurityPolicy,
CustomBrowserXssValue: headers.CustomBrowserXSSValue,
CustomFrameOptionsValue: headers.CustomFrameOptionsValue,
PublicKey: headers.PublicKey,
ReferrerPolicy: headers.ReferrerPolicy,
SSLHost: headers.SSLHost,
AllowedHosts: headers.AllowedHosts,
HostsProxyHeaders: headers.HostsProxyHeaders,
SSLProxyHeaders: headers.SSLProxyHeaders,
STSSeconds: headers.STSSeconds,
}
return &secureHeader{
next: next,
secure: secure.New(opt),
}
}
func (s secureHeader) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
s.secure.HandlerFuncWithNextForRequestOnly(rw, req, s.next.ServeHTTP)
}
// Header is a middleware that helps setup a few basic security features. A single headerOptions struct can be
// provided to configure which features should be enabled, and the ability to override a few of the default values.
type header struct {
next http.Handler
// If Custom request headers are set, these will be added to the request
customRequestHeaders map[string]string
}
// NewHeader constructs a new header instance from supplied frontend header struct.
func newHeader(next http.Handler, headers config.Headers) *header {
return &header{
next: next,
customRequestHeaders: headers.CustomRequestHeaders,
}
}
func (s *header) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
s.modifyRequestHeaders(req)
s.next.ServeHTTP(rw, req)
}
// modifyRequestHeaders set or delete request headers.
func (s *header) modifyRequestHeaders(req *http.Request) {
// Loop through Custom request headers
for header, value := range s.customRequestHeaders {
if value == "" {
req.Header.Del(header)
} else {
req.Header.Set(header, value)
}
}
}

View file

@ -0,0 +1,190 @@
package headers
// Middleware tests based on https://github.com/unrolled/secure
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCustomRequestHeader(t *testing.T) {
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
header := newHeader(emptyHandler, config.Headers{
CustomRequestHeaders: map[string]string{
"X-Custom-Request-Header": "test_request",
},
})
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
header.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"))
}
func TestCustomRequestHeaderEmptyValue(t *testing.T) {
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
header := newHeader(emptyHandler, config.Headers{
CustomRequestHeaders: map[string]string{
"X-Custom-Request-Header": "test_request",
},
})
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
header.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"))
header = newHeader(emptyHandler, config.Headers{
CustomRequestHeaders: map[string]string{
"X-Custom-Request-Header": "",
},
})
header.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
assert.Equal(t, "", req.Header.Get("X-Custom-Request-Header"))
}
func TestSecureHeader(t *testing.T) {
testCases := []struct {
desc string
fromHost string
expected int
}{
{
desc: "Should accept the request when given a host that is in the list",
fromHost: "foo.com",
expected: http.StatusOK,
},
{
desc: "Should refuse the request when no host is given",
fromHost: "",
expected: http.StatusInternalServerError,
},
{
desc: "Should refuse the request when no matching host is given",
fromHost: "boo.com",
expected: http.StatusInternalServerError,
},
}
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
header, err := New(context.Background(), emptyHandler, config.Headers{
AllowedHosts: []string{"foo.com", "bar.com"},
}, "foo")
require.NoError(t, err)
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
req.Host = test.fromHost
header.ServeHTTP(res, req)
assert.Equal(t, test.expected, res.Code)
})
}
}
func TestSSLForceHost(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
_, _ = rw.Write([]byte("OK"))
})
testCases := []struct {
desc string
host string
secureMiddleware *secureHeader
expected int
}{
{
desc: "http should return a 301",
host: "http://powpow.example.com",
secureMiddleware: newSecure(next, config.Headers{
SSLRedirect: true,
SSLForceHost: true,
SSLHost: "powpow.example.com",
}),
expected: http.StatusMovedPermanently,
},
{
desc: "http sub domain should return a 301",
host: "http://www.powpow.example.com",
secureMiddleware: newSecure(next, config.Headers{
SSLRedirect: true,
SSLForceHost: true,
SSLHost: "powpow.example.com",
}),
expected: http.StatusMovedPermanently,
},
{
desc: "https should return a 200",
host: "https://powpow.example.com",
secureMiddleware: newSecure(next, config.Headers{
SSLRedirect: true,
SSLForceHost: true,
SSLHost: "powpow.example.com",
}),
expected: http.StatusOK,
},
{
desc: "https sub domain should return a 301",
host: "https://www.powpow.example.com",
secureMiddleware: newSecure(next, config.Headers{
SSLRedirect: true,
SSLForceHost: true,
SSLHost: "powpow.example.com",
}),
expected: http.StatusMovedPermanently,
},
{
desc: "http without force host and sub domain should return a 301",
host: "http://www.powpow.example.com",
secureMiddleware: newSecure(next, config.Headers{
SSLRedirect: true,
SSLForceHost: false,
SSLHost: "powpow.example.com",
}),
expected: http.StatusMovedPermanently,
},
{
desc: "https without force host and sub domain should return a 301",
host: "https://www.powpow.example.com",
secureMiddleware: newSecure(next, config.Headers{
SSLRedirect: true,
SSLForceHost: false,
SSLHost: "powpow.example.com",
}),
expected: http.StatusOK,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
req := testhelpers.MustNewRequest(http.MethodGet, test.host, nil)
rw := httptest.NewRecorder()
test.secureMiddleware.ServeHTTP(rw, req)
assert.Equal(t, test.expected, rw.Result().StatusCode)
})
}
}

View file

@ -0,0 +1,85 @@
package ipwhitelist
import (
"context"
"fmt"
"net/http"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/ip"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
const (
typeName = "IPWhiteLister"
)
// ipWhiteLister is a middleware that provides Checks of the Requesting IP against a set of Whitelists
type ipWhiteLister struct {
next http.Handler
whiteLister *ip.Checker
strategy ip.Strategy
name string
}
// New builds a new IPWhiteLister given a list of CIDR-Strings to whitelist
func New(ctx context.Context, next http.Handler, config config.IPWhiteList, name string) (http.Handler, error) {
logger := middlewares.GetLogger(ctx, name, typeName)
logger.Debug("Creating middleware")
if len(config.SourceRange) == 0 {
return nil, errors.New("sourceRange is empty, IPWhiteLister not created")
}
checker, err := ip.NewChecker(config.SourceRange)
if err != nil {
return nil, fmt.Errorf("cannot parse CIDR whitelist %s: %v", config.SourceRange, err)
}
strategy, err := config.IPStrategy.Get()
if err != nil {
return nil, err
}
logger.Debugf("Setting up IPWhiteLister with sourceRange: %s", config.SourceRange)
return &ipWhiteLister{
strategy: strategy,
whiteLister: checker,
next: next,
name: name,
}, nil
}
func (wl *ipWhiteLister) GetTracingInformation() (string, ext.SpanKindEnum) {
return wl.name, tracing.SpanKindNoneEnum
}
func (wl *ipWhiteLister) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), wl.name, typeName)
err := wl.whiteLister.IsAuthorized(wl.strategy.GetIP(req))
if err != nil {
logMessage := fmt.Sprintf("rejecting request %+v: %v", req, err)
logger.Debug(logMessage)
tracing.SetErrorWithEvent(req, logMessage)
reject(logger, rw)
return
}
logger.Debugf("Accept %s: %+v", wl.strategy.GetIP(req), req)
wl.next.ServeHTTP(rw, req)
}
func reject(logger logrus.FieldLogger, rw http.ResponseWriter) {
statusCode := http.StatusForbidden
rw.WriteHeader(statusCode)
_, err := rw.Write([]byte(http.StatusText(statusCode)))
if err != nil {
logger.Error(err)
}
}

View file

@ -0,0 +1,100 @@
package ipwhitelist
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewIPWhiteLister(t *testing.T) {
testCases := []struct {
desc string
whiteList config.IPWhiteList
expectedError bool
}{
{
desc: "invalid IP",
whiteList: config.IPWhiteList{
SourceRange: []string{"foo"},
},
expectedError: true,
},
{
desc: "valid IP",
whiteList: config.IPWhiteList{
SourceRange: []string{"10.10.10.10"},
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
whiteLister, err := New(context.Background(), next, test.whiteList, "traefikTest")
if test.expectedError {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.NotNil(t, whiteLister)
}
})
}
}
func TestIPWhiteLister_ServeHTTP(t *testing.T) {
testCases := []struct {
desc string
whiteList config.IPWhiteList
remoteAddr string
expected int
}{
{
desc: "authorized with remote address",
whiteList: config.IPWhiteList{
SourceRange: []string{"20.20.20.20"},
},
remoteAddr: "20.20.20.20:1234",
expected: 200,
},
{
desc: "non authorized with remote address",
whiteList: config.IPWhiteList{
SourceRange: []string{"20.20.20.20"},
},
remoteAddr: "20.20.20.21:1234",
expected: 403,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
whiteLister, err := New(context.Background(), next, test.whiteList, "traefikTest")
require.NoError(t, err)
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://10.10.10.10", nil)
if len(test.remoteAddr) > 0 {
req.RemoteAddr = test.remoteAddr
}
whiteLister.ServeHTTP(recorder, req)
assert.Equal(t, test.expected, recorder.Code)
})
}
}

View file

@ -0,0 +1,48 @@
package maxconnection
import (
"context"
"fmt"
"net/http"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/vulcand/oxy/connlimit"
"github.com/vulcand/oxy/utils"
)
const (
typeName = "MaxConnection"
)
type maxConnection struct {
handler http.Handler
name string
}
// New creates a max connection middleware.
func New(ctx context.Context, next http.Handler, maxConns config.MaxConn, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
extractFunc, err := utils.NewExtractor(maxConns.ExtractorFunc)
if err != nil {
return nil, fmt.Errorf("error creating connection limit: %v", err)
}
handler, err := connlimit.New(next, extractFunc, maxConns.Amount)
if err != nil {
return nil, fmt.Errorf("error creating connection limit: %v", err)
}
return &maxConnection{handler: handler, name: name}, nil
}
func (mc *maxConnection) GetTracingInformation() (string, ext.SpanKindEnum) {
return mc.name, tracing.SpanKindNoneEnum
}
func (mc *maxConnection) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
mc.handler.ServeHTTP(rw, req)
}

View file

@ -0,0 +1,13 @@
package middlewares
import (
"context"
"github.com/containous/traefik/pkg/log"
"github.com/sirupsen/logrus"
)
// GetLogger creates a logger configured with the middleware fields.
func GetLogger(ctx context.Context, middleware string, middlewareType string) logrus.FieldLogger {
return log.FromContext(ctx).WithField(log.MiddlewareName, middleware).WithField(log.MiddlewareType, middlewareType)
}

View file

@ -0,0 +1,294 @@
package passtlsclientcert
import (
"context"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/sirupsen/logrus"
)
const (
xForwardedTLSClientCert = "X-Forwarded-Tls-Client-Cert"
xForwardedTLSClientCertInfo = "X-Forwarded-Tls-Client-Cert-Info"
typeName = "PassClientTLSCert"
)
var attributeTypeNames = map[string]string{
"0.9.2342.19200300.100.1.25": "DC", // Domain component OID - RFC 2247
}
// DistinguishedNameOptions is a struct for specifying the configuration for the distinguished name info.
type DistinguishedNameOptions struct {
CommonName bool
CountryName bool
DomainComponent bool
LocalityName bool
OrganizationName bool
SerialNumber bool
StateOrProvinceName bool
}
func newDistinguishedNameOptions(info *config.TLSCLientCertificateDNInfo) *DistinguishedNameOptions {
if info == nil {
return nil
}
return &DistinguishedNameOptions{
CommonName: info.CommonName,
CountryName: info.Country,
DomainComponent: info.DomainComponent,
LocalityName: info.Locality,
OrganizationName: info.Organization,
SerialNumber: info.SerialNumber,
StateOrProvinceName: info.Province,
}
}
// passTLSClientCert is a middleware that helps setup a few tls info features.
type passTLSClientCert struct {
next http.Handler
name string
pem bool // pass the sanitized pem to the backend in a specific header
info *tlsClientCertificateInfo // pass selected information from the client certificate
}
// New constructs a new PassTLSClientCert instance from supplied frontend header struct.
func New(ctx context.Context, next http.Handler, config config.PassTLSClientCert, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
return &passTLSClientCert{
next: next,
name: name,
pem: config.PEM,
info: newTLSClientInfo(config.Info),
}, nil
}
// tlsClientCertificateInfo is a struct for specifying the configuration for the passTLSClientCert middleware.
type tlsClientCertificateInfo struct {
notAfter bool
notBefore bool
sans bool
subject *DistinguishedNameOptions
issuer *DistinguishedNameOptions
}
func newTLSClientInfo(info *config.TLSClientCertificateInfo) *tlsClientCertificateInfo {
if info == nil {
return nil
}
return &tlsClientCertificateInfo{
issuer: newDistinguishedNameOptions(info.Issuer),
notAfter: info.NotAfter,
notBefore: info.NotBefore,
subject: newDistinguishedNameOptions(info.Subject),
sans: info.Sans,
}
}
func (p *passTLSClientCert) GetTracingInformation() (string, ext.SpanKindEnum) {
return p.name, tracing.SpanKindNoneEnum
}
func (p *passTLSClientCert) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), p.name, typeName)
p.modifyRequestHeaders(logger, req)
p.next.ServeHTTP(rw, req)
}
func getDNInfo(prefix string, options *DistinguishedNameOptions, cs *pkix.Name) string {
if options == nil {
return ""
}
content := &strings.Builder{}
// Manage non standard attributes
for _, name := range cs.Names {
// Domain Component - RFC 2247
if options.DomainComponent && attributeTypeNames[name.Type.String()] == "DC" {
content.WriteString(fmt.Sprintf("DC=%s,", name.Value))
}
}
if options.CountryName {
writeParts(content, cs.Country, "C")
}
if options.StateOrProvinceName {
writeParts(content, cs.Province, "ST")
}
if options.LocalityName {
writeParts(content, cs.Locality, "L")
}
if options.OrganizationName {
writeParts(content, cs.Organization, "O")
}
if options.SerialNumber {
writePart(content, cs.SerialNumber, "SN")
}
if options.CommonName {
writePart(content, cs.CommonName, "CN")
}
if content.Len() > 0 {
return prefix + `="` + strings.TrimSuffix(content.String(), ",") + `"`
}
return ""
}
func writeParts(content io.StringWriter, entries []string, prefix string) {
for _, entry := range entries {
writePart(content, entry, prefix)
}
}
func writePart(content io.StringWriter, entry string, prefix string) {
if len(entry) > 0 {
_, err := content.WriteString(fmt.Sprintf("%s=%s,", prefix, entry))
if err != nil {
log.Error(err)
}
}
}
// getXForwardedTLSClientCertInfo Build a string with the wanted client certificates information
// like Subject="C=%s,ST=%s,L=%s,O=%s,CN=%s",NB=%d,NA=%d,SAN=%s;
func (p *passTLSClientCert) getXForwardedTLSClientCertInfo(certs []*x509.Certificate) string {
var headerValues []string
for _, peerCert := range certs {
var values []string
var sans string
var nb string
var na string
if p.info != nil {
subject := getDNInfo("Subject", p.info.subject, &peerCert.Subject)
if len(subject) > 0 {
values = append(values, subject)
}
issuer := getDNInfo("Issuer", p.info.issuer, &peerCert.Issuer)
if len(issuer) > 0 {
values = append(values, issuer)
}
}
ci := p.info
if ci != nil {
if ci.notBefore {
nb = fmt.Sprintf("NB=%d", uint64(peerCert.NotBefore.Unix()))
values = append(values, nb)
}
if ci.notAfter {
na = fmt.Sprintf("NA=%d", uint64(peerCert.NotAfter.Unix()))
values = append(values, na)
}
if ci.sans {
sans = fmt.Sprintf("SAN=%s", strings.Join(getSANs(peerCert), ","))
values = append(values, sans)
}
}
value := strings.Join(values, ",")
headerValues = append(headerValues, value)
}
return strings.Join(headerValues, ";")
}
// modifyRequestHeaders set the wanted headers with the certificates information.
func (p *passTLSClientCert) modifyRequestHeaders(logger logrus.FieldLogger, r *http.Request) {
if p.pem {
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
r.Header.Set(xForwardedTLSClientCert, getXForwardedTLSClientCert(logger, r.TLS.PeerCertificates))
} else {
logger.Warn("Try to extract certificate on a request without TLS")
}
}
if p.info != nil {
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
headerContent := p.getXForwardedTLSClientCertInfo(r.TLS.PeerCertificates)
r.Header.Set(xForwardedTLSClientCertInfo, url.QueryEscape(headerContent))
} else {
logger.Warn("Try to extract certificate on a request without TLS")
}
}
}
// sanitize As we pass the raw certificates, remove the useless data and make it http request compliant.
func sanitize(cert []byte) string {
s := string(cert)
r := strings.NewReplacer("-----BEGIN CERTIFICATE-----", "",
"-----END CERTIFICATE-----", "",
"\n", "")
cleaned := r.Replace(s)
return url.QueryEscape(cleaned)
}
// extractCertificate extract the certificate from the request.
func extractCertificate(logger logrus.FieldLogger, cert *x509.Certificate) string {
b := pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
certPEM := pem.EncodeToMemory(&b)
if certPEM == nil {
logger.Error("Cannot extract the certificate content")
return ""
}
return sanitize(certPEM)
}
// getXForwardedTLSClientCert Build a string with the client certificates.
func getXForwardedTLSClientCert(logger logrus.FieldLogger, certs []*x509.Certificate) string {
var headerValues []string
for _, peerCert := range certs {
headerValues = append(headerValues, extractCertificate(logger, peerCert))
}
return strings.Join(headerValues, ",")
}
// getSANs get the Subject Alternate Name values.
func getSANs(cert *x509.Certificate) []string {
var sans []string
if cert == nil {
return sans
}
sans = append(sans, cert.DNSNames...)
sans = append(sans, cert.EmailAddresses...)
var ips []string
for _, ip := range cert.IPAddresses {
ips = append(ips, ip.String())
}
sans = append(sans, ips...)
var uris []string
for _, uri := range cert.URIs {
uris = append(uris, uri.String())
}
return append(sans, uris...)
}

View file

@ -0,0 +1,663 @@
package passtlsclientcert
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"net"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strings"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/require"
)
const (
signingCA = `Certificate:
Data:
Version: 3 (0x2)
Serial Number: 2 (0x2)
Signature Algorithm: sha1WithRSAEncryption
Issuer: DC=org, DC=cheese, O=Cheese, O=Cheese 2, OU=Cheese Section, OU=Cheese Section 2, CN=Simple Root CA, CN=Simple Root CA 2, C=FR, C=US, L=TOULOUSE, L=LYON, ST=Root State, ST=Root State 2/emailAddress=root@signing.com/emailAddress=root2@signing.com
Validity
Not Before: Dec 6 11:10:09 2018 GMT
Not After : Dec 5 11:10:09 2028 GMT
Subject: DC=org, DC=cheese, O=Cheese, O=Cheese 2, OU=Simple Signing Section, OU=Simple Signing Section 2, CN=Simple Signing CA, CN=Simple Signing CA 2, C=FR, C=US, L=TOULOUSE, L=LYON, ST=Signing State, ST=Signing State 2/emailAddress=simple@signing.com/emailAddress=simple2@signing.com
Subject Public Key Info:
Public Key Algorithm: rsaEncryption
RSA Public-Key: (2048 bit)
Modulus:
00:c3:9d:9f:61:15:57:3f:78:cc:e7:5d:20:e2:3e:
2e:79:4a:c3:3a:0c:26:40:18:db:87:08:85:c2:f7:
af:87:13:1a:ff:67:8a:b7:2b:58:a7:cc:89:dd:77:
ff:5e:27:65:11:80:82:8f:af:a0:af:25:86:ec:a2:
4f:20:0e:14:15:16:12:d7:74:5a:c3:99:bd:3b:81:
c8:63:6f:fc:90:14:86:d2:39:ee:87:b2:ff:6d:a5:
69:da:ab:5a:3a:97:cd:23:37:6a:4b:ba:63:cd:a1:
a9:e6:79:aa:37:b8:d1:90:c9:24:b5:e8:70:fc:15:
ad:39:97:28:73:47:66:f6:22:79:5a:b0:03:83:8a:
f1:ca:ae:8b:50:1e:c8:fa:0d:9f:76:2e:00:c2:0e:
75:bc:47:5a:b6:d8:05:ed:5a:bc:6d:50:50:36:6b:
ab:ab:69:f6:9b:1b:6c:7e:a8:9f:b2:33:3a:3c:8c:
6d:5e:83:ce:17:82:9e:10:51:a6:39:ec:98:4e:50:
b7:b1:aa:8b:ac:bb:a1:60:1b:ea:31:3b:b8:0a:ea:
63:41:79:b5:ec:ee:19:e9:85:8e:f3:6d:93:80:da:
98:58:a2:40:93:a5:53:eb:1d:24:b6:66:07:ec:58:
10:63:e7:fa:6e:18:60:74:76:15:39:3c:f4:95:95:
7e:df
Exponent: 65537 (0x10001)
X509v3 extensions:
X509v3 Key Usage: critical
Certificate Sign, CRL Sign
X509v3 Basic Constraints: critical
CA:TRUE, pathlen:0
X509v3 Subject Key Identifier:
1E:52:A2:E8:54:D5:37:EB:D5:A8:1D:E4:C2:04:1D:37:E2:F7:70:03
X509v3 Authority Key Identifier:
keyid:36:70:35:AA:F0:F6:93:B2:86:5D:32:73:F9:41:5A:3F:3B:C8:BC:8B
Signature Algorithm: sha1WithRSAEncryption
76:f3:16:21:27:6d:a2:2e:e8:18:49:aa:54:1e:f8:3b:07:fa:
65:50:d8:1f:a2:cf:64:6c:15:e0:0f:c8:46:b2:d7:b8:0e:cd:
05:3b:06:fb:dd:c6:2f:01:ae:bd:69:d3:bb:55:47:a9:f6:e5:
ba:be:4b:45:fb:2e:3c:33:e0:57:d4:3e:8e:3e:11:f2:0a:f1:
7d:06:ab:04:2e:a5:76:20:c2:db:a4:68:5a:39:00:62:2a:1d:
c2:12:b1:90:66:8c:36:a8:fd:83:d1:1b:da:23:a7:1d:5b:e6:
9b:40:c4:78:25:c7:b7:6b:75:35:cf:bb:37:4a:4f:fc:7e:32:
1f:8c:cf:12:d2:c9:c8:99:d9:4a:55:0a:1e:ac:de:b4:cb:7c:
bf:c4:fb:60:2c:a8:f7:e7:63:5c:b0:1c:62:af:01:3c:fe:4d:
3c:0b:18:37:4c:25:fc:d0:b2:f6:b2:f1:c3:f4:0f:53:d6:1e:
b5:fa:bc:d8:ad:dd:1c:f5:45:9f:af:fe:0a:01:79:92:9a:d8:
71:db:37:f3:1e:bd:fb:c7:1e:0a:0f:97:2a:61:f3:7b:19:93:
9c:a6:8a:69:cd:b0:f5:91:02:a5:1b:10:f4:80:5d:42:af:4e:
82:12:30:3e:d3:a7:11:14:ce:50:91:04:80:d7:2a:03:ef:71:
10:b8:db:a5
-----BEGIN CERTIFICATE-----
MIIFzTCCBLWgAwIBAgIBAjANBgkqhkiG9w0BAQUFADCCAWQxEzARBgoJkiaJk/Is
ZAEZFgNvcmcxFjAUBgoJkiaJk/IsZAEZFgZjaGVlc2UxDzANBgNVBAoMBkNoZWVz
ZTERMA8GA1UECgwIQ2hlZXNlIDIxFzAVBgNVBAsMDkNoZWVzZSBTZWN0aW9uMRkw
FwYDVQQLDBBDaGVlc2UgU2VjdGlvbiAyMRcwFQYDVQQDDA5TaW1wbGUgUm9vdCBD
QTEZMBcGA1UEAwwQU2ltcGxlIFJvb3QgQ0EgMjELMAkGA1UEBhMCRlIxCzAJBgNV
BAYTAlVTMREwDwYDVQQHDAhUT1VMT1VTRTENMAsGA1UEBwwETFlPTjETMBEGA1UE
CAwKUm9vdCBTdGF0ZTEVMBMGA1UECAwMUm9vdCBTdGF0ZSAyMR8wHQYJKoZIhvcN
AQkBFhByb290QHNpZ25pbmcuY29tMSAwHgYJKoZIhvcNAQkBFhFyb290MkBzaWdu
aW5nLmNvbTAeFw0xODEyMDYxMTEwMDlaFw0yODEyMDUxMTEwMDlaMIIBhDETMBEG
CgmSJomT8ixkARkWA29yZzEWMBQGCgmSJomT8ixkARkWBmNoZWVzZTEPMA0GA1UE
CgwGQ2hlZXNlMREwDwYDVQQKDAhDaGVlc2UgMjEfMB0GA1UECwwWU2ltcGxlIFNp
Z25pbmcgU2VjdGlvbjEhMB8GA1UECwwYU2ltcGxlIFNpZ25pbmcgU2VjdGlvbiAy
MRowGAYDVQQDDBFTaW1wbGUgU2lnbmluZyBDQTEcMBoGA1UEAwwTU2ltcGxlIFNp
Z25pbmcgQ0EgMjELMAkGA1UEBhMCRlIxCzAJBgNVBAYTAlVTMREwDwYDVQQHDAhU
T1VMT1VTRTENMAsGA1UEBwwETFlPTjEWMBQGA1UECAwNU2lnbmluZyBTdGF0ZTEY
MBYGA1UECAwPU2lnbmluZyBTdGF0ZSAyMSEwHwYJKoZIhvcNAQkBFhJzaW1wbGVA
c2lnbmluZy5jb20xIjAgBgkqhkiG9w0BCQEWE3NpbXBsZTJAc2lnbmluZy5jb20w
ggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDDnZ9hFVc/eMznXSDiPi55
SsM6DCZAGNuHCIXC96+HExr/Z4q3K1inzIndd/9eJ2URgIKPr6CvJYbsok8gDhQV
FhLXdFrDmb07gchjb/yQFIbSOe6Hsv9tpWnaq1o6l80jN2pLumPNoanmeao3uNGQ
ySS16HD8Fa05lyhzR2b2InlasAODivHKrotQHsj6DZ92LgDCDnW8R1q22AXtWrxt
UFA2a6urafabG2x+qJ+yMzo8jG1eg84Xgp4QUaY57JhOULexqousu6FgG+oxO7gK
6mNBebXs7hnphY7zbZOA2phYokCTpVPrHSS2ZgfsWBBj5/puGGB0dhU5PPSVlX7f
AgMBAAGjZjBkMA4GA1UdDwEB/wQEAwIBBjASBgNVHRMBAf8ECDAGAQH/AgEAMB0G
A1UdDgQWBBQeUqLoVNU369WoHeTCBB034vdwAzAfBgNVHSMEGDAWgBQ2cDWq8PaT
soZdMnP5QVo/O8i8izANBgkqhkiG9w0BAQUFAAOCAQEAdvMWISdtoi7oGEmqVB74
Owf6ZVDYH6LPZGwV4A/IRrLXuA7NBTsG+93GLwGuvWnTu1VHqfblur5LRfsuPDPg
V9Q+jj4R8grxfQarBC6ldiDC26RoWjkAYiodwhKxkGaMNqj9g9Eb2iOnHVvmm0DE
eCXHt2t1Nc+7N0pP/H4yH4zPEtLJyJnZSlUKHqzetMt8v8T7YCyo9+djXLAcYq8B
PP5NPAsYN0wl/NCy9rLxw/QPU9Yetfq82K3dHPVFn6/+CgF5kprYcds38x69+8ce
Cg+XKmHzexmTnKaKac2w9ZECpRsQ9IBdQq9OghIwPtOnERTOUJEEgNcqA+9xELjb
pQ==
-----END CERTIFICATE-----
`
minimalCheeseCrt = `-----BEGIN CERTIFICATE-----
MIIEQDCCAygCFFRY0OBk/L5Se0IZRj3CMljawL2UMA0GCSqGSIb3DQEBCwUAMIIB
hDETMBEGCgmSJomT8ixkARkWA29yZzEWMBQGCgmSJomT8ixkARkWBmNoZWVzZTEP
MA0GA1UECgwGQ2hlZXNlMREwDwYDVQQKDAhDaGVlc2UgMjEfMB0GA1UECwwWU2lt
cGxlIFNpZ25pbmcgU2VjdGlvbjEhMB8GA1UECwwYU2ltcGxlIFNpZ25pbmcgU2Vj
dGlvbiAyMRowGAYDVQQDDBFTaW1wbGUgU2lnbmluZyBDQTEcMBoGA1UEAwwTU2lt
cGxlIFNpZ25pbmcgQ0EgMjELMAkGA1UEBhMCRlIxCzAJBgNVBAYTAlVTMREwDwYD
VQQHDAhUT1VMT1VTRTENMAsGA1UEBwwETFlPTjEWMBQGA1UECAwNU2lnbmluZyBT
dGF0ZTEYMBYGA1UECAwPU2lnbmluZyBTdGF0ZSAyMSEwHwYJKoZIhvcNAQkBFhJz
aW1wbGVAc2lnbmluZy5jb20xIjAgBgkqhkiG9w0BCQEWE3NpbXBsZTJAc2lnbmlu
Zy5jb20wHhcNMTgxMjA2MTExMDM2WhcNMjEwOTI1MTExMDM2WjAzMQswCQYDVQQG
EwJGUjETMBEGA1UECAwKU29tZS1TdGF0ZTEPMA0GA1UECgwGQ2hlZXNlMIIBIjAN
BgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAskX/bUtwFo1gF2BTPNaNcTUMaRFu
FMZozK8IgLjccZ4kZ0R9oFO6Yp8Zl/IvPaf7tE26PI7XP7eHriUdhnQzX7iioDd0
RZa68waIhAGc+xPzRFrP3b3yj3S2a9Rve3c0K+SCV+EtKAwsxMqQDhoo9PcBfo5B
RHfht07uD5MncUcGirwN+/pxHV5xzAGPcc7On0/5L7bq/G+63nhu78zw9XyuLaHC
PM5VbOUvpyIESJHbMMzTdFGL8ob9VKO+Kr1kVGdEA9i8FLGl3xz/GBKuW/JD0xyW
DrU29mri5vYWHmkuv7ZWHGXnpXjTtPHwveE9/0/ArnmpMyR9JtqFr1oEvQIDAQAB
MA0GCSqGSIb3DQEBCwUAA4IBAQBHta+NWXI08UHeOkGzOTGRiWXsOH2dqdX6gTe9
xF1AIjyoQ0gvpoGVvlnChSzmlUj+vnx/nOYGIt1poE3hZA3ZHZD/awsvGyp3GwWD
IfXrEViSCIyF+8tNNKYyUcEO3xdAsAUGgfUwwF/mZ6MBV5+A/ZEEILlTq8zFt9dV
vdKzIt7fZYxYBBHFSarl1x8pDgWXlf3hAufevGJXip9xGYmznF0T5cq1RbWJ4be3
/9K7yuWhuBYC3sbTbCneHBa91M82za+PIISc1ygCYtWSBoZKSAqLk0rkZpHaekDP
WqeUSNGYV//RunTeuRDAf5OxehERb1srzBXhRZ3cZdzXbgR/
-----END CERTIFICATE-----
`
completeCheeseCrt = `Certificate:
Data:
Version: 3 (0x2)
Serial Number: 1 (0x1)
Signature Algorithm: sha1WithRSAEncryption
Issuer: DC=org, DC=cheese, O=Cheese, O=Cheese 2, OU=Simple Signing Section, OU=Simple Signing Section 2, CN=Simple Signing CA, CN=Simple Signing CA 2, C=FR, C=US, L=TOULOUSE, L=LYON, ST=Signing State, ST=Signing State 2/emailAddress=simple@signing.com/emailAddress=simple2@signing.com
Validity
Not Before: Dec 6 11:10:16 2018 GMT
Not After : Dec 5 11:10:16 2020 GMT
Subject: DC=org, DC=cheese, O=Cheese, O=Cheese 2, OU=Simple Signing Section, OU=Simple Signing Section 2, CN=*.cheese.org, CN=*.cheese.com, C=FR, C=US, L=TOULOUSE, L=LYON, ST=Cheese org state, ST=Cheese com state/emailAddress=cert@cheese.org/emailAddress=cert@scheese.com
Subject Public Key Info:
Public Key Algorithm: rsaEncryption
RSA Public-Key: (2048 bit)
Modulus:
00:de:77:fa:8d:03:70:30:39:dd:51:1b:cc:60:db:
a9:5a:13:b1:af:fe:2c:c6:38:9b:88:0a:0f:8e:d9:
1b:a1:1d:af:0d:66:e4:13:5b:bc:5d:36:92:d7:5e:
d0:fa:88:29:d3:78:e1:81:de:98:b2:a9:22:3f:bf:
8a:af:12:92:63:d4:a9:c3:f2:e4:7e:d2:dc:a2:c5:
39:1c:7a:eb:d7:12:70:63:2e:41:47:e0:f0:08:e8:
dc:be:09:01:ec:28:09:af:35:d7:79:9c:50:35:d1:
6b:e5:87:7b:34:f6:d2:31:65:1d:18:42:69:6c:04:
11:83:fe:44:ae:90:92:2d:0b:75:39:57:62:e6:17:
2f:47:2b:c7:53:dd:10:2d:c9:e3:06:13:d2:b9:ba:
63:2e:3c:7d:83:6b:d6:89:c9:cc:9d:4d:bf:9f:e8:
a3:7b:da:c8:99:2b:ba:66:d6:8e:f8:41:41:a0:c9:
d0:5e:c8:11:a4:55:4a:93:83:87:63:04:63:41:9c:
fb:68:04:67:c2:71:2f:f2:65:1d:02:5d:15:db:2c:
d9:04:69:85:c2:7d:0d:ea:3b:ac:85:f8:d4:8f:0f:
c5:70:b2:45:e1:ec:b2:54:0b:e9:f7:82:b4:9b:1b:
2d:b9:25:d4:ab:ca:8f:5b:44:3e:15:dd:b8:7f:b7:
ee:f9
Exponent: 65537 (0x10001)
X509v3 extensions:
X509v3 Key Usage: critical
Digital Signature, Key Encipherment
X509v3 Basic Constraints:
CA:FALSE
X509v3 Extended Key Usage:
TLS Web Server Authentication, TLS Web Client Authentication
X509v3 Subject Key Identifier:
94:BA:73:78:A2:87:FB:58:28:28:CF:98:3B:C2:45:70:16:6E:29:2F
X509v3 Authority Key Identifier:
keyid:1E:52:A2:E8:54:D5:37:EB:D5:A8:1D:E4:C2:04:1D:37:E2:F7:70:03
X509v3 Subject Alternative Name:
DNS:*.cheese.org, DNS:*.cheese.net, DNS:*.cheese.com, IP Address:10.0.1.0, IP Address:10.0.1.2, email:test@cheese.org, email:test@cheese.net
Signature Algorithm: sha1WithRSAEncryption
76:6b:05:b0:0e:34:11:b1:83:99:91:dc:ae:1b:e2:08:15:8b:
16:b2:9b:27:1c:02:ac:b5:df:1b:d0:d0:75:a4:2b:2c:5c:65:
ed:99:ab:f7:cd:fe:38:3f:c3:9a:22:31:1b:ac:8c:1c:c2:f9:
5d:d4:75:7a:2e:72:c7:85:a9:04:af:9f:2a:cc:d3:96:75:f0:
8e:c7:c6:76:48:ac:45:a4:b9:02:1e:2f:c0:15:c4:07:08:92:
cb:27:50:67:a1:c8:05:c5:3a:b3:a6:48:be:eb:d5:59:ab:a2:
1b:95:30:71:13:5b:0a:9a:73:3b:60:cc:10:d0:6a:c7:e5:d7:
8b:2f:f9:2e:98:f2:ff:81:14:24:09:e3:4b:55:57:09:1a:22:
74:f1:f6:40:13:31:43:89:71:0a:96:1a:05:82:1f:83:3a:87:
9b:17:25:ef:5a:55:f2:2d:cd:0d:4d:e4:81:58:b6:e3:8d:09:
62:9a:0c:bd:e4:e5:5c:f0:95:da:cb:c7:34:2c:34:5f:6d:fc:
60:7b:12:5b:86:fd:df:21:89:3b:48:08:30:bf:67:ff:8c:e6:
9b:53:cc:87:36:47:70:40:3b:d9:90:2a:d2:d2:82:c6:9c:f5:
d1:d8:e0:e6:fd:aa:2f:95:7e:39:ac:fc:4e:d4:ce:65:b3:ec:
c6:98:8a:31
-----BEGIN CERTIFICATE-----
MIIGWjCCBUKgAwIBAgIBATANBgkqhkiG9w0BAQUFADCCAYQxEzARBgoJkiaJk/Is
ZAEZFgNvcmcxFjAUBgoJkiaJk/IsZAEZFgZjaGVlc2UxDzANBgNVBAoMBkNoZWVz
ZTERMA8GA1UECgwIQ2hlZXNlIDIxHzAdBgNVBAsMFlNpbXBsZSBTaWduaW5nIFNl
Y3Rpb24xITAfBgNVBAsMGFNpbXBsZSBTaWduaW5nIFNlY3Rpb24gMjEaMBgGA1UE
AwwRU2ltcGxlIFNpZ25pbmcgQ0ExHDAaBgNVBAMME1NpbXBsZSBTaWduaW5nIENB
IDIxCzAJBgNVBAYTAkZSMQswCQYDVQQGEwJVUzERMA8GA1UEBwwIVE9VTE9VU0Ux
DTALBgNVBAcMBExZT04xFjAUBgNVBAgMDVNpZ25pbmcgU3RhdGUxGDAWBgNVBAgM
D1NpZ25pbmcgU3RhdGUgMjEhMB8GCSqGSIb3DQEJARYSc2ltcGxlQHNpZ25pbmcu
Y29tMSIwIAYJKoZIhvcNAQkBFhNzaW1wbGUyQHNpZ25pbmcuY29tMB4XDTE4MTIw
NjExMTAxNloXDTIwMTIwNTExMTAxNlowggF2MRMwEQYKCZImiZPyLGQBGRYDb3Jn
MRYwFAYKCZImiZPyLGQBGRYGY2hlZXNlMQ8wDQYDVQQKDAZDaGVlc2UxETAPBgNV
BAoMCENoZWVzZSAyMR8wHQYDVQQLDBZTaW1wbGUgU2lnbmluZyBTZWN0aW9uMSEw
HwYDVQQLDBhTaW1wbGUgU2lnbmluZyBTZWN0aW9uIDIxFTATBgNVBAMMDCouY2hl
ZXNlLm9yZzEVMBMGA1UEAwwMKi5jaGVlc2UuY29tMQswCQYDVQQGEwJGUjELMAkG
A1UEBhMCVVMxETAPBgNVBAcMCFRPVUxPVVNFMQ0wCwYDVQQHDARMWU9OMRkwFwYD
VQQIDBBDaGVlc2Ugb3JnIHN0YXRlMRkwFwYDVQQIDBBDaGVlc2UgY29tIHN0YXRl
MR4wHAYJKoZIhvcNAQkBFg9jZXJ0QGNoZWVzZS5vcmcxHzAdBgkqhkiG9w0BCQEW
EGNlcnRAc2NoZWVzZS5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIB
AQDed/qNA3AwOd1RG8xg26laE7Gv/izGOJuICg+O2RuhHa8NZuQTW7xdNpLXXtD6
iCnTeOGB3piyqSI/v4qvEpJj1KnD8uR+0tyixTkceuvXEnBjLkFH4PAI6Ny+CQHs
KAmvNdd5nFA10Wvlh3s09tIxZR0YQmlsBBGD/kSukJItC3U5V2LmFy9HK8dT3RAt
yeMGE9K5umMuPH2Da9aJycydTb+f6KN72siZK7pm1o74QUGgydBeyBGkVUqTg4dj
BGNBnPtoBGfCcS/yZR0CXRXbLNkEaYXCfQ3qO6yF+NSPD8VwskXh7LJUC+n3grSb
Gy25JdSryo9bRD4V3bh/t+75AgMBAAGjgeAwgd0wDgYDVR0PAQH/BAQDAgWgMAkG
A1UdEwQCMAAwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQW
BBSUunN4oof7WCgoz5g7wkVwFm4pLzAfBgNVHSMEGDAWgBQeUqLoVNU369WoHeTC
BB034vdwAzBhBgNVHREEWjBYggwqLmNoZWVzZS5vcmeCDCouY2hlZXNlLm5ldIIM
Ki5jaGVlc2UuY29thwQKAAEAhwQKAAECgQ90ZXN0QGNoZWVzZS5vcmeBD3Rlc3RA
Y2hlZXNlLm5ldDANBgkqhkiG9w0BAQUFAAOCAQEAdmsFsA40EbGDmZHcrhviCBWL
FrKbJxwCrLXfG9DQdaQrLFxl7Zmr983+OD/DmiIxG6yMHML5XdR1ei5yx4WpBK+f
KszTlnXwjsfGdkisRaS5Ah4vwBXEBwiSyydQZ6HIBcU6s6ZIvuvVWauiG5UwcRNb
CppzO2DMENBqx+XXiy/5Lpjy/4EUJAnjS1VXCRoidPH2QBMxQ4lxCpYaBYIfgzqH
mxcl71pV8i3NDU3kgVi2440JYpoMveTlXPCV2svHNCw0X238YHsSW4b93yGJO0gI
ML9n/4zmm1PMhzZHcEA72ZAq0tKCxpz10djg5v2qL5V+Oaz8TtTOZbPsxpiKMQ==
-----END CERTIFICATE-----
`
minimalCert = `-----BEGIN CERTIFICATE-----
MIIDGTCCAgECCQCqLd75YLi2kDANBgkqhkiG9w0BAQsFADBYMQswCQYDVQQGEwJG
UjETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UEBwwIVG91bG91c2UxITAfBgNV
BAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0xODA3MTgwODI4MTZaFw0x
ODA4MTcwODI4MTZaMEUxCzAJBgNVBAYTAkZSMRMwEQYDVQQIDApTb21lLVN0YXRl
MSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3
DQEBAQUAA4IBDwAwggEKAoIBAQC/+frDMMTLQyXG34F68BPhQq0kzK4LIq9Y0/gl
FjySZNn1C0QDWA1ubVCAcA6yY204I9cxcQDPNrhC7JlS5QA8Y5rhIBrqQlzZizAi
Rj3NTrRjtGUtOScnHuJaWjLy03DWD+aMwb7q718xt5SEABmmUvLwQK+EjW2MeDwj
y8/UEIpvrRDmdhGaqv7IFpIDkcIF7FowJ/hwDvx3PMc+z/JWK0ovzpvgbx69AVbw
ZxCimeha65rOqVi+lEetD26le+WnOdYsdJ2IkmpPNTXGdfb15xuAc+gFXfMCh7Iw
3Ynl6dZtZM/Ok2kiA7/OsmVnRKkWrtBfGYkI9HcNGb3zrk6nAgMBAAEwDQYJKoZI
hvcNAQELBQADggEBAC/R+Yvhh1VUhcbK49olWsk/JKqfS3VIDQYZg1Eo+JCPbwgS
I1BSYVfMcGzuJTX6ua3m/AHzGF3Tap4GhF4tX12jeIx4R4utnjj7/YKkTvuEM2f4
xT56YqI7zalGScIB0iMeyNz1QcimRl+M/49au8ow9hNX8C2tcA2cwd/9OIj/6T8q
SBRHc6ojvbqZSJCO0jziGDT1L3D+EDgTjED4nd77v/NRdP+egb0q3P0s4dnQ/5AV
aQlQADUn61j3ScbGJ4NSeZFFvsl38jeRi/MEzp0bGgNBcPj6JHi7qbbauZcZfQ05
jECvgAY7Nfd9mZ1KtyNaW31is+kag7NsvjxU/kM=
-----END CERTIFICATE-----`
)
func getCleanCertContents(certContents []string) string {
var re = regexp.MustCompile("-----BEGIN CERTIFICATE-----(?s)(.*)")
var cleanedCertContent []string
for _, certContent := range certContents {
cert := re.FindString(certContent)
cleanedCertContent = append(cleanedCertContent, sanitize([]byte(cert)))
}
return strings.Join(cleanedCertContent, ",")
}
func getCertificate(certContent string) *x509.Certificate {
roots := x509.NewCertPool()
ok := roots.AppendCertsFromPEM([]byte(signingCA))
if !ok {
panic("failed to parse root certificate")
}
block, _ := pem.Decode([]byte(certContent))
if block == nil {
panic("failed to parse certificate PEM")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
panic("failed to parse certificate: " + err.Error())
}
return cert
}
func buildTLSWith(certContents []string) *tls.ConnectionState {
var peerCertificates []*x509.Certificate
for _, certContent := range certContents {
peerCertificates = append(peerCertificates, getCertificate(certContent))
}
return &tls.ConnectionState{PeerCertificates: peerCertificates}
}
var next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("bar"))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
})
func getExpectedSanitized(s string) string {
return url.QueryEscape(strings.Replace(s, "\n", "", -1))
}
func TestSanitize(t *testing.T) {
testCases := []struct {
desc string
toSanitize []byte
expected string
}{
{
desc: "Empty",
},
{
desc: "With a minimal cert",
toSanitize: []byte(minimalCheeseCrt),
expected: getExpectedSanitized(`MIIEQDCCAygCFFRY0OBk/L5Se0IZRj3CMljawL2UMA0GCSqGSIb3DQEBCwUAMIIB
hDETMBEGCgmSJomT8ixkARkWA29yZzEWMBQGCgmSJomT8ixkARkWBmNoZWVzZTEP
MA0GA1UECgwGQ2hlZXNlMREwDwYDVQQKDAhDaGVlc2UgMjEfMB0GA1UECwwWU2lt
cGxlIFNpZ25pbmcgU2VjdGlvbjEhMB8GA1UECwwYU2ltcGxlIFNpZ25pbmcgU2Vj
dGlvbiAyMRowGAYDVQQDDBFTaW1wbGUgU2lnbmluZyBDQTEcMBoGA1UEAwwTU2lt
cGxlIFNpZ25pbmcgQ0EgMjELMAkGA1UEBhMCRlIxCzAJBgNVBAYTAlVTMREwDwYD
VQQHDAhUT1VMT1VTRTENMAsGA1UEBwwETFlPTjEWMBQGA1UECAwNU2lnbmluZyBT
dGF0ZTEYMBYGA1UECAwPU2lnbmluZyBTdGF0ZSAyMSEwHwYJKoZIhvcNAQkBFhJz
aW1wbGVAc2lnbmluZy5jb20xIjAgBgkqhkiG9w0BCQEWE3NpbXBsZTJAc2lnbmlu
Zy5jb20wHhcNMTgxMjA2MTExMDM2WhcNMjEwOTI1MTExMDM2WjAzMQswCQYDVQQG
EwJGUjETMBEGA1UECAwKU29tZS1TdGF0ZTEPMA0GA1UECgwGQ2hlZXNlMIIBIjAN
BgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAskX/bUtwFo1gF2BTPNaNcTUMaRFu
FMZozK8IgLjccZ4kZ0R9oFO6Yp8Zl/IvPaf7tE26PI7XP7eHriUdhnQzX7iioDd0
RZa68waIhAGc+xPzRFrP3b3yj3S2a9Rve3c0K+SCV+EtKAwsxMqQDhoo9PcBfo5B
RHfht07uD5MncUcGirwN+/pxHV5xzAGPcc7On0/5L7bq/G+63nhu78zw9XyuLaHC
PM5VbOUvpyIESJHbMMzTdFGL8ob9VKO+Kr1kVGdEA9i8FLGl3xz/GBKuW/JD0xyW
DrU29mri5vYWHmkuv7ZWHGXnpXjTtPHwveE9/0/ArnmpMyR9JtqFr1oEvQIDAQAB
MA0GCSqGSIb3DQEBCwUAA4IBAQBHta+NWXI08UHeOkGzOTGRiWXsOH2dqdX6gTe9
xF1AIjyoQ0gvpoGVvlnChSzmlUj+vnx/nOYGIt1poE3hZA3ZHZD/awsvGyp3GwWD
IfXrEViSCIyF+8tNNKYyUcEO3xdAsAUGgfUwwF/mZ6MBV5+A/ZEEILlTq8zFt9dV
vdKzIt7fZYxYBBHFSarl1x8pDgWXlf3hAufevGJXip9xGYmznF0T5cq1RbWJ4be3
/9K7yuWhuBYC3sbTbCneHBa91M82za+PIISc1ygCYtWSBoZKSAqLk0rkZpHaekDP
WqeUSNGYV//RunTeuRDAf5OxehERb1srzBXhRZ3cZdzXbgR/`),
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
require.Equal(t, test.expected, sanitize(test.toSanitize), "The sanitized certificates should be equal")
})
}
}
func TestTLSClientHeadersWithPEM(t *testing.T) {
testCases := []struct {
desc string
certContents []string // set the request TLS attribute if defined
config config.PassTLSClientCert
expectedHeader string
}{
{
desc: "No TLS, no option",
},
{
desc: "TLS, no option",
certContents: []string{minimalCheeseCrt},
},
{
desc: "No TLS, with pem option true",
config: config.PassTLSClientCert{PEM: true},
},
{
desc: "TLS with simple certificate, with pem option true",
certContents: []string{minimalCheeseCrt},
config: config.PassTLSClientCert{PEM: true},
expectedHeader: getCleanCertContents([]string{minimalCert}),
},
{
desc: "TLS with complete certificate, with pem option true",
certContents: []string{minimalCheeseCrt},
config: config.PassTLSClientCert{PEM: true},
expectedHeader: getCleanCertContents([]string{minimalCheeseCrt}),
},
{
desc: "TLS with two certificate, with pem option true",
certContents: []string{minimalCert, minimalCheeseCrt},
config: config.PassTLSClientCert{PEM: true},
expectedHeader: getCleanCertContents([]string{minimalCert, minimalCheeseCrt}),
},
}
for _, test := range testCases {
tlsClientHeaders, err := New(context.Background(), next, test.config, "foo")
require.NoError(t, err)
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://example.com/foo", nil)
if test.certContents != nil && len(test.certContents) > 0 {
req.TLS = buildTLSWith(test.certContents)
}
tlsClientHeaders.ServeHTTP(res, req)
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
require.Equal(t, http.StatusOK, res.Code, "Http Status should be OK")
require.Equal(t, "bar", res.Body.String(), "Should be the expected body")
if test.expectedHeader != "" {
require.Equal(t, getCleanCertContents(test.certContents), req.Header.Get(xForwardedTLSClientCert), "The request header should contain the cleaned certificate")
} else {
require.Empty(t, req.Header.Get(xForwardedTLSClientCert))
}
require.Empty(t, res.Header().Get(xForwardedTLSClientCert), "The response header should be always empty")
})
}
}
func TestGetSans(t *testing.T) {
urlFoo, err := url.Parse("my.foo.com")
require.NoError(t, err)
urlBar, err := url.Parse("my.bar.com")
require.NoError(t, err)
testCases := []struct {
desc string
cert *x509.Certificate // set the request TLS attribute if defined
expected []string
}{
{
desc: "With nil",
},
{
desc: "Certificate without Sans",
cert: &x509.Certificate{},
},
{
desc: "Certificate with all Sans",
cert: &x509.Certificate{
DNSNames: []string{"foo", "bar"},
EmailAddresses: []string{"test@test.com", "test2@test.com"},
IPAddresses: []net.IP{net.IPv4(10, 0, 0, 1), net.IPv4(10, 0, 0, 2)},
URIs: []*url.URL{urlFoo, urlBar},
},
expected: []string{"foo", "bar", "test@test.com", "test2@test.com", "10.0.0.1", "10.0.0.2", urlFoo.String(), urlBar.String()},
},
}
for _, test := range testCases {
sans := getSANs(test.cert)
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
if len(test.expected) > 0 {
for i, expected := range test.expected {
require.Equal(t, expected, sans[i])
}
} else {
require.Empty(t, sans)
}
})
}
}
func TestTLSClientHeadersWithCertInfo(t *testing.T) {
minimalCheeseCertAllInfo := `Subject="C=FR,ST=Some-State,O=Cheese",Issuer="DC=org,DC=cheese,C=FR,C=US,ST=Signing State,ST=Signing State 2,L=TOULOUSE,L=LYON,O=Cheese,O=Cheese 2,CN=Simple Signing CA 2",NB=1544094636,NA=1632568236,SAN=`
completeCertAllInfo := `Subject="DC=org,DC=cheese,C=FR,C=US,ST=Cheese org state,ST=Cheese com state,L=TOULOUSE,L=LYON,O=Cheese,O=Cheese 2,CN=*.cheese.com",Issuer="DC=org,DC=cheese,C=FR,C=US,ST=Signing State,ST=Signing State 2,L=TOULOUSE,L=LYON,O=Cheese,O=Cheese 2,CN=Simple Signing CA 2",NB=1544094616,NA=1607166616,SAN=*.cheese.org,*.cheese.net,*.cheese.com,test@cheese.org,test@cheese.net,10.0.1.0,10.0.1.2`
testCases := []struct {
desc string
certContents []string // set the request TLS attribute if defined
config config.PassTLSClientCert
expectedHeader string
}{
{
desc: "No TLS, no option",
},
{
desc: "TLS, no option",
certContents: []string{minimalCert},
},
{
desc: "No TLS, with subject info",
config: config.PassTLSClientCert{
Info: &config.TLSClientCertificateInfo{
Subject: &config.TLSCLientCertificateDNInfo{
CommonName: true,
Organization: true,
Locality: true,
Province: true,
Country: true,
SerialNumber: true,
},
},
},
},
{
desc: "No TLS, with pem option false with empty subject info",
config: config.PassTLSClientCert{
PEM: false,
Info: &config.TLSClientCertificateInfo{
Subject: &config.TLSCLientCertificateDNInfo{},
},
},
},
{
desc: "TLS with simple certificate, with all info",
certContents: []string{minimalCheeseCrt},
config: config.PassTLSClientCert{
Info: &config.TLSClientCertificateInfo{
NotAfter: true,
NotBefore: true,
Sans: true,
Subject: &config.TLSCLientCertificateDNInfo{
CommonName: true,
Country: true,
DomainComponent: true,
Locality: true,
Organization: true,
Province: true,
SerialNumber: true,
},
Issuer: &config.TLSCLientCertificateDNInfo{
CommonName: true,
Country: true,
DomainComponent: true,
Locality: true,
Organization: true,
Province: true,
SerialNumber: true,
},
},
},
expectedHeader: url.QueryEscape(minimalCheeseCertAllInfo),
},
{
desc: "TLS with simple certificate, with some info",
certContents: []string{minimalCheeseCrt},
config: config.PassTLSClientCert{
Info: &config.TLSClientCertificateInfo{
NotAfter: true,
Sans: true,
Subject: &config.TLSCLientCertificateDNInfo{
Organization: true,
},
Issuer: &config.TLSCLientCertificateDNInfo{
Country: true,
},
},
},
expectedHeader: url.QueryEscape(`Subject="O=Cheese",Issuer="C=FR,C=US",NA=1632568236,SAN=`),
},
{
desc: "TLS with complete certificate, with all info",
certContents: []string{completeCheeseCrt},
config: config.PassTLSClientCert{
Info: &config.TLSClientCertificateInfo{
NotAfter: true,
NotBefore: true,
Sans: true,
Subject: &config.TLSCLientCertificateDNInfo{
Country: true,
Province: true,
Locality: true,
Organization: true,
CommonName: true,
SerialNumber: true,
DomainComponent: true,
},
Issuer: &config.TLSCLientCertificateDNInfo{
Country: true,
Province: true,
Locality: true,
Organization: true,
CommonName: true,
SerialNumber: true,
DomainComponent: true,
},
},
},
expectedHeader: url.QueryEscape(completeCertAllInfo),
},
{
desc: "TLS with 2 certificates, with all info",
certContents: []string{minimalCheeseCrt, completeCheeseCrt},
config: config.PassTLSClientCert{
Info: &config.TLSClientCertificateInfo{
NotAfter: true,
NotBefore: true,
Sans: true,
Subject: &config.TLSCLientCertificateDNInfo{
Country: true,
Province: true,
Locality: true,
Organization: true,
CommonName: true,
SerialNumber: true,
DomainComponent: true,
},
Issuer: &config.TLSCLientCertificateDNInfo{
Country: true,
Province: true,
Locality: true,
Organization: true,
CommonName: true,
SerialNumber: true,
DomainComponent: true,
},
},
},
expectedHeader: url.QueryEscape(strings.Join([]string{minimalCheeseCertAllInfo, completeCertAllInfo}, ";")),
},
}
for _, test := range testCases {
tlsClientHeaders, err := New(context.Background(), next, test.config, "foo")
require.NoError(t, err)
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://example.com/foo", nil)
if test.certContents != nil && len(test.certContents) > 0 {
req.TLS = buildTLSWith(test.certContents)
}
tlsClientHeaders.ServeHTTP(res, req)
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
require.Equal(t, http.StatusOK, res.Code, "Http Status should be OK")
require.Equal(t, "bar", res.Body.String(), "Should be the expected body")
if test.expectedHeader != "" {
require.Equal(t, test.expectedHeader, req.Header.Get(xForwardedTLSClientCertInfo), "The request header should contain the cleaned certificate")
} else {
require.Empty(t, req.Header.Get(xForwardedTLSClientCertInfo))
}
require.Empty(t, res.Header().Get(xForwardedTLSClientCertInfo), "The response header should be always empty")
})
}
}

View file

@ -0,0 +1,54 @@
package ratelimiter
import (
"context"
"net/http"
"time"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/vulcand/oxy/ratelimit"
"github.com/vulcand/oxy/utils"
)
const (
typeName = "RateLimiterType"
)
type rateLimiter struct {
handler http.Handler
name string
}
// New creates rate limiter middleware.
func New(ctx context.Context, next http.Handler, config config.RateLimit, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
extractFunc, err := utils.NewExtractor(config.ExtractorFunc)
if err != nil {
return nil, err
}
rateSet := ratelimit.NewRateSet()
for _, rate := range config.RateSet {
if err = rateSet.Add(time.Duration(rate.Period), rate.Average, rate.Burst); err != nil {
return nil, err
}
}
rl, err := ratelimit.New(next, extractFunc, rateSet)
if err != nil {
return nil, err
}
return &rateLimiter{handler: rl, name: name}, nil
}
func (r *rateLimiter) GetTracingInformation() (string, ext.SpanKindEnum) {
return r.name, tracing.SpanKindNoneEnum
}
func (r *rateLimiter) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
r.handler.ServeHTTP(rw, req)
}

View file

@ -0,0 +1,40 @@
package recovery
import (
"context"
"net/http"
"github.com/containous/traefik/pkg/middlewares"
"github.com/sirupsen/logrus"
)
const (
typeName = "Recovery"
)
type recovery struct {
next http.Handler
name string
}
// New creates recovery middleware.
func New(ctx context.Context, next http.Handler, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
return &recovery{
next: next,
name: name,
}, nil
}
func (re *recovery) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
defer recoverFunc(middlewares.GetLogger(req.Context(), re.name, typeName), rw)
re.next.ServeHTTP(rw, req)
}
func recoverFunc(logger logrus.FieldLogger, rw http.ResponseWriter) {
if err := recover(); err != nil {
logger.Errorf("Recovered from panic in http handler: %+v", err)
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
}

View file

@ -0,0 +1,27 @@
package recovery
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRecoverHandler(t *testing.T) {
fn := func(w http.ResponseWriter, r *http.Request) {
panic("I love panicing!")
}
recovery, err := New(context.Background(), http.HandlerFunc(fn), "foo-recovery")
require.NoError(t, err)
server := httptest.NewServer(recovery)
defer server.Close()
resp, err := http.Get(server.URL)
require.NoError(t, err)
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
}

View file

@ -0,0 +1,158 @@
package redirect
import (
"bytes"
"context"
"html/template"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/vulcand/oxy/utils"
)
type redirect struct {
next http.Handler
regex *regexp.Regexp
replacement string
permanent bool
errHandler utils.ErrorHandler
name string
}
// New creates a Redirect middleware.
func newRedirect(_ context.Context, next http.Handler, regex string, replacement string, permanent bool, name string) (http.Handler, error) {
re, err := regexp.Compile(regex)
if err != nil {
return nil, err
}
return &redirect{
regex: re,
replacement: replacement,
permanent: permanent,
errHandler: utils.DefaultHandler,
next: next,
name: name,
}, nil
}
func (r *redirect) GetTracingInformation() (string, ext.SpanKindEnum) {
return r.name, tracing.SpanKindNoneEnum
}
func (r *redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
oldURL := rawURL(req)
// If the Regexp doesn't match, skip to the next handler
if !r.regex.MatchString(oldURL) {
r.next.ServeHTTP(rw, req)
return
}
// apply a rewrite regexp to the URL
newURL := r.regex.ReplaceAllString(oldURL, r.replacement)
// replace any variables that may be in there
rewrittenURL := &bytes.Buffer{}
if err := applyString(newURL, rewrittenURL, req); err != nil {
r.errHandler.ServeHTTP(rw, req, err)
return
}
// parse the rewritten URL and replace request URL with it
parsedURL, err := url.Parse(rewrittenURL.String())
if err != nil {
r.errHandler.ServeHTTP(rw, req, err)
return
}
if newURL != oldURL {
handler := &moveHandler{location: parsedURL, permanent: r.permanent}
handler.ServeHTTP(rw, req)
return
}
req.URL = parsedURL
// make sure the request URI corresponds the rewritten URL
req.RequestURI = req.URL.RequestURI()
r.next.ServeHTTP(rw, req)
}
type moveHandler struct {
location *url.URL
permanent bool
}
func (m *moveHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("Location", m.location.String())
status := http.StatusFound
if req.Method != http.MethodGet {
status = http.StatusTemporaryRedirect
}
if m.permanent {
status = http.StatusMovedPermanently
if req.Method != http.MethodGet {
status = http.StatusPermanentRedirect
}
}
rw.WriteHeader(status)
_, err := rw.Write([]byte(http.StatusText(status)))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}
func rawURL(req *http.Request) string {
scheme := "http"
host := req.Host
port := ""
uri := req.RequestURI
schemeRegex := `^(https?):\/\/([\w\._-]+)(:\d+)?(.*)$`
re, _ := regexp.Compile(schemeRegex)
if re.Match([]byte(req.RequestURI)) {
match := re.FindStringSubmatch(req.RequestURI)
scheme = match[1]
if len(match[2]) > 0 {
host = match[2]
}
if len(match[3]) > 0 {
port = match[3]
}
uri = match[4]
}
if req.TLS != nil || isXForwardedHTTPS(req) {
scheme = "https"
}
return strings.Join([]string{scheme, "://", host, port, uri}, "")
}
func isXForwardedHTTPS(request *http.Request) bool {
xForwardedProto := request.Header.Get("X-Forwarded-Proto")
return len(xForwardedProto) > 0 && xForwardedProto == "https"
}
func applyString(in string, out io.Writer, req *http.Request) error {
t, err := template.New("t").Parse(in)
if err != nil {
return err
}
data := struct{ Request *http.Request }{Request: req}
return t.Execute(out, data)
}

View file

@ -0,0 +1,22 @@
package redirect
import (
"context"
"net/http"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
)
const (
typeRegexName = "RedirectRegex"
)
// NewRedirectRegex creates a redirect middleware.
func NewRedirectRegex(ctx context.Context, next http.Handler, conf config.RedirectRegex, name string) (http.Handler, error) {
logger := middlewares.GetLogger(ctx, name, typeRegexName)
logger.Debug("Creating middleware")
logger.Debugf("Setting up redirection from %s to %s", conf.Regex, conf.Replacement)
return newRedirect(ctx, next, conf.Regex, conf.Replacement, conf.Permanent, name)
}

View file

@ -0,0 +1,198 @@
package redirect
import (
"context"
"crypto/tls"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRedirectRegexHandler(t *testing.T) {
testCases := []struct {
desc string
config config.RedirectRegex
method string
url string
secured bool
expectedURL string
expectedStatus int
errorExpected bool
}{
{
desc: "simple redirection",
config: config.RedirectRegex{
Regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
Replacement: "https://${1}bar$2:443$4",
},
url: "http://foo.com:80",
expectedURL: "https://foobar.com:443",
expectedStatus: http.StatusFound,
},
{
desc: "use request header",
config: config.RedirectRegex{
Regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
Replacement: `https://${1}{{ .Request.Header.Get "X-Foo" }}$2:443$4`,
},
url: "http://foo.com:80",
expectedURL: "https://foobar.com:443",
expectedStatus: http.StatusFound,
},
{
desc: "URL doesn't match regex",
config: config.RedirectRegex{
Regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
Replacement: "https://${1}bar$2:443$4",
},
url: "http://bar.com:80",
expectedStatus: http.StatusOK,
},
{
desc: "invalid rewritten URL",
config: config.RedirectRegex{
Regex: `^(.*)$`,
Replacement: "http://192.168.0.%31/",
},
url: "http://foo.com:80",
expectedStatus: http.StatusBadGateway,
},
{
desc: "invalid regex",
config: config.RedirectRegex{
Regex: `^(.*`,
Replacement: "$1",
},
url: "http://foo.com:80",
errorExpected: true,
},
{
desc: "HTTP to HTTPS permanent",
config: config.RedirectRegex{
Regex: `^http://`,
Replacement: "https://$1",
Permanent: true,
},
url: "http://foo",
expectedURL: "https://foo",
expectedStatus: http.StatusMovedPermanently,
},
{
desc: "HTTPS to HTTP permanent",
config: config.RedirectRegex{
Regex: `https://foo`,
Replacement: "http://foo",
Permanent: true,
},
secured: true,
url: "https://foo",
expectedURL: "http://foo",
expectedStatus: http.StatusMovedPermanently,
},
{
desc: "HTTP to HTTPS",
config: config.RedirectRegex{
Regex: `http://foo:80`,
Replacement: "https://foo:443",
},
url: "http://foo:80",
expectedURL: "https://foo:443",
expectedStatus: http.StatusFound,
},
{
desc: "HTTPS to HTTP",
config: config.RedirectRegex{
Regex: `https://foo:443`,
Replacement: "http://foo:80",
},
secured: true,
url: "https://foo:443",
expectedURL: "http://foo:80",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP to HTTP",
config: config.RedirectRegex{
Regex: `http://foo:80`,
Replacement: "http://foo:88",
},
url: "http://foo:80",
expectedURL: "http://foo:88",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP to HTTP POST",
config: config.RedirectRegex{
Regex: `^http://`,
Replacement: "https://$1",
},
url: "http://foo",
method: http.MethodPost,
expectedURL: "https://foo",
expectedStatus: http.StatusTemporaryRedirect,
},
{
desc: "HTTP to HTTP POST permanent",
config: config.RedirectRegex{
Regex: `^http://`,
Replacement: "https://$1",
Permanent: true,
},
url: "http://foo",
method: http.MethodPost,
expectedURL: "https://foo",
expectedStatus: http.StatusPermanentRedirect,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
handler, err := NewRedirectRegex(context.Background(), next, test.config, "traefikTest")
if test.errorExpected {
require.Error(t, err)
require.Nil(t, handler)
} else {
require.NoError(t, err)
require.NotNil(t, handler)
recorder := httptest.NewRecorder()
method := http.MethodGet
if test.method != "" {
method = test.method
}
r := testhelpers.MustNewRequest(method, test.url, nil)
if test.secured {
r.TLS = &tls.ConnectionState{}
}
r.Header.Set("X-Foo", "bar")
handler.ServeHTTP(recorder, r)
assert.Equal(t, test.expectedStatus, recorder.Code)
if test.expectedStatus == http.StatusMovedPermanently ||
test.expectedStatus == http.StatusFound ||
test.expectedStatus == http.StatusTemporaryRedirect ||
test.expectedStatus == http.StatusPermanentRedirect {
location, err := recorder.Result().Location()
require.NoError(t, err)
assert.Equal(t, test.expectedURL, location.String())
} else {
location, err := recorder.Result().Location()
require.Errorf(t, err, "Location %v", location)
}
}
})
}
}

View file

@ -0,0 +1,34 @@
package redirect
import (
"context"
"net/http"
"github.com/containous/traefik/pkg/middlewares"
"github.com/pkg/errors"
"github.com/containous/traefik/pkg/config"
)
const (
typeSchemeName = "RedirectScheme"
schemeRedirectRegex = `^(https?:\/\/)?([\w\._-]+)(:\d+)?(.*)$`
)
// NewRedirectScheme creates a new RedirectScheme middleware.
func NewRedirectScheme(ctx context.Context, next http.Handler, conf config.RedirectScheme, name string) (http.Handler, error) {
logger := middlewares.GetLogger(ctx, name, typeSchemeName)
logger.Debug("Creating middleware")
logger.Debugf("Setting up redirection to %s %s", conf.Scheme, conf.Port)
if len(conf.Scheme) == 0 {
return nil, errors.New("you must provide a target scheme")
}
port := ""
if len(conf.Port) > 0 && !(conf.Scheme == "http" && conf.Port == "80" || conf.Scheme == "https" && conf.Port == "443") {
port = ":" + conf.Port
}
return newRedirect(ctx, next, schemeRedirectRegex, conf.Scheme+"://${2}"+port+"${4}", conf.Permanent, name)
}

View file

@ -0,0 +1,250 @@
package redirect
import (
"context"
"crypto/tls"
"net/http"
"net/http/httptest"
"regexp"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRedirectSchemeHandler(t *testing.T) {
testCases := []struct {
desc string
config config.RedirectScheme
method string
url string
secured bool
expectedURL string
expectedStatus int
errorExpected bool
}{
{
desc: "Without scheme",
config: config.RedirectScheme{},
url: "http://foo",
errorExpected: true,
},
{
desc: "HTTP to HTTPS",
config: config.RedirectScheme{
Scheme: "https",
},
url: "http://foo",
expectedURL: "https://foo",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP with port to HTTPS without port",
config: config.RedirectScheme{
Scheme: "https",
},
url: "http://foo:8080",
expectedURL: "https://foo",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP without port to HTTPS with port",
config: config.RedirectScheme{
Scheme: "https",
Port: "8443",
},
url: "http://foo",
expectedURL: "https://foo:8443",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP with port to HTTPS with port",
config: config.RedirectScheme{
Scheme: "https",
Port: "8443",
},
url: "http://foo:8000",
expectedURL: "https://foo:8443",
expectedStatus: http.StatusFound,
},
{
desc: "HTTPS with port to HTTPS with port",
config: config.RedirectScheme{
Scheme: "https",
Port: "8443",
},
url: "https://foo:8000",
expectedURL: "https://foo:8443",
expectedStatus: http.StatusFound,
},
{
desc: "HTTPS with port to HTTPS without port",
config: config.RedirectScheme{
Scheme: "https",
},
url: "https://foo:8000",
expectedURL: "https://foo",
expectedStatus: http.StatusFound,
},
{
desc: "redirection to HTTPS without port from an URL already in https",
config: config.RedirectScheme{
Scheme: "https",
},
url: "https://foo:8000/theother",
expectedURL: "https://foo/theother",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP to HTTPS permanent",
config: config.RedirectScheme{
Scheme: "https",
Port: "8443",
Permanent: true,
},
url: "http://foo",
expectedURL: "https://foo:8443",
expectedStatus: http.StatusMovedPermanently,
},
{
desc: "to HTTP 80",
config: config.RedirectScheme{
Scheme: "http",
Port: "80",
},
url: "http://foo:80",
expectedURL: "http://foo",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP to wss",
config: config.RedirectScheme{
Scheme: "wss",
Port: "9443",
},
url: "http://foo",
expectedURL: "wss://foo:9443",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP to wss without port",
config: config.RedirectScheme{
Scheme: "wss",
},
url: "http://foo",
expectedURL: "wss://foo",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP with port to wss without port",
config: config.RedirectScheme{
Scheme: "wss",
},
url: "http://foo:5678",
expectedURL: "wss://foo",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP to HTTPS without port",
config: config.RedirectScheme{
Scheme: "https",
},
url: "http://foo:443",
expectedURL: "https://foo",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP port redirection",
config: config.RedirectScheme{
Scheme: "http",
Port: "8181",
},
url: "http://foo:8080",
expectedURL: "http://foo:8181",
expectedStatus: http.StatusFound,
},
{
desc: "HTTPS with port 80 to HTTPS without port",
config: config.RedirectScheme{
Scheme: "https",
},
url: "https://foo:80",
expectedURL: "https://foo",
expectedStatus: http.StatusFound,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
handler, err := NewRedirectScheme(context.Background(), next, test.config, "traefikTest")
if test.errorExpected {
require.Error(t, err)
require.Nil(t, handler)
} else {
require.NoError(t, err)
require.NotNil(t, handler)
recorder := httptest.NewRecorder()
method := http.MethodGet
if test.method != "" {
method = test.method
}
r := httptest.NewRequest(method, test.url, nil)
if test.secured {
r.TLS = &tls.ConnectionState{}
}
r.Header.Set("X-Foo", "bar")
handler.ServeHTTP(recorder, r)
assert.Equal(t, test.expectedStatus, recorder.Code)
if test.expectedStatus == http.StatusMovedPermanently ||
test.expectedStatus == http.StatusFound ||
test.expectedStatus == http.StatusTemporaryRedirect ||
test.expectedStatus == http.StatusPermanentRedirect {
location, err := recorder.Result().Location()
require.NoError(t, err)
assert.Equal(t, test.expectedURL, location.String())
} else {
location, err := recorder.Result().Location()
require.Errorf(t, err, "Location %v", location)
}
schemeRegex := `^(https?):\/\/([\w\._-]+)(:\d+)?(.*)$`
re, _ := regexp.Compile(schemeRegex)
if re.Match([]byte(test.url)) {
match := re.FindStringSubmatch(test.url)
r.RequestURI = match[4]
handler.ServeHTTP(recorder, r)
assert.Equal(t, test.expectedStatus, recorder.Code)
if test.expectedStatus == http.StatusMovedPermanently ||
test.expectedStatus == http.StatusFound ||
test.expectedStatus == http.StatusTemporaryRedirect ||
test.expectedStatus == http.StatusPermanentRedirect {
location, err := recorder.Result().Location()
require.NoError(t, err)
assert.Equal(t, test.expectedURL, location.String())
} else {
location, err := recorder.Result().Location()
require.Errorf(t, err, "Location %v", location)
}
}
}
})
}
}

View file

@ -0,0 +1,46 @@
package replacepath
import (
"context"
"net/http"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
)
const (
// ReplacedPathHeader is the default header to set the old path to.
ReplacedPathHeader = "X-Replaced-Path"
typeName = "ReplacePath"
)
// ReplacePath is a middleware used to replace the path of a URL request.
type replacePath struct {
next http.Handler
path string
name string
}
// New creates a new replace path middleware.
func New(ctx context.Context, next http.Handler, config config.ReplacePath, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
return &replacePath{
next: next,
path: config.Path,
name: name,
}, nil
}
func (r *replacePath) GetTracingInformation() (string, ext.SpanKindEnum) {
return r.name, tracing.SpanKindNoneEnum
}
func (r *replacePath) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
req.Header.Add(ReplacedPathHeader, req.URL.Path)
req.URL.Path = r.path
req.RequestURI = req.URL.RequestURI()
r.next.ServeHTTP(rw, req)
}

View file

@ -0,0 +1,46 @@
package replacepath
import (
"context"
"net/http"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestReplacePath(t *testing.T) {
var replacementConfig = config.ReplacePath{
Path: "/replacement-path",
}
paths := []string{
"/example",
"/some/really/long/path",
}
for _, path := range paths {
t.Run(path, func(t *testing.T) {
var expectedPath, actualHeader, requestURI string
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expectedPath = r.URL.Path
actualHeader = r.Header.Get(ReplacedPathHeader)
requestURI = r.RequestURI
})
handler, err := New(context.Background(), next, replacementConfig, "foo-replace-path")
require.NoError(t, err)
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+path, nil)
handler.ServeHTTP(nil, req)
assert.Equal(t, expectedPath, replacementConfig.Path, "Unexpected path.")
assert.Equal(t, path, actualHeader, "Unexpected '%s' header.", ReplacedPathHeader)
assert.Equal(t, expectedPath, requestURI, "Unexpected request URI.")
})
}
}

View file

@ -0,0 +1,57 @@
package replacepathregex
import (
"context"
"fmt"
"net/http"
"regexp"
"strings"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/middlewares/replacepath"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
)
const (
typeName = "ReplacePathRegex"
)
// ReplacePathRegex is a middleware used to replace the path of a URL request with a regular expression.
type replacePathRegex struct {
next http.Handler
regexp *regexp.Regexp
replacement string
name string
}
// New creates a new replace path regex middleware.
func New(ctx context.Context, next http.Handler, config config.ReplacePathRegex, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
exp, err := regexp.Compile(strings.TrimSpace(config.Regex))
if err != nil {
return nil, fmt.Errorf("error compiling regular expression %s: %s", config.Regex, err)
}
return &replacePathRegex{
regexp: exp,
replacement: strings.TrimSpace(config.Replacement),
next: next,
name: name,
}, nil
}
func (rp *replacePathRegex) GetTracingInformation() (string, ext.SpanKindEnum) {
return rp.name, tracing.SpanKindNoneEnum
}
func (rp *replacePathRegex) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if rp.regexp != nil && len(rp.replacement) > 0 && rp.regexp.MatchString(req.URL.Path) {
req.Header.Add(replacepath.ReplacedPathHeader, req.URL.Path)
req.URL.Path = rp.regexp.ReplaceAllString(req.URL.Path, rp.replacement)
req.RequestURI = req.URL.RequestURI()
}
rp.next.ServeHTTP(rw, req)
}

View file

@ -0,0 +1,104 @@
package replacepathregex
import (
"context"
"net/http"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares/replacepath"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestReplacePathRegex(t *testing.T) {
testCases := []struct {
desc string
path string
config config.ReplacePathRegex
expectedPath string
expectedHeader string
expectsError bool
}{
{
desc: "simple regex",
path: "/whoami/and/whoami",
config: config.ReplacePathRegex{
Replacement: "/who-am-i/$1",
Regex: `^/whoami/(.*)`,
},
expectedPath: "/who-am-i/and/whoami",
expectedHeader: "/whoami/and/whoami",
},
{
desc: "simple replace (no regex)",
path: "/whoami/and/whoami",
config: config.ReplacePathRegex{
Replacement: "/who-am-i",
Regex: `/whoami`,
},
expectedPath: "/who-am-i/and/who-am-i",
expectedHeader: "/whoami/and/whoami",
},
{
desc: "no match",
path: "/whoami/and/whoami",
config: config.ReplacePathRegex{
Replacement: "/whoami",
Regex: `/no-match`,
},
expectedPath: "/whoami/and/whoami",
},
{
desc: "multiple replacement",
path: "/downloads/src/source.go",
config: config.ReplacePathRegex{
Replacement: "/downloads/$1-$2",
Regex: `^(?i)/downloads/([^/]+)/([^/]+)$`,
},
expectedPath: "/downloads/src-source.go",
expectedHeader: "/downloads/src/source.go",
},
{
desc: "invalid regular expression",
path: "/invalid/regexp/test",
config: config.ReplacePathRegex{
Replacement: "/valid/regexp/$1",
Regex: `^(?err)/invalid/regexp/([^/]+)$`,
},
expectedPath: "/invalid/regexp/test",
expectsError: true,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
var actualPath, actualHeader, requestURI string
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actualPath = r.URL.Path
actualHeader = r.Header.Get(replacepath.ReplacedPathHeader)
requestURI = r.RequestURI
})
handler, err := New(context.Background(), next, test.config, "foo-replace-path-regexp")
if test.expectsError {
require.Error(t, err)
} else {
require.NoError(t, err)
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
req.RequestURI = test.path
handler.ServeHTTP(nil, req)
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
assert.Equal(t, actualPath, requestURI, "Unexpected request URI.")
if test.expectedHeader != "" {
assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", replacepath.ReplacedPathHeader)
}
}
})
}
}

View file

@ -0,0 +1,124 @@
package requestdecorator
import (
"context"
"fmt"
"net"
"sort"
"strings"
"time"
"github.com/containous/traefik/pkg/log"
"github.com/miekg/dns"
"github.com/patrickmn/go-cache"
)
type cnameResolv struct {
TTL time.Duration
Record string
}
type byTTL []*cnameResolv
func (a byTTL) Len() int { return len(a) }
func (a byTTL) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a byTTL) Less(i, j int) bool { return a[i].TTL > a[j].TTL }
// Resolver used for host resolver.
type Resolver struct {
CnameFlattening bool
ResolvConfig string
ResolvDepth int
cache *cache.Cache
}
// CNAMEFlatten check if CNAME record exists, flatten if possible.
func (hr *Resolver) CNAMEFlatten(ctx context.Context, host string) string {
if hr.cache == nil {
hr.cache = cache.New(30*time.Minute, 5*time.Minute)
}
result := host
request := host
value, found := hr.cache.Get(host)
if found {
return value.(string)
}
logger := log.FromContext(ctx)
var cacheDuration = 0 * time.Second
for depth := 0; depth < hr.ResolvDepth; depth++ {
resolv, err := cnameResolve(ctx, request, hr.ResolvConfig)
if err != nil {
logger.Error(err)
break
}
if resolv == nil {
break
}
result = resolv.Record
if depth == 0 {
cacheDuration = resolv.TTL
}
request = resolv.Record
}
if err := hr.cache.Add(host, result, cacheDuration); err != nil {
logger.Error(err)
}
return result
}
// cnameResolve resolves CNAME if exists, and return with the highest TTL.
func cnameResolve(ctx context.Context, host string, resolvPath string) (*cnameResolv, error) {
config, err := dns.ClientConfigFromFile(resolvPath)
if err != nil {
return nil, fmt.Errorf("invalid resolver configuration file: %s", resolvPath)
}
client := &dns.Client{Timeout: 30 * time.Second}
m := &dns.Msg{}
m.SetQuestion(dns.Fqdn(host), dns.TypeCNAME)
var result []*cnameResolv
for _, server := range config.Servers {
tempRecord, err := getRecord(client, m, server, config.Port)
if err != nil {
log.FromContext(ctx).Errorf("Failed to resolve host %s: %v", host, err)
continue
}
result = append(result, tempRecord)
}
if len(result) == 0 {
return nil, nil
}
sort.Sort(byTTL(result))
return result[0], nil
}
func getRecord(client *dns.Client, msg *dns.Msg, server string, port string) (*cnameResolv, error) {
resp, _, err := client.Exchange(msg, net.JoinHostPort(server, port))
if err != nil {
return nil, fmt.Errorf("exchange error for server %s: %v", server, err)
}
if resp == nil || len(resp.Answer) == 0 {
return nil, fmt.Errorf("empty answer for server %s", server)
}
rr, ok := resp.Answer[0].(*dns.CNAME)
if !ok {
return nil, fmt.Errorf("invalid response type for server %s", server)
}
return &cnameResolv{
TTL: time.Duration(rr.Hdr.Ttl) * time.Second,
Record: strings.TrimSuffix(rr.Target, "."),
}, nil
}

View file

@ -0,0 +1,51 @@
package requestdecorator
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCNAMEFlatten(t *testing.T) {
testCases := []struct {
desc string
resolvFile string
domain string
expectedDomain string
}{
{
desc: "host request is CNAME record",
resolvFile: "/etc/resolv.conf",
domain: "www.github.com",
expectedDomain: "github.com",
},
{
desc: "resolve file not found",
resolvFile: "/etc/resolv.oops",
domain: "www.github.com",
expectedDomain: "www.github.com",
},
{
desc: "host request is not CNAME record",
resolvFile: "/etc/resolv.conf",
domain: "github.com",
expectedDomain: "github.com",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
hostResolver := &Resolver{
ResolvConfig: test.resolvFile,
ResolvDepth: 5,
}
flatH := hostResolver.CNAMEFlatten(context.Background(), test.domain)
assert.Equal(t, test.expectedDomain, flatH)
})
}
}

View file

@ -0,0 +1,87 @@
package requestdecorator
import (
"context"
"net"
"net/http"
"strings"
"github.com/containous/alice"
"github.com/containous/traefik/pkg/types"
)
const (
canonicalKey key = "canonical"
flattenKey key = "flatten"
)
type key string
// RequestDecorator is the struct for the middleware that adds the CanonicalDomain of the request Host into a context for later use.
type RequestDecorator struct {
hostResolver *Resolver
}
// New creates a new request host middleware.
func New(hostResolverConfig *types.HostResolverConfig) *RequestDecorator {
requestDecorator := &RequestDecorator{}
if hostResolverConfig != nil {
requestDecorator.hostResolver = &Resolver{
CnameFlattening: hostResolverConfig.CnameFlattening,
ResolvConfig: hostResolverConfig.ResolvConfig,
ResolvDepth: hostResolverConfig.ResolvDepth,
}
}
return requestDecorator
}
func (r *RequestDecorator) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
host := types.CanonicalDomain(parseHost(req.Host))
reqt := req.WithContext(context.WithValue(req.Context(), canonicalKey, host))
if r.hostResolver != nil && r.hostResolver.CnameFlattening {
flatHost := r.hostResolver.CNAMEFlatten(reqt.Context(), host)
reqt = reqt.WithContext(context.WithValue(reqt.Context(), flattenKey, flatHost))
}
next(rw, reqt)
}
func parseHost(addr string) string {
if !strings.Contains(addr, ":") {
return addr
}
host, _, err := net.SplitHostPort(addr)
if err != nil {
return addr
}
return host
}
// GetCanonizedHost retrieves the canonized host from the given context (previously stored in the request context by the middleware).
func GetCanonizedHost(ctx context.Context) string {
if val, ok := ctx.Value(canonicalKey).(string); ok {
return val
}
return ""
}
// GetCNAMEFlatten return the flat name if it is present in the context.
func GetCNAMEFlatten(ctx context.Context) string {
if val, ok := ctx.Value(flattenKey).(string); ok {
return val
}
return ""
}
// WrapHandler Wraps a ServeHTTP with next to an alice.Constructor.
func WrapHandler(handler *RequestDecorator) alice.Constructor {
return func(next http.Handler) (http.Handler, error) {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
handler.ServeHTTP(rw, req, next.ServeHTTP)
}), nil
}
}

View file

@ -0,0 +1,146 @@
package requestdecorator
import (
"net/http"
"testing"
"github.com/containous/traefik/pkg/types"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
)
func TestRequestHost(t *testing.T) {
testCases := []struct {
desc string
url string
expected string
}{
{
desc: "host without :",
url: "http://host",
expected: "host",
},
{
desc: "host with : and without port",
url: "http://host:",
expected: "host",
},
{
desc: "IP host with : and with port",
url: "http://127.0.0.1:123",
expected: "127.0.0.1",
},
{
desc: "IP host with : and without port",
url: "http://127.0.0.1:",
expected: "127.0.0.1",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
host := GetCanonizedHost(r.Context())
assert.Equal(t, test.expected, host)
})
rh := New(nil)
req := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
rh.ServeHTTP(nil, req, next)
})
}
}
func TestRequestFlattening(t *testing.T) {
testCases := []struct {
desc string
url string
expected string
}{
{
desc: "host with flattening",
url: "http://www.github.com",
expected: "github.com",
},
{
desc: "host without flattening",
url: "http://github.com",
expected: "github.com",
},
{
desc: "ip without flattening",
url: "http://127.0.0.1",
expected: "127.0.0.1",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
host := GetCNAMEFlatten(r.Context())
assert.Equal(t, test.expected, host)
})
rh := New(
&types.HostResolverConfig{
CnameFlattening: true,
ResolvConfig: "/etc/resolv.conf",
ResolvDepth: 5,
},
)
req := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
rh.ServeHTTP(nil, req, next)
})
}
}
func TestRequestHostParseHost(t *testing.T) {
testCases := []struct {
desc string
host string
expected string
}{
{
desc: "host without :",
host: "host",
expected: "host",
},
{
desc: "host with : and without port",
host: "host:",
expected: "host",
},
{
desc: "IP host with : and with port",
host: "127.0.0.1:123",
expected: "127.0.0.1",
},
{
desc: "IP host with : and without port",
host: "127.0.0.1:",
expected: "127.0.0.1",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
actual := parseHost(test.host)
assert.Equal(t, test.expected, actual)
})
}
}

View file

@ -0,0 +1,207 @@
package retry
import (
"bufio"
"context"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptrace"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
)
// Compile time validation that the response writer implements http interfaces correctly.
var _ middlewares.Stateful = &responseWriterWithCloseNotify{}
const (
typeName = "Retry"
)
// Listener is used to inform about retry attempts.
type Listener interface {
// Retried will be called when a retry happens, with the request attempt passed to it.
// For the first retry this will be attempt 2.
Retried(req *http.Request, attempt int)
}
// Listeners is a convenience type to construct a list of Listener and notify
// each of them about a retry attempt.
type Listeners []Listener
// retry is a middleware that retries requests.
type retry struct {
attempts int
next http.Handler
listener Listener
name string
}
// New returns a new retry middleware.
func New(ctx context.Context, next http.Handler, config config.Retry, listener Listener, name string) (http.Handler, error) {
logger := middlewares.GetLogger(ctx, name, typeName)
logger.Debug("Creating middleware")
if config.Attempts <= 0 {
return nil, fmt.Errorf("incorrect (or empty) value for attempt (%d)", config.Attempts)
}
return &retry{
attempts: config.Attempts,
next: next,
listener: listener,
name: name,
}, nil
}
func (r *retry) GetTracingInformation() (string, ext.SpanKindEnum) {
return r.name, tracing.SpanKindNoneEnum
}
func (r *retry) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// if we might make multiple attempts, swap the body for an ioutil.NopCloser
// cf https://github.com/containous/traefik/issues/1008
if r.attempts > 1 {
body := req.Body
defer body.Close()
req.Body = ioutil.NopCloser(body)
}
attempts := 1
for {
shouldRetry := attempts < r.attempts
retryResponseWriter := newResponseWriter(rw, shouldRetry)
// Disable retries when the backend already received request data
trace := &httptrace.ClientTrace{
WroteHeaders: func() {
retryResponseWriter.DisableRetries()
},
WroteRequest: func(httptrace.WroteRequestInfo) {
retryResponseWriter.DisableRetries()
},
}
newCtx := httptrace.WithClientTrace(req.Context(), trace)
r.next.ServeHTTP(retryResponseWriter, req.WithContext(newCtx))
if !retryResponseWriter.ShouldRetry() {
break
}
attempts++
logger := middlewares.GetLogger(req.Context(), r.name, typeName)
logger.Debugf("New attempt %d for request: %v", attempts, req.URL)
r.listener.Retried(req, attempts)
}
}
// Retried exists to implement the Listener interface. It calls Retried on each of its slice entries.
func (l Listeners) Retried(req *http.Request, attempt int) {
for _, listener := range l {
listener.Retried(req, attempt)
}
}
type responseWriter interface {
http.ResponseWriter
http.Flusher
ShouldRetry() bool
DisableRetries()
}
func newResponseWriter(rw http.ResponseWriter, shouldRetry bool) responseWriter {
responseWriter := &responseWriterWithoutCloseNotify{
responseWriter: rw,
headers: make(http.Header),
shouldRetry: shouldRetry,
}
if _, ok := rw.(http.CloseNotifier); ok {
return &responseWriterWithCloseNotify{
responseWriterWithoutCloseNotify: responseWriter,
}
}
return responseWriter
}
type responseWriterWithoutCloseNotify struct {
responseWriter http.ResponseWriter
headers http.Header
shouldRetry bool
written bool
}
func (r *responseWriterWithoutCloseNotify) ShouldRetry() bool {
return r.shouldRetry
}
func (r *responseWriterWithoutCloseNotify) DisableRetries() {
r.shouldRetry = false
}
func (r *responseWriterWithoutCloseNotify) Header() http.Header {
if r.written {
return r.responseWriter.Header()
}
return r.headers
}
func (r *responseWriterWithoutCloseNotify) Write(buf []byte) (int, error) {
if r.ShouldRetry() {
return len(buf), nil
}
return r.responseWriter.Write(buf)
}
func (r *responseWriterWithoutCloseNotify) WriteHeader(code int) {
if r.ShouldRetry() && code == http.StatusServiceUnavailable {
// We get a 503 HTTP Status Code when there is no backend server in the pool
// to which the request could be sent. Also, note that r.ShouldRetry()
// will never return true in case there was a connection established to
// the backend server and so we can be sure that the 503 was produced
// inside Traefik already and we don't have to retry in this cases.
r.DisableRetries()
}
if r.ShouldRetry() {
return
}
// In that case retry case is set to false which means we at least managed
// to write headers to the backend : we are not going to perform any further retry.
// So it is now safe to alter current response headers with headers collected during
// the latest try before writing headers to client.
headers := r.responseWriter.Header()
for header, value := range r.headers {
headers[header] = value
}
r.responseWriter.WriteHeader(code)
r.written = true
}
func (r *responseWriterWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := r.responseWriter.(http.Hijacker)
if !ok {
return nil, nil, fmt.Errorf("%T is not a http.Hijacker", r.responseWriter)
}
return hijacker.Hijack()
}
func (r *responseWriterWithoutCloseNotify) Flush() {
if flusher, ok := r.responseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
type responseWriterWithCloseNotify struct {
*responseWriterWithoutCloseNotify
}
func (r *responseWriterWithCloseNotify) CloseNotify() <-chan bool {
return r.responseWriter.(http.CloseNotifier).CloseNotify()
}

View file

@ -0,0 +1,312 @@
package retry
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/http/httptrace"
"strings"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares/emptybackendhandler"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/roundrobin"
)
func TestRetry(t *testing.T) {
testCases := []struct {
desc string
config config.Retry
wantRetryAttempts int
wantResponseStatus int
amountFaultyEndpoints int
}{
{
desc: "no retry on success",
config: config.Retry{Attempts: 1},
wantRetryAttempts: 0,
wantResponseStatus: http.StatusOK,
amountFaultyEndpoints: 0,
},
{
desc: "no retry when max request attempts is one",
config: config.Retry{Attempts: 1},
wantRetryAttempts: 0,
wantResponseStatus: http.StatusInternalServerError,
amountFaultyEndpoints: 1,
},
{
desc: "one retry when one server is faulty",
config: config.Retry{Attempts: 2},
wantRetryAttempts: 1,
wantResponseStatus: http.StatusOK,
amountFaultyEndpoints: 1,
},
{
desc: "two retries when two servers are faulty",
config: config.Retry{Attempts: 3},
wantRetryAttempts: 2,
wantResponseStatus: http.StatusOK,
amountFaultyEndpoints: 2,
},
{
desc: "max attempts exhausted delivers the 5xx response",
config: config.Retry{Attempts: 3},
wantRetryAttempts: 2,
wantResponseStatus: http.StatusInternalServerError,
amountFaultyEndpoints: 3,
},
}
backendServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
_, err := rw.Write([]byte("OK"))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}))
forwarder, err := forward.New()
require.NoError(t, err)
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
loadBalancer, err := roundrobin.New(forwarder)
require.NoError(t, err)
basePort := 33444
for i := 0; i < test.amountFaultyEndpoints; i++ {
// 192.0.2.0 is a non-routable IP for testing purposes.
// See: https://stackoverflow.com/questions/528538/non-routable-ip-address/18436928#18436928
// We only use the port specification here because the URL is used as identifier
// in the load balancer and using the exact same URL would not add a new server.
err = loadBalancer.UpsertServer(testhelpers.MustParseURL("http://192.0.2.0:" + string(basePort+i)))
require.NoError(t, err)
}
// add the functioning server to the end of the load balancer list
err = loadBalancer.UpsertServer(testhelpers.MustParseURL(backendServer.URL))
require.NoError(t, err)
retryListener := &countingRetryListener{}
retry, err := New(context.Background(), loadBalancer, test.config, retryListener, "traefikTest")
require.NoError(t, err)
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil)
retry.ServeHTTP(recorder, req)
assert.Equal(t, test.wantResponseStatus, recorder.Code)
assert.Equal(t, test.wantRetryAttempts, retryListener.timesCalled)
})
}
}
func TestRetryEmptyServerList(t *testing.T) {
forwarder, err := forward.New()
require.NoError(t, err)
loadBalancer, err := roundrobin.New(forwarder)
require.NoError(t, err)
// The EmptyBackend middleware ensures that there is a 503
// response status set when there is no backend server in the pool.
next := emptybackendhandler.New(loadBalancer)
retryListener := &countingRetryListener{}
retry, err := New(context.Background(), next, config.Retry{Attempts: 3}, retryListener, "traefikTest")
require.NoError(t, err)
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil)
retry.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
assert.Equal(t, 0, retryListener.timesCalled)
}
func TestRetryListeners(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
retryListeners := Listeners{&countingRetryListener{}, &countingRetryListener{}}
retryListeners.Retried(req, 1)
retryListeners.Retried(req, 1)
for _, retryListener := range retryListeners {
listener := retryListener.(*countingRetryListener)
if listener.timesCalled != 2 {
t.Errorf("retry listener was called %d time(s), want %d time(s)", listener.timesCalled, 2)
}
}
}
func TestMultipleRetriesShouldNotLooseHeaders(t *testing.T) {
attempt := 0
expectedHeaderName := "X-Foo-Test-2"
expectedHeaderValue := "bar"
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
headerName := fmt.Sprintf("X-Foo-Test-%d", attempt)
rw.Header().Add(headerName, expectedHeaderValue)
if attempt < 2 {
attempt++
return
}
// Request has been successfully written to backend
trace := httptrace.ContextClientTrace(req.Context())
trace.WroteHeaders()
// And we decide to answer to client
rw.WriteHeader(http.StatusNoContent)
})
retry, err := New(context.Background(), next, config.Retry{Attempts: 3}, &countingRetryListener{}, "traefikTest")
require.NoError(t, err)
responseRecorder := httptest.NewRecorder()
retry.ServeHTTP(responseRecorder, testhelpers.MustNewRequest(http.MethodGet, "http://test", http.NoBody))
headerValue := responseRecorder.Header().Get(expectedHeaderName)
// Validate if we have the correct header
if headerValue != expectedHeaderValue {
t.Errorf("Expected to have %s for header %s, got %s", expectedHeaderValue, expectedHeaderName, headerValue)
}
// Validate that we don't have headers from previous attempts
for i := 0; i < attempt; i++ {
headerName := fmt.Sprintf("X-Foo-Test-%d", i)
headerValue = responseRecorder.Header().Get("headerName")
if headerValue != "" {
t.Errorf("Expected no value for header %s, got %s", headerName, headerValue)
}
}
}
// countingRetryListener is a Listener implementation to count the times the Retried fn is called.
type countingRetryListener struct {
timesCalled int
}
func (l *countingRetryListener) Retried(req *http.Request, attempt int) {
l.timesCalled++
}
func TestRetryWithFlush(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(200)
_, err := rw.Write([]byte("FULL "))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
rw.(http.Flusher).Flush()
_, err = rw.Write([]byte("DATA"))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
})
retry, err := New(context.Background(), next, config.Retry{Attempts: 1}, &countingRetryListener{}, "traefikTest")
require.NoError(t, err)
responseRecorder := httptest.NewRecorder()
retry.ServeHTTP(responseRecorder, &http.Request{})
assert.Equal(t, "FULL DATA", responseRecorder.Body.String())
}
func TestRetryWebsocket(t *testing.T) {
testCases := []struct {
desc string
maxRequestAttempts int
expectedRetryAttempts int
expectedResponseStatus int
expectedError bool
amountFaultyEndpoints int
}{
{
desc: "Switching ok after 2 retries",
maxRequestAttempts: 3,
expectedRetryAttempts: 2,
amountFaultyEndpoints: 2,
expectedResponseStatus: http.StatusSwitchingProtocols,
},
{
desc: "Switching failed",
maxRequestAttempts: 2,
expectedRetryAttempts: 1,
amountFaultyEndpoints: 2,
expectedResponseStatus: http.StatusBadGateway,
expectedError: true,
},
}
forwarder, err := forward.New()
if err != nil {
t.Fatalf("Error creating forwarder: %v", err)
}
backendServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
upgrader := websocket.Upgrader{}
_, err := upgrader.Upgrade(rw, req, nil)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}))
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
loadBalancer, err := roundrobin.New(forwarder)
if err != nil {
t.Fatalf("Error creating load balancer: %v", err)
}
basePort := 33444
for i := 0; i < test.amountFaultyEndpoints; i++ {
// 192.0.2.0 is a non-routable IP for testing purposes.
// See: https://stackoverflow.com/questions/528538/non-routable-ip-address/18436928#18436928
// We only use the port specification here because the URL is used as identifier
// in the load balancer and using the exact same URL would not add a new server.
_ = loadBalancer.UpsertServer(testhelpers.MustParseURL("http://192.0.2.0:" + string(basePort+i)))
}
// add the functioning server to the end of the load balancer list
err = loadBalancer.UpsertServer(testhelpers.MustParseURL(backendServer.URL))
if err != nil {
t.Fatalf("Fail to upsert server: %v", err)
}
retryListener := &countingRetryListener{}
retryH, err := New(context.Background(), loadBalancer, config.Retry{Attempts: test.maxRequestAttempts}, retryListener, "traefikTest")
require.NoError(t, err)
retryServer := httptest.NewServer(retryH)
url := strings.Replace(retryServer.URL, "http", "ws", 1)
_, response, err := websocket.DefaultDialer.Dial(url, nil)
if !test.expectedError {
require.NoError(t, err)
}
assert.Equal(t, test.expectedResponseStatus, response.StatusCode)
assert.Equal(t, test.expectedRetryAttempts, retryListener.timesCalled)
})
}
}

View file

@ -0,0 +1,12 @@
package middlewares
import "net/http"
// Stateful interface groups all http interfaces that must be
// implemented by a stateful middleware (ie: recorders)
type Stateful interface {
http.ResponseWriter
http.Hijacker
http.Flusher
http.CloseNotifier
}

View file

@ -0,0 +1,67 @@
package stripprefix
import (
"context"
"net/http"
"strings"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
)
const (
// ForwardedPrefixHeader is the default header to set prefix.
ForwardedPrefixHeader = "X-Forwarded-Prefix"
typeName = "StripPrefix"
)
// stripPrefix is a middleware used to strip prefix from an URL request.
type stripPrefix struct {
next http.Handler
prefixes []string
name string
}
// New creates a new strip prefix middleware.
func New(ctx context.Context, next http.Handler, config config.StripPrefix, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
return &stripPrefix{
prefixes: config.Prefixes,
next: next,
name: name,
}, nil
}
func (s *stripPrefix) GetTracingInformation() (string, ext.SpanKindEnum) {
return s.name, tracing.SpanKindNoneEnum
}
func (s *stripPrefix) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
for _, prefix := range s.prefixes {
if strings.HasPrefix(req.URL.Path, prefix) {
req.URL.Path = getPrefixStripped(req.URL.Path, prefix)
if req.URL.RawPath != "" {
req.URL.RawPath = getPrefixStripped(req.URL.RawPath, prefix)
}
s.serveRequest(rw, req, strings.TrimSpace(prefix))
return
}
}
http.NotFound(rw, req)
}
func (s *stripPrefix) serveRequest(rw http.ResponseWriter, req *http.Request, prefix string) {
req.Header.Add(ForwardedPrefixHeader, prefix)
req.RequestURI = req.URL.RequestURI()
s.next.ServeHTTP(rw, req)
}
func getPrefixStripped(s, prefix string) string {
return ensureLeadingSlash(strings.TrimPrefix(s, prefix))
}
func ensureLeadingSlash(str string) string {
return "/" + strings.TrimPrefix(str, "/")
}

View file

@ -0,0 +1,169 @@
package stripprefix
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStripPrefix(t *testing.T) {
testCases := []struct {
desc string
config config.StripPrefix
path string
expectedStatusCode int
expectedPath string
expectedRawPath string
expectedHeader string
}{
{
desc: "no prefixes configured",
config: config.StripPrefix{
Prefixes: []string{},
},
path: "/noprefixes",
expectedStatusCode: http.StatusNotFound,
},
{
desc: "wildcard (.*) requests",
config: config.StripPrefix{
Prefixes: []string{"/"},
},
path: "/",
expectedStatusCode: http.StatusOK,
expectedPath: "/",
expectedHeader: "/",
},
{
desc: "prefix and path matching",
config: config.StripPrefix{
Prefixes: []string{"/stat"},
},
path: "/stat",
expectedStatusCode: http.StatusOK,
expectedPath: "/",
expectedHeader: "/stat",
},
{
desc: "path prefix on exactly matching path",
config: config.StripPrefix{
Prefixes: []string{"/stat/"},
},
path: "/stat/",
expectedStatusCode: http.StatusOK,
expectedPath: "/",
expectedHeader: "/stat/",
},
{
desc: "path prefix on matching longer path",
config: config.StripPrefix{
Prefixes: []string{"/stat/"},
},
path: "/stat/us",
expectedStatusCode: http.StatusOK,
expectedPath: "/us",
expectedHeader: "/stat/",
},
{
desc: "path prefix on mismatching path",
config: config.StripPrefix{
Prefixes: []string{"/stat/"},
},
path: "/status",
expectedStatusCode: http.StatusNotFound,
},
{
desc: "general prefix on matching path",
config: config.StripPrefix{
Prefixes: []string{"/stat"},
},
path: "/stat/",
expectedStatusCode: http.StatusOK,
expectedPath: "/",
expectedHeader: "/stat",
},
{
desc: "earlier prefix matching",
config: config.StripPrefix{
Prefixes: []string{"/stat", "/stat/us"},
},
path: "/stat/us",
expectedStatusCode: http.StatusOK,
expectedPath: "/us",
expectedHeader: "/stat",
},
{
desc: "later prefix matching",
config: config.StripPrefix{
Prefixes: []string{"/mismatch", "/stat"},
},
path: "/stat",
expectedStatusCode: http.StatusOK,
expectedPath: "/",
expectedHeader: "/stat",
},
{
desc: "prefix matching within slash boundaries",
config: config.StripPrefix{
Prefixes: []string{"/stat"},
},
path: "/status",
expectedStatusCode: http.StatusOK,
expectedPath: "/us",
expectedHeader: "/stat",
},
{
desc: "raw path is also stripped",
config: config.StripPrefix{
Prefixes: []string{"/stat"},
},
path: "/stat/a%2Fb",
expectedStatusCode: http.StatusOK,
expectedPath: "/a/b",
expectedRawPath: "/a%2Fb",
expectedHeader: "/stat",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
var actualPath, actualRawPath, actualHeader, requestURI string
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actualPath = r.URL.Path
actualRawPath = r.URL.RawPath
actualHeader = r.Header.Get(ForwardedPrefixHeader)
requestURI = r.RequestURI
})
handler, err := New(context.Background(), next, test.config, "foo-strip-prefix")
require.NoError(t, err)
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
resp := &httptest.ResponseRecorder{Code: http.StatusOK}
handler.ServeHTTP(resp, req)
assert.Equal(t, test.expectedStatusCode, resp.Code, "Unexpected status code.")
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
assert.Equal(t, test.expectedRawPath, actualRawPath, "Unexpected raw path.")
assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", ForwardedPrefixHeader)
expectedURI := test.expectedPath
if test.expectedRawPath != "" {
// go HTTP uses the raw path when existent in the RequestURI
expectedURI = test.expectedRawPath
}
assert.Equal(t, expectedURI, requestURI, "Unexpected request URI.")
})
}
}

View file

@ -0,0 +1,79 @@
package stripprefixregex
import (
"context"
"net/http"
"strings"
"github.com/containous/mux"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/middlewares/stripprefix"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
)
const (
typeName = "StripPrefixRegex"
)
// StripPrefixRegex is a middleware used to strip prefix from an URL request.
type stripPrefixRegex struct {
next http.Handler
router *mux.Router
name string
}
// New builds a new StripPrefixRegex middleware.
func New(ctx context.Context, next http.Handler, config config.StripPrefixRegex, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
stripPrefix := stripPrefixRegex{
next: next,
router: mux.NewRouter(),
name: name,
}
for _, prefix := range config.Regex {
stripPrefix.router.PathPrefix(prefix)
}
return &stripPrefix, nil
}
func (s *stripPrefixRegex) GetTracingInformation() (string, ext.SpanKindEnum) {
return s.name, tracing.SpanKindNoneEnum
}
func (s *stripPrefixRegex) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
var match mux.RouteMatch
if s.router.Match(req, &match) {
params := make([]string, 0, len(match.Vars)*2)
for key, val := range match.Vars {
params = append(params, key)
params = append(params, val)
}
prefix, err := match.Route.URL(params...)
if err != nil || len(prefix.Path) > len(req.URL.Path) {
logger := middlewares.GetLogger(req.Context(), s.name, typeName)
logger.Error("Error in stripPrefix middleware", err)
return
}
req.URL.Path = req.URL.Path[len(prefix.Path):]
if req.URL.RawPath != "" {
req.URL.RawPath = req.URL.RawPath[len(prefix.Path):]
}
req.Header.Add(stripprefix.ForwardedPrefixHeader, prefix.Path)
req.RequestURI = ensureLeadingSlash(req.URL.RequestURI())
s.next.ServeHTTP(rw, req)
return
}
http.NotFound(rw, req)
}
func ensureLeadingSlash(str string) string {
return "/" + strings.TrimPrefix(str, "/")
}

View file

@ -0,0 +1,104 @@
package stripprefixregex
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares/stripprefix"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStripPrefixRegex(t *testing.T) {
testPrefixRegex := config.StripPrefixRegex{
Regex: []string{"/a/api/", "/b/{regex}/", "/c/{category}/{id:[0-9]+}/"},
}
testCases := []struct {
path string
expectedStatusCode int
expectedPath string
expectedRawPath string
expectedHeader string
}{
{
path: "/a/test",
expectedStatusCode: http.StatusNotFound,
},
{
path: "/a/api/test",
expectedStatusCode: http.StatusOK,
expectedPath: "test",
expectedHeader: "/a/api/",
},
{
path: "/b/api/",
expectedStatusCode: http.StatusOK,
expectedHeader: "/b/api/",
},
{
path: "/b/api/test1",
expectedStatusCode: http.StatusOK,
expectedPath: "test1",
expectedHeader: "/b/api/",
},
{
path: "/b/api2/test2",
expectedStatusCode: http.StatusOK,
expectedPath: "test2",
expectedHeader: "/b/api2/",
},
{
path: "/c/api/123/",
expectedStatusCode: http.StatusOK,
expectedHeader: "/c/api/123/",
},
{
path: "/c/api/123/test3",
expectedStatusCode: http.StatusOK,
expectedPath: "test3",
expectedHeader: "/c/api/123/",
},
{
path: "/c/api/abc/test4",
expectedStatusCode: http.StatusNotFound,
},
{
path: "/a/api/a%2Fb",
expectedStatusCode: http.StatusOK,
expectedPath: "a/b",
expectedRawPath: "a%2Fb",
expectedHeader: "/a/api/",
},
}
for _, test := range testCases {
test := test
t.Run(test.path, func(t *testing.T) {
t.Parallel()
var actualPath, actualRawPath, actualHeader string
handlerPath := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actualPath = r.URL.Path
actualRawPath = r.URL.RawPath
actualHeader = r.Header.Get(stripprefix.ForwardedPrefixHeader)
})
handler, err := New(context.Background(), handlerPath, testPrefixRegex, "foo-strip-prefix-regex")
require.NoError(t, err)
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
resp := &httptest.ResponseRecorder{Code: http.StatusOK}
handler.ServeHTTP(resp, req)
assert.Equal(t, test.expectedStatusCode, resp.Code, "Unexpected status code.")
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
assert.Equal(t, test.expectedRawPath, actualRawPath, "Unexpected raw path.")
assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", stripprefix.ForwardedPrefixHeader)
})
}
}

View file

@ -0,0 +1,57 @@
package tracing
import (
"context"
"net/http"
"github.com/containous/alice"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
)
const (
entryPointTypeName = "TracingEntryPoint"
)
// NewEntryPoint creates a new middleware that the incoming request.
func NewEntryPoint(ctx context.Context, t *tracing.Tracing, entryPointName string, next http.Handler) http.Handler {
middlewares.GetLogger(ctx, "tracing", entryPointTypeName).Debug("Creating middleware")
return &entryPointMiddleware{
entryPoint: entryPointName,
Tracing: t,
next: next,
}
}
type entryPointMiddleware struct {
*tracing.Tracing
entryPoint string
next http.Handler
}
func (e *entryPointMiddleware) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
spanCtx, _ := e.Extract(opentracing.HTTPHeaders, tracing.HTTPHeadersCarrier(req.Header))
span, req, finish := e.StartSpanf(req, ext.SpanKindRPCServerEnum, "EntryPoint", []string{e.entryPoint, req.Host}, " ", ext.RPCServerOption(spanCtx))
defer finish()
ext.Component.Set(span, e.ServiceName)
tracing.LogRequest(span, req)
req = req.WithContext(tracing.WithTracing(req.Context(), e.Tracing))
recorder := newStatusCodeRecoder(rw, http.StatusOK)
e.next.ServeHTTP(recorder, req)
tracing.LogResponseCode(span, recorder.Status())
}
// WrapEntryPointHandler Wraps tracing to alice.Constructor.
func WrapEntryPointHandler(ctx context.Context, tracer *tracing.Tracing, entryPointName string) alice.Constructor {
return func(next http.Handler) (http.Handler, error) {
return NewEntryPoint(ctx, tracer, entryPointName, next), nil
}
}

View file

@ -0,0 +1,87 @@
package tracing
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestEntryPointMiddleware(t *testing.T) {
type expected struct {
Tags map[string]interface{}
OperationName string
}
testCases := []struct {
desc string
entryPoint string
spanNameLimit int
tracing *trackingBackenMock
expected expected
}{
{
desc: "no truncation test",
entryPoint: "test",
spanNameLimit: 0,
tracing: &trackingBackenMock{
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
},
expected: expected{
Tags: map[string]interface{}{
"span.kind": ext.SpanKindRPCServerEnum,
"http.method": http.MethodGet,
"component": "",
"http.url": "http://www.test.com",
"http.host": "www.test.com",
},
OperationName: "EntryPoint test www.test.com",
},
},
{
desc: "basic test",
entryPoint: "test",
spanNameLimit: 25,
tracing: &trackingBackenMock{
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
},
expected: expected{
Tags: map[string]interface{}{
"span.kind": ext.SpanKindRPCServerEnum,
"http.method": http.MethodGet,
"component": "",
"http.url": "http://www.test.com",
"http.host": "www.test.com",
},
OperationName: "EntryPoint te... ww... 0c15301b",
},
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
newTracing, err := tracing.NewTracing("", test.spanNameLimit, test.tracing)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "http://www.test.com", nil)
rw := httptest.NewRecorder()
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
span := test.tracing.tracer.(*MockTracer).Span
tags := span.Tags
assert.Equal(t, test.expected.Tags, tags)
assert.Equal(t, test.expected.OperationName, span.OpName)
})
handler := NewEntryPoint(context.Background(), newTracing, test.entryPoint, next)
handler.ServeHTTP(rw, req)
})
}
}

View file

@ -0,0 +1,58 @@
package tracing
import (
"context"
"net/http"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
)
const (
forwarderTypeName = "TracingForwarder"
)
type forwarderMiddleware struct {
router string
service string
next http.Handler
}
// NewForwarder creates a new forwarder middleware that traces the outgoing request.
func NewForwarder(ctx context.Context, router, service string, next http.Handler) http.Handler {
middlewares.GetLogger(ctx, "tracing", forwarderTypeName).
Debugf("Added outgoing tracing middleware %s", service)
return &forwarderMiddleware{
router: router,
service: service,
next: next,
}
}
func (f *forwarderMiddleware) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
tr, err := tracing.FromContext(req.Context())
if err != nil {
f.next.ServeHTTP(rw, req)
return
}
opParts := []string{f.service, f.router}
span, req, finish := tr.StartSpanf(req, ext.SpanKindRPCClientEnum, "forward", opParts, "/")
defer finish()
span.SetTag("service.name", f.service)
span.SetTag("router.name", f.router)
ext.HTTPMethod.Set(span, req.Method)
ext.HTTPUrl.Set(span, req.URL.String())
span.SetTag("http.host", req.Host)
tracing.InjectRequestHeaders(req)
recorder := newStatusCodeRecoder(rw, 200)
f.next.ServeHTTP(recorder, req)
tracing.LogResponseCode(span, recorder.Status())
}

View file

@ -0,0 +1,136 @@
package tracing
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewForwarder(t *testing.T) {
type expected struct {
Tags map[string]interface{}
OperationName string
}
testCases := []struct {
desc string
spanNameLimit int
tracing *trackingBackenMock
service string
router string
expected expected
}{
{
desc: "Simple Forward Tracer without truncation and hashing",
spanNameLimit: 101,
tracing: &trackingBackenMock{
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
},
service: "some-service.domain.tld",
router: "some-service.domain.tld",
expected: expected{
Tags: map[string]interface{}{
"http.host": "www.test.com",
"http.method": "GET",
"http.url": "http://www.test.com/toto",
"service.name": "some-service.domain.tld",
"router.name": "some-service.domain.tld",
"span.kind": ext.SpanKindRPCClientEnum,
},
OperationName: "forward some-service.domain.tld/some-service.domain.tld",
},
}, {
desc: "Simple Forward Tracer with truncation and hashing",
spanNameLimit: 101,
tracing: &trackingBackenMock{
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
},
service: "some-service-100.slug.namespace.environment.domain.tld",
router: "some-service-100.slug.namespace.environment.domain.tld",
expected: expected{
Tags: map[string]interface{}{
"http.host": "www.test.com",
"http.method": "GET",
"http.url": "http://www.test.com/toto",
"service.name": "some-service-100.slug.namespace.environment.domain.tld",
"router.name": "some-service-100.slug.namespace.environment.domain.tld",
"span.kind": ext.SpanKindRPCClientEnum,
},
OperationName: "forward some-service-100.slug.namespace.enviro.../some-service-100.slug.namespace.enviro.../bc4a0d48",
},
},
{
desc: "Exactly 101 chars",
spanNameLimit: 101,
tracing: &trackingBackenMock{
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
},
service: "some-service1.namespace.environment.domain.tld",
router: "some-service1.namespace.environment.domain.tld",
expected: expected{
Tags: map[string]interface{}{
"http.host": "www.test.com",
"http.method": "GET",
"http.url": "http://www.test.com/toto",
"service.name": "some-service1.namespace.environment.domain.tld",
"router.name": "some-service1.namespace.environment.domain.tld",
"span.kind": ext.SpanKindRPCClientEnum,
},
OperationName: "forward some-service1.namespace.environment.domain.tld/some-service1.namespace.environment.domain.tld",
},
},
{
desc: "More than 101 chars",
spanNameLimit: 101,
tracing: &trackingBackenMock{
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
},
service: "some-service1.frontend.namespace.environment.domain.tld",
router: "some-service1.backend.namespace.environment.domain.tld",
expected: expected{
Tags: map[string]interface{}{
"http.host": "www.test.com",
"http.method": "GET",
"http.url": "http://www.test.com/toto",
"service.name": "some-service1.frontend.namespace.environment.domain.tld",
"router.name": "some-service1.backend.namespace.environment.domain.tld",
"span.kind": ext.SpanKindRPCClientEnum,
},
OperationName: "forward some-service1.frontend.namespace.envir.../some-service1.backend.namespace.enviro.../fa49dd23",
},
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
newTracing, err := tracing.NewTracing("", test.spanNameLimit, test.tracing)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodGet, "http://www.test.com/toto", nil)
req = req.WithContext(tracing.WithTracing(req.Context(), newTracing))
rw := httptest.NewRecorder()
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
span := test.tracing.tracer.(*MockTracer).Span
tags := span.Tags
assert.Equal(t, test.expected.Tags, tags)
assert.True(t, len(test.expected.OperationName) <= test.spanNameLimit,
"the len of the operation name %q [len: %d] doesn't respect limit %d",
test.expected.OperationName, len(test.expected.OperationName), test.spanNameLimit)
assert.Equal(t, test.expected.OperationName, span.OpName)
})
handler := NewForwarder(context.Background(), test.router, test.service, next)
handler.ServeHTTP(rw, req)
})
}
}

View file

@ -0,0 +1,70 @@
package tracing
import (
"io"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/log"
)
type MockTracer struct {
Span *MockSpan
}
// StartSpan belongs to the Tracer interface.
func (n MockTracer) StartSpan(operationName string, opts ...opentracing.StartSpanOption) opentracing.Span {
n.Span.OpName = operationName
return n.Span
}
// Inject belongs to the Tracer interface.
func (n MockTracer) Inject(sp opentracing.SpanContext, format interface{}, carrier interface{}) error {
return nil
}
// Extract belongs to the Tracer interface.
func (n MockTracer) Extract(format interface{}, carrier interface{}) (opentracing.SpanContext, error) {
return nil, opentracing.ErrSpanContextNotFound
}
// MockSpanContext
type MockSpanContext struct{}
func (n MockSpanContext) ForeachBaggageItem(handler func(k, v string) bool) {}
// MockSpan
type MockSpan struct {
OpName string
Tags map[string]interface{}
}
func (n MockSpan) Context() opentracing.SpanContext { return MockSpanContext{} }
func (n MockSpan) SetBaggageItem(key, val string) opentracing.Span {
return MockSpan{Tags: make(map[string]interface{})}
}
func (n MockSpan) BaggageItem(key string) string { return "" }
func (n MockSpan) SetTag(key string, value interface{}) opentracing.Span {
n.Tags[key] = value
return n
}
func (n MockSpan) LogFields(fields ...log.Field) {}
func (n MockSpan) LogKV(keyVals ...interface{}) {}
func (n MockSpan) Finish() {}
func (n MockSpan) FinishWithOptions(opts opentracing.FinishOptions) {}
func (n MockSpan) SetOperationName(operationName string) opentracing.Span { return n }
func (n MockSpan) Tracer() opentracing.Tracer { return MockTracer{} }
func (n MockSpan) LogEvent(event string) {}
func (n MockSpan) LogEventWithPayload(event string, payload interface{}) {}
func (n MockSpan) Log(data opentracing.LogData) {}
func (n MockSpan) Reset() {
n.Tags = make(map[string]interface{})
}
type trackingBackenMock struct {
tracer opentracing.Tracer
}
func (t *trackingBackenMock) Setup(componentName string) (opentracing.Tracer, io.Closer, error) {
opentracing.SetGlobalTracer(t.tracer)
return t.tracer, nil, nil
}

View file

@ -0,0 +1,57 @@
package tracing
import (
"bufio"
"net"
"net/http"
)
type statusCodeRecoder interface {
http.ResponseWriter
Status() int
}
type statusCodeWithoutCloseNotify struct {
http.ResponseWriter
status int
}
// WriteHeader captures the status code for later retrieval.
func (s *statusCodeWithoutCloseNotify) WriteHeader(status int) {
s.status = status
s.ResponseWriter.WriteHeader(status)
}
// Status get response status
func (s *statusCodeWithoutCloseNotify) Status() int {
return s.status
}
// Hijack hijacks the connection
func (s *statusCodeWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return s.ResponseWriter.(http.Hijacker).Hijack()
}
// Flush sends any buffered data to the client.
func (s *statusCodeWithoutCloseNotify) Flush() {
if flusher, ok := s.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
type statusCodeWithCloseNotify struct {
*statusCodeWithoutCloseNotify
}
func (s *statusCodeWithCloseNotify) CloseNotify() <-chan bool {
return s.ResponseWriter.(http.CloseNotifier).CloseNotify()
}
// newStatusCodeRecoder returns an initialized statusCodeRecoder.
func newStatusCodeRecoder(rw http.ResponseWriter, status int) statusCodeRecoder {
recorder := &statusCodeWithoutCloseNotify{rw, status}
if _, ok := rw.(http.CloseNotifier); ok {
return &statusCodeWithCloseNotify{recorder}
}
return recorder
}

View file

@ -0,0 +1,68 @@
package tracing
import (
"context"
"net/http"
"github.com/containous/alice"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
)
// Tracable embeds tracing information.
type Tracable interface {
GetTracingInformation() (name string, spanKind ext.SpanKindEnum)
}
// Wrap adds tracability to an alice.Constructor.
func Wrap(ctx context.Context, constructor alice.Constructor) alice.Constructor {
return func(next http.Handler) (http.Handler, error) {
if constructor == nil {
return nil, nil
}
handler, err := constructor(next)
if err != nil {
return nil, err
}
if tracableHandler, ok := handler.(Tracable); ok {
name, spanKind := tracableHandler.GetTracingInformation()
log.FromContext(ctx).WithField(log.MiddlewareName, name).Debug("Adding tracing to middleware")
return NewWrapper(handler, name, spanKind), nil
}
return handler, nil
}
}
// NewWrapper returns a http.Handler struct
func NewWrapper(next http.Handler, name string, spanKind ext.SpanKindEnum) http.Handler {
return &Wrapper{
next: next,
name: name,
spanKind: spanKind,
}
}
// Wrapper is used to wrap http handler middleware.
type Wrapper struct {
next http.Handler
name string
spanKind ext.SpanKindEnum
}
func (w *Wrapper) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
_, err := tracing.FromContext(req.Context())
if err != nil {
w.next.ServeHTTP(rw, req)
return
}
var finish func()
_, req, finish = tracing.StartSpan(req, w.name, w.spanKind)
defer finish()
if w.next != nil {
w.next.ServeHTTP(rw, req)
}
}