1
0
Fork 0

Dynamic Configuration Refactoring

This commit is contained in:
Ludovic Fernandez 2018-11-14 10:18:03 +01:00 committed by Traefiker Bot
parent d3ae88f108
commit a09dfa3ce1
452 changed files with 21023 additions and 9419 deletions

View file

@ -6,7 +6,7 @@ import (
"net"
"net/http"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/old/middlewares"
)
var (

View file

@ -0,0 +1,59 @@
package accesslog
import (
"net/http"
"time"
"github.com/vulcand/oxy/utils"
)
// FieldApply function hook to add data in accesslog
type FieldApply func(rw http.ResponseWriter, r *http.Request, next http.Handler, data *LogData)
// FieldHandler sends a new field to the logger.
type FieldHandler struct {
next http.Handler
name string
value string
applyFn FieldApply
}
// NewFieldHandler creates a Field handler.
func NewFieldHandler(next http.Handler, name string, value string, applyFn FieldApply) http.Handler {
return &FieldHandler{next: next, name: name, value: value, applyFn: applyFn}
}
func (f *FieldHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
table := GetLogData(req)
if table == nil {
f.next.ServeHTTP(rw, req)
return
}
table.Core[f.name] = f.value
if f.applyFn != nil {
f.applyFn(rw, req, f.next, table)
} else {
f.next.ServeHTTP(rw, req)
}
}
// AddServiceFields add service fields
func AddServiceFields(rw http.ResponseWriter, req *http.Request, next http.Handler, data *LogData) {
data.Core[ServiceURL] = req.URL // note that this is *not* the original incoming URL
data.Core[ServiceAddr] = req.URL.Host
crw := &captureResponseWriter{rw: rw}
start := time.Now().UTC()
next.ServeHTTP(crw, req)
// use UTC to handle switchover of daylight saving correctly
data.Core[OriginDuration] = time.Now().UTC().Sub(start)
data.Core[OriginStatus] = crw.Status()
// make copy of headers so we can ensure there is no subsequent mutation during response processing
data.OriginResponse = make(http.Header)
utils.CopyHeaders(data.OriginResponse, crw.Header())
data.Core[OriginContentSize] = crw.Size()
}

View file

@ -12,14 +12,16 @@ const (
// Duration is the map key used for the total time taken by processing the response, including the origin server's time but
// not the log writing time.
Duration = "Duration"
// FrontendName is the map key used for the name of the Traefik frontend.
FrontendName = "FrontendName"
// BackendName is the map key used for the name of the Traefik backend.
BackendName = "BackendName"
// BackendURL is the map key used for the URL of the Traefik backend.
BackendURL = "BackendURL"
// BackendAddr is the map key used for the IP:port of the Traefik backend (extracted from BackendURL)
BackendAddr = "BackendAddr"
// RouterName is the map key used for the name of the Traefik router.
RouterName = "RouterName"
// ServiceName is the map key used for the name of the Traefik backend.
ServiceName = "ServiceName"
// ServiceURL is the map key used for the URL of the Traefik backend.
ServiceURL = "ServiceURL"
// ServiceAddr is the map key used for the IP:port of the Traefik backend (extracted from BackendURL)
ServiceAddr = "ServiceAddr"
// ClientAddr is the map key used for the remote address in its original form (usually IP:port).
ClientAddr = "ClientAddr"
// ClientHost is the map key used for the remote IP address from which the client request was received.
@ -72,9 +74,9 @@ const (
var defaultCoreKeys = [...]string{
StartUTC,
Duration,
FrontendName,
BackendName,
BackendURL,
RouterName,
ServiceName,
ServiceURL,
ClientHost,
ClientPort,
ClientUsername,
@ -99,7 +101,7 @@ func init() {
for _, k := range defaultCoreKeys {
allCoreKeys[k] = struct{}{}
}
allCoreKeys[BackendAddr] = struct{}{}
allCoreKeys[ServiceAddr] = struct{}{}
allCoreKeys[ClientAddr] = struct{}{}
allCoreKeys[RequestAddr] = struct{}{}
allCoreKeys[GzipRatio] = struct{}{}

View file

@ -12,6 +12,7 @@ import (
"sync/atomic"
"time"
"github.com/containous/alice"
"github.com/containous/flaeg/parse"
"github.com/containous/traefik/log"
"github.com/containous/traefik/types"
@ -21,36 +22,44 @@ import (
type key string
const (
// DataTableKey is the key within the request context used to
// store the Log Data Table
// DataTableKey is the key within the request context used to store the Log Data Table.
DataTableKey key = "LogDataTable"
// CommonFormat is the common logging format (CLF)
// CommonFormat is the common logging format (CLF).
CommonFormat string = "common"
// JSONFormat is the JSON logging format
// JSONFormat is the JSON logging format.
JSONFormat string = "json"
)
type logHandlerParams struct {
type handlerParams struct {
logDataTable *LogData
crr *captureRequestReader
crw *captureResponseWriter
}
// LogHandler will write each request and its response to the access log.
type LogHandler struct {
// Handler will write each request and its response to the access log.
type Handler struct {
config *types.AccessLog
logger *logrus.Logger
file *os.File
mu sync.Mutex
httpCodeRanges types.HTTPCodeRanges
logHandlerChan chan logHandlerParams
logHandlerChan chan handlerParams
wg sync.WaitGroup
}
// NewLogHandler creates a new LogHandler
func NewLogHandler(config *types.AccessLog) (*LogHandler, error) {
// WrapHandler Wraps access log handler into an Alice Constructor.
func WrapHandler(handler *Handler) alice.Constructor {
return func(next http.Handler) (http.Handler, error) {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
handler.ServeHTTP(rw, req, next.ServeHTTP)
}), nil
}
}
// NewHandler creates a new Handler.
func NewHandler(config *types.AccessLog) (*Handler, error) {
file := os.Stdout
if len(config.FilePath) > 0 {
f, err := openAccessLogFile(config.FilePath)
@ -59,7 +68,7 @@ func NewLogHandler(config *types.AccessLog) (*LogHandler, error) {
}
file = f
}
logHandlerChan := make(chan logHandlerParams, config.BufferingSize)
logHandlerChan := make(chan handlerParams, config.BufferingSize)
var formatter logrus.Formatter
@ -79,7 +88,7 @@ func NewLogHandler(config *types.AccessLog) (*LogHandler, error) {
Level: logrus.InfoLevel,
}
logHandler := &LogHandler{
logHandler := &Handler{
config: config,
logger: logger,
file: file,
@ -88,7 +97,7 @@ func NewLogHandler(config *types.AccessLog) (*LogHandler, error) {
if config.Filters != nil {
if httpCodeRanges, err := types.NewHTTPCodeRanges(config.Filters.StatusCodes); err != nil {
log.Errorf("Failed to create new HTTP code ranges: %s", err)
log.WithoutContext().Errorf("Failed to create new HTTP code ranges: %s", err)
} else {
logHandler.httpCodeRanges = httpCodeRanges
}
@ -122,17 +131,16 @@ func openAccessLogFile(filePath string) (*os.File, error) {
return file, nil
}
// GetLogDataTable gets the request context object that contains logging data.
// GetLogData gets the request context object that contains logging data.
// This creates data as the request passes through the middleware chain.
func GetLogDataTable(req *http.Request) *LogData {
func GetLogData(req *http.Request) *LogData {
if ld, ok := req.Context().Value(DataTableKey).(*LogData); ok {
return ld
}
log.Errorf("%s is nil", DataTableKey)
return &LogData{Core: make(CoreLogData)}
return nil
}
func (l *LogHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
now := time.Now().UTC()
core := CoreLogData{
@ -179,46 +187,45 @@ func (l *LogHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next h
next.ServeHTTP(crw, reqWithDataTable)
core[ClientUsername] = formatUsernameForLog(core[ClientUsername])
core[ClientUsername] = usernameIfPresent(reqWithDataTable.URL)
logDataTable.DownstreamResponse = crw.Header()
if l.config.BufferingSize > 0 {
l.logHandlerChan <- logHandlerParams{
if h.config.BufferingSize > 0 {
h.logHandlerChan <- handlerParams{
logDataTable: logDataTable,
crr: crr,
crw: crw,
}
} else {
l.logTheRoundTrip(logDataTable, crr, crw)
h.logTheRoundTrip(logDataTable, crr, crw)
}
}
// Close closes the Logger (i.e. the file, drain logHandlerChan, etc).
func (l *LogHandler) Close() error {
close(l.logHandlerChan)
l.wg.Wait()
return l.file.Close()
func (h *Handler) Close() error {
close(h.logHandlerChan)
h.wg.Wait()
return h.file.Close()
}
// Rotate closes and reopens the log file to allow for rotation
// by an external source.
func (l *LogHandler) Rotate() error {
// Rotate closes and reopens the log file to allow for rotation by an external source.
func (h *Handler) Rotate() error {
var err error
if l.file != nil {
if h.file != nil {
defer func(f *os.File) {
f.Close()
}(l.file)
}(h.file)
}
l.file, err = os.OpenFile(l.config.FilePath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0664)
h.file, err = os.OpenFile(h.config.FilePath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0664)
if err != nil {
return err
}
l.mu.Lock()
defer l.mu.Unlock()
l.logger.Out = l.file
h.mu.Lock()
defer h.mu.Unlock()
h.logger.Out = h.file
return nil
}
@ -230,16 +237,17 @@ func silentSplitHostPort(value string) (host string, port string) {
return host, port
}
func formatUsernameForLog(usernameField interface{}) string {
username, ok := usernameField.(string)
if ok && len(username) != 0 {
return username
func usernameIfPresent(theURL *url.URL) string {
if theURL.User != nil {
if name := theURL.User.Username(); name != "" {
return name
}
}
return "-"
}
// Logging handler to log frontend name, backend name, and elapsed time
func (l *LogHandler) logTheRoundTrip(logDataTable *LogData, crr *captureRequestReader, crw *captureResponseWriter) {
// Logging handler to log frontend name, backend name, and elapsed time.
func (h *Handler) logTheRoundTrip(logDataTable *LogData, crr *captureRequestReader, crw *captureResponseWriter) {
core := logDataTable.Core
retryAttempts, ok := core[RetryAttempts].(int)
@ -254,11 +262,11 @@ func (l *LogHandler) logTheRoundTrip(logDataTable *LogData, crr *captureRequestR
core[DownstreamStatus] = crw.Status()
// n.b. take care to perform time arithmetic using UTC to avoid errors at DST boundaries
// n.b. take care to perform time arithmetic using UTC to avoid errors at DST boundaries.
totalDuration := time.Now().UTC().Sub(core[StartUTC].(time.Time))
core[Duration] = totalDuration
if l.keepAccessLog(crw.Status(), retryAttempts, totalDuration) {
if h.keepAccessLog(crw.Status(), retryAttempts, totalDuration) {
core[DownstreamContentSize] = crw.Size()
if original, ok := core[OriginContentSize]; ok {
o64 := original.(int64)
@ -275,24 +283,24 @@ func (l *LogHandler) logTheRoundTrip(logDataTable *LogData, crr *captureRequestR
fields := logrus.Fields{}
for k, v := range logDataTable.Core {
if l.config.Fields.Keep(k) {
if h.config.Fields.Keep(k) {
fields[k] = v
}
}
l.redactHeaders(logDataTable.Request, fields, "request_")
l.redactHeaders(logDataTable.OriginResponse, fields, "origin_")
l.redactHeaders(logDataTable.DownstreamResponse, fields, "downstream_")
h.redactHeaders(logDataTable.Request, fields, "request_")
h.redactHeaders(logDataTable.OriginResponse, fields, "origin_")
h.redactHeaders(logDataTable.DownstreamResponse, fields, "downstream_")
l.mu.Lock()
defer l.mu.Unlock()
l.logger.WithFields(fields).Println()
h.mu.Lock()
defer h.mu.Unlock()
h.logger.WithFields(fields).Println()
}
}
func (l *LogHandler) redactHeaders(headers http.Header, fields logrus.Fields, prefix string) {
func (h *Handler) redactHeaders(headers http.Header, fields logrus.Fields, prefix string) {
for k := range headers {
v := l.config.Fields.KeepHeader(k)
v := h.config.Fields.KeepHeader(k)
if v == types.AccessLogKeep {
fields[prefix+k] = headers.Get(k)
} else if v == types.AccessLogRedact {
@ -301,26 +309,26 @@ func (l *LogHandler) redactHeaders(headers http.Header, fields logrus.Fields, pr
}
}
func (l *LogHandler) keepAccessLog(statusCode, retryAttempts int, duration time.Duration) bool {
if l.config.Filters == nil {
func (h *Handler) keepAccessLog(statusCode, retryAttempts int, duration time.Duration) bool {
if h.config.Filters == nil {
// no filters were specified
return true
}
if len(l.httpCodeRanges) == 0 && !l.config.Filters.RetryAttempts && l.config.Filters.MinDuration == 0 {
if len(h.httpCodeRanges) == 0 && !h.config.Filters.RetryAttempts && h.config.Filters.MinDuration == 0 {
// empty filters were specified, e.g. by passing --accessLog.filters only (without other filter options)
return true
}
if l.httpCodeRanges.Contains(statusCode) {
if h.httpCodeRanges.Contains(statusCode) {
return true
}
if l.config.Filters.RetryAttempts && retryAttempts > 0 {
if h.config.Filters.RetryAttempts && retryAttempts > 0 {
return true
}
if l.config.Filters.MinDuration > 0 && (parse.Duration(duration) > l.config.Filters.MinDuration) {
if h.config.Filters.MinDuration > 0 && (parse.Duration(duration) > h.config.Filters.MinDuration) {
return true
}

View file

@ -8,16 +8,16 @@ import (
"github.com/sirupsen/logrus"
)
// default format for time presentation
// default format for time presentation.
const (
commonLogTimeFormat = "02/Jan/2006:15:04:05 -0700"
defaultValue = "-"
)
// CommonLogFormatter provides formatting in the Traefik common log format
// CommonLogFormatter provides formatting in the Traefik common log format.
type CommonLogFormatter struct{}
// Format formats the log entry in the Traefik common log format
// Format formats the log entry in the Traefik common log format.
func (f *CommonLogFormatter) Format(entry *logrus.Entry) ([]byte, error) {
b := &bytes.Buffer{}
@ -43,8 +43,8 @@ func (f *CommonLogFormatter) Format(entry *logrus.Entry) ([]byte, error) {
toLog(entry.Data, "request_Referer", `"-"`, true),
toLog(entry.Data, "request_User-Agent", `"-"`, true),
toLog(entry.Data, RequestCount, defaultValue, true),
toLog(entry.Data, FrontendName, defaultValue, true),
toLog(entry.Data, BackendURL, defaultValue, true),
toLog(entry.Data, RouterName, defaultValue, true),
toLog(entry.Data, ServiceURL, defaultValue, true),
elapsedMillis)
return b.Bytes(), err
@ -70,6 +70,7 @@ func toLog(fields logrus.Fields, key string, defaultValue string, quoted bool) i
return defaultValue
}
func toLogEntry(s string, defaultValue string, quote bool) string {
if len(s) == 0 {
return defaultValue

View file

@ -32,8 +32,8 @@ func TestCommonLogFormatter_Format(t *testing.T) {
RequestRefererHeader: "",
RequestUserAgentHeader: "",
RequestCount: 0,
FrontendName: "",
BackendURL: "",
RouterName: "",
ServiceURL: "",
},
expectedLog: `10.0.0.1 - Client [10/Nov/2009:23:00:00 +0000] "GET /foo http" - - "-" "-" 0 - - 123000ms
`,
@ -53,8 +53,8 @@ func TestCommonLogFormatter_Format(t *testing.T) {
RequestRefererHeader: "referer",
RequestUserAgentHeader: "agent",
RequestCount: nil,
FrontendName: "foo",
BackendURL: "http://10.0.0.2/toto",
RouterName: "foo",
ServiceURL: "http://10.0.0.2/toto",
},
expectedLog: `10.0.0.1 - Client [10/Nov/2009:23:00:00 +0000] "GET /foo http" 123 132 "referer" "agent" - "foo" "http://10.0.0.2/toto" 123000ms
`,

View file

@ -15,7 +15,6 @@ import (
"time"
"github.com/containous/flaeg/parse"
"github.com/containous/traefik/log"
"github.com/containous/traefik/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -24,8 +23,8 @@ import (
var (
logFileNameSuffix = "/traefik/logger/test.log"
testContent = "Hello, World"
testBackendName = "http://127.0.0.1/testBackend"
testFrontendName = "testFrontend"
testServiceName = "http://127.0.0.1/testService"
testRouterName = "testRouter"
testStatus = 123
testContentSize int64 = 12
testHostname = "TestHost"
@ -50,7 +49,7 @@ func TestLogRotation(t *testing.T) {
rotatedFileName := fileName + ".rotated"
config := &types.AccessLog{FilePath: fileName, Format: CommonFormat}
logHandler, err := NewLogHandler(config)
logHandler, err := NewHandler(config)
if err != nil {
t.Fatalf("Error creating new log handler: %s", err)
}
@ -129,7 +128,7 @@ func TestLoggerCLF(t *testing.T) {
logData, err := ioutil.ReadFile(logFilePath)
require.NoError(t, err)
expectedLog := ` TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testFrontend" "http://127.0.0.1/testBackend" 1ms`
expectedLog := ` TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testRouter" "http://127.0.0.1/testService" 1ms`
assertValidLogData(t, expectedLog, logData)
}
@ -144,7 +143,7 @@ func TestAsyncLoggerCLF(t *testing.T) {
logData, err := ioutil.ReadFile(logFilePath)
require.NoError(t, err)
expectedLog := ` TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testFrontend" "http://127.0.0.1/testBackend" 1ms`
expectedLog := ` TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testRouter" "http://127.0.0.1/testService" 1ms`
assertValidLogData(t, expectedLog, logData)
}
@ -156,11 +155,11 @@ func assertString(exp string) func(t *testing.T, actual interface{}) {
}
}
func assertNotEqual(exp string) func(t *testing.T, actual interface{}) {
func assertNotEmpty() func(t *testing.T, actual interface{}) {
return func(t *testing.T, actual interface{}) {
t.Helper()
assert.NotEqual(t, exp, actual)
assert.NotEqual(t, "", actual)
}
}
@ -205,8 +204,8 @@ func TestLoggerJSON(t *testing.T) {
OriginStatus: assertFloat64(float64(testStatus)),
RequestRefererHeader: assertString(testReferer),
RequestUserAgentHeader: assertString(testUserAgent),
FrontendName: assertString(testFrontendName),
BackendURL: assertString(testBackendName),
RouterName: assertString(testRouterName),
ServiceURL: assertString(testServiceName),
ClientUsername: assertString(testUsername),
ClientHost: assertString(testHostname),
ClientPort: assertString(fmt.Sprintf("%d", testPort)),
@ -218,9 +217,9 @@ func TestLoggerJSON(t *testing.T) {
Duration: assertFloat64NotZero(),
Overhead: assertFloat64NotZero(),
RetryAttempts: assertFloat64(float64(testRetryAttempts)),
"time": assertNotEqual(""),
"StartLocal": assertNotEqual(""),
"StartUTC": assertNotEqual(""),
"time": assertNotEmpty(),
"StartLocal": assertNotEmpty(),
"StartUTC": assertNotEmpty(),
},
},
{
@ -235,7 +234,7 @@ func TestLoggerJSON(t *testing.T) {
expected: map[string]func(t *testing.T, value interface{}){
"level": assertString("info"),
"msg": assertString(""),
"time": assertNotEqual(""),
"time": assertNotEmpty(),
"downstream_Content-Type": assertString("text/plain; charset=utf-8"),
RequestRefererHeader: assertString(testReferer),
RequestUserAgentHeader: assertString(testUserAgent),
@ -256,7 +255,7 @@ func TestLoggerJSON(t *testing.T) {
expected: map[string]func(t *testing.T, value interface{}){
"level": assertString("info"),
"msg": assertString(""),
"time": assertNotEqual(""),
"time": assertNotEmpty(),
},
},
{
@ -274,7 +273,7 @@ func TestLoggerJSON(t *testing.T) {
expected: map[string]func(t *testing.T, value interface{}){
"level": assertString("info"),
"msg": assertString(""),
"time": assertNotEqual(""),
"time": assertNotEmpty(),
"downstream_Content-Type": assertString("REDACTED"),
RequestRefererHeader: assertString("REDACTED"),
RequestUserAgentHeader: assertString("REDACTED"),
@ -302,7 +301,7 @@ func TestLoggerJSON(t *testing.T) {
RequestHost: assertString(testHostname),
"level": assertString("info"),
"msg": assertString(""),
"time": assertNotEqual(""),
"time": assertNotEmpty(),
RequestRefererHeader: assertString(testReferer),
},
},
@ -349,7 +348,7 @@ func TestNewLogHandlerOutputStdout(t *testing.T) {
FilePath: "",
Format: CommonFormat,
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
},
{
desc: "default config with empty filters",
@ -358,7 +357,7 @@ func TestNewLogHandlerOutputStdout(t *testing.T) {
Format: CommonFormat,
Filters: &types.AccessLogFilters{},
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
},
{
desc: "Status code filter not matching",
@ -380,7 +379,7 @@ func TestNewLogHandlerOutputStdout(t *testing.T) {
StatusCodes: []string{"123"},
},
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
},
{
desc: "Duration filter not matching",
@ -402,7 +401,7 @@ func TestNewLogHandlerOutputStdout(t *testing.T) {
MinDuration: parse.Duration(1 * time.Millisecond),
},
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
},
{
desc: "Retry attempts filter matching",
@ -413,7 +412,7 @@ func TestNewLogHandlerOutputStdout(t *testing.T) {
RetryAttempts: true,
},
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
},
{
desc: "Default mode keep",
@ -424,7 +423,7 @@ func TestNewLogHandlerOutputStdout(t *testing.T) {
DefaultMode: "keep",
},
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
},
{
desc: "Default mode keep with override",
@ -438,7 +437,7 @@ func TestNewLogHandlerOutputStdout(t *testing.T) {
},
},
},
expectedLog: `- - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
expectedLog: `- - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
},
{
desc: "Default mode drop",
@ -572,8 +571,8 @@ func assertValidLogData(t *testing.T, expected string, logData []byte) {
assert.Equal(t, resultExpected[RequestRefererHeader], result[RequestRefererHeader], formatErrMessage)
assert.Equal(t, resultExpected[RequestUserAgentHeader], result[RequestUserAgentHeader], formatErrMessage)
assert.Regexp(t, regexp.MustCompile("[0-9]*"), result[RequestCount], formatErrMessage)
assert.Equal(t, resultExpected[FrontendName], result[FrontendName], formatErrMessage)
assert.Equal(t, resultExpected[BackendURL], result[BackendURL], formatErrMessage)
assert.Equal(t, resultExpected[RouterName], result[RouterName], formatErrMessage)
assert.Equal(t, resultExpected[ServiceURL], result[ServiceURL], formatErrMessage)
assert.Regexp(t, regexp.MustCompile("[0-9]*ms"), result[Duration], formatErrMessage)
}
@ -599,7 +598,7 @@ func createTempDir(t *testing.T, prefix string) string {
}
func doLogging(t *testing.T, config *types.AccessLog) {
logger, err := NewLogHandler(config)
logger, err := NewHandler(config)
require.NoError(t, err)
defer logger.Close()
@ -618,6 +617,7 @@ func doLogging(t *testing.T, config *types.AccessLog) {
Method: testMethod,
RemoteAddr: fmt.Sprintf("%s:%d", testHostname, testPort),
URL: &url.URL{
User: url.UserPassword(testUsername, ""),
Path: testPath,
},
}
@ -627,18 +627,23 @@ func doLogging(t *testing.T, config *types.AccessLog) {
func logWriterTestHandlerFunc(rw http.ResponseWriter, r *http.Request) {
if _, err := rw.Write([]byte(testContent)); err != nil {
log.Error(err)
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
logData := GetLogData(r)
if logData != nil {
logData.Core[RouterName] = testRouterName
logData.Core[ServiceURL] = testServiceName
logData.Core[OriginStatus] = testStatus
logData.Core[OriginContentSize] = testContentSize
logData.Core[RetryAttempts] = testRetryAttempts
logData.Core[StartUTC] = testStart.UTC()
logData.Core[StartLocal] = testStart.Local()
} else {
http.Error(rw, "LogData is nil", http.StatusInternalServerError)
return
}
rw.WriteHeader(testStatus)
logDataTable := GetLogDataTable(r)
logDataTable.Core[FrontendName] = testFrontendName
logDataTable.Core[BackendURL] = testBackendName
logDataTable.Core[OriginStatus] = testStatus
logDataTable.Core[OriginContentSize] = testContentSize
logDataTable.Core[RetryAttempts] = testRetryAttempts
logDataTable.Core[StartUTC] = testStart.UTC()
logDataTable.Core[StartLocal] = testStart.Local()
logDataTable.Core[ClientUsername] = testUsername
}

View file

@ -45,8 +45,8 @@ func ParseAccessLog(data string) (map[string]string, error) {
result[RequestRefererHeader] = submatch[9]
result[RequestUserAgentHeader] = submatch[10]
result[RequestCount] = submatch[11]
result[FrontendName] = submatch[12]
result[BackendURL] = submatch[13]
result[RouterName] = submatch[12]
result[ServiceURL] = submatch[13]
result[Duration] = submatch[14]
}

View file

@ -14,7 +14,7 @@ func TestParseAccessLog(t *testing.T) {
}{
{
desc: "full log",
value: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
value: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testRouter" "http://127.0.0.1/testService" 1ms`,
expected: map[string]string{
ClientHost: "TestHost",
ClientUsername: "TestUser",
@ -27,14 +27,14 @@ func TestParseAccessLog(t *testing.T) {
RequestRefererHeader: `"testReferer"`,
RequestUserAgentHeader: `"testUserAgent"`,
RequestCount: "1",
FrontendName: `"testFrontend"`,
BackendURL: `"http://127.0.0.1/testBackend"`,
RouterName: `"testRouter"`,
ServiceURL: `"http://127.0.0.1/testService"`,
Duration: "1ms",
},
},
{
desc: "log with space",
value: `127.0.0.1 - - [09/Mar/2018:10:51:32 +0000] "GET / HTTP/1.1" 401 17 "-" "Go-http-client/1.1" 1 "testFrontend with space" - 0ms`,
value: `127.0.0.1 - - [09/Mar/2018:10:51:32 +0000] "GET / HTTP/1.1" 401 17 "-" "Go-http-client/1.1" 1 "testRouter with space" - 0ms`,
expected: map[string]string{
ClientHost: "127.0.0.1",
ClientUsername: "-",
@ -47,8 +47,8 @@ func TestParseAccessLog(t *testing.T) {
RequestRefererHeader: `"-"`,
RequestUserAgentHeader: `"Go-http-client/1.1"`,
RequestCount: "1",
FrontendName: `"testFrontend with space"`,
BackendURL: `-`,
RouterName: `"testRouter with space"`,
ServiceURL: `-`,
Duration: "0ms",
},
},

View file

@ -1,64 +0,0 @@
package accesslog
import (
"net/http"
"time"
"github.com/urfave/negroni"
"github.com/vulcand/oxy/utils"
)
// SaveBackend sends the backend name to the logger.
// These are always used with a corresponding SaveFrontend handler.
type SaveBackend struct {
next http.Handler
backendName string
}
// NewSaveBackend creates a SaveBackend handler.
func NewSaveBackend(next http.Handler, backendName string) http.Handler {
return &SaveBackend{next, backendName}
}
func (sb *SaveBackend) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
serveSaveBackend(rw, r, sb.backendName, func(crw *captureResponseWriter) {
sb.next.ServeHTTP(crw, r)
})
}
// SaveNegroniBackend sends the backend name to the logger.
type SaveNegroniBackend struct {
next negroni.Handler
backendName string
}
// NewSaveNegroniBackend creates a SaveBackend handler.
func NewSaveNegroniBackend(next negroni.Handler, backendName string) negroni.Handler {
return &SaveNegroniBackend{next, backendName}
}
func (sb *SaveNegroniBackend) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
serveSaveBackend(rw, r, sb.backendName, func(crw *captureResponseWriter) {
sb.next.ServeHTTP(crw, r, next)
})
}
func serveSaveBackend(rw http.ResponseWriter, r *http.Request, backendName string, apply func(*captureResponseWriter)) {
table := GetLogDataTable(r)
table.Core[BackendName] = backendName
table.Core[BackendURL] = r.URL // note that this is *not* the original incoming URL
table.Core[BackendAddr] = r.URL.Host
crw := &captureResponseWriter{rw: rw}
start := time.Now().UTC()
apply(crw)
// use UTC to handle switchover of daylight saving correctly
table.Core[OriginDuration] = time.Now().UTC().Sub(start)
table.Core[OriginStatus] = crw.Status()
// make copy of headers so we can ensure there is no subsequent mutation during response processing
table.OriginResponse = make(http.Header)
utils.CopyHeaders(table.OriginResponse, crw.Header())
table.Core[OriginContentSize] = crw.Size()
}

View file

@ -1,51 +0,0 @@
package accesslog
import (
"net/http"
"strings"
"github.com/urfave/negroni"
)
// SaveFrontend sends the frontend name to the logger.
// These are sometimes used with a corresponding SaveBackend handler, but not always.
// For example, redirected requests don't reach a backend.
type SaveFrontend struct {
next http.Handler
frontendName string
}
// NewSaveFrontend creates a SaveFrontend handler.
func NewSaveFrontend(next http.Handler, frontendName string) http.Handler {
return &SaveFrontend{next, frontendName}
}
func (sf *SaveFrontend) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
serveSaveFrontend(r, sf.frontendName, func() {
sf.next.ServeHTTP(rw, r)
})
}
// SaveNegroniFrontend sends the frontend name to the logger.
type SaveNegroniFrontend struct {
next negroni.Handler
frontendName string
}
// NewSaveNegroniFrontend creates a SaveNegroniFrontend handler.
func NewSaveNegroniFrontend(next negroni.Handler, frontendName string) negroni.Handler {
return &SaveNegroniFrontend{next, frontendName}
}
func (sf *SaveNegroniFrontend) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
serveSaveFrontend(r, sf.frontendName, func() {
sf.next.ServeHTTP(rw, r, next)
})
}
func serveSaveFrontend(r *http.Request, frontendName string, apply func()) {
table := GetLogDataTable(r)
table.Core[FrontendName] = strings.TrimPrefix(frontendName, "frontend-")
apply()
}

View file

@ -14,6 +14,8 @@ func (s *SaveRetries) Retried(req *http.Request, attempt int) {
attempt--
}
table := GetLogDataTable(req)
table.Core[RetryAttempts] = attempt
table := GetLogData(req)
if table != nil {
table.Core[RetryAttempts] = attempt
}
}

View file

@ -1,60 +0,0 @@
package accesslog
import (
"context"
"net/http"
"github.com/urfave/negroni"
)
const (
clientUsernameKey key = "ClientUsername"
)
// SaveUsername sends the Username name to the access logger.
type SaveUsername struct {
next http.Handler
}
// NewSaveUsername creates a SaveUsername handler.
func NewSaveUsername(next http.Handler) http.Handler {
return &SaveUsername{next}
}
func (sf *SaveUsername) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
serveSaveUsername(r, func() {
sf.next.ServeHTTP(rw, r)
})
}
// SaveNegroniUsername adds the Username to the access logger data table.
type SaveNegroniUsername struct {
next negroni.Handler
}
// NewSaveNegroniUsername creates a SaveNegroniUsername handler.
func NewSaveNegroniUsername(next negroni.Handler) negroni.Handler {
return &SaveNegroniUsername{next}
}
func (sf *SaveNegroniUsername) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
serveSaveUsername(r, func() {
sf.next.ServeHTTP(rw, r, next)
})
}
func serveSaveUsername(r *http.Request, apply func()) {
table := GetLogDataTable(r)
username, ok := r.Context().Value(clientUsernameKey).(string)
if ok {
table.Core[ClientUsername] = username
}
apply()
}
// WithUserName adds a username to a requests' context
func WithUserName(req *http.Request, username string) *http.Request {
return req.WithContext(context.WithValue(req.Context(), clientUsernameKey, username))
}

View file

@ -1,35 +0,0 @@
package middlewares
import (
"context"
"net/http"
)
// AddPrefix is a middleware used to add prefix to an URL request
type AddPrefix struct {
Handler http.Handler
Prefix string
}
type key string
const (
// AddPrefixKey is the key within the request context used to
// store the added prefix
AddPrefixKey key = "AddPrefix"
)
func (s *AddPrefix) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r.URL.Path = s.Prefix + r.URL.Path
if r.URL.RawPath != "" {
r.URL.RawPath = s.Prefix + r.URL.RawPath
}
r.RequestURI = r.URL.RequestURI()
r = r.WithContext(context.WithValue(r.Context(), AddPrefixKey, s.Prefix))
s.Handler.ServeHTTP(w, r)
}
// SetHandler sets handler
func (s *AddPrefix) SetHandler(Handler http.Handler) {
s.Handler = Handler
}

View file

@ -1,64 +0,0 @@
package middlewares
import (
"net/http"
"testing"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
)
func TestAddPrefix(t *testing.T) {
tests := []struct {
desc string
prefix string
path string
expectedPath string
expectedRawPath string
}{
{
desc: "regular path",
prefix: "/a",
path: "/b",
expectedPath: "/a/b",
},
{
desc: "raw path is supported",
prefix: "/a",
path: "/b%2Fc",
expectedPath: "/a/b/c",
expectedRawPath: "/a/b%2Fc",
},
}
for _, test := range tests {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
var actualPath, actualRawPath, requestURI string
handler := &AddPrefix{
Prefix: test.prefix,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actualPath = r.URL.Path
actualRawPath = r.URL.RawPath
requestURI = r.RequestURI
}),
}
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
handler.ServeHTTP(nil, req)
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
assert.Equal(t, test.expectedRawPath, actualRawPath, "Unexpected raw path.")
expectedURI := test.expectedPath
if test.expectedRawPath != "" {
// go HTTP uses the raw path when existent in the RequestURI
expectedURI = test.expectedRawPath
}
assert.Equal(t, expectedURI, requestURI, "Unexpected request URI.")
})
}
}

View file

@ -0,0 +1,62 @@
package addprefix
import (
"context"
"fmt"
"net/http"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go/ext"
)
const (
typeName = "AddPrefix"
)
// AddPrefix is a middleware used to add prefix to an URL request.
type addPrefix struct {
next http.Handler
prefix string
name string
}
// New creates a new handler.
func New(ctx context.Context, next http.Handler, config config.AddPrefix, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
var result *addPrefix
if len(config.Prefix) > 0 {
result = &addPrefix{
prefix: config.Prefix,
next: next,
name: name,
}
} else {
return nil, fmt.Errorf("prefix cannot be empty")
}
return result, nil
}
func (ap *addPrefix) GetTracingInformation() (string, ext.SpanKindEnum) {
return ap.name, tracing.SpanKindNoneEnum
}
func (ap *addPrefix) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), ap.name, typeName)
oldURLPath := req.URL.Path
req.URL.Path = ap.prefix + req.URL.Path
logger.Debugf("URL.Path is now %s (was %s).", req.URL.Path, oldURLPath)
if req.URL.RawPath != "" {
oldURLRawPath := req.URL.RawPath
req.URL.RawPath = ap.prefix + req.URL.RawPath
logger.Debugf("URL.RawPath is now %s (was %s).", req.URL.RawPath, oldURLRawPath)
}
req.RequestURI = req.URL.RequestURI()
ap.next.ServeHTTP(rw, req)
}

View file

@ -0,0 +1,104 @@
package addprefix
import (
"context"
"net/http"
"testing"
"github.com/containous/traefik/config"
"github.com/containous/traefik/testhelpers"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewAddPrefix(t *testing.T) {
testCases := []struct {
desc string
prefix config.AddPrefix
expectsError bool
}{
{
desc: "Works with a non empty prefix",
prefix: config.AddPrefix{Prefix: "/a"},
},
{
desc: "Fails if prefix is empty",
prefix: config.AddPrefix{Prefix: ""},
expectsError: true,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
_, err := New(context.Background(), next, test.prefix, "foo-add-prefix")
if test.expectsError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestAddPrefix(t *testing.T) {
logrus.SetLevel(logrus.DebugLevel)
testCases := []struct {
desc string
prefix config.AddPrefix
path string
expectedPath string
expectedRawPath string
}{
{
desc: "Works with a regular path",
prefix: config.AddPrefix{Prefix: "/a"},
path: "/b",
expectedPath: "/a/b",
},
{
desc: "Works with a raw path",
prefix: config.AddPrefix{Prefix: "/a"},
path: "/b%2Fc",
expectedPath: "/a/b/c",
expectedRawPath: "/a/b%2Fc",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
var actualPath, actualRawPath, requestURI string
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actualPath = r.URL.Path
actualRawPath = r.URL.RawPath
requestURI = r.RequestURI
})
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
handler, err := New(context.Background(), next, test.prefix, "foo-add-prefix")
require.NoError(t, err)
handler.ServeHTTP(nil, req)
assert.Equal(t, test.expectedPath, actualPath)
assert.Equal(t, test.expectedRawPath, actualRawPath)
expectedURI := test.expectedPath
if test.expectedRawPath != "" {
// go HTTP uses the raw path when existent in the RequestURI
expectedURI = test.expectedRawPath
}
assert.Equal(t, expectedURI, requestURI)
})
}
}

65
middlewares/auth/auth.go Normal file
View file

@ -0,0 +1,65 @@
package auth
import (
"io/ioutil"
"strings"
)
// UserParser Parses a string and return a userName/userHash. An error if the format of the string is incorrect.
type UserParser func(user string) (string, string, error)
const (
defaultRealm = "traefik"
authorizationHeader = "Authorization"
)
func getUsers(fileName string, appendUsers []string, parser UserParser) (map[string]string, error) {
users, err := loadUsers(fileName, appendUsers)
if err != nil {
return nil, err
}
userMap := make(map[string]string)
for _, user := range users {
userName, userHash, err := parser(user)
if err != nil {
return nil, err
}
userMap[userName] = userHash
}
return userMap, nil
}
func loadUsers(fileName string, appendUsers []string) ([]string, error) {
var users []string
var err error
if fileName != "" {
users, err = getLinesFromFile(fileName)
if err != nil {
return nil, err
}
}
return append(users, appendUsers...), nil
}
func getLinesFromFile(filename string) ([]string, error) {
dat, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
// Trim lines and filter out blanks
rawLines := strings.Split(string(dat), "\n")
var filteredLines []string
for _, rawLine := range rawLines {
line := strings.TrimSpace(rawLine)
if line != "" {
filteredLines = append(filteredLines, line)
}
}
return filteredLines, nil
}

View file

@ -1,165 +0,0 @@
package auth
import (
"fmt"
"io/ioutil"
"net/http"
"strings"
goauth "github.com/abbot/go-http-auth"
"github.com/containous/traefik/log"
"github.com/containous/traefik/middlewares/accesslog"
"github.com/containous/traefik/middlewares/tracing"
"github.com/containous/traefik/types"
"github.com/urfave/negroni"
)
// Authenticator is a middleware that provides HTTP basic and digest authentication
type Authenticator struct {
handler negroni.Handler
users map[string]string
}
type tracingAuthenticator struct {
name string
handler negroni.Handler
clientSpanKind bool
}
const (
authorizationHeader = "Authorization"
)
// NewAuthenticator builds a new Authenticator given a config
func NewAuthenticator(authConfig *types.Auth, tracingMiddleware *tracing.Tracing) (*Authenticator, error) {
if authConfig == nil {
return nil, fmt.Errorf("error creating Authenticator: auth is nil")
}
var err error
authenticator := &Authenticator{}
tracingAuth := tracingAuthenticator{}
if authConfig.Basic != nil {
authenticator.users, err = parserBasicUsers(authConfig.Basic)
if err != nil {
return nil, err
}
realm := "traefik"
if authConfig.Basic.Realm != "" {
realm = authConfig.Basic.Realm
}
basicAuth := goauth.NewBasicAuthenticator(realm, authenticator.secretBasic)
tracingAuth.handler = createAuthBasicHandler(basicAuth, authConfig)
tracingAuth.name = "Auth Basic"
tracingAuth.clientSpanKind = false
} else if authConfig.Digest != nil {
authenticator.users, err = parserDigestUsers(authConfig.Digest)
if err != nil {
return nil, err
}
digestAuth := goauth.NewDigestAuthenticator("traefik", authenticator.secretDigest)
tracingAuth.handler = createAuthDigestHandler(digestAuth, authConfig)
tracingAuth.name = "Auth Digest"
tracingAuth.clientSpanKind = false
} else if authConfig.Forward != nil {
tracingAuth.handler = createAuthForwardHandler(authConfig)
tracingAuth.name = "Auth Forward"
tracingAuth.clientSpanKind = true
}
if tracingMiddleware != nil {
authenticator.handler = tracingMiddleware.NewNegroniHandlerWrapper(tracingAuth.name, tracingAuth.handler, tracingAuth.clientSpanKind)
} else {
authenticator.handler = tracingAuth.handler
}
return authenticator, nil
}
func createAuthForwardHandler(authConfig *types.Auth) negroni.HandlerFunc {
return negroni.HandlerFunc(func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
Forward(authConfig.Forward, w, r, next)
})
}
func createAuthDigestHandler(digestAuth *goauth.DigestAuth, authConfig *types.Auth) negroni.HandlerFunc {
return negroni.HandlerFunc(func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
if username, _ := digestAuth.CheckAuth(r); username == "" {
log.Debugf("Digest auth failed")
digestAuth.RequireAuth(w, r)
} else {
log.Debugf("Digest auth succeeded")
// set username in request context
r = accesslog.WithUserName(r, username)
if authConfig.HeaderField != "" {
r.Header[authConfig.HeaderField] = []string{username}
}
if authConfig.Digest.RemoveHeader {
log.Debugf("Remove the Authorization header from the Digest auth")
r.Header.Del(authorizationHeader)
}
next.ServeHTTP(w, r)
}
})
}
func createAuthBasicHandler(basicAuth *goauth.BasicAuth, authConfig *types.Auth) negroni.HandlerFunc {
return negroni.HandlerFunc(func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
if username := basicAuth.CheckAuth(r); username == "" {
log.Debugf("Basic auth failed")
basicAuth.RequireAuth(w, r)
} else {
log.Debugf("Basic auth succeeded")
// set username in request context
r = accesslog.WithUserName(r, username)
if authConfig.HeaderField != "" {
r.Header[authConfig.HeaderField] = []string{username}
}
if authConfig.Basic.RemoveHeader {
log.Debugf("Remove the Authorization header from the Basic auth")
r.Header.Del(authorizationHeader)
}
next.ServeHTTP(w, r)
}
})
}
func getLinesFromFile(filename string) ([]string, error) {
dat, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
// Trim lines and filter out blanks
rawLines := strings.Split(string(dat), "\n")
var filteredLines []string
for _, rawLine := range rawLines {
line := strings.TrimSpace(rawLine)
if line != "" {
filteredLines = append(filteredLines, line)
}
}
return filteredLines, nil
}
func (a *Authenticator) secretBasic(user, realm string) string {
if secret, ok := a.users[user]; ok {
return secret
}
log.Debugf("User not found: %s", user)
return ""
}
func (a *Authenticator) secretDigest(user, realm string) string {
if secret, ok := a.users[user+":"+realm]; ok {
return secret
}
log.Debugf("User not found: %s:%s", user, realm)
return ""
}
func (a *Authenticator) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
a.handler.ServeHTTP(rw, r, next)
}

View file

@ -1,297 +0,0 @@
package auth
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/containous/traefik/middlewares/tracing"
"github.com/containous/traefik/testhelpers"
"github.com/containous/traefik/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/negroni"
)
func TestAuthUsersFromFile(t *testing.T) {
tests := []struct {
authType string
usersStr string
userKeys []string
parserFunc func(fileName string) (map[string]string, error)
}{
{
authType: "basic",
usersStr: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
userKeys: []string{"test", "test2"},
parserFunc: func(fileName string) (map[string]string, error) {
basic := &types.Basic{
UsersFile: fileName,
}
return parserBasicUsers(basic)
},
},
{
authType: "digest",
usersStr: "test:traefik:a2688e031edb4be6a3797f3882655c05 \ntest2:traefik:518845800f9e2bfb1f1f740ec24f074e\n",
userKeys: []string{"test:traefik", "test2:traefik"},
parserFunc: func(fileName string) (map[string]string, error) {
digest := &types.Digest{
UsersFile: fileName,
}
return parserDigestUsers(digest)
},
},
}
for _, test := range tests {
test := test
t.Run(test.authType, func(t *testing.T) {
t.Parallel()
usersFile, err := ioutil.TempFile("", "auth-users")
require.NoError(t, err)
defer os.Remove(usersFile.Name())
_, err = usersFile.Write([]byte(test.usersStr))
require.NoError(t, err)
users, err := test.parserFunc(usersFile.Name())
require.NoError(t, err)
assert.Equal(t, 2, len(users), "they should be equal")
_, ok := users[test.userKeys[0]]
assert.True(t, ok, "user test should be found")
_, ok = users[test.userKeys[1]]
assert.True(t, ok, "user test2 should be found")
})
}
}
func TestBasicAuthFail(t *testing.T) {
_, err := NewAuthenticator(&types.Auth{
Basic: &types.Basic{
Users: []string{"test"},
},
}, &tracing.Tracing{})
assert.Contains(t, err.Error(), "error parsing Authenticator user", "should contains")
authMiddleware, err := NewAuthenticator(&types.Auth{
Basic: &types.Basic{
Users: []string{"test:test"},
},
}, &tracing.Tracing{})
require.NoError(t, err)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
n := negroni.New(authMiddleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
defer ts.Close()
client := &http.Client{}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, res.StatusCode, "they should be equal")
}
func TestBasicAuthSuccess(t *testing.T) {
authMiddleware, err := NewAuthenticator(&types.Auth{
Basic: &types.Basic{
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
},
}, &tracing.Tracing{})
require.NoError(t, err)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
n := negroni.New(authMiddleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
defer ts.Close()
client := &http.Client{}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, "traefik\n", string(body), "they should be equal")
}
func TestBasicRealm(t *testing.T) {
authMiddlewareDefaultRealm, errdefault := NewAuthenticator(&types.Auth{
Basic: &types.Basic{
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
},
}, &tracing.Tracing{})
require.NoError(t, errdefault)
authMiddlewareCustomRealm, errcustom := NewAuthenticator(&types.Auth{
Basic: &types.Basic{
Realm: "foobar",
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
},
}, &tracing.Tracing{})
require.NoError(t, errcustom)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
n := negroni.New(authMiddlewareDefaultRealm)
n.UseHandler(handler)
ts := httptest.NewServer(n)
defer ts.Close()
client := &http.Client{}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, "Basic realm=\"traefik\"", res.Header.Get("Www-Authenticate"), "they should be equal")
n = negroni.New(authMiddlewareCustomRealm)
n.UseHandler(handler)
ts = httptest.NewServer(n)
defer ts.Close()
client = &http.Client{}
req = testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err = client.Do(req)
require.NoError(t, err)
assert.Equal(t, "Basic realm=\"foobar\"", res.Header.Get("Www-Authenticate"), "they should be equal")
}
func TestDigestAuthFail(t *testing.T) {
_, err := NewAuthenticator(&types.Auth{
Digest: &types.Digest{
Users: []string{"test"},
},
}, &tracing.Tracing{})
assert.Contains(t, err.Error(), "error parsing Authenticator user", "should contains")
authMiddleware, err := NewAuthenticator(&types.Auth{
Digest: &types.Digest{
Users: []string{"test:traefik:test"},
},
}, &tracing.Tracing{})
require.NoError(t, err)
assert.NotNil(t, authMiddleware, "this should not be nil")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
n := negroni.New(authMiddleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
defer ts.Close()
client := &http.Client{}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, res.StatusCode, "they should be equal")
}
func TestBasicAuthUserHeader(t *testing.T) {
middleware, err := NewAuthenticator(&types.Auth{
Basic: &types.Basic{
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
},
HeaderField: "X-Webauth-User",
}, &tracing.Tracing{})
require.NoError(t, err)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "test", r.Header["X-Webauth-User"][0], "auth user should be set")
fmt.Fprintln(w, "traefik")
})
n := negroni.New(middleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
defer ts.Close()
client := &http.Client{}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, "traefik\n", string(body), "they should be equal")
}
func TestBasicAuthHeaderRemoved(t *testing.T) {
middleware, err := NewAuthenticator(&types.Auth{
Basic: &types.Basic{
RemoveHeader: true,
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
},
}, &tracing.Tracing{})
require.NoError(t, err)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get(authorizationHeader))
fmt.Fprintln(w, "traefik")
})
n := negroni.New(middleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
defer ts.Close()
client := &http.Client{}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, "traefik\n", string(body), "they should be equal")
}
func TestBasicAuthHeaderPresent(t *testing.T) {
middleware, err := NewAuthenticator(&types.Auth{
Basic: &types.Basic{
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
},
}, &tracing.Tracing{})
require.NoError(t, err)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.NotEmpty(t, r.Header.Get(authorizationHeader))
fmt.Fprintln(w, "traefik")
})
n := negroni.New(middleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
defer ts.Close()
client := &http.Client{}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, "traefik\n", string(body), "they should be equal")
}

View file

@ -0,0 +1,102 @@
package auth
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
goauth "github.com/abbot/go-http-auth"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/middlewares/accesslog"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go/ext"
)
const (
basicTypeName = "BasicAuth"
)
type basicAuth struct {
next http.Handler
auth *goauth.BasicAuth
users map[string]string
headerField string
removeHeader bool
name string
}
// NewBasic creates a basicAuth middleware.
func NewBasic(ctx context.Context, next http.Handler, authConfig config.BasicAuth, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, basicTypeName).Debug("Creating middleware")
users, err := getUsers(authConfig.UsersFile, authConfig.Users, basicUserParser)
if err != nil {
return nil, err
}
ba := &basicAuth{
next: next,
users: users,
headerField: authConfig.HeaderField,
removeHeader: authConfig.RemoveHeader,
name: name,
}
realm := defaultRealm
if len(authConfig.Realm) > 0 {
realm = authConfig.Realm
}
ba.auth = goauth.NewBasicAuthenticator(realm, ba.secretBasic)
return ba, nil
}
func (b *basicAuth) GetTracingInformation() (string, ext.SpanKindEnum) {
return b.name, tracing.SpanKindNoneEnum
}
func (b *basicAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), b.name, basicTypeName)
if username := b.auth.CheckAuth(req); username == "" {
logger.Debug("Authentication failed")
tracing.SetErrorWithEvent(req, "Authentication failed")
b.auth.RequireAuth(rw, req)
} else {
logger.Debug("Authentication succeeded")
req.URL.User = url.User(username)
logData := accesslog.GetLogData(req)
if logData != nil {
logData.Core[accesslog.ClientUsername] = username
}
if b.headerField != "" {
req.Header[b.headerField] = []string{username}
}
if b.removeHeader {
logger.Debug("Removing authorization header")
req.Header.Del(authorizationHeader)
}
b.next.ServeHTTP(rw, req)
}
}
func (b *basicAuth) secretBasic(user, realm string) string {
if secret, ok := b.users[user]; ok {
return secret
}
return ""
}
func basicUserParser(user string) (string, string, error) {
split := strings.Split(user, ":")
if len(split) != 2 {
return "", "", fmt.Errorf("error parsing BasicUser: %v", user)
}
return split[0], split[1], nil
}

View file

@ -0,0 +1,274 @@
package auth
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/containous/traefik/config"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBasicAuthFail(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.BasicAuth{
Users: []string{"test"},
}
_, err := NewBasic(context.Background(), next, auth, "authName")
require.Error(t, err)
auth2 := config.BasicAuth{
Users: []string{"test:test"},
}
authMiddleware, err := NewBasic(context.Background(), next, auth2, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, res.StatusCode, "they should be equal")
}
func TestBasicAuthSuccess(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.BasicAuth{
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
}
authMiddleware, err := NewBasic(context.Background(), next, auth, "authName")
require.NoError(t, err)
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, "traefik\n", string(body), "they should be equal")
}
func TestBasicAuthUserHeader(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "test", r.Header["X-Webauth-User"][0], "auth user should be set")
fmt.Fprintln(w, "traefik")
})
auth := config.BasicAuth{
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
HeaderField: "X-Webauth-User",
}
middleware, err := NewBasic(context.Background(), next, auth, "authName")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, "traefik\n", string(body))
}
func TestBasicAuthHeaderRemoved(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get(authorizationHeader))
fmt.Fprintln(w, "traefik")
})
auth := config.BasicAuth{
RemoveHeader: true,
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
}
middleware, err := NewBasic(context.Background(), next, auth, "authName")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.Equal(t, "traefik\n", string(body))
}
func TestBasicAuthHeaderPresent(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.NotEmpty(t, r.Header.Get(authorizationHeader))
fmt.Fprintln(w, "traefik")
})
auth := config.BasicAuth{
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
}
middleware, err := NewBasic(context.Background(), next, auth, "authName")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.Equal(t, "traefik\n", string(body))
}
func TestBasicAuthUsersFromFile(t *testing.T) {
testCases := []struct {
desc string
userFileContent string
expectedUsers map[string]string
givenUsers []string
realm string
}{
{
desc: "Finds the users in the file",
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
givenUsers: []string{},
expectedUsers: map[string]string{"test": "test", "test2": "test2"},
},
{
desc: "Merges given users with users from the file",
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\n",
givenUsers: []string{"test2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0", "test3:$apr1$3rJbDP0q$RfzJiorTk78jQ1EcKqWso0"},
expectedUsers: map[string]string{"test": "test", "test2": "test2", "test3": "test3"},
},
{
desc: "Given users have priority over users in the file",
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
givenUsers: []string{"test2:$apr1$mK.GtItK$ncnLYvNLek0weXdxo68690"},
expectedUsers: map[string]string{"test": "test", "test2": "overridden"},
},
{
desc: "Should authenticate the correct user based on the realm",
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
givenUsers: []string{"test2:$apr1$mK.GtItK$ncnLYvNLek0weXdxo68690"},
expectedUsers: map[string]string{"test": "test", "test2": "overridden"},
realm: "trafikee",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
// Creates the temporary configuration file with the users
usersFile, err := ioutil.TempFile("", "auth-users")
require.NoError(t, err)
defer os.Remove(usersFile.Name())
_, err = usersFile.Write([]byte(test.userFileContent))
require.NoError(t, err)
// Creates the configuration for our Authenticator
authenticatorConfiguration := config.BasicAuth{
Users: test.givenUsers,
UsersFile: usersFile.Name(),
Realm: test.realm,
}
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
authenticator, err := NewBasic(context.Background(), next, authenticatorConfiguration, "authName")
require.NoError(t, err)
ts := httptest.NewServer(authenticator)
defer ts.Close()
for userName, userPwd := range test.expectedUsers {
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth(userName, userPwd)
var res *http.Response
res, err = http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode, "Cannot authenticate user "+userName)
var body []byte
body, err = ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
require.Equal(t, "traefik\n", string(body))
}
// Checks that user foo doesn't work
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("foo", "foo")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
if len(test.realm) > 0 {
require.Equal(t, `Basic realm="`+test.realm+`"`, res.Header.Get("WWW-Authenticate"))
}
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
require.NotContains(t, "traefik", string(body))
})
}
}

View file

@ -0,0 +1,102 @@
package auth
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
goauth "github.com/abbot/go-http-auth"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/middlewares/accesslog"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go/ext"
)
const (
digestTypeName = "digestAuth"
)
type digestAuth struct {
next http.Handler
auth *goauth.DigestAuth
users map[string]string
headerField string
removeHeader bool
name string
}
// NewDigest creates a digest auth middleware.
func NewDigest(ctx context.Context, next http.Handler, authConfig config.DigestAuth, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, digestTypeName).Debug("Creating middleware")
users, err := getUsers(authConfig.UsersFile, authConfig.Users, digestUserParser)
if err != nil {
return nil, err
}
da := &digestAuth{
next: next,
users: users,
headerField: authConfig.HeaderField,
removeHeader: authConfig.RemoveHeader,
name: name,
}
realm := defaultRealm
if len(authConfig.Realm) > 0 {
realm = authConfig.Realm
}
da.auth = goauth.NewDigestAuthenticator(realm, da.secretDigest)
return da, nil
}
func (d *digestAuth) GetTracingInformation() (string, ext.SpanKindEnum) {
return d.name, tracing.SpanKindNoneEnum
}
func (d *digestAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), d.name, digestTypeName)
if username, _ := d.auth.CheckAuth(req); username == "" {
logger.Debug("Digest authentication failed")
tracing.SetErrorWithEvent(req, "Digest authentication failed")
d.auth.RequireAuth(rw, req)
} else {
logger.Debug("Digest authentication succeeded")
req.URL.User = url.User(username)
logData := accesslog.GetLogData(req)
if logData != nil {
logData.Core[accesslog.ClientUsername] = username
}
if d.headerField != "" {
req.Header[d.headerField] = []string{username}
}
if d.removeHeader {
logger.Debug("Removing the Authorization header")
req.Header.Del(authorizationHeader)
}
d.next.ServeHTTP(rw, req)
}
}
func (d *digestAuth) secretDigest(user, realm string) string {
if secret, ok := d.users[user+":"+realm]; ok {
return secret
}
return ""
}
func digestUserParser(user string) (string, string, error) {
split := strings.Split(user, ":")
if len(split) != 3 {
return "", "", fmt.Errorf("error parsing DigestUser: %v", user)
}
return split[0] + ":" + split[1], split[2], nil
}

View file

@ -0,0 +1,141 @@
package auth
import (
"crypto/md5"
"crypto/rand"
"encoding/hex"
"fmt"
"io"
"net/http"
"strings"
)
const (
algorithm = "algorithm"
authorization = "Authorization"
nonce = "nonce"
opaque = "opaque"
qop = "qop"
realm = "realm"
wwwAuthenticate = "Www-Authenticate"
)
// DigestRequest is a client for digest authentication requests
type digestRequest struct {
client *http.Client
username, password string
nonceCount nonceCount
}
type nonceCount int
func (nc nonceCount) String() string {
return fmt.Sprintf("%08x", int(nc))
}
var wanted = []string{algorithm, nonce, opaque, qop, realm}
// New makes a DigestRequest instance
func newDigestRequest(username, password string, client *http.Client) *digestRequest {
return &digestRequest{
client: client,
username: username,
password: password,
}
}
// Do does requests as http.Do does
func (r *digestRequest) Do(req *http.Request) (*http.Response, error) {
parts, err := r.makeParts(req)
if err != nil {
return nil, err
}
if parts != nil {
req.Header.Set(authorization, r.makeAuthorization(req, parts))
}
return r.client.Do(req)
}
func (r *digestRequest) makeParts(req *http.Request) (map[string]string, error) {
authReq, err := http.NewRequest(req.Method, req.URL.String(), nil)
if err != nil {
return nil, err
}
resp, err := r.client.Do(authReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
return nil, nil
}
if len(resp.Header[wwwAuthenticate]) == 0 {
return nil, fmt.Errorf("headers do not have %s", wwwAuthenticate)
}
headers := strings.Split(resp.Header[wwwAuthenticate][0], ",")
parts := make(map[string]string, len(wanted))
for _, r := range headers {
for _, w := range wanted {
if strings.Contains(r, w) {
parts[w] = strings.Split(r, `"`)[1]
}
}
}
if len(parts) != len(wanted) {
return nil, fmt.Errorf("header is invalid: %+v", parts)
}
return parts, nil
}
func getMD5(texts []string) string {
h := md5.New()
_, _ = io.WriteString(h, strings.Join(texts, ":"))
return hex.EncodeToString(h.Sum(nil))
}
func (r *digestRequest) getNonceCount() string {
r.nonceCount++
return r.nonceCount.String()
}
func (r *digestRequest) makeAuthorization(req *http.Request, parts map[string]string) string {
ha1 := getMD5([]string{r.username, parts[realm], r.password})
ha2 := getMD5([]string{req.Method, req.URL.String()})
cnonce := generateRandom(16)
nc := r.getNonceCount()
response := getMD5([]string{
ha1,
parts[nonce],
nc,
cnonce,
parts[qop],
ha2,
})
return fmt.Sprintf(
`Digest username="%s", realm="%s", nonce="%s", uri="%s", algorithm=%s, qop=%s, nc=%s, cnonce="%s", response="%s", opaque="%s"`,
r.username,
parts[realm],
parts[nonce],
req.URL.String(),
parts[algorithm],
parts[qop],
nc,
cnonce,
response,
parts[opaque],
)
}
// GenerateRandom generates random string
func generateRandom(n int) string {
b := make([]byte, 8)
_, _ = io.ReadFull(rand.Reader, b)
return fmt.Sprintf("%x", b)[:n]
}

View file

@ -0,0 +1,156 @@
package auth
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/containous/traefik/config"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDigestAuthError(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.DigestAuth{
Users: []string{"test"},
}
_, err := NewDigest(context.Background(), next, auth, "authName")
assert.Error(t, err)
}
func TestDigestAuthFail(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.DigestAuth{
Users: []string{"test:traefik:a2688e031edb4be6a3797f3882655c05"},
}
authMiddleware, err := NewDigest(context.Background(), next, auth, "authName")
require.NoError(t, err)
assert.NotNil(t, authMiddleware, "this should not be nil")
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
client := http.DefaultClient
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
}
func TestDigestAuthUsersFromFile(t *testing.T) {
testCases := []struct {
desc string
userFileContent string
expectedUsers map[string]string
givenUsers []string
realm string
}{
{
desc: "Finds the users in the file",
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\ntest2:traefik:518845800f9e2bfb1f1f740ec24f074e\n",
givenUsers: []string{},
expectedUsers: map[string]string{"test": "test", "test2": "test2"},
},
{
desc: "Merges given users with users from the file",
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\n",
givenUsers: []string{"test2:traefik:518845800f9e2bfb1f1f740ec24f074e", "test3:traefik:c8e9f57ce58ecb4424407f665a91646c"},
expectedUsers: map[string]string{"test": "test", "test2": "test2", "test3": "test3"},
},
{
desc: "Given users have priority over users in the file",
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\ntest2:traefik:518845800f9e2bfb1f1f740ec24f074e\n",
givenUsers: []string{"test2:traefik:8de60a1c52da68ccf41f0c0ffb7c51a0"},
expectedUsers: map[string]string{"test": "test", "test2": "overridden"},
},
{
desc: "Should authenticate the correct user based on the realm",
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\ntest:traefikee:316a669c158c8b7ab1048b03961a7aa5\n",
givenUsers: []string{},
expectedUsers: map[string]string{"test": "test2"},
realm: "traefikee",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
// Creates the temporary configuration file with the users
usersFile, err := ioutil.TempFile("", "auth-users")
require.NoError(t, err)
defer os.Remove(usersFile.Name())
_, err = usersFile.Write([]byte(test.userFileContent))
require.NoError(t, err)
// Creates the configuration for our Authenticator
authenticatorConfiguration := config.DigestAuth{
Users: test.givenUsers,
UsersFile: usersFile.Name(),
Realm: test.realm,
}
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
authenticator, err := NewDigest(context.Background(), next, authenticatorConfiguration, "authName")
require.NoError(t, err)
ts := httptest.NewServer(authenticator)
defer ts.Close()
for userName, userPwd := range test.expectedUsers {
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
digestRequest := newDigestRequest(userName, userPwd, http.DefaultClient)
var res *http.Response
res, err = digestRequest.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode, "Cannot authenticate user "+userName)
var body []byte
body, err = ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
require.Equal(t, "traefik\n", string(body))
}
// Checks that user foo doesn't work
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
digestRequest := newDigestRequest("foo", "foo", http.DefaultClient)
var res *http.Response
res, err = digestRequest.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
var body []byte
body, err = ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
require.NotContains(t, "traefik", string(body))
})
}
}

View file

@ -1,25 +1,68 @@
package auth
import (
"context"
"crypto/tls"
"fmt"
"io/ioutil"
"net"
"net/http"
"strings"
"github.com/containous/traefik/log"
"github.com/containous/traefik/middlewares/tracing"
"github.com/containous/traefik/types"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/utils"
)
const (
xForwardedURI = "X-Forwarded-Uri"
xForwardedMethod = "X-Forwarded-Method"
xForwardedURI = "X-Forwarded-Uri"
xForwardedMethod = "X-Forwarded-Method"
forwardedTypeName = "ForwardedAuthType"
)
// Forward the authentication to a external server
func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
type forwardAuth struct {
address string
authResponseHeaders []string
next http.Handler
name string
tlsConfig *tls.Config
trustForwardHeader bool
}
// NewForward creates a forward auth middleware.
func NewForward(ctx context.Context, next http.Handler, config config.ForwardAuth, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, forwardedTypeName).Debug("Creating middleware")
fa := &forwardAuth{
address: config.Address,
authResponseHeaders: config.AuthResponseHeaders,
next: next,
name: name,
trustForwardHeader: config.TrustForwardHeader,
}
if config.TLS != nil {
tlsConfig, err := config.TLS.CreateTLSConfig()
if err != nil {
return nil, err
}
fa.tlsConfig = tlsConfig
}
return fa, nil
}
func (fa *forwardAuth) GetTracingInformation() (string, ext.SpanKindEnum) {
return fa.name, ext.SpanKindRPCClientEnum
}
func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), fa.name, forwardedTypeName)
// Ensure our request client does not follow redirects
httpClient := http.Client{
CheckRedirect: func(r *http.Request, via []*http.Request) error {
@ -27,42 +70,44 @@ func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next
},
}
if config.TLS != nil {
tlsConfig, err := config.TLS.CreateTLSConfig()
if err != nil {
tracing.SetErrorAndDebugLog(r, "Unable to configure TLS to call %s. Cause %s", config.Address, err)
w.WriteHeader(http.StatusInternalServerError)
return
}
if fa.tlsConfig != nil {
httpClient.Transport = &http.Transport{
TLSClientConfig: tlsConfig,
TLSClientConfig: fa.tlsConfig,
}
}
forwardReq, err := http.NewRequest(http.MethodGet, config.Address, http.NoBody)
tracing.LogRequest(tracing.GetSpan(r), forwardReq)
forwardReq, err := http.NewRequest(http.MethodGet, fa.address, nil)
tracing.LogRequest(tracing.GetSpan(req), forwardReq)
if err != nil {
tracing.SetErrorAndDebugLog(r, "Error calling %s. Cause %s", config.Address, err)
w.WriteHeader(http.StatusInternalServerError)
logMessage := fmt.Sprintf("Error calling %s. Cause %s", fa.address, err)
logger.Debug(logMessage)
tracing.SetErrorWithEvent(req, logMessage)
rw.WriteHeader(http.StatusInternalServerError)
return
}
writeHeader(r, forwardReq, config.TrustForwardHeader)
writeHeader(req, forwardReq, fa.trustForwardHeader)
tracing.InjectRequestHeaders(forwardReq)
forwardResponse, forwardErr := httpClient.Do(forwardReq)
if forwardErr != nil {
tracing.SetErrorAndDebugLog(r, "Error calling %s. Cause: %s", config.Address, forwardErr)
w.WriteHeader(http.StatusInternalServerError)
logMessage := fmt.Sprintf("Error calling %s. Cause: %s", fa.address, forwardErr)
logger.Debug(logMessage)
tracing.SetErrorWithEvent(req, logMessage)
rw.WriteHeader(http.StatusInternalServerError)
return
}
body, readError := ioutil.ReadAll(forwardResponse.Body)
if readError != nil {
tracing.SetErrorAndDebugLog(r, "Error reading body %s. Cause: %s", config.Address, readError)
w.WriteHeader(http.StatusInternalServerError)
logMessage := fmt.Sprintf("Error reading body %s. Cause: %s", fa.address, readError)
logger.Debug(logMessage)
tracing.SetErrorWithEvent(req, logMessage)
rw.WriteHeader(http.StatusInternalServerError)
return
}
defer forwardResponse.Body.Close()
@ -70,40 +115,43 @@ func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next
// Pass the forward response's body and selected headers if it
// didn't return a response within the range of [200, 300).
if forwardResponse.StatusCode < http.StatusOK || forwardResponse.StatusCode >= http.StatusMultipleChoices {
log.Debugf("Remote error %s. StatusCode: %d", config.Address, forwardResponse.StatusCode)
logger.Debugf("Remote error %s. StatusCode: %d", fa.address, forwardResponse.StatusCode)
utils.CopyHeaders(w.Header(), forwardResponse.Header)
utils.RemoveHeaders(w.Header(), forward.HopHeaders...)
utils.CopyHeaders(rw.Header(), forwardResponse.Header)
utils.RemoveHeaders(rw.Header(), forward.HopHeaders...)
// Grab the location header, if any.
redirectURL, err := forwardResponse.Location()
if err != nil {
if err != http.ErrNoLocation {
tracing.SetErrorAndDebugLog(r, "Error reading response location header %s. Cause: %s", config.Address, err)
w.WriteHeader(http.StatusInternalServerError)
logMessage := fmt.Sprintf("Error reading response location header %s. Cause: %s", fa.address, err)
logger.Debug(logMessage)
tracing.SetErrorWithEvent(req, logMessage)
rw.WriteHeader(http.StatusInternalServerError)
return
}
} else if redirectURL.String() != "" {
// Set the location in our response if one was sent back.
w.Header().Set("Location", redirectURL.String())
rw.Header().Set("Location", redirectURL.String())
}
tracing.LogResponseCode(tracing.GetSpan(r), forwardResponse.StatusCode)
w.WriteHeader(forwardResponse.StatusCode)
tracing.LogResponseCode(tracing.GetSpan(req), forwardResponse.StatusCode)
rw.WriteHeader(forwardResponse.StatusCode)
if _, err = w.Write(body); err != nil {
log.Error(err)
if _, err = rw.Write(body); err != nil {
logger.Error(err)
}
return
}
for _, headerName := range config.AuthResponseHeaders {
r.Header.Set(headerName, forwardResponse.Header.Get(headerName))
for _, headerName := range fa.authResponseHeaders {
req.Header.Set(headerName, forwardResponse.Header.Get(headerName))
}
r.RequestURI = r.URL.RequestURI()
next(w, r)
req.RequestURI = req.URL.RequestURI()
fa.next.ServeHTTP(rw, req)
}
func writeHeader(req *http.Request, forwardReq *http.Request, trustForwardHeader bool) {

View file

@ -1,50 +1,49 @@
package auth
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/middlewares/tracing"
"github.com/containous/traefik/config"
"github.com/containous/traefik/testhelpers"
"github.com/containous/traefik/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/negroni"
"github.com/vulcand/oxy/forward"
)
func TestForwardAuthFail(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Forbidden", http.StatusForbidden)
}))
defer server.Close()
middleware, err := NewAuthenticator(&types.Auth{
Forward: &types.Forward{
Address: server.URL,
},
}, &tracing.Tracing{})
assert.NoError(t, err, "there should be no error")
middleware, err := NewForward(context.Background(), next, config.ForwardAuth{
Address: server.URL,
}, "authTest")
require.NoError(t, err)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
n := negroni.New(middleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
ts := httptest.NewServer(middleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := http.DefaultClient.Do(req)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, http.StatusForbidden, res.StatusCode, "they should be equal")
require.NoError(t, err)
assert.Equal(t, http.StatusForbidden, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, "Forbidden\n", string(body), "they should be equal")
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.Equal(t, "Forbidden\n", string(body))
}
func TestForwardAuthSuccess(t *testing.T) {
@ -55,32 +54,32 @@ func TestForwardAuthSuccess(t *testing.T) {
}))
defer server.Close()
middleware, err := NewAuthenticator(&types.Auth{
Forward: &types.Forward{
Address: server.URL,
AuthResponseHeaders: []string{"X-Auth-User"},
},
}, &tracing.Tracing{})
assert.NoError(t, err, "there should be no error")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "user@example.com", r.Header.Get("X-Auth-User"))
assert.Empty(t, r.Header.Get("X-Auth-Secret"))
fmt.Fprintln(w, "traefik")
})
n := negroni.New(middleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
auth := config.ForwardAuth{
Address: server.URL,
AuthResponseHeaders: []string{"X-Auth-User"},
}
middleware, err := NewForward(context.Background(), next, auth, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := http.DefaultClient.Do(req)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, "traefik\n", string(body), "they should be equal")
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.Equal(t, "traefik\n", string(body))
}
func TestForwardAuthRedirect(t *testing.T) {
@ -89,19 +88,17 @@ func TestForwardAuthRedirect(t *testing.T) {
}))
defer authTs.Close()
authMiddleware, err := NewAuthenticator(&types.Auth{
Forward: &types.Forward{
Address: authTs.URL,
},
}, &tracing.Tracing{})
assert.NoError(t, err, "there should be no error")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
n := negroni.New(authMiddleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
auth := config.ForwardAuth{
Address: authTs.URL,
}
authMiddleware, err := NewForward(context.Background(), next, auth, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
client := &http.Client{
@ -109,18 +106,23 @@ func TestForwardAuthRedirect(t *testing.T) {
return http.ErrUseLastResponse
},
}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := client.Do(req)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, http.StatusFound, res.StatusCode, "they should be equal")
require.NoError(t, err)
assert.Equal(t, http.StatusFound, res.StatusCode)
location, err := res.Location()
assert.NoError(t, err, "there should be no error")
assert.Equal(t, "http://example.com/redirect-test", location.String(), "they should be equal")
require.NoError(t, err)
assert.Equal(t, "http://example.com/redirect-test", location.String())
body, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err, "there should be no error")
assert.NotEmpty(t, string(body), "there should be something in the body")
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.NotEmpty(t, string(body))
}
func TestForwardAuthRemoveHopByHopHeaders(t *testing.T) {
@ -138,19 +140,17 @@ func TestForwardAuthRemoveHopByHopHeaders(t *testing.T) {
}))
defer authTs.Close()
authMiddleware, err := NewAuthenticator(&types.Auth{
Forward: &types.Forward{
Address: authTs.URL,
},
}, &tracing.Tracing{})
assert.NoError(t, err, "there should be no error")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
n := negroni.New(authMiddleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
auth := config.ForwardAuth{
Address: authTs.URL,
}
authMiddleware, err := NewForward(context.Background(), next, auth, "authTest")
assert.NoError(t, err, "there should be no error")
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
client := &http.Client{
@ -185,30 +185,28 @@ func TestForwardAuthFailResponseHeaders(t *testing.T) {
}))
defer authTs.Close()
authMiddleware, err := NewAuthenticator(&types.Auth{
Forward: &types.Forward{
Address: authTs.URL,
},
}, &tracing.Tracing{})
assert.NoError(t, err, "there should be no error")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
n := negroni.New(authMiddleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
auth := config.ForwardAuth{
Address: authTs.URL,
}
authMiddleware, err := NewForward(context.Background(), next, auth, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
client := &http.Client{}
res, err := client.Do(req)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, http.StatusForbidden, res.StatusCode, "they should be equal")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusForbidden, res.StatusCode)
require.Len(t, res.Cookies(), 1)
for _, cookie := range res.Cookies() {
assert.Equal(t, "testing", cookie.Value, "they should be equal")
assert.Equal(t, "testing", cookie.Value)
}
expectedHeaders := http.Header{
@ -225,8 +223,11 @@ func TestForwardAuthFailResponseHeaders(t *testing.T) {
}
body, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, "Forbidden\n", string(body), "they should be equal")
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.Equal(t, "Forbidden\n", string(body))
}
func Test_writeHeader(t *testing.T) {

View file

@ -1,48 +0,0 @@
package auth
import (
"fmt"
"strings"
"github.com/containous/traefik/types"
)
func parserBasicUsers(basic *types.Basic) (map[string]string, error) {
var userStrs []string
if basic.UsersFile != "" {
var err error
if userStrs, err = getLinesFromFile(basic.UsersFile); err != nil {
return nil, err
}
}
userStrs = append(basic.Users, userStrs...)
userMap := make(map[string]string)
for _, user := range userStrs {
split := strings.Split(user, ":")
if len(split) != 2 {
return nil, fmt.Errorf("error parsing Authenticator user: %v", user)
}
userMap[split[0]] = split[1]
}
return userMap, nil
}
func parserDigestUsers(digest *types.Digest) (map[string]string, error) {
var userStrs []string
if digest.UsersFile != "" {
var err error
if userStrs, err = getLinesFromFile(digest.UsersFile); err != nil {
return nil, err
}
}
userStrs = append(digest.Users, userStrs...)
userMap := make(map[string]string)
for _, user := range userStrs {
split := strings.Split(user, ":")
if len(split) != 3 {
return nil, fmt.Errorf("error parsing Authenticator user: %v", user)
}
userMap[split[0]+":"+split[1]] = split[2]
}
return userMap, nil
}

View file

@ -0,0 +1,54 @@
package buffering
import (
"context"
"net/http"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go/ext"
oxybuffer "github.com/vulcand/oxy/buffer"
)
const (
typeName = "Buffer"
)
type buffer struct {
name string
buffer *oxybuffer.Buffer
}
// New creates a buffering middleware.
func New(ctx context.Context, next http.Handler, config config.Buffering, name string) (http.Handler, error) {
logger := middlewares.GetLogger(ctx, name, typeName)
logger.Debug("Creating middleware")
logger.Debug("Setting up buffering: request limits: %d (mem), %d (max), response limits: %d (mem), %d (max) with retry: '%s'",
config.MemRequestBodyBytes, config.MaxRequestBodyBytes, config.MemResponseBodyBytes, config.MaxResponseBodyBytes, config.RetryExpression)
oxyBuffer, err := oxybuffer.New(
next,
oxybuffer.MemRequestBodyBytes(config.MemRequestBodyBytes),
oxybuffer.MaxRequestBodyBytes(config.MaxRequestBodyBytes),
oxybuffer.MemResponseBodyBytes(config.MemResponseBodyBytes),
oxybuffer.MaxResponseBodyBytes(config.MaxResponseBodyBytes),
oxybuffer.CondSetter(len(config.RetryExpression) > 0, oxybuffer.Retry(config.RetryExpression)),
)
if err != nil {
return nil, err
}
return &buffer{
name: name,
buffer: oxyBuffer,
}, nil
}
func (b *buffer) GetTracingInformation() (string, ext.SpanKindEnum) {
return b.name, tracing.SpanKindNoneEnum
}
func (b *buffer) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
b.buffer.ServeHTTP(rw, req)
}

View file

@ -1,40 +0,0 @@
package middlewares
import (
"net/http"
"github.com/containous/traefik/log"
"github.com/containous/traefik/middlewares/tracing"
"github.com/vulcand/oxy/cbreaker"
)
// CircuitBreaker holds the oxy circuit breaker.
type CircuitBreaker struct {
circuitBreaker *cbreaker.CircuitBreaker
}
// NewCircuitBreaker returns a new CircuitBreaker.
func NewCircuitBreaker(next http.Handler, expression string, options ...cbreaker.CircuitBreakerOption) (*CircuitBreaker, error) {
circuitBreaker, err := cbreaker.New(next, expression, options...)
if err != nil {
return nil, err
}
return &CircuitBreaker{circuitBreaker}, nil
}
// NewCircuitBreakerOptions returns a new CircuitBreakerOption
func NewCircuitBreakerOptions(expression string) cbreaker.CircuitBreakerOption {
return cbreaker.Fallback(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tracing.LogEventf(r, "blocked by circuit-breaker (%q)", expression)
w.WriteHeader(http.StatusServiceUnavailable)
if _, err := w.Write([]byte(http.StatusText(http.StatusServiceUnavailable))); err != nil {
log.Error(err)
}
}))
}
func (cb *CircuitBreaker) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
cb.circuitBreaker.ServeHTTP(rw, r)
}

View file

@ -0,0 +1,30 @@
package chain
import (
"context"
"net/http"
"github.com/containous/alice"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares"
)
const (
typeName = "Chain"
)
type chainBuilder interface {
BuildChain(ctx context.Context, middlewares []string) (*alice.Chain, error)
}
// New creates a chain middleware
func New(ctx context.Context, next http.Handler, config config.Chain, builder chainBuilder, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
middlewareChain, err := builder.BuildChain(ctx, config.Middlewares)
if err != nil {
return nil, err
}
return middlewareChain.Then(next)
}

View file

@ -0,0 +1,61 @@
package circuitbreaker
import (
"context"
"net/http"
"github.com/containous/traefik/config"
"github.com/containous/traefik/log"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/vulcand/oxy/cbreaker"
)
const (
typeName = "CircuitBreaker"
)
type circuitBreaker struct {
circuitBreaker *cbreaker.CircuitBreaker
name string
}
// New creates a new circuit breaker middleware.
func New(ctx context.Context, next http.Handler, confCircuitBreaker config.CircuitBreaker, name string) (http.Handler, error) {
expression := confCircuitBreaker.Expression
logger := middlewares.GetLogger(ctx, name, typeName)
logger.Debug("Creating middleware")
logger.Debug("Setting up with expression: %s", expression)
oxyCircuitBreaker, err := cbreaker.New(next, expression, createCircuitBreakerOptions(expression))
if err != nil {
return nil, err
}
return &circuitBreaker{
circuitBreaker: oxyCircuitBreaker,
name: name,
}, nil
}
// NewCircuitBreakerOptions returns a new CircuitBreakerOption
func createCircuitBreakerOptions(expression string) cbreaker.CircuitBreakerOption {
return cbreaker.Fallback(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
tracing.SetErrorWithEvent(req, "blocked by circuit-breaker (%q)", expression)
rw.WriteHeader(http.StatusServiceUnavailable)
if _, err := rw.Write([]byte(http.StatusText(http.StatusServiceUnavailable))); err != nil {
log.FromContext(req.Context()).Error(err)
}
}))
}
func (c *circuitBreaker) GetTracingInformation() (string, ext.SpanKindEnum) {
return c.name, tracing.SpanKindNoneEnum
}
func (c *circuitBreaker) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
middlewares.GetLogger(req.Context(), c.name, typeName).Debug("Entering middleware")
c.circuitBreaker.ServeHTTP(rw, req)
}

View file

@ -1,33 +0,0 @@
package middlewares
import (
"compress/gzip"
"net/http"
"strings"
"github.com/NYTimes/gziphandler"
"github.com/containous/traefik/log"
)
// Compress is a middleware that allows to compress the response
type Compress struct{}
// ServeHTTP is a function used by Negroni
func (c *Compress) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
contentType := r.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/grpc") {
next.ServeHTTP(rw, r)
} else {
gzipHandler(next).ServeHTTP(rw, r)
}
}
func gzipHandler(h http.Handler) http.Handler {
wrapper, err := gziphandler.GzipHandlerWithOpts(
gziphandler.CompressionLevel(gzip.DefaultCompression),
gziphandler.MinSize(gziphandler.DefaultMinSize))
if err != nil {
log.Error(err)
}
return wrapper(h)
}

View file

@ -0,0 +1,58 @@
package compress
import (
"compress/gzip"
"context"
"net/http"
"strings"
"github.com/NYTimes/gziphandler"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/sirupsen/logrus"
)
const (
typeName = "Compress"
)
// Compress is a middleware that allows to compress the response.
type compress struct {
next http.Handler
name string
}
// New creates a new compress middleware.
func New(ctx context.Context, next http.Handler, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
return &compress{
next: next,
name: name,
}, nil
}
func (c *compress) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
contentType := req.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/grpc") {
c.next.ServeHTTP(rw, req)
} else {
gzipHandler(c.next, middlewares.GetLogger(req.Context(), c.name, typeName)).ServeHTTP(rw, req)
}
}
func (c *compress) GetTracingInformation() (string, ext.SpanKindEnum) {
return c.name, tracing.SpanKindNoneEnum
}
func gzipHandler(h http.Handler, logger logrus.FieldLogger) http.Handler {
wrapper, err := gziphandler.GzipHandlerWithOpts(
gziphandler.CompressionLevel(gzip.DefaultCompression),
gziphandler.MinSize(gziphandler.DefaultMinSize))
if err != nil {
logger.Error(err)
}
return wrapper(h)
}

View file

@ -1,4 +1,4 @@
package middlewares
package compress
import (
"io/ioutil"
@ -10,7 +10,6 @@ import (
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/negroni"
)
const (
@ -22,19 +21,19 @@ const (
)
func TestShouldCompressWhenNoContentEncodingHeader(t *testing.T) {
handler := &Compress{}
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
baseBody := generateBytes(gziphandler.DefaultMinSize)
next := func(rw http.ResponseWriter, r *http.Request) {
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
_, err := rw.Write(baseBody)
assert.NoError(t, err)
}
})
handler := &compress{next: next}
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req, next)
handler.ServeHTTP(rw, req)
assert.Equal(t, gzipValue, rw.Header().Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, rw.Header().Get(varyHeader))
@ -45,20 +44,22 @@ func TestShouldCompressWhenNoContentEncodingHeader(t *testing.T) {
}
func TestShouldNotCompressWhenContentEncodingHeader(t *testing.T) {
handler := &Compress{}
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
fakeCompressedBody := generateBytes(gziphandler.DefaultMinSize)
next := func(rw http.ResponseWriter, r *http.Request) {
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(contentEncodingHeader, gzipValue)
rw.Header().Add(varyHeader, acceptEncodingHeader)
rw.Write(fakeCompressedBody)
}
_, err := rw.Write(fakeCompressedBody)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
})
handler := &compress{next: next}
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req, next)
handler.ServeHTTP(rw, req)
assert.Equal(t, gzipValue, rw.Header().Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, rw.Header().Get(varyHeader))
@ -67,36 +68,40 @@ func TestShouldNotCompressWhenContentEncodingHeader(t *testing.T) {
}
func TestShouldNotCompressWhenNoAcceptEncodingHeader(t *testing.T) {
handler := &Compress{}
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
fakeBody := generateBytes(gziphandler.DefaultMinSize)
next := func(rw http.ResponseWriter, r *http.Request) {
rw.Write(fakeBody)
}
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
_, err := rw.Write(fakeBody)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
})
handler := &compress{next: next}
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req, next)
handler.ServeHTTP(rw, req)
assert.Empty(t, rw.Header().Get(contentEncodingHeader))
assert.EqualValues(t, rw.Body.Bytes(), fakeBody)
}
func TestShouldNotCompressWhenGRPC(t *testing.T) {
handler := &Compress{}
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
req.Header.Add(contentTypeHeader, "application/grpc")
baseBody := generateBytes(gziphandler.DefaultMinSize)
next := func(rw http.ResponseWriter, r *http.Request) {
rw.Write(baseBody)
}
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
_, err := rw.Write(baseBody)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
})
handler := &compress{next: next}
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req, next)
handler.ServeHTTP(rw, req)
assert.Empty(t, rw.Header().Get(acceptEncodingHeader))
assert.Empty(t, rw.Header().Get(contentEncodingHeader))
@ -105,40 +110,43 @@ func TestShouldNotCompressWhenGRPC(t *testing.T) {
func TestIntegrationShouldNotCompress(t *testing.T) {
fakeCompressedBody := generateBytes(100000)
comp := &Compress{}
testCases := []struct {
name string
handler func(rw http.ResponseWriter, r *http.Request)
handler http.Handler
expectedStatusCode int
}{
{
name: "when content already compressed",
handler: func(rw http.ResponseWriter, r *http.Request) {
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(contentEncodingHeader, gzipValue)
rw.Header().Add(varyHeader, acceptEncodingHeader)
rw.Write(fakeCompressedBody)
},
_, err := rw.Write(fakeCompressedBody)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}),
expectedStatusCode: http.StatusOK,
},
{
name: "when content already compressed and status code Created",
handler: func(rw http.ResponseWriter, r *http.Request) {
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(contentEncodingHeader, gzipValue)
rw.Header().Add(varyHeader, acceptEncodingHeader)
rw.WriteHeader(http.StatusCreated)
rw.Write(fakeCompressedBody)
},
_, err := rw.Write(fakeCompressedBody)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}),
expectedStatusCode: http.StatusCreated,
},
}
for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {
negro := negroni.New(comp)
negro.UseHandlerFunc(test.handler)
ts := httptest.NewServer(negro)
compress := &compress{next: test.handler}
ts := httptest.NewServer(compress)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
@ -160,16 +168,18 @@ func TestIntegrationShouldNotCompress(t *testing.T) {
}
func TestShouldWriteHeaderWhenFlush(t *testing.T) {
comp := &Compress{}
negro := negroni.New(comp)
negro.UseHandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(contentEncodingHeader, gzipValue)
rw.Header().Add(varyHeader, acceptEncodingHeader)
rw.WriteHeader(http.StatusUnauthorized)
rw.(http.Flusher).Flush()
rw.Write([]byte("short"))
_, err := rw.Write([]byte("short"))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
})
ts := httptest.NewServer(negro)
handler := &compress{next: next}
ts := httptest.NewServer(handler)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
@ -189,34 +199,36 @@ func TestIntegrationShouldCompress(t *testing.T) {
testCases := []struct {
name string
handler func(rw http.ResponseWriter, r *http.Request)
handler http.Handler
expectedStatusCode int
}{
{
name: "when AcceptEncoding header is present",
handler: func(rw http.ResponseWriter, r *http.Request) {
rw.Write(fakeBody)
},
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
_, err := rw.Write(fakeBody)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}),
expectedStatusCode: http.StatusOK,
},
{
name: "when AcceptEncoding header is present and status code Created",
handler: func(rw http.ResponseWriter, r *http.Request) {
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusCreated)
rw.Write(fakeBody)
},
_, err := rw.Write(fakeBody)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}),
expectedStatusCode: http.StatusCreated,
},
}
for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {
comp := &Compress{}
negro := negroni.New(comp)
negro.UseHandlerFunc(test.handler)
ts := httptest.NewServer(negro)
compress := &compress{next: test.handler}
ts := httptest.NewServer(compress)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)

View file

@ -1,9 +1,9 @@
package errorpages
package customerrors
import (
"bufio"
"bytes"
"errors"
"context"
"fmt"
"net"
"net/http"
@ -11,110 +11,120 @@ import (
"strconv"
"strings"
"github.com/containous/traefik/log"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/types"
"github.com/vulcand/oxy/forward"
"github.com/containous/traefik/old/types"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/sirupsen/logrus"
"github.com/vulcand/oxy/utils"
)
// Compile time validation that the response recorder implements http interfaces correctly.
var _ middlewares.Stateful = &responseRecorderWithCloseNotify{}
// Handler is a middleware that provides the custom error pages
type Handler struct {
BackendName string
backendHandler http.Handler
httpCodeRanges types.HTTPCodeRanges
backendURL string
backendQuery string
FallbackURL string // Deprecated
const (
typeName = "customError"
backendURL = "http://0.0.0.0"
)
type serviceBuilder interface {
Build(ctx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error)
}
// NewHandler initializes the utils.ErrorHandler for the custom error pages
func NewHandler(errorPage *types.ErrorPage, backendName string) (*Handler, error) {
if len(backendName) == 0 {
return nil, errors.New("error pages: backend name is mandatory ")
}
// customErrors is a middleware that provides the custom error pages..
type customErrors struct {
name string
next http.Handler
backendHandler http.Handler
httpCodeRanges types.HTTPCodeRanges
backendQuery string
}
httpCodeRanges, err := types.NewHTTPCodeRanges(errorPage.Status)
// New creates a new custom error pages middleware.
func New(ctx context.Context, next http.Handler, config config.ErrorPage, serviceBuilder serviceBuilder, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
httpCodeRanges, err := types.NewHTTPCodeRanges(config.Status)
if err != nil {
return nil, err
}
return &Handler{
BackendName: backendName,
backend, err := serviceBuilder.Build(ctx, config.Service, nil)
if err != nil {
return nil, err
}
return &customErrors{
name: name,
next: next,
backendHandler: backend,
httpCodeRanges: httpCodeRanges,
backendQuery: errorPage.Query,
backendURL: "http://0.0.0.0",
backendQuery: config.Query,
}, nil
}
// PostLoad adds backend handler if available
func (h *Handler) PostLoad(backendHandler http.Handler) error {
if backendHandler == nil {
fwd, err := forward.New()
if err != nil {
return err
}
h.backendHandler = fwd
h.backendURL = h.FallbackURL
} else {
h.backendHandler = backendHandler
}
return nil
func (c *customErrors) GetTracingInformation() (string, ext.SpanKindEnum) {
return c.name, tracing.SpanKindNoneEnum
}
func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
if h.backendHandler == nil {
log.Error("Error pages: no backend handler.")
next.ServeHTTP(w, req)
func (c *customErrors) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), c.name, typeName)
if c.backendHandler == nil {
logger.Error("Error pages: no backend handler.")
tracing.SetErrorWithEvent(req, "Error pages: no backend handler.")
c.next.ServeHTTP(rw, req)
return
}
recorder := newResponseRecorder(w)
next.ServeHTTP(recorder, req)
recorder := newResponseRecorder(rw, middlewares.GetLogger(context.Background(), "test", typeName))
c.next.ServeHTTP(recorder, req)
// check the recorder code against the configured http status code ranges
for _, block := range h.httpCodeRanges {
for _, block := range c.httpCodeRanges {
if recorder.GetCode() >= block[0] && recorder.GetCode() <= block[1] {
log.Errorf("Caught HTTP Status Code %d, returning error page", recorder.GetCode())
logger.Errorf("Caught HTTP Status Code %d, returning error page", recorder.GetCode())
var query string
if len(h.backendQuery) > 0 {
query = "/" + strings.TrimPrefix(h.backendQuery, "/")
if len(c.backendQuery) > 0 {
query = "/" + strings.TrimPrefix(c.backendQuery, "/")
query = strings.Replace(query, "{status}", strconv.Itoa(recorder.GetCode()), -1)
}
pageReq, err := newRequest(h.backendURL + query)
pageReq, err := newRequest(backendURL + query)
if err != nil {
log.Error(err)
w.WriteHeader(recorder.GetCode())
fmt.Fprint(w, http.StatusText(recorder.GetCode()))
logger.Error(err)
rw.WriteHeader(recorder.GetCode())
_, err = fmt.Fprint(rw, http.StatusText(recorder.GetCode()))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
return
}
recorderErrorPage := newResponseRecorder(w)
recorderErrorPage := newResponseRecorder(rw, middlewares.GetLogger(context.Background(), "test", typeName))
utils.CopyHeaders(pageReq.Header, req.Header)
h.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context()))
c.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context()))
utils.CopyHeaders(w.Header(), recorderErrorPage.Header())
w.WriteHeader(recorder.GetCode())
utils.CopyHeaders(rw.Header(), recorderErrorPage.Header())
rw.WriteHeader(recorder.GetCode())
if _, err = w.Write(recorderErrorPage.GetBody().Bytes()); err != nil {
log.Error(err)
if _, err = rw.Write(recorderErrorPage.GetBody().Bytes()); err != nil {
logger.Error(err)
}
return
}
}
// did not catch a configured status code so proceed with the request
utils.CopyHeaders(w.Header(), recorder.Header())
w.WriteHeader(recorder.GetCode())
w.Write(recorder.GetBody().Bytes())
utils.CopyHeaders(rw.Header(), recorder.Header())
rw.WriteHeader(recorder.GetCode())
_, err := rw.Write(recorder.GetBody().Bytes())
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}
func newRequest(baseURL string) (*http.Request, error) {
@ -123,7 +133,7 @@ func newRequest(baseURL string) (*http.Request, error) {
return nil, fmt.Errorf("error pages: error when parse URL: %v", err)
}
req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody)
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
if err != nil {
return nil, fmt.Errorf("error pages: error when create query: %v", err)
}
@ -141,12 +151,13 @@ type responseRecorder interface {
}
// newResponseRecorder returns an initialized responseRecorder.
func newResponseRecorder(rw http.ResponseWriter) responseRecorder {
func newResponseRecorder(rw http.ResponseWriter, logger logrus.FieldLogger) responseRecorder {
recorder := &responseRecorderWithoutCloseNotify{
HeaderMap: make(http.Header),
Body: new(bytes.Buffer),
Code: http.StatusOK,
responseWriter: rw,
logger: logger,
}
if _, ok := rw.(http.CloseNotifier); ok {
return &responseRecorderWithCloseNotify{recorder}
@ -164,6 +175,7 @@ type responseRecorderWithoutCloseNotify struct {
responseWriter http.ResponseWriter
err error
streamingResponseStarted bool
logger logrus.FieldLogger
}
type responseRecorderWithCloseNotify struct {
@ -225,7 +237,7 @@ func (r *responseRecorderWithoutCloseNotify) Flush() {
_, err := r.responseWriter.Write(r.Body.Bytes())
if err != nil {
log.Errorf("Error writing response in responseRecorder: %v", err)
r.logger.Errorf("Error writing response in responseRecorder: %v", err)
r.err = err
}
r.Body.Reset()

View file

@ -0,0 +1,176 @@
package customerrors
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHandler(t *testing.T) {
testCases := []struct {
desc string
errorPage *config.ErrorPage
backendCode int
backendErrorHandler http.HandlerFunc
validate func(t *testing.T, recorder *httptest.ResponseRecorder)
}{
{
desc: "no error",
errorPage: &config.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusOK,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusOK))
},
},
{
desc: "in the range",
errorPage: &config.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusInternalServerError,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusInternalServerError, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My error page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
{
desc: "not in the range",
errorPage: &config.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusBadGateway,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusBadGateway, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusBadGateway))
assert.NotContains(t, recorder.Body.String(), "Test Server", "Should return the oops page since we have not configured the 502 code")
},
},
{
desc: "query replacement",
errorPage: &config.ErrorPage{Service: "error", Query: "/{status}", Status: []string{"503-503"}},
backendCode: http.StatusServiceUnavailable,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI == "/503" {
fmt.Fprintln(w, "My 503 page.")
} else {
fmt.Fprintln(w, "Failed")
}
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My 503 page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
{
desc: "Single code",
errorPage: &config.ErrorPage{Service: "error", Query: "/{status}", Status: []string{"503"}},
backendCode: http.StatusServiceUnavailable,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI == "/503" {
fmt.Fprintln(w, "My 503 page.")
} else {
fmt.Fprintln(w, "Failed")
}
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My 503 page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
serviceBuilderMock := &mockServiceBuilder{handler: test.backendErrorHandler}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(test.backendCode)
fmt.Fprintln(w, http.StatusText(test.backendCode))
})
errorPageHandler, err := New(context.Background(), handler, *test.errorPage, serviceBuilderMock, "test")
require.NoError(t, err)
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost/test", nil)
recorder := httptest.NewRecorder()
errorPageHandler.ServeHTTP(recorder, req)
test.validate(t, recorder)
})
}
}
type mockServiceBuilder struct {
handler http.Handler
}
func (m *mockServiceBuilder) Build(_ context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) {
return m.handler, nil
}
func TestNewResponseRecorder(t *testing.T) {
testCases := []struct {
desc string
rw http.ResponseWriter
expected http.ResponseWriter
}{
{
desc: "Without Close Notify",
rw: httptest.NewRecorder(),
expected: &responseRecorderWithoutCloseNotify{},
},
{
desc: "With Close Notify",
rw: &mockRWCloseNotify{},
expected: &responseRecorderWithCloseNotify{},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
rec := newResponseRecorder(test.rw, middlewares.GetLogger(context.Background(), "test", typeName))
assert.IsType(t, rec, test.expected)
})
}
}
type mockRWCloseNotify struct{}
func (m *mockRWCloseNotify) CloseNotify() <-chan bool {
panic("implement me")
}
func (m *mockRWCloseNotify) Header() http.Header {
panic("implement me")
}
func (m *mockRWCloseNotify) Write([]byte) (int, error) {
panic("implement me")
}
func (m *mockRWCloseNotify) WriteHeader(int) {
panic("implement me")
}

View file

@ -1,30 +0,0 @@
package middlewares
import (
"net/http"
"github.com/containous/traefik/healthcheck"
)
// EmptyBackendHandler is a middlware that checks whether the current Backend
// has at least one active Server in respect to the healthchecks and if this
// is not the case, it will stop the middleware chain and respond with 503.
type EmptyBackendHandler struct {
next healthcheck.BalancerHandler
}
// NewEmptyBackendHandler creates a new EmptyBackendHandler instance.
func NewEmptyBackendHandler(lb healthcheck.BalancerHandler) *EmptyBackendHandler {
return &EmptyBackendHandler{next: lb}
}
// ServeHTTP responds with 503 when there is no active Server and otherwise
// invokes the next handler in the middleware chain.
func (h *EmptyBackendHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
if len(h.next.Servers()) == 0 {
rw.WriteHeader(http.StatusServiceUnavailable)
rw.Write([]byte(http.StatusText(http.StatusServiceUnavailable)))
} else {
h.next.ServeHTTP(rw, r)
}
}

View file

@ -0,0 +1,33 @@
package emptybackendhandler
import (
"net/http"
"github.com/containous/traefik/healthcheck"
)
// EmptyBackend is a middleware that checks whether the current Backend
// has at least one active Server in respect to the healthchecks and if this
// is not the case, it will stop the middleware chain and respond with 503.
type emptyBackend struct {
next healthcheck.BalancerHandler
}
// New creates a new EmptyBackend middleware.
func New(lb healthcheck.BalancerHandler) http.Handler {
return &emptyBackend{next: lb}
}
// ServeHTTP responds with 503 when there is no active Server and otherwise
// invokes the next handler in the middleware chain.
func (e *emptyBackend) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if len(e.next.Servers()) == 0 {
rw.WriteHeader(http.StatusServiceUnavailable)
_, err := rw.Write([]byte(http.StatusText(http.StatusServiceUnavailable)))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
} else {
e.next.ServeHTTP(rw, req)
}
}

View file

@ -1,4 +1,4 @@
package middlewares
package emptybackendhandler
import (
"fmt"
@ -8,40 +8,38 @@ import (
"testing"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/vulcand/oxy/roundrobin"
)
func TestEmptyBackendHandler(t *testing.T) {
tests := []struct {
amountServer int
wantStatusCode int
testCases := []struct {
amountServer int
expectedStatusCode int
}{
{
amountServer: 0,
wantStatusCode: http.StatusServiceUnavailable,
amountServer: 0,
expectedStatusCode: http.StatusServiceUnavailable,
},
{
amountServer: 1,
wantStatusCode: http.StatusOK,
amountServer: 1,
expectedStatusCode: http.StatusOK,
},
}
for _, test := range tests {
for _, test := range testCases {
test := test
t.Run(fmt.Sprintf("amount servers %d", test.amountServer), func(t *testing.T) {
t.Parallel()
handler := NewEmptyBackendHandler(&healthCheckLoadBalancer{test.amountServer})
handler := New(&healthCheckLoadBalancer{amountServer: test.amountServer})
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
handler.ServeHTTP(recorder, req)
if recorder.Result().StatusCode != test.wantStatusCode {
t.Errorf("Received status code %d, wanted %d", recorder.Result().StatusCode, test.wantStatusCode)
}
assert.Equal(t, test.expectedStatusCode, recorder.Result().StatusCode)
})
}
}

View file

@ -1,384 +0,0 @@
package errorpages
import (
"fmt"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"github.com/containous/traefik/testhelpers"
"github.com/containous/traefik/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/negroni"
)
func TestHandler(t *testing.T) {
testCases := []struct {
desc string
errorPage *types.ErrorPage
backendCode int
backendErrorHandler http.HandlerFunc
validate func(t *testing.T, recorder *httptest.ResponseRecorder)
}{
{
desc: "no error",
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusOK,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusOK))
},
},
{
desc: "in the range",
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusInternalServerError,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusInternalServerError, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My error page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
{
desc: "not in the range",
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusBadGateway,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusBadGateway, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusBadGateway))
assert.NotContains(t, recorder.Body.String(), "Test Server", "Should return the oops page since we have not configured the 502 code")
},
},
{
desc: "query replacement",
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503-503"}},
backendCode: http.StatusServiceUnavailable,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI == "/503" {
fmt.Fprintln(w, "My 503 page.")
} else {
fmt.Fprintln(w, "Failed")
}
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My 503 page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
{
desc: "Single code",
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503"}},
backendCode: http.StatusServiceUnavailable,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI == "/503" {
fmt.Fprintln(w, "My 503 page.")
} else {
fmt.Fprintln(w, "Failed")
}
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My 503 page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
errorPageHandler, err := NewHandler(test.errorPage, "test")
require.NoError(t, err)
errorPageHandler.backendHandler = test.backendErrorHandler
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(test.backendCode)
fmt.Fprintln(w, http.StatusText(test.backendCode))
})
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost/test", nil)
n := negroni.New()
n.Use(errorPageHandler)
n.UseHandler(handler)
recorder := httptest.NewRecorder()
n.ServeHTTP(recorder, req)
test.validate(t, recorder)
})
}
}
func TestHandlerOldWay(t *testing.T) {
testCases := []struct {
desc string
errorPage *types.ErrorPage
backendCode int
errorPageForwarder http.HandlerFunc
validate func(t *testing.T, recorder *httptest.ResponseRecorder)
}{
{
desc: "no error",
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusOK,
errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "OK")
},
},
{
desc: "in the range",
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusInternalServerError,
errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusInternalServerError, recorder.Code)
assert.Contains(t, recorder.Body.String(), "My error page.")
assert.NotContains(t, recorder.Body.String(), http.StatusText(http.StatusInternalServerError), "Should not return the oops page")
},
},
{
desc: "not in the range",
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusBadGateway,
errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusBadGateway, recorder.Code)
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusBadGateway))
assert.NotContains(t, recorder.Body.String(), "My error page.", "Should return the oops page since we have not configured the 502 code")
},
},
{
desc: "query replacement",
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503-503"}},
backendCode: http.StatusServiceUnavailable,
errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.RequestURI() == "/"+strconv.Itoa(503) {
fmt.Fprintln(w, "My 503 page.")
} else {
fmt.Fprintln(w, "Failed")
}
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My 503 page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
{
desc: "Single code",
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503"}},
backendCode: http.StatusServiceUnavailable,
errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.RequestURI() == "/"+strconv.Itoa(503) {
fmt.Fprintln(w, "My 503 page.")
} else {
fmt.Fprintln(w, "Failed")
}
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My 503 page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
}
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost/test", nil)
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
errorPageHandler, err := NewHandler(test.errorPage, "test")
require.NoError(t, err)
errorPageHandler.FallbackURL = "http://localhost"
err = errorPageHandler.PostLoad(test.errorPageForwarder)
require.NoError(t, err)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(test.backendCode)
fmt.Fprintln(w, http.StatusText(test.backendCode))
})
n := negroni.New()
n.Use(errorPageHandler)
n.UseHandler(handler)
recorder := httptest.NewRecorder()
n.ServeHTTP(recorder, req)
test.validate(t, recorder)
})
}
}
func TestHandlerOldWayIntegration(t *testing.T) {
errorPagesServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.RequestURI() == "/503" {
fmt.Fprintln(w, "My 503 page.")
} else {
fmt.Fprintln(w, "Test Server")
}
}))
defer errorPagesServer.Close()
testCases := []struct {
desc string
errorPage *types.ErrorPage
backendCode int
validate func(t *testing.T, recorder *httptest.ResponseRecorder)
}{
{
desc: "no error",
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusOK,
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "OK")
},
},
{
desc: "in the range",
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusInternalServerError,
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusInternalServerError, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Test Server")
assert.NotContains(t, recorder.Body.String(), http.StatusText(http.StatusInternalServerError), "Should not return the oops page")
},
},
{
desc: "not in the range",
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusBadGateway,
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusBadGateway, recorder.Code)
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusBadGateway))
assert.NotContains(t, recorder.Body.String(), "Test Server", "Should return the oops page since we have not configured the 502 code")
},
},
{
desc: "query replacement",
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503-503"}},
backendCode: http.StatusServiceUnavailable,
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My 503 page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
{
desc: "Single code",
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503"}},
backendCode: http.StatusServiceUnavailable,
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My 503 page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
}
req := testhelpers.MustNewRequest(http.MethodGet, errorPagesServer.URL+"/test", nil)
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
errorPageHandler, err := NewHandler(test.errorPage, "test")
require.NoError(t, err)
errorPageHandler.FallbackURL = errorPagesServer.URL
err = errorPageHandler.PostLoad(nil)
require.NoError(t, err)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(test.backendCode)
fmt.Fprintln(w, http.StatusText(test.backendCode))
})
n := negroni.New()
n.Use(errorPageHandler)
n.UseHandler(handler)
recorder := httptest.NewRecorder()
n.ServeHTTP(recorder, req)
test.validate(t, recorder)
})
}
}
func TestNewResponseRecorder(t *testing.T) {
testCases := []struct {
desc string
rw http.ResponseWriter
expected http.ResponseWriter
}{
{
desc: "Without Close Notify",
rw: httptest.NewRecorder(),
expected: &responseRecorderWithoutCloseNotify{},
},
{
desc: "With Close Notify",
rw: &mockRWCloseNotify{},
expected: &responseRecorderWithCloseNotify{},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
rec := newResponseRecorder(test.rw)
assert.IsType(t, rec, test.expected)
})
}
}
type mockRWCloseNotify struct{}
func (m *mockRWCloseNotify) CloseNotify() <-chan bool {
panic("implement me")
}
func (m *mockRWCloseNotify) Header() http.Header {
panic("implement me")
}
func (m *mockRWCloseNotify) Write([]byte) (int, error) {
panic("implement me")
}
func (m *mockRWCloseNotify) WriteHeader(int) {
panic("implement me")
}

View file

@ -1,52 +0,0 @@
package forwardedheaders
import (
"net/http"
"github.com/containous/traefik/ip"
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/utils"
)
// XForwarded filter for XForwarded headers
type XForwarded struct {
insecure bool
trustedIps []string
ipChecker *ip.Checker
}
// NewXforwarded creates a new XForwarded
func NewXforwarded(insecure bool, trustedIps []string) (*XForwarded, error) {
var ipChecker *ip.Checker
if len(trustedIps) > 0 {
var err error
ipChecker, err = ip.NewChecker(trustedIps)
if err != nil {
return nil, err
}
}
return &XForwarded{
insecure: insecure,
trustedIps: trustedIps,
ipChecker: ipChecker,
}, nil
}
func (x *XForwarded) isTrustedIP(ip string) bool {
if x.ipChecker == nil {
return false
}
return x.ipChecker.IsAuthorized(ip) == nil
}
func (x *XForwarded) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
if !x.insecure && !x.isTrustedIP(r.RemoteAddr) {
utils.RemoveHeaders(r.Header, forward.XHeaders...)
}
// If there is a next, call it.
if next != nil {
next(w, r)
}
}

View file

@ -1,128 +0,0 @@
package forwardedheaders
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestServeHTTP(t *testing.T) {
testCases := []struct {
desc string
insecure bool
trustedIps []string
incomingHeaders map[string]string
remoteAddr string
expectedHeaders map[string]string
}{
{
desc: "all Empty",
insecure: true,
trustedIps: nil,
remoteAddr: "",
incomingHeaders: map[string]string{},
expectedHeaders: map[string]string{
"X-Forwarded-for": "",
},
},
{
desc: "insecure true with incoming X-Forwarded-For",
insecure: true,
trustedIps: nil,
remoteAddr: "",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
},
{
desc: "insecure false with incoming X-Forwarded-For",
insecure: false,
trustedIps: nil,
remoteAddr: "",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "",
},
},
{
desc: "insecure false with incoming X-Forwarded-For and valid Trusted Ips",
insecure: false,
trustedIps: []string{"10.0.1.100"},
remoteAddr: "10.0.1.100:80",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
},
{
desc: "insecure false with incoming X-Forwarded-For and invalid Trusted Ips",
insecure: false,
trustedIps: []string{"10.0.1.100"},
remoteAddr: "10.0.1.101:80",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "",
},
},
{
desc: "insecure false with incoming X-Forwarded-For and valid Trusted Ips CIDR",
insecure: false,
trustedIps: []string{"1.2.3.4/24"},
remoteAddr: "1.2.3.156:80",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
},
{
desc: "insecure false with incoming X-Forwarded-For and invalid Trusted Ips CIDR",
insecure: false,
trustedIps: []string{"1.2.3.4/24"},
remoteAddr: "10.0.1.101:80",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "",
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
req, err := http.NewRequest(http.MethodGet, "", nil)
require.NoError(t, err)
req.RemoteAddr = test.remoteAddr
for k, v := range test.incomingHeaders {
req.Header.Set(k, v)
}
m, err := NewXforwarded(test.insecure, test.trustedIps)
require.NoError(t, err)
m.ServeHTTP(nil, req, nil)
for k, v := range test.expectedHeaders {
assert.Equal(t, v, req.Header.Get(k))
}
})
}
}

View file

@ -3,7 +3,6 @@ package middlewares
import (
"net/http"
"github.com/containous/mux"
"github.com/containous/traefik/safe"
)
@ -13,24 +12,24 @@ type HandlerSwitcher struct {
}
// NewHandlerSwitcher builds a new instance of HandlerSwitcher
func NewHandlerSwitcher(newHandler *mux.Router) (hs *HandlerSwitcher) {
func NewHandlerSwitcher(newHandler http.Handler) (hs *HandlerSwitcher) {
return &HandlerSwitcher{
handler: safe.New(newHandler),
}
}
func (hs *HandlerSwitcher) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
handlerBackup := hs.handler.Get().(*mux.Router)
handlerBackup.ServeHTTP(rw, r)
func (h *HandlerSwitcher) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
handlerBackup := h.handler.Get().(http.Handler)
handlerBackup.ServeHTTP(rw, req)
}
// GetHandler returns the current http.ServeMux
func (hs *HandlerSwitcher) GetHandler() (newHandler *mux.Router) {
handler := hs.handler.Get().(*mux.Router)
func (h *HandlerSwitcher) GetHandler() (newHandler http.Handler) {
handler := h.handler.Get().(http.Handler)
return handler
}
// UpdateHandler safely updates the current http.ServeMux with a new one
func (hs *HandlerSwitcher) UpdateHandler(newHandler *mux.Router) {
hs.handler.Set(newHandler)
func (h *HandlerSwitcher) UpdateHandler(newHandler http.Handler) {
h.handler.Set(newHandler)
}

View file

@ -1,71 +0,0 @@
package middlewares
// Middleware based on https://github.com/unrolled/secure
import (
"net/http"
"github.com/containous/traefik/types"
)
// HeaderOptions is a struct for specifying configuration options for the headers middleware.
type HeaderOptions struct {
// If Custom request headers are set, these will be added to the request
CustomRequestHeaders map[string]string
// If Custom response headers are set, these will be added to the ResponseWriter
CustomResponseHeaders map[string]string
}
// HeaderStruct is a middleware that helps setup a few basic security features. A single headerOptions struct can be
// provided to configure which features should be enabled, and the ability to override a few of the default values.
type HeaderStruct struct {
// Customize headers with a headerOptions struct.
opt HeaderOptions
}
// NewHeaderFromStruct constructs a new header instance from supplied frontend header struct.
func NewHeaderFromStruct(headers *types.Headers) *HeaderStruct {
if headers == nil || !headers.HasCustomHeadersDefined() {
return nil
}
return &HeaderStruct{
opt: HeaderOptions{
CustomRequestHeaders: headers.CustomRequestHeaders,
CustomResponseHeaders: headers.CustomResponseHeaders,
},
}
}
func (s *HeaderStruct) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
s.ModifyRequestHeaders(r)
// If there is a next, call it.
if next != nil {
next(w, r)
}
}
// ModifyRequestHeaders set or delete request headers
func (s *HeaderStruct) ModifyRequestHeaders(r *http.Request) {
// Loop through Custom request headers
for header, value := range s.opt.CustomRequestHeaders {
if value == "" {
r.Header.Del(header)
} else {
r.Header.Set(header, value)
}
}
}
// ModifyResponseHeaders set or delete response headers
func (s *HeaderStruct) ModifyResponseHeaders(res *http.Response) error {
// Loop through Custom response headers
for header, value := range s.opt.CustomResponseHeaders {
if value == "" {
res.Header.Del(header)
} else {
res.Header.Set(header, value)
}
}
return nil
}

View file

@ -0,0 +1,134 @@
// Package headers Middleware based on https://github.com/unrolled/secure.
package headers
import (
"context"
"errors"
"net/http"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/unrolled/secure"
)
const (
typeName = "Headers"
)
type headers struct {
name string
handler http.Handler
}
// New creates a Headers middleware.
func New(ctx context.Context, next http.Handler, config config.Headers, name string) (http.Handler, error) {
// HeaderMiddleware -> SecureMiddleWare -> next
logger := middlewares.GetLogger(ctx, name, typeName)
logger.Debug("Creating middleware")
if !config.HasSecureHeadersDefined() && !config.HasCustomHeadersDefined() {
return nil, errors.New("headers configuration not valid")
}
var handler http.Handler
nextHandler := next
if config.HasSecureHeadersDefined() {
logger.Debug("Setting up secureHeaders from %v", config)
handler = newSecure(next, config)
nextHandler = handler
}
if config.HasCustomHeadersDefined() {
logger.Debug("Setting up customHeaders from %v", config)
handler = newHeader(nextHandler, config)
}
return &headers{
handler: handler,
name: name,
}, nil
}
func (h *headers) GetTracingInformation() (string, ext.SpanKindEnum) {
return h.name, tracing.SpanKindNoneEnum
}
func (h *headers) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
h.handler.ServeHTTP(rw, req)
}
type secureHeader struct {
next http.Handler
secure *secure.Secure
}
// newSecure constructs a new secure instance with supplied options.
func newSecure(next http.Handler, headers config.Headers) *secureHeader {
opt := secure.Options{
BrowserXssFilter: headers.BrowserXSSFilter,
ContentTypeNosniff: headers.ContentTypeNosniff,
ForceSTSHeader: headers.ForceSTSHeader,
FrameDeny: headers.FrameDeny,
IsDevelopment: headers.IsDevelopment,
SSLRedirect: headers.SSLRedirect,
SSLForceHost: headers.SSLForceHost,
SSLTemporaryRedirect: headers.SSLTemporaryRedirect,
STSIncludeSubdomains: headers.STSIncludeSubdomains,
STSPreload: headers.STSPreload,
ContentSecurityPolicy: headers.ContentSecurityPolicy,
CustomBrowserXssValue: headers.CustomBrowserXSSValue,
CustomFrameOptionsValue: headers.CustomFrameOptionsValue,
PublicKey: headers.PublicKey,
ReferrerPolicy: headers.ReferrerPolicy,
SSLHost: headers.SSLHost,
AllowedHosts: headers.AllowedHosts,
HostsProxyHeaders: headers.HostsProxyHeaders,
SSLProxyHeaders: headers.SSLProxyHeaders,
STSSeconds: headers.STSSeconds,
}
return &secureHeader{
next: next,
secure: secure.New(opt),
}
}
func (s secureHeader) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
s.secure.HandlerFuncWithNextForRequestOnly(rw, req, s.next.ServeHTTP)
}
// Header is a middleware that helps setup a few basic security features. A single headerOptions struct can be
// provided to configure which features should be enabled, and the ability to override a few of the default values.
type header struct {
next http.Handler
// If Custom request headers are set, these will be added to the request
customRequestHeaders map[string]string
}
// NewHeader constructs a new header instance from supplied frontend header struct.
func newHeader(next http.Handler, headers config.Headers) *header {
return &header{
next: next,
customRequestHeaders: headers.CustomRequestHeaders,
}
}
func (s *header) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
s.modifyRequestHeaders(req)
s.next.ServeHTTP(rw, req)
}
// modifyRequestHeaders set or delete request headers.
func (s *header) modifyRequestHeaders(req *http.Request) {
// Loop through Custom request headers
for header, value := range s.customRequestHeaders {
if value == "" {
req.Header.Del(header)
} else {
req.Header.Set(header, value)
}
}
}

View file

@ -0,0 +1,105 @@
package headers
// Middleware tests based on https://github.com/unrolled/secure
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/config"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCustomRequestHeader(t *testing.T) {
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
header := newHeader(emptyHandler, config.Headers{
CustomRequestHeaders: map[string]string{
"X-Custom-Request-Header": "test_request",
},
})
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
header.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"))
}
func TestCustomRequestHeaderEmptyValue(t *testing.T) {
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
header := newHeader(emptyHandler, config.Headers{
CustomRequestHeaders: map[string]string{
"X-Custom-Request-Header": "test_request",
},
})
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
header.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"))
header = newHeader(emptyHandler, config.Headers{
CustomRequestHeaders: map[string]string{
"X-Custom-Request-Header": "",
},
})
header.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
assert.Equal(t, "", req.Header.Get("X-Custom-Request-Header"))
}
func TestSecureHeader(t *testing.T) {
testCases := []struct {
desc string
fromHost string
expected int
}{
{
desc: "Should accept the request when given a host that is in the list",
fromHost: "foo.com",
expected: http.StatusOK,
},
{
desc: "Should refuse the request when no host is given",
fromHost: "",
expected: http.StatusInternalServerError,
},
{
desc: "Should refuse the request when no matching host is given",
fromHost: "boo.com",
expected: http.StatusInternalServerError,
},
}
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
header, err := New(context.Background(), emptyHandler, config.Headers{
AllowedHosts: []string{"foo.com", "bar.com"},
}, "foo")
require.NoError(t, err)
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
req.Host = test.fromHost
header.ServeHTTP(res, req)
assert.Equal(t, test.expected, res.Code)
})
}
}

View file

@ -1,118 +0,0 @@
package middlewares
// Middleware tests based on https://github.com/unrolled/secure
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
)
var myHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("bar"))
})
// newHeader constructs a new header instance with supplied options.
func newHeader(options ...HeaderOptions) *HeaderStruct {
var opt HeaderOptions
if len(options) == 0 {
opt = HeaderOptions{}
} else {
opt = options[0]
}
return &HeaderStruct{opt: opt}
}
func TestNoConfig(t *testing.T) {
header := newHeader()
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://example.com/foo", nil)
header.ServeHTTP(res, req, myHandler)
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
assert.Equal(t, "bar", res.Body.String(), "Body not the expected")
}
func TestModifyResponseHeaders(t *testing.T) {
header := newHeader(HeaderOptions{
CustomResponseHeaders: map[string]string{
"X-Custom-Response-Header": "test_response",
},
})
res := httptest.NewRecorder()
res.HeaderMap.Add("X-Custom-Response-Header", "test_response")
err := header.ModifyResponseHeaders(res.Result())
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
assert.Equal(t, "test_response", res.Header().Get("X-Custom-Response-Header"), "Did not get expected header")
res = httptest.NewRecorder()
res.HeaderMap.Add("X-Custom-Response-Header", "")
err = header.ModifyResponseHeaders(res.Result())
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
assert.Equal(t, "", res.Header().Get("X-Custom-Response-Header"), "Did not get expected header")
res = httptest.NewRecorder()
res.HeaderMap.Add("X-Custom-Response-Header", "test_override")
err = header.ModifyResponseHeaders(res.Result())
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
assert.Equal(t, "test_override", res.Header().Get("X-Custom-Response-Header"), "Did not get expected header")
}
func TestCustomRequestHeader(t *testing.T) {
header := newHeader(HeaderOptions{
CustomRequestHeaders: map[string]string{
"X-Custom-Request-Header": "test_request",
},
})
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
header.ServeHTTP(res, req, nil)
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"), "Did not get expected header")
}
func TestCustomRequestHeaderEmptyValue(t *testing.T) {
header := newHeader(HeaderOptions{
CustomRequestHeaders: map[string]string{
"X-Custom-Request-Header": "test_request",
},
})
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
header.ServeHTTP(res, req, nil)
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"), "Did not get expected header")
header = newHeader(HeaderOptions{
CustomRequestHeaders: map[string]string{
"X-Custom-Request-Header": "",
},
})
header.ServeHTTP(res, req, nil)
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
assert.Equal(t, "", req.Header.Get("X-Custom-Request-Header"), "This header is not expected")
}

View file

@ -1,67 +0,0 @@
package middlewares
import (
"fmt"
"net/http"
"github.com/containous/traefik/ip"
"github.com/containous/traefik/log"
"github.com/containous/traefik/middlewares/tracing"
"github.com/pkg/errors"
"github.com/urfave/negroni"
)
// IPWhiteLister is a middleware that provides Checks of the Requesting IP against a set of Whitelists
type IPWhiteLister struct {
handler negroni.Handler
whiteLister *ip.Checker
strategy ip.Strategy
}
// NewIPWhiteLister builds a new IPWhiteLister given a list of CIDR-Strings to whitelist
func NewIPWhiteLister(whiteList []string, strategy ip.Strategy) (*IPWhiteLister, error) {
if len(whiteList) == 0 {
return nil, errors.New("no white list provided")
}
checker, err := ip.NewChecker(whiteList)
if err != nil {
return nil, fmt.Errorf("parsing CIDR whitelist %s: %v", whiteList, err)
}
whiteLister := IPWhiteLister{
strategy: strategy,
whiteLister: checker,
}
whiteLister.handler = negroni.HandlerFunc(whiteLister.handle)
log.Debugf("configured IP white list: %s", whiteList)
return &whiteLister, nil
}
func (wl *IPWhiteLister) handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
err := wl.whiteLister.IsAuthorized(wl.strategy.GetIP(r))
if err != nil {
tracing.SetErrorAndDebugLog(r, "request %+v - rejecting: %v", r, err)
reject(w)
return
}
log.Debugf("Accept %s: %+v", wl.strategy.GetIP(r), r)
tracing.SetErrorAndDebugLog(r, "request %+v matched white list %v - passing", r, wl.whiteLister)
next.ServeHTTP(w, r)
}
func (wl *IPWhiteLister) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
wl.handler.ServeHTTP(rw, r, next)
}
func reject(w http.ResponseWriter) {
statusCode := http.StatusForbidden
w.WriteHeader(statusCode)
_, err := w.Write([]byte(http.StatusText(statusCode)))
if err != nil {
log.Error(err)
}
}

View file

@ -0,0 +1,85 @@
package ipwhitelist
import (
"context"
"fmt"
"net/http"
"github.com/containous/traefik/config"
"github.com/containous/traefik/ip"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
const (
typeName = "IPWhiteLister"
)
// ipWhiteLister is a middleware that provides Checks of the Requesting IP against a set of Whitelists
type ipWhiteLister struct {
next http.Handler
whiteLister *ip.Checker
strategy ip.Strategy
name string
}
// New builds a new IPWhiteLister given a list of CIDR-Strings to whitelist
func New(ctx context.Context, next http.Handler, config config.IPWhiteList, name string) (http.Handler, error) {
logger := middlewares.GetLogger(ctx, name, typeName)
logger.Debug("Creating middleware")
if len(config.SourceRange) == 0 {
return nil, errors.New("sourceRange is empty, IPWhiteLister not created")
}
checker, err := ip.NewChecker(config.SourceRange)
if err != nil {
return nil, fmt.Errorf("cannot parse CIDR whitelist %s: %v", config.SourceRange, err)
}
strategy, err := config.IPStrategy.Get()
if err != nil {
return nil, err
}
logger.Debugf("Setting up IPWhiteLister with sourceRange: %s", config.SourceRange)
return &ipWhiteLister{
strategy: strategy,
whiteLister: checker,
next: next,
name: name,
}, nil
}
func (wl *ipWhiteLister) GetTracingInformation() (string, ext.SpanKindEnum) {
return wl.name, tracing.SpanKindNoneEnum
}
func (wl *ipWhiteLister) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), wl.name, typeName)
err := wl.whiteLister.IsAuthorized(wl.strategy.GetIP(req))
if err != nil {
logMessage := fmt.Sprintf("rejecting request %+v: %v", req, err)
logger.Debug(logMessage)
tracing.SetErrorWithEvent(req, logMessage)
reject(logger, rw)
return
}
logger.Debugf("Accept %s: %+v", wl.strategy.GetIP(req), req)
wl.next.ServeHTTP(rw, req)
}
func reject(logger logrus.FieldLogger, rw http.ResponseWriter) {
statusCode := http.StatusForbidden
rw.WriteHeader(statusCode)
_, err := rw.Write([]byte(http.StatusText(statusCode)))
if err != nil {
logger.Error(err)
}
}

View file

@ -1,11 +1,12 @@
package middlewares
package ipwhitelist
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/ip"
"github.com/containous/traefik/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -13,18 +14,21 @@ import (
func TestNewIPWhiteLister(t *testing.T) {
testCases := []struct {
desc string
whiteList []string
expectedError string
whiteList config.IPWhiteList
expectedError bool
}{
{
desc: "invalid IP",
whiteList: []string{"foo"},
expectedError: "parsing CIDR whitelist [foo]: parsing CIDR trusted IPs <nil>: invalid CIDR address: foo",
desc: "invalid IP",
whiteList: config.IPWhiteList{
SourceRange: []string{"foo"},
},
expectedError: true,
},
{
desc: "valid IP",
whiteList: []string{"10.10.10.10"},
expectedError: "",
desc: "valid IP",
whiteList: config.IPWhiteList{
SourceRange: []string{"10.10.10.10"},
},
},
}
@ -33,10 +37,11 @@ func TestNewIPWhiteLister(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
whiteLister, err := NewIPWhiteLister(test.whiteList, &ip.RemoteAddrStrategy{})
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
whiteLister, err := New(context.Background(), next, test.whiteList, "traefikTest")
if len(test.expectedError) > 0 {
assert.EqualError(t, err, test.expectedError)
if test.expectedError {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.NotNil(t, whiteLister)
@ -48,19 +53,23 @@ func TestNewIPWhiteLister(t *testing.T) {
func TestIPWhiteLister_ServeHTTP(t *testing.T) {
testCases := []struct {
desc string
whiteList []string
whiteList config.IPWhiteList
remoteAddr string
expected int
}{
{
desc: "authorized with remote address",
whiteList: []string{"20.20.20.20"},
desc: "authorized with remote address",
whiteList: config.IPWhiteList{
SourceRange: []string{"20.20.20.20"},
},
remoteAddr: "20.20.20.20:1234",
expected: 200,
},
{
desc: "non authorized with remote address",
whiteList: []string{"20.20.20.20"},
desc: "non authorized with remote address",
whiteList: config.IPWhiteList{
SourceRange: []string{"20.20.20.20"},
},
remoteAddr: "20.20.20.21:1234",
expected: 403,
},
@ -71,7 +80,8 @@ func TestIPWhiteLister_ServeHTTP(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
whiteLister, err := NewIPWhiteLister(test.whiteList, &ip.RemoteAddrStrategy{})
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
whiteLister, err := New(context.Background(), next, test.whiteList, "traefikTest")
require.NoError(t, err)
recorder := httptest.NewRecorder()
@ -82,9 +92,7 @@ func TestIPWhiteLister_ServeHTTP(t *testing.T) {
req.RemoteAddr = test.remoteAddr
}
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
whiteLister.ServeHTTP(recorder, req, next)
whiteLister.ServeHTTP(recorder, req)
assert.Equal(t, test.expected, recorder.Code)
})

View file

@ -0,0 +1,48 @@
package maxconnection
import (
"context"
"fmt"
"net/http"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/vulcand/oxy/connlimit"
"github.com/vulcand/oxy/utils"
)
const (
typeName = "MaxConnection"
)
type maxConnection struct {
handler http.Handler
name string
}
// New creates a max connection middleware.
func New(ctx context.Context, next http.Handler, maxConns config.MaxConn, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
extractFunc, err := utils.NewExtractor(maxConns.ExtractorFunc)
if err != nil {
return nil, fmt.Errorf("error creating connection limit: %v", err)
}
handler, err := connlimit.New(next, extractFunc, maxConns.Amount)
if err != nil {
return nil, fmt.Errorf("error creating connection limit: %v", err)
}
return &maxConnection{handler: handler, name: name}, nil
}
func (mc *maxConnection) GetTracingInformation() (string, ext.SpanKindEnum) {
return mc.name, tracing.SpanKindNoneEnum
}
func (mc *maxConnection) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
mc.handler.ServeHTTP(rw, req)
}

View file

@ -1,131 +0,0 @@
package middlewares
import (
"net/http"
"strconv"
"strings"
"sync/atomic"
"time"
"unicode/utf8"
"github.com/containous/traefik/log"
"github.com/containous/traefik/metrics"
gokitmetrics "github.com/go-kit/kit/metrics"
"github.com/urfave/negroni"
)
const (
protoHTTP = "http"
protoSSE = "sse"
protoWebsocket = "websocket"
)
// NewEntryPointMetricsMiddleware creates a new metrics middleware for an Entrypoint.
func NewEntryPointMetricsMiddleware(registry metrics.Registry, entryPointName string) negroni.Handler {
return &metricsMiddleware{
reqsCounter: registry.EntrypointReqsCounter(),
reqDurationHistogram: registry.EntrypointReqDurationHistogram(),
openConnsGauge: registry.EntrypointOpenConnsGauge(),
baseLabels: []string{"entrypoint", entryPointName},
}
}
// NewBackendMetricsMiddleware creates a new metrics middleware for a Backend.
func NewBackendMetricsMiddleware(registry metrics.Registry, backendName string) negroni.Handler {
return &metricsMiddleware{
reqsCounter: registry.BackendReqsCounter(),
reqDurationHistogram: registry.BackendReqDurationHistogram(),
openConnsGauge: registry.BackendOpenConnsGauge(),
baseLabels: []string{"backend", backendName},
}
}
type metricsMiddleware struct {
// Important: Since this int64 field is using sync/atomic, it has to be at the top of the struct due to a bug on 32-bit platform
// See: https://golang.org/pkg/sync/atomic/ for more information
openConns int64
reqsCounter gokitmetrics.Counter
reqDurationHistogram gokitmetrics.Histogram
openConnsGauge gokitmetrics.Gauge
baseLabels []string
}
func (m *metricsMiddleware) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
labels := []string{"method", getMethod(r), "protocol", getRequestProtocol(r)}
labels = append(labels, m.baseLabels...)
openConns := atomic.AddInt64(&m.openConns, 1)
m.openConnsGauge.With(labels...).Set(float64(openConns))
defer func(labelValues []string) {
openConns := atomic.AddInt64(&m.openConns, -1)
m.openConnsGauge.With(labelValues...).Set(float64(openConns))
}(labels)
start := time.Now()
recorder := &responseRecorder{rw, http.StatusOK}
next(recorder, r)
labels = append(labels, "code", strconv.Itoa(recorder.statusCode))
m.reqsCounter.With(labels...).Add(1)
m.reqDurationHistogram.With(labels...).Observe(time.Since(start).Seconds())
}
func getRequestProtocol(req *http.Request) string {
switch {
case isWebsocketRequest(req):
return protoWebsocket
case isSSERequest(req):
return protoSSE
default:
return protoHTTP
}
}
// isWebsocketRequest determines if the specified HTTP request is a websocket handshake request.
func isWebsocketRequest(req *http.Request) bool {
return containsHeader(req, "Connection", "upgrade") && containsHeader(req, "Upgrade", "websocket")
}
// isSSERequest determines if the specified HTTP request is a request for an event subscription.
func isSSERequest(req *http.Request) bool {
return containsHeader(req, "Accept", "text/event-stream")
}
func containsHeader(req *http.Request, name, value string) bool {
items := strings.Split(req.Header.Get(name), ",")
for _, item := range items {
if value == strings.ToLower(strings.TrimSpace(item)) {
return true
}
}
return false
}
func getMethod(r *http.Request) string {
if !utf8.ValidString(r.Method) {
log.Warnf("Invalid HTTP method encoding: %s", r.Method)
return "NON_UTF8_HTTP_METHOD"
}
return r.Method
}
type retryMetrics interface {
BackendRetriesCounter() gokitmetrics.Counter
}
// NewMetricsRetryListener instantiates a MetricsRetryListener with the given retryMetrics.
func NewMetricsRetryListener(retryMetrics retryMetrics, backendName string) RetryListener {
return &MetricsRetryListener{retryMetrics: retryMetrics, backendName: backendName}
}
// MetricsRetryListener is an implementation of the RetryListener interface to
// record RequestMetrics about retry attempts.
type MetricsRetryListener struct {
retryMetrics retryMetrics
backendName string
}
// Retried tracks the retry in the RequestMetrics implementation.
func (m *MetricsRetryListener) Retried(req *http.Request, attempt int) {
m.retryMetrics.BackendRetriesCounter().With("backend", m.backendName).Add(1)
}

View file

@ -1,42 +0,0 @@
package middlewares
import (
"net/http"
"net/http/httptest"
"reflect"
"testing"
"github.com/containous/traefik/testhelpers"
"github.com/go-kit/kit/metrics"
)
func TestMetricsRetryListener(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
retryMetrics := newCollectingRetryMetrics()
retryListener := NewMetricsRetryListener(retryMetrics, "backendName")
retryListener.Retried(req, 1)
retryListener.Retried(req, 2)
wantCounterValue := float64(2)
if retryMetrics.retriesCounter.CounterValue != wantCounterValue {
t.Errorf("got counter value of %f, want %f", retryMetrics.retriesCounter.CounterValue, wantCounterValue)
}
wantLabelValues := []string{"backend", "backendName"}
if !reflect.DeepEqual(retryMetrics.retriesCounter.LastLabelValues, wantLabelValues) {
t.Errorf("wrong label values %v used, want %v", retryMetrics.retriesCounter.LastLabelValues, wantLabelValues)
}
}
// collectingRetryMetrics is an implementation of the retryMetrics interface that can be used inside tests to collect the times Add() was called.
type collectingRetryMetrics struct {
retriesCounter *testhelpers.CollectingCounter
}
func newCollectingRetryMetrics() *collectingRetryMetrics {
return &collectingRetryMetrics{retriesCounter: &testhelpers.CollectingCounter{}}
}
func (metrics *collectingRetryMetrics) BackendRetriesCounter() metrics.Counter {
return metrics.retriesCounter
}

13
middlewares/middleware.go Normal file
View file

@ -0,0 +1,13 @@
package middlewares
import (
"context"
"github.com/containous/traefik/log"
"github.com/sirupsen/logrus"
)
// GetLogger creates a logger configured with the middleware fields.
func GetLogger(ctx context.Context, middleware string, middlewareType string) logrus.FieldLogger {
return log.FromContext(ctx).WithField(log.MiddlewareName, middleware).WithField(log.MiddlewareType, middlewareType)
}

View file

@ -0,0 +1,253 @@
package passtlsclientcert
import (
"context"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"net/http"
"net/url"
"strings"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/sirupsen/logrus"
)
const (
xForwardedTLSClientCert = "X-Forwarded-Tls-Client-Cert"
xForwardedTLSClientCertInfos = "X-Forwarded-Tls-Client-Cert-infos"
typeName = "PassClientTLSCert"
)
// passTLSClientCert is a middleware that helps setup a few tls info features.
type passTLSClientCert struct {
next http.Handler
name string
pem bool // pass the sanitized pem to the backend in a specific header
infos *tlsClientCertificateInfos // pass selected information from the client certificate
}
// New constructs a new PassTLSClientCert instance from supplied frontend header struct.
func New(ctx context.Context, next http.Handler, config config.PassTLSClientCert, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
return &passTLSClientCert{
next: next,
name: name,
pem: config.PEM,
infos: newTLSClientInfos(config.Infos),
}, nil
}
// tlsClientCertificateInfos is a struct for specifying the configuration for the passTLSClientCert middleware.
type tlsClientCertificateInfos struct {
notAfter bool
notBefore bool
subject *tlsCLientCertificateSubjectInfos
sans bool
}
func newTLSClientInfos(infos *config.TLSClientCertificateInfos) *tlsClientCertificateInfos {
if infos == nil {
return nil
}
return &tlsClientCertificateInfos{
notBefore: infos.NotBefore,
notAfter: infos.NotAfter,
sans: infos.Sans,
subject: newTLSCLientCertificateSubjectInfos(infos.Subject),
}
}
// tlsCLientCertificateSubjectInfos contains the configuration for the certificate subject infos.
type tlsCLientCertificateSubjectInfos struct {
country bool
province bool
locality bool
Organization bool
commonName bool
serialNumber bool
}
func newTLSCLientCertificateSubjectInfos(infos *config.TLSCLientCertificateSubjectInfos) *tlsCLientCertificateSubjectInfos {
if infos == nil {
return nil
}
return &tlsCLientCertificateSubjectInfos{
serialNumber: infos.SerialNumber,
commonName: infos.CommonName,
country: infos.Country,
locality: infos.Locality,
Organization: infos.Organization,
province: infos.Province,
}
}
func (p *passTLSClientCert) GetTracingInformation() (string, ext.SpanKindEnum) {
return p.name, tracing.SpanKindNoneEnum
}
func (p *passTLSClientCert) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), p.name, typeName)
p.modifyRequestHeaders(logger, req)
p.next.ServeHTTP(rw, req)
}
// getSubjectInfos extract the requested information from the certificate subject.
func (p *passTLSClientCert) getSubjectInfos(cs *pkix.Name) string {
var subject string
if p.infos != nil && p.infos.subject != nil {
options := p.infos.subject
var content []string
if options.country && len(cs.Country) > 0 {
content = append(content, fmt.Sprintf("C=%s", cs.Country[0]))
}
if options.province && len(cs.Province) > 0 {
content = append(content, fmt.Sprintf("ST=%s", cs.Province[0]))
}
if options.locality && len(cs.Locality) > 0 {
content = append(content, fmt.Sprintf("L=%s", cs.Locality[0]))
}
if options.Organization && len(cs.Organization) > 0 {
content = append(content, fmt.Sprintf("O=%s", cs.Organization[0]))
}
if options.commonName && len(cs.CommonName) > 0 {
content = append(content, fmt.Sprintf("CN=%s", cs.CommonName))
}
if len(content) > 0 {
subject = `Subject="` + strings.Join(content, ",") + `"`
}
}
return subject
}
// getXForwardedTLSClientCertInfos Build a string with the wanted client certificates information
// like Subject="C=%s,ST=%s,L=%s,O=%s,CN=%s",NB=%d,NA=%d,SAN=%s;
func (p *passTLSClientCert) getXForwardedTLSClientCertInfos(certs []*x509.Certificate) string {
var headerValues []string
for _, peerCert := range certs {
var values []string
var sans string
var nb string
var na string
subject := p.getSubjectInfos(&peerCert.Subject)
if len(subject) > 0 {
values = append(values, subject)
}
ci := p.infos
if ci != nil {
if ci.notBefore {
nb = fmt.Sprintf("NB=%d", uint64(peerCert.NotBefore.Unix()))
values = append(values, nb)
}
if ci.notAfter {
na = fmt.Sprintf("NA=%d", uint64(peerCert.NotAfter.Unix()))
values = append(values, na)
}
if ci.sans {
sans = fmt.Sprintf("SAN=%s", strings.Join(getSANs(peerCert), ","))
values = append(values, sans)
}
}
value := strings.Join(values, ",")
headerValues = append(headerValues, value)
}
return strings.Join(headerValues, ";")
}
// modifyRequestHeaders set the wanted headers with the certificates information.
func (p *passTLSClientCert) modifyRequestHeaders(logger logrus.FieldLogger, r *http.Request) {
if p.pem {
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
r.Header.Set(xForwardedTLSClientCert, getXForwardedTLSClientCert(logger, r.TLS.PeerCertificates))
} else {
logger.Warn("Try to extract certificate on a request without TLS")
}
}
if p.infos != nil {
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
headerContent := p.getXForwardedTLSClientCertInfos(r.TLS.PeerCertificates)
r.Header.Set(xForwardedTLSClientCertInfos, url.QueryEscape(headerContent))
} else {
logger.Warn("Try to extract certificate on a request without TLS")
}
}
}
// sanitize As we pass the raw certificates, remove the useless data and make it http request compliant.
func sanitize(cert []byte) string {
s := string(cert)
r := strings.NewReplacer("-----BEGIN CERTIFICATE-----", "",
"-----END CERTIFICATE-----", "",
"\n", "")
cleaned := r.Replace(s)
return url.QueryEscape(cleaned)
}
// extractCertificate extract the certificate from the request.
func extractCertificate(logger logrus.FieldLogger, cert *x509.Certificate) string {
b := pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
certPEM := pem.EncodeToMemory(&b)
if certPEM == nil {
logger.Error("Cannot extract the certificate content")
return ""
}
return sanitize(certPEM)
}
// getXForwardedTLSClientCert Build a string with the client certificates.
func getXForwardedTLSClientCert(logger logrus.FieldLogger, certs []*x509.Certificate) string {
var headerValues []string
for _, peerCert := range certs {
headerValues = append(headerValues, extractCertificate(logger, peerCert))
}
return strings.Join(headerValues, ",")
}
// getSANs get the Subject Alternate Name values.
func getSANs(cert *x509.Certificate) []string {
var sans []string
if cert == nil {
return sans
}
sans = append(cert.DNSNames, cert.EmailAddresses...)
var ips []string
for _, ip := range cert.IPAddresses {
ips = append(ips, ip.String())
}
sans = append(sans, ips...)
var uris []string
for _, uri := range cert.URIs {
uris = append(uris, uri.String())
}
return append(sans, uris...)
}

View file

@ -1,6 +1,7 @@
package middlewares
package passtlsclientcert
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
@ -12,8 +13,8 @@ import (
"strings"
"testing"
"github.com/containous/traefik/config"
"github.com/containous/traefik/testhelpers"
"github.com/containous/traefik/types"
"github.com/stretchr/testify/require"
)
@ -69,8 +70,8 @@ jECvgAY7Nfd9mZ1KtyNaW31is+kag7NsvjxU/kM=
Validity
Not Before: Jul 18 08:00:16 2018 GMT
Not After : Jul 18 08:00:16 2019 GMT
Subject: C=FR, ST=SomeState, L=Toulouse, O=Cheese, CN=*.cheese.org
Subject Public Key Info:
subject: C=FR, ST=SomeState, L=Toulouse, O=Cheese, CN=*.cheese.org
subject Public Key Info:
Public Key Algorithm: rsaEncryption
Public-Key: (2048 bit)
Modulus:
@ -96,9 +97,9 @@ jECvgAY7Nfd9mZ1KtyNaW31is+kag7NsvjxU/kM=
X509v3 extensions:
X509v3 Basic Constraints:
CA:FALSE
X509v3 Subject Alternative Name:
X509v3 subject Alternative Name:
DNS:*.cheese.org, DNS:*.cheese.net, DNS:cheese.in, IP Address:10.0.1.0, IP Address:10.0.1.2, email:test@cheese.org, email:test@cheese.net
X509v3 Subject Key Identifier:
X509v3 subject Key Identifier:
AB:6B:89:25:11:FC:5E:7B:D4:B0:F7:D4:B6:D9:EB:D0:30:93:E5:58
Signature Algorithm: sha1WithRSAEncryption
ad:87:84:a0:88:a3:4c:d9:0a:c0:14:e4:2d:9a:1d:bb:57:b7:
@ -147,7 +148,7 @@ func getCleanCertContents(certContents []string) string {
var cleanedCertContent []string
for _, certContent := range certContents {
cert := re.FindString(string(certContent))
cert := re.FindString(certContent)
cleanedCertContent = append(cleanedCertContent, sanitize([]byte(cert)))
}
@ -183,8 +184,11 @@ func buildTLSWith(certContents []string) *tls.ConnectionState {
return &tls.ConnectionState{PeerCertificates: peerCertificates}
}
var myPassTLSClientCustomHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("bar"))
var next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("bar"))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
})
func getExpectedSanitized(s string) string {
@ -234,12 +238,12 @@ jECvgAY7Nfd9mZ1KtyNaW31is+kag7NsvjxU/kM=`),
}
func TestTlsClientheadersWithPEM(t *testing.T) {
func TestTLSClientHeadersWithPEM(t *testing.T) {
testCases := []struct {
desc string
certContents []string // set the request TLS attribute if defined
tlsClientCertHeaders *types.TLSClientHeaders
expectedHeader string
desc string
certContents []string // set the request TLS attribute if defined
config config.PassTLSClientCert
expectedHeader string
}{
{
desc: "No TLS, no option",
@ -249,31 +253,32 @@ func TestTlsClientheadersWithPEM(t *testing.T) {
certContents: []string{minimalCert},
},
{
desc: "No TLS, with pem option true",
tlsClientCertHeaders: &types.TLSClientHeaders{PEM: true},
desc: "No TLS, with pem option true",
config: config.PassTLSClientCert{PEM: true},
},
{
desc: "TLS with simple certificate, with pem option true",
certContents: []string{minimalCert},
tlsClientCertHeaders: &types.TLSClientHeaders{PEM: true},
expectedHeader: getCleanCertContents([]string{minimalCert}),
desc: "TLS with simple certificate, with pem option true",
certContents: []string{minimalCert},
config: config.PassTLSClientCert{PEM: true},
expectedHeader: getCleanCertContents([]string{minimalCert}),
},
{
desc: "TLS with complete certificate, with pem option true",
certContents: []string{completeCert},
tlsClientCertHeaders: &types.TLSClientHeaders{PEM: true},
expectedHeader: getCleanCertContents([]string{completeCert}),
desc: "TLS with complete certificate, with pem option true",
certContents: []string{completeCert},
config: config.PassTLSClientCert{PEM: true},
expectedHeader: getCleanCertContents([]string{completeCert}),
},
{
desc: "TLS with two certificate, with pem option true",
certContents: []string{minimalCert, completeCert},
tlsClientCertHeaders: &types.TLSClientHeaders{PEM: true},
expectedHeader: getCleanCertContents([]string{minimalCert, completeCert}),
desc: "TLS with two certificate, with pem option true",
certContents: []string{minimalCert, completeCert},
config: config.PassTLSClientCert{PEM: true},
expectedHeader: getCleanCertContents([]string{minimalCert, completeCert}),
},
}
for _, test := range testCases {
tlsClientHeaders := NewTLSClientHeaders(&types.Frontend{PassTLSClientCert: test.tlsClientCertHeaders})
tlsClientHeaders, err := New(context.Background(), next, test.config, "foo")
require.NoError(t, err)
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://example.com/foo", nil)
@ -282,7 +287,7 @@ func TestTlsClientheadersWithPEM(t *testing.T) {
req.TLS = buildTLSWith(test.certContents)
}
tlsClientHeaders.ServeHTTP(res, req, myPassTLSClientCustomHandler)
tlsClientHeaders.ServeHTTP(res, req)
test := test
t.Run(test.desc, func(t *testing.T) {
@ -334,8 +339,8 @@ func TestGetSans(t *testing.T) {
for _, test := range testCases {
sans := getSANs(test.cert)
test := test
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
@ -351,15 +356,15 @@ func TestGetSans(t *testing.T) {
}
func TestTlsClientheadersWithCertInfos(t *testing.T) {
func TestTLSClientHeadersWithCertInfos(t *testing.T) {
minimalCertAllInfos := `Subject="C=FR,ST=Some-State,O=Internet Widgits Pty Ltd",NB=1531902496,NA=1534494496,SAN=`
completeCertAllInfos := `Subject="C=FR,ST=SomeState,L=Toulouse,O=Cheese,CN=*.cheese.org",NB=1531900816,NA=1563436816,SAN=*.cheese.org,*.cheese.net,cheese.in,test@cheese.org,test@cheese.net,10.0.1.0,10.0.1.2`
testCases := []struct {
desc string
certContents []string // set the request TLS attribute if defined
tlsClientCertHeaders *types.TLSClientHeaders
expectedHeader string
desc string
certContents []string // set the request TLS attribute if defined
config config.PassTLSClientCert
expectedHeader string
}{
{
desc: "No TLS, no option",
@ -370,9 +375,9 @@ func TestTlsClientheadersWithCertInfos(t *testing.T) {
},
{
desc: "No TLS, with pem option true",
tlsClientCertHeaders: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
Subject: &types.TLSCLientCertificateSubjectInfos{
config: config.PassTLSClientCert{
Infos: &config.TLSClientCertificateInfos{
Subject: &config.TLSCLientCertificateSubjectInfos{
CommonName: true,
Organization: true,
Locality: true,
@ -385,21 +390,21 @@ func TestTlsClientheadersWithCertInfos(t *testing.T) {
},
{
desc: "No TLS, with pem option true with no flag",
tlsClientCertHeaders: &types.TLSClientHeaders{
config: config.PassTLSClientCert{
PEM: false,
Infos: &types.TLSClientCertificateInfos{
Subject: &types.TLSCLientCertificateSubjectInfos{},
Infos: &config.TLSClientCertificateInfos{
Subject: &config.TLSCLientCertificateSubjectInfos{},
},
},
},
{
desc: "TLS with simple certificate, with all infos",
certContents: []string{minimalCert},
tlsClientCertHeaders: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
config: config.PassTLSClientCert{
Infos: &config.TLSClientCertificateInfos{
NotAfter: true,
NotBefore: true,
Subject: &types.TLSCLientCertificateSubjectInfos{
Subject: &config.TLSCLientCertificateSubjectInfos{
CommonName: true,
Organization: true,
Locality: true,
@ -415,10 +420,10 @@ func TestTlsClientheadersWithCertInfos(t *testing.T) {
{
desc: "TLS with simple certificate, with some infos",
certContents: []string{minimalCert},
tlsClientCertHeaders: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
config: config.PassTLSClientCert{
Infos: &config.TLSClientCertificateInfos{
NotAfter: true,
Subject: &types.TLSCLientCertificateSubjectInfos{
Subject: &config.TLSCLientCertificateSubjectInfos{
Organization: true,
},
Sans: true,
@ -429,11 +434,11 @@ func TestTlsClientheadersWithCertInfos(t *testing.T) {
{
desc: "TLS with complete certificate, with all infos",
certContents: []string{completeCert},
tlsClientCertHeaders: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
config: config.PassTLSClientCert{
Infos: &config.TLSClientCertificateInfos{
NotAfter: true,
NotBefore: true,
Subject: &types.TLSCLientCertificateSubjectInfos{
Subject: &config.TLSCLientCertificateSubjectInfos{
CommonName: true,
Organization: true,
Locality: true,
@ -449,11 +454,11 @@ func TestTlsClientheadersWithCertInfos(t *testing.T) {
{
desc: "TLS with 2 certificates, with all infos",
certContents: []string{minimalCert, completeCert},
tlsClientCertHeaders: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
config: config.PassTLSClientCert{
Infos: &config.TLSClientCertificateInfos{
NotAfter: true,
NotBefore: true,
Subject: &types.TLSCLientCertificateSubjectInfos{
Subject: &config.TLSCLientCertificateSubjectInfos{
CommonName: true,
Organization: true,
Locality: true,
@ -469,7 +474,8 @@ func TestTlsClientheadersWithCertInfos(t *testing.T) {
}
for _, test := range testCases {
tlsClientHeaders := NewTLSClientHeaders(&types.Frontend{PassTLSClientCert: test.tlsClientCertHeaders})
tlsClientHeaders, err := New(context.Background(), next, test.config, "foo")
require.NoError(t, err)
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://example.com/foo", nil)
@ -478,7 +484,7 @@ func TestTlsClientheadersWithCertInfos(t *testing.T) {
req.TLS = buildTLSWith(test.certContents)
}
tlsClientHeaders.ServeHTTP(res, req, myPassTLSClientCustomHandler)
tlsClientHeaders.ServeHTTP(res, req)
test := test
t.Run(test.desc, func(t *testing.T) {
@ -497,303 +503,3 @@ func TestTlsClientheadersWithCertInfos(t *testing.T) {
}
}
func TestNewTLSClientHeadersFromStruct(t *testing.T) {
testCases := []struct {
desc string
frontend *types.Frontend
expected *TLSClientHeaders
}{
{
desc: "Without frontend",
},
{
desc: "frontend without the option",
frontend: &types.Frontend{},
expected: &TLSClientHeaders{},
},
{
desc: "frontend with the pem set false",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
PEM: false,
},
},
expected: &TLSClientHeaders{PEM: false},
},
{
desc: "frontend with the pem set true",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
PEM: true,
},
},
expected: &TLSClientHeaders{PEM: true},
},
{
desc: "frontend with the Infos with no flag",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
NotAfter: false,
NotBefore: false,
Sans: false,
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{},
},
},
{
desc: "frontend with the Infos basic",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
NotAfter: true,
NotBefore: true,
Sans: true,
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
NotBefore: true,
NotAfter: true,
Sans: true,
},
},
},
{
desc: "frontend with the Infos NotAfter",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
NotAfter: true,
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
NotAfter: true,
},
},
},
{
desc: "frontend with the Infos NotBefore",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
NotBefore: true,
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
NotBefore: true,
},
},
},
{
desc: "frontend with the Infos Sans",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
Sans: true,
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
Sans: true,
},
},
},
{
desc: "frontend with the Infos Subject Organization",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
Subject: &types.TLSCLientCertificateSubjectInfos{
Organization: true,
},
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
Subject: &TLSCLientCertificateSubjectInfos{
Organization: true,
},
},
},
},
{
desc: "frontend with the Infos Subject Country",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
Subject: &types.TLSCLientCertificateSubjectInfos{
Country: true,
},
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
Subject: &TLSCLientCertificateSubjectInfos{
Country: true,
},
},
},
},
{
desc: "frontend with the Infos Subject SerialNumber",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
Subject: &types.TLSCLientCertificateSubjectInfos{
SerialNumber: true,
},
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
Subject: &TLSCLientCertificateSubjectInfos{
SerialNumber: true,
},
},
},
},
{
desc: "frontend with the Infos Subject Province",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
Subject: &types.TLSCLientCertificateSubjectInfos{
Province: true,
},
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
Subject: &TLSCLientCertificateSubjectInfos{
Province: true,
},
},
},
},
{
desc: "frontend with the Infos Subject Locality",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
Subject: &types.TLSCLientCertificateSubjectInfos{
Locality: true,
},
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
Subject: &TLSCLientCertificateSubjectInfos{
Locality: true,
},
},
},
},
{
desc: "frontend with the Infos Subject CommonName",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
Subject: &types.TLSCLientCertificateSubjectInfos{
CommonName: true,
},
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
Subject: &TLSCLientCertificateSubjectInfos{
CommonName: true,
},
},
},
},
{
desc: "frontend with the Infos NotBefore",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
Sans: true,
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
Sans: true,
},
},
},
{
desc: "frontend with the Infos all",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
NotAfter: true,
NotBefore: true,
Subject: &types.TLSCLientCertificateSubjectInfos{
CommonName: true,
Country: true,
Locality: true,
Organization: true,
Province: true,
SerialNumber: true,
},
Sans: true,
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
NotBefore: true,
NotAfter: true,
Sans: true,
Subject: &TLSCLientCertificateSubjectInfos{
Province: true,
Organization: true,
Locality: true,
Country: true,
CommonName: true,
SerialNumber: true,
},
}},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
require.Equal(t, test.expected, NewTLSClientHeaders(test.frontend))
})
}
}

View file

@ -1,62 +0,0 @@
package pipelining
import (
"bufio"
"net"
"net/http"
)
// Pipelining returns a middleware
type Pipelining struct {
next http.Handler
}
// NewPipelining returns a new Pipelining instance
func NewPipelining(next http.Handler) *Pipelining {
return &Pipelining{
next: next,
}
}
func (p *Pipelining) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
// https://github.com/golang/go/blob/3d59583836630cf13ec4bfbed977d27b1b7adbdc/src/net/http/server.go#L201-L218
if r.Method == http.MethodPut || r.Method == http.MethodPost {
p.next.ServeHTTP(rw, r)
} else {
p.next.ServeHTTP(&writerWithoutCloseNotify{rw}, r)
}
}
// writerWithoutCloseNotify helps to disable closeNotify
type writerWithoutCloseNotify struct {
W http.ResponseWriter
}
// Header returns the response headers.
func (w *writerWithoutCloseNotify) Header() http.Header {
return w.W.Header()
}
// Write writes the data to the connection as part of an HTTP reply.
func (w *writerWithoutCloseNotify) Write(buf []byte) (int, error) {
return w.W.Write(buf)
}
// WriteHeader sends an HTTP response header with the provided
// status code.
func (w *writerWithoutCloseNotify) WriteHeader(code int) {
w.W.WriteHeader(code)
}
// Flush sends any buffered data to the client.
func (w *writerWithoutCloseNotify) Flush() {
if f, ok := w.W.(http.Flusher); ok {
f.Flush()
}
}
// Hijack hijacks the connection.
func (w *writerWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.W.(http.Hijacker).Hijack()
}

View file

@ -1,69 +0,0 @@
package pipelining
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
type recorderWithCloseNotify struct {
*httptest.ResponseRecorder
}
func (r *recorderWithCloseNotify) CloseNotify() <-chan bool {
panic("implement me")
}
func TestNewPipelining(t *testing.T) {
testCases := []struct {
desc string
HTTPMethod string
implementCloseNotifier bool
}{
{
desc: "should not implement CloseNotifier with GET method",
HTTPMethod: http.MethodGet,
implementCloseNotifier: false,
},
{
desc: "should implement CloseNotifier with PUT method",
HTTPMethod: http.MethodPut,
implementCloseNotifier: true,
},
{
desc: "should implement CloseNotifier with POST method",
HTTPMethod: http.MethodPost,
implementCloseNotifier: true,
},
{
desc: "should not implement CloseNotifier with GET method",
HTTPMethod: http.MethodHead,
implementCloseNotifier: false,
},
{
desc: "should not implement CloseNotifier with PROPFIND method",
HTTPMethod: "PROPFIND",
implementCloseNotifier: false,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, ok := w.(http.CloseNotifier)
assert.Equal(t, test.implementCloseNotifier, ok)
w.WriteHeader(http.StatusOK)
})
handler := NewPipelining(nextHandler)
req := httptest.NewRequest(test.HTTPMethod, "http://localhost", nil)
handler.ServeHTTP(&recorderWithCloseNotify{httptest.NewRecorder()}, req)
})
}
}

View file

@ -0,0 +1,54 @@
package ratelimiter
import (
"context"
"net/http"
"time"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/vulcand/oxy/ratelimit"
"github.com/vulcand/oxy/utils"
)
const (
typeName = "RateLimiterType"
)
type rateLimiter struct {
handler http.Handler
name string
}
// New creates rate limiter middleware.
func New(ctx context.Context, next http.Handler, config config.RateLimit, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
extractFunc, err := utils.NewExtractor(config.ExtractorFunc)
if err != nil {
return nil, err
}
rateSet := ratelimit.NewRateSet()
for _, rate := range config.RateSet {
if err = rateSet.Add(time.Duration(rate.Period), rate.Average, rate.Burst); err != nil {
return nil, err
}
}
rl, err := ratelimit.New(next, extractFunc, rateSet)
if err != nil {
return nil, err
}
return &rateLimiter{handler: rl, name: name}, nil
}
func (r *rateLimiter) GetTracingInformation() (string, ext.SpanKindEnum) {
return r.name, tracing.SpanKindNoneEnum
}
func (r *rateLimiter) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
r.handler.ServeHTTP(rw, req)
}

View file

@ -1,51 +0,0 @@
package middlewares
import (
"net/http"
"runtime"
"github.com/containous/traefik/log"
"github.com/urfave/negroni"
)
// RecoverHandler recovers from a panic in http handlers
func RecoverHandler(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
defer recoverFunc(w, r)
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
// NegroniRecoverHandler recovers from a panic in negroni handlers
func NegroniRecoverHandler() negroni.Handler {
fn := func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
defer recoverFunc(w, r)
next.ServeHTTP(w, r)
}
return negroni.HandlerFunc(fn)
}
func recoverFunc(w http.ResponseWriter, r *http.Request) {
if err := recover(); err != nil {
if !shouldLogPanic(err) {
log.Debugf("Request has been aborted [%s - %s]: %v", r.RemoteAddr, r.URL, err)
return
}
log.Errorf("Recovered from panic in HTTP handler [%s - %s]: %+v", r.RemoteAddr, r.URL, err)
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
log.Errorf("Stack: %s", buf)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
}
// https://github.com/golang/go/blob/a0d6420d8be2ae7164797051ec74fa2a2df466a1/src/net/http/server.go#L1761-L1775
// https://github.com/golang/go/blob/c33153f7b416c03983324b3e8f869ce1116d84bc/src/net/http/httputil/reverseproxy.go#L284
func shouldLogPanic(panicValue interface{}) bool {
return panicValue != nil && panicValue != http.ErrAbortHandler
}

View file

@ -1,45 +0,0 @@
package middlewares
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/urfave/negroni"
)
func TestRecoverHandler(t *testing.T) {
fn := func(w http.ResponseWriter, r *http.Request) {
panic("I love panicing!")
}
recoverHandler := RecoverHandler(http.HandlerFunc(fn))
server := httptest.NewServer(recoverHandler)
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != http.StatusInternalServerError {
t.Fatalf("Received non-%d response: %d\n", http.StatusInternalServerError, resp.StatusCode)
}
}
func TestNegroniRecoverHandler(t *testing.T) {
n := negroni.New()
n.Use(NegroniRecoverHandler())
panicHandler := func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
panic("I love panicing!")
}
n.UseFunc(negroni.HandlerFunc(panicHandler))
server := httptest.NewServer(n)
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != http.StatusInternalServerError {
t.Fatalf("Received non-%d response: %d\n", http.StatusInternalServerError, resp.StatusCode)
}
}

View file

@ -0,0 +1,40 @@
package recovery
import (
"context"
"net/http"
"github.com/containous/traefik/middlewares"
"github.com/sirupsen/logrus"
)
const (
typeName = "Recovery"
)
type recovery struct {
next http.Handler
name string
}
// New creates recovery middleware.
func New(ctx context.Context, next http.Handler, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
return &recovery{
next: next,
name: name,
}, nil
}
func (re *recovery) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
defer recoverFunc(middlewares.GetLogger(req.Context(), re.name, typeName), rw)
re.next.ServeHTTP(rw, req)
}
func recoverFunc(logger logrus.FieldLogger, rw http.ResponseWriter) {
if err := recover(); err != nil {
logger.Errorf("Recovered from panic in http handler: %+v", err)
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
}

View file

@ -0,0 +1,27 @@
package recovery
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRecoverHandler(t *testing.T) {
fn := func(w http.ResponseWriter, r *http.Request) {
panic("I love panicing!")
}
recovery, err := New(context.Background(), http.HandlerFunc(fn), "foo-recovery")
require.NoError(t, err)
server := httptest.NewServer(recovery)
defer server.Close()
resp, err := http.Get(server.URL)
require.NoError(t, err)
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
}

View file

@ -2,110 +2,87 @@ package redirect
import (
"bytes"
"fmt"
"context"
"html/template"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"text/template"
"github.com/containous/traefik/configuration"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares"
"github.com/urfave/negroni"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/vulcand/oxy/utils"
)
const (
defaultRedirectRegex = `^(?:https?:\/\/)?([\w\._-]+)(?::\d+)?(.*)$`
typeName = "Redirect"
)
// NewEntryPointHandler create a new redirection handler base on entry point
func NewEntryPointHandler(dstEntryPoint *configuration.EntryPoint, permanent bool) (negroni.Handler, error) {
exp := regexp.MustCompile(`(:\d+)`)
match := exp.FindStringSubmatch(dstEntryPoint.Address)
if len(match) == 0 {
return nil, fmt.Errorf("bad Address format %q", dstEntryPoint.Address)
}
protocol := "http"
if dstEntryPoint.TLS != nil {
protocol = "https"
}
replacement := protocol + "://${1}" + match[0] + "${2}"
return NewRegexHandler(defaultRedirectRegex, replacement, permanent)
type redirect struct {
next http.Handler
regex *regexp.Regexp
replacement string
permanent bool
errHandler utils.ErrorHandler
name string
}
// NewRegexHandler create a new redirection handler base on regex
func NewRegexHandler(exp string, replacement string, permanent bool) (negroni.Handler, error) {
re, err := regexp.Compile(exp)
// New creates a redirect middleware.
func New(ctx context.Context, next http.Handler, config config.Redirect, name string) (http.Handler, error) {
logger := middlewares.GetLogger(ctx, name, typeName)
logger.Debug("Creating middleware")
logger.Debugf("Setting up redirect %s -> %s", config.Regex, config.Replacement)
re, err := regexp.Compile(config.Regex)
if err != nil {
return nil, err
}
return &handler{
regexp: re,
replacement: replacement,
permanent: permanent,
return &redirect{
regex: re,
replacement: config.Replacement,
permanent: config.Permanent,
errHandler: utils.DefaultHandler,
next: next,
name: name,
}, nil
}
type handler struct {
regexp *regexp.Regexp
replacement string
permanent bool
errHandler utils.ErrorHandler
func (r *redirect) GetTracingInformation() (string, ext.SpanKindEnum) {
return r.name, tracing.SpanKindNoneEnum
}
func (h *handler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
func (r *redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
oldURL := rawURL(req)
// only continue if the Regexp param matches the URL
if !h.regexp.MatchString(oldURL) {
next.ServeHTTP(rw, req)
// If the Regexp doesn't match, skip to the next handler
if !r.regex.MatchString(oldURL) {
r.next.ServeHTTP(rw, req)
return
}
// apply a rewrite regexp to the URL
newURL := h.regexp.ReplaceAllString(oldURL, h.replacement)
newURL := r.regex.ReplaceAllString(oldURL, r.replacement)
// replace any variables that may be in there
rewrittenURL := &bytes.Buffer{}
if err := applyString(newURL, rewrittenURL, req); err != nil {
h.errHandler.ServeHTTP(rw, req, err)
r.errHandler.ServeHTTP(rw, req, err)
return
}
// parse the rewritten URL and replace request URL with it
parsedURL, err := url.Parse(rewrittenURL.String())
if err != nil {
h.errHandler.ServeHTTP(rw, req, err)
r.errHandler.ServeHTTP(rw, req, err)
return
}
if stripPrefix, stripPrefixOk := req.Context().Value(middlewares.StripPrefixKey).(string); stripPrefixOk {
if len(stripPrefix) > 0 {
parsedURL.Path = stripPrefix
}
}
if addPrefix, addPrefixOk := req.Context().Value(middlewares.AddPrefixKey).(string); addPrefixOk {
if len(addPrefix) > 0 {
parsedURL.Path = strings.Replace(parsedURL.Path, addPrefix, "", 1)
}
}
if replacePath, replacePathOk := req.Context().Value(middlewares.ReplacePathKey).(string); replacePathOk {
if len(replacePath) > 0 {
parsedURL.Path = replacePath
}
}
if newURL != oldURL {
handler := &moveHandler{location: parsedURL, permanent: h.permanent}
handler := &moveHandler{location: parsedURL, permanent: r.permanent}
handler.ServeHTTP(rw, req)
return
}
@ -114,7 +91,7 @@ func (h *handler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http
// make sure the request URI corresponds the rewritten URL
req.RequestURI = req.URL.RequestURI()
next.ServeHTTP(rw, req)
r.next.ServeHTTP(rw, req)
}
type moveHandler struct {
@ -124,21 +101,25 @@ type moveHandler struct {
func (m *moveHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("Location", m.location.String())
status := http.StatusFound
if m.permanent {
status = http.StatusMovedPermanently
}
rw.WriteHeader(status)
rw.Write([]byte(http.StatusText(status)))
_, err := rw.Write([]byte(http.StatusText(status)))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}
func rawURL(request *http.Request) string {
func rawURL(req *http.Request) string {
scheme := "http"
if request.TLS != nil || isXForwardedHTTPS(request) {
if req.TLS != nil || isXForwardedHTTPS(req) {
scheme = "https"
}
return strings.Join([]string{scheme, "://", request.Host, request.RequestURI}, "")
return strings.Join([]string{scheme, "://", req.Host, req.RequestURI}, "")
}
func isXForwardedHTTPS(request *http.Request) bool {
@ -147,17 +128,13 @@ func isXForwardedHTTPS(request *http.Request) bool {
return len(xForwardedProto) > 0 && xForwardedProto == "https"
}
func applyString(in string, out io.Writer, request *http.Request) error {
func applyString(in string, out io.Writer, req *http.Request) error {
t, err := template.New("t").Parse(in)
if err != nil {
return err
}
data := struct {
Request *http.Request
}{
Request: request,
}
data := struct{ Request *http.Request }{Request: req}
return t.Execute(out, data)
}

View file

@ -1,146 +1,129 @@
package redirect
import (
"context"
"crypto/tls"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/configuration"
"github.com/containous/traefik/config"
"github.com/containous/traefik/testhelpers"
"github.com/containous/traefik/tls"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewEntryPointHandler(t *testing.T) {
testCases := []struct {
desc string
entryPoint *configuration.EntryPoint
permanent bool
url string
expectedURL string
expectedStatus int
errorExpected bool
}{
{
desc: "HTTP to HTTPS",
entryPoint: &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}},
url: "http://foo:80",
expectedURL: "https://foo:443",
expectedStatus: http.StatusFound,
},
{
desc: "HTTPS to HTTP",
entryPoint: &configuration.EntryPoint{Address: ":80"},
url: "https://foo:443",
expectedURL: "http://foo:80",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP to HTTP",
entryPoint: &configuration.EntryPoint{Address: ":88"},
url: "http://foo:80",
expectedURL: "http://foo:88",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP to HTTPS permanent",
entryPoint: &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}},
permanent: true,
url: "http://foo:80",
expectedURL: "https://foo:443",
expectedStatus: http.StatusMovedPermanently,
},
{
desc: "HTTPS to HTTP permanent",
entryPoint: &configuration.EntryPoint{Address: ":80"},
permanent: true,
url: "https://foo:443",
expectedURL: "http://foo:80",
expectedStatus: http.StatusMovedPermanently,
},
{
desc: "invalid address",
entryPoint: &configuration.EntryPoint{Address: ":foo", TLS: &tls.TLS{}},
url: "http://foo:80",
errorExpected: true,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
handler, err := NewEntryPointHandler(test.entryPoint, test.permanent)
if test.errorExpected {
require.Error(t, err)
} else {
require.NoError(t, err)
recorder := httptest.NewRecorder()
r := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
handler.ServeHTTP(recorder, r, nil)
location, err := recorder.Result().Location()
require.NoError(t, err)
assert.Equal(t, test.expectedURL, location.String())
assert.Equal(t, test.expectedStatus, recorder.Code)
}
})
}
}
func TestNewRegexHandler(t *testing.T) {
testCases := []struct {
desc string
regex string
replacement string
permanent bool
config config.Redirect
url string
expectedURL string
expectedStatus int
errorExpected bool
secured bool
}{
{
desc: "simple redirection",
regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
replacement: "https://${1}bar$2:443$4",
desc: "simple redirection",
config: config.Redirect{
Regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
Replacement: "https://${1}bar$2:443$4",
},
url: "http://foo.com:80",
expectedURL: "https://foobar.com:443",
expectedStatus: http.StatusFound,
},
{
desc: "use request header",
regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
replacement: `https://${1}{{ .Request.Header.Get "X-Foo" }}$2:443$4`,
desc: "use request header",
config: config.Redirect{
Regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
Replacement: `https://${1}{{ .Request.Header.Get "X-Foo" }}$2:443$4`,
},
url: "http://foo.com:80",
expectedURL: "https://foobar.com:443",
expectedStatus: http.StatusFound,
},
{
desc: "URL doesn't match regex",
regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
replacement: "https://${1}bar$2:443$4",
desc: "URL doesn't match regex",
config: config.Redirect{
Regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
Replacement: "https://${1}bar$2:443$4",
},
url: "http://bar.com:80",
expectedStatus: http.StatusOK,
},
{
desc: "invalid rewritten URL",
regex: `^(.*)$`,
replacement: "http://192.168.0.%31/",
desc: "invalid rewritten URL",
config: config.Redirect{
Regex: `^(.*)$`,
Replacement: "http://192.168.0.%31/",
},
url: "http://foo.com:80",
expectedStatus: http.StatusBadGateway,
},
{
desc: "invalid regex",
regex: `^(.*`,
replacement: "$1",
desc: "invalid regex",
config: config.Redirect{
Regex: `^(.*`,
Replacement: "$1",
},
url: "http://foo.com:80",
errorExpected: true,
},
{
desc: "HTTP to HTTPS permanent",
config: config.Redirect{
Regex: `^http://`,
Replacement: "https://$1",
Permanent: true,
},
url: "http://foo",
expectedURL: "https://foo",
expectedStatus: http.StatusMovedPermanently,
},
{
desc: "HTTPS to HTTP permanent",
config: config.Redirect{
Regex: `https://foo`,
Replacement: "http://foo",
Permanent: true,
},
secured: true,
url: "https://foo",
expectedURL: "http://foo",
expectedStatus: http.StatusMovedPermanently,
},
{
desc: "HTTP to HTTPS",
config: config.Redirect{
Regex: `http://foo:80`,
Replacement: "https://foo:443",
},
url: "http://foo:80",
expectedURL: "https://foo:443",
expectedStatus: http.StatusFound,
},
{
desc: "HTTPS to HTTP",
config: config.Redirect{
Regex: `https://foo:443`,
Replacement: "http://foo:80",
},
secured: true,
url: "https://foo:443",
expectedURL: "http://foo:80",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP to HTTP",
config: config.Redirect{
Regex: `http://foo:80`,
Replacement: "http://foo:88",
},
url: "http://foo:80",
expectedURL: "http://foo:88",
expectedStatus: http.StatusFound,
},
}
for _, test := range testCases {
@ -148,20 +131,23 @@ func TestNewRegexHandler(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
handler, err := NewRegexHandler(test.regex, test.replacement, test.permanent)
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
handler, err := New(context.Background(), next, test.config, "traefikTest")
if test.errorExpected {
require.Nil(t, handler)
require.Error(t, err)
require.Nil(t, handler)
} else {
require.NotNil(t, handler)
require.NoError(t, err)
require.NotNil(t, handler)
recorder := httptest.NewRecorder()
r := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
if test.secured {
r.TLS = &tls.ConnectionState{}
}
r.Header.Set("X-Foo", "bar")
next := func(rw http.ResponseWriter, req *http.Request) {}
handler.ServeHTTP(recorder, r, next)
handler.ServeHTTP(recorder, r)
if test.expectedStatus == http.StatusMovedPermanently || test.expectedStatus == http.StatusFound {
assert.Equal(t, test.expectedStatus, recorder.Code)

View file

@ -1,28 +0,0 @@
package middlewares
import (
"context"
"net/http"
)
const (
// ReplacePathKey is the key within the request context used to
// store the replaced path
ReplacePathKey key = "ReplacePath"
// ReplacedPathHeader is the default header to set the old path to
ReplacedPathHeader = "X-Replaced-Path"
)
// ReplacePath is a middleware used to replace the path of a URL request
type ReplacePath struct {
Handler http.Handler
Path string
}
func (s *ReplacePath) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(context.WithValue(r.Context(), ReplacePathKey, r.URL.Path))
r.Header.Add(ReplacedPathHeader, r.URL.Path)
r.URL.Path = s.Path
r.RequestURI = r.URL.RequestURI()
s.Handler.ServeHTTP(w, r)
}

View file

@ -1,40 +0,0 @@
package middlewares
import (
"context"
"net/http"
"regexp"
"strings"
"github.com/containous/traefik/log"
)
// ReplacePathRegex is a middleware used to replace the path of a URL request with a regular expression
type ReplacePathRegex struct {
Handler http.Handler
Regexp *regexp.Regexp
Replacement string
}
// NewReplacePathRegexHandler returns a new ReplacePathRegex
func NewReplacePathRegexHandler(regex string, replacement string, handler http.Handler) http.Handler {
exp, err := regexp.Compile(strings.TrimSpace(regex))
if err != nil {
log.Errorf("Error compiling regular expression %s: %s", regex, err)
}
return &ReplacePathRegex{
Regexp: exp,
Replacement: strings.TrimSpace(replacement),
Handler: handler,
}
}
func (s *ReplacePathRegex) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if s.Regexp != nil && len(s.Replacement) > 0 && s.Regexp.MatchString(r.URL.Path) {
r = r.WithContext(context.WithValue(r.Context(), ReplacePathKey, r.URL.Path))
r.Header.Add(ReplacedPathHeader, r.URL.Path)
r.URL.Path = s.Regexp.ReplaceAllString(r.URL.Path, s.Replacement)
r.RequestURI = r.URL.RequestURI()
}
s.Handler.ServeHTTP(w, r)
}

View file

@ -1,80 +0,0 @@
package middlewares
import (
"net/http"
"testing"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
)
func TestReplacePathRegex(t *testing.T) {
testCases := []struct {
desc string
path string
replacement string
regex string
expectedPath string
expectedHeader string
}{
{
desc: "simple regex",
path: "/whoami/and/whoami",
replacement: "/who-am-i/$1",
regex: `^/whoami/(.*)`,
expectedPath: "/who-am-i/and/whoami",
expectedHeader: "/whoami/and/whoami",
},
{
desc: "simple replace (no regex)",
path: "/whoami/and/whoami",
replacement: "/who-am-i",
regex: `/whoami`,
expectedPath: "/who-am-i/and/who-am-i",
expectedHeader: "/whoami/and/whoami",
},
{
desc: "multiple replacement",
path: "/downloads/src/source.go",
replacement: "/downloads/$1-$2",
regex: `^(?i)/downloads/([^/]+)/([^/]+)$`,
expectedPath: "/downloads/src-source.go",
expectedHeader: "/downloads/src/source.go",
},
{
desc: "invalid regular expression",
path: "/invalid/regexp/test",
replacement: "/valid/regexp/$1",
regex: `^(?err)/invalid/regexp/([^/]+)$`,
expectedPath: "/invalid/regexp/test",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
var actualPath, actualHeader, requestURI string
handler := NewReplacePathRegexHandler(
test.regex,
test.replacement,
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actualPath = r.URL.Path
actualHeader = r.Header.Get(ReplacedPathHeader)
requestURI = r.RequestURI
}),
)
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
handler.ServeHTTP(nil, req)
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", ReplacedPathHeader)
if test.expectedHeader != "" {
assert.Equal(t, actualPath, requestURI, "Unexpected request URI.")
}
})
}
}

View file

@ -0,0 +1,46 @@
package replacepath
import (
"context"
"net/http"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go/ext"
)
const (
// ReplacedPathHeader is the default header to set the old path to.
ReplacedPathHeader = "X-Replaced-Path"
typeName = "ReplacePath"
)
// ReplacePath is a middleware used to replace the path of a URL request.
type replacePath struct {
next http.Handler
path string
name string
}
// New creates a new replace path middleware.
func New(ctx context.Context, next http.Handler, config config.ReplacePath, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
return &replacePath{
next: next,
path: config.Path,
name: name,
}, nil
}
func (r *replacePath) GetTracingInformation() (string, ext.SpanKindEnum) {
return r.name, tracing.SpanKindNoneEnum
}
func (r *replacePath) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
req.Header.Add(ReplacedPathHeader, req.URL.Path)
req.URL.Path = r.path
req.RequestURI = req.URL.RequestURI()
r.next.ServeHTTP(rw, req)
}

View file

@ -1,15 +1,20 @@
package middlewares
package replacepath
import (
"context"
"net/http"
"testing"
"github.com/containous/traefik/config"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestReplacePath(t *testing.T) {
const replacementPath = "/replacement-path"
var replacementConfig = config.ReplacePath{
Path: "/replacement-path",
}
paths := []string{
"/example",
@ -20,20 +25,20 @@ func TestReplacePath(t *testing.T) {
t.Run(path, func(t *testing.T) {
var expectedPath, actualHeader, requestURI string
handler := &ReplacePath{
Path: replacementPath,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expectedPath = r.URL.Path
actualHeader = r.Header.Get(ReplacedPathHeader)
requestURI = r.RequestURI
}),
}
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expectedPath = r.URL.Path
actualHeader = r.Header.Get(ReplacedPathHeader)
requestURI = r.RequestURI
})
handler, err := New(context.Background(), next, replacementConfig, "foo-replace-path")
require.NoError(t, err)
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+path, nil)
handler.ServeHTTP(nil, req)
assert.Equal(t, expectedPath, replacementPath, "Unexpected path.")
assert.Equal(t, expectedPath, replacementConfig.Path, "Unexpected path.")
assert.Equal(t, path, actualHeader, "Unexpected '%s' header.", ReplacedPathHeader)
assert.Equal(t, expectedPath, requestURI, "Unexpected request URI.")
})

View file

@ -0,0 +1,57 @@
package replacepathregex
import (
"context"
"fmt"
"net/http"
"regexp"
"strings"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/middlewares/replacepath"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go/ext"
)
const (
typeName = "ReplacePathRegex"
)
// ReplacePathRegex is a middleware used to replace the path of a URL request with a regular expression.
type replacePathRegex struct {
next http.Handler
regexp *regexp.Regexp
replacement string
name string
}
// New creates a new replace path regex middleware.
func New(ctx context.Context, next http.Handler, config config.ReplacePathRegex, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
exp, err := regexp.Compile(strings.TrimSpace(config.Regex))
if err != nil {
return nil, fmt.Errorf("error compiling regular expression %s: %s", config.Regex, err)
}
return &replacePathRegex{
regexp: exp,
replacement: strings.TrimSpace(config.Replacement),
next: next,
name: name,
}, nil
}
func (rp *replacePathRegex) GetTracingInformation() (string, ext.SpanKindEnum) {
return rp.name, tracing.SpanKindNoneEnum
}
func (rp *replacePathRegex) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if rp.regexp != nil && len(rp.replacement) > 0 && rp.regexp.MatchString(req.URL.Path) {
req.Header.Add(replacepath.ReplacedPathHeader, req.URL.Path)
req.URL.Path = rp.regexp.ReplaceAllString(req.URL.Path, rp.replacement)
req.RequestURI = req.URL.RequestURI()
}
rp.next.ServeHTTP(rw, req)
}

View file

@ -0,0 +1,104 @@
package replacepathregex
import (
"context"
"net/http"
"testing"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares/replacepath"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestReplacePathRegex(t *testing.T) {
testCases := []struct {
desc string
path string
config config.ReplacePathRegex
expectedPath string
expectedHeader string
expectsError bool
}{
{
desc: "simple regex",
path: "/whoami/and/whoami",
config: config.ReplacePathRegex{
Replacement: "/who-am-i/$1",
Regex: `^/whoami/(.*)`,
},
expectedPath: "/who-am-i/and/whoami",
expectedHeader: "/whoami/and/whoami",
},
{
desc: "simple replace (no regex)",
path: "/whoami/and/whoami",
config: config.ReplacePathRegex{
Replacement: "/who-am-i",
Regex: `/whoami`,
},
expectedPath: "/who-am-i/and/who-am-i",
expectedHeader: "/whoami/and/whoami",
},
{
desc: "no match",
path: "/whoami/and/whoami",
config: config.ReplacePathRegex{
Replacement: "/whoami",
Regex: `/no-match`,
},
expectedPath: "/whoami/and/whoami",
},
{
desc: "multiple replacement",
path: "/downloads/src/source.go",
config: config.ReplacePathRegex{
Replacement: "/downloads/$1-$2",
Regex: `^(?i)/downloads/([^/]+)/([^/]+)$`,
},
expectedPath: "/downloads/src-source.go",
expectedHeader: "/downloads/src/source.go",
},
{
desc: "invalid regular expression",
path: "/invalid/regexp/test",
config: config.ReplacePathRegex{
Replacement: "/valid/regexp/$1",
Regex: `^(?err)/invalid/regexp/([^/]+)$`,
},
expectedPath: "/invalid/regexp/test",
expectsError: true,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
var actualPath, actualHeader, requestURI string
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actualPath = r.URL.Path
actualHeader = r.Header.Get(replacepath.ReplacedPathHeader)
requestURI = r.RequestURI
})
handler, err := New(context.Background(), next, test.config, "foo-replace-path-regexp")
if test.expectsError {
require.Error(t, err)
} else {
require.NoError(t, err)
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
req.RequestURI = test.path
handler.ServeHTTP(nil, req)
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
assert.Equal(t, actualPath, requestURI, "Unexpected request URI.")
if test.expectedHeader != "" {
assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", replacepath.ReplacedPathHeader)
}
}
})
}
}

View file

@ -1,42 +0,0 @@
package middlewares
import (
"context"
"net"
"net/http"
"strings"
"github.com/containous/traefik/types"
)
var requestHostKey struct{}
// RequestHost is the struct for the middleware that adds the CanonicalDomain of the request Host into a context for later use.
type RequestHost struct{}
func (rh *RequestHost) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
if next != nil {
host := types.CanonicalDomain(parseHost(r.Host))
next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), requestHostKey, host)))
}
}
func parseHost(addr string) string {
if !strings.Contains(addr, ":") {
return addr
}
host, _, err := net.SplitHostPort(addr)
if err != nil {
return addr
}
return host
}
// GetCanonizedHost plucks the canonized host key from the request of a context that was put through the middleware
func GetCanonizedHost(ctx context.Context) string {
if val, ok := ctx.Value(requestHostKey).(string); ok {
return val
}
return ""
}

View file

@ -0,0 +1,124 @@
package requestdecorator
import (
"context"
"fmt"
"net"
"sort"
"strings"
"time"
"github.com/containous/traefik/log"
"github.com/miekg/dns"
"github.com/patrickmn/go-cache"
)
type cnameResolv struct {
TTL time.Duration
Record string
}
type byTTL []*cnameResolv
func (a byTTL) Len() int { return len(a) }
func (a byTTL) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a byTTL) Less(i, j int) bool { return a[i].TTL > a[j].TTL }
// Resolver used for host resolver.
type Resolver struct {
CnameFlattening bool
ResolvConfig string
ResolvDepth int
cache *cache.Cache
}
// CNAMEFlatten check if CNAME record exists, flatten if possible.
func (hr *Resolver) CNAMEFlatten(ctx context.Context, host string) string {
if hr.cache == nil {
hr.cache = cache.New(30*time.Minute, 5*time.Minute)
}
result := host
request := host
value, found := hr.cache.Get(host)
if found {
return value.(string)
}
logger := log.FromContext(ctx)
var cacheDuration = 0 * time.Second
for depth := 0; depth < hr.ResolvDepth; depth++ {
resolv, err := cnameResolve(ctx, request, hr.ResolvConfig)
if err != nil {
logger.Error(err)
break
}
if resolv == nil {
break
}
result = resolv.Record
if depth == 0 {
cacheDuration = resolv.TTL
}
request = resolv.Record
}
if err := hr.cache.Add(host, result, cacheDuration); err != nil {
logger.Error(err)
}
return result
}
// cnameResolve resolves CNAME if exists, and return with the highest TTL.
func cnameResolve(ctx context.Context, host string, resolvPath string) (*cnameResolv, error) {
config, err := dns.ClientConfigFromFile(resolvPath)
if err != nil {
return nil, fmt.Errorf("invalid resolver configuration file: %s", resolvPath)
}
client := &dns.Client{Timeout: 30 * time.Second}
m := &dns.Msg{}
m.SetQuestion(dns.Fqdn(host), dns.TypeCNAME)
var result []*cnameResolv
for _, server := range config.Servers {
tempRecord, err := getRecord(client, m, server, config.Port)
if err != nil {
log.FromContext(ctx).Errorf("Failed to resolve host %s: %v", host, err)
continue
}
result = append(result, tempRecord)
}
if len(result) <= 0 {
return nil, nil
}
sort.Sort(byTTL(result))
return result[0], nil
}
func getRecord(client *dns.Client, msg *dns.Msg, server string, port string) (*cnameResolv, error) {
resp, _, err := client.Exchange(msg, net.JoinHostPort(server, port))
if err != nil {
return nil, fmt.Errorf("exchange error for server %s: %v", server, err)
}
if resp == nil || len(resp.Answer) == 0 {
return nil, fmt.Errorf("empty answer for server %s", server)
}
rr, ok := resp.Answer[0].(*dns.CNAME)
if !ok {
return nil, fmt.Errorf("invalid response type for server %s", server)
}
return &cnameResolv{
TTL: time.Duration(rr.Hdr.Ttl) * time.Second,
Record: strings.TrimSuffix(rr.Target, "."),
}, nil
}

View file

@ -0,0 +1,51 @@
package requestdecorator
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCNAMEFlatten(t *testing.T) {
testCases := []struct {
desc string
resolvFile string
domain string
expectedDomain string
}{
{
desc: "host request is CNAME record",
resolvFile: "/etc/resolv.conf",
domain: "www.github.com",
expectedDomain: "github.com",
},
{
desc: "resolve file not found",
resolvFile: "/etc/resolv.oops",
domain: "www.github.com",
expectedDomain: "www.github.com",
},
{
desc: "host request is not CNAME record",
resolvFile: "/etc/resolv.conf",
domain: "github.com",
expectedDomain: "github.com",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
hostResolver := &Resolver{
ResolvConfig: test.resolvFile,
ResolvDepth: 5,
}
flatH := hostResolver.CNAMEFlatten(context.Background(), test.domain)
assert.Equal(t, test.expectedDomain, flatH)
})
}
}

View file

@ -0,0 +1,88 @@
package requestdecorator
import (
"context"
"net"
"net/http"
"strings"
"github.com/containous/alice"
"github.com/containous/traefik/config/static"
"github.com/containous/traefik/old/types"
)
const (
canonicalKey key = "canonical"
flattenKey key = "flatten"
)
type key string
// RequestDecorator is the struct for the middleware that adds the CanonicalDomain of the request Host into a context for later use.
type RequestDecorator struct {
hostResolver *Resolver
}
// New creates a new request host middleware.
func New(hostResolverConfig *static.HostResolverConfig) *RequestDecorator {
requestDecorator := &RequestDecorator{}
if hostResolverConfig != nil {
requestDecorator.hostResolver = &Resolver{
CnameFlattening: hostResolverConfig.CnameFlattening,
ResolvConfig: hostResolverConfig.ResolvConfig,
ResolvDepth: hostResolverConfig.ResolvDepth,
}
}
return requestDecorator
}
func (r *RequestDecorator) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
host := types.CanonicalDomain(parseHost(req.Host))
reqt := req.WithContext(context.WithValue(req.Context(), canonicalKey, host))
if r.hostResolver != nil && r.hostResolver.CnameFlattening {
flatHost := r.hostResolver.CNAMEFlatten(reqt.Context(), host)
reqt = reqt.WithContext(context.WithValue(reqt.Context(), flattenKey, flatHost))
}
next(rw, reqt)
}
func parseHost(addr string) string {
if !strings.Contains(addr, ":") {
return addr
}
host, _, err := net.SplitHostPort(addr)
if err != nil {
return addr
}
return host
}
// GetCanonizedHost retrieves the canonized host from the given context (previously stored in the request context by the middleware).
func GetCanonizedHost(ctx context.Context) string {
if val, ok := ctx.Value(canonicalKey).(string); ok {
return val
}
return ""
}
// GetCNAMEFlatten return the flat name if it is present in the context.
func GetCNAMEFlatten(ctx context.Context) string {
if val, ok := ctx.Value(flattenKey).(string); ok {
return val
}
return ""
}
// WrapHandler Wraps a ServeHTTP with next to an alice.Constructor.
func WrapHandler(handler *RequestDecorator) alice.Constructor {
return func(next http.Handler) (http.Handler, error) {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
handler.ServeHTTP(rw, req, next.ServeHTTP)
}), nil
}
}

View file

@ -1,9 +1,10 @@
package middlewares
package requestdecorator
import (
"net/http"
"testing"
"github.com/containous/traefik/config/static"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
)
@ -36,19 +37,69 @@ func TestRequestHost(t *testing.T) {
},
}
rh := &RequestHost{}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
host := GetCanonizedHost(r.Context())
assert.Equal(t, test.expected, host)
})
rh := New(nil)
req := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
rh.ServeHTTP(nil, req, next)
})
}
}
func TestRequestFlattening(t *testing.T) {
testCases := []struct {
desc string
url string
expected string
}{
{
desc: "host with flattening",
url: "http://www.github.com",
expected: "github.com",
},
{
desc: "host without flattening",
url: "http://github.com",
expected: "github.com",
},
{
desc: "ip without flattening",
url: "http://127.0.0.1",
expected: "127.0.0.1",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
req := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
rh.ServeHTTP(nil, req, func(_ http.ResponseWriter, r *http.Request) {
host := GetCanonizedHost(r.Context())
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
host := GetCNAMEFlatten(r.Context())
assert.Equal(t, test.expected, host)
})
rh := New(
&static.HostResolverConfig{
CnameFlattening: true,
ResolvConfig: "/etc/resolv.conf",
ResolvDepth: 5,
},
)
req := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
rh.ServeHTTP(nil, req, next)
})
}
}

View file

@ -1,173 +0,0 @@
package middlewares
import (
"bufio"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptrace"
"github.com/containous/traefik/log"
)
// Compile time validation that the response writer implements http interfaces correctly.
var _ Stateful = &retryResponseWriterWithCloseNotify{}
// Retry is a middleware that retries requests
type Retry struct {
attempts int
next http.Handler
listener RetryListener
}
// NewRetry returns a new Retry instance
func NewRetry(attempts int, next http.Handler, listener RetryListener) *Retry {
return &Retry{
attempts: attempts,
next: next,
listener: listener,
}
}
func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
// if we might make multiple attempts, swap the body for an ioutil.NopCloser
// cf https://github.com/containous/traefik/issues/1008
if retry.attempts > 1 {
body := r.Body
if body == nil {
body = http.NoBody
}
defer body.Close()
r.Body = ioutil.NopCloser(body)
}
attempts := 1
for {
attemptsExhausted := attempts >= retry.attempts
shouldRetry := !attemptsExhausted
retryResponseWriter := newRetryResponseWriter(rw, shouldRetry)
// Disable retries when the backend already received request data
trace := &httptrace.ClientTrace{
WroteHeaders: func() {
retryResponseWriter.DisableRetries()
},
WroteRequest: func(httptrace.WroteRequestInfo) {
retryResponseWriter.DisableRetries()
},
}
newCtx := httptrace.WithClientTrace(r.Context(), trace)
retry.next.ServeHTTP(retryResponseWriter, r.WithContext(newCtx))
if !retryResponseWriter.ShouldRetry() {
break
}
attempts++
log.Debugf("New attempt %d for request: %v", attempts, r.URL)
retry.listener.Retried(r, attempts)
}
}
// RetryListener is used to inform about retry attempts.
type RetryListener interface {
// Retried will be called when a retry happens, with the request attempt passed to it.
// For the first retry this will be attempt 2.
Retried(req *http.Request, attempt int)
}
// RetryListeners is a convenience type to construct a list of RetryListener and notify
// each of them about a retry attempt.
type RetryListeners []RetryListener
// Retried exists to implement the RetryListener interface. It calls Retried on each of its slice entries.
func (l RetryListeners) Retried(req *http.Request, attempt int) {
for _, retryListener := range l {
retryListener.Retried(req, attempt)
}
}
type retryResponseWriter interface {
http.ResponseWriter
http.Flusher
ShouldRetry() bool
DisableRetries()
}
func newRetryResponseWriter(rw http.ResponseWriter, shouldRetry bool) retryResponseWriter {
responseWriter := &retryResponseWriterWithoutCloseNotify{
responseWriter: rw,
shouldRetry: shouldRetry,
}
if _, ok := rw.(http.CloseNotifier); ok {
return &retryResponseWriterWithCloseNotify{responseWriter}
}
return responseWriter
}
type retryResponseWriterWithoutCloseNotify struct {
responseWriter http.ResponseWriter
shouldRetry bool
}
func (rr *retryResponseWriterWithoutCloseNotify) ShouldRetry() bool {
return rr.shouldRetry
}
func (rr *retryResponseWriterWithoutCloseNotify) DisableRetries() {
rr.shouldRetry = false
}
func (rr *retryResponseWriterWithoutCloseNotify) Header() http.Header {
if rr.ShouldRetry() {
return make(http.Header)
}
return rr.responseWriter.Header()
}
func (rr *retryResponseWriterWithoutCloseNotify) Write(buf []byte) (int, error) {
if rr.ShouldRetry() {
return len(buf), nil
}
return rr.responseWriter.Write(buf)
}
func (rr *retryResponseWriterWithoutCloseNotify) WriteHeader(code int) {
if rr.ShouldRetry() && code == http.StatusServiceUnavailable {
// We get a 503 HTTP Status Code when there is no backend server in the pool
// to which the request could be sent. Also, note that rr.ShouldRetry()
// will never return true in case there was a connection established to
// the backend server and so we can be sure that the 503 was produced
// inside Traefik already and we don't have to retry in this cases.
rr.DisableRetries()
}
if rr.ShouldRetry() {
return
}
rr.responseWriter.WriteHeader(code)
}
func (rr *retryResponseWriterWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := rr.responseWriter.(http.Hijacker)
if !ok {
return nil, nil, fmt.Errorf("%T is not a http.Hijacker", rr.responseWriter)
}
return hijacker.Hijack()
}
func (rr *retryResponseWriterWithoutCloseNotify) Flush() {
if flusher, ok := rr.responseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
type retryResponseWriterWithCloseNotify struct {
*retryResponseWriterWithoutCloseNotify
}
func (rr *retryResponseWriterWithCloseNotify) CloseNotify() <-chan bool {
return rr.responseWriter.(http.CloseNotifier).CloseNotify()
}

194
middlewares/retry/retry.go Normal file
View file

@ -0,0 +1,194 @@
package retry
import (
"bufio"
"context"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptrace"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go/ext"
)
// Compile time validation that the response writer implements http interfaces correctly.
var _ middlewares.Stateful = &responseWriterWithCloseNotify{}
const (
typeName = "Retry"
)
// Listener is used to inform about retry attempts.
type Listener interface {
// Retried will be called when a retry happens, with the request attempt passed to it.
// For the first retry this will be attempt 2.
Retried(req *http.Request, attempt int)
}
// Listeners is a convenience type to construct a list of Listener and notify
// each of them about a retry attempt.
type Listeners []Listener
// retry is a middleware that retries requests.
type retry struct {
attempts int
next http.Handler
listener Listener
name string
}
// New returns a new retry middleware.
func New(ctx context.Context, next http.Handler, config config.Retry, listener Listener, name string) (http.Handler, error) {
logger := middlewares.GetLogger(ctx, name, typeName)
logger.Debug("Creating middleware")
if config.Attempts <= 0 {
return nil, fmt.Errorf("incorrect (or empty) value for attempt (%d)", config.Attempts)
}
return &retry{
attempts: config.Attempts,
next: next,
listener: listener,
name: name,
}, nil
}
func (r *retry) GetTracingInformation() (string, ext.SpanKindEnum) {
return r.name, tracing.SpanKindNoneEnum
}
func (r *retry) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// if we might make multiple attempts, swap the body for an ioutil.NopCloser
// cf https://github.com/containous/traefik/issues/1008
if r.attempts > 1 {
body := req.Body
defer body.Close()
req.Body = ioutil.NopCloser(body)
}
attempts := 1
for {
attemptsExhausted := attempts >= r.attempts
shouldRetry := !attemptsExhausted
retryResponseWriter := newResponseWriter(rw, shouldRetry)
// Disable retries when the backend already received request data
trace := &httptrace.ClientTrace{
WroteHeaders: func() {
retryResponseWriter.DisableRetries()
},
WroteRequest: func(httptrace.WroteRequestInfo) {
retryResponseWriter.DisableRetries()
},
}
newCtx := httptrace.WithClientTrace(req.Context(), trace)
r.next.ServeHTTP(retryResponseWriter, req.WithContext(newCtx))
if !retryResponseWriter.ShouldRetry() {
break
}
attempts++
logger := middlewares.GetLogger(req.Context(), r.name, typeName)
logger.Debugf("New attempt %d for request: %v", attempts, req.URL)
r.listener.Retried(req, attempts)
}
}
// Retried exists to implement the Listener interface. It calls Retried on each of its slice entries.
func (l Listeners) Retried(req *http.Request, attempt int) {
for _, listener := range l {
listener.Retried(req, attempt)
}
}
type responseWriter interface {
http.ResponseWriter
http.Flusher
ShouldRetry() bool
DisableRetries()
}
func newResponseWriter(rw http.ResponseWriter, shouldRetry bool) responseWriter {
responseWriter := &responseWriterWithoutCloseNotify{
responseWriter: rw,
shouldRetry: shouldRetry,
}
if _, ok := rw.(http.CloseNotifier); ok {
return &responseWriterWithCloseNotify{
responseWriterWithoutCloseNotify: responseWriter,
}
}
return responseWriter
}
type responseWriterWithoutCloseNotify struct {
responseWriter http.ResponseWriter
shouldRetry bool
}
func (r *responseWriterWithoutCloseNotify) ShouldRetry() bool {
return r.shouldRetry
}
func (r *responseWriterWithoutCloseNotify) DisableRetries() {
r.shouldRetry = false
}
func (r *responseWriterWithoutCloseNotify) Header() http.Header {
if r.ShouldRetry() {
return make(http.Header)
}
return r.responseWriter.Header()
}
func (r *responseWriterWithoutCloseNotify) Write(buf []byte) (int, error) {
if r.ShouldRetry() {
return len(buf), nil
}
return r.responseWriter.Write(buf)
}
func (r *responseWriterWithoutCloseNotify) WriteHeader(code int) {
if r.ShouldRetry() && code == http.StatusServiceUnavailable {
// We get a 503 HTTP Status Code when there is no backend server in the pool
// to which the request could be sent. Also, note that r.ShouldRetry()
// will never return true in case there was a connection established to
// the backend server and so we can be sure that the 503 was produced
// inside Traefik already and we don't have to retry in this cases.
r.DisableRetries()
}
if r.ShouldRetry() {
return
}
r.responseWriter.WriteHeader(code)
}
func (r *responseWriterWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := r.responseWriter.(http.Hijacker)
if !ok {
return nil, nil, fmt.Errorf("%T is not a http.Hijacker", r.responseWriter)
}
return hijacker.Hijack()
}
func (r *responseWriterWithoutCloseNotify) Flush() {
if flusher, ok := r.responseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
type responseWriterWithCloseNotify struct {
*responseWriterWithoutCloseNotify
}
func (r *responseWriterWithCloseNotify) CloseNotify() <-chan bool {
return r.responseWriter.(http.CloseNotifier).CloseNotify()
}

View file

@ -1,11 +1,14 @@
package middlewares
package retry
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares/emptybackendhandler"
"github.com/containous/traefik/testhelpers"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
@ -17,42 +20,42 @@ import (
func TestRetry(t *testing.T) {
testCases := []struct {
desc string
maxRequestAttempts int
config config.Retry
wantRetryAttempts int
wantResponseStatus int
amountFaultyEndpoints int
}{
{
desc: "no retry on success",
maxRequestAttempts: 1,
config: config.Retry{Attempts: 1},
wantRetryAttempts: 0,
wantResponseStatus: http.StatusOK,
amountFaultyEndpoints: 0,
},
{
desc: "no retry when max request attempts is one",
maxRequestAttempts: 1,
config: config.Retry{Attempts: 1},
wantRetryAttempts: 0,
wantResponseStatus: http.StatusInternalServerError,
amountFaultyEndpoints: 1,
},
{
desc: "one retry when one server is faulty",
maxRequestAttempts: 2,
config: config.Retry{Attempts: 2},
wantRetryAttempts: 1,
wantResponseStatus: http.StatusOK,
amountFaultyEndpoints: 1,
},
{
desc: "two retries when two servers are faulty",
maxRequestAttempts: 3,
config: config.Retry{Attempts: 3},
wantRetryAttempts: 2,
wantResponseStatus: http.StatusOK,
amountFaultyEndpoints: 2,
},
{
desc: "max attempts exhausted delivers the 5xx response",
maxRequestAttempts: 3,
config: config.Retry{Attempts: 3},
wantRetryAttempts: 2,
wantResponseStatus: http.StatusInternalServerError,
amountFaultyEndpoints: 3,
@ -61,24 +64,22 @@ func TestRetry(t *testing.T) {
backendServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
rw.Write([]byte("OK"))
_, err := rw.Write([]byte("OK"))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}))
forwarder, err := forward.New()
if err != nil {
t.Fatalf("Error creating forwarder: %s", err)
}
require.NoError(t, err)
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
loadBalancer, err := roundrobin.New(forwarder)
if err != nil {
t.Fatalf("Error creating load balancer: %s", err)
}
require.NoError(t, err)
basePort := 33444
for i := 0; i < test.amountFaultyEndpoints; i++ {
@ -87,15 +88,16 @@ func TestRetry(t *testing.T) {
// We only use the port specification here because the URL is used as identifier
// in the load balancer and using the exact same URL would not add a new server.
err = loadBalancer.UpsertServer(testhelpers.MustParseURL("http://192.0.2.0:" + string(basePort+i)))
assert.NoError(t, err)
require.NoError(t, err)
}
// add the functioning server to the end of the load balancer list
err = loadBalancer.UpsertServer(testhelpers.MustParseURL(backendServer.URL))
assert.NoError(t, err)
require.NoError(t, err)
retryListener := &countingRetryListener{}
retry := NewRetry(test.maxRequestAttempts, loadBalancer, retryListener)
retry, err := New(context.Background(), loadBalancer, test.config, retryListener, "traefikTest")
require.NoError(t, err)
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil)
@ -108,6 +110,78 @@ func TestRetry(t *testing.T) {
}
}
func TestRetryEmptyServerList(t *testing.T) {
forwarder, err := forward.New()
require.NoError(t, err)
loadBalancer, err := roundrobin.New(forwarder)
require.NoError(t, err)
// The EmptyBackend middleware ensures that there is a 503
// response status set when there is no backend server in the pool.
next := emptybackendhandler.New(loadBalancer)
retryListener := &countingRetryListener{}
retry, err := New(context.Background(), next, config.Retry{Attempts: 3}, retryListener, "traefikTest")
require.NoError(t, err)
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil)
retry.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
assert.Equal(t, 0, retryListener.timesCalled)
}
func TestRetryListeners(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
retryListeners := Listeners{&countingRetryListener{}, &countingRetryListener{}}
retryListeners.Retried(req, 1)
retryListeners.Retried(req, 1)
for _, retryListener := range retryListeners {
listener := retryListener.(*countingRetryListener)
if listener.timesCalled != 2 {
t.Errorf("retry listener was called %d time(s), want %d time(s)", listener.timesCalled, 2)
}
}
}
// countingRetryListener is a Listener implementation to count the times the Retried fn is called.
type countingRetryListener struct {
timesCalled int
}
func (l *countingRetryListener) Retried(req *http.Request, attempt int) {
l.timesCalled++
}
func TestRetryWithFlush(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(200)
_, err := rw.Write([]byte("FULL "))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
rw.(http.Flusher).Flush()
_, err = rw.Write([]byte("DATA"))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
})
retry, err := New(context.Background(), next, config.Retry{Attempts: 1}, &countingRetryListener{}, "traefikTest")
require.NoError(t, err)
responseRecorder := httptest.NewRecorder()
retry.ServeHTTP(responseRecorder, &http.Request{})
assert.Equal(t, "FULL DATA", responseRecorder.Body.String())
}
func TestRetryWebsocket(t *testing.T) {
testCases := []struct {
desc string
@ -141,7 +215,10 @@ func TestRetryWebsocket(t *testing.T) {
backendServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
upgrader := websocket.Upgrader{}
upgrader.Upgrade(rw, req, nil)
_, err := upgrader.Upgrade(rw, req, nil)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}))
for _, test := range testCases {
@ -160,16 +237,17 @@ func TestRetryWebsocket(t *testing.T) {
// See: https://stackoverflow.com/questions/528538/non-routable-ip-address/18436928#18436928
// We only use the port specification here because the URL is used as identifier
// in the load balancer and using the exact same URL would not add a new server.
loadBalancer.UpsertServer(testhelpers.MustParseURL("http://192.0.2.0:" + string(basePort+i)))
_ = loadBalancer.UpsertServer(testhelpers.MustParseURL("http://192.0.2.0:" + string(basePort+i)))
}
// add the functioning server to the end of the load balancer list
loadBalancer.UpsertServer(testhelpers.MustParseURL(backendServer.URL))
retryListener := &countingRetryListener{}
retry := NewRetry(test.maxRequestAttempts, loadBalancer, retryListener)
retryH, err := New(context.Background(), loadBalancer, config.Retry{Attempts: test.maxRequestAttempts}, retryListener, "traefikTest")
require.NoError(t, err)
retryServer := httptest.NewServer(retry)
retryServer := httptest.NewServer(retryH)
url := strings.Replace(retryServer.URL, "http", "ws", 1)
_, response, err := websocket.DefaultDialer.Dial(url, nil)
@ -183,78 +261,3 @@ func TestRetryWebsocket(t *testing.T) {
})
}
}
func TestRetryEmptyServerList(t *testing.T) {
forwarder, err := forward.New()
if err != nil {
t.Fatalf("Error creating forwarder: %s", err)
}
loadBalancer, err := roundrobin.New(forwarder)
if err != nil {
t.Fatalf("Error creating load balancer: %s", err)
}
// The EmptyBackendHandler middleware ensures that there is a 503
// response status set when there is no backend server in the pool.
next := NewEmptyBackendHandler(loadBalancer)
retryListener := &countingRetryListener{}
retry := NewRetry(3, next, retryListener)
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil)
retry.ServeHTTP(recorder, req)
const wantResponseStatus = http.StatusServiceUnavailable
if wantResponseStatus != recorder.Code {
t.Errorf("got status code %d, want %d", recorder.Code, wantResponseStatus)
}
const wantRetryAttempts = 0
if wantRetryAttempts != retryListener.timesCalled {
t.Errorf("retry listener called %d time(s), want %d time(s)", retryListener.timesCalled, wantRetryAttempts)
}
}
func TestRetryListeners(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
retryListeners := RetryListeners{&countingRetryListener{}, &countingRetryListener{}}
retryListeners.Retried(req, 1)
retryListeners.Retried(req, 1)
for _, retryListener := range retryListeners {
listener := retryListener.(*countingRetryListener)
if listener.timesCalled != 2 {
t.Errorf("retry listener was called %d time(s), want %d time(s)", listener.timesCalled, 2)
}
}
}
// countingRetryListener is a RetryListener implementation to count the times the Retried fn is called.
type countingRetryListener struct {
timesCalled int
}
func (l *countingRetryListener) Retried(req *http.Request, attempt int) {
l.timesCalled++
}
func TestRetryWithFlush(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(200)
rw.Write([]byte("FULL "))
rw.(http.Flusher).Flush()
rw.Write([]byte("DATA"))
})
retry := NewRetry(1, next, &countingRetryListener{})
responseRecorder := httptest.NewRecorder()
retry.ServeHTTP(responseRecorder, &http.Request{})
if responseRecorder.Body.String() != "FULL DATA" {
t.Errorf("Wrong body %q want %q", responseRecorder.Body.String(), "FULL DATA")
}
}

View file

@ -1,28 +0,0 @@
package middlewares
import (
"encoding/json"
"log"
"net/http"
"github.com/containous/mux"
)
// Routes holds the gorilla mux routes (for the API & co).
type Routes struct {
router *mux.Router
}
// NewRoutes return a Routes based on the given router.
func NewRoutes(router *mux.Router) *Routes {
return &Routes{router}
}
func (router *Routes) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
routeMatch := mux.RouteMatch{}
if router.router.Match(r, &routeMatch) {
rt, _ := json.Marshal(routeMatch.Handler)
log.Println("Request match route ", rt)
}
next(rw, r)
}

View file

@ -1,36 +0,0 @@
package middlewares
import (
"github.com/containous/traefik/types"
"github.com/unrolled/secure"
)
// NewSecure constructs a new Secure instance with supplied options.
func NewSecure(headers *types.Headers) *secure.Secure {
if headers == nil || !headers.HasSecureHeadersDefined() {
return nil
}
opt := secure.Options{
AllowedHosts: headers.AllowedHosts,
HostsProxyHeaders: headers.HostsProxyHeaders,
SSLRedirect: headers.SSLRedirect,
SSLTemporaryRedirect: headers.SSLTemporaryRedirect,
SSLHost: headers.SSLHost,
SSLProxyHeaders: headers.SSLProxyHeaders,
STSSeconds: headers.STSSeconds,
STSIncludeSubdomains: headers.STSIncludeSubdomains,
STSPreload: headers.STSPreload,
ForceSTSHeader: headers.ForceSTSHeader,
FrameDeny: headers.FrameDeny,
CustomFrameOptionsValue: headers.CustomFrameOptionsValue,
ContentTypeNosniff: headers.ContentTypeNosniff,
BrowserXssFilter: headers.BrowserXSSFilter,
CustomBrowserXssValue: headers.CustomBrowserXSSValue,
ContentSecurityPolicy: headers.ContentSecurityPolicy,
PublicKey: headers.PublicKey,
ReferrerPolicy: headers.ReferrerPolicy,
IsDevelopment: headers.IsDevelopment,
}
return secure.New(opt)
}

View file

@ -1,115 +0,0 @@
package middlewares
import (
"bufio"
"net"
"net/http"
"sync"
"time"
)
var (
_ Stateful = &responseRecorder{}
)
// StatsRecorder is an optional middleware that records more details statistics
// about requests and how they are processed. This currently consists of recent
// requests that have caused errors (4xx and 5xx status codes), making it easy
// to pinpoint problems.
type StatsRecorder struct {
mutex sync.RWMutex
numRecentErrors int
recentErrors []*statsError
}
// NewStatsRecorder returns a new StatsRecorder
func NewStatsRecorder(numRecentErrors int) *StatsRecorder {
return &StatsRecorder{
numRecentErrors: numRecentErrors,
}
}
// Stats includes all of the stats gathered by the recorder.
type Stats struct {
RecentErrors []*statsError `json:"recent_errors"`
}
// statsError represents an error that has occurred during request processing.
type statsError struct {
StatusCode int `json:"status_code"`
Status string `json:"status"`
Method string `json:"method"`
Host string `json:"host"`
Path string `json:"path"`
Time time.Time `json:"time"`
}
// responseRecorder captures information from the response and preserves it for
// later analysis.
type responseRecorder struct {
http.ResponseWriter
statusCode int
}
// WriteHeader captures the status code for later retrieval.
func (r *responseRecorder) WriteHeader(status int) {
r.ResponseWriter.WriteHeader(status)
r.statusCode = status
}
// Hijack hijacks the connection
func (r *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return r.ResponseWriter.(http.Hijacker).Hijack()
}
// CloseNotify returns a channel that receives at most a
// single value (true) when the client connection has gone
// away.
func (r *responseRecorder) CloseNotify() <-chan bool {
return r.ResponseWriter.(http.CloseNotifier).CloseNotify()
}
// Flush sends any buffered data to the client.
func (r *responseRecorder) Flush() {
r.ResponseWriter.(http.Flusher).Flush()
}
// ServeHTTP silently extracts information from the request and response as it
// is processed. If the response is 4xx or 5xx, add it to the list of 10 most
// recent errors.
func (s *StatsRecorder) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
recorder := &responseRecorder{w, http.StatusOK}
next(recorder, r)
if recorder.statusCode >= http.StatusBadRequest {
s.mutex.Lock()
defer s.mutex.Unlock()
s.recentErrors = append([]*statsError{
{
StatusCode: recorder.statusCode,
Status: http.StatusText(recorder.statusCode),
Method: r.Method,
Host: r.Host,
Path: r.URL.Path,
Time: time.Now(),
},
}, s.recentErrors...)
// Limit the size of the list to numRecentErrors
if len(s.recentErrors) > s.numRecentErrors {
s.recentErrors = s.recentErrors[:s.numRecentErrors]
}
}
}
// Data returns a copy of the statistics that have been gathered.
func (s *StatsRecorder) Data() *Stats {
s.mutex.RLock()
defer s.mutex.RUnlock()
// We can't return the slice directly or a race condition might develop
recentErrors := make([]*statsError, len(s.recentErrors))
copy(recentErrors, s.recentErrors)
return &Stats{
RecentErrors: recentErrors,
}
}

View file

@ -1,56 +0,0 @@
package middlewares
import (
"context"
"net/http"
"strings"
)
const (
// StripPrefixKey is the key within the request context used to
// store the stripped prefix
StripPrefixKey key = "StripPrefix"
// ForwardedPrefixHeader is the default header to set prefix
ForwardedPrefixHeader = "X-Forwarded-Prefix"
)
// StripPrefix is a middleware used to strip prefix from an URL request
type StripPrefix struct {
Handler http.Handler
Prefixes []string
}
func (s *StripPrefix) ServeHTTP(w http.ResponseWriter, r *http.Request) {
for _, prefix := range s.Prefixes {
if strings.HasPrefix(r.URL.Path, prefix) {
rawReqPath := r.URL.Path
r.URL.Path = stripPrefix(r.URL.Path, prefix)
if r.URL.RawPath != "" {
r.URL.RawPath = stripPrefix(r.URL.RawPath, prefix)
}
s.serveRequest(w, r, strings.TrimSpace(prefix), rawReqPath)
return
}
}
http.NotFound(w, r)
}
func (s *StripPrefix) serveRequest(w http.ResponseWriter, r *http.Request, prefix string, rawReqPath string) {
r = r.WithContext(context.WithValue(r.Context(), StripPrefixKey, rawReqPath))
r.Header.Add(ForwardedPrefixHeader, prefix)
r.RequestURI = r.URL.RequestURI()
s.Handler.ServeHTTP(w, r)
}
// SetHandler sets handler
func (s *StripPrefix) SetHandler(Handler http.Handler) {
s.Handler = Handler
}
func stripPrefix(s, prefix string) string {
return ensureLeadingSlash(strings.TrimPrefix(s, prefix))
}
func ensureLeadingSlash(str string) string {
return "/" + strings.TrimPrefix(str, "/")
}

View file

@ -1,59 +0,0 @@
package middlewares
import (
"context"
"net/http"
"github.com/containous/mux"
"github.com/containous/traefik/log"
)
// StripPrefixRegex is a middleware used to strip prefix from an URL request
type StripPrefixRegex struct {
Handler http.Handler
router *mux.Router
}
// NewStripPrefixRegex builds a new StripPrefixRegex given a handler and prefixes
func NewStripPrefixRegex(handler http.Handler, prefixes []string) *StripPrefixRegex {
stripPrefix := StripPrefixRegex{Handler: handler, router: mux.NewRouter()}
for _, prefix := range prefixes {
stripPrefix.router.PathPrefix(prefix)
}
return &stripPrefix
}
func (s *StripPrefixRegex) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var match mux.RouteMatch
if s.router.Match(r, &match) {
params := make([]string, 0, len(match.Vars)*2)
for key, val := range match.Vars {
params = append(params, key)
params = append(params, val)
}
prefix, err := match.Route.URL(params...)
if err != nil || len(prefix.Path) > len(r.URL.Path) {
log.Error("Error in stripPrefix middleware", err)
return
}
rawReqPath := r.URL.Path
r.URL.Path = r.URL.Path[len(prefix.Path):]
if r.URL.RawPath != "" {
r.URL.RawPath = r.URL.RawPath[len(prefix.Path):]
}
r = r.WithContext(context.WithValue(r.Context(), StripPrefixKey, rawReqPath))
r.Header.Add(ForwardedPrefixHeader, prefix.Path)
r.RequestURI = ensureLeadingSlash(r.URL.RequestURI())
s.Handler.ServeHTTP(w, r)
return
}
http.NotFound(w, r)
}
// SetHandler sets handler
func (s *StripPrefixRegex) SetHandler(Handler http.Handler) {
s.Handler = Handler
}

View file

@ -0,0 +1,67 @@
package stripprefix
import (
"context"
"net/http"
"strings"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go/ext"
)
const (
// ForwardedPrefixHeader is the default header to set prefix.
ForwardedPrefixHeader = "X-Forwarded-Prefix"
typeName = "StripPrefix"
)
// stripPrefix is a middleware used to strip prefix from an URL request.
type stripPrefix struct {
next http.Handler
prefixes []string
name string
}
// New creates a new strip prefix middleware.
func New(ctx context.Context, next http.Handler, config config.StripPrefix, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
return &stripPrefix{
prefixes: config.Prefixes,
next: next,
name: name,
}, nil
}
func (s *stripPrefix) GetTracingInformation() (string, ext.SpanKindEnum) {
return s.name, tracing.SpanKindNoneEnum
}
func (s *stripPrefix) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
for _, prefix := range s.prefixes {
if strings.HasPrefix(req.URL.Path, prefix) {
req.URL.Path = getPrefixStripped(req.URL.Path, prefix)
if req.URL.RawPath != "" {
req.URL.RawPath = getPrefixStripped(req.URL.RawPath, prefix)
}
s.serveRequest(rw, req, strings.TrimSpace(prefix))
return
}
}
http.NotFound(rw, req)
}
func (s *stripPrefix) serveRequest(rw http.ResponseWriter, req *http.Request, prefix string) {
req.Header.Add(ForwardedPrefixHeader, prefix)
req.RequestURI = req.URL.RequestURI()
s.next.ServeHTTP(rw, req)
}
func getPrefixStripped(s, prefix string) string {
return ensureLeadingSlash(strings.TrimPrefix(s, prefix))
}
func ensureLeadingSlash(str string) string {
return "/" + strings.TrimPrefix(str, "/")
}

View file

@ -1,18 +1,21 @@
package middlewares
package stripprefix
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/config"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStripPrefix(t *testing.T) {
tests := []struct {
testCases := []struct {
desc string
prefixes []string
config config.StripPrefix
path string
expectedStatusCode int
expectedPath string
@ -20,84 +23,107 @@ func TestStripPrefix(t *testing.T) {
expectedHeader string
}{
{
desc: "no prefixes configured",
prefixes: []string{},
desc: "no prefixes configured",
config: config.StripPrefix{
Prefixes: []string{},
},
path: "/noprefixes",
expectedStatusCode: http.StatusNotFound,
},
{
desc: "wildcard (.*) requests",
prefixes: []string{"/"},
desc: "wildcard (.*) requests",
config: config.StripPrefix{
Prefixes: []string{"/"},
},
path: "/",
expectedStatusCode: http.StatusOK,
expectedPath: "/",
expectedHeader: "/",
},
{
desc: "prefix and path matching",
prefixes: []string{"/stat"},
desc: "prefix and path matching",
config: config.StripPrefix{
Prefixes: []string{"/stat"},
},
path: "/stat",
expectedStatusCode: http.StatusOK,
expectedPath: "/",
expectedHeader: "/stat",
},
{
desc: "path prefix on exactly matching path",
prefixes: []string{"/stat/"},
desc: "path prefix on exactly matching path",
config: config.StripPrefix{
Prefixes: []string{"/stat/"},
},
path: "/stat/",
expectedStatusCode: http.StatusOK,
expectedPath: "/",
expectedHeader: "/stat/",
},
{
desc: "path prefix on matching longer path",
prefixes: []string{"/stat/"},
desc: "path prefix on matching longer path",
config: config.StripPrefix{
Prefixes: []string{"/stat/"},
},
path: "/stat/us",
expectedStatusCode: http.StatusOK,
expectedPath: "/us",
expectedHeader: "/stat/",
},
{
desc: "path prefix on mismatching path",
prefixes: []string{"/stat/"},
desc: "path prefix on mismatching path",
config: config.StripPrefix{
Prefixes: []string{"/stat/"},
},
path: "/status",
expectedStatusCode: http.StatusNotFound,
},
{
desc: "general prefix on matching path",
prefixes: []string{"/stat"},
desc: "general prefix on matching path",
config: config.StripPrefix{
Prefixes: []string{"/stat"},
},
path: "/stat/",
expectedStatusCode: http.StatusOK,
expectedPath: "/",
expectedHeader: "/stat",
},
{
desc: "earlier prefix matching",
prefixes: []string{"/stat", "/stat/us"},
desc: "earlier prefix matching",
config: config.StripPrefix{
Prefixes: []string{"/stat", "/stat/us"},
},
path: "/stat/us",
expectedStatusCode: http.StatusOK,
expectedPath: "/us",
expectedHeader: "/stat",
},
{
desc: "later prefix matching",
prefixes: []string{"/mismatch", "/stat"},
desc: "later prefix matching",
config: config.StripPrefix{
Prefixes: []string{"/mismatch", "/stat"},
},
path: "/stat",
expectedStatusCode: http.StatusOK,
expectedPath: "/",
expectedHeader: "/stat",
},
{
desc: "prefix matching within slash boundaries",
prefixes: []string{"/stat"},
desc: "prefix matching within slash boundaries",
config: config.StripPrefix{
Prefixes: []string{"/stat"},
},
path: "/status",
expectedStatusCode: http.StatusOK,
expectedPath: "/us",
expectedHeader: "/stat",
},
{
desc: "raw path is also stripped",
prefixes: []string{"/stat"},
desc: "raw path is also stripped",
config: config.StripPrefix{
Prefixes: []string{"/stat"},
},
path: "/stat/a%2Fb",
expectedStatusCode: http.StatusOK,
expectedPath: "/a/b",
@ -106,21 +132,21 @@ func TestStripPrefix(t *testing.T) {
},
}
for _, test := range tests {
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
var actualPath, actualRawPath, actualHeader, requestURI string
handler := &StripPrefix{
Prefixes: test.prefixes,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actualPath = r.URL.Path
actualRawPath = r.URL.RawPath
actualHeader = r.Header.Get(ForwardedPrefixHeader)
requestURI = r.RequestURI
}),
}
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actualPath = r.URL.Path
actualRawPath = r.URL.RawPath
actualHeader = r.Header.Get(ForwardedPrefixHeader)
requestURI = r.RequestURI
})
handler, err := New(context.Background(), next, test.config, "foo-strip-prefix")
require.NoError(t, err)
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
resp := &httptest.ResponseRecorder{Code: http.StatusOK}

View file

@ -0,0 +1,79 @@
package stripprefixregex
import (
"context"
"net/http"
"strings"
"github.com/containous/mux"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/middlewares/stripprefix"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go/ext"
)
const (
typeName = "StripPrefixRegex"
)
// StripPrefixRegex is a middleware used to strip prefix from an URL request.
type stripPrefixRegex struct {
next http.Handler
router *mux.Router
name string
}
// New builds a new StripPrefixRegex middleware.
func New(ctx context.Context, next http.Handler, config config.StripPrefixRegex, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
stripPrefix := stripPrefixRegex{
next: next,
router: mux.NewRouter(),
name: name,
}
for _, prefix := range config.Regex {
stripPrefix.router.PathPrefix(prefix)
}
return &stripPrefix, nil
}
func (s *stripPrefixRegex) GetTracingInformation() (string, ext.SpanKindEnum) {
return s.name, tracing.SpanKindNoneEnum
}
func (s *stripPrefixRegex) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
var match mux.RouteMatch
if s.router.Match(req, &match) {
params := make([]string, 0, len(match.Vars)*2)
for key, val := range match.Vars {
params = append(params, key)
params = append(params, val)
}
prefix, err := match.Route.URL(params...)
if err != nil || len(prefix.Path) > len(req.URL.Path) {
logger := middlewares.GetLogger(req.Context(), s.name, typeName)
logger.Error("Error in stripPrefix middleware", err)
return
}
req.URL.Path = req.URL.Path[len(prefix.Path):]
if req.URL.RawPath != "" {
req.URL.RawPath = req.URL.RawPath[len(prefix.Path):]
}
req.Header.Add(stripprefix.ForwardedPrefixHeader, prefix.Path)
req.RequestURI = ensureLeadingSlash(req.URL.RequestURI())
s.next.ServeHTTP(rw, req)
return
}
http.NotFound(rw, req)
}
func ensureLeadingSlash(str string) string {
return "/" + strings.TrimPrefix(str, "/")
}

View file

@ -1,18 +1,24 @@
package middlewares
package stripprefixregex
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/config"
"github.com/containous/traefik/middlewares/stripprefix"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestStripPrefixRegex(t *testing.T) {
testPrefixRegex := []string{"/a/api/", "/b/{regex}/", "/c/{category}/{id:[0-9]+}/"}
testPrefixRegex := config.StripPrefixRegex{
Regex: []string{"/a/api/", "/b/{regex}/", "/c/{category}/{id:[0-9]+}/"},
}
tests := []struct {
testCases := []struct {
path string
expectedStatusCode int
expectedPath string
@ -70,7 +76,7 @@ func TestStripPrefixRegex(t *testing.T) {
},
}
for _, test := range tests {
for _, test := range testCases {
test := test
t.Run(test.path, func(t *testing.T) {
t.Parallel()
@ -79,9 +85,10 @@ func TestStripPrefixRegex(t *testing.T) {
handlerPath := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actualPath = r.URL.Path
actualRawPath = r.URL.RawPath
actualHeader = r.Header.Get(ForwardedPrefixHeader)
actualHeader = r.Header.Get(stripprefix.ForwardedPrefixHeader)
})
handler := NewStripPrefixRegex(handlerPath, testPrefixRegex)
handler, err := New(context.Background(), handlerPath, testPrefixRegex, "foo-strip-prefix-regex")
require.NoError(t, err)
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
resp := &httptest.ResponseRecorder{Code: http.StatusOK}
@ -91,7 +98,7 @@ func TestStripPrefixRegex(t *testing.T) {
assert.Equal(t, test.expectedStatusCode, resp.Code, "Unexpected status code.")
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
assert.Equal(t, test.expectedRawPath, actualRawPath, "Unexpected raw path.")
assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", ForwardedPrefixHeader)
assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", stripprefix.ForwardedPrefixHeader)
})
}
}

View file

@ -1,251 +0,0 @@
package middlewares
import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"net/http"
"net/url"
"strings"
"github.com/containous/traefik/log"
"github.com/containous/traefik/types"
)
const xForwardedTLSClientCert = "X-Forwarded-Tls-Client-Cert"
const xForwardedTLSClientCertInfos = "X-Forwarded-Tls-Client-Cert-Infos"
// TLSClientCertificateInfos is a struct for specifying the configuration for the tlsClientHeaders middleware.
type TLSClientCertificateInfos struct {
NotAfter bool
NotBefore bool
Subject *TLSCLientCertificateSubjectInfos
Sans bool
}
// TLSCLientCertificateSubjectInfos contains the configuration for the certificate subject infos.
type TLSCLientCertificateSubjectInfos struct {
Country bool
Province bool
Locality bool
Organization bool
CommonName bool
SerialNumber bool
}
// TLSClientHeaders is a middleware that helps setup a few tls infos features.
type TLSClientHeaders struct {
PEM bool // pass the sanitized pem to the backend in a specific header
Infos *TLSClientCertificateInfos // pass selected informations from the client certificate
}
func newTLSCLientCertificateSubjectInfos(infos *types.TLSCLientCertificateSubjectInfos) *TLSCLientCertificateSubjectInfos {
if infos == nil {
return nil
}
return &TLSCLientCertificateSubjectInfos{
SerialNumber: infos.SerialNumber,
CommonName: infos.CommonName,
Country: infos.Country,
Locality: infos.Locality,
Organization: infos.Organization,
Province: infos.Province,
}
}
func newTLSClientInfos(infos *types.TLSClientCertificateInfos) *TLSClientCertificateInfos {
if infos == nil {
return nil
}
return &TLSClientCertificateInfos{
NotBefore: infos.NotBefore,
NotAfter: infos.NotAfter,
Sans: infos.Sans,
Subject: newTLSCLientCertificateSubjectInfos(infos.Subject),
}
}
// NewTLSClientHeaders constructs a new TLSClientHeaders instance from supplied frontend header struct.
func NewTLSClientHeaders(frontend *types.Frontend) *TLSClientHeaders {
if frontend == nil {
return nil
}
var pem bool
var infos *TLSClientCertificateInfos
if frontend.PassTLSClientCert != nil {
conf := frontend.PassTLSClientCert
pem = conf.PEM
infos = newTLSClientInfos(conf.Infos)
}
return &TLSClientHeaders{
PEM: pem,
Infos: infos,
}
}
func (s *TLSClientHeaders) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
s.ModifyRequestHeaders(r)
// If there is a next, call it.
if next != nil {
next(w, r)
}
}
// sanitize As we pass the raw certificates, remove the useless data and make it http request compliant
func sanitize(cert []byte) string {
s := string(cert)
r := strings.NewReplacer("-----BEGIN CERTIFICATE-----", "",
"-----END CERTIFICATE-----", "",
"\n", "")
cleaned := r.Replace(s)
return url.QueryEscape(cleaned)
}
// extractCertificate extract the certificate from the request
func extractCertificate(cert *x509.Certificate) string {
b := pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
certPEM := pem.EncodeToMemory(&b)
if certPEM == nil {
log.Error("Cannot extract the certificate content")
return ""
}
return sanitize(certPEM)
}
// getXForwardedTLSClientCert Build a string with the client certificates
func getXForwardedTLSClientCert(certs []*x509.Certificate) string {
var headerValues []string
for _, peerCert := range certs {
headerValues = append(headerValues, extractCertificate(peerCert))
}
return strings.Join(headerValues, ",")
}
// getSANs get the Subject Alternate Name values
func getSANs(cert *x509.Certificate) []string {
var sans []string
if cert == nil {
return sans
}
sans = append(cert.DNSNames, cert.EmailAddresses...)
var ips []string
for _, ip := range cert.IPAddresses {
ips = append(ips, ip.String())
}
sans = append(sans, ips...)
var uris []string
for _, uri := range cert.URIs {
uris = append(uris, uri.String())
}
return append(sans, uris...)
}
// getSubjectInfos extract the requested informations from the certificate subject
func (s *TLSClientHeaders) getSubjectInfos(cs *pkix.Name) string {
var subject string
if s.Infos != nil && s.Infos.Subject != nil {
options := s.Infos.Subject
var content []string
if options.Country && len(cs.Country) > 0 {
content = append(content, fmt.Sprintf("C=%s", cs.Country[0]))
}
if options.Province && len(cs.Province) > 0 {
content = append(content, fmt.Sprintf("ST=%s", cs.Province[0]))
}
if options.Locality && len(cs.Locality) > 0 {
content = append(content, fmt.Sprintf("L=%s", cs.Locality[0]))
}
if options.Organization && len(cs.Organization) > 0 {
content = append(content, fmt.Sprintf("O=%s", cs.Organization[0]))
}
if options.CommonName && len(cs.CommonName) > 0 {
content = append(content, fmt.Sprintf("CN=%s", cs.CommonName))
}
if len(content) > 0 {
subject = `Subject="` + strings.Join(content, ",") + `"`
}
}
return subject
}
// getXForwardedTLSClientCertInfos Build a string with the wanted client certificates informations
// like Subject="C=%s,ST=%s,L=%s,O=%s,CN=%s",NB=%d,NA=%d,SAN=%s;
func (s *TLSClientHeaders) getXForwardedTLSClientCertInfos(certs []*x509.Certificate) string {
var headerValues []string
for _, peerCert := range certs {
var values []string
var sans string
var nb string
var na string
subject := s.getSubjectInfos(&peerCert.Subject)
if len(subject) > 0 {
values = append(values, subject)
}
ci := s.Infos
if ci != nil {
if ci.NotBefore {
nb = fmt.Sprintf("NB=%d", uint64(peerCert.NotBefore.Unix()))
values = append(values, nb)
}
if ci.NotAfter {
na = fmt.Sprintf("NA=%d", uint64(peerCert.NotAfter.Unix()))
values = append(values, na)
}
if ci.Sans {
sans = fmt.Sprintf("SAN=%s", strings.Join(getSANs(peerCert), ","))
values = append(values, sans)
}
}
value := strings.Join(values, ",")
headerValues = append(headerValues, value)
}
return strings.Join(headerValues, ";")
}
// ModifyRequestHeaders set the wanted headers with the certificates informations
func (s *TLSClientHeaders) ModifyRequestHeaders(r *http.Request) {
if s.PEM {
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
r.Header.Set(xForwardedTLSClientCert, getXForwardedTLSClientCert(r.TLS.PeerCertificates))
} else {
log.Warn("Try to extract certificate on a request without TLS")
}
}
if s.Infos != nil {
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
headerContent := s.getXForwardedTLSClientCertInfos(r.TLS.PeerCertificates)
r.Header.Set(xForwardedTLSClientCertInfos, url.QueryEscape(headerContent))
} else {
log.Warn("Try to extract certificate on a request without TLS")
}
}
}

View file

@ -1,25 +0,0 @@
package tracing
import "net/http"
// HTTPHeadersCarrier custom implementation to fix duplicated headers
// It has been fixed in https://github.com/opentracing/opentracing-go/pull/191
type HTTPHeadersCarrier http.Header
// Set conforms to the TextMapWriter interface.
func (c HTTPHeadersCarrier) Set(key, val string) {
h := http.Header(c)
h.Set(key, val)
}
// ForeachKey conforms to the TextMapReader interface.
func (c HTTPHeadersCarrier) ForeachKey(handler func(key, val string) error) error {
for k, vals := range c {
for _, v := range vals {
if err := handler(k, v); err != nil {
return err
}
}
}
return nil
}

View file

@ -1,45 +0,0 @@
package datadog
import (
"io"
"strings"
"github.com/containous/traefik/log"
"github.com/opentracing/opentracing-go"
ddtracer "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/opentracer"
datadog "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
)
// Name sets the name of this tracer
const Name = "datadog"
// Config provides configuration settings for a datadog tracer
type Config struct {
LocalAgentHostPort string `description:"Set datadog-agent's host:port that the reporter will used. Defaults to localhost:8126" export:"false"`
GlobalTag string `description:"Key:Value tag to be set on all the spans." export:"true"`
Debug bool `description:"Enable DataDog debug." export:"true"`
}
// Setup sets up the tracer
func (c *Config) Setup(serviceName string) (opentracing.Tracer, io.Closer, error) {
tag := strings.SplitN(c.GlobalTag, ":", 2)
value := ""
if len(tag) == 2 {
value = tag[1]
}
tracer := ddtracer.New(
datadog.WithAgentAddr(c.LocalAgentHostPort),
datadog.WithServiceName(serviceName),
datadog.WithGlobalTag(tag[0], value),
datadog.WithDebugMode(c.Debug),
)
// Without this, child spans are getting the NOOP tracer
opentracing.SetGlobalTracer(tracer)
log.Debug("DataDog tracer configured")
return tracer, nil, nil
}

View file

@ -1,57 +1,57 @@
package tracing
import (
"fmt"
"context"
"net/http"
"github.com/containous/traefik/log"
"github.com/containous/alice"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"github.com/urfave/negroni"
)
type entryPointMiddleware struct {
entryPoint string
*Tracing
}
const (
entryPointTypeName = "TracingEntryPoint"
)
// NewEntryPoint creates a new middleware that the incoming request
func (t *Tracing) NewEntryPoint(name string) negroni.Handler {
log.Debug("Added entrypoint tracing middleware")
return &entryPointMiddleware{Tracing: t, entryPoint: name}
}
// NewEntryPoint creates a new middleware that the incoming request.
func NewEntryPoint(ctx context.Context, t *tracing.Tracing, entryPointName string, next http.Handler) http.Handler {
middlewares.GetLogger(ctx, "tracing", entryPointTypeName).Debug("Creating middleware")
func (e *entryPointMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
opNameFunc := generateEntryPointSpanName
ctx, _ := e.Extract(opentracing.HTTPHeaders, HTTPHeadersCarrier(r.Header))
span := e.StartSpan(opNameFunc(r, e.entryPoint, e.SpanNameLimit), ext.RPCServerOption(ctx))
ext.Component.Set(span, e.ServiceName)
LogRequest(span, r)
ext.SpanKindRPCServer.Set(span)
r = r.WithContext(opentracing.ContextWithSpan(r.Context(), span))
recorder := newStatusCodeRecoder(w, 200)
next(recorder, r)
LogResponseCode(span, recorder.Status())
span.Finish()
}
// generateEntryPointSpanName will return a Span name of an appropriate lenth based on the 'spanLimit' argument. If needed, it will be truncated, but will not be less than 24 characters.
func generateEntryPointSpanName(r *http.Request, entryPoint string, spanLimit int) string {
name := fmt.Sprintf("Entrypoint %s %s", entryPoint, r.Host)
if spanLimit > 0 && len(name) > spanLimit {
if spanLimit < EntryPointMaxLengthNumber {
log.Warnf("SpanNameLimit is set to be less than required static number of characters, defaulting to %d + 3", EntryPointMaxLengthNumber)
spanLimit = EntryPointMaxLengthNumber + 3
}
hash := computeHash(name)
limit := (spanLimit - EntryPointMaxLengthNumber) / 2
name = fmt.Sprintf("Entrypoint %s %s %s", truncateString(entryPoint, limit), truncateString(r.Host, limit), hash)
return &entryPointMiddleware{
entryPoint: entryPointName,
Tracing: t,
next: next,
}
}
type entryPointMiddleware struct {
*tracing.Tracing
entryPoint string
next http.Handler
}
func (e *entryPointMiddleware) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
spanCtx, _ := e.Extract(opentracing.HTTPHeaders, tracing.HTTPHeadersCarrier(req.Header))
span, req, finish := e.StartSpanf(req, ext.SpanKindRPCServerEnum, "EntryPoint", []string{e.entryPoint, req.Host}, " ", ext.RPCServerOption(spanCtx))
defer finish()
ext.Component.Set(span, e.ServiceName)
tracing.LogRequest(span, req)
req = req.WithContext(tracing.WithTracing(req.Context(), e.Tracing))
recorder := newStatusCodeRecoder(rw, http.StatusOK)
e.next.ServeHTTP(recorder, req)
tracing.LogResponseCode(span, recorder.Status())
}
// WrapEntryPointHandler Wraps tracing to alice.Constructor.
func WrapEntryPointHandler(ctx context.Context, tracer *tracing.Tracing, entryPointName string) alice.Constructor {
return func(next http.Handler) (http.Handler, error) {
return NewEntryPoint(ctx, tracer, entryPointName, next), nil
}
return name
}

View file

@ -1,70 +1,87 @@
package tracing
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestEntryPointMiddlewareServeHTTP(t *testing.T) {
expectedTags := map[string]interface{}{
"span.kind": ext.SpanKindRPCServerEnum,
"http.method": "GET",
"component": "",
"http.url": "http://www.test.com",
"http.host": "www.test.com",
func TestEntryPointMiddleware(t *testing.T) {
type expected struct {
Tags map[string]interface{}
OperationName string
}
testCases := []struct {
desc string
entryPoint string
tracing *Tracing
expectedTags map[string]interface{}
expectedName string
desc string
entryPoint string
spanNameLimit int
tracing *trackingBackenMock
expected expected
}{
{
desc: "no truncation test",
entryPoint: "test",
tracing: &Tracing{
SpanNameLimit: 0,
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
desc: "no truncation test",
entryPoint: "test",
spanNameLimit: 0,
tracing: &trackingBackenMock{
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
},
expectedTags: expectedTags,
expectedName: "Entrypoint test www.test.com",
}, {
desc: "basic test",
entryPoint: "test",
tracing: &Tracing{
SpanNameLimit: 25,
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
expected: expected{
Tags: map[string]interface{}{
"span.kind": ext.SpanKindRPCServerEnum,
"http.method": http.MethodGet,
"component": "",
"http.url": "http://www.test.com",
"http.host": "www.test.com",
},
OperationName: "EntryPoint test www.test.com",
},
},
{
desc: "basic test",
entryPoint: "test",
spanNameLimit: 25,
tracing: &trackingBackenMock{
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
},
expected: expected{
Tags: map[string]interface{}{
"span.kind": ext.SpanKindRPCServerEnum,
"http.method": http.MethodGet,
"component": "",
"http.url": "http://www.test.com",
"http.host": "www.test.com",
},
OperationName: "EntryPoint te... ww... 0c15301b",
},
expectedTags: expectedTags,
expectedName: "Entrypoint te... ww... 39b97e58",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
e := &entryPointMiddleware{
entryPoint: test.entryPoint,
Tracing: test.tracing,
}
newTracing, err := tracing.NewTracing("", test.spanNameLimit, test.tracing)
require.NoError(t, err)
next := func(http.ResponseWriter, *http.Request) {
req := httptest.NewRequest(http.MethodGet, "http://www.test.com", nil)
rw := httptest.NewRecorder()
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
span := test.tracing.tracer.(*MockTracer).Span
actual := span.Tags
assert.Equal(t, test.expectedTags, actual)
assert.Equal(t, test.expectedName, span.OpName)
}
tags := span.Tags
assert.Equal(t, test.expected.Tags, tags)
assert.Equal(t, test.expected.OperationName, span.OpName)
})
e.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "http://www.test.com", nil), next)
handler := NewEntryPoint(context.Background(), newTracing, test.entryPoint, next)
handler.ServeHTTP(rw, req)
})
}
}

View file

@ -1,63 +1,58 @@
package tracing
import (
"fmt"
"context"
"net/http"
"github.com/containous/traefik/log"
"github.com/containous/traefik/middlewares"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/urfave/negroni"
)
const (
forwarderTypeName = "TracingForwarder"
)
type forwarderMiddleware struct {
frontend string
backend string
opName string
*Tracing
router string
service string
next http.Handler
}
// NewForwarderMiddleware creates a new forwarder middleware that traces the outgoing request
func (t *Tracing) NewForwarderMiddleware(frontend, backend string) negroni.Handler {
log.Debugf("Added outgoing tracing middleware %s", frontend)
// NewForwarder creates a new forwarder middleware that traces the outgoing request.
func NewForwarder(ctx context.Context, router, service string, next http.Handler) http.Handler {
middlewares.GetLogger(ctx, "tracing", forwarderTypeName).
Debugf("Added outgoing tracing middleware %s", service)
return &forwarderMiddleware{
Tracing: t,
frontend: frontend,
backend: backend,
opName: generateForwardSpanName(frontend, backend, t.SpanNameLimit),
router: router,
service: service,
next: next,
}
}
func (f *forwarderMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
span, r, finish := StartSpan(r, f.opName, true)
func (f *forwarderMiddleware) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
tr, err := tracing.FromContext(req.Context())
if err != nil {
f.next.ServeHTTP(rw, req)
return
}
opParts := []string{f.service, f.router}
span, req, finish := tr.StartSpanf(req, ext.SpanKindRPCClientEnum, "forward", opParts, "/")
defer finish()
span.SetTag("frontend.name", f.frontend)
span.SetTag("backend.name", f.backend)
ext.HTTPMethod.Set(span, r.Method)
ext.HTTPUrl.Set(span, fmt.Sprintf("%s%s", r.URL.String(), r.RequestURI))
span.SetTag("http.host", r.Host)
InjectRequestHeaders(r)
span.SetTag("service.name", f.service)
span.SetTag("router.name", f.router)
ext.HTTPMethod.Set(span, req.Method)
ext.HTTPUrl.Set(span, req.URL.String())
span.SetTag("http.host", req.Host)
recorder := newStatusCodeRecoder(w, 200)
tracing.InjectRequestHeaders(req)
next(recorder, r)
recorder := newStatusCodeRecoder(rw, 200)
LogResponseCode(span, recorder.Status())
}
// generateForwardSpanName will return a Span name of an appropriate lenth based on the 'spanLimit' argument. If needed, it will be truncated, but will not be less than 21 characters
func generateForwardSpanName(frontend, backend string, spanLimit int) string {
name := fmt.Sprintf("forward %s/%s", frontend, backend)
if spanLimit > 0 && len(name) > spanLimit {
if spanLimit < ForwardMaxLengthNumber {
log.Warnf("SpanNameLimit is set to be less than required static number of characters, defaulting to %d + 3", ForwardMaxLengthNumber)
spanLimit = ForwardMaxLengthNumber + 3
}
hash := computeHash(name)
limit := (spanLimit - ForwardMaxLengthNumber) / 2
name = fmt.Sprintf("forward %s/%s/%s", truncateString(frontend, limit), truncateString(backend, limit), hash)
}
return name
f.next.ServeHTTP(recorder, req)
tracing.LogResponseCode(span, recorder.Status())
}

View file

@ -1,93 +1,136 @@
package tracing
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestTracingNewForwarderMiddleware(t *testing.T) {
func TestNewForwarder(t *testing.T) {
type expected struct {
Tags map[string]interface{}
OperationName string
}
testCases := []struct {
desc string
tracer *Tracing
frontend string
backend string
expected *forwarderMiddleware
desc string
spanNameLimit int
tracing *trackingBackenMock
service string
router string
expected expected
}{
{
desc: "Simple Forward Tracer without truncation and hashing",
tracer: &Tracing{
SpanNameLimit: 101,
desc: "Simple Forward Tracer without truncation and hashing",
spanNameLimit: 101,
tracing: &trackingBackenMock{
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
},
frontend: "some-service.domain.tld",
backend: "some-service.domain.tld",
expected: &forwarderMiddleware{
Tracing: &Tracing{
SpanNameLimit: 101,
service: "some-service.domain.tld",
router: "some-service.domain.tld",
expected: expected{
Tags: map[string]interface{}{
"http.host": "www.test.com",
"http.method": "GET",
"http.url": "http://www.test.com/toto",
"service.name": "some-service.domain.tld",
"router.name": "some-service.domain.tld",
"span.kind": ext.SpanKindRPCClientEnum,
},
frontend: "some-service.domain.tld",
backend: "some-service.domain.tld",
opName: "forward some-service.domain.tld/some-service.domain.tld",
OperationName: "forward some-service.domain.tld/some-service.domain.tld",
},
}, {
desc: "Simple Forward Tracer with truncation and hashing",
tracer: &Tracing{
SpanNameLimit: 101,
desc: "Simple Forward Tracer with truncation and hashing",
spanNameLimit: 101,
tracing: &trackingBackenMock{
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
},
frontend: "some-service-100.slug.namespace.environment.domain.tld",
backend: "some-service-100.slug.namespace.environment.domain.tld",
expected: &forwarderMiddleware{
Tracing: &Tracing{
SpanNameLimit: 101,
service: "some-service-100.slug.namespace.environment.domain.tld",
router: "some-service-100.slug.namespace.environment.domain.tld",
expected: expected{
Tags: map[string]interface{}{
"http.host": "www.test.com",
"http.method": "GET",
"http.url": "http://www.test.com/toto",
"service.name": "some-service-100.slug.namespace.environment.domain.tld",
"router.name": "some-service-100.slug.namespace.environment.domain.tld",
"span.kind": ext.SpanKindRPCClientEnum,
},
frontend: "some-service-100.slug.namespace.environment.domain.tld",
backend: "some-service-100.slug.namespace.environment.domain.tld",
opName: "forward some-service-100.slug.namespace.enviro.../some-service-100.slug.namespace.enviro.../bc4a0d48",
OperationName: "forward some-service-100.slug.namespace.enviro.../some-service-100.slug.namespace.enviro.../bc4a0d48",
},
},
{
desc: "Exactly 101 chars",
tracer: &Tracing{
SpanNameLimit: 101,
desc: "Exactly 101 chars",
spanNameLimit: 101,
tracing: &trackingBackenMock{
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
},
frontend: "some-service1.namespace.environment.domain.tld",
backend: "some-service1.namespace.environment.domain.tld",
expected: &forwarderMiddleware{
Tracing: &Tracing{
SpanNameLimit: 101,
service: "some-service1.namespace.environment.domain.tld",
router: "some-service1.namespace.environment.domain.tld",
expected: expected{
Tags: map[string]interface{}{
"http.host": "www.test.com",
"http.method": "GET",
"http.url": "http://www.test.com/toto",
"service.name": "some-service1.namespace.environment.domain.tld",
"router.name": "some-service1.namespace.environment.domain.tld",
"span.kind": ext.SpanKindRPCClientEnum,
},
frontend: "some-service1.namespace.environment.domain.tld",
backend: "some-service1.namespace.environment.domain.tld",
opName: "forward some-service1.namespace.environment.domain.tld/some-service1.namespace.environment.domain.tld",
OperationName: "forward some-service1.namespace.environment.domain.tld/some-service1.namespace.environment.domain.tld",
},
},
{
desc: "More than 101 chars",
tracer: &Tracing{
SpanNameLimit: 101,
desc: "More than 101 chars",
spanNameLimit: 101,
tracing: &trackingBackenMock{
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
},
frontend: "some-service1.frontend.namespace.environment.domain.tld",
backend: "some-service1.backend.namespace.environment.domain.tld",
expected: &forwarderMiddleware{
Tracing: &Tracing{
SpanNameLimit: 101,
service: "some-service1.frontend.namespace.environment.domain.tld",
router: "some-service1.backend.namespace.environment.domain.tld",
expected: expected{
Tags: map[string]interface{}{
"http.host": "www.test.com",
"http.method": "GET",
"http.url": "http://www.test.com/toto",
"service.name": "some-service1.frontend.namespace.environment.domain.tld",
"router.name": "some-service1.backend.namespace.environment.domain.tld",
"span.kind": ext.SpanKindRPCClientEnum,
},
frontend: "some-service1.frontend.namespace.environment.domain.tld",
backend: "some-service1.backend.namespace.environment.domain.tld",
opName: "forward some-service1.frontend.namespace.envir.../some-service1.backend.namespace.enviro.../fa49dd23",
OperationName: "forward some-service1.frontend.namespace.envir.../some-service1.backend.namespace.enviro.../fa49dd23",
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
actual := test.tracer.NewForwarderMiddleware(test.frontend, test.backend)
newTracing, err := tracing.NewTracing("", test.spanNameLimit, test.tracing)
require.NoError(t, err)
assert.Equal(t, test.expected, actual)
assert.True(t, len(test.expected.opName) <= test.tracer.SpanNameLimit)
req := httptest.NewRequest(http.MethodGet, "http://www.test.com/toto", nil)
req = req.WithContext(tracing.WithTracing(req.Context(), newTracing))
rw := httptest.NewRecorder()
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
span := test.tracing.tracer.(*MockTracer).Span
tags := span.Tags
assert.Equal(t, test.expected.Tags, tags)
assert.True(t, len(test.expected.OperationName) <= test.spanNameLimit,
"the len of the operation name %q [len: %d] doesn't respect limit %d",
test.expected.OperationName, len(test.expected.OperationName), test.spanNameLimit)
assert.Equal(t, test.expected.OperationName, span.OpName)
})
handler := NewForwarder(context.Background(), test.router, test.service, next)
handler.ServeHTTP(rw, req)
})
}
}

View file

@ -1,73 +0,0 @@
package jaeger
import (
"fmt"
"io"
"github.com/containous/traefik/log"
"github.com/opentracing/opentracing-go"
jaegercfg "github.com/uber/jaeger-client-go/config"
"github.com/uber/jaeger-client-go/zipkin"
jaegermet "github.com/uber/jaeger-lib/metrics"
)
// Name sets the name of this tracer
const Name = "jaeger"
// Config provides configuration settings for a jaeger tracer
type Config struct {
SamplingServerURL string `description:"set the sampling server url." export:"false"`
SamplingType string `description:"set the sampling type." export:"true"`
SamplingParam float64 `description:"set the sampling parameter." export:"true"`
LocalAgentHostPort string `description:"set jaeger-agent's host:port that the reporter will used." export:"false"`
Gen128Bit bool `description:"generate 128 bit span IDs." export:"true"`
Propagation string `description:"which propgation format to use (jaeger/b3)." export:"true"`
}
// Setup sets up the tracer
func (c *Config) Setup(componentName string) (opentracing.Tracer, io.Closer, error) {
jcfg := jaegercfg.Configuration{
Sampler: &jaegercfg.SamplerConfig{
SamplingServerURL: c.SamplingServerURL,
Type: c.SamplingType,
Param: c.SamplingParam,
},
Reporter: &jaegercfg.ReporterConfig{
LogSpans: true,
LocalAgentHostPort: c.LocalAgentHostPort,
},
}
jMetricsFactory := jaegermet.NullFactory
opts := []jaegercfg.Option{
jaegercfg.Logger(&jaegerLogger{}),
jaegercfg.Metrics(jMetricsFactory),
jaegercfg.Gen128Bit(c.Gen128Bit),
}
switch c.Propagation {
case "b3":
p := zipkin.NewZipkinB3HTTPHeaderPropagator()
opts = append(opts,
jaegercfg.Injector(opentracing.HTTPHeaders, p),
jaegercfg.Extractor(opentracing.HTTPHeaders, p),
)
case "jaeger", "":
default:
return nil, nil, fmt.Errorf("unknown propagation format: %s", c.Propagation)
}
// Initialize tracer with a logger and a metrics factory
closer, err := jcfg.InitGlobalTracer(
componentName,
opts...,
)
if err != nil {
log.Warnf("Could not initialize jaeger tracer: %s", err.Error())
return nil, nil, err
}
log.Debug("Jaeger tracer configured")
return opentracing.GlobalTracer(), closer, nil
}

View file

@ -1,15 +0,0 @@
package jaeger
import "github.com/containous/traefik/log"
// jaegerLogger is an implementation of the Logger interface that delegates to traefik log
type jaegerLogger struct{}
func (l *jaegerLogger) Error(msg string) {
log.Errorf("Tracing jaeger error: %s", msg)
}
// Infof logs a message at debug priority
func (l *jaegerLogger) Infof(msg string, args ...interface{}) {
log.Debugf(msg, args...)
}

View file

@ -1,29 +1,43 @@
package tracing
import (
"testing"
"io"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/log"
"github.com/stretchr/testify/assert"
)
type MockTracer struct {
Span *MockSpan
}
// StartSpan belongs to the Tracer interface.
func (n MockTracer) StartSpan(operationName string, opts ...opentracing.StartSpanOption) opentracing.Span {
n.Span.OpName = operationName
return n.Span
}
// Inject belongs to the Tracer interface.
func (n MockTracer) Inject(sp opentracing.SpanContext, format interface{}, carrier interface{}) error {
return nil
}
// Extract belongs to the Tracer interface.
func (n MockTracer) Extract(format interface{}, carrier interface{}) (opentracing.SpanContext, error) {
return nil, opentracing.ErrSpanContextNotFound
}
// MockSpanContext
type MockSpanContext struct{}
func (n MockSpanContext) ForeachBaggageItem(handler func(k, v string) bool) {}
// MockSpan
type MockSpan struct {
OpName string
Tags map[string]interface{}
}
type MockSpanContext struct {
}
// MockSpanContext:
func (n MockSpanContext) ForeachBaggageItem(handler func(k, v string) bool) {}
// MockSpan:
func (n MockSpan) Context() opentracing.SpanContext { return MockSpanContext{} }
func (n MockSpan) SetBaggageItem(key, val string) opentracing.Span {
return MockSpan{Tags: make(map[string]interface{})}
@ -46,88 +60,11 @@ func (n MockSpan) Reset() {
n.Tags = make(map[string]interface{})
}
// StartSpan belongs to the Tracer interface.
func (n MockTracer) StartSpan(operationName string, opts ...opentracing.StartSpanOption) opentracing.Span {
n.Span.OpName = operationName
return n.Span
type trackingBackenMock struct {
tracer opentracing.Tracer
}
// Inject belongs to the Tracer interface.
func (n MockTracer) Inject(sp opentracing.SpanContext, format interface{}, carrier interface{}) error {
return nil
}
// Extract belongs to the Tracer interface.
func (n MockTracer) Extract(format interface{}, carrier interface{}) (opentracing.SpanContext, error) {
return nil, opentracing.ErrSpanContextNotFound
}
func TestTruncateString(t *testing.T) {
testCases := []struct {
desc string
text string
limit int
expected string
}{
{
desc: "short text less than limit 10",
text: "short",
limit: 10,
expected: "short",
},
{
desc: "basic truncate with limit 10",
text: "some very long pice of text",
limit: 10,
expected: "some ve...",
},
{
desc: "truncate long FQDN to 39 chars",
text: "some-service-100.slug.namespace.environment.domain.tld",
limit: 39,
expected: "some-service-100.slug.namespace.envi...",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
actual := truncateString(test.text, test.limit)
assert.Equal(t, test.expected, actual)
assert.True(t, len(actual) <= test.limit)
})
}
}
func TestComputeHash(t *testing.T) {
testCases := []struct {
desc string
text string
expected string
}{
{
desc: "hashing",
text: "some very long pice of text",
expected: "0258ea1c",
},
{
desc: "short text less than limit 10",
text: "short",
expected: "f9b0078b",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
actual := computeHash(test.text)
assert.Equal(t, test.expected, actual)
})
}
func (t *trackingBackenMock) Setup(componentName string) (opentracing.Tracer, io.Closer, error) {
opentracing.SetGlobalTracer(t.tracer)
return t.tracer, nil, nil
}

Some files were not shown because too many files have changed in this diff Show more