1
0
Fork 0

Move code to pkg

This commit is contained in:
Ludovic Fernandez 2019-03-15 09:42:03 +01:00 committed by Traefiker Bot
parent bd4c822670
commit f1b085fa36
465 changed files with 656 additions and 680 deletions

136
pkg/anonymize/anonymize.go Normal file
View file

@ -0,0 +1,136 @@
package anonymize
import (
"encoding/json"
"fmt"
"reflect"
"regexp"
"github.com/mitchellh/copystructure"
"github.com/mvdan/xurls"
)
const (
maskShort = "xxxx"
maskLarge = maskShort + maskShort + maskShort + maskShort + maskShort + maskShort + maskShort + maskShort
)
// Do configuration.
func Do(baseConfig interface{}, indent bool) (string, error) {
anomConfig, err := copystructure.Copy(baseConfig)
if err != nil {
return "", err
}
val := reflect.ValueOf(anomConfig)
err = doOnStruct(val)
if err != nil {
return "", err
}
configJSON, err := marshal(anomConfig, indent)
if err != nil {
return "", err
}
return doOnJSON(string(configJSON)), nil
}
func doOnJSON(input string) string {
mailExp := regexp.MustCompile(`\w[-._\w]*\w@\w[-._\w]*\w\.\w{2,3}"`)
return xurls.Relaxed.ReplaceAllString(mailExp.ReplaceAllString(input, maskLarge+"\""), maskLarge)
}
func doOnStruct(field reflect.Value) error {
switch field.Kind() {
case reflect.Ptr:
if !field.IsNil() {
if err := doOnStruct(field.Elem()); err != nil {
return err
}
}
case reflect.Struct:
for i := 0; i < field.NumField(); i++ {
fld := field.Field(i)
stField := field.Type().Field(i)
if !isExported(stField) {
continue
}
if stField.Tag.Get("export") == "true" {
if err := doOnStruct(fld); err != nil {
return err
}
} else {
if err := reset(fld, stField.Name); err != nil {
return err
}
}
}
case reflect.Map:
for _, key := range field.MapKeys() {
if err := doOnStruct(field.MapIndex(key)); err != nil {
return err
}
}
case reflect.Slice:
for j := 0; j < field.Len(); j++ {
if err := doOnStruct(field.Index(j)); err != nil {
return err
}
}
}
return nil
}
func reset(field reflect.Value, name string) error {
if !field.CanSet() {
return fmt.Errorf("cannot reset field %s", name)
}
switch field.Kind() {
case reflect.Ptr:
if !field.IsNil() {
field.Set(reflect.Zero(field.Type()))
}
case reflect.Struct:
if field.IsValid() {
field.Set(reflect.Zero(field.Type()))
}
case reflect.String:
if field.String() != "" {
field.Set(reflect.ValueOf(maskShort))
}
case reflect.Map:
if field.Len() > 0 {
field.Set(reflect.MakeMap(field.Type()))
}
case reflect.Slice:
if field.Len() > 0 {
field.Set(reflect.MakeSlice(field.Type(), 0, 0))
}
case reflect.Interface:
if !field.IsNil() {
return reset(field.Elem(), "")
}
default:
// Primitive type
field.Set(reflect.Zero(field.Type()))
}
return nil
}
// isExported return true is a struct field is exported, else false
func isExported(f reflect.StructField) bool {
if f.PkgPath != "" && !f.Anonymous {
return false
}
return true
}
func marshal(anomConfig interface{}, indent bool) ([]byte, error) {
if indent {
return json.MarshalIndent(anomConfig, "", " ")
}
return json.Marshal(anomConfig)
}

View file

@ -0,0 +1,329 @@
package anonymize
import (
"os"
"testing"
"time"
"github.com/containous/flaeg/parse"
"github.com/containous/traefik/pkg/config/static"
"github.com/containous/traefik/pkg/ping"
"github.com/containous/traefik/pkg/provider"
"github.com/containous/traefik/pkg/provider/acme"
acmeprovider "github.com/containous/traefik/pkg/provider/acme"
"github.com/containous/traefik/pkg/provider/docker"
"github.com/containous/traefik/pkg/provider/file"
"github.com/containous/traefik/pkg/provider/kubernetes/crd"
"github.com/containous/traefik/pkg/provider/kubernetes/ingress"
traefiktls "github.com/containous/traefik/pkg/tls"
"github.com/containous/traefik/pkg/tracing/datadog"
"github.com/containous/traefik/pkg/tracing/instana"
"github.com/containous/traefik/pkg/tracing/jaeger"
"github.com/containous/traefik/pkg/tracing/zipkin"
"github.com/containous/traefik/pkg/types"
assetfs "github.com/elazarl/go-bindata-assetfs"
)
func TestDo_globalConfiguration(t *testing.T) {
config := &static.Configuration{}
sendAnonymousUsage := true
config.Global = &static.Global{
Debug: true,
CheckNewVersion: true,
SendAnonymousUsage: &sendAnonymousUsage,
}
config.AccessLog = &types.AccessLog{
FilePath: "AccessLog FilePath",
Format: "AccessLog Format",
Filters: &types.AccessLogFilters{
StatusCodes: types.StatusCodes{"200", "500"},
RetryAttempts: true,
MinDuration: 10,
},
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Names: types.FieldNames{
"RequestHost": "keep",
},
Headers: &types.FieldHeaders{
DefaultMode: "drop",
Names: types.FieldHeaderNames{
"Referer": "keep",
},
},
},
BufferingSize: 4,
}
config.Log = &types.TraefikLog{
LogLevel: "LogLevel",
FilePath: "/foo/path",
Format: "json",
}
config.EntryPoints = static.EntryPoints{
"foo": {
Address: "foo Address",
Transport: &static.EntryPointsTransport{
RespondingTimeouts: &static.RespondingTimeouts{
ReadTimeout: parse.Duration(111 * time.Second),
WriteTimeout: parse.Duration(111 * time.Second),
IdleTimeout: parse.Duration(111 * time.Second),
},
},
ProxyProtocol: &static.ProxyProtocol{
TrustedIPs: []string{"127.0.0.1/32", "192.168.0.1"},
},
},
"fii": {
Address: "fii Address",
Transport: &static.EntryPointsTransport{
RespondingTimeouts: &static.RespondingTimeouts{
ReadTimeout: parse.Duration(111 * time.Second),
WriteTimeout: parse.Duration(111 * time.Second),
IdleTimeout: parse.Duration(111 * time.Second),
},
},
ProxyProtocol: &static.ProxyProtocol{
TrustedIPs: []string{"127.0.0.1/32", "192.168.0.1"},
},
},
}
config.ACME = &acme.Configuration{
Email: "acme Email",
ACMELogging: true,
CAServer: "CAServer",
Storage: "Storage",
EntryPoint: "EntryPoint",
KeyType: "MyKeyType",
OnHostRule: true,
DNSChallenge: &acmeprovider.DNSChallenge{Provider: "DNSProvider"},
HTTPChallenge: &acmeprovider.HTTPChallenge{
EntryPoint: "MyEntryPoint",
},
TLSChallenge: &acmeprovider.TLSChallenge{},
Domains: []types.Domain{
{
Main: "Domains Main",
SANs: []string{"Domains acme SANs 1", "Domains acme SANs 2", "Domains acme SANs 3"},
},
},
}
config.Providers = &static.Providers{
ProvidersThrottleDuration: parse.Duration(111 * time.Second),
}
config.ServersTransport = &static.ServersTransport{
InsecureSkipVerify: true,
RootCAs: traefiktls.FilesOrContents{"RootCAs 1", "RootCAs 2", "RootCAs 3"},
MaxIdleConnsPerHost: 111,
ForwardingTimeouts: &static.ForwardingTimeouts{
DialTimeout: parse.Duration(111 * time.Second),
ResponseHeaderTimeout: parse.Duration(111 * time.Second),
},
}
config.API = &static.API{
EntryPoint: "traefik",
Dashboard: true,
Statistics: &types.Statistics{
RecentErrors: 111,
},
DashboardAssets: &assetfs.AssetFS{
Asset: func(path string) ([]byte, error) {
return nil, nil
},
AssetDir: func(path string) ([]string, error) {
return nil, nil
},
AssetInfo: func(path string) (os.FileInfo, error) {
return nil, nil
},
Prefix: "fii",
},
Middlewares: []string{"first", "second"},
}
config.Providers.File = &file.Provider{
BaseProvider: provider.BaseProvider{
Watch: true,
Filename: "file Filename",
Constraints: types.Constraints{
{
Key: "file Constraints Key 1",
Regex: "file Constraints Regex 2",
MustMatch: true,
},
{
Key: "file Constraints Key 1",
Regex: "file Constraints Regex 2",
MustMatch: true,
},
},
Trace: true,
DebugLogGeneratedTemplate: true,
},
Directory: "file Directory",
}
config.Providers.Docker = &docker.Provider{
BaseProvider: provider.BaseProvider{
Watch: true,
Filename: "myfilename",
Constraints: nil,
Trace: true,
DebugLogGeneratedTemplate: true,
},
Endpoint: "MyEndPoint",
DefaultRule: "PathPrefix(`/`)",
TLS: &types.ClientTLS{
CA: "myCa",
CAOptional: true,
Cert: "mycert.pem",
Key: "mycert.key",
InsecureSkipVerify: true,
},
ExposedByDefault: true,
UseBindPortIP: true,
SwarmMode: true,
Network: "MyNetwork",
SwarmModeRefreshSeconds: 42,
}
config.Providers.Kubernetes = &ingress.Provider{
BaseProvider: provider.BaseProvider{
Watch: true,
Filename: "myFileName",
Constraints: types.Constraints{
{
Key: "k8s Constraints Key 1",
Regex: "k8s Constraints Regex 2",
MustMatch: true,
},
{
Key: "k8s Constraints Key 1",
Regex: "k8s Constraints Regex 2",
MustMatch: true,
},
},
Trace: true,
DebugLogGeneratedTemplate: true,
},
Endpoint: "MyEndpoint",
Token: "MyToken",
CertAuthFilePath: "MyCertAuthPath",
DisablePassHostHeaders: true,
EnablePassTLSCert: true,
Namespaces: []string{"a", "b"},
LabelSelector: "myLabelSelector",
IngressClass: "MyIngressClass",
}
config.Providers.KubernetesCRD = &crd.Provider{
BaseProvider: provider.BaseProvider{
Watch: true,
Filename: "myFileName",
Constraints: types.Constraints{
{
Key: "k8s Constraints Key 1",
Regex: "k8s Constraints Regex 2",
MustMatch: true,
},
{
Key: "k8s Constraints Key 1",
Regex: "k8s Constraints Regex 2",
MustMatch: true,
},
},
Trace: true,
DebugLogGeneratedTemplate: true,
},
Endpoint: "MyEndpoint",
Token: "MyToken",
CertAuthFilePath: "MyCertAuthPath",
DisablePassHostHeaders: true,
EnablePassTLSCert: true,
Namespaces: []string{"a", "b"},
LabelSelector: "myLabelSelector",
IngressClass: "MyIngressClass",
}
// FIXME Test the other providers once they are migrated
config.Metrics = &types.Metrics{
Prometheus: &types.Prometheus{
Buckets: types.Buckets{0.1, 0.3, 1.2, 5},
EntryPoint: "MyEntryPoint",
Middlewares: []string{"m1", "m2"},
},
Datadog: &types.Datadog{
Address: "localhost:8181",
PushInterval: "12",
},
StatsD: &types.Statsd{
Address: "localhost:8182",
PushInterval: "42",
},
InfluxDB: &types.InfluxDB{
Address: "localhost:8183",
Protocol: "http",
PushInterval: "22",
Database: "myDB",
RetentionPolicy: "12",
Username: "a",
Password: "aaaa",
},
}
config.Ping = &ping.Handler{
EntryPoint: "MyEntryPoint",
Middlewares: []string{"m1", "m2", "m3"},
}
config.Tracing = &static.Tracing{
Backend: "myBackend",
ServiceName: "myServiceName",
SpanNameLimit: 3,
Jaeger: &jaeger.Config{
SamplingServerURL: "aaa",
SamplingType: "bbb",
SamplingParam: 43,
LocalAgentHostPort: "ccc",
Gen128Bit: true,
Propagation: "ddd",
TraceContextHeaderName: "eee",
},
Zipkin: &zipkin.Config{
HTTPEndpoint: "fff",
SameSpan: true,
ID128Bit: true,
Debug: true,
SampleRate: 53,
},
DataDog: &datadog.Config{
LocalAgentHostPort: "ggg",
GlobalTag: "eee",
Debug: true,
PrioritySampling: true,
},
Instana: &instana.Config{
LocalAgentHost: "fff",
LocalAgentPort: 32,
LogLevel: "ggg",
},
}
config.HostResolver = &types.HostResolverConfig{
CnameFlattening: true,
ResolvConfig: "aaa",
ResolvDepth: 3,
}
cleanJSON, err := Do(config, true)
if err != nil {
t.Fatal(err, cleanJSON)
}
}

View file

@ -0,0 +1,225 @@
package anonymize
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_doOnJSON(t *testing.T) {
baseConfiguration := `
{
"GraceTimeOut": 10000000000,
"Debug": false,
"CheckNewVersion": true,
"AccessLogsFile": "",
"TraefikLogsFile": "",
"LogLevel": "ERROR",
"EntryPoints": {
"http": {
"Network": "",
"Address": ":80",
"TLS": null,
"Auth": null,
"Compress": false
},
"https": {
"Address": ":443",
"TLS": {
"MinVersion": "",
"CipherSuites": null,
"Certificates": null,
"ClientCAFiles": null
},
"Auth": null,
"Compress": false
}
},
"Cluster": null,
"Constraints": [],
"ACME": {
"Email": "foo@bar.com",
"Domains": [
{
"Main": "foo@bar.com",
"SANs": null
},
{
"Main": "foo@bar.com",
"SANs": null
}
],
"Storage": "",
"StorageFile": "/acme/acme.json",
"OnDemand": true,
"OnHostRule": true,
"CAServer": "",
"EntryPoint": "https",
"DNSProvider": "",
"DelayDontCheckDNS": 0,
"ACMELogging": false,
"TLSOptions": null
},
"DefaultEntryPoints": [
"https",
"http"
],
"ProvidersThrottleDuration": 2000000000,
"MaxIdleConnsPerHost": 200,
"IdleTimeout": 180000000000,
"InsecureSkipVerify": false,
"Retry": null,
"HealthCheck": {
"Interval": 30000000000
},
"Docker": null,
"File": null,
"Web": null,
"Marathon": null,
"Consul": null,
"ConsulCatalog": null,
"Etcd": null,
"Zookeeper": null,
"Boltdb": null,
"Kubernetes": null,
"Mesos": null,
"Eureka": null,
"ECS": null,
"Rancher": null,
"DynamoDB": null,
"ConfigFile": "/etc/traefik/traefik.toml"
}
`
expectedConfiguration := `
{
"GraceTimeOut": 10000000000,
"Debug": false,
"CheckNewVersion": true,
"AccessLogsFile": "",
"TraefikLogsFile": "",
"LogLevel": "ERROR",
"EntryPoints": {
"http": {
"Network": "",
"Address": ":80",
"TLS": null,
"Auth": null,
"Compress": false
},
"https": {
"Address": ":443",
"TLS": {
"MinVersion": "",
"CipherSuites": null,
"Certificates": null,
"ClientCAFiles": null
},
"Auth": null,
"Compress": false
}
},
"Cluster": null,
"Constraints": [],
"ACME": {
"Email": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"Domains": [
{
"Main": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"SANs": null
},
{
"Main": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"SANs": null
}
],
"Storage": "",
"StorageFile": "/acme/acme.json",
"OnDemand": true,
"OnHostRule": true,
"CAServer": "",
"EntryPoint": "https",
"DNSProvider": "",
"DelayDontCheckDNS": 0,
"ACMELogging": false,
"TLSOptions": null
},
"DefaultEntryPoints": [
"https",
"http"
],
"ProvidersThrottleDuration": 2000000000,
"MaxIdleConnsPerHost": 200,
"IdleTimeout": 180000000000,
"InsecureSkipVerify": false,
"Retry": null,
"HealthCheck": {
"Interval": 30000000000
},
"Docker": null,
"File": null,
"Web": null,
"Marathon": null,
"Consul": null,
"ConsulCatalog": null,
"Etcd": null,
"Zookeeper": null,
"Boltdb": null,
"Kubernetes": null,
"Mesos": null,
"Eureka": null,
"ECS": null,
"Rancher": null,
"DynamoDB": null,
"ConfigFile": "/etc/traefik/traefik.toml"
}
`
anomConfiguration := doOnJSON(baseConfiguration)
if anomConfiguration != expectedConfiguration {
t.Errorf("Got %s, want %s.", anomConfiguration, expectedConfiguration)
}
}
func Test_doOnJSON_simple(t *testing.T) {
testCases := []struct {
name string
input string
expectedOutput string
}{
{
name: "email",
input: `{
"email1": "goo@example.com",
"email2": "foo.bargoo@example.com",
"email3": "foo.bargoo@example.com.us"
}`,
expectedOutput: `{
"email1": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"email2": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"email3": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
}`,
},
{
name: "url",
input: `{
"URL": "foo domain.com foo",
"URL": "foo sub.domain.com foo",
"URL": "foo sub.sub.domain.com foo",
"URL": "foo sub.sub.sub.domain.com.us foo"
}`,
expectedOutput: `{
"URL": "foo xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx foo",
"URL": "foo xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx foo",
"URL": "foo xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx foo",
"URL": "foo xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx foo"
}`,
},
}
for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {
output := doOnJSON(test.input)
assert.Equal(t, test.expectedOutput, output)
})
}
}

View file

@ -0,0 +1,176 @@
package anonymize
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
type Courgette struct {
Ji string
Ho string
}
type Tomate struct {
Ji string
Ho string
}
type Carotte struct {
Name string
Value int
Courgette Courgette
ECourgette Courgette `export:"true"`
Pourgette *Courgette
EPourgette *Courgette `export:"true"`
Aubergine map[string]string
EAubergine map[string]string `export:"true"`
SAubergine map[string]Tomate
ESAubergine map[string]Tomate `export:"true"`
PSAubergine map[string]*Tomate
EPAubergine map[string]*Tomate `export:"true"`
}
func Test_doOnStruct(t *testing.T) {
testCase := []struct {
name string
base *Carotte
expected *Carotte
hasError bool
}{
{
name: "primitive",
base: &Carotte{
Name: "koko",
Value: 666,
},
expected: &Carotte{
Name: "xxxx",
},
},
{
name: "struct",
base: &Carotte{
Name: "koko",
Courgette: Courgette{
Ji: "huu",
},
},
expected: &Carotte{
Name: "xxxx",
},
},
{
name: "pointer",
base: &Carotte{
Name: "koko",
Pourgette: &Courgette{
Ji: "hoo",
},
},
expected: &Carotte{
Name: "xxxx",
Pourgette: nil,
},
},
{
name: "export struct",
base: &Carotte{
Name: "koko",
ECourgette: Courgette{
Ji: "huu",
},
},
expected: &Carotte{
Name: "xxxx",
ECourgette: Courgette{
Ji: "xxxx",
},
},
},
{
name: "export pointer struct",
base: &Carotte{
Name: "koko",
ECourgette: Courgette{
Ji: "huu",
},
},
expected: &Carotte{
Name: "xxxx",
ECourgette: Courgette{
Ji: "xxxx",
},
},
},
{
name: "export map string/string",
base: &Carotte{
Name: "koko",
EAubergine: map[string]string{
"foo": "bar",
},
},
expected: &Carotte{
Name: "xxxx",
EAubergine: map[string]string{
"foo": "bar",
},
},
},
{
name: "export map string/pointer",
base: &Carotte{
Name: "koko",
EPAubergine: map[string]*Tomate{
"foo": {
Ji: "fdskljf",
},
},
},
expected: &Carotte{
Name: "xxxx",
EPAubergine: map[string]*Tomate{
"foo": {
Ji: "xxxx",
},
},
},
},
{
name: "export map string/struct (UNSAFE)",
base: &Carotte{
Name: "koko",
ESAubergine: map[string]Tomate{
"foo": {
Ji: "JiJiJi",
},
},
},
expected: &Carotte{
Name: "xxxx",
ESAubergine: map[string]Tomate{
"foo": {
Ji: "JiJiJi",
},
},
},
hasError: true,
},
}
for _, test := range testCase {
t.Run(test.name, func(t *testing.T) {
val := reflect.ValueOf(test.base).Elem()
err := doOnStruct(val)
if !test.hasError && err != nil {
t.Fatal(err)
}
if test.hasError && err == nil {
t.Fatal("Got no error but want an error.")
}
assert.EqualValues(t, test.expected, test.base)
})
}
}

39
pkg/api/dashboard.go Normal file
View file

@ -0,0 +1,39 @@
package api
import (
"net/http"
"github.com/containous/mux"
"github.com/containous/traefik/pkg/log"
assetfs "github.com/elazarl/go-bindata-assetfs"
)
// DashboardHandler expose dashboard routes
type DashboardHandler struct {
Assets *assetfs.AssetFS
}
// Append add dashboard routes on a router
func (g DashboardHandler) Append(router *mux.Router) {
if g.Assets == nil {
log.WithoutContext().Error("No assets for dashboard")
return
}
// Expose dashboard
router.Methods(http.MethodGet).
Path("/").
HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
http.Redirect(response, request, request.Header.Get("X-Forwarded-Prefix")+"/dashboard/", http.StatusFound)
})
router.Methods(http.MethodGet).
Path("/dashboard/status").
HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
http.Redirect(response, request, "/dashboard/", http.StatusFound)
})
router.Methods(http.MethodGet).
PathPrefix("/dashboard/").
Handler(http.StripPrefix("/dashboard/", http.FileServer(g.Assets)))
}

49
pkg/api/debug.go Normal file
View file

@ -0,0 +1,49 @@
package api
import (
"expvar"
"fmt"
"net/http"
"net/http/pprof"
"runtime"
"github.com/containous/mux"
)
func init() {
// FIXME Goroutines2 -> Goroutines
expvar.Publish("Goroutines2", expvar.Func(goroutines))
}
func goroutines() interface{} {
return runtime.NumGoroutine()
}
// DebugHandler expose debug routes
type DebugHandler struct{}
// Append add debug routes on a router
func (g DebugHandler) Append(router *mux.Router) {
router.Methods(http.MethodGet).Path("/debug/vars").
HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
fmt.Fprint(w, "{\n")
first := true
expvar.Do(func(kv expvar.KeyValue) {
if !first {
fmt.Fprint(w, ",\n")
}
first = false
fmt.Fprintf(w, "%q: %s", kv.Key, kv.Value)
})
fmt.Fprint(w, "\n}\n")
})
runtime.SetBlockProfileRate(1)
runtime.SetMutexProfileFraction(5)
router.Methods(http.MethodGet).PathPrefix("/debug/pprof/cmdline").HandlerFunc(pprof.Cmdline)
router.Methods(http.MethodGet).PathPrefix("/debug/pprof/profile").HandlerFunc(pprof.Profile)
router.Methods(http.MethodGet).PathPrefix("/debug/pprof/symbol").HandlerFunc(pprof.Symbol)
router.Methods(http.MethodGet).PathPrefix("/debug/pprof/trace").HandlerFunc(pprof.Trace)
router.Methods(http.MethodGet).PathPrefix("/debug/pprof/").HandlerFunc(pprof.Index)
}

355
pkg/api/handler.go Normal file
View file

@ -0,0 +1,355 @@
package api
import (
"io"
"net/http"
"github.com/containous/mux"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/safe"
"github.com/containous/traefik/pkg/types"
"github.com/containous/traefik/pkg/version"
assetfs "github.com/elazarl/go-bindata-assetfs"
thoasstats "github.com/thoas/stats"
"github.com/unrolled/render"
)
// ResourceIdentifier a resource identifier
type ResourceIdentifier struct {
ID string `json:"id"`
Path string `json:"path"`
}
// ProviderRepresentation a provider with resource identifiers
type ProviderRepresentation struct {
Routers []ResourceIdentifier `json:"routers,omitempty"`
Middlewares []ResourceIdentifier `json:"middlewares,omitempty"`
Services []ResourceIdentifier `json:"services,omitempty"`
}
// RouterRepresentation extended version of a router configuration with an ID
type RouterRepresentation struct {
*config.Router
ID string `json:"id"`
}
// MiddlewareRepresentation extended version of a middleware configuration with an ID
type MiddlewareRepresentation struct {
*config.Middleware
ID string `json:"id"`
}
// ServiceRepresentation extended version of a service configuration with an ID
type ServiceRepresentation struct {
*config.Service
ID string `json:"id"`
}
// Handler expose api routes
type Handler struct {
EntryPoint string
Dashboard bool
Debug bool
CurrentConfigurations *safe.Safe
Statistics *types.Statistics
Stats *thoasstats.Stats
// StatsRecorder *middlewares.StatsRecorder // FIXME stats
DashboardAssets *assetfs.AssetFS
}
var templateRenderer jsonRenderer = render.New(render.Options{Directory: "nowhere"})
type jsonRenderer interface {
JSON(w io.Writer, status int, v interface{}) error
}
// Append add api routes on a router
func (h Handler) Append(router *mux.Router) {
if h.Debug {
DebugHandler{}.Append(router)
}
router.Methods(http.MethodGet).Path("/api/rawdata").HandlerFunc(h.getRawData)
router.Methods(http.MethodGet).Path("/api/providers").HandlerFunc(h.getProvidersHandler)
router.Methods(http.MethodGet).Path("/api/providers/{provider}").HandlerFunc(h.getProviderHandler)
router.Methods(http.MethodGet).Path("/api/providers/{provider}/routers").HandlerFunc(h.getRoutersHandler)
router.Methods(http.MethodGet).Path("/api/providers/{provider}/routers/{router}").HandlerFunc(h.getRouterHandler)
router.Methods(http.MethodGet).Path("/api/providers/{provider}/middlewares").HandlerFunc(h.getMiddlewaresHandler)
router.Methods(http.MethodGet).Path("/api/providers/{provider}/middlewares/{middleware}").HandlerFunc(h.getMiddlewareHandler)
router.Methods(http.MethodGet).Path("/api/providers/{provider}/services").HandlerFunc(h.getServicesHandler)
router.Methods(http.MethodGet).Path("/api/providers/{provider}/services/{service}").HandlerFunc(h.getServiceHandler)
// FIXME stats
// health route
//router.Methods(http.MethodGet).Path("/health").HandlerFunc(p.getHealthHandler)
version.Handler{}.Append(router)
if h.Dashboard {
DashboardHandler{Assets: h.DashboardAssets}.Append(router)
}
}
func (h Handler) getRawData(rw http.ResponseWriter, request *http.Request) {
if h.CurrentConfigurations != nil {
currentConfigurations, ok := h.CurrentConfigurations.Get().(config.Configurations)
if !ok {
rw.WriteHeader(http.StatusOK)
return
}
err := templateRenderer.JSON(rw, http.StatusOK, currentConfigurations)
if err != nil {
log.FromContext(request.Context()).Error(err)
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}
}
func (h Handler) getProvidersHandler(rw http.ResponseWriter, request *http.Request) {
// FIXME handle currentConfiguration
if h.CurrentConfigurations != nil {
currentConfigurations, ok := h.CurrentConfigurations.Get().(config.Configurations)
if !ok {
rw.WriteHeader(http.StatusOK)
return
}
var providers []ResourceIdentifier
for name := range currentConfigurations {
providers = append(providers, ResourceIdentifier{
ID: name,
Path: "/api/providers/" + name,
})
}
err := templateRenderer.JSON(rw, http.StatusOK, providers)
if err != nil {
log.FromContext(request.Context()).Error(err)
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}
}
func (h Handler) getProviderHandler(rw http.ResponseWriter, request *http.Request) {
providerID := mux.Vars(request)["provider"]
currentConfigurations := h.CurrentConfigurations.Get().(config.Configurations)
provider, ok := currentConfigurations[providerID]
if !ok {
http.NotFound(rw, request)
return
}
if provider.HTTP == nil {
http.NotFound(rw, request)
return
}
var routers []ResourceIdentifier
for name := range provider.HTTP.Routers {
routers = append(routers, ResourceIdentifier{
ID: name,
Path: "/api/providers/" + providerID + "/routers",
})
}
var services []ResourceIdentifier
for name := range provider.HTTP.Services {
services = append(services, ResourceIdentifier{
ID: name,
Path: "/api/providers/" + providerID + "/services",
})
}
var middlewares []ResourceIdentifier
for name := range provider.HTTP.Middlewares {
middlewares = append(middlewares, ResourceIdentifier{
ID: name,
Path: "/api/providers/" + providerID + "/middlewares",
})
}
providers := ProviderRepresentation{Routers: routers, Middlewares: middlewares, Services: services}
err := templateRenderer.JSON(rw, http.StatusOK, providers)
if err != nil {
log.FromContext(request.Context()).Error(err)
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}
func (h Handler) getRoutersHandler(rw http.ResponseWriter, request *http.Request) {
providerID := mux.Vars(request)["provider"]
currentConfigurations := h.CurrentConfigurations.Get().(config.Configurations)
provider, ok := currentConfigurations[providerID]
if !ok {
http.NotFound(rw, request)
return
}
if provider.HTTP == nil {
http.NotFound(rw, request)
return
}
var routers []RouterRepresentation
for name, router := range provider.HTTP.Routers {
routers = append(routers, RouterRepresentation{Router: router, ID: name})
}
err := templateRenderer.JSON(rw, http.StatusOK, routers)
if err != nil {
log.FromContext(request.Context()).Error(err)
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}
func (h Handler) getRouterHandler(rw http.ResponseWriter, request *http.Request) {
providerID := mux.Vars(request)["provider"]
routerID := mux.Vars(request)["router"]
currentConfigurations := h.CurrentConfigurations.Get().(config.Configurations)
provider, ok := currentConfigurations[providerID]
if !ok {
http.NotFound(rw, request)
return
}
if provider.HTTP == nil {
http.NotFound(rw, request)
return
}
router, ok := provider.HTTP.Routers[routerID]
if !ok {
http.NotFound(rw, request)
return
}
err := templateRenderer.JSON(rw, http.StatusOK, router)
if err != nil {
log.FromContext(request.Context()).Error(err)
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}
func (h Handler) getMiddlewaresHandler(rw http.ResponseWriter, request *http.Request) {
providerID := mux.Vars(request)["provider"]
currentConfigurations := h.CurrentConfigurations.Get().(config.Configurations)
provider, ok := currentConfigurations[providerID]
if !ok {
http.NotFound(rw, request)
return
}
if provider.HTTP == nil {
http.NotFound(rw, request)
return
}
var middlewares []MiddlewareRepresentation
for name, middleware := range provider.HTTP.Middlewares {
middlewares = append(middlewares, MiddlewareRepresentation{Middleware: middleware, ID: name})
}
err := templateRenderer.JSON(rw, http.StatusOK, middlewares)
if err != nil {
log.FromContext(request.Context()).Error(err)
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}
func (h Handler) getMiddlewareHandler(rw http.ResponseWriter, request *http.Request) {
providerID := mux.Vars(request)["provider"]
middlewareID := mux.Vars(request)["middleware"]
currentConfigurations := h.CurrentConfigurations.Get().(config.Configurations)
provider, ok := currentConfigurations[providerID]
if !ok {
http.NotFound(rw, request)
return
}
if provider.HTTP == nil {
http.NotFound(rw, request)
return
}
middleware, ok := provider.HTTP.Middlewares[middlewareID]
if !ok {
http.NotFound(rw, request)
return
}
err := templateRenderer.JSON(rw, http.StatusOK, middleware)
if err != nil {
log.FromContext(request.Context()).Error(err)
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}
func (h Handler) getServicesHandler(rw http.ResponseWriter, request *http.Request) {
providerID := mux.Vars(request)["provider"]
currentConfigurations := h.CurrentConfigurations.Get().(config.Configurations)
provider, ok := currentConfigurations[providerID]
if !ok {
http.NotFound(rw, request)
return
}
if provider.HTTP == nil {
http.NotFound(rw, request)
return
}
var services []ServiceRepresentation
for name, service := range provider.HTTP.Services {
services = append(services, ServiceRepresentation{Service: service, ID: name})
}
err := templateRenderer.JSON(rw, http.StatusOK, services)
if err != nil {
log.FromContext(request.Context()).Error(err)
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}
func (h Handler) getServiceHandler(rw http.ResponseWriter, request *http.Request) {
providerID := mux.Vars(request)["provider"]
serviceID := mux.Vars(request)["service"]
currentConfigurations := h.CurrentConfigurations.Get().(config.Configurations)
provider, ok := currentConfigurations[providerID]
if !ok {
http.NotFound(rw, request)
return
}
if provider.HTTP == nil {
http.NotFound(rw, request)
return
}
service, ok := provider.HTTP.Services[serviceID]
if !ok {
http.NotFound(rw, request)
return
}
err := templateRenderer.JSON(rw, http.StatusOK, service)
if err != nil {
log.FromContext(request.Context()).Error(err)
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}

226
pkg/api/handler_test.go Normal file
View file

@ -0,0 +1,226 @@
package api
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/mux"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/safe"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHandler_Configuration(t *testing.T) {
type expected struct {
statusCode int
body string
}
testCases := []struct {
desc string
path string
configuration config.Configurations
expected expected
}{
{
desc: "Get all the providers",
path: "/api/providers",
configuration: config.Configurations{
"foo": {
HTTP: &config.HTTPConfiguration{
Routers: map[string]*config.Router{
"bar": {EntryPoints: []string{"foo", "bar"}},
},
},
},
},
expected: expected{statusCode: http.StatusOK, body: `[{"id":"foo","path":"/api/providers/foo"}]`},
},
{
desc: "Get a provider",
path: "/api/providers/foo",
configuration: config.Configurations{
"foo": {
HTTP: &config.HTTPConfiguration{
Routers: map[string]*config.Router{
"bar": {EntryPoints: []string{"foo", "bar"}},
},
Middlewares: map[string]*config.Middleware{
"bar": {
AddPrefix: &config.AddPrefix{Prefix: "bar"},
},
},
Services: map[string]*config.Service{
"foo": {
LoadBalancer: &config.LoadBalancerService{
Method: "wrr",
},
},
},
},
},
},
expected: expected{statusCode: http.StatusOK, body: `{"routers":[{"id":"bar","path":"/api/providers/foo/routers"}],"middlewares":[{"id":"bar","path":"/api/providers/foo/middlewares"}],"services":[{"id":"foo","path":"/api/providers/foo/services"}]}`},
},
{
desc: "Provider not found",
path: "/api/providers/foo",
configuration: config.Configurations{},
expected: expected{statusCode: http.StatusNotFound, body: "404 page not found\n"},
},
{
desc: "Get all routers",
path: "/api/providers/foo/routers",
configuration: config.Configurations{
"foo": {
HTTP: &config.HTTPConfiguration{
Routers: map[string]*config.Router{
"bar": {EntryPoints: []string{"foo", "bar"}},
},
},
},
},
expected: expected{statusCode: http.StatusOK, body: `[{"entryPoints":["foo","bar"],"id":"bar"}]`},
},
{
desc: "Get a router",
path: "/api/providers/foo/routers/bar",
configuration: config.Configurations{
"foo": {
HTTP: &config.HTTPConfiguration{
Routers: map[string]*config.Router{
"bar": {EntryPoints: []string{"foo", "bar"}},
},
},
},
},
expected: expected{statusCode: http.StatusOK, body: `{"entryPoints":["foo","bar"]}`},
},
{
desc: "Router not found",
path: "/api/providers/foo/routers/bar",
configuration: config.Configurations{
"foo": {},
},
expected: expected{statusCode: http.StatusNotFound, body: "404 page not found\n"},
},
{
desc: "Get all services",
path: "/api/providers/foo/services",
configuration: config.Configurations{
"foo": {
HTTP: &config.HTTPConfiguration{
Services: map[string]*config.Service{
"foo": {
LoadBalancer: &config.LoadBalancerService{
Method: "wrr",
},
},
},
},
},
},
expected: expected{statusCode: http.StatusOK, body: `[{"loadbalancer":{"method":"wrr","passHostHeader":false},"id":"foo"}]`},
},
{
desc: "Get a service",
path: "/api/providers/foo/services/foo",
configuration: config.Configurations{
"foo": {
HTTP: &config.HTTPConfiguration{
Services: map[string]*config.Service{
"foo": {
LoadBalancer: &config.LoadBalancerService{
Method: "wrr",
},
},
},
},
},
},
expected: expected{statusCode: http.StatusOK, body: `{"loadbalancer":{"method":"wrr","passHostHeader":false}}`},
},
{
desc: "Service not found",
path: "/api/providers/foo/services/bar",
configuration: config.Configurations{
"foo": {},
},
expected: expected{statusCode: http.StatusNotFound, body: "404 page not found\n"},
},
{
desc: "Get all middlewares",
path: "/api/providers/foo/middlewares",
configuration: config.Configurations{
"foo": {
HTTP: &config.HTTPConfiguration{
Middlewares: map[string]*config.Middleware{
"bar": {
AddPrefix: &config.AddPrefix{Prefix: "bar"},
},
},
},
},
},
expected: expected{statusCode: http.StatusOK, body: `[{"addPrefix":{"prefix":"bar"},"id":"bar"}]`},
},
{
desc: "Get a middleware",
path: "/api/providers/foo/middlewares/bar",
configuration: config.Configurations{
"foo": {
HTTP: &config.HTTPConfiguration{
Middlewares: map[string]*config.Middleware{
"bar": {
AddPrefix: &config.AddPrefix{Prefix: "bar"},
},
},
},
},
},
expected: expected{statusCode: http.StatusOK, body: `{"addPrefix":{"prefix":"bar"}}`},
},
{
desc: "Middleware not found",
path: "/api/providers/foo/middlewares/bar",
configuration: config.Configurations{
"foo": {},
},
expected: expected{statusCode: http.StatusNotFound, body: "404 page not found\n"},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
currentConfiguration := &safe.Safe{}
currentConfiguration.Set(test.configuration)
handler := Handler{
CurrentConfigurations: currentConfiguration,
}
router := mux.NewRouter()
handler.Append(router)
server := httptest.NewServer(router)
resp, err := http.DefaultClient.Get(server.URL + test.path)
require.NoError(t, err)
assert.Equal(t, test.expected.statusCode, resp.StatusCode)
content, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
err = resp.Body.Close()
require.NoError(t, err)
assert.Equal(t, test.expected.body, string(content))
})
}
}

View file

@ -0,0 +1,80 @@
package collector
import (
"bytes"
"encoding/base64"
"encoding/json"
"net"
"net/http"
"strconv"
"time"
"github.com/containous/traefik/old/configuration"
"github.com/containous/traefik/pkg/anonymize"
"github.com/containous/traefik/pkg/config/static"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/version"
"github.com/mitchellh/hashstructure"
)
// collectorURL URL where the stats are send
const collectorURL = "https://collect.traefik.io/9vxmmkcdmalbdi635d4jgc5p5rx0h7h8"
// Collected data
type data struct {
Version string
Codename string
BuildDate string
Configuration string
Hash string
}
// Collect anonymous data.
func Collect(staticConfiguration *static.Configuration) error {
anonConfig, err := anonymize.Do(staticConfiguration, false)
if err != nil {
return err
}
log.Infof("Anonymous stats sent to %s: %s", collectorURL, anonConfig)
hashConf, err := hashstructure.Hash(staticConfiguration, nil)
if err != nil {
return err
}
data := &data{
Version: version.Version,
Codename: version.Codename,
BuildDate: version.BuildDate,
Hash: strconv.FormatUint(hashConf, 10),
Configuration: base64.StdEncoding.EncodeToString([]byte(anonConfig)),
}
buf := new(bytes.Buffer)
err = json.NewEncoder(buf).Encode(data)
if err != nil {
return err
}
_, err = makeHTTPClient().Post(collectorURL, "application/json; charset=utf-8", buf)
return err
}
func makeHTTPClient() *http.Client {
dialer := &net.Dialer{
Timeout: configuration.DefaultDialTimeout,
KeepAlive: 30 * time.Second,
DualStack: true,
}
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialer.DialContext,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
return &http.Client{Transport: transport}
}

230
pkg/config/dyn_config.go Normal file
View file

@ -0,0 +1,230 @@
package config
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"os"
"reflect"
traefiktls "github.com/containous/traefik/pkg/tls"
)
// Router holds the router configuration.
type Router struct {
EntryPoints []string `json:"entryPoints"`
Middlewares []string `json:"middlewares,omitempty" toml:",omitempty"`
Service string `json:"service,omitempty" toml:",omitempty"`
Rule string `json:"rule,omitempty" toml:",omitempty"`
Priority int `json:"priority,omitempty" toml:"priority,omitzero"`
TLS *RouterTLSConfig `json:"tls,omitempty" toml:"tls,omitzero" label:"allowEmpty"`
}
// RouterTLSConfig holds the TLS configuration for a router
type RouterTLSConfig struct{}
// TCPRouter holds the router configuration.
type TCPRouter struct {
EntryPoints []string `json:"entryPoints"`
Service string `json:"service,omitempty" toml:",omitempty"`
Rule string `json:"rule,omitempty" toml:",omitempty"`
TLS *RouterTCPTLSConfig `json:"tls,omitempty" toml:"tls,omitzero" label:"allowEmpty"`
}
// RouterTCPTLSConfig holds the TLS configuration for a router
type RouterTCPTLSConfig struct {
Passthrough bool `json:"passthrough" toml:"passthrough,omitzero"`
}
// LoadBalancerService holds the LoadBalancerService configuration.
type LoadBalancerService struct {
Stickiness *Stickiness `json:"stickiness,omitempty" toml:",omitempty" label:"allowEmpty"`
Servers []Server `json:"servers,omitempty" toml:",omitempty" label-slice-as-struct:"server"`
Method string `json:"method,omitempty" toml:",omitempty"`
HealthCheck *HealthCheck `json:"healthCheck,omitempty" toml:",omitempty"`
PassHostHeader bool `json:"passHostHeader" toml:",omitempty"`
ResponseForwarding *ResponseForwarding `json:"forwardingResponse,omitempty" toml:",omitempty"`
}
// TCPLoadBalancerService holds the LoadBalancerService configuration.
type TCPLoadBalancerService struct {
Servers []TCPServer `json:"servers,omitempty" toml:",omitempty" label-slice-as-struct:"server"`
Method string `json:"method,omitempty" toml:",omitempty"`
}
// Mergeable tells if the given service is mergeable.
func (l *LoadBalancerService) Mergeable(loadBalancer *LoadBalancerService) bool {
savedServers := l.Servers
defer func() {
l.Servers = savedServers
}()
l.Servers = nil
savedServersLB := loadBalancer.Servers
defer func() {
loadBalancer.Servers = savedServersLB
}()
loadBalancer.Servers = nil
return reflect.DeepEqual(l, loadBalancer)
}
// SetDefaults Default values for a LoadBalancerService.
func (l *LoadBalancerService) SetDefaults() {
l.PassHostHeader = true
l.Method = "wrr"
}
// ResponseForwarding holds configuration for the forward of the response.
type ResponseForwarding struct {
FlushInterval string `json:"flushInterval,omitempty" toml:",omitempty"`
}
// Stickiness holds the stickiness configuration.
type Stickiness struct {
CookieName string `json:"cookieName,omitempty" toml:",omitempty"`
}
// Server holds the server configuration.
type Server struct {
URL string `json:"url" label:"-"`
Scheme string `toml:"-" json:"-"`
Port string `toml:"-" json:"-"`
Weight int `json:"weight"`
}
// TCPServer holds a TCP Server configuration
type TCPServer struct {
Address string `json:"address" label:"-"`
Weight int `json:"weight"`
}
// SetDefaults Default values for a Server.
func (s *Server) SetDefaults() {
s.Weight = 1
s.Scheme = "http"
}
// HealthCheck holds the HealthCheck configuration.
type HealthCheck struct {
Scheme string `json:"scheme,omitempty" toml:",omitempty"`
Path string `json:"path,omitempty" toml:",omitempty"`
Port int `json:"port,omitempty" toml:",omitempty,omitzero"`
// FIXME change string to parse.Duration
Interval string `json:"interval,omitempty" toml:",omitempty"`
// FIXME change string to parse.Duration
Timeout string `json:"timeout,omitempty" toml:",omitempty"`
Hostname string `json:"hostname,omitempty" toml:",omitempty"`
Headers map[string]string `json:"headers,omitempty" toml:",omitempty"`
}
// CreateTLSConfig creates a TLS config from ClientTLS structures.
func (clientTLS *ClientTLS) CreateTLSConfig() (*tls.Config, error) {
if clientTLS == nil {
return nil, nil
}
var err error
caPool := x509.NewCertPool()
clientAuth := tls.NoClientCert
if clientTLS.CA != "" {
var ca []byte
if _, errCA := os.Stat(clientTLS.CA); errCA == nil {
ca, err = ioutil.ReadFile(clientTLS.CA)
if err != nil {
return nil, fmt.Errorf("failed to read CA. %s", err)
}
} else {
ca = []byte(clientTLS.CA)
}
if !caPool.AppendCertsFromPEM(ca) {
return nil, fmt.Errorf("failed to parse CA")
}
if clientTLS.CAOptional {
clientAuth = tls.VerifyClientCertIfGiven
} else {
clientAuth = tls.RequireAndVerifyClientCert
}
}
cert := tls.Certificate{}
_, errKeyIsFile := os.Stat(clientTLS.Key)
if !clientTLS.InsecureSkipVerify && (len(clientTLS.Cert) == 0 || len(clientTLS.Key) == 0) {
return nil, fmt.Errorf("TLS Certificate or Key file must be set when TLS configuration is created")
}
if len(clientTLS.Cert) > 0 && len(clientTLS.Key) > 0 {
if _, errCertIsFile := os.Stat(clientTLS.Cert); errCertIsFile == nil {
if errKeyIsFile == nil {
cert, err = tls.LoadX509KeyPair(clientTLS.Cert, clientTLS.Key)
if err != nil {
return nil, fmt.Errorf("failed to load TLS keypair: %v", err)
}
} else {
return nil, fmt.Errorf("tls cert is a file, but tls key is not")
}
} else {
if errKeyIsFile != nil {
cert, err = tls.X509KeyPair([]byte(clientTLS.Cert), []byte(clientTLS.Key))
if err != nil {
return nil, fmt.Errorf("failed to load TLS keypair: %v", err)
}
} else {
return nil, fmt.Errorf("TLS key is a file, but tls cert is not")
}
}
}
return &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: caPool,
InsecureSkipVerify: clientTLS.InsecureSkipVerify,
ClientAuth: clientAuth,
}, nil
}
// Message holds configuration information exchanged between parts of traefik.
type Message struct {
ProviderName string
Configuration *Configuration
}
// Configuration is the root of the dynamic configuration
type Configuration struct {
HTTP *HTTPConfiguration
TCP *TCPConfiguration
TLS []*traefiktls.Configuration `json:"-" label:"-"`
TLSOptions map[string]traefiktls.TLS
TLSStores map[string]traefiktls.Store
}
// Configurations is for currentConfigurations Map.
type Configurations map[string]*Configuration
// HTTPConfiguration FIXME better name?
type HTTPConfiguration struct {
Routers map[string]*Router `json:"routers,omitempty" toml:",omitempty"`
Middlewares map[string]*Middleware `json:"middlewares,omitempty" toml:",omitempty"`
Services map[string]*Service `json:"services,omitempty" toml:",omitempty"`
}
// TCPConfiguration FIXME better name?
type TCPConfiguration struct {
Routers map[string]*TCPRouter `json:"routers,omitempty" toml:",omitempty"`
Services map[string]*TCPService `json:"services,omitempty" toml:",omitempty"`
}
// Service holds a service configuration (can only be of one type at the same time).
type Service struct {
LoadBalancer *LoadBalancerService `json:"loadbalancer,omitempty" toml:",omitempty,omitzero"`
}
// TCPService holds a tcp service configuration (can only be of one type at the same time).
type TCPService struct {
LoadBalancer *TCPLoadBalancerService `json:"loadbalancer,omitempty" toml:",omitempty,omitzero"`
}

363
pkg/config/middlewares.go Normal file
View file

@ -0,0 +1,363 @@
package config
import (
"github.com/containous/flaeg/parse"
"github.com/containous/traefik/pkg/ip"
)
// +k8s:deepcopy-gen=true
// Middleware holds the Middleware configuration.
type Middleware struct {
AddPrefix *AddPrefix `json:"addPrefix,omitempty"`
StripPrefix *StripPrefix `json:"stripPrefix,omitempty"`
StripPrefixRegex *StripPrefixRegex `json:"stripPrefixRegex,omitempty"`
ReplacePath *ReplacePath `json:"replacePath,omitempty"`
ReplacePathRegex *ReplacePathRegex `json:"replacePathRegex,omitempty"`
Chain *Chain `json:"chain,omitempty"`
IPWhiteList *IPWhiteList `json:"ipWhiteList,omitempty"`
Headers *Headers `json:"headers,omitempty"`
Errors *ErrorPage `json:"errors,omitempty"`
RateLimit *RateLimit `json:"rateLimit,omitempty"`
RedirectRegex *RedirectRegex `json:"redirectregex,omitempty"`
RedirectScheme *RedirectScheme `json:"redirectscheme,omitempty"`
BasicAuth *BasicAuth `json:"basicAuth,omitempty"`
DigestAuth *DigestAuth `json:"digestAuth,omitempty"`
ForwardAuth *ForwardAuth `json:"forwardAuth,omitempty"`
MaxConn *MaxConn `json:"maxConn,omitempty"`
Buffering *Buffering `json:"buffering,omitempty"`
CircuitBreaker *CircuitBreaker `json:"circuitBreaker,omitempty"`
Compress *Compress `json:"compress,omitempty" label:"allowEmpty"`
PassTLSClientCert *PassTLSClientCert `json:"passTLSClientCert,omitempty"`
Retry *Retry `json:"retry,omitempty"`
}
// +k8s:deepcopy-gen=true
// AddPrefix holds the AddPrefix configuration.
type AddPrefix struct {
Prefix string `json:"prefix,omitempty"`
}
// +k8s:deepcopy-gen=true
// Auth holds the authentication configuration (BASIC, DIGEST, users).
type Auth struct {
Basic *BasicAuth `json:"basic,omitempty" export:"true"`
Digest *DigestAuth `json:"digest,omitempty" export:"true"`
Forward *ForwardAuth `json:"forward,omitempty" export:"true"`
}
// +k8s:deepcopy-gen=true
// BasicAuth holds the HTTP basic authentication configuration.
type BasicAuth struct {
Users `json:"users,omitempty" mapstructure:","`
UsersFile string `json:"usersFile,omitempty"`
Realm string `json:"realm,omitempty"`
RemoveHeader bool `json:"removeHeader,omitempty"`
HeaderField string `json:"headerField,omitempty" export:"true"`
}
// +k8s:deepcopy-gen=true
// Buffering holds the request/response buffering configuration.
type Buffering struct {
MaxRequestBodyBytes int64 `json:"maxRequestBodyBytes,omitempty"`
MemRequestBodyBytes int64 `json:"memRequestBodyBytes,omitempty"`
MaxResponseBodyBytes int64 `json:"maxResponseBodyBytes,omitempty"`
MemResponseBodyBytes int64 `json:"memResponseBodyBytes,omitempty"`
RetryExpression string `json:"retryExpression,omitempty"`
}
// +k8s:deepcopy-gen=true
// Chain holds a chain of middlewares
type Chain struct {
Middlewares []string `json:"middlewares"`
}
// +k8s:deepcopy-gen=true
// CircuitBreaker holds the circuit breaker configuration.
type CircuitBreaker struct {
Expression string `json:"expression,omitempty"`
}
// +k8s:deepcopy-gen=true
// Compress holds the compress configuration.
type Compress struct{}
// +k8s:deepcopy-gen=true
// DigestAuth holds the Digest HTTP authentication configuration.
type DigestAuth struct {
Users `json:"users,omitempty" mapstructure:","`
UsersFile string `json:"usersFile,omitempty"`
RemoveHeader bool `json:"removeHeader,omitempty"`
Realm string `json:"realm,omitempty" mapstructure:","`
HeaderField string `json:"headerField,omitempty" export:"true"`
}
// +k8s:deepcopy-gen=true
// ErrorPage holds the custom error page configuration.
type ErrorPage struct {
Status []string `json:"status,omitempty"`
Service string `json:"service,omitempty"`
Query string `json:"query,omitempty"`
}
// +k8s:deepcopy-gen=true
// ForwardAuth holds the http forward authentication configuration.
type ForwardAuth struct {
Address string `description:"Authentication server address" json:"address,omitempty"`
TLS *ClientTLS `description:"Enable TLS support" json:"tls,omitempty" export:"true"`
TrustForwardHeader bool `description:"Trust X-Forwarded-* headers" json:"trustForwardHeader,omitempty" export:"true"`
AuthResponseHeaders []string `description:"Headers to be forwarded from auth response" json:"authResponseHeaders,omitempty"`
}
// +k8s:deepcopy-gen=true
// Headers holds the custom header configuration.
type Headers struct {
CustomRequestHeaders map[string]string `json:"customRequestHeaders,omitempty"`
CustomResponseHeaders map[string]string `json:"customResponseHeaders,omitempty"`
AllowedHosts []string `json:"allowedHosts,omitempty"`
HostsProxyHeaders []string `json:"hostsProxyHeaders,omitempty"`
SSLRedirect bool `json:"sslRedirect,omitempty"`
SSLTemporaryRedirect bool `json:"sslTemporaryRedirect,omitempty"`
SSLHost string `json:"sslHost,omitempty"`
SSLProxyHeaders map[string]string `json:"sslProxyHeaders,omitempty"`
SSLForceHost bool `json:"sslForceHost,omitempty"`
STSSeconds int64 `json:"stsSeconds,omitempty"`
STSIncludeSubdomains bool `json:"stsIncludeSubdomains,omitempty"`
STSPreload bool `json:"stsPreload,omitempty"`
ForceSTSHeader bool `json:"forceSTSHeader,omitempty"`
FrameDeny bool `json:"frameDeny,omitempty"`
CustomFrameOptionsValue string `json:"customFrameOptionsValue,omitempty"`
ContentTypeNosniff bool `json:"contentTypeNosniff,omitempty"`
BrowserXSSFilter bool `json:"browserXssFilter,omitempty"`
CustomBrowserXSSValue string `json:"customBrowserXSSValue,omitempty"`
ContentSecurityPolicy string `json:"contentSecurityPolicy,omitempty"`
PublicKey string `json:"publicKey,omitempty"`
ReferrerPolicy string `json:"referrerPolicy,omitempty"`
IsDevelopment bool `json:"isDevelopment,omitempty"`
}
// HasCustomHeadersDefined checks to see if any of the custom header elements have been set
func (h *Headers) HasCustomHeadersDefined() bool {
return h != nil && (len(h.CustomResponseHeaders) != 0 ||
len(h.CustomRequestHeaders) != 0)
}
// HasSecureHeadersDefined checks to see if any of the secure header elements have been set
func (h *Headers) HasSecureHeadersDefined() bool {
return h != nil && (len(h.AllowedHosts) != 0 ||
len(h.HostsProxyHeaders) != 0 ||
h.SSLRedirect ||
h.SSLTemporaryRedirect ||
h.SSLForceHost ||
h.SSLHost != "" ||
len(h.SSLProxyHeaders) != 0 ||
h.STSSeconds != 0 ||
h.STSIncludeSubdomains ||
h.STSPreload ||
h.ForceSTSHeader ||
h.FrameDeny ||
h.CustomFrameOptionsValue != "" ||
h.ContentTypeNosniff ||
h.BrowserXSSFilter ||
h.CustomBrowserXSSValue != "" ||
h.ContentSecurityPolicy != "" ||
h.PublicKey != "" ||
h.ReferrerPolicy != "" ||
h.IsDevelopment)
}
// +k8s:deepcopy-gen=true
// IPStrategy holds the ip strategy configuration.
type IPStrategy struct {
Depth int `json:"depth,omitempty" export:"true"`
ExcludedIPs []string `json:"excludedIPs,omitempty"`
}
// Get an IP selection strategy
// if nil return the RemoteAddr strategy
// else return a strategy base on the configuration using the X-Forwarded-For Header.
// Depth override the ExcludedIPs
func (s *IPStrategy) Get() (ip.Strategy, error) {
if s == nil {
return &ip.RemoteAddrStrategy{}, nil
}
if s.Depth > 0 {
return &ip.DepthStrategy{
Depth: s.Depth,
}, nil
}
if len(s.ExcludedIPs) > 0 {
checker, err := ip.NewChecker(s.ExcludedIPs)
if err != nil {
return nil, err
}
return &ip.CheckerStrategy{
Checker: checker,
}, nil
}
return &ip.RemoteAddrStrategy{}, nil
}
// +k8s:deepcopy-gen=true
// IPWhiteList holds the ip white list configuration.
type IPWhiteList struct {
SourceRange []string `json:"sourceRange,omitempty"`
IPStrategy *IPStrategy `json:"ipStrategy,omitempty" label:"allowEmpty"`
}
// +k8s:deepcopy-gen=true
// MaxConn holds maximum connection configuration.
type MaxConn struct {
Amount int64 `json:"amount,omitempty"`
ExtractorFunc string `json:"extractorFunc,omitempty"`
}
// SetDefaults Default values for a MaxConn.
func (m *MaxConn) SetDefaults() {
m.ExtractorFunc = "request.host"
}
// +k8s:deepcopy-gen=true
// PassTLSClientCert holds the TLS client cert headers configuration.
type PassTLSClientCert struct {
PEM bool `description:"Enable header with escaped client pem" json:"pem"`
Info *TLSClientCertificateInfo `description:"Enable header with configured client cert info" json:"info,omitempty"`
}
// +k8s:deepcopy-gen=true
// Rate holds the rate limiting configuration for a specific time period.
type Rate struct {
Period parse.Duration `json:"period,omitempty"`
Average int64 `json:"average,omitempty"`
Burst int64 `json:"burst,omitempty"`
}
// +k8s:deepcopy-gen=true
// RateLimit holds the rate limiting configuration for a given frontend.
type RateLimit struct {
RateSet map[string]*Rate `json:"rateset,omitempty"`
// FIXME replace by ipStrategy see oxy and replace
ExtractorFunc string `json:"extractorFunc,omitempty"`
}
// SetDefaults Default values for a MaxConn.
func (r *RateLimit) SetDefaults() {
r.ExtractorFunc = "request.host"
}
// +k8s:deepcopy-gen=true
// RedirectRegex holds the redirection configuration.
type RedirectRegex struct {
Regex string `json:"regex,omitempty"`
Replacement string `json:"replacement,omitempty"`
Permanent bool `json:"permanent,omitempty"`
}
// +k8s:deepcopy-gen=true
// RedirectScheme holds the scheme redirection configuration.
type RedirectScheme struct {
Scheme string `json:"scheme,omitempty"`
Port string `json:"port,omitempty"`
Permanent bool `json:"permanent,omitempty"`
}
// +k8s:deepcopy-gen=true
// ReplacePath holds the ReplacePath configuration.
type ReplacePath struct {
Path string `json:"path,omitempty"`
}
// +k8s:deepcopy-gen=true
// ReplacePathRegex holds the ReplacePathRegex configuration.
type ReplacePathRegex struct {
Regex string `json:"regex,omitempty"`
Replacement string `json:"replacement,omitempty"`
}
// +k8s:deepcopy-gen=true
// Retry holds the retry configuration.
type Retry struct {
Attempts int `description:"Number of attempts" export:"true"`
}
// +k8s:deepcopy-gen=true
// StripPrefix holds the StripPrefix configuration.
type StripPrefix struct {
Prefixes []string `json:"prefixes,omitempty"`
}
// +k8s:deepcopy-gen=true
// StripPrefixRegex holds the StripPrefixRegex configuration.
type StripPrefixRegex struct {
Regex []string `json:"regex,omitempty"`
}
// +k8s:deepcopy-gen=true
// TLSClientCertificateInfo holds the client TLS certificate info configuration.
type TLSClientCertificateInfo struct {
NotAfter bool `description:"Add NotAfter info in header" json:"notAfter"`
NotBefore bool `description:"Add NotBefore info in header" json:"notBefore"`
Sans bool `description:"Add Sans info in header" json:"sans"`
Subject *TLSCLientCertificateDNInfo `description:"Add Subject info in header" json:"subject,omitempty"`
Issuer *TLSCLientCertificateDNInfo `description:"Add Issuer info in header" json:"issuer,omitempty"`
}
// +k8s:deepcopy-gen=true
// TLSCLientCertificateDNInfo holds the client TLS certificate distinguished name info configuration
// cf https://tools.ietf.org/html/rfc3739
type TLSCLientCertificateDNInfo struct {
Country bool `description:"Add Country info in header" json:"country"`
Province bool `description:"Add Province info in header" json:"province"`
Locality bool `description:"Add Locality info in header" json:"locality"`
Organization bool `description:"Add Organization info in header" json:"organization"`
CommonName bool `description:"Add CommonName info in header" json:"commonName"`
SerialNumber bool `description:"Add SerialNumber info in header" json:"serialNumber"`
DomainComponent bool `description:"Add Domain Component info in header" json:"domainComponent"`
}
// +k8s:deepcopy-gen=true
// Users holds a list of users
type Users []string
// +k8s:deepcopy-gen=true
// ClientTLS holds the TLS specific configurations as client
// CA, Cert and Key can be either path or file contents.
type ClientTLS struct {
CA string `description:"TLS CA" json:"ca,omitempty"`
CAOptional bool `description:"TLS CA.Optional" json:"caOptional,omitempty"`
Cert string `description:"TLS cert" json:"cert,omitempty"`
Key string `description:"TLS key" json:"key,omitempty"`
InsecureSkipVerify bool `description:"TLS insecure skip verify" json:"insecureSkipVerify,omitempty"`
}

View file

@ -0,0 +1,134 @@
package static
import (
"fmt"
"strings"
"github.com/containous/traefik/pkg/log"
)
// EntryPoint holds the entry point configuration.
type EntryPoint struct {
Address string
Transport *EntryPointsTransport
ProxyProtocol *ProxyProtocol
ForwardedHeaders *ForwardedHeaders
}
// ForwardedHeaders Trust client forwarding headers.
type ForwardedHeaders struct {
Insecure bool
TrustedIPs []string
}
// ProxyProtocol contains Proxy-Protocol configuration.
type ProxyProtocol struct {
Insecure bool `export:"true"`
TrustedIPs []string
}
// EntryPoints holds the HTTP entry point list.
type EntryPoints map[string]*EntryPoint
// EntryPointsTransport configures communication between clients and Traefik.
type EntryPointsTransport struct {
LifeCycle *LifeCycle `description:"Timeouts influencing the server life cycle" export:"true"`
RespondingTimeouts *RespondingTimeouts `description:"Timeouts for incoming requests to the Traefik instance" export:"true"`
}
// String is the method to format the flag's value, part of the flag.Value interface.
// The String method's output will be used in diagnostics.
func (ep EntryPoints) String() string {
return fmt.Sprintf("%+v", map[string]*EntryPoint(ep))
}
// Get return the EntryPoints map.
func (ep *EntryPoints) Get() interface{} {
return *ep
}
// SetValue sets the EntryPoints map with val.
func (ep *EntryPoints) SetValue(val interface{}) {
*ep = val.(EntryPoints)
}
// Type is type of the struct.
func (ep *EntryPoints) Type() string {
return "entrypoints"
}
// Set is the method to set the flag value, part of the flag.Value interface.
// Set's argument is a string to be parsed to set the flag.
// It's a comma-separated list, so we split it.
func (ep *EntryPoints) Set(value string) error {
result := parseEntryPointsConfiguration(value)
(*ep)[result["name"]] = &EntryPoint{
Address: result["address"],
ProxyProtocol: makeEntryPointProxyProtocol(result),
ForwardedHeaders: makeEntryPointForwardedHeaders(result),
}
return nil
}
func makeEntryPointProxyProtocol(result map[string]string) *ProxyProtocol {
var proxyProtocol *ProxyProtocol
ppTrustedIPs := result["proxyprotocol_trustedips"]
if len(result["proxyprotocol_insecure"]) > 0 || len(ppTrustedIPs) > 0 {
proxyProtocol = &ProxyProtocol{
Insecure: toBool(result, "proxyprotocol_insecure"),
}
if len(ppTrustedIPs) > 0 {
proxyProtocol.TrustedIPs = strings.Split(ppTrustedIPs, ",")
}
}
if proxyProtocol != nil && proxyProtocol.Insecure {
log.Warn("ProxyProtocol.insecure:true is dangerous. Please use 'ProxyProtocol.TrustedIPs:IPs' and remove 'ProxyProtocol.insecure:true'")
}
return proxyProtocol
}
func parseEntryPointsConfiguration(raw string) map[string]string {
sections := strings.Fields(raw)
config := make(map[string]string)
for _, part := range sections {
field := strings.SplitN(part, ":", 2)
name := strings.ToLower(strings.Replace(field[0], ".", "_", -1))
if len(field) > 1 {
config[name] = field[1]
} else {
if strings.EqualFold(name, "TLS") {
config["tls_acme"] = "TLS"
} else {
config[name] = ""
}
}
}
return config
}
func toBool(conf map[string]string, key string) bool {
if val, ok := conf[key]; ok {
return strings.EqualFold(val, "true") ||
strings.EqualFold(val, "enable") ||
strings.EqualFold(val, "on")
}
return false
}
func makeEntryPointForwardedHeaders(result map[string]string) *ForwardedHeaders {
forwardedHeaders := &ForwardedHeaders{}
forwardedHeaders.Insecure = toBool(result, "forwardedheaders_insecure")
fhTrustedIPs := result["forwardedheaders_trustedips"]
if len(fhTrustedIPs) > 0 {
forwardedHeaders.TrustedIPs = strings.Split(fhTrustedIPs, ",")
}
return forwardedHeaders
}

View file

@ -0,0 +1,257 @@
package static
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_parseEntryPointsConfiguration(t *testing.T) {
testCases := []struct {
name string
value string
expectedResult map[string]string
}{
{
name: "all parameters",
value: "Name:foo " +
"Address::8000 " +
"CA:car " +
"CA.Optional:true " +
"Redirect.EntryPoint:https " +
"Redirect.Regex:http://localhost/(.*) " +
"Redirect.Replacement:http://mydomain/$1 " +
"Redirect.Permanent:true " +
"Compress:true " +
"ProxyProtocol.TrustedIPs:192.168.0.1 " +
"ForwardedHeaders.TrustedIPs:10.0.0.3/24,20.0.0.3/24 " +
"Auth.Basic.Realm:myRealm " +
"Auth.Basic.Users:test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/,test2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0 " +
"Auth.Basic.RemoveHeader:true " +
"Auth.Digest.Users:test:traefik:a2688e031edb4be6a3797f3882655c05,test2:traefik:518845800f9e2bfb1f1f740ec24f074e " +
"Auth.Digest.RemoveHeader:true " +
"Auth.HeaderField:X-WebAuth-User " +
"Auth.Forward.Address:https://authserver.com/auth " +
"Auth.Forward.AuthResponseHeaders:X-Auth,X-Test,X-Secret " +
"Auth.Forward.TrustForwardHeader:true " +
"Auth.Forward.TLS.CA:path/to/local.crt " +
"Auth.Forward.TLS.CAOptional:true " +
"Auth.Forward.TLS.Cert:path/to/foo.cert " +
"Auth.Forward.TLS.Key:path/to/foo.key " +
"Auth.Forward.TLS.InsecureSkipVerify:true " +
"WhiteList.SourceRange:10.42.0.0/16,152.89.1.33/32,afed:be44::/16 " +
"WhiteList.IPStrategy.depth:3 " +
"WhiteList.IPStrategy.ExcludedIPs:10.0.0.3/24,20.0.0.3/24 " +
"ClientIPStrategy.depth:3 " +
"ClientIPStrategy.ExcludedIPs:10.0.0.3/24,20.0.0.3/24 ",
expectedResult: map[string]string{
"address": ":8000",
"auth_basic_realm": "myRealm",
"auth_basic_users": "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/,test2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0",
"auth_basic_removeheader": "true",
"auth_digest_users": "test:traefik:a2688e031edb4be6a3797f3882655c05,test2:traefik:518845800f9e2bfb1f1f740ec24f074e",
"auth_digest_removeheader": "true",
"auth_forward_address": "https://authserver.com/auth",
"auth_forward_authresponseheaders": "X-Auth,X-Test,X-Secret",
"auth_forward_tls_ca": "path/to/local.crt",
"auth_forward_tls_caoptional": "true",
"auth_forward_tls_cert": "path/to/foo.cert",
"auth_forward_tls_insecureskipverify": "true",
"auth_forward_tls_key": "path/to/foo.key",
"auth_forward_trustforwardheader": "true",
"auth_headerfield": "X-WebAuth-User",
"ca": "car",
"ca_optional": "true",
"compress": "true",
"forwardedheaders_trustedips": "10.0.0.3/24,20.0.0.3/24",
"name": "foo",
"proxyprotocol_trustedips": "192.168.0.1",
"redirect_entrypoint": "https",
"redirect_permanent": "true",
"redirect_regex": "http://localhost/(.*)",
"redirect_replacement": "http://mydomain/$1",
"whitelist_sourcerange": "10.42.0.0/16,152.89.1.33/32,afed:be44::/16",
"whitelist_ipstrategy_depth": "3",
"whitelist_ipstrategy_excludedips": "10.0.0.3/24,20.0.0.3/24",
"clientipstrategy_depth": "3",
"clientipstrategy_excludedips": "10.0.0.3/24,20.0.0.3/24",
},
},
{
name: "compress on",
value: "name:foo Compress:on",
expectedResult: map[string]string{
"name": "foo",
"compress": "on",
},
},
}
for _, test := range testCases {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
conf := parseEntryPointsConfiguration(test.value)
assert.Len(t, conf, len(test.expectedResult))
assert.Equal(t, test.expectedResult, conf)
})
}
}
func Test_toBool(t *testing.T) {
testCases := []struct {
name string
value string
key string
expectedBool bool
}{
{
name: "on",
value: "on",
key: "foo",
expectedBool: true,
},
{
name: "true",
value: "true",
key: "foo",
expectedBool: true,
},
{
name: "enable",
value: "enable",
key: "foo",
expectedBool: true,
},
{
name: "arbitrary string",
value: "bar",
key: "foo",
expectedBool: false,
},
{
name: "no existing entry",
value: "bar",
key: "fii",
expectedBool: false,
},
}
for _, test := range testCases {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
conf := map[string]string{
"foo": test.value,
}
result := toBool(conf, test.key)
assert.Equal(t, test.expectedBool, result)
})
}
}
func TestEntryPoints_Set(t *testing.T) {
testCases := []struct {
name string
expression string
expectedEntryPointName string
expectedEntryPoint *EntryPoint
}{
{
name: "all parameters camelcase",
expression: "Name:foo " +
"Address::8000 " +
"CA:car " +
"CA.Optional:true " +
"ProxyProtocol.TrustedIPs:192.168.0.1 ",
expectedEntryPointName: "foo",
expectedEntryPoint: &EntryPoint{
Address: ":8000",
ProxyProtocol: &ProxyProtocol{
Insecure: false,
TrustedIPs: []string{"192.168.0.1"},
},
ForwardedHeaders: &ForwardedHeaders{},
// FIXME Test ServersTransport
},
},
{
name: "all parameters lowercase",
expression: "Name:foo " +
"address::8000 " +
"tls " +
"tls.minversion:VersionTLS11 " +
"tls.ciphersuites:TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA384,TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA " +
"ca:car " +
"ca.Optional:true " +
"proxyProtocol.TrustedIPs:192.168.0.1 ",
expectedEntryPointName: "foo",
expectedEntryPoint: &EntryPoint{
Address: ":8000",
ProxyProtocol: &ProxyProtocol{
Insecure: false,
TrustedIPs: []string{"192.168.0.1"},
},
ForwardedHeaders: &ForwardedHeaders{},
// FIXME Test ServersTransport
},
},
{
name: "default",
expression: "Name:foo",
expectedEntryPointName: "foo",
expectedEntryPoint: &EntryPoint{
ForwardedHeaders: &ForwardedHeaders{},
},
},
{
name: "ProxyProtocol insecure true",
expression: "Name:foo ProxyProtocol.insecure:true",
expectedEntryPointName: "foo",
expectedEntryPoint: &EntryPoint{
ProxyProtocol: &ProxyProtocol{Insecure: true},
ForwardedHeaders: &ForwardedHeaders{},
},
},
{
name: "ProxyProtocol insecure false",
expression: "Name:foo ProxyProtocol.insecure:false",
expectedEntryPointName: "foo",
expectedEntryPoint: &EntryPoint{
ProxyProtocol: &ProxyProtocol{},
ForwardedHeaders: &ForwardedHeaders{},
},
},
{
name: "ProxyProtocol TrustedIPs",
expression: "Name:foo ProxyProtocol.TrustedIPs:10.0.0.3/24,20.0.0.3/24",
expectedEntryPointName: "foo",
expectedEntryPoint: &EntryPoint{
ProxyProtocol: &ProxyProtocol{
TrustedIPs: []string{"10.0.0.3/24", "20.0.0.3/24"},
},
ForwardedHeaders: &ForwardedHeaders{},
},
},
}
for _, test := range testCases {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
eps := EntryPoints{}
err := eps.Set(test.expression)
require.NoError(t, err)
ep := eps[test.expectedEntryPointName]
assert.EqualValues(t, test.expectedEntryPoint, ep)
})
}
}

View file

@ -0,0 +1,401 @@
package static
import (
"errors"
"strings"
"time"
"github.com/containous/flaeg/parse"
"github.com/containous/traefik/old/provider/boltdb"
"github.com/containous/traefik/old/provider/consul"
"github.com/containous/traefik/old/provider/consulcatalog"
"github.com/containous/traefik/old/provider/dynamodb"
"github.com/containous/traefik/old/provider/ecs"
"github.com/containous/traefik/old/provider/etcd"
"github.com/containous/traefik/old/provider/eureka"
"github.com/containous/traefik/old/provider/mesos"
"github.com/containous/traefik/old/provider/rancher"
"github.com/containous/traefik/old/provider/zk"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/ping"
acmeprovider "github.com/containous/traefik/pkg/provider/acme"
"github.com/containous/traefik/pkg/provider/docker"
"github.com/containous/traefik/pkg/provider/file"
"github.com/containous/traefik/pkg/provider/kubernetes/crd"
"github.com/containous/traefik/pkg/provider/kubernetes/ingress"
"github.com/containous/traefik/pkg/provider/marathon"
"github.com/containous/traefik/pkg/provider/rest"
"github.com/containous/traefik/pkg/tls"
"github.com/containous/traefik/pkg/tracing/datadog"
"github.com/containous/traefik/pkg/tracing/instana"
"github.com/containous/traefik/pkg/tracing/jaeger"
"github.com/containous/traefik/pkg/tracing/zipkin"
"github.com/containous/traefik/pkg/types"
assetfs "github.com/elazarl/go-bindata-assetfs"
"github.com/go-acme/lego/challenge/dns01"
jaegercli "github.com/uber/jaeger-client-go"
)
const (
// DefaultInternalEntryPointName the name of the default internal entry point
DefaultInternalEntryPointName = "traefik"
// DefaultGraceTimeout controls how long Traefik serves pending requests
// prior to shutting down.
DefaultGraceTimeout = 10 * time.Second
// DefaultIdleTimeout before closing an idle connection.
DefaultIdleTimeout = 180 * time.Second
// DefaultAcmeCAServer is the default ACME API endpoint
DefaultAcmeCAServer = "https://acme-v02.api.letsencrypt.org/directory"
)
// Configuration is the static configuration
type Configuration struct {
Global *Global `description:"Global configuration options" export:"true"`
ServersTransport *ServersTransport `description:"Servers default transport" export:"true"`
EntryPoints EntryPoints `description:"Entrypoints definition using format: --entryPoints='Name:http Address::8000 Redirect.EntryPoint:https' --entryPoints='Name:https Address::4442 TLS:tests/traefik.crt,tests/traefik.key;prod/traefik.crt,prod/traefik.key'" export:"true"`
Providers *Providers `description:"Providers configuration" export:"true"`
API *API `description:"Enable api/dashboard" export:"true"`
Metrics *types.Metrics `description:"Enable a metrics exporter" export:"true"`
Ping *ping.Handler `description:"Enable ping" export:"true"`
// Rest *rest.Provider `description:"Enable Rest backend with default settings" export:"true"`
Log *types.TraefikLog
AccessLog *types.AccessLog `description:"Access log settings" export:"true"`
Tracing *Tracing `description:"OpenTracing configuration" export:"true"`
HostResolver *types.HostResolverConfig `description:"Enable CNAME Flattening" export:"true"`
ACME *acmeprovider.Configuration `description:"Enable ACME (Let's Encrypt): automatic SSL" export:"true"`
}
// Global holds the global configuration.
type Global struct {
Debug bool `short:"d" description:"Enable debug mode" export:"true"`
CheckNewVersion bool `description:"Periodically check if a new version has been released" export:"true"`
SendAnonymousUsage *bool `description:"send periodically anonymous usage statistics" export:"true"`
}
// ServersTransport options to configure communication between Traefik and the servers
type ServersTransport struct {
InsecureSkipVerify bool `description:"Disable SSL certificate verification" export:"true"`
RootCAs tls.FilesOrContents `description:"Add cert file for self-signed certificate"`
MaxIdleConnsPerHost int `description:"If non-zero, controls the maximum idle (keep-alive) to keep per-host. If zero, DefaultMaxIdleConnsPerHost is used" export:"true"`
ForwardingTimeouts *ForwardingTimeouts `description:"Timeouts for requests forwarded to the backend servers" export:"true"`
}
// API holds the API configuration
type API struct {
EntryPoint string `description:"EntryPoint" export:"true"`
Dashboard bool `description:"Activate dashboard" export:"true"`
Statistics *types.Statistics `description:"Enable more detailed statistics" export:"true"`
Middlewares []string `description:"Middleware list" export:"true"`
DashboardAssets *assetfs.AssetFS `json:"-"`
}
// RespondingTimeouts contains timeout configurations for incoming requests to the Traefik instance.
type RespondingTimeouts struct {
ReadTimeout parse.Duration `description:"ReadTimeout is the maximum duration for reading the entire request, including the body. If zero, no timeout is set" export:"true"`
WriteTimeout parse.Duration `description:"WriteTimeout is the maximum duration before timing out writes of the response. If zero, no timeout is set" export:"true"`
IdleTimeout parse.Duration `description:"IdleTimeout is the maximum amount duration an idle (keep-alive) connection will remain idle before closing itself. Defaults to 180 seconds. If zero, no timeout is set" export:"true"`
}
// ForwardingTimeouts contains timeout configurations for forwarding requests to the backend servers.
type ForwardingTimeouts struct {
DialTimeout parse.Duration `description:"The amount of time to wait until a connection to a backend server can be established. Defaults to 30 seconds. If zero, no timeout exists" export:"true"`
ResponseHeaderTimeout parse.Duration `description:"The amount of time to wait for a server's response headers after fully writing the request (including its body, if any). If zero, no timeout exists" export:"true"`
}
// LifeCycle contains configurations relevant to the lifecycle (such as the shutdown phase) of Traefik.
type LifeCycle struct {
RequestAcceptGraceTimeout parse.Duration `description:"Duration to keep accepting requests before Traefik initiates the graceful shutdown procedure"`
GraceTimeOut parse.Duration `description:"Duration to give active requests a chance to finish before Traefik stops"`
}
// Tracing holds the tracing configuration.
type Tracing struct {
Backend string `description:"Selects the tracking backend ('jaeger','zipkin','datadog','instana')." export:"true"`
ServiceName string `description:"Set the name for this service" export:"true"`
SpanNameLimit int `description:"Set the maximum character limit for Span names (default 0 = no limit)" export:"true"`
Jaeger *jaeger.Config `description:"Settings for jaeger"`
Zipkin *zipkin.Config `description:"Settings for zipkin"`
DataDog *datadog.Config `description:"Settings for DataDog"`
Instana *instana.Config `description:"Settings for Instana"`
}
// Providers contains providers configuration
type Providers struct {
ProvidersThrottleDuration parse.Duration `description:"Backends throttle duration: minimum duration between 2 events from providers before applying a new configuration. It avoids unnecessary reloads if multiples events are sent in a short amount of time." export:"true"`
Docker *docker.Provider `description:"Enable Docker backend with default settings" export:"true"`
File *file.Provider `description:"Enable File backend with default settings" export:"true"`
Marathon *marathon.Provider `description:"Enable Marathon backend with default settings" export:"true"`
Consul *consul.Provider `description:"Enable Consul backend with default settings" export:"true"`
ConsulCatalog *consulcatalog.Provider `description:"Enable Consul catalog backend with default settings" export:"true"`
Etcd *etcd.Provider `description:"Enable Etcd backend with default settings" export:"true"`
Zookeeper *zk.Provider `description:"Enable Zookeeper backend with default settings" export:"true"`
Boltdb *boltdb.Provider `description:"Enable Boltdb backend with default settings" export:"true"`
Kubernetes *ingress.Provider `description:"Enable Kubernetes backend with default settings" export:"true"`
KubernetesCRD *crd.Provider `description:"Enable Kubernetes backend with default settings" export:"true"`
Mesos *mesos.Provider `description:"Enable Mesos backend with default settings" export:"true"`
Eureka *eureka.Provider `description:"Enable Eureka backend with default settings" export:"true"`
ECS *ecs.Provider `description:"Enable ECS backend with default settings" export:"true"`
Rancher *rancher.Provider `description:"Enable Rancher backend with default settings" export:"true"`
DynamoDB *dynamodb.Provider `description:"Enable DynamoDB backend with default settings" export:"true"`
Rest *rest.Provider `description:"Enable Rest backend with default settings" export:"true"`
}
// SetEffectiveConfiguration adds missing configuration parameters derived from existing ones.
// It also takes care of maintaining backwards compatibility.
func (c *Configuration) SetEffectiveConfiguration(configFile string) {
if len(c.EntryPoints) == 0 {
c.EntryPoints = EntryPoints{
"http": &EntryPoint{
Address: ":80",
},
}
}
if (c.API != nil && c.API.EntryPoint == DefaultInternalEntryPointName) ||
(c.Ping != nil && c.Ping.EntryPoint == DefaultInternalEntryPointName) ||
(c.Metrics != nil && c.Metrics.Prometheus != nil && c.Metrics.Prometheus.EntryPoint == DefaultInternalEntryPointName) ||
(c.Providers.Rest != nil && c.Providers.Rest.EntryPoint == DefaultInternalEntryPointName) {
if _, ok := c.EntryPoints[DefaultInternalEntryPointName]; !ok {
c.EntryPoints[DefaultInternalEntryPointName] = &EntryPoint{Address: ":8080"}
}
}
for _, entryPoint := range c.EntryPoints {
if entryPoint.Transport == nil {
entryPoint.Transport = &EntryPointsTransport{}
}
// Make sure LifeCycle isn't nil to spare nil checks elsewhere.
if entryPoint.Transport.LifeCycle == nil {
entryPoint.Transport.LifeCycle = &LifeCycle{
GraceTimeOut: parse.Duration(DefaultGraceTimeout),
}
entryPoint.Transport.RespondingTimeouts = &RespondingTimeouts{
IdleTimeout: parse.Duration(DefaultIdleTimeout),
}
}
if entryPoint.ForwardedHeaders == nil {
entryPoint.ForwardedHeaders = &ForwardedHeaders{}
}
}
if c.Providers.Rancher != nil {
// Ensure backwards compatibility for now
if len(c.Providers.Rancher.AccessKey) > 0 ||
len(c.Providers.Rancher.Endpoint) > 0 ||
len(c.Providers.Rancher.SecretKey) > 0 {
if c.Providers.Rancher.API == nil {
c.Providers.Rancher.API = &rancher.APIConfiguration{
AccessKey: c.Providers.Rancher.AccessKey,
SecretKey: c.Providers.Rancher.SecretKey,
Endpoint: c.Providers.Rancher.Endpoint,
}
}
log.Warn("Deprecated configuration found: rancher.[accesskey|secretkey|endpoint]. " +
"Please use rancher.api.[accesskey|secretkey|endpoint] instead.")
}
if c.Providers.Rancher.Metadata != nil && len(c.Providers.Rancher.Metadata.Prefix) == 0 {
c.Providers.Rancher.Metadata.Prefix = "latest"
}
}
if c.Providers.Docker != nil {
if c.Providers.Docker.SwarmModeRefreshSeconds <= 0 {
c.Providers.Docker.SwarmModeRefreshSeconds = 15
}
}
if c.Providers.File != nil {
c.Providers.File.TraefikFile = configFile
}
c.initACMEProvider()
c.initTracing()
}
func (c *Configuration) initTracing() {
if c.Tracing != nil {
switch c.Tracing.Backend {
case jaeger.Name:
if c.Tracing.Jaeger == nil {
c.Tracing.Jaeger = &jaeger.Config{
SamplingServerURL: "http://localhost:5778/sampling",
SamplingType: "const",
SamplingParam: 1.0,
LocalAgentHostPort: "127.0.0.1:6831",
Propagation: "jaeger",
Gen128Bit: false,
TraceContextHeaderName: jaegercli.TraceContextHeaderName,
}
}
if c.Tracing.Zipkin != nil {
log.Warn("Zipkin configuration will be ignored")
c.Tracing.Zipkin = nil
}
if c.Tracing.DataDog != nil {
log.Warn("DataDog configuration will be ignored")
c.Tracing.DataDog = nil
}
if c.Tracing.Instana != nil {
log.Warn("Instana configuration will be ignored")
c.Tracing.Instana = nil
}
case zipkin.Name:
if c.Tracing.Zipkin == nil {
c.Tracing.Zipkin = &zipkin.Config{
HTTPEndpoint: "http://localhost:9411/api/v1/spans",
SameSpan: false,
ID128Bit: true,
Debug: false,
SampleRate: 1.0,
}
}
if c.Tracing.Jaeger != nil {
log.Warn("Jaeger configuration will be ignored")
c.Tracing.Jaeger = nil
}
if c.Tracing.DataDog != nil {
log.Warn("DataDog configuration will be ignored")
c.Tracing.DataDog = nil
}
if c.Tracing.Instana != nil {
log.Warn("Instana configuration will be ignored")
c.Tracing.Instana = nil
}
case datadog.Name:
if c.Tracing.DataDog == nil {
c.Tracing.DataDog = &datadog.Config{
LocalAgentHostPort: "localhost:8126",
GlobalTag: "",
Debug: false,
}
}
if c.Tracing.Zipkin != nil {
log.Warn("Zipkin configuration will be ignored")
c.Tracing.Zipkin = nil
}
if c.Tracing.Jaeger != nil {
log.Warn("Jaeger configuration will be ignored")
c.Tracing.Jaeger = nil
}
if c.Tracing.Instana != nil {
log.Warn("Instana configuration will be ignored")
c.Tracing.Instana = nil
}
case instana.Name:
if c.Tracing.Instana == nil {
c.Tracing.Instana = &instana.Config{
LocalAgentHost: "localhost",
LocalAgentPort: 42699,
LogLevel: "info",
}
}
if c.Tracing.Zipkin != nil {
log.Warn("Zipkin configuration will be ignored")
c.Tracing.Zipkin = nil
}
if c.Tracing.Jaeger != nil {
log.Warn("Jaeger configuration will be ignored")
c.Tracing.Jaeger = nil
}
if c.Tracing.DataDog != nil {
log.Warn("DataDog configuration will be ignored")
c.Tracing.DataDog = nil
}
default:
log.Warnf("Unknown tracer %q", c.Tracing.Backend)
return
}
}
}
// FIXME handle on new configuration ACME struct
func (c *Configuration) initACMEProvider() {
if c.ACME != nil {
c.ACME.CAServer = getSafeACMECAServer(c.ACME.CAServer)
if c.ACME.DNSChallenge != nil && c.ACME.HTTPChallenge != nil {
log.Warn("Unable to use DNS challenge and HTTP challenge at the same time. Fallback to DNS challenge.")
c.ACME.HTTPChallenge = nil
}
if c.ACME.DNSChallenge != nil && c.ACME.TLSChallenge != nil {
log.Warn("Unable to use DNS challenge and TLS challenge at the same time. Fallback to DNS challenge.")
c.ACME.TLSChallenge = nil
}
if c.ACME.HTTPChallenge != nil && c.ACME.TLSChallenge != nil {
log.Warn("Unable to use HTTP challenge and TLS challenge at the same time. Fallback to TLS challenge.")
c.ACME.HTTPChallenge = nil
}
}
}
// InitACMEProvider create an acme provider from the ACME part of globalConfiguration
func (c *Configuration) InitACMEProvider() (*acmeprovider.Provider, error) {
if c.ACME != nil {
if len(c.ACME.Storage) == 0 {
return nil, errors.New("unable to initialize ACME provider with no storage location for the certificates")
}
return &acmeprovider.Provider{
Configuration: c.ACME,
}, nil
}
return nil, nil
}
// ValidateConfiguration validate that configuration is coherent
func (c *Configuration) ValidateConfiguration() {
if c.ACME != nil {
for _, domain := range c.ACME.Domains {
if domain.Main != dns01.UnFqdn(domain.Main) {
log.Warnf("FQDN detected, please remove the trailing dot: %s", domain.Main)
}
for _, san := range domain.SANs {
if san != dns01.UnFqdn(san) {
log.Warnf("FQDN detected, please remove the trailing dot: %s", san)
}
}
}
}
// FIXME Validate store config?
// if c.ACME != nil {
// if _, ok := c.EntryPoints[c.ACME.EntryPoint]; !ok {
// log.Fatalf("Unknown entrypoint %q for ACME configuration", c.ACME.EntryPoint)
// }
// else if c.EntryPoints[c.ACME.EntryPoint].TLS == nil {
// log.Fatalf("Entrypoint %q has no TLS configuration for ACME configuration", c.ACME.EntryPoint)
// }
// }
}
func getSafeACMECAServer(caServerSrc string) string {
if len(caServerSrc) == 0 {
return DefaultAcmeCAServer
}
if strings.HasPrefix(caServerSrc, "https://acme-v01.api.letsencrypt.org") {
caServer := strings.Replace(caServerSrc, "v01", "v02", 1)
log.Warnf("The CA server %[1]q refers to a v01 endpoint of the ACME API, please change to %[2]q. Fallback to %[2]q.", caServerSrc, caServer)
return caServer
}
if strings.HasPrefix(caServerSrc, "https://acme-staging.api.letsencrypt.org") {
caServer := strings.Replace(caServerSrc, "https://acme-staging.api.letsencrypt.org", "https://acme-staging-v02.api.letsencrypt.org", 1)
log.Warnf("The CA server %[1]q refers to a v01 endpoint of the ACME API, please change to %[2]q. Fallback to %[2]q.", caServerSrc, caServer)
return caServer
}
return caServerSrc
}

View file

@ -0,0 +1,733 @@
// +build !ignore_autogenerated
/*
The MIT License (MIT)
Copyright (c) 2016-2019 Containous SAS
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
*/
// Code generated by deepcopy-gen. DO NOT EDIT.
package config
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *AddPrefix) DeepCopyInto(out *AddPrefix) {
*out = *in
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new AddPrefix.
func (in *AddPrefix) DeepCopy() *AddPrefix {
if in == nil {
return nil
}
out := new(AddPrefix)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *Auth) DeepCopyInto(out *Auth) {
*out = *in
if in.Basic != nil {
in, out := &in.Basic, &out.Basic
*out = new(BasicAuth)
(*in).DeepCopyInto(*out)
}
if in.Digest != nil {
in, out := &in.Digest, &out.Digest
*out = new(DigestAuth)
(*in).DeepCopyInto(*out)
}
if in.Forward != nil {
in, out := &in.Forward, &out.Forward
*out = new(ForwardAuth)
(*in).DeepCopyInto(*out)
}
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Auth.
func (in *Auth) DeepCopy() *Auth {
if in == nil {
return nil
}
out := new(Auth)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *BasicAuth) DeepCopyInto(out *BasicAuth) {
*out = *in
if in.Users != nil {
in, out := &in.Users, &out.Users
*out = make(Users, len(*in))
copy(*out, *in)
}
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new BasicAuth.
func (in *BasicAuth) DeepCopy() *BasicAuth {
if in == nil {
return nil
}
out := new(BasicAuth)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *Buffering) DeepCopyInto(out *Buffering) {
*out = *in
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Buffering.
func (in *Buffering) DeepCopy() *Buffering {
if in == nil {
return nil
}
out := new(Buffering)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *Chain) DeepCopyInto(out *Chain) {
*out = *in
if in.Middlewares != nil {
in, out := &in.Middlewares, &out.Middlewares
*out = make([]string, len(*in))
copy(*out, *in)
}
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Chain.
func (in *Chain) DeepCopy() *Chain {
if in == nil {
return nil
}
out := new(Chain)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *CircuitBreaker) DeepCopyInto(out *CircuitBreaker) {
*out = *in
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new CircuitBreaker.
func (in *CircuitBreaker) DeepCopy() *CircuitBreaker {
if in == nil {
return nil
}
out := new(CircuitBreaker)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *ClientTLS) DeepCopyInto(out *ClientTLS) {
*out = *in
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ClientTLS.
func (in *ClientTLS) DeepCopy() *ClientTLS {
if in == nil {
return nil
}
out := new(ClientTLS)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *Compress) DeepCopyInto(out *Compress) {
*out = *in
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Compress.
func (in *Compress) DeepCopy() *Compress {
if in == nil {
return nil
}
out := new(Compress)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *DigestAuth) DeepCopyInto(out *DigestAuth) {
*out = *in
if in.Users != nil {
in, out := &in.Users, &out.Users
*out = make(Users, len(*in))
copy(*out, *in)
}
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DigestAuth.
func (in *DigestAuth) DeepCopy() *DigestAuth {
if in == nil {
return nil
}
out := new(DigestAuth)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *ErrorPage) DeepCopyInto(out *ErrorPage) {
*out = *in
if in.Status != nil {
in, out := &in.Status, &out.Status
*out = make([]string, len(*in))
copy(*out, *in)
}
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ErrorPage.
func (in *ErrorPage) DeepCopy() *ErrorPage {
if in == nil {
return nil
}
out := new(ErrorPage)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *ForwardAuth) DeepCopyInto(out *ForwardAuth) {
*out = *in
if in.TLS != nil {
in, out := &in.TLS, &out.TLS
*out = new(ClientTLS)
**out = **in
}
if in.AuthResponseHeaders != nil {
in, out := &in.AuthResponseHeaders, &out.AuthResponseHeaders
*out = make([]string, len(*in))
copy(*out, *in)
}
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ForwardAuth.
func (in *ForwardAuth) DeepCopy() *ForwardAuth {
if in == nil {
return nil
}
out := new(ForwardAuth)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *Headers) DeepCopyInto(out *Headers) {
*out = *in
if in.CustomRequestHeaders != nil {
in, out := &in.CustomRequestHeaders, &out.CustomRequestHeaders
*out = make(map[string]string, len(*in))
for key, val := range *in {
(*out)[key] = val
}
}
if in.CustomResponseHeaders != nil {
in, out := &in.CustomResponseHeaders, &out.CustomResponseHeaders
*out = make(map[string]string, len(*in))
for key, val := range *in {
(*out)[key] = val
}
}
if in.AllowedHosts != nil {
in, out := &in.AllowedHosts, &out.AllowedHosts
*out = make([]string, len(*in))
copy(*out, *in)
}
if in.HostsProxyHeaders != nil {
in, out := &in.HostsProxyHeaders, &out.HostsProxyHeaders
*out = make([]string, len(*in))
copy(*out, *in)
}
if in.SSLProxyHeaders != nil {
in, out := &in.SSLProxyHeaders, &out.SSLProxyHeaders
*out = make(map[string]string, len(*in))
for key, val := range *in {
(*out)[key] = val
}
}
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Headers.
func (in *Headers) DeepCopy() *Headers {
if in == nil {
return nil
}
out := new(Headers)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *IPStrategy) DeepCopyInto(out *IPStrategy) {
*out = *in
if in.ExcludedIPs != nil {
in, out := &in.ExcludedIPs, &out.ExcludedIPs
*out = make([]string, len(*in))
copy(*out, *in)
}
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new IPStrategy.
func (in *IPStrategy) DeepCopy() *IPStrategy {
if in == nil {
return nil
}
out := new(IPStrategy)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *IPWhiteList) DeepCopyInto(out *IPWhiteList) {
*out = *in
if in.SourceRange != nil {
in, out := &in.SourceRange, &out.SourceRange
*out = make([]string, len(*in))
copy(*out, *in)
}
if in.IPStrategy != nil {
in, out := &in.IPStrategy, &out.IPStrategy
*out = new(IPStrategy)
(*in).DeepCopyInto(*out)
}
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new IPWhiteList.
func (in *IPWhiteList) DeepCopy() *IPWhiteList {
if in == nil {
return nil
}
out := new(IPWhiteList)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *MaxConn) DeepCopyInto(out *MaxConn) {
*out = *in
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MaxConn.
func (in *MaxConn) DeepCopy() *MaxConn {
if in == nil {
return nil
}
out := new(MaxConn)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *Middleware) DeepCopyInto(out *Middleware) {
*out = *in
if in.AddPrefix != nil {
in, out := &in.AddPrefix, &out.AddPrefix
*out = new(AddPrefix)
**out = **in
}
if in.StripPrefix != nil {
in, out := &in.StripPrefix, &out.StripPrefix
*out = new(StripPrefix)
(*in).DeepCopyInto(*out)
}
if in.StripPrefixRegex != nil {
in, out := &in.StripPrefixRegex, &out.StripPrefixRegex
*out = new(StripPrefixRegex)
(*in).DeepCopyInto(*out)
}
if in.ReplacePath != nil {
in, out := &in.ReplacePath, &out.ReplacePath
*out = new(ReplacePath)
**out = **in
}
if in.ReplacePathRegex != nil {
in, out := &in.ReplacePathRegex, &out.ReplacePathRegex
*out = new(ReplacePathRegex)
**out = **in
}
if in.Chain != nil {
in, out := &in.Chain, &out.Chain
*out = new(Chain)
(*in).DeepCopyInto(*out)
}
if in.IPWhiteList != nil {
in, out := &in.IPWhiteList, &out.IPWhiteList
*out = new(IPWhiteList)
(*in).DeepCopyInto(*out)
}
if in.Headers != nil {
in, out := &in.Headers, &out.Headers
*out = new(Headers)
(*in).DeepCopyInto(*out)
}
if in.Errors != nil {
in, out := &in.Errors, &out.Errors
*out = new(ErrorPage)
(*in).DeepCopyInto(*out)
}
if in.RateLimit != nil {
in, out := &in.RateLimit, &out.RateLimit
*out = new(RateLimit)
(*in).DeepCopyInto(*out)
}
if in.RedirectRegex != nil {
in, out := &in.RedirectRegex, &out.RedirectRegex
*out = new(RedirectRegex)
**out = **in
}
if in.RedirectScheme != nil {
in, out := &in.RedirectScheme, &out.RedirectScheme
*out = new(RedirectScheme)
**out = **in
}
if in.BasicAuth != nil {
in, out := &in.BasicAuth, &out.BasicAuth
*out = new(BasicAuth)
(*in).DeepCopyInto(*out)
}
if in.DigestAuth != nil {
in, out := &in.DigestAuth, &out.DigestAuth
*out = new(DigestAuth)
(*in).DeepCopyInto(*out)
}
if in.ForwardAuth != nil {
in, out := &in.ForwardAuth, &out.ForwardAuth
*out = new(ForwardAuth)
(*in).DeepCopyInto(*out)
}
if in.MaxConn != nil {
in, out := &in.MaxConn, &out.MaxConn
*out = new(MaxConn)
**out = **in
}
if in.Buffering != nil {
in, out := &in.Buffering, &out.Buffering
*out = new(Buffering)
**out = **in
}
if in.CircuitBreaker != nil {
in, out := &in.CircuitBreaker, &out.CircuitBreaker
*out = new(CircuitBreaker)
**out = **in
}
if in.Compress != nil {
in, out := &in.Compress, &out.Compress
*out = new(Compress)
**out = **in
}
if in.PassTLSClientCert != nil {
in, out := &in.PassTLSClientCert, &out.PassTLSClientCert
*out = new(PassTLSClientCert)
(*in).DeepCopyInto(*out)
}
if in.Retry != nil {
in, out := &in.Retry, &out.Retry
*out = new(Retry)
**out = **in
}
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Middleware.
func (in *Middleware) DeepCopy() *Middleware {
if in == nil {
return nil
}
out := new(Middleware)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *PassTLSClientCert) DeepCopyInto(out *PassTLSClientCert) {
*out = *in
if in.Info != nil {
in, out := &in.Info, &out.Info
*out = new(TLSClientCertificateInfo)
(*in).DeepCopyInto(*out)
}
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new PassTLSClientCert.
func (in *PassTLSClientCert) DeepCopy() *PassTLSClientCert {
if in == nil {
return nil
}
out := new(PassTLSClientCert)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *Rate) DeepCopyInto(out *Rate) {
*out = *in
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Rate.
func (in *Rate) DeepCopy() *Rate {
if in == nil {
return nil
}
out := new(Rate)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *RateLimit) DeepCopyInto(out *RateLimit) {
*out = *in
if in.RateSet != nil {
in, out := &in.RateSet, &out.RateSet
*out = make(map[string]*Rate, len(*in))
for key, val := range *in {
var outVal *Rate
if val == nil {
(*out)[key] = nil
} else {
in, out := &val, &outVal
*out = new(Rate)
**out = **in
}
(*out)[key] = outVal
}
}
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RateLimit.
func (in *RateLimit) DeepCopy() *RateLimit {
if in == nil {
return nil
}
out := new(RateLimit)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *RedirectRegex) DeepCopyInto(out *RedirectRegex) {
*out = *in
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RedirectRegex.
func (in *RedirectRegex) DeepCopy() *RedirectRegex {
if in == nil {
return nil
}
out := new(RedirectRegex)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *RedirectScheme) DeepCopyInto(out *RedirectScheme) {
*out = *in
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RedirectScheme.
func (in *RedirectScheme) DeepCopy() *RedirectScheme {
if in == nil {
return nil
}
out := new(RedirectScheme)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *ReplacePath) DeepCopyInto(out *ReplacePath) {
*out = *in
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ReplacePath.
func (in *ReplacePath) DeepCopy() *ReplacePath {
if in == nil {
return nil
}
out := new(ReplacePath)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *ReplacePathRegex) DeepCopyInto(out *ReplacePathRegex) {
*out = *in
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ReplacePathRegex.
func (in *ReplacePathRegex) DeepCopy() *ReplacePathRegex {
if in == nil {
return nil
}
out := new(ReplacePathRegex)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *Retry) DeepCopyInto(out *Retry) {
*out = *in
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Retry.
func (in *Retry) DeepCopy() *Retry {
if in == nil {
return nil
}
out := new(Retry)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *StripPrefix) DeepCopyInto(out *StripPrefix) {
*out = *in
if in.Prefixes != nil {
in, out := &in.Prefixes, &out.Prefixes
*out = make([]string, len(*in))
copy(*out, *in)
}
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new StripPrefix.
func (in *StripPrefix) DeepCopy() *StripPrefix {
if in == nil {
return nil
}
out := new(StripPrefix)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *StripPrefixRegex) DeepCopyInto(out *StripPrefixRegex) {
*out = *in
if in.Regex != nil {
in, out := &in.Regex, &out.Regex
*out = make([]string, len(*in))
copy(*out, *in)
}
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new StripPrefixRegex.
func (in *StripPrefixRegex) DeepCopy() *StripPrefixRegex {
if in == nil {
return nil
}
out := new(StripPrefixRegex)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *TLSCLientCertificateDNInfo) DeepCopyInto(out *TLSCLientCertificateDNInfo) {
*out = *in
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new TLSCLientCertificateDNInfo.
func (in *TLSCLientCertificateDNInfo) DeepCopy() *TLSCLientCertificateDNInfo {
if in == nil {
return nil
}
out := new(TLSCLientCertificateDNInfo)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *TLSClientCertificateInfo) DeepCopyInto(out *TLSClientCertificateInfo) {
*out = *in
if in.Subject != nil {
in, out := &in.Subject, &out.Subject
*out = new(TLSCLientCertificateDNInfo)
**out = **in
}
if in.Issuer != nil {
in, out := &in.Issuer, &out.Issuer
*out = new(TLSCLientCertificateDNInfo)
**out = **in
}
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new TLSClientCertificateInfo.
func (in *TLSClientCertificateInfo) DeepCopy() *TLSClientCertificateInfo {
if in == nil {
return nil
}
out := new(TLSClientCertificateInfo)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in Users) DeepCopyInto(out *Users) {
{
in := &in
*out = make(Users, len(*in))
copy(*out, *in)
return
}
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Users.
func (in Users) DeepCopy() Users {
if in == nil {
return nil
}
out := new(Users)
in.DeepCopyInto(out)
return *out
}

476
pkg/h2c/h2c.go Normal file
View file

@ -0,0 +1,476 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package h2c implements the h2c part of HTTP/2.
//
// The h2c protocol is the non-TLS secured version of HTTP/2 which is not
// available from net/http.
package h2c
import (
"bufio"
"bytes"
"encoding/base64"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/textproto"
"os"
"strings"
"github.com/containous/traefik/pkg/log"
"golang.org/x/net/http/httpguts"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
)
var (
http2VerboseLogs bool
)
func init() {
e := os.Getenv("GODEBUG")
if strings.Contains(e, "http2debug=1") || strings.Contains(e, "http2debug=2") {
http2VerboseLogs = true
}
}
// Server implements net.Handler and enables h2c. Users who want h2c just need
// to provide an http.Server.
type Server struct {
*http.Server
}
// Serve Put a middleware around the original handler to handle h2c
func (s Server) Serve(l net.Listener) error {
originalHandler := s.Server.Handler
if originalHandler == nil {
originalHandler = http.DefaultServeMux
}
s.Server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "PRI" && r.URL.Path == "*" && r.Proto == "HTTP/2.0" {
if http2VerboseLogs {
log.Debugf("Attempting h2c with prior knowledge.")
}
conn, err := initH2CWithPriorKnowledge(w)
if err != nil {
if http2VerboseLogs {
log.Debugf("Error h2c with prior knowledge: %v", err)
}
return
}
defer conn.Close()
h2cSrv := &http2.Server{}
h2cSrv.ServeConn(conn, &http2.ServeConnOpts{Handler: originalHandler})
return
}
if conn, err := h2cUpgrade(w, r); err == nil {
defer conn.Close()
h2cSrv := &http2.Server{}
h2cSrv.ServeConn(conn, &http2.ServeConnOpts{Handler: originalHandler})
return
}
originalHandler.ServeHTTP(w, r)
})
return s.Server.Serve(l)
}
// initH2CWithPriorKnowledge implements creating a h2c connection with prior
// knowledge (Section 3.4) and creates a net.Conn suitable for http2.ServeConn.
// All we have to do is look for the client preface that is suppose to be part
// of the body, and reforward the client preface on the net.Conn this function
// creates.
func initH2CWithPriorKnowledge(w http.ResponseWriter) (net.Conn, error) {
hijacker, ok := w.(http.Hijacker)
if !ok {
return nil, errors.New("hijack not supported")
}
conn, rw, err := hijacker.Hijack()
if err != nil {
return nil, fmt.Errorf("hijack failed: %v", err)
}
expectedBody := "SM\r\n\r\n"
buf := make([]byte, len(expectedBody))
n, err := io.ReadFull(rw, buf)
if err != nil {
return nil, fmt.Errorf("fail to read body: %v", err)
}
if bytes.Equal(buf[0:n], []byte(expectedBody)) {
c := &rwConn{
Conn: conn,
Reader: io.MultiReader(bytes.NewBuffer([]byte(http2.ClientPreface)), rw),
BufWriter: rw.Writer,
}
return c, nil
}
conn.Close()
if http2VerboseLogs {
log.Infof(
"Missing the request body portion of the client preface. Wanted: %v Got: %v",
[]byte(expectedBody),
buf[0:n],
)
}
return nil, errors.New("invalid client preface")
}
// drainClientPreface reads a single instance of the HTTP/2 client preface from
// the supplied reader.
func drainClientPreface(r io.Reader) error {
var buf bytes.Buffer
prefaceLen := int64(len([]byte(http2.ClientPreface)))
n, err := io.CopyN(&buf, r, prefaceLen)
if err != nil {
return err
}
if n != prefaceLen || buf.String() != http2.ClientPreface {
return fmt.Errorf("client never sent: %s", http2.ClientPreface)
}
return nil
}
// h2cUpgrade establishes a h2c connection using the HTTP/1 upgrade (Section 3.2).
func h2cUpgrade(w http.ResponseWriter, r *http.Request) (net.Conn, error) {
if !isH2CUpgrade(r.Header) {
return nil, errors.New("non-conforming h2c headers")
}
// Initial bytes we put into conn to fool http2 server
initBytes, _, err := convertH1ReqToH2(r)
if err != nil {
return nil, err
}
hijacker, ok := w.(http.Hijacker)
if !ok {
return nil, errors.New("hijack not supported")
}
conn, rw, err := hijacker.Hijack()
if err != nil {
return nil, fmt.Errorf("hijack failed: %v", err)
}
rw.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n" +
"Connection: Upgrade\r\n" +
"Upgrade: h2c\r\n\r\n"))
rw.Flush()
// A conforming client will now send an H2 client preface which need to drain
// since we already sent this.
if err := drainClientPreface(rw); err != nil {
return nil, err
}
c := &rwConn{
Conn: conn,
Reader: io.MultiReader(initBytes, rw),
BufWriter: newSettingsAckSwallowWriter(rw.Writer),
}
return c, nil
}
// convert the data contained in the HTTP/1 upgrade request into the HTTP/2
// version in byte form.
func convertH1ReqToH2(r *http.Request) (*bytes.Buffer, []http2.Setting, error) {
h2Bytes := bytes.NewBuffer([]byte((http2.ClientPreface)))
framer := http2.NewFramer(h2Bytes, nil)
settings, err := getH2Settings(r.Header)
if err != nil {
return nil, nil, err
}
if err := framer.WriteSettings(settings...); err != nil {
return nil, nil, err
}
headerBytes, err := getH2HeaderBytes(r, getMaxHeaderTableSize(settings))
if err != nil {
return nil, nil, err
}
maxFrameSize := int(getMaxFrameSize(settings))
needOneHeader := len(headerBytes) < maxFrameSize
err = framer.WriteHeaders(http2.HeadersFrameParam{
StreamID: 1,
BlockFragment: headerBytes,
EndHeaders: needOneHeader,
})
if err != nil {
return nil, nil, err
}
for i := maxFrameSize; i < len(headerBytes); i += maxFrameSize {
if len(headerBytes)-i > maxFrameSize {
if err := framer.WriteContinuation(1,
false, // endHeaders
headerBytes[i:maxFrameSize]); err != nil {
return nil, nil, err
}
} else {
if err := framer.WriteContinuation(1,
true, // endHeaders
headerBytes[i:]); err != nil {
return nil, nil, err
}
}
}
return h2Bytes, settings, nil
}
// getMaxFrameSize returns the SETTINGS_MAX_FRAME_SIZE. If not present default
// value is 16384 as specified by RFC 7540 Section 6.5.2.
func getMaxFrameSize(settings []http2.Setting) uint32 {
for _, setting := range settings {
if setting.ID == http2.SettingMaxFrameSize {
return setting.Val
}
}
return 16384
}
// getMaxHeaderTableSize returns the SETTINGS_HEADER_TABLE_SIZE. If not present
// default value is 4096 as specified by RFC 7540 Section 6.5.2.
func getMaxHeaderTableSize(settings []http2.Setting) uint32 {
for _, setting := range settings {
if setting.ID == http2.SettingHeaderTableSize {
return setting.Val
}
}
return 4096
}
// bufWriter is a Writer interface that also has a Flush method.
type bufWriter interface {
io.Writer
Flush() error
}
// rwConn implements net.Conn but overrides Read and Write so that reads and
// writes are forwarded to the provided io.Reader and bufWriter.
type rwConn struct {
net.Conn
io.Reader
BufWriter bufWriter
}
// Read forwards reads to the underlying Reader.
func (c *rwConn) Read(p []byte) (int, error) {
return c.Reader.Read(p)
}
// Write forwards writes to the underlying bufWriter and immediately flushes.
func (c *rwConn) Write(p []byte) (int, error) {
n, err := c.BufWriter.Write(p)
if err := c.BufWriter.Flush(); err != nil {
return 0, err
}
return n, err
}
// settingsAckSwallowWriter is a writer that normally forwards bytes to it's
// underlying Writer, but swallows the first SettingsAck frame that it sees.
type settingsAckSwallowWriter struct {
Writer *bufio.Writer
buf []byte
didSwallow bool
}
// newSettingsAckSwallowWriter returns a new settingsAckSwallowWriter.
func newSettingsAckSwallowWriter(w *bufio.Writer) *settingsAckSwallowWriter {
return &settingsAckSwallowWriter{
Writer: w,
buf: make([]byte, 0),
didSwallow: false,
}
}
// Write implements io.Writer interface. Normally forwards bytes to w.Writer,
// except for the first Settings ACK frame that it sees.
func (w *settingsAckSwallowWriter) Write(p []byte) (int, error) {
if !w.didSwallow {
w.buf = append(w.buf, p...)
// Process all the frames we have collected into w.buf
for {
// Append until we get full frame header which is 9 bytes
if len(w.buf) < 9 {
break
}
// Check if we have collected a whole frame.
fh, err := http2.ReadFrameHeader(bytes.NewBuffer(w.buf))
if err != nil {
// Corrupted frame, fail current Write
return 0, err
}
fSize := fh.Length + 9
if uint32(len(w.buf)) < fSize {
// Have not collected whole frame. Stop processing buf, and withhold on
// forward bytes to w.Writer until we get the full frame.
break
}
// We have now collected a whole frame.
if fh.Type == http2.FrameSettings && fh.Flags.Has(http2.FlagSettingsAck) {
// If Settings ACK frame, do not forward to underlying writer, remove
// bytes from w.buf, and record that we have swallowed Settings Ack
// frame.
w.didSwallow = true
w.buf = w.buf[fSize:]
continue
}
// Not settings ack frame. Forward bytes to w.Writer.
if _, err := w.Writer.Write(w.buf[:fSize]); err != nil {
// Couldn't forward bytes. Fail current Write.
return 0, err
}
w.buf = w.buf[fSize:]
}
return len(p), nil
}
return w.Writer.Write(p)
}
// Flush calls w.Writer.Flush.
func (w *settingsAckSwallowWriter) Flush() error {
return w.Writer.Flush()
}
// isH2CUpgrade returns true if the header properly request an upgrade to h2c
// as specified by Section 3.2.
func isH2CUpgrade(h http.Header) bool {
return httpguts.HeaderValuesContainsToken(h[textproto.CanonicalMIMEHeaderKey("Upgrade")], "h2c") &&
httpguts.HeaderValuesContainsToken(h[textproto.CanonicalMIMEHeaderKey("Connection")], "HTTP2-Settings")
}
// getH2Settings returns the []http2.Setting that are encoded in the
// HTTP2-Settings header.
func getH2Settings(h http.Header) ([]http2.Setting, error) {
vals, ok := h[textproto.CanonicalMIMEHeaderKey("HTTP2-Settings")]
if !ok {
return nil, errors.New("missing HTTP2-Settings header")
}
if len(vals) != 1 {
return nil, fmt.Errorf("expected 1 HTTP2-Settings. Got: %v", vals)
}
settings, err := decodeSettings(vals[0])
if err != nil {
return nil, fmt.Errorf("invalid HTTP2-Settings: %q", vals[0])
}
return settings, nil
}
// decodeSettings decodes the base64url header value of the HTTP2-Settings
// header. RFC 7540 Section 3.2.1.
func decodeSettings(headerVal string) ([]http2.Setting, error) {
b, err := base64.RawURLEncoding.DecodeString(headerVal)
if err != nil {
return nil, err
}
if len(b)%6 != 0 {
return nil, err
}
settings := make([]http2.Setting, 0)
for i := 0; i < len(b)/6; i++ {
settings = append(settings, http2.Setting{
ID: http2.SettingID(binary.BigEndian.Uint16(b[i*6 : i*6+2])),
Val: binary.BigEndian.Uint32(b[i*6+2 : i*6+6]),
})
}
return settings, nil
}
// getH2HeaderBytes return the headers in r a []bytes encoded by HPACK.
func getH2HeaderBytes(r *http.Request, maxHeaderTableSize uint32) ([]byte, error) {
headerBytes := bytes.NewBuffer(nil)
hpackEnc := hpack.NewEncoder(headerBytes)
hpackEnc.SetMaxDynamicTableSize(maxHeaderTableSize)
// Section 8.1.2.3
err := hpackEnc.WriteField(hpack.HeaderField{
Name: ":method",
Value: r.Method,
})
if err != nil {
return nil, err
}
err = hpackEnc.WriteField(hpack.HeaderField{
Name: ":scheme",
Value: "http",
})
if err != nil {
return nil, err
}
err = hpackEnc.WriteField(hpack.HeaderField{
Name: ":authority",
Value: r.Host,
})
if err != nil {
return nil, err
}
path := r.URL.Path
if r.URL.RawQuery != "" {
path = strings.Join([]string{path, r.URL.RawQuery}, "?")
}
err = hpackEnc.WriteField(hpack.HeaderField{
Name: ":path",
Value: path,
})
if err != nil {
return nil, err
}
// TODO Implement Section 8.3
for header, values := range r.Header {
// Skip non h2 headers
if isNonH2Header(header) {
continue
}
for _, v := range values {
err := hpackEnc.WriteField(hpack.HeaderField{
Name: strings.ToLower(header),
Value: v,
})
if err != nil {
return nil, err
}
}
}
return headerBytes.Bytes(), nil
}
// Connection specific headers listed in RFC 7540 Section 8.1.2.2 that are not
// suppose to be transferred to HTTP/2. The Http2-Settings header is skipped
// since already use to create the HTTP/2 SETTINGS frame.
var nonH2Headers = []string{
"Connection",
"Keep-Alive",
"Proxy-Connection",
"Transfer-Encoding",
"Upgrade",
"Http2-Settings",
}
// isNonH2Header returns true if header should not be transferred to HTTP/2.
func isNonH2Header(header string) bool {
for _, nonH2h := range nonH2Headers {
if header == nonH2h {
return true
}
}
return false
}

View file

@ -0,0 +1,223 @@
package healthcheck
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"strconv"
"sync"
"time"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/safe"
"github.com/go-kit/kit/metrics"
"github.com/vulcand/oxy/roundrobin"
)
var singleton *HealthCheck
var once sync.Once
// BalancerHandler includes functionality for load-balancing management.
type BalancerHandler interface {
ServeHTTP(w http.ResponseWriter, req *http.Request)
Servers() []*url.URL
RemoveServer(u *url.URL) error
UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error
}
// metricsRegistry is a local interface in the health check package, exposing only the required metrics
// necessary for the health check package. This makes it easier for the tests.
type metricsRegistry interface {
BackendServerUpGauge() metrics.Gauge
}
// Options are the public health check options.
type Options struct {
Headers map[string]string
Hostname string
Scheme string
Path string
Port int
Transport http.RoundTripper
Interval time.Duration
Timeout time.Duration
LB BalancerHandler
}
func (opt Options) String() string {
return fmt.Sprintf("[Hostname: %s Headers: %v Path: %s Port: %d Interval: %s Timeout: %s]", opt.Hostname, opt.Headers, opt.Path, opt.Port, opt.Interval, opt.Timeout)
}
// BackendConfig HealthCheck configuration for a backend
type BackendConfig struct {
Options
name string
disabledURLs []*url.URL
}
func (b *BackendConfig) newRequest(serverURL *url.URL) (*http.Request, error) {
u, err := serverURL.Parse(b.Path)
if err != nil {
return nil, err
}
if len(b.Scheme) > 0 {
u.Scheme = b.Scheme
}
if b.Port != 0 {
u.Host = net.JoinHostPort(u.Hostname(), strconv.Itoa(b.Port))
}
return http.NewRequest(http.MethodGet, u.String(), http.NoBody)
}
// this function adds additional http headers and hostname to http.request
func (b *BackendConfig) addHeadersAndHost(req *http.Request) *http.Request {
if b.Options.Hostname != "" {
req.Host = b.Options.Hostname
}
for k, v := range b.Options.Headers {
req.Header.Set(k, v)
}
return req
}
// HealthCheck struct
type HealthCheck struct {
Backends map[string]*BackendConfig
metrics metricsRegistry
cancel context.CancelFunc
}
// SetBackendsConfiguration set backends configuration
func (hc *HealthCheck) SetBackendsConfiguration(parentCtx context.Context, backends map[string]*BackendConfig) {
hc.Backends = backends
if hc.cancel != nil {
hc.cancel()
}
ctx, cancel := context.WithCancel(parentCtx)
hc.cancel = cancel
for _, backend := range backends {
currentBackend := backend
safe.Go(func() {
hc.execute(ctx, currentBackend)
})
}
}
func (hc *HealthCheck) execute(ctx context.Context, backend *BackendConfig) {
log.Debugf("Initial health check for backend: %q", backend.name)
hc.checkBackend(backend)
ticker := time.NewTicker(backend.Interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Debugf("Stopping current health check goroutines of backend: %s", backend.name)
return
case <-ticker.C:
log.Debugf("Refreshing health check for backend: %s", backend.name)
hc.checkBackend(backend)
}
}
}
func (hc *HealthCheck) checkBackend(backend *BackendConfig) {
enabledURLs := backend.LB.Servers()
var newDisabledURLs []*url.URL
// FIXME re enable metrics
for _, disableURL := range backend.disabledURLs {
// FIXME serverUpMetricValue := float64(0)
if err := checkHealth(disableURL, backend); err == nil {
log.Warnf("Health check up: Returning to server list. Backend: %q URL: %q", backend.name, disableURL.String())
if err = backend.LB.UpsertServer(disableURL, roundrobin.Weight(1)); err != nil {
log.Error(err)
}
// FIXME serverUpMetricValue = 1
} else {
log.Warnf("Health check still failing. Backend: %q URL: %q Reason: %s", backend.name, disableURL.String(), err)
newDisabledURLs = append(newDisabledURLs, disableURL)
}
// FIXME labelValues := []string{"backend", backend.name, "url", disableURL.String()}
// FIXME hc.metrics.BackendServerUpGauge().With(labelValues...).Set(serverUpMetricValue)
}
backend.disabledURLs = newDisabledURLs
// FIXME re enable metrics
for _, enableURL := range enabledURLs {
// FIXME serverUpMetricValue := float64(1)
if err := checkHealth(enableURL, backend); err != nil {
log.Warnf("Health check failed: Remove from server list. Backend: %q URL: %q Reason: %s", backend.name, enableURL.String(), err)
if err := backend.LB.RemoveServer(enableURL); err != nil {
log.Error(err)
}
backend.disabledURLs = append(backend.disabledURLs, enableURL)
// FIXME serverUpMetricValue = 0
}
// FIXME labelValues := []string{"backend", backend.name, "url", enableURL.String()}
// FIXME hc.metrics.BackendServerUpGauge().With(labelValues...).Set(serverUpMetricValue)
}
}
// FIXME re add metrics
//func GetHealthCheck(metrics metricsRegistry) *HealthCheck {
// GetHealthCheck returns the health check which is guaranteed to be a singleton.
func GetHealthCheck() *HealthCheck {
once.Do(func() {
singleton = newHealthCheck()
//singleton = newHealthCheck(metrics)
})
return singleton
}
// FIXME re add metrics
//func newHealthCheck(metrics metricsRegistry) *HealthCheck {
func newHealthCheck() *HealthCheck {
return &HealthCheck{
Backends: make(map[string]*BackendConfig),
//metrics: metrics,
}
}
// NewBackendConfig Instantiate a new BackendConfig
func NewBackendConfig(options Options, backendName string) *BackendConfig {
return &BackendConfig{
Options: options,
name: backendName,
}
}
// checkHealth returns a nil error in case it was successful and otherwise
// a non-nil error with a meaningful description why the health check failed.
func checkHealth(serverURL *url.URL, backend *BackendConfig) error {
req, err := backend.newRequest(serverURL)
if err != nil {
return fmt.Errorf("failed to create HTTP request: %s", err)
}
req = backend.addHeadersAndHost(req)
client := http.Client{
Timeout: backend.Options.Timeout,
Transport: backend.Options.Transport,
}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("HTTP request failed: %s", err)
}
defer resp.Body.Close()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest {
return fmt.Errorf("received error status code: %v", resp.StatusCode)
}
return nil
}

View file

@ -0,0 +1,429 @@
package healthcheck
import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"sync"
"testing"
"time"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vulcand/oxy/roundrobin"
)
const healthCheckInterval = 200 * time.Millisecond
const healthCheckTimeout = 100 * time.Millisecond
type testHandler struct {
done func()
healthSequence []int
}
func TestSetBackendsConfiguration(t *testing.T) {
testCases := []struct {
desc string
startHealthy bool
healthSequence []int
expectedNumRemovedServers int
expectedNumUpsertedServers int
expectedGaugeValue float64
}{
{
desc: "healthy server staying healthy",
startHealthy: true,
healthSequence: []int{http.StatusOK},
expectedNumRemovedServers: 0,
expectedNumUpsertedServers: 0,
expectedGaugeValue: 1,
},
{
desc: "healthy server staying healthy (StatusNoContent)",
startHealthy: true,
healthSequence: []int{http.StatusNoContent},
expectedNumRemovedServers: 0,
expectedNumUpsertedServers: 0,
expectedGaugeValue: 1,
},
{
desc: "healthy server staying healthy (StatusPermanentRedirect)",
startHealthy: true,
healthSequence: []int{http.StatusPermanentRedirect},
expectedNumRemovedServers: 0,
expectedNumUpsertedServers: 0,
expectedGaugeValue: 1,
},
{
desc: "healthy server becoming sick",
startHealthy: true,
healthSequence: []int{http.StatusServiceUnavailable},
expectedNumRemovedServers: 1,
expectedNumUpsertedServers: 0,
expectedGaugeValue: 0,
},
{
desc: "sick server becoming healthy",
startHealthy: false,
healthSequence: []int{http.StatusOK},
expectedNumRemovedServers: 0,
expectedNumUpsertedServers: 1,
expectedGaugeValue: 1,
},
{
desc: "sick server staying sick",
startHealthy: false,
healthSequence: []int{http.StatusServiceUnavailable},
expectedNumRemovedServers: 0,
expectedNumUpsertedServers: 0,
expectedGaugeValue: 0,
},
{
desc: "healthy server toggling to sick and back to healthy",
startHealthy: true,
healthSequence: []int{http.StatusServiceUnavailable, http.StatusOK},
expectedNumRemovedServers: 1,
expectedNumUpsertedServers: 1,
expectedGaugeValue: 1,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
// The context is passed to the health check and canonically canceled by
// the test server once all expected requests have been received.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ts := newTestServer(cancel, test.healthSequence)
defer ts.Close()
lb := &testLoadBalancer{RWMutex: &sync.RWMutex{}}
backend := NewBackendConfig(Options{
Path: "/path",
Interval: healthCheckInterval,
Timeout: healthCheckTimeout,
LB: lb,
}, "backendName")
serverURL := testhelpers.MustParseURL(ts.URL)
if test.startHealthy {
lb.servers = append(lb.servers, serverURL)
} else {
backend.disabledURLs = append(backend.disabledURLs, serverURL)
}
collectingMetrics := testhelpers.NewCollectingHealthCheckMetrics()
check := HealthCheck{
Backends: make(map[string]*BackendConfig),
metrics: collectingMetrics,
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
check.execute(ctx, backend)
wg.Done()
}()
// Make test timeout dependent on number of expected requests, health
// check interval, and a safety margin.
timeout := time.Duration(len(test.healthSequence)*int(healthCheckInterval) + 500)
select {
case <-time.After(timeout):
t.Fatal("test did not complete in time")
case <-ctx.Done():
wg.Wait()
}
lb.Lock()
defer lb.Unlock()
assert.Equal(t, test.expectedNumRemovedServers, lb.numRemovedServers, "removed servers")
assert.Equal(t, test.expectedNumUpsertedServers, lb.numUpsertedServers, "upserted servers")
// FIXME re add metrics
//assert.Equal(t, test.expectedGaugeValue, collectingMetrics.Gauge.GaugeValue, "ServerUp Gauge")
})
}
}
func TestNewRequest(t *testing.T) {
type expected struct {
err bool
value string
}
testCases := []struct {
desc string
serverURL string
options Options
expected expected
}{
{
desc: "no port override",
serverURL: "http://backend1:80",
options: Options{
Path: "/test",
Port: 0,
},
expected: expected{
err: false,
value: "http://backend1:80/test",
},
},
{
desc: "port override",
serverURL: "http://backend2:80",
options: Options{
Path: "/test",
Port: 8080,
},
expected: expected{
err: false,
value: "http://backend2:8080/test",
},
},
{
desc: "no port override with no port in server URL",
serverURL: "http://backend1",
options: Options{
Path: "/health",
Port: 0,
},
expected: expected{
err: false,
value: "http://backend1/health",
},
},
{
desc: "port override with no port in server URL",
serverURL: "http://backend2",
options: Options{
Path: "/health",
Port: 8080,
},
expected: expected{
err: false,
value: "http://backend2:8080/health",
},
},
{
desc: "scheme override",
serverURL: "https://backend1:80",
options: Options{
Scheme: "http",
Path: "/test",
Port: 0,
},
expected: expected{
err: false,
value: "http://backend1:80/test",
},
},
{
desc: "path with param",
serverURL: "http://backend1:80",
options: Options{
Path: "/health?powpow=do",
Port: 0,
},
expected: expected{
err: false,
value: "http://backend1:80/health?powpow=do",
},
},
{
desc: "path with params",
serverURL: "http://backend1:80",
options: Options{
Path: "/health?powpow=do&do=powpow",
Port: 0,
},
expected: expected{
err: false,
value: "http://backend1:80/health?powpow=do&do=powpow",
},
},
{
desc: "path with invalid path",
serverURL: "http://backend1:80",
options: Options{
Path: ":",
Port: 0,
},
expected: expected{
err: true,
value: "",
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
backend := NewBackendConfig(test.options, "backendName")
u := testhelpers.MustParseURL(test.serverURL)
req, err := backend.newRequest(u)
if test.expected.err {
require.Error(t, err)
assert.Nil(t, nil)
} else {
require.NoError(t, err, "failed to create new backend request")
require.NotNil(t, req)
assert.Equal(t, test.expected.value, req.URL.String())
}
})
}
}
func TestAddHeadersAndHost(t *testing.T) {
testCases := []struct {
desc string
serverURL string
options Options
expectedHostname string
expectedHeader string
}{
{
desc: "override hostname",
serverURL: "http://backend1:80",
options: Options{
Hostname: "myhost",
Path: "/",
},
expectedHostname: "myhost",
expectedHeader: "",
},
{
desc: "not override hostname",
serverURL: "http://backend1:80",
options: Options{
Hostname: "",
Path: "/",
},
expectedHostname: "backend1:80",
expectedHeader: "",
},
{
desc: "custom header",
serverURL: "http://backend1:80",
options: Options{
Headers: map[string]string{"Custom-Header": "foo"},
Hostname: "",
Path: "/",
},
expectedHostname: "backend1:80",
expectedHeader: "foo",
},
{
desc: "custom header with hostname override",
serverURL: "http://backend1:80",
options: Options{
Headers: map[string]string{"Custom-Header": "foo"},
Hostname: "myhost",
Path: "/",
},
expectedHostname: "myhost",
expectedHeader: "foo",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
backend := NewBackendConfig(test.options, "backendName")
u, err := url.Parse(test.serverURL)
require.NoError(t, err)
req, err := backend.newRequest(u)
require.NoError(t, err, "failed to create new backend request")
req = backend.addHeadersAndHost(req)
assert.Equal(t, "http://backend1:80/", req.URL.String())
assert.Equal(t, test.expectedHostname, req.Host)
assert.Equal(t, test.expectedHeader, req.Header.Get("Custom-Header"))
})
}
}
type testLoadBalancer struct {
// RWMutex needed due to parallel test execution: Both the system-under-test
// and the test assertions reference the counters.
*sync.RWMutex
numRemovedServers int
numUpsertedServers int
servers []*url.URL
}
func (lb *testLoadBalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// noop
}
func (lb *testLoadBalancer) RemoveServer(u *url.URL) error {
lb.Lock()
defer lb.Unlock()
lb.numRemovedServers++
lb.removeServer(u)
return nil
}
func (lb *testLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error {
lb.Lock()
defer lb.Unlock()
lb.numUpsertedServers++
lb.servers = append(lb.servers, u)
return nil
}
func (lb *testLoadBalancer) Servers() []*url.URL {
return lb.servers
}
func (lb *testLoadBalancer) removeServer(u *url.URL) {
var i int
var serverURL *url.URL
for i, serverURL = range lb.servers {
if *serverURL == *u {
break
}
}
lb.servers = append(lb.servers[:i], lb.servers[i+1:]...)
}
func newTestServer(done func(), healthSequence []int) *httptest.Server {
handler := &testHandler{
done: done,
healthSequence: healthSequence,
}
return httptest.NewServer(handler)
}
// ServeHTTP returns HTTP response codes following a status sequences.
// It calls the given 'done' function once all request health indicators have been depleted.
func (th *testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if len(th.healthSequence) == 0 {
panic("received unexpected request")
}
w.WriteHeader(th.healthSequence[0])
th.healthSequence = th.healthSequence[1:]
if len(th.healthSequence) == 0 {
th.done()
}
}

View file

@ -0,0 +1,123 @@
package hostresolver
import (
"fmt"
"net"
"sort"
"strings"
"time"
"github.com/containous/traefik/pkg/log"
"github.com/miekg/dns"
"github.com/patrickmn/go-cache"
)
type cnameResolv struct {
TTL time.Duration
Record string
}
type byTTL []*cnameResolv
func (a byTTL) Len() int { return len(a) }
func (a byTTL) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a byTTL) Less(i, j int) bool { return a[i].TTL > a[j].TTL }
// Resolver used for host resolver
type Resolver struct {
CnameFlattening bool
ResolvConfig string
ResolvDepth int
cache *cache.Cache
}
// CNAMEFlatten check if CNAME record exists, flatten if possible
func (hr *Resolver) CNAMEFlatten(host string) (string, string) {
if hr.cache == nil {
hr.cache = cache.New(30*time.Minute, 5*time.Minute)
}
result := []string{host}
request := host
value, found := hr.cache.Get(host)
if found {
result = strings.Split(value.(string), ",")
} else {
var cacheDuration = 0 * time.Second
for depth := 0; depth < hr.ResolvDepth; depth++ {
resolv, err := cnameResolve(request, hr.ResolvConfig)
if err != nil {
log.Error(err)
break
}
if resolv == nil {
break
}
result = append(result, resolv.Record)
if depth == 0 {
cacheDuration = resolv.TTL
}
request = resolv.Record
}
if err := hr.cache.Add(host, strings.Join(result, ","), cacheDuration); err != nil {
log.Error(err)
}
}
return result[0], result[len(result)-1]
}
// cnameResolve resolves CNAME if exists, and return with the highest TTL
func cnameResolve(host string, resolvPath string) (*cnameResolv, error) {
config, err := dns.ClientConfigFromFile(resolvPath)
if err != nil {
return nil, fmt.Errorf("invalid resolver configuration file: %s", resolvPath)
}
client := &dns.Client{Timeout: 30 * time.Second}
m := &dns.Msg{}
m.SetQuestion(dns.Fqdn(host), dns.TypeCNAME)
var result []*cnameResolv
for _, server := range config.Servers {
tempRecord, err := getRecord(client, m, server, config.Port)
if err != nil {
log.Errorf("Failed to resolve host %s: %v", host, err)
continue
}
result = append(result, tempRecord)
}
if len(result) == 0 {
return nil, nil
}
sort.Sort(byTTL(result))
return result[0], nil
}
func getRecord(client *dns.Client, msg *dns.Msg, server string, port string) (*cnameResolv, error) {
resp, _, err := client.Exchange(msg, net.JoinHostPort(server, port))
if err != nil {
return nil, fmt.Errorf("exchange error for server %s: %v", server, err)
}
if resp == nil || len(resp.Answer) == 0 {
return nil, fmt.Errorf("empty answer for server %s", server)
}
rr, ok := resp.Answer[0].(*dns.CNAME)
if !ok {
return nil, fmt.Errorf("invalid response type for server %s", server)
}
return &cnameResolv{
TTL: time.Duration(rr.Hdr.Ttl) * time.Second,
Record: strings.TrimSuffix(rr.Target, "."),
}, nil
}

View file

@ -0,0 +1,61 @@
package hostresolver
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestCNAMEFlatten(t *testing.T) {
testCase := []struct {
desc string
resolvFile string
domain string
expectedDomain string
isCNAME bool
}{
{
desc: "host request is CNAME record",
resolvFile: "/etc/resolv.conf",
domain: "www.github.com",
expectedDomain: "github.com",
isCNAME: true,
},
{
desc: "resolve file not found",
resolvFile: "/etc/resolv.oops",
domain: "www.github.com",
expectedDomain: "www.github.com",
isCNAME: false,
},
{
desc: "host request is not CNAME record",
resolvFile: "/etc/resolv.conf",
domain: "github.com",
expectedDomain: "github.com",
isCNAME: false,
},
}
for _, test := range testCase {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
hostResolver := &Resolver{
ResolvConfig: test.resolvFile,
ResolvDepth: 5,
}
reqH, flatH := hostResolver.CNAMEFlatten(test.domain)
assert.Equal(t, test.domain, reqH)
assert.Equal(t, test.expectedDomain, flatH)
if test.isCNAME {
assert.NotEqual(t, test.expectedDomain, reqH)
} else {
assert.Equal(t, test.expectedDomain, reqH)
}
})
}
}

99
pkg/ip/checker.go Normal file
View file

@ -0,0 +1,99 @@
package ip
import (
"errors"
"fmt"
"net"
"strings"
)
// Checker allows to check that addresses are in a trusted IPs
type Checker struct {
authorizedIPs []*net.IP
authorizedIPsNet []*net.IPNet
}
// NewChecker builds a new Checker given a list of CIDR-Strings to trusted IPs
func NewChecker(trustedIPs []string) (*Checker, error) {
if len(trustedIPs) == 0 {
return nil, errors.New("no trusted IPs provided")
}
checker := &Checker{}
for _, ipMask := range trustedIPs {
if ipAddr := net.ParseIP(ipMask); ipAddr != nil {
checker.authorizedIPs = append(checker.authorizedIPs, &ipAddr)
} else {
_, ipAddr, err := net.ParseCIDR(ipMask)
if err != nil {
return nil, fmt.Errorf("parsing CIDR trusted IPs %s: %v", ipAddr, err)
}
checker.authorizedIPsNet = append(checker.authorizedIPsNet, ipAddr)
}
}
return checker, nil
}
// IsAuthorized checks if provided request is authorized by the trusted IPs
func (ip *Checker) IsAuthorized(addr string) error {
var invalidMatches []string
host, _, err := net.SplitHostPort(addr)
if err != nil {
host = addr
}
ok, err := ip.Contains(host)
if err != nil {
return err
}
if !ok {
invalidMatches = append(invalidMatches, addr)
return fmt.Errorf("%q matched none of the trusted IPs", strings.Join(invalidMatches, ", "))
}
return nil
}
// Contains checks if provided address is in the trusted IPs
func (ip *Checker) Contains(addr string) (bool, error) {
if len(addr) == 0 {
return false, errors.New("empty IP address")
}
ipAddr, err := parseIP(addr)
if err != nil {
return false, fmt.Errorf("unable to parse address: %s: %s", addr, err)
}
return ip.ContainsIP(ipAddr), nil
}
// ContainsIP checks if provided address is in the trusted IPs
func (ip *Checker) ContainsIP(addr net.IP) bool {
for _, authorizedIP := range ip.authorizedIPs {
if authorizedIP.Equal(addr) {
return true
}
}
for _, authorizedNet := range ip.authorizedIPsNet {
if authorizedNet.Contains(addr) {
return true
}
}
return false
}
func parseIP(addr string) (net.IP, error) {
userIP := net.ParseIP(addr)
if userIP == nil {
return nil, fmt.Errorf("can't parse IP from address %s", addr)
}
return userIP, nil
}

326
pkg/ip/checker_test.go Normal file
View file

@ -0,0 +1,326 @@
package ip
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIsAuthorized(t *testing.T) {
testCases := []struct {
desc string
whiteList []string
remoteAddr string
authorized bool
}{
{
desc: "remoteAddr not in range",
whiteList: []string{"1.2.3.4/24"},
remoteAddr: "10.2.3.1:123",
authorized: false,
},
{
desc: "remoteAddr in range",
whiteList: []string{"1.2.3.4/24"},
remoteAddr: "1.2.3.1:123",
authorized: true,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
ipChecker, err := NewChecker(test.whiteList)
require.NoError(t, err)
err = ipChecker.IsAuthorized(test.remoteAddr)
if test.authorized {
require.NoError(t, err)
} else {
require.Error(t, err)
}
})
}
}
func TestNew(t *testing.T) {
testCases := []struct {
desc string
trustedIPs []string
expectedAuthorizedIPs []*net.IPNet
errMessage string
}{
{
desc: "nil trusted IPs",
trustedIPs: nil,
expectedAuthorizedIPs: nil,
errMessage: "no trusted IPs provided",
}, {
desc: "empty trusted IPs",
trustedIPs: []string{},
expectedAuthorizedIPs: nil,
errMessage: "no trusted IPs provided",
}, {
desc: "trusted IPs containing empty string",
trustedIPs: []string{
"1.2.3.4/24",
"",
"fe80::/16",
},
expectedAuthorizedIPs: nil,
errMessage: "parsing CIDR trusted IPs <nil>: invalid CIDR address: ",
}, {
desc: "trusted IPs containing only an empty string",
trustedIPs: []string{
"",
},
expectedAuthorizedIPs: nil,
errMessage: "parsing CIDR trusted IPs <nil>: invalid CIDR address: ",
}, {
desc: "trusted IPs containing an invalid string",
trustedIPs: []string{
"foo",
},
expectedAuthorizedIPs: nil,
errMessage: "parsing CIDR trusted IPs <nil>: invalid CIDR address: foo",
}, {
desc: "IPv4 & IPv6 trusted IPs",
trustedIPs: []string{
"1.2.3.4/24",
"fe80::/16",
},
expectedAuthorizedIPs: []*net.IPNet{
{IP: net.IPv4(1, 2, 3, 0).To4(), Mask: net.IPv4Mask(255, 255, 255, 0)},
{IP: net.ParseIP("fe80::"), Mask: net.IPMask(net.ParseIP("ffff::"))},
},
errMessage: "",
}, {
desc: "IPv4 only",
trustedIPs: []string{
"127.0.0.1/8",
},
expectedAuthorizedIPs: []*net.IPNet{
{IP: net.IPv4(127, 0, 0, 0).To4(), Mask: net.IPv4Mask(255, 0, 0, 0)},
},
errMessage: "",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
ipChecker, err := NewChecker(test.trustedIPs)
if test.errMessage != "" {
require.EqualError(t, err, test.errMessage)
} else {
require.NoError(t, err)
for index, actual := range ipChecker.authorizedIPsNet {
expected := test.expectedAuthorizedIPs[index]
assert.Equal(t, expected.IP, actual.IP)
assert.Equal(t, expected.Mask.String(), actual.Mask.String())
}
}
})
}
}
func TestContainsIsAllowed(t *testing.T) {
testCases := []struct {
desc string
trustedIPs []string
passIPs []string
rejectIPs []string
}{
{
desc: "IPv4",
trustedIPs: []string{"1.2.3.4/24"},
passIPs: []string{
"1.2.3.1",
"1.2.3.32",
"1.2.3.156",
"1.2.3.255",
},
rejectIPs: []string{
"1.2.16.1",
"1.2.32.1",
"127.0.0.1",
"8.8.8.8",
},
},
{
desc: "IPv4 single IP",
trustedIPs: []string{"8.8.8.8"},
passIPs: []string{"8.8.8.8"},
rejectIPs: []string{
"8.8.8.7",
"8.8.8.9",
"8.8.8.0",
"8.8.8.255",
"4.4.4.4",
"127.0.0.1",
},
},
{
desc: "IPv4 Net single IP",
trustedIPs: []string{"8.8.8.8/32"},
passIPs: []string{"8.8.8.8"},
rejectIPs: []string{
"8.8.8.7",
"8.8.8.9",
"8.8.8.0",
"8.8.8.255",
"4.4.4.4",
"127.0.0.1",
},
},
{
desc: "multiple IPv4",
trustedIPs: []string{"1.2.3.4/24", "8.8.8.8/8"},
passIPs: []string{
"1.2.3.1",
"1.2.3.32",
"1.2.3.156",
"1.2.3.255",
"8.8.4.4",
"8.0.0.1",
"8.32.42.128",
"8.255.255.255",
},
rejectIPs: []string{
"1.2.16.1",
"1.2.32.1",
"127.0.0.1",
"4.4.4.4",
"4.8.8.8",
},
},
{
desc: "IPv6",
trustedIPs: []string{"2a03:4000:6:d080::/64"},
passIPs: []string{
"2a03:4000:6:d080::",
"2a03:4000:6:d080::1",
"2a03:4000:6:d080:dead:beef:ffff:ffff",
"2a03:4000:6:d080::42",
},
rejectIPs: []string{
"2a03:4000:7:d080::",
"2a03:4000:7:d080::1",
"fe80::",
"4242::1",
},
},
{
desc: "IPv6 single IP",
trustedIPs: []string{"2a03:4000:6:d080::42/128"},
passIPs: []string{"2a03:4000:6:d080::42"},
rejectIPs: []string{
"2a03:4000:6:d080::1",
"2a03:4000:6:d080:dead:beef:ffff:ffff",
"2a03:4000:6:d080::43",
},
},
{
desc: "multiple IPv6",
trustedIPs: []string{"2a03:4000:6:d080::/64", "fe80::/16"},
passIPs: []string{
"2a03:4000:6:d080::",
"2a03:4000:6:d080::1",
"2a03:4000:6:d080:dead:beef:ffff:ffff",
"2a03:4000:6:d080::42",
"fe80::1",
"fe80:aa00:00bb:4232:ff00:eeee:00ff:1111",
"fe80::fe80",
},
rejectIPs: []string{
"2a03:4000:7:d080::",
"2a03:4000:7:d080::1",
"4242::1",
},
},
{
desc: "multiple IPv6 & IPv4",
trustedIPs: []string{"2a03:4000:6:d080::/64", "fe80::/16", "1.2.3.4/24", "8.8.8.8/8"},
passIPs: []string{
"2a03:4000:6:d080::",
"2a03:4000:6:d080::1",
"2a03:4000:6:d080:dead:beef:ffff:ffff",
"2a03:4000:6:d080::42",
"fe80::1",
"fe80:aa00:00bb:4232:ff00:eeee:00ff:1111",
"fe80::fe80",
"1.2.3.1",
"1.2.3.32",
"1.2.3.156",
"1.2.3.255",
"8.8.4.4",
"8.0.0.1",
"8.32.42.128",
"8.255.255.255",
},
rejectIPs: []string{
"2a03:4000:7:d080::",
"2a03:4000:7:d080::1",
"4242::1",
"1.2.16.1",
"1.2.32.1",
"127.0.0.1",
"4.4.4.4",
"4.8.8.8",
},
},
{
desc: "broken IP-addresses",
trustedIPs: []string{"127.0.0.1/32"},
passIPs: nil,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
ipChecker, err := NewChecker(test.trustedIPs)
require.NoError(t, err)
require.NotNil(t, ipChecker)
for _, testIP := range test.passIPs {
allowed, err := ipChecker.Contains(testIP)
require.NoError(t, err)
assert.Truef(t, allowed, "%s should have passed.", testIP)
}
for _, testIP := range test.rejectIPs {
allowed, err := ipChecker.Contains(testIP)
require.NoError(t, err)
assert.Falsef(t, allowed, "%s should not have passed.", testIP)
}
})
}
}
func TestContainsBrokenIPs(t *testing.T) {
brokenIPs := []string{
"foo",
"10.0.0.350",
"fe:::80",
"",
"\\&$§&/(",
}
ipChecker, err := NewChecker([]string{"1.2.3.4/24"})
require.NoError(t, err)
for _, testIP := range brokenIPs {
_, err := ipChecker.Contains(testIP)
assert.Error(t, err)
}
}

63
pkg/ip/strategy.go Normal file
View file

@ -0,0 +1,63 @@
package ip
import (
"net/http"
"strings"
)
const (
xForwardedFor = "X-Forwarded-For"
)
// Strategy a strategy for IP selection
type Strategy interface {
GetIP(req *http.Request) string
}
// RemoteAddrStrategy a strategy that always return the remote address
type RemoteAddrStrategy struct{}
// GetIP return the selected IP
func (s *RemoteAddrStrategy) GetIP(req *http.Request) string {
return req.RemoteAddr
}
// DepthStrategy a strategy based on the depth inside the X-Forwarded-For from right to left
type DepthStrategy struct {
Depth int
}
// GetIP return the selected IP
func (s *DepthStrategy) GetIP(req *http.Request) string {
xff := req.Header.Get(xForwardedFor)
xffs := strings.Split(xff, ",")
if len(xffs) < s.Depth {
return ""
}
return strings.TrimSpace(xffs[len(xffs)-s.Depth])
}
// CheckerStrategy a strategy based on an IP Checker
// allows to check that addresses are in a trusted IPs
type CheckerStrategy struct {
Checker *Checker
}
// GetIP return the selected IP
func (s *CheckerStrategy) GetIP(req *http.Request) string {
if s.Checker == nil {
return ""
}
xff := req.Header.Get(xForwardedFor)
xffs := strings.Split(xff, ",")
for i := len(xffs) - 1; i >= 0; i-- {
xffTrimmed := strings.TrimSpace(xffs[i])
if contain, _ := s.Checker.Contains(xffTrimmed); !contain {
return xffTrimmed
}
}
return ""
}

125
pkg/ip/strategy_test.go Normal file
View file

@ -0,0 +1,125 @@
package ip
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRemoteAddrStrategy_GetIP(t *testing.T) {
testCases := []struct {
desc string
expected string
}{
{
desc: "Use RemoteAddr",
expected: "192.0.2.1:1234",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
strategy := RemoteAddrStrategy{}
req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
actual := strategy.GetIP(req)
assert.Equal(t, test.expected, actual)
})
}
}
func TestDepthStrategy_GetIP(t *testing.T) {
testCases := []struct {
desc string
depth int
xForwardedFor string
expected string
}{
{
desc: "Use depth",
depth: 3,
xForwardedFor: "10.0.0.4,10.0.0.3,10.0.0.2,10.0.0.1",
expected: "10.0.0.3",
},
{
desc: "Use non existing depth in XForwardedFor",
depth: 2,
xForwardedFor: "",
expected: "",
},
{
desc: "Use depth that match the first IP in XForwardedFor",
depth: 2,
xForwardedFor: "10.0.0.2,10.0.0.1",
expected: "10.0.0.2",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
strategy := DepthStrategy{Depth: test.depth}
req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
req.Header.Set(xForwardedFor, test.xForwardedFor)
actual := strategy.GetIP(req)
assert.Equal(t, test.expected, actual)
})
}
}
func TestExcludedIPsStrategy_GetIP(t *testing.T) {
testCases := []struct {
desc string
excludedIPs []string
xForwardedFor string
expected string
}{
{
desc: "Use excluded all IPs",
excludedIPs: []string{"10.0.0.4", "10.0.0.3", "10.0.0.2", "10.0.0.1"},
xForwardedFor: "10.0.0.4,10.0.0.3,10.0.0.2,10.0.0.1",
expected: "",
},
{
desc: "Use excluded IPs",
excludedIPs: []string{"10.0.0.2", "10.0.0.1"},
xForwardedFor: "10.0.0.4,10.0.0.3,10.0.0.2,10.0.0.1",
expected: "10.0.0.3",
},
{
desc: "Use excluded IPs CIDR",
excludedIPs: []string{"10.0.0.1/24"},
xForwardedFor: "127.0.0.1,10.0.0.4,10.0.0.3,10.0.0.2,10.0.0.1",
expected: "127.0.0.1",
},
{
desc: "Use excluded all IPs CIDR",
excludedIPs: []string{"10.0.0.1/24"},
xForwardedFor: "10.0.0.4,10.0.0.3,10.0.0.2,10.0.0.1",
expected: "",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
checker, err := NewChecker(test.excludedIPs)
require.NoError(t, err)
strategy := CheckerStrategy{Checker: checker}
req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
req.Header.Set(xForwardedFor, test.xForwardedFor)
actual := strategy.GetIP(req)
assert.Equal(t, test.expected, actual)
})
}
}

40
pkg/job/job.go Normal file
View file

@ -0,0 +1,40 @@
package job
import (
"time"
"github.com/cenkalti/backoff"
)
var (
_ backoff.BackOff = (*BackOff)(nil)
)
const (
defaultMinJobInterval = 30 * time.Second
)
// BackOff is an exponential backoff implementation for long running jobs.
// In long running jobs, an operation() that fails after a long Duration should not increments the backoff period.
// If operation() takes more than MinJobInterval, Reset() is called in NextBackOff().
type BackOff struct {
*backoff.ExponentialBackOff
MinJobInterval time.Duration
}
// NewBackOff creates an instance of BackOff using default values.
func NewBackOff(backOff *backoff.ExponentialBackOff) *BackOff {
backOff.MaxElapsedTime = 0
return &BackOff{
ExponentialBackOff: backOff,
MinJobInterval: defaultMinJobInterval,
}
}
// NextBackOff calculates the next backoff interval.
func (b *BackOff) NextBackOff() time.Duration {
if b.GetElapsedTime() >= b.MinJobInterval {
b.Reset()
}
return b.ExponentialBackOff.NextBackOff()
}

45
pkg/job/job_test.go Normal file
View file

@ -0,0 +1,45 @@
package job
import (
"testing"
"time"
"github.com/cenkalti/backoff"
)
func TestJobBackOff(t *testing.T) {
var (
testInitialInterval = 500 * time.Millisecond
testRandomizationFactor = 0.1
testMultiplier = 2.0
testMaxInterval = 5 * time.Second
testMinJobInterval = 1 * time.Second
)
exp := NewBackOff(backoff.NewExponentialBackOff())
exp.InitialInterval = testInitialInterval
exp.RandomizationFactor = testRandomizationFactor
exp.Multiplier = testMultiplier
exp.MaxInterval = testMaxInterval
exp.MinJobInterval = testMinJobInterval
exp.Reset()
var expectedResults = []time.Duration{500, 500, 500, 1000, 2000, 4000, 5000, 5000, 500, 1000, 2000, 4000, 5000, 5000}
for i, d := range expectedResults {
expectedResults[i] = d * time.Millisecond
}
for i, expected := range expectedResults {
// Assert that the next backoff falls in the expected range.
var minInterval = expected - time.Duration(testRandomizationFactor*float64(expected))
var maxInterval = expected + time.Duration(testRandomizationFactor*float64(expected))
if i < 3 || i == 8 {
time.Sleep(2 * time.Second)
}
var actualInterval = exp.NextBackOff()
if !(minInterval <= actualInterval && actualInterval <= maxInterval) {
t.Error("error")
}
// assertEquals(t, expected, exp.currentInterval)
}
}

139
pkg/log/deprecated.go Normal file
View file

@ -0,0 +1,139 @@
package log
import (
"bufio"
"io"
"runtime"
"github.com/sirupsen/logrus"
)
// Debug logs a message at level Debug on the standard logger.
// Deprecated
func Debug(args ...interface{}) {
mainLogger.Debug(args...)
}
// Debugf logs a message at level Debug on the standard logger.
// Deprecated
func Debugf(format string, args ...interface{}) {
mainLogger.Debugf(format, args...)
}
// Info logs a message at level Info on the standard logger.
// Deprecated
func Info(args ...interface{}) {
mainLogger.Info(args...)
}
// Infof logs a message at level Info on the standard logger.
// Deprecated
func Infof(format string, args ...interface{}) {
mainLogger.Infof(format, args...)
}
// Warn logs a message at level Warn on the standard logger.
// Deprecated
func Warn(args ...interface{}) {
mainLogger.Warn(args...)
}
// Warnf logs a message at level Warn on the standard logger.
// Deprecated
func Warnf(format string, args ...interface{}) {
mainLogger.Warnf(format, args...)
}
// Error logs a message at level Error on the standard logger.
// Deprecated
func Error(args ...interface{}) {
mainLogger.Error(args...)
}
// Errorf logs a message at level Error on the standard logger.
// Deprecated
func Errorf(format string, args ...interface{}) {
mainLogger.Errorf(format, args...)
}
// Panic logs a message at level Panic on the standard logger.
// Deprecated
func Panic(args ...interface{}) {
mainLogger.Panic(args...)
}
// Panicf logs a message at level Panic on the standard logger.
// Deprecated
func Panicf(format string, args ...interface{}) {
mainLogger.Panicf(format, args...)
}
// Fatal logs a message at level Fatal on the standard logger.
// Deprecated
func Fatal(args ...interface{}) {
mainLogger.Fatal(args...)
}
// Fatalf logs a message at level Fatal on the standard logger.
// Deprecated
func Fatalf(format string, args ...interface{}) {
mainLogger.Fatalf(format, args...)
}
// AddHook adds a hook to the standard logger hooks.
func AddHook(hook logrus.Hook) {
logrus.AddHook(hook)
}
// CustomWriterLevel logs writer for a specific level. (with a custom scanner buffer size.)
// adapted from github.com/Sirupsen/logrus/writer.go
func CustomWriterLevel(level logrus.Level, maxScanTokenSize int) *io.PipeWriter {
reader, writer := io.Pipe()
var printFunc func(args ...interface{})
switch level {
case logrus.DebugLevel:
printFunc = Debug
case logrus.InfoLevel:
printFunc = Info
case logrus.WarnLevel:
printFunc = Warn
case logrus.ErrorLevel:
printFunc = Error
case logrus.FatalLevel:
printFunc = Fatal
case logrus.PanicLevel:
printFunc = Panic
default:
printFunc = mainLogger.Print
}
go writerScanner(reader, maxScanTokenSize, printFunc)
runtime.SetFinalizer(writer, writerFinalizer)
return writer
}
// extract from github.com/Sirupsen/logrus/writer.go
// Hack the buffer size
func writerScanner(reader io.ReadCloser, scanTokenSize int, printFunc func(args ...interface{})) {
scanner := bufio.NewScanner(reader)
if scanTokenSize > bufio.MaxScanTokenSize {
buf := make([]byte, bufio.MaxScanTokenSize)
scanner.Buffer(buf, scanTokenSize)
}
for scanner.Scan() {
printFunc(scanner.Text())
}
if err := scanner.Err(); err != nil {
Errorf("Error while reading from Writer: %s", err)
}
reader.Close()
}
func writerFinalizer(writer *io.PipeWriter) {
writer.Close()
}

15
pkg/log/fields.go Normal file
View file

@ -0,0 +1,15 @@
package log
// Log entry name
const (
EntryPointName = "entryPointName"
RouterName = "routerName"
Rule = "rule"
MiddlewareName = "middlewareName"
MiddlewareType = "middlewareType"
ProviderName = "providerName"
ServiceName = "serviceName"
MetricsProviderName = "metricsProviderName"
TracingProviderName = "tracingProviderName"
ServerName = "serverName"
)

145
pkg/log/log.go Normal file
View file

@ -0,0 +1,145 @@
package log
import (
"context"
"fmt"
"io"
"os"
"github.com/sirupsen/logrus"
)
type contextKey int
const (
loggerKey contextKey = iota
)
// Logger the Traefik logger
type Logger interface {
logrus.FieldLogger
WriterLevel(logrus.Level) *io.PipeWriter
}
var (
mainLogger Logger
logFilePath string
logFile *os.File
)
func init() {
mainLogger = logrus.StandardLogger()
logrus.SetOutput(os.Stdout)
}
// SetLogger sets the logger.
func SetLogger(l Logger) {
mainLogger = l
}
// SetOutput sets the standard logger output.
func SetOutput(out io.Writer) {
logrus.SetOutput(out)
}
// SetFormatter sets the standard logger formatter.
func SetFormatter(formatter logrus.Formatter) {
logrus.SetFormatter(formatter)
}
// SetLevel sets the standard logger level.
func SetLevel(level logrus.Level) {
logrus.SetLevel(level)
}
// GetLevel returns the standard logger level.
func GetLevel() logrus.Level {
return logrus.GetLevel()
}
// Str adds a string field
func Str(key, value string) func(logrus.Fields) {
return func(fields logrus.Fields) {
fields[key] = value
}
}
// With Adds fields
func With(ctx context.Context, opts ...func(logrus.Fields)) context.Context {
logger := FromContext(ctx)
fields := make(logrus.Fields)
for _, opt := range opts {
opt(fields)
}
logger = logger.WithFields(fields)
return context.WithValue(ctx, loggerKey, logger)
}
// FromContext Gets the logger from context
func FromContext(ctx context.Context) Logger {
if ctx == nil {
panic("nil context")
}
logger, ok := ctx.Value(loggerKey).(Logger)
if !ok {
logger = mainLogger
}
return logger
}
// WithoutContext Gets the main logger
func WithoutContext() Logger {
return mainLogger
}
// OpenFile opens the log file using the specified path
func OpenFile(path string) error {
logFilePath = path
var err error
logFile, err = os.OpenFile(logFilePath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
if err != nil {
return err
}
SetOutput(logFile)
return nil
}
// CloseFile closes the log and sets the Output to stdout
func CloseFile() error {
logrus.SetOutput(os.Stdout)
if logFile != nil {
return logFile.Close()
}
return nil
}
// RotateFile closes and reopens the log file to allow for rotation
// by an external source. If the log isn't backed by a file then
// it does nothing.
func RotateFile() error {
logger := FromContext(context.Background())
if logFile == nil && logFilePath == "" {
logger.Debug("Traefik log is not writing to a file, ignoring rotate request")
return nil
}
if logFile != nil {
defer func(f *os.File) {
_ = f.Close()
}(logFile)
}
if err := OpenFile(logFilePath); err != nil {
return fmt.Errorf("error opening log file: %s", err)
}
return nil
}

58
pkg/log/log_test.go Normal file
View file

@ -0,0 +1,58 @@
package log
import (
"bytes"
"context"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestLog(t *testing.T) {
testCases := []struct {
desc string
fields map[string]string
expected string
}{
{
desc: "Log with one field",
fields: map[string]string{
"foo": "bar",
},
expected: ` level=error msg="message test" foo=bar$`,
},
{
desc: "Log with two fields",
fields: map[string]string{
"foo": "bar",
"oof": "rab",
},
expected: ` level=error msg="message test" foo=bar oof=rab$`,
},
{
desc: "Log without field",
fields: map[string]string{},
expected: ` level=error msg="message test"$`,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
var buffer bytes.Buffer
SetOutput(&buffer)
ctx := context.Background()
for key, value := range test.fields {
ctx = With(ctx, Str(key, value))
}
FromContext(ctx).Error("message test")
assert.Regexp(t, test.expected, strings.TrimSpace(buffer.String()))
})
}
}

88
pkg/metrics/datadog.go Normal file
View file

@ -0,0 +1,88 @@
package metrics
import (
"context"
"time"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/safe"
"github.com/containous/traefik/pkg/types"
kitlog "github.com/go-kit/kit/log"
"github.com/go-kit/kit/metrics/dogstatsd"
)
var datadogClient = dogstatsd.New("traefik.", kitlog.LoggerFunc(func(keyvals ...interface{}) error {
log.WithoutContext().WithField(log.MetricsProviderName, "datadog").Info(keyvals)
return nil
}))
var datadogTicker *time.Ticker
// Metric names consistent with https://github.com/DataDog/integrations-extras/pull/64
const (
ddMetricsBackendReqsName = "backend.request.total"
ddMetricsBackendLatencyName = "backend.request.duration"
ddRetriesTotalName = "backend.retries.total"
ddConfigReloadsName = "config.reload.total"
ddConfigReloadsFailureTagName = "failure"
ddLastConfigReloadSuccessName = "config.reload.lastSuccessTimestamp"
ddLastConfigReloadFailureName = "config.reload.lastFailureTimestamp"
ddEntrypointReqsName = "entrypoint.request.total"
ddEntrypointReqDurationName = "entrypoint.request.duration"
ddEntrypointOpenConnsName = "entrypoint.connections.open"
ddOpenConnsName = "backend.connections.open"
ddServerUpName = "backend.server.up"
)
// RegisterDatadog registers the metrics pusher if this didn't happen yet and creates a datadog Registry instance.
func RegisterDatadog(ctx context.Context, config *types.Datadog) Registry {
if datadogTicker == nil {
datadogTicker = initDatadogClient(ctx, config)
}
registry := &standardRegistry{
enabled: true,
configReloadsCounter: datadogClient.NewCounter(ddConfigReloadsName, 1.0),
configReloadsFailureCounter: datadogClient.NewCounter(ddConfigReloadsName, 1.0).With(ddConfigReloadsFailureTagName, "true"),
lastConfigReloadSuccessGauge: datadogClient.NewGauge(ddLastConfigReloadSuccessName),
lastConfigReloadFailureGauge: datadogClient.NewGauge(ddLastConfigReloadFailureName),
entrypointReqsCounter: datadogClient.NewCounter(ddEntrypointReqsName, 1.0),
entrypointReqDurationHistogram: datadogClient.NewHistogram(ddEntrypointReqDurationName, 1.0),
entrypointOpenConnsGauge: datadogClient.NewGauge(ddEntrypointOpenConnsName),
backendReqsCounter: datadogClient.NewCounter(ddMetricsBackendReqsName, 1.0),
backendReqDurationHistogram: datadogClient.NewHistogram(ddMetricsBackendLatencyName, 1.0),
backendRetriesCounter: datadogClient.NewCounter(ddRetriesTotalName, 1.0),
backendOpenConnsGauge: datadogClient.NewGauge(ddOpenConnsName),
backendServerUpGauge: datadogClient.NewGauge(ddServerUpName),
}
return registry
}
func initDatadogClient(ctx context.Context, config *types.Datadog) *time.Ticker {
address := config.Address
if len(address) == 0 {
address = "localhost:8125"
}
pushInterval, err := time.ParseDuration(config.PushInterval)
if err != nil {
log.FromContext(ctx).Warnf("Unable to parse %s from config.PushInterval: using 10s as the default value", config.PushInterval)
pushInterval = 10 * time.Second
}
report := time.NewTicker(pushInterval)
safe.Go(func() {
datadogClient.SendLoop(report.C, "udp", address)
})
return report
}
// StopDatadog stops internal datadogTicker which controls the pushing of metrics to DD Agent and resets it to `nil`.
func StopDatadog() {
if datadogTicker != nil {
datadogTicker.Stop()
}
datadogTicker = nil
}

View file

@ -0,0 +1,53 @@
package metrics
import (
"context"
"net/http"
"strconv"
"testing"
"time"
"github.com/containous/traefik/pkg/types"
"github.com/stvp/go-udp-testing"
)
func TestDatadog(t *testing.T) {
udp.SetAddr(":18125")
// This is needed to make sure that UDP Listener listens for data a bit longer, otherwise it will quit after a millisecond
udp.Timeout = 5 * time.Second
datadogRegistry := RegisterDatadog(context.Background(), &types.Datadog{Address: ":18125", PushInterval: "1s"})
defer StopDatadog()
if !datadogRegistry.IsEnabled() {
t.Errorf("DatadogRegistry should return true for IsEnabled()")
}
expected := []string{
// We are only validating counts, as it is nearly impossible to validate latency, since it varies every run
"traefik.backend.request.total:1.000000|c|#service:test,code:404,method:GET\n",
"traefik.backend.request.total:1.000000|c|#service:test,code:200,method:GET\n",
"traefik.backend.retries.total:2.000000|c|#service:test\n",
"traefik.backend.request.duration:10000.000000|h|#service:test,code:200\n",
"traefik.config.reload.total:1.000000|c\n",
"traefik.config.reload.total:1.000000|c|#failure:true\n",
"traefik.entrypoint.request.total:1.000000|c|#entrypoint:test\n",
"traefik.entrypoint.request.duration:10000.000000|h|#entrypoint:test\n",
"traefik.entrypoint.connections.open:1.000000|g|#entrypoint:test\n",
"traefik.backend.server.up:1.000000|g|#backend:test,url:http://127.0.0.1,one:two\n",
}
udp.ShouldReceiveAll(t, expected, func() {
datadogRegistry.BackendReqsCounter().With("service", "test", "code", strconv.Itoa(http.StatusOK), "method", http.MethodGet).Add(1)
datadogRegistry.BackendReqsCounter().With("service", "test", "code", strconv.Itoa(http.StatusNotFound), "method", http.MethodGet).Add(1)
datadogRegistry.BackendReqDurationHistogram().With("service", "test", "code", strconv.Itoa(http.StatusOK)).Observe(10000)
datadogRegistry.BackendRetriesCounter().With("service", "test").Add(1)
datadogRegistry.BackendRetriesCounter().With("service", "test").Add(1)
datadogRegistry.ConfigReloadsCounter().Add(1)
datadogRegistry.ConfigReloadsFailureCounter().Add(1)
datadogRegistry.EntrypointReqsCounter().With("entrypoint", "test").Add(1)
datadogRegistry.EntrypointReqDurationHistogram().With("entrypoint", "test").Observe(10000)
datadogRegistry.EntrypointOpenConnsGauge().With("entrypoint", "test").Set(1)
datadogRegistry.BackendServerUpGauge().With("backend", "test", "url", "http://127.0.0.1", "one", "two").Set(1)
})
}

213
pkg/metrics/influxdb.go Normal file
View file

@ -0,0 +1,213 @@
package metrics
import (
"bytes"
"context"
"fmt"
"net/url"
"regexp"
"time"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/safe"
"github.com/containous/traefik/pkg/types"
kitlog "github.com/go-kit/kit/log"
"github.com/go-kit/kit/metrics/influx"
influxdb "github.com/influxdata/influxdb/client/v2"
)
var influxDBClient *influx.Influx
type influxDBWriter struct {
buf bytes.Buffer
config *types.InfluxDB
}
var influxDBTicker *time.Ticker
const (
influxDBMetricsBackendReqsName = "traefik.backend.requests.total"
influxDBMetricsBackendLatencyName = "traefik.backend.request.duration"
influxDBRetriesTotalName = "traefik.backend.retries.total"
influxDBConfigReloadsName = "traefik.config.reload.total"
influxDBConfigReloadsFailureName = influxDBConfigReloadsName + ".failure"
influxDBLastConfigReloadSuccessName = "traefik.config.reload.lastSuccessTimestamp"
influxDBLastConfigReloadFailureName = "traefik.config.reload.lastFailureTimestamp"
influxDBEntrypointReqsName = "traefik.entrypoint.requests.total"
influxDBEntrypointReqDurationName = "traefik.entrypoint.request.duration"
influxDBEntrypointOpenConnsName = "traefik.entrypoint.connections.open"
influxDBOpenConnsName = "traefik.backend.connections.open"
influxDBServerUpName = "traefik.backend.server.up"
)
const (
protocolHTTP = "http"
protocolUDP = "udp"
)
// RegisterInfluxDB registers the metrics pusher if this didn't happen yet and creates a InfluxDB Registry instance.
func RegisterInfluxDB(ctx context.Context, config *types.InfluxDB) Registry {
if influxDBClient == nil {
influxDBClient = initInfluxDBClient(ctx, config)
}
if influxDBTicker == nil {
influxDBTicker = initInfluxDBTicker(ctx, config)
}
return &standardRegistry{
enabled: true,
configReloadsCounter: influxDBClient.NewCounter(influxDBConfigReloadsName),
configReloadsFailureCounter: influxDBClient.NewCounter(influxDBConfigReloadsFailureName),
lastConfigReloadSuccessGauge: influxDBClient.NewGauge(influxDBLastConfigReloadSuccessName),
lastConfigReloadFailureGauge: influxDBClient.NewGauge(influxDBLastConfigReloadFailureName),
entrypointReqsCounter: influxDBClient.NewCounter(influxDBEntrypointReqsName),
entrypointReqDurationHistogram: influxDBClient.NewHistogram(influxDBEntrypointReqDurationName),
entrypointOpenConnsGauge: influxDBClient.NewGauge(influxDBEntrypointOpenConnsName),
backendReqsCounter: influxDBClient.NewCounter(influxDBMetricsBackendReqsName),
backendReqDurationHistogram: influxDBClient.NewHistogram(influxDBMetricsBackendLatencyName),
backendRetriesCounter: influxDBClient.NewCounter(influxDBRetriesTotalName),
backendOpenConnsGauge: influxDBClient.NewGauge(influxDBOpenConnsName),
backendServerUpGauge: influxDBClient.NewGauge(influxDBServerUpName),
}
}
// initInfluxDBTicker creates a influxDBClient
func initInfluxDBClient(ctx context.Context, config *types.InfluxDB) *influx.Influx {
logger := log.FromContext(ctx)
// TODO deprecated: move this switch into configuration.SetEffectiveConfiguration when web provider will be removed.
switch config.Protocol {
case protocolUDP:
if len(config.Database) > 0 || len(config.RetentionPolicy) > 0 {
logger.Warn("Database and RetentionPolicy options have no effect with UDP.")
config.Database = ""
config.RetentionPolicy = ""
}
case protocolHTTP:
if u, err := url.Parse(config.Address); err == nil {
if u.Scheme != "http" && u.Scheme != "https" {
logger.Warnf("InfluxDB address %s should specify a scheme (http or https): falling back on HTTP.", config.Address)
config.Address = "http://" + config.Address
}
} else {
logger.Errorf("Unable to parse the InfluxDB address %v: falling back on UDP.", err)
config.Protocol = protocolUDP
config.Database = ""
config.RetentionPolicy = ""
}
default:
logger.Warnf("Unsupported protocol %s: falling back on UDP.", config.Protocol)
config.Protocol = protocolUDP
config.Database = ""
config.RetentionPolicy = ""
}
return influx.New(
map[string]string{},
influxdb.BatchPointsConfig{
Database: config.Database,
RetentionPolicy: config.RetentionPolicy,
},
kitlog.LoggerFunc(func(keyvals ...interface{}) error {
log.WithoutContext().WithField(log.MetricsProviderName, "influxdb").Info(keyvals)
return nil
}))
}
// initInfluxDBTicker initializes metrics pusher
func initInfluxDBTicker(ctx context.Context, config *types.InfluxDB) *time.Ticker {
pushInterval, err := time.ParseDuration(config.PushInterval)
if err != nil {
log.FromContext(ctx).Warnf("Unable to parse %s from config.PushInterval: using 10s as the default value", config.PushInterval)
pushInterval = 10 * time.Second
}
report := time.NewTicker(pushInterval)
safe.Go(func() {
var buf bytes.Buffer
influxDBClient.WriteLoop(report.C, &influxDBWriter{buf: buf, config: config})
})
return report
}
// StopInfluxDB stops internal influxDBTicker which controls the pushing of metrics to InfluxDB Agent and resets it to `nil`
func StopInfluxDB() {
if influxDBTicker != nil {
influxDBTicker.Stop()
}
influxDBTicker = nil
}
// Write creates a http or udp client and attempts to write BatchPoints.
// If a "database not found" error is encountered, a CREATE DATABASE
// query is attempted when using protocol http.
func (w *influxDBWriter) Write(bp influxdb.BatchPoints) error {
c, err := w.initWriteClient()
if err != nil {
return err
}
defer c.Close()
if writeErr := c.Write(bp); writeErr != nil {
ctx := log.With(context.Background(), log.Str(log.MetricsProviderName, "influxdb"))
log.FromContext(ctx).Errorf("Error while writing to InfluxDB: %s", writeErr.Error())
if handleErr := w.handleWriteError(ctx, c, writeErr); handleErr != nil {
return handleErr
}
// Retry write after successful handling of writeErr
return c.Write(bp)
}
return nil
}
func (w *influxDBWriter) initWriteClient() (influxdb.Client, error) {
if w.config.Protocol == "http" {
return influxdb.NewHTTPClient(influxdb.HTTPConfig{
Addr: w.config.Address,
Username: w.config.Username,
Password: w.config.Password,
})
}
return influxdb.NewUDPClient(influxdb.UDPConfig{
Addr: w.config.Address,
})
}
func (w *influxDBWriter) handleWriteError(ctx context.Context, c influxdb.Client, writeErr error) error {
if w.config.Protocol != protocolHTTP {
return writeErr
}
match, matchErr := regexp.MatchString("database not found", writeErr.Error())
if matchErr != nil || !match {
return writeErr
}
qStr := fmt.Sprintf("CREATE DATABASE \"%s\"", w.config.Database)
if w.config.RetentionPolicy != "" {
qStr = fmt.Sprintf("%s WITH NAME \"%s\"", qStr, w.config.RetentionPolicy)
}
logger := log.FromContext(ctx)
logger.Debugf("InfluxDB database not found: attempting to create one with %s", qStr)
q := influxdb.NewQuery(qStr, "", "")
response, queryErr := c.Query(q)
if queryErr == nil && response.Error() != nil {
queryErr = response.Error()
}
if queryErr != nil {
logger.Errorf("Error while creating the InfluxDB database %s", queryErr)
return queryErr
}
logger.Debugf("Successfully created the InfluxDB database %s", w.config.Database)
return nil
}

View file

@ -0,0 +1,135 @@
package metrics
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"regexp"
"strconv"
"testing"
"time"
"github.com/containous/traefik/pkg/types"
"github.com/stvp/go-udp-testing"
)
func TestInfluxDB(t *testing.T) {
udp.SetAddr(":8089")
// This is needed to make sure that UDP Listener listens for data a bit longer, otherwise it will quit after a millisecond
udp.Timeout = 5 * time.Second
influxDBRegistry := RegisterInfluxDB(context.Background(), &types.InfluxDB{Address: ":8089", PushInterval: "1s"})
defer StopInfluxDB()
if !influxDBRegistry.IsEnabled() {
t.Fatalf("InfluxDB registry must be enabled")
}
expectedBackend := []string{
`(traefik\.backend\.requests\.total,backend=test,code=200,method=GET count=1) [\d]{19}`,
`(traefik\.backend\.requests\.total,backend=test,code=404,method=GET count=1) [\d]{19}`,
`(traefik\.backend\.request\.duration,backend=test,code=200 p50=10000,p90=10000,p95=10000,p99=10000) [\d]{19}`,
`(traefik\.backend\.retries\.total(?:,code=[\d]{3},method=GET)?,backend=test count=2) [\d]{19}`,
`(traefik\.config\.reload\.total(?:[a-z=0-9A-Z,]+)? count=1) [\d]{19}`,
`(traefik\.config\.reload\.total\.failure(?:[a-z=0-9A-Z,]+)? count=1) [\d]{19}`,
`(traefik\.backend\.server\.up,backend=test(?:[a-z=0-9A-Z,]+)?,url=http://127.0.0.1 value=1) [\d]{19}`,
}
msgBackend := udp.ReceiveString(t, func() {
influxDBRegistry.BackendReqsCounter().With("backend", "test", "code", strconv.Itoa(http.StatusOK), "method", http.MethodGet).Add(1)
influxDBRegistry.BackendReqsCounter().With("backend", "test", "code", strconv.Itoa(http.StatusNotFound), "method", http.MethodGet).Add(1)
influxDBRegistry.BackendRetriesCounter().With("backend", "test").Add(1)
influxDBRegistry.BackendRetriesCounter().With("backend", "test").Add(1)
influxDBRegistry.BackendReqDurationHistogram().With("backend", "test", "code", strconv.Itoa(http.StatusOK)).Observe(10000)
influxDBRegistry.ConfigReloadsCounter().Add(1)
influxDBRegistry.ConfigReloadsFailureCounter().Add(1)
influxDBRegistry.BackendServerUpGauge().With("backend", "test", "url", "http://127.0.0.1").Set(1)
})
assertMessage(t, msgBackend, expectedBackend)
expectedEntrypoint := []string{
`(traefik\.entrypoint\.requests\.total,entrypoint=test(?:[a-z=0-9A-Z,:/.]+)? count=1) [\d]{19}`,
`(traefik\.entrypoint\.request\.duration(?:,code=[\d]{3})?,entrypoint=test(?:[a-z=0-9A-Z,:/.]+)? p50=10000,p90=10000,p95=10000,p99=10000) [\d]{19}`,
`(traefik\.entrypoint\.connections\.open,entrypoint=test value=1) [\d]{19}`,
}
msgEntrypoint := udp.ReceiveString(t, func() {
influxDBRegistry.EntrypointReqsCounter().With("entrypoint", "test").Add(1)
influxDBRegistry.EntrypointReqDurationHistogram().With("entrypoint", "test").Observe(10000)
influxDBRegistry.EntrypointOpenConnsGauge().With("entrypoint", "test").Set(1)
})
assertMessage(t, msgEntrypoint, expectedEntrypoint)
}
func TestInfluxDBHTTP(t *testing.T) {
c := make(chan *string)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
http.Error(w, "can't read body "+err.Error(), http.StatusBadRequest)
return
}
bodyStr := string(body)
c <- &bodyStr
fmt.Fprintln(w, "ok")
}))
defer ts.Close()
influxDBRegistry := RegisterInfluxDB(context.Background(), &types.InfluxDB{Address: ts.URL, Protocol: "http", PushInterval: "1s", Database: "test", RetentionPolicy: "autogen"})
defer StopInfluxDB()
if !influxDBRegistry.IsEnabled() {
t.Fatalf("InfluxDB registry must be enabled")
}
expectedBackend := []string{
`(traefik\.backend\.requests\.total,backend=test,code=200,method=GET count=1) [\d]{19}`,
`(traefik\.backend\.requests\.total,backend=test,code=404,method=GET count=1) [\d]{19}`,
`(traefik\.backend\.request\.duration,backend=test,code=200 p50=10000,p90=10000,p95=10000,p99=10000) [\d]{19}`,
`(traefik\.backend\.retries\.total(?:,code=[\d]{3},method=GET)?,backend=test count=2) [\d]{19}`,
`(traefik\.config\.reload\.total(?:[a-z=0-9A-Z,]+)? count=1) [\d]{19}`,
`(traefik\.config\.reload\.total\.failure(?:[a-z=0-9A-Z,]+)? count=1) [\d]{19}`,
`(traefik\.backend\.server\.up,backend=test(?:[a-z=0-9A-Z,]+)?,url=http://127.0.0.1 value=1) [\d]{19}`,
}
influxDBRegistry.BackendReqsCounter().With("backend", "test", "code", strconv.Itoa(http.StatusOK), "method", http.MethodGet).Add(1)
influxDBRegistry.BackendReqsCounter().With("backend", "test", "code", strconv.Itoa(http.StatusNotFound), "method", http.MethodGet).Add(1)
influxDBRegistry.BackendRetriesCounter().With("backend", "test").Add(1)
influxDBRegistry.BackendRetriesCounter().With("backend", "test").Add(1)
influxDBRegistry.BackendReqDurationHistogram().With("backend", "test", "code", strconv.Itoa(http.StatusOK)).Observe(10000)
influxDBRegistry.ConfigReloadsCounter().Add(1)
influxDBRegistry.ConfigReloadsFailureCounter().Add(1)
influxDBRegistry.BackendServerUpGauge().With("backend", "test", "url", "http://127.0.0.1").Set(1)
msgBackend := <-c
assertMessage(t, *msgBackend, expectedBackend)
expectedEntrypoint := []string{
`(traefik\.entrypoint\.requests\.total,entrypoint=test(?:[a-z=0-9A-Z,:/.]+)? count=1) [\d]{19}`,
`(traefik\.entrypoint\.request\.duration(?:,code=[\d]{3})?,entrypoint=test(?:[a-z=0-9A-Z,:/.]+)? p50=10000,p90=10000,p95=10000,p99=10000) [\d]{19}`,
`(traefik\.entrypoint\.connections\.open,entrypoint=test value=1) [\d]{19}`,
}
influxDBRegistry.EntrypointReqsCounter().With("entrypoint", "test").Add(1)
influxDBRegistry.EntrypointReqDurationHistogram().With("entrypoint", "test").Observe(10000)
influxDBRegistry.EntrypointOpenConnsGauge().With("entrypoint", "test").Set(1)
msgEntrypoint := <-c
assertMessage(t, *msgEntrypoint, expectedEntrypoint)
}
func assertMessage(t *testing.T, msg string, patterns []string) {
t.Helper()
for _, pattern := range patterns {
re := regexp.MustCompile(pattern)
match := re.FindStringSubmatch(msg)
if len(match) != 2 {
t.Errorf("Got %q %v, want %q", msg, match, pattern)
}
}
}

177
pkg/metrics/metrics.go Normal file
View file

@ -0,0 +1,177 @@
package metrics
import (
"github.com/go-kit/kit/metrics"
"github.com/go-kit/kit/metrics/multi"
)
// Registry has to implemented by any system that wants to monitor and expose metrics.
type Registry interface {
// IsEnabled shows whether metrics instrumentation is enabled.
IsEnabled() bool
// server metrics
ConfigReloadsCounter() metrics.Counter
ConfigReloadsFailureCounter() metrics.Counter
LastConfigReloadSuccessGauge() metrics.Gauge
LastConfigReloadFailureGauge() metrics.Gauge
// entry point metrics
EntrypointReqsCounter() metrics.Counter
EntrypointReqDurationHistogram() metrics.Histogram
EntrypointOpenConnsGauge() metrics.Gauge
// backend metrics
BackendReqsCounter() metrics.Counter
BackendReqDurationHistogram() metrics.Histogram
BackendOpenConnsGauge() metrics.Gauge
BackendRetriesCounter() metrics.Counter
BackendServerUpGauge() metrics.Gauge
}
// NewVoidRegistry is a noop implementation of metrics.Registry.
// It is used to avoid nil checking in components that do metric collections.
func NewVoidRegistry() Registry {
return NewMultiRegistry([]Registry{})
}
// NewMultiRegistry is an implementation of metrics.Registry that wraps multiple registries.
// It handles the case when a registry hasn't registered some metric and returns nil.
// This allows for feature imparity between the different metric implementations.
func NewMultiRegistry(registries []Registry) Registry {
var configReloadsCounter []metrics.Counter
var configReloadsFailureCounter []metrics.Counter
var lastConfigReloadSuccessGauge []metrics.Gauge
var lastConfigReloadFailureGauge []metrics.Gauge
var entrypointReqsCounter []metrics.Counter
var entrypointReqDurationHistogram []metrics.Histogram
var entrypointOpenConnsGauge []metrics.Gauge
var backendReqsCounter []metrics.Counter
var backendReqDurationHistogram []metrics.Histogram
var backendOpenConnsGauge []metrics.Gauge
var backendRetriesCounter []metrics.Counter
var backendServerUpGauge []metrics.Gauge
for _, r := range registries {
if r.ConfigReloadsCounter() != nil {
configReloadsCounter = append(configReloadsCounter, r.ConfigReloadsCounter())
}
if r.ConfigReloadsFailureCounter() != nil {
configReloadsFailureCounter = append(configReloadsFailureCounter, r.ConfigReloadsFailureCounter())
}
if r.LastConfigReloadSuccessGauge() != nil {
lastConfigReloadSuccessGauge = append(lastConfigReloadSuccessGauge, r.LastConfigReloadSuccessGauge())
}
if r.LastConfigReloadFailureGauge() != nil {
lastConfigReloadFailureGauge = append(lastConfigReloadFailureGauge, r.LastConfigReloadFailureGauge())
}
if r.EntrypointReqsCounter() != nil {
entrypointReqsCounter = append(entrypointReqsCounter, r.EntrypointReqsCounter())
}
if r.EntrypointReqDurationHistogram() != nil {
entrypointReqDurationHistogram = append(entrypointReqDurationHistogram, r.EntrypointReqDurationHistogram())
}
if r.EntrypointOpenConnsGauge() != nil {
entrypointOpenConnsGauge = append(entrypointOpenConnsGauge, r.EntrypointOpenConnsGauge())
}
if r.BackendReqsCounter() != nil {
backendReqsCounter = append(backendReqsCounter, r.BackendReqsCounter())
}
if r.BackendReqDurationHistogram() != nil {
backendReqDurationHistogram = append(backendReqDurationHistogram, r.BackendReqDurationHistogram())
}
if r.BackendOpenConnsGauge() != nil {
backendOpenConnsGauge = append(backendOpenConnsGauge, r.BackendOpenConnsGauge())
}
if r.BackendRetriesCounter() != nil {
backendRetriesCounter = append(backendRetriesCounter, r.BackendRetriesCounter())
}
if r.BackendServerUpGauge() != nil {
backendServerUpGauge = append(backendServerUpGauge, r.BackendServerUpGauge())
}
}
return &standardRegistry{
enabled: len(registries) > 0,
configReloadsCounter: multi.NewCounter(configReloadsCounter...),
configReloadsFailureCounter: multi.NewCounter(configReloadsFailureCounter...),
lastConfigReloadSuccessGauge: multi.NewGauge(lastConfigReloadSuccessGauge...),
lastConfigReloadFailureGauge: multi.NewGauge(lastConfigReloadFailureGauge...),
entrypointReqsCounter: multi.NewCounter(entrypointReqsCounter...),
entrypointReqDurationHistogram: multi.NewHistogram(entrypointReqDurationHistogram...),
entrypointOpenConnsGauge: multi.NewGauge(entrypointOpenConnsGauge...),
backendReqsCounter: multi.NewCounter(backendReqsCounter...),
backendReqDurationHistogram: multi.NewHistogram(backendReqDurationHistogram...),
backendOpenConnsGauge: multi.NewGauge(backendOpenConnsGauge...),
backendRetriesCounter: multi.NewCounter(backendRetriesCounter...),
backendServerUpGauge: multi.NewGauge(backendServerUpGauge...),
}
}
type standardRegistry struct {
enabled bool
configReloadsCounter metrics.Counter
configReloadsFailureCounter metrics.Counter
lastConfigReloadSuccessGauge metrics.Gauge
lastConfigReloadFailureGauge metrics.Gauge
entrypointReqsCounter metrics.Counter
entrypointReqDurationHistogram metrics.Histogram
entrypointOpenConnsGauge metrics.Gauge
backendReqsCounter metrics.Counter
backendReqDurationHistogram metrics.Histogram
backendOpenConnsGauge metrics.Gauge
backendRetriesCounter metrics.Counter
backendServerUpGauge metrics.Gauge
}
func (r *standardRegistry) IsEnabled() bool {
return r.enabled
}
func (r *standardRegistry) ConfigReloadsCounter() metrics.Counter {
return r.configReloadsCounter
}
func (r *standardRegistry) ConfigReloadsFailureCounter() metrics.Counter {
return r.configReloadsFailureCounter
}
func (r *standardRegistry) LastConfigReloadSuccessGauge() metrics.Gauge {
return r.lastConfigReloadSuccessGauge
}
func (r *standardRegistry) LastConfigReloadFailureGauge() metrics.Gauge {
return r.lastConfigReloadFailureGauge
}
func (r *standardRegistry) EntrypointReqsCounter() metrics.Counter {
return r.entrypointReqsCounter
}
func (r *standardRegistry) EntrypointReqDurationHistogram() metrics.Histogram {
return r.entrypointReqDurationHistogram
}
func (r *standardRegistry) EntrypointOpenConnsGauge() metrics.Gauge {
return r.entrypointOpenConnsGauge
}
func (r *standardRegistry) BackendReqsCounter() metrics.Counter {
return r.backendReqsCounter
}
func (r *standardRegistry) BackendReqDurationHistogram() metrics.Histogram {
return r.backendReqDurationHistogram
}
func (r *standardRegistry) BackendOpenConnsGauge() metrics.Gauge {
return r.backendOpenConnsGauge
}
func (r *standardRegistry) BackendRetriesCounter() metrics.Counter {
return r.backendRetriesCounter
}
func (r *standardRegistry) BackendServerUpGauge() metrics.Gauge {
return r.backendServerUpGauge
}

View file

@ -0,0 +1,76 @@
package metrics
import (
"testing"
"github.com/go-kit/kit/metrics"
"github.com/stretchr/testify/assert"
)
func TestNewMultiRegistry(t *testing.T) {
registries := []Registry{newCollectingRetryMetrics(), newCollectingRetryMetrics()}
registry := NewMultiRegistry(registries)
registry.BackendReqsCounter().With("key", "requests").Add(1)
registry.BackendReqDurationHistogram().With("key", "durations").Observe(2)
registry.BackendRetriesCounter().With("key", "retries").Add(3)
for _, collectingRegistry := range registries {
cReqsCounter := collectingRegistry.BackendReqsCounter().(*counterMock)
cReqDurationHistogram := collectingRegistry.BackendReqDurationHistogram().(*histogramMock)
cRetriesCounter := collectingRegistry.BackendRetriesCounter().(*counterMock)
wantCounterValue := float64(1)
if cReqsCounter.counterValue != wantCounterValue {
t.Errorf("Got value %f for ReqsCounter, want %f", cReqsCounter.counterValue, wantCounterValue)
}
wantHistogramValue := float64(2)
if cReqDurationHistogram.lastHistogramValue != wantHistogramValue {
t.Errorf("Got last observation %f for ReqDurationHistogram, want %f", cReqDurationHistogram.lastHistogramValue, wantHistogramValue)
}
wantCounterValue = float64(3)
if cRetriesCounter.counterValue != wantCounterValue {
t.Errorf("Got value %f for RetriesCounter, want %f", cRetriesCounter.counterValue, wantCounterValue)
}
assert.Equal(t, []string{"key", "requests"}, cReqsCounter.lastLabelValues)
assert.Equal(t, []string{"key", "durations"}, cReqDurationHistogram.lastLabelValues)
assert.Equal(t, []string{"key", "retries"}, cRetriesCounter.lastLabelValues)
}
}
func newCollectingRetryMetrics() Registry {
return &standardRegistry{
backendReqsCounter: &counterMock{},
backendReqDurationHistogram: &histogramMock{},
backendRetriesCounter: &counterMock{},
}
}
type counterMock struct {
counterValue float64
lastLabelValues []string
}
func (c *counterMock) With(labelValues ...string) metrics.Counter {
c.lastLabelValues = labelValues
return c
}
func (c *counterMock) Add(delta float64) {
c.counterValue += delta
}
type histogramMock struct {
lastHistogramValue float64
lastLabelValues []string
}
func (c *histogramMock) With(labelValues ...string) metrics.Histogram {
c.lastLabelValues = labelValues
return c
}
func (c *histogramMock) Observe(value float64) {
c.lastHistogramValue = value
}

512
pkg/metrics/prometheus.go Normal file
View file

@ -0,0 +1,512 @@
package metrics
import (
"context"
"net/http"
"sort"
"strings"
"sync"
"github.com/containous/mux"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/safe"
"github.com/containous/traefik/pkg/types"
"github.com/go-kit/kit/metrics"
stdprometheus "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
const (
// MetricNamePrefix prefix of all metric names
MetricNamePrefix = "traefik_"
// server meta information
metricConfigPrefix = MetricNamePrefix + "config_"
configReloadsTotalName = metricConfigPrefix + "reloads_total"
configReloadsFailuresTotalName = metricConfigPrefix + "reloads_failure_total"
configLastReloadSuccessName = metricConfigPrefix + "last_reload_success"
configLastReloadFailureName = metricConfigPrefix + "last_reload_failure"
// entrypoint
metricEntryPointPrefix = MetricNamePrefix + "entrypoint_"
entrypointReqsTotalName = metricEntryPointPrefix + "requests_total"
entrypointReqDurationName = metricEntryPointPrefix + "request_duration_seconds"
entrypointOpenConnsName = metricEntryPointPrefix + "open_connections"
// backend level.
// MetricBackendPrefix prefix of all backend metric names
MetricBackendPrefix = MetricNamePrefix + "backend_"
backendReqsTotalName = MetricBackendPrefix + "requests_total"
backendReqDurationName = MetricBackendPrefix + "request_duration_seconds"
backendOpenConnsName = MetricBackendPrefix + "open_connections"
backendRetriesTotalName = MetricBackendPrefix + "retries_total"
backendServerUpName = MetricBackendPrefix + "server_up"
)
// promState holds all metric state internally and acts as the only Collector we register for Prometheus.
//
// This enables control to remove metrics that belong to outdated configuration.
// As an example why this is required, consider Traefik learns about a new service.
// It populates the 'traefik_server_backend_up' metric for it with a value of 1 (alive).
// When the backend is undeployed now the metric is still there in the client library
// and will be returned on the metrics endpoint until Traefik would be restarted.
//
// To solve this problem promState keeps track of Traefik's dynamic configuration.
// Metrics that "belong" to a dynamic configuration part like backends or entrypoints
// are removed after they were scraped at least once when the corresponding object
// doesn't exist anymore.
var promState = newPrometheusState()
// PrometheusHandler exposes Prometheus routes.
type PrometheusHandler struct{}
// Append adds Prometheus routes on a router.
func (h PrometheusHandler) Append(router *mux.Router) {
router.Methods(http.MethodGet).Path("/metrics").Handler(promhttp.Handler())
}
// RegisterPrometheus registers all Prometheus metrics.
// It must be called only once and failing to register the metrics will lead to a panic.
func RegisterPrometheus(ctx context.Context, config *types.Prometheus) Registry {
standardRegistry := initStandardRegistry(config)
if !registerPromState(ctx) {
return nil
}
return standardRegistry
}
func initStandardRegistry(config *types.Prometheus) Registry {
buckets := []float64{0.1, 0.3, 1.2, 5.0}
if config.Buckets != nil {
buckets = config.Buckets
}
safe.Go(func() {
promState.ListenValueUpdates()
})
configReloads := newCounterFrom(promState.collectors, stdprometheus.CounterOpts{
Name: configReloadsTotalName,
Help: "Config reloads",
}, []string{})
configReloadsFailures := newCounterFrom(promState.collectors, stdprometheus.CounterOpts{
Name: configReloadsFailuresTotalName,
Help: "Config failure reloads",
}, []string{})
lastConfigReloadSuccess := newGaugeFrom(promState.collectors, stdprometheus.GaugeOpts{
Name: configLastReloadSuccessName,
Help: "Last config reload success",
}, []string{})
lastConfigReloadFailure := newGaugeFrom(promState.collectors, stdprometheus.GaugeOpts{
Name: configLastReloadFailureName,
Help: "Last config reload failure",
}, []string{})
entrypointReqs := newCounterFrom(promState.collectors, stdprometheus.CounterOpts{
Name: entrypointReqsTotalName,
Help: "How many HTTP requests processed on an entrypoint, partitioned by status code, protocol, and method.",
}, []string{"code", "method", "protocol", "entrypoint"})
entrypointReqDurations := newHistogramFrom(promState.collectors, stdprometheus.HistogramOpts{
Name: entrypointReqDurationName,
Help: "How long it took to process the request on an entrypoint, partitioned by status code, protocol, and method.",
Buckets: buckets,
}, []string{"code", "method", "protocol", "entrypoint"})
entrypointOpenConns := newGaugeFrom(promState.collectors, stdprometheus.GaugeOpts{
Name: entrypointOpenConnsName,
Help: "How many open connections exist on an entrypoint, partitioned by method and protocol.",
}, []string{"method", "protocol", "entrypoint"})
backendReqs := newCounterFrom(promState.collectors, stdprometheus.CounterOpts{
Name: backendReqsTotalName,
Help: "How many HTTP requests processed on a backend, partitioned by status code, protocol, and method.",
}, []string{"code", "method", "protocol", "backend"})
backendReqDurations := newHistogramFrom(promState.collectors, stdprometheus.HistogramOpts{
Name: backendReqDurationName,
Help: "How long it took to process the request on a backend, partitioned by status code, protocol, and method.",
Buckets: buckets,
}, []string{"code", "method", "protocol", "backend"})
backendOpenConns := newGaugeFrom(promState.collectors, stdprometheus.GaugeOpts{
Name: backendOpenConnsName,
Help: "How many open connections exist on a backend, partitioned by method and protocol.",
}, []string{"method", "protocol", "backend"})
backendRetries := newCounterFrom(promState.collectors, stdprometheus.CounterOpts{
Name: backendRetriesTotalName,
Help: "How many request retries happened on a backend.",
}, []string{"backend"})
backendServerUp := newGaugeFrom(promState.collectors, stdprometheus.GaugeOpts{
Name: backendServerUpName,
Help: "Backend server is up, described by gauge value of 0 or 1.",
}, []string{"backend", "url"})
promState.describers = []func(chan<- *stdprometheus.Desc){
configReloads.cv.Describe,
configReloadsFailures.cv.Describe,
lastConfigReloadSuccess.gv.Describe,
lastConfigReloadFailure.gv.Describe,
entrypointReqs.cv.Describe,
entrypointReqDurations.hv.Describe,
entrypointOpenConns.gv.Describe,
backendReqs.cv.Describe,
backendReqDurations.hv.Describe,
backendOpenConns.gv.Describe,
backendRetries.cv.Describe,
backendServerUp.gv.Describe,
}
return &standardRegistry{
enabled: true,
configReloadsCounter: configReloads,
configReloadsFailureCounter: configReloadsFailures,
lastConfigReloadSuccessGauge: lastConfigReloadSuccess,
lastConfigReloadFailureGauge: lastConfigReloadFailure,
entrypointReqsCounter: entrypointReqs,
entrypointReqDurationHistogram: entrypointReqDurations,
entrypointOpenConnsGauge: entrypointOpenConns,
backendReqsCounter: backendReqs,
backendReqDurationHistogram: backendReqDurations,
backendOpenConnsGauge: backendOpenConns,
backendRetriesCounter: backendRetries,
backendServerUpGauge: backendServerUp,
}
}
func registerPromState(ctx context.Context) bool {
if err := stdprometheus.Register(promState); err != nil {
logger := log.FromContext(ctx)
if _, ok := err.(stdprometheus.AlreadyRegisteredError); !ok {
logger.Errorf("Unable to register Traefik to Prometheus: %v", err)
return false
}
logger.Debug("Prometheus collector already registered.")
}
return true
}
// OnConfigurationUpdate receives the current configuration from Traefik.
// It then converts the configuration to the optimized package internal format
// and sets it to the promState.
func OnConfigurationUpdate(configurations config.Configurations) {
dynamicConfig := newDynamicConfig()
// FIXME metrics
// for _, config := range configurations {
// for _, frontend := range config.Frontends {
// for _, entrypointName := range frontend.EntryPoints {
// dynamicConfig.entrypoints[entrypointName] = true
// }
// }
//
// for backendName, backend := range config.Backends {
// dynamicConfig.backends[backendName] = make(map[string]bool)
// for _, server := range backend.Servers {
// dynamicConfig.backends[backendName][server.URL] = true
// }
// }
// }
promState.SetDynamicConfig(dynamicConfig)
}
func newPrometheusState() *prometheusState {
return &prometheusState{
collectors: make(chan *collector),
dynamicConfig: newDynamicConfig(),
state: make(map[string]*collector),
}
}
type prometheusState struct {
collectors chan *collector
describers []func(ch chan<- *stdprometheus.Desc)
mtx sync.Mutex
dynamicConfig *dynamicConfig
state map[string]*collector
}
func (ps *prometheusState) SetDynamicConfig(dynamicConfig *dynamicConfig) {
ps.mtx.Lock()
defer ps.mtx.Unlock()
ps.dynamicConfig = dynamicConfig
}
func (ps *prometheusState) ListenValueUpdates() {
for collector := range ps.collectors {
ps.mtx.Lock()
ps.state[collector.id] = collector
ps.mtx.Unlock()
}
}
// Describe implements prometheus.Collector and simply calls
// the registered describer functions.
func (ps *prometheusState) Describe(ch chan<- *stdprometheus.Desc) {
for _, desc := range ps.describers {
desc(ch)
}
}
// Collect implements prometheus.Collector. It calls the Collect
// method of all metrics it received on the collectors channel.
// It's also responsible to remove metrics that belong to an outdated configuration.
// The removal happens only after their Collect method was called to ensure that
// also those metrics will be exported on the current scrape.
func (ps *prometheusState) Collect(ch chan<- stdprometheus.Metric) {
ps.mtx.Lock()
defer ps.mtx.Unlock()
var outdatedKeys []string
for key, cs := range ps.state {
cs.collector.Collect(ch)
if ps.isOutdated(cs) {
outdatedKeys = append(outdatedKeys, key)
}
}
for _, key := range outdatedKeys {
ps.state[key].delete()
delete(ps.state, key)
}
}
// isOutdated checks whether the passed collector has labels that mark
// it as belonging to an outdated configuration of Traefik.
func (ps *prometheusState) isOutdated(collector *collector) bool {
labels := collector.labels
if entrypointName, ok := labels["entrypoint"]; ok && !ps.dynamicConfig.hasEntrypoint(entrypointName) {
return true
}
if backendName, ok := labels["backend"]; ok {
if !ps.dynamicConfig.hasBackend(backendName) {
return true
}
if url, ok := labels["url"]; ok && !ps.dynamicConfig.hasServerURL(backendName, url) {
return true
}
}
return false
}
func newDynamicConfig() *dynamicConfig {
return &dynamicConfig{
entrypoints: make(map[string]bool),
backends: make(map[string]map[string]bool),
}
}
// dynamicConfig holds the current configuration for entrypoints, backends,
// and server URLs in an optimized way to check for existence. This provides
// a performant way to check whether the collected metrics belong to the
// current configuration or to an outdated one.
type dynamicConfig struct {
entrypoints map[string]bool
backends map[string]map[string]bool
}
func (d *dynamicConfig) hasEntrypoint(entrypointName string) bool {
_, ok := d.entrypoints[entrypointName]
return ok
}
func (d *dynamicConfig) hasBackend(backendName string) bool {
_, ok := d.backends[backendName]
return ok
}
func (d *dynamicConfig) hasServerURL(backendName, serverURL string) bool {
if backend, hasBackend := d.backends[backendName]; hasBackend {
_, ok := backend[serverURL]
return ok
}
return false
}
func newCollector(metricName string, labels stdprometheus.Labels, c stdprometheus.Collector, delete func()) *collector {
return &collector{
id: buildMetricID(metricName, labels),
labels: labels,
collector: c,
delete: delete,
}
}
// collector wraps a Collector object from the Prometheus client library.
// It adds information on how many generations this metric should be present
// in the /metrics output, relatived to the time it was last tracked.
type collector struct {
id string
labels stdprometheus.Labels
collector stdprometheus.Collector
delete func()
}
func buildMetricID(metricName string, labels stdprometheus.Labels) string {
var labelNamesValues []string
for name, value := range labels {
labelNamesValues = append(labelNamesValues, name, value)
}
sort.Strings(labelNamesValues)
return metricName + ":" + strings.Join(labelNamesValues, "|")
}
func newCounterFrom(collectors chan<- *collector, opts stdprometheus.CounterOpts, labelNames []string) *counter {
cv := stdprometheus.NewCounterVec(opts, labelNames)
c := &counter{
name: opts.Name,
cv: cv,
collectors: collectors,
}
if len(labelNames) == 0 {
c.Add(0)
}
return c
}
type counter struct {
name string
cv *stdprometheus.CounterVec
labelNamesValues labelNamesValues
collectors chan<- *collector
}
func (c *counter) With(labelValues ...string) metrics.Counter {
return &counter{
name: c.name,
cv: c.cv,
labelNamesValues: c.labelNamesValues.With(labelValues...),
collectors: c.collectors,
}
}
func (c *counter) Add(delta float64) {
labels := c.labelNamesValues.ToLabels()
collector := c.cv.With(labels)
collector.Add(delta)
c.collectors <- newCollector(c.name, labels, collector, func() {
c.cv.Delete(labels)
})
}
func (c *counter) Describe(ch chan<- *stdprometheus.Desc) {
c.cv.Describe(ch)
}
func newGaugeFrom(collectors chan<- *collector, opts stdprometheus.GaugeOpts, labelNames []string) *gauge {
gv := stdprometheus.NewGaugeVec(opts, labelNames)
g := &gauge{
name: opts.Name,
gv: gv,
collectors: collectors,
}
if len(labelNames) == 0 {
g.Set(0)
}
return g
}
type gauge struct {
name string
gv *stdprometheus.GaugeVec
labelNamesValues labelNamesValues
collectors chan<- *collector
}
func (g *gauge) With(labelValues ...string) metrics.Gauge {
return &gauge{
name: g.name,
gv: g.gv,
labelNamesValues: g.labelNamesValues.With(labelValues...),
collectors: g.collectors,
}
}
func (g *gauge) Add(delta float64) {
labels := g.labelNamesValues.ToLabels()
collector := g.gv.With(labels)
collector.Add(delta)
g.collectors <- newCollector(g.name, labels, collector, func() {
g.gv.Delete(labels)
})
}
func (g *gauge) Set(value float64) {
labels := g.labelNamesValues.ToLabels()
collector := g.gv.With(labels)
collector.Set(value)
g.collectors <- newCollector(g.name, labels, collector, func() {
g.gv.Delete(labels)
})
}
func (g *gauge) Describe(ch chan<- *stdprometheus.Desc) {
g.gv.Describe(ch)
}
func newHistogramFrom(collectors chan<- *collector, opts stdprometheus.HistogramOpts, labelNames []string) *histogram {
hv := stdprometheus.NewHistogramVec(opts, labelNames)
return &histogram{
name: opts.Name,
hv: hv,
collectors: collectors,
}
}
type histogram struct {
name string
hv *stdprometheus.HistogramVec
labelNamesValues labelNamesValues
collectors chan<- *collector
}
func (h *histogram) With(labelValues ...string) metrics.Histogram {
return &histogram{
name: h.name,
hv: h.hv,
labelNamesValues: h.labelNamesValues.With(labelValues...),
collectors: h.collectors,
}
}
func (h *histogram) Observe(value float64) {
labels := h.labelNamesValues.ToLabels()
collector := h.hv.With(labels)
collector.Observe(value)
h.collectors <- newCollector(h.name, labels, collector, func() {
h.hv.Delete(labels)
})
}
func (h *histogram) Describe(ch chan<- *stdprometheus.Desc) {
h.hv.Describe(ch)
}
// labelNamesValues is a type alias that provides validation on its With method.
// Metrics may include it as a member to help them satisfy With semantics and
// save some code duplication.
type labelNamesValues []string
// With validates the input, and returns a new aggregate labelNamesValues.
func (lvs labelNamesValues) With(labelValues ...string) labelNamesValues {
if len(labelValues)%2 != 0 {
labelValues = append(labelValues, "unknown")
}
return append(lvs, labelValues...)
}
// ToLabels is a convenience method to convert a labelNamesValues
// to the native prometheus.Labels.
func (lvs labelNamesValues) ToLabels() stdprometheus.Labels {
labels := stdprometheus.Labels{}
for i := 0; i < len(lvs); i += 2 {
labels[lvs[i]] = lvs[i+1]
}
return labels
}

View file

@ -0,0 +1,504 @@
package metrics
import (
"context"
"fmt"
"net/http"
"strconv"
"testing"
"time"
"github.com/containous/traefik/pkg/config"
th "github.com/containous/traefik/pkg/testhelpers"
"github.com/containous/traefik/pkg/types"
"github.com/prometheus/client_golang/prometheus"
dto "github.com/prometheus/client_model/go"
"github.com/stretchr/testify/assert"
)
func TestRegisterPromState(t *testing.T) {
// Reset state of global promState.
defer promState.reset()
testCases := []struct {
desc string
prometheusSlice []*types.Prometheus
initPromState bool
unregisterPromState bool
expectedNbRegistries int
}{
{
desc: "Register once",
prometheusSlice: []*types.Prometheus{{}},
expectedNbRegistries: 1,
initPromState: true,
},
{
desc: "Register once with no promState init",
prometheusSlice: []*types.Prometheus{{}},
expectedNbRegistries: 0,
},
{
desc: "Register twice",
prometheusSlice: []*types.Prometheus{{}, {}},
expectedNbRegistries: 2,
initPromState: true,
},
{
desc: "Register twice with no promstate init",
prometheusSlice: []*types.Prometheus{{}, {}},
expectedNbRegistries: 0,
},
{
desc: "Register twice with unregister",
prometheusSlice: []*types.Prometheus{{}, {}},
unregisterPromState: true,
expectedNbRegistries: 2,
initPromState: true,
},
{
desc: "Register twice with unregister but no promstate init",
prometheusSlice: []*types.Prometheus{{}, {}},
unregisterPromState: true,
expectedNbRegistries: 0,
},
}
for _, test := range testCases {
actualNbRegistries := 0
for _, prom := range test.prometheusSlice {
if test.initPromState {
initStandardRegistry(prom)
}
if registerPromState(context.Background()) {
actualNbRegistries++
}
if test.unregisterPromState {
prometheus.Unregister(promState)
}
promState.reset()
}
prometheus.Unregister(promState)
assert.Equal(t, test.expectedNbRegistries, actualNbRegistries)
}
}
// reset is a utility method for unit testing. It should be called after each
// test run that changes promState internally in order to avoid dependencies
// between unit tests.
func (ps *prometheusState) reset() {
ps.collectors = make(chan *collector)
ps.describers = []func(ch chan<- *prometheus.Desc){}
ps.dynamicConfig = newDynamicConfig()
ps.state = make(map[string]*collector)
}
func TestPrometheus(t *testing.T) {
// Reset state of global promState.
defer promState.reset()
prometheusRegistry := RegisterPrometheus(context.Background(), &types.Prometheus{})
defer prometheus.Unregister(promState)
if !prometheusRegistry.IsEnabled() {
t.Errorf("PrometheusRegistry should return true for IsEnabled()")
}
prometheusRegistry.ConfigReloadsCounter().Add(1)
prometheusRegistry.ConfigReloadsFailureCounter().Add(1)
prometheusRegistry.LastConfigReloadSuccessGauge().Set(float64(time.Now().Unix()))
prometheusRegistry.LastConfigReloadFailureGauge().Set(float64(time.Now().Unix()))
prometheusRegistry.
EntrypointReqsCounter().
With("code", strconv.Itoa(http.StatusOK), "method", http.MethodGet, "protocol", "http", "entrypoint", "http").
Add(1)
prometheusRegistry.
EntrypointReqDurationHistogram().
With("code", strconv.Itoa(http.StatusOK), "method", http.MethodGet, "protocol", "http", "entrypoint", "http").
Observe(1)
prometheusRegistry.
EntrypointOpenConnsGauge().
With("method", http.MethodGet, "protocol", "http", "entrypoint", "http").
Set(1)
prometheusRegistry.
BackendReqsCounter().
With("backend", "backend1", "code", strconv.Itoa(http.StatusOK), "method", http.MethodGet, "protocol", "http").
Add(1)
prometheusRegistry.
BackendReqDurationHistogram().
With("backend", "backend1", "code", strconv.Itoa(http.StatusOK), "method", http.MethodGet, "protocol", "http").
Observe(10000)
prometheusRegistry.
BackendOpenConnsGauge().
With("backend", "backend1", "method", http.MethodGet, "protocol", "http").
Set(1)
prometheusRegistry.
BackendRetriesCounter().
With("backend", "backend1").
Add(1)
prometheusRegistry.
BackendServerUpGauge().
With("backend", "backend1", "url", "http://127.0.0.10:80").
Set(1)
delayForTrackingCompletion()
metricsFamilies := mustScrape()
tests := []struct {
name string
labels map[string]string
assert func(*dto.MetricFamily)
}{
{
name: configReloadsTotalName,
assert: buildCounterAssert(t, configReloadsTotalName, 1),
},
{
name: configReloadsFailuresTotalName,
assert: buildCounterAssert(t, configReloadsFailuresTotalName, 1),
},
{
name: configLastReloadSuccessName,
assert: buildTimestampAssert(t, configLastReloadSuccessName),
},
{
name: configLastReloadFailureName,
assert: buildTimestampAssert(t, configLastReloadFailureName),
},
{
name: entrypointReqsTotalName,
labels: map[string]string{
"code": "200",
"method": http.MethodGet,
"protocol": "http",
"entrypoint": "http",
},
assert: buildCounterAssert(t, entrypointReqsTotalName, 1),
},
{
name: entrypointReqDurationName,
labels: map[string]string{
"code": "200",
"method": http.MethodGet,
"protocol": "http",
"entrypoint": "http",
},
assert: buildHistogramAssert(t, entrypointReqDurationName, 1),
},
{
name: entrypointOpenConnsName,
labels: map[string]string{
"method": http.MethodGet,
"protocol": "http",
"entrypoint": "http",
},
assert: buildGaugeAssert(t, entrypointOpenConnsName, 1),
},
{
name: backendReqsTotalName,
labels: map[string]string{
"code": "200",
"method": http.MethodGet,
"protocol": "http",
"backend": "backend1",
},
assert: buildCounterAssert(t, backendReqsTotalName, 1),
},
{
name: backendReqDurationName,
labels: map[string]string{
"code": "200",
"method": http.MethodGet,
"protocol": "http",
"backend": "backend1",
},
assert: buildHistogramAssert(t, backendReqDurationName, 1),
},
{
name: backendOpenConnsName,
labels: map[string]string{
"method": http.MethodGet,
"protocol": "http",
"backend": "backend1",
},
assert: buildGaugeAssert(t, backendOpenConnsName, 1),
},
{
name: backendRetriesTotalName,
labels: map[string]string{
"backend": "backend1",
},
assert: buildGreaterThanCounterAssert(t, backendRetriesTotalName, 1),
},
{
name: backendServerUpName,
labels: map[string]string{
"backend": "backend1",
"url": "http://127.0.0.10:80",
},
assert: buildGaugeAssert(t, backendServerUpName, 1),
},
}
for _, test := range tests {
family := findMetricFamily(test.name, metricsFamilies)
if family == nil {
t.Errorf("gathered metrics do not contain %q", test.name)
continue
}
for _, label := range family.Metric[0].Label {
val, ok := test.labels[*label.Name]
if !ok {
t.Errorf("%q metric contains unexpected label %q", test.name, *label.Name)
} else if val != *label.Value {
t.Errorf("label %q in metric %q has wrong value %q, expected %q", *label.Name, test.name, *label.Value, val)
}
}
test.assert(family)
}
}
func TestPrometheusMetricRemoval(t *testing.T) {
// FIXME metrics
t.Skip("waiting for metrics")
// Reset state of global promState.
defer promState.reset()
prometheusRegistry := RegisterPrometheus(context.Background(), &types.Prometheus{})
defer prometheus.Unregister(promState)
configurations := make(config.Configurations)
configurations["providerName"] = &config.Configuration{
HTTP: th.BuildConfiguration(
th.WithRouters(
th.WithRouter("foo",
th.WithServiceName("bar")),
),
th.WithLoadBalancerServices(th.WithService("bar",
th.WithLBMethod("wrr"),
th.WithServers(th.WithServer("http://localhost:9000"))),
),
),
}
OnConfigurationUpdate(configurations)
// Register some metrics manually that are not part of the active configuration.
// Those metrics should be part of the /metrics output on the first scrape but
// should be removed after that scrape.
prometheusRegistry.
EntrypointReqsCounter().
With("entrypoint", "entrypoint2", "code", strconv.Itoa(http.StatusOK), "method", http.MethodGet, "protocol", "http").
Add(1)
prometheusRegistry.
BackendReqsCounter().
With("backend", "backend2", "code", strconv.Itoa(http.StatusOK), "method", http.MethodGet, "protocol", "http").
Add(1)
prometheusRegistry.
BackendServerUpGauge().
With("backend", "backend1", "url", "http://localhost:9999").
Set(1)
delayForTrackingCompletion()
assertMetricsExist(t, mustScrape(), entrypointReqsTotalName, backendReqsTotalName, backendServerUpName)
assertMetricsAbsent(t, mustScrape(), entrypointReqsTotalName, backendReqsTotalName, backendServerUpName)
// To verify that metrics belonging to active configurations are not removed
// here the counter examples.
prometheusRegistry.
EntrypointReqsCounter().
With("entrypoint", "entrypoint1", "code", strconv.Itoa(http.StatusOK), "method", http.MethodGet, "protocol", "http").
Add(1)
delayForTrackingCompletion()
assertMetricsExist(t, mustScrape(), entrypointReqsTotalName)
assertMetricsExist(t, mustScrape(), entrypointReqsTotalName)
}
func TestPrometheusRemovedMetricsReset(t *testing.T) {
// Reset state of global promState.
defer promState.reset()
prometheusRegistry := RegisterPrometheus(context.Background(), &types.Prometheus{})
defer prometheus.Unregister(promState)
labelNamesValues := []string{
"backend", "backend",
"code", strconv.Itoa(http.StatusOK),
"method", http.MethodGet,
"protocol", "http",
}
prometheusRegistry.
BackendReqsCounter().
With(labelNamesValues...).
Add(3)
delayForTrackingCompletion()
metricsFamilies := mustScrape()
assertCounterValue(t, 3, findMetricFamily(backendReqsTotalName, metricsFamilies), labelNamesValues...)
// There is no dynamic configuration and so this metric will be deleted
// after the first scrape.
assertMetricsAbsent(t, mustScrape(), backendReqsTotalName)
prometheusRegistry.
BackendReqsCounter().
With(labelNamesValues...).
Add(1)
delayForTrackingCompletion()
metricsFamilies = mustScrape()
assertCounterValue(t, 1, findMetricFamily(backendReqsTotalName, metricsFamilies), labelNamesValues...)
}
// Tracking and gathering the metrics happens concurrently.
// In practice this is no problem, because in case a tracked metric would miss
// the current scrape, it would just be there in the next one.
// That we can test reliably the tracking of all metrics here, we sleep
// for a short amount of time, to make sure the metric will be present
// in the next scrape.
func delayForTrackingCompletion() {
time.Sleep(250 * time.Millisecond)
}
func mustScrape() []*dto.MetricFamily {
families, err := prometheus.DefaultGatherer.Gather()
if err != nil {
panic(fmt.Sprintf("could not gather metrics families: %s", err))
}
return families
}
func assertMetricsExist(t *testing.T, families []*dto.MetricFamily, metricNames ...string) {
t.Helper()
for _, metricName := range metricNames {
if findMetricFamily(metricName, families) == nil {
t.Errorf("gathered metrics should contain %q", metricName)
}
}
}
func assertMetricsAbsent(t *testing.T, families []*dto.MetricFamily, metricNames ...string) {
t.Helper()
for _, metricName := range metricNames {
if findMetricFamily(metricName, families) != nil {
t.Errorf("gathered metrics should not contain %q", metricName)
}
}
}
func findMetricFamily(name string, families []*dto.MetricFamily) *dto.MetricFamily {
for _, family := range families {
if family.GetName() == name {
return family
}
}
return nil
}
func findMetricByLabelNamesValues(family *dto.MetricFamily, labelNamesValues ...string) *dto.Metric {
if family == nil {
return nil
}
for _, metric := range family.Metric {
if hasMetricAllLabelPairs(metric, labelNamesValues...) {
return metric
}
}
return nil
}
func hasMetricAllLabelPairs(metric *dto.Metric, labelNamesValues ...string) bool {
for i := 0; i < len(labelNamesValues); i += 2 {
name, val := labelNamesValues[i], labelNamesValues[i+1]
if !hasMetricLabelPair(metric, name, val) {
return false
}
}
return true
}
func hasMetricLabelPair(metric *dto.Metric, labelName, labelValue string) bool {
for _, labelPair := range metric.Label {
if labelPair.GetName() == labelName && labelPair.GetValue() == labelValue {
return true
}
}
return false
}
func assertCounterValue(t *testing.T, want float64, family *dto.MetricFamily, labelNamesValues ...string) {
t.Helper()
metric := findMetricByLabelNamesValues(family, labelNamesValues...)
if metric == nil {
t.Error("metric must not be nil")
return
}
if metric.Counter == nil {
t.Errorf("metric %s must be a counter", family.GetName())
return
}
if cv := metric.Counter.GetValue(); cv != want {
t.Errorf("metric %s has value %v, want %v", family.GetName(), cv, want)
}
}
func buildCounterAssert(t *testing.T, metricName string, expectedValue int) func(family *dto.MetricFamily) {
return func(family *dto.MetricFamily) {
if cv := int(family.Metric[0].Counter.GetValue()); cv != expectedValue {
t.Errorf("metric %s has value %d, want %d", metricName, cv, expectedValue)
}
}
}
func buildGreaterThanCounterAssert(t *testing.T, metricName string, expectedMinValue int) func(family *dto.MetricFamily) {
return func(family *dto.MetricFamily) {
if cv := int(family.Metric[0].Counter.GetValue()); cv < expectedMinValue {
t.Errorf("metric %s has value %d, want at least %d", metricName, cv, expectedMinValue)
}
}
}
func buildHistogramAssert(t *testing.T, metricName string, expectedSampleCount int) func(family *dto.MetricFamily) {
return func(family *dto.MetricFamily) {
if sc := int(family.Metric[0].Histogram.GetSampleCount()); sc != expectedSampleCount {
t.Errorf("metric %s has sample count value %d, want %d", metricName, sc, expectedSampleCount)
}
}
}
func buildGaugeAssert(t *testing.T, metricName string, expectedValue int) func(family *dto.MetricFamily) {
return func(family *dto.MetricFamily) {
if gv := int(family.Metric[0].Gauge.GetValue()); gv != expectedValue {
t.Errorf("metric %s has value %d, want %d", metricName, gv, expectedValue)
}
}
}
func buildTimestampAssert(t *testing.T, metricName string) func(family *dto.MetricFamily) {
return func(family *dto.MetricFamily) {
if ts := time.Unix(int64(family.Metric[0].Gauge.GetValue()), 0); time.Since(ts) > time.Minute {
t.Errorf("metric %s has wrong timestamp %v", metricName, ts)
}
}
}

86
pkg/metrics/statsd.go Normal file
View file

@ -0,0 +1,86 @@
package metrics
import (
"context"
"time"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/safe"
"github.com/containous/traefik/pkg/types"
kitlog "github.com/go-kit/kit/log"
"github.com/go-kit/kit/metrics/statsd"
)
var statsdClient = statsd.New("traefik.", kitlog.LoggerFunc(func(keyvals ...interface{}) error {
log.WithoutContext().WithField(log.MetricsProviderName, "statsd").Info(keyvals)
return nil
}))
var statsdTicker *time.Ticker
const (
statsdMetricsBackendReqsName = "backend.request.total"
statsdMetricsBackendLatencyName = "backend.request.duration"
statsdRetriesTotalName = "backend.retries.total"
statsdConfigReloadsName = "config.reload.total"
statsdConfigReloadsFailureName = statsdConfigReloadsName + ".failure"
statsdLastConfigReloadSuccessName = "config.reload.lastSuccessTimestamp"
statsdLastConfigReloadFailureName = "config.reload.lastFailureTimestamp"
statsdEntrypointReqsName = "entrypoint.request.total"
statsdEntrypointReqDurationName = "entrypoint.request.duration"
statsdEntrypointOpenConnsName = "entrypoint.connections.open"
statsdOpenConnsName = "backend.connections.open"
statsdServerUpName = "backend.server.up"
)
// RegisterStatsd registers the metrics pusher if this didn't happen yet and creates a statsd Registry instance.
func RegisterStatsd(ctx context.Context, config *types.Statsd) Registry {
if statsdTicker == nil {
statsdTicker = initStatsdTicker(ctx, config)
}
return &standardRegistry{
enabled: true,
configReloadsCounter: statsdClient.NewCounter(statsdConfigReloadsName, 1.0),
configReloadsFailureCounter: statsdClient.NewCounter(statsdConfigReloadsFailureName, 1.0),
lastConfigReloadSuccessGauge: statsdClient.NewGauge(statsdLastConfigReloadSuccessName),
lastConfigReloadFailureGauge: statsdClient.NewGauge(statsdLastConfigReloadFailureName),
entrypointReqsCounter: statsdClient.NewCounter(statsdEntrypointReqsName, 1.0),
entrypointReqDurationHistogram: statsdClient.NewTiming(statsdEntrypointReqDurationName, 1.0),
entrypointOpenConnsGauge: statsdClient.NewGauge(statsdEntrypointOpenConnsName),
backendReqsCounter: statsdClient.NewCounter(statsdMetricsBackendReqsName, 1.0),
backendReqDurationHistogram: statsdClient.NewTiming(statsdMetricsBackendLatencyName, 1.0),
backendRetriesCounter: statsdClient.NewCounter(statsdRetriesTotalName, 1.0),
backendOpenConnsGauge: statsdClient.NewGauge(statsdOpenConnsName),
backendServerUpGauge: statsdClient.NewGauge(statsdServerUpName),
}
}
// initStatsdTicker initializes metrics pusher and creates a statsdClient if not created already
func initStatsdTicker(ctx context.Context, config *types.Statsd) *time.Ticker {
address := config.Address
if len(address) == 0 {
address = "localhost:8125"
}
pushInterval, err := time.ParseDuration(config.PushInterval)
if err != nil {
log.FromContext(ctx).Warnf("Unable to parse %s from config.PushInterval: using 10s as the default value", config.PushInterval)
pushInterval = 10 * time.Second
}
report := time.NewTicker(pushInterval)
safe.Go(func() {
statsdClient.SendLoop(report.C, "udp", address)
})
return report
}
// StopStatsd stops internal statsdTicker which controls the pushing of metrics to StatsD Agent and resets it to `nil`
func StopStatsd() {
if statsdTicker != nil {
statsdTicker.Stop()
}
statsdTicker = nil
}

View file

@ -0,0 +1,51 @@
package metrics
import (
"context"
"net/http"
"testing"
"time"
"github.com/containous/traefik/pkg/types"
"github.com/stvp/go-udp-testing"
)
func TestStatsD(t *testing.T) {
udp.SetAddr(":18125")
// This is needed to make sure that UDP Listener listens for data a bit longer, otherwise it will quit after a millisecond
udp.Timeout = 5 * time.Second
statsdRegistry := RegisterStatsd(context.Background(), &types.Statsd{Address: ":18125", PushInterval: "1s"})
defer StopStatsd()
if !statsdRegistry.IsEnabled() {
t.Errorf("Statsd registry should return true for IsEnabled()")
}
expected := []string{
// We are only validating counts, as it is nearly impossible to validate latency, since it varies every run
"traefik.backend.request.total:2.000000|c\n",
"traefik.backend.retries.total:2.000000|c\n",
"traefik.backend.request.duration:10000.000000|ms",
"traefik.config.reload.total:1.000000|c\n",
"traefik.config.reload.total:1.000000|c\n",
"traefik.entrypoint.request.total:1.000000|c\n",
"traefik.entrypoint.request.duration:10000.000000|ms",
"traefik.entrypoint.connections.open:1.000000|g\n",
"traefik.backend.server.up:1.000000|g\n",
}
udp.ShouldReceiveAll(t, expected, func() {
statsdRegistry.BackendReqsCounter().With("service", "test", "code", string(http.StatusOK), "method", http.MethodGet).Add(1)
statsdRegistry.BackendReqsCounter().With("service", "test", "code", string(http.StatusNotFound), "method", http.MethodGet).Add(1)
statsdRegistry.BackendRetriesCounter().With("service", "test").Add(1)
statsdRegistry.BackendRetriesCounter().With("service", "test").Add(1)
statsdRegistry.BackendReqDurationHistogram().With("service", "test", "code", string(http.StatusOK)).Observe(10000)
statsdRegistry.ConfigReloadsCounter().Add(1)
statsdRegistry.ConfigReloadsFailureCounter().Add(1)
statsdRegistry.EntrypointReqsCounter().With("entrypoint", "test").Add(1)
statsdRegistry.EntrypointReqDurationHistogram().With("entrypoint", "test").Observe(10000)
statsdRegistry.EntrypointOpenConnsGauge().With("entrypoint", "test").Set(1)
statsdRegistry.BackendServerUpGauge().With("backend:test", "url", "http://127.0.0.1").Set(1)
})
}

View file

@ -0,0 +1,18 @@
package accesslog
import "io"
type captureRequestReader struct {
source io.ReadCloser
count int64
}
func (r *captureRequestReader) Read(p []byte) (int, error) {
n, err := r.source.Read(p)
r.count += int64(n)
return n, err
}
func (r *captureRequestReader) Close() error {
return r.source.Close()
}

View file

@ -0,0 +1,68 @@
package accesslog
import (
"bufio"
"fmt"
"net"
"net/http"
"github.com/containous/traefik/old/middlewares"
)
var (
_ middlewares.Stateful = &captureResponseWriter{}
)
// captureResponseWriter is a wrapper of type http.ResponseWriter
// that tracks request status and size
type captureResponseWriter struct {
rw http.ResponseWriter
status int
size int64
}
func (crw *captureResponseWriter) Header() http.Header {
return crw.rw.Header()
}
func (crw *captureResponseWriter) Write(b []byte) (int, error) {
if crw.status == 0 {
crw.status = http.StatusOK
}
size, err := crw.rw.Write(b)
crw.size += int64(size)
return size, err
}
func (crw *captureResponseWriter) WriteHeader(s int) {
crw.rw.WriteHeader(s)
crw.status = s
}
func (crw *captureResponseWriter) Flush() {
if f, ok := crw.rw.(http.Flusher); ok {
f.Flush()
}
}
func (crw *captureResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if h, ok := crw.rw.(http.Hijacker); ok {
return h.Hijack()
}
return nil, nil, fmt.Errorf("not a hijacker: %T", crw.rw)
}
func (crw *captureResponseWriter) CloseNotify() <-chan bool {
if c, ok := crw.rw.(http.CloseNotifier); ok {
return c.CloseNotify()
}
return nil
}
func (crw *captureResponseWriter) Status() int {
return crw.status
}
func (crw *captureResponseWriter) Size() int64 {
return crw.size
}

View file

@ -0,0 +1,65 @@
package accesslog
import (
"net/http"
"time"
"github.com/vulcand/oxy/utils"
)
// FieldApply function hook to add data in accesslog
type FieldApply func(rw http.ResponseWriter, r *http.Request, next http.Handler, data *LogData)
// FieldHandler sends a new field to the logger.
type FieldHandler struct {
next http.Handler
name string
value string
applyFn FieldApply
}
// NewFieldHandler creates a Field handler.
func NewFieldHandler(next http.Handler, name string, value string, applyFn FieldApply) http.Handler {
return &FieldHandler{next: next, name: name, value: value, applyFn: applyFn}
}
func (f *FieldHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
table := GetLogData(req)
if table == nil {
f.next.ServeHTTP(rw, req)
return
}
table.Core[f.name] = f.value
if f.applyFn != nil {
f.applyFn(rw, req, f.next, table)
} else {
f.next.ServeHTTP(rw, req)
}
}
// AddServiceFields add service fields
func AddServiceFields(rw http.ResponseWriter, req *http.Request, next http.Handler, data *LogData) {
data.Core[ServiceURL] = req.URL // note that this is *not* the original incoming URL
data.Core[ServiceAddr] = req.URL.Host
next.ServeHTTP(rw, req)
}
// AddOriginFields add origin fields
func AddOriginFields(rw http.ResponseWriter, req *http.Request, next http.Handler, data *LogData) {
crw := &captureResponseWriter{rw: rw}
start := time.Now().UTC()
next.ServeHTTP(crw, req)
// use UTC to handle switchover of daylight saving correctly
data.Core[OriginDuration] = time.Now().UTC().Sub(start)
data.Core[OriginStatus] = crw.Status()
// make copy of headers so we can ensure there is no subsequent mutation during response processing
data.OriginResponse = make(http.Header)
utils.CopyHeaders(data.OriginResponse, crw.Header())
data.Core[OriginContentSize] = crw.Size()
}

View file

@ -0,0 +1,122 @@
package accesslog
import (
"net/http"
)
const (
// StartUTC is the map key used for the time at which request processing started.
StartUTC = "StartUTC"
// StartLocal is the map key used for the local time at which request processing started.
StartLocal = "StartLocal"
// Duration is the map key used for the total time taken by processing the response, including the origin server's time but
// not the log writing time.
Duration = "Duration"
// RouterName is the map key used for the name of the Traefik router.
RouterName = "RouterName"
// ServiceName is the map key used for the name of the Traefik backend.
ServiceName = "ServiceName"
// ServiceURL is the map key used for the URL of the Traefik backend.
ServiceURL = "ServiceURL"
// ServiceAddr is the map key used for the IP:port of the Traefik backend (extracted from BackendURL)
ServiceAddr = "ServiceAddr"
// ClientAddr is the map key used for the remote address in its original form (usually IP:port).
ClientAddr = "ClientAddr"
// ClientHost is the map key used for the remote IP address from which the client request was received.
ClientHost = "ClientHost"
// ClientPort is the map key used for the remote TCP port from which the client request was received.
ClientPort = "ClientPort"
// ClientUsername is the map key used for the username provided in the URL, if present.
ClientUsername = "ClientUsername"
// RequestAddr is the map key used for the HTTP Host header (usually IP:port). This is treated as not a header by the Go API.
RequestAddr = "RequestAddr"
// RequestHost is the map key used for the HTTP Host server name (not including port).
RequestHost = "RequestHost"
// RequestPort is the map key used for the TCP port from the HTTP Host.
RequestPort = "RequestPort"
// RequestMethod is the map key used for the HTTP method.
RequestMethod = "RequestMethod"
// RequestPath is the map key used for the HTTP request URI, not including the scheme, host or port.
RequestPath = "RequestPath"
// RequestProtocol is the map key used for the version of HTTP requested.
RequestProtocol = "RequestProtocol"
// RequestContentSize is the map key used for the number of bytes in the request entity (a.k.a. body) sent by the client.
RequestContentSize = "RequestContentSize"
// RequestRefererHeader is the Referer header in the request
RequestRefererHeader = "request_Referer"
// RequestUserAgentHeader is the User-Agent header in the request
RequestUserAgentHeader = "request_User-Agent"
// OriginDuration is the map key used for the time taken by the origin server ('upstream') to return its response.
OriginDuration = "OriginDuration"
// OriginContentSize is the map key used for the content length specified by the origin server, or 0 if unspecified.
OriginContentSize = "OriginContentSize"
// OriginStatus is the map key used for the HTTP status code returned by the origin server.
// If the request was handled by this Traefik instance (e.g. with a redirect), then this value will be absent.
OriginStatus = "OriginStatus"
// DownstreamStatus is the map key used for the HTTP status code returned to the client.
DownstreamStatus = "DownstreamStatus"
// DownstreamContentSize is the map key used for the number of bytes in the response entity returned to the client.
// This is in addition to the "Content-Length" header, which may be present in the origin response.
DownstreamContentSize = "DownstreamContentSize"
// RequestCount is the map key used for the number of requests received since the Traefik instance started.
RequestCount = "RequestCount"
// GzipRatio is the map key used for the response body compression ratio achieved.
GzipRatio = "GzipRatio"
// Overhead is the map key used for the processing time overhead caused by Traefik.
Overhead = "Overhead"
// RetryAttempts is the map key used for the amount of attempts the request was retried.
RetryAttempts = "RetryAttempts"
)
// These are written out in the default case when no config is provided to specify keys of interest.
var defaultCoreKeys = [...]string{
StartUTC,
Duration,
RouterName,
ServiceName,
ServiceURL,
ClientHost,
ClientPort,
ClientUsername,
RequestHost,
RequestPort,
RequestMethod,
RequestPath,
RequestProtocol,
RequestContentSize,
OriginDuration,
OriginContentSize,
OriginStatus,
DownstreamStatus,
DownstreamContentSize,
RequestCount,
}
// This contains the set of all keys, i.e. all the default keys plus all non-default keys.
var allCoreKeys = make(map[string]struct{})
func init() {
for _, k := range defaultCoreKeys {
allCoreKeys[k] = struct{}{}
}
allCoreKeys[ServiceAddr] = struct{}{}
allCoreKeys[ClientAddr] = struct{}{}
allCoreKeys[RequestAddr] = struct{}{}
allCoreKeys[GzipRatio] = struct{}{}
allCoreKeys[StartLocal] = struct{}{}
allCoreKeys[Overhead] = struct{}{}
allCoreKeys[RetryAttempts] = struct{}{}
}
// CoreLogData holds the fields computed from the request/response.
type CoreLogData map[string]interface{}
// LogData is the data captured by the middleware so that it can be logged.
type LogData struct {
Core CoreLogData
Request http.Header
OriginResponse http.Header
DownstreamResponse http.Header
}

View file

@ -0,0 +1,344 @@
package accesslog
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"sync"
"sync/atomic"
"time"
"github.com/containous/alice"
"github.com/containous/flaeg/parse"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/types"
"github.com/sirupsen/logrus"
)
type key string
const (
// DataTableKey is the key within the request context used to store the Log Data Table.
DataTableKey key = "LogDataTable"
// CommonFormat is the common logging format (CLF).
CommonFormat string = "common"
// JSONFormat is the JSON logging format.
JSONFormat string = "json"
)
type handlerParams struct {
logDataTable *LogData
crr *captureRequestReader
crw *captureResponseWriter
}
// Handler will write each request and its response to the access log.
type Handler struct {
config *types.AccessLog
logger *logrus.Logger
file *os.File
mu sync.Mutex
httpCodeRanges types.HTTPCodeRanges
logHandlerChan chan handlerParams
wg sync.WaitGroup
}
// WrapHandler Wraps access log handler into an Alice Constructor.
func WrapHandler(handler *Handler) alice.Constructor {
return func(next http.Handler) (http.Handler, error) {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
handler.ServeHTTP(rw, req, next.ServeHTTP)
}), nil
}
}
// NewHandler creates a new Handler.
func NewHandler(config *types.AccessLog) (*Handler, error) {
file := os.Stdout
if len(config.FilePath) > 0 {
f, err := openAccessLogFile(config.FilePath)
if err != nil {
return nil, fmt.Errorf("error opening access log file: %s", err)
}
file = f
}
logHandlerChan := make(chan handlerParams, config.BufferingSize)
var formatter logrus.Formatter
switch config.Format {
case CommonFormat:
formatter = new(CommonLogFormatter)
case JSONFormat:
formatter = new(logrus.JSONFormatter)
default:
return nil, fmt.Errorf("unsupported access log format: %s", config.Format)
}
logger := &logrus.Logger{
Out: file,
Formatter: formatter,
Hooks: make(logrus.LevelHooks),
Level: logrus.InfoLevel,
}
logHandler := &Handler{
config: config,
logger: logger,
file: file,
logHandlerChan: logHandlerChan,
}
if config.Filters != nil {
if httpCodeRanges, err := types.NewHTTPCodeRanges(config.Filters.StatusCodes); err != nil {
log.WithoutContext().Errorf("Failed to create new HTTP code ranges: %s", err)
} else {
logHandler.httpCodeRanges = httpCodeRanges
}
}
if config.BufferingSize > 0 {
logHandler.wg.Add(1)
go func() {
defer logHandler.wg.Done()
for handlerParams := range logHandler.logHandlerChan {
logHandler.logTheRoundTrip(handlerParams.logDataTable, handlerParams.crr, handlerParams.crw)
}
}()
}
return logHandler, nil
}
func openAccessLogFile(filePath string) (*os.File, error) {
dir := filepath.Dir(filePath)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create log path %s: %s", dir, err)
}
file, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0664)
if err != nil {
return nil, fmt.Errorf("error opening file %s: %s", filePath, err)
}
return file, nil
}
// GetLogData gets the request context object that contains logging data.
// This creates data as the request passes through the middleware chain.
func GetLogData(req *http.Request) *LogData {
if ld, ok := req.Context().Value(DataTableKey).(*LogData); ok {
return ld
}
return nil
}
func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
now := time.Now().UTC()
core := CoreLogData{
StartUTC: now,
StartLocal: now.Local(),
}
logDataTable := &LogData{Core: core, Request: req.Header}
reqWithDataTable := req.WithContext(context.WithValue(req.Context(), DataTableKey, logDataTable))
var crr *captureRequestReader
if req.Body != nil {
crr = &captureRequestReader{source: req.Body, count: 0}
reqWithDataTable.Body = crr
}
core[RequestCount] = nextRequestCount()
if req.Host != "" {
core[RequestAddr] = req.Host
core[RequestHost], core[RequestPort] = silentSplitHostPort(req.Host)
}
// copy the URL without the scheme, hostname etc
urlCopy := &url.URL{
Path: req.URL.Path,
RawPath: req.URL.RawPath,
RawQuery: req.URL.RawQuery,
ForceQuery: req.URL.ForceQuery,
Fragment: req.URL.Fragment,
}
urlCopyString := urlCopy.String()
core[RequestMethod] = req.Method
core[RequestPath] = urlCopyString
core[RequestProtocol] = req.Proto
core[ClientAddr] = req.RemoteAddr
core[ClientHost], core[ClientPort] = silentSplitHostPort(req.RemoteAddr)
if forwardedFor := req.Header.Get("X-Forwarded-For"); forwardedFor != "" {
core[ClientHost] = forwardedFor
}
crw := &captureResponseWriter{rw: rw}
next.ServeHTTP(crw, reqWithDataTable)
if _, ok := core[ClientUsername]; !ok {
core[ClientUsername] = usernameIfPresent(reqWithDataTable.URL)
}
logDataTable.DownstreamResponse = crw.Header()
if h.config.BufferingSize > 0 {
h.logHandlerChan <- handlerParams{
logDataTable: logDataTable,
crr: crr,
crw: crw,
}
} else {
h.logTheRoundTrip(logDataTable, crr, crw)
}
}
// Close closes the Logger (i.e. the file, drain logHandlerChan, etc).
func (h *Handler) Close() error {
close(h.logHandlerChan)
h.wg.Wait()
return h.file.Close()
}
// Rotate closes and reopens the log file to allow for rotation by an external source.
func (h *Handler) Rotate() error {
var err error
if h.file != nil {
defer func(f *os.File) {
f.Close()
}(h.file)
}
h.file, err = os.OpenFile(h.config.FilePath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0664)
if err != nil {
return err
}
h.mu.Lock()
defer h.mu.Unlock()
h.logger.Out = h.file
return nil
}
func silentSplitHostPort(value string) (host string, port string) {
host, port, err := net.SplitHostPort(value)
if err != nil {
return value, "-"
}
return host, port
}
func usernameIfPresent(theURL *url.URL) string {
if theURL.User != nil {
if name := theURL.User.Username(); name != "" {
return name
}
}
return "-"
}
// Logging handler to log frontend name, backend name, and elapsed time.
func (h *Handler) logTheRoundTrip(logDataTable *LogData, crr *captureRequestReader, crw *captureResponseWriter) {
core := logDataTable.Core
retryAttempts, ok := core[RetryAttempts].(int)
if !ok {
retryAttempts = 0
}
core[RetryAttempts] = retryAttempts
if crr != nil {
core[RequestContentSize] = crr.count
}
core[DownstreamStatus] = crw.Status()
// n.b. take care to perform time arithmetic using UTC to avoid errors at DST boundaries.
totalDuration := time.Now().UTC().Sub(core[StartUTC].(time.Time))
core[Duration] = totalDuration
if h.keepAccessLog(crw.Status(), retryAttempts, totalDuration) {
core[DownstreamContentSize] = crw.Size()
if original, ok := core[OriginContentSize]; ok {
o64 := original.(int64)
if crw.Size() != o64 && crw.Size() != 0 {
core[GzipRatio] = float64(o64) / float64(crw.Size())
}
}
core[Overhead] = totalDuration
if origin, ok := core[OriginDuration]; ok {
core[Overhead] = totalDuration - origin.(time.Duration)
}
fields := logrus.Fields{}
for k, v := range logDataTable.Core {
if h.config.Fields.Keep(k) {
fields[k] = v
}
}
h.redactHeaders(logDataTable.Request, fields, "request_")
h.redactHeaders(logDataTable.OriginResponse, fields, "origin_")
h.redactHeaders(logDataTable.DownstreamResponse, fields, "downstream_")
h.mu.Lock()
defer h.mu.Unlock()
h.logger.WithFields(fields).Println()
}
}
func (h *Handler) redactHeaders(headers http.Header, fields logrus.Fields, prefix string) {
for k := range headers {
v := h.config.Fields.KeepHeader(k)
if v == types.AccessLogKeep {
fields[prefix+k] = headers.Get(k)
} else if v == types.AccessLogRedact {
fields[prefix+k] = "REDACTED"
}
}
}
func (h *Handler) keepAccessLog(statusCode, retryAttempts int, duration time.Duration) bool {
if h.config.Filters == nil {
// no filters were specified
return true
}
if len(h.httpCodeRanges) == 0 && !h.config.Filters.RetryAttempts && h.config.Filters.MinDuration == 0 {
// empty filters were specified, e.g. by passing --accessLog.filters only (without other filter options)
return true
}
if h.httpCodeRanges.Contains(statusCode) {
return true
}
if h.config.Filters.RetryAttempts && retryAttempts > 0 {
return true
}
if h.config.Filters.MinDuration > 0 && (parse.Duration(duration) > h.config.Filters.MinDuration) {
return true
}
return false
}
var requestCounter uint64 // Request ID
func nextRequestCount() uint64 {
return atomic.AddUint64(&requestCounter, 1)
}

View file

@ -0,0 +1,83 @@
package accesslog
import (
"bytes"
"fmt"
"time"
"github.com/sirupsen/logrus"
)
// default format for time presentation.
const (
commonLogTimeFormat = "02/Jan/2006:15:04:05 -0700"
defaultValue = "-"
)
// CommonLogFormatter provides formatting in the Traefik common log format.
type CommonLogFormatter struct{}
// Format formats the log entry in the Traefik common log format.
func (f *CommonLogFormatter) Format(entry *logrus.Entry) ([]byte, error) {
b := &bytes.Buffer{}
var timestamp = defaultValue
if v, ok := entry.Data[StartUTC]; ok {
timestamp = v.(time.Time).Format(commonLogTimeFormat)
}
var elapsedMillis int64
if v, ok := entry.Data[Duration]; ok {
elapsedMillis = v.(time.Duration).Nanoseconds() / 1000000
}
_, err := fmt.Fprintf(b, "%s - %s [%s] \"%s %s %s\" %v %v %s %s %v %s %s %dms\n",
toLog(entry.Data, ClientHost, defaultValue, false),
toLog(entry.Data, ClientUsername, defaultValue, false),
timestamp,
toLog(entry.Data, RequestMethod, defaultValue, false),
toLog(entry.Data, RequestPath, defaultValue, false),
toLog(entry.Data, RequestProtocol, defaultValue, false),
toLog(entry.Data, OriginStatus, defaultValue, true),
toLog(entry.Data, OriginContentSize, defaultValue, true),
toLog(entry.Data, "request_Referer", `"-"`, true),
toLog(entry.Data, "request_User-Agent", `"-"`, true),
toLog(entry.Data, RequestCount, defaultValue, true),
toLog(entry.Data, RouterName, defaultValue, true),
toLog(entry.Data, ServiceURL, defaultValue, true),
elapsedMillis)
return b.Bytes(), err
}
func toLog(fields logrus.Fields, key string, defaultValue string, quoted bool) interface{} {
if v, ok := fields[key]; ok {
if v == nil {
return defaultValue
}
switch s := v.(type) {
case string:
return toLogEntry(s, defaultValue, quoted)
case fmt.Stringer:
return toLogEntry(s.String(), defaultValue, quoted)
default:
return v
}
}
return defaultValue
}
func toLogEntry(s string, defaultValue string, quote bool) string {
if len(s) == 0 {
return defaultValue
}
if quote {
return `"` + s + `"`
}
return s
}

View file

@ -0,0 +1,140 @@
package accesslog
import (
"net/http"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)
func TestCommonLogFormatter_Format(t *testing.T) {
clf := CommonLogFormatter{}
testCases := []struct {
name string
data map[string]interface{}
expectedLog string
}{
{
name: "OriginStatus & OriginContentSize are nil",
data: map[string]interface{}{
StartUTC: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC),
Duration: 123 * time.Second,
ClientHost: "10.0.0.1",
ClientUsername: "Client",
RequestMethod: http.MethodGet,
RequestPath: "/foo",
RequestProtocol: "http",
OriginStatus: nil,
OriginContentSize: nil,
RequestRefererHeader: "",
RequestUserAgentHeader: "",
RequestCount: 0,
RouterName: "",
ServiceURL: "",
},
expectedLog: `10.0.0.1 - Client [10/Nov/2009:23:00:00 +0000] "GET /foo http" - - "-" "-" 0 - - 123000ms
`,
},
{
name: "all data",
data: map[string]interface{}{
StartUTC: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC),
Duration: 123 * time.Second,
ClientHost: "10.0.0.1",
ClientUsername: "Client",
RequestMethod: http.MethodGet,
RequestPath: "/foo",
RequestProtocol: "http",
OriginStatus: 123,
OriginContentSize: 132,
RequestRefererHeader: "referer",
RequestUserAgentHeader: "agent",
RequestCount: nil,
RouterName: "foo",
ServiceURL: "http://10.0.0.2/toto",
},
expectedLog: `10.0.0.1 - Client [10/Nov/2009:23:00:00 +0000] "GET /foo http" 123 132 "referer" "agent" - "foo" "http://10.0.0.2/toto" 123000ms
`,
},
}
for _, test := range testCases {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
entry := &logrus.Entry{Data: test.data}
raw, err := clf.Format(entry)
assert.NoError(t, err)
assert.Equal(t, test.expectedLog, string(raw))
})
}
}
func Test_toLog(t *testing.T) {
testCases := []struct {
desc string
fields logrus.Fields
fieldName string
defaultValue string
quoted bool
expectedLog interface{}
}{
{
desc: "Should return int 1",
fields: logrus.Fields{
"Powpow": 1,
},
fieldName: "Powpow",
defaultValue: defaultValue,
quoted: false,
expectedLog: 1,
},
{
desc: "Should return string foo",
fields: logrus.Fields{
"Powpow": "foo",
},
fieldName: "Powpow",
defaultValue: defaultValue,
quoted: true,
expectedLog: `"foo"`,
},
{
desc: "Should return defaultValue if fieldName does not exist",
fields: logrus.Fields{
"Powpow": "foo",
},
fieldName: "",
defaultValue: defaultValue,
quoted: false,
expectedLog: "-",
},
{
desc: "Should return defaultValue if fields is nil",
fields: nil,
fieldName: "",
defaultValue: defaultValue,
quoted: false,
expectedLog: "-",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
lg := toLog(test.fields, test.fieldName, defaultValue, test.quoted)
assert.Equal(t, test.expectedLog, lg)
})
}
}

View file

@ -0,0 +1,649 @@
package accesslog
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
"testing"
"time"
"github.com/containous/flaeg/parse"
"github.com/containous/traefik/pkg/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
logFileNameSuffix = "/traefik/logger/test.log"
testContent = "Hello, World"
testServiceName = "http://127.0.0.1/testService"
testRouterName = "testRouter"
testStatus = 123
testContentSize int64 = 12
testHostname = "TestHost"
testUsername = "TestUser"
testPath = "testpath"
testPort = 8181
testProto = "HTTP/0.0"
testMethod = http.MethodPost
testReferer = "testReferer"
testUserAgent = "testUserAgent"
testRetryAttempts = 2
testStart = time.Now()
)
func TestLogRotation(t *testing.T) {
tempDir, err := ioutil.TempDir("", "traefik_")
if err != nil {
t.Fatalf("Error setting up temporary directory: %s", err)
}
fileName := tempDir + "traefik.log"
rotatedFileName := fileName + ".rotated"
config := &types.AccessLog{FilePath: fileName, Format: CommonFormat}
logHandler, err := NewHandler(config)
if err != nil {
t.Fatalf("Error creating new log handler: %s", err)
}
defer logHandler.Close()
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
next := func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
}
iterations := 20
halfDone := make(chan bool)
writeDone := make(chan bool)
go func() {
for i := 0; i < iterations; i++ {
logHandler.ServeHTTP(recorder, req, next)
if i == iterations/2 {
halfDone <- true
}
}
writeDone <- true
}()
<-halfDone
err = os.Rename(fileName, rotatedFileName)
if err != nil {
t.Fatalf("Error renaming file: %s", err)
}
err = logHandler.Rotate()
if err != nil {
t.Fatalf("Error rotating file: %s", err)
}
select {
case <-writeDone:
gotLineCount := lineCount(t, fileName) + lineCount(t, rotatedFileName)
if iterations != gotLineCount {
t.Errorf("Wanted %d written log lines, got %d", iterations, gotLineCount)
}
case <-time.After(500 * time.Millisecond):
t.Fatalf("test timed out")
}
close(halfDone)
close(writeDone)
}
func lineCount(t *testing.T, fileName string) int {
t.Helper()
fileContents, err := ioutil.ReadFile(fileName)
if err != nil {
t.Fatalf("Error reading from file %s: %s", fileName, err)
}
count := 0
for _, line := range strings.Split(string(fileContents), "\n") {
if strings.TrimSpace(line) == "" {
continue
}
count++
}
return count
}
func TestLoggerCLF(t *testing.T) {
tmpDir := createTempDir(t, CommonFormat)
defer os.RemoveAll(tmpDir)
logFilePath := filepath.Join(tmpDir, logFileNameSuffix)
config := &types.AccessLog{FilePath: logFilePath, Format: CommonFormat}
doLogging(t, config)
logData, err := ioutil.ReadFile(logFilePath)
require.NoError(t, err)
expectedLog := ` TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testRouter" "http://127.0.0.1/testService" 1ms`
assertValidLogData(t, expectedLog, logData)
}
func TestAsyncLoggerCLF(t *testing.T) {
tmpDir := createTempDir(t, CommonFormat)
defer os.RemoveAll(tmpDir)
logFilePath := filepath.Join(tmpDir, logFileNameSuffix)
config := &types.AccessLog{FilePath: logFilePath, Format: CommonFormat, BufferingSize: 1024}
doLogging(t, config)
logData, err := ioutil.ReadFile(logFilePath)
require.NoError(t, err)
expectedLog := ` TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testRouter" "http://127.0.0.1/testService" 1ms`
assertValidLogData(t, expectedLog, logData)
}
func assertString(exp string) func(t *testing.T, actual interface{}) {
return func(t *testing.T, actual interface{}) {
t.Helper()
assert.Equal(t, exp, actual)
}
}
func assertNotEmpty() func(t *testing.T, actual interface{}) {
return func(t *testing.T, actual interface{}) {
t.Helper()
assert.NotEqual(t, "", actual)
}
}
func assertFloat64(exp float64) func(t *testing.T, actual interface{}) {
return func(t *testing.T, actual interface{}) {
t.Helper()
assert.Equal(t, exp, actual)
}
}
func assertFloat64NotZero() func(t *testing.T, actual interface{}) {
return func(t *testing.T, actual interface{}) {
t.Helper()
assert.NotZero(t, actual)
}
}
func TestLoggerJSON(t *testing.T) {
testCases := []struct {
desc string
config *types.AccessLog
expected map[string]func(t *testing.T, value interface{})
}{
{
desc: "default config",
config: &types.AccessLog{
FilePath: "",
Format: JSONFormat,
},
expected: map[string]func(t *testing.T, value interface{}){
RequestHost: assertString(testHostname),
RequestAddr: assertString(testHostname),
RequestMethod: assertString(testMethod),
RequestPath: assertString(testPath),
RequestProtocol: assertString(testProto),
RequestPort: assertString("-"),
DownstreamStatus: assertFloat64(float64(testStatus)),
DownstreamContentSize: assertFloat64(float64(len(testContent))),
OriginContentSize: assertFloat64(float64(len(testContent))),
OriginStatus: assertFloat64(float64(testStatus)),
RequestRefererHeader: assertString(testReferer),
RequestUserAgentHeader: assertString(testUserAgent),
RouterName: assertString(testRouterName),
ServiceURL: assertString(testServiceName),
ClientUsername: assertString(testUsername),
ClientHost: assertString(testHostname),
ClientPort: assertString(fmt.Sprintf("%d", testPort)),
ClientAddr: assertString(fmt.Sprintf("%s:%d", testHostname, testPort)),
"level": assertString("info"),
"msg": assertString(""),
"downstream_Content-Type": assertString("text/plain; charset=utf-8"),
RequestCount: assertFloat64NotZero(),
Duration: assertFloat64NotZero(),
Overhead: assertFloat64NotZero(),
RetryAttempts: assertFloat64(float64(testRetryAttempts)),
"time": assertNotEmpty(),
"StartLocal": assertNotEmpty(),
"StartUTC": assertNotEmpty(),
},
},
{
desc: "default config drop all fields",
config: &types.AccessLog{
FilePath: "",
Format: JSONFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
},
},
expected: map[string]func(t *testing.T, value interface{}){
"level": assertString("info"),
"msg": assertString(""),
"time": assertNotEmpty(),
"downstream_Content-Type": assertString("text/plain; charset=utf-8"),
RequestRefererHeader: assertString(testReferer),
RequestUserAgentHeader: assertString(testUserAgent),
},
},
{
desc: "default config drop all fields and headers",
config: &types.AccessLog{
FilePath: "",
Format: JSONFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Headers: &types.FieldHeaders{
DefaultMode: "drop",
},
},
},
expected: map[string]func(t *testing.T, value interface{}){
"level": assertString("info"),
"msg": assertString(""),
"time": assertNotEmpty(),
},
},
{
desc: "default config drop all fields and redact headers",
config: &types.AccessLog{
FilePath: "",
Format: JSONFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Headers: &types.FieldHeaders{
DefaultMode: "redact",
},
},
},
expected: map[string]func(t *testing.T, value interface{}){
"level": assertString("info"),
"msg": assertString(""),
"time": assertNotEmpty(),
"downstream_Content-Type": assertString("REDACTED"),
RequestRefererHeader: assertString("REDACTED"),
RequestUserAgentHeader: assertString("REDACTED"),
},
},
{
desc: "default config drop all fields and headers but kept someone",
config: &types.AccessLog{
FilePath: "",
Format: JSONFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Names: types.FieldNames{
RequestHost: "keep",
},
Headers: &types.FieldHeaders{
DefaultMode: "drop",
Names: types.FieldHeaderNames{
"Referer": "keep",
},
},
},
},
expected: map[string]func(t *testing.T, value interface{}){
RequestHost: assertString(testHostname),
"level": assertString("info"),
"msg": assertString(""),
"time": assertNotEmpty(),
RequestRefererHeader: assertString(testReferer),
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
tmpDir := createTempDir(t, JSONFormat)
defer os.RemoveAll(tmpDir)
logFilePath := filepath.Join(tmpDir, logFileNameSuffix)
test.config.FilePath = logFilePath
doLogging(t, test.config)
logData, err := ioutil.ReadFile(logFilePath)
require.NoError(t, err)
jsonData := make(map[string]interface{})
err = json.Unmarshal(logData, &jsonData)
require.NoError(t, err)
assert.Equal(t, len(test.expected), len(jsonData))
for field, assertion := range test.expected {
assertion(t, jsonData[field])
}
})
}
}
func TestNewLogHandlerOutputStdout(t *testing.T) {
testCases := []struct {
desc string
config *types.AccessLog
expectedLog string
}{
{
desc: "default config",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
},
{
desc: "default config with empty filters",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Filters: &types.AccessLogFilters{},
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
},
{
desc: "Status code filter not matching",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Filters: &types.AccessLogFilters{
StatusCodes: []string{"200"},
},
},
expectedLog: ``,
},
{
desc: "Status code filter matching",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Filters: &types.AccessLogFilters{
StatusCodes: []string{"123"},
},
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
},
{
desc: "Duration filter not matching",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Filters: &types.AccessLogFilters{
MinDuration: parse.Duration(1 * time.Hour),
},
},
expectedLog: ``,
},
{
desc: "Duration filter matching",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Filters: &types.AccessLogFilters{
MinDuration: parse.Duration(1 * time.Millisecond),
},
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
},
{
desc: "Retry attempts filter matching",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Filters: &types.AccessLogFilters{
RetryAttempts: true,
},
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
},
{
desc: "Default mode keep",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Fields: &types.AccessLogFields{
DefaultMode: "keep",
},
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
},
{
desc: "Default mode keep with override",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Fields: &types.AccessLogFields{
DefaultMode: "keep",
Names: types.FieldNames{
ClientHost: "drop",
},
},
},
expectedLog: `- - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
},
{
desc: "Default mode drop",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
},
},
expectedLog: `- - - [-] "- - -" - - "testReferer" "testUserAgent" - - - 0ms`,
},
{
desc: "Default mode drop with override",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Names: types.FieldNames{
ClientHost: "drop",
ClientUsername: "keep",
},
},
},
expectedLog: `- - TestUser [-] "- - -" - - "testReferer" "testUserAgent" - - - 0ms`,
},
{
desc: "Default mode drop with header dropped",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Names: types.FieldNames{
ClientHost: "drop",
ClientUsername: "keep",
},
Headers: &types.FieldHeaders{
DefaultMode: "drop",
},
},
},
expectedLog: `- - TestUser [-] "- - -" - - "-" "-" - - - 0ms`,
},
{
desc: "Default mode drop with header redacted",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Names: types.FieldNames{
ClientHost: "drop",
ClientUsername: "keep",
},
Headers: &types.FieldHeaders{
DefaultMode: "redact",
},
},
},
expectedLog: `- - TestUser [-] "- - -" - - "REDACTED" "REDACTED" - - - 0ms`,
},
{
desc: "Default mode drop with header redacted",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Names: types.FieldNames{
ClientHost: "drop",
ClientUsername: "keep",
},
Headers: &types.FieldHeaders{
DefaultMode: "keep",
Names: types.FieldHeaderNames{
"Referer": "redact",
},
},
},
},
expectedLog: `- - TestUser [-] "- - -" - - "REDACTED" "testUserAgent" - - - 0ms`,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
// NOTE: It is not possible to run these cases in parallel because we capture Stdout
file, restoreStdout := captureStdout(t)
defer restoreStdout()
doLogging(t, test.config)
written, err := ioutil.ReadFile(file.Name())
require.NoError(t, err, "unable to read captured stdout from file")
assertValidLogData(t, test.expectedLog, written)
})
}
}
func assertValidLogData(t *testing.T, expected string, logData []byte) {
if len(expected) == 0 {
assert.Zero(t, len(logData))
t.Log(string(logData))
return
}
result, err := ParseAccessLog(string(logData))
require.NoError(t, err)
resultExpected, err := ParseAccessLog(expected)
require.NoError(t, err)
formatErrMessage := fmt.Sprintf(`
Expected: %s
Actual: %s`, expected, string(logData))
require.Equal(t, len(resultExpected), len(result), formatErrMessage)
assert.Equal(t, resultExpected[ClientHost], result[ClientHost], formatErrMessage)
assert.Equal(t, resultExpected[ClientUsername], result[ClientUsername], formatErrMessage)
assert.Equal(t, resultExpected[RequestMethod], result[RequestMethod], formatErrMessage)
assert.Equal(t, resultExpected[RequestPath], result[RequestPath], formatErrMessage)
assert.Equal(t, resultExpected[RequestProtocol], result[RequestProtocol], formatErrMessage)
assert.Equal(t, resultExpected[OriginStatus], result[OriginStatus], formatErrMessage)
assert.Equal(t, resultExpected[OriginContentSize], result[OriginContentSize], formatErrMessage)
assert.Equal(t, resultExpected[RequestRefererHeader], result[RequestRefererHeader], formatErrMessage)
assert.Equal(t, resultExpected[RequestUserAgentHeader], result[RequestUserAgentHeader], formatErrMessage)
assert.Regexp(t, regexp.MustCompile("[0-9]*"), result[RequestCount], formatErrMessage)
assert.Equal(t, resultExpected[RouterName], result[RouterName], formatErrMessage)
assert.Equal(t, resultExpected[ServiceURL], result[ServiceURL], formatErrMessage)
assert.Regexp(t, regexp.MustCompile("[0-9]*ms"), result[Duration], formatErrMessage)
}
func captureStdout(t *testing.T) (out *os.File, restoreStdout func()) {
file, err := ioutil.TempFile("", "testlogger")
require.NoError(t, err, "failed to create temp file")
original := os.Stdout
os.Stdout = file
restoreStdout = func() {
os.Stdout = original
}
return file, restoreStdout
}
func createTempDir(t *testing.T, prefix string) string {
tmpDir, err := ioutil.TempDir("", prefix)
require.NoError(t, err, "failed to create temp dir")
return tmpDir
}
func doLogging(t *testing.T, config *types.AccessLog) {
logger, err := NewHandler(config)
require.NoError(t, err)
defer logger.Close()
if config.FilePath != "" {
_, err = os.Stat(config.FilePath)
require.NoError(t, err, fmt.Sprintf("logger should create %s", config.FilePath))
}
req := &http.Request{
Header: map[string][]string{
"User-Agent": {testUserAgent},
"Referer": {testReferer},
},
Proto: testProto,
Host: testHostname,
Method: testMethod,
RemoteAddr: fmt.Sprintf("%s:%d", testHostname, testPort),
URL: &url.URL{
User: url.UserPassword(testUsername, ""),
Path: testPath,
},
}
logger.ServeHTTP(httptest.NewRecorder(), req, logWriterTestHandlerFunc)
}
func logWriterTestHandlerFunc(rw http.ResponseWriter, r *http.Request) {
if _, err := rw.Write([]byte(testContent)); err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
return
}
logData := GetLogData(r)
if logData != nil {
logData.Core[RouterName] = testRouterName
logData.Core[ServiceURL] = testServiceName
logData.Core[OriginStatus] = testStatus
logData.Core[OriginContentSize] = testContentSize
logData.Core[RetryAttempts] = testRetryAttempts
logData.Core[StartUTC] = testStart.UTC()
logData.Core[StartLocal] = testStart.Local()
} else {
http.Error(rw, "LogData is nil", http.StatusInternalServerError)
return
}
rw.WriteHeader(testStatus)
}

View file

@ -0,0 +1,54 @@
package accesslog
import (
"bytes"
"regexp"
)
// ParseAccessLog parse line of access log and return a map with each fields
func ParseAccessLog(data string) (map[string]string, error) {
var buffer bytes.Buffer
buffer.WriteString(`(\S+)`) // 1 - ClientHost
buffer.WriteString(`\s-\s`) // - - Spaces
buffer.WriteString(`(\S+)\s`) // 2 - ClientUsername
buffer.WriteString(`\[([^]]+)\]\s`) // 3 - StartUTC
buffer.WriteString(`"(\S*)\s?`) // 4 - RequestMethod
buffer.WriteString(`((?:[^"]*(?:\\")?)*)\s`) // 5 - RequestPath
buffer.WriteString(`([^"]*)"\s`) // 6 - RequestProtocol
buffer.WriteString(`(\S+)\s`) // 7 - OriginStatus
buffer.WriteString(`(\S+)\s`) // 8 - OriginContentSize
buffer.WriteString(`("?\S+"?)\s`) // 9 - Referrer
buffer.WriteString(`("\S+")\s`) // 10 - User-Agent
buffer.WriteString(`(\S+)\s`) // 11 - RequestCount
buffer.WriteString(`("[^"]*"|-)\s`) // 12 - FrontendName
buffer.WriteString(`("[^"]*"|-)\s`) // 13 - BackendURL
buffer.WriteString(`(\S+)`) // 14 - Duration
regex, err := regexp.Compile(buffer.String())
if err != nil {
return nil, err
}
submatch := regex.FindStringSubmatch(data)
result := make(map[string]string)
// Need to be > 13 to match CLF format
if len(submatch) > 13 {
result[ClientHost] = submatch[1]
result[ClientUsername] = submatch[2]
result[StartUTC] = submatch[3]
result[RequestMethod] = submatch[4]
result[RequestPath] = submatch[5]
result[RequestProtocol] = submatch[6]
result[OriginStatus] = submatch[7]
result[OriginContentSize] = submatch[8]
result[RequestRefererHeader] = submatch[9]
result[RequestUserAgentHeader] = submatch[10]
result[RequestCount] = submatch[11]
result[RouterName] = submatch[12]
result[ServiceURL] = submatch[13]
result[Duration] = submatch[14]
}
return result, nil
}

View file

@ -0,0 +1,75 @@
package accesslog
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestParseAccessLog(t *testing.T) {
testCases := []struct {
desc string
value string
expected map[string]string
}{
{
desc: "full log",
value: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testRouter" "http://127.0.0.1/testService" 1ms`,
expected: map[string]string{
ClientHost: "TestHost",
ClientUsername: "TestUser",
StartUTC: "13/Apr/2016:07:14:19 -0700",
RequestMethod: "POST",
RequestPath: "testpath",
RequestProtocol: "HTTP/0.0",
OriginStatus: "123",
OriginContentSize: "12",
RequestRefererHeader: `"testReferer"`,
RequestUserAgentHeader: `"testUserAgent"`,
RequestCount: "1",
RouterName: `"testRouter"`,
ServiceURL: `"http://127.0.0.1/testService"`,
Duration: "1ms",
},
},
{
desc: "log with space",
value: `127.0.0.1 - - [09/Mar/2018:10:51:32 +0000] "GET / HTTP/1.1" 401 17 "-" "Go-http-client/1.1" 1 "testRouter with space" - 0ms`,
expected: map[string]string{
ClientHost: "127.0.0.1",
ClientUsername: "-",
StartUTC: "09/Mar/2018:10:51:32 +0000",
RequestMethod: "GET",
RequestPath: "/",
RequestProtocol: "HTTP/1.1",
OriginStatus: "401",
OriginContentSize: "17",
RequestRefererHeader: `"-"`,
RequestUserAgentHeader: `"Go-http-client/1.1"`,
RequestCount: "1",
RouterName: `"testRouter with space"`,
ServiceURL: `-`,
Duration: "0ms",
},
},
{
desc: "bad log",
value: `bad`,
expected: map[string]string{},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
result, err := ParseAccessLog(test.value)
assert.NoError(t, err)
assert.Equal(t, len(test.expected), len(result))
for key, value := range test.expected {
assert.Equal(t, value, result[key])
}
})
}
}

View file

@ -0,0 +1,21 @@
package accesslog
import (
"net/http"
)
// SaveRetries is an implementation of RetryListener that stores RetryAttempts in the LogDataTable.
type SaveRetries struct{}
// Retried implements the RetryListener interface and will be called for each retry that happens.
func (s *SaveRetries) Retried(req *http.Request, attempt int) {
// it is the request attempt x, but the retry attempt is x-1
if attempt > 0 {
attempt--
}
table := GetLogData(req)
if table != nil {
table.Core[RetryAttempts] = attempt
}
}

View file

@ -0,0 +1,48 @@
package accesslog
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
)
func TestSaveRetries(t *testing.T) {
tests := []struct {
requestAttempt int
wantRetryAttemptsInLog int
}{
{
requestAttempt: 0,
wantRetryAttemptsInLog: 0,
},
{
requestAttempt: 1,
wantRetryAttemptsInLog: 0,
},
{
requestAttempt: 3,
wantRetryAttemptsInLog: 2,
},
}
for _, test := range tests {
test := test
t.Run(fmt.Sprintf("%d retries", test.requestAttempt), func(t *testing.T) {
t.Parallel()
saveRetries := &SaveRetries{}
logDataTable := &LogData{Core: make(CoreLogData)}
req := httptest.NewRequest(http.MethodGet, "/some/path", nil)
reqWithDataTable := req.WithContext(context.WithValue(req.Context(), DataTableKey, logDataTable))
saveRetries.Retried(reqWithDataTable, test.requestAttempt)
if logDataTable.Core[RetryAttempts] != test.wantRetryAttemptsInLog {
t.Errorf("got %v in logDataTable, want %v", logDataTable.Core[RetryAttempts], test.wantRetryAttemptsInLog)
}
})
}
}

View file

@ -0,0 +1,62 @@
package addprefix
import (
"context"
"fmt"
"net/http"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
)
const (
typeName = "AddPrefix"
)
// AddPrefix is a middleware used to add prefix to an URL request.
type addPrefix struct {
next http.Handler
prefix string
name string
}
// New creates a new handler.
func New(ctx context.Context, next http.Handler, config config.AddPrefix, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
var result *addPrefix
if len(config.Prefix) > 0 {
result = &addPrefix{
prefix: config.Prefix,
next: next,
name: name,
}
} else {
return nil, fmt.Errorf("prefix cannot be empty")
}
return result, nil
}
func (ap *addPrefix) GetTracingInformation() (string, ext.SpanKindEnum) {
return ap.name, tracing.SpanKindNoneEnum
}
func (ap *addPrefix) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), ap.name, typeName)
oldURLPath := req.URL.Path
req.URL.Path = ap.prefix + req.URL.Path
logger.Debugf("URL.Path is now %s (was %s).", req.URL.Path, oldURLPath)
if req.URL.RawPath != "" {
oldURLRawPath := req.URL.RawPath
req.URL.RawPath = ap.prefix + req.URL.RawPath
logger.Debugf("URL.RawPath is now %s (was %s).", req.URL.RawPath, oldURLRawPath)
}
req.RequestURI = req.URL.RequestURI()
ap.next.ServeHTTP(rw, req)
}

View file

@ -0,0 +1,104 @@
package addprefix
import (
"context"
"net/http"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewAddPrefix(t *testing.T) {
testCases := []struct {
desc string
prefix config.AddPrefix
expectsError bool
}{
{
desc: "Works with a non empty prefix",
prefix: config.AddPrefix{Prefix: "/a"},
},
{
desc: "Fails if prefix is empty",
prefix: config.AddPrefix{Prefix: ""},
expectsError: true,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
_, err := New(context.Background(), next, test.prefix, "foo-add-prefix")
if test.expectsError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestAddPrefix(t *testing.T) {
logrus.SetLevel(logrus.DebugLevel)
testCases := []struct {
desc string
prefix config.AddPrefix
path string
expectedPath string
expectedRawPath string
}{
{
desc: "Works with a regular path",
prefix: config.AddPrefix{Prefix: "/a"},
path: "/b",
expectedPath: "/a/b",
},
{
desc: "Works with a raw path",
prefix: config.AddPrefix{Prefix: "/a"},
path: "/b%2Fc",
expectedPath: "/a/b/c",
expectedRawPath: "/a/b%2Fc",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
var actualPath, actualRawPath, requestURI string
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actualPath = r.URL.Path
actualRawPath = r.URL.RawPath
requestURI = r.RequestURI
})
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
handler, err := New(context.Background(), next, test.prefix, "foo-add-prefix")
require.NoError(t, err)
handler.ServeHTTP(nil, req)
assert.Equal(t, test.expectedPath, actualPath)
assert.Equal(t, test.expectedRawPath, actualRawPath)
expectedURI := test.expectedPath
if test.expectedRawPath != "" {
// go HTTP uses the raw path when existent in the RequestURI
expectedURI = test.expectedRawPath
}
assert.Equal(t, expectedURI, requestURI)
})
}
}

View file

@ -0,0 +1,65 @@
package auth
import (
"io/ioutil"
"strings"
)
// UserParser Parses a string and return a userName/userHash. An error if the format of the string is incorrect.
type UserParser func(user string) (string, string, error)
const (
defaultRealm = "traefik"
authorizationHeader = "Authorization"
)
func getUsers(fileName string, appendUsers []string, parser UserParser) (map[string]string, error) {
users, err := loadUsers(fileName, appendUsers)
if err != nil {
return nil, err
}
userMap := make(map[string]string)
for _, user := range users {
userName, userHash, err := parser(user)
if err != nil {
return nil, err
}
userMap[userName] = userHash
}
return userMap, nil
}
func loadUsers(fileName string, appendUsers []string) ([]string, error) {
var users []string
var err error
if fileName != "" {
users, err = getLinesFromFile(fileName)
if err != nil {
return nil, err
}
}
return append(users, appendUsers...), nil
}
func getLinesFromFile(filename string) ([]string, error) {
dat, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
// Trim lines and filter out blanks
rawLines := strings.Split(string(dat), "\n")
var filteredLines []string
for _, rawLine := range rawLines {
line := strings.TrimSpace(rawLine)
if line != "" && !strings.HasPrefix(line, "#") {
filteredLines = append(filteredLines, line)
}
}
return filteredLines, nil
}

View file

@ -0,0 +1,102 @@
package auth
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
goauth "github.com/abbot/go-http-auth"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/middlewares/accesslog"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
)
const (
basicTypeName = "BasicAuth"
)
type basicAuth struct {
next http.Handler
auth *goauth.BasicAuth
users map[string]string
headerField string
removeHeader bool
name string
}
// NewBasic creates a basicAuth middleware.
func NewBasic(ctx context.Context, next http.Handler, authConfig config.BasicAuth, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, basicTypeName).Debug("Creating middleware")
users, err := getUsers(authConfig.UsersFile, authConfig.Users, basicUserParser)
if err != nil {
return nil, err
}
ba := &basicAuth{
next: next,
users: users,
headerField: authConfig.HeaderField,
removeHeader: authConfig.RemoveHeader,
name: name,
}
realm := defaultRealm
if len(authConfig.Realm) > 0 {
realm = authConfig.Realm
}
ba.auth = goauth.NewBasicAuthenticator(realm, ba.secretBasic)
return ba, nil
}
func (b *basicAuth) GetTracingInformation() (string, ext.SpanKindEnum) {
return b.name, tracing.SpanKindNoneEnum
}
func (b *basicAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), b.name, basicTypeName)
if username := b.auth.CheckAuth(req); username == "" {
logger.Debug("Authentication failed")
tracing.SetErrorWithEvent(req, "Authentication failed")
b.auth.RequireAuth(rw, req)
} else {
logger.Debug("Authentication succeeded")
req.URL.User = url.User(username)
logData := accesslog.GetLogData(req)
if logData != nil {
logData.Core[accesslog.ClientUsername] = username
}
if b.headerField != "" {
req.Header[b.headerField] = []string{username}
}
if b.removeHeader {
logger.Debug("Removing authorization header")
req.Header.Del(authorizationHeader)
}
b.next.ServeHTTP(rw, req)
}
}
func (b *basicAuth) secretBasic(user, realm string) string {
if secret, ok := b.users[user]; ok {
return secret
}
return ""
}
func basicUserParser(user string) (string, string, error) {
split := strings.Split(user, ":")
if len(split) != 2 {
return "", "", fmt.Errorf("error parsing BasicUser: %v", user)
}
return split[0], split[1], nil
}

View file

@ -0,0 +1,284 @@
package auth
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBasicAuthFail(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.BasicAuth{
Users: []string{"test"},
}
_, err := NewBasic(context.Background(), next, auth, "authName")
require.Error(t, err)
auth2 := config.BasicAuth{
Users: []string{"test:test"},
}
authMiddleware, err := NewBasic(context.Background(), next, auth2, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, res.StatusCode, "they should be equal")
}
func TestBasicAuthSuccess(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.BasicAuth{
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
}
authMiddleware, err := NewBasic(context.Background(), next, auth, "authName")
require.NoError(t, err)
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, "traefik\n", string(body), "they should be equal")
}
func TestBasicAuthUserHeader(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "test", r.Header["X-Webauth-User"][0], "auth user should be set")
fmt.Fprintln(w, "traefik")
})
auth := config.BasicAuth{
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
HeaderField: "X-Webauth-User",
}
middleware, err := NewBasic(context.Background(), next, auth, "authName")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, "traefik\n", string(body))
}
func TestBasicAuthHeaderRemoved(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get(authorizationHeader))
fmt.Fprintln(w, "traefik")
})
auth := config.BasicAuth{
RemoveHeader: true,
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
}
middleware, err := NewBasic(context.Background(), next, auth, "authName")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.Equal(t, "traefik\n", string(body))
}
func TestBasicAuthHeaderPresent(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.NotEmpty(t, r.Header.Get(authorizationHeader))
fmt.Fprintln(w, "traefik")
})
auth := config.BasicAuth{
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
}
middleware, err := NewBasic(context.Background(), next, auth, "authName")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.Equal(t, "traefik\n", string(body))
}
func TestBasicAuthUsersFromFile(t *testing.T) {
testCases := []struct {
desc string
userFileContent string
expectedUsers map[string]string
givenUsers []string
realm string
}{
{
desc: "Finds the users in the file",
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
givenUsers: []string{},
expectedUsers: map[string]string{"test": "test", "test2": "test2"},
},
{
desc: "Merges given users with users from the file",
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\n",
givenUsers: []string{"test2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0", "test3:$apr1$3rJbDP0q$RfzJiorTk78jQ1EcKqWso0"},
expectedUsers: map[string]string{"test": "test", "test2": "test2", "test3": "test3"},
},
{
desc: "Given users have priority over users in the file",
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
givenUsers: []string{"test2:$apr1$mK.GtItK$ncnLYvNLek0weXdxo68690"},
expectedUsers: map[string]string{"test": "test", "test2": "overridden"},
},
{
desc: "Should authenticate the correct user based on the realm",
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
givenUsers: []string{"test2:$apr1$mK.GtItK$ncnLYvNLek0weXdxo68690"},
expectedUsers: map[string]string{"test": "test", "test2": "overridden"},
realm: "traefik",
},
{
desc: "Should skip comments",
userFileContent: "#Comment\ntest:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
givenUsers: []string{},
expectedUsers: map[string]string{"test": "test", "test2": "test2"},
realm: "traefiker",
},
}
for _, test := range testCases {
if test.desc != "Should skip comments" {
continue
}
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
// Creates the temporary configuration file with the users
usersFile, err := ioutil.TempFile("", "auth-users")
require.NoError(t, err)
defer os.Remove(usersFile.Name())
_, err = usersFile.Write([]byte(test.userFileContent))
require.NoError(t, err)
// Creates the configuration for our Authenticator
authenticatorConfiguration := config.BasicAuth{
Users: test.givenUsers,
UsersFile: usersFile.Name(),
Realm: test.realm,
}
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
authenticator, err := NewBasic(context.Background(), next, authenticatorConfiguration, "authName")
require.NoError(t, err)
ts := httptest.NewServer(authenticator)
defer ts.Close()
for userName, userPwd := range test.expectedUsers {
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth(userName, userPwd)
var res *http.Response
res, err = http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode, "Cannot authenticate user "+userName)
var body []byte
body, err = ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
require.Equal(t, "traefik\n", string(body))
}
// Checks that user foo doesn't work
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("foo", "foo")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
if len(test.realm) > 0 {
require.Equal(t, `Basic realm="`+test.realm+`"`, res.Header.Get("WWW-Authenticate"))
}
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
require.NotContains(t, "traefik", string(body))
})
}
}

View file

@ -0,0 +1,102 @@
package auth
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
goauth "github.com/abbot/go-http-auth"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/middlewares/accesslog"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
)
const (
digestTypeName = "digestAuth"
)
type digestAuth struct {
next http.Handler
auth *goauth.DigestAuth
users map[string]string
headerField string
removeHeader bool
name string
}
// NewDigest creates a digest auth middleware.
func NewDigest(ctx context.Context, next http.Handler, authConfig config.DigestAuth, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, digestTypeName).Debug("Creating middleware")
users, err := getUsers(authConfig.UsersFile, authConfig.Users, digestUserParser)
if err != nil {
return nil, err
}
da := &digestAuth{
next: next,
users: users,
headerField: authConfig.HeaderField,
removeHeader: authConfig.RemoveHeader,
name: name,
}
realm := defaultRealm
if len(authConfig.Realm) > 0 {
realm = authConfig.Realm
}
da.auth = goauth.NewDigestAuthenticator(realm, da.secretDigest)
return da, nil
}
func (d *digestAuth) GetTracingInformation() (string, ext.SpanKindEnum) {
return d.name, tracing.SpanKindNoneEnum
}
func (d *digestAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), d.name, digestTypeName)
if username, _ := d.auth.CheckAuth(req); username == "" {
logger.Debug("Digest authentication failed")
tracing.SetErrorWithEvent(req, "Digest authentication failed")
d.auth.RequireAuth(rw, req)
} else {
logger.Debug("Digest authentication succeeded")
req.URL.User = url.User(username)
logData := accesslog.GetLogData(req)
if logData != nil {
logData.Core[accesslog.ClientUsername] = username
}
if d.headerField != "" {
req.Header[d.headerField] = []string{username}
}
if d.removeHeader {
logger.Debug("Removing the Authorization header")
req.Header.Del(authorizationHeader)
}
d.next.ServeHTTP(rw, req)
}
}
func (d *digestAuth) secretDigest(user, realm string) string {
if secret, ok := d.users[user+":"+realm]; ok {
return secret
}
return ""
}
func digestUserParser(user string) (string, string, error) {
split := strings.Split(user, ":")
if len(split) != 3 {
return "", "", fmt.Errorf("error parsing DigestUser: %v", user)
}
return split[0] + ":" + split[1], split[2], nil
}

View file

@ -0,0 +1,141 @@
package auth
import (
"crypto/md5"
"crypto/rand"
"encoding/hex"
"fmt"
"io"
"net/http"
"strings"
)
const (
algorithm = "algorithm"
authorization = "Authorization"
nonce = "nonce"
opaque = "opaque"
qop = "qop"
realm = "realm"
wwwAuthenticate = "Www-Authenticate"
)
// DigestRequest is a client for digest authentication requests
type digestRequest struct {
client *http.Client
username, password string
nonceCount nonceCount
}
type nonceCount int
func (nc nonceCount) String() string {
return fmt.Sprintf("%08x", int(nc))
}
var wanted = []string{algorithm, nonce, opaque, qop, realm}
// New makes a DigestRequest instance
func newDigestRequest(username, password string, client *http.Client) *digestRequest {
return &digestRequest{
client: client,
username: username,
password: password,
}
}
// Do does requests as http.Do does
func (r *digestRequest) Do(req *http.Request) (*http.Response, error) {
parts, err := r.makeParts(req)
if err != nil {
return nil, err
}
if parts != nil {
req.Header.Set(authorization, r.makeAuthorization(req, parts))
}
return r.client.Do(req)
}
func (r *digestRequest) makeParts(req *http.Request) (map[string]string, error) {
authReq, err := http.NewRequest(req.Method, req.URL.String(), nil)
if err != nil {
return nil, err
}
resp, err := r.client.Do(authReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
return nil, nil
}
if len(resp.Header[wwwAuthenticate]) == 0 {
return nil, fmt.Errorf("headers do not have %s", wwwAuthenticate)
}
headers := strings.Split(resp.Header[wwwAuthenticate][0], ",")
parts := make(map[string]string, len(wanted))
for _, r := range headers {
for _, w := range wanted {
if strings.Contains(r, w) {
parts[w] = strings.Split(r, `"`)[1]
}
}
}
if len(parts) != len(wanted) {
return nil, fmt.Errorf("header is invalid: %+v", parts)
}
return parts, nil
}
func getMD5(texts []string) string {
h := md5.New()
_, _ = io.WriteString(h, strings.Join(texts, ":"))
return hex.EncodeToString(h.Sum(nil))
}
func (r *digestRequest) getNonceCount() string {
r.nonceCount++
return r.nonceCount.String()
}
func (r *digestRequest) makeAuthorization(req *http.Request, parts map[string]string) string {
ha1 := getMD5([]string{r.username, parts[realm], r.password})
ha2 := getMD5([]string{req.Method, req.URL.String()})
cnonce := generateRandom(16)
nc := r.getNonceCount()
response := getMD5([]string{
ha1,
parts[nonce],
nc,
cnonce,
parts[qop],
ha2,
})
return fmt.Sprintf(
`Digest username="%s", realm="%s", nonce="%s", uri="%s", algorithm=%s, qop=%s, nc=%s, cnonce="%s", response="%s", opaque="%s"`,
r.username,
parts[realm],
parts[nonce],
req.URL.String(),
parts[algorithm],
parts[qop],
nc,
cnonce,
response,
parts[opaque],
)
}
// GenerateRandom generates random string
func generateRandom(n int) string {
b := make([]byte, 8)
_, _ = io.ReadFull(rand.Reader, b)
return fmt.Sprintf("%x", b)[:n]
}

View file

@ -0,0 +1,156 @@
package auth
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDigestAuthError(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.DigestAuth{
Users: []string{"test"},
}
_, err := NewDigest(context.Background(), next, auth, "authName")
assert.Error(t, err)
}
func TestDigestAuthFail(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.DigestAuth{
Users: []string{"test:traefik:a2688e031edb4be6a3797f3882655c05"},
}
authMiddleware, err := NewDigest(context.Background(), next, auth, "authName")
require.NoError(t, err)
assert.NotNil(t, authMiddleware, "this should not be nil")
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
client := http.DefaultClient
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
}
func TestDigestAuthUsersFromFile(t *testing.T) {
testCases := []struct {
desc string
userFileContent string
expectedUsers map[string]string
givenUsers []string
realm string
}{
{
desc: "Finds the users in the file",
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\ntest2:traefik:518845800f9e2bfb1f1f740ec24f074e\n",
givenUsers: []string{},
expectedUsers: map[string]string{"test": "test", "test2": "test2"},
},
{
desc: "Merges given users with users from the file",
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\n",
givenUsers: []string{"test2:traefik:518845800f9e2bfb1f1f740ec24f074e", "test3:traefik:c8e9f57ce58ecb4424407f665a91646c"},
expectedUsers: map[string]string{"test": "test", "test2": "test2", "test3": "test3"},
},
{
desc: "Given users have priority over users in the file",
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\ntest2:traefik:518845800f9e2bfb1f1f740ec24f074e\n",
givenUsers: []string{"test2:traefik:8de60a1c52da68ccf41f0c0ffb7c51a0"},
expectedUsers: map[string]string{"test": "test", "test2": "overridden"},
},
{
desc: "Should authenticate the correct user based on the realm",
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\ntest:traefiker:a3d334dff2645b914918de78bec50bf4\n",
givenUsers: []string{},
expectedUsers: map[string]string{"test": "test2"},
realm: "traefiker",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
// Creates the temporary configuration file with the users
usersFile, err := ioutil.TempFile("", "auth-users")
require.NoError(t, err)
defer os.Remove(usersFile.Name())
_, err = usersFile.Write([]byte(test.userFileContent))
require.NoError(t, err)
// Creates the configuration for our Authenticator
authenticatorConfiguration := config.DigestAuth{
Users: test.givenUsers,
UsersFile: usersFile.Name(),
Realm: test.realm,
}
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
authenticator, err := NewDigest(context.Background(), next, authenticatorConfiguration, "authName")
require.NoError(t, err)
ts := httptest.NewServer(authenticator)
defer ts.Close()
for userName, userPwd := range test.expectedUsers {
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
digestRequest := newDigestRequest(userName, userPwd, http.DefaultClient)
var res *http.Response
res, err = digestRequest.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode, "Cannot authenticate user "+userName)
var body []byte
body, err = ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
require.Equal(t, "traefik\n", string(body))
}
// Checks that user foo doesn't work
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
digestRequest := newDigestRequest("foo", "foo", http.DefaultClient)
var res *http.Response
res, err = digestRequest.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
var body []byte
body, err = ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
require.NotContains(t, "traefik", string(body))
})
}
}

View file

@ -0,0 +1,213 @@
package auth
import (
"context"
"crypto/tls"
"fmt"
"io/ioutil"
"net"
"net/http"
"strings"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/utils"
)
const (
xForwardedURI = "X-Forwarded-Uri"
xForwardedMethod = "X-Forwarded-Method"
forwardedTypeName = "ForwardedAuthType"
)
type forwardAuth struct {
address string
authResponseHeaders []string
next http.Handler
name string
tlsConfig *tls.Config
trustForwardHeader bool
}
// NewForward creates a forward auth middleware.
func NewForward(ctx context.Context, next http.Handler, config config.ForwardAuth, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, forwardedTypeName).Debug("Creating middleware")
fa := &forwardAuth{
address: config.Address,
authResponseHeaders: config.AuthResponseHeaders,
next: next,
name: name,
trustForwardHeader: config.TrustForwardHeader,
}
if config.TLS != nil {
tlsConfig, err := config.TLS.CreateTLSConfig()
if err != nil {
return nil, err
}
fa.tlsConfig = tlsConfig
}
return fa, nil
}
func (fa *forwardAuth) GetTracingInformation() (string, ext.SpanKindEnum) {
return fa.name, ext.SpanKindRPCClientEnum
}
func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), fa.name, forwardedTypeName)
// Ensure our request client does not follow redirects
httpClient := http.Client{
CheckRedirect: func(r *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
if fa.tlsConfig != nil {
httpClient.Transport = &http.Transport{
TLSClientConfig: fa.tlsConfig,
}
}
forwardReq, err := http.NewRequest(http.MethodGet, fa.address, nil)
tracing.LogRequest(tracing.GetSpan(req), forwardReq)
if err != nil {
logMessage := fmt.Sprintf("Error calling %s. Cause %s", fa.address, err)
logger.Debug(logMessage)
tracing.SetErrorWithEvent(req, logMessage)
rw.WriteHeader(http.StatusInternalServerError)
return
}
writeHeader(req, forwardReq, fa.trustForwardHeader)
tracing.InjectRequestHeaders(forwardReq)
forwardResponse, forwardErr := httpClient.Do(forwardReq)
if forwardErr != nil {
logMessage := fmt.Sprintf("Error calling %s. Cause: %s", fa.address, forwardErr)
logger.Debug(logMessage)
tracing.SetErrorWithEvent(req, logMessage)
rw.WriteHeader(http.StatusInternalServerError)
return
}
body, readError := ioutil.ReadAll(forwardResponse.Body)
if readError != nil {
logMessage := fmt.Sprintf("Error reading body %s. Cause: %s", fa.address, readError)
logger.Debug(logMessage)
tracing.SetErrorWithEvent(req, logMessage)
rw.WriteHeader(http.StatusInternalServerError)
return
}
defer forwardResponse.Body.Close()
// Pass the forward response's body and selected headers if it
// didn't return a response within the range of [200, 300).
if forwardResponse.StatusCode < http.StatusOK || forwardResponse.StatusCode >= http.StatusMultipleChoices {
logger.Debugf("Remote error %s. StatusCode: %d", fa.address, forwardResponse.StatusCode)
utils.CopyHeaders(rw.Header(), forwardResponse.Header)
utils.RemoveHeaders(rw.Header(), forward.HopHeaders...)
// Grab the location header, if any.
redirectURL, err := forwardResponse.Location()
if err != nil {
if err != http.ErrNoLocation {
logMessage := fmt.Sprintf("Error reading response location header %s. Cause: %s", fa.address, err)
logger.Debug(logMessage)
tracing.SetErrorWithEvent(req, logMessage)
rw.WriteHeader(http.StatusInternalServerError)
return
}
} else if redirectURL.String() != "" {
// Set the location in our response if one was sent back.
rw.Header().Set("Location", redirectURL.String())
}
tracing.LogResponseCode(tracing.GetSpan(req), forwardResponse.StatusCode)
rw.WriteHeader(forwardResponse.StatusCode)
if _, err = rw.Write(body); err != nil {
logger.Error(err)
}
return
}
for _, headerName := range fa.authResponseHeaders {
req.Header.Set(headerName, forwardResponse.Header.Get(headerName))
}
req.RequestURI = req.URL.RequestURI()
fa.next.ServeHTTP(rw, req)
}
func writeHeader(req *http.Request, forwardReq *http.Request, trustForwardHeader bool) {
utils.CopyHeaders(forwardReq.Header, req.Header)
utils.RemoveHeaders(forwardReq.Header, forward.HopHeaders...)
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if trustForwardHeader {
if prior, ok := req.Header[forward.XForwardedFor]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
}
forwardReq.Header.Set(forward.XForwardedFor, clientIP)
}
xMethod := req.Header.Get(xForwardedMethod)
switch {
case xMethod != "" && trustForwardHeader:
forwardReq.Header.Set(xForwardedMethod, xMethod)
case req.Method != "":
forwardReq.Header.Set(xForwardedMethod, req.Method)
default:
forwardReq.Header.Del(xForwardedMethod)
}
xfp := req.Header.Get(forward.XForwardedProto)
switch {
case xfp != "" && trustForwardHeader:
forwardReq.Header.Set(forward.XForwardedProto, xfp)
case req.TLS != nil:
forwardReq.Header.Set(forward.XForwardedProto, "https")
default:
forwardReq.Header.Set(forward.XForwardedProto, "http")
}
if xfp := req.Header.Get(forward.XForwardedPort); xfp != "" && trustForwardHeader {
forwardReq.Header.Set(forward.XForwardedPort, xfp)
}
xfh := req.Header.Get(forward.XForwardedHost)
switch {
case xfh != "" && trustForwardHeader:
forwardReq.Header.Set(forward.XForwardedHost, xfh)
case req.Host != "":
forwardReq.Header.Set(forward.XForwardedHost, req.Host)
default:
forwardReq.Header.Del(forward.XForwardedHost)
}
xfURI := req.Header.Get(xForwardedURI)
switch {
case xfURI != "" && trustForwardHeader:
forwardReq.Header.Set(xForwardedURI, xfURI)
case req.URL.RequestURI() != "":
forwardReq.Header.Set(xForwardedURI, req.URL.RequestURI())
default:
forwardReq.Header.Del(xForwardedURI)
}
}

View file

@ -0,0 +1,393 @@
package auth
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vulcand/oxy/forward"
)
func TestForwardAuthFail(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Forbidden", http.StatusForbidden)
}))
defer server.Close()
middleware, err := NewForward(context.Background(), next, config.ForwardAuth{
Address: server.URL,
}, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusForbidden, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.Equal(t, "Forbidden\n", string(body))
}
func TestForwardAuthSuccess(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Auth-User", "user@example.com")
w.Header().Set("X-Auth-Secret", "secret")
fmt.Fprintln(w, "Success")
}))
defer server.Close()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "user@example.com", r.Header.Get("X-Auth-User"))
assert.Empty(t, r.Header.Get("X-Auth-Secret"))
fmt.Fprintln(w, "traefik")
})
auth := config.ForwardAuth{
Address: server.URL,
AuthResponseHeaders: []string{"X-Auth-User"},
}
middleware, err := NewForward(context.Background(), next, auth, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.Equal(t, "traefik\n", string(body))
}
func TestForwardAuthRedirect(t *testing.T) {
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "http://example.com/redirect-test", http.StatusFound)
}))
defer authTs.Close()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.ForwardAuth{
Address: authTs.URL,
}
authMiddleware, err := NewForward(context.Background(), next, auth, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
client := &http.Client{
CheckRedirect: func(r *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusFound, res.StatusCode)
location, err := res.Location()
require.NoError(t, err)
assert.Equal(t, "http://example.com/redirect-test", location.String())
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.NotEmpty(t, string(body))
}
func TestForwardAuthRemoveHopByHopHeaders(t *testing.T) {
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
headers := w.Header()
for _, header := range forward.HopHeaders {
if header == forward.TransferEncoding {
headers.Add(header, "identity")
} else {
headers.Add(header, "test")
}
}
http.Redirect(w, r, "http://example.com/redirect-test", http.StatusFound)
}))
defer authTs.Close()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.ForwardAuth{
Address: authTs.URL,
}
authMiddleware, err := NewForward(context.Background(), next, auth, "authTest")
assert.NoError(t, err, "there should be no error")
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
client := &http.Client{
CheckRedirect: func(r *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := client.Do(req)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, http.StatusFound, res.StatusCode, "they should be equal")
for _, header := range forward.HopHeaders {
assert.Equal(t, "", res.Header.Get(header), "hop-by-hop header '%s' mustn't be set", header)
}
location, err := res.Location()
assert.NoError(t, err, "there should be no error")
assert.Equal(t, "http://example.com/redirect-test", location.String(), "they should be equal")
body, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err, "there should be no error")
assert.NotEmpty(t, string(body), "there should be something in the body")
}
func TestForwardAuthFailResponseHeaders(t *testing.T) {
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie := &http.Cookie{Name: "example", Value: "testing", Path: "/"}
http.SetCookie(w, cookie)
w.Header().Add("X-Foo", "bar")
http.Error(w, "Forbidden", http.StatusForbidden)
}))
defer authTs.Close()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.ForwardAuth{
Address: authTs.URL,
}
authMiddleware, err := NewForward(context.Background(), next, auth, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusForbidden, res.StatusCode)
require.Len(t, res.Cookies(), 1)
for _, cookie := range res.Cookies() {
assert.Equal(t, "testing", cookie.Value)
}
expectedHeaders := http.Header{
"Content-Length": []string{"10"},
"Content-Type": []string{"text/plain; charset=utf-8"},
"X-Foo": []string{"bar"},
"Set-Cookie": []string{"example=testing; Path=/"},
"X-Content-Type-Options": []string{"nosniff"},
}
assert.Len(t, res.Header, 6)
for key, value := range expectedHeaders {
assert.Equal(t, value, res.Header[key])
}
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.Equal(t, "Forbidden\n", string(body))
}
func Test_writeHeader(t *testing.T) {
testCases := []struct {
name string
headers map[string]string
trustForwardHeader bool
emptyHost bool
expectedHeaders map[string]string
checkForUnexpectedHeaders bool
}{
{
name: "trust Forward Header",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
trustForwardHeader: true,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
},
{
name: "not trust Forward Header",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
trustForwardHeader: false,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "foo.bar",
},
},
{
name: "trust Forward Header with empty Host",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
trustForwardHeader: true,
emptyHost: true,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
},
{
name: "not trust Forward Header with empty Host",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
trustForwardHeader: false,
emptyHost: true,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "",
},
},
{
name: "trust Forward Header with forwarded URI",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
"X-Forwarded-Uri": "/forward?q=1",
},
trustForwardHeader: true,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
"X-Forwarded-Uri": "/forward?q=1",
},
},
{
name: "not trust Forward Header with forward requested URI",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
"X-Forwarded-Uri": "/forward?q=1",
},
trustForwardHeader: false,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "foo.bar",
"X-Forwarded-Uri": "/path?q=1",
},
}, {
name: "trust Forward Header with forwarded request Method",
headers: map[string]string{
"X-Forwarded-Method": "OPTIONS",
},
trustForwardHeader: true,
expectedHeaders: map[string]string{
"X-Forwarded-Method": "OPTIONS",
},
},
{
name: "not trust Forward Header with forward request Method",
headers: map[string]string{
"X-Forwarded-Method": "OPTIONS",
},
trustForwardHeader: false,
expectedHeaders: map[string]string{
"X-Forwarded-Method": "GET",
},
},
{
name: "remove hop-by-hop headers",
headers: map[string]string{
forward.Connection: "Connection",
forward.KeepAlive: "KeepAlive",
forward.ProxyAuthenticate: "ProxyAuthenticate",
forward.ProxyAuthorization: "ProxyAuthorization",
forward.Te: "Te",
forward.Trailers: "Trailers",
forward.TransferEncoding: "TransferEncoding",
forward.Upgrade: "Upgrade",
"X-CustomHeader": "CustomHeader",
},
trustForwardHeader: false,
expectedHeaders: map[string]string{
"X-CustomHeader": "CustomHeader",
"X-Forwarded-Proto": "http",
"X-Forwarded-Host": "foo.bar",
"X-Forwarded-Uri": "/path?q=1",
"X-Forwarded-Method": "GET",
},
checkForUnexpectedHeaders: true,
},
}
for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/path?q=1", nil)
for key, value := range test.headers {
req.Header.Set(key, value)
}
if test.emptyHost {
req.Host = ""
}
forwardReq := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/path?q=1", nil)
writeHeader(req, forwardReq, test.trustForwardHeader)
actualHeaders := forwardReq.Header
expectedHeaders := test.expectedHeaders
for key, value := range expectedHeaders {
assert.Equal(t, value, actualHeaders.Get(key))
actualHeaders.Del(key)
}
if test.checkForUnexpectedHeaders {
for key := range actualHeaders {
assert.Fail(t, "Unexpected header found", key)
}
}
})
}
}

View file

@ -0,0 +1,54 @@
package buffering
import (
"context"
"net/http"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
oxybuffer "github.com/vulcand/oxy/buffer"
)
const (
typeName = "Buffer"
)
type buffer struct {
name string
buffer *oxybuffer.Buffer
}
// New creates a buffering middleware.
func New(ctx context.Context, next http.Handler, config config.Buffering, name string) (http.Handler, error) {
logger := middlewares.GetLogger(ctx, name, typeName)
logger.Debug("Creating middleware")
logger.Debug("Setting up buffering: request limits: %d (mem), %d (max), response limits: %d (mem), %d (max) with retry: '%s'",
config.MemRequestBodyBytes, config.MaxRequestBodyBytes, config.MemResponseBodyBytes, config.MaxResponseBodyBytes, config.RetryExpression)
oxyBuffer, err := oxybuffer.New(
next,
oxybuffer.MemRequestBodyBytes(config.MemRequestBodyBytes),
oxybuffer.MaxRequestBodyBytes(config.MaxRequestBodyBytes),
oxybuffer.MemResponseBodyBytes(config.MemResponseBodyBytes),
oxybuffer.MaxResponseBodyBytes(config.MaxResponseBodyBytes),
oxybuffer.CondSetter(len(config.RetryExpression) > 0, oxybuffer.Retry(config.RetryExpression)),
)
if err != nil {
return nil, err
}
return &buffer{
name: name,
buffer: oxyBuffer,
}, nil
}
func (b *buffer) GetTracingInformation() (string, ext.SpanKindEnum) {
return b.name, tracing.SpanKindNoneEnum
}
func (b *buffer) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
b.buffer.ServeHTTP(rw, req)
}

View file

@ -0,0 +1,26 @@
package chain
import (
"context"
"net/http"
"github.com/containous/alice"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
)
const (
typeName = "Chain"
)
type chainBuilder interface {
BuildChain(ctx context.Context, middlewares []string) *alice.Chain
}
// New creates a chain middleware
func New(ctx context.Context, next http.Handler, config config.Chain, builder chainBuilder, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
middlewareChain := builder.BuildChain(ctx, config.Middlewares)
return middlewareChain.Then(next)
}

View file

@ -0,0 +1,61 @@
package circuitbreaker
import (
"context"
"net/http"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/vulcand/oxy/cbreaker"
)
const (
typeName = "CircuitBreaker"
)
type circuitBreaker struct {
circuitBreaker *cbreaker.CircuitBreaker
name string
}
// New creates a new circuit breaker middleware.
func New(ctx context.Context, next http.Handler, confCircuitBreaker config.CircuitBreaker, name string) (http.Handler, error) {
expression := confCircuitBreaker.Expression
logger := middlewares.GetLogger(ctx, name, typeName)
logger.Debug("Creating middleware")
logger.Debug("Setting up with expression: %s", expression)
oxyCircuitBreaker, err := cbreaker.New(next, expression, createCircuitBreakerOptions(expression))
if err != nil {
return nil, err
}
return &circuitBreaker{
circuitBreaker: oxyCircuitBreaker,
name: name,
}, nil
}
// NewCircuitBreakerOptions returns a new CircuitBreakerOption
func createCircuitBreakerOptions(expression string) cbreaker.CircuitBreakerOption {
return cbreaker.Fallback(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
tracing.SetErrorWithEvent(req, "blocked by circuit-breaker (%q)", expression)
rw.WriteHeader(http.StatusServiceUnavailable)
if _, err := rw.Write([]byte(http.StatusText(http.StatusServiceUnavailable))); err != nil {
log.FromContext(req.Context()).Error(err)
}
}))
}
func (c *circuitBreaker) GetTracingInformation() (string, ext.SpanKindEnum) {
return c.name, tracing.SpanKindNoneEnum
}
func (c *circuitBreaker) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
middlewares.GetLogger(req.Context(), c.name, typeName).Debug("Entering middleware")
c.circuitBreaker.ServeHTTP(rw, req)
}

View file

@ -0,0 +1,58 @@
package compress
import (
"compress/gzip"
"context"
"net/http"
"strings"
"github.com/NYTimes/gziphandler"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/sirupsen/logrus"
)
const (
typeName = "Compress"
)
// Compress is a middleware that allows to compress the response.
type compress struct {
next http.Handler
name string
}
// New creates a new compress middleware.
func New(ctx context.Context, next http.Handler, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
return &compress{
next: next,
name: name,
}, nil
}
func (c *compress) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
contentType := req.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/grpc") {
c.next.ServeHTTP(rw, req)
} else {
gzipHandler(c.next, middlewares.GetLogger(req.Context(), c.name, typeName)).ServeHTTP(rw, req)
}
}
func (c *compress) GetTracingInformation() (string, ext.SpanKindEnum) {
return c.name, tracing.SpanKindNoneEnum
}
func gzipHandler(h http.Handler, logger logrus.FieldLogger) http.Handler {
wrapper, err := gziphandler.GzipHandlerWithOpts(
gziphandler.CompressionLevel(gzip.DefaultCompression),
gziphandler.MinSize(gziphandler.DefaultMinSize))
if err != nil {
logger.Error(err)
}
return wrapper(h)
}

View file

@ -0,0 +1,260 @@
package compress
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"github.com/NYTimes/gziphandler"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
acceptEncodingHeader = "Accept-Encoding"
contentEncodingHeader = "Content-Encoding"
contentTypeHeader = "Content-Type"
varyHeader = "Vary"
gzipValue = "gzip"
)
func TestShouldCompressWhenNoContentEncodingHeader(t *testing.T) {
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
baseBody := generateBytes(gziphandler.DefaultMinSize)
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
_, err := rw.Write(baseBody)
assert.NoError(t, err)
})
handler := &compress{next: next}
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req)
assert.Equal(t, gzipValue, rw.Header().Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, rw.Header().Get(varyHeader))
if assert.ObjectsAreEqualValues(rw.Body.Bytes(), baseBody) {
assert.Fail(t, "expected a compressed body", "got %v", rw.Body.Bytes())
}
}
func TestShouldNotCompressWhenContentEncodingHeader(t *testing.T) {
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
fakeCompressedBody := generateBytes(gziphandler.DefaultMinSize)
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(contentEncodingHeader, gzipValue)
rw.Header().Add(varyHeader, acceptEncodingHeader)
_, err := rw.Write(fakeCompressedBody)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
})
handler := &compress{next: next}
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req)
assert.Equal(t, gzipValue, rw.Header().Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, rw.Header().Get(varyHeader))
assert.EqualValues(t, rw.Body.Bytes(), fakeCompressedBody)
}
func TestShouldNotCompressWhenNoAcceptEncodingHeader(t *testing.T) {
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
fakeBody := generateBytes(gziphandler.DefaultMinSize)
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
_, err := rw.Write(fakeBody)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
})
handler := &compress{next: next}
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req)
assert.Empty(t, rw.Header().Get(contentEncodingHeader))
assert.EqualValues(t, rw.Body.Bytes(), fakeBody)
}
func TestShouldNotCompressWhenGRPC(t *testing.T) {
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
req.Header.Add(contentTypeHeader, "application/grpc")
baseBody := generateBytes(gziphandler.DefaultMinSize)
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
_, err := rw.Write(baseBody)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
})
handler := &compress{next: next}
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req)
assert.Empty(t, rw.Header().Get(acceptEncodingHeader))
assert.Empty(t, rw.Header().Get(contentEncodingHeader))
assert.EqualValues(t, rw.Body.Bytes(), baseBody)
}
func TestIntegrationShouldNotCompress(t *testing.T) {
fakeCompressedBody := generateBytes(100000)
testCases := []struct {
name string
handler http.Handler
expectedStatusCode int
}{
{
name: "when content already compressed",
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(contentEncodingHeader, gzipValue)
rw.Header().Add(varyHeader, acceptEncodingHeader)
_, err := rw.Write(fakeCompressedBody)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}),
expectedStatusCode: http.StatusOK,
},
{
name: "when content already compressed and status code Created",
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(contentEncodingHeader, gzipValue)
rw.Header().Add(varyHeader, acceptEncodingHeader)
rw.WriteHeader(http.StatusCreated)
_, err := rw.Write(fakeCompressedBody)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}),
expectedStatusCode: http.StatusCreated,
},
}
for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {
compress := &compress{next: test.handler}
ts := httptest.NewServer(compress)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, test.expectedStatusCode, resp.StatusCode)
assert.Equal(t, gzipValue, resp.Header.Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, resp.Header.Get(varyHeader))
body, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
assert.EqualValues(t, fakeCompressedBody, body)
})
}
}
func TestShouldWriteHeaderWhenFlush(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(contentEncodingHeader, gzipValue)
rw.Header().Add(varyHeader, acceptEncodingHeader)
rw.WriteHeader(http.StatusUnauthorized)
rw.(http.Flusher).Flush()
_, err := rw.Write([]byte("short"))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
})
handler := &compress{next: next}
ts := httptest.NewServer(handler)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.Equal(t, gzipValue, resp.Header.Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, resp.Header.Get(varyHeader))
}
func TestIntegrationShouldCompress(t *testing.T) {
fakeBody := generateBytes(100000)
testCases := []struct {
name string
handler http.Handler
expectedStatusCode int
}{
{
name: "when AcceptEncoding header is present",
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
_, err := rw.Write(fakeBody)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}),
expectedStatusCode: http.StatusOK,
},
{
name: "when AcceptEncoding header is present and status code Created",
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusCreated)
_, err := rw.Write(fakeBody)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}),
expectedStatusCode: http.StatusCreated,
},
}
for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {
compress := &compress{next: test.handler}
ts := httptest.NewServer(compress)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, test.expectedStatusCode, resp.StatusCode)
assert.Equal(t, gzipValue, resp.Header.Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, resp.Header.Get(varyHeader))
body, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
if assert.ObjectsAreEqualValues(body, fakeBody) {
assert.Fail(t, "expected a compressed body", "got %v", body)
}
})
}
}
func generateBytes(len int) []byte {
var value []byte
for i := 0; i < len; i++ {
value = append(value, 0x61+byte(i))
}
return value
}

View file

@ -0,0 +1,248 @@
package customerrors
import (
"bufio"
"bytes"
"context"
"fmt"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/containous/traefik/old/types"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/sirupsen/logrus"
"github.com/vulcand/oxy/utils"
)
// Compile time validation that the response recorder implements http interfaces correctly.
var _ middlewares.Stateful = &responseRecorderWithCloseNotify{}
const (
typeName = "customError"
backendURL = "http://0.0.0.0"
)
type serviceBuilder interface {
BuildHTTP(ctx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error)
}
// customErrors is a middleware that provides the custom error pages..
type customErrors struct {
name string
next http.Handler
backendHandler http.Handler
httpCodeRanges types.HTTPCodeRanges
backendQuery string
}
// New creates a new custom error pages middleware.
func New(ctx context.Context, next http.Handler, config config.ErrorPage, serviceBuilder serviceBuilder, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
httpCodeRanges, err := types.NewHTTPCodeRanges(config.Status)
if err != nil {
return nil, err
}
backend, err := serviceBuilder.BuildHTTP(ctx, config.Service, nil)
if err != nil {
return nil, err
}
return &customErrors{
name: name,
next: next,
backendHandler: backend,
httpCodeRanges: httpCodeRanges,
backendQuery: config.Query,
}, nil
}
func (c *customErrors) GetTracingInformation() (string, ext.SpanKindEnum) {
return c.name, tracing.SpanKindNoneEnum
}
func (c *customErrors) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), c.name, typeName)
if c.backendHandler == nil {
logger.Error("Error pages: no backend handler.")
tracing.SetErrorWithEvent(req, "Error pages: no backend handler.")
c.next.ServeHTTP(rw, req)
return
}
recorder := newResponseRecorder(rw, middlewares.GetLogger(context.Background(), "test", typeName))
c.next.ServeHTTP(recorder, req)
// check the recorder code against the configured http status code ranges
for _, block := range c.httpCodeRanges {
if recorder.GetCode() >= block[0] && recorder.GetCode() <= block[1] {
logger.Errorf("Caught HTTP Status Code %d, returning error page", recorder.GetCode())
var query string
if len(c.backendQuery) > 0 {
query = "/" + strings.TrimPrefix(c.backendQuery, "/")
query = strings.Replace(query, "{status}", strconv.Itoa(recorder.GetCode()), -1)
}
pageReq, err := newRequest(backendURL + query)
if err != nil {
logger.Error(err)
rw.WriteHeader(recorder.GetCode())
_, err = fmt.Fprint(rw, http.StatusText(recorder.GetCode()))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
return
}
recorderErrorPage := newResponseRecorder(rw, middlewares.GetLogger(context.Background(), "test", typeName))
utils.CopyHeaders(pageReq.Header, req.Header)
c.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context()))
utils.CopyHeaders(rw.Header(), recorderErrorPage.Header())
rw.WriteHeader(recorder.GetCode())
if _, err = rw.Write(recorderErrorPage.GetBody().Bytes()); err != nil {
logger.Error(err)
}
return
}
}
// did not catch a configured status code so proceed with the request
utils.CopyHeaders(rw.Header(), recorder.Header())
rw.WriteHeader(recorder.GetCode())
_, err := rw.Write(recorder.GetBody().Bytes())
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}
func newRequest(baseURL string) (*http.Request, error) {
u, err := url.Parse(baseURL)
if err != nil {
return nil, fmt.Errorf("error pages: error when parse URL: %v", err)
}
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
if err != nil {
return nil, fmt.Errorf("error pages: error when create query: %v", err)
}
req.RequestURI = u.RequestURI()
return req, nil
}
type responseRecorder interface {
http.ResponseWriter
http.Flusher
GetCode() int
GetBody() *bytes.Buffer
IsStreamingResponseStarted() bool
}
// newResponseRecorder returns an initialized responseRecorder.
func newResponseRecorder(rw http.ResponseWriter, logger logrus.FieldLogger) responseRecorder {
recorder := &responseRecorderWithoutCloseNotify{
HeaderMap: make(http.Header),
Body: new(bytes.Buffer),
Code: http.StatusOK,
responseWriter: rw,
logger: logger,
}
if _, ok := rw.(http.CloseNotifier); ok {
return &responseRecorderWithCloseNotify{recorder}
}
return recorder
}
// responseRecorderWithoutCloseNotify is an implementation of http.ResponseWriter that
// records its mutations for later inspection.
type responseRecorderWithoutCloseNotify struct {
Code int // the HTTP response code from WriteHeader
HeaderMap http.Header // the HTTP response headers
Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
responseWriter http.ResponseWriter
err error
streamingResponseStarted bool
logger logrus.FieldLogger
}
type responseRecorderWithCloseNotify struct {
*responseRecorderWithoutCloseNotify
}
// CloseNotify returns a channel that receives at most a
// single value (true) when the client connection has gone away.
func (r *responseRecorderWithCloseNotify) CloseNotify() <-chan bool {
return r.responseWriter.(http.CloseNotifier).CloseNotify()
}
// Header returns the response headers.
func (r *responseRecorderWithoutCloseNotify) Header() http.Header {
if r.HeaderMap == nil {
r.HeaderMap = make(http.Header)
}
return r.HeaderMap
}
func (r *responseRecorderWithoutCloseNotify) GetCode() int {
return r.Code
}
func (r *responseRecorderWithoutCloseNotify) GetBody() *bytes.Buffer {
return r.Body
}
func (r *responseRecorderWithoutCloseNotify) IsStreamingResponseStarted() bool {
return r.streamingResponseStarted
}
// Write always succeeds and writes to rw.Body, if not nil.
func (r *responseRecorderWithoutCloseNotify) Write(buf []byte) (int, error) {
if r.err != nil {
return 0, r.err
}
return r.Body.Write(buf)
}
// WriteHeader sets rw.Code.
func (r *responseRecorderWithoutCloseNotify) WriteHeader(code int) {
r.Code = code
}
// Hijack hijacks the connection
func (r *responseRecorderWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return r.responseWriter.(http.Hijacker).Hijack()
}
// Flush sends any buffered data to the client.
func (r *responseRecorderWithoutCloseNotify) Flush() {
if !r.streamingResponseStarted {
utils.CopyHeaders(r.responseWriter.Header(), r.Header())
r.responseWriter.WriteHeader(r.Code)
r.streamingResponseStarted = true
}
_, err := r.responseWriter.Write(r.Body.Bytes())
if err != nil {
r.logger.Errorf("Error writing response in responseRecorder: %v", err)
r.err = err
}
r.Body.Reset()
if flusher, ok := r.responseWriter.(http.Flusher); ok {
flusher.Flush()
}
}

View file

@ -0,0 +1,176 @@
package customerrors
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestHandler(t *testing.T) {
testCases := []struct {
desc string
errorPage *config.ErrorPage
backendCode int
backendErrorHandler http.HandlerFunc
validate func(t *testing.T, recorder *httptest.ResponseRecorder)
}{
{
desc: "no error",
errorPage: &config.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusOK,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusOK))
},
},
{
desc: "in the range",
errorPage: &config.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusInternalServerError,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusInternalServerError, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My error page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
{
desc: "not in the range",
errorPage: &config.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusBadGateway,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusBadGateway, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusBadGateway))
assert.NotContains(t, recorder.Body.String(), "Test Server", "Should return the oops page since we have not configured the 502 code")
},
},
{
desc: "query replacement",
errorPage: &config.ErrorPage{Service: "error", Query: "/{status}", Status: []string{"503-503"}},
backendCode: http.StatusServiceUnavailable,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI == "/503" {
fmt.Fprintln(w, "My 503 page.")
} else {
fmt.Fprintln(w, "Failed")
}
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My 503 page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
{
desc: "Single code",
errorPage: &config.ErrorPage{Service: "error", Query: "/{status}", Status: []string{"503"}},
backendCode: http.StatusServiceUnavailable,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI == "/503" {
fmt.Fprintln(w, "My 503 page.")
} else {
fmt.Fprintln(w, "Failed")
}
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My 503 page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
serviceBuilderMock := &mockServiceBuilder{handler: test.backendErrorHandler}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(test.backendCode)
fmt.Fprintln(w, http.StatusText(test.backendCode))
})
errorPageHandler, err := New(context.Background(), handler, *test.errorPage, serviceBuilderMock, "test")
require.NoError(t, err)
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost/test", nil)
recorder := httptest.NewRecorder()
errorPageHandler.ServeHTTP(recorder, req)
test.validate(t, recorder)
})
}
}
type mockServiceBuilder struct {
handler http.Handler
}
func (m *mockServiceBuilder) BuildHTTP(_ context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) {
return m.handler, nil
}
func TestNewResponseRecorder(t *testing.T) {
testCases := []struct {
desc string
rw http.ResponseWriter
expected http.ResponseWriter
}{
{
desc: "Without Close Notify",
rw: httptest.NewRecorder(),
expected: &responseRecorderWithoutCloseNotify{},
},
{
desc: "With Close Notify",
rw: &mockRWCloseNotify{},
expected: &responseRecorderWithCloseNotify{},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
rec := newResponseRecorder(test.rw, middlewares.GetLogger(context.Background(), "test", typeName))
assert.IsType(t, rec, test.expected)
})
}
}
type mockRWCloseNotify struct{}
func (m *mockRWCloseNotify) CloseNotify() <-chan bool {
panic("implement me")
}
func (m *mockRWCloseNotify) Header() http.Header {
panic("implement me")
}
func (m *mockRWCloseNotify) Write([]byte) (int, error) {
panic("implement me")
}
func (m *mockRWCloseNotify) WriteHeader(int) {
panic("implement me")
}

View file

@ -0,0 +1,33 @@
package emptybackendhandler
import (
"net/http"
"github.com/containous/traefik/pkg/healthcheck"
)
// EmptyBackend is a middleware that checks whether the current Backend
// has at least one active Server in respect to the healthchecks and if this
// is not the case, it will stop the middleware chain and respond with 503.
type emptyBackend struct {
next healthcheck.BalancerHandler
}
// New creates a new EmptyBackend middleware.
func New(lb healthcheck.BalancerHandler) http.Handler {
return &emptyBackend{next: lb}
}
// ServeHTTP responds with 503 when there is no active Server and otherwise
// invokes the next handler in the middleware chain.
func (e *emptyBackend) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if len(e.next.Servers()) == 0 {
rw.WriteHeader(http.StatusServiceUnavailable)
_, err := rw.Write([]byte(http.StatusText(http.StatusServiceUnavailable)))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
} else {
e.next.ServeHTTP(rw, req)
}
}

View file

@ -0,0 +1,81 @@
package emptybackendhandler
import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/vulcand/oxy/roundrobin"
)
func TestEmptyBackendHandler(t *testing.T) {
testCases := []struct {
amountServer int
expectedStatusCode int
}{
{
amountServer: 0,
expectedStatusCode: http.StatusServiceUnavailable,
},
{
amountServer: 1,
expectedStatusCode: http.StatusOK,
},
}
for _, test := range testCases {
test := test
t.Run(fmt.Sprintf("amount servers %d", test.amountServer), func(t *testing.T) {
t.Parallel()
handler := New(&healthCheckLoadBalancer{amountServer: test.amountServer})
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
handler.ServeHTTP(recorder, req)
assert.Equal(t, test.expectedStatusCode, recorder.Result().StatusCode)
})
}
}
type healthCheckLoadBalancer struct {
amountServer int
}
func (lb *healthCheckLoadBalancer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
func (lb *healthCheckLoadBalancer) Servers() []*url.URL {
servers := make([]*url.URL, lb.amountServer)
for i := 0; i < lb.amountServer; i++ {
servers = append(servers, testhelpers.MustParseURL("http://localhost"))
}
return servers
}
func (lb *healthCheckLoadBalancer) RemoveServer(u *url.URL) error {
return nil
}
func (lb *healthCheckLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error {
return nil
}
func (lb *healthCheckLoadBalancer) ServerWeight(u *url.URL) (int, bool) {
return 0, false
}
func (lb *healthCheckLoadBalancer) NextServer() (*url.URL, error) {
return nil, nil
}
func (lb *healthCheckLoadBalancer) Next() http.Handler {
return nil
}

View file

@ -0,0 +1,51 @@
package forwardedheaders
import (
"net/http"
"github.com/containous/traefik/pkg/ip"
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/utils"
)
// XForwarded filter for XForwarded headers.
type XForwarded struct {
insecure bool
trustedIps []string
ipChecker *ip.Checker
next http.Handler
}
// NewXForwarded creates a new XForwarded.
func NewXForwarded(insecure bool, trustedIps []string, next http.Handler) (*XForwarded, error) {
var ipChecker *ip.Checker
if len(trustedIps) > 0 {
var err error
ipChecker, err = ip.NewChecker(trustedIps)
if err != nil {
return nil, err
}
}
return &XForwarded{
insecure: insecure,
trustedIps: trustedIps,
ipChecker: ipChecker,
next: next,
}, nil
}
func (x *XForwarded) isTrustedIP(ip string) bool {
if x.ipChecker == nil {
return false
}
return x.ipChecker.IsAuthorized(ip) == nil
}
func (x *XForwarded) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !x.insecure && !x.isTrustedIP(r.RemoteAddr) {
utils.RemoveHeaders(r.Header, forward.XHeaders...)
}
x.next.ServeHTTP(w, r)
}

View file

@ -0,0 +1,128 @@
package forwardedheaders
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestServeHTTP(t *testing.T) {
testCases := []struct {
desc string
insecure bool
trustedIps []string
incomingHeaders map[string]string
remoteAddr string
expectedHeaders map[string]string
}{
{
desc: "all Empty",
insecure: true,
trustedIps: nil,
remoteAddr: "",
incomingHeaders: map[string]string{},
expectedHeaders: map[string]string{
"X-Forwarded-for": "",
},
},
{
desc: "insecure true with incoming X-Forwarded-For",
insecure: true,
trustedIps: nil,
remoteAddr: "",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
},
{
desc: "insecure false with incoming X-Forwarded-For",
insecure: false,
trustedIps: nil,
remoteAddr: "",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "",
},
},
{
desc: "insecure false with incoming X-Forwarded-For and valid Trusted Ips",
insecure: false,
trustedIps: []string{"10.0.1.100"},
remoteAddr: "10.0.1.100:80",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
},
{
desc: "insecure false with incoming X-Forwarded-For and invalid Trusted Ips",
insecure: false,
trustedIps: []string{"10.0.1.100"},
remoteAddr: "10.0.1.101:80",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "",
},
},
{
desc: "insecure false with incoming X-Forwarded-For and valid Trusted Ips CIDR",
insecure: false,
trustedIps: []string{"1.2.3.4/24"},
remoteAddr: "1.2.3.156:80",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
},
{
desc: "insecure false with incoming X-Forwarded-For and invalid Trusted Ips CIDR",
insecure: false,
trustedIps: []string{"1.2.3.4/24"},
remoteAddr: "10.0.1.101:80",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "",
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
req, err := http.NewRequest(http.MethodGet, "", nil)
require.NoError(t, err)
req.RemoteAddr = test.remoteAddr
for k, v := range test.incomingHeaders {
req.Header.Set(k, v)
}
m, err := NewXForwarded(test.insecure, test.trustedIps, http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
require.NoError(t, err)
m.ServeHTTP(nil, req)
for k, v := range test.expectedHeaders {
assert.Equal(t, v, req.Header.Get(k))
}
})
}
}

View file

@ -0,0 +1,35 @@
package middlewares
import (
"net/http"
"github.com/containous/traefik/pkg/safe"
)
// HTTPHandlerSwitcher allows hot switching of http.ServeMux
type HTTPHandlerSwitcher struct {
handler *safe.Safe
}
// NewHandlerSwitcher builds a new instance of HTTPHandlerSwitcher
func NewHandlerSwitcher(newHandler http.Handler) (hs *HTTPHandlerSwitcher) {
return &HTTPHandlerSwitcher{
handler: safe.New(newHandler),
}
}
func (h *HTTPHandlerSwitcher) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
handlerBackup := h.handler.Get().(http.Handler)
handlerBackup.ServeHTTP(rw, req)
}
// GetHandler returns the current http.ServeMux
func (h *HTTPHandlerSwitcher) GetHandler() (newHandler http.Handler) {
handler := h.handler.Get().(http.Handler)
return handler
}
// UpdateHandler safely updates the current http.ServeMux with a new one
func (h *HTTPHandlerSwitcher) UpdateHandler(newHandler http.Handler) {
h.handler.Set(newHandler)
}

View file

@ -0,0 +1,134 @@
// Package headers Middleware based on https://github.com/unrolled/secure.
package headers
import (
"context"
"errors"
"net/http"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/unrolled/secure"
)
const (
typeName = "Headers"
)
type headers struct {
name string
handler http.Handler
}
// New creates a Headers middleware.
func New(ctx context.Context, next http.Handler, config config.Headers, name string) (http.Handler, error) {
// HeaderMiddleware -> SecureMiddleWare -> next
logger := middlewares.GetLogger(ctx, name, typeName)
logger.Debug("Creating middleware")
if !config.HasSecureHeadersDefined() && !config.HasCustomHeadersDefined() {
return nil, errors.New("headers configuration not valid")
}
var handler http.Handler
nextHandler := next
if config.HasSecureHeadersDefined() {
logger.Debug("Setting up secureHeaders from %v", config)
handler = newSecure(next, config)
nextHandler = handler
}
if config.HasCustomHeadersDefined() {
logger.Debug("Setting up customHeaders from %v", config)
handler = newHeader(nextHandler, config)
}
return &headers{
handler: handler,
name: name,
}, nil
}
func (h *headers) GetTracingInformation() (string, ext.SpanKindEnum) {
return h.name, tracing.SpanKindNoneEnum
}
func (h *headers) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
h.handler.ServeHTTP(rw, req)
}
type secureHeader struct {
next http.Handler
secure *secure.Secure
}
// newSecure constructs a new secure instance with supplied options.
func newSecure(next http.Handler, headers config.Headers) *secureHeader {
opt := secure.Options{
BrowserXssFilter: headers.BrowserXSSFilter,
ContentTypeNosniff: headers.ContentTypeNosniff,
ForceSTSHeader: headers.ForceSTSHeader,
FrameDeny: headers.FrameDeny,
IsDevelopment: headers.IsDevelopment,
SSLRedirect: headers.SSLRedirect,
SSLForceHost: headers.SSLForceHost,
SSLTemporaryRedirect: headers.SSLTemporaryRedirect,
STSIncludeSubdomains: headers.STSIncludeSubdomains,
STSPreload: headers.STSPreload,
ContentSecurityPolicy: headers.ContentSecurityPolicy,
CustomBrowserXssValue: headers.CustomBrowserXSSValue,
CustomFrameOptionsValue: headers.CustomFrameOptionsValue,
PublicKey: headers.PublicKey,
ReferrerPolicy: headers.ReferrerPolicy,
SSLHost: headers.SSLHost,
AllowedHosts: headers.AllowedHosts,
HostsProxyHeaders: headers.HostsProxyHeaders,
SSLProxyHeaders: headers.SSLProxyHeaders,
STSSeconds: headers.STSSeconds,
}
return &secureHeader{
next: next,
secure: secure.New(opt),
}
}
func (s secureHeader) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
s.secure.HandlerFuncWithNextForRequestOnly(rw, req, s.next.ServeHTTP)
}
// Header is a middleware that helps setup a few basic security features. A single headerOptions struct can be
// provided to configure which features should be enabled, and the ability to override a few of the default values.
type header struct {
next http.Handler
// If Custom request headers are set, these will be added to the request
customRequestHeaders map[string]string
}
// NewHeader constructs a new header instance from supplied frontend header struct.
func newHeader(next http.Handler, headers config.Headers) *header {
return &header{
next: next,
customRequestHeaders: headers.CustomRequestHeaders,
}
}
func (s *header) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
s.modifyRequestHeaders(req)
s.next.ServeHTTP(rw, req)
}
// modifyRequestHeaders set or delete request headers.
func (s *header) modifyRequestHeaders(req *http.Request) {
// Loop through Custom request headers
for header, value := range s.customRequestHeaders {
if value == "" {
req.Header.Del(header)
} else {
req.Header.Set(header, value)
}
}
}

View file

@ -0,0 +1,190 @@
package headers
// Middleware tests based on https://github.com/unrolled/secure
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCustomRequestHeader(t *testing.T) {
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
header := newHeader(emptyHandler, config.Headers{
CustomRequestHeaders: map[string]string{
"X-Custom-Request-Header": "test_request",
},
})
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
header.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"))
}
func TestCustomRequestHeaderEmptyValue(t *testing.T) {
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
header := newHeader(emptyHandler, config.Headers{
CustomRequestHeaders: map[string]string{
"X-Custom-Request-Header": "test_request",
},
})
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
header.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"))
header = newHeader(emptyHandler, config.Headers{
CustomRequestHeaders: map[string]string{
"X-Custom-Request-Header": "",
},
})
header.ServeHTTP(res, req)
assert.Equal(t, http.StatusOK, res.Code)
assert.Equal(t, "", req.Header.Get("X-Custom-Request-Header"))
}
func TestSecureHeader(t *testing.T) {
testCases := []struct {
desc string
fromHost string
expected int
}{
{
desc: "Should accept the request when given a host that is in the list",
fromHost: "foo.com",
expected: http.StatusOK,
},
{
desc: "Should refuse the request when no host is given",
fromHost: "",
expected: http.StatusInternalServerError,
},
{
desc: "Should refuse the request when no matching host is given",
fromHost: "boo.com",
expected: http.StatusInternalServerError,
},
}
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
header, err := New(context.Background(), emptyHandler, config.Headers{
AllowedHosts: []string{"foo.com", "bar.com"},
}, "foo")
require.NoError(t, err)
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
req.Host = test.fromHost
header.ServeHTTP(res, req)
assert.Equal(t, test.expected, res.Code)
})
}
}
func TestSSLForceHost(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
_, _ = rw.Write([]byte("OK"))
})
testCases := []struct {
desc string
host string
secureMiddleware *secureHeader
expected int
}{
{
desc: "http should return a 301",
host: "http://powpow.example.com",
secureMiddleware: newSecure(next, config.Headers{
SSLRedirect: true,
SSLForceHost: true,
SSLHost: "powpow.example.com",
}),
expected: http.StatusMovedPermanently,
},
{
desc: "http sub domain should return a 301",
host: "http://www.powpow.example.com",
secureMiddleware: newSecure(next, config.Headers{
SSLRedirect: true,
SSLForceHost: true,
SSLHost: "powpow.example.com",
}),
expected: http.StatusMovedPermanently,
},
{
desc: "https should return a 200",
host: "https://powpow.example.com",
secureMiddleware: newSecure(next, config.Headers{
SSLRedirect: true,
SSLForceHost: true,
SSLHost: "powpow.example.com",
}),
expected: http.StatusOK,
},
{
desc: "https sub domain should return a 301",
host: "https://www.powpow.example.com",
secureMiddleware: newSecure(next, config.Headers{
SSLRedirect: true,
SSLForceHost: true,
SSLHost: "powpow.example.com",
}),
expected: http.StatusMovedPermanently,
},
{
desc: "http without force host and sub domain should return a 301",
host: "http://www.powpow.example.com",
secureMiddleware: newSecure(next, config.Headers{
SSLRedirect: true,
SSLForceHost: false,
SSLHost: "powpow.example.com",
}),
expected: http.StatusMovedPermanently,
},
{
desc: "https without force host and sub domain should return a 301",
host: "https://www.powpow.example.com",
secureMiddleware: newSecure(next, config.Headers{
SSLRedirect: true,
SSLForceHost: false,
SSLHost: "powpow.example.com",
}),
expected: http.StatusOK,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
req := testhelpers.MustNewRequest(http.MethodGet, test.host, nil)
rw := httptest.NewRecorder()
test.secureMiddleware.ServeHTTP(rw, req)
assert.Equal(t, test.expected, rw.Result().StatusCode)
})
}
}

View file

@ -0,0 +1,85 @@
package ipwhitelist
import (
"context"
"fmt"
"net/http"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/ip"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
const (
typeName = "IPWhiteLister"
)
// ipWhiteLister is a middleware that provides Checks of the Requesting IP against a set of Whitelists
type ipWhiteLister struct {
next http.Handler
whiteLister *ip.Checker
strategy ip.Strategy
name string
}
// New builds a new IPWhiteLister given a list of CIDR-Strings to whitelist
func New(ctx context.Context, next http.Handler, config config.IPWhiteList, name string) (http.Handler, error) {
logger := middlewares.GetLogger(ctx, name, typeName)
logger.Debug("Creating middleware")
if len(config.SourceRange) == 0 {
return nil, errors.New("sourceRange is empty, IPWhiteLister not created")
}
checker, err := ip.NewChecker(config.SourceRange)
if err != nil {
return nil, fmt.Errorf("cannot parse CIDR whitelist %s: %v", config.SourceRange, err)
}
strategy, err := config.IPStrategy.Get()
if err != nil {
return nil, err
}
logger.Debugf("Setting up IPWhiteLister with sourceRange: %s", config.SourceRange)
return &ipWhiteLister{
strategy: strategy,
whiteLister: checker,
next: next,
name: name,
}, nil
}
func (wl *ipWhiteLister) GetTracingInformation() (string, ext.SpanKindEnum) {
return wl.name, tracing.SpanKindNoneEnum
}
func (wl *ipWhiteLister) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), wl.name, typeName)
err := wl.whiteLister.IsAuthorized(wl.strategy.GetIP(req))
if err != nil {
logMessage := fmt.Sprintf("rejecting request %+v: %v", req, err)
logger.Debug(logMessage)
tracing.SetErrorWithEvent(req, logMessage)
reject(logger, rw)
return
}
logger.Debugf("Accept %s: %+v", wl.strategy.GetIP(req), req)
wl.next.ServeHTTP(rw, req)
}
func reject(logger logrus.FieldLogger, rw http.ResponseWriter) {
statusCode := http.StatusForbidden
rw.WriteHeader(statusCode)
_, err := rw.Write([]byte(http.StatusText(statusCode)))
if err != nil {
logger.Error(err)
}
}

View file

@ -0,0 +1,100 @@
package ipwhitelist
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewIPWhiteLister(t *testing.T) {
testCases := []struct {
desc string
whiteList config.IPWhiteList
expectedError bool
}{
{
desc: "invalid IP",
whiteList: config.IPWhiteList{
SourceRange: []string{"foo"},
},
expectedError: true,
},
{
desc: "valid IP",
whiteList: config.IPWhiteList{
SourceRange: []string{"10.10.10.10"},
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
whiteLister, err := New(context.Background(), next, test.whiteList, "traefikTest")
if test.expectedError {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.NotNil(t, whiteLister)
}
})
}
}
func TestIPWhiteLister_ServeHTTP(t *testing.T) {
testCases := []struct {
desc string
whiteList config.IPWhiteList
remoteAddr string
expected int
}{
{
desc: "authorized with remote address",
whiteList: config.IPWhiteList{
SourceRange: []string{"20.20.20.20"},
},
remoteAddr: "20.20.20.20:1234",
expected: 200,
},
{
desc: "non authorized with remote address",
whiteList: config.IPWhiteList{
SourceRange: []string{"20.20.20.20"},
},
remoteAddr: "20.20.20.21:1234",
expected: 403,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
whiteLister, err := New(context.Background(), next, test.whiteList, "traefikTest")
require.NoError(t, err)
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://10.10.10.10", nil)
if len(test.remoteAddr) > 0 {
req.RemoteAddr = test.remoteAddr
}
whiteLister.ServeHTTP(recorder, req)
assert.Equal(t, test.expected, recorder.Code)
})
}
}

View file

@ -0,0 +1,48 @@
package maxconnection
import (
"context"
"fmt"
"net/http"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/vulcand/oxy/connlimit"
"github.com/vulcand/oxy/utils"
)
const (
typeName = "MaxConnection"
)
type maxConnection struct {
handler http.Handler
name string
}
// New creates a max connection middleware.
func New(ctx context.Context, next http.Handler, maxConns config.MaxConn, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
extractFunc, err := utils.NewExtractor(maxConns.ExtractorFunc)
if err != nil {
return nil, fmt.Errorf("error creating connection limit: %v", err)
}
handler, err := connlimit.New(next, extractFunc, maxConns.Amount)
if err != nil {
return nil, fmt.Errorf("error creating connection limit: %v", err)
}
return &maxConnection{handler: handler, name: name}, nil
}
func (mc *maxConnection) GetTracingInformation() (string, ext.SpanKindEnum) {
return mc.name, tracing.SpanKindNoneEnum
}
func (mc *maxConnection) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
mc.handler.ServeHTTP(rw, req)
}

View file

@ -0,0 +1,13 @@
package middlewares
import (
"context"
"github.com/containous/traefik/pkg/log"
"github.com/sirupsen/logrus"
)
// GetLogger creates a logger configured with the middleware fields.
func GetLogger(ctx context.Context, middleware string, middlewareType string) logrus.FieldLogger {
return log.FromContext(ctx).WithField(log.MiddlewareName, middleware).WithField(log.MiddlewareType, middlewareType)
}

View file

@ -0,0 +1,294 @@
package passtlsclientcert
import (
"context"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/sirupsen/logrus"
)
const (
xForwardedTLSClientCert = "X-Forwarded-Tls-Client-Cert"
xForwardedTLSClientCertInfo = "X-Forwarded-Tls-Client-Cert-Info"
typeName = "PassClientTLSCert"
)
var attributeTypeNames = map[string]string{
"0.9.2342.19200300.100.1.25": "DC", // Domain component OID - RFC 2247
}
// DistinguishedNameOptions is a struct for specifying the configuration for the distinguished name info.
type DistinguishedNameOptions struct {
CommonName bool
CountryName bool
DomainComponent bool
LocalityName bool
OrganizationName bool
SerialNumber bool
StateOrProvinceName bool
}
func newDistinguishedNameOptions(info *config.TLSCLientCertificateDNInfo) *DistinguishedNameOptions {
if info == nil {
return nil
}
return &DistinguishedNameOptions{
CommonName: info.CommonName,
CountryName: info.Country,
DomainComponent: info.DomainComponent,
LocalityName: info.Locality,
OrganizationName: info.Organization,
SerialNumber: info.SerialNumber,
StateOrProvinceName: info.Province,
}
}
// passTLSClientCert is a middleware that helps setup a few tls info features.
type passTLSClientCert struct {
next http.Handler
name string
pem bool // pass the sanitized pem to the backend in a specific header
info *tlsClientCertificateInfo // pass selected information from the client certificate
}
// New constructs a new PassTLSClientCert instance from supplied frontend header struct.
func New(ctx context.Context, next http.Handler, config config.PassTLSClientCert, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
return &passTLSClientCert{
next: next,
name: name,
pem: config.PEM,
info: newTLSClientInfo(config.Info),
}, nil
}
// tlsClientCertificateInfo is a struct for specifying the configuration for the passTLSClientCert middleware.
type tlsClientCertificateInfo struct {
notAfter bool
notBefore bool
sans bool
subject *DistinguishedNameOptions
issuer *DistinguishedNameOptions
}
func newTLSClientInfo(info *config.TLSClientCertificateInfo) *tlsClientCertificateInfo {
if info == nil {
return nil
}
return &tlsClientCertificateInfo{
issuer: newDistinguishedNameOptions(info.Issuer),
notAfter: info.NotAfter,
notBefore: info.NotBefore,
subject: newDistinguishedNameOptions(info.Subject),
sans: info.Sans,
}
}
func (p *passTLSClientCert) GetTracingInformation() (string, ext.SpanKindEnum) {
return p.name, tracing.SpanKindNoneEnum
}
func (p *passTLSClientCert) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), p.name, typeName)
p.modifyRequestHeaders(logger, req)
p.next.ServeHTTP(rw, req)
}
func getDNInfo(prefix string, options *DistinguishedNameOptions, cs *pkix.Name) string {
if options == nil {
return ""
}
content := &strings.Builder{}
// Manage non standard attributes
for _, name := range cs.Names {
// Domain Component - RFC 2247
if options.DomainComponent && attributeTypeNames[name.Type.String()] == "DC" {
content.WriteString(fmt.Sprintf("DC=%s,", name.Value))
}
}
if options.CountryName {
writeParts(content, cs.Country, "C")
}
if options.StateOrProvinceName {
writeParts(content, cs.Province, "ST")
}
if options.LocalityName {
writeParts(content, cs.Locality, "L")
}
if options.OrganizationName {
writeParts(content, cs.Organization, "O")
}
if options.SerialNumber {
writePart(content, cs.SerialNumber, "SN")
}
if options.CommonName {
writePart(content, cs.CommonName, "CN")
}
if content.Len() > 0 {
return prefix + `="` + strings.TrimSuffix(content.String(), ",") + `"`
}
return ""
}
func writeParts(content io.StringWriter, entries []string, prefix string) {
for _, entry := range entries {
writePart(content, entry, prefix)
}
}
func writePart(content io.StringWriter, entry string, prefix string) {
if len(entry) > 0 {
_, err := content.WriteString(fmt.Sprintf("%s=%s,", prefix, entry))
if err != nil {
log.Error(err)
}
}
}
// getXForwardedTLSClientCertInfo Build a string with the wanted client certificates information
// like Subject="C=%s,ST=%s,L=%s,O=%s,CN=%s",NB=%d,NA=%d,SAN=%s;
func (p *passTLSClientCert) getXForwardedTLSClientCertInfo(certs []*x509.Certificate) string {
var headerValues []string
for _, peerCert := range certs {
var values []string
var sans string
var nb string
var na string
if p.info != nil {
subject := getDNInfo("Subject", p.info.subject, &peerCert.Subject)
if len(subject) > 0 {
values = append(values, subject)
}
issuer := getDNInfo("Issuer", p.info.issuer, &peerCert.Issuer)
if len(issuer) > 0 {
values = append(values, issuer)
}
}
ci := p.info
if ci != nil {
if ci.notBefore {
nb = fmt.Sprintf("NB=%d", uint64(peerCert.NotBefore.Unix()))
values = append(values, nb)
}
if ci.notAfter {
na = fmt.Sprintf("NA=%d", uint64(peerCert.NotAfter.Unix()))
values = append(values, na)
}
if ci.sans {
sans = fmt.Sprintf("SAN=%s", strings.Join(getSANs(peerCert), ","))
values = append(values, sans)
}
}
value := strings.Join(values, ",")
headerValues = append(headerValues, value)
}
return strings.Join(headerValues, ";")
}
// modifyRequestHeaders set the wanted headers with the certificates information.
func (p *passTLSClientCert) modifyRequestHeaders(logger logrus.FieldLogger, r *http.Request) {
if p.pem {
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
r.Header.Set(xForwardedTLSClientCert, getXForwardedTLSClientCert(logger, r.TLS.PeerCertificates))
} else {
logger.Warn("Try to extract certificate on a request without TLS")
}
}
if p.info != nil {
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
headerContent := p.getXForwardedTLSClientCertInfo(r.TLS.PeerCertificates)
r.Header.Set(xForwardedTLSClientCertInfo, url.QueryEscape(headerContent))
} else {
logger.Warn("Try to extract certificate on a request without TLS")
}
}
}
// sanitize As we pass the raw certificates, remove the useless data and make it http request compliant.
func sanitize(cert []byte) string {
s := string(cert)
r := strings.NewReplacer("-----BEGIN CERTIFICATE-----", "",
"-----END CERTIFICATE-----", "",
"\n", "")
cleaned := r.Replace(s)
return url.QueryEscape(cleaned)
}
// extractCertificate extract the certificate from the request.
func extractCertificate(logger logrus.FieldLogger, cert *x509.Certificate) string {
b := pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
certPEM := pem.EncodeToMemory(&b)
if certPEM == nil {
logger.Error("Cannot extract the certificate content")
return ""
}
return sanitize(certPEM)
}
// getXForwardedTLSClientCert Build a string with the client certificates.
func getXForwardedTLSClientCert(logger logrus.FieldLogger, certs []*x509.Certificate) string {
var headerValues []string
for _, peerCert := range certs {
headerValues = append(headerValues, extractCertificate(logger, peerCert))
}
return strings.Join(headerValues, ",")
}
// getSANs get the Subject Alternate Name values.
func getSANs(cert *x509.Certificate) []string {
var sans []string
if cert == nil {
return sans
}
sans = append(sans, cert.DNSNames...)
sans = append(sans, cert.EmailAddresses...)
var ips []string
for _, ip := range cert.IPAddresses {
ips = append(ips, ip.String())
}
sans = append(sans, ips...)
var uris []string
for _, uri := range cert.URIs {
uris = append(uris, uri.String())
}
return append(sans, uris...)
}

View file

@ -0,0 +1,663 @@
package passtlsclientcert
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"net"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strings"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/require"
)
const (
signingCA = `Certificate:
Data:
Version: 3 (0x2)
Serial Number: 2 (0x2)
Signature Algorithm: sha1WithRSAEncryption
Issuer: DC=org, DC=cheese, O=Cheese, O=Cheese 2, OU=Cheese Section, OU=Cheese Section 2, CN=Simple Root CA, CN=Simple Root CA 2, C=FR, C=US, L=TOULOUSE, L=LYON, ST=Root State, ST=Root State 2/emailAddress=root@signing.com/emailAddress=root2@signing.com
Validity
Not Before: Dec 6 11:10:09 2018 GMT
Not After : Dec 5 11:10:09 2028 GMT
Subject: DC=org, DC=cheese, O=Cheese, O=Cheese 2, OU=Simple Signing Section, OU=Simple Signing Section 2, CN=Simple Signing CA, CN=Simple Signing CA 2, C=FR, C=US, L=TOULOUSE, L=LYON, ST=Signing State, ST=Signing State 2/emailAddress=simple@signing.com/emailAddress=simple2@signing.com
Subject Public Key Info:
Public Key Algorithm: rsaEncryption
RSA Public-Key: (2048 bit)
Modulus:
00:c3:9d:9f:61:15:57:3f:78:cc:e7:5d:20:e2:3e:
2e:79:4a:c3:3a:0c:26:40:18:db:87:08:85:c2:f7:
af:87:13:1a:ff:67:8a:b7:2b:58:a7:cc:89:dd:77:
ff:5e:27:65:11:80:82:8f:af:a0:af:25:86:ec:a2:
4f:20:0e:14:15:16:12:d7:74:5a:c3:99:bd:3b:81:
c8:63:6f:fc:90:14:86:d2:39:ee:87:b2:ff:6d:a5:
69:da:ab:5a:3a:97:cd:23:37:6a:4b:ba:63:cd:a1:
a9:e6:79:aa:37:b8:d1:90:c9:24:b5:e8:70:fc:15:
ad:39:97:28:73:47:66:f6:22:79:5a:b0:03:83:8a:
f1:ca:ae:8b:50:1e:c8:fa:0d:9f:76:2e:00:c2:0e:
75:bc:47:5a:b6:d8:05:ed:5a:bc:6d:50:50:36:6b:
ab:ab:69:f6:9b:1b:6c:7e:a8:9f:b2:33:3a:3c:8c:
6d:5e:83:ce:17:82:9e:10:51:a6:39:ec:98:4e:50:
b7:b1:aa:8b:ac:bb:a1:60:1b:ea:31:3b:b8:0a:ea:
63:41:79:b5:ec:ee:19:e9:85:8e:f3:6d:93:80:da:
98:58:a2:40:93:a5:53:eb:1d:24:b6:66:07:ec:58:
10:63:e7:fa:6e:18:60:74:76:15:39:3c:f4:95:95:
7e:df
Exponent: 65537 (0x10001)
X509v3 extensions:
X509v3 Key Usage: critical
Certificate Sign, CRL Sign
X509v3 Basic Constraints: critical
CA:TRUE, pathlen:0
X509v3 Subject Key Identifier:
1E:52:A2:E8:54:D5:37:EB:D5:A8:1D:E4:C2:04:1D:37:E2:F7:70:03
X509v3 Authority Key Identifier:
keyid:36:70:35:AA:F0:F6:93:B2:86:5D:32:73:F9:41:5A:3F:3B:C8:BC:8B
Signature Algorithm: sha1WithRSAEncryption
76:f3:16:21:27:6d:a2:2e:e8:18:49:aa:54:1e:f8:3b:07:fa:
65:50:d8:1f:a2:cf:64:6c:15:e0:0f:c8:46:b2:d7:b8:0e:cd:
05:3b:06:fb:dd:c6:2f:01:ae:bd:69:d3:bb:55:47:a9:f6:e5:
ba:be:4b:45:fb:2e:3c:33:e0:57:d4:3e:8e:3e:11:f2:0a:f1:
7d:06:ab:04:2e:a5:76:20:c2:db:a4:68:5a:39:00:62:2a:1d:
c2:12:b1:90:66:8c:36:a8:fd:83:d1:1b:da:23:a7:1d:5b:e6:
9b:40:c4:78:25:c7:b7:6b:75:35:cf:bb:37:4a:4f:fc:7e:32:
1f:8c:cf:12:d2:c9:c8:99:d9:4a:55:0a:1e:ac:de:b4:cb:7c:
bf:c4:fb:60:2c:a8:f7:e7:63:5c:b0:1c:62:af:01:3c:fe:4d:
3c:0b:18:37:4c:25:fc:d0:b2:f6:b2:f1:c3:f4:0f:53:d6:1e:
b5:fa:bc:d8:ad:dd:1c:f5:45:9f:af:fe:0a:01:79:92:9a:d8:
71:db:37:f3:1e:bd:fb:c7:1e:0a:0f:97:2a:61:f3:7b:19:93:
9c:a6:8a:69:cd:b0:f5:91:02:a5:1b:10:f4:80:5d:42:af:4e:
82:12:30:3e:d3:a7:11:14:ce:50:91:04:80:d7:2a:03:ef:71:
10:b8:db:a5
-----BEGIN CERTIFICATE-----
MIIFzTCCBLWgAwIBAgIBAjANBgkqhkiG9w0BAQUFADCCAWQxEzARBgoJkiaJk/Is
ZAEZFgNvcmcxFjAUBgoJkiaJk/IsZAEZFgZjaGVlc2UxDzANBgNVBAoMBkNoZWVz
ZTERMA8GA1UECgwIQ2hlZXNlIDIxFzAVBgNVBAsMDkNoZWVzZSBTZWN0aW9uMRkw
FwYDVQQLDBBDaGVlc2UgU2VjdGlvbiAyMRcwFQYDVQQDDA5TaW1wbGUgUm9vdCBD
QTEZMBcGA1UEAwwQU2ltcGxlIFJvb3QgQ0EgMjELMAkGA1UEBhMCRlIxCzAJBgNV
BAYTAlVTMREwDwYDVQQHDAhUT1VMT1VTRTENMAsGA1UEBwwETFlPTjETMBEGA1UE
CAwKUm9vdCBTdGF0ZTEVMBMGA1UECAwMUm9vdCBTdGF0ZSAyMR8wHQYJKoZIhvcN
AQkBFhByb290QHNpZ25pbmcuY29tMSAwHgYJKoZIhvcNAQkBFhFyb290MkBzaWdu
aW5nLmNvbTAeFw0xODEyMDYxMTEwMDlaFw0yODEyMDUxMTEwMDlaMIIBhDETMBEG
CgmSJomT8ixkARkWA29yZzEWMBQGCgmSJomT8ixkARkWBmNoZWVzZTEPMA0GA1UE
CgwGQ2hlZXNlMREwDwYDVQQKDAhDaGVlc2UgMjEfMB0GA1UECwwWU2ltcGxlIFNp
Z25pbmcgU2VjdGlvbjEhMB8GA1UECwwYU2ltcGxlIFNpZ25pbmcgU2VjdGlvbiAy
MRowGAYDVQQDDBFTaW1wbGUgU2lnbmluZyBDQTEcMBoGA1UEAwwTU2ltcGxlIFNp
Z25pbmcgQ0EgMjELMAkGA1UEBhMCRlIxCzAJBgNVBAYTAlVTMREwDwYDVQQHDAhU
T1VMT1VTRTENMAsGA1UEBwwETFlPTjEWMBQGA1UECAwNU2lnbmluZyBTdGF0ZTEY
MBYGA1UECAwPU2lnbmluZyBTdGF0ZSAyMSEwHwYJKoZIhvcNAQkBFhJzaW1wbGVA
c2lnbmluZy5jb20xIjAgBgkqhkiG9w0BCQEWE3NpbXBsZTJAc2lnbmluZy5jb20w
ggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDDnZ9hFVc/eMznXSDiPi55
SsM6DCZAGNuHCIXC96+HExr/Z4q3K1inzIndd/9eJ2URgIKPr6CvJYbsok8gDhQV
FhLXdFrDmb07gchjb/yQFIbSOe6Hsv9tpWnaq1o6l80jN2pLumPNoanmeao3uNGQ
ySS16HD8Fa05lyhzR2b2InlasAODivHKrotQHsj6DZ92LgDCDnW8R1q22AXtWrxt
UFA2a6urafabG2x+qJ+yMzo8jG1eg84Xgp4QUaY57JhOULexqousu6FgG+oxO7gK
6mNBebXs7hnphY7zbZOA2phYokCTpVPrHSS2ZgfsWBBj5/puGGB0dhU5PPSVlX7f
AgMBAAGjZjBkMA4GA1UdDwEB/wQEAwIBBjASBgNVHRMBAf8ECDAGAQH/AgEAMB0G
A1UdDgQWBBQeUqLoVNU369WoHeTCBB034vdwAzAfBgNVHSMEGDAWgBQ2cDWq8PaT
soZdMnP5QVo/O8i8izANBgkqhkiG9w0BAQUFAAOCAQEAdvMWISdtoi7oGEmqVB74
Owf6ZVDYH6LPZGwV4A/IRrLXuA7NBTsG+93GLwGuvWnTu1VHqfblur5LRfsuPDPg
V9Q+jj4R8grxfQarBC6ldiDC26RoWjkAYiodwhKxkGaMNqj9g9Eb2iOnHVvmm0DE
eCXHt2t1Nc+7N0pP/H4yH4zPEtLJyJnZSlUKHqzetMt8v8T7YCyo9+djXLAcYq8B
PP5NPAsYN0wl/NCy9rLxw/QPU9Yetfq82K3dHPVFn6/+CgF5kprYcds38x69+8ce
Cg+XKmHzexmTnKaKac2w9ZECpRsQ9IBdQq9OghIwPtOnERTOUJEEgNcqA+9xELjb
pQ==
-----END CERTIFICATE-----
`
minimalCheeseCrt = `-----BEGIN CERTIFICATE-----
MIIEQDCCAygCFFRY0OBk/L5Se0IZRj3CMljawL2UMA0GCSqGSIb3DQEBCwUAMIIB
hDETMBEGCgmSJomT8ixkARkWA29yZzEWMBQGCgmSJomT8ixkARkWBmNoZWVzZTEP
MA0GA1UECgwGQ2hlZXNlMREwDwYDVQQKDAhDaGVlc2UgMjEfMB0GA1UECwwWU2lt
cGxlIFNpZ25pbmcgU2VjdGlvbjEhMB8GA1UECwwYU2ltcGxlIFNpZ25pbmcgU2Vj
dGlvbiAyMRowGAYDVQQDDBFTaW1wbGUgU2lnbmluZyBDQTEcMBoGA1UEAwwTU2lt
cGxlIFNpZ25pbmcgQ0EgMjELMAkGA1UEBhMCRlIxCzAJBgNVBAYTAlVTMREwDwYD
VQQHDAhUT1VMT1VTRTENMAsGA1UEBwwETFlPTjEWMBQGA1UECAwNU2lnbmluZyBT
dGF0ZTEYMBYGA1UECAwPU2lnbmluZyBTdGF0ZSAyMSEwHwYJKoZIhvcNAQkBFhJz
aW1wbGVAc2lnbmluZy5jb20xIjAgBgkqhkiG9w0BCQEWE3NpbXBsZTJAc2lnbmlu
Zy5jb20wHhcNMTgxMjA2MTExMDM2WhcNMjEwOTI1MTExMDM2WjAzMQswCQYDVQQG
EwJGUjETMBEGA1UECAwKU29tZS1TdGF0ZTEPMA0GA1UECgwGQ2hlZXNlMIIBIjAN
BgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAskX/bUtwFo1gF2BTPNaNcTUMaRFu
FMZozK8IgLjccZ4kZ0R9oFO6Yp8Zl/IvPaf7tE26PI7XP7eHriUdhnQzX7iioDd0
RZa68waIhAGc+xPzRFrP3b3yj3S2a9Rve3c0K+SCV+EtKAwsxMqQDhoo9PcBfo5B
RHfht07uD5MncUcGirwN+/pxHV5xzAGPcc7On0/5L7bq/G+63nhu78zw9XyuLaHC
PM5VbOUvpyIESJHbMMzTdFGL8ob9VKO+Kr1kVGdEA9i8FLGl3xz/GBKuW/JD0xyW
DrU29mri5vYWHmkuv7ZWHGXnpXjTtPHwveE9/0/ArnmpMyR9JtqFr1oEvQIDAQAB
MA0GCSqGSIb3DQEBCwUAA4IBAQBHta+NWXI08UHeOkGzOTGRiWXsOH2dqdX6gTe9
xF1AIjyoQ0gvpoGVvlnChSzmlUj+vnx/nOYGIt1poE3hZA3ZHZD/awsvGyp3GwWD
IfXrEViSCIyF+8tNNKYyUcEO3xdAsAUGgfUwwF/mZ6MBV5+A/ZEEILlTq8zFt9dV
vdKzIt7fZYxYBBHFSarl1x8pDgWXlf3hAufevGJXip9xGYmznF0T5cq1RbWJ4be3
/9K7yuWhuBYC3sbTbCneHBa91M82za+PIISc1ygCYtWSBoZKSAqLk0rkZpHaekDP
WqeUSNGYV//RunTeuRDAf5OxehERb1srzBXhRZ3cZdzXbgR/
-----END CERTIFICATE-----
`
completeCheeseCrt = `Certificate:
Data:
Version: 3 (0x2)
Serial Number: 1 (0x1)
Signature Algorithm: sha1WithRSAEncryption
Issuer: DC=org, DC=cheese, O=Cheese, O=Cheese 2, OU=Simple Signing Section, OU=Simple Signing Section 2, CN=Simple Signing CA, CN=Simple Signing CA 2, C=FR, C=US, L=TOULOUSE, L=LYON, ST=Signing State, ST=Signing State 2/emailAddress=simple@signing.com/emailAddress=simple2@signing.com
Validity
Not Before: Dec 6 11:10:16 2018 GMT
Not After : Dec 5 11:10:16 2020 GMT
Subject: DC=org, DC=cheese, O=Cheese, O=Cheese 2, OU=Simple Signing Section, OU=Simple Signing Section 2, CN=*.cheese.org, CN=*.cheese.com, C=FR, C=US, L=TOULOUSE, L=LYON, ST=Cheese org state, ST=Cheese com state/emailAddress=cert@cheese.org/emailAddress=cert@scheese.com
Subject Public Key Info:
Public Key Algorithm: rsaEncryption
RSA Public-Key: (2048 bit)
Modulus:
00:de:77:fa:8d:03:70:30:39:dd:51:1b:cc:60:db:
a9:5a:13:b1:af:fe:2c:c6:38:9b:88:0a:0f:8e:d9:
1b:a1:1d:af:0d:66:e4:13:5b:bc:5d:36:92:d7:5e:
d0:fa:88:29:d3:78:e1:81:de:98:b2:a9:22:3f:bf:
8a:af:12:92:63:d4:a9:c3:f2:e4:7e:d2:dc:a2:c5:
39:1c:7a:eb:d7:12:70:63:2e:41:47:e0:f0:08:e8:
dc:be:09:01:ec:28:09:af:35:d7:79:9c:50:35:d1:
6b:e5:87:7b:34:f6:d2:31:65:1d:18:42:69:6c:04:
11:83:fe:44:ae:90:92:2d:0b:75:39:57:62:e6:17:
2f:47:2b:c7:53:dd:10:2d:c9:e3:06:13:d2:b9:ba:
63:2e:3c:7d:83:6b:d6:89:c9:cc:9d:4d:bf:9f:e8:
a3:7b:da:c8:99:2b:ba:66:d6:8e:f8:41:41:a0:c9:
d0:5e:c8:11:a4:55:4a:93:83:87:63:04:63:41:9c:
fb:68:04:67:c2:71:2f:f2:65:1d:02:5d:15:db:2c:
d9:04:69:85:c2:7d:0d:ea:3b:ac:85:f8:d4:8f:0f:
c5:70:b2:45:e1:ec:b2:54:0b:e9:f7:82:b4:9b:1b:
2d:b9:25:d4:ab:ca:8f:5b:44:3e:15:dd:b8:7f:b7:
ee:f9
Exponent: 65537 (0x10001)
X509v3 extensions:
X509v3 Key Usage: critical
Digital Signature, Key Encipherment
X509v3 Basic Constraints:
CA:FALSE
X509v3 Extended Key Usage:
TLS Web Server Authentication, TLS Web Client Authentication
X509v3 Subject Key Identifier:
94:BA:73:78:A2:87:FB:58:28:28:CF:98:3B:C2:45:70:16:6E:29:2F
X509v3 Authority Key Identifier:
keyid:1E:52:A2:E8:54:D5:37:EB:D5:A8:1D:E4:C2:04:1D:37:E2:F7:70:03
X509v3 Subject Alternative Name:
DNS:*.cheese.org, DNS:*.cheese.net, DNS:*.cheese.com, IP Address:10.0.1.0, IP Address:10.0.1.2, email:test@cheese.org, email:test@cheese.net
Signature Algorithm: sha1WithRSAEncryption
76:6b:05:b0:0e:34:11:b1:83:99:91:dc:ae:1b:e2:08:15:8b:
16:b2:9b:27:1c:02:ac:b5:df:1b:d0:d0:75:a4:2b:2c:5c:65:
ed:99:ab:f7:cd:fe:38:3f:c3:9a:22:31:1b:ac:8c:1c:c2:f9:
5d:d4:75:7a:2e:72:c7:85:a9:04:af:9f:2a:cc:d3:96:75:f0:
8e:c7:c6:76:48:ac:45:a4:b9:02:1e:2f:c0:15:c4:07:08:92:
cb:27:50:67:a1:c8:05:c5:3a:b3:a6:48:be:eb:d5:59:ab:a2:
1b:95:30:71:13:5b:0a:9a:73:3b:60:cc:10:d0:6a:c7:e5:d7:
8b:2f:f9:2e:98:f2:ff:81:14:24:09:e3:4b:55:57:09:1a:22:
74:f1:f6:40:13:31:43:89:71:0a:96:1a:05:82:1f:83:3a:87:
9b:17:25:ef:5a:55:f2:2d:cd:0d:4d:e4:81:58:b6:e3:8d:09:
62:9a:0c:bd:e4:e5:5c:f0:95:da:cb:c7:34:2c:34:5f:6d:fc:
60:7b:12:5b:86:fd:df:21:89:3b:48:08:30:bf:67:ff:8c:e6:
9b:53:cc:87:36:47:70:40:3b:d9:90:2a:d2:d2:82:c6:9c:f5:
d1:d8:e0:e6:fd:aa:2f:95:7e:39:ac:fc:4e:d4:ce:65:b3:ec:
c6:98:8a:31
-----BEGIN CERTIFICATE-----
MIIGWjCCBUKgAwIBAgIBATANBgkqhkiG9w0BAQUFADCCAYQxEzARBgoJkiaJk/Is
ZAEZFgNvcmcxFjAUBgoJkiaJk/IsZAEZFgZjaGVlc2UxDzANBgNVBAoMBkNoZWVz
ZTERMA8GA1UECgwIQ2hlZXNlIDIxHzAdBgNVBAsMFlNpbXBsZSBTaWduaW5nIFNl
Y3Rpb24xITAfBgNVBAsMGFNpbXBsZSBTaWduaW5nIFNlY3Rpb24gMjEaMBgGA1UE
AwwRU2ltcGxlIFNpZ25pbmcgQ0ExHDAaBgNVBAMME1NpbXBsZSBTaWduaW5nIENB
IDIxCzAJBgNVBAYTAkZSMQswCQYDVQQGEwJVUzERMA8GA1UEBwwIVE9VTE9VU0Ux
DTALBgNVBAcMBExZT04xFjAUBgNVBAgMDVNpZ25pbmcgU3RhdGUxGDAWBgNVBAgM
D1NpZ25pbmcgU3RhdGUgMjEhMB8GCSqGSIb3DQEJARYSc2ltcGxlQHNpZ25pbmcu
Y29tMSIwIAYJKoZIhvcNAQkBFhNzaW1wbGUyQHNpZ25pbmcuY29tMB4XDTE4MTIw
NjExMTAxNloXDTIwMTIwNTExMTAxNlowggF2MRMwEQYKCZImiZPyLGQBGRYDb3Jn
MRYwFAYKCZImiZPyLGQBGRYGY2hlZXNlMQ8wDQYDVQQKDAZDaGVlc2UxETAPBgNV
BAoMCENoZWVzZSAyMR8wHQYDVQQLDBZTaW1wbGUgU2lnbmluZyBTZWN0aW9uMSEw
HwYDVQQLDBhTaW1wbGUgU2lnbmluZyBTZWN0aW9uIDIxFTATBgNVBAMMDCouY2hl
ZXNlLm9yZzEVMBMGA1UEAwwMKi5jaGVlc2UuY29tMQswCQYDVQQGEwJGUjELMAkG
A1UEBhMCVVMxETAPBgNVBAcMCFRPVUxPVVNFMQ0wCwYDVQQHDARMWU9OMRkwFwYD
VQQIDBBDaGVlc2Ugb3JnIHN0YXRlMRkwFwYDVQQIDBBDaGVlc2UgY29tIHN0YXRl
MR4wHAYJKoZIhvcNAQkBFg9jZXJ0QGNoZWVzZS5vcmcxHzAdBgkqhkiG9w0BCQEW
EGNlcnRAc2NoZWVzZS5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIB
AQDed/qNA3AwOd1RG8xg26laE7Gv/izGOJuICg+O2RuhHa8NZuQTW7xdNpLXXtD6
iCnTeOGB3piyqSI/v4qvEpJj1KnD8uR+0tyixTkceuvXEnBjLkFH4PAI6Ny+CQHs
KAmvNdd5nFA10Wvlh3s09tIxZR0YQmlsBBGD/kSukJItC3U5V2LmFy9HK8dT3RAt
yeMGE9K5umMuPH2Da9aJycydTb+f6KN72siZK7pm1o74QUGgydBeyBGkVUqTg4dj
BGNBnPtoBGfCcS/yZR0CXRXbLNkEaYXCfQ3qO6yF+NSPD8VwskXh7LJUC+n3grSb
Gy25JdSryo9bRD4V3bh/t+75AgMBAAGjgeAwgd0wDgYDVR0PAQH/BAQDAgWgMAkG
A1UdEwQCMAAwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQW
BBSUunN4oof7WCgoz5g7wkVwFm4pLzAfBgNVHSMEGDAWgBQeUqLoVNU369WoHeTC
BB034vdwAzBhBgNVHREEWjBYggwqLmNoZWVzZS5vcmeCDCouY2hlZXNlLm5ldIIM
Ki5jaGVlc2UuY29thwQKAAEAhwQKAAECgQ90ZXN0QGNoZWVzZS5vcmeBD3Rlc3RA
Y2hlZXNlLm5ldDANBgkqhkiG9w0BAQUFAAOCAQEAdmsFsA40EbGDmZHcrhviCBWL
FrKbJxwCrLXfG9DQdaQrLFxl7Zmr983+OD/DmiIxG6yMHML5XdR1ei5yx4WpBK+f
KszTlnXwjsfGdkisRaS5Ah4vwBXEBwiSyydQZ6HIBcU6s6ZIvuvVWauiG5UwcRNb
CppzO2DMENBqx+XXiy/5Lpjy/4EUJAnjS1VXCRoidPH2QBMxQ4lxCpYaBYIfgzqH
mxcl71pV8i3NDU3kgVi2440JYpoMveTlXPCV2svHNCw0X238YHsSW4b93yGJO0gI
ML9n/4zmm1PMhzZHcEA72ZAq0tKCxpz10djg5v2qL5V+Oaz8TtTOZbPsxpiKMQ==
-----END CERTIFICATE-----
`
minimalCert = `-----BEGIN CERTIFICATE-----
MIIDGTCCAgECCQCqLd75YLi2kDANBgkqhkiG9w0BAQsFADBYMQswCQYDVQQGEwJG
UjETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UEBwwIVG91bG91c2UxITAfBgNV
BAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0xODA3MTgwODI4MTZaFw0x
ODA4MTcwODI4MTZaMEUxCzAJBgNVBAYTAkZSMRMwEQYDVQQIDApTb21lLVN0YXRl
MSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3
DQEBAQUAA4IBDwAwggEKAoIBAQC/+frDMMTLQyXG34F68BPhQq0kzK4LIq9Y0/gl
FjySZNn1C0QDWA1ubVCAcA6yY204I9cxcQDPNrhC7JlS5QA8Y5rhIBrqQlzZizAi
Rj3NTrRjtGUtOScnHuJaWjLy03DWD+aMwb7q718xt5SEABmmUvLwQK+EjW2MeDwj
y8/UEIpvrRDmdhGaqv7IFpIDkcIF7FowJ/hwDvx3PMc+z/JWK0ovzpvgbx69AVbw
ZxCimeha65rOqVi+lEetD26le+WnOdYsdJ2IkmpPNTXGdfb15xuAc+gFXfMCh7Iw
3Ynl6dZtZM/Ok2kiA7/OsmVnRKkWrtBfGYkI9HcNGb3zrk6nAgMBAAEwDQYJKoZI
hvcNAQELBQADggEBAC/R+Yvhh1VUhcbK49olWsk/JKqfS3VIDQYZg1Eo+JCPbwgS
I1BSYVfMcGzuJTX6ua3m/AHzGF3Tap4GhF4tX12jeIx4R4utnjj7/YKkTvuEM2f4
xT56YqI7zalGScIB0iMeyNz1QcimRl+M/49au8ow9hNX8C2tcA2cwd/9OIj/6T8q
SBRHc6ojvbqZSJCO0jziGDT1L3D+EDgTjED4nd77v/NRdP+egb0q3P0s4dnQ/5AV
aQlQADUn61j3ScbGJ4NSeZFFvsl38jeRi/MEzp0bGgNBcPj6JHi7qbbauZcZfQ05
jECvgAY7Nfd9mZ1KtyNaW31is+kag7NsvjxU/kM=
-----END CERTIFICATE-----`
)
func getCleanCertContents(certContents []string) string {
var re = regexp.MustCompile("-----BEGIN CERTIFICATE-----(?s)(.*)")
var cleanedCertContent []string
for _, certContent := range certContents {
cert := re.FindString(certContent)
cleanedCertContent = append(cleanedCertContent, sanitize([]byte(cert)))
}
return strings.Join(cleanedCertContent, ",")
}
func getCertificate(certContent string) *x509.Certificate {
roots := x509.NewCertPool()
ok := roots.AppendCertsFromPEM([]byte(signingCA))
if !ok {
panic("failed to parse root certificate")
}
block, _ := pem.Decode([]byte(certContent))
if block == nil {
panic("failed to parse certificate PEM")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
panic("failed to parse certificate: " + err.Error())
}
return cert
}
func buildTLSWith(certContents []string) *tls.ConnectionState {
var peerCertificates []*x509.Certificate
for _, certContent := range certContents {
peerCertificates = append(peerCertificates, getCertificate(certContent))
}
return &tls.ConnectionState{PeerCertificates: peerCertificates}
}
var next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("bar"))
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
})
func getExpectedSanitized(s string) string {
return url.QueryEscape(strings.Replace(s, "\n", "", -1))
}
func TestSanitize(t *testing.T) {
testCases := []struct {
desc string
toSanitize []byte
expected string
}{
{
desc: "Empty",
},
{
desc: "With a minimal cert",
toSanitize: []byte(minimalCheeseCrt),
expected: getExpectedSanitized(`MIIEQDCCAygCFFRY0OBk/L5Se0IZRj3CMljawL2UMA0GCSqGSIb3DQEBCwUAMIIB
hDETMBEGCgmSJomT8ixkARkWA29yZzEWMBQGCgmSJomT8ixkARkWBmNoZWVzZTEP
MA0GA1UECgwGQ2hlZXNlMREwDwYDVQQKDAhDaGVlc2UgMjEfMB0GA1UECwwWU2lt
cGxlIFNpZ25pbmcgU2VjdGlvbjEhMB8GA1UECwwYU2ltcGxlIFNpZ25pbmcgU2Vj
dGlvbiAyMRowGAYDVQQDDBFTaW1wbGUgU2lnbmluZyBDQTEcMBoGA1UEAwwTU2lt
cGxlIFNpZ25pbmcgQ0EgMjELMAkGA1UEBhMCRlIxCzAJBgNVBAYTAlVTMREwDwYD
VQQHDAhUT1VMT1VTRTENMAsGA1UEBwwETFlPTjEWMBQGA1UECAwNU2lnbmluZyBT
dGF0ZTEYMBYGA1UECAwPU2lnbmluZyBTdGF0ZSAyMSEwHwYJKoZIhvcNAQkBFhJz
aW1wbGVAc2lnbmluZy5jb20xIjAgBgkqhkiG9w0BCQEWE3NpbXBsZTJAc2lnbmlu
Zy5jb20wHhcNMTgxMjA2MTExMDM2WhcNMjEwOTI1MTExMDM2WjAzMQswCQYDVQQG
EwJGUjETMBEGA1UECAwKU29tZS1TdGF0ZTEPMA0GA1UECgwGQ2hlZXNlMIIBIjAN
BgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAskX/bUtwFo1gF2BTPNaNcTUMaRFu
FMZozK8IgLjccZ4kZ0R9oFO6Yp8Zl/IvPaf7tE26PI7XP7eHriUdhnQzX7iioDd0
RZa68waIhAGc+xPzRFrP3b3yj3S2a9Rve3c0K+SCV+EtKAwsxMqQDhoo9PcBfo5B
RHfht07uD5MncUcGirwN+/pxHV5xzAGPcc7On0/5L7bq/G+63nhu78zw9XyuLaHC
PM5VbOUvpyIESJHbMMzTdFGL8ob9VKO+Kr1kVGdEA9i8FLGl3xz/GBKuW/JD0xyW
DrU29mri5vYWHmkuv7ZWHGXnpXjTtPHwveE9/0/ArnmpMyR9JtqFr1oEvQIDAQAB
MA0GCSqGSIb3DQEBCwUAA4IBAQBHta+NWXI08UHeOkGzOTGRiWXsOH2dqdX6gTe9
xF1AIjyoQ0gvpoGVvlnChSzmlUj+vnx/nOYGIt1poE3hZA3ZHZD/awsvGyp3GwWD
IfXrEViSCIyF+8tNNKYyUcEO3xdAsAUGgfUwwF/mZ6MBV5+A/ZEEILlTq8zFt9dV
vdKzIt7fZYxYBBHFSarl1x8pDgWXlf3hAufevGJXip9xGYmznF0T5cq1RbWJ4be3
/9K7yuWhuBYC3sbTbCneHBa91M82za+PIISc1ygCYtWSBoZKSAqLk0rkZpHaekDP
WqeUSNGYV//RunTeuRDAf5OxehERb1srzBXhRZ3cZdzXbgR/`),
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
require.Equal(t, test.expected, sanitize(test.toSanitize), "The sanitized certificates should be equal")
})
}
}
func TestTLSClientHeadersWithPEM(t *testing.T) {
testCases := []struct {
desc string
certContents []string // set the request TLS attribute if defined
config config.PassTLSClientCert
expectedHeader string
}{
{
desc: "No TLS, no option",
},
{
desc: "TLS, no option",
certContents: []string{minimalCheeseCrt},
},
{
desc: "No TLS, with pem option true",
config: config.PassTLSClientCert{PEM: true},
},
{
desc: "TLS with simple certificate, with pem option true",
certContents: []string{minimalCheeseCrt},
config: config.PassTLSClientCert{PEM: true},
expectedHeader: getCleanCertContents([]string{minimalCert}),
},
{
desc: "TLS with complete certificate, with pem option true",
certContents: []string{minimalCheeseCrt},
config: config.PassTLSClientCert{PEM: true},
expectedHeader: getCleanCertContents([]string{minimalCheeseCrt}),
},
{
desc: "TLS with two certificate, with pem option true",
certContents: []string{minimalCert, minimalCheeseCrt},
config: config.PassTLSClientCert{PEM: true},
expectedHeader: getCleanCertContents([]string{minimalCert, minimalCheeseCrt}),
},
}
for _, test := range testCases {
tlsClientHeaders, err := New(context.Background(), next, test.config, "foo")
require.NoError(t, err)
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://example.com/foo", nil)
if test.certContents != nil && len(test.certContents) > 0 {
req.TLS = buildTLSWith(test.certContents)
}
tlsClientHeaders.ServeHTTP(res, req)
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
require.Equal(t, http.StatusOK, res.Code, "Http Status should be OK")
require.Equal(t, "bar", res.Body.String(), "Should be the expected body")
if test.expectedHeader != "" {
require.Equal(t, getCleanCertContents(test.certContents), req.Header.Get(xForwardedTLSClientCert), "The request header should contain the cleaned certificate")
} else {
require.Empty(t, req.Header.Get(xForwardedTLSClientCert))
}
require.Empty(t, res.Header().Get(xForwardedTLSClientCert), "The response header should be always empty")
})
}
}
func TestGetSans(t *testing.T) {
urlFoo, err := url.Parse("my.foo.com")
require.NoError(t, err)
urlBar, err := url.Parse("my.bar.com")
require.NoError(t, err)
testCases := []struct {
desc string
cert *x509.Certificate // set the request TLS attribute if defined
expected []string
}{
{
desc: "With nil",
},
{
desc: "Certificate without Sans",
cert: &x509.Certificate{},
},
{
desc: "Certificate with all Sans",
cert: &x509.Certificate{
DNSNames: []string{"foo", "bar"},
EmailAddresses: []string{"test@test.com", "test2@test.com"},
IPAddresses: []net.IP{net.IPv4(10, 0, 0, 1), net.IPv4(10, 0, 0, 2)},
URIs: []*url.URL{urlFoo, urlBar},
},
expected: []string{"foo", "bar", "test@test.com", "test2@test.com", "10.0.0.1", "10.0.0.2", urlFoo.String(), urlBar.String()},
},
}
for _, test := range testCases {
sans := getSANs(test.cert)
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
if len(test.expected) > 0 {
for i, expected := range test.expected {
require.Equal(t, expected, sans[i])
}
} else {
require.Empty(t, sans)
}
})
}
}
func TestTLSClientHeadersWithCertInfo(t *testing.T) {
minimalCheeseCertAllInfo := `Subject="C=FR,ST=Some-State,O=Cheese",Issuer="DC=org,DC=cheese,C=FR,C=US,ST=Signing State,ST=Signing State 2,L=TOULOUSE,L=LYON,O=Cheese,O=Cheese 2,CN=Simple Signing CA 2",NB=1544094636,NA=1632568236,SAN=`
completeCertAllInfo := `Subject="DC=org,DC=cheese,C=FR,C=US,ST=Cheese org state,ST=Cheese com state,L=TOULOUSE,L=LYON,O=Cheese,O=Cheese 2,CN=*.cheese.com",Issuer="DC=org,DC=cheese,C=FR,C=US,ST=Signing State,ST=Signing State 2,L=TOULOUSE,L=LYON,O=Cheese,O=Cheese 2,CN=Simple Signing CA 2",NB=1544094616,NA=1607166616,SAN=*.cheese.org,*.cheese.net,*.cheese.com,test@cheese.org,test@cheese.net,10.0.1.0,10.0.1.2`
testCases := []struct {
desc string
certContents []string // set the request TLS attribute if defined
config config.PassTLSClientCert
expectedHeader string
}{
{
desc: "No TLS, no option",
},
{
desc: "TLS, no option",
certContents: []string{minimalCert},
},
{
desc: "No TLS, with subject info",
config: config.PassTLSClientCert{
Info: &config.TLSClientCertificateInfo{
Subject: &config.TLSCLientCertificateDNInfo{
CommonName: true,
Organization: true,
Locality: true,
Province: true,
Country: true,
SerialNumber: true,
},
},
},
},
{
desc: "No TLS, with pem option false with empty subject info",
config: config.PassTLSClientCert{
PEM: false,
Info: &config.TLSClientCertificateInfo{
Subject: &config.TLSCLientCertificateDNInfo{},
},
},
},
{
desc: "TLS with simple certificate, with all info",
certContents: []string{minimalCheeseCrt},
config: config.PassTLSClientCert{
Info: &config.TLSClientCertificateInfo{
NotAfter: true,
NotBefore: true,
Sans: true,
Subject: &config.TLSCLientCertificateDNInfo{
CommonName: true,
Country: true,
DomainComponent: true,
Locality: true,
Organization: true,
Province: true,
SerialNumber: true,
},
Issuer: &config.TLSCLientCertificateDNInfo{
CommonName: true,
Country: true,
DomainComponent: true,
Locality: true,
Organization: true,
Province: true,
SerialNumber: true,
},
},
},
expectedHeader: url.QueryEscape(minimalCheeseCertAllInfo),
},
{
desc: "TLS with simple certificate, with some info",
certContents: []string{minimalCheeseCrt},
config: config.PassTLSClientCert{
Info: &config.TLSClientCertificateInfo{
NotAfter: true,
Sans: true,
Subject: &config.TLSCLientCertificateDNInfo{
Organization: true,
},
Issuer: &config.TLSCLientCertificateDNInfo{
Country: true,
},
},
},
expectedHeader: url.QueryEscape(`Subject="O=Cheese",Issuer="C=FR,C=US",NA=1632568236,SAN=`),
},
{
desc: "TLS with complete certificate, with all info",
certContents: []string{completeCheeseCrt},
config: config.PassTLSClientCert{
Info: &config.TLSClientCertificateInfo{
NotAfter: true,
NotBefore: true,
Sans: true,
Subject: &config.TLSCLientCertificateDNInfo{
Country: true,
Province: true,
Locality: true,
Organization: true,
CommonName: true,
SerialNumber: true,
DomainComponent: true,
},
Issuer: &config.TLSCLientCertificateDNInfo{
Country: true,
Province: true,
Locality: true,
Organization: true,
CommonName: true,
SerialNumber: true,
DomainComponent: true,
},
},
},
expectedHeader: url.QueryEscape(completeCertAllInfo),
},
{
desc: "TLS with 2 certificates, with all info",
certContents: []string{minimalCheeseCrt, completeCheeseCrt},
config: config.PassTLSClientCert{
Info: &config.TLSClientCertificateInfo{
NotAfter: true,
NotBefore: true,
Sans: true,
Subject: &config.TLSCLientCertificateDNInfo{
Country: true,
Province: true,
Locality: true,
Organization: true,
CommonName: true,
SerialNumber: true,
DomainComponent: true,
},
Issuer: &config.TLSCLientCertificateDNInfo{
Country: true,
Province: true,
Locality: true,
Organization: true,
CommonName: true,
SerialNumber: true,
DomainComponent: true,
},
},
},
expectedHeader: url.QueryEscape(strings.Join([]string{minimalCheeseCertAllInfo, completeCertAllInfo}, ";")),
},
}
for _, test := range testCases {
tlsClientHeaders, err := New(context.Background(), next, test.config, "foo")
require.NoError(t, err)
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://example.com/foo", nil)
if test.certContents != nil && len(test.certContents) > 0 {
req.TLS = buildTLSWith(test.certContents)
}
tlsClientHeaders.ServeHTTP(res, req)
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
require.Equal(t, http.StatusOK, res.Code, "Http Status should be OK")
require.Equal(t, "bar", res.Body.String(), "Should be the expected body")
if test.expectedHeader != "" {
require.Equal(t, test.expectedHeader, req.Header.Get(xForwardedTLSClientCertInfo), "The request header should contain the cleaned certificate")
} else {
require.Empty(t, req.Header.Get(xForwardedTLSClientCertInfo))
}
require.Empty(t, res.Header().Get(xForwardedTLSClientCertInfo), "The response header should be always empty")
})
}
}

View file

@ -0,0 +1,54 @@
package ratelimiter
import (
"context"
"net/http"
"time"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/vulcand/oxy/ratelimit"
"github.com/vulcand/oxy/utils"
)
const (
typeName = "RateLimiterType"
)
type rateLimiter struct {
handler http.Handler
name string
}
// New creates rate limiter middleware.
func New(ctx context.Context, next http.Handler, config config.RateLimit, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
extractFunc, err := utils.NewExtractor(config.ExtractorFunc)
if err != nil {
return nil, err
}
rateSet := ratelimit.NewRateSet()
for _, rate := range config.RateSet {
if err = rateSet.Add(time.Duration(rate.Period), rate.Average, rate.Burst); err != nil {
return nil, err
}
}
rl, err := ratelimit.New(next, extractFunc, rateSet)
if err != nil {
return nil, err
}
return &rateLimiter{handler: rl, name: name}, nil
}
func (r *rateLimiter) GetTracingInformation() (string, ext.SpanKindEnum) {
return r.name, tracing.SpanKindNoneEnum
}
func (r *rateLimiter) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
r.handler.ServeHTTP(rw, req)
}

View file

@ -0,0 +1,40 @@
package recovery
import (
"context"
"net/http"
"github.com/containous/traefik/pkg/middlewares"
"github.com/sirupsen/logrus"
)
const (
typeName = "Recovery"
)
type recovery struct {
next http.Handler
name string
}
// New creates recovery middleware.
func New(ctx context.Context, next http.Handler, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
return &recovery{
next: next,
name: name,
}, nil
}
func (re *recovery) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
defer recoverFunc(middlewares.GetLogger(req.Context(), re.name, typeName), rw)
re.next.ServeHTTP(rw, req)
}
func recoverFunc(logger logrus.FieldLogger, rw http.ResponseWriter) {
if err := recover(); err != nil {
logger.Errorf("Recovered from panic in http handler: %+v", err)
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
}

View file

@ -0,0 +1,27 @@
package recovery
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRecoverHandler(t *testing.T) {
fn := func(w http.ResponseWriter, r *http.Request) {
panic("I love panicing!")
}
recovery, err := New(context.Background(), http.HandlerFunc(fn), "foo-recovery")
require.NoError(t, err)
server := httptest.NewServer(recovery)
defer server.Close()
resp, err := http.Get(server.URL)
require.NoError(t, err)
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
}

View file

@ -0,0 +1,158 @@
package redirect
import (
"bytes"
"context"
"html/template"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/vulcand/oxy/utils"
)
type redirect struct {
next http.Handler
regex *regexp.Regexp
replacement string
permanent bool
errHandler utils.ErrorHandler
name string
}
// New creates a Redirect middleware.
func newRedirect(_ context.Context, next http.Handler, regex string, replacement string, permanent bool, name string) (http.Handler, error) {
re, err := regexp.Compile(regex)
if err != nil {
return nil, err
}
return &redirect{
regex: re,
replacement: replacement,
permanent: permanent,
errHandler: utils.DefaultHandler,
next: next,
name: name,
}, nil
}
func (r *redirect) GetTracingInformation() (string, ext.SpanKindEnum) {
return r.name, tracing.SpanKindNoneEnum
}
func (r *redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
oldURL := rawURL(req)
// If the Regexp doesn't match, skip to the next handler
if !r.regex.MatchString(oldURL) {
r.next.ServeHTTP(rw, req)
return
}
// apply a rewrite regexp to the URL
newURL := r.regex.ReplaceAllString(oldURL, r.replacement)
// replace any variables that may be in there
rewrittenURL := &bytes.Buffer{}
if err := applyString(newURL, rewrittenURL, req); err != nil {
r.errHandler.ServeHTTP(rw, req, err)
return
}
// parse the rewritten URL and replace request URL with it
parsedURL, err := url.Parse(rewrittenURL.String())
if err != nil {
r.errHandler.ServeHTTP(rw, req, err)
return
}
if newURL != oldURL {
handler := &moveHandler{location: parsedURL, permanent: r.permanent}
handler.ServeHTTP(rw, req)
return
}
req.URL = parsedURL
// make sure the request URI corresponds the rewritten URL
req.RequestURI = req.URL.RequestURI()
r.next.ServeHTTP(rw, req)
}
type moveHandler struct {
location *url.URL
permanent bool
}
func (m *moveHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("Location", m.location.String())
status := http.StatusFound
if req.Method != http.MethodGet {
status = http.StatusTemporaryRedirect
}
if m.permanent {
status = http.StatusMovedPermanently
if req.Method != http.MethodGet {
status = http.StatusPermanentRedirect
}
}
rw.WriteHeader(status)
_, err := rw.Write([]byte(http.StatusText(status)))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}
func rawURL(req *http.Request) string {
scheme := "http"
host := req.Host
port := ""
uri := req.RequestURI
schemeRegex := `^(https?):\/\/([\w\._-]+)(:\d+)?(.*)$`
re, _ := regexp.Compile(schemeRegex)
if re.Match([]byte(req.RequestURI)) {
match := re.FindStringSubmatch(req.RequestURI)
scheme = match[1]
if len(match[2]) > 0 {
host = match[2]
}
if len(match[3]) > 0 {
port = match[3]
}
uri = match[4]
}
if req.TLS != nil || isXForwardedHTTPS(req) {
scheme = "https"
}
return strings.Join([]string{scheme, "://", host, port, uri}, "")
}
func isXForwardedHTTPS(request *http.Request) bool {
xForwardedProto := request.Header.Get("X-Forwarded-Proto")
return len(xForwardedProto) > 0 && xForwardedProto == "https"
}
func applyString(in string, out io.Writer, req *http.Request) error {
t, err := template.New("t").Parse(in)
if err != nil {
return err
}
data := struct{ Request *http.Request }{Request: req}
return t.Execute(out, data)
}

View file

@ -0,0 +1,22 @@
package redirect
import (
"context"
"net/http"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
)
const (
typeRegexName = "RedirectRegex"
)
// NewRedirectRegex creates a redirect middleware.
func NewRedirectRegex(ctx context.Context, next http.Handler, conf config.RedirectRegex, name string) (http.Handler, error) {
logger := middlewares.GetLogger(ctx, name, typeRegexName)
logger.Debug("Creating middleware")
logger.Debugf("Setting up redirection from %s to %s", conf.Regex, conf.Replacement)
return newRedirect(ctx, next, conf.Regex, conf.Replacement, conf.Permanent, name)
}

View file

@ -0,0 +1,198 @@
package redirect
import (
"context"
"crypto/tls"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRedirectRegexHandler(t *testing.T) {
testCases := []struct {
desc string
config config.RedirectRegex
method string
url string
secured bool
expectedURL string
expectedStatus int
errorExpected bool
}{
{
desc: "simple redirection",
config: config.RedirectRegex{
Regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
Replacement: "https://${1}bar$2:443$4",
},
url: "http://foo.com:80",
expectedURL: "https://foobar.com:443",
expectedStatus: http.StatusFound,
},
{
desc: "use request header",
config: config.RedirectRegex{
Regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
Replacement: `https://${1}{{ .Request.Header.Get "X-Foo" }}$2:443$4`,
},
url: "http://foo.com:80",
expectedURL: "https://foobar.com:443",
expectedStatus: http.StatusFound,
},
{
desc: "URL doesn't match regex",
config: config.RedirectRegex{
Regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
Replacement: "https://${1}bar$2:443$4",
},
url: "http://bar.com:80",
expectedStatus: http.StatusOK,
},
{
desc: "invalid rewritten URL",
config: config.RedirectRegex{
Regex: `^(.*)$`,
Replacement: "http://192.168.0.%31/",
},
url: "http://foo.com:80",
expectedStatus: http.StatusBadGateway,
},
{
desc: "invalid regex",
config: config.RedirectRegex{
Regex: `^(.*`,
Replacement: "$1",
},
url: "http://foo.com:80",
errorExpected: true,
},
{
desc: "HTTP to HTTPS permanent",
config: config.RedirectRegex{
Regex: `^http://`,
Replacement: "https://$1",
Permanent: true,
},
url: "http://foo",
expectedURL: "https://foo",
expectedStatus: http.StatusMovedPermanently,
},
{
desc: "HTTPS to HTTP permanent",
config: config.RedirectRegex{
Regex: `https://foo`,
Replacement: "http://foo",
Permanent: true,
},
secured: true,
url: "https://foo",
expectedURL: "http://foo",
expectedStatus: http.StatusMovedPermanently,
},
{
desc: "HTTP to HTTPS",
config: config.RedirectRegex{
Regex: `http://foo:80`,
Replacement: "https://foo:443",
},
url: "http://foo:80",
expectedURL: "https://foo:443",
expectedStatus: http.StatusFound,
},
{
desc: "HTTPS to HTTP",
config: config.RedirectRegex{
Regex: `https://foo:443`,
Replacement: "http://foo:80",
},
secured: true,
url: "https://foo:443",
expectedURL: "http://foo:80",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP to HTTP",
config: config.RedirectRegex{
Regex: `http://foo:80`,
Replacement: "http://foo:88",
},
url: "http://foo:80",
expectedURL: "http://foo:88",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP to HTTP POST",
config: config.RedirectRegex{
Regex: `^http://`,
Replacement: "https://$1",
},
url: "http://foo",
method: http.MethodPost,
expectedURL: "https://foo",
expectedStatus: http.StatusTemporaryRedirect,
},
{
desc: "HTTP to HTTP POST permanent",
config: config.RedirectRegex{
Regex: `^http://`,
Replacement: "https://$1",
Permanent: true,
},
url: "http://foo",
method: http.MethodPost,
expectedURL: "https://foo",
expectedStatus: http.StatusPermanentRedirect,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
handler, err := NewRedirectRegex(context.Background(), next, test.config, "traefikTest")
if test.errorExpected {
require.Error(t, err)
require.Nil(t, handler)
} else {
require.NoError(t, err)
require.NotNil(t, handler)
recorder := httptest.NewRecorder()
method := http.MethodGet
if test.method != "" {
method = test.method
}
r := testhelpers.MustNewRequest(method, test.url, nil)
if test.secured {
r.TLS = &tls.ConnectionState{}
}
r.Header.Set("X-Foo", "bar")
handler.ServeHTTP(recorder, r)
assert.Equal(t, test.expectedStatus, recorder.Code)
if test.expectedStatus == http.StatusMovedPermanently ||
test.expectedStatus == http.StatusFound ||
test.expectedStatus == http.StatusTemporaryRedirect ||
test.expectedStatus == http.StatusPermanentRedirect {
location, err := recorder.Result().Location()
require.NoError(t, err)
assert.Equal(t, test.expectedURL, location.String())
} else {
location, err := recorder.Result().Location()
require.Errorf(t, err, "Location %v", location)
}
}
})
}
}

View file

@ -0,0 +1,34 @@
package redirect
import (
"context"
"net/http"
"github.com/containous/traefik/pkg/middlewares"
"github.com/pkg/errors"
"github.com/containous/traefik/pkg/config"
)
const (
typeSchemeName = "RedirectScheme"
schemeRedirectRegex = `^(https?:\/\/)?([\w\._-]+)(:\d+)?(.*)$`
)
// NewRedirectScheme creates a new RedirectScheme middleware.
func NewRedirectScheme(ctx context.Context, next http.Handler, conf config.RedirectScheme, name string) (http.Handler, error) {
logger := middlewares.GetLogger(ctx, name, typeSchemeName)
logger.Debug("Creating middleware")
logger.Debugf("Setting up redirection to %s %s", conf.Scheme, conf.Port)
if len(conf.Scheme) == 0 {
return nil, errors.New("you must provide a target scheme")
}
port := ""
if len(conf.Port) > 0 && !(conf.Scheme == "http" && conf.Port == "80" || conf.Scheme == "https" && conf.Port == "443") {
port = ":" + conf.Port
}
return newRedirect(ctx, next, schemeRedirectRegex, conf.Scheme+"://${2}"+port+"${4}", conf.Permanent, name)
}

View file

@ -0,0 +1,250 @@
package redirect
import (
"context"
"crypto/tls"
"net/http"
"net/http/httptest"
"regexp"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRedirectSchemeHandler(t *testing.T) {
testCases := []struct {
desc string
config config.RedirectScheme
method string
url string
secured bool
expectedURL string
expectedStatus int
errorExpected bool
}{
{
desc: "Without scheme",
config: config.RedirectScheme{},
url: "http://foo",
errorExpected: true,
},
{
desc: "HTTP to HTTPS",
config: config.RedirectScheme{
Scheme: "https",
},
url: "http://foo",
expectedURL: "https://foo",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP with port to HTTPS without port",
config: config.RedirectScheme{
Scheme: "https",
},
url: "http://foo:8080",
expectedURL: "https://foo",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP without port to HTTPS with port",
config: config.RedirectScheme{
Scheme: "https",
Port: "8443",
},
url: "http://foo",
expectedURL: "https://foo:8443",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP with port to HTTPS with port",
config: config.RedirectScheme{
Scheme: "https",
Port: "8443",
},
url: "http://foo:8000",
expectedURL: "https://foo:8443",
expectedStatus: http.StatusFound,
},
{
desc: "HTTPS with port to HTTPS with port",
config: config.RedirectScheme{
Scheme: "https",
Port: "8443",
},
url: "https://foo:8000",
expectedURL: "https://foo:8443",
expectedStatus: http.StatusFound,
},
{
desc: "HTTPS with port to HTTPS without port",
config: config.RedirectScheme{
Scheme: "https",
},
url: "https://foo:8000",
expectedURL: "https://foo",
expectedStatus: http.StatusFound,
},
{
desc: "redirection to HTTPS without port from an URL already in https",
config: config.RedirectScheme{
Scheme: "https",
},
url: "https://foo:8000/theother",
expectedURL: "https://foo/theother",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP to HTTPS permanent",
config: config.RedirectScheme{
Scheme: "https",
Port: "8443",
Permanent: true,
},
url: "http://foo",
expectedURL: "https://foo:8443",
expectedStatus: http.StatusMovedPermanently,
},
{
desc: "to HTTP 80",
config: config.RedirectScheme{
Scheme: "http",
Port: "80",
},
url: "http://foo:80",
expectedURL: "http://foo",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP to wss",
config: config.RedirectScheme{
Scheme: "wss",
Port: "9443",
},
url: "http://foo",
expectedURL: "wss://foo:9443",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP to wss without port",
config: config.RedirectScheme{
Scheme: "wss",
},
url: "http://foo",
expectedURL: "wss://foo",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP with port to wss without port",
config: config.RedirectScheme{
Scheme: "wss",
},
url: "http://foo:5678",
expectedURL: "wss://foo",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP to HTTPS without port",
config: config.RedirectScheme{
Scheme: "https",
},
url: "http://foo:443",
expectedURL: "https://foo",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP port redirection",
config: config.RedirectScheme{
Scheme: "http",
Port: "8181",
},
url: "http://foo:8080",
expectedURL: "http://foo:8181",
expectedStatus: http.StatusFound,
},
{
desc: "HTTPS with port 80 to HTTPS without port",
config: config.RedirectScheme{
Scheme: "https",
},
url: "https://foo:80",
expectedURL: "https://foo",
expectedStatus: http.StatusFound,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
handler, err := NewRedirectScheme(context.Background(), next, test.config, "traefikTest")
if test.errorExpected {
require.Error(t, err)
require.Nil(t, handler)
} else {
require.NoError(t, err)
require.NotNil(t, handler)
recorder := httptest.NewRecorder()
method := http.MethodGet
if test.method != "" {
method = test.method
}
r := httptest.NewRequest(method, test.url, nil)
if test.secured {
r.TLS = &tls.ConnectionState{}
}
r.Header.Set("X-Foo", "bar")
handler.ServeHTTP(recorder, r)
assert.Equal(t, test.expectedStatus, recorder.Code)
if test.expectedStatus == http.StatusMovedPermanently ||
test.expectedStatus == http.StatusFound ||
test.expectedStatus == http.StatusTemporaryRedirect ||
test.expectedStatus == http.StatusPermanentRedirect {
location, err := recorder.Result().Location()
require.NoError(t, err)
assert.Equal(t, test.expectedURL, location.String())
} else {
location, err := recorder.Result().Location()
require.Errorf(t, err, "Location %v", location)
}
schemeRegex := `^(https?):\/\/([\w\._-]+)(:\d+)?(.*)$`
re, _ := regexp.Compile(schemeRegex)
if re.Match([]byte(test.url)) {
match := re.FindStringSubmatch(test.url)
r.RequestURI = match[4]
handler.ServeHTTP(recorder, r)
assert.Equal(t, test.expectedStatus, recorder.Code)
if test.expectedStatus == http.StatusMovedPermanently ||
test.expectedStatus == http.StatusFound ||
test.expectedStatus == http.StatusTemporaryRedirect ||
test.expectedStatus == http.StatusPermanentRedirect {
location, err := recorder.Result().Location()
require.NoError(t, err)
assert.Equal(t, test.expectedURL, location.String())
} else {
location, err := recorder.Result().Location()
require.Errorf(t, err, "Location %v", location)
}
}
}
})
}
}

View file

@ -0,0 +1,46 @@
package replacepath
import (
"context"
"net/http"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
)
const (
// ReplacedPathHeader is the default header to set the old path to.
ReplacedPathHeader = "X-Replaced-Path"
typeName = "ReplacePath"
)
// ReplacePath is a middleware used to replace the path of a URL request.
type replacePath struct {
next http.Handler
path string
name string
}
// New creates a new replace path middleware.
func New(ctx context.Context, next http.Handler, config config.ReplacePath, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
return &replacePath{
next: next,
path: config.Path,
name: name,
}, nil
}
func (r *replacePath) GetTracingInformation() (string, ext.SpanKindEnum) {
return r.name, tracing.SpanKindNoneEnum
}
func (r *replacePath) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
req.Header.Add(ReplacedPathHeader, req.URL.Path)
req.URL.Path = r.path
req.RequestURI = req.URL.RequestURI()
r.next.ServeHTTP(rw, req)
}

View file

@ -0,0 +1,46 @@
package replacepath
import (
"context"
"net/http"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestReplacePath(t *testing.T) {
var replacementConfig = config.ReplacePath{
Path: "/replacement-path",
}
paths := []string{
"/example",
"/some/really/long/path",
}
for _, path := range paths {
t.Run(path, func(t *testing.T) {
var expectedPath, actualHeader, requestURI string
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expectedPath = r.URL.Path
actualHeader = r.Header.Get(ReplacedPathHeader)
requestURI = r.RequestURI
})
handler, err := New(context.Background(), next, replacementConfig, "foo-replace-path")
require.NoError(t, err)
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+path, nil)
handler.ServeHTTP(nil, req)
assert.Equal(t, expectedPath, replacementConfig.Path, "Unexpected path.")
assert.Equal(t, path, actualHeader, "Unexpected '%s' header.", ReplacedPathHeader)
assert.Equal(t, expectedPath, requestURI, "Unexpected request URI.")
})
}
}

View file

@ -0,0 +1,57 @@
package replacepathregex
import (
"context"
"fmt"
"net/http"
"regexp"
"strings"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/middlewares/replacepath"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
)
const (
typeName = "ReplacePathRegex"
)
// ReplacePathRegex is a middleware used to replace the path of a URL request with a regular expression.
type replacePathRegex struct {
next http.Handler
regexp *regexp.Regexp
replacement string
name string
}
// New creates a new replace path regex middleware.
func New(ctx context.Context, next http.Handler, config config.ReplacePathRegex, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
exp, err := regexp.Compile(strings.TrimSpace(config.Regex))
if err != nil {
return nil, fmt.Errorf("error compiling regular expression %s: %s", config.Regex, err)
}
return &replacePathRegex{
regexp: exp,
replacement: strings.TrimSpace(config.Replacement),
next: next,
name: name,
}, nil
}
func (rp *replacePathRegex) GetTracingInformation() (string, ext.SpanKindEnum) {
return rp.name, tracing.SpanKindNoneEnum
}
func (rp *replacePathRegex) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if rp.regexp != nil && len(rp.replacement) > 0 && rp.regexp.MatchString(req.URL.Path) {
req.Header.Add(replacepath.ReplacedPathHeader, req.URL.Path)
req.URL.Path = rp.regexp.ReplaceAllString(req.URL.Path, rp.replacement)
req.RequestURI = req.URL.RequestURI()
}
rp.next.ServeHTTP(rw, req)
}

View file

@ -0,0 +1,104 @@
package replacepathregex
import (
"context"
"net/http"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares/replacepath"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestReplacePathRegex(t *testing.T) {
testCases := []struct {
desc string
path string
config config.ReplacePathRegex
expectedPath string
expectedHeader string
expectsError bool
}{
{
desc: "simple regex",
path: "/whoami/and/whoami",
config: config.ReplacePathRegex{
Replacement: "/who-am-i/$1",
Regex: `^/whoami/(.*)`,
},
expectedPath: "/who-am-i/and/whoami",
expectedHeader: "/whoami/and/whoami",
},
{
desc: "simple replace (no regex)",
path: "/whoami/and/whoami",
config: config.ReplacePathRegex{
Replacement: "/who-am-i",
Regex: `/whoami`,
},
expectedPath: "/who-am-i/and/who-am-i",
expectedHeader: "/whoami/and/whoami",
},
{
desc: "no match",
path: "/whoami/and/whoami",
config: config.ReplacePathRegex{
Replacement: "/whoami",
Regex: `/no-match`,
},
expectedPath: "/whoami/and/whoami",
},
{
desc: "multiple replacement",
path: "/downloads/src/source.go",
config: config.ReplacePathRegex{
Replacement: "/downloads/$1-$2",
Regex: `^(?i)/downloads/([^/]+)/([^/]+)$`,
},
expectedPath: "/downloads/src-source.go",
expectedHeader: "/downloads/src/source.go",
},
{
desc: "invalid regular expression",
path: "/invalid/regexp/test",
config: config.ReplacePathRegex{
Replacement: "/valid/regexp/$1",
Regex: `^(?err)/invalid/regexp/([^/]+)$`,
},
expectedPath: "/invalid/regexp/test",
expectsError: true,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
var actualPath, actualHeader, requestURI string
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actualPath = r.URL.Path
actualHeader = r.Header.Get(replacepath.ReplacedPathHeader)
requestURI = r.RequestURI
})
handler, err := New(context.Background(), next, test.config, "foo-replace-path-regexp")
if test.expectsError {
require.Error(t, err)
} else {
require.NoError(t, err)
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
req.RequestURI = test.path
handler.ServeHTTP(nil, req)
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
assert.Equal(t, actualPath, requestURI, "Unexpected request URI.")
if test.expectedHeader != "" {
assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", replacepath.ReplacedPathHeader)
}
}
})
}
}

View file

@ -0,0 +1,124 @@
package requestdecorator
import (
"context"
"fmt"
"net"
"sort"
"strings"
"time"
"github.com/containous/traefik/pkg/log"
"github.com/miekg/dns"
"github.com/patrickmn/go-cache"
)
type cnameResolv struct {
TTL time.Duration
Record string
}
type byTTL []*cnameResolv
func (a byTTL) Len() int { return len(a) }
func (a byTTL) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a byTTL) Less(i, j int) bool { return a[i].TTL > a[j].TTL }
// Resolver used for host resolver.
type Resolver struct {
CnameFlattening bool
ResolvConfig string
ResolvDepth int
cache *cache.Cache
}
// CNAMEFlatten check if CNAME record exists, flatten if possible.
func (hr *Resolver) CNAMEFlatten(ctx context.Context, host string) string {
if hr.cache == nil {
hr.cache = cache.New(30*time.Minute, 5*time.Minute)
}
result := host
request := host
value, found := hr.cache.Get(host)
if found {
return value.(string)
}
logger := log.FromContext(ctx)
var cacheDuration = 0 * time.Second
for depth := 0; depth < hr.ResolvDepth; depth++ {
resolv, err := cnameResolve(ctx, request, hr.ResolvConfig)
if err != nil {
logger.Error(err)
break
}
if resolv == nil {
break
}
result = resolv.Record
if depth == 0 {
cacheDuration = resolv.TTL
}
request = resolv.Record
}
if err := hr.cache.Add(host, result, cacheDuration); err != nil {
logger.Error(err)
}
return result
}
// cnameResolve resolves CNAME if exists, and return with the highest TTL.
func cnameResolve(ctx context.Context, host string, resolvPath string) (*cnameResolv, error) {
config, err := dns.ClientConfigFromFile(resolvPath)
if err != nil {
return nil, fmt.Errorf("invalid resolver configuration file: %s", resolvPath)
}
client := &dns.Client{Timeout: 30 * time.Second}
m := &dns.Msg{}
m.SetQuestion(dns.Fqdn(host), dns.TypeCNAME)
var result []*cnameResolv
for _, server := range config.Servers {
tempRecord, err := getRecord(client, m, server, config.Port)
if err != nil {
log.FromContext(ctx).Errorf("Failed to resolve host %s: %v", host, err)
continue
}
result = append(result, tempRecord)
}
if len(result) == 0 {
return nil, nil
}
sort.Sort(byTTL(result))
return result[0], nil
}
func getRecord(client *dns.Client, msg *dns.Msg, server string, port string) (*cnameResolv, error) {
resp, _, err := client.Exchange(msg, net.JoinHostPort(server, port))
if err != nil {
return nil, fmt.Errorf("exchange error for server %s: %v", server, err)
}
if resp == nil || len(resp.Answer) == 0 {
return nil, fmt.Errorf("empty answer for server %s", server)
}
rr, ok := resp.Answer[0].(*dns.CNAME)
if !ok {
return nil, fmt.Errorf("invalid response type for server %s", server)
}
return &cnameResolv{
TTL: time.Duration(rr.Hdr.Ttl) * time.Second,
Record: strings.TrimSuffix(rr.Target, "."),
}, nil
}

View file

@ -0,0 +1,51 @@
package requestdecorator
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCNAMEFlatten(t *testing.T) {
testCases := []struct {
desc string
resolvFile string
domain string
expectedDomain string
}{
{
desc: "host request is CNAME record",
resolvFile: "/etc/resolv.conf",
domain: "www.github.com",
expectedDomain: "github.com",
},
{
desc: "resolve file not found",
resolvFile: "/etc/resolv.oops",
domain: "www.github.com",
expectedDomain: "www.github.com",
},
{
desc: "host request is not CNAME record",
resolvFile: "/etc/resolv.conf",
domain: "github.com",
expectedDomain: "github.com",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
hostResolver := &Resolver{
ResolvConfig: test.resolvFile,
ResolvDepth: 5,
}
flatH := hostResolver.CNAMEFlatten(context.Background(), test.domain)
assert.Equal(t, test.expectedDomain, flatH)
})
}
}

View file

@ -0,0 +1,87 @@
package requestdecorator
import (
"context"
"net"
"net/http"
"strings"
"github.com/containous/alice"
"github.com/containous/traefik/pkg/types"
)
const (
canonicalKey key = "canonical"
flattenKey key = "flatten"
)
type key string
// RequestDecorator is the struct for the middleware that adds the CanonicalDomain of the request Host into a context for later use.
type RequestDecorator struct {
hostResolver *Resolver
}
// New creates a new request host middleware.
func New(hostResolverConfig *types.HostResolverConfig) *RequestDecorator {
requestDecorator := &RequestDecorator{}
if hostResolverConfig != nil {
requestDecorator.hostResolver = &Resolver{
CnameFlattening: hostResolverConfig.CnameFlattening,
ResolvConfig: hostResolverConfig.ResolvConfig,
ResolvDepth: hostResolverConfig.ResolvDepth,
}
}
return requestDecorator
}
func (r *RequestDecorator) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
host := types.CanonicalDomain(parseHost(req.Host))
reqt := req.WithContext(context.WithValue(req.Context(), canonicalKey, host))
if r.hostResolver != nil && r.hostResolver.CnameFlattening {
flatHost := r.hostResolver.CNAMEFlatten(reqt.Context(), host)
reqt = reqt.WithContext(context.WithValue(reqt.Context(), flattenKey, flatHost))
}
next(rw, reqt)
}
func parseHost(addr string) string {
if !strings.Contains(addr, ":") {
return addr
}
host, _, err := net.SplitHostPort(addr)
if err != nil {
return addr
}
return host
}
// GetCanonizedHost retrieves the canonized host from the given context (previously stored in the request context by the middleware).
func GetCanonizedHost(ctx context.Context) string {
if val, ok := ctx.Value(canonicalKey).(string); ok {
return val
}
return ""
}
// GetCNAMEFlatten return the flat name if it is present in the context.
func GetCNAMEFlatten(ctx context.Context) string {
if val, ok := ctx.Value(flattenKey).(string); ok {
return val
}
return ""
}
// WrapHandler Wraps a ServeHTTP with next to an alice.Constructor.
func WrapHandler(handler *RequestDecorator) alice.Constructor {
return func(next http.Handler) (http.Handler, error) {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
handler.ServeHTTP(rw, req, next.ServeHTTP)
}), nil
}
}

View file

@ -0,0 +1,146 @@
package requestdecorator
import (
"net/http"
"testing"
"github.com/containous/traefik/pkg/types"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
)
func TestRequestHost(t *testing.T) {
testCases := []struct {
desc string
url string
expected string
}{
{
desc: "host without :",
url: "http://host",
expected: "host",
},
{
desc: "host with : and without port",
url: "http://host:",
expected: "host",
},
{
desc: "IP host with : and with port",
url: "http://127.0.0.1:123",
expected: "127.0.0.1",
},
{
desc: "IP host with : and without port",
url: "http://127.0.0.1:",
expected: "127.0.0.1",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
host := GetCanonizedHost(r.Context())
assert.Equal(t, test.expected, host)
})
rh := New(nil)
req := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
rh.ServeHTTP(nil, req, next)
})
}
}
func TestRequestFlattening(t *testing.T) {
testCases := []struct {
desc string
url string
expected string
}{
{
desc: "host with flattening",
url: "http://www.github.com",
expected: "github.com",
},
{
desc: "host without flattening",
url: "http://github.com",
expected: "github.com",
},
{
desc: "ip without flattening",
url: "http://127.0.0.1",
expected: "127.0.0.1",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
host := GetCNAMEFlatten(r.Context())
assert.Equal(t, test.expected, host)
})
rh := New(
&types.HostResolverConfig{
CnameFlattening: true,
ResolvConfig: "/etc/resolv.conf",
ResolvDepth: 5,
},
)
req := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
rh.ServeHTTP(nil, req, next)
})
}
}
func TestRequestHostParseHost(t *testing.T) {
testCases := []struct {
desc string
host string
expected string
}{
{
desc: "host without :",
host: "host",
expected: "host",
},
{
desc: "host with : and without port",
host: "host:",
expected: "host",
},
{
desc: "IP host with : and with port",
host: "127.0.0.1:123",
expected: "127.0.0.1",
},
{
desc: "IP host with : and without port",
host: "127.0.0.1:",
expected: "127.0.0.1",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
actual := parseHost(test.host)
assert.Equal(t, test.expected, actual)
})
}
}

View file

@ -0,0 +1,207 @@
package retry
import (
"bufio"
"context"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptrace"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
)
// Compile time validation that the response writer implements http interfaces correctly.
var _ middlewares.Stateful = &responseWriterWithCloseNotify{}
const (
typeName = "Retry"
)
// Listener is used to inform about retry attempts.
type Listener interface {
// Retried will be called when a retry happens, with the request attempt passed to it.
// For the first retry this will be attempt 2.
Retried(req *http.Request, attempt int)
}
// Listeners is a convenience type to construct a list of Listener and notify
// each of them about a retry attempt.
type Listeners []Listener
// retry is a middleware that retries requests.
type retry struct {
attempts int
next http.Handler
listener Listener
name string
}
// New returns a new retry middleware.
func New(ctx context.Context, next http.Handler, config config.Retry, listener Listener, name string) (http.Handler, error) {
logger := middlewares.GetLogger(ctx, name, typeName)
logger.Debug("Creating middleware")
if config.Attempts <= 0 {
return nil, fmt.Errorf("incorrect (or empty) value for attempt (%d)", config.Attempts)
}
return &retry{
attempts: config.Attempts,
next: next,
listener: listener,
name: name,
}, nil
}
func (r *retry) GetTracingInformation() (string, ext.SpanKindEnum) {
return r.name, tracing.SpanKindNoneEnum
}
func (r *retry) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// if we might make multiple attempts, swap the body for an ioutil.NopCloser
// cf https://github.com/containous/traefik/issues/1008
if r.attempts > 1 {
body := req.Body
defer body.Close()
req.Body = ioutil.NopCloser(body)
}
attempts := 1
for {
shouldRetry := attempts < r.attempts
retryResponseWriter := newResponseWriter(rw, shouldRetry)
// Disable retries when the backend already received request data
trace := &httptrace.ClientTrace{
WroteHeaders: func() {
retryResponseWriter.DisableRetries()
},
WroteRequest: func(httptrace.WroteRequestInfo) {
retryResponseWriter.DisableRetries()
},
}
newCtx := httptrace.WithClientTrace(req.Context(), trace)
r.next.ServeHTTP(retryResponseWriter, req.WithContext(newCtx))
if !retryResponseWriter.ShouldRetry() {
break
}
attempts++
logger := middlewares.GetLogger(req.Context(), r.name, typeName)
logger.Debugf("New attempt %d for request: %v", attempts, req.URL)
r.listener.Retried(req, attempts)
}
}
// Retried exists to implement the Listener interface. It calls Retried on each of its slice entries.
func (l Listeners) Retried(req *http.Request, attempt int) {
for _, listener := range l {
listener.Retried(req, attempt)
}
}
type responseWriter interface {
http.ResponseWriter
http.Flusher
ShouldRetry() bool
DisableRetries()
}
func newResponseWriter(rw http.ResponseWriter, shouldRetry bool) responseWriter {
responseWriter := &responseWriterWithoutCloseNotify{
responseWriter: rw,
headers: make(http.Header),
shouldRetry: shouldRetry,
}
if _, ok := rw.(http.CloseNotifier); ok {
return &responseWriterWithCloseNotify{
responseWriterWithoutCloseNotify: responseWriter,
}
}
return responseWriter
}
type responseWriterWithoutCloseNotify struct {
responseWriter http.ResponseWriter
headers http.Header
shouldRetry bool
written bool
}
func (r *responseWriterWithoutCloseNotify) ShouldRetry() bool {
return r.shouldRetry
}
func (r *responseWriterWithoutCloseNotify) DisableRetries() {
r.shouldRetry = false
}
func (r *responseWriterWithoutCloseNotify) Header() http.Header {
if r.written {
return r.responseWriter.Header()
}
return r.headers
}
func (r *responseWriterWithoutCloseNotify) Write(buf []byte) (int, error) {
if r.ShouldRetry() {
return len(buf), nil
}
return r.responseWriter.Write(buf)
}
func (r *responseWriterWithoutCloseNotify) WriteHeader(code int) {
if r.ShouldRetry() && code == http.StatusServiceUnavailable {
// We get a 503 HTTP Status Code when there is no backend server in the pool
// to which the request could be sent. Also, note that r.ShouldRetry()
// will never return true in case there was a connection established to
// the backend server and so we can be sure that the 503 was produced
// inside Traefik already and we don't have to retry in this cases.
r.DisableRetries()
}
if r.ShouldRetry() {
return
}
// In that case retry case is set to false which means we at least managed
// to write headers to the backend : we are not going to perform any further retry.
// So it is now safe to alter current response headers with headers collected during
// the latest try before writing headers to client.
headers := r.responseWriter.Header()
for header, value := range r.headers {
headers[header] = value
}
r.responseWriter.WriteHeader(code)
r.written = true
}
func (r *responseWriterWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := r.responseWriter.(http.Hijacker)
if !ok {
return nil, nil, fmt.Errorf("%T is not a http.Hijacker", r.responseWriter)
}
return hijacker.Hijack()
}
func (r *responseWriterWithoutCloseNotify) Flush() {
if flusher, ok := r.responseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
type responseWriterWithCloseNotify struct {
*responseWriterWithoutCloseNotify
}
func (r *responseWriterWithCloseNotify) CloseNotify() <-chan bool {
return r.responseWriter.(http.CloseNotifier).CloseNotify()
}

View file

@ -0,0 +1,312 @@
package retry
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"net/http/httptrace"
"strings"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares/emptybackendhandler"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/roundrobin"
)
func TestRetry(t *testing.T) {
testCases := []struct {
desc string
config config.Retry
wantRetryAttempts int
wantResponseStatus int
amountFaultyEndpoints int
}{
{
desc: "no retry on success",
config: config.Retry{Attempts: 1},
wantRetryAttempts: 0,
wantResponseStatus: http.StatusOK,
amountFaultyEndpoints: 0,
},
{
desc: "no retry when max request attempts is one",
config: config.Retry{Attempts: 1},
wantRetryAttempts: 0,
wantResponseStatus: http.StatusInternalServerError,
amountFaultyEndpoints: 1,
},
{
desc: "one retry when one server is faulty",
config: config.Retry{Attempts: 2},
wantRetryAttempts: 1,
wantResponseStatus: http.StatusOK,
amountFaultyEndpoints: 1,
},
{
desc: "two retries when two servers are faulty",
config: config.Retry{Attempts: 3},
wantRetryAttempts: 2,
wantResponseStatus: http.StatusOK,
amountFaultyEndpoints: 2,
},
{
desc: "max attempts exhausted delivers the 5xx response",
config: config.Retry{Attempts: 3},
wantRetryAttempts: 2,
wantResponseStatus: http.StatusInternalServerError,
amountFaultyEndpoints: 3,
},
}
backendServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
_, err := rw.Write([]byte("OK"))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}))
forwarder, err := forward.New()
require.NoError(t, err)
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
loadBalancer, err := roundrobin.New(forwarder)
require.NoError(t, err)
basePort := 33444
for i := 0; i < test.amountFaultyEndpoints; i++ {
// 192.0.2.0 is a non-routable IP for testing purposes.
// See: https://stackoverflow.com/questions/528538/non-routable-ip-address/18436928#18436928
// We only use the port specification here because the URL is used as identifier
// in the load balancer and using the exact same URL would not add a new server.
err = loadBalancer.UpsertServer(testhelpers.MustParseURL("http://192.0.2.0:" + string(basePort+i)))
require.NoError(t, err)
}
// add the functioning server to the end of the load balancer list
err = loadBalancer.UpsertServer(testhelpers.MustParseURL(backendServer.URL))
require.NoError(t, err)
retryListener := &countingRetryListener{}
retry, err := New(context.Background(), loadBalancer, test.config, retryListener, "traefikTest")
require.NoError(t, err)
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil)
retry.ServeHTTP(recorder, req)
assert.Equal(t, test.wantResponseStatus, recorder.Code)
assert.Equal(t, test.wantRetryAttempts, retryListener.timesCalled)
})
}
}
func TestRetryEmptyServerList(t *testing.T) {
forwarder, err := forward.New()
require.NoError(t, err)
loadBalancer, err := roundrobin.New(forwarder)
require.NoError(t, err)
// The EmptyBackend middleware ensures that there is a 503
// response status set when there is no backend server in the pool.
next := emptybackendhandler.New(loadBalancer)
retryListener := &countingRetryListener{}
retry, err := New(context.Background(), next, config.Retry{Attempts: 3}, retryListener, "traefikTest")
require.NoError(t, err)
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil)
retry.ServeHTTP(recorder, req)
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
assert.Equal(t, 0, retryListener.timesCalled)
}
func TestRetryListeners(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
retryListeners := Listeners{&countingRetryListener{}, &countingRetryListener{}}
retryListeners.Retried(req, 1)
retryListeners.Retried(req, 1)
for _, retryListener := range retryListeners {
listener := retryListener.(*countingRetryListener)
if listener.timesCalled != 2 {
t.Errorf("retry listener was called %d time(s), want %d time(s)", listener.timesCalled, 2)
}
}
}
func TestMultipleRetriesShouldNotLooseHeaders(t *testing.T) {
attempt := 0
expectedHeaderName := "X-Foo-Test-2"
expectedHeaderValue := "bar"
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
headerName := fmt.Sprintf("X-Foo-Test-%d", attempt)
rw.Header().Add(headerName, expectedHeaderValue)
if attempt < 2 {
attempt++
return
}
// Request has been successfully written to backend
trace := httptrace.ContextClientTrace(req.Context())
trace.WroteHeaders()
// And we decide to answer to client
rw.WriteHeader(http.StatusNoContent)
})
retry, err := New(context.Background(), next, config.Retry{Attempts: 3}, &countingRetryListener{}, "traefikTest")
require.NoError(t, err)
responseRecorder := httptest.NewRecorder()
retry.ServeHTTP(responseRecorder, testhelpers.MustNewRequest(http.MethodGet, "http://test", http.NoBody))
headerValue := responseRecorder.Header().Get(expectedHeaderName)
// Validate if we have the correct header
if headerValue != expectedHeaderValue {
t.Errorf("Expected to have %s for header %s, got %s", expectedHeaderValue, expectedHeaderName, headerValue)
}
// Validate that we don't have headers from previous attempts
for i := 0; i < attempt; i++ {
headerName := fmt.Sprintf("X-Foo-Test-%d", i)
headerValue = responseRecorder.Header().Get("headerName")
if headerValue != "" {
t.Errorf("Expected no value for header %s, got %s", headerName, headerValue)
}
}
}
// countingRetryListener is a Listener implementation to count the times the Retried fn is called.
type countingRetryListener struct {
timesCalled int
}
func (l *countingRetryListener) Retried(req *http.Request, attempt int) {
l.timesCalled++
}
func TestRetryWithFlush(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(200)
_, err := rw.Write([]byte("FULL "))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
rw.(http.Flusher).Flush()
_, err = rw.Write([]byte("DATA"))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
})
retry, err := New(context.Background(), next, config.Retry{Attempts: 1}, &countingRetryListener{}, "traefikTest")
require.NoError(t, err)
responseRecorder := httptest.NewRecorder()
retry.ServeHTTP(responseRecorder, &http.Request{})
assert.Equal(t, "FULL DATA", responseRecorder.Body.String())
}
func TestRetryWebsocket(t *testing.T) {
testCases := []struct {
desc string
maxRequestAttempts int
expectedRetryAttempts int
expectedResponseStatus int
expectedError bool
amountFaultyEndpoints int
}{
{
desc: "Switching ok after 2 retries",
maxRequestAttempts: 3,
expectedRetryAttempts: 2,
amountFaultyEndpoints: 2,
expectedResponseStatus: http.StatusSwitchingProtocols,
},
{
desc: "Switching failed",
maxRequestAttempts: 2,
expectedRetryAttempts: 1,
amountFaultyEndpoints: 2,
expectedResponseStatus: http.StatusBadGateway,
expectedError: true,
},
}
forwarder, err := forward.New()
if err != nil {
t.Fatalf("Error creating forwarder: %v", err)
}
backendServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
upgrader := websocket.Upgrader{}
_, err := upgrader.Upgrade(rw, req, nil)
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}))
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
loadBalancer, err := roundrobin.New(forwarder)
if err != nil {
t.Fatalf("Error creating load balancer: %v", err)
}
basePort := 33444
for i := 0; i < test.amountFaultyEndpoints; i++ {
// 192.0.2.0 is a non-routable IP for testing purposes.
// See: https://stackoverflow.com/questions/528538/non-routable-ip-address/18436928#18436928
// We only use the port specification here because the URL is used as identifier
// in the load balancer and using the exact same URL would not add a new server.
_ = loadBalancer.UpsertServer(testhelpers.MustParseURL("http://192.0.2.0:" + string(basePort+i)))
}
// add the functioning server to the end of the load balancer list
err = loadBalancer.UpsertServer(testhelpers.MustParseURL(backendServer.URL))
if err != nil {
t.Fatalf("Fail to upsert server: %v", err)
}
retryListener := &countingRetryListener{}
retryH, err := New(context.Background(), loadBalancer, config.Retry{Attempts: test.maxRequestAttempts}, retryListener, "traefikTest")
require.NoError(t, err)
retryServer := httptest.NewServer(retryH)
url := strings.Replace(retryServer.URL, "http", "ws", 1)
_, response, err := websocket.DefaultDialer.Dial(url, nil)
if !test.expectedError {
require.NoError(t, err)
}
assert.Equal(t, test.expectedResponseStatus, response.StatusCode)
assert.Equal(t, test.expectedRetryAttempts, retryListener.timesCalled)
})
}
}

Some files were not shown because too many files have changed in this diff Show more