Dynamic Configuration Refactoring
This commit is contained in:
parent
d3ae88f108
commit
a09dfa3ce1
452 changed files with 21023 additions and 9419 deletions
|
@ -6,7 +6,7 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/old/middlewares"
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
59
middlewares/accesslog/field_middleware.go
Normal file
59
middlewares/accesslog/field_middleware.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
// FieldApply function hook to add data in accesslog
|
||||
type FieldApply func(rw http.ResponseWriter, r *http.Request, next http.Handler, data *LogData)
|
||||
|
||||
// FieldHandler sends a new field to the logger.
|
||||
type FieldHandler struct {
|
||||
next http.Handler
|
||||
name string
|
||||
value string
|
||||
applyFn FieldApply
|
||||
}
|
||||
|
||||
// NewFieldHandler creates a Field handler.
|
||||
func NewFieldHandler(next http.Handler, name string, value string, applyFn FieldApply) http.Handler {
|
||||
return &FieldHandler{next: next, name: name, value: value, applyFn: applyFn}
|
||||
}
|
||||
|
||||
func (f *FieldHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
table := GetLogData(req)
|
||||
if table == nil {
|
||||
f.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
table.Core[f.name] = f.value
|
||||
|
||||
if f.applyFn != nil {
|
||||
f.applyFn(rw, req, f.next, table)
|
||||
} else {
|
||||
f.next.ServeHTTP(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
// AddServiceFields add service fields
|
||||
func AddServiceFields(rw http.ResponseWriter, req *http.Request, next http.Handler, data *LogData) {
|
||||
data.Core[ServiceURL] = req.URL // note that this is *not* the original incoming URL
|
||||
data.Core[ServiceAddr] = req.URL.Host
|
||||
|
||||
crw := &captureResponseWriter{rw: rw}
|
||||
start := time.Now().UTC()
|
||||
|
||||
next.ServeHTTP(crw, req)
|
||||
|
||||
// use UTC to handle switchover of daylight saving correctly
|
||||
data.Core[OriginDuration] = time.Now().UTC().Sub(start)
|
||||
data.Core[OriginStatus] = crw.Status()
|
||||
// make copy of headers so we can ensure there is no subsequent mutation during response processing
|
||||
data.OriginResponse = make(http.Header)
|
||||
utils.CopyHeaders(data.OriginResponse, crw.Header())
|
||||
data.Core[OriginContentSize] = crw.Size()
|
||||
}
|
|
@ -12,14 +12,16 @@ const (
|
|||
// Duration is the map key used for the total time taken by processing the response, including the origin server's time but
|
||||
// not the log writing time.
|
||||
Duration = "Duration"
|
||||
// FrontendName is the map key used for the name of the Traefik frontend.
|
||||
FrontendName = "FrontendName"
|
||||
// BackendName is the map key used for the name of the Traefik backend.
|
||||
BackendName = "BackendName"
|
||||
// BackendURL is the map key used for the URL of the Traefik backend.
|
||||
BackendURL = "BackendURL"
|
||||
// BackendAddr is the map key used for the IP:port of the Traefik backend (extracted from BackendURL)
|
||||
BackendAddr = "BackendAddr"
|
||||
|
||||
// RouterName is the map key used for the name of the Traefik router.
|
||||
RouterName = "RouterName"
|
||||
// ServiceName is the map key used for the name of the Traefik backend.
|
||||
ServiceName = "ServiceName"
|
||||
// ServiceURL is the map key used for the URL of the Traefik backend.
|
||||
ServiceURL = "ServiceURL"
|
||||
// ServiceAddr is the map key used for the IP:port of the Traefik backend (extracted from BackendURL)
|
||||
ServiceAddr = "ServiceAddr"
|
||||
|
||||
// ClientAddr is the map key used for the remote address in its original form (usually IP:port).
|
||||
ClientAddr = "ClientAddr"
|
||||
// ClientHost is the map key used for the remote IP address from which the client request was received.
|
||||
|
@ -72,9 +74,9 @@ const (
|
|||
var defaultCoreKeys = [...]string{
|
||||
StartUTC,
|
||||
Duration,
|
||||
FrontendName,
|
||||
BackendName,
|
||||
BackendURL,
|
||||
RouterName,
|
||||
ServiceName,
|
||||
ServiceURL,
|
||||
ClientHost,
|
||||
ClientPort,
|
||||
ClientUsername,
|
||||
|
@ -99,7 +101,7 @@ func init() {
|
|||
for _, k := range defaultCoreKeys {
|
||||
allCoreKeys[k] = struct{}{}
|
||||
}
|
||||
allCoreKeys[BackendAddr] = struct{}{}
|
||||
allCoreKeys[ServiceAddr] = struct{}{}
|
||||
allCoreKeys[ClientAddr] = struct{}{}
|
||||
allCoreKeys[RequestAddr] = struct{}{}
|
||||
allCoreKeys[GzipRatio] = struct{}{}
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/containous/alice"
|
||||
"github.com/containous/flaeg/parse"
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/types"
|
||||
|
@ -21,36 +22,44 @@ import (
|
|||
type key string
|
||||
|
||||
const (
|
||||
// DataTableKey is the key within the request context used to
|
||||
// store the Log Data Table
|
||||
// DataTableKey is the key within the request context used to store the Log Data Table.
|
||||
DataTableKey key = "LogDataTable"
|
||||
|
||||
// CommonFormat is the common logging format (CLF)
|
||||
// CommonFormat is the common logging format (CLF).
|
||||
CommonFormat string = "common"
|
||||
|
||||
// JSONFormat is the JSON logging format
|
||||
// JSONFormat is the JSON logging format.
|
||||
JSONFormat string = "json"
|
||||
)
|
||||
|
||||
type logHandlerParams struct {
|
||||
type handlerParams struct {
|
||||
logDataTable *LogData
|
||||
crr *captureRequestReader
|
||||
crw *captureResponseWriter
|
||||
}
|
||||
|
||||
// LogHandler will write each request and its response to the access log.
|
||||
type LogHandler struct {
|
||||
// Handler will write each request and its response to the access log.
|
||||
type Handler struct {
|
||||
config *types.AccessLog
|
||||
logger *logrus.Logger
|
||||
file *os.File
|
||||
mu sync.Mutex
|
||||
httpCodeRanges types.HTTPCodeRanges
|
||||
logHandlerChan chan logHandlerParams
|
||||
logHandlerChan chan handlerParams
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewLogHandler creates a new LogHandler
|
||||
func NewLogHandler(config *types.AccessLog) (*LogHandler, error) {
|
||||
// WrapHandler Wraps access log handler into an Alice Constructor.
|
||||
func WrapHandler(handler *Handler) alice.Constructor {
|
||||
return func(next http.Handler) (http.Handler, error) {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
handler.ServeHTTP(rw, req, next.ServeHTTP)
|
||||
}), nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewHandler creates a new Handler.
|
||||
func NewHandler(config *types.AccessLog) (*Handler, error) {
|
||||
file := os.Stdout
|
||||
if len(config.FilePath) > 0 {
|
||||
f, err := openAccessLogFile(config.FilePath)
|
||||
|
@ -59,7 +68,7 @@ func NewLogHandler(config *types.AccessLog) (*LogHandler, error) {
|
|||
}
|
||||
file = f
|
||||
}
|
||||
logHandlerChan := make(chan logHandlerParams, config.BufferingSize)
|
||||
logHandlerChan := make(chan handlerParams, config.BufferingSize)
|
||||
|
||||
var formatter logrus.Formatter
|
||||
|
||||
|
@ -79,7 +88,7 @@ func NewLogHandler(config *types.AccessLog) (*LogHandler, error) {
|
|||
Level: logrus.InfoLevel,
|
||||
}
|
||||
|
||||
logHandler := &LogHandler{
|
||||
logHandler := &Handler{
|
||||
config: config,
|
||||
logger: logger,
|
||||
file: file,
|
||||
|
@ -88,7 +97,7 @@ func NewLogHandler(config *types.AccessLog) (*LogHandler, error) {
|
|||
|
||||
if config.Filters != nil {
|
||||
if httpCodeRanges, err := types.NewHTTPCodeRanges(config.Filters.StatusCodes); err != nil {
|
||||
log.Errorf("Failed to create new HTTP code ranges: %s", err)
|
||||
log.WithoutContext().Errorf("Failed to create new HTTP code ranges: %s", err)
|
||||
} else {
|
||||
logHandler.httpCodeRanges = httpCodeRanges
|
||||
}
|
||||
|
@ -122,17 +131,16 @@ func openAccessLogFile(filePath string) (*os.File, error) {
|
|||
return file, nil
|
||||
}
|
||||
|
||||
// GetLogDataTable gets the request context object that contains logging data.
|
||||
// GetLogData gets the request context object that contains logging data.
|
||||
// This creates data as the request passes through the middleware chain.
|
||||
func GetLogDataTable(req *http.Request) *LogData {
|
||||
func GetLogData(req *http.Request) *LogData {
|
||||
if ld, ok := req.Context().Value(DataTableKey).(*LogData); ok {
|
||||
return ld
|
||||
}
|
||||
log.Errorf("%s is nil", DataTableKey)
|
||||
return &LogData{Core: make(CoreLogData)}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LogHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
|
||||
func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
|
||||
now := time.Now().UTC()
|
||||
|
||||
core := CoreLogData{
|
||||
|
@ -179,46 +187,45 @@ func (l *LogHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next h
|
|||
|
||||
next.ServeHTTP(crw, reqWithDataTable)
|
||||
|
||||
core[ClientUsername] = formatUsernameForLog(core[ClientUsername])
|
||||
core[ClientUsername] = usernameIfPresent(reqWithDataTable.URL)
|
||||
|
||||
logDataTable.DownstreamResponse = crw.Header()
|
||||
|
||||
if l.config.BufferingSize > 0 {
|
||||
l.logHandlerChan <- logHandlerParams{
|
||||
if h.config.BufferingSize > 0 {
|
||||
h.logHandlerChan <- handlerParams{
|
||||
logDataTable: logDataTable,
|
||||
crr: crr,
|
||||
crw: crw,
|
||||
}
|
||||
} else {
|
||||
l.logTheRoundTrip(logDataTable, crr, crw)
|
||||
h.logTheRoundTrip(logDataTable, crr, crw)
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the Logger (i.e. the file, drain logHandlerChan, etc).
|
||||
func (l *LogHandler) Close() error {
|
||||
close(l.logHandlerChan)
|
||||
l.wg.Wait()
|
||||
return l.file.Close()
|
||||
func (h *Handler) Close() error {
|
||||
close(h.logHandlerChan)
|
||||
h.wg.Wait()
|
||||
return h.file.Close()
|
||||
}
|
||||
|
||||
// Rotate closes and reopens the log file to allow for rotation
|
||||
// by an external source.
|
||||
func (l *LogHandler) Rotate() error {
|
||||
// Rotate closes and reopens the log file to allow for rotation by an external source.
|
||||
func (h *Handler) Rotate() error {
|
||||
var err error
|
||||
|
||||
if l.file != nil {
|
||||
if h.file != nil {
|
||||
defer func(f *os.File) {
|
||||
f.Close()
|
||||
}(l.file)
|
||||
}(h.file)
|
||||
}
|
||||
|
||||
l.file, err = os.OpenFile(l.config.FilePath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0664)
|
||||
h.file, err = os.OpenFile(h.config.FilePath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0664)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.logger.Out = l.file
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.logger.Out = h.file
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -230,16 +237,17 @@ func silentSplitHostPort(value string) (host string, port string) {
|
|||
return host, port
|
||||
}
|
||||
|
||||
func formatUsernameForLog(usernameField interface{}) string {
|
||||
username, ok := usernameField.(string)
|
||||
if ok && len(username) != 0 {
|
||||
return username
|
||||
func usernameIfPresent(theURL *url.URL) string {
|
||||
if theURL.User != nil {
|
||||
if name := theURL.User.Username(); name != "" {
|
||||
return name
|
||||
}
|
||||
}
|
||||
return "-"
|
||||
}
|
||||
|
||||
// Logging handler to log frontend name, backend name, and elapsed time
|
||||
func (l *LogHandler) logTheRoundTrip(logDataTable *LogData, crr *captureRequestReader, crw *captureResponseWriter) {
|
||||
// Logging handler to log frontend name, backend name, and elapsed time.
|
||||
func (h *Handler) logTheRoundTrip(logDataTable *LogData, crr *captureRequestReader, crw *captureResponseWriter) {
|
||||
core := logDataTable.Core
|
||||
|
||||
retryAttempts, ok := core[RetryAttempts].(int)
|
||||
|
@ -254,11 +262,11 @@ func (l *LogHandler) logTheRoundTrip(logDataTable *LogData, crr *captureRequestR
|
|||
|
||||
core[DownstreamStatus] = crw.Status()
|
||||
|
||||
// n.b. take care to perform time arithmetic using UTC to avoid errors at DST boundaries
|
||||
// n.b. take care to perform time arithmetic using UTC to avoid errors at DST boundaries.
|
||||
totalDuration := time.Now().UTC().Sub(core[StartUTC].(time.Time))
|
||||
core[Duration] = totalDuration
|
||||
|
||||
if l.keepAccessLog(crw.Status(), retryAttempts, totalDuration) {
|
||||
if h.keepAccessLog(crw.Status(), retryAttempts, totalDuration) {
|
||||
core[DownstreamContentSize] = crw.Size()
|
||||
if original, ok := core[OriginContentSize]; ok {
|
||||
o64 := original.(int64)
|
||||
|
@ -275,24 +283,24 @@ func (l *LogHandler) logTheRoundTrip(logDataTable *LogData, crr *captureRequestR
|
|||
fields := logrus.Fields{}
|
||||
|
||||
for k, v := range logDataTable.Core {
|
||||
if l.config.Fields.Keep(k) {
|
||||
if h.config.Fields.Keep(k) {
|
||||
fields[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
l.redactHeaders(logDataTable.Request, fields, "request_")
|
||||
l.redactHeaders(logDataTable.OriginResponse, fields, "origin_")
|
||||
l.redactHeaders(logDataTable.DownstreamResponse, fields, "downstream_")
|
||||
h.redactHeaders(logDataTable.Request, fields, "request_")
|
||||
h.redactHeaders(logDataTable.OriginResponse, fields, "origin_")
|
||||
h.redactHeaders(logDataTable.DownstreamResponse, fields, "downstream_")
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.logger.WithFields(fields).Println()
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.logger.WithFields(fields).Println()
|
||||
}
|
||||
}
|
||||
|
||||
func (l *LogHandler) redactHeaders(headers http.Header, fields logrus.Fields, prefix string) {
|
||||
func (h *Handler) redactHeaders(headers http.Header, fields logrus.Fields, prefix string) {
|
||||
for k := range headers {
|
||||
v := l.config.Fields.KeepHeader(k)
|
||||
v := h.config.Fields.KeepHeader(k)
|
||||
if v == types.AccessLogKeep {
|
||||
fields[prefix+k] = headers.Get(k)
|
||||
} else if v == types.AccessLogRedact {
|
||||
|
@ -301,26 +309,26 @@ func (l *LogHandler) redactHeaders(headers http.Header, fields logrus.Fields, pr
|
|||
}
|
||||
}
|
||||
|
||||
func (l *LogHandler) keepAccessLog(statusCode, retryAttempts int, duration time.Duration) bool {
|
||||
if l.config.Filters == nil {
|
||||
func (h *Handler) keepAccessLog(statusCode, retryAttempts int, duration time.Duration) bool {
|
||||
if h.config.Filters == nil {
|
||||
// no filters were specified
|
||||
return true
|
||||
}
|
||||
|
||||
if len(l.httpCodeRanges) == 0 && !l.config.Filters.RetryAttempts && l.config.Filters.MinDuration == 0 {
|
||||
if len(h.httpCodeRanges) == 0 && !h.config.Filters.RetryAttempts && h.config.Filters.MinDuration == 0 {
|
||||
// empty filters were specified, e.g. by passing --accessLog.filters only (without other filter options)
|
||||
return true
|
||||
}
|
||||
|
||||
if l.httpCodeRanges.Contains(statusCode) {
|
||||
if h.httpCodeRanges.Contains(statusCode) {
|
||||
return true
|
||||
}
|
||||
|
||||
if l.config.Filters.RetryAttempts && retryAttempts > 0 {
|
||||
if h.config.Filters.RetryAttempts && retryAttempts > 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if l.config.Filters.MinDuration > 0 && (parse.Duration(duration) > l.config.Filters.MinDuration) {
|
||||
if h.config.Filters.MinDuration > 0 && (parse.Duration(duration) > h.config.Filters.MinDuration) {
|
||||
return true
|
||||
}
|
||||
|
||||
|
|
|
@ -8,16 +8,16 @@ import (
|
|||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// default format for time presentation
|
||||
// default format for time presentation.
|
||||
const (
|
||||
commonLogTimeFormat = "02/Jan/2006:15:04:05 -0700"
|
||||
defaultValue = "-"
|
||||
)
|
||||
|
||||
// CommonLogFormatter provides formatting in the Traefik common log format
|
||||
// CommonLogFormatter provides formatting in the Traefik common log format.
|
||||
type CommonLogFormatter struct{}
|
||||
|
||||
// Format formats the log entry in the Traefik common log format
|
||||
// Format formats the log entry in the Traefik common log format.
|
||||
func (f *CommonLogFormatter) Format(entry *logrus.Entry) ([]byte, error) {
|
||||
b := &bytes.Buffer{}
|
||||
|
||||
|
@ -43,8 +43,8 @@ func (f *CommonLogFormatter) Format(entry *logrus.Entry) ([]byte, error) {
|
|||
toLog(entry.Data, "request_Referer", `"-"`, true),
|
||||
toLog(entry.Data, "request_User-Agent", `"-"`, true),
|
||||
toLog(entry.Data, RequestCount, defaultValue, true),
|
||||
toLog(entry.Data, FrontendName, defaultValue, true),
|
||||
toLog(entry.Data, BackendURL, defaultValue, true),
|
||||
toLog(entry.Data, RouterName, defaultValue, true),
|
||||
toLog(entry.Data, ServiceURL, defaultValue, true),
|
||||
elapsedMillis)
|
||||
|
||||
return b.Bytes(), err
|
||||
|
@ -70,6 +70,7 @@ func toLog(fields logrus.Fields, key string, defaultValue string, quoted bool) i
|
|||
return defaultValue
|
||||
|
||||
}
|
||||
|
||||
func toLogEntry(s string, defaultValue string, quote bool) string {
|
||||
if len(s) == 0 {
|
||||
return defaultValue
|
||||
|
|
|
@ -32,8 +32,8 @@ func TestCommonLogFormatter_Format(t *testing.T) {
|
|||
RequestRefererHeader: "",
|
||||
RequestUserAgentHeader: "",
|
||||
RequestCount: 0,
|
||||
FrontendName: "",
|
||||
BackendURL: "",
|
||||
RouterName: "",
|
||||
ServiceURL: "",
|
||||
},
|
||||
expectedLog: `10.0.0.1 - Client [10/Nov/2009:23:00:00 +0000] "GET /foo http" - - "-" "-" 0 - - 123000ms
|
||||
`,
|
||||
|
@ -53,8 +53,8 @@ func TestCommonLogFormatter_Format(t *testing.T) {
|
|||
RequestRefererHeader: "referer",
|
||||
RequestUserAgentHeader: "agent",
|
||||
RequestCount: nil,
|
||||
FrontendName: "foo",
|
||||
BackendURL: "http://10.0.0.2/toto",
|
||||
RouterName: "foo",
|
||||
ServiceURL: "http://10.0.0.2/toto",
|
||||
},
|
||||
expectedLog: `10.0.0.1 - Client [10/Nov/2009:23:00:00 +0000] "GET /foo http" 123 132 "referer" "agent" - "foo" "http://10.0.0.2/toto" 123000ms
|
||||
`,
|
||||
|
|
|
@ -15,7 +15,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/containous/flaeg/parse"
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -24,8 +23,8 @@ import (
|
|||
var (
|
||||
logFileNameSuffix = "/traefik/logger/test.log"
|
||||
testContent = "Hello, World"
|
||||
testBackendName = "http://127.0.0.1/testBackend"
|
||||
testFrontendName = "testFrontend"
|
||||
testServiceName = "http://127.0.0.1/testService"
|
||||
testRouterName = "testRouter"
|
||||
testStatus = 123
|
||||
testContentSize int64 = 12
|
||||
testHostname = "TestHost"
|
||||
|
@ -50,7 +49,7 @@ func TestLogRotation(t *testing.T) {
|
|||
rotatedFileName := fileName + ".rotated"
|
||||
|
||||
config := &types.AccessLog{FilePath: fileName, Format: CommonFormat}
|
||||
logHandler, err := NewLogHandler(config)
|
||||
logHandler, err := NewHandler(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating new log handler: %s", err)
|
||||
}
|
||||
|
@ -129,7 +128,7 @@ func TestLoggerCLF(t *testing.T) {
|
|||
logData, err := ioutil.ReadFile(logFilePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedLog := ` TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testFrontend" "http://127.0.0.1/testBackend" 1ms`
|
||||
expectedLog := ` TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testRouter" "http://127.0.0.1/testService" 1ms`
|
||||
assertValidLogData(t, expectedLog, logData)
|
||||
}
|
||||
|
||||
|
@ -144,7 +143,7 @@ func TestAsyncLoggerCLF(t *testing.T) {
|
|||
logData, err := ioutil.ReadFile(logFilePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedLog := ` TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testFrontend" "http://127.0.0.1/testBackend" 1ms`
|
||||
expectedLog := ` TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testRouter" "http://127.0.0.1/testService" 1ms`
|
||||
assertValidLogData(t, expectedLog, logData)
|
||||
}
|
||||
|
||||
|
@ -156,11 +155,11 @@ func assertString(exp string) func(t *testing.T, actual interface{}) {
|
|||
}
|
||||
}
|
||||
|
||||
func assertNotEqual(exp string) func(t *testing.T, actual interface{}) {
|
||||
func assertNotEmpty() func(t *testing.T, actual interface{}) {
|
||||
return func(t *testing.T, actual interface{}) {
|
||||
t.Helper()
|
||||
|
||||
assert.NotEqual(t, exp, actual)
|
||||
assert.NotEqual(t, "", actual)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -205,8 +204,8 @@ func TestLoggerJSON(t *testing.T) {
|
|||
OriginStatus: assertFloat64(float64(testStatus)),
|
||||
RequestRefererHeader: assertString(testReferer),
|
||||
RequestUserAgentHeader: assertString(testUserAgent),
|
||||
FrontendName: assertString(testFrontendName),
|
||||
BackendURL: assertString(testBackendName),
|
||||
RouterName: assertString(testRouterName),
|
||||
ServiceURL: assertString(testServiceName),
|
||||
ClientUsername: assertString(testUsername),
|
||||
ClientHost: assertString(testHostname),
|
||||
ClientPort: assertString(fmt.Sprintf("%d", testPort)),
|
||||
|
@ -218,9 +217,9 @@ func TestLoggerJSON(t *testing.T) {
|
|||
Duration: assertFloat64NotZero(),
|
||||
Overhead: assertFloat64NotZero(),
|
||||
RetryAttempts: assertFloat64(float64(testRetryAttempts)),
|
||||
"time": assertNotEqual(""),
|
||||
"StartLocal": assertNotEqual(""),
|
||||
"StartUTC": assertNotEqual(""),
|
||||
"time": assertNotEmpty(),
|
||||
"StartLocal": assertNotEmpty(),
|
||||
"StartUTC": assertNotEmpty(),
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -235,7 +234,7 @@ func TestLoggerJSON(t *testing.T) {
|
|||
expected: map[string]func(t *testing.T, value interface{}){
|
||||
"level": assertString("info"),
|
||||
"msg": assertString(""),
|
||||
"time": assertNotEqual(""),
|
||||
"time": assertNotEmpty(),
|
||||
"downstream_Content-Type": assertString("text/plain; charset=utf-8"),
|
||||
RequestRefererHeader: assertString(testReferer),
|
||||
RequestUserAgentHeader: assertString(testUserAgent),
|
||||
|
@ -256,7 +255,7 @@ func TestLoggerJSON(t *testing.T) {
|
|||
expected: map[string]func(t *testing.T, value interface{}){
|
||||
"level": assertString("info"),
|
||||
"msg": assertString(""),
|
||||
"time": assertNotEqual(""),
|
||||
"time": assertNotEmpty(),
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -274,7 +273,7 @@ func TestLoggerJSON(t *testing.T) {
|
|||
expected: map[string]func(t *testing.T, value interface{}){
|
||||
"level": assertString("info"),
|
||||
"msg": assertString(""),
|
||||
"time": assertNotEqual(""),
|
||||
"time": assertNotEmpty(),
|
||||
"downstream_Content-Type": assertString("REDACTED"),
|
||||
RequestRefererHeader: assertString("REDACTED"),
|
||||
RequestUserAgentHeader: assertString("REDACTED"),
|
||||
|
@ -302,7 +301,7 @@ func TestLoggerJSON(t *testing.T) {
|
|||
RequestHost: assertString(testHostname),
|
||||
"level": assertString("info"),
|
||||
"msg": assertString(""),
|
||||
"time": assertNotEqual(""),
|
||||
"time": assertNotEmpty(),
|
||||
RequestRefererHeader: assertString(testReferer),
|
||||
},
|
||||
},
|
||||
|
@ -349,7 +348,7 @@ func TestNewLogHandlerOutputStdout(t *testing.T) {
|
|||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
},
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
|
||||
},
|
||||
{
|
||||
desc: "default config with empty filters",
|
||||
|
@ -358,7 +357,7 @@ func TestNewLogHandlerOutputStdout(t *testing.T) {
|
|||
Format: CommonFormat,
|
||||
Filters: &types.AccessLogFilters{},
|
||||
},
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
|
||||
},
|
||||
{
|
||||
desc: "Status code filter not matching",
|
||||
|
@ -380,7 +379,7 @@ func TestNewLogHandlerOutputStdout(t *testing.T) {
|
|||
StatusCodes: []string{"123"},
|
||||
},
|
||||
},
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
|
||||
},
|
||||
{
|
||||
desc: "Duration filter not matching",
|
||||
|
@ -402,7 +401,7 @@ func TestNewLogHandlerOutputStdout(t *testing.T) {
|
|||
MinDuration: parse.Duration(1 * time.Millisecond),
|
||||
},
|
||||
},
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
|
||||
},
|
||||
{
|
||||
desc: "Retry attempts filter matching",
|
||||
|
@ -413,7 +412,7 @@ func TestNewLogHandlerOutputStdout(t *testing.T) {
|
|||
RetryAttempts: true,
|
||||
},
|
||||
},
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
|
||||
},
|
||||
{
|
||||
desc: "Default mode keep",
|
||||
|
@ -424,7 +423,7 @@ func TestNewLogHandlerOutputStdout(t *testing.T) {
|
|||
DefaultMode: "keep",
|
||||
},
|
||||
},
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
|
||||
},
|
||||
{
|
||||
desc: "Default mode keep with override",
|
||||
|
@ -438,7 +437,7 @@ func TestNewLogHandlerOutputStdout(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
expectedLog: `- - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
|
||||
expectedLog: `- - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
|
||||
},
|
||||
{
|
||||
desc: "Default mode drop",
|
||||
|
@ -572,8 +571,8 @@ func assertValidLogData(t *testing.T, expected string, logData []byte) {
|
|||
assert.Equal(t, resultExpected[RequestRefererHeader], result[RequestRefererHeader], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[RequestUserAgentHeader], result[RequestUserAgentHeader], formatErrMessage)
|
||||
assert.Regexp(t, regexp.MustCompile("[0-9]*"), result[RequestCount], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[FrontendName], result[FrontendName], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[BackendURL], result[BackendURL], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[RouterName], result[RouterName], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[ServiceURL], result[ServiceURL], formatErrMessage)
|
||||
assert.Regexp(t, regexp.MustCompile("[0-9]*ms"), result[Duration], formatErrMessage)
|
||||
}
|
||||
|
||||
|
@ -599,7 +598,7 @@ func createTempDir(t *testing.T, prefix string) string {
|
|||
}
|
||||
|
||||
func doLogging(t *testing.T, config *types.AccessLog) {
|
||||
logger, err := NewLogHandler(config)
|
||||
logger, err := NewHandler(config)
|
||||
require.NoError(t, err)
|
||||
defer logger.Close()
|
||||
|
||||
|
@ -618,6 +617,7 @@ func doLogging(t *testing.T, config *types.AccessLog) {
|
|||
Method: testMethod,
|
||||
RemoteAddr: fmt.Sprintf("%s:%d", testHostname, testPort),
|
||||
URL: &url.URL{
|
||||
User: url.UserPassword(testUsername, ""),
|
||||
Path: testPath,
|
||||
},
|
||||
}
|
||||
|
@ -627,18 +627,23 @@ func doLogging(t *testing.T, config *types.AccessLog) {
|
|||
|
||||
func logWriterTestHandlerFunc(rw http.ResponseWriter, r *http.Request) {
|
||||
if _, err := rw.Write([]byte(testContent)); err != nil {
|
||||
log.Error(err)
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
logData := GetLogData(r)
|
||||
if logData != nil {
|
||||
logData.Core[RouterName] = testRouterName
|
||||
logData.Core[ServiceURL] = testServiceName
|
||||
logData.Core[OriginStatus] = testStatus
|
||||
logData.Core[OriginContentSize] = testContentSize
|
||||
logData.Core[RetryAttempts] = testRetryAttempts
|
||||
logData.Core[StartUTC] = testStart.UTC()
|
||||
logData.Core[StartLocal] = testStart.Local()
|
||||
} else {
|
||||
http.Error(rw, "LogData is nil", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(testStatus)
|
||||
|
||||
logDataTable := GetLogDataTable(r)
|
||||
logDataTable.Core[FrontendName] = testFrontendName
|
||||
logDataTable.Core[BackendURL] = testBackendName
|
||||
logDataTable.Core[OriginStatus] = testStatus
|
||||
logDataTable.Core[OriginContentSize] = testContentSize
|
||||
logDataTable.Core[RetryAttempts] = testRetryAttempts
|
||||
logDataTable.Core[StartUTC] = testStart.UTC()
|
||||
logDataTable.Core[StartLocal] = testStart.Local()
|
||||
logDataTable.Core[ClientUsername] = testUsername
|
||||
}
|
||||
|
|
|
@ -45,8 +45,8 @@ func ParseAccessLog(data string) (map[string]string, error) {
|
|||
result[RequestRefererHeader] = submatch[9]
|
||||
result[RequestUserAgentHeader] = submatch[10]
|
||||
result[RequestCount] = submatch[11]
|
||||
result[FrontendName] = submatch[12]
|
||||
result[BackendURL] = submatch[13]
|
||||
result[RouterName] = submatch[12]
|
||||
result[ServiceURL] = submatch[13]
|
||||
result[Duration] = submatch[14]
|
||||
}
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ func TestParseAccessLog(t *testing.T) {
|
|||
}{
|
||||
{
|
||||
desc: "full log",
|
||||
value: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
|
||||
value: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testRouter" "http://127.0.0.1/testService" 1ms`,
|
||||
expected: map[string]string{
|
||||
ClientHost: "TestHost",
|
||||
ClientUsername: "TestUser",
|
||||
|
@ -27,14 +27,14 @@ func TestParseAccessLog(t *testing.T) {
|
|||
RequestRefererHeader: `"testReferer"`,
|
||||
RequestUserAgentHeader: `"testUserAgent"`,
|
||||
RequestCount: "1",
|
||||
FrontendName: `"testFrontend"`,
|
||||
BackendURL: `"http://127.0.0.1/testBackend"`,
|
||||
RouterName: `"testRouter"`,
|
||||
ServiceURL: `"http://127.0.0.1/testService"`,
|
||||
Duration: "1ms",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "log with space",
|
||||
value: `127.0.0.1 - - [09/Mar/2018:10:51:32 +0000] "GET / HTTP/1.1" 401 17 "-" "Go-http-client/1.1" 1 "testFrontend with space" - 0ms`,
|
||||
value: `127.0.0.1 - - [09/Mar/2018:10:51:32 +0000] "GET / HTTP/1.1" 401 17 "-" "Go-http-client/1.1" 1 "testRouter with space" - 0ms`,
|
||||
expected: map[string]string{
|
||||
ClientHost: "127.0.0.1",
|
||||
ClientUsername: "-",
|
||||
|
@ -47,8 +47,8 @@ func TestParseAccessLog(t *testing.T) {
|
|||
RequestRefererHeader: `"-"`,
|
||||
RequestUserAgentHeader: `"Go-http-client/1.1"`,
|
||||
RequestCount: "1",
|
||||
FrontendName: `"testFrontend with space"`,
|
||||
BackendURL: `-`,
|
||||
RouterName: `"testRouter with space"`,
|
||||
ServiceURL: `-`,
|
||||
Duration: "0ms",
|
||||
},
|
||||
},
|
||||
|
|
|
@ -1,64 +0,0 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/urfave/negroni"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
// SaveBackend sends the backend name to the logger.
|
||||
// These are always used with a corresponding SaveFrontend handler.
|
||||
type SaveBackend struct {
|
||||
next http.Handler
|
||||
backendName string
|
||||
}
|
||||
|
||||
// NewSaveBackend creates a SaveBackend handler.
|
||||
func NewSaveBackend(next http.Handler, backendName string) http.Handler {
|
||||
return &SaveBackend{next, backendName}
|
||||
}
|
||||
|
||||
func (sb *SaveBackend) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
serveSaveBackend(rw, r, sb.backendName, func(crw *captureResponseWriter) {
|
||||
sb.next.ServeHTTP(crw, r)
|
||||
})
|
||||
}
|
||||
|
||||
// SaveNegroniBackend sends the backend name to the logger.
|
||||
type SaveNegroniBackend struct {
|
||||
next negroni.Handler
|
||||
backendName string
|
||||
}
|
||||
|
||||
// NewSaveNegroniBackend creates a SaveBackend handler.
|
||||
func NewSaveNegroniBackend(next negroni.Handler, backendName string) negroni.Handler {
|
||||
return &SaveNegroniBackend{next, backendName}
|
||||
}
|
||||
|
||||
func (sb *SaveNegroniBackend) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
serveSaveBackend(rw, r, sb.backendName, func(crw *captureResponseWriter) {
|
||||
sb.next.ServeHTTP(crw, r, next)
|
||||
})
|
||||
}
|
||||
|
||||
func serveSaveBackend(rw http.ResponseWriter, r *http.Request, backendName string, apply func(*captureResponseWriter)) {
|
||||
table := GetLogDataTable(r)
|
||||
table.Core[BackendName] = backendName
|
||||
table.Core[BackendURL] = r.URL // note that this is *not* the original incoming URL
|
||||
table.Core[BackendAddr] = r.URL.Host
|
||||
|
||||
crw := &captureResponseWriter{rw: rw}
|
||||
start := time.Now().UTC()
|
||||
|
||||
apply(crw)
|
||||
|
||||
// use UTC to handle switchover of daylight saving correctly
|
||||
table.Core[OriginDuration] = time.Now().UTC().Sub(start)
|
||||
table.Core[OriginStatus] = crw.Status()
|
||||
// make copy of headers so we can ensure there is no subsequent mutation during response processing
|
||||
table.OriginResponse = make(http.Header)
|
||||
utils.CopyHeaders(table.OriginResponse, crw.Header())
|
||||
table.Core[OriginContentSize] = crw.Size()
|
||||
}
|
|
@ -1,51 +0,0 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
// SaveFrontend sends the frontend name to the logger.
|
||||
// These are sometimes used with a corresponding SaveBackend handler, but not always.
|
||||
// For example, redirected requests don't reach a backend.
|
||||
type SaveFrontend struct {
|
||||
next http.Handler
|
||||
frontendName string
|
||||
}
|
||||
|
||||
// NewSaveFrontend creates a SaveFrontend handler.
|
||||
func NewSaveFrontend(next http.Handler, frontendName string) http.Handler {
|
||||
return &SaveFrontend{next, frontendName}
|
||||
}
|
||||
|
||||
func (sf *SaveFrontend) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
serveSaveFrontend(r, sf.frontendName, func() {
|
||||
sf.next.ServeHTTP(rw, r)
|
||||
})
|
||||
}
|
||||
|
||||
// SaveNegroniFrontend sends the frontend name to the logger.
|
||||
type SaveNegroniFrontend struct {
|
||||
next negroni.Handler
|
||||
frontendName string
|
||||
}
|
||||
|
||||
// NewSaveNegroniFrontend creates a SaveNegroniFrontend handler.
|
||||
func NewSaveNegroniFrontend(next negroni.Handler, frontendName string) negroni.Handler {
|
||||
return &SaveNegroniFrontend{next, frontendName}
|
||||
}
|
||||
|
||||
func (sf *SaveNegroniFrontend) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
serveSaveFrontend(r, sf.frontendName, func() {
|
||||
sf.next.ServeHTTP(rw, r, next)
|
||||
})
|
||||
}
|
||||
|
||||
func serveSaveFrontend(r *http.Request, frontendName string, apply func()) {
|
||||
table := GetLogDataTable(r)
|
||||
table.Core[FrontendName] = strings.TrimPrefix(frontendName, "frontend-")
|
||||
|
||||
apply()
|
||||
}
|
|
@ -14,6 +14,8 @@ func (s *SaveRetries) Retried(req *http.Request, attempt int) {
|
|||
attempt--
|
||||
}
|
||||
|
||||
table := GetLogDataTable(req)
|
||||
table.Core[RetryAttempts] = attempt
|
||||
table := GetLogData(req)
|
||||
if table != nil {
|
||||
table.Core[RetryAttempts] = attempt
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,60 +0,0 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
const (
|
||||
clientUsernameKey key = "ClientUsername"
|
||||
)
|
||||
|
||||
// SaveUsername sends the Username name to the access logger.
|
||||
type SaveUsername struct {
|
||||
next http.Handler
|
||||
}
|
||||
|
||||
// NewSaveUsername creates a SaveUsername handler.
|
||||
func NewSaveUsername(next http.Handler) http.Handler {
|
||||
return &SaveUsername{next}
|
||||
}
|
||||
|
||||
func (sf *SaveUsername) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
serveSaveUsername(r, func() {
|
||||
sf.next.ServeHTTP(rw, r)
|
||||
})
|
||||
}
|
||||
|
||||
// SaveNegroniUsername adds the Username to the access logger data table.
|
||||
type SaveNegroniUsername struct {
|
||||
next negroni.Handler
|
||||
}
|
||||
|
||||
// NewSaveNegroniUsername creates a SaveNegroniUsername handler.
|
||||
func NewSaveNegroniUsername(next negroni.Handler) negroni.Handler {
|
||||
return &SaveNegroniUsername{next}
|
||||
}
|
||||
|
||||
func (sf *SaveNegroniUsername) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
serveSaveUsername(r, func() {
|
||||
sf.next.ServeHTTP(rw, r, next)
|
||||
})
|
||||
}
|
||||
|
||||
func serveSaveUsername(r *http.Request, apply func()) {
|
||||
table := GetLogDataTable(r)
|
||||
|
||||
username, ok := r.Context().Value(clientUsernameKey).(string)
|
||||
if ok {
|
||||
table.Core[ClientUsername] = username
|
||||
}
|
||||
|
||||
apply()
|
||||
}
|
||||
|
||||
// WithUserName adds a username to a requests' context
|
||||
func WithUserName(req *http.Request, username string) *http.Request {
|
||||
return req.WithContext(context.WithValue(req.Context(), clientUsernameKey, username))
|
||||
}
|
|
@ -1,35 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// AddPrefix is a middleware used to add prefix to an URL request
|
||||
type AddPrefix struct {
|
||||
Handler http.Handler
|
||||
Prefix string
|
||||
}
|
||||
|
||||
type key string
|
||||
|
||||
const (
|
||||
// AddPrefixKey is the key within the request context used to
|
||||
// store the added prefix
|
||||
AddPrefixKey key = "AddPrefix"
|
||||
)
|
||||
|
||||
func (s *AddPrefix) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
r.URL.Path = s.Prefix + r.URL.Path
|
||||
if r.URL.RawPath != "" {
|
||||
r.URL.RawPath = s.Prefix + r.URL.RawPath
|
||||
}
|
||||
r.RequestURI = r.URL.RequestURI()
|
||||
r = r.WithContext(context.WithValue(r.Context(), AddPrefixKey, s.Prefix))
|
||||
s.Handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// SetHandler sets handler
|
||||
func (s *AddPrefix) SetHandler(Handler http.Handler) {
|
||||
s.Handler = Handler
|
||||
}
|
|
@ -1,64 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAddPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
desc string
|
||||
prefix string
|
||||
path string
|
||||
expectedPath string
|
||||
expectedRawPath string
|
||||
}{
|
||||
{
|
||||
desc: "regular path",
|
||||
prefix: "/a",
|
||||
path: "/b",
|
||||
expectedPath: "/a/b",
|
||||
},
|
||||
{
|
||||
desc: "raw path is supported",
|
||||
prefix: "/a",
|
||||
path: "/b%2Fc",
|
||||
expectedPath: "/a/b/c",
|
||||
expectedRawPath: "/a/b%2Fc",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var actualPath, actualRawPath, requestURI string
|
||||
handler := &AddPrefix{
|
||||
Prefix: test.prefix,
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
actualPath = r.URL.Path
|
||||
actualRawPath = r.URL.RawPath
|
||||
requestURI = r.RequestURI
|
||||
}),
|
||||
}
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
|
||||
|
||||
handler.ServeHTTP(nil, req)
|
||||
|
||||
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
|
||||
assert.Equal(t, test.expectedRawPath, actualRawPath, "Unexpected raw path.")
|
||||
|
||||
expectedURI := test.expectedPath
|
||||
if test.expectedRawPath != "" {
|
||||
// go HTTP uses the raw path when existent in the RequestURI
|
||||
expectedURI = test.expectedRawPath
|
||||
}
|
||||
assert.Equal(t, expectedURI, requestURI, "Unexpected request URI.")
|
||||
})
|
||||
}
|
||||
}
|
62
middlewares/addprefix/add_prefix.go
Normal file
62
middlewares/addprefix/add_prefix.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package addprefix
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "AddPrefix"
|
||||
)
|
||||
|
||||
// AddPrefix is a middleware used to add prefix to an URL request.
|
||||
type addPrefix struct {
|
||||
next http.Handler
|
||||
prefix string
|
||||
name string
|
||||
}
|
||||
|
||||
// New creates a new handler.
|
||||
func New(ctx context.Context, next http.Handler, config config.AddPrefix, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
var result *addPrefix
|
||||
|
||||
if len(config.Prefix) > 0 {
|
||||
result = &addPrefix{
|
||||
prefix: config.Prefix,
|
||||
next: next,
|
||||
name: name,
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("prefix cannot be empty")
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (ap *addPrefix) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return ap.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (ap *addPrefix) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
logger := middlewares.GetLogger(req.Context(), ap.name, typeName)
|
||||
|
||||
oldURLPath := req.URL.Path
|
||||
req.URL.Path = ap.prefix + req.URL.Path
|
||||
logger.Debugf("URL.Path is now %s (was %s).", req.URL.Path, oldURLPath)
|
||||
|
||||
if req.URL.RawPath != "" {
|
||||
oldURLRawPath := req.URL.RawPath
|
||||
req.URL.RawPath = ap.prefix + req.URL.RawPath
|
||||
logger.Debugf("URL.RawPath is now %s (was %s).", req.URL.RawPath, oldURLRawPath)
|
||||
}
|
||||
req.RequestURI = req.URL.RequestURI()
|
||||
|
||||
ap.next.ServeHTTP(rw, req)
|
||||
}
|
104
middlewares/addprefix/add_prefix_test.go
Normal file
104
middlewares/addprefix/add_prefix_test.go
Normal file
|
@ -0,0 +1,104 @@
|
|||
package addprefix
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewAddPrefix(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
prefix config.AddPrefix
|
||||
expectsError bool
|
||||
}{
|
||||
{
|
||||
desc: "Works with a non empty prefix",
|
||||
prefix: config.AddPrefix{Prefix: "/a"},
|
||||
},
|
||||
{
|
||||
desc: "Fails if prefix is empty",
|
||||
prefix: config.AddPrefix{Prefix: ""},
|
||||
expectsError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
_, err := New(context.Background(), next, test.prefix, "foo-add-prefix")
|
||||
if test.expectsError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddPrefix(t *testing.T) {
|
||||
logrus.SetLevel(logrus.DebugLevel)
|
||||
testCases := []struct {
|
||||
desc string
|
||||
prefix config.AddPrefix
|
||||
path string
|
||||
expectedPath string
|
||||
expectedRawPath string
|
||||
}{
|
||||
{
|
||||
desc: "Works with a regular path",
|
||||
prefix: config.AddPrefix{Prefix: "/a"},
|
||||
path: "/b",
|
||||
expectedPath: "/a/b",
|
||||
},
|
||||
{
|
||||
desc: "Works with a raw path",
|
||||
prefix: config.AddPrefix{Prefix: "/a"},
|
||||
path: "/b%2Fc",
|
||||
expectedPath: "/a/b/c",
|
||||
expectedRawPath: "/a/b%2Fc",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var actualPath, actualRawPath, requestURI string
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
actualPath = r.URL.Path
|
||||
actualRawPath = r.URL.RawPath
|
||||
requestURI = r.RequestURI
|
||||
})
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
|
||||
|
||||
handler, err := New(context.Background(), next, test.prefix, "foo-add-prefix")
|
||||
require.NoError(t, err)
|
||||
|
||||
handler.ServeHTTP(nil, req)
|
||||
|
||||
assert.Equal(t, test.expectedPath, actualPath)
|
||||
assert.Equal(t, test.expectedRawPath, actualRawPath)
|
||||
|
||||
expectedURI := test.expectedPath
|
||||
if test.expectedRawPath != "" {
|
||||
// go HTTP uses the raw path when existent in the RequestURI
|
||||
expectedURI = test.expectedRawPath
|
||||
}
|
||||
assert.Equal(t, expectedURI, requestURI)
|
||||
})
|
||||
}
|
||||
}
|
65
middlewares/auth/auth.go
Normal file
65
middlewares/auth/auth.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// UserParser Parses a string and return a userName/userHash. An error if the format of the string is incorrect.
|
||||
type UserParser func(user string) (string, string, error)
|
||||
|
||||
const (
|
||||
defaultRealm = "traefik"
|
||||
authorizationHeader = "Authorization"
|
||||
)
|
||||
|
||||
func getUsers(fileName string, appendUsers []string, parser UserParser) (map[string]string, error) {
|
||||
users, err := loadUsers(fileName, appendUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userMap := make(map[string]string)
|
||||
for _, user := range users {
|
||||
userName, userHash, err := parser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userMap[userName] = userHash
|
||||
}
|
||||
|
||||
return userMap, nil
|
||||
}
|
||||
|
||||
func loadUsers(fileName string, appendUsers []string) ([]string, error) {
|
||||
var users []string
|
||||
var err error
|
||||
|
||||
if fileName != "" {
|
||||
users, err = getLinesFromFile(fileName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return append(users, appendUsers...), nil
|
||||
}
|
||||
|
||||
func getLinesFromFile(filename string) ([]string, error) {
|
||||
dat, err := ioutil.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Trim lines and filter out blanks
|
||||
rawLines := strings.Split(string(dat), "\n")
|
||||
var filteredLines []string
|
||||
for _, rawLine := range rawLines {
|
||||
line := strings.TrimSpace(rawLine)
|
||||
if line != "" {
|
||||
filteredLines = append(filteredLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredLines, nil
|
||||
}
|
|
@ -1,165 +0,0 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
goauth "github.com/abbot/go-http-auth"
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/middlewares/accesslog"
|
||||
"github.com/containous/traefik/middlewares/tracing"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
// Authenticator is a middleware that provides HTTP basic and digest authentication
|
||||
type Authenticator struct {
|
||||
handler negroni.Handler
|
||||
users map[string]string
|
||||
}
|
||||
|
||||
type tracingAuthenticator struct {
|
||||
name string
|
||||
handler negroni.Handler
|
||||
clientSpanKind bool
|
||||
}
|
||||
|
||||
const (
|
||||
authorizationHeader = "Authorization"
|
||||
)
|
||||
|
||||
// NewAuthenticator builds a new Authenticator given a config
|
||||
func NewAuthenticator(authConfig *types.Auth, tracingMiddleware *tracing.Tracing) (*Authenticator, error) {
|
||||
if authConfig == nil {
|
||||
return nil, fmt.Errorf("error creating Authenticator: auth is nil")
|
||||
}
|
||||
|
||||
var err error
|
||||
authenticator := &Authenticator{}
|
||||
tracingAuth := tracingAuthenticator{}
|
||||
|
||||
if authConfig.Basic != nil {
|
||||
authenticator.users, err = parserBasicUsers(authConfig.Basic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
realm := "traefik"
|
||||
if authConfig.Basic.Realm != "" {
|
||||
realm = authConfig.Basic.Realm
|
||||
}
|
||||
basicAuth := goauth.NewBasicAuthenticator(realm, authenticator.secretBasic)
|
||||
tracingAuth.handler = createAuthBasicHandler(basicAuth, authConfig)
|
||||
tracingAuth.name = "Auth Basic"
|
||||
tracingAuth.clientSpanKind = false
|
||||
} else if authConfig.Digest != nil {
|
||||
authenticator.users, err = parserDigestUsers(authConfig.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
digestAuth := goauth.NewDigestAuthenticator("traefik", authenticator.secretDigest)
|
||||
tracingAuth.handler = createAuthDigestHandler(digestAuth, authConfig)
|
||||
tracingAuth.name = "Auth Digest"
|
||||
tracingAuth.clientSpanKind = false
|
||||
} else if authConfig.Forward != nil {
|
||||
tracingAuth.handler = createAuthForwardHandler(authConfig)
|
||||
tracingAuth.name = "Auth Forward"
|
||||
tracingAuth.clientSpanKind = true
|
||||
}
|
||||
|
||||
if tracingMiddleware != nil {
|
||||
authenticator.handler = tracingMiddleware.NewNegroniHandlerWrapper(tracingAuth.name, tracingAuth.handler, tracingAuth.clientSpanKind)
|
||||
} else {
|
||||
authenticator.handler = tracingAuth.handler
|
||||
}
|
||||
return authenticator, nil
|
||||
}
|
||||
|
||||
func createAuthForwardHandler(authConfig *types.Auth) negroni.HandlerFunc {
|
||||
return negroni.HandlerFunc(func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
Forward(authConfig.Forward, w, r, next)
|
||||
})
|
||||
}
|
||||
func createAuthDigestHandler(digestAuth *goauth.DigestAuth, authConfig *types.Auth) negroni.HandlerFunc {
|
||||
return negroni.HandlerFunc(func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
if username, _ := digestAuth.CheckAuth(r); username == "" {
|
||||
log.Debugf("Digest auth failed")
|
||||
digestAuth.RequireAuth(w, r)
|
||||
} else {
|
||||
log.Debugf("Digest auth succeeded")
|
||||
|
||||
// set username in request context
|
||||
r = accesslog.WithUserName(r, username)
|
||||
|
||||
if authConfig.HeaderField != "" {
|
||||
r.Header[authConfig.HeaderField] = []string{username}
|
||||
}
|
||||
if authConfig.Digest.RemoveHeader {
|
||||
log.Debugf("Remove the Authorization header from the Digest auth")
|
||||
r.Header.Del(authorizationHeader)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
func createAuthBasicHandler(basicAuth *goauth.BasicAuth, authConfig *types.Auth) negroni.HandlerFunc {
|
||||
return negroni.HandlerFunc(func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
if username := basicAuth.CheckAuth(r); username == "" {
|
||||
log.Debugf("Basic auth failed")
|
||||
basicAuth.RequireAuth(w, r)
|
||||
} else {
|
||||
log.Debugf("Basic auth succeeded")
|
||||
|
||||
// set username in request context
|
||||
r = accesslog.WithUserName(r, username)
|
||||
|
||||
if authConfig.HeaderField != "" {
|
||||
r.Header[authConfig.HeaderField] = []string{username}
|
||||
}
|
||||
if authConfig.Basic.RemoveHeader {
|
||||
log.Debugf("Remove the Authorization header from the Basic auth")
|
||||
r.Header.Del(authorizationHeader)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func getLinesFromFile(filename string) ([]string, error) {
|
||||
dat, err := ioutil.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Trim lines and filter out blanks
|
||||
rawLines := strings.Split(string(dat), "\n")
|
||||
var filteredLines []string
|
||||
for _, rawLine := range rawLines {
|
||||
line := strings.TrimSpace(rawLine)
|
||||
if line != "" {
|
||||
filteredLines = append(filteredLines, line)
|
||||
}
|
||||
}
|
||||
return filteredLines, nil
|
||||
}
|
||||
|
||||
func (a *Authenticator) secretBasic(user, realm string) string {
|
||||
if secret, ok := a.users[user]; ok {
|
||||
return secret
|
||||
}
|
||||
log.Debugf("User not found: %s", user)
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *Authenticator) secretDigest(user, realm string) string {
|
||||
if secret, ok := a.users[user+":"+realm]; ok {
|
||||
return secret
|
||||
}
|
||||
log.Debugf("User not found: %s:%s", user, realm)
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *Authenticator) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
a.handler.ServeHTTP(rw, r, next)
|
||||
}
|
|
@ -1,297 +0,0 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/middlewares/tracing"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
func TestAuthUsersFromFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
authType string
|
||||
usersStr string
|
||||
userKeys []string
|
||||
parserFunc func(fileName string) (map[string]string, error)
|
||||
}{
|
||||
{
|
||||
authType: "basic",
|
||||
usersStr: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
|
||||
userKeys: []string{"test", "test2"},
|
||||
parserFunc: func(fileName string) (map[string]string, error) {
|
||||
basic := &types.Basic{
|
||||
UsersFile: fileName,
|
||||
}
|
||||
return parserBasicUsers(basic)
|
||||
},
|
||||
},
|
||||
{
|
||||
authType: "digest",
|
||||
usersStr: "test:traefik:a2688e031edb4be6a3797f3882655c05 \ntest2:traefik:518845800f9e2bfb1f1f740ec24f074e\n",
|
||||
userKeys: []string{"test:traefik", "test2:traefik"},
|
||||
parserFunc: func(fileName string) (map[string]string, error) {
|
||||
digest := &types.Digest{
|
||||
UsersFile: fileName,
|
||||
}
|
||||
return parserDigestUsers(digest)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.authType, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
usersFile, err := ioutil.TempFile("", "auth-users")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(usersFile.Name())
|
||||
|
||||
_, err = usersFile.Write([]byte(test.usersStr))
|
||||
require.NoError(t, err)
|
||||
|
||||
users, err := test.parserFunc(usersFile.Name())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, len(users), "they should be equal")
|
||||
|
||||
_, ok := users[test.userKeys[0]]
|
||||
assert.True(t, ok, "user test should be found")
|
||||
_, ok = users[test.userKeys[1]]
|
||||
assert.True(t, ok, "user test2 should be found")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicAuthFail(t *testing.T) {
|
||||
_, err := NewAuthenticator(&types.Auth{
|
||||
Basic: &types.Basic{
|
||||
Users: []string{"test"},
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
assert.Contains(t, err.Error(), "error parsing Authenticator user", "should contains")
|
||||
|
||||
authMiddleware, err := NewAuthenticator(&types.Auth{
|
||||
Basic: &types.Basic{
|
||||
Users: []string{"test:test"},
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(authMiddleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{}
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusUnauthorized, res.StatusCode, "they should be equal")
|
||||
}
|
||||
|
||||
func TestBasicAuthSuccess(t *testing.T) {
|
||||
authMiddleware, err := NewAuthenticator(&types.Auth{
|
||||
Basic: &types.Basic{
|
||||
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(authMiddleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{}
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "traefik\n", string(body), "they should be equal")
|
||||
}
|
||||
|
||||
func TestBasicRealm(t *testing.T) {
|
||||
authMiddlewareDefaultRealm, errdefault := NewAuthenticator(&types.Auth{
|
||||
Basic: &types.Basic{
|
||||
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
require.NoError(t, errdefault)
|
||||
|
||||
authMiddlewareCustomRealm, errcustom := NewAuthenticator(&types.Auth{
|
||||
Basic: &types.Basic{
|
||||
Realm: "foobar",
|
||||
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
require.NoError(t, errcustom)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
n := negroni.New(authMiddlewareDefaultRealm)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{}
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Basic realm=\"traefik\"", res.Header.Get("Www-Authenticate"), "they should be equal")
|
||||
|
||||
n = negroni.New(authMiddlewareCustomRealm)
|
||||
n.UseHandler(handler)
|
||||
ts = httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
client = &http.Client{}
|
||||
req = testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
res, err = client.Do(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Basic realm=\"foobar\"", res.Header.Get("Www-Authenticate"), "they should be equal")
|
||||
}
|
||||
|
||||
func TestDigestAuthFail(t *testing.T) {
|
||||
_, err := NewAuthenticator(&types.Auth{
|
||||
Digest: &types.Digest{
|
||||
Users: []string{"test"},
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
assert.Contains(t, err.Error(), "error parsing Authenticator user", "should contains")
|
||||
|
||||
authMiddleware, err := NewAuthenticator(&types.Auth{
|
||||
Digest: &types.Digest{
|
||||
Users: []string{"test:traefik:test"},
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, authMiddleware, "this should not be nil")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(authMiddleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{}
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusUnauthorized, res.StatusCode, "they should be equal")
|
||||
}
|
||||
|
||||
func TestBasicAuthUserHeader(t *testing.T) {
|
||||
middleware, err := NewAuthenticator(&types.Auth{
|
||||
Basic: &types.Basic{
|
||||
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
|
||||
},
|
||||
HeaderField: "X-Webauth-User",
|
||||
}, &tracing.Tracing{})
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "test", r.Header["X-Webauth-User"][0], "auth user should be set")
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(middleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{}
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "traefik\n", string(body), "they should be equal")
|
||||
}
|
||||
|
||||
func TestBasicAuthHeaderRemoved(t *testing.T) {
|
||||
middleware, err := NewAuthenticator(&types.Auth{
|
||||
Basic: &types.Basic{
|
||||
RemoveHeader: true,
|
||||
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Empty(t, r.Header.Get(authorizationHeader))
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(middleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{}
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "traefik\n", string(body), "they should be equal")
|
||||
}
|
||||
|
||||
func TestBasicAuthHeaderPresent(t *testing.T) {
|
||||
middleware, err := NewAuthenticator(&types.Auth{
|
||||
Basic: &types.Basic{
|
||||
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.NotEmpty(t, r.Header.Get(authorizationHeader))
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(middleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{}
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "traefik\n", string(body), "they should be equal")
|
||||
}
|
102
middlewares/auth/basic_auth.go
Normal file
102
middlewares/auth/basic_auth.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
goauth "github.com/abbot/go-http-auth"
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/middlewares/accesslog"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
)
|
||||
|
||||
const (
|
||||
basicTypeName = "BasicAuth"
|
||||
)
|
||||
|
||||
type basicAuth struct {
|
||||
next http.Handler
|
||||
auth *goauth.BasicAuth
|
||||
users map[string]string
|
||||
headerField string
|
||||
removeHeader bool
|
||||
name string
|
||||
}
|
||||
|
||||
// NewBasic creates a basicAuth middleware.
|
||||
func NewBasic(ctx context.Context, next http.Handler, authConfig config.BasicAuth, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, basicTypeName).Debug("Creating middleware")
|
||||
users, err := getUsers(authConfig.UsersFile, authConfig.Users, basicUserParser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ba := &basicAuth{
|
||||
next: next,
|
||||
users: users,
|
||||
headerField: authConfig.HeaderField,
|
||||
removeHeader: authConfig.RemoveHeader,
|
||||
name: name,
|
||||
}
|
||||
|
||||
realm := defaultRealm
|
||||
if len(authConfig.Realm) > 0 {
|
||||
realm = authConfig.Realm
|
||||
}
|
||||
ba.auth = goauth.NewBasicAuthenticator(realm, ba.secretBasic)
|
||||
|
||||
return ba, nil
|
||||
}
|
||||
|
||||
func (b *basicAuth) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return b.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (b *basicAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
logger := middlewares.GetLogger(req.Context(), b.name, basicTypeName)
|
||||
|
||||
if username := b.auth.CheckAuth(req); username == "" {
|
||||
logger.Debug("Authentication failed")
|
||||
tracing.SetErrorWithEvent(req, "Authentication failed")
|
||||
b.auth.RequireAuth(rw, req)
|
||||
} else {
|
||||
logger.Debug("Authentication succeeded")
|
||||
req.URL.User = url.User(username)
|
||||
|
||||
logData := accesslog.GetLogData(req)
|
||||
if logData != nil {
|
||||
logData.Core[accesslog.ClientUsername] = username
|
||||
}
|
||||
|
||||
if b.headerField != "" {
|
||||
req.Header[b.headerField] = []string{username}
|
||||
}
|
||||
|
||||
if b.removeHeader {
|
||||
logger.Debug("Removing authorization header")
|
||||
req.Header.Del(authorizationHeader)
|
||||
}
|
||||
b.next.ServeHTTP(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *basicAuth) secretBasic(user, realm string) string {
|
||||
if secret, ok := b.users[user]; ok {
|
||||
return secret
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func basicUserParser(user string) (string, string, error) {
|
||||
split := strings.Split(user, ":")
|
||||
if len(split) != 2 {
|
||||
return "", "", fmt.Errorf("error parsing BasicUser: %v", user)
|
||||
}
|
||||
return split[0], split[1], nil
|
||||
}
|
274
middlewares/auth/basic_auth_test.go
Normal file
274
middlewares/auth/basic_auth_test.go
Normal file
|
@ -0,0 +1,274 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBasicAuthFail(t *testing.T) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
auth := config.BasicAuth{
|
||||
Users: []string{"test"},
|
||||
}
|
||||
_, err := NewBasic(context.Background(), next, auth, "authName")
|
||||
require.Error(t, err)
|
||||
|
||||
auth2 := config.BasicAuth{
|
||||
Users: []string{"test:test"},
|
||||
}
|
||||
authMiddleware, err := NewBasic(context.Background(), next, auth2, "authTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(authMiddleware)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, res.StatusCode, "they should be equal")
|
||||
}
|
||||
|
||||
func TestBasicAuthSuccess(t *testing.T) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
auth := config.BasicAuth{
|
||||
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
|
||||
}
|
||||
authMiddleware, err := NewBasic(context.Background(), next, auth, "authName")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(authMiddleware)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, "traefik\n", string(body), "they should be equal")
|
||||
}
|
||||
|
||||
func TestBasicAuthUserHeader(t *testing.T) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "test", r.Header["X-Webauth-User"][0], "auth user should be set")
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
auth := config.BasicAuth{
|
||||
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
|
||||
HeaderField: "X-Webauth-User",
|
||||
}
|
||||
middleware, err := NewBasic(context.Background(), next, auth, "authName")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(middleware)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, "traefik\n", string(body))
|
||||
}
|
||||
|
||||
func TestBasicAuthHeaderRemoved(t *testing.T) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Empty(t, r.Header.Get(authorizationHeader))
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
auth := config.BasicAuth{
|
||||
RemoveHeader: true,
|
||||
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
|
||||
}
|
||||
middleware, err := NewBasic(context.Background(), next, auth, "authName")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(middleware)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
err = res.Body.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "traefik\n", string(body))
|
||||
}
|
||||
|
||||
func TestBasicAuthHeaderPresent(t *testing.T) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.NotEmpty(t, r.Header.Get(authorizationHeader))
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
auth := config.BasicAuth{
|
||||
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
|
||||
}
|
||||
middleware, err := NewBasic(context.Background(), next, auth, "authName")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(middleware)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
err = res.Body.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "traefik\n", string(body))
|
||||
}
|
||||
|
||||
func TestBasicAuthUsersFromFile(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
userFileContent string
|
||||
expectedUsers map[string]string
|
||||
givenUsers []string
|
||||
realm string
|
||||
}{
|
||||
{
|
||||
desc: "Finds the users in the file",
|
||||
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
|
||||
givenUsers: []string{},
|
||||
expectedUsers: map[string]string{"test": "test", "test2": "test2"},
|
||||
},
|
||||
{
|
||||
desc: "Merges given users with users from the file",
|
||||
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\n",
|
||||
givenUsers: []string{"test2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0", "test3:$apr1$3rJbDP0q$RfzJiorTk78jQ1EcKqWso0"},
|
||||
expectedUsers: map[string]string{"test": "test", "test2": "test2", "test3": "test3"},
|
||||
},
|
||||
{
|
||||
desc: "Given users have priority over users in the file",
|
||||
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
|
||||
givenUsers: []string{"test2:$apr1$mK.GtItK$ncnLYvNLek0weXdxo68690"},
|
||||
expectedUsers: map[string]string{"test": "test", "test2": "overridden"},
|
||||
},
|
||||
{
|
||||
desc: "Should authenticate the correct user based on the realm",
|
||||
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
|
||||
givenUsers: []string{"test2:$apr1$mK.GtItK$ncnLYvNLek0weXdxo68690"},
|
||||
expectedUsers: map[string]string{"test": "test", "test2": "overridden"},
|
||||
realm: "trafikee",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Creates the temporary configuration file with the users
|
||||
usersFile, err := ioutil.TempFile("", "auth-users")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(usersFile.Name())
|
||||
|
||||
_, err = usersFile.Write([]byte(test.userFileContent))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Creates the configuration for our Authenticator
|
||||
authenticatorConfiguration := config.BasicAuth{
|
||||
Users: test.givenUsers,
|
||||
UsersFile: usersFile.Name(),
|
||||
Realm: test.realm,
|
||||
}
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
authenticator, err := NewBasic(context.Background(), next, authenticatorConfiguration, "authName")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(authenticator)
|
||||
defer ts.Close()
|
||||
|
||||
for userName, userPwd := range test.expectedUsers {
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth(userName, userPwd)
|
||||
|
||||
var res *http.Response
|
||||
res, err = http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, http.StatusOK, res.StatusCode, "Cannot authenticate user "+userName)
|
||||
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
err = res.Body.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "traefik\n", string(body))
|
||||
}
|
||||
|
||||
// Checks that user foo doesn't work
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("foo", "foo")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
if len(test.realm) > 0 {
|
||||
require.Equal(t, `Basic realm="`+test.realm+`"`, res.Header.Get("WWW-Authenticate"))
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
err = res.Body.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotContains(t, "traefik", string(body))
|
||||
})
|
||||
}
|
||||
}
|
102
middlewares/auth/digest_auth.go
Normal file
102
middlewares/auth/digest_auth.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
goauth "github.com/abbot/go-http-auth"
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/middlewares/accesslog"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
)
|
||||
|
||||
const (
|
||||
digestTypeName = "digestAuth"
|
||||
)
|
||||
|
||||
type digestAuth struct {
|
||||
next http.Handler
|
||||
auth *goauth.DigestAuth
|
||||
users map[string]string
|
||||
headerField string
|
||||
removeHeader bool
|
||||
name string
|
||||
}
|
||||
|
||||
// NewDigest creates a digest auth middleware.
|
||||
func NewDigest(ctx context.Context, next http.Handler, authConfig config.DigestAuth, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, digestTypeName).Debug("Creating middleware")
|
||||
users, err := getUsers(authConfig.UsersFile, authConfig.Users, digestUserParser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
da := &digestAuth{
|
||||
next: next,
|
||||
users: users,
|
||||
headerField: authConfig.HeaderField,
|
||||
removeHeader: authConfig.RemoveHeader,
|
||||
name: name,
|
||||
}
|
||||
|
||||
realm := defaultRealm
|
||||
if len(authConfig.Realm) > 0 {
|
||||
realm = authConfig.Realm
|
||||
}
|
||||
da.auth = goauth.NewDigestAuthenticator(realm, da.secretDigest)
|
||||
|
||||
return da, nil
|
||||
}
|
||||
|
||||
func (d *digestAuth) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return d.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (d *digestAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
logger := middlewares.GetLogger(req.Context(), d.name, digestTypeName)
|
||||
|
||||
if username, _ := d.auth.CheckAuth(req); username == "" {
|
||||
logger.Debug("Digest authentication failed")
|
||||
tracing.SetErrorWithEvent(req, "Digest authentication failed")
|
||||
d.auth.RequireAuth(rw, req)
|
||||
} else {
|
||||
logger.Debug("Digest authentication succeeded")
|
||||
req.URL.User = url.User(username)
|
||||
|
||||
logData := accesslog.GetLogData(req)
|
||||
if logData != nil {
|
||||
logData.Core[accesslog.ClientUsername] = username
|
||||
}
|
||||
|
||||
if d.headerField != "" {
|
||||
req.Header[d.headerField] = []string{username}
|
||||
}
|
||||
|
||||
if d.removeHeader {
|
||||
logger.Debug("Removing the Authorization header")
|
||||
req.Header.Del(authorizationHeader)
|
||||
}
|
||||
d.next.ServeHTTP(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *digestAuth) secretDigest(user, realm string) string {
|
||||
if secret, ok := d.users[user+":"+realm]; ok {
|
||||
return secret
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func digestUserParser(user string) (string, string, error) {
|
||||
split := strings.Split(user, ":")
|
||||
if len(split) != 3 {
|
||||
return "", "", fmt.Errorf("error parsing DigestUser: %v", user)
|
||||
}
|
||||
return split[0] + ":" + split[1], split[2], nil
|
||||
}
|
141
middlewares/auth/digest_auth_request_test.go
Normal file
141
middlewares/auth/digest_auth_request_test.go
Normal file
|
@ -0,0 +1,141 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
algorithm = "algorithm"
|
||||
authorization = "Authorization"
|
||||
nonce = "nonce"
|
||||
opaque = "opaque"
|
||||
qop = "qop"
|
||||
realm = "realm"
|
||||
wwwAuthenticate = "Www-Authenticate"
|
||||
)
|
||||
|
||||
// DigestRequest is a client for digest authentication requests
|
||||
type digestRequest struct {
|
||||
client *http.Client
|
||||
username, password string
|
||||
nonceCount nonceCount
|
||||
}
|
||||
|
||||
type nonceCount int
|
||||
|
||||
func (nc nonceCount) String() string {
|
||||
return fmt.Sprintf("%08x", int(nc))
|
||||
}
|
||||
|
||||
var wanted = []string{algorithm, nonce, opaque, qop, realm}
|
||||
|
||||
// New makes a DigestRequest instance
|
||||
func newDigestRequest(username, password string, client *http.Client) *digestRequest {
|
||||
return &digestRequest{
|
||||
client: client,
|
||||
username: username,
|
||||
password: password,
|
||||
}
|
||||
}
|
||||
|
||||
// Do does requests as http.Do does
|
||||
func (r *digestRequest) Do(req *http.Request) (*http.Response, error) {
|
||||
parts, err := r.makeParts(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if parts != nil {
|
||||
req.Header.Set(authorization, r.makeAuthorization(req, parts))
|
||||
}
|
||||
|
||||
return r.client.Do(req)
|
||||
}
|
||||
|
||||
func (r *digestRequest) makeParts(req *http.Request) (map[string]string, error) {
|
||||
authReq, err := http.NewRequest(req.Method, req.URL.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := r.client.Do(authReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if len(resp.Header[wwwAuthenticate]) == 0 {
|
||||
return nil, fmt.Errorf("headers do not have %s", wwwAuthenticate)
|
||||
}
|
||||
|
||||
headers := strings.Split(resp.Header[wwwAuthenticate][0], ",")
|
||||
parts := make(map[string]string, len(wanted))
|
||||
for _, r := range headers {
|
||||
for _, w := range wanted {
|
||||
if strings.Contains(r, w) {
|
||||
parts[w] = strings.Split(r, `"`)[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(parts) != len(wanted) {
|
||||
return nil, fmt.Errorf("header is invalid: %+v", parts)
|
||||
}
|
||||
|
||||
return parts, nil
|
||||
}
|
||||
|
||||
func getMD5(texts []string) string {
|
||||
h := md5.New()
|
||||
_, _ = io.WriteString(h, strings.Join(texts, ":"))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func (r *digestRequest) getNonceCount() string {
|
||||
r.nonceCount++
|
||||
return r.nonceCount.String()
|
||||
}
|
||||
|
||||
func (r *digestRequest) makeAuthorization(req *http.Request, parts map[string]string) string {
|
||||
ha1 := getMD5([]string{r.username, parts[realm], r.password})
|
||||
ha2 := getMD5([]string{req.Method, req.URL.String()})
|
||||
cnonce := generateRandom(16)
|
||||
nc := r.getNonceCount()
|
||||
response := getMD5([]string{
|
||||
ha1,
|
||||
parts[nonce],
|
||||
nc,
|
||||
cnonce,
|
||||
parts[qop],
|
||||
ha2,
|
||||
})
|
||||
return fmt.Sprintf(
|
||||
`Digest username="%s", realm="%s", nonce="%s", uri="%s", algorithm=%s, qop=%s, nc=%s, cnonce="%s", response="%s", opaque="%s"`,
|
||||
r.username,
|
||||
parts[realm],
|
||||
parts[nonce],
|
||||
req.URL.String(),
|
||||
parts[algorithm],
|
||||
parts[qop],
|
||||
nc,
|
||||
cnonce,
|
||||
response,
|
||||
parts[opaque],
|
||||
)
|
||||
}
|
||||
|
||||
// GenerateRandom generates random string
|
||||
func generateRandom(n int) string {
|
||||
b := make([]byte, 8)
|
||||
_, _ = io.ReadFull(rand.Reader, b)
|
||||
return fmt.Sprintf("%x", b)[:n]
|
||||
}
|
156
middlewares/auth/digest_auth_test.go
Normal file
156
middlewares/auth/digest_auth_test.go
Normal file
|
@ -0,0 +1,156 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDigestAuthError(t *testing.T) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
auth := config.DigestAuth{
|
||||
Users: []string{"test"},
|
||||
}
|
||||
_, err := NewDigest(context.Background(), next, auth, "authName")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDigestAuthFail(t *testing.T) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
auth := config.DigestAuth{
|
||||
Users: []string{"test:traefik:a2688e031edb4be6a3797f3882655c05"},
|
||||
}
|
||||
authMiddleware, err := NewDigest(context.Background(), next, auth, "authName")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, authMiddleware, "this should not be nil")
|
||||
|
||||
ts := httptest.NewServer(authMiddleware)
|
||||
defer ts.Close()
|
||||
|
||||
client := http.DefaultClient
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
|
||||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
}
|
||||
|
||||
func TestDigestAuthUsersFromFile(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
userFileContent string
|
||||
expectedUsers map[string]string
|
||||
givenUsers []string
|
||||
realm string
|
||||
}{
|
||||
{
|
||||
desc: "Finds the users in the file",
|
||||
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\ntest2:traefik:518845800f9e2bfb1f1f740ec24f074e\n",
|
||||
givenUsers: []string{},
|
||||
expectedUsers: map[string]string{"test": "test", "test2": "test2"},
|
||||
},
|
||||
{
|
||||
desc: "Merges given users with users from the file",
|
||||
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\n",
|
||||
givenUsers: []string{"test2:traefik:518845800f9e2bfb1f1f740ec24f074e", "test3:traefik:c8e9f57ce58ecb4424407f665a91646c"},
|
||||
expectedUsers: map[string]string{"test": "test", "test2": "test2", "test3": "test3"},
|
||||
},
|
||||
{
|
||||
desc: "Given users have priority over users in the file",
|
||||
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\ntest2:traefik:518845800f9e2bfb1f1f740ec24f074e\n",
|
||||
givenUsers: []string{"test2:traefik:8de60a1c52da68ccf41f0c0ffb7c51a0"},
|
||||
expectedUsers: map[string]string{"test": "test", "test2": "overridden"},
|
||||
},
|
||||
{
|
||||
desc: "Should authenticate the correct user based on the realm",
|
||||
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\ntest:traefikee:316a669c158c8b7ab1048b03961a7aa5\n",
|
||||
givenUsers: []string{},
|
||||
expectedUsers: map[string]string{"test": "test2"},
|
||||
realm: "traefikee",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Creates the temporary configuration file with the users
|
||||
usersFile, err := ioutil.TempFile("", "auth-users")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(usersFile.Name())
|
||||
|
||||
_, err = usersFile.Write([]byte(test.userFileContent))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Creates the configuration for our Authenticator
|
||||
authenticatorConfiguration := config.DigestAuth{
|
||||
Users: test.givenUsers,
|
||||
UsersFile: usersFile.Name(),
|
||||
Realm: test.realm,
|
||||
}
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
authenticator, err := NewDigest(context.Background(), next, authenticatorConfiguration, "authName")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(authenticator)
|
||||
defer ts.Close()
|
||||
|
||||
for userName, userPwd := range test.expectedUsers {
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
digestRequest := newDigestRequest(userName, userPwd, http.DefaultClient)
|
||||
|
||||
var res *http.Response
|
||||
res, err = digestRequest.Do(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, res.StatusCode, "Cannot authenticate user "+userName)
|
||||
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
err = res.Body.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "traefik\n", string(body))
|
||||
}
|
||||
|
||||
// Checks that user foo doesn't work
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
digestRequest := newDigestRequest("foo", "foo", http.DefaultClient)
|
||||
|
||||
var res *http.Response
|
||||
res, err = digestRequest.Do(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
err = res.Body.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotContains(t, "traefik", string(body))
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,25 +1,68 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/middlewares/tracing"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/vulcand/oxy/forward"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
xForwardedURI = "X-Forwarded-Uri"
|
||||
xForwardedMethod = "X-Forwarded-Method"
|
||||
xForwardedURI = "X-Forwarded-Uri"
|
||||
xForwardedMethod = "X-Forwarded-Method"
|
||||
forwardedTypeName = "ForwardedAuthType"
|
||||
)
|
||||
|
||||
// Forward the authentication to a external server
|
||||
func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
type forwardAuth struct {
|
||||
address string
|
||||
authResponseHeaders []string
|
||||
next http.Handler
|
||||
name string
|
||||
tlsConfig *tls.Config
|
||||
trustForwardHeader bool
|
||||
}
|
||||
|
||||
// NewForward creates a forward auth middleware.
|
||||
func NewForward(ctx context.Context, next http.Handler, config config.ForwardAuth, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, forwardedTypeName).Debug("Creating middleware")
|
||||
|
||||
fa := &forwardAuth{
|
||||
address: config.Address,
|
||||
authResponseHeaders: config.AuthResponseHeaders,
|
||||
next: next,
|
||||
name: name,
|
||||
trustForwardHeader: config.TrustForwardHeader,
|
||||
}
|
||||
|
||||
if config.TLS != nil {
|
||||
tlsConfig, err := config.TLS.CreateTLSConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fa.tlsConfig = tlsConfig
|
||||
}
|
||||
|
||||
return fa, nil
|
||||
}
|
||||
|
||||
func (fa *forwardAuth) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return fa.name, ext.SpanKindRPCClientEnum
|
||||
}
|
||||
|
||||
func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
logger := middlewares.GetLogger(req.Context(), fa.name, forwardedTypeName)
|
||||
|
||||
// Ensure our request client does not follow redirects
|
||||
httpClient := http.Client{
|
||||
CheckRedirect: func(r *http.Request, via []*http.Request) error {
|
||||
|
@ -27,42 +70,44 @@ func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next
|
|||
},
|
||||
}
|
||||
|
||||
if config.TLS != nil {
|
||||
tlsConfig, err := config.TLS.CreateTLSConfig()
|
||||
if err != nil {
|
||||
tracing.SetErrorAndDebugLog(r, "Unable to configure TLS to call %s. Cause %s", config.Address, err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if fa.tlsConfig != nil {
|
||||
httpClient.Transport = &http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
TLSClientConfig: fa.tlsConfig,
|
||||
}
|
||||
}
|
||||
|
||||
forwardReq, err := http.NewRequest(http.MethodGet, config.Address, http.NoBody)
|
||||
tracing.LogRequest(tracing.GetSpan(r), forwardReq)
|
||||
forwardReq, err := http.NewRequest(http.MethodGet, fa.address, nil)
|
||||
tracing.LogRequest(tracing.GetSpan(req), forwardReq)
|
||||
if err != nil {
|
||||
tracing.SetErrorAndDebugLog(r, "Error calling %s. Cause %s", config.Address, err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
logMessage := fmt.Sprintf("Error calling %s. Cause %s", fa.address, err)
|
||||
logger.Debug(logMessage)
|
||||
tracing.SetErrorWithEvent(req, logMessage)
|
||||
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeHeader(r, forwardReq, config.TrustForwardHeader)
|
||||
writeHeader(req, forwardReq, fa.trustForwardHeader)
|
||||
|
||||
tracing.InjectRequestHeaders(forwardReq)
|
||||
|
||||
forwardResponse, forwardErr := httpClient.Do(forwardReq)
|
||||
if forwardErr != nil {
|
||||
tracing.SetErrorAndDebugLog(r, "Error calling %s. Cause: %s", config.Address, forwardErr)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
logMessage := fmt.Sprintf("Error calling %s. Cause: %s", fa.address, forwardErr)
|
||||
logger.Debug(logMessage)
|
||||
tracing.SetErrorWithEvent(req, logMessage)
|
||||
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
body, readError := ioutil.ReadAll(forwardResponse.Body)
|
||||
if readError != nil {
|
||||
tracing.SetErrorAndDebugLog(r, "Error reading body %s. Cause: %s", config.Address, readError)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
logMessage := fmt.Sprintf("Error reading body %s. Cause: %s", fa.address, readError)
|
||||
logger.Debug(logMessage)
|
||||
tracing.SetErrorWithEvent(req, logMessage)
|
||||
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer forwardResponse.Body.Close()
|
||||
|
@ -70,40 +115,43 @@ func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next
|
|||
// Pass the forward response's body and selected headers if it
|
||||
// didn't return a response within the range of [200, 300).
|
||||
if forwardResponse.StatusCode < http.StatusOK || forwardResponse.StatusCode >= http.StatusMultipleChoices {
|
||||
log.Debugf("Remote error %s. StatusCode: %d", config.Address, forwardResponse.StatusCode)
|
||||
logger.Debugf("Remote error %s. StatusCode: %d", fa.address, forwardResponse.StatusCode)
|
||||
|
||||
utils.CopyHeaders(w.Header(), forwardResponse.Header)
|
||||
utils.RemoveHeaders(w.Header(), forward.HopHeaders...)
|
||||
utils.CopyHeaders(rw.Header(), forwardResponse.Header)
|
||||
utils.RemoveHeaders(rw.Header(), forward.HopHeaders...)
|
||||
|
||||
// Grab the location header, if any.
|
||||
redirectURL, err := forwardResponse.Location()
|
||||
|
||||
if err != nil {
|
||||
if err != http.ErrNoLocation {
|
||||
tracing.SetErrorAndDebugLog(r, "Error reading response location header %s. Cause: %s", config.Address, err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
logMessage := fmt.Sprintf("Error reading response location header %s. Cause: %s", fa.address, err)
|
||||
logger.Debug(logMessage)
|
||||
tracing.SetErrorWithEvent(req, logMessage)
|
||||
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
} else if redirectURL.String() != "" {
|
||||
// Set the location in our response if one was sent back.
|
||||
w.Header().Set("Location", redirectURL.String())
|
||||
rw.Header().Set("Location", redirectURL.String())
|
||||
}
|
||||
|
||||
tracing.LogResponseCode(tracing.GetSpan(r), forwardResponse.StatusCode)
|
||||
w.WriteHeader(forwardResponse.StatusCode)
|
||||
tracing.LogResponseCode(tracing.GetSpan(req), forwardResponse.StatusCode)
|
||||
rw.WriteHeader(forwardResponse.StatusCode)
|
||||
|
||||
if _, err = w.Write(body); err != nil {
|
||||
log.Error(err)
|
||||
if _, err = rw.Write(body); err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, headerName := range config.AuthResponseHeaders {
|
||||
r.Header.Set(headerName, forwardResponse.Header.Get(headerName))
|
||||
for _, headerName := range fa.authResponseHeaders {
|
||||
req.Header.Set(headerName, forwardResponse.Header.Get(headerName))
|
||||
}
|
||||
|
||||
r.RequestURI = r.URL.RequestURI()
|
||||
next(w, r)
|
||||
req.RequestURI = req.URL.RequestURI()
|
||||
fa.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
func writeHeader(req *http.Request, forwardReq *http.Request, trustForwardHeader bool) {
|
||||
|
|
|
@ -1,50 +1,49 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/middlewares/tracing"
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/urfave/negroni"
|
||||
"github.com/vulcand/oxy/forward"
|
||||
)
|
||||
|
||||
func TestForwardAuthFail(t *testing.T) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
middleware, err := NewAuthenticator(&types.Auth{
|
||||
Forward: &types.Forward{
|
||||
Address: server.URL,
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
middleware, err := NewForward(context.Background(), next, config.ForwardAuth{
|
||||
Address: server.URL,
|
||||
}, "authTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(middleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
ts := httptest.NewServer(middleware)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, http.StatusForbidden, res.StatusCode, "they should be equal")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusForbidden, res.StatusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, "Forbidden\n", string(body), "they should be equal")
|
||||
require.NoError(t, err)
|
||||
err = res.Body.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "Forbidden\n", string(body))
|
||||
}
|
||||
|
||||
func TestForwardAuthSuccess(t *testing.T) {
|
||||
|
@ -55,32 +54,32 @@ func TestForwardAuthSuccess(t *testing.T) {
|
|||
}))
|
||||
defer server.Close()
|
||||
|
||||
middleware, err := NewAuthenticator(&types.Auth{
|
||||
Forward: &types.Forward{
|
||||
Address: server.URL,
|
||||
AuthResponseHeaders: []string{"X-Auth-User"},
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "user@example.com", r.Header.Get("X-Auth-User"))
|
||||
assert.Empty(t, r.Header.Get("X-Auth-Secret"))
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(middleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
|
||||
auth := config.ForwardAuth{
|
||||
Address: server.URL,
|
||||
AuthResponseHeaders: []string{"X-Auth-User"},
|
||||
}
|
||||
middleware, err := NewForward(context.Background(), next, auth, "authTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(middleware)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, "traefik\n", string(body), "they should be equal")
|
||||
require.NoError(t, err)
|
||||
err = res.Body.Close()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "traefik\n", string(body))
|
||||
}
|
||||
|
||||
func TestForwardAuthRedirect(t *testing.T) {
|
||||
|
@ -89,19 +88,17 @@ func TestForwardAuthRedirect(t *testing.T) {
|
|||
}))
|
||||
defer authTs.Close()
|
||||
|
||||
authMiddleware, err := NewAuthenticator(&types.Auth{
|
||||
Forward: &types.Forward{
|
||||
Address: authTs.URL,
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(authMiddleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
|
||||
auth := config.ForwardAuth{
|
||||
Address: authTs.URL,
|
||||
}
|
||||
authMiddleware, err := NewForward(context.Background(), next, auth, "authTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(authMiddleware)
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{
|
||||
|
@ -109,18 +106,23 @@ func TestForwardAuthRedirect(t *testing.T) {
|
|||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
|
||||
res, err := client.Do(req)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, http.StatusFound, res.StatusCode, "they should be equal")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusFound, res.StatusCode)
|
||||
|
||||
location, err := res.Location()
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, "http://example.com/redirect-test", location.String(), "they should be equal")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "http://example.com/redirect-test", location.String())
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.NotEmpty(t, string(body), "there should be something in the body")
|
||||
require.NoError(t, err)
|
||||
err = res.Body.Close()
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, string(body))
|
||||
}
|
||||
|
||||
func TestForwardAuthRemoveHopByHopHeaders(t *testing.T) {
|
||||
|
@ -138,19 +140,17 @@ func TestForwardAuthRemoveHopByHopHeaders(t *testing.T) {
|
|||
}))
|
||||
defer authTs.Close()
|
||||
|
||||
authMiddleware, err := NewAuthenticator(&types.Auth{
|
||||
Forward: &types.Forward{
|
||||
Address: authTs.URL,
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(authMiddleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
auth := config.ForwardAuth{
|
||||
Address: authTs.URL,
|
||||
}
|
||||
authMiddleware, err := NewForward(context.Background(), next, auth, "authTest")
|
||||
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
|
||||
ts := httptest.NewServer(authMiddleware)
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{
|
||||
|
@ -185,30 +185,28 @@ func TestForwardAuthFailResponseHeaders(t *testing.T) {
|
|||
}))
|
||||
defer authTs.Close()
|
||||
|
||||
authMiddleware, err := NewAuthenticator(&types.Auth{
|
||||
Forward: &types.Forward{
|
||||
Address: authTs.URL,
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(authMiddleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
|
||||
auth := config.ForwardAuth{
|
||||
Address: authTs.URL,
|
||||
}
|
||||
authMiddleware, err := NewForward(context.Background(), next, auth, "authTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(authMiddleware)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
client := &http.Client{}
|
||||
res, err := client.Do(req)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, http.StatusForbidden, res.StatusCode, "they should be equal")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusForbidden, res.StatusCode)
|
||||
|
||||
require.Len(t, res.Cookies(), 1)
|
||||
for _, cookie := range res.Cookies() {
|
||||
assert.Equal(t, "testing", cookie.Value, "they should be equal")
|
||||
assert.Equal(t, "testing", cookie.Value)
|
||||
}
|
||||
|
||||
expectedHeaders := http.Header{
|
||||
|
@ -225,8 +223,11 @@ func TestForwardAuthFailResponseHeaders(t *testing.T) {
|
|||
}
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, "Forbidden\n", string(body), "they should be equal")
|
||||
require.NoError(t, err)
|
||||
err = res.Body.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "Forbidden\n", string(body))
|
||||
}
|
||||
|
||||
func Test_writeHeader(t *testing.T) {
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/types"
|
||||
)
|
||||
|
||||
func parserBasicUsers(basic *types.Basic) (map[string]string, error) {
|
||||
var userStrs []string
|
||||
if basic.UsersFile != "" {
|
||||
var err error
|
||||
if userStrs, err = getLinesFromFile(basic.UsersFile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
userStrs = append(basic.Users, userStrs...)
|
||||
userMap := make(map[string]string)
|
||||
for _, user := range userStrs {
|
||||
split := strings.Split(user, ":")
|
||||
if len(split) != 2 {
|
||||
return nil, fmt.Errorf("error parsing Authenticator user: %v", user)
|
||||
}
|
||||
userMap[split[0]] = split[1]
|
||||
}
|
||||
return userMap, nil
|
||||
}
|
||||
|
||||
func parserDigestUsers(digest *types.Digest) (map[string]string, error) {
|
||||
var userStrs []string
|
||||
if digest.UsersFile != "" {
|
||||
var err error
|
||||
if userStrs, err = getLinesFromFile(digest.UsersFile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
userStrs = append(digest.Users, userStrs...)
|
||||
userMap := make(map[string]string)
|
||||
for _, user := range userStrs {
|
||||
split := strings.Split(user, ":")
|
||||
if len(split) != 3 {
|
||||
return nil, fmt.Errorf("error parsing Authenticator user: %v", user)
|
||||
}
|
||||
userMap[split[0]+":"+split[1]] = split[2]
|
||||
}
|
||||
return userMap, nil
|
||||
}
|
54
middlewares/buffering/buffering.go
Normal file
54
middlewares/buffering/buffering.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package buffering
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
oxybuffer "github.com/vulcand/oxy/buffer"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "Buffer"
|
||||
)
|
||||
|
||||
type buffer struct {
|
||||
name string
|
||||
buffer *oxybuffer.Buffer
|
||||
}
|
||||
|
||||
// New creates a buffering middleware.
|
||||
func New(ctx context.Context, next http.Handler, config config.Buffering, name string) (http.Handler, error) {
|
||||
logger := middlewares.GetLogger(ctx, name, typeName)
|
||||
logger.Debug("Creating middleware")
|
||||
logger.Debug("Setting up buffering: request limits: %d (mem), %d (max), response limits: %d (mem), %d (max) with retry: '%s'",
|
||||
config.MemRequestBodyBytes, config.MaxRequestBodyBytes, config.MemResponseBodyBytes, config.MaxResponseBodyBytes, config.RetryExpression)
|
||||
|
||||
oxyBuffer, err := oxybuffer.New(
|
||||
next,
|
||||
oxybuffer.MemRequestBodyBytes(config.MemRequestBodyBytes),
|
||||
oxybuffer.MaxRequestBodyBytes(config.MaxRequestBodyBytes),
|
||||
oxybuffer.MemResponseBodyBytes(config.MemResponseBodyBytes),
|
||||
oxybuffer.MaxResponseBodyBytes(config.MaxResponseBodyBytes),
|
||||
oxybuffer.CondSetter(len(config.RetryExpression) > 0, oxybuffer.Retry(config.RetryExpression)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &buffer{
|
||||
name: name,
|
||||
buffer: oxyBuffer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *buffer) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return b.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (b *buffer) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
b.buffer.ServeHTTP(rw, req)
|
||||
}
|
|
@ -1,40 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/middlewares/tracing"
|
||||
"github.com/vulcand/oxy/cbreaker"
|
||||
)
|
||||
|
||||
// CircuitBreaker holds the oxy circuit breaker.
|
||||
type CircuitBreaker struct {
|
||||
circuitBreaker *cbreaker.CircuitBreaker
|
||||
}
|
||||
|
||||
// NewCircuitBreaker returns a new CircuitBreaker.
|
||||
func NewCircuitBreaker(next http.Handler, expression string, options ...cbreaker.CircuitBreakerOption) (*CircuitBreaker, error) {
|
||||
circuitBreaker, err := cbreaker.New(next, expression, options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &CircuitBreaker{circuitBreaker}, nil
|
||||
}
|
||||
|
||||
// NewCircuitBreakerOptions returns a new CircuitBreakerOption
|
||||
func NewCircuitBreakerOptions(expression string) cbreaker.CircuitBreakerOption {
|
||||
return cbreaker.Fallback(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tracing.LogEventf(r, "blocked by circuit-breaker (%q)", expression)
|
||||
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
|
||||
if _, err := w.Write([]byte(http.StatusText(http.StatusServiceUnavailable))); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func (cb *CircuitBreaker) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
cb.circuitBreaker.ServeHTTP(rw, r)
|
||||
}
|
30
middlewares/chain/chain.go
Normal file
30
middlewares/chain/chain.go
Normal file
|
@ -0,0 +1,30 @@
|
|||
package chain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/alice"
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "Chain"
|
||||
)
|
||||
|
||||
type chainBuilder interface {
|
||||
BuildChain(ctx context.Context, middlewares []string) (*alice.Chain, error)
|
||||
}
|
||||
|
||||
// New creates a chain middleware
|
||||
func New(ctx context.Context, next http.Handler, config config.Chain, builder chainBuilder, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
|
||||
middlewareChain, err := builder.BuildChain(ctx, config.Middlewares)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return middlewareChain.Then(next)
|
||||
}
|
61
middlewares/circuitbreaker/circuit_breaker.go
Normal file
61
middlewares/circuitbreaker/circuit_breaker.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package circuitbreaker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/vulcand/oxy/cbreaker"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "CircuitBreaker"
|
||||
)
|
||||
|
||||
type circuitBreaker struct {
|
||||
circuitBreaker *cbreaker.CircuitBreaker
|
||||
name string
|
||||
}
|
||||
|
||||
// New creates a new circuit breaker middleware.
|
||||
func New(ctx context.Context, next http.Handler, confCircuitBreaker config.CircuitBreaker, name string) (http.Handler, error) {
|
||||
expression := confCircuitBreaker.Expression
|
||||
|
||||
logger := middlewares.GetLogger(ctx, name, typeName)
|
||||
logger.Debug("Creating middleware")
|
||||
logger.Debug("Setting up with expression: %s", expression)
|
||||
|
||||
oxyCircuitBreaker, err := cbreaker.New(next, expression, createCircuitBreakerOptions(expression))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &circuitBreaker{
|
||||
circuitBreaker: oxyCircuitBreaker,
|
||||
name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewCircuitBreakerOptions returns a new CircuitBreakerOption
|
||||
func createCircuitBreakerOptions(expression string) cbreaker.CircuitBreakerOption {
|
||||
return cbreaker.Fallback(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
tracing.SetErrorWithEvent(req, "blocked by circuit-breaker (%q)", expression)
|
||||
rw.WriteHeader(http.StatusServiceUnavailable)
|
||||
|
||||
if _, err := rw.Write([]byte(http.StatusText(http.StatusServiceUnavailable))); err != nil {
|
||||
log.FromContext(req.Context()).Error(err)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func (c *circuitBreaker) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return c.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (c *circuitBreaker) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
middlewares.GetLogger(req.Context(), c.name, typeName).Debug("Entering middleware")
|
||||
c.circuitBreaker.ServeHTTP(rw, req)
|
||||
}
|
|
@ -1,33 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/containous/traefik/log"
|
||||
)
|
||||
|
||||
// Compress is a middleware that allows to compress the response
|
||||
type Compress struct{}
|
||||
|
||||
// ServeHTTP is a function used by Negroni
|
||||
func (c *Compress) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/grpc") {
|
||||
next.ServeHTTP(rw, r)
|
||||
} else {
|
||||
gzipHandler(next).ServeHTTP(rw, r)
|
||||
}
|
||||
}
|
||||
|
||||
func gzipHandler(h http.Handler) http.Handler {
|
||||
wrapper, err := gziphandler.GzipHandlerWithOpts(
|
||||
gziphandler.CompressionLevel(gzip.DefaultCompression),
|
||||
gziphandler.MinSize(gziphandler.DefaultMinSize))
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
return wrapper(h)
|
||||
}
|
58
middlewares/compress/compress.go
Normal file
58
middlewares/compress/compress.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package compress
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "Compress"
|
||||
)
|
||||
|
||||
// Compress is a middleware that allows to compress the response.
|
||||
type compress struct {
|
||||
next http.Handler
|
||||
name string
|
||||
}
|
||||
|
||||
// New creates a new compress middleware.
|
||||
func New(ctx context.Context, next http.Handler, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
|
||||
return &compress{
|
||||
next: next,
|
||||
name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *compress) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
contentType := req.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/grpc") {
|
||||
c.next.ServeHTTP(rw, req)
|
||||
} else {
|
||||
gzipHandler(c.next, middlewares.GetLogger(req.Context(), c.name, typeName)).ServeHTTP(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *compress) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return c.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func gzipHandler(h http.Handler, logger logrus.FieldLogger) http.Handler {
|
||||
wrapper, err := gziphandler.GzipHandlerWithOpts(
|
||||
gziphandler.CompressionLevel(gzip.DefaultCompression),
|
||||
gziphandler.MinSize(gziphandler.DefaultMinSize))
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
|
||||
return wrapper(h)
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package middlewares
|
||||
package compress
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
|
@ -10,7 +10,6 @@ import (
|
|||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -22,19 +21,19 @@ const (
|
|||
)
|
||||
|
||||
func TestShouldCompressWhenNoContentEncodingHeader(t *testing.T) {
|
||||
handler := &Compress{}
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
|
||||
req.Header.Add(acceptEncodingHeader, gzipValue)
|
||||
|
||||
baseBody := generateBytes(gziphandler.DefaultMinSize)
|
||||
next := func(rw http.ResponseWriter, r *http.Request) {
|
||||
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
_, err := rw.Write(baseBody)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
handler := &compress{next: next}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rw, req, next)
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
assert.Equal(t, gzipValue, rw.Header().Get(contentEncodingHeader))
|
||||
assert.Equal(t, acceptEncodingHeader, rw.Header().Get(varyHeader))
|
||||
|
@ -45,20 +44,22 @@ func TestShouldCompressWhenNoContentEncodingHeader(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestShouldNotCompressWhenContentEncodingHeader(t *testing.T) {
|
||||
handler := &Compress{}
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
|
||||
req.Header.Add(acceptEncodingHeader, gzipValue)
|
||||
|
||||
fakeCompressedBody := generateBytes(gziphandler.DefaultMinSize)
|
||||
next := func(rw http.ResponseWriter, r *http.Request) {
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Add(contentEncodingHeader, gzipValue)
|
||||
rw.Header().Add(varyHeader, acceptEncodingHeader)
|
||||
rw.Write(fakeCompressedBody)
|
||||
}
|
||||
_, err := rw.Write(fakeCompressedBody)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
handler := &compress{next: next}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rw, req, next)
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
assert.Equal(t, gzipValue, rw.Header().Get(contentEncodingHeader))
|
||||
assert.Equal(t, acceptEncodingHeader, rw.Header().Get(varyHeader))
|
||||
|
@ -67,36 +68,40 @@ func TestShouldNotCompressWhenContentEncodingHeader(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestShouldNotCompressWhenNoAcceptEncodingHeader(t *testing.T) {
|
||||
handler := &Compress{}
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
|
||||
|
||||
fakeBody := generateBytes(gziphandler.DefaultMinSize)
|
||||
next := func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Write(fakeBody)
|
||||
}
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
_, err := rw.Write(fakeBody)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
handler := &compress{next: next}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rw, req, next)
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
assert.Empty(t, rw.Header().Get(contentEncodingHeader))
|
||||
assert.EqualValues(t, rw.Body.Bytes(), fakeBody)
|
||||
}
|
||||
|
||||
func TestShouldNotCompressWhenGRPC(t *testing.T) {
|
||||
handler := &Compress{}
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
|
||||
req.Header.Add(acceptEncodingHeader, gzipValue)
|
||||
req.Header.Add(contentTypeHeader, "application/grpc")
|
||||
|
||||
baseBody := generateBytes(gziphandler.DefaultMinSize)
|
||||
next := func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Write(baseBody)
|
||||
}
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
_, err := rw.Write(baseBody)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
handler := &compress{next: next}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rw, req, next)
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
assert.Empty(t, rw.Header().Get(acceptEncodingHeader))
|
||||
assert.Empty(t, rw.Header().Get(contentEncodingHeader))
|
||||
|
@ -105,40 +110,43 @@ func TestShouldNotCompressWhenGRPC(t *testing.T) {
|
|||
|
||||
func TestIntegrationShouldNotCompress(t *testing.T) {
|
||||
fakeCompressedBody := generateBytes(100000)
|
||||
comp := &Compress{}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
handler func(rw http.ResponseWriter, r *http.Request)
|
||||
handler http.Handler
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "when content already compressed",
|
||||
handler: func(rw http.ResponseWriter, r *http.Request) {
|
||||
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Add(contentEncodingHeader, gzipValue)
|
||||
rw.Header().Add(varyHeader, acceptEncodingHeader)
|
||||
rw.Write(fakeCompressedBody)
|
||||
},
|
||||
_, err := rw.Write(fakeCompressedBody)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}),
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "when content already compressed and status code Created",
|
||||
handler: func(rw http.ResponseWriter, r *http.Request) {
|
||||
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Add(contentEncodingHeader, gzipValue)
|
||||
rw.Header().Add(varyHeader, acceptEncodingHeader)
|
||||
rw.WriteHeader(http.StatusCreated)
|
||||
rw.Write(fakeCompressedBody)
|
||||
},
|
||||
_, err := rw.Write(fakeCompressedBody)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}),
|
||||
expectedStatusCode: http.StatusCreated,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
negro := negroni.New(comp)
|
||||
negro.UseHandlerFunc(test.handler)
|
||||
ts := httptest.NewServer(negro)
|
||||
compress := &compress{next: test.handler}
|
||||
ts := httptest.NewServer(compress)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
|
@ -160,16 +168,18 @@ func TestIntegrationShouldNotCompress(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestShouldWriteHeaderWhenFlush(t *testing.T) {
|
||||
comp := &Compress{}
|
||||
negro := negroni.New(comp)
|
||||
negro.UseHandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Add(contentEncodingHeader, gzipValue)
|
||||
rw.Header().Add(varyHeader, acceptEncodingHeader)
|
||||
rw.WriteHeader(http.StatusUnauthorized)
|
||||
rw.(http.Flusher).Flush()
|
||||
rw.Write([]byte("short"))
|
||||
_, err := rw.Write([]byte("short"))
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
ts := httptest.NewServer(negro)
|
||||
handler := &compress{next: next}
|
||||
ts := httptest.NewServer(handler)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
|
@ -189,34 +199,36 @@ func TestIntegrationShouldCompress(t *testing.T) {
|
|||
|
||||
testCases := []struct {
|
||||
name string
|
||||
handler func(rw http.ResponseWriter, r *http.Request)
|
||||
handler http.Handler
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "when AcceptEncoding header is present",
|
||||
handler: func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Write(fakeBody)
|
||||
},
|
||||
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
_, err := rw.Write(fakeBody)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}),
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "when AcceptEncoding header is present and status code Created",
|
||||
handler: func(rw http.ResponseWriter, r *http.Request) {
|
||||
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusCreated)
|
||||
rw.Write(fakeBody)
|
||||
},
|
||||
_, err := rw.Write(fakeBody)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}),
|
||||
expectedStatusCode: http.StatusCreated,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
comp := &Compress{}
|
||||
|
||||
negro := negroni.New(comp)
|
||||
negro.UseHandlerFunc(test.handler)
|
||||
ts := httptest.NewServer(negro)
|
||||
compress := &compress{next: test.handler}
|
||||
ts := httptest.NewServer(compress)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
|
@ -1,9 +1,9 @@
|
|||
package errorpages
|
||||
package customerrors
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
@ -11,110 +11,120 @@ import (
|
|||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/vulcand/oxy/forward"
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
// Compile time validation that the response recorder implements http interfaces correctly.
|
||||
var _ middlewares.Stateful = &responseRecorderWithCloseNotify{}
|
||||
|
||||
// Handler is a middleware that provides the custom error pages
|
||||
type Handler struct {
|
||||
BackendName string
|
||||
backendHandler http.Handler
|
||||
httpCodeRanges types.HTTPCodeRanges
|
||||
backendURL string
|
||||
backendQuery string
|
||||
FallbackURL string // Deprecated
|
||||
const (
|
||||
typeName = "customError"
|
||||
backendURL = "http://0.0.0.0"
|
||||
)
|
||||
|
||||
type serviceBuilder interface {
|
||||
Build(ctx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error)
|
||||
}
|
||||
|
||||
// NewHandler initializes the utils.ErrorHandler for the custom error pages
|
||||
func NewHandler(errorPage *types.ErrorPage, backendName string) (*Handler, error) {
|
||||
if len(backendName) == 0 {
|
||||
return nil, errors.New("error pages: backend name is mandatory ")
|
||||
}
|
||||
// customErrors is a middleware that provides the custom error pages..
|
||||
type customErrors struct {
|
||||
name string
|
||||
next http.Handler
|
||||
backendHandler http.Handler
|
||||
httpCodeRanges types.HTTPCodeRanges
|
||||
backendQuery string
|
||||
}
|
||||
|
||||
httpCodeRanges, err := types.NewHTTPCodeRanges(errorPage.Status)
|
||||
// New creates a new custom error pages middleware.
|
||||
func New(ctx context.Context, next http.Handler, config config.ErrorPage, serviceBuilder serviceBuilder, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
|
||||
httpCodeRanges, err := types.NewHTTPCodeRanges(config.Status)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Handler{
|
||||
BackendName: backendName,
|
||||
backend, err := serviceBuilder.Build(ctx, config.Service, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &customErrors{
|
||||
name: name,
|
||||
next: next,
|
||||
backendHandler: backend,
|
||||
httpCodeRanges: httpCodeRanges,
|
||||
backendQuery: errorPage.Query,
|
||||
backendURL: "http://0.0.0.0",
|
||||
backendQuery: config.Query,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// PostLoad adds backend handler if available
|
||||
func (h *Handler) PostLoad(backendHandler http.Handler) error {
|
||||
if backendHandler == nil {
|
||||
fwd, err := forward.New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.backendHandler = fwd
|
||||
h.backendURL = h.FallbackURL
|
||||
} else {
|
||||
h.backendHandler = backendHandler
|
||||
}
|
||||
|
||||
return nil
|
||||
func (c *customErrors) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return c.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
|
||||
if h.backendHandler == nil {
|
||||
log.Error("Error pages: no backend handler.")
|
||||
next.ServeHTTP(w, req)
|
||||
func (c *customErrors) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
logger := middlewares.GetLogger(req.Context(), c.name, typeName)
|
||||
|
||||
if c.backendHandler == nil {
|
||||
logger.Error("Error pages: no backend handler.")
|
||||
tracing.SetErrorWithEvent(req, "Error pages: no backend handler.")
|
||||
c.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
recorder := newResponseRecorder(w)
|
||||
next.ServeHTTP(recorder, req)
|
||||
recorder := newResponseRecorder(rw, middlewares.GetLogger(context.Background(), "test", typeName))
|
||||
c.next.ServeHTTP(recorder, req)
|
||||
|
||||
// check the recorder code against the configured http status code ranges
|
||||
for _, block := range h.httpCodeRanges {
|
||||
for _, block := range c.httpCodeRanges {
|
||||
if recorder.GetCode() >= block[0] && recorder.GetCode() <= block[1] {
|
||||
log.Errorf("Caught HTTP Status Code %d, returning error page", recorder.GetCode())
|
||||
logger.Errorf("Caught HTTP Status Code %d, returning error page", recorder.GetCode())
|
||||
|
||||
var query string
|
||||
if len(h.backendQuery) > 0 {
|
||||
query = "/" + strings.TrimPrefix(h.backendQuery, "/")
|
||||
if len(c.backendQuery) > 0 {
|
||||
query = "/" + strings.TrimPrefix(c.backendQuery, "/")
|
||||
query = strings.Replace(query, "{status}", strconv.Itoa(recorder.GetCode()), -1)
|
||||
}
|
||||
|
||||
pageReq, err := newRequest(h.backendURL + query)
|
||||
pageReq, err := newRequest(backendURL + query)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
w.WriteHeader(recorder.GetCode())
|
||||
fmt.Fprint(w, http.StatusText(recorder.GetCode()))
|
||||
logger.Error(err)
|
||||
rw.WriteHeader(recorder.GetCode())
|
||||
_, err = fmt.Fprint(rw, http.StatusText(recorder.GetCode()))
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
recorderErrorPage := newResponseRecorder(w)
|
||||
recorderErrorPage := newResponseRecorder(rw, middlewares.GetLogger(context.Background(), "test", typeName))
|
||||
utils.CopyHeaders(pageReq.Header, req.Header)
|
||||
|
||||
h.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context()))
|
||||
c.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context()))
|
||||
|
||||
utils.CopyHeaders(w.Header(), recorderErrorPage.Header())
|
||||
w.WriteHeader(recorder.GetCode())
|
||||
utils.CopyHeaders(rw.Header(), recorderErrorPage.Header())
|
||||
rw.WriteHeader(recorder.GetCode())
|
||||
|
||||
if _, err = w.Write(recorderErrorPage.GetBody().Bytes()); err != nil {
|
||||
log.Error(err)
|
||||
if _, err = rw.Write(recorderErrorPage.GetBody().Bytes()); err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// did not catch a configured status code so proceed with the request
|
||||
utils.CopyHeaders(w.Header(), recorder.Header())
|
||||
w.WriteHeader(recorder.GetCode())
|
||||
w.Write(recorder.GetBody().Bytes())
|
||||
utils.CopyHeaders(rw.Header(), recorder.Header())
|
||||
rw.WriteHeader(recorder.GetCode())
|
||||
_, err := rw.Write(recorder.GetBody().Bytes())
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func newRequest(baseURL string) (*http.Request, error) {
|
||||
|
@ -123,7 +133,7 @@ func newRequest(baseURL string) (*http.Request, error) {
|
|||
return nil, fmt.Errorf("error pages: error when parse URL: %v", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody)
|
||||
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error pages: error when create query: %v", err)
|
||||
}
|
||||
|
@ -141,12 +151,13 @@ type responseRecorder interface {
|
|||
}
|
||||
|
||||
// newResponseRecorder returns an initialized responseRecorder.
|
||||
func newResponseRecorder(rw http.ResponseWriter) responseRecorder {
|
||||
func newResponseRecorder(rw http.ResponseWriter, logger logrus.FieldLogger) responseRecorder {
|
||||
recorder := &responseRecorderWithoutCloseNotify{
|
||||
HeaderMap: make(http.Header),
|
||||
Body: new(bytes.Buffer),
|
||||
Code: http.StatusOK,
|
||||
responseWriter: rw,
|
||||
logger: logger,
|
||||
}
|
||||
if _, ok := rw.(http.CloseNotifier); ok {
|
||||
return &responseRecorderWithCloseNotify{recorder}
|
||||
|
@ -164,6 +175,7 @@ type responseRecorderWithoutCloseNotify struct {
|
|||
responseWriter http.ResponseWriter
|
||||
err error
|
||||
streamingResponseStarted bool
|
||||
logger logrus.FieldLogger
|
||||
}
|
||||
|
||||
type responseRecorderWithCloseNotify struct {
|
||||
|
@ -225,7 +237,7 @@ func (r *responseRecorderWithoutCloseNotify) Flush() {
|
|||
|
||||
_, err := r.responseWriter.Write(r.Body.Bytes())
|
||||
if err != nil {
|
||||
log.Errorf("Error writing response in responseRecorder: %v", err)
|
||||
r.logger.Errorf("Error writing response in responseRecorder: %v", err)
|
||||
r.err = err
|
||||
}
|
||||
r.Body.Reset()
|
176
middlewares/customerrors/custom_errors_test.go
Normal file
176
middlewares/customerrors/custom_errors_test.go
Normal file
|
@ -0,0 +1,176 @@
|
|||
package customerrors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHandler(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
errorPage *config.ErrorPage
|
||||
backendCode int
|
||||
backendErrorHandler http.HandlerFunc
|
||||
validate func(t *testing.T, recorder *httptest.ResponseRecorder)
|
||||
}{
|
||||
{
|
||||
desc: "no error",
|
||||
errorPage: &config.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusOK,
|
||||
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "My error page.")
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusOK, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusOK))
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "in the range",
|
||||
errorPage: &config.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusInternalServerError,
|
||||
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "My error page.")
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusInternalServerError, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "My error page.")
|
||||
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "not in the range",
|
||||
errorPage: &config.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusBadGateway,
|
||||
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "My error page.")
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusBadGateway, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusBadGateway))
|
||||
assert.NotContains(t, recorder.Body.String(), "Test Server", "Should return the oops page since we have not configured the 502 code")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "query replacement",
|
||||
errorPage: &config.ErrorPage{Service: "error", Query: "/{status}", Status: []string{"503-503"}},
|
||||
backendCode: http.StatusServiceUnavailable,
|
||||
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.RequestURI == "/503" {
|
||||
fmt.Fprintln(w, "My 503 page.")
|
||||
} else {
|
||||
fmt.Fprintln(w, "Failed")
|
||||
}
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "My 503 page.")
|
||||
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Single code",
|
||||
errorPage: &config.ErrorPage{Service: "error", Query: "/{status}", Status: []string{"503"}},
|
||||
backendCode: http.StatusServiceUnavailable,
|
||||
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.RequestURI == "/503" {
|
||||
fmt.Fprintln(w, "My 503 page.")
|
||||
} else {
|
||||
fmt.Fprintln(w, "Failed")
|
||||
}
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "My 503 page.")
|
||||
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serviceBuilderMock := &mockServiceBuilder{handler: test.backendErrorHandler}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(test.backendCode)
|
||||
fmt.Fprintln(w, http.StatusText(test.backendCode))
|
||||
})
|
||||
errorPageHandler, err := New(context.Background(), handler, *test.errorPage, serviceBuilderMock, "test")
|
||||
require.NoError(t, err)
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost/test", nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
errorPageHandler.ServeHTTP(recorder, req)
|
||||
|
||||
test.validate(t, recorder)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockServiceBuilder struct {
|
||||
handler http.Handler
|
||||
}
|
||||
|
||||
func (m *mockServiceBuilder) Build(_ context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) {
|
||||
return m.handler, nil
|
||||
}
|
||||
|
||||
func TestNewResponseRecorder(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
rw http.ResponseWriter
|
||||
expected http.ResponseWriter
|
||||
}{
|
||||
{
|
||||
desc: "Without Close Notify",
|
||||
rw: httptest.NewRecorder(),
|
||||
expected: &responseRecorderWithoutCloseNotify{},
|
||||
},
|
||||
{
|
||||
desc: "With Close Notify",
|
||||
rw: &mockRWCloseNotify{},
|
||||
expected: &responseRecorderWithCloseNotify{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := newResponseRecorder(test.rw, middlewares.GetLogger(context.Background(), "test", typeName))
|
||||
assert.IsType(t, rec, test.expected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockRWCloseNotify struct{}
|
||||
|
||||
func (m *mockRWCloseNotify) CloseNotify() <-chan bool {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockRWCloseNotify) Header() http.Header {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockRWCloseNotify) Write([]byte) (int, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockRWCloseNotify) WriteHeader(int) {
|
||||
panic("implement me")
|
||||
}
|
|
@ -1,30 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/healthcheck"
|
||||
)
|
||||
|
||||
// EmptyBackendHandler is a middlware that checks whether the current Backend
|
||||
// has at least one active Server in respect to the healthchecks and if this
|
||||
// is not the case, it will stop the middleware chain and respond with 503.
|
||||
type EmptyBackendHandler struct {
|
||||
next healthcheck.BalancerHandler
|
||||
}
|
||||
|
||||
// NewEmptyBackendHandler creates a new EmptyBackendHandler instance.
|
||||
func NewEmptyBackendHandler(lb healthcheck.BalancerHandler) *EmptyBackendHandler {
|
||||
return &EmptyBackendHandler{next: lb}
|
||||
}
|
||||
|
||||
// ServeHTTP responds with 503 when there is no active Server and otherwise
|
||||
// invokes the next handler in the middleware chain.
|
||||
func (h *EmptyBackendHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
if len(h.next.Servers()) == 0 {
|
||||
rw.WriteHeader(http.StatusServiceUnavailable)
|
||||
rw.Write([]byte(http.StatusText(http.StatusServiceUnavailable)))
|
||||
} else {
|
||||
h.next.ServeHTTP(rw, r)
|
||||
}
|
||||
}
|
33
middlewares/emptybackendhandler/empty_backend_handler.go
Normal file
33
middlewares/emptybackendhandler/empty_backend_handler.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package emptybackendhandler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/healthcheck"
|
||||
)
|
||||
|
||||
// EmptyBackend is a middleware that checks whether the current Backend
|
||||
// has at least one active Server in respect to the healthchecks and if this
|
||||
// is not the case, it will stop the middleware chain and respond with 503.
|
||||
type emptyBackend struct {
|
||||
next healthcheck.BalancerHandler
|
||||
}
|
||||
|
||||
// New creates a new EmptyBackend middleware.
|
||||
func New(lb healthcheck.BalancerHandler) http.Handler {
|
||||
return &emptyBackend{next: lb}
|
||||
}
|
||||
|
||||
// ServeHTTP responds with 503 when there is no active Server and otherwise
|
||||
// invokes the next handler in the middleware chain.
|
||||
func (e *emptyBackend) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if len(e.next.Servers()) == 0 {
|
||||
rw.WriteHeader(http.StatusServiceUnavailable)
|
||||
_, err := rw.Write([]byte(http.StatusText(http.StatusServiceUnavailable)))
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
} else {
|
||||
e.next.ServeHTTP(rw, req)
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package middlewares
|
||||
package emptybackendhandler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
@ -8,40 +8,38 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/vulcand/oxy/roundrobin"
|
||||
)
|
||||
|
||||
func TestEmptyBackendHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
amountServer int
|
||||
wantStatusCode int
|
||||
testCases := []struct {
|
||||
amountServer int
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
amountServer: 0,
|
||||
wantStatusCode: http.StatusServiceUnavailable,
|
||||
amountServer: 0,
|
||||
expectedStatusCode: http.StatusServiceUnavailable,
|
||||
},
|
||||
{
|
||||
amountServer: 1,
|
||||
wantStatusCode: http.StatusOK,
|
||||
amountServer: 1,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
|
||||
t.Run(fmt.Sprintf("amount servers %d", test.amountServer), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewEmptyBackendHandler(&healthCheckLoadBalancer{test.amountServer})
|
||||
handler := New(&healthCheckLoadBalancer{amountServer: test.amountServer})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Result().StatusCode != test.wantStatusCode {
|
||||
t.Errorf("Received status code %d, wanted %d", recorder.Result().StatusCode, test.wantStatusCode)
|
||||
}
|
||||
assert.Equal(t, test.expectedStatusCode, recorder.Result().StatusCode)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,384 +0,0 @@
|
|||
package errorpages
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
func TestHandler(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
errorPage *types.ErrorPage
|
||||
backendCode int
|
||||
backendErrorHandler http.HandlerFunc
|
||||
validate func(t *testing.T, recorder *httptest.ResponseRecorder)
|
||||
}{
|
||||
{
|
||||
desc: "no error",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusOK,
|
||||
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "My error page.")
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusOK, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusOK))
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "in the range",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusInternalServerError,
|
||||
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "My error page.")
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusInternalServerError, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "My error page.")
|
||||
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "not in the range",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusBadGateway,
|
||||
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "My error page.")
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusBadGateway, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusBadGateway))
|
||||
assert.NotContains(t, recorder.Body.String(), "Test Server", "Should return the oops page since we have not configured the 502 code")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "query replacement",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503-503"}},
|
||||
backendCode: http.StatusServiceUnavailable,
|
||||
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.RequestURI == "/503" {
|
||||
fmt.Fprintln(w, "My 503 page.")
|
||||
} else {
|
||||
fmt.Fprintln(w, "Failed")
|
||||
}
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "My 503 page.")
|
||||
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Single code",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503"}},
|
||||
backendCode: http.StatusServiceUnavailable,
|
||||
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.RequestURI == "/503" {
|
||||
fmt.Fprintln(w, "My 503 page.")
|
||||
} else {
|
||||
fmt.Fprintln(w, "Failed")
|
||||
}
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "My 503 page.")
|
||||
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errorPageHandler, err := NewHandler(test.errorPage, "test")
|
||||
require.NoError(t, err)
|
||||
|
||||
errorPageHandler.backendHandler = test.backendErrorHandler
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(test.backendCode)
|
||||
fmt.Fprintln(w, http.StatusText(test.backendCode))
|
||||
})
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost/test", nil)
|
||||
|
||||
n := negroni.New()
|
||||
n.Use(errorPageHandler)
|
||||
n.UseHandler(handler)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
n.ServeHTTP(recorder, req)
|
||||
|
||||
test.validate(t, recorder)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerOldWay(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
errorPage *types.ErrorPage
|
||||
backendCode int
|
||||
errorPageForwarder http.HandlerFunc
|
||||
validate func(t *testing.T, recorder *httptest.ResponseRecorder)
|
||||
}{
|
||||
{
|
||||
desc: "no error",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusOK,
|
||||
errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "My error page.")
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusOK, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "OK")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "in the range",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusInternalServerError,
|
||||
errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "My error page.")
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusInternalServerError, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), "My error page.")
|
||||
assert.NotContains(t, recorder.Body.String(), http.StatusText(http.StatusInternalServerError), "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "not in the range",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusBadGateway,
|
||||
errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "My error page.")
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusBadGateway, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusBadGateway))
|
||||
assert.NotContains(t, recorder.Body.String(), "My error page.", "Should return the oops page since we have not configured the 502 code")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "query replacement",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503-503"}},
|
||||
backendCode: http.StatusServiceUnavailable,
|
||||
errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.RequestURI() == "/"+strconv.Itoa(503) {
|
||||
fmt.Fprintln(w, "My 503 page.")
|
||||
} else {
|
||||
fmt.Fprintln(w, "Failed")
|
||||
}
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "My 503 page.")
|
||||
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Single code",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503"}},
|
||||
backendCode: http.StatusServiceUnavailable,
|
||||
errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.RequestURI() == "/"+strconv.Itoa(503) {
|
||||
fmt.Fprintln(w, "My 503 page.")
|
||||
} else {
|
||||
fmt.Fprintln(w, "Failed")
|
||||
}
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "My 503 page.")
|
||||
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost/test", nil)
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errorPageHandler, err := NewHandler(test.errorPage, "test")
|
||||
require.NoError(t, err)
|
||||
errorPageHandler.FallbackURL = "http://localhost"
|
||||
|
||||
err = errorPageHandler.PostLoad(test.errorPageForwarder)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(test.backendCode)
|
||||
fmt.Fprintln(w, http.StatusText(test.backendCode))
|
||||
})
|
||||
|
||||
n := negroni.New()
|
||||
n.Use(errorPageHandler)
|
||||
n.UseHandler(handler)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
n.ServeHTTP(recorder, req)
|
||||
|
||||
test.validate(t, recorder)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerOldWayIntegration(t *testing.T) {
|
||||
errorPagesServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.RequestURI() == "/503" {
|
||||
fmt.Fprintln(w, "My 503 page.")
|
||||
} else {
|
||||
fmt.Fprintln(w, "Test Server")
|
||||
}
|
||||
}))
|
||||
defer errorPagesServer.Close()
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
errorPage *types.ErrorPage
|
||||
backendCode int
|
||||
validate func(t *testing.T, recorder *httptest.ResponseRecorder)
|
||||
}{
|
||||
{
|
||||
desc: "no error",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusOK,
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusOK, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "OK")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "in the range",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusInternalServerError,
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusInternalServerError, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), "Test Server")
|
||||
assert.NotContains(t, recorder.Body.String(), http.StatusText(http.StatusInternalServerError), "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "not in the range",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusBadGateway,
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusBadGateway, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusBadGateway))
|
||||
assert.NotContains(t, recorder.Body.String(), "Test Server", "Should return the oops page since we have not configured the 502 code")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "query replacement",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503-503"}},
|
||||
backendCode: http.StatusServiceUnavailable,
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "My 503 page.")
|
||||
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Single code",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503"}},
|
||||
backendCode: http.StatusServiceUnavailable,
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "My 503 page.")
|
||||
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, errorPagesServer.URL+"/test", nil)
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
|
||||
errorPageHandler, err := NewHandler(test.errorPage, "test")
|
||||
require.NoError(t, err)
|
||||
errorPageHandler.FallbackURL = errorPagesServer.URL
|
||||
|
||||
err = errorPageHandler.PostLoad(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(test.backendCode)
|
||||
fmt.Fprintln(w, http.StatusText(test.backendCode))
|
||||
})
|
||||
|
||||
n := negroni.New()
|
||||
n.Use(errorPageHandler)
|
||||
n.UseHandler(handler)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
n.ServeHTTP(recorder, req)
|
||||
|
||||
test.validate(t, recorder)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewResponseRecorder(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
rw http.ResponseWriter
|
||||
expected http.ResponseWriter
|
||||
}{
|
||||
{
|
||||
desc: "Without Close Notify",
|
||||
rw: httptest.NewRecorder(),
|
||||
expected: &responseRecorderWithoutCloseNotify{},
|
||||
},
|
||||
{
|
||||
desc: "With Close Notify",
|
||||
rw: &mockRWCloseNotify{},
|
||||
expected: &responseRecorderWithCloseNotify{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := newResponseRecorder(test.rw)
|
||||
|
||||
assert.IsType(t, rec, test.expected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockRWCloseNotify struct{}
|
||||
|
||||
func (m *mockRWCloseNotify) CloseNotify() <-chan bool {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockRWCloseNotify) Header() http.Header {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockRWCloseNotify) Write([]byte) (int, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockRWCloseNotify) WriteHeader(int) {
|
||||
panic("implement me")
|
||||
}
|
|
@ -1,52 +0,0 @@
|
|||
package forwardedheaders
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/ip"
|
||||
"github.com/vulcand/oxy/forward"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
// XForwarded filter for XForwarded headers
|
||||
type XForwarded struct {
|
||||
insecure bool
|
||||
trustedIps []string
|
||||
ipChecker *ip.Checker
|
||||
}
|
||||
|
||||
// NewXforwarded creates a new XForwarded
|
||||
func NewXforwarded(insecure bool, trustedIps []string) (*XForwarded, error) {
|
||||
var ipChecker *ip.Checker
|
||||
if len(trustedIps) > 0 {
|
||||
var err error
|
||||
ipChecker, err = ip.NewChecker(trustedIps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &XForwarded{
|
||||
insecure: insecure,
|
||||
trustedIps: trustedIps,
|
||||
ipChecker: ipChecker,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (x *XForwarded) isTrustedIP(ip string) bool {
|
||||
if x.ipChecker == nil {
|
||||
return false
|
||||
}
|
||||
return x.ipChecker.IsAuthorized(ip) == nil
|
||||
}
|
||||
|
||||
func (x *XForwarded) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
if !x.insecure && !x.isTrustedIP(r.RemoteAddr) {
|
||||
utils.RemoveHeaders(r.Header, forward.XHeaders...)
|
||||
}
|
||||
|
||||
// If there is a next, call it.
|
||||
if next != nil {
|
||||
next(w, r)
|
||||
}
|
||||
}
|
|
@ -1,128 +0,0 @@
|
|||
package forwardedheaders
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServeHTTP(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
insecure bool
|
||||
trustedIps []string
|
||||
incomingHeaders map[string]string
|
||||
remoteAddr string
|
||||
expectedHeaders map[string]string
|
||||
}{
|
||||
{
|
||||
desc: "all Empty",
|
||||
insecure: true,
|
||||
trustedIps: nil,
|
||||
remoteAddr: "",
|
||||
incomingHeaders: map[string]string{},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-for": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "insecure true with incoming X-Forwarded-For",
|
||||
insecure: true,
|
||||
trustedIps: nil,
|
||||
remoteAddr: "",
|
||||
incomingHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "insecure false with incoming X-Forwarded-For",
|
||||
insecure: false,
|
||||
trustedIps: nil,
|
||||
remoteAddr: "",
|
||||
incomingHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-for": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "insecure false with incoming X-Forwarded-For and valid Trusted Ips",
|
||||
insecure: false,
|
||||
trustedIps: []string{"10.0.1.100"},
|
||||
remoteAddr: "10.0.1.100:80",
|
||||
incomingHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "insecure false with incoming X-Forwarded-For and invalid Trusted Ips",
|
||||
insecure: false,
|
||||
trustedIps: []string{"10.0.1.100"},
|
||||
remoteAddr: "10.0.1.101:80",
|
||||
incomingHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-for": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "insecure false with incoming X-Forwarded-For and valid Trusted Ips CIDR",
|
||||
insecure: false,
|
||||
trustedIps: []string{"1.2.3.4/24"},
|
||||
remoteAddr: "1.2.3.156:80",
|
||||
incomingHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "insecure false with incoming X-Forwarded-For and invalid Trusted Ips CIDR",
|
||||
insecure: false,
|
||||
trustedIps: []string{"1.2.3.4/24"},
|
||||
remoteAddr: "10.0.1.101:80",
|
||||
incomingHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-for": "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.RemoteAddr = test.remoteAddr
|
||||
|
||||
for k, v := range test.incomingHeaders {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
m, err := NewXforwarded(test.insecure, test.trustedIps)
|
||||
require.NoError(t, err)
|
||||
|
||||
m.ServeHTTP(nil, req, nil)
|
||||
|
||||
for k, v := range test.expectedHeaders {
|
||||
assert.Equal(t, v, req.Header.Get(k))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -3,7 +3,6 @@ package middlewares
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/safe"
|
||||
)
|
||||
|
||||
|
@ -13,24 +12,24 @@ type HandlerSwitcher struct {
|
|||
}
|
||||
|
||||
// NewHandlerSwitcher builds a new instance of HandlerSwitcher
|
||||
func NewHandlerSwitcher(newHandler *mux.Router) (hs *HandlerSwitcher) {
|
||||
func NewHandlerSwitcher(newHandler http.Handler) (hs *HandlerSwitcher) {
|
||||
return &HandlerSwitcher{
|
||||
handler: safe.New(newHandler),
|
||||
}
|
||||
}
|
||||
|
||||
func (hs *HandlerSwitcher) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
handlerBackup := hs.handler.Get().(*mux.Router)
|
||||
handlerBackup.ServeHTTP(rw, r)
|
||||
func (h *HandlerSwitcher) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
handlerBackup := h.handler.Get().(http.Handler)
|
||||
handlerBackup.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
// GetHandler returns the current http.ServeMux
|
||||
func (hs *HandlerSwitcher) GetHandler() (newHandler *mux.Router) {
|
||||
handler := hs.handler.Get().(*mux.Router)
|
||||
func (h *HandlerSwitcher) GetHandler() (newHandler http.Handler) {
|
||||
handler := h.handler.Get().(http.Handler)
|
||||
return handler
|
||||
}
|
||||
|
||||
// UpdateHandler safely updates the current http.ServeMux with a new one
|
||||
func (hs *HandlerSwitcher) UpdateHandler(newHandler *mux.Router) {
|
||||
hs.handler.Set(newHandler)
|
||||
func (h *HandlerSwitcher) UpdateHandler(newHandler http.Handler) {
|
||||
h.handler.Set(newHandler)
|
||||
}
|
|
@ -1,71 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
// Middleware based on https://github.com/unrolled/secure
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/types"
|
||||
)
|
||||
|
||||
// HeaderOptions is a struct for specifying configuration options for the headers middleware.
|
||||
type HeaderOptions struct {
|
||||
// If Custom request headers are set, these will be added to the request
|
||||
CustomRequestHeaders map[string]string
|
||||
// If Custom response headers are set, these will be added to the ResponseWriter
|
||||
CustomResponseHeaders map[string]string
|
||||
}
|
||||
|
||||
// HeaderStruct is a middleware that helps setup a few basic security features. A single headerOptions struct can be
|
||||
// provided to configure which features should be enabled, and the ability to override a few of the default values.
|
||||
type HeaderStruct struct {
|
||||
// Customize headers with a headerOptions struct.
|
||||
opt HeaderOptions
|
||||
}
|
||||
|
||||
// NewHeaderFromStruct constructs a new header instance from supplied frontend header struct.
|
||||
func NewHeaderFromStruct(headers *types.Headers) *HeaderStruct {
|
||||
if headers == nil || !headers.HasCustomHeadersDefined() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &HeaderStruct{
|
||||
opt: HeaderOptions{
|
||||
CustomRequestHeaders: headers.CustomRequestHeaders,
|
||||
CustomResponseHeaders: headers.CustomResponseHeaders,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HeaderStruct) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
s.ModifyRequestHeaders(r)
|
||||
// If there is a next, call it.
|
||||
if next != nil {
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// ModifyRequestHeaders set or delete request headers
|
||||
func (s *HeaderStruct) ModifyRequestHeaders(r *http.Request) {
|
||||
// Loop through Custom request headers
|
||||
for header, value := range s.opt.CustomRequestHeaders {
|
||||
if value == "" {
|
||||
r.Header.Del(header)
|
||||
} else {
|
||||
r.Header.Set(header, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ModifyResponseHeaders set or delete response headers
|
||||
func (s *HeaderStruct) ModifyResponseHeaders(res *http.Response) error {
|
||||
// Loop through Custom response headers
|
||||
for header, value := range s.opt.CustomResponseHeaders {
|
||||
if value == "" {
|
||||
res.Header.Del(header)
|
||||
} else {
|
||||
res.Header.Set(header, value)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
134
middlewares/headers/headers.go
Normal file
134
middlewares/headers/headers.go
Normal file
|
@ -0,0 +1,134 @@
|
|||
// Package headers Middleware based on https://github.com/unrolled/secure.
|
||||
package headers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/unrolled/secure"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "Headers"
|
||||
)
|
||||
|
||||
type headers struct {
|
||||
name string
|
||||
handler http.Handler
|
||||
}
|
||||
|
||||
// New creates a Headers middleware.
|
||||
func New(ctx context.Context, next http.Handler, config config.Headers, name string) (http.Handler, error) {
|
||||
// HeaderMiddleware -> SecureMiddleWare -> next
|
||||
logger := middlewares.GetLogger(ctx, name, typeName)
|
||||
logger.Debug("Creating middleware")
|
||||
|
||||
if !config.HasSecureHeadersDefined() && !config.HasCustomHeadersDefined() {
|
||||
return nil, errors.New("headers configuration not valid")
|
||||
}
|
||||
|
||||
var handler http.Handler
|
||||
nextHandler := next
|
||||
|
||||
if config.HasSecureHeadersDefined() {
|
||||
logger.Debug("Setting up secureHeaders from %v", config)
|
||||
handler = newSecure(next, config)
|
||||
nextHandler = handler
|
||||
}
|
||||
|
||||
if config.HasCustomHeadersDefined() {
|
||||
logger.Debug("Setting up customHeaders from %v", config)
|
||||
handler = newHeader(nextHandler, config)
|
||||
}
|
||||
|
||||
return &headers{
|
||||
handler: handler,
|
||||
name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *headers) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return h.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (h *headers) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
h.handler.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
type secureHeader struct {
|
||||
next http.Handler
|
||||
secure *secure.Secure
|
||||
}
|
||||
|
||||
// newSecure constructs a new secure instance with supplied options.
|
||||
func newSecure(next http.Handler, headers config.Headers) *secureHeader {
|
||||
opt := secure.Options{
|
||||
BrowserXssFilter: headers.BrowserXSSFilter,
|
||||
ContentTypeNosniff: headers.ContentTypeNosniff,
|
||||
ForceSTSHeader: headers.ForceSTSHeader,
|
||||
FrameDeny: headers.FrameDeny,
|
||||
IsDevelopment: headers.IsDevelopment,
|
||||
SSLRedirect: headers.SSLRedirect,
|
||||
SSLForceHost: headers.SSLForceHost,
|
||||
SSLTemporaryRedirect: headers.SSLTemporaryRedirect,
|
||||
STSIncludeSubdomains: headers.STSIncludeSubdomains,
|
||||
STSPreload: headers.STSPreload,
|
||||
ContentSecurityPolicy: headers.ContentSecurityPolicy,
|
||||
CustomBrowserXssValue: headers.CustomBrowserXSSValue,
|
||||
CustomFrameOptionsValue: headers.CustomFrameOptionsValue,
|
||||
PublicKey: headers.PublicKey,
|
||||
ReferrerPolicy: headers.ReferrerPolicy,
|
||||
SSLHost: headers.SSLHost,
|
||||
AllowedHosts: headers.AllowedHosts,
|
||||
HostsProxyHeaders: headers.HostsProxyHeaders,
|
||||
SSLProxyHeaders: headers.SSLProxyHeaders,
|
||||
STSSeconds: headers.STSSeconds,
|
||||
}
|
||||
|
||||
return &secureHeader{
|
||||
next: next,
|
||||
secure: secure.New(opt),
|
||||
}
|
||||
}
|
||||
|
||||
func (s secureHeader) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
s.secure.HandlerFuncWithNextForRequestOnly(rw, req, s.next.ServeHTTP)
|
||||
}
|
||||
|
||||
// Header is a middleware that helps setup a few basic security features. A single headerOptions struct can be
|
||||
// provided to configure which features should be enabled, and the ability to override a few of the default values.
|
||||
type header struct {
|
||||
next http.Handler
|
||||
// If Custom request headers are set, these will be added to the request
|
||||
customRequestHeaders map[string]string
|
||||
}
|
||||
|
||||
// NewHeader constructs a new header instance from supplied frontend header struct.
|
||||
func newHeader(next http.Handler, headers config.Headers) *header {
|
||||
return &header{
|
||||
next: next,
|
||||
customRequestHeaders: headers.CustomRequestHeaders,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *header) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
s.modifyRequestHeaders(req)
|
||||
s.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
// modifyRequestHeaders set or delete request headers.
|
||||
func (s *header) modifyRequestHeaders(req *http.Request) {
|
||||
// Loop through Custom request headers
|
||||
for header, value := range s.customRequestHeaders {
|
||||
if value == "" {
|
||||
req.Header.Del(header)
|
||||
} else {
|
||||
req.Header.Set(header, value)
|
||||
}
|
||||
}
|
||||
}
|
105
middlewares/headers/headers_test.go
Normal file
105
middlewares/headers/headers_test.go
Normal file
|
@ -0,0 +1,105 @@
|
|||
package headers
|
||||
|
||||
// Middleware tests based on https://github.com/unrolled/secure
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCustomRequestHeader(t *testing.T) {
|
||||
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
header := newHeader(emptyHandler, config.Headers{
|
||||
CustomRequestHeaders: map[string]string{
|
||||
"X-Custom-Request-Header": "test_request",
|
||||
},
|
||||
})
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
|
||||
|
||||
header.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code)
|
||||
assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"))
|
||||
}
|
||||
|
||||
func TestCustomRequestHeaderEmptyValue(t *testing.T) {
|
||||
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
header := newHeader(emptyHandler, config.Headers{
|
||||
CustomRequestHeaders: map[string]string{
|
||||
"X-Custom-Request-Header": "test_request",
|
||||
},
|
||||
})
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
|
||||
|
||||
header.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code)
|
||||
assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"))
|
||||
|
||||
header = newHeader(emptyHandler, config.Headers{
|
||||
CustomRequestHeaders: map[string]string{
|
||||
"X-Custom-Request-Header": "",
|
||||
},
|
||||
})
|
||||
|
||||
header.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code)
|
||||
assert.Equal(t, "", req.Header.Get("X-Custom-Request-Header"))
|
||||
}
|
||||
|
||||
func TestSecureHeader(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
fromHost string
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
desc: "Should accept the request when given a host that is in the list",
|
||||
fromHost: "foo.com",
|
||||
expected: http.StatusOK,
|
||||
},
|
||||
{
|
||||
desc: "Should refuse the request when no host is given",
|
||||
fromHost: "",
|
||||
expected: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
desc: "Should refuse the request when no matching host is given",
|
||||
fromHost: "boo.com",
|
||||
expected: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
header, err := New(context.Background(), emptyHandler, config.Headers{
|
||||
AllowedHosts: []string{"foo.com", "bar.com"},
|
||||
}, "foo")
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
|
||||
req.Host = test.fromHost
|
||||
header.ServeHTTP(res, req)
|
||||
assert.Equal(t, test.expected, res.Code)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,118 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
// Middleware tests based on https://github.com/unrolled/secure
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var myHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("bar"))
|
||||
})
|
||||
|
||||
// newHeader constructs a new header instance with supplied options.
|
||||
func newHeader(options ...HeaderOptions) *HeaderStruct {
|
||||
var opt HeaderOptions
|
||||
if len(options) == 0 {
|
||||
opt = HeaderOptions{}
|
||||
} else {
|
||||
opt = options[0]
|
||||
}
|
||||
|
||||
return &HeaderStruct{opt: opt}
|
||||
}
|
||||
|
||||
func TestNoConfig(t *testing.T) {
|
||||
header := newHeader()
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://example.com/foo", nil)
|
||||
|
||||
header.ServeHTTP(res, req, myHandler)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
|
||||
assert.Equal(t, "bar", res.Body.String(), "Body not the expected")
|
||||
}
|
||||
|
||||
func TestModifyResponseHeaders(t *testing.T) {
|
||||
header := newHeader(HeaderOptions{
|
||||
CustomResponseHeaders: map[string]string{
|
||||
"X-Custom-Response-Header": "test_response",
|
||||
},
|
||||
})
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
res.HeaderMap.Add("X-Custom-Response-Header", "test_response")
|
||||
|
||||
err := header.ModifyResponseHeaders(res.Result())
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
|
||||
assert.Equal(t, "test_response", res.Header().Get("X-Custom-Response-Header"), "Did not get expected header")
|
||||
|
||||
res = httptest.NewRecorder()
|
||||
res.HeaderMap.Add("X-Custom-Response-Header", "")
|
||||
|
||||
err = header.ModifyResponseHeaders(res.Result())
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
|
||||
assert.Equal(t, "", res.Header().Get("X-Custom-Response-Header"), "Did not get expected header")
|
||||
|
||||
res = httptest.NewRecorder()
|
||||
res.HeaderMap.Add("X-Custom-Response-Header", "test_override")
|
||||
|
||||
err = header.ModifyResponseHeaders(res.Result())
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
|
||||
assert.Equal(t, "test_override", res.Header().Get("X-Custom-Response-Header"), "Did not get expected header")
|
||||
}
|
||||
|
||||
func TestCustomRequestHeader(t *testing.T) {
|
||||
header := newHeader(HeaderOptions{
|
||||
CustomRequestHeaders: map[string]string{
|
||||
"X-Custom-Request-Header": "test_request",
|
||||
},
|
||||
})
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
|
||||
|
||||
header.ServeHTTP(res, req, nil)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
|
||||
assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"), "Did not get expected header")
|
||||
}
|
||||
|
||||
func TestCustomRequestHeaderEmptyValue(t *testing.T) {
|
||||
header := newHeader(HeaderOptions{
|
||||
CustomRequestHeaders: map[string]string{
|
||||
"X-Custom-Request-Header": "test_request",
|
||||
},
|
||||
})
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
|
||||
|
||||
header.ServeHTTP(res, req, nil)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
|
||||
assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"), "Did not get expected header")
|
||||
|
||||
header = newHeader(HeaderOptions{
|
||||
CustomRequestHeaders: map[string]string{
|
||||
"X-Custom-Request-Header": "",
|
||||
},
|
||||
})
|
||||
|
||||
header.ServeHTTP(res, req, nil)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
|
||||
assert.Equal(t, "", req.Header.Get("X-Custom-Request-Header"), "This header is not expected")
|
||||
}
|
|
@ -1,67 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/ip"
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/middlewares/tracing"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
// IPWhiteLister is a middleware that provides Checks of the Requesting IP against a set of Whitelists
|
||||
type IPWhiteLister struct {
|
||||
handler negroni.Handler
|
||||
whiteLister *ip.Checker
|
||||
strategy ip.Strategy
|
||||
}
|
||||
|
||||
// NewIPWhiteLister builds a new IPWhiteLister given a list of CIDR-Strings to whitelist
|
||||
func NewIPWhiteLister(whiteList []string, strategy ip.Strategy) (*IPWhiteLister, error) {
|
||||
if len(whiteList) == 0 {
|
||||
return nil, errors.New("no white list provided")
|
||||
}
|
||||
|
||||
checker, err := ip.NewChecker(whiteList)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing CIDR whitelist %s: %v", whiteList, err)
|
||||
}
|
||||
|
||||
whiteLister := IPWhiteLister{
|
||||
strategy: strategy,
|
||||
whiteLister: checker,
|
||||
}
|
||||
|
||||
whiteLister.handler = negroni.HandlerFunc(whiteLister.handle)
|
||||
log.Debugf("configured IP white list: %s", whiteList)
|
||||
|
||||
return &whiteLister, nil
|
||||
}
|
||||
|
||||
func (wl *IPWhiteLister) handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
err := wl.whiteLister.IsAuthorized(wl.strategy.GetIP(r))
|
||||
if err != nil {
|
||||
tracing.SetErrorAndDebugLog(r, "request %+v - rejecting: %v", r, err)
|
||||
reject(w)
|
||||
return
|
||||
}
|
||||
log.Debugf("Accept %s: %+v", wl.strategy.GetIP(r), r)
|
||||
tracing.SetErrorAndDebugLog(r, "request %+v matched white list %v - passing", r, wl.whiteLister)
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (wl *IPWhiteLister) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
wl.handler.ServeHTTP(rw, r, next)
|
||||
}
|
||||
|
||||
func reject(w http.ResponseWriter) {
|
||||
statusCode := http.StatusForbidden
|
||||
|
||||
w.WriteHeader(statusCode)
|
||||
_, err := w.Write([]byte(http.StatusText(statusCode)))
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
85
middlewares/ipwhitelist/ip_whitelist.go
Normal file
85
middlewares/ipwhitelist/ip_whitelist.go
Normal file
|
@ -0,0 +1,85 @@
|
|||
package ipwhitelist
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/ip"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "IPWhiteLister"
|
||||
)
|
||||
|
||||
// ipWhiteLister is a middleware that provides Checks of the Requesting IP against a set of Whitelists
|
||||
type ipWhiteLister struct {
|
||||
next http.Handler
|
||||
whiteLister *ip.Checker
|
||||
strategy ip.Strategy
|
||||
name string
|
||||
}
|
||||
|
||||
// New builds a new IPWhiteLister given a list of CIDR-Strings to whitelist
|
||||
func New(ctx context.Context, next http.Handler, config config.IPWhiteList, name string) (http.Handler, error) {
|
||||
logger := middlewares.GetLogger(ctx, name, typeName)
|
||||
logger.Debug("Creating middleware")
|
||||
|
||||
if len(config.SourceRange) == 0 {
|
||||
return nil, errors.New("sourceRange is empty, IPWhiteLister not created")
|
||||
}
|
||||
|
||||
checker, err := ip.NewChecker(config.SourceRange)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot parse CIDR whitelist %s: %v", config.SourceRange, err)
|
||||
}
|
||||
|
||||
strategy, err := config.IPStrategy.Get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Debugf("Setting up IPWhiteLister with sourceRange: %s", config.SourceRange)
|
||||
return &ipWhiteLister{
|
||||
strategy: strategy,
|
||||
whiteLister: checker,
|
||||
next: next,
|
||||
name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (wl *ipWhiteLister) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return wl.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (wl *ipWhiteLister) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
logger := middlewares.GetLogger(req.Context(), wl.name, typeName)
|
||||
|
||||
err := wl.whiteLister.IsAuthorized(wl.strategy.GetIP(req))
|
||||
if err != nil {
|
||||
logMessage := fmt.Sprintf("rejecting request %+v: %v", req, err)
|
||||
logger.Debug(logMessage)
|
||||
tracing.SetErrorWithEvent(req, logMessage)
|
||||
reject(logger, rw)
|
||||
return
|
||||
}
|
||||
logger.Debugf("Accept %s: %+v", wl.strategy.GetIP(req), req)
|
||||
|
||||
wl.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
func reject(logger logrus.FieldLogger, rw http.ResponseWriter) {
|
||||
statusCode := http.StatusForbidden
|
||||
|
||||
rw.WriteHeader(statusCode)
|
||||
_, err := rw.Write([]byte(http.StatusText(statusCode)))
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
}
|
|
@ -1,11 +1,12 @@
|
|||
package middlewares
|
||||
package ipwhitelist
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/ip"
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
@ -13,18 +14,21 @@ import (
|
|||
func TestNewIPWhiteLister(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
whiteList []string
|
||||
expectedError string
|
||||
whiteList config.IPWhiteList
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
desc: "invalid IP",
|
||||
whiteList: []string{"foo"},
|
||||
expectedError: "parsing CIDR whitelist [foo]: parsing CIDR trusted IPs <nil>: invalid CIDR address: foo",
|
||||
desc: "invalid IP",
|
||||
whiteList: config.IPWhiteList{
|
||||
SourceRange: []string{"foo"},
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
desc: "valid IP",
|
||||
whiteList: []string{"10.10.10.10"},
|
||||
expectedError: "",
|
||||
desc: "valid IP",
|
||||
whiteList: config.IPWhiteList{
|
||||
SourceRange: []string{"10.10.10.10"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -33,10 +37,11 @@ func TestNewIPWhiteLister(t *testing.T) {
|
|||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
whiteLister, err := NewIPWhiteLister(test.whiteList, &ip.RemoteAddrStrategy{})
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
whiteLister, err := New(context.Background(), next, test.whiteList, "traefikTest")
|
||||
|
||||
if len(test.expectedError) > 0 {
|
||||
assert.EqualError(t, err, test.expectedError)
|
||||
if test.expectedError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, whiteLister)
|
||||
|
@ -48,19 +53,23 @@ func TestNewIPWhiteLister(t *testing.T) {
|
|||
func TestIPWhiteLister_ServeHTTP(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
whiteList []string
|
||||
whiteList config.IPWhiteList
|
||||
remoteAddr string
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
desc: "authorized with remote address",
|
||||
whiteList: []string{"20.20.20.20"},
|
||||
desc: "authorized with remote address",
|
||||
whiteList: config.IPWhiteList{
|
||||
SourceRange: []string{"20.20.20.20"},
|
||||
},
|
||||
remoteAddr: "20.20.20.20:1234",
|
||||
expected: 200,
|
||||
},
|
||||
{
|
||||
desc: "non authorized with remote address",
|
||||
whiteList: []string{"20.20.20.20"},
|
||||
desc: "non authorized with remote address",
|
||||
whiteList: config.IPWhiteList{
|
||||
SourceRange: []string{"20.20.20.20"},
|
||||
},
|
||||
remoteAddr: "20.20.20.21:1234",
|
||||
expected: 403,
|
||||
},
|
||||
|
@ -71,7 +80,8 @@ func TestIPWhiteLister_ServeHTTP(t *testing.T) {
|
|||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
whiteLister, err := NewIPWhiteLister(test.whiteList, &ip.RemoteAddrStrategy{})
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
whiteLister, err := New(context.Background(), next, test.whiteList, "traefikTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
@ -82,9 +92,7 @@ func TestIPWhiteLister_ServeHTTP(t *testing.T) {
|
|||
req.RemoteAddr = test.remoteAddr
|
||||
}
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
whiteLister.ServeHTTP(recorder, req, next)
|
||||
whiteLister.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, test.expected, recorder.Code)
|
||||
})
|
48
middlewares/maxconnection/max_connection.go
Normal file
48
middlewares/maxconnection/max_connection.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package maxconnection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/vulcand/oxy/connlimit"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "MaxConnection"
|
||||
)
|
||||
|
||||
type maxConnection struct {
|
||||
handler http.Handler
|
||||
name string
|
||||
}
|
||||
|
||||
// New creates a max connection middleware.
|
||||
func New(ctx context.Context, next http.Handler, maxConns config.MaxConn, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
|
||||
extractFunc, err := utils.NewExtractor(maxConns.ExtractorFunc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating connection limit: %v", err)
|
||||
}
|
||||
|
||||
handler, err := connlimit.New(next, extractFunc, maxConns.Amount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating connection limit: %v", err)
|
||||
}
|
||||
|
||||
return &maxConnection{handler: handler, name: name}, nil
|
||||
}
|
||||
|
||||
func (mc *maxConnection) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return mc.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (mc *maxConnection) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
mc.handler.ServeHTTP(rw, req)
|
||||
}
|
|
@ -1,131 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/metrics"
|
||||
gokitmetrics "github.com/go-kit/kit/metrics"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
const (
|
||||
protoHTTP = "http"
|
||||
protoSSE = "sse"
|
||||
protoWebsocket = "websocket"
|
||||
)
|
||||
|
||||
// NewEntryPointMetricsMiddleware creates a new metrics middleware for an Entrypoint.
|
||||
func NewEntryPointMetricsMiddleware(registry metrics.Registry, entryPointName string) negroni.Handler {
|
||||
return &metricsMiddleware{
|
||||
reqsCounter: registry.EntrypointReqsCounter(),
|
||||
reqDurationHistogram: registry.EntrypointReqDurationHistogram(),
|
||||
openConnsGauge: registry.EntrypointOpenConnsGauge(),
|
||||
baseLabels: []string{"entrypoint", entryPointName},
|
||||
}
|
||||
}
|
||||
|
||||
// NewBackendMetricsMiddleware creates a new metrics middleware for a Backend.
|
||||
func NewBackendMetricsMiddleware(registry metrics.Registry, backendName string) negroni.Handler {
|
||||
return &metricsMiddleware{
|
||||
reqsCounter: registry.BackendReqsCounter(),
|
||||
reqDurationHistogram: registry.BackendReqDurationHistogram(),
|
||||
openConnsGauge: registry.BackendOpenConnsGauge(),
|
||||
baseLabels: []string{"backend", backendName},
|
||||
}
|
||||
}
|
||||
|
||||
type metricsMiddleware struct {
|
||||
// Important: Since this int64 field is using sync/atomic, it has to be at the top of the struct due to a bug on 32-bit platform
|
||||
// See: https://golang.org/pkg/sync/atomic/ for more information
|
||||
openConns int64
|
||||
reqsCounter gokitmetrics.Counter
|
||||
reqDurationHistogram gokitmetrics.Histogram
|
||||
openConnsGauge gokitmetrics.Gauge
|
||||
baseLabels []string
|
||||
}
|
||||
|
||||
func (m *metricsMiddleware) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
labels := []string{"method", getMethod(r), "protocol", getRequestProtocol(r)}
|
||||
labels = append(labels, m.baseLabels...)
|
||||
|
||||
openConns := atomic.AddInt64(&m.openConns, 1)
|
||||
m.openConnsGauge.With(labels...).Set(float64(openConns))
|
||||
defer func(labelValues []string) {
|
||||
openConns := atomic.AddInt64(&m.openConns, -1)
|
||||
m.openConnsGauge.With(labelValues...).Set(float64(openConns))
|
||||
}(labels)
|
||||
|
||||
start := time.Now()
|
||||
recorder := &responseRecorder{rw, http.StatusOK}
|
||||
next(recorder, r)
|
||||
|
||||
labels = append(labels, "code", strconv.Itoa(recorder.statusCode))
|
||||
m.reqsCounter.With(labels...).Add(1)
|
||||
m.reqDurationHistogram.With(labels...).Observe(time.Since(start).Seconds())
|
||||
}
|
||||
|
||||
func getRequestProtocol(req *http.Request) string {
|
||||
switch {
|
||||
case isWebsocketRequest(req):
|
||||
return protoWebsocket
|
||||
case isSSERequest(req):
|
||||
return protoSSE
|
||||
default:
|
||||
return protoHTTP
|
||||
}
|
||||
}
|
||||
|
||||
// isWebsocketRequest determines if the specified HTTP request is a websocket handshake request.
|
||||
func isWebsocketRequest(req *http.Request) bool {
|
||||
return containsHeader(req, "Connection", "upgrade") && containsHeader(req, "Upgrade", "websocket")
|
||||
}
|
||||
|
||||
// isSSERequest determines if the specified HTTP request is a request for an event subscription.
|
||||
func isSSERequest(req *http.Request) bool {
|
||||
return containsHeader(req, "Accept", "text/event-stream")
|
||||
}
|
||||
|
||||
func containsHeader(req *http.Request, name, value string) bool {
|
||||
items := strings.Split(req.Header.Get(name), ",")
|
||||
for _, item := range items {
|
||||
if value == strings.ToLower(strings.TrimSpace(item)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getMethod(r *http.Request) string {
|
||||
if !utf8.ValidString(r.Method) {
|
||||
log.Warnf("Invalid HTTP method encoding: %s", r.Method)
|
||||
return "NON_UTF8_HTTP_METHOD"
|
||||
}
|
||||
return r.Method
|
||||
}
|
||||
|
||||
type retryMetrics interface {
|
||||
BackendRetriesCounter() gokitmetrics.Counter
|
||||
}
|
||||
|
||||
// NewMetricsRetryListener instantiates a MetricsRetryListener with the given retryMetrics.
|
||||
func NewMetricsRetryListener(retryMetrics retryMetrics, backendName string) RetryListener {
|
||||
return &MetricsRetryListener{retryMetrics: retryMetrics, backendName: backendName}
|
||||
}
|
||||
|
||||
// MetricsRetryListener is an implementation of the RetryListener interface to
|
||||
// record RequestMetrics about retry attempts.
|
||||
type MetricsRetryListener struct {
|
||||
retryMetrics retryMetrics
|
||||
backendName string
|
||||
}
|
||||
|
||||
// Retried tracks the retry in the RequestMetrics implementation.
|
||||
func (m *MetricsRetryListener) Retried(req *http.Request, attempt int) {
|
||||
m.retryMetrics.BackendRetriesCounter().With("backend", m.backendName).Add(1)
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/go-kit/kit/metrics"
|
||||
)
|
||||
|
||||
func TestMetricsRetryListener(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
retryMetrics := newCollectingRetryMetrics()
|
||||
retryListener := NewMetricsRetryListener(retryMetrics, "backendName")
|
||||
retryListener.Retried(req, 1)
|
||||
retryListener.Retried(req, 2)
|
||||
|
||||
wantCounterValue := float64(2)
|
||||
if retryMetrics.retriesCounter.CounterValue != wantCounterValue {
|
||||
t.Errorf("got counter value of %f, want %f", retryMetrics.retriesCounter.CounterValue, wantCounterValue)
|
||||
}
|
||||
|
||||
wantLabelValues := []string{"backend", "backendName"}
|
||||
if !reflect.DeepEqual(retryMetrics.retriesCounter.LastLabelValues, wantLabelValues) {
|
||||
t.Errorf("wrong label values %v used, want %v", retryMetrics.retriesCounter.LastLabelValues, wantLabelValues)
|
||||
}
|
||||
}
|
||||
|
||||
// collectingRetryMetrics is an implementation of the retryMetrics interface that can be used inside tests to collect the times Add() was called.
|
||||
type collectingRetryMetrics struct {
|
||||
retriesCounter *testhelpers.CollectingCounter
|
||||
}
|
||||
|
||||
func newCollectingRetryMetrics() *collectingRetryMetrics {
|
||||
return &collectingRetryMetrics{retriesCounter: &testhelpers.CollectingCounter{}}
|
||||
}
|
||||
|
||||
func (metrics *collectingRetryMetrics) BackendRetriesCounter() metrics.Counter {
|
||||
return metrics.retriesCounter
|
||||
}
|
13
middlewares/middleware.go
Normal file
13
middlewares/middleware.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// GetLogger creates a logger configured with the middleware fields.
|
||||
func GetLogger(ctx context.Context, middleware string, middlewareType string) logrus.FieldLogger {
|
||||
return log.FromContext(ctx).WithField(log.MiddlewareName, middleware).WithField(log.MiddlewareType, middlewareType)
|
||||
}
|
253
middlewares/passtlsclientcert/pass_tls_client_cert.go
Normal file
253
middlewares/passtlsclientcert/pass_tls_client_cert.go
Normal file
|
@ -0,0 +1,253 @@
|
|||
package passtlsclientcert
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
xForwardedTLSClientCert = "X-Forwarded-Tls-Client-Cert"
|
||||
xForwardedTLSClientCertInfos = "X-Forwarded-Tls-Client-Cert-infos"
|
||||
typeName = "PassClientTLSCert"
|
||||
)
|
||||
|
||||
// passTLSClientCert is a middleware that helps setup a few tls info features.
|
||||
type passTLSClientCert struct {
|
||||
next http.Handler
|
||||
name string
|
||||
pem bool // pass the sanitized pem to the backend in a specific header
|
||||
infos *tlsClientCertificateInfos // pass selected information from the client certificate
|
||||
}
|
||||
|
||||
// New constructs a new PassTLSClientCert instance from supplied frontend header struct.
|
||||
func New(ctx context.Context, next http.Handler, config config.PassTLSClientCert, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
|
||||
return &passTLSClientCert{
|
||||
next: next,
|
||||
name: name,
|
||||
pem: config.PEM,
|
||||
infos: newTLSClientInfos(config.Infos),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// tlsClientCertificateInfos is a struct for specifying the configuration for the passTLSClientCert middleware.
|
||||
type tlsClientCertificateInfos struct {
|
||||
notAfter bool
|
||||
notBefore bool
|
||||
subject *tlsCLientCertificateSubjectInfos
|
||||
sans bool
|
||||
}
|
||||
|
||||
func newTLSClientInfos(infos *config.TLSClientCertificateInfos) *tlsClientCertificateInfos {
|
||||
if infos == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &tlsClientCertificateInfos{
|
||||
notBefore: infos.NotBefore,
|
||||
notAfter: infos.NotAfter,
|
||||
sans: infos.Sans,
|
||||
subject: newTLSCLientCertificateSubjectInfos(infos.Subject),
|
||||
}
|
||||
}
|
||||
|
||||
// tlsCLientCertificateSubjectInfos contains the configuration for the certificate subject infos.
|
||||
type tlsCLientCertificateSubjectInfos struct {
|
||||
country bool
|
||||
province bool
|
||||
locality bool
|
||||
Organization bool
|
||||
commonName bool
|
||||
serialNumber bool
|
||||
}
|
||||
|
||||
func newTLSCLientCertificateSubjectInfos(infos *config.TLSCLientCertificateSubjectInfos) *tlsCLientCertificateSubjectInfos {
|
||||
if infos == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &tlsCLientCertificateSubjectInfos{
|
||||
serialNumber: infos.SerialNumber,
|
||||
commonName: infos.CommonName,
|
||||
country: infos.Country,
|
||||
locality: infos.Locality,
|
||||
Organization: infos.Organization,
|
||||
province: infos.Province,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *passTLSClientCert) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return p.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (p *passTLSClientCert) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
logger := middlewares.GetLogger(req.Context(), p.name, typeName)
|
||||
p.modifyRequestHeaders(logger, req)
|
||||
p.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
// getSubjectInfos extract the requested information from the certificate subject.
|
||||
func (p *passTLSClientCert) getSubjectInfos(cs *pkix.Name) string {
|
||||
var subject string
|
||||
|
||||
if p.infos != nil && p.infos.subject != nil {
|
||||
options := p.infos.subject
|
||||
|
||||
var content []string
|
||||
|
||||
if options.country && len(cs.Country) > 0 {
|
||||
content = append(content, fmt.Sprintf("C=%s", cs.Country[0]))
|
||||
}
|
||||
|
||||
if options.province && len(cs.Province) > 0 {
|
||||
content = append(content, fmt.Sprintf("ST=%s", cs.Province[0]))
|
||||
}
|
||||
|
||||
if options.locality && len(cs.Locality) > 0 {
|
||||
content = append(content, fmt.Sprintf("L=%s", cs.Locality[0]))
|
||||
}
|
||||
|
||||
if options.Organization && len(cs.Organization) > 0 {
|
||||
content = append(content, fmt.Sprintf("O=%s", cs.Organization[0]))
|
||||
}
|
||||
|
||||
if options.commonName && len(cs.CommonName) > 0 {
|
||||
content = append(content, fmt.Sprintf("CN=%s", cs.CommonName))
|
||||
}
|
||||
|
||||
if len(content) > 0 {
|
||||
subject = `Subject="` + strings.Join(content, ",") + `"`
|
||||
}
|
||||
}
|
||||
|
||||
return subject
|
||||
}
|
||||
|
||||
// getXForwardedTLSClientCertInfos Build a string with the wanted client certificates information
|
||||
// like Subject="C=%s,ST=%s,L=%s,O=%s,CN=%s",NB=%d,NA=%d,SAN=%s;
|
||||
func (p *passTLSClientCert) getXForwardedTLSClientCertInfos(certs []*x509.Certificate) string {
|
||||
var headerValues []string
|
||||
|
||||
for _, peerCert := range certs {
|
||||
var values []string
|
||||
var sans string
|
||||
var nb string
|
||||
var na string
|
||||
|
||||
subject := p.getSubjectInfos(&peerCert.Subject)
|
||||
if len(subject) > 0 {
|
||||
values = append(values, subject)
|
||||
}
|
||||
|
||||
ci := p.infos
|
||||
if ci != nil {
|
||||
if ci.notBefore {
|
||||
nb = fmt.Sprintf("NB=%d", uint64(peerCert.NotBefore.Unix()))
|
||||
values = append(values, nb)
|
||||
}
|
||||
if ci.notAfter {
|
||||
na = fmt.Sprintf("NA=%d", uint64(peerCert.NotAfter.Unix()))
|
||||
values = append(values, na)
|
||||
}
|
||||
|
||||
if ci.sans {
|
||||
sans = fmt.Sprintf("SAN=%s", strings.Join(getSANs(peerCert), ","))
|
||||
values = append(values, sans)
|
||||
}
|
||||
}
|
||||
|
||||
value := strings.Join(values, ",")
|
||||
headerValues = append(headerValues, value)
|
||||
}
|
||||
|
||||
return strings.Join(headerValues, ";")
|
||||
}
|
||||
|
||||
// modifyRequestHeaders set the wanted headers with the certificates information.
|
||||
func (p *passTLSClientCert) modifyRequestHeaders(logger logrus.FieldLogger, r *http.Request) {
|
||||
if p.pem {
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
r.Header.Set(xForwardedTLSClientCert, getXForwardedTLSClientCert(logger, r.TLS.PeerCertificates))
|
||||
} else {
|
||||
logger.Warn("Try to extract certificate on a request without TLS")
|
||||
}
|
||||
}
|
||||
|
||||
if p.infos != nil {
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
headerContent := p.getXForwardedTLSClientCertInfos(r.TLS.PeerCertificates)
|
||||
r.Header.Set(xForwardedTLSClientCertInfos, url.QueryEscape(headerContent))
|
||||
} else {
|
||||
logger.Warn("Try to extract certificate on a request without TLS")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sanitize As we pass the raw certificates, remove the useless data and make it http request compliant.
|
||||
func sanitize(cert []byte) string {
|
||||
s := string(cert)
|
||||
r := strings.NewReplacer("-----BEGIN CERTIFICATE-----", "",
|
||||
"-----END CERTIFICATE-----", "",
|
||||
"\n", "")
|
||||
cleaned := r.Replace(s)
|
||||
|
||||
return url.QueryEscape(cleaned)
|
||||
}
|
||||
|
||||
// extractCertificate extract the certificate from the request.
|
||||
func extractCertificate(logger logrus.FieldLogger, cert *x509.Certificate) string {
|
||||
b := pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
|
||||
certPEM := pem.EncodeToMemory(&b)
|
||||
if certPEM == nil {
|
||||
logger.Error("Cannot extract the certificate content")
|
||||
return ""
|
||||
}
|
||||
return sanitize(certPEM)
|
||||
}
|
||||
|
||||
// getXForwardedTLSClientCert Build a string with the client certificates.
|
||||
func getXForwardedTLSClientCert(logger logrus.FieldLogger, certs []*x509.Certificate) string {
|
||||
var headerValues []string
|
||||
|
||||
for _, peerCert := range certs {
|
||||
headerValues = append(headerValues, extractCertificate(logger, peerCert))
|
||||
}
|
||||
|
||||
return strings.Join(headerValues, ",")
|
||||
}
|
||||
|
||||
// getSANs get the Subject Alternate Name values.
|
||||
func getSANs(cert *x509.Certificate) []string {
|
||||
var sans []string
|
||||
if cert == nil {
|
||||
return sans
|
||||
}
|
||||
|
||||
sans = append(cert.DNSNames, cert.EmailAddresses...)
|
||||
|
||||
var ips []string
|
||||
for _, ip := range cert.IPAddresses {
|
||||
ips = append(ips, ip.String())
|
||||
}
|
||||
sans = append(sans, ips...)
|
||||
|
||||
var uris []string
|
||||
for _, uri := range cert.URIs {
|
||||
uris = append(uris, uri.String())
|
||||
}
|
||||
|
||||
return append(sans, uris...)
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package middlewares
|
||||
package passtlsclientcert
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
|
@ -12,8 +13,8 @@ import (
|
|||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
|
@ -69,8 +70,8 @@ jECvgAY7Nfd9mZ1KtyNaW31is+kag7NsvjxU/kM=
|
|||
Validity
|
||||
Not Before: Jul 18 08:00:16 2018 GMT
|
||||
Not After : Jul 18 08:00:16 2019 GMT
|
||||
Subject: C=FR, ST=SomeState, L=Toulouse, O=Cheese, CN=*.cheese.org
|
||||
Subject Public Key Info:
|
||||
subject: C=FR, ST=SomeState, L=Toulouse, O=Cheese, CN=*.cheese.org
|
||||
subject Public Key Info:
|
||||
Public Key Algorithm: rsaEncryption
|
||||
Public-Key: (2048 bit)
|
||||
Modulus:
|
||||
|
@ -96,9 +97,9 @@ jECvgAY7Nfd9mZ1KtyNaW31is+kag7NsvjxU/kM=
|
|||
X509v3 extensions:
|
||||
X509v3 Basic Constraints:
|
||||
CA:FALSE
|
||||
X509v3 Subject Alternative Name:
|
||||
X509v3 subject Alternative Name:
|
||||
DNS:*.cheese.org, DNS:*.cheese.net, DNS:cheese.in, IP Address:10.0.1.0, IP Address:10.0.1.2, email:test@cheese.org, email:test@cheese.net
|
||||
X509v3 Subject Key Identifier:
|
||||
X509v3 subject Key Identifier:
|
||||
AB:6B:89:25:11:FC:5E:7B:D4:B0:F7:D4:B6:D9:EB:D0:30:93:E5:58
|
||||
Signature Algorithm: sha1WithRSAEncryption
|
||||
ad:87:84:a0:88:a3:4c:d9:0a:c0:14:e4:2d:9a:1d:bb:57:b7:
|
||||
|
@ -147,7 +148,7 @@ func getCleanCertContents(certContents []string) string {
|
|||
|
||||
var cleanedCertContent []string
|
||||
for _, certContent := range certContents {
|
||||
cert := re.FindString(string(certContent))
|
||||
cert := re.FindString(certContent)
|
||||
cleanedCertContent = append(cleanedCertContent, sanitize([]byte(cert)))
|
||||
}
|
||||
|
||||
|
@ -183,8 +184,11 @@ func buildTLSWith(certContents []string) *tls.ConnectionState {
|
|||
return &tls.ConnectionState{PeerCertificates: peerCertificates}
|
||||
}
|
||||
|
||||
var myPassTLSClientCustomHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("bar"))
|
||||
var next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := w.Write([]byte("bar"))
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
func getExpectedSanitized(s string) string {
|
||||
|
@ -234,12 +238,12 @@ jECvgAY7Nfd9mZ1KtyNaW31is+kag7NsvjxU/kM=`),
|
|||
|
||||
}
|
||||
|
||||
func TestTlsClientheadersWithPEM(t *testing.T) {
|
||||
func TestTLSClientHeadersWithPEM(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
certContents []string // set the request TLS attribute if defined
|
||||
tlsClientCertHeaders *types.TLSClientHeaders
|
||||
expectedHeader string
|
||||
desc string
|
||||
certContents []string // set the request TLS attribute if defined
|
||||
config config.PassTLSClientCert
|
||||
expectedHeader string
|
||||
}{
|
||||
{
|
||||
desc: "No TLS, no option",
|
||||
|
@ -249,31 +253,32 @@ func TestTlsClientheadersWithPEM(t *testing.T) {
|
|||
certContents: []string{minimalCert},
|
||||
},
|
||||
{
|
||||
desc: "No TLS, with pem option true",
|
||||
tlsClientCertHeaders: &types.TLSClientHeaders{PEM: true},
|
||||
desc: "No TLS, with pem option true",
|
||||
config: config.PassTLSClientCert{PEM: true},
|
||||
},
|
||||
{
|
||||
desc: "TLS with simple certificate, with pem option true",
|
||||
certContents: []string{minimalCert},
|
||||
tlsClientCertHeaders: &types.TLSClientHeaders{PEM: true},
|
||||
expectedHeader: getCleanCertContents([]string{minimalCert}),
|
||||
desc: "TLS with simple certificate, with pem option true",
|
||||
certContents: []string{minimalCert},
|
||||
config: config.PassTLSClientCert{PEM: true},
|
||||
expectedHeader: getCleanCertContents([]string{minimalCert}),
|
||||
},
|
||||
{
|
||||
desc: "TLS with complete certificate, with pem option true",
|
||||
certContents: []string{completeCert},
|
||||
tlsClientCertHeaders: &types.TLSClientHeaders{PEM: true},
|
||||
expectedHeader: getCleanCertContents([]string{completeCert}),
|
||||
desc: "TLS with complete certificate, with pem option true",
|
||||
certContents: []string{completeCert},
|
||||
config: config.PassTLSClientCert{PEM: true},
|
||||
expectedHeader: getCleanCertContents([]string{completeCert}),
|
||||
},
|
||||
{
|
||||
desc: "TLS with two certificate, with pem option true",
|
||||
certContents: []string{minimalCert, completeCert},
|
||||
tlsClientCertHeaders: &types.TLSClientHeaders{PEM: true},
|
||||
expectedHeader: getCleanCertContents([]string{minimalCert, completeCert}),
|
||||
desc: "TLS with two certificate, with pem option true",
|
||||
certContents: []string{minimalCert, completeCert},
|
||||
config: config.PassTLSClientCert{PEM: true},
|
||||
expectedHeader: getCleanCertContents([]string{minimalCert, completeCert}),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
tlsClientHeaders := NewTLSClientHeaders(&types.Frontend{PassTLSClientCert: test.tlsClientCertHeaders})
|
||||
tlsClientHeaders, err := New(context.Background(), next, test.config, "foo")
|
||||
require.NoError(t, err)
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://example.com/foo", nil)
|
||||
|
@ -282,7 +287,7 @@ func TestTlsClientheadersWithPEM(t *testing.T) {
|
|||
req.TLS = buildTLSWith(test.certContents)
|
||||
}
|
||||
|
||||
tlsClientHeaders.ServeHTTP(res, req, myPassTLSClientCustomHandler)
|
||||
tlsClientHeaders.ServeHTTP(res, req)
|
||||
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
|
@ -334,8 +339,8 @@ func TestGetSans(t *testing.T) {
|
|||
|
||||
for _, test := range testCases {
|
||||
sans := getSANs(test.cert)
|
||||
test := test
|
||||
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -351,15 +356,15 @@ func TestGetSans(t *testing.T) {
|
|||
|
||||
}
|
||||
|
||||
func TestTlsClientheadersWithCertInfos(t *testing.T) {
|
||||
func TestTLSClientHeadersWithCertInfos(t *testing.T) {
|
||||
minimalCertAllInfos := `Subject="C=FR,ST=Some-State,O=Internet Widgits Pty Ltd",NB=1531902496,NA=1534494496,SAN=`
|
||||
completeCertAllInfos := `Subject="C=FR,ST=SomeState,L=Toulouse,O=Cheese,CN=*.cheese.org",NB=1531900816,NA=1563436816,SAN=*.cheese.org,*.cheese.net,cheese.in,test@cheese.org,test@cheese.net,10.0.1.0,10.0.1.2`
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
certContents []string // set the request TLS attribute if defined
|
||||
tlsClientCertHeaders *types.TLSClientHeaders
|
||||
expectedHeader string
|
||||
desc string
|
||||
certContents []string // set the request TLS attribute if defined
|
||||
config config.PassTLSClientCert
|
||||
expectedHeader string
|
||||
}{
|
||||
{
|
||||
desc: "No TLS, no option",
|
||||
|
@ -370,9 +375,9 @@ func TestTlsClientheadersWithCertInfos(t *testing.T) {
|
|||
},
|
||||
{
|
||||
desc: "No TLS, with pem option true",
|
||||
tlsClientCertHeaders: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
config: config.PassTLSClientCert{
|
||||
Infos: &config.TLSClientCertificateInfos{
|
||||
Subject: &config.TLSCLientCertificateSubjectInfos{
|
||||
CommonName: true,
|
||||
Organization: true,
|
||||
Locality: true,
|
||||
|
@ -385,21 +390,21 @@ func TestTlsClientheadersWithCertInfos(t *testing.T) {
|
|||
},
|
||||
{
|
||||
desc: "No TLS, with pem option true with no flag",
|
||||
tlsClientCertHeaders: &types.TLSClientHeaders{
|
||||
config: config.PassTLSClientCert{
|
||||
PEM: false,
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{},
|
||||
Infos: &config.TLSClientCertificateInfos{
|
||||
Subject: &config.TLSCLientCertificateSubjectInfos{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "TLS with simple certificate, with all infos",
|
||||
certContents: []string{minimalCert},
|
||||
tlsClientCertHeaders: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
config: config.PassTLSClientCert{
|
||||
Infos: &config.TLSClientCertificateInfos{
|
||||
NotAfter: true,
|
||||
NotBefore: true,
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
Subject: &config.TLSCLientCertificateSubjectInfos{
|
||||
CommonName: true,
|
||||
Organization: true,
|
||||
Locality: true,
|
||||
|
@ -415,10 +420,10 @@ func TestTlsClientheadersWithCertInfos(t *testing.T) {
|
|||
{
|
||||
desc: "TLS with simple certificate, with some infos",
|
||||
certContents: []string{minimalCert},
|
||||
tlsClientCertHeaders: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
config: config.PassTLSClientCert{
|
||||
Infos: &config.TLSClientCertificateInfos{
|
||||
NotAfter: true,
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
Subject: &config.TLSCLientCertificateSubjectInfos{
|
||||
Organization: true,
|
||||
},
|
||||
Sans: true,
|
||||
|
@ -429,11 +434,11 @@ func TestTlsClientheadersWithCertInfos(t *testing.T) {
|
|||
{
|
||||
desc: "TLS with complete certificate, with all infos",
|
||||
certContents: []string{completeCert},
|
||||
tlsClientCertHeaders: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
config: config.PassTLSClientCert{
|
||||
Infos: &config.TLSClientCertificateInfos{
|
||||
NotAfter: true,
|
||||
NotBefore: true,
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
Subject: &config.TLSCLientCertificateSubjectInfos{
|
||||
CommonName: true,
|
||||
Organization: true,
|
||||
Locality: true,
|
||||
|
@ -449,11 +454,11 @@ func TestTlsClientheadersWithCertInfos(t *testing.T) {
|
|||
{
|
||||
desc: "TLS with 2 certificates, with all infos",
|
||||
certContents: []string{minimalCert, completeCert},
|
||||
tlsClientCertHeaders: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
config: config.PassTLSClientCert{
|
||||
Infos: &config.TLSClientCertificateInfos{
|
||||
NotAfter: true,
|
||||
NotBefore: true,
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
Subject: &config.TLSCLientCertificateSubjectInfos{
|
||||
CommonName: true,
|
||||
Organization: true,
|
||||
Locality: true,
|
||||
|
@ -469,7 +474,8 @@ func TestTlsClientheadersWithCertInfos(t *testing.T) {
|
|||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
tlsClientHeaders := NewTLSClientHeaders(&types.Frontend{PassTLSClientCert: test.tlsClientCertHeaders})
|
||||
tlsClientHeaders, err := New(context.Background(), next, test.config, "foo")
|
||||
require.NoError(t, err)
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://example.com/foo", nil)
|
||||
|
@ -478,7 +484,7 @@ func TestTlsClientheadersWithCertInfos(t *testing.T) {
|
|||
req.TLS = buildTLSWith(test.certContents)
|
||||
}
|
||||
|
||||
tlsClientHeaders.ServeHTTP(res, req, myPassTLSClientCustomHandler)
|
||||
tlsClientHeaders.ServeHTTP(res, req)
|
||||
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
|
@ -497,303 +503,3 @@ func TestTlsClientheadersWithCertInfos(t *testing.T) {
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
func TestNewTLSClientHeadersFromStruct(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
frontend *types.Frontend
|
||||
expected *TLSClientHeaders
|
||||
}{
|
||||
{
|
||||
desc: "Without frontend",
|
||||
},
|
||||
{
|
||||
desc: "frontend without the option",
|
||||
frontend: &types.Frontend{},
|
||||
expected: &TLSClientHeaders{},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the pem set false",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
PEM: false,
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{PEM: false},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the pem set true",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
PEM: true,
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{PEM: true},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos with no flag",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
NotAfter: false,
|
||||
NotBefore: false,
|
||||
Sans: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos basic",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
NotAfter: true,
|
||||
NotBefore: true,
|
||||
Sans: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
NotBefore: true,
|
||||
NotAfter: true,
|
||||
Sans: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos NotAfter",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
NotAfter: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
NotAfter: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos NotBefore",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
NotBefore: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
NotBefore: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos Sans",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
Sans: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
Sans: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos Subject Organization",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
Organization: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
Subject: &TLSCLientCertificateSubjectInfos{
|
||||
Organization: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos Subject Country",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
Country: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
Subject: &TLSCLientCertificateSubjectInfos{
|
||||
Country: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos Subject SerialNumber",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
SerialNumber: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
Subject: &TLSCLientCertificateSubjectInfos{
|
||||
SerialNumber: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos Subject Province",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
Province: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
Subject: &TLSCLientCertificateSubjectInfos{
|
||||
Province: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos Subject Locality",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
Locality: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
Subject: &TLSCLientCertificateSubjectInfos{
|
||||
Locality: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos Subject CommonName",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
CommonName: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
Subject: &TLSCLientCertificateSubjectInfos{
|
||||
CommonName: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos NotBefore",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
Sans: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
Sans: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos all",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
NotAfter: true,
|
||||
NotBefore: true,
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
CommonName: true,
|
||||
Country: true,
|
||||
Locality: true,
|
||||
Organization: true,
|
||||
Province: true,
|
||||
SerialNumber: true,
|
||||
},
|
||||
Sans: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
NotBefore: true,
|
||||
NotAfter: true,
|
||||
Sans: true,
|
||||
Subject: &TLSCLientCertificateSubjectInfos{
|
||||
Province: true,
|
||||
Organization: true,
|
||||
Locality: true,
|
||||
Country: true,
|
||||
CommonName: true,
|
||||
SerialNumber: true,
|
||||
},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Equal(t, test.expected, NewTLSClientHeaders(test.frontend))
|
||||
})
|
||||
}
|
||||
|
||||
}
|
|
@ -1,62 +0,0 @@
|
|||
package pipelining
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Pipelining returns a middleware
|
||||
type Pipelining struct {
|
||||
next http.Handler
|
||||
}
|
||||
|
||||
// NewPipelining returns a new Pipelining instance
|
||||
func NewPipelining(next http.Handler) *Pipelining {
|
||||
return &Pipelining{
|
||||
next: next,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pipelining) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
// https://github.com/golang/go/blob/3d59583836630cf13ec4bfbed977d27b1b7adbdc/src/net/http/server.go#L201-L218
|
||||
if r.Method == http.MethodPut || r.Method == http.MethodPost {
|
||||
p.next.ServeHTTP(rw, r)
|
||||
} else {
|
||||
p.next.ServeHTTP(&writerWithoutCloseNotify{rw}, r)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// writerWithoutCloseNotify helps to disable closeNotify
|
||||
type writerWithoutCloseNotify struct {
|
||||
W http.ResponseWriter
|
||||
}
|
||||
|
||||
// Header returns the response headers.
|
||||
func (w *writerWithoutCloseNotify) Header() http.Header {
|
||||
return w.W.Header()
|
||||
}
|
||||
|
||||
// Write writes the data to the connection as part of an HTTP reply.
|
||||
func (w *writerWithoutCloseNotify) Write(buf []byte) (int, error) {
|
||||
return w.W.Write(buf)
|
||||
}
|
||||
|
||||
// WriteHeader sends an HTTP response header with the provided
|
||||
// status code.
|
||||
func (w *writerWithoutCloseNotify) WriteHeader(code int) {
|
||||
w.W.WriteHeader(code)
|
||||
}
|
||||
|
||||
// Flush sends any buffered data to the client.
|
||||
func (w *writerWithoutCloseNotify) Flush() {
|
||||
if f, ok := w.W.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Hijack hijacks the connection.
|
||||
func (w *writerWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return w.W.(http.Hijacker).Hijack()
|
||||
}
|
|
@ -1,69 +0,0 @@
|
|||
package pipelining
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type recorderWithCloseNotify struct {
|
||||
*httptest.ResponseRecorder
|
||||
}
|
||||
|
||||
func (r *recorderWithCloseNotify) CloseNotify() <-chan bool {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func TestNewPipelining(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
HTTPMethod string
|
||||
implementCloseNotifier bool
|
||||
}{
|
||||
{
|
||||
desc: "should not implement CloseNotifier with GET method",
|
||||
HTTPMethod: http.MethodGet,
|
||||
implementCloseNotifier: false,
|
||||
},
|
||||
{
|
||||
desc: "should implement CloseNotifier with PUT method",
|
||||
HTTPMethod: http.MethodPut,
|
||||
implementCloseNotifier: true,
|
||||
},
|
||||
{
|
||||
desc: "should implement CloseNotifier with POST method",
|
||||
HTTPMethod: http.MethodPost,
|
||||
implementCloseNotifier: true,
|
||||
},
|
||||
{
|
||||
desc: "should not implement CloseNotifier with GET method",
|
||||
HTTPMethod: http.MethodHead,
|
||||
implementCloseNotifier: false,
|
||||
},
|
||||
{
|
||||
desc: "should not implement CloseNotifier with PROPFIND method",
|
||||
HTTPMethod: "PROPFIND",
|
||||
implementCloseNotifier: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, ok := w.(http.CloseNotifier)
|
||||
assert.Equal(t, test.implementCloseNotifier, ok)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler := NewPipelining(nextHandler)
|
||||
|
||||
req := httptest.NewRequest(test.HTTPMethod, "http://localhost", nil)
|
||||
|
||||
handler.ServeHTTP(&recorderWithCloseNotify{httptest.NewRecorder()}, req)
|
||||
})
|
||||
}
|
||||
}
|
54
middlewares/ratelimiter/rate_limiter.go
Normal file
54
middlewares/ratelimiter/rate_limiter.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package ratelimiter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/vulcand/oxy/ratelimit"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "RateLimiterType"
|
||||
)
|
||||
|
||||
type rateLimiter struct {
|
||||
handler http.Handler
|
||||
name string
|
||||
}
|
||||
|
||||
// New creates rate limiter middleware.
|
||||
func New(ctx context.Context, next http.Handler, config config.RateLimit, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
|
||||
extractFunc, err := utils.NewExtractor(config.ExtractorFunc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rateSet := ratelimit.NewRateSet()
|
||||
for _, rate := range config.RateSet {
|
||||
if err = rateSet.Add(time.Duration(rate.Period), rate.Average, rate.Burst); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
rl, err := ratelimit.New(next, extractFunc, rateSet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &rateLimiter{handler: rl, name: name}, nil
|
||||
}
|
||||
|
||||
func (r *rateLimiter) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return r.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (r *rateLimiter) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
r.handler.ServeHTTP(rw, req)
|
||||
}
|
|
@ -1,51 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"runtime"
|
||||
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
// RecoverHandler recovers from a panic in http handlers
|
||||
func RecoverHandler(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
defer recoverFunc(w, r)
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
|
||||
// NegroniRecoverHandler recovers from a panic in negroni handlers
|
||||
func NegroniRecoverHandler() negroni.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
defer recoverFunc(w, r)
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
return negroni.HandlerFunc(fn)
|
||||
}
|
||||
|
||||
func recoverFunc(w http.ResponseWriter, r *http.Request) {
|
||||
if err := recover(); err != nil {
|
||||
if !shouldLogPanic(err) {
|
||||
log.Debugf("Request has been aborted [%s - %s]: %v", r.RemoteAddr, r.URL, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Errorf("Recovered from panic in HTTP handler [%s - %s]: %+v", r.RemoteAddr, r.URL, err)
|
||||
|
||||
const size = 64 << 10
|
||||
buf := make([]byte, size)
|
||||
buf = buf[:runtime.Stack(buf, false)]
|
||||
log.Errorf("Stack: %s", buf)
|
||||
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/golang/go/blob/a0d6420d8be2ae7164797051ec74fa2a2df466a1/src/net/http/server.go#L1761-L1775
|
||||
// https://github.com/golang/go/blob/c33153f7b416c03983324b3e8f869ce1116d84bc/src/net/http/httputil/reverseproxy.go#L284
|
||||
func shouldLogPanic(panicValue interface{}) bool {
|
||||
return panicValue != nil && panicValue != http.ErrAbortHandler
|
||||
}
|
|
@ -1,45 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
func TestRecoverHandler(t *testing.T) {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
panic("I love panicing!")
|
||||
}
|
||||
recoverHandler := RecoverHandler(http.HandlerFunc(fn))
|
||||
server := httptest.NewServer(recoverHandler)
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusInternalServerError {
|
||||
t.Fatalf("Received non-%d response: %d\n", http.StatusInternalServerError, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNegroniRecoverHandler(t *testing.T) {
|
||||
n := negroni.New()
|
||||
n.Use(NegroniRecoverHandler())
|
||||
panicHandler := func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
panic("I love panicing!")
|
||||
}
|
||||
n.UseFunc(negroni.HandlerFunc(panicHandler))
|
||||
server := httptest.NewServer(n)
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusInternalServerError {
|
||||
t.Fatalf("Received non-%d response: %d\n", http.StatusInternalServerError, resp.StatusCode)
|
||||
}
|
||||
}
|
40
middlewares/recovery/recovery.go
Normal file
40
middlewares/recovery/recovery.go
Normal file
|
@ -0,0 +1,40 @@
|
|||
package recovery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "Recovery"
|
||||
)
|
||||
|
||||
type recovery struct {
|
||||
next http.Handler
|
||||
name string
|
||||
}
|
||||
|
||||
// New creates recovery middleware.
|
||||
func New(ctx context.Context, next http.Handler, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
|
||||
return &recovery{
|
||||
next: next,
|
||||
name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (re *recovery) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
defer recoverFunc(middlewares.GetLogger(req.Context(), re.name, typeName), rw)
|
||||
re.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
func recoverFunc(logger logrus.FieldLogger, rw http.ResponseWriter) {
|
||||
if err := recover(); err != nil {
|
||||
logger.Errorf("Recovered from panic in http handler: %+v", err)
|
||||
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
27
middlewares/recovery/recovery_test.go
Normal file
27
middlewares/recovery/recovery_test.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
package recovery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRecoverHandler(t *testing.T) {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
panic("I love panicing!")
|
||||
}
|
||||
recovery, err := New(context.Background(), http.HandlerFunc(fn), "foo-recovery")
|
||||
require.NoError(t, err)
|
||||
|
||||
server := httptest.NewServer(recovery)
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
|
||||
}
|
|
@ -2,110 +2,87 @@ package redirect
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"context"
|
||||
"html/template"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/containous/traefik/configuration"
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/urfave/negroni"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultRedirectRegex = `^(?:https?:\/\/)?([\w\._-]+)(?::\d+)?(.*)$`
|
||||
typeName = "Redirect"
|
||||
)
|
||||
|
||||
// NewEntryPointHandler create a new redirection handler base on entry point
|
||||
func NewEntryPointHandler(dstEntryPoint *configuration.EntryPoint, permanent bool) (negroni.Handler, error) {
|
||||
exp := regexp.MustCompile(`(:\d+)`)
|
||||
match := exp.FindStringSubmatch(dstEntryPoint.Address)
|
||||
if len(match) == 0 {
|
||||
return nil, fmt.Errorf("bad Address format %q", dstEntryPoint.Address)
|
||||
}
|
||||
|
||||
protocol := "http"
|
||||
if dstEntryPoint.TLS != nil {
|
||||
protocol = "https"
|
||||
}
|
||||
|
||||
replacement := protocol + "://${1}" + match[0] + "${2}"
|
||||
|
||||
return NewRegexHandler(defaultRedirectRegex, replacement, permanent)
|
||||
type redirect struct {
|
||||
next http.Handler
|
||||
regex *regexp.Regexp
|
||||
replacement string
|
||||
permanent bool
|
||||
errHandler utils.ErrorHandler
|
||||
name string
|
||||
}
|
||||
|
||||
// NewRegexHandler create a new redirection handler base on regex
|
||||
func NewRegexHandler(exp string, replacement string, permanent bool) (negroni.Handler, error) {
|
||||
re, err := regexp.Compile(exp)
|
||||
// New creates a redirect middleware.
|
||||
func New(ctx context.Context, next http.Handler, config config.Redirect, name string) (http.Handler, error) {
|
||||
logger := middlewares.GetLogger(ctx, name, typeName)
|
||||
logger.Debug("Creating middleware")
|
||||
logger.Debugf("Setting up redirect %s -> %s", config.Regex, config.Replacement)
|
||||
|
||||
re, err := regexp.Compile(config.Regex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &handler{
|
||||
regexp: re,
|
||||
replacement: replacement,
|
||||
permanent: permanent,
|
||||
return &redirect{
|
||||
regex: re,
|
||||
replacement: config.Replacement,
|
||||
permanent: config.Permanent,
|
||||
errHandler: utils.DefaultHandler,
|
||||
next: next,
|
||||
name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type handler struct {
|
||||
regexp *regexp.Regexp
|
||||
replacement string
|
||||
permanent bool
|
||||
errHandler utils.ErrorHandler
|
||||
func (r *redirect) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return r.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (h *handler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
|
||||
func (r *redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
oldURL := rawURL(req)
|
||||
|
||||
// only continue if the Regexp param matches the URL
|
||||
if !h.regexp.MatchString(oldURL) {
|
||||
next.ServeHTTP(rw, req)
|
||||
// If the Regexp doesn't match, skip to the next handler
|
||||
if !r.regex.MatchString(oldURL) {
|
||||
r.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
// apply a rewrite regexp to the URL
|
||||
newURL := h.regexp.ReplaceAllString(oldURL, h.replacement)
|
||||
newURL := r.regex.ReplaceAllString(oldURL, r.replacement)
|
||||
|
||||
// replace any variables that may be in there
|
||||
rewrittenURL := &bytes.Buffer{}
|
||||
if err := applyString(newURL, rewrittenURL, req); err != nil {
|
||||
h.errHandler.ServeHTTP(rw, req, err)
|
||||
r.errHandler.ServeHTTP(rw, req, err)
|
||||
return
|
||||
}
|
||||
|
||||
// parse the rewritten URL and replace request URL with it
|
||||
parsedURL, err := url.Parse(rewrittenURL.String())
|
||||
if err != nil {
|
||||
h.errHandler.ServeHTTP(rw, req, err)
|
||||
r.errHandler.ServeHTTP(rw, req, err)
|
||||
return
|
||||
}
|
||||
|
||||
if stripPrefix, stripPrefixOk := req.Context().Value(middlewares.StripPrefixKey).(string); stripPrefixOk {
|
||||
if len(stripPrefix) > 0 {
|
||||
parsedURL.Path = stripPrefix
|
||||
}
|
||||
}
|
||||
|
||||
if addPrefix, addPrefixOk := req.Context().Value(middlewares.AddPrefixKey).(string); addPrefixOk {
|
||||
if len(addPrefix) > 0 {
|
||||
parsedURL.Path = strings.Replace(parsedURL.Path, addPrefix, "", 1)
|
||||
}
|
||||
}
|
||||
|
||||
if replacePath, replacePathOk := req.Context().Value(middlewares.ReplacePathKey).(string); replacePathOk {
|
||||
if len(replacePath) > 0 {
|
||||
parsedURL.Path = replacePath
|
||||
}
|
||||
}
|
||||
|
||||
if newURL != oldURL {
|
||||
handler := &moveHandler{location: parsedURL, permanent: h.permanent}
|
||||
handler := &moveHandler{location: parsedURL, permanent: r.permanent}
|
||||
handler.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
@ -114,7 +91,7 @@ func (h *handler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http
|
|||
|
||||
// make sure the request URI corresponds the rewritten URL
|
||||
req.RequestURI = req.URL.RequestURI()
|
||||
next.ServeHTTP(rw, req)
|
||||
r.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
type moveHandler struct {
|
||||
|
@ -124,21 +101,25 @@ type moveHandler struct {
|
|||
|
||||
func (m *moveHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("Location", m.location.String())
|
||||
|
||||
status := http.StatusFound
|
||||
if m.permanent {
|
||||
status = http.StatusMovedPermanently
|
||||
}
|
||||
rw.WriteHeader(status)
|
||||
rw.Write([]byte(http.StatusText(status)))
|
||||
_, err := rw.Write([]byte(http.StatusText(status)))
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func rawURL(request *http.Request) string {
|
||||
func rawURL(req *http.Request) string {
|
||||
scheme := "http"
|
||||
if request.TLS != nil || isXForwardedHTTPS(request) {
|
||||
if req.TLS != nil || isXForwardedHTTPS(req) {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
return strings.Join([]string{scheme, "://", request.Host, request.RequestURI}, "")
|
||||
return strings.Join([]string{scheme, "://", req.Host, req.RequestURI}, "")
|
||||
}
|
||||
|
||||
func isXForwardedHTTPS(request *http.Request) bool {
|
||||
|
@ -147,17 +128,13 @@ func isXForwardedHTTPS(request *http.Request) bool {
|
|||
return len(xForwardedProto) > 0 && xForwardedProto == "https"
|
||||
}
|
||||
|
||||
func applyString(in string, out io.Writer, request *http.Request) error {
|
||||
func applyString(in string, out io.Writer, req *http.Request) error {
|
||||
t, err := template.New("t").Parse(in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data := struct {
|
||||
Request *http.Request
|
||||
}{
|
||||
Request: request,
|
||||
}
|
||||
data := struct{ Request *http.Request }{Request: req}
|
||||
|
||||
return t.Execute(out, data)
|
||||
}
|
||||
|
|
|
@ -1,146 +1,129 @@
|
|||
package redirect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/configuration"
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/containous/traefik/tls"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewEntryPointHandler(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
entryPoint *configuration.EntryPoint
|
||||
permanent bool
|
||||
url string
|
||||
expectedURL string
|
||||
expectedStatus int
|
||||
errorExpected bool
|
||||
}{
|
||||
{
|
||||
desc: "HTTP to HTTPS",
|
||||
entryPoint: &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}},
|
||||
url: "http://foo:80",
|
||||
expectedURL: "https://foo:443",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTPS to HTTP",
|
||||
entryPoint: &configuration.EntryPoint{Address: ":80"},
|
||||
url: "https://foo:443",
|
||||
expectedURL: "http://foo:80",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTP to HTTP",
|
||||
entryPoint: &configuration.EntryPoint{Address: ":88"},
|
||||
url: "http://foo:80",
|
||||
expectedURL: "http://foo:88",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTP to HTTPS permanent",
|
||||
entryPoint: &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}},
|
||||
permanent: true,
|
||||
url: "http://foo:80",
|
||||
expectedURL: "https://foo:443",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
desc: "HTTPS to HTTP permanent",
|
||||
entryPoint: &configuration.EntryPoint{Address: ":80"},
|
||||
permanent: true,
|
||||
url: "https://foo:443",
|
||||
expectedURL: "http://foo:80",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
desc: "invalid address",
|
||||
entryPoint: &configuration.EntryPoint{Address: ":foo", TLS: &tls.TLS{}},
|
||||
url: "http://foo:80",
|
||||
errorExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := NewEntryPointHandler(test.entryPoint, test.permanent)
|
||||
|
||||
if test.errorExpected {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
r := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
|
||||
handler.ServeHTTP(recorder, r, nil)
|
||||
|
||||
location, err := recorder.Result().Location()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.expectedURL, location.String())
|
||||
assert.Equal(t, test.expectedStatus, recorder.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRegexHandler(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
regex string
|
||||
replacement string
|
||||
permanent bool
|
||||
config config.Redirect
|
||||
url string
|
||||
expectedURL string
|
||||
expectedStatus int
|
||||
errorExpected bool
|
||||
secured bool
|
||||
}{
|
||||
{
|
||||
desc: "simple redirection",
|
||||
regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
|
||||
replacement: "https://${1}bar$2:443$4",
|
||||
desc: "simple redirection",
|
||||
config: config.Redirect{
|
||||
Regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
|
||||
Replacement: "https://${1}bar$2:443$4",
|
||||
},
|
||||
url: "http://foo.com:80",
|
||||
expectedURL: "https://foobar.com:443",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "use request header",
|
||||
regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
|
||||
replacement: `https://${1}{{ .Request.Header.Get "X-Foo" }}$2:443$4`,
|
||||
desc: "use request header",
|
||||
config: config.Redirect{
|
||||
Regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
|
||||
Replacement: `https://${1}{{ .Request.Header.Get "X-Foo" }}$2:443$4`,
|
||||
},
|
||||
url: "http://foo.com:80",
|
||||
expectedURL: "https://foobar.com:443",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "URL doesn't match regex",
|
||||
regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
|
||||
replacement: "https://${1}bar$2:443$4",
|
||||
desc: "URL doesn't match regex",
|
||||
config: config.Redirect{
|
||||
Regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
|
||||
Replacement: "https://${1}bar$2:443$4",
|
||||
},
|
||||
url: "http://bar.com:80",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
desc: "invalid rewritten URL",
|
||||
regex: `^(.*)$`,
|
||||
replacement: "http://192.168.0.%31/",
|
||||
desc: "invalid rewritten URL",
|
||||
config: config.Redirect{
|
||||
Regex: `^(.*)$`,
|
||||
Replacement: "http://192.168.0.%31/",
|
||||
},
|
||||
url: "http://foo.com:80",
|
||||
expectedStatus: http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
desc: "invalid regex",
|
||||
regex: `^(.*`,
|
||||
replacement: "$1",
|
||||
desc: "invalid regex",
|
||||
config: config.Redirect{
|
||||
Regex: `^(.*`,
|
||||
Replacement: "$1",
|
||||
},
|
||||
url: "http://foo.com:80",
|
||||
errorExpected: true,
|
||||
},
|
||||
{
|
||||
desc: "HTTP to HTTPS permanent",
|
||||
config: config.Redirect{
|
||||
Regex: `^http://`,
|
||||
Replacement: "https://$1",
|
||||
Permanent: true,
|
||||
},
|
||||
url: "http://foo",
|
||||
expectedURL: "https://foo",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
desc: "HTTPS to HTTP permanent",
|
||||
config: config.Redirect{
|
||||
Regex: `https://foo`,
|
||||
Replacement: "http://foo",
|
||||
Permanent: true,
|
||||
},
|
||||
secured: true,
|
||||
url: "https://foo",
|
||||
expectedURL: "http://foo",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
desc: "HTTP to HTTPS",
|
||||
config: config.Redirect{
|
||||
Regex: `http://foo:80`,
|
||||
Replacement: "https://foo:443",
|
||||
},
|
||||
url: "http://foo:80",
|
||||
expectedURL: "https://foo:443",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTPS to HTTP",
|
||||
config: config.Redirect{
|
||||
Regex: `https://foo:443`,
|
||||
Replacement: "http://foo:80",
|
||||
},
|
||||
secured: true,
|
||||
url: "https://foo:443",
|
||||
expectedURL: "http://foo:80",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTP to HTTP",
|
||||
config: config.Redirect{
|
||||
Regex: `http://foo:80`,
|
||||
Replacement: "http://foo:88",
|
||||
},
|
||||
url: "http://foo:80",
|
||||
expectedURL: "http://foo:88",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
|
@ -148,20 +131,23 @@ func TestNewRegexHandler(t *testing.T) {
|
|||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := NewRegexHandler(test.regex, test.replacement, test.permanent)
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
handler, err := New(context.Background(), next, test.config, "traefikTest")
|
||||
|
||||
if test.errorExpected {
|
||||
require.Nil(t, handler)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, handler)
|
||||
} else {
|
||||
require.NotNil(t, handler)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, handler)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
r := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
|
||||
if test.secured {
|
||||
r.TLS = &tls.ConnectionState{}
|
||||
}
|
||||
r.Header.Set("X-Foo", "bar")
|
||||
next := func(rw http.ResponseWriter, req *http.Request) {}
|
||||
handler.ServeHTTP(recorder, r, next)
|
||||
handler.ServeHTTP(recorder, r)
|
||||
|
||||
if test.expectedStatus == http.StatusMovedPermanently || test.expectedStatus == http.StatusFound {
|
||||
assert.Equal(t, test.expectedStatus, recorder.Code)
|
||||
|
|
|
@ -1,28 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
// ReplacePathKey is the key within the request context used to
|
||||
// store the replaced path
|
||||
ReplacePathKey key = "ReplacePath"
|
||||
// ReplacedPathHeader is the default header to set the old path to
|
||||
ReplacedPathHeader = "X-Replaced-Path"
|
||||
)
|
||||
|
||||
// ReplacePath is a middleware used to replace the path of a URL request
|
||||
type ReplacePath struct {
|
||||
Handler http.Handler
|
||||
Path string
|
||||
}
|
||||
|
||||
func (s *ReplacePath) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
r = r.WithContext(context.WithValue(r.Context(), ReplacePathKey, r.URL.Path))
|
||||
r.Header.Add(ReplacedPathHeader, r.URL.Path)
|
||||
r.URL.Path = s.Path
|
||||
r.RequestURI = r.URL.RequestURI()
|
||||
s.Handler.ServeHTTP(w, r)
|
||||
}
|
|
@ -1,40 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/log"
|
||||
)
|
||||
|
||||
// ReplacePathRegex is a middleware used to replace the path of a URL request with a regular expression
|
||||
type ReplacePathRegex struct {
|
||||
Handler http.Handler
|
||||
Regexp *regexp.Regexp
|
||||
Replacement string
|
||||
}
|
||||
|
||||
// NewReplacePathRegexHandler returns a new ReplacePathRegex
|
||||
func NewReplacePathRegexHandler(regex string, replacement string, handler http.Handler) http.Handler {
|
||||
exp, err := regexp.Compile(strings.TrimSpace(regex))
|
||||
if err != nil {
|
||||
log.Errorf("Error compiling regular expression %s: %s", regex, err)
|
||||
}
|
||||
return &ReplacePathRegex{
|
||||
Regexp: exp,
|
||||
Replacement: strings.TrimSpace(replacement),
|
||||
Handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ReplacePathRegex) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if s.Regexp != nil && len(s.Replacement) > 0 && s.Regexp.MatchString(r.URL.Path) {
|
||||
r = r.WithContext(context.WithValue(r.Context(), ReplacePathKey, r.URL.Path))
|
||||
r.Header.Add(ReplacedPathHeader, r.URL.Path)
|
||||
r.URL.Path = s.Regexp.ReplaceAllString(r.URL.Path, s.Replacement)
|
||||
r.RequestURI = r.URL.RequestURI()
|
||||
}
|
||||
s.Handler.ServeHTTP(w, r)
|
||||
}
|
|
@ -1,80 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestReplacePathRegex(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
path string
|
||||
replacement string
|
||||
regex string
|
||||
expectedPath string
|
||||
expectedHeader string
|
||||
}{
|
||||
{
|
||||
desc: "simple regex",
|
||||
path: "/whoami/and/whoami",
|
||||
replacement: "/who-am-i/$1",
|
||||
regex: `^/whoami/(.*)`,
|
||||
expectedPath: "/who-am-i/and/whoami",
|
||||
expectedHeader: "/whoami/and/whoami",
|
||||
},
|
||||
{
|
||||
desc: "simple replace (no regex)",
|
||||
path: "/whoami/and/whoami",
|
||||
replacement: "/who-am-i",
|
||||
regex: `/whoami`,
|
||||
expectedPath: "/who-am-i/and/who-am-i",
|
||||
expectedHeader: "/whoami/and/whoami",
|
||||
},
|
||||
{
|
||||
desc: "multiple replacement",
|
||||
path: "/downloads/src/source.go",
|
||||
replacement: "/downloads/$1-$2",
|
||||
regex: `^(?i)/downloads/([^/]+)/([^/]+)$`,
|
||||
expectedPath: "/downloads/src-source.go",
|
||||
expectedHeader: "/downloads/src/source.go",
|
||||
},
|
||||
{
|
||||
desc: "invalid regular expression",
|
||||
path: "/invalid/regexp/test",
|
||||
replacement: "/valid/regexp/$1",
|
||||
regex: `^(?err)/invalid/regexp/([^/]+)$`,
|
||||
expectedPath: "/invalid/regexp/test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var actualPath, actualHeader, requestURI string
|
||||
handler := NewReplacePathRegexHandler(
|
||||
test.regex,
|
||||
test.replacement,
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
actualPath = r.URL.Path
|
||||
actualHeader = r.Header.Get(ReplacedPathHeader)
|
||||
requestURI = r.RequestURI
|
||||
}),
|
||||
)
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
|
||||
|
||||
handler.ServeHTTP(nil, req)
|
||||
|
||||
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
|
||||
assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", ReplacedPathHeader)
|
||||
if test.expectedHeader != "" {
|
||||
assert.Equal(t, actualPath, requestURI, "Unexpected request URI.")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
46
middlewares/replacepath/replace_path.go
Normal file
46
middlewares/replacepath/replace_path.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package replacepath
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
)
|
||||
|
||||
const (
|
||||
// ReplacedPathHeader is the default header to set the old path to.
|
||||
ReplacedPathHeader = "X-Replaced-Path"
|
||||
typeName = "ReplacePath"
|
||||
)
|
||||
|
||||
// ReplacePath is a middleware used to replace the path of a URL request.
|
||||
type replacePath struct {
|
||||
next http.Handler
|
||||
path string
|
||||
name string
|
||||
}
|
||||
|
||||
// New creates a new replace path middleware.
|
||||
func New(ctx context.Context, next http.Handler, config config.ReplacePath, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
|
||||
return &replacePath{
|
||||
next: next,
|
||||
path: config.Path,
|
||||
name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *replacePath) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return r.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (r *replacePath) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
req.Header.Add(ReplacedPathHeader, req.URL.Path)
|
||||
req.URL.Path = r.path
|
||||
req.RequestURI = req.URL.RequestURI()
|
||||
r.next.ServeHTTP(rw, req)
|
||||
}
|
|
@ -1,15 +1,20 @@
|
|||
package middlewares
|
||||
package replacepath
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestReplacePath(t *testing.T) {
|
||||
const replacementPath = "/replacement-path"
|
||||
var replacementConfig = config.ReplacePath{
|
||||
Path: "/replacement-path",
|
||||
}
|
||||
|
||||
paths := []string{
|
||||
"/example",
|
||||
|
@ -20,20 +25,20 @@ func TestReplacePath(t *testing.T) {
|
|||
t.Run(path, func(t *testing.T) {
|
||||
|
||||
var expectedPath, actualHeader, requestURI string
|
||||
handler := &ReplacePath{
|
||||
Path: replacementPath,
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
expectedPath = r.URL.Path
|
||||
actualHeader = r.Header.Get(ReplacedPathHeader)
|
||||
requestURI = r.RequestURI
|
||||
}),
|
||||
}
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
expectedPath = r.URL.Path
|
||||
actualHeader = r.Header.Get(ReplacedPathHeader)
|
||||
requestURI = r.RequestURI
|
||||
})
|
||||
|
||||
handler, err := New(context.Background(), next, replacementConfig, "foo-replace-path")
|
||||
require.NoError(t, err)
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+path, nil)
|
||||
|
||||
handler.ServeHTTP(nil, req)
|
||||
|
||||
assert.Equal(t, expectedPath, replacementPath, "Unexpected path.")
|
||||
assert.Equal(t, expectedPath, replacementConfig.Path, "Unexpected path.")
|
||||
assert.Equal(t, path, actualHeader, "Unexpected '%s' header.", ReplacedPathHeader)
|
||||
assert.Equal(t, expectedPath, requestURI, "Unexpected request URI.")
|
||||
})
|
57
middlewares/replacepathregex/replace_path_regex.go
Normal file
57
middlewares/replacepathregex/replace_path_regex.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package replacepathregex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/middlewares/replacepath"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "ReplacePathRegex"
|
||||
)
|
||||
|
||||
// ReplacePathRegex is a middleware used to replace the path of a URL request with a regular expression.
|
||||
type replacePathRegex struct {
|
||||
next http.Handler
|
||||
regexp *regexp.Regexp
|
||||
replacement string
|
||||
name string
|
||||
}
|
||||
|
||||
// New creates a new replace path regex middleware.
|
||||
func New(ctx context.Context, next http.Handler, config config.ReplacePathRegex, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
|
||||
exp, err := regexp.Compile(strings.TrimSpace(config.Regex))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error compiling regular expression %s: %s", config.Regex, err)
|
||||
}
|
||||
|
||||
return &replacePathRegex{
|
||||
regexp: exp,
|
||||
replacement: strings.TrimSpace(config.Replacement),
|
||||
next: next,
|
||||
name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (rp *replacePathRegex) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return rp.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (rp *replacePathRegex) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if rp.regexp != nil && len(rp.replacement) > 0 && rp.regexp.MatchString(req.URL.Path) {
|
||||
req.Header.Add(replacepath.ReplacedPathHeader, req.URL.Path)
|
||||
req.URL.Path = rp.regexp.ReplaceAllString(req.URL.Path, rp.replacement)
|
||||
req.RequestURI = req.URL.RequestURI()
|
||||
}
|
||||
rp.next.ServeHTTP(rw, req)
|
||||
}
|
104
middlewares/replacepathregex/replace_path_regex_test.go
Normal file
104
middlewares/replacepathregex/replace_path_regex_test.go
Normal file
|
@ -0,0 +1,104 @@
|
|||
package replacepathregex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares/replacepath"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestReplacePathRegex(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
path string
|
||||
config config.ReplacePathRegex
|
||||
expectedPath string
|
||||
expectedHeader string
|
||||
expectsError bool
|
||||
}{
|
||||
{
|
||||
desc: "simple regex",
|
||||
path: "/whoami/and/whoami",
|
||||
config: config.ReplacePathRegex{
|
||||
Replacement: "/who-am-i/$1",
|
||||
Regex: `^/whoami/(.*)`,
|
||||
},
|
||||
expectedPath: "/who-am-i/and/whoami",
|
||||
expectedHeader: "/whoami/and/whoami",
|
||||
},
|
||||
{
|
||||
desc: "simple replace (no regex)",
|
||||
path: "/whoami/and/whoami",
|
||||
config: config.ReplacePathRegex{
|
||||
Replacement: "/who-am-i",
|
||||
Regex: `/whoami`,
|
||||
},
|
||||
expectedPath: "/who-am-i/and/who-am-i",
|
||||
expectedHeader: "/whoami/and/whoami",
|
||||
},
|
||||
{
|
||||
desc: "no match",
|
||||
path: "/whoami/and/whoami",
|
||||
config: config.ReplacePathRegex{
|
||||
Replacement: "/whoami",
|
||||
Regex: `/no-match`,
|
||||
},
|
||||
expectedPath: "/whoami/and/whoami",
|
||||
},
|
||||
{
|
||||
desc: "multiple replacement",
|
||||
path: "/downloads/src/source.go",
|
||||
config: config.ReplacePathRegex{
|
||||
Replacement: "/downloads/$1-$2",
|
||||
Regex: `^(?i)/downloads/([^/]+)/([^/]+)$`,
|
||||
},
|
||||
expectedPath: "/downloads/src-source.go",
|
||||
expectedHeader: "/downloads/src/source.go",
|
||||
},
|
||||
{
|
||||
desc: "invalid regular expression",
|
||||
path: "/invalid/regexp/test",
|
||||
config: config.ReplacePathRegex{
|
||||
Replacement: "/valid/regexp/$1",
|
||||
Regex: `^(?err)/invalid/regexp/([^/]+)$`,
|
||||
},
|
||||
expectedPath: "/invalid/regexp/test",
|
||||
expectsError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
|
||||
var actualPath, actualHeader, requestURI string
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
actualPath = r.URL.Path
|
||||
actualHeader = r.Header.Get(replacepath.ReplacedPathHeader)
|
||||
requestURI = r.RequestURI
|
||||
})
|
||||
|
||||
handler, err := New(context.Background(), next, test.config, "foo-replace-path-regexp")
|
||||
if test.expectsError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
|
||||
req.RequestURI = test.path
|
||||
|
||||
handler.ServeHTTP(nil, req)
|
||||
|
||||
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
|
||||
assert.Equal(t, actualPath, requestURI, "Unexpected request URI.")
|
||||
if test.expectedHeader != "" {
|
||||
assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", replacepath.ReplacedPathHeader)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,42 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/types"
|
||||
)
|
||||
|
||||
var requestHostKey struct{}
|
||||
|
||||
// RequestHost is the struct for the middleware that adds the CanonicalDomain of the request Host into a context for later use.
|
||||
type RequestHost struct{}
|
||||
|
||||
func (rh *RequestHost) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
if next != nil {
|
||||
host := types.CanonicalDomain(parseHost(r.Host))
|
||||
next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), requestHostKey, host)))
|
||||
}
|
||||
}
|
||||
|
||||
func parseHost(addr string) string {
|
||||
if !strings.Contains(addr, ":") {
|
||||
return addr
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return addr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// GetCanonizedHost plucks the canonized host key from the request of a context that was put through the middleware
|
||||
func GetCanonizedHost(ctx context.Context) string {
|
||||
if val, ok := ctx.Value(requestHostKey).(string); ok {
|
||||
return val
|
||||
}
|
||||
return ""
|
||||
}
|
124
middlewares/requestdecorator/hostresolver.go
Normal file
124
middlewares/requestdecorator/hostresolver.go
Normal file
|
@ -0,0 +1,124 @@
|
|||
package requestdecorator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
type cnameResolv struct {
|
||||
TTL time.Duration
|
||||
Record string
|
||||
}
|
||||
|
||||
type byTTL []*cnameResolv
|
||||
|
||||
func (a byTTL) Len() int { return len(a) }
|
||||
func (a byTTL) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a byTTL) Less(i, j int) bool { return a[i].TTL > a[j].TTL }
|
||||
|
||||
// Resolver used for host resolver.
|
||||
type Resolver struct {
|
||||
CnameFlattening bool
|
||||
ResolvConfig string
|
||||
ResolvDepth int
|
||||
cache *cache.Cache
|
||||
}
|
||||
|
||||
// CNAMEFlatten check if CNAME record exists, flatten if possible.
|
||||
func (hr *Resolver) CNAMEFlatten(ctx context.Context, host string) string {
|
||||
if hr.cache == nil {
|
||||
hr.cache = cache.New(30*time.Minute, 5*time.Minute)
|
||||
}
|
||||
|
||||
result := host
|
||||
request := host
|
||||
|
||||
value, found := hr.cache.Get(host)
|
||||
if found {
|
||||
return value.(string)
|
||||
}
|
||||
|
||||
logger := log.FromContext(ctx)
|
||||
var cacheDuration = 0 * time.Second
|
||||
for depth := 0; depth < hr.ResolvDepth; depth++ {
|
||||
resolv, err := cnameResolve(ctx, request, hr.ResolvConfig)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
break
|
||||
}
|
||||
if resolv == nil {
|
||||
break
|
||||
}
|
||||
|
||||
result = resolv.Record
|
||||
if depth == 0 {
|
||||
cacheDuration = resolv.TTL
|
||||
}
|
||||
request = resolv.Record
|
||||
}
|
||||
|
||||
if err := hr.cache.Add(host, result, cacheDuration); err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// cnameResolve resolves CNAME if exists, and return with the highest TTL.
|
||||
func cnameResolve(ctx context.Context, host string, resolvPath string) (*cnameResolv, error) {
|
||||
config, err := dns.ClientConfigFromFile(resolvPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid resolver configuration file: %s", resolvPath)
|
||||
}
|
||||
|
||||
client := &dns.Client{Timeout: 30 * time.Second}
|
||||
|
||||
m := &dns.Msg{}
|
||||
m.SetQuestion(dns.Fqdn(host), dns.TypeCNAME)
|
||||
|
||||
var result []*cnameResolv
|
||||
for _, server := range config.Servers {
|
||||
tempRecord, err := getRecord(client, m, server, config.Port)
|
||||
if err != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to resolve host %s: %v", host, err)
|
||||
continue
|
||||
}
|
||||
result = append(result, tempRecord)
|
||||
}
|
||||
|
||||
if len(result) <= 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
sort.Sort(byTTL(result))
|
||||
return result[0], nil
|
||||
}
|
||||
|
||||
func getRecord(client *dns.Client, msg *dns.Msg, server string, port string) (*cnameResolv, error) {
|
||||
resp, _, err := client.Exchange(msg, net.JoinHostPort(server, port))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("exchange error for server %s: %v", server, err)
|
||||
}
|
||||
|
||||
if resp == nil || len(resp.Answer) == 0 {
|
||||
return nil, fmt.Errorf("empty answer for server %s", server)
|
||||
}
|
||||
|
||||
rr, ok := resp.Answer[0].(*dns.CNAME)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid response type for server %s", server)
|
||||
}
|
||||
|
||||
return &cnameResolv{
|
||||
TTL: time.Duration(rr.Hdr.Ttl) * time.Second,
|
||||
Record: strings.TrimSuffix(rr.Target, "."),
|
||||
}, nil
|
||||
}
|
51
middlewares/requestdecorator/hostresolver_test.go
Normal file
51
middlewares/requestdecorator/hostresolver_test.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
package requestdecorator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCNAMEFlatten(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
resolvFile string
|
||||
domain string
|
||||
expectedDomain string
|
||||
}{
|
||||
{
|
||||
desc: "host request is CNAME record",
|
||||
resolvFile: "/etc/resolv.conf",
|
||||
domain: "www.github.com",
|
||||
expectedDomain: "github.com",
|
||||
},
|
||||
{
|
||||
desc: "resolve file not found",
|
||||
resolvFile: "/etc/resolv.oops",
|
||||
domain: "www.github.com",
|
||||
expectedDomain: "www.github.com",
|
||||
},
|
||||
{
|
||||
desc: "host request is not CNAME record",
|
||||
resolvFile: "/etc/resolv.conf",
|
||||
domain: "github.com",
|
||||
expectedDomain: "github.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
hostResolver := &Resolver{
|
||||
ResolvConfig: test.resolvFile,
|
||||
ResolvDepth: 5,
|
||||
}
|
||||
|
||||
flatH := hostResolver.CNAMEFlatten(context.Background(), test.domain)
|
||||
assert.Equal(t, test.expectedDomain, flatH)
|
||||
})
|
||||
}
|
||||
}
|
88
middlewares/requestdecorator/request_decorator.go
Normal file
88
middlewares/requestdecorator/request_decorator.go
Normal file
|
@ -0,0 +1,88 @@
|
|||
package requestdecorator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/alice"
|
||||
"github.com/containous/traefik/config/static"
|
||||
"github.com/containous/traefik/old/types"
|
||||
)
|
||||
|
||||
const (
|
||||
canonicalKey key = "canonical"
|
||||
flattenKey key = "flatten"
|
||||
)
|
||||
|
||||
type key string
|
||||
|
||||
// RequestDecorator is the struct for the middleware that adds the CanonicalDomain of the request Host into a context for later use.
|
||||
type RequestDecorator struct {
|
||||
hostResolver *Resolver
|
||||
}
|
||||
|
||||
// New creates a new request host middleware.
|
||||
func New(hostResolverConfig *static.HostResolverConfig) *RequestDecorator {
|
||||
requestDecorator := &RequestDecorator{}
|
||||
if hostResolverConfig != nil {
|
||||
requestDecorator.hostResolver = &Resolver{
|
||||
CnameFlattening: hostResolverConfig.CnameFlattening,
|
||||
ResolvConfig: hostResolverConfig.ResolvConfig,
|
||||
ResolvDepth: hostResolverConfig.ResolvDepth,
|
||||
}
|
||||
}
|
||||
return requestDecorator
|
||||
}
|
||||
|
||||
func (r *RequestDecorator) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
|
||||
host := types.CanonicalDomain(parseHost(req.Host))
|
||||
reqt := req.WithContext(context.WithValue(req.Context(), canonicalKey, host))
|
||||
|
||||
if r.hostResolver != nil && r.hostResolver.CnameFlattening {
|
||||
flatHost := r.hostResolver.CNAMEFlatten(reqt.Context(), host)
|
||||
reqt = reqt.WithContext(context.WithValue(reqt.Context(), flattenKey, flatHost))
|
||||
}
|
||||
|
||||
next(rw, reqt)
|
||||
}
|
||||
|
||||
func parseHost(addr string) string {
|
||||
if !strings.Contains(addr, ":") {
|
||||
return addr
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return addr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// GetCanonizedHost retrieves the canonized host from the given context (previously stored in the request context by the middleware).
|
||||
func GetCanonizedHost(ctx context.Context) string {
|
||||
if val, ok := ctx.Value(canonicalKey).(string); ok {
|
||||
return val
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetCNAMEFlatten return the flat name if it is present in the context.
|
||||
func GetCNAMEFlatten(ctx context.Context) string {
|
||||
if val, ok := ctx.Value(flattenKey).(string); ok {
|
||||
return val
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// WrapHandler Wraps a ServeHTTP with next to an alice.Constructor.
|
||||
func WrapHandler(handler *RequestDecorator) alice.Constructor {
|
||||
return func(next http.Handler) (http.Handler, error) {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
handler.ServeHTTP(rw, req, next.ServeHTTP)
|
||||
}), nil
|
||||
}
|
||||
}
|
|
@ -1,9 +1,10 @@
|
|||
package middlewares
|
||||
package requestdecorator
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/config/static"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
@ -36,19 +37,69 @@ func TestRequestHost(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
rh := &RequestHost{}
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
host := GetCanonizedHost(r.Context())
|
||||
assert.Equal(t, test.expected, host)
|
||||
})
|
||||
|
||||
rh := New(nil)
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
|
||||
|
||||
rh.ServeHTTP(nil, req, next)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestFlattening(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
url string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "host with flattening",
|
||||
url: "http://www.github.com",
|
||||
expected: "github.com",
|
||||
},
|
||||
{
|
||||
desc: "host without flattening",
|
||||
url: "http://github.com",
|
||||
expected: "github.com",
|
||||
},
|
||||
{
|
||||
desc: "ip without flattening",
|
||||
url: "http://127.0.0.1",
|
||||
expected: "127.0.0.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
|
||||
|
||||
rh.ServeHTTP(nil, req, func(_ http.ResponseWriter, r *http.Request) {
|
||||
host := GetCanonizedHost(r.Context())
|
||||
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
host := GetCNAMEFlatten(r.Context())
|
||||
assert.Equal(t, test.expected, host)
|
||||
})
|
||||
|
||||
rh := New(
|
||||
&static.HostResolverConfig{
|
||||
CnameFlattening: true,
|
||||
ResolvConfig: "/etc/resolv.conf",
|
||||
ResolvDepth: 5,
|
||||
},
|
||||
)
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
|
||||
|
||||
rh.ServeHTTP(nil, req, next)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,173 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
|
||||
"github.com/containous/traefik/log"
|
||||
)
|
||||
|
||||
// Compile time validation that the response writer implements http interfaces correctly.
|
||||
var _ Stateful = &retryResponseWriterWithCloseNotify{}
|
||||
|
||||
// Retry is a middleware that retries requests
|
||||
type Retry struct {
|
||||
attempts int
|
||||
next http.Handler
|
||||
listener RetryListener
|
||||
}
|
||||
|
||||
// NewRetry returns a new Retry instance
|
||||
func NewRetry(attempts int, next http.Handler, listener RetryListener) *Retry {
|
||||
return &Retry{
|
||||
attempts: attempts,
|
||||
next: next,
|
||||
listener: listener,
|
||||
}
|
||||
}
|
||||
|
||||
func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
// if we might make multiple attempts, swap the body for an ioutil.NopCloser
|
||||
// cf https://github.com/containous/traefik/issues/1008
|
||||
if retry.attempts > 1 {
|
||||
body := r.Body
|
||||
if body == nil {
|
||||
body = http.NoBody
|
||||
}
|
||||
defer body.Close()
|
||||
r.Body = ioutil.NopCloser(body)
|
||||
}
|
||||
|
||||
attempts := 1
|
||||
for {
|
||||
attemptsExhausted := attempts >= retry.attempts
|
||||
|
||||
shouldRetry := !attemptsExhausted
|
||||
retryResponseWriter := newRetryResponseWriter(rw, shouldRetry)
|
||||
|
||||
// Disable retries when the backend already received request data
|
||||
trace := &httptrace.ClientTrace{
|
||||
WroteHeaders: func() {
|
||||
retryResponseWriter.DisableRetries()
|
||||
},
|
||||
WroteRequest: func(httptrace.WroteRequestInfo) {
|
||||
retryResponseWriter.DisableRetries()
|
||||
},
|
||||
}
|
||||
newCtx := httptrace.WithClientTrace(r.Context(), trace)
|
||||
|
||||
retry.next.ServeHTTP(retryResponseWriter, r.WithContext(newCtx))
|
||||
if !retryResponseWriter.ShouldRetry() {
|
||||
break
|
||||
}
|
||||
|
||||
attempts++
|
||||
log.Debugf("New attempt %d for request: %v", attempts, r.URL)
|
||||
retry.listener.Retried(r, attempts)
|
||||
}
|
||||
}
|
||||
|
||||
// RetryListener is used to inform about retry attempts.
|
||||
type RetryListener interface {
|
||||
// Retried will be called when a retry happens, with the request attempt passed to it.
|
||||
// For the first retry this will be attempt 2.
|
||||
Retried(req *http.Request, attempt int)
|
||||
}
|
||||
|
||||
// RetryListeners is a convenience type to construct a list of RetryListener and notify
|
||||
// each of them about a retry attempt.
|
||||
type RetryListeners []RetryListener
|
||||
|
||||
// Retried exists to implement the RetryListener interface. It calls Retried on each of its slice entries.
|
||||
func (l RetryListeners) Retried(req *http.Request, attempt int) {
|
||||
for _, retryListener := range l {
|
||||
retryListener.Retried(req, attempt)
|
||||
}
|
||||
}
|
||||
|
||||
type retryResponseWriter interface {
|
||||
http.ResponseWriter
|
||||
http.Flusher
|
||||
ShouldRetry() bool
|
||||
DisableRetries()
|
||||
}
|
||||
|
||||
func newRetryResponseWriter(rw http.ResponseWriter, shouldRetry bool) retryResponseWriter {
|
||||
responseWriter := &retryResponseWriterWithoutCloseNotify{
|
||||
responseWriter: rw,
|
||||
shouldRetry: shouldRetry,
|
||||
}
|
||||
if _, ok := rw.(http.CloseNotifier); ok {
|
||||
return &retryResponseWriterWithCloseNotify{responseWriter}
|
||||
}
|
||||
return responseWriter
|
||||
}
|
||||
|
||||
type retryResponseWriterWithoutCloseNotify struct {
|
||||
responseWriter http.ResponseWriter
|
||||
shouldRetry bool
|
||||
}
|
||||
|
||||
func (rr *retryResponseWriterWithoutCloseNotify) ShouldRetry() bool {
|
||||
return rr.shouldRetry
|
||||
}
|
||||
|
||||
func (rr *retryResponseWriterWithoutCloseNotify) DisableRetries() {
|
||||
rr.shouldRetry = false
|
||||
}
|
||||
|
||||
func (rr *retryResponseWriterWithoutCloseNotify) Header() http.Header {
|
||||
if rr.ShouldRetry() {
|
||||
return make(http.Header)
|
||||
}
|
||||
return rr.responseWriter.Header()
|
||||
}
|
||||
|
||||
func (rr *retryResponseWriterWithoutCloseNotify) Write(buf []byte) (int, error) {
|
||||
if rr.ShouldRetry() {
|
||||
return len(buf), nil
|
||||
}
|
||||
return rr.responseWriter.Write(buf)
|
||||
}
|
||||
|
||||
func (rr *retryResponseWriterWithoutCloseNotify) WriteHeader(code int) {
|
||||
if rr.ShouldRetry() && code == http.StatusServiceUnavailable {
|
||||
// We get a 503 HTTP Status Code when there is no backend server in the pool
|
||||
// to which the request could be sent. Also, note that rr.ShouldRetry()
|
||||
// will never return true in case there was a connection established to
|
||||
// the backend server and so we can be sure that the 503 was produced
|
||||
// inside Traefik already and we don't have to retry in this cases.
|
||||
rr.DisableRetries()
|
||||
}
|
||||
|
||||
if rr.ShouldRetry() {
|
||||
return
|
||||
}
|
||||
rr.responseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (rr *retryResponseWriterWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
hijacker, ok := rr.responseWriter.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("%T is not a http.Hijacker", rr.responseWriter)
|
||||
}
|
||||
return hijacker.Hijack()
|
||||
}
|
||||
|
||||
func (rr *retryResponseWriterWithoutCloseNotify) Flush() {
|
||||
if flusher, ok := rr.responseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
type retryResponseWriterWithCloseNotify struct {
|
||||
*retryResponseWriterWithoutCloseNotify
|
||||
}
|
||||
|
||||
func (rr *retryResponseWriterWithCloseNotify) CloseNotify() <-chan bool {
|
||||
return rr.responseWriter.(http.CloseNotifier).CloseNotify()
|
||||
}
|
194
middlewares/retry/retry.go
Normal file
194
middlewares/retry/retry.go
Normal file
|
@ -0,0 +1,194 @@
|
|||
package retry
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
)
|
||||
|
||||
// Compile time validation that the response writer implements http interfaces correctly.
|
||||
var _ middlewares.Stateful = &responseWriterWithCloseNotify{}
|
||||
|
||||
const (
|
||||
typeName = "Retry"
|
||||
)
|
||||
|
||||
// Listener is used to inform about retry attempts.
|
||||
type Listener interface {
|
||||
// Retried will be called when a retry happens, with the request attempt passed to it.
|
||||
// For the first retry this will be attempt 2.
|
||||
Retried(req *http.Request, attempt int)
|
||||
}
|
||||
|
||||
// Listeners is a convenience type to construct a list of Listener and notify
|
||||
// each of them about a retry attempt.
|
||||
type Listeners []Listener
|
||||
|
||||
// retry is a middleware that retries requests.
|
||||
type retry struct {
|
||||
attempts int
|
||||
next http.Handler
|
||||
listener Listener
|
||||
name string
|
||||
}
|
||||
|
||||
// New returns a new retry middleware.
|
||||
func New(ctx context.Context, next http.Handler, config config.Retry, listener Listener, name string) (http.Handler, error) {
|
||||
logger := middlewares.GetLogger(ctx, name, typeName)
|
||||
logger.Debug("Creating middleware")
|
||||
|
||||
if config.Attempts <= 0 {
|
||||
return nil, fmt.Errorf("incorrect (or empty) value for attempt (%d)", config.Attempts)
|
||||
}
|
||||
|
||||
return &retry{
|
||||
attempts: config.Attempts,
|
||||
next: next,
|
||||
listener: listener,
|
||||
name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *retry) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return r.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (r *retry) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// if we might make multiple attempts, swap the body for an ioutil.NopCloser
|
||||
// cf https://github.com/containous/traefik/issues/1008
|
||||
if r.attempts > 1 {
|
||||
body := req.Body
|
||||
defer body.Close()
|
||||
req.Body = ioutil.NopCloser(body)
|
||||
}
|
||||
|
||||
attempts := 1
|
||||
for {
|
||||
attemptsExhausted := attempts >= r.attempts
|
||||
shouldRetry := !attemptsExhausted
|
||||
retryResponseWriter := newResponseWriter(rw, shouldRetry)
|
||||
|
||||
// Disable retries when the backend already received request data
|
||||
trace := &httptrace.ClientTrace{
|
||||
WroteHeaders: func() {
|
||||
retryResponseWriter.DisableRetries()
|
||||
},
|
||||
WroteRequest: func(httptrace.WroteRequestInfo) {
|
||||
retryResponseWriter.DisableRetries()
|
||||
},
|
||||
}
|
||||
newCtx := httptrace.WithClientTrace(req.Context(), trace)
|
||||
|
||||
r.next.ServeHTTP(retryResponseWriter, req.WithContext(newCtx))
|
||||
|
||||
if !retryResponseWriter.ShouldRetry() {
|
||||
break
|
||||
}
|
||||
|
||||
attempts++
|
||||
logger := middlewares.GetLogger(req.Context(), r.name, typeName)
|
||||
logger.Debugf("New attempt %d for request: %v", attempts, req.URL)
|
||||
r.listener.Retried(req, attempts)
|
||||
}
|
||||
}
|
||||
|
||||
// Retried exists to implement the Listener interface. It calls Retried on each of its slice entries.
|
||||
func (l Listeners) Retried(req *http.Request, attempt int) {
|
||||
for _, listener := range l {
|
||||
listener.Retried(req, attempt)
|
||||
}
|
||||
}
|
||||
|
||||
type responseWriter interface {
|
||||
http.ResponseWriter
|
||||
http.Flusher
|
||||
ShouldRetry() bool
|
||||
DisableRetries()
|
||||
}
|
||||
|
||||
func newResponseWriter(rw http.ResponseWriter, shouldRetry bool) responseWriter {
|
||||
responseWriter := &responseWriterWithoutCloseNotify{
|
||||
responseWriter: rw,
|
||||
shouldRetry: shouldRetry,
|
||||
}
|
||||
if _, ok := rw.(http.CloseNotifier); ok {
|
||||
return &responseWriterWithCloseNotify{
|
||||
responseWriterWithoutCloseNotify: responseWriter,
|
||||
}
|
||||
}
|
||||
return responseWriter
|
||||
}
|
||||
|
||||
type responseWriterWithoutCloseNotify struct {
|
||||
responseWriter http.ResponseWriter
|
||||
shouldRetry bool
|
||||
}
|
||||
|
||||
func (r *responseWriterWithoutCloseNotify) ShouldRetry() bool {
|
||||
return r.shouldRetry
|
||||
}
|
||||
|
||||
func (r *responseWriterWithoutCloseNotify) DisableRetries() {
|
||||
r.shouldRetry = false
|
||||
}
|
||||
|
||||
func (r *responseWriterWithoutCloseNotify) Header() http.Header {
|
||||
if r.ShouldRetry() {
|
||||
return make(http.Header)
|
||||
}
|
||||
return r.responseWriter.Header()
|
||||
}
|
||||
|
||||
func (r *responseWriterWithoutCloseNotify) Write(buf []byte) (int, error) {
|
||||
if r.ShouldRetry() {
|
||||
return len(buf), nil
|
||||
}
|
||||
return r.responseWriter.Write(buf)
|
||||
}
|
||||
|
||||
func (r *responseWriterWithoutCloseNotify) WriteHeader(code int) {
|
||||
if r.ShouldRetry() && code == http.StatusServiceUnavailable {
|
||||
// We get a 503 HTTP Status Code when there is no backend server in the pool
|
||||
// to which the request could be sent. Also, note that r.ShouldRetry()
|
||||
// will never return true in case there was a connection established to
|
||||
// the backend server and so we can be sure that the 503 was produced
|
||||
// inside Traefik already and we don't have to retry in this cases.
|
||||
r.DisableRetries()
|
||||
}
|
||||
|
||||
if r.ShouldRetry() {
|
||||
return
|
||||
}
|
||||
r.responseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (r *responseWriterWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
hijacker, ok := r.responseWriter.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("%T is not a http.Hijacker", r.responseWriter)
|
||||
}
|
||||
return hijacker.Hijack()
|
||||
}
|
||||
|
||||
func (r *responseWriterWithoutCloseNotify) Flush() {
|
||||
if flusher, ok := r.responseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
type responseWriterWithCloseNotify struct {
|
||||
*responseWriterWithoutCloseNotify
|
||||
}
|
||||
|
||||
func (r *responseWriterWithCloseNotify) CloseNotify() <-chan bool {
|
||||
return r.responseWriter.(http.CloseNotifier).CloseNotify()
|
||||
}
|
|
@ -1,11 +1,14 @@
|
|||
package middlewares
|
||||
package retry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares/emptybackendhandler"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -17,42 +20,42 @@ import (
|
|||
func TestRetry(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
maxRequestAttempts int
|
||||
config config.Retry
|
||||
wantRetryAttempts int
|
||||
wantResponseStatus int
|
||||
amountFaultyEndpoints int
|
||||
}{
|
||||
{
|
||||
desc: "no retry on success",
|
||||
maxRequestAttempts: 1,
|
||||
config: config.Retry{Attempts: 1},
|
||||
wantRetryAttempts: 0,
|
||||
wantResponseStatus: http.StatusOK,
|
||||
amountFaultyEndpoints: 0,
|
||||
},
|
||||
{
|
||||
desc: "no retry when max request attempts is one",
|
||||
maxRequestAttempts: 1,
|
||||
config: config.Retry{Attempts: 1},
|
||||
wantRetryAttempts: 0,
|
||||
wantResponseStatus: http.StatusInternalServerError,
|
||||
amountFaultyEndpoints: 1,
|
||||
},
|
||||
{
|
||||
desc: "one retry when one server is faulty",
|
||||
maxRequestAttempts: 2,
|
||||
config: config.Retry{Attempts: 2},
|
||||
wantRetryAttempts: 1,
|
||||
wantResponseStatus: http.StatusOK,
|
||||
amountFaultyEndpoints: 1,
|
||||
},
|
||||
{
|
||||
desc: "two retries when two servers are faulty",
|
||||
maxRequestAttempts: 3,
|
||||
config: config.Retry{Attempts: 3},
|
||||
wantRetryAttempts: 2,
|
||||
wantResponseStatus: http.StatusOK,
|
||||
amountFaultyEndpoints: 2,
|
||||
},
|
||||
{
|
||||
desc: "max attempts exhausted delivers the 5xx response",
|
||||
maxRequestAttempts: 3,
|
||||
config: config.Retry{Attempts: 3},
|
||||
wantRetryAttempts: 2,
|
||||
wantResponseStatus: http.StatusInternalServerError,
|
||||
amountFaultyEndpoints: 3,
|
||||
|
@ -61,24 +64,22 @@ func TestRetry(t *testing.T) {
|
|||
|
||||
backendServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
rw.Write([]byte("OK"))
|
||||
_, err := rw.Write([]byte("OK"))
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
|
||||
forwarder, err := forward.New()
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating forwarder: %s", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
loadBalancer, err := roundrobin.New(forwarder)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating load balancer: %s", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
basePort := 33444
|
||||
for i := 0; i < test.amountFaultyEndpoints; i++ {
|
||||
|
@ -87,15 +88,16 @@ func TestRetry(t *testing.T) {
|
|||
// We only use the port specification here because the URL is used as identifier
|
||||
// in the load balancer and using the exact same URL would not add a new server.
|
||||
err = loadBalancer.UpsertServer(testhelpers.MustParseURL("http://192.0.2.0:" + string(basePort+i)))
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// add the functioning server to the end of the load balancer list
|
||||
err = loadBalancer.UpsertServer(testhelpers.MustParseURL(backendServer.URL))
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
retryListener := &countingRetryListener{}
|
||||
retry := NewRetry(test.maxRequestAttempts, loadBalancer, retryListener)
|
||||
retry, err := New(context.Background(), loadBalancer, test.config, retryListener, "traefikTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil)
|
||||
|
@ -108,6 +110,78 @@ func TestRetry(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestRetryEmptyServerList(t *testing.T) {
|
||||
forwarder, err := forward.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
loadBalancer, err := roundrobin.New(forwarder)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The EmptyBackend middleware ensures that there is a 503
|
||||
// response status set when there is no backend server in the pool.
|
||||
next := emptybackendhandler.New(loadBalancer)
|
||||
|
||||
retryListener := &countingRetryListener{}
|
||||
retry, err := New(context.Background(), next, config.Retry{Attempts: 3}, retryListener, "traefikTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil)
|
||||
|
||||
retry.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
|
||||
assert.Equal(t, 0, retryListener.timesCalled)
|
||||
}
|
||||
|
||||
func TestRetryListeners(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
retryListeners := Listeners{&countingRetryListener{}, &countingRetryListener{}}
|
||||
|
||||
retryListeners.Retried(req, 1)
|
||||
retryListeners.Retried(req, 1)
|
||||
|
||||
for _, retryListener := range retryListeners {
|
||||
listener := retryListener.(*countingRetryListener)
|
||||
if listener.timesCalled != 2 {
|
||||
t.Errorf("retry listener was called %d time(s), want %d time(s)", listener.timesCalled, 2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// countingRetryListener is a Listener implementation to count the times the Retried fn is called.
|
||||
type countingRetryListener struct {
|
||||
timesCalled int
|
||||
}
|
||||
|
||||
func (l *countingRetryListener) Retried(req *http.Request, attempt int) {
|
||||
l.timesCalled++
|
||||
}
|
||||
|
||||
func TestRetryWithFlush(t *testing.T) {
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(200)
|
||||
_, err := rw.Write([]byte("FULL "))
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
rw.(http.Flusher).Flush()
|
||||
_, err = rw.Write([]byte("DATA"))
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
retry, err := New(context.Background(), next, config.Retry{Attempts: 1}, &countingRetryListener{}, "traefikTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
retry.ServeHTTP(responseRecorder, &http.Request{})
|
||||
|
||||
assert.Equal(t, "FULL DATA", responseRecorder.Body.String())
|
||||
}
|
||||
|
||||
func TestRetryWebsocket(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
|
@ -141,7 +215,10 @@ func TestRetryWebsocket(t *testing.T) {
|
|||
|
||||
backendServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
upgrader := websocket.Upgrader{}
|
||||
upgrader.Upgrade(rw, req, nil)
|
||||
_, err := upgrader.Upgrade(rw, req, nil)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
|
||||
for _, test := range testCases {
|
||||
|
@ -160,16 +237,17 @@ func TestRetryWebsocket(t *testing.T) {
|
|||
// See: https://stackoverflow.com/questions/528538/non-routable-ip-address/18436928#18436928
|
||||
// We only use the port specification here because the URL is used as identifier
|
||||
// in the load balancer and using the exact same URL would not add a new server.
|
||||
loadBalancer.UpsertServer(testhelpers.MustParseURL("http://192.0.2.0:" + string(basePort+i)))
|
||||
_ = loadBalancer.UpsertServer(testhelpers.MustParseURL("http://192.0.2.0:" + string(basePort+i)))
|
||||
}
|
||||
|
||||
// add the functioning server to the end of the load balancer list
|
||||
loadBalancer.UpsertServer(testhelpers.MustParseURL(backendServer.URL))
|
||||
|
||||
retryListener := &countingRetryListener{}
|
||||
retry := NewRetry(test.maxRequestAttempts, loadBalancer, retryListener)
|
||||
retryH, err := New(context.Background(), loadBalancer, config.Retry{Attempts: test.maxRequestAttempts}, retryListener, "traefikTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
retryServer := httptest.NewServer(retry)
|
||||
retryServer := httptest.NewServer(retryH)
|
||||
|
||||
url := strings.Replace(retryServer.URL, "http", "ws", 1)
|
||||
_, response, err := websocket.DefaultDialer.Dial(url, nil)
|
||||
|
@ -183,78 +261,3 @@ func TestRetryWebsocket(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryEmptyServerList(t *testing.T) {
|
||||
forwarder, err := forward.New()
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating forwarder: %s", err)
|
||||
}
|
||||
|
||||
loadBalancer, err := roundrobin.New(forwarder)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating load balancer: %s", err)
|
||||
}
|
||||
|
||||
// The EmptyBackendHandler middleware ensures that there is a 503
|
||||
// response status set when there is no backend server in the pool.
|
||||
next := NewEmptyBackendHandler(loadBalancer)
|
||||
|
||||
retryListener := &countingRetryListener{}
|
||||
retry := NewRetry(3, next, retryListener)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil)
|
||||
|
||||
retry.ServeHTTP(recorder, req)
|
||||
|
||||
const wantResponseStatus = http.StatusServiceUnavailable
|
||||
if wantResponseStatus != recorder.Code {
|
||||
t.Errorf("got status code %d, want %d", recorder.Code, wantResponseStatus)
|
||||
}
|
||||
const wantRetryAttempts = 0
|
||||
if wantRetryAttempts != retryListener.timesCalled {
|
||||
t.Errorf("retry listener called %d time(s), want %d time(s)", retryListener.timesCalled, wantRetryAttempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryListeners(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
retryListeners := RetryListeners{&countingRetryListener{}, &countingRetryListener{}}
|
||||
|
||||
retryListeners.Retried(req, 1)
|
||||
retryListeners.Retried(req, 1)
|
||||
|
||||
for _, retryListener := range retryListeners {
|
||||
listener := retryListener.(*countingRetryListener)
|
||||
if listener.timesCalled != 2 {
|
||||
t.Errorf("retry listener was called %d time(s), want %d time(s)", listener.timesCalled, 2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// countingRetryListener is a RetryListener implementation to count the times the Retried fn is called.
|
||||
type countingRetryListener struct {
|
||||
timesCalled int
|
||||
}
|
||||
|
||||
func (l *countingRetryListener) Retried(req *http.Request, attempt int) {
|
||||
l.timesCalled++
|
||||
}
|
||||
|
||||
func TestRetryWithFlush(t *testing.T) {
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(200)
|
||||
rw.Write([]byte("FULL "))
|
||||
rw.(http.Flusher).Flush()
|
||||
rw.Write([]byte("DATA"))
|
||||
})
|
||||
|
||||
retry := NewRetry(1, next, &countingRetryListener{})
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
retry.ServeHTTP(responseRecorder, &http.Request{})
|
||||
|
||||
if responseRecorder.Body.String() != "FULL DATA" {
|
||||
t.Errorf("Wrong body %q want %q", responseRecorder.Body.String(), "FULL DATA")
|
||||
}
|
||||
}
|
|
@ -1,28 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/mux"
|
||||
)
|
||||
|
||||
// Routes holds the gorilla mux routes (for the API & co).
|
||||
type Routes struct {
|
||||
router *mux.Router
|
||||
}
|
||||
|
||||
// NewRoutes return a Routes based on the given router.
|
||||
func NewRoutes(router *mux.Router) *Routes {
|
||||
return &Routes{router}
|
||||
}
|
||||
|
||||
func (router *Routes) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
routeMatch := mux.RouteMatch{}
|
||||
if router.router.Match(r, &routeMatch) {
|
||||
rt, _ := json.Marshal(routeMatch.Handler)
|
||||
log.Println("Request match route ", rt)
|
||||
}
|
||||
next(rw, r)
|
||||
}
|
|
@ -1,36 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"github.com/containous/traefik/types"
|
||||
"github.com/unrolled/secure"
|
||||
)
|
||||
|
||||
// NewSecure constructs a new Secure instance with supplied options.
|
||||
func NewSecure(headers *types.Headers) *secure.Secure {
|
||||
if headers == nil || !headers.HasSecureHeadersDefined() {
|
||||
return nil
|
||||
}
|
||||
|
||||
opt := secure.Options{
|
||||
AllowedHosts: headers.AllowedHosts,
|
||||
HostsProxyHeaders: headers.HostsProxyHeaders,
|
||||
SSLRedirect: headers.SSLRedirect,
|
||||
SSLTemporaryRedirect: headers.SSLTemporaryRedirect,
|
||||
SSLHost: headers.SSLHost,
|
||||
SSLProxyHeaders: headers.SSLProxyHeaders,
|
||||
STSSeconds: headers.STSSeconds,
|
||||
STSIncludeSubdomains: headers.STSIncludeSubdomains,
|
||||
STSPreload: headers.STSPreload,
|
||||
ForceSTSHeader: headers.ForceSTSHeader,
|
||||
FrameDeny: headers.FrameDeny,
|
||||
CustomFrameOptionsValue: headers.CustomFrameOptionsValue,
|
||||
ContentTypeNosniff: headers.ContentTypeNosniff,
|
||||
BrowserXssFilter: headers.BrowserXSSFilter,
|
||||
CustomBrowserXssValue: headers.CustomBrowserXSSValue,
|
||||
ContentSecurityPolicy: headers.ContentSecurityPolicy,
|
||||
PublicKey: headers.PublicKey,
|
||||
ReferrerPolicy: headers.ReferrerPolicy,
|
||||
IsDevelopment: headers.IsDevelopment,
|
||||
}
|
||||
return secure.New(opt)
|
||||
}
|
|
@ -1,115 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
_ Stateful = &responseRecorder{}
|
||||
)
|
||||
|
||||
// StatsRecorder is an optional middleware that records more details statistics
|
||||
// about requests and how they are processed. This currently consists of recent
|
||||
// requests that have caused errors (4xx and 5xx status codes), making it easy
|
||||
// to pinpoint problems.
|
||||
type StatsRecorder struct {
|
||||
mutex sync.RWMutex
|
||||
numRecentErrors int
|
||||
recentErrors []*statsError
|
||||
}
|
||||
|
||||
// NewStatsRecorder returns a new StatsRecorder
|
||||
func NewStatsRecorder(numRecentErrors int) *StatsRecorder {
|
||||
return &StatsRecorder{
|
||||
numRecentErrors: numRecentErrors,
|
||||
}
|
||||
}
|
||||
|
||||
// Stats includes all of the stats gathered by the recorder.
|
||||
type Stats struct {
|
||||
RecentErrors []*statsError `json:"recent_errors"`
|
||||
}
|
||||
|
||||
// statsError represents an error that has occurred during request processing.
|
||||
type statsError struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Status string `json:"status"`
|
||||
Method string `json:"method"`
|
||||
Host string `json:"host"`
|
||||
Path string `json:"path"`
|
||||
Time time.Time `json:"time"`
|
||||
}
|
||||
|
||||
// responseRecorder captures information from the response and preserves it for
|
||||
// later analysis.
|
||||
type responseRecorder struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
// WriteHeader captures the status code for later retrieval.
|
||||
func (r *responseRecorder) WriteHeader(status int) {
|
||||
r.ResponseWriter.WriteHeader(status)
|
||||
r.statusCode = status
|
||||
}
|
||||
|
||||
// Hijack hijacks the connection
|
||||
func (r *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return r.ResponseWriter.(http.Hijacker).Hijack()
|
||||
}
|
||||
|
||||
// CloseNotify returns a channel that receives at most a
|
||||
// single value (true) when the client connection has gone
|
||||
// away.
|
||||
func (r *responseRecorder) CloseNotify() <-chan bool {
|
||||
return r.ResponseWriter.(http.CloseNotifier).CloseNotify()
|
||||
}
|
||||
|
||||
// Flush sends any buffered data to the client.
|
||||
func (r *responseRecorder) Flush() {
|
||||
r.ResponseWriter.(http.Flusher).Flush()
|
||||
}
|
||||
|
||||
// ServeHTTP silently extracts information from the request and response as it
|
||||
// is processed. If the response is 4xx or 5xx, add it to the list of 10 most
|
||||
// recent errors.
|
||||
func (s *StatsRecorder) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
recorder := &responseRecorder{w, http.StatusOK}
|
||||
next(recorder, r)
|
||||
if recorder.statusCode >= http.StatusBadRequest {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
s.recentErrors = append([]*statsError{
|
||||
{
|
||||
StatusCode: recorder.statusCode,
|
||||
Status: http.StatusText(recorder.statusCode),
|
||||
Method: r.Method,
|
||||
Host: r.Host,
|
||||
Path: r.URL.Path,
|
||||
Time: time.Now(),
|
||||
},
|
||||
}, s.recentErrors...)
|
||||
// Limit the size of the list to numRecentErrors
|
||||
if len(s.recentErrors) > s.numRecentErrors {
|
||||
s.recentErrors = s.recentErrors[:s.numRecentErrors]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Data returns a copy of the statistics that have been gathered.
|
||||
func (s *StatsRecorder) Data() *Stats {
|
||||
s.mutex.RLock()
|
||||
defer s.mutex.RUnlock()
|
||||
|
||||
// We can't return the slice directly or a race condition might develop
|
||||
recentErrors := make([]*statsError, len(s.recentErrors))
|
||||
copy(recentErrors, s.recentErrors)
|
||||
|
||||
return &Stats{
|
||||
RecentErrors: recentErrors,
|
||||
}
|
||||
}
|
|
@ -1,56 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// StripPrefixKey is the key within the request context used to
|
||||
// store the stripped prefix
|
||||
StripPrefixKey key = "StripPrefix"
|
||||
// ForwardedPrefixHeader is the default header to set prefix
|
||||
ForwardedPrefixHeader = "X-Forwarded-Prefix"
|
||||
)
|
||||
|
||||
// StripPrefix is a middleware used to strip prefix from an URL request
|
||||
type StripPrefix struct {
|
||||
Handler http.Handler
|
||||
Prefixes []string
|
||||
}
|
||||
|
||||
func (s *StripPrefix) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
for _, prefix := range s.Prefixes {
|
||||
if strings.HasPrefix(r.URL.Path, prefix) {
|
||||
rawReqPath := r.URL.Path
|
||||
r.URL.Path = stripPrefix(r.URL.Path, prefix)
|
||||
if r.URL.RawPath != "" {
|
||||
r.URL.RawPath = stripPrefix(r.URL.RawPath, prefix)
|
||||
}
|
||||
s.serveRequest(w, r, strings.TrimSpace(prefix), rawReqPath)
|
||||
return
|
||||
}
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
|
||||
func (s *StripPrefix) serveRequest(w http.ResponseWriter, r *http.Request, prefix string, rawReqPath string) {
|
||||
r = r.WithContext(context.WithValue(r.Context(), StripPrefixKey, rawReqPath))
|
||||
r.Header.Add(ForwardedPrefixHeader, prefix)
|
||||
r.RequestURI = r.URL.RequestURI()
|
||||
s.Handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// SetHandler sets handler
|
||||
func (s *StripPrefix) SetHandler(Handler http.Handler) {
|
||||
s.Handler = Handler
|
||||
}
|
||||
|
||||
func stripPrefix(s, prefix string) string {
|
||||
return ensureLeadingSlash(strings.TrimPrefix(s, prefix))
|
||||
}
|
||||
|
||||
func ensureLeadingSlash(str string) string {
|
||||
return "/" + strings.TrimPrefix(str, "/")
|
||||
}
|
|
@ -1,59 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/log"
|
||||
)
|
||||
|
||||
// StripPrefixRegex is a middleware used to strip prefix from an URL request
|
||||
type StripPrefixRegex struct {
|
||||
Handler http.Handler
|
||||
router *mux.Router
|
||||
}
|
||||
|
||||
// NewStripPrefixRegex builds a new StripPrefixRegex given a handler and prefixes
|
||||
func NewStripPrefixRegex(handler http.Handler, prefixes []string) *StripPrefixRegex {
|
||||
stripPrefix := StripPrefixRegex{Handler: handler, router: mux.NewRouter()}
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
stripPrefix.router.PathPrefix(prefix)
|
||||
}
|
||||
|
||||
return &stripPrefix
|
||||
}
|
||||
|
||||
func (s *StripPrefixRegex) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
var match mux.RouteMatch
|
||||
if s.router.Match(r, &match) {
|
||||
params := make([]string, 0, len(match.Vars)*2)
|
||||
for key, val := range match.Vars {
|
||||
params = append(params, key)
|
||||
params = append(params, val)
|
||||
}
|
||||
|
||||
prefix, err := match.Route.URL(params...)
|
||||
if err != nil || len(prefix.Path) > len(r.URL.Path) {
|
||||
log.Error("Error in stripPrefix middleware", err)
|
||||
return
|
||||
}
|
||||
rawReqPath := r.URL.Path
|
||||
r.URL.Path = r.URL.Path[len(prefix.Path):]
|
||||
if r.URL.RawPath != "" {
|
||||
r.URL.RawPath = r.URL.RawPath[len(prefix.Path):]
|
||||
}
|
||||
r = r.WithContext(context.WithValue(r.Context(), StripPrefixKey, rawReqPath))
|
||||
r.Header.Add(ForwardedPrefixHeader, prefix.Path)
|
||||
r.RequestURI = ensureLeadingSlash(r.URL.RequestURI())
|
||||
s.Handler.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
|
||||
// SetHandler sets handler
|
||||
func (s *StripPrefixRegex) SetHandler(Handler http.Handler) {
|
||||
s.Handler = Handler
|
||||
}
|
67
middlewares/stripprefix/strip_prefix.go
Normal file
67
middlewares/stripprefix/strip_prefix.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package stripprefix
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
)
|
||||
|
||||
const (
|
||||
// ForwardedPrefixHeader is the default header to set prefix.
|
||||
ForwardedPrefixHeader = "X-Forwarded-Prefix"
|
||||
typeName = "StripPrefix"
|
||||
)
|
||||
|
||||
// stripPrefix is a middleware used to strip prefix from an URL request.
|
||||
type stripPrefix struct {
|
||||
next http.Handler
|
||||
prefixes []string
|
||||
name string
|
||||
}
|
||||
|
||||
// New creates a new strip prefix middleware.
|
||||
func New(ctx context.Context, next http.Handler, config config.StripPrefix, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
return &stripPrefix{
|
||||
prefixes: config.Prefixes,
|
||||
next: next,
|
||||
name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *stripPrefix) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return s.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (s *stripPrefix) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
for _, prefix := range s.prefixes {
|
||||
if strings.HasPrefix(req.URL.Path, prefix) {
|
||||
req.URL.Path = getPrefixStripped(req.URL.Path, prefix)
|
||||
if req.URL.RawPath != "" {
|
||||
req.URL.RawPath = getPrefixStripped(req.URL.RawPath, prefix)
|
||||
}
|
||||
s.serveRequest(rw, req, strings.TrimSpace(prefix))
|
||||
return
|
||||
}
|
||||
}
|
||||
http.NotFound(rw, req)
|
||||
}
|
||||
|
||||
func (s *stripPrefix) serveRequest(rw http.ResponseWriter, req *http.Request, prefix string) {
|
||||
req.Header.Add(ForwardedPrefixHeader, prefix)
|
||||
req.RequestURI = req.URL.RequestURI()
|
||||
s.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
func getPrefixStripped(s, prefix string) string {
|
||||
return ensureLeadingSlash(strings.TrimPrefix(s, prefix))
|
||||
}
|
||||
|
||||
func ensureLeadingSlash(str string) string {
|
||||
return "/" + strings.TrimPrefix(str, "/")
|
||||
}
|
|
@ -1,18 +1,21 @@
|
|||
package middlewares
|
||||
package stripprefix
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStripPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
prefixes []string
|
||||
config config.StripPrefix
|
||||
path string
|
||||
expectedStatusCode int
|
||||
expectedPath string
|
||||
|
@ -20,84 +23,107 @@ func TestStripPrefix(t *testing.T) {
|
|||
expectedHeader string
|
||||
}{
|
||||
{
|
||||
desc: "no prefixes configured",
|
||||
prefixes: []string{},
|
||||
desc: "no prefixes configured",
|
||||
config: config.StripPrefix{
|
||||
Prefixes: []string{},
|
||||
},
|
||||
path: "/noprefixes",
|
||||
expectedStatusCode: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
desc: "wildcard (.*) requests",
|
||||
prefixes: []string{"/"},
|
||||
desc: "wildcard (.*) requests",
|
||||
config: config.StripPrefix{
|
||||
Prefixes: []string{"/"},
|
||||
},
|
||||
path: "/",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "/",
|
||||
expectedHeader: "/",
|
||||
},
|
||||
{
|
||||
desc: "prefix and path matching",
|
||||
prefixes: []string{"/stat"},
|
||||
desc: "prefix and path matching",
|
||||
config: config.StripPrefix{
|
||||
Prefixes: []string{"/stat"},
|
||||
},
|
||||
path: "/stat",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "/",
|
||||
expectedHeader: "/stat",
|
||||
},
|
||||
{
|
||||
desc: "path prefix on exactly matching path",
|
||||
prefixes: []string{"/stat/"},
|
||||
desc: "path prefix on exactly matching path",
|
||||
config: config.StripPrefix{
|
||||
Prefixes: []string{"/stat/"},
|
||||
},
|
||||
path: "/stat/",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "/",
|
||||
expectedHeader: "/stat/",
|
||||
},
|
||||
{
|
||||
desc: "path prefix on matching longer path",
|
||||
prefixes: []string{"/stat/"},
|
||||
desc: "path prefix on matching longer path",
|
||||
config: config.StripPrefix{
|
||||
Prefixes: []string{"/stat/"},
|
||||
},
|
||||
path: "/stat/us",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "/us",
|
||||
expectedHeader: "/stat/",
|
||||
},
|
||||
{
|
||||
desc: "path prefix on mismatching path",
|
||||
prefixes: []string{"/stat/"},
|
||||
desc: "path prefix on mismatching path",
|
||||
config: config.StripPrefix{
|
||||
Prefixes: []string{"/stat/"},
|
||||
},
|
||||
path: "/status",
|
||||
expectedStatusCode: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
desc: "general prefix on matching path",
|
||||
prefixes: []string{"/stat"},
|
||||
desc: "general prefix on matching path",
|
||||
config: config.StripPrefix{
|
||||
Prefixes: []string{"/stat"},
|
||||
},
|
||||
path: "/stat/",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "/",
|
||||
expectedHeader: "/stat",
|
||||
},
|
||||
{
|
||||
desc: "earlier prefix matching",
|
||||
prefixes: []string{"/stat", "/stat/us"},
|
||||
desc: "earlier prefix matching",
|
||||
config: config.StripPrefix{
|
||||
|
||||
Prefixes: []string{"/stat", "/stat/us"},
|
||||
},
|
||||
path: "/stat/us",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "/us",
|
||||
expectedHeader: "/stat",
|
||||
},
|
||||
{
|
||||
desc: "later prefix matching",
|
||||
prefixes: []string{"/mismatch", "/stat"},
|
||||
desc: "later prefix matching",
|
||||
config: config.StripPrefix{
|
||||
Prefixes: []string{"/mismatch", "/stat"},
|
||||
},
|
||||
path: "/stat",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "/",
|
||||
expectedHeader: "/stat",
|
||||
},
|
||||
{
|
||||
desc: "prefix matching within slash boundaries",
|
||||
prefixes: []string{"/stat"},
|
||||
desc: "prefix matching within slash boundaries",
|
||||
config: config.StripPrefix{
|
||||
Prefixes: []string{"/stat"},
|
||||
},
|
||||
path: "/status",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "/us",
|
||||
expectedHeader: "/stat",
|
||||
},
|
||||
{
|
||||
desc: "raw path is also stripped",
|
||||
prefixes: []string{"/stat"},
|
||||
desc: "raw path is also stripped",
|
||||
config: config.StripPrefix{
|
||||
Prefixes: []string{"/stat"},
|
||||
},
|
||||
path: "/stat/a%2Fb",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "/a/b",
|
||||
|
@ -106,21 +132,21 @@ func TestStripPrefix(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var actualPath, actualRawPath, actualHeader, requestURI string
|
||||
handler := &StripPrefix{
|
||||
Prefixes: test.prefixes,
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
actualPath = r.URL.Path
|
||||
actualRawPath = r.URL.RawPath
|
||||
actualHeader = r.Header.Get(ForwardedPrefixHeader)
|
||||
requestURI = r.RequestURI
|
||||
}),
|
||||
}
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
actualPath = r.URL.Path
|
||||
actualRawPath = r.URL.RawPath
|
||||
actualHeader = r.Header.Get(ForwardedPrefixHeader)
|
||||
requestURI = r.RequestURI
|
||||
})
|
||||
|
||||
handler, err := New(context.Background(), next, test.config, "foo-strip-prefix")
|
||||
require.NoError(t, err)
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
|
||||
resp := &httptest.ResponseRecorder{Code: http.StatusOK}
|
79
middlewares/stripprefixregex/strip_prefix_regex.go
Normal file
79
middlewares/stripprefixregex/strip_prefix_regex.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package stripprefixregex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/middlewares/stripprefix"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "StripPrefixRegex"
|
||||
)
|
||||
|
||||
// StripPrefixRegex is a middleware used to strip prefix from an URL request.
|
||||
type stripPrefixRegex struct {
|
||||
next http.Handler
|
||||
router *mux.Router
|
||||
name string
|
||||
}
|
||||
|
||||
// New builds a new StripPrefixRegex middleware.
|
||||
func New(ctx context.Context, next http.Handler, config config.StripPrefixRegex, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
|
||||
stripPrefix := stripPrefixRegex{
|
||||
next: next,
|
||||
router: mux.NewRouter(),
|
||||
name: name,
|
||||
}
|
||||
|
||||
for _, prefix := range config.Regex {
|
||||
stripPrefix.router.PathPrefix(prefix)
|
||||
}
|
||||
|
||||
return &stripPrefix, nil
|
||||
}
|
||||
|
||||
func (s *stripPrefixRegex) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return s.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (s *stripPrefixRegex) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
var match mux.RouteMatch
|
||||
if s.router.Match(req, &match) {
|
||||
params := make([]string, 0, len(match.Vars)*2)
|
||||
for key, val := range match.Vars {
|
||||
params = append(params, key)
|
||||
params = append(params, val)
|
||||
}
|
||||
|
||||
prefix, err := match.Route.URL(params...)
|
||||
if err != nil || len(prefix.Path) > len(req.URL.Path) {
|
||||
logger := middlewares.GetLogger(req.Context(), s.name, typeName)
|
||||
logger.Error("Error in stripPrefix middleware", err)
|
||||
return
|
||||
}
|
||||
|
||||
req.URL.Path = req.URL.Path[len(prefix.Path):]
|
||||
if req.URL.RawPath != "" {
|
||||
req.URL.RawPath = req.URL.RawPath[len(prefix.Path):]
|
||||
}
|
||||
req.Header.Add(stripprefix.ForwardedPrefixHeader, prefix.Path)
|
||||
req.RequestURI = ensureLeadingSlash(req.URL.RequestURI())
|
||||
|
||||
s.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
http.NotFound(rw, req)
|
||||
}
|
||||
|
||||
func ensureLeadingSlash(str string) string {
|
||||
return "/" + strings.TrimPrefix(str, "/")
|
||||
}
|
|
@ -1,18 +1,24 @@
|
|||
package middlewares
|
||||
package stripprefixregex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/middlewares/stripprefix"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStripPrefixRegex(t *testing.T) {
|
||||
testPrefixRegex := []string{"/a/api/", "/b/{regex}/", "/c/{category}/{id:[0-9]+}/"}
|
||||
testPrefixRegex := config.StripPrefixRegex{
|
||||
Regex: []string{"/a/api/", "/b/{regex}/", "/c/{category}/{id:[0-9]+}/"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
testCases := []struct {
|
||||
path string
|
||||
expectedStatusCode int
|
||||
expectedPath string
|
||||
|
@ -70,7 +76,7 @@ func TestStripPrefixRegex(t *testing.T) {
|
|||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.path, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
@ -79,9 +85,10 @@ func TestStripPrefixRegex(t *testing.T) {
|
|||
handlerPath := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
actualPath = r.URL.Path
|
||||
actualRawPath = r.URL.RawPath
|
||||
actualHeader = r.Header.Get(ForwardedPrefixHeader)
|
||||
actualHeader = r.Header.Get(stripprefix.ForwardedPrefixHeader)
|
||||
})
|
||||
handler := NewStripPrefixRegex(handlerPath, testPrefixRegex)
|
||||
handler, err := New(context.Background(), handlerPath, testPrefixRegex, "foo-strip-prefix-regex")
|
||||
require.NoError(t, err)
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
|
||||
resp := &httptest.ResponseRecorder{Code: http.StatusOK}
|
||||
|
@ -91,7 +98,7 @@ func TestStripPrefixRegex(t *testing.T) {
|
|||
assert.Equal(t, test.expectedStatusCode, resp.Code, "Unexpected status code.")
|
||||
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
|
||||
assert.Equal(t, test.expectedRawPath, actualRawPath, "Unexpected raw path.")
|
||||
assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", ForwardedPrefixHeader)
|
||||
assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", stripprefix.ForwardedPrefixHeader)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,251 +0,0 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/types"
|
||||
)
|
||||
|
||||
const xForwardedTLSClientCert = "X-Forwarded-Tls-Client-Cert"
|
||||
const xForwardedTLSClientCertInfos = "X-Forwarded-Tls-Client-Cert-Infos"
|
||||
|
||||
// TLSClientCertificateInfos is a struct for specifying the configuration for the tlsClientHeaders middleware.
|
||||
type TLSClientCertificateInfos struct {
|
||||
NotAfter bool
|
||||
NotBefore bool
|
||||
Subject *TLSCLientCertificateSubjectInfos
|
||||
Sans bool
|
||||
}
|
||||
|
||||
// TLSCLientCertificateSubjectInfos contains the configuration for the certificate subject infos.
|
||||
type TLSCLientCertificateSubjectInfos struct {
|
||||
Country bool
|
||||
Province bool
|
||||
Locality bool
|
||||
Organization bool
|
||||
CommonName bool
|
||||
SerialNumber bool
|
||||
}
|
||||
|
||||
// TLSClientHeaders is a middleware that helps setup a few tls infos features.
|
||||
type TLSClientHeaders struct {
|
||||
PEM bool // pass the sanitized pem to the backend in a specific header
|
||||
Infos *TLSClientCertificateInfos // pass selected informations from the client certificate
|
||||
}
|
||||
|
||||
func newTLSCLientCertificateSubjectInfos(infos *types.TLSCLientCertificateSubjectInfos) *TLSCLientCertificateSubjectInfos {
|
||||
if infos == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &TLSCLientCertificateSubjectInfos{
|
||||
SerialNumber: infos.SerialNumber,
|
||||
CommonName: infos.CommonName,
|
||||
Country: infos.Country,
|
||||
Locality: infos.Locality,
|
||||
Organization: infos.Organization,
|
||||
Province: infos.Province,
|
||||
}
|
||||
}
|
||||
|
||||
func newTLSClientInfos(infos *types.TLSClientCertificateInfos) *TLSClientCertificateInfos {
|
||||
if infos == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &TLSClientCertificateInfos{
|
||||
NotBefore: infos.NotBefore,
|
||||
NotAfter: infos.NotAfter,
|
||||
Sans: infos.Sans,
|
||||
Subject: newTLSCLientCertificateSubjectInfos(infos.Subject),
|
||||
}
|
||||
}
|
||||
|
||||
// NewTLSClientHeaders constructs a new TLSClientHeaders instance from supplied frontend header struct.
|
||||
func NewTLSClientHeaders(frontend *types.Frontend) *TLSClientHeaders {
|
||||
if frontend == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var pem bool
|
||||
var infos *TLSClientCertificateInfos
|
||||
|
||||
if frontend.PassTLSClientCert != nil {
|
||||
conf := frontend.PassTLSClientCert
|
||||
pem = conf.PEM
|
||||
infos = newTLSClientInfos(conf.Infos)
|
||||
}
|
||||
|
||||
return &TLSClientHeaders{
|
||||
PEM: pem,
|
||||
Infos: infos,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TLSClientHeaders) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
s.ModifyRequestHeaders(r)
|
||||
// If there is a next, call it.
|
||||
if next != nil {
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// sanitize As we pass the raw certificates, remove the useless data and make it http request compliant
|
||||
func sanitize(cert []byte) string {
|
||||
s := string(cert)
|
||||
r := strings.NewReplacer("-----BEGIN CERTIFICATE-----", "",
|
||||
"-----END CERTIFICATE-----", "",
|
||||
"\n", "")
|
||||
cleaned := r.Replace(s)
|
||||
|
||||
return url.QueryEscape(cleaned)
|
||||
}
|
||||
|
||||
// extractCertificate extract the certificate from the request
|
||||
func extractCertificate(cert *x509.Certificate) string {
|
||||
b := pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
|
||||
certPEM := pem.EncodeToMemory(&b)
|
||||
if certPEM == nil {
|
||||
log.Error("Cannot extract the certificate content")
|
||||
return ""
|
||||
}
|
||||
return sanitize(certPEM)
|
||||
}
|
||||
|
||||
// getXForwardedTLSClientCert Build a string with the client certificates
|
||||
func getXForwardedTLSClientCert(certs []*x509.Certificate) string {
|
||||
var headerValues []string
|
||||
|
||||
for _, peerCert := range certs {
|
||||
headerValues = append(headerValues, extractCertificate(peerCert))
|
||||
}
|
||||
|
||||
return strings.Join(headerValues, ",")
|
||||
}
|
||||
|
||||
// getSANs get the Subject Alternate Name values
|
||||
func getSANs(cert *x509.Certificate) []string {
|
||||
var sans []string
|
||||
if cert == nil {
|
||||
return sans
|
||||
}
|
||||
|
||||
sans = append(cert.DNSNames, cert.EmailAddresses...)
|
||||
|
||||
var ips []string
|
||||
for _, ip := range cert.IPAddresses {
|
||||
ips = append(ips, ip.String())
|
||||
}
|
||||
sans = append(sans, ips...)
|
||||
|
||||
var uris []string
|
||||
for _, uri := range cert.URIs {
|
||||
uris = append(uris, uri.String())
|
||||
}
|
||||
|
||||
return append(sans, uris...)
|
||||
}
|
||||
|
||||
// getSubjectInfos extract the requested informations from the certificate subject
|
||||
func (s *TLSClientHeaders) getSubjectInfos(cs *pkix.Name) string {
|
||||
var subject string
|
||||
|
||||
if s.Infos != nil && s.Infos.Subject != nil {
|
||||
options := s.Infos.Subject
|
||||
|
||||
var content []string
|
||||
|
||||
if options.Country && len(cs.Country) > 0 {
|
||||
content = append(content, fmt.Sprintf("C=%s", cs.Country[0]))
|
||||
}
|
||||
|
||||
if options.Province && len(cs.Province) > 0 {
|
||||
content = append(content, fmt.Sprintf("ST=%s", cs.Province[0]))
|
||||
}
|
||||
|
||||
if options.Locality && len(cs.Locality) > 0 {
|
||||
content = append(content, fmt.Sprintf("L=%s", cs.Locality[0]))
|
||||
}
|
||||
|
||||
if options.Organization && len(cs.Organization) > 0 {
|
||||
content = append(content, fmt.Sprintf("O=%s", cs.Organization[0]))
|
||||
}
|
||||
|
||||
if options.CommonName && len(cs.CommonName) > 0 {
|
||||
content = append(content, fmt.Sprintf("CN=%s", cs.CommonName))
|
||||
}
|
||||
|
||||
if len(content) > 0 {
|
||||
subject = `Subject="` + strings.Join(content, ",") + `"`
|
||||
}
|
||||
}
|
||||
|
||||
return subject
|
||||
}
|
||||
|
||||
// getXForwardedTLSClientCertInfos Build a string with the wanted client certificates informations
|
||||
// like Subject="C=%s,ST=%s,L=%s,O=%s,CN=%s",NB=%d,NA=%d,SAN=%s;
|
||||
func (s *TLSClientHeaders) getXForwardedTLSClientCertInfos(certs []*x509.Certificate) string {
|
||||
var headerValues []string
|
||||
|
||||
for _, peerCert := range certs {
|
||||
var values []string
|
||||
var sans string
|
||||
var nb string
|
||||
var na string
|
||||
|
||||
subject := s.getSubjectInfos(&peerCert.Subject)
|
||||
if len(subject) > 0 {
|
||||
values = append(values, subject)
|
||||
}
|
||||
|
||||
ci := s.Infos
|
||||
if ci != nil {
|
||||
if ci.NotBefore {
|
||||
nb = fmt.Sprintf("NB=%d", uint64(peerCert.NotBefore.Unix()))
|
||||
values = append(values, nb)
|
||||
}
|
||||
if ci.NotAfter {
|
||||
na = fmt.Sprintf("NA=%d", uint64(peerCert.NotAfter.Unix()))
|
||||
values = append(values, na)
|
||||
}
|
||||
|
||||
if ci.Sans {
|
||||
sans = fmt.Sprintf("SAN=%s", strings.Join(getSANs(peerCert), ","))
|
||||
values = append(values, sans)
|
||||
}
|
||||
}
|
||||
|
||||
value := strings.Join(values, ",")
|
||||
headerValues = append(headerValues, value)
|
||||
}
|
||||
|
||||
return strings.Join(headerValues, ";")
|
||||
}
|
||||
|
||||
// ModifyRequestHeaders set the wanted headers with the certificates informations
|
||||
func (s *TLSClientHeaders) ModifyRequestHeaders(r *http.Request) {
|
||||
if s.PEM {
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
r.Header.Set(xForwardedTLSClientCert, getXForwardedTLSClientCert(r.TLS.PeerCertificates))
|
||||
} else {
|
||||
log.Warn("Try to extract certificate on a request without TLS")
|
||||
}
|
||||
}
|
||||
|
||||
if s.Infos != nil {
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
headerContent := s.getXForwardedTLSClientCertInfos(r.TLS.PeerCertificates)
|
||||
r.Header.Set(xForwardedTLSClientCertInfos, url.QueryEscape(headerContent))
|
||||
} else {
|
||||
log.Warn("Try to extract certificate on a request without TLS")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
package tracing
|
||||
|
||||
import "net/http"
|
||||
|
||||
// HTTPHeadersCarrier custom implementation to fix duplicated headers
|
||||
// It has been fixed in https://github.com/opentracing/opentracing-go/pull/191
|
||||
type HTTPHeadersCarrier http.Header
|
||||
|
||||
// Set conforms to the TextMapWriter interface.
|
||||
func (c HTTPHeadersCarrier) Set(key, val string) {
|
||||
h := http.Header(c)
|
||||
h.Set(key, val)
|
||||
}
|
||||
|
||||
// ForeachKey conforms to the TextMapReader interface.
|
||||
func (c HTTPHeadersCarrier) ForeachKey(handler func(key, val string) error) error {
|
||||
for k, vals := range c {
|
||||
for _, v := range vals {
|
||||
if err := handler(k, v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -1,45 +0,0 @@
|
|||
package datadog
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
ddtracer "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/opentracer"
|
||||
datadog "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||
)
|
||||
|
||||
// Name sets the name of this tracer
|
||||
const Name = "datadog"
|
||||
|
||||
// Config provides configuration settings for a datadog tracer
|
||||
type Config struct {
|
||||
LocalAgentHostPort string `description:"Set datadog-agent's host:port that the reporter will used. Defaults to localhost:8126" export:"false"`
|
||||
GlobalTag string `description:"Key:Value tag to be set on all the spans." export:"true"`
|
||||
Debug bool `description:"Enable DataDog debug." export:"true"`
|
||||
}
|
||||
|
||||
// Setup sets up the tracer
|
||||
func (c *Config) Setup(serviceName string) (opentracing.Tracer, io.Closer, error) {
|
||||
tag := strings.SplitN(c.GlobalTag, ":", 2)
|
||||
|
||||
value := ""
|
||||
if len(tag) == 2 {
|
||||
value = tag[1]
|
||||
}
|
||||
|
||||
tracer := ddtracer.New(
|
||||
datadog.WithAgentAddr(c.LocalAgentHostPort),
|
||||
datadog.WithServiceName(serviceName),
|
||||
datadog.WithGlobalTag(tag[0], value),
|
||||
datadog.WithDebugMode(c.Debug),
|
||||
)
|
||||
|
||||
// Without this, child spans are getting the NOOP tracer
|
||||
opentracing.SetGlobalTracer(tracer)
|
||||
|
||||
log.Debug("DataDog tracer configured")
|
||||
|
||||
return tracer, nil, nil
|
||||
}
|
|
@ -1,57 +1,57 @@
|
|||
package tracing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/alice"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
type entryPointMiddleware struct {
|
||||
entryPoint string
|
||||
*Tracing
|
||||
}
|
||||
const (
|
||||
entryPointTypeName = "TracingEntryPoint"
|
||||
)
|
||||
|
||||
// NewEntryPoint creates a new middleware that the incoming request
|
||||
func (t *Tracing) NewEntryPoint(name string) negroni.Handler {
|
||||
log.Debug("Added entrypoint tracing middleware")
|
||||
return &entryPointMiddleware{Tracing: t, entryPoint: name}
|
||||
}
|
||||
// NewEntryPoint creates a new middleware that the incoming request.
|
||||
func NewEntryPoint(ctx context.Context, t *tracing.Tracing, entryPointName string, next http.Handler) http.Handler {
|
||||
middlewares.GetLogger(ctx, "tracing", entryPointTypeName).Debug("Creating middleware")
|
||||
|
||||
func (e *entryPointMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
opNameFunc := generateEntryPointSpanName
|
||||
|
||||
ctx, _ := e.Extract(opentracing.HTTPHeaders, HTTPHeadersCarrier(r.Header))
|
||||
span := e.StartSpan(opNameFunc(r, e.entryPoint, e.SpanNameLimit), ext.RPCServerOption(ctx))
|
||||
ext.Component.Set(span, e.ServiceName)
|
||||
LogRequest(span, r)
|
||||
ext.SpanKindRPCServer.Set(span)
|
||||
|
||||
r = r.WithContext(opentracing.ContextWithSpan(r.Context(), span))
|
||||
|
||||
recorder := newStatusCodeRecoder(w, 200)
|
||||
next(recorder, r)
|
||||
|
||||
LogResponseCode(span, recorder.Status())
|
||||
span.Finish()
|
||||
}
|
||||
|
||||
// generateEntryPointSpanName will return a Span name of an appropriate lenth based on the 'spanLimit' argument. If needed, it will be truncated, but will not be less than 24 characters.
|
||||
func generateEntryPointSpanName(r *http.Request, entryPoint string, spanLimit int) string {
|
||||
name := fmt.Sprintf("Entrypoint %s %s", entryPoint, r.Host)
|
||||
|
||||
if spanLimit > 0 && len(name) > spanLimit {
|
||||
if spanLimit < EntryPointMaxLengthNumber {
|
||||
log.Warnf("SpanNameLimit is set to be less than required static number of characters, defaulting to %d + 3", EntryPointMaxLengthNumber)
|
||||
spanLimit = EntryPointMaxLengthNumber + 3
|
||||
}
|
||||
hash := computeHash(name)
|
||||
limit := (spanLimit - EntryPointMaxLengthNumber) / 2
|
||||
name = fmt.Sprintf("Entrypoint %s %s %s", truncateString(entryPoint, limit), truncateString(r.Host, limit), hash)
|
||||
return &entryPointMiddleware{
|
||||
entryPoint: entryPointName,
|
||||
Tracing: t,
|
||||
next: next,
|
||||
}
|
||||
}
|
||||
|
||||
type entryPointMiddleware struct {
|
||||
*tracing.Tracing
|
||||
entryPoint string
|
||||
next http.Handler
|
||||
}
|
||||
|
||||
func (e *entryPointMiddleware) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
spanCtx, _ := e.Extract(opentracing.HTTPHeaders, tracing.HTTPHeadersCarrier(req.Header))
|
||||
|
||||
span, req, finish := e.StartSpanf(req, ext.SpanKindRPCServerEnum, "EntryPoint", []string{e.entryPoint, req.Host}, " ", ext.RPCServerOption(spanCtx))
|
||||
defer finish()
|
||||
|
||||
ext.Component.Set(span, e.ServiceName)
|
||||
tracing.LogRequest(span, req)
|
||||
|
||||
req = req.WithContext(tracing.WithTracing(req.Context(), e.Tracing))
|
||||
|
||||
recorder := newStatusCodeRecoder(rw, http.StatusOK)
|
||||
e.next.ServeHTTP(recorder, req)
|
||||
|
||||
tracing.LogResponseCode(span, recorder.Status())
|
||||
}
|
||||
|
||||
// WrapEntryPointHandler Wraps tracing to alice.Constructor.
|
||||
func WrapEntryPointHandler(ctx context.Context, tracer *tracing.Tracing, entryPointName string) alice.Constructor {
|
||||
return func(next http.Handler) (http.Handler, error) {
|
||||
return NewEntryPoint(ctx, tracer, entryPointName, next), nil
|
||||
}
|
||||
|
||||
return name
|
||||
}
|
||||
|
|
|
@ -1,70 +1,87 @@
|
|||
package tracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEntryPointMiddlewareServeHTTP(t *testing.T) {
|
||||
expectedTags := map[string]interface{}{
|
||||
"span.kind": ext.SpanKindRPCServerEnum,
|
||||
"http.method": "GET",
|
||||
"component": "",
|
||||
"http.url": "http://www.test.com",
|
||||
"http.host": "www.test.com",
|
||||
func TestEntryPointMiddleware(t *testing.T) {
|
||||
type expected struct {
|
||||
Tags map[string]interface{}
|
||||
OperationName string
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
entryPoint string
|
||||
tracing *Tracing
|
||||
expectedTags map[string]interface{}
|
||||
expectedName string
|
||||
desc string
|
||||
entryPoint string
|
||||
spanNameLimit int
|
||||
tracing *trackingBackenMock
|
||||
expected expected
|
||||
}{
|
||||
{
|
||||
desc: "no truncation test",
|
||||
entryPoint: "test",
|
||||
tracing: &Tracing{
|
||||
SpanNameLimit: 0,
|
||||
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
|
||||
desc: "no truncation test",
|
||||
entryPoint: "test",
|
||||
spanNameLimit: 0,
|
||||
tracing: &trackingBackenMock{
|
||||
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
|
||||
},
|
||||
expectedTags: expectedTags,
|
||||
expectedName: "Entrypoint test www.test.com",
|
||||
}, {
|
||||
desc: "basic test",
|
||||
entryPoint: "test",
|
||||
tracing: &Tracing{
|
||||
SpanNameLimit: 25,
|
||||
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
|
||||
expected: expected{
|
||||
Tags: map[string]interface{}{
|
||||
"span.kind": ext.SpanKindRPCServerEnum,
|
||||
"http.method": http.MethodGet,
|
||||
"component": "",
|
||||
"http.url": "http://www.test.com",
|
||||
"http.host": "www.test.com",
|
||||
},
|
||||
OperationName: "EntryPoint test www.test.com",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "basic test",
|
||||
entryPoint: "test",
|
||||
spanNameLimit: 25,
|
||||
tracing: &trackingBackenMock{
|
||||
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
|
||||
},
|
||||
expected: expected{
|
||||
Tags: map[string]interface{}{
|
||||
"span.kind": ext.SpanKindRPCServerEnum,
|
||||
"http.method": http.MethodGet,
|
||||
"component": "",
|
||||
"http.url": "http://www.test.com",
|
||||
"http.host": "www.test.com",
|
||||
},
|
||||
OperationName: "EntryPoint te... ww... 0c15301b",
|
||||
},
|
||||
expectedTags: expectedTags,
|
||||
expectedName: "Entrypoint te... ww... 39b97e58",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
e := &entryPointMiddleware{
|
||||
entryPoint: test.entryPoint,
|
||||
Tracing: test.tracing,
|
||||
}
|
||||
newTracing, err := tracing.NewTracing("", test.spanNameLimit, test.tracing)
|
||||
require.NoError(t, err)
|
||||
|
||||
next := func(http.ResponseWriter, *http.Request) {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://www.test.com", nil)
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
|
||||
span := test.tracing.tracer.(*MockTracer).Span
|
||||
|
||||
actual := span.Tags
|
||||
assert.Equal(t, test.expectedTags, actual)
|
||||
assert.Equal(t, test.expectedName, span.OpName)
|
||||
}
|
||||
tags := span.Tags
|
||||
assert.Equal(t, test.expected.Tags, tags)
|
||||
assert.Equal(t, test.expected.OperationName, span.OpName)
|
||||
})
|
||||
|
||||
e.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "http://www.test.com", nil), next)
|
||||
handler := NewEntryPoint(context.Background(), newTracing, test.entryPoint, next)
|
||||
handler.ServeHTTP(rw, req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,63 +1,58 @@
|
|||
package tracing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/middlewares"
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
const (
|
||||
forwarderTypeName = "TracingForwarder"
|
||||
)
|
||||
|
||||
type forwarderMiddleware struct {
|
||||
frontend string
|
||||
backend string
|
||||
opName string
|
||||
*Tracing
|
||||
router string
|
||||
service string
|
||||
next http.Handler
|
||||
}
|
||||
|
||||
// NewForwarderMiddleware creates a new forwarder middleware that traces the outgoing request
|
||||
func (t *Tracing) NewForwarderMiddleware(frontend, backend string) negroni.Handler {
|
||||
log.Debugf("Added outgoing tracing middleware %s", frontend)
|
||||
// NewForwarder creates a new forwarder middleware that traces the outgoing request.
|
||||
func NewForwarder(ctx context.Context, router, service string, next http.Handler) http.Handler {
|
||||
middlewares.GetLogger(ctx, "tracing", forwarderTypeName).
|
||||
Debugf("Added outgoing tracing middleware %s", service)
|
||||
|
||||
return &forwarderMiddleware{
|
||||
Tracing: t,
|
||||
frontend: frontend,
|
||||
backend: backend,
|
||||
opName: generateForwardSpanName(frontend, backend, t.SpanNameLimit),
|
||||
router: router,
|
||||
service: service,
|
||||
next: next,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *forwarderMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
span, r, finish := StartSpan(r, f.opName, true)
|
||||
func (f *forwarderMiddleware) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
tr, err := tracing.FromContext(req.Context())
|
||||
if err != nil {
|
||||
f.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
opParts := []string{f.service, f.router}
|
||||
span, req, finish := tr.StartSpanf(req, ext.SpanKindRPCClientEnum, "forward", opParts, "/")
|
||||
defer finish()
|
||||
span.SetTag("frontend.name", f.frontend)
|
||||
span.SetTag("backend.name", f.backend)
|
||||
ext.HTTPMethod.Set(span, r.Method)
|
||||
ext.HTTPUrl.Set(span, fmt.Sprintf("%s%s", r.URL.String(), r.RequestURI))
|
||||
span.SetTag("http.host", r.Host)
|
||||
|
||||
InjectRequestHeaders(r)
|
||||
span.SetTag("service.name", f.service)
|
||||
span.SetTag("router.name", f.router)
|
||||
ext.HTTPMethod.Set(span, req.Method)
|
||||
ext.HTTPUrl.Set(span, req.URL.String())
|
||||
span.SetTag("http.host", req.Host)
|
||||
|
||||
recorder := newStatusCodeRecoder(w, 200)
|
||||
tracing.InjectRequestHeaders(req)
|
||||
|
||||
next(recorder, r)
|
||||
recorder := newStatusCodeRecoder(rw, 200)
|
||||
|
||||
LogResponseCode(span, recorder.Status())
|
||||
}
|
||||
|
||||
// generateForwardSpanName will return a Span name of an appropriate lenth based on the 'spanLimit' argument. If needed, it will be truncated, but will not be less than 21 characters
|
||||
func generateForwardSpanName(frontend, backend string, spanLimit int) string {
|
||||
name := fmt.Sprintf("forward %s/%s", frontend, backend)
|
||||
|
||||
if spanLimit > 0 && len(name) > spanLimit {
|
||||
if spanLimit < ForwardMaxLengthNumber {
|
||||
log.Warnf("SpanNameLimit is set to be less than required static number of characters, defaulting to %d + 3", ForwardMaxLengthNumber)
|
||||
spanLimit = ForwardMaxLengthNumber + 3
|
||||
}
|
||||
hash := computeHash(name)
|
||||
limit := (spanLimit - ForwardMaxLengthNumber) / 2
|
||||
name = fmt.Sprintf("forward %s/%s/%s", truncateString(frontend, limit), truncateString(backend, limit), hash)
|
||||
}
|
||||
|
||||
return name
|
||||
f.next.ServeHTTP(recorder, req)
|
||||
|
||||
tracing.LogResponseCode(span, recorder.Status())
|
||||
}
|
||||
|
|
|
@ -1,93 +1,136 @@
|
|||
package tracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTracingNewForwarderMiddleware(t *testing.T) {
|
||||
func TestNewForwarder(t *testing.T) {
|
||||
type expected struct {
|
||||
Tags map[string]interface{}
|
||||
OperationName string
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
tracer *Tracing
|
||||
frontend string
|
||||
backend string
|
||||
expected *forwarderMiddleware
|
||||
desc string
|
||||
spanNameLimit int
|
||||
tracing *trackingBackenMock
|
||||
service string
|
||||
router string
|
||||
expected expected
|
||||
}{
|
||||
{
|
||||
desc: "Simple Forward Tracer without truncation and hashing",
|
||||
tracer: &Tracing{
|
||||
SpanNameLimit: 101,
|
||||
desc: "Simple Forward Tracer without truncation and hashing",
|
||||
spanNameLimit: 101,
|
||||
tracing: &trackingBackenMock{
|
||||
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
|
||||
},
|
||||
frontend: "some-service.domain.tld",
|
||||
backend: "some-service.domain.tld",
|
||||
expected: &forwarderMiddleware{
|
||||
Tracing: &Tracing{
|
||||
SpanNameLimit: 101,
|
||||
service: "some-service.domain.tld",
|
||||
router: "some-service.domain.tld",
|
||||
expected: expected{
|
||||
Tags: map[string]interface{}{
|
||||
"http.host": "www.test.com",
|
||||
"http.method": "GET",
|
||||
"http.url": "http://www.test.com/toto",
|
||||
"service.name": "some-service.domain.tld",
|
||||
"router.name": "some-service.domain.tld",
|
||||
"span.kind": ext.SpanKindRPCClientEnum,
|
||||
},
|
||||
frontend: "some-service.domain.tld",
|
||||
backend: "some-service.domain.tld",
|
||||
opName: "forward some-service.domain.tld/some-service.domain.tld",
|
||||
OperationName: "forward some-service.domain.tld/some-service.domain.tld",
|
||||
},
|
||||
}, {
|
||||
desc: "Simple Forward Tracer with truncation and hashing",
|
||||
tracer: &Tracing{
|
||||
SpanNameLimit: 101,
|
||||
desc: "Simple Forward Tracer with truncation and hashing",
|
||||
spanNameLimit: 101,
|
||||
tracing: &trackingBackenMock{
|
||||
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
|
||||
},
|
||||
frontend: "some-service-100.slug.namespace.environment.domain.tld",
|
||||
backend: "some-service-100.slug.namespace.environment.domain.tld",
|
||||
expected: &forwarderMiddleware{
|
||||
Tracing: &Tracing{
|
||||
SpanNameLimit: 101,
|
||||
service: "some-service-100.slug.namespace.environment.domain.tld",
|
||||
router: "some-service-100.slug.namespace.environment.domain.tld",
|
||||
expected: expected{
|
||||
Tags: map[string]interface{}{
|
||||
"http.host": "www.test.com",
|
||||
"http.method": "GET",
|
||||
"http.url": "http://www.test.com/toto",
|
||||
"service.name": "some-service-100.slug.namespace.environment.domain.tld",
|
||||
"router.name": "some-service-100.slug.namespace.environment.domain.tld",
|
||||
"span.kind": ext.SpanKindRPCClientEnum,
|
||||
},
|
||||
frontend: "some-service-100.slug.namespace.environment.domain.tld",
|
||||
backend: "some-service-100.slug.namespace.environment.domain.tld",
|
||||
opName: "forward some-service-100.slug.namespace.enviro.../some-service-100.slug.namespace.enviro.../bc4a0d48",
|
||||
OperationName: "forward some-service-100.slug.namespace.enviro.../some-service-100.slug.namespace.enviro.../bc4a0d48",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Exactly 101 chars",
|
||||
tracer: &Tracing{
|
||||
SpanNameLimit: 101,
|
||||
desc: "Exactly 101 chars",
|
||||
spanNameLimit: 101,
|
||||
tracing: &trackingBackenMock{
|
||||
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
|
||||
},
|
||||
frontend: "some-service1.namespace.environment.domain.tld",
|
||||
backend: "some-service1.namespace.environment.domain.tld",
|
||||
expected: &forwarderMiddleware{
|
||||
Tracing: &Tracing{
|
||||
SpanNameLimit: 101,
|
||||
service: "some-service1.namespace.environment.domain.tld",
|
||||
router: "some-service1.namespace.environment.domain.tld",
|
||||
expected: expected{
|
||||
Tags: map[string]interface{}{
|
||||
"http.host": "www.test.com",
|
||||
"http.method": "GET",
|
||||
"http.url": "http://www.test.com/toto",
|
||||
"service.name": "some-service1.namespace.environment.domain.tld",
|
||||
"router.name": "some-service1.namespace.environment.domain.tld",
|
||||
"span.kind": ext.SpanKindRPCClientEnum,
|
||||
},
|
||||
frontend: "some-service1.namespace.environment.domain.tld",
|
||||
backend: "some-service1.namespace.environment.domain.tld",
|
||||
opName: "forward some-service1.namespace.environment.domain.tld/some-service1.namespace.environment.domain.tld",
|
||||
OperationName: "forward some-service1.namespace.environment.domain.tld/some-service1.namespace.environment.domain.tld",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "More than 101 chars",
|
||||
tracer: &Tracing{
|
||||
SpanNameLimit: 101,
|
||||
desc: "More than 101 chars",
|
||||
spanNameLimit: 101,
|
||||
tracing: &trackingBackenMock{
|
||||
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
|
||||
},
|
||||
frontend: "some-service1.frontend.namespace.environment.domain.tld",
|
||||
backend: "some-service1.backend.namespace.environment.domain.tld",
|
||||
expected: &forwarderMiddleware{
|
||||
Tracing: &Tracing{
|
||||
SpanNameLimit: 101,
|
||||
service: "some-service1.frontend.namespace.environment.domain.tld",
|
||||
router: "some-service1.backend.namespace.environment.domain.tld",
|
||||
expected: expected{
|
||||
Tags: map[string]interface{}{
|
||||
"http.host": "www.test.com",
|
||||
"http.method": "GET",
|
||||
"http.url": "http://www.test.com/toto",
|
||||
"service.name": "some-service1.frontend.namespace.environment.domain.tld",
|
||||
"router.name": "some-service1.backend.namespace.environment.domain.tld",
|
||||
"span.kind": ext.SpanKindRPCClientEnum,
|
||||
},
|
||||
frontend: "some-service1.frontend.namespace.environment.domain.tld",
|
||||
backend: "some-service1.backend.namespace.environment.domain.tld",
|
||||
opName: "forward some-service1.frontend.namespace.envir.../some-service1.backend.namespace.enviro.../fa49dd23",
|
||||
OperationName: "forward some-service1.frontend.namespace.envir.../some-service1.backend.namespace.enviro.../fa49dd23",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
actual := test.tracer.NewForwarderMiddleware(test.frontend, test.backend)
|
||||
newTracing, err := tracing.NewTracing("", test.spanNameLimit, test.tracing)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.expected, actual)
|
||||
assert.True(t, len(test.expected.opName) <= test.tracer.SpanNameLimit)
|
||||
req := httptest.NewRequest(http.MethodGet, "http://www.test.com/toto", nil)
|
||||
req = req.WithContext(tracing.WithTracing(req.Context(), newTracing))
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
next := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
|
||||
span := test.tracing.tracer.(*MockTracer).Span
|
||||
|
||||
tags := span.Tags
|
||||
assert.Equal(t, test.expected.Tags, tags)
|
||||
assert.True(t, len(test.expected.OperationName) <= test.spanNameLimit,
|
||||
"the len of the operation name %q [len: %d] doesn't respect limit %d",
|
||||
test.expected.OperationName, len(test.expected.OperationName), test.spanNameLimit)
|
||||
assert.Equal(t, test.expected.OperationName, span.OpName)
|
||||
})
|
||||
|
||||
handler := NewForwarder(context.Background(), test.router, test.service, next)
|
||||
handler.ServeHTTP(rw, req)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,73 +0,0 @@
|
|||
package jaeger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
jaegercfg "github.com/uber/jaeger-client-go/config"
|
||||
"github.com/uber/jaeger-client-go/zipkin"
|
||||
jaegermet "github.com/uber/jaeger-lib/metrics"
|
||||
)
|
||||
|
||||
// Name sets the name of this tracer
|
||||
const Name = "jaeger"
|
||||
|
||||
// Config provides configuration settings for a jaeger tracer
|
||||
type Config struct {
|
||||
SamplingServerURL string `description:"set the sampling server url." export:"false"`
|
||||
SamplingType string `description:"set the sampling type." export:"true"`
|
||||
SamplingParam float64 `description:"set the sampling parameter." export:"true"`
|
||||
LocalAgentHostPort string `description:"set jaeger-agent's host:port that the reporter will used." export:"false"`
|
||||
Gen128Bit bool `description:"generate 128 bit span IDs." export:"true"`
|
||||
Propagation string `description:"which propgation format to use (jaeger/b3)." export:"true"`
|
||||
}
|
||||
|
||||
// Setup sets up the tracer
|
||||
func (c *Config) Setup(componentName string) (opentracing.Tracer, io.Closer, error) {
|
||||
jcfg := jaegercfg.Configuration{
|
||||
Sampler: &jaegercfg.SamplerConfig{
|
||||
SamplingServerURL: c.SamplingServerURL,
|
||||
Type: c.SamplingType,
|
||||
Param: c.SamplingParam,
|
||||
},
|
||||
Reporter: &jaegercfg.ReporterConfig{
|
||||
LogSpans: true,
|
||||
LocalAgentHostPort: c.LocalAgentHostPort,
|
||||
},
|
||||
}
|
||||
|
||||
jMetricsFactory := jaegermet.NullFactory
|
||||
|
||||
opts := []jaegercfg.Option{
|
||||
jaegercfg.Logger(&jaegerLogger{}),
|
||||
jaegercfg.Metrics(jMetricsFactory),
|
||||
jaegercfg.Gen128Bit(c.Gen128Bit),
|
||||
}
|
||||
|
||||
switch c.Propagation {
|
||||
case "b3":
|
||||
p := zipkin.NewZipkinB3HTTPHeaderPropagator()
|
||||
opts = append(opts,
|
||||
jaegercfg.Injector(opentracing.HTTPHeaders, p),
|
||||
jaegercfg.Extractor(opentracing.HTTPHeaders, p),
|
||||
)
|
||||
case "jaeger", "":
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unknown propagation format: %s", c.Propagation)
|
||||
}
|
||||
|
||||
// Initialize tracer with a logger and a metrics factory
|
||||
closer, err := jcfg.InitGlobalTracer(
|
||||
componentName,
|
||||
opts...,
|
||||
)
|
||||
if err != nil {
|
||||
log.Warnf("Could not initialize jaeger tracer: %s", err.Error())
|
||||
return nil, nil, err
|
||||
}
|
||||
log.Debug("Jaeger tracer configured")
|
||||
|
||||
return opentracing.GlobalTracer(), closer, nil
|
||||
}
|
|
@ -1,15 +0,0 @@
|
|||
package jaeger
|
||||
|
||||
import "github.com/containous/traefik/log"
|
||||
|
||||
// jaegerLogger is an implementation of the Logger interface that delegates to traefik log
|
||||
type jaegerLogger struct{}
|
||||
|
||||
func (l *jaegerLogger) Error(msg string) {
|
||||
log.Errorf("Tracing jaeger error: %s", msg)
|
||||
}
|
||||
|
||||
// Infof logs a message at debug priority
|
||||
func (l *jaegerLogger) Infof(msg string, args ...interface{}) {
|
||||
log.Debugf(msg, args...)
|
||||
}
|
|
@ -1,29 +1,43 @@
|
|||
package tracing
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"io"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"github.com/opentracing/opentracing-go/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type MockTracer struct {
|
||||
Span *MockSpan
|
||||
}
|
||||
|
||||
// StartSpan belongs to the Tracer interface.
|
||||
func (n MockTracer) StartSpan(operationName string, opts ...opentracing.StartSpanOption) opentracing.Span {
|
||||
n.Span.OpName = operationName
|
||||
return n.Span
|
||||
}
|
||||
|
||||
// Inject belongs to the Tracer interface.
|
||||
func (n MockTracer) Inject(sp opentracing.SpanContext, format interface{}, carrier interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract belongs to the Tracer interface.
|
||||
func (n MockTracer) Extract(format interface{}, carrier interface{}) (opentracing.SpanContext, error) {
|
||||
return nil, opentracing.ErrSpanContextNotFound
|
||||
}
|
||||
|
||||
// MockSpanContext
|
||||
type MockSpanContext struct{}
|
||||
|
||||
func (n MockSpanContext) ForeachBaggageItem(handler func(k, v string) bool) {}
|
||||
|
||||
// MockSpan
|
||||
type MockSpan struct {
|
||||
OpName string
|
||||
Tags map[string]interface{}
|
||||
}
|
||||
|
||||
type MockSpanContext struct {
|
||||
}
|
||||
|
||||
// MockSpanContext:
|
||||
func (n MockSpanContext) ForeachBaggageItem(handler func(k, v string) bool) {}
|
||||
|
||||
// MockSpan:
|
||||
func (n MockSpan) Context() opentracing.SpanContext { return MockSpanContext{} }
|
||||
func (n MockSpan) SetBaggageItem(key, val string) opentracing.Span {
|
||||
return MockSpan{Tags: make(map[string]interface{})}
|
||||
|
@ -46,88 +60,11 @@ func (n MockSpan) Reset() {
|
|||
n.Tags = make(map[string]interface{})
|
||||
}
|
||||
|
||||
// StartSpan belongs to the Tracer interface.
|
||||
func (n MockTracer) StartSpan(operationName string, opts ...opentracing.StartSpanOption) opentracing.Span {
|
||||
n.Span.OpName = operationName
|
||||
return n.Span
|
||||
type trackingBackenMock struct {
|
||||
tracer opentracing.Tracer
|
||||
}
|
||||
|
||||
// Inject belongs to the Tracer interface.
|
||||
func (n MockTracer) Inject(sp opentracing.SpanContext, format interface{}, carrier interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract belongs to the Tracer interface.
|
||||
func (n MockTracer) Extract(format interface{}, carrier interface{}) (opentracing.SpanContext, error) {
|
||||
return nil, opentracing.ErrSpanContextNotFound
|
||||
}
|
||||
|
||||
func TestTruncateString(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
text string
|
||||
limit int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "short text less than limit 10",
|
||||
text: "short",
|
||||
limit: 10,
|
||||
expected: "short",
|
||||
},
|
||||
{
|
||||
desc: "basic truncate with limit 10",
|
||||
text: "some very long pice of text",
|
||||
limit: 10,
|
||||
expected: "some ve...",
|
||||
},
|
||||
{
|
||||
desc: "truncate long FQDN to 39 chars",
|
||||
text: "some-service-100.slug.namespace.environment.domain.tld",
|
||||
limit: 39,
|
||||
expected: "some-service-100.slug.namespace.envi...",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
actual := truncateString(test.text, test.limit)
|
||||
|
||||
assert.Equal(t, test.expected, actual)
|
||||
assert.True(t, len(actual) <= test.limit)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeHash(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
text string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "hashing",
|
||||
text: "some very long pice of text",
|
||||
expected: "0258ea1c",
|
||||
},
|
||||
{
|
||||
desc: "short text less than limit 10",
|
||||
text: "short",
|
||||
expected: "f9b0078b",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
actual := computeHash(test.text)
|
||||
|
||||
assert.Equal(t, test.expected, actual)
|
||||
})
|
||||
}
|
||||
func (t *trackingBackenMock) Setup(componentName string) (opentracing.Tracer, io.Closer, error) {
|
||||
opentracing.SetGlobalTracer(t.tracer)
|
||||
return t.tracer, nil, nil
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue