Move code to pkg
This commit is contained in:
parent
bd4c822670
commit
f1b085fa36
465 changed files with 656 additions and 680 deletions
136
pkg/anonymize/anonymize.go
Normal file
136
pkg/anonymize/anonymize.go
Normal file
|
@ -0,0 +1,136 @@
|
|||
package anonymize
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
|
||||
"github.com/mitchellh/copystructure"
|
||||
"github.com/mvdan/xurls"
|
||||
)
|
||||
|
||||
const (
|
||||
maskShort = "xxxx"
|
||||
maskLarge = maskShort + maskShort + maskShort + maskShort + maskShort + maskShort + maskShort + maskShort
|
||||
)
|
||||
|
||||
// Do configuration.
|
||||
func Do(baseConfig interface{}, indent bool) (string, error) {
|
||||
anomConfig, err := copystructure.Copy(baseConfig)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
val := reflect.ValueOf(anomConfig)
|
||||
|
||||
err = doOnStruct(val)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
configJSON, err := marshal(anomConfig, indent)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return doOnJSON(string(configJSON)), nil
|
||||
}
|
||||
|
||||
func doOnJSON(input string) string {
|
||||
mailExp := regexp.MustCompile(`\w[-._\w]*\w@\w[-._\w]*\w\.\w{2,3}"`)
|
||||
return xurls.Relaxed.ReplaceAllString(mailExp.ReplaceAllString(input, maskLarge+"\""), maskLarge)
|
||||
}
|
||||
|
||||
func doOnStruct(field reflect.Value) error {
|
||||
switch field.Kind() {
|
||||
case reflect.Ptr:
|
||||
if !field.IsNil() {
|
||||
if err := doOnStruct(field.Elem()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
for i := 0; i < field.NumField(); i++ {
|
||||
fld := field.Field(i)
|
||||
stField := field.Type().Field(i)
|
||||
if !isExported(stField) {
|
||||
continue
|
||||
}
|
||||
if stField.Tag.Get("export") == "true" {
|
||||
if err := doOnStruct(fld); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := reset(fld, stField.Name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Map:
|
||||
for _, key := range field.MapKeys() {
|
||||
if err := doOnStruct(field.MapIndex(key)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case reflect.Slice:
|
||||
for j := 0; j < field.Len(); j++ {
|
||||
if err := doOnStruct(field.Index(j)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func reset(field reflect.Value, name string) error {
|
||||
if !field.CanSet() {
|
||||
return fmt.Errorf("cannot reset field %s", name)
|
||||
}
|
||||
|
||||
switch field.Kind() {
|
||||
case reflect.Ptr:
|
||||
if !field.IsNil() {
|
||||
field.Set(reflect.Zero(field.Type()))
|
||||
}
|
||||
case reflect.Struct:
|
||||
if field.IsValid() {
|
||||
field.Set(reflect.Zero(field.Type()))
|
||||
}
|
||||
case reflect.String:
|
||||
if field.String() != "" {
|
||||
field.Set(reflect.ValueOf(maskShort))
|
||||
}
|
||||
case reflect.Map:
|
||||
if field.Len() > 0 {
|
||||
field.Set(reflect.MakeMap(field.Type()))
|
||||
}
|
||||
case reflect.Slice:
|
||||
if field.Len() > 0 {
|
||||
field.Set(reflect.MakeSlice(field.Type(), 0, 0))
|
||||
}
|
||||
case reflect.Interface:
|
||||
if !field.IsNil() {
|
||||
return reset(field.Elem(), "")
|
||||
}
|
||||
default:
|
||||
// Primitive type
|
||||
field.Set(reflect.Zero(field.Type()))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isExported return true is a struct field is exported, else false
|
||||
func isExported(f reflect.StructField) bool {
|
||||
if f.PkgPath != "" && !f.Anonymous {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func marshal(anomConfig interface{}, indent bool) ([]byte, error) {
|
||||
if indent {
|
||||
return json.MarshalIndent(anomConfig, "", " ")
|
||||
}
|
||||
return json.Marshal(anomConfig)
|
||||
}
|
329
pkg/anonymize/anonymize_config_test.go
Normal file
329
pkg/anonymize/anonymize_config_test.go
Normal file
|
@ -0,0 +1,329 @@
|
|||
package anonymize
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/containous/flaeg/parse"
|
||||
"github.com/containous/traefik/pkg/config/static"
|
||||
"github.com/containous/traefik/pkg/ping"
|
||||
"github.com/containous/traefik/pkg/provider"
|
||||
"github.com/containous/traefik/pkg/provider/acme"
|
||||
acmeprovider "github.com/containous/traefik/pkg/provider/acme"
|
||||
"github.com/containous/traefik/pkg/provider/docker"
|
||||
"github.com/containous/traefik/pkg/provider/file"
|
||||
"github.com/containous/traefik/pkg/provider/kubernetes/crd"
|
||||
"github.com/containous/traefik/pkg/provider/kubernetes/ingress"
|
||||
traefiktls "github.com/containous/traefik/pkg/tls"
|
||||
"github.com/containous/traefik/pkg/tracing/datadog"
|
||||
"github.com/containous/traefik/pkg/tracing/instana"
|
||||
"github.com/containous/traefik/pkg/tracing/jaeger"
|
||||
"github.com/containous/traefik/pkg/tracing/zipkin"
|
||||
"github.com/containous/traefik/pkg/types"
|
||||
assetfs "github.com/elazarl/go-bindata-assetfs"
|
||||
)
|
||||
|
||||
func TestDo_globalConfiguration(t *testing.T) {
|
||||
|
||||
config := &static.Configuration{}
|
||||
|
||||
sendAnonymousUsage := true
|
||||
config.Global = &static.Global{
|
||||
Debug: true,
|
||||
CheckNewVersion: true,
|
||||
SendAnonymousUsage: &sendAnonymousUsage,
|
||||
}
|
||||
|
||||
config.AccessLog = &types.AccessLog{
|
||||
FilePath: "AccessLog FilePath",
|
||||
Format: "AccessLog Format",
|
||||
Filters: &types.AccessLogFilters{
|
||||
StatusCodes: types.StatusCodes{"200", "500"},
|
||||
RetryAttempts: true,
|
||||
MinDuration: 10,
|
||||
},
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "drop",
|
||||
Names: types.FieldNames{
|
||||
"RequestHost": "keep",
|
||||
},
|
||||
Headers: &types.FieldHeaders{
|
||||
DefaultMode: "drop",
|
||||
Names: types.FieldHeaderNames{
|
||||
"Referer": "keep",
|
||||
},
|
||||
},
|
||||
},
|
||||
BufferingSize: 4,
|
||||
}
|
||||
|
||||
config.Log = &types.TraefikLog{
|
||||
LogLevel: "LogLevel",
|
||||
FilePath: "/foo/path",
|
||||
Format: "json",
|
||||
}
|
||||
|
||||
config.EntryPoints = static.EntryPoints{
|
||||
"foo": {
|
||||
Address: "foo Address",
|
||||
Transport: &static.EntryPointsTransport{
|
||||
RespondingTimeouts: &static.RespondingTimeouts{
|
||||
ReadTimeout: parse.Duration(111 * time.Second),
|
||||
WriteTimeout: parse.Duration(111 * time.Second),
|
||||
IdleTimeout: parse.Duration(111 * time.Second),
|
||||
},
|
||||
},
|
||||
ProxyProtocol: &static.ProxyProtocol{
|
||||
TrustedIPs: []string{"127.0.0.1/32", "192.168.0.1"},
|
||||
},
|
||||
},
|
||||
"fii": {
|
||||
Address: "fii Address",
|
||||
Transport: &static.EntryPointsTransport{
|
||||
RespondingTimeouts: &static.RespondingTimeouts{
|
||||
ReadTimeout: parse.Duration(111 * time.Second),
|
||||
WriteTimeout: parse.Duration(111 * time.Second),
|
||||
IdleTimeout: parse.Duration(111 * time.Second),
|
||||
},
|
||||
},
|
||||
ProxyProtocol: &static.ProxyProtocol{
|
||||
TrustedIPs: []string{"127.0.0.1/32", "192.168.0.1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
config.ACME = &acme.Configuration{
|
||||
Email: "acme Email",
|
||||
ACMELogging: true,
|
||||
CAServer: "CAServer",
|
||||
Storage: "Storage",
|
||||
EntryPoint: "EntryPoint",
|
||||
KeyType: "MyKeyType",
|
||||
OnHostRule: true,
|
||||
DNSChallenge: &acmeprovider.DNSChallenge{Provider: "DNSProvider"},
|
||||
HTTPChallenge: &acmeprovider.HTTPChallenge{
|
||||
EntryPoint: "MyEntryPoint",
|
||||
},
|
||||
TLSChallenge: &acmeprovider.TLSChallenge{},
|
||||
Domains: []types.Domain{
|
||||
{
|
||||
Main: "Domains Main",
|
||||
SANs: []string{"Domains acme SANs 1", "Domains acme SANs 2", "Domains acme SANs 3"},
|
||||
},
|
||||
},
|
||||
}
|
||||
config.Providers = &static.Providers{
|
||||
ProvidersThrottleDuration: parse.Duration(111 * time.Second),
|
||||
}
|
||||
|
||||
config.ServersTransport = &static.ServersTransport{
|
||||
InsecureSkipVerify: true,
|
||||
RootCAs: traefiktls.FilesOrContents{"RootCAs 1", "RootCAs 2", "RootCAs 3"},
|
||||
MaxIdleConnsPerHost: 111,
|
||||
ForwardingTimeouts: &static.ForwardingTimeouts{
|
||||
DialTimeout: parse.Duration(111 * time.Second),
|
||||
ResponseHeaderTimeout: parse.Duration(111 * time.Second),
|
||||
},
|
||||
}
|
||||
|
||||
config.API = &static.API{
|
||||
EntryPoint: "traefik",
|
||||
Dashboard: true,
|
||||
Statistics: &types.Statistics{
|
||||
RecentErrors: 111,
|
||||
},
|
||||
DashboardAssets: &assetfs.AssetFS{
|
||||
Asset: func(path string) ([]byte, error) {
|
||||
return nil, nil
|
||||
},
|
||||
AssetDir: func(path string) ([]string, error) {
|
||||
return nil, nil
|
||||
},
|
||||
AssetInfo: func(path string) (os.FileInfo, error) {
|
||||
return nil, nil
|
||||
},
|
||||
Prefix: "fii",
|
||||
},
|
||||
Middlewares: []string{"first", "second"},
|
||||
}
|
||||
|
||||
config.Providers.File = &file.Provider{
|
||||
BaseProvider: provider.BaseProvider{
|
||||
Watch: true,
|
||||
Filename: "file Filename",
|
||||
Constraints: types.Constraints{
|
||||
{
|
||||
Key: "file Constraints Key 1",
|
||||
Regex: "file Constraints Regex 2",
|
||||
MustMatch: true,
|
||||
},
|
||||
{
|
||||
Key: "file Constraints Key 1",
|
||||
Regex: "file Constraints Regex 2",
|
||||
MustMatch: true,
|
||||
},
|
||||
},
|
||||
Trace: true,
|
||||
DebugLogGeneratedTemplate: true,
|
||||
},
|
||||
Directory: "file Directory",
|
||||
}
|
||||
|
||||
config.Providers.Docker = &docker.Provider{
|
||||
BaseProvider: provider.BaseProvider{
|
||||
Watch: true,
|
||||
Filename: "myfilename",
|
||||
Constraints: nil,
|
||||
Trace: true,
|
||||
DebugLogGeneratedTemplate: true,
|
||||
},
|
||||
Endpoint: "MyEndPoint",
|
||||
DefaultRule: "PathPrefix(`/`)",
|
||||
TLS: &types.ClientTLS{
|
||||
CA: "myCa",
|
||||
CAOptional: true,
|
||||
Cert: "mycert.pem",
|
||||
Key: "mycert.key",
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
ExposedByDefault: true,
|
||||
UseBindPortIP: true,
|
||||
SwarmMode: true,
|
||||
Network: "MyNetwork",
|
||||
SwarmModeRefreshSeconds: 42,
|
||||
}
|
||||
|
||||
config.Providers.Kubernetes = &ingress.Provider{
|
||||
BaseProvider: provider.BaseProvider{
|
||||
Watch: true,
|
||||
Filename: "myFileName",
|
||||
Constraints: types.Constraints{
|
||||
{
|
||||
Key: "k8s Constraints Key 1",
|
||||
Regex: "k8s Constraints Regex 2",
|
||||
MustMatch: true,
|
||||
},
|
||||
{
|
||||
Key: "k8s Constraints Key 1",
|
||||
Regex: "k8s Constraints Regex 2",
|
||||
MustMatch: true,
|
||||
},
|
||||
},
|
||||
Trace: true,
|
||||
DebugLogGeneratedTemplate: true,
|
||||
},
|
||||
Endpoint: "MyEndpoint",
|
||||
Token: "MyToken",
|
||||
CertAuthFilePath: "MyCertAuthPath",
|
||||
DisablePassHostHeaders: true,
|
||||
EnablePassTLSCert: true,
|
||||
Namespaces: []string{"a", "b"},
|
||||
LabelSelector: "myLabelSelector",
|
||||
IngressClass: "MyIngressClass",
|
||||
}
|
||||
|
||||
config.Providers.KubernetesCRD = &crd.Provider{
|
||||
BaseProvider: provider.BaseProvider{
|
||||
Watch: true,
|
||||
Filename: "myFileName",
|
||||
Constraints: types.Constraints{
|
||||
{
|
||||
Key: "k8s Constraints Key 1",
|
||||
Regex: "k8s Constraints Regex 2",
|
||||
MustMatch: true,
|
||||
},
|
||||
{
|
||||
Key: "k8s Constraints Key 1",
|
||||
Regex: "k8s Constraints Regex 2",
|
||||
MustMatch: true,
|
||||
},
|
||||
},
|
||||
Trace: true,
|
||||
DebugLogGeneratedTemplate: true,
|
||||
},
|
||||
Endpoint: "MyEndpoint",
|
||||
Token: "MyToken",
|
||||
CertAuthFilePath: "MyCertAuthPath",
|
||||
DisablePassHostHeaders: true,
|
||||
EnablePassTLSCert: true,
|
||||
Namespaces: []string{"a", "b"},
|
||||
LabelSelector: "myLabelSelector",
|
||||
IngressClass: "MyIngressClass",
|
||||
}
|
||||
|
||||
// FIXME Test the other providers once they are migrated
|
||||
|
||||
config.Metrics = &types.Metrics{
|
||||
Prometheus: &types.Prometheus{
|
||||
Buckets: types.Buckets{0.1, 0.3, 1.2, 5},
|
||||
EntryPoint: "MyEntryPoint",
|
||||
Middlewares: []string{"m1", "m2"},
|
||||
},
|
||||
Datadog: &types.Datadog{
|
||||
Address: "localhost:8181",
|
||||
PushInterval: "12",
|
||||
},
|
||||
StatsD: &types.Statsd{
|
||||
Address: "localhost:8182",
|
||||
PushInterval: "42",
|
||||
},
|
||||
InfluxDB: &types.InfluxDB{
|
||||
Address: "localhost:8183",
|
||||
Protocol: "http",
|
||||
PushInterval: "22",
|
||||
Database: "myDB",
|
||||
RetentionPolicy: "12",
|
||||
Username: "a",
|
||||
Password: "aaaa",
|
||||
},
|
||||
}
|
||||
|
||||
config.Ping = &ping.Handler{
|
||||
EntryPoint: "MyEntryPoint",
|
||||
Middlewares: []string{"m1", "m2", "m3"},
|
||||
}
|
||||
|
||||
config.Tracing = &static.Tracing{
|
||||
Backend: "myBackend",
|
||||
ServiceName: "myServiceName",
|
||||
SpanNameLimit: 3,
|
||||
Jaeger: &jaeger.Config{
|
||||
SamplingServerURL: "aaa",
|
||||
SamplingType: "bbb",
|
||||
SamplingParam: 43,
|
||||
LocalAgentHostPort: "ccc",
|
||||
Gen128Bit: true,
|
||||
Propagation: "ddd",
|
||||
TraceContextHeaderName: "eee",
|
||||
},
|
||||
Zipkin: &zipkin.Config{
|
||||
HTTPEndpoint: "fff",
|
||||
SameSpan: true,
|
||||
ID128Bit: true,
|
||||
Debug: true,
|
||||
SampleRate: 53,
|
||||
},
|
||||
DataDog: &datadog.Config{
|
||||
LocalAgentHostPort: "ggg",
|
||||
GlobalTag: "eee",
|
||||
Debug: true,
|
||||
PrioritySampling: true,
|
||||
},
|
||||
Instana: &instana.Config{
|
||||
LocalAgentHost: "fff",
|
||||
LocalAgentPort: 32,
|
||||
LogLevel: "ggg",
|
||||
},
|
||||
}
|
||||
|
||||
config.HostResolver = &types.HostResolverConfig{
|
||||
CnameFlattening: true,
|
||||
ResolvConfig: "aaa",
|
||||
ResolvDepth: 3,
|
||||
}
|
||||
|
||||
cleanJSON, err := Do(config, true)
|
||||
if err != nil {
|
||||
t.Fatal(err, cleanJSON)
|
||||
}
|
||||
}
|
225
pkg/anonymize/anonymize_doOnJSON_test.go
Normal file
225
pkg/anonymize/anonymize_doOnJSON_test.go
Normal file
|
@ -0,0 +1,225 @@
|
|||
package anonymize
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_doOnJSON(t *testing.T) {
|
||||
baseConfiguration := `
|
||||
{
|
||||
"GraceTimeOut": 10000000000,
|
||||
"Debug": false,
|
||||
"CheckNewVersion": true,
|
||||
"AccessLogsFile": "",
|
||||
"TraefikLogsFile": "",
|
||||
"LogLevel": "ERROR",
|
||||
"EntryPoints": {
|
||||
"http": {
|
||||
"Network": "",
|
||||
"Address": ":80",
|
||||
"TLS": null,
|
||||
"Auth": null,
|
||||
"Compress": false
|
||||
},
|
||||
"https": {
|
||||
"Address": ":443",
|
||||
"TLS": {
|
||||
"MinVersion": "",
|
||||
"CipherSuites": null,
|
||||
"Certificates": null,
|
||||
"ClientCAFiles": null
|
||||
},
|
||||
"Auth": null,
|
||||
"Compress": false
|
||||
}
|
||||
},
|
||||
"Cluster": null,
|
||||
"Constraints": [],
|
||||
"ACME": {
|
||||
"Email": "foo@bar.com",
|
||||
"Domains": [
|
||||
{
|
||||
"Main": "foo@bar.com",
|
||||
"SANs": null
|
||||
},
|
||||
{
|
||||
"Main": "foo@bar.com",
|
||||
"SANs": null
|
||||
}
|
||||
],
|
||||
"Storage": "",
|
||||
"StorageFile": "/acme/acme.json",
|
||||
"OnDemand": true,
|
||||
"OnHostRule": true,
|
||||
"CAServer": "",
|
||||
"EntryPoint": "https",
|
||||
"DNSProvider": "",
|
||||
"DelayDontCheckDNS": 0,
|
||||
"ACMELogging": false,
|
||||
"TLSOptions": null
|
||||
},
|
||||
"DefaultEntryPoints": [
|
||||
"https",
|
||||
"http"
|
||||
],
|
||||
"ProvidersThrottleDuration": 2000000000,
|
||||
"MaxIdleConnsPerHost": 200,
|
||||
"IdleTimeout": 180000000000,
|
||||
"InsecureSkipVerify": false,
|
||||
"Retry": null,
|
||||
"HealthCheck": {
|
||||
"Interval": 30000000000
|
||||
},
|
||||
"Docker": null,
|
||||
"File": null,
|
||||
"Web": null,
|
||||
"Marathon": null,
|
||||
"Consul": null,
|
||||
"ConsulCatalog": null,
|
||||
"Etcd": null,
|
||||
"Zookeeper": null,
|
||||
"Boltdb": null,
|
||||
"Kubernetes": null,
|
||||
"Mesos": null,
|
||||
"Eureka": null,
|
||||
"ECS": null,
|
||||
"Rancher": null,
|
||||
"DynamoDB": null,
|
||||
"ConfigFile": "/etc/traefik/traefik.toml"
|
||||
}
|
||||
`
|
||||
expectedConfiguration := `
|
||||
{
|
||||
"GraceTimeOut": 10000000000,
|
||||
"Debug": false,
|
||||
"CheckNewVersion": true,
|
||||
"AccessLogsFile": "",
|
||||
"TraefikLogsFile": "",
|
||||
"LogLevel": "ERROR",
|
||||
"EntryPoints": {
|
||||
"http": {
|
||||
"Network": "",
|
||||
"Address": ":80",
|
||||
"TLS": null,
|
||||
"Auth": null,
|
||||
"Compress": false
|
||||
},
|
||||
"https": {
|
||||
"Address": ":443",
|
||||
"TLS": {
|
||||
"MinVersion": "",
|
||||
"CipherSuites": null,
|
||||
"Certificates": null,
|
||||
"ClientCAFiles": null
|
||||
},
|
||||
"Auth": null,
|
||||
"Compress": false
|
||||
}
|
||||
},
|
||||
"Cluster": null,
|
||||
"Constraints": [],
|
||||
"ACME": {
|
||||
"Email": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"Domains": [
|
||||
{
|
||||
"Main": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"SANs": null
|
||||
},
|
||||
{
|
||||
"Main": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"SANs": null
|
||||
}
|
||||
],
|
||||
"Storage": "",
|
||||
"StorageFile": "/acme/acme.json",
|
||||
"OnDemand": true,
|
||||
"OnHostRule": true,
|
||||
"CAServer": "",
|
||||
"EntryPoint": "https",
|
||||
"DNSProvider": "",
|
||||
"DelayDontCheckDNS": 0,
|
||||
"ACMELogging": false,
|
||||
"TLSOptions": null
|
||||
},
|
||||
"DefaultEntryPoints": [
|
||||
"https",
|
||||
"http"
|
||||
],
|
||||
"ProvidersThrottleDuration": 2000000000,
|
||||
"MaxIdleConnsPerHost": 200,
|
||||
"IdleTimeout": 180000000000,
|
||||
"InsecureSkipVerify": false,
|
||||
"Retry": null,
|
||||
"HealthCheck": {
|
||||
"Interval": 30000000000
|
||||
},
|
||||
"Docker": null,
|
||||
"File": null,
|
||||
"Web": null,
|
||||
"Marathon": null,
|
||||
"Consul": null,
|
||||
"ConsulCatalog": null,
|
||||
"Etcd": null,
|
||||
"Zookeeper": null,
|
||||
"Boltdb": null,
|
||||
"Kubernetes": null,
|
||||
"Mesos": null,
|
||||
"Eureka": null,
|
||||
"ECS": null,
|
||||
"Rancher": null,
|
||||
"DynamoDB": null,
|
||||
"ConfigFile": "/etc/traefik/traefik.toml"
|
||||
}
|
||||
`
|
||||
anomConfiguration := doOnJSON(baseConfiguration)
|
||||
|
||||
if anomConfiguration != expectedConfiguration {
|
||||
t.Errorf("Got %s, want %s.", anomConfiguration, expectedConfiguration)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_doOnJSON_simple(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedOutput string
|
||||
}{
|
||||
{
|
||||
name: "email",
|
||||
input: `{
|
||||
"email1": "goo@example.com",
|
||||
"email2": "foo.bargoo@example.com",
|
||||
"email3": "foo.bargoo@example.com.us"
|
||||
}`,
|
||||
expectedOutput: `{
|
||||
"email1": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"email2": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
|
||||
"email3": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "url",
|
||||
input: `{
|
||||
"URL": "foo domain.com foo",
|
||||
"URL": "foo sub.domain.com foo",
|
||||
"URL": "foo sub.sub.domain.com foo",
|
||||
"URL": "foo sub.sub.sub.domain.com.us foo"
|
||||
}`,
|
||||
expectedOutput: `{
|
||||
"URL": "foo xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx foo",
|
||||
"URL": "foo xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx foo",
|
||||
"URL": "foo xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx foo",
|
||||
"URL": "foo xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx foo"
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
output := doOnJSON(test.input)
|
||||
assert.Equal(t, test.expectedOutput, output)
|
||||
})
|
||||
}
|
||||
}
|
176
pkg/anonymize/anonymize_doOnStruct_test.go
Normal file
176
pkg/anonymize/anonymize_doOnStruct_test.go
Normal file
|
@ -0,0 +1,176 @@
|
|||
package anonymize
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type Courgette struct {
|
||||
Ji string
|
||||
Ho string
|
||||
}
|
||||
type Tomate struct {
|
||||
Ji string
|
||||
Ho string
|
||||
}
|
||||
|
||||
type Carotte struct {
|
||||
Name string
|
||||
Value int
|
||||
Courgette Courgette
|
||||
ECourgette Courgette `export:"true"`
|
||||
Pourgette *Courgette
|
||||
EPourgette *Courgette `export:"true"`
|
||||
Aubergine map[string]string
|
||||
EAubergine map[string]string `export:"true"`
|
||||
SAubergine map[string]Tomate
|
||||
ESAubergine map[string]Tomate `export:"true"`
|
||||
PSAubergine map[string]*Tomate
|
||||
EPAubergine map[string]*Tomate `export:"true"`
|
||||
}
|
||||
|
||||
func Test_doOnStruct(t *testing.T) {
|
||||
testCase := []struct {
|
||||
name string
|
||||
base *Carotte
|
||||
expected *Carotte
|
||||
hasError bool
|
||||
}{
|
||||
{
|
||||
name: "primitive",
|
||||
base: &Carotte{
|
||||
Name: "koko",
|
||||
Value: 666,
|
||||
},
|
||||
expected: &Carotte{
|
||||
Name: "xxxx",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "struct",
|
||||
base: &Carotte{
|
||||
Name: "koko",
|
||||
Courgette: Courgette{
|
||||
Ji: "huu",
|
||||
},
|
||||
},
|
||||
expected: &Carotte{
|
||||
Name: "xxxx",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "pointer",
|
||||
base: &Carotte{
|
||||
Name: "koko",
|
||||
Pourgette: &Courgette{
|
||||
Ji: "hoo",
|
||||
},
|
||||
},
|
||||
expected: &Carotte{
|
||||
Name: "xxxx",
|
||||
Pourgette: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "export struct",
|
||||
base: &Carotte{
|
||||
Name: "koko",
|
||||
ECourgette: Courgette{
|
||||
Ji: "huu",
|
||||
},
|
||||
},
|
||||
expected: &Carotte{
|
||||
Name: "xxxx",
|
||||
ECourgette: Courgette{
|
||||
Ji: "xxxx",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "export pointer struct",
|
||||
base: &Carotte{
|
||||
Name: "koko",
|
||||
ECourgette: Courgette{
|
||||
Ji: "huu",
|
||||
},
|
||||
},
|
||||
expected: &Carotte{
|
||||
Name: "xxxx",
|
||||
ECourgette: Courgette{
|
||||
Ji: "xxxx",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "export map string/string",
|
||||
base: &Carotte{
|
||||
Name: "koko",
|
||||
EAubergine: map[string]string{
|
||||
"foo": "bar",
|
||||
},
|
||||
},
|
||||
expected: &Carotte{
|
||||
Name: "xxxx",
|
||||
EAubergine: map[string]string{
|
||||
"foo": "bar",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "export map string/pointer",
|
||||
base: &Carotte{
|
||||
Name: "koko",
|
||||
EPAubergine: map[string]*Tomate{
|
||||
"foo": {
|
||||
Ji: "fdskljf",
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &Carotte{
|
||||
Name: "xxxx",
|
||||
EPAubergine: map[string]*Tomate{
|
||||
"foo": {
|
||||
Ji: "xxxx",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "export map string/struct (UNSAFE)",
|
||||
base: &Carotte{
|
||||
Name: "koko",
|
||||
ESAubergine: map[string]Tomate{
|
||||
"foo": {
|
||||
Ji: "JiJiJi",
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &Carotte{
|
||||
Name: "xxxx",
|
||||
ESAubergine: map[string]Tomate{
|
||||
"foo": {
|
||||
Ji: "JiJiJi",
|
||||
},
|
||||
},
|
||||
},
|
||||
hasError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCase {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
val := reflect.ValueOf(test.base).Elem()
|
||||
err := doOnStruct(val)
|
||||
if !test.hasError && err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if test.hasError && err == nil {
|
||||
t.Fatal("Got no error but want an error.")
|
||||
}
|
||||
|
||||
assert.EqualValues(t, test.expected, test.base)
|
||||
})
|
||||
}
|
||||
}
|
39
pkg/api/dashboard.go
Normal file
39
pkg/api/dashboard.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
assetfs "github.com/elazarl/go-bindata-assetfs"
|
||||
)
|
||||
|
||||
// DashboardHandler expose dashboard routes
|
||||
type DashboardHandler struct {
|
||||
Assets *assetfs.AssetFS
|
||||
}
|
||||
|
||||
// Append add dashboard routes on a router
|
||||
func (g DashboardHandler) Append(router *mux.Router) {
|
||||
if g.Assets == nil {
|
||||
log.WithoutContext().Error("No assets for dashboard")
|
||||
return
|
||||
}
|
||||
|
||||
// Expose dashboard
|
||||
router.Methods(http.MethodGet).
|
||||
Path("/").
|
||||
HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
|
||||
http.Redirect(response, request, request.Header.Get("X-Forwarded-Prefix")+"/dashboard/", http.StatusFound)
|
||||
})
|
||||
|
||||
router.Methods(http.MethodGet).
|
||||
Path("/dashboard/status").
|
||||
HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
|
||||
http.Redirect(response, request, "/dashboard/", http.StatusFound)
|
||||
})
|
||||
|
||||
router.Methods(http.MethodGet).
|
||||
PathPrefix("/dashboard/").
|
||||
Handler(http.StripPrefix("/dashboard/", http.FileServer(g.Assets)))
|
||||
}
|
49
pkg/api/debug.go
Normal file
49
pkg/api/debug.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"expvar"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/pprof"
|
||||
"runtime"
|
||||
|
||||
"github.com/containous/mux"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// FIXME Goroutines2 -> Goroutines
|
||||
expvar.Publish("Goroutines2", expvar.Func(goroutines))
|
||||
}
|
||||
|
||||
func goroutines() interface{} {
|
||||
return runtime.NumGoroutine()
|
||||
}
|
||||
|
||||
// DebugHandler expose debug routes
|
||||
type DebugHandler struct{}
|
||||
|
||||
// Append add debug routes on a router
|
||||
func (g DebugHandler) Append(router *mux.Router) {
|
||||
router.Methods(http.MethodGet).Path("/debug/vars").
|
||||
HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
fmt.Fprint(w, "{\n")
|
||||
first := true
|
||||
expvar.Do(func(kv expvar.KeyValue) {
|
||||
if !first {
|
||||
fmt.Fprint(w, ",\n")
|
||||
}
|
||||
first = false
|
||||
fmt.Fprintf(w, "%q: %s", kv.Key, kv.Value)
|
||||
})
|
||||
fmt.Fprint(w, "\n}\n")
|
||||
})
|
||||
|
||||
runtime.SetBlockProfileRate(1)
|
||||
runtime.SetMutexProfileFraction(5)
|
||||
router.Methods(http.MethodGet).PathPrefix("/debug/pprof/cmdline").HandlerFunc(pprof.Cmdline)
|
||||
router.Methods(http.MethodGet).PathPrefix("/debug/pprof/profile").HandlerFunc(pprof.Profile)
|
||||
router.Methods(http.MethodGet).PathPrefix("/debug/pprof/symbol").HandlerFunc(pprof.Symbol)
|
||||
router.Methods(http.MethodGet).PathPrefix("/debug/pprof/trace").HandlerFunc(pprof.Trace)
|
||||
router.Methods(http.MethodGet).PathPrefix("/debug/pprof/").HandlerFunc(pprof.Index)
|
||||
}
|
355
pkg/api/handler.go
Normal file
355
pkg/api/handler.go
Normal file
|
@ -0,0 +1,355 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
"github.com/containous/traefik/pkg/safe"
|
||||
"github.com/containous/traefik/pkg/types"
|
||||
"github.com/containous/traefik/pkg/version"
|
||||
assetfs "github.com/elazarl/go-bindata-assetfs"
|
||||
thoasstats "github.com/thoas/stats"
|
||||
"github.com/unrolled/render"
|
||||
)
|
||||
|
||||
// ResourceIdentifier a resource identifier
|
||||
type ResourceIdentifier struct {
|
||||
ID string `json:"id"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// ProviderRepresentation a provider with resource identifiers
|
||||
type ProviderRepresentation struct {
|
||||
Routers []ResourceIdentifier `json:"routers,omitempty"`
|
||||
Middlewares []ResourceIdentifier `json:"middlewares,omitempty"`
|
||||
Services []ResourceIdentifier `json:"services,omitempty"`
|
||||
}
|
||||
|
||||
// RouterRepresentation extended version of a router configuration with an ID
|
||||
type RouterRepresentation struct {
|
||||
*config.Router
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// MiddlewareRepresentation extended version of a middleware configuration with an ID
|
||||
type MiddlewareRepresentation struct {
|
||||
*config.Middleware
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// ServiceRepresentation extended version of a service configuration with an ID
|
||||
type ServiceRepresentation struct {
|
||||
*config.Service
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// Handler expose api routes
|
||||
type Handler struct {
|
||||
EntryPoint string
|
||||
Dashboard bool
|
||||
Debug bool
|
||||
CurrentConfigurations *safe.Safe
|
||||
Statistics *types.Statistics
|
||||
Stats *thoasstats.Stats
|
||||
// StatsRecorder *middlewares.StatsRecorder // FIXME stats
|
||||
DashboardAssets *assetfs.AssetFS
|
||||
}
|
||||
|
||||
var templateRenderer jsonRenderer = render.New(render.Options{Directory: "nowhere"})
|
||||
|
||||
type jsonRenderer interface {
|
||||
JSON(w io.Writer, status int, v interface{}) error
|
||||
}
|
||||
|
||||
// Append add api routes on a router
|
||||
func (h Handler) Append(router *mux.Router) {
|
||||
if h.Debug {
|
||||
DebugHandler{}.Append(router)
|
||||
}
|
||||
|
||||
router.Methods(http.MethodGet).Path("/api/rawdata").HandlerFunc(h.getRawData)
|
||||
router.Methods(http.MethodGet).Path("/api/providers").HandlerFunc(h.getProvidersHandler)
|
||||
router.Methods(http.MethodGet).Path("/api/providers/{provider}").HandlerFunc(h.getProviderHandler)
|
||||
router.Methods(http.MethodGet).Path("/api/providers/{provider}/routers").HandlerFunc(h.getRoutersHandler)
|
||||
router.Methods(http.MethodGet).Path("/api/providers/{provider}/routers/{router}").HandlerFunc(h.getRouterHandler)
|
||||
router.Methods(http.MethodGet).Path("/api/providers/{provider}/middlewares").HandlerFunc(h.getMiddlewaresHandler)
|
||||
router.Methods(http.MethodGet).Path("/api/providers/{provider}/middlewares/{middleware}").HandlerFunc(h.getMiddlewareHandler)
|
||||
router.Methods(http.MethodGet).Path("/api/providers/{provider}/services").HandlerFunc(h.getServicesHandler)
|
||||
router.Methods(http.MethodGet).Path("/api/providers/{provider}/services/{service}").HandlerFunc(h.getServiceHandler)
|
||||
|
||||
// FIXME stats
|
||||
// health route
|
||||
//router.Methods(http.MethodGet).Path("/health").HandlerFunc(p.getHealthHandler)
|
||||
|
||||
version.Handler{}.Append(router)
|
||||
|
||||
if h.Dashboard {
|
||||
DashboardHandler{Assets: h.DashboardAssets}.Append(router)
|
||||
}
|
||||
}
|
||||
|
||||
func (h Handler) getRawData(rw http.ResponseWriter, request *http.Request) {
|
||||
if h.CurrentConfigurations != nil {
|
||||
currentConfigurations, ok := h.CurrentConfigurations.Get().(config.Configurations)
|
||||
if !ok {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
err := templateRenderer.JSON(rw, http.StatusOK, currentConfigurations)
|
||||
if err != nil {
|
||||
log.FromContext(request.Context()).Error(err)
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h Handler) getProvidersHandler(rw http.ResponseWriter, request *http.Request) {
|
||||
// FIXME handle currentConfiguration
|
||||
if h.CurrentConfigurations != nil {
|
||||
currentConfigurations, ok := h.CurrentConfigurations.Get().(config.Configurations)
|
||||
if !ok {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
var providers []ResourceIdentifier
|
||||
for name := range currentConfigurations {
|
||||
providers = append(providers, ResourceIdentifier{
|
||||
ID: name,
|
||||
Path: "/api/providers/" + name,
|
||||
})
|
||||
}
|
||||
|
||||
err := templateRenderer.JSON(rw, http.StatusOK, providers)
|
||||
if err != nil {
|
||||
log.FromContext(request.Context()).Error(err)
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h Handler) getProviderHandler(rw http.ResponseWriter, request *http.Request) {
|
||||
providerID := mux.Vars(request)["provider"]
|
||||
|
||||
currentConfigurations := h.CurrentConfigurations.Get().(config.Configurations)
|
||||
|
||||
provider, ok := currentConfigurations[providerID]
|
||||
if !ok {
|
||||
http.NotFound(rw, request)
|
||||
return
|
||||
}
|
||||
|
||||
if provider.HTTP == nil {
|
||||
http.NotFound(rw, request)
|
||||
return
|
||||
}
|
||||
|
||||
var routers []ResourceIdentifier
|
||||
for name := range provider.HTTP.Routers {
|
||||
routers = append(routers, ResourceIdentifier{
|
||||
ID: name,
|
||||
Path: "/api/providers/" + providerID + "/routers",
|
||||
})
|
||||
}
|
||||
|
||||
var services []ResourceIdentifier
|
||||
for name := range provider.HTTP.Services {
|
||||
services = append(services, ResourceIdentifier{
|
||||
ID: name,
|
||||
Path: "/api/providers/" + providerID + "/services",
|
||||
})
|
||||
}
|
||||
|
||||
var middlewares []ResourceIdentifier
|
||||
for name := range provider.HTTP.Middlewares {
|
||||
middlewares = append(middlewares, ResourceIdentifier{
|
||||
ID: name,
|
||||
Path: "/api/providers/" + providerID + "/middlewares",
|
||||
})
|
||||
}
|
||||
|
||||
providers := ProviderRepresentation{Routers: routers, Middlewares: middlewares, Services: services}
|
||||
|
||||
err := templateRenderer.JSON(rw, http.StatusOK, providers)
|
||||
if err != nil {
|
||||
log.FromContext(request.Context()).Error(err)
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func (h Handler) getRoutersHandler(rw http.ResponseWriter, request *http.Request) {
|
||||
providerID := mux.Vars(request)["provider"]
|
||||
|
||||
currentConfigurations := h.CurrentConfigurations.Get().(config.Configurations)
|
||||
|
||||
provider, ok := currentConfigurations[providerID]
|
||||
if !ok {
|
||||
http.NotFound(rw, request)
|
||||
return
|
||||
}
|
||||
|
||||
if provider.HTTP == nil {
|
||||
http.NotFound(rw, request)
|
||||
return
|
||||
}
|
||||
|
||||
var routers []RouterRepresentation
|
||||
for name, router := range provider.HTTP.Routers {
|
||||
routers = append(routers, RouterRepresentation{Router: router, ID: name})
|
||||
}
|
||||
|
||||
err := templateRenderer.JSON(rw, http.StatusOK, routers)
|
||||
if err != nil {
|
||||
log.FromContext(request.Context()).Error(err)
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func (h Handler) getRouterHandler(rw http.ResponseWriter, request *http.Request) {
|
||||
providerID := mux.Vars(request)["provider"]
|
||||
routerID := mux.Vars(request)["router"]
|
||||
|
||||
currentConfigurations := h.CurrentConfigurations.Get().(config.Configurations)
|
||||
|
||||
provider, ok := currentConfigurations[providerID]
|
||||
if !ok {
|
||||
http.NotFound(rw, request)
|
||||
return
|
||||
}
|
||||
|
||||
if provider.HTTP == nil {
|
||||
http.NotFound(rw, request)
|
||||
return
|
||||
}
|
||||
|
||||
router, ok := provider.HTTP.Routers[routerID]
|
||||
if !ok {
|
||||
http.NotFound(rw, request)
|
||||
return
|
||||
}
|
||||
|
||||
err := templateRenderer.JSON(rw, http.StatusOK, router)
|
||||
if err != nil {
|
||||
log.FromContext(request.Context()).Error(err)
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func (h Handler) getMiddlewaresHandler(rw http.ResponseWriter, request *http.Request) {
|
||||
providerID := mux.Vars(request)["provider"]
|
||||
|
||||
currentConfigurations := h.CurrentConfigurations.Get().(config.Configurations)
|
||||
|
||||
provider, ok := currentConfigurations[providerID]
|
||||
if !ok {
|
||||
http.NotFound(rw, request)
|
||||
return
|
||||
}
|
||||
|
||||
if provider.HTTP == nil {
|
||||
http.NotFound(rw, request)
|
||||
return
|
||||
}
|
||||
|
||||
var middlewares []MiddlewareRepresentation
|
||||
for name, middleware := range provider.HTTP.Middlewares {
|
||||
middlewares = append(middlewares, MiddlewareRepresentation{Middleware: middleware, ID: name})
|
||||
}
|
||||
|
||||
err := templateRenderer.JSON(rw, http.StatusOK, middlewares)
|
||||
if err != nil {
|
||||
log.FromContext(request.Context()).Error(err)
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func (h Handler) getMiddlewareHandler(rw http.ResponseWriter, request *http.Request) {
|
||||
providerID := mux.Vars(request)["provider"]
|
||||
middlewareID := mux.Vars(request)["middleware"]
|
||||
|
||||
currentConfigurations := h.CurrentConfigurations.Get().(config.Configurations)
|
||||
|
||||
provider, ok := currentConfigurations[providerID]
|
||||
if !ok {
|
||||
http.NotFound(rw, request)
|
||||
return
|
||||
}
|
||||
|
||||
if provider.HTTP == nil {
|
||||
http.NotFound(rw, request)
|
||||
return
|
||||
}
|
||||
|
||||
middleware, ok := provider.HTTP.Middlewares[middlewareID]
|
||||
if !ok {
|
||||
http.NotFound(rw, request)
|
||||
return
|
||||
}
|
||||
|
||||
err := templateRenderer.JSON(rw, http.StatusOK, middleware)
|
||||
if err != nil {
|
||||
log.FromContext(request.Context()).Error(err)
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func (h Handler) getServicesHandler(rw http.ResponseWriter, request *http.Request) {
|
||||
providerID := mux.Vars(request)["provider"]
|
||||
|
||||
currentConfigurations := h.CurrentConfigurations.Get().(config.Configurations)
|
||||
|
||||
provider, ok := currentConfigurations[providerID]
|
||||
if !ok {
|
||||
http.NotFound(rw, request)
|
||||
return
|
||||
}
|
||||
|
||||
if provider.HTTP == nil {
|
||||
http.NotFound(rw, request)
|
||||
return
|
||||
}
|
||||
|
||||
var services []ServiceRepresentation
|
||||
for name, service := range provider.HTTP.Services {
|
||||
services = append(services, ServiceRepresentation{Service: service, ID: name})
|
||||
}
|
||||
|
||||
err := templateRenderer.JSON(rw, http.StatusOK, services)
|
||||
if err != nil {
|
||||
log.FromContext(request.Context()).Error(err)
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func (h Handler) getServiceHandler(rw http.ResponseWriter, request *http.Request) {
|
||||
providerID := mux.Vars(request)["provider"]
|
||||
serviceID := mux.Vars(request)["service"]
|
||||
|
||||
currentConfigurations := h.CurrentConfigurations.Get().(config.Configurations)
|
||||
|
||||
provider, ok := currentConfigurations[providerID]
|
||||
if !ok {
|
||||
http.NotFound(rw, request)
|
||||
return
|
||||
}
|
||||
|
||||
if provider.HTTP == nil {
|
||||
http.NotFound(rw, request)
|
||||
return
|
||||
}
|
||||
|
||||
service, ok := provider.HTTP.Services[serviceID]
|
||||
if !ok {
|
||||
http.NotFound(rw, request)
|
||||
return
|
||||
}
|
||||
|
||||
err := templateRenderer.JSON(rw, http.StatusOK, service)
|
||||
if err != nil {
|
||||
log.FromContext(request.Context()).Error(err)
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
226
pkg/api/handler_test.go
Normal file
226
pkg/api/handler_test.go
Normal file
|
@ -0,0 +1,226 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/safe"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHandler_Configuration(t *testing.T) {
|
||||
type expected struct {
|
||||
statusCode int
|
||||
body string
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
path string
|
||||
configuration config.Configurations
|
||||
expected expected
|
||||
}{
|
||||
{
|
||||
desc: "Get all the providers",
|
||||
path: "/api/providers",
|
||||
configuration: config.Configurations{
|
||||
"foo": {
|
||||
HTTP: &config.HTTPConfiguration{
|
||||
Routers: map[string]*config.Router{
|
||||
"bar": {EntryPoints: []string{"foo", "bar"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: expected{statusCode: http.StatusOK, body: `[{"id":"foo","path":"/api/providers/foo"}]`},
|
||||
},
|
||||
{
|
||||
desc: "Get a provider",
|
||||
path: "/api/providers/foo",
|
||||
configuration: config.Configurations{
|
||||
"foo": {
|
||||
HTTP: &config.HTTPConfiguration{
|
||||
Routers: map[string]*config.Router{
|
||||
"bar": {EntryPoints: []string{"foo", "bar"}},
|
||||
},
|
||||
Middlewares: map[string]*config.Middleware{
|
||||
"bar": {
|
||||
AddPrefix: &config.AddPrefix{Prefix: "bar"},
|
||||
},
|
||||
},
|
||||
Services: map[string]*config.Service{
|
||||
"foo": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: expected{statusCode: http.StatusOK, body: `{"routers":[{"id":"bar","path":"/api/providers/foo/routers"}],"middlewares":[{"id":"bar","path":"/api/providers/foo/middlewares"}],"services":[{"id":"foo","path":"/api/providers/foo/services"}]}`},
|
||||
},
|
||||
{
|
||||
desc: "Provider not found",
|
||||
path: "/api/providers/foo",
|
||||
configuration: config.Configurations{},
|
||||
expected: expected{statusCode: http.StatusNotFound, body: "404 page not found\n"},
|
||||
},
|
||||
{
|
||||
desc: "Get all routers",
|
||||
path: "/api/providers/foo/routers",
|
||||
configuration: config.Configurations{
|
||||
"foo": {
|
||||
HTTP: &config.HTTPConfiguration{
|
||||
Routers: map[string]*config.Router{
|
||||
"bar": {EntryPoints: []string{"foo", "bar"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: expected{statusCode: http.StatusOK, body: `[{"entryPoints":["foo","bar"],"id":"bar"}]`},
|
||||
},
|
||||
{
|
||||
desc: "Get a router",
|
||||
path: "/api/providers/foo/routers/bar",
|
||||
configuration: config.Configurations{
|
||||
"foo": {
|
||||
HTTP: &config.HTTPConfiguration{
|
||||
Routers: map[string]*config.Router{
|
||||
"bar": {EntryPoints: []string{"foo", "bar"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: expected{statusCode: http.StatusOK, body: `{"entryPoints":["foo","bar"]}`},
|
||||
},
|
||||
{
|
||||
desc: "Router not found",
|
||||
path: "/api/providers/foo/routers/bar",
|
||||
configuration: config.Configurations{
|
||||
"foo": {},
|
||||
},
|
||||
expected: expected{statusCode: http.StatusNotFound, body: "404 page not found\n"},
|
||||
},
|
||||
{
|
||||
desc: "Get all services",
|
||||
path: "/api/providers/foo/services",
|
||||
configuration: config.Configurations{
|
||||
"foo": {
|
||||
HTTP: &config.HTTPConfiguration{
|
||||
Services: map[string]*config.Service{
|
||||
"foo": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: expected{statusCode: http.StatusOK, body: `[{"loadbalancer":{"method":"wrr","passHostHeader":false},"id":"foo"}]`},
|
||||
},
|
||||
{
|
||||
desc: "Get a service",
|
||||
path: "/api/providers/foo/services/foo",
|
||||
configuration: config.Configurations{
|
||||
"foo": {
|
||||
HTTP: &config.HTTPConfiguration{
|
||||
Services: map[string]*config.Service{
|
||||
"foo": {
|
||||
LoadBalancer: &config.LoadBalancerService{
|
||||
Method: "wrr",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: expected{statusCode: http.StatusOK, body: `{"loadbalancer":{"method":"wrr","passHostHeader":false}}`},
|
||||
},
|
||||
{
|
||||
desc: "Service not found",
|
||||
path: "/api/providers/foo/services/bar",
|
||||
configuration: config.Configurations{
|
||||
"foo": {},
|
||||
},
|
||||
expected: expected{statusCode: http.StatusNotFound, body: "404 page not found\n"},
|
||||
},
|
||||
{
|
||||
desc: "Get all middlewares",
|
||||
path: "/api/providers/foo/middlewares",
|
||||
configuration: config.Configurations{
|
||||
"foo": {
|
||||
HTTP: &config.HTTPConfiguration{
|
||||
Middlewares: map[string]*config.Middleware{
|
||||
"bar": {
|
||||
AddPrefix: &config.AddPrefix{Prefix: "bar"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: expected{statusCode: http.StatusOK, body: `[{"addPrefix":{"prefix":"bar"},"id":"bar"}]`},
|
||||
},
|
||||
{
|
||||
desc: "Get a middleware",
|
||||
path: "/api/providers/foo/middlewares/bar",
|
||||
configuration: config.Configurations{
|
||||
"foo": {
|
||||
HTTP: &config.HTTPConfiguration{
|
||||
Middlewares: map[string]*config.Middleware{
|
||||
"bar": {
|
||||
AddPrefix: &config.AddPrefix{Prefix: "bar"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: expected{statusCode: http.StatusOK, body: `{"addPrefix":{"prefix":"bar"}}`},
|
||||
},
|
||||
{
|
||||
desc: "Middleware not found",
|
||||
path: "/api/providers/foo/middlewares/bar",
|
||||
configuration: config.Configurations{
|
||||
"foo": {},
|
||||
},
|
||||
expected: expected{statusCode: http.StatusNotFound, body: "404 page not found\n"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
currentConfiguration := &safe.Safe{}
|
||||
currentConfiguration.Set(test.configuration)
|
||||
|
||||
handler := Handler{
|
||||
CurrentConfigurations: currentConfiguration,
|
||||
}
|
||||
|
||||
router := mux.NewRouter()
|
||||
handler.Append(router)
|
||||
|
||||
server := httptest.NewServer(router)
|
||||
|
||||
resp, err := http.DefaultClient.Get(server.URL + test.path)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.expected.statusCode, resp.StatusCode)
|
||||
|
||||
content, err := ioutil.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
err = resp.Body.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.expected.body, string(content))
|
||||
})
|
||||
}
|
||||
}
|
80
pkg/collector/collector.go
Normal file
80
pkg/collector/collector.go
Normal file
|
@ -0,0 +1,80 @@
|
|||
package collector
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/old/configuration"
|
||||
"github.com/containous/traefik/pkg/anonymize"
|
||||
"github.com/containous/traefik/pkg/config/static"
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
"github.com/containous/traefik/pkg/version"
|
||||
"github.com/mitchellh/hashstructure"
|
||||
)
|
||||
|
||||
// collectorURL URL where the stats are send
|
||||
const collectorURL = "https://collect.traefik.io/9vxmmkcdmalbdi635d4jgc5p5rx0h7h8"
|
||||
|
||||
// Collected data
|
||||
type data struct {
|
||||
Version string
|
||||
Codename string
|
||||
BuildDate string
|
||||
Configuration string
|
||||
Hash string
|
||||
}
|
||||
|
||||
// Collect anonymous data.
|
||||
func Collect(staticConfiguration *static.Configuration) error {
|
||||
anonConfig, err := anonymize.Do(staticConfiguration, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("Anonymous stats sent to %s: %s", collectorURL, anonConfig)
|
||||
|
||||
hashConf, err := hashstructure.Hash(staticConfiguration, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data := &data{
|
||||
Version: version.Version,
|
||||
Codename: version.Codename,
|
||||
BuildDate: version.BuildDate,
|
||||
Hash: strconv.FormatUint(hashConf, 10),
|
||||
Configuration: base64.StdEncoding.EncodeToString([]byte(anonConfig)),
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
err = json.NewEncoder(buf).Encode(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = makeHTTPClient().Post(collectorURL, "application/json; charset=utf-8", buf)
|
||||
return err
|
||||
}
|
||||
|
||||
func makeHTTPClient() *http.Client {
|
||||
dialer := &net.Dialer{
|
||||
Timeout: configuration.DefaultDialTimeout,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: dialer.DialContext,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
return &http.Client{Transport: transport}
|
||||
}
|
230
pkg/config/dyn_config.go
Normal file
230
pkg/config/dyn_config.go
Normal file
|
@ -0,0 +1,230 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"reflect"
|
||||
|
||||
traefiktls "github.com/containous/traefik/pkg/tls"
|
||||
)
|
||||
|
||||
// Router holds the router configuration.
|
||||
type Router struct {
|
||||
EntryPoints []string `json:"entryPoints"`
|
||||
Middlewares []string `json:"middlewares,omitempty" toml:",omitempty"`
|
||||
Service string `json:"service,omitempty" toml:",omitempty"`
|
||||
Rule string `json:"rule,omitempty" toml:",omitempty"`
|
||||
Priority int `json:"priority,omitempty" toml:"priority,omitzero"`
|
||||
TLS *RouterTLSConfig `json:"tls,omitempty" toml:"tls,omitzero" label:"allowEmpty"`
|
||||
}
|
||||
|
||||
// RouterTLSConfig holds the TLS configuration for a router
|
||||
type RouterTLSConfig struct{}
|
||||
|
||||
// TCPRouter holds the router configuration.
|
||||
type TCPRouter struct {
|
||||
EntryPoints []string `json:"entryPoints"`
|
||||
Service string `json:"service,omitempty" toml:",omitempty"`
|
||||
Rule string `json:"rule,omitempty" toml:",omitempty"`
|
||||
TLS *RouterTCPTLSConfig `json:"tls,omitempty" toml:"tls,omitzero" label:"allowEmpty"`
|
||||
}
|
||||
|
||||
// RouterTCPTLSConfig holds the TLS configuration for a router
|
||||
type RouterTCPTLSConfig struct {
|
||||
Passthrough bool `json:"passthrough" toml:"passthrough,omitzero"`
|
||||
}
|
||||
|
||||
// LoadBalancerService holds the LoadBalancerService configuration.
|
||||
type LoadBalancerService struct {
|
||||
Stickiness *Stickiness `json:"stickiness,omitempty" toml:",omitempty" label:"allowEmpty"`
|
||||
Servers []Server `json:"servers,omitempty" toml:",omitempty" label-slice-as-struct:"server"`
|
||||
Method string `json:"method,omitempty" toml:",omitempty"`
|
||||
HealthCheck *HealthCheck `json:"healthCheck,omitempty" toml:",omitempty"`
|
||||
PassHostHeader bool `json:"passHostHeader" toml:",omitempty"`
|
||||
ResponseForwarding *ResponseForwarding `json:"forwardingResponse,omitempty" toml:",omitempty"`
|
||||
}
|
||||
|
||||
// TCPLoadBalancerService holds the LoadBalancerService configuration.
|
||||
type TCPLoadBalancerService struct {
|
||||
Servers []TCPServer `json:"servers,omitempty" toml:",omitempty" label-slice-as-struct:"server"`
|
||||
Method string `json:"method,omitempty" toml:",omitempty"`
|
||||
}
|
||||
|
||||
// Mergeable tells if the given service is mergeable.
|
||||
func (l *LoadBalancerService) Mergeable(loadBalancer *LoadBalancerService) bool {
|
||||
savedServers := l.Servers
|
||||
defer func() {
|
||||
l.Servers = savedServers
|
||||
}()
|
||||
l.Servers = nil
|
||||
|
||||
savedServersLB := loadBalancer.Servers
|
||||
defer func() {
|
||||
loadBalancer.Servers = savedServersLB
|
||||
}()
|
||||
loadBalancer.Servers = nil
|
||||
|
||||
return reflect.DeepEqual(l, loadBalancer)
|
||||
}
|
||||
|
||||
// SetDefaults Default values for a LoadBalancerService.
|
||||
func (l *LoadBalancerService) SetDefaults() {
|
||||
l.PassHostHeader = true
|
||||
l.Method = "wrr"
|
||||
}
|
||||
|
||||
// ResponseForwarding holds configuration for the forward of the response.
|
||||
type ResponseForwarding struct {
|
||||
FlushInterval string `json:"flushInterval,omitempty" toml:",omitempty"`
|
||||
}
|
||||
|
||||
// Stickiness holds the stickiness configuration.
|
||||
type Stickiness struct {
|
||||
CookieName string `json:"cookieName,omitempty" toml:",omitempty"`
|
||||
}
|
||||
|
||||
// Server holds the server configuration.
|
||||
type Server struct {
|
||||
URL string `json:"url" label:"-"`
|
||||
Scheme string `toml:"-" json:"-"`
|
||||
Port string `toml:"-" json:"-"`
|
||||
Weight int `json:"weight"`
|
||||
}
|
||||
|
||||
// TCPServer holds a TCP Server configuration
|
||||
type TCPServer struct {
|
||||
Address string `json:"address" label:"-"`
|
||||
Weight int `json:"weight"`
|
||||
}
|
||||
|
||||
// SetDefaults Default values for a Server.
|
||||
func (s *Server) SetDefaults() {
|
||||
s.Weight = 1
|
||||
s.Scheme = "http"
|
||||
}
|
||||
|
||||
// HealthCheck holds the HealthCheck configuration.
|
||||
type HealthCheck struct {
|
||||
Scheme string `json:"scheme,omitempty" toml:",omitempty"`
|
||||
Path string `json:"path,omitempty" toml:",omitempty"`
|
||||
Port int `json:"port,omitempty" toml:",omitempty,omitzero"`
|
||||
// FIXME change string to parse.Duration
|
||||
Interval string `json:"interval,omitempty" toml:",omitempty"`
|
||||
// FIXME change string to parse.Duration
|
||||
Timeout string `json:"timeout,omitempty" toml:",omitempty"`
|
||||
Hostname string `json:"hostname,omitempty" toml:",omitempty"`
|
||||
Headers map[string]string `json:"headers,omitempty" toml:",omitempty"`
|
||||
}
|
||||
|
||||
// CreateTLSConfig creates a TLS config from ClientTLS structures.
|
||||
func (clientTLS *ClientTLS) CreateTLSConfig() (*tls.Config, error) {
|
||||
if clientTLS == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var err error
|
||||
caPool := x509.NewCertPool()
|
||||
clientAuth := tls.NoClientCert
|
||||
if clientTLS.CA != "" {
|
||||
var ca []byte
|
||||
if _, errCA := os.Stat(clientTLS.CA); errCA == nil {
|
||||
ca, err = ioutil.ReadFile(clientTLS.CA)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read CA. %s", err)
|
||||
}
|
||||
} else {
|
||||
ca = []byte(clientTLS.CA)
|
||||
}
|
||||
|
||||
if !caPool.AppendCertsFromPEM(ca) {
|
||||
return nil, fmt.Errorf("failed to parse CA")
|
||||
}
|
||||
|
||||
if clientTLS.CAOptional {
|
||||
clientAuth = tls.VerifyClientCertIfGiven
|
||||
} else {
|
||||
clientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
}
|
||||
|
||||
cert := tls.Certificate{}
|
||||
_, errKeyIsFile := os.Stat(clientTLS.Key)
|
||||
|
||||
if !clientTLS.InsecureSkipVerify && (len(clientTLS.Cert) == 0 || len(clientTLS.Key) == 0) {
|
||||
return nil, fmt.Errorf("TLS Certificate or Key file must be set when TLS configuration is created")
|
||||
}
|
||||
|
||||
if len(clientTLS.Cert) > 0 && len(clientTLS.Key) > 0 {
|
||||
if _, errCertIsFile := os.Stat(clientTLS.Cert); errCertIsFile == nil {
|
||||
if errKeyIsFile == nil {
|
||||
cert, err = tls.LoadX509KeyPair(clientTLS.Cert, clientTLS.Key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load TLS keypair: %v", err)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("tls cert is a file, but tls key is not")
|
||||
}
|
||||
} else {
|
||||
if errKeyIsFile != nil {
|
||||
cert, err = tls.X509KeyPair([]byte(clientTLS.Cert), []byte(clientTLS.Key))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load TLS keypair: %v", err)
|
||||
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("TLS key is a file, but tls cert is not")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
RootCAs: caPool,
|
||||
InsecureSkipVerify: clientTLS.InsecureSkipVerify,
|
||||
ClientAuth: clientAuth,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Message holds configuration information exchanged between parts of traefik.
|
||||
type Message struct {
|
||||
ProviderName string
|
||||
Configuration *Configuration
|
||||
}
|
||||
|
||||
// Configuration is the root of the dynamic configuration
|
||||
type Configuration struct {
|
||||
HTTP *HTTPConfiguration
|
||||
TCP *TCPConfiguration
|
||||
TLS []*traefiktls.Configuration `json:"-" label:"-"`
|
||||
TLSOptions map[string]traefiktls.TLS
|
||||
TLSStores map[string]traefiktls.Store
|
||||
}
|
||||
|
||||
// Configurations is for currentConfigurations Map.
|
||||
type Configurations map[string]*Configuration
|
||||
|
||||
// HTTPConfiguration FIXME better name?
|
||||
type HTTPConfiguration struct {
|
||||
Routers map[string]*Router `json:"routers,omitempty" toml:",omitempty"`
|
||||
Middlewares map[string]*Middleware `json:"middlewares,omitempty" toml:",omitempty"`
|
||||
Services map[string]*Service `json:"services,omitempty" toml:",omitempty"`
|
||||
}
|
||||
|
||||
// TCPConfiguration FIXME better name?
|
||||
type TCPConfiguration struct {
|
||||
Routers map[string]*TCPRouter `json:"routers,omitempty" toml:",omitempty"`
|
||||
Services map[string]*TCPService `json:"services,omitempty" toml:",omitempty"`
|
||||
}
|
||||
|
||||
// Service holds a service configuration (can only be of one type at the same time).
|
||||
type Service struct {
|
||||
LoadBalancer *LoadBalancerService `json:"loadbalancer,omitempty" toml:",omitempty,omitzero"`
|
||||
}
|
||||
|
||||
// TCPService holds a tcp service configuration (can only be of one type at the same time).
|
||||
type TCPService struct {
|
||||
LoadBalancer *TCPLoadBalancerService `json:"loadbalancer,omitempty" toml:",omitempty,omitzero"`
|
||||
}
|
363
pkg/config/middlewares.go
Normal file
363
pkg/config/middlewares.go
Normal file
|
@ -0,0 +1,363 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"github.com/containous/flaeg/parse"
|
||||
"github.com/containous/traefik/pkg/ip"
|
||||
)
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// Middleware holds the Middleware configuration.
|
||||
type Middleware struct {
|
||||
AddPrefix *AddPrefix `json:"addPrefix,omitempty"`
|
||||
StripPrefix *StripPrefix `json:"stripPrefix,omitempty"`
|
||||
StripPrefixRegex *StripPrefixRegex `json:"stripPrefixRegex,omitempty"`
|
||||
ReplacePath *ReplacePath `json:"replacePath,omitempty"`
|
||||
ReplacePathRegex *ReplacePathRegex `json:"replacePathRegex,omitempty"`
|
||||
Chain *Chain `json:"chain,omitempty"`
|
||||
IPWhiteList *IPWhiteList `json:"ipWhiteList,omitempty"`
|
||||
Headers *Headers `json:"headers,omitempty"`
|
||||
Errors *ErrorPage `json:"errors,omitempty"`
|
||||
RateLimit *RateLimit `json:"rateLimit,omitempty"`
|
||||
RedirectRegex *RedirectRegex `json:"redirectregex,omitempty"`
|
||||
RedirectScheme *RedirectScheme `json:"redirectscheme,omitempty"`
|
||||
BasicAuth *BasicAuth `json:"basicAuth,omitempty"`
|
||||
DigestAuth *DigestAuth `json:"digestAuth,omitempty"`
|
||||
ForwardAuth *ForwardAuth `json:"forwardAuth,omitempty"`
|
||||
MaxConn *MaxConn `json:"maxConn,omitempty"`
|
||||
Buffering *Buffering `json:"buffering,omitempty"`
|
||||
CircuitBreaker *CircuitBreaker `json:"circuitBreaker,omitempty"`
|
||||
Compress *Compress `json:"compress,omitempty" label:"allowEmpty"`
|
||||
PassTLSClientCert *PassTLSClientCert `json:"passTLSClientCert,omitempty"`
|
||||
Retry *Retry `json:"retry,omitempty"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// AddPrefix holds the AddPrefix configuration.
|
||||
type AddPrefix struct {
|
||||
Prefix string `json:"prefix,omitempty"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// Auth holds the authentication configuration (BASIC, DIGEST, users).
|
||||
type Auth struct {
|
||||
Basic *BasicAuth `json:"basic,omitempty" export:"true"`
|
||||
Digest *DigestAuth `json:"digest,omitempty" export:"true"`
|
||||
Forward *ForwardAuth `json:"forward,omitempty" export:"true"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// BasicAuth holds the HTTP basic authentication configuration.
|
||||
type BasicAuth struct {
|
||||
Users `json:"users,omitempty" mapstructure:","`
|
||||
UsersFile string `json:"usersFile,omitempty"`
|
||||
Realm string `json:"realm,omitempty"`
|
||||
RemoveHeader bool `json:"removeHeader,omitempty"`
|
||||
HeaderField string `json:"headerField,omitempty" export:"true"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// Buffering holds the request/response buffering configuration.
|
||||
type Buffering struct {
|
||||
MaxRequestBodyBytes int64 `json:"maxRequestBodyBytes,omitempty"`
|
||||
MemRequestBodyBytes int64 `json:"memRequestBodyBytes,omitempty"`
|
||||
MaxResponseBodyBytes int64 `json:"maxResponseBodyBytes,omitempty"`
|
||||
MemResponseBodyBytes int64 `json:"memResponseBodyBytes,omitempty"`
|
||||
RetryExpression string `json:"retryExpression,omitempty"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// Chain holds a chain of middlewares
|
||||
type Chain struct {
|
||||
Middlewares []string `json:"middlewares"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// CircuitBreaker holds the circuit breaker configuration.
|
||||
type CircuitBreaker struct {
|
||||
Expression string `json:"expression,omitempty"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// Compress holds the compress configuration.
|
||||
type Compress struct{}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// DigestAuth holds the Digest HTTP authentication configuration.
|
||||
type DigestAuth struct {
|
||||
Users `json:"users,omitempty" mapstructure:","`
|
||||
UsersFile string `json:"usersFile,omitempty"`
|
||||
RemoveHeader bool `json:"removeHeader,omitempty"`
|
||||
Realm string `json:"realm,omitempty" mapstructure:","`
|
||||
HeaderField string `json:"headerField,omitempty" export:"true"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// ErrorPage holds the custom error page configuration.
|
||||
type ErrorPage struct {
|
||||
Status []string `json:"status,omitempty"`
|
||||
Service string `json:"service,omitempty"`
|
||||
Query string `json:"query,omitempty"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// ForwardAuth holds the http forward authentication configuration.
|
||||
type ForwardAuth struct {
|
||||
Address string `description:"Authentication server address" json:"address,omitempty"`
|
||||
TLS *ClientTLS `description:"Enable TLS support" json:"tls,omitempty" export:"true"`
|
||||
TrustForwardHeader bool `description:"Trust X-Forwarded-* headers" json:"trustForwardHeader,omitempty" export:"true"`
|
||||
AuthResponseHeaders []string `description:"Headers to be forwarded from auth response" json:"authResponseHeaders,omitempty"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// Headers holds the custom header configuration.
|
||||
type Headers struct {
|
||||
CustomRequestHeaders map[string]string `json:"customRequestHeaders,omitempty"`
|
||||
CustomResponseHeaders map[string]string `json:"customResponseHeaders,omitempty"`
|
||||
|
||||
AllowedHosts []string `json:"allowedHosts,omitempty"`
|
||||
HostsProxyHeaders []string `json:"hostsProxyHeaders,omitempty"`
|
||||
SSLRedirect bool `json:"sslRedirect,omitempty"`
|
||||
SSLTemporaryRedirect bool `json:"sslTemporaryRedirect,omitempty"`
|
||||
SSLHost string `json:"sslHost,omitempty"`
|
||||
SSLProxyHeaders map[string]string `json:"sslProxyHeaders,omitempty"`
|
||||
SSLForceHost bool `json:"sslForceHost,omitempty"`
|
||||
STSSeconds int64 `json:"stsSeconds,omitempty"`
|
||||
STSIncludeSubdomains bool `json:"stsIncludeSubdomains,omitempty"`
|
||||
STSPreload bool `json:"stsPreload,omitempty"`
|
||||
ForceSTSHeader bool `json:"forceSTSHeader,omitempty"`
|
||||
FrameDeny bool `json:"frameDeny,omitempty"`
|
||||
CustomFrameOptionsValue string `json:"customFrameOptionsValue,omitempty"`
|
||||
ContentTypeNosniff bool `json:"contentTypeNosniff,omitempty"`
|
||||
BrowserXSSFilter bool `json:"browserXssFilter,omitempty"`
|
||||
CustomBrowserXSSValue string `json:"customBrowserXSSValue,omitempty"`
|
||||
ContentSecurityPolicy string `json:"contentSecurityPolicy,omitempty"`
|
||||
PublicKey string `json:"publicKey,omitempty"`
|
||||
ReferrerPolicy string `json:"referrerPolicy,omitempty"`
|
||||
IsDevelopment bool `json:"isDevelopment,omitempty"`
|
||||
}
|
||||
|
||||
// HasCustomHeadersDefined checks to see if any of the custom header elements have been set
|
||||
func (h *Headers) HasCustomHeadersDefined() bool {
|
||||
return h != nil && (len(h.CustomResponseHeaders) != 0 ||
|
||||
len(h.CustomRequestHeaders) != 0)
|
||||
}
|
||||
|
||||
// HasSecureHeadersDefined checks to see if any of the secure header elements have been set
|
||||
func (h *Headers) HasSecureHeadersDefined() bool {
|
||||
return h != nil && (len(h.AllowedHosts) != 0 ||
|
||||
len(h.HostsProxyHeaders) != 0 ||
|
||||
h.SSLRedirect ||
|
||||
h.SSLTemporaryRedirect ||
|
||||
h.SSLForceHost ||
|
||||
h.SSLHost != "" ||
|
||||
len(h.SSLProxyHeaders) != 0 ||
|
||||
h.STSSeconds != 0 ||
|
||||
h.STSIncludeSubdomains ||
|
||||
h.STSPreload ||
|
||||
h.ForceSTSHeader ||
|
||||
h.FrameDeny ||
|
||||
h.CustomFrameOptionsValue != "" ||
|
||||
h.ContentTypeNosniff ||
|
||||
h.BrowserXSSFilter ||
|
||||
h.CustomBrowserXSSValue != "" ||
|
||||
h.ContentSecurityPolicy != "" ||
|
||||
h.PublicKey != "" ||
|
||||
h.ReferrerPolicy != "" ||
|
||||
h.IsDevelopment)
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// IPStrategy holds the ip strategy configuration.
|
||||
type IPStrategy struct {
|
||||
Depth int `json:"depth,omitempty" export:"true"`
|
||||
ExcludedIPs []string `json:"excludedIPs,omitempty"`
|
||||
}
|
||||
|
||||
// Get an IP selection strategy
|
||||
// if nil return the RemoteAddr strategy
|
||||
// else return a strategy base on the configuration using the X-Forwarded-For Header.
|
||||
// Depth override the ExcludedIPs
|
||||
func (s *IPStrategy) Get() (ip.Strategy, error) {
|
||||
if s == nil {
|
||||
return &ip.RemoteAddrStrategy{}, nil
|
||||
}
|
||||
|
||||
if s.Depth > 0 {
|
||||
return &ip.DepthStrategy{
|
||||
Depth: s.Depth,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if len(s.ExcludedIPs) > 0 {
|
||||
checker, err := ip.NewChecker(s.ExcludedIPs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ip.CheckerStrategy{
|
||||
Checker: checker,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &ip.RemoteAddrStrategy{}, nil
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// IPWhiteList holds the ip white list configuration.
|
||||
type IPWhiteList struct {
|
||||
SourceRange []string `json:"sourceRange,omitempty"`
|
||||
IPStrategy *IPStrategy `json:"ipStrategy,omitempty" label:"allowEmpty"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// MaxConn holds maximum connection configuration.
|
||||
type MaxConn struct {
|
||||
Amount int64 `json:"amount,omitempty"`
|
||||
ExtractorFunc string `json:"extractorFunc,omitempty"`
|
||||
}
|
||||
|
||||
// SetDefaults Default values for a MaxConn.
|
||||
func (m *MaxConn) SetDefaults() {
|
||||
m.ExtractorFunc = "request.host"
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// PassTLSClientCert holds the TLS client cert headers configuration.
|
||||
type PassTLSClientCert struct {
|
||||
PEM bool `description:"Enable header with escaped client pem" json:"pem"`
|
||||
Info *TLSClientCertificateInfo `description:"Enable header with configured client cert info" json:"info,omitempty"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// Rate holds the rate limiting configuration for a specific time period.
|
||||
type Rate struct {
|
||||
Period parse.Duration `json:"period,omitempty"`
|
||||
Average int64 `json:"average,omitempty"`
|
||||
Burst int64 `json:"burst,omitempty"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// RateLimit holds the rate limiting configuration for a given frontend.
|
||||
type RateLimit struct {
|
||||
RateSet map[string]*Rate `json:"rateset,omitempty"`
|
||||
// FIXME replace by ipStrategy see oxy and replace
|
||||
ExtractorFunc string `json:"extractorFunc,omitempty"`
|
||||
}
|
||||
|
||||
// SetDefaults Default values for a MaxConn.
|
||||
func (r *RateLimit) SetDefaults() {
|
||||
r.ExtractorFunc = "request.host"
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// RedirectRegex holds the redirection configuration.
|
||||
type RedirectRegex struct {
|
||||
Regex string `json:"regex,omitempty"`
|
||||
Replacement string `json:"replacement,omitempty"`
|
||||
Permanent bool `json:"permanent,omitempty"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// RedirectScheme holds the scheme redirection configuration.
|
||||
type RedirectScheme struct {
|
||||
Scheme string `json:"scheme,omitempty"`
|
||||
Port string `json:"port,omitempty"`
|
||||
Permanent bool `json:"permanent,omitempty"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// ReplacePath holds the ReplacePath configuration.
|
||||
type ReplacePath struct {
|
||||
Path string `json:"path,omitempty"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// ReplacePathRegex holds the ReplacePathRegex configuration.
|
||||
type ReplacePathRegex struct {
|
||||
Regex string `json:"regex,omitempty"`
|
||||
Replacement string `json:"replacement,omitempty"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// Retry holds the retry configuration.
|
||||
type Retry struct {
|
||||
Attempts int `description:"Number of attempts" export:"true"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// StripPrefix holds the StripPrefix configuration.
|
||||
type StripPrefix struct {
|
||||
Prefixes []string `json:"prefixes,omitempty"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// StripPrefixRegex holds the StripPrefixRegex configuration.
|
||||
type StripPrefixRegex struct {
|
||||
Regex []string `json:"regex,omitempty"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// TLSClientCertificateInfo holds the client TLS certificate info configuration.
|
||||
type TLSClientCertificateInfo struct {
|
||||
NotAfter bool `description:"Add NotAfter info in header" json:"notAfter"`
|
||||
NotBefore bool `description:"Add NotBefore info in header" json:"notBefore"`
|
||||
Sans bool `description:"Add Sans info in header" json:"sans"`
|
||||
Subject *TLSCLientCertificateDNInfo `description:"Add Subject info in header" json:"subject,omitempty"`
|
||||
Issuer *TLSCLientCertificateDNInfo `description:"Add Issuer info in header" json:"issuer,omitempty"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// TLSCLientCertificateDNInfo holds the client TLS certificate distinguished name info configuration
|
||||
// cf https://tools.ietf.org/html/rfc3739
|
||||
type TLSCLientCertificateDNInfo struct {
|
||||
Country bool `description:"Add Country info in header" json:"country"`
|
||||
Province bool `description:"Add Province info in header" json:"province"`
|
||||
Locality bool `description:"Add Locality info in header" json:"locality"`
|
||||
Organization bool `description:"Add Organization info in header" json:"organization"`
|
||||
CommonName bool `description:"Add CommonName info in header" json:"commonName"`
|
||||
SerialNumber bool `description:"Add SerialNumber info in header" json:"serialNumber"`
|
||||
DomainComponent bool `description:"Add Domain Component info in header" json:"domainComponent"`
|
||||
}
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// Users holds a list of users
|
||||
type Users []string
|
||||
|
||||
// +k8s:deepcopy-gen=true
|
||||
|
||||
// ClientTLS holds the TLS specific configurations as client
|
||||
// CA, Cert and Key can be either path or file contents.
|
||||
type ClientTLS struct {
|
||||
CA string `description:"TLS CA" json:"ca,omitempty"`
|
||||
CAOptional bool `description:"TLS CA.Optional" json:"caOptional,omitempty"`
|
||||
Cert string `description:"TLS cert" json:"cert,omitempty"`
|
||||
Key string `description:"TLS key" json:"key,omitempty"`
|
||||
InsecureSkipVerify bool `description:"TLS insecure skip verify" json:"insecureSkipVerify,omitempty"`
|
||||
}
|
134
pkg/config/static/entrypoints.go
Normal file
134
pkg/config/static/entrypoints.go
Normal file
|
@ -0,0 +1,134 @@
|
|||
package static
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
)
|
||||
|
||||
// EntryPoint holds the entry point configuration.
|
||||
type EntryPoint struct {
|
||||
Address string
|
||||
Transport *EntryPointsTransport
|
||||
ProxyProtocol *ProxyProtocol
|
||||
ForwardedHeaders *ForwardedHeaders
|
||||
}
|
||||
|
||||
// ForwardedHeaders Trust client forwarding headers.
|
||||
type ForwardedHeaders struct {
|
||||
Insecure bool
|
||||
TrustedIPs []string
|
||||
}
|
||||
|
||||
// ProxyProtocol contains Proxy-Protocol configuration.
|
||||
type ProxyProtocol struct {
|
||||
Insecure bool `export:"true"`
|
||||
TrustedIPs []string
|
||||
}
|
||||
|
||||
// EntryPoints holds the HTTP entry point list.
|
||||
type EntryPoints map[string]*EntryPoint
|
||||
|
||||
// EntryPointsTransport configures communication between clients and Traefik.
|
||||
type EntryPointsTransport struct {
|
||||
LifeCycle *LifeCycle `description:"Timeouts influencing the server life cycle" export:"true"`
|
||||
RespondingTimeouts *RespondingTimeouts `description:"Timeouts for incoming requests to the Traefik instance" export:"true"`
|
||||
}
|
||||
|
||||
// String is the method to format the flag's value, part of the flag.Value interface.
|
||||
// The String method's output will be used in diagnostics.
|
||||
func (ep EntryPoints) String() string {
|
||||
return fmt.Sprintf("%+v", map[string]*EntryPoint(ep))
|
||||
}
|
||||
|
||||
// Get return the EntryPoints map.
|
||||
func (ep *EntryPoints) Get() interface{} {
|
||||
return *ep
|
||||
}
|
||||
|
||||
// SetValue sets the EntryPoints map with val.
|
||||
func (ep *EntryPoints) SetValue(val interface{}) {
|
||||
*ep = val.(EntryPoints)
|
||||
}
|
||||
|
||||
// Type is type of the struct.
|
||||
func (ep *EntryPoints) Type() string {
|
||||
return "entrypoints"
|
||||
}
|
||||
|
||||
// Set is the method to set the flag value, part of the flag.Value interface.
|
||||
// Set's argument is a string to be parsed to set the flag.
|
||||
// It's a comma-separated list, so we split it.
|
||||
func (ep *EntryPoints) Set(value string) error {
|
||||
result := parseEntryPointsConfiguration(value)
|
||||
|
||||
(*ep)[result["name"]] = &EntryPoint{
|
||||
Address: result["address"],
|
||||
ProxyProtocol: makeEntryPointProxyProtocol(result),
|
||||
ForwardedHeaders: makeEntryPointForwardedHeaders(result),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeEntryPointProxyProtocol(result map[string]string) *ProxyProtocol {
|
||||
var proxyProtocol *ProxyProtocol
|
||||
|
||||
ppTrustedIPs := result["proxyprotocol_trustedips"]
|
||||
if len(result["proxyprotocol_insecure"]) > 0 || len(ppTrustedIPs) > 0 {
|
||||
proxyProtocol = &ProxyProtocol{
|
||||
Insecure: toBool(result, "proxyprotocol_insecure"),
|
||||
}
|
||||
if len(ppTrustedIPs) > 0 {
|
||||
proxyProtocol.TrustedIPs = strings.Split(ppTrustedIPs, ",")
|
||||
}
|
||||
}
|
||||
|
||||
if proxyProtocol != nil && proxyProtocol.Insecure {
|
||||
log.Warn("ProxyProtocol.insecure:true is dangerous. Please use 'ProxyProtocol.TrustedIPs:IPs' and remove 'ProxyProtocol.insecure:true'")
|
||||
}
|
||||
|
||||
return proxyProtocol
|
||||
}
|
||||
|
||||
func parseEntryPointsConfiguration(raw string) map[string]string {
|
||||
sections := strings.Fields(raw)
|
||||
|
||||
config := make(map[string]string)
|
||||
for _, part := range sections {
|
||||
field := strings.SplitN(part, ":", 2)
|
||||
name := strings.ToLower(strings.Replace(field[0], ".", "_", -1))
|
||||
if len(field) > 1 {
|
||||
config[name] = field[1]
|
||||
} else {
|
||||
if strings.EqualFold(name, "TLS") {
|
||||
config["tls_acme"] = "TLS"
|
||||
} else {
|
||||
config[name] = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
func toBool(conf map[string]string, key string) bool {
|
||||
if val, ok := conf[key]; ok {
|
||||
return strings.EqualFold(val, "true") ||
|
||||
strings.EqualFold(val, "enable") ||
|
||||
strings.EqualFold(val, "on")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func makeEntryPointForwardedHeaders(result map[string]string) *ForwardedHeaders {
|
||||
forwardedHeaders := &ForwardedHeaders{}
|
||||
forwardedHeaders.Insecure = toBool(result, "forwardedheaders_insecure")
|
||||
|
||||
fhTrustedIPs := result["forwardedheaders_trustedips"]
|
||||
if len(fhTrustedIPs) > 0 {
|
||||
forwardedHeaders.TrustedIPs = strings.Split(fhTrustedIPs, ",")
|
||||
}
|
||||
|
||||
return forwardedHeaders
|
||||
}
|
257
pkg/config/static/entrypoints_test.go
Normal file
257
pkg/config/static/entrypoints_test.go
Normal file
|
@ -0,0 +1,257 @@
|
|||
package static
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_parseEntryPointsConfiguration(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
value string
|
||||
expectedResult map[string]string
|
||||
}{
|
||||
{
|
||||
name: "all parameters",
|
||||
value: "Name:foo " +
|
||||
"Address::8000 " +
|
||||
"CA:car " +
|
||||
"CA.Optional:true " +
|
||||
"Redirect.EntryPoint:https " +
|
||||
"Redirect.Regex:http://localhost/(.*) " +
|
||||
"Redirect.Replacement:http://mydomain/$1 " +
|
||||
"Redirect.Permanent:true " +
|
||||
"Compress:true " +
|
||||
"ProxyProtocol.TrustedIPs:192.168.0.1 " +
|
||||
"ForwardedHeaders.TrustedIPs:10.0.0.3/24,20.0.0.3/24 " +
|
||||
"Auth.Basic.Realm:myRealm " +
|
||||
"Auth.Basic.Users:test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/,test2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0 " +
|
||||
"Auth.Basic.RemoveHeader:true " +
|
||||
"Auth.Digest.Users:test:traefik:a2688e031edb4be6a3797f3882655c05,test2:traefik:518845800f9e2bfb1f1f740ec24f074e " +
|
||||
"Auth.Digest.RemoveHeader:true " +
|
||||
"Auth.HeaderField:X-WebAuth-User " +
|
||||
"Auth.Forward.Address:https://authserver.com/auth " +
|
||||
"Auth.Forward.AuthResponseHeaders:X-Auth,X-Test,X-Secret " +
|
||||
"Auth.Forward.TrustForwardHeader:true " +
|
||||
"Auth.Forward.TLS.CA:path/to/local.crt " +
|
||||
"Auth.Forward.TLS.CAOptional:true " +
|
||||
"Auth.Forward.TLS.Cert:path/to/foo.cert " +
|
||||
"Auth.Forward.TLS.Key:path/to/foo.key " +
|
||||
"Auth.Forward.TLS.InsecureSkipVerify:true " +
|
||||
"WhiteList.SourceRange:10.42.0.0/16,152.89.1.33/32,afed:be44::/16 " +
|
||||
"WhiteList.IPStrategy.depth:3 " +
|
||||
"WhiteList.IPStrategy.ExcludedIPs:10.0.0.3/24,20.0.0.3/24 " +
|
||||
"ClientIPStrategy.depth:3 " +
|
||||
"ClientIPStrategy.ExcludedIPs:10.0.0.3/24,20.0.0.3/24 ",
|
||||
expectedResult: map[string]string{
|
||||
"address": ":8000",
|
||||
"auth_basic_realm": "myRealm",
|
||||
"auth_basic_users": "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/,test2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0",
|
||||
"auth_basic_removeheader": "true",
|
||||
"auth_digest_users": "test:traefik:a2688e031edb4be6a3797f3882655c05,test2:traefik:518845800f9e2bfb1f1f740ec24f074e",
|
||||
"auth_digest_removeheader": "true",
|
||||
"auth_forward_address": "https://authserver.com/auth",
|
||||
"auth_forward_authresponseheaders": "X-Auth,X-Test,X-Secret",
|
||||
"auth_forward_tls_ca": "path/to/local.crt",
|
||||
"auth_forward_tls_caoptional": "true",
|
||||
"auth_forward_tls_cert": "path/to/foo.cert",
|
||||
"auth_forward_tls_insecureskipverify": "true",
|
||||
"auth_forward_tls_key": "path/to/foo.key",
|
||||
"auth_forward_trustforwardheader": "true",
|
||||
"auth_headerfield": "X-WebAuth-User",
|
||||
"ca": "car",
|
||||
"ca_optional": "true",
|
||||
"compress": "true",
|
||||
"forwardedheaders_trustedips": "10.0.0.3/24,20.0.0.3/24",
|
||||
"name": "foo",
|
||||
"proxyprotocol_trustedips": "192.168.0.1",
|
||||
"redirect_entrypoint": "https",
|
||||
"redirect_permanent": "true",
|
||||
"redirect_regex": "http://localhost/(.*)",
|
||||
"redirect_replacement": "http://mydomain/$1",
|
||||
"whitelist_sourcerange": "10.42.0.0/16,152.89.1.33/32,afed:be44::/16",
|
||||
"whitelist_ipstrategy_depth": "3",
|
||||
"whitelist_ipstrategy_excludedips": "10.0.0.3/24,20.0.0.3/24",
|
||||
"clientipstrategy_depth": "3",
|
||||
"clientipstrategy_excludedips": "10.0.0.3/24,20.0.0.3/24",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "compress on",
|
||||
value: "name:foo Compress:on",
|
||||
expectedResult: map[string]string{
|
||||
"name": "foo",
|
||||
"compress": "on",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conf := parseEntryPointsConfiguration(test.value)
|
||||
|
||||
assert.Len(t, conf, len(test.expectedResult))
|
||||
assert.Equal(t, test.expectedResult, conf)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_toBool(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
value string
|
||||
key string
|
||||
expectedBool bool
|
||||
}{
|
||||
{
|
||||
name: "on",
|
||||
value: "on",
|
||||
key: "foo",
|
||||
expectedBool: true,
|
||||
},
|
||||
{
|
||||
name: "true",
|
||||
value: "true",
|
||||
key: "foo",
|
||||
expectedBool: true,
|
||||
},
|
||||
{
|
||||
name: "enable",
|
||||
value: "enable",
|
||||
key: "foo",
|
||||
expectedBool: true,
|
||||
},
|
||||
{
|
||||
name: "arbitrary string",
|
||||
value: "bar",
|
||||
key: "foo",
|
||||
expectedBool: false,
|
||||
},
|
||||
{
|
||||
name: "no existing entry",
|
||||
value: "bar",
|
||||
key: "fii",
|
||||
expectedBool: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conf := map[string]string{
|
||||
"foo": test.value,
|
||||
}
|
||||
|
||||
result := toBool(conf, test.key)
|
||||
|
||||
assert.Equal(t, test.expectedBool, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEntryPoints_Set(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
expression string
|
||||
expectedEntryPointName string
|
||||
expectedEntryPoint *EntryPoint
|
||||
}{
|
||||
{
|
||||
name: "all parameters camelcase",
|
||||
expression: "Name:foo " +
|
||||
"Address::8000 " +
|
||||
"CA:car " +
|
||||
"CA.Optional:true " +
|
||||
"ProxyProtocol.TrustedIPs:192.168.0.1 ",
|
||||
expectedEntryPointName: "foo",
|
||||
expectedEntryPoint: &EntryPoint{
|
||||
Address: ":8000",
|
||||
ProxyProtocol: &ProxyProtocol{
|
||||
Insecure: false,
|
||||
TrustedIPs: []string{"192.168.0.1"},
|
||||
},
|
||||
ForwardedHeaders: &ForwardedHeaders{},
|
||||
// FIXME Test ServersTransport
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all parameters lowercase",
|
||||
expression: "Name:foo " +
|
||||
"address::8000 " +
|
||||
"tls " +
|
||||
"tls.minversion:VersionTLS11 " +
|
||||
"tls.ciphersuites:TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA384,TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA " +
|
||||
"ca:car " +
|
||||
"ca.Optional:true " +
|
||||
"proxyProtocol.TrustedIPs:192.168.0.1 ",
|
||||
expectedEntryPointName: "foo",
|
||||
expectedEntryPoint: &EntryPoint{
|
||||
Address: ":8000",
|
||||
ProxyProtocol: &ProxyProtocol{
|
||||
Insecure: false,
|
||||
TrustedIPs: []string{"192.168.0.1"},
|
||||
},
|
||||
ForwardedHeaders: &ForwardedHeaders{},
|
||||
// FIXME Test ServersTransport
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "default",
|
||||
expression: "Name:foo",
|
||||
expectedEntryPointName: "foo",
|
||||
expectedEntryPoint: &EntryPoint{
|
||||
ForwardedHeaders: &ForwardedHeaders{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ProxyProtocol insecure true",
|
||||
expression: "Name:foo ProxyProtocol.insecure:true",
|
||||
expectedEntryPointName: "foo",
|
||||
expectedEntryPoint: &EntryPoint{
|
||||
ProxyProtocol: &ProxyProtocol{Insecure: true},
|
||||
ForwardedHeaders: &ForwardedHeaders{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ProxyProtocol insecure false",
|
||||
expression: "Name:foo ProxyProtocol.insecure:false",
|
||||
expectedEntryPointName: "foo",
|
||||
expectedEntryPoint: &EntryPoint{
|
||||
ProxyProtocol: &ProxyProtocol{},
|
||||
ForwardedHeaders: &ForwardedHeaders{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ProxyProtocol TrustedIPs",
|
||||
expression: "Name:foo ProxyProtocol.TrustedIPs:10.0.0.3/24,20.0.0.3/24",
|
||||
expectedEntryPointName: "foo",
|
||||
expectedEntryPoint: &EntryPoint{
|
||||
ProxyProtocol: &ProxyProtocol{
|
||||
TrustedIPs: []string{"10.0.0.3/24", "20.0.0.3/24"},
|
||||
},
|
||||
ForwardedHeaders: &ForwardedHeaders{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
eps := EntryPoints{}
|
||||
err := eps.Set(test.expression)
|
||||
require.NoError(t, err)
|
||||
|
||||
ep := eps[test.expectedEntryPointName]
|
||||
assert.EqualValues(t, test.expectedEntryPoint, ep)
|
||||
})
|
||||
}
|
||||
}
|
401
pkg/config/static/static_config.go
Normal file
401
pkg/config/static/static_config.go
Normal file
|
@ -0,0 +1,401 @@
|
|||
package static
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/containous/flaeg/parse"
|
||||
"github.com/containous/traefik/old/provider/boltdb"
|
||||
"github.com/containous/traefik/old/provider/consul"
|
||||
"github.com/containous/traefik/old/provider/consulcatalog"
|
||||
"github.com/containous/traefik/old/provider/dynamodb"
|
||||
"github.com/containous/traefik/old/provider/ecs"
|
||||
"github.com/containous/traefik/old/provider/etcd"
|
||||
"github.com/containous/traefik/old/provider/eureka"
|
||||
"github.com/containous/traefik/old/provider/mesos"
|
||||
"github.com/containous/traefik/old/provider/rancher"
|
||||
"github.com/containous/traefik/old/provider/zk"
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
"github.com/containous/traefik/pkg/ping"
|
||||
acmeprovider "github.com/containous/traefik/pkg/provider/acme"
|
||||
"github.com/containous/traefik/pkg/provider/docker"
|
||||
"github.com/containous/traefik/pkg/provider/file"
|
||||
"github.com/containous/traefik/pkg/provider/kubernetes/crd"
|
||||
"github.com/containous/traefik/pkg/provider/kubernetes/ingress"
|
||||
"github.com/containous/traefik/pkg/provider/marathon"
|
||||
"github.com/containous/traefik/pkg/provider/rest"
|
||||
"github.com/containous/traefik/pkg/tls"
|
||||
"github.com/containous/traefik/pkg/tracing/datadog"
|
||||
"github.com/containous/traefik/pkg/tracing/instana"
|
||||
"github.com/containous/traefik/pkg/tracing/jaeger"
|
||||
"github.com/containous/traefik/pkg/tracing/zipkin"
|
||||
"github.com/containous/traefik/pkg/types"
|
||||
assetfs "github.com/elazarl/go-bindata-assetfs"
|
||||
"github.com/go-acme/lego/challenge/dns01"
|
||||
jaegercli "github.com/uber/jaeger-client-go"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultInternalEntryPointName the name of the default internal entry point
|
||||
DefaultInternalEntryPointName = "traefik"
|
||||
|
||||
// DefaultGraceTimeout controls how long Traefik serves pending requests
|
||||
// prior to shutting down.
|
||||
DefaultGraceTimeout = 10 * time.Second
|
||||
|
||||
// DefaultIdleTimeout before closing an idle connection.
|
||||
DefaultIdleTimeout = 180 * time.Second
|
||||
|
||||
// DefaultAcmeCAServer is the default ACME API endpoint
|
||||
DefaultAcmeCAServer = "https://acme-v02.api.letsencrypt.org/directory"
|
||||
)
|
||||
|
||||
// Configuration is the static configuration
|
||||
type Configuration struct {
|
||||
Global *Global `description:"Global configuration options" export:"true"`
|
||||
|
||||
ServersTransport *ServersTransport `description:"Servers default transport" export:"true"`
|
||||
EntryPoints EntryPoints `description:"Entrypoints definition using format: --entryPoints='Name:http Address::8000 Redirect.EntryPoint:https' --entryPoints='Name:https Address::4442 TLS:tests/traefik.crt,tests/traefik.key;prod/traefik.crt,prod/traefik.key'" export:"true"`
|
||||
Providers *Providers `description:"Providers configuration" export:"true"`
|
||||
|
||||
API *API `description:"Enable api/dashboard" export:"true"`
|
||||
Metrics *types.Metrics `description:"Enable a metrics exporter" export:"true"`
|
||||
Ping *ping.Handler `description:"Enable ping" export:"true"`
|
||||
// Rest *rest.Provider `description:"Enable Rest backend with default settings" export:"true"`
|
||||
|
||||
Log *types.TraefikLog
|
||||
AccessLog *types.AccessLog `description:"Access log settings" export:"true"`
|
||||
Tracing *Tracing `description:"OpenTracing configuration" export:"true"`
|
||||
|
||||
HostResolver *types.HostResolverConfig `description:"Enable CNAME Flattening" export:"true"`
|
||||
|
||||
ACME *acmeprovider.Configuration `description:"Enable ACME (Let's Encrypt): automatic SSL" export:"true"`
|
||||
}
|
||||
|
||||
// Global holds the global configuration.
|
||||
type Global struct {
|
||||
Debug bool `short:"d" description:"Enable debug mode" export:"true"`
|
||||
CheckNewVersion bool `description:"Periodically check if a new version has been released" export:"true"`
|
||||
SendAnonymousUsage *bool `description:"send periodically anonymous usage statistics" export:"true"`
|
||||
}
|
||||
|
||||
// ServersTransport options to configure communication between Traefik and the servers
|
||||
type ServersTransport struct {
|
||||
InsecureSkipVerify bool `description:"Disable SSL certificate verification" export:"true"`
|
||||
RootCAs tls.FilesOrContents `description:"Add cert file for self-signed certificate"`
|
||||
MaxIdleConnsPerHost int `description:"If non-zero, controls the maximum idle (keep-alive) to keep per-host. If zero, DefaultMaxIdleConnsPerHost is used" export:"true"`
|
||||
ForwardingTimeouts *ForwardingTimeouts `description:"Timeouts for requests forwarded to the backend servers" export:"true"`
|
||||
}
|
||||
|
||||
// API holds the API configuration
|
||||
type API struct {
|
||||
EntryPoint string `description:"EntryPoint" export:"true"`
|
||||
Dashboard bool `description:"Activate dashboard" export:"true"`
|
||||
Statistics *types.Statistics `description:"Enable more detailed statistics" export:"true"`
|
||||
Middlewares []string `description:"Middleware list" export:"true"`
|
||||
DashboardAssets *assetfs.AssetFS `json:"-"`
|
||||
}
|
||||
|
||||
// RespondingTimeouts contains timeout configurations for incoming requests to the Traefik instance.
|
||||
type RespondingTimeouts struct {
|
||||
ReadTimeout parse.Duration `description:"ReadTimeout is the maximum duration for reading the entire request, including the body. If zero, no timeout is set" export:"true"`
|
||||
WriteTimeout parse.Duration `description:"WriteTimeout is the maximum duration before timing out writes of the response. If zero, no timeout is set" export:"true"`
|
||||
IdleTimeout parse.Duration `description:"IdleTimeout is the maximum amount duration an idle (keep-alive) connection will remain idle before closing itself. Defaults to 180 seconds. If zero, no timeout is set" export:"true"`
|
||||
}
|
||||
|
||||
// ForwardingTimeouts contains timeout configurations for forwarding requests to the backend servers.
|
||||
type ForwardingTimeouts struct {
|
||||
DialTimeout parse.Duration `description:"The amount of time to wait until a connection to a backend server can be established. Defaults to 30 seconds. If zero, no timeout exists" export:"true"`
|
||||
ResponseHeaderTimeout parse.Duration `description:"The amount of time to wait for a server's response headers after fully writing the request (including its body, if any). If zero, no timeout exists" export:"true"`
|
||||
}
|
||||
|
||||
// LifeCycle contains configurations relevant to the lifecycle (such as the shutdown phase) of Traefik.
|
||||
type LifeCycle struct {
|
||||
RequestAcceptGraceTimeout parse.Duration `description:"Duration to keep accepting requests before Traefik initiates the graceful shutdown procedure"`
|
||||
GraceTimeOut parse.Duration `description:"Duration to give active requests a chance to finish before Traefik stops"`
|
||||
}
|
||||
|
||||
// Tracing holds the tracing configuration.
|
||||
type Tracing struct {
|
||||
Backend string `description:"Selects the tracking backend ('jaeger','zipkin','datadog','instana')." export:"true"`
|
||||
ServiceName string `description:"Set the name for this service" export:"true"`
|
||||
SpanNameLimit int `description:"Set the maximum character limit for Span names (default 0 = no limit)" export:"true"`
|
||||
Jaeger *jaeger.Config `description:"Settings for jaeger"`
|
||||
Zipkin *zipkin.Config `description:"Settings for zipkin"`
|
||||
DataDog *datadog.Config `description:"Settings for DataDog"`
|
||||
Instana *instana.Config `description:"Settings for Instana"`
|
||||
}
|
||||
|
||||
// Providers contains providers configuration
|
||||
type Providers struct {
|
||||
ProvidersThrottleDuration parse.Duration `description:"Backends throttle duration: minimum duration between 2 events from providers before applying a new configuration. It avoids unnecessary reloads if multiples events are sent in a short amount of time." export:"true"`
|
||||
Docker *docker.Provider `description:"Enable Docker backend with default settings" export:"true"`
|
||||
File *file.Provider `description:"Enable File backend with default settings" export:"true"`
|
||||
Marathon *marathon.Provider `description:"Enable Marathon backend with default settings" export:"true"`
|
||||
Consul *consul.Provider `description:"Enable Consul backend with default settings" export:"true"`
|
||||
ConsulCatalog *consulcatalog.Provider `description:"Enable Consul catalog backend with default settings" export:"true"`
|
||||
Etcd *etcd.Provider `description:"Enable Etcd backend with default settings" export:"true"`
|
||||
Zookeeper *zk.Provider `description:"Enable Zookeeper backend with default settings" export:"true"`
|
||||
Boltdb *boltdb.Provider `description:"Enable Boltdb backend with default settings" export:"true"`
|
||||
Kubernetes *ingress.Provider `description:"Enable Kubernetes backend with default settings" export:"true"`
|
||||
KubernetesCRD *crd.Provider `description:"Enable Kubernetes backend with default settings" export:"true"`
|
||||
Mesos *mesos.Provider `description:"Enable Mesos backend with default settings" export:"true"`
|
||||
Eureka *eureka.Provider `description:"Enable Eureka backend with default settings" export:"true"`
|
||||
ECS *ecs.Provider `description:"Enable ECS backend with default settings" export:"true"`
|
||||
Rancher *rancher.Provider `description:"Enable Rancher backend with default settings" export:"true"`
|
||||
DynamoDB *dynamodb.Provider `description:"Enable DynamoDB backend with default settings" export:"true"`
|
||||
Rest *rest.Provider `description:"Enable Rest backend with default settings" export:"true"`
|
||||
}
|
||||
|
||||
// SetEffectiveConfiguration adds missing configuration parameters derived from existing ones.
|
||||
// It also takes care of maintaining backwards compatibility.
|
||||
func (c *Configuration) SetEffectiveConfiguration(configFile string) {
|
||||
if len(c.EntryPoints) == 0 {
|
||||
c.EntryPoints = EntryPoints{
|
||||
"http": &EntryPoint{
|
||||
Address: ":80",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if (c.API != nil && c.API.EntryPoint == DefaultInternalEntryPointName) ||
|
||||
(c.Ping != nil && c.Ping.EntryPoint == DefaultInternalEntryPointName) ||
|
||||
(c.Metrics != nil && c.Metrics.Prometheus != nil && c.Metrics.Prometheus.EntryPoint == DefaultInternalEntryPointName) ||
|
||||
(c.Providers.Rest != nil && c.Providers.Rest.EntryPoint == DefaultInternalEntryPointName) {
|
||||
if _, ok := c.EntryPoints[DefaultInternalEntryPointName]; !ok {
|
||||
c.EntryPoints[DefaultInternalEntryPointName] = &EntryPoint{Address: ":8080"}
|
||||
}
|
||||
}
|
||||
|
||||
for _, entryPoint := range c.EntryPoints {
|
||||
if entryPoint.Transport == nil {
|
||||
entryPoint.Transport = &EntryPointsTransport{}
|
||||
}
|
||||
|
||||
// Make sure LifeCycle isn't nil to spare nil checks elsewhere.
|
||||
if entryPoint.Transport.LifeCycle == nil {
|
||||
entryPoint.Transport.LifeCycle = &LifeCycle{
|
||||
GraceTimeOut: parse.Duration(DefaultGraceTimeout),
|
||||
}
|
||||
entryPoint.Transport.RespondingTimeouts = &RespondingTimeouts{
|
||||
IdleTimeout: parse.Duration(DefaultIdleTimeout),
|
||||
}
|
||||
}
|
||||
|
||||
if entryPoint.ForwardedHeaders == nil {
|
||||
entryPoint.ForwardedHeaders = &ForwardedHeaders{}
|
||||
}
|
||||
}
|
||||
|
||||
if c.Providers.Rancher != nil {
|
||||
// Ensure backwards compatibility for now
|
||||
if len(c.Providers.Rancher.AccessKey) > 0 ||
|
||||
len(c.Providers.Rancher.Endpoint) > 0 ||
|
||||
len(c.Providers.Rancher.SecretKey) > 0 {
|
||||
|
||||
if c.Providers.Rancher.API == nil {
|
||||
c.Providers.Rancher.API = &rancher.APIConfiguration{
|
||||
AccessKey: c.Providers.Rancher.AccessKey,
|
||||
SecretKey: c.Providers.Rancher.SecretKey,
|
||||
Endpoint: c.Providers.Rancher.Endpoint,
|
||||
}
|
||||
}
|
||||
log.Warn("Deprecated configuration found: rancher.[accesskey|secretkey|endpoint]. " +
|
||||
"Please use rancher.api.[accesskey|secretkey|endpoint] instead.")
|
||||
}
|
||||
|
||||
if c.Providers.Rancher.Metadata != nil && len(c.Providers.Rancher.Metadata.Prefix) == 0 {
|
||||
c.Providers.Rancher.Metadata.Prefix = "latest"
|
||||
}
|
||||
}
|
||||
|
||||
if c.Providers.Docker != nil {
|
||||
if c.Providers.Docker.SwarmModeRefreshSeconds <= 0 {
|
||||
c.Providers.Docker.SwarmModeRefreshSeconds = 15
|
||||
}
|
||||
}
|
||||
|
||||
if c.Providers.File != nil {
|
||||
c.Providers.File.TraefikFile = configFile
|
||||
}
|
||||
|
||||
c.initACMEProvider()
|
||||
c.initTracing()
|
||||
}
|
||||
|
||||
func (c *Configuration) initTracing() {
|
||||
if c.Tracing != nil {
|
||||
switch c.Tracing.Backend {
|
||||
case jaeger.Name:
|
||||
if c.Tracing.Jaeger == nil {
|
||||
c.Tracing.Jaeger = &jaeger.Config{
|
||||
SamplingServerURL: "http://localhost:5778/sampling",
|
||||
SamplingType: "const",
|
||||
SamplingParam: 1.0,
|
||||
LocalAgentHostPort: "127.0.0.1:6831",
|
||||
Propagation: "jaeger",
|
||||
Gen128Bit: false,
|
||||
TraceContextHeaderName: jaegercli.TraceContextHeaderName,
|
||||
}
|
||||
}
|
||||
if c.Tracing.Zipkin != nil {
|
||||
log.Warn("Zipkin configuration will be ignored")
|
||||
c.Tracing.Zipkin = nil
|
||||
}
|
||||
if c.Tracing.DataDog != nil {
|
||||
log.Warn("DataDog configuration will be ignored")
|
||||
c.Tracing.DataDog = nil
|
||||
}
|
||||
if c.Tracing.Instana != nil {
|
||||
log.Warn("Instana configuration will be ignored")
|
||||
c.Tracing.Instana = nil
|
||||
}
|
||||
case zipkin.Name:
|
||||
if c.Tracing.Zipkin == nil {
|
||||
c.Tracing.Zipkin = &zipkin.Config{
|
||||
HTTPEndpoint: "http://localhost:9411/api/v1/spans",
|
||||
SameSpan: false,
|
||||
ID128Bit: true,
|
||||
Debug: false,
|
||||
SampleRate: 1.0,
|
||||
}
|
||||
}
|
||||
if c.Tracing.Jaeger != nil {
|
||||
log.Warn("Jaeger configuration will be ignored")
|
||||
c.Tracing.Jaeger = nil
|
||||
}
|
||||
if c.Tracing.DataDog != nil {
|
||||
log.Warn("DataDog configuration will be ignored")
|
||||
c.Tracing.DataDog = nil
|
||||
}
|
||||
if c.Tracing.Instana != nil {
|
||||
log.Warn("Instana configuration will be ignored")
|
||||
c.Tracing.Instana = nil
|
||||
}
|
||||
case datadog.Name:
|
||||
if c.Tracing.DataDog == nil {
|
||||
c.Tracing.DataDog = &datadog.Config{
|
||||
LocalAgentHostPort: "localhost:8126",
|
||||
GlobalTag: "",
|
||||
Debug: false,
|
||||
}
|
||||
}
|
||||
if c.Tracing.Zipkin != nil {
|
||||
log.Warn("Zipkin configuration will be ignored")
|
||||
c.Tracing.Zipkin = nil
|
||||
}
|
||||
if c.Tracing.Jaeger != nil {
|
||||
log.Warn("Jaeger configuration will be ignored")
|
||||
c.Tracing.Jaeger = nil
|
||||
}
|
||||
if c.Tracing.Instana != nil {
|
||||
log.Warn("Instana configuration will be ignored")
|
||||
c.Tracing.Instana = nil
|
||||
}
|
||||
case instana.Name:
|
||||
if c.Tracing.Instana == nil {
|
||||
c.Tracing.Instana = &instana.Config{
|
||||
LocalAgentHost: "localhost",
|
||||
LocalAgentPort: 42699,
|
||||
LogLevel: "info",
|
||||
}
|
||||
}
|
||||
if c.Tracing.Zipkin != nil {
|
||||
log.Warn("Zipkin configuration will be ignored")
|
||||
c.Tracing.Zipkin = nil
|
||||
}
|
||||
if c.Tracing.Jaeger != nil {
|
||||
log.Warn("Jaeger configuration will be ignored")
|
||||
c.Tracing.Jaeger = nil
|
||||
}
|
||||
if c.Tracing.DataDog != nil {
|
||||
log.Warn("DataDog configuration will be ignored")
|
||||
c.Tracing.DataDog = nil
|
||||
}
|
||||
default:
|
||||
log.Warnf("Unknown tracer %q", c.Tracing.Backend)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME handle on new configuration ACME struct
|
||||
func (c *Configuration) initACMEProvider() {
|
||||
if c.ACME != nil {
|
||||
c.ACME.CAServer = getSafeACMECAServer(c.ACME.CAServer)
|
||||
|
||||
if c.ACME.DNSChallenge != nil && c.ACME.HTTPChallenge != nil {
|
||||
log.Warn("Unable to use DNS challenge and HTTP challenge at the same time. Fallback to DNS challenge.")
|
||||
c.ACME.HTTPChallenge = nil
|
||||
}
|
||||
|
||||
if c.ACME.DNSChallenge != nil && c.ACME.TLSChallenge != nil {
|
||||
log.Warn("Unable to use DNS challenge and TLS challenge at the same time. Fallback to DNS challenge.")
|
||||
c.ACME.TLSChallenge = nil
|
||||
}
|
||||
|
||||
if c.ACME.HTTPChallenge != nil && c.ACME.TLSChallenge != nil {
|
||||
log.Warn("Unable to use HTTP challenge and TLS challenge at the same time. Fallback to TLS challenge.")
|
||||
c.ACME.HTTPChallenge = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// InitACMEProvider create an acme provider from the ACME part of globalConfiguration
|
||||
func (c *Configuration) InitACMEProvider() (*acmeprovider.Provider, error) {
|
||||
if c.ACME != nil {
|
||||
if len(c.ACME.Storage) == 0 {
|
||||
return nil, errors.New("unable to initialize ACME provider with no storage location for the certificates")
|
||||
}
|
||||
return &acmeprovider.Provider{
|
||||
Configuration: c.ACME,
|
||||
}, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// ValidateConfiguration validate that configuration is coherent
|
||||
func (c *Configuration) ValidateConfiguration() {
|
||||
if c.ACME != nil {
|
||||
for _, domain := range c.ACME.Domains {
|
||||
if domain.Main != dns01.UnFqdn(domain.Main) {
|
||||
log.Warnf("FQDN detected, please remove the trailing dot: %s", domain.Main)
|
||||
}
|
||||
for _, san := range domain.SANs {
|
||||
if san != dns01.UnFqdn(san) {
|
||||
log.Warnf("FQDN detected, please remove the trailing dot: %s", san)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// FIXME Validate store config?
|
||||
// if c.ACME != nil {
|
||||
// if _, ok := c.EntryPoints[c.ACME.EntryPoint]; !ok {
|
||||
// log.Fatalf("Unknown entrypoint %q for ACME configuration", c.ACME.EntryPoint)
|
||||
// }
|
||||
// else if c.EntryPoints[c.ACME.EntryPoint].TLS == nil {
|
||||
// log.Fatalf("Entrypoint %q has no TLS configuration for ACME configuration", c.ACME.EntryPoint)
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
func getSafeACMECAServer(caServerSrc string) string {
|
||||
if len(caServerSrc) == 0 {
|
||||
return DefaultAcmeCAServer
|
||||
}
|
||||
|
||||
if strings.HasPrefix(caServerSrc, "https://acme-v01.api.letsencrypt.org") {
|
||||
caServer := strings.Replace(caServerSrc, "v01", "v02", 1)
|
||||
log.Warnf("The CA server %[1]q refers to a v01 endpoint of the ACME API, please change to %[2]q. Fallback to %[2]q.", caServerSrc, caServer)
|
||||
return caServer
|
||||
}
|
||||
|
||||
if strings.HasPrefix(caServerSrc, "https://acme-staging.api.letsencrypt.org") {
|
||||
caServer := strings.Replace(caServerSrc, "https://acme-staging.api.letsencrypt.org", "https://acme-staging-v02.api.letsencrypt.org", 1)
|
||||
log.Warnf("The CA server %[1]q refers to a v01 endpoint of the ACME API, please change to %[2]q. Fallback to %[2]q.", caServerSrc, caServer)
|
||||
return caServer
|
||||
}
|
||||
|
||||
return caServerSrc
|
||||
}
|
733
pkg/config/zz_generated.deepcopy.go
Normal file
733
pkg/config/zz_generated.deepcopy.go
Normal file
|
@ -0,0 +1,733 @@
|
|||
// +build !ignore_autogenerated
|
||||
|
||||
/*
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2016-2019 Containous SAS
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
*/
|
||||
|
||||
// Code generated by deepcopy-gen. DO NOT EDIT.
|
||||
|
||||
package config
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *AddPrefix) DeepCopyInto(out *AddPrefix) {
|
||||
*out = *in
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new AddPrefix.
|
||||
func (in *AddPrefix) DeepCopy() *AddPrefix {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(AddPrefix)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *Auth) DeepCopyInto(out *Auth) {
|
||||
*out = *in
|
||||
if in.Basic != nil {
|
||||
in, out := &in.Basic, &out.Basic
|
||||
*out = new(BasicAuth)
|
||||
(*in).DeepCopyInto(*out)
|
||||
}
|
||||
if in.Digest != nil {
|
||||
in, out := &in.Digest, &out.Digest
|
||||
*out = new(DigestAuth)
|
||||
(*in).DeepCopyInto(*out)
|
||||
}
|
||||
if in.Forward != nil {
|
||||
in, out := &in.Forward, &out.Forward
|
||||
*out = new(ForwardAuth)
|
||||
(*in).DeepCopyInto(*out)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Auth.
|
||||
func (in *Auth) DeepCopy() *Auth {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(Auth)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *BasicAuth) DeepCopyInto(out *BasicAuth) {
|
||||
*out = *in
|
||||
if in.Users != nil {
|
||||
in, out := &in.Users, &out.Users
|
||||
*out = make(Users, len(*in))
|
||||
copy(*out, *in)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new BasicAuth.
|
||||
func (in *BasicAuth) DeepCopy() *BasicAuth {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(BasicAuth)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *Buffering) DeepCopyInto(out *Buffering) {
|
||||
*out = *in
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Buffering.
|
||||
func (in *Buffering) DeepCopy() *Buffering {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(Buffering)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *Chain) DeepCopyInto(out *Chain) {
|
||||
*out = *in
|
||||
if in.Middlewares != nil {
|
||||
in, out := &in.Middlewares, &out.Middlewares
|
||||
*out = make([]string, len(*in))
|
||||
copy(*out, *in)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Chain.
|
||||
func (in *Chain) DeepCopy() *Chain {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(Chain)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *CircuitBreaker) DeepCopyInto(out *CircuitBreaker) {
|
||||
*out = *in
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new CircuitBreaker.
|
||||
func (in *CircuitBreaker) DeepCopy() *CircuitBreaker {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(CircuitBreaker)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *ClientTLS) DeepCopyInto(out *ClientTLS) {
|
||||
*out = *in
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ClientTLS.
|
||||
func (in *ClientTLS) DeepCopy() *ClientTLS {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(ClientTLS)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *Compress) DeepCopyInto(out *Compress) {
|
||||
*out = *in
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Compress.
|
||||
func (in *Compress) DeepCopy() *Compress {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(Compress)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *DigestAuth) DeepCopyInto(out *DigestAuth) {
|
||||
*out = *in
|
||||
if in.Users != nil {
|
||||
in, out := &in.Users, &out.Users
|
||||
*out = make(Users, len(*in))
|
||||
copy(*out, *in)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new DigestAuth.
|
||||
func (in *DigestAuth) DeepCopy() *DigestAuth {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(DigestAuth)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *ErrorPage) DeepCopyInto(out *ErrorPage) {
|
||||
*out = *in
|
||||
if in.Status != nil {
|
||||
in, out := &in.Status, &out.Status
|
||||
*out = make([]string, len(*in))
|
||||
copy(*out, *in)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ErrorPage.
|
||||
func (in *ErrorPage) DeepCopy() *ErrorPage {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(ErrorPage)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *ForwardAuth) DeepCopyInto(out *ForwardAuth) {
|
||||
*out = *in
|
||||
if in.TLS != nil {
|
||||
in, out := &in.TLS, &out.TLS
|
||||
*out = new(ClientTLS)
|
||||
**out = **in
|
||||
}
|
||||
if in.AuthResponseHeaders != nil {
|
||||
in, out := &in.AuthResponseHeaders, &out.AuthResponseHeaders
|
||||
*out = make([]string, len(*in))
|
||||
copy(*out, *in)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ForwardAuth.
|
||||
func (in *ForwardAuth) DeepCopy() *ForwardAuth {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(ForwardAuth)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *Headers) DeepCopyInto(out *Headers) {
|
||||
*out = *in
|
||||
if in.CustomRequestHeaders != nil {
|
||||
in, out := &in.CustomRequestHeaders, &out.CustomRequestHeaders
|
||||
*out = make(map[string]string, len(*in))
|
||||
for key, val := range *in {
|
||||
(*out)[key] = val
|
||||
}
|
||||
}
|
||||
if in.CustomResponseHeaders != nil {
|
||||
in, out := &in.CustomResponseHeaders, &out.CustomResponseHeaders
|
||||
*out = make(map[string]string, len(*in))
|
||||
for key, val := range *in {
|
||||
(*out)[key] = val
|
||||
}
|
||||
}
|
||||
if in.AllowedHosts != nil {
|
||||
in, out := &in.AllowedHosts, &out.AllowedHosts
|
||||
*out = make([]string, len(*in))
|
||||
copy(*out, *in)
|
||||
}
|
||||
if in.HostsProxyHeaders != nil {
|
||||
in, out := &in.HostsProxyHeaders, &out.HostsProxyHeaders
|
||||
*out = make([]string, len(*in))
|
||||
copy(*out, *in)
|
||||
}
|
||||
if in.SSLProxyHeaders != nil {
|
||||
in, out := &in.SSLProxyHeaders, &out.SSLProxyHeaders
|
||||
*out = make(map[string]string, len(*in))
|
||||
for key, val := range *in {
|
||||
(*out)[key] = val
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Headers.
|
||||
func (in *Headers) DeepCopy() *Headers {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(Headers)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *IPStrategy) DeepCopyInto(out *IPStrategy) {
|
||||
*out = *in
|
||||
if in.ExcludedIPs != nil {
|
||||
in, out := &in.ExcludedIPs, &out.ExcludedIPs
|
||||
*out = make([]string, len(*in))
|
||||
copy(*out, *in)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new IPStrategy.
|
||||
func (in *IPStrategy) DeepCopy() *IPStrategy {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(IPStrategy)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *IPWhiteList) DeepCopyInto(out *IPWhiteList) {
|
||||
*out = *in
|
||||
if in.SourceRange != nil {
|
||||
in, out := &in.SourceRange, &out.SourceRange
|
||||
*out = make([]string, len(*in))
|
||||
copy(*out, *in)
|
||||
}
|
||||
if in.IPStrategy != nil {
|
||||
in, out := &in.IPStrategy, &out.IPStrategy
|
||||
*out = new(IPStrategy)
|
||||
(*in).DeepCopyInto(*out)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new IPWhiteList.
|
||||
func (in *IPWhiteList) DeepCopy() *IPWhiteList {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(IPWhiteList)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *MaxConn) DeepCopyInto(out *MaxConn) {
|
||||
*out = *in
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new MaxConn.
|
||||
func (in *MaxConn) DeepCopy() *MaxConn {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(MaxConn)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *Middleware) DeepCopyInto(out *Middleware) {
|
||||
*out = *in
|
||||
if in.AddPrefix != nil {
|
||||
in, out := &in.AddPrefix, &out.AddPrefix
|
||||
*out = new(AddPrefix)
|
||||
**out = **in
|
||||
}
|
||||
if in.StripPrefix != nil {
|
||||
in, out := &in.StripPrefix, &out.StripPrefix
|
||||
*out = new(StripPrefix)
|
||||
(*in).DeepCopyInto(*out)
|
||||
}
|
||||
if in.StripPrefixRegex != nil {
|
||||
in, out := &in.StripPrefixRegex, &out.StripPrefixRegex
|
||||
*out = new(StripPrefixRegex)
|
||||
(*in).DeepCopyInto(*out)
|
||||
}
|
||||
if in.ReplacePath != nil {
|
||||
in, out := &in.ReplacePath, &out.ReplacePath
|
||||
*out = new(ReplacePath)
|
||||
**out = **in
|
||||
}
|
||||
if in.ReplacePathRegex != nil {
|
||||
in, out := &in.ReplacePathRegex, &out.ReplacePathRegex
|
||||
*out = new(ReplacePathRegex)
|
||||
**out = **in
|
||||
}
|
||||
if in.Chain != nil {
|
||||
in, out := &in.Chain, &out.Chain
|
||||
*out = new(Chain)
|
||||
(*in).DeepCopyInto(*out)
|
||||
}
|
||||
if in.IPWhiteList != nil {
|
||||
in, out := &in.IPWhiteList, &out.IPWhiteList
|
||||
*out = new(IPWhiteList)
|
||||
(*in).DeepCopyInto(*out)
|
||||
}
|
||||
if in.Headers != nil {
|
||||
in, out := &in.Headers, &out.Headers
|
||||
*out = new(Headers)
|
||||
(*in).DeepCopyInto(*out)
|
||||
}
|
||||
if in.Errors != nil {
|
||||
in, out := &in.Errors, &out.Errors
|
||||
*out = new(ErrorPage)
|
||||
(*in).DeepCopyInto(*out)
|
||||
}
|
||||
if in.RateLimit != nil {
|
||||
in, out := &in.RateLimit, &out.RateLimit
|
||||
*out = new(RateLimit)
|
||||
(*in).DeepCopyInto(*out)
|
||||
}
|
||||
if in.RedirectRegex != nil {
|
||||
in, out := &in.RedirectRegex, &out.RedirectRegex
|
||||
*out = new(RedirectRegex)
|
||||
**out = **in
|
||||
}
|
||||
if in.RedirectScheme != nil {
|
||||
in, out := &in.RedirectScheme, &out.RedirectScheme
|
||||
*out = new(RedirectScheme)
|
||||
**out = **in
|
||||
}
|
||||
if in.BasicAuth != nil {
|
||||
in, out := &in.BasicAuth, &out.BasicAuth
|
||||
*out = new(BasicAuth)
|
||||
(*in).DeepCopyInto(*out)
|
||||
}
|
||||
if in.DigestAuth != nil {
|
||||
in, out := &in.DigestAuth, &out.DigestAuth
|
||||
*out = new(DigestAuth)
|
||||
(*in).DeepCopyInto(*out)
|
||||
}
|
||||
if in.ForwardAuth != nil {
|
||||
in, out := &in.ForwardAuth, &out.ForwardAuth
|
||||
*out = new(ForwardAuth)
|
||||
(*in).DeepCopyInto(*out)
|
||||
}
|
||||
if in.MaxConn != nil {
|
||||
in, out := &in.MaxConn, &out.MaxConn
|
||||
*out = new(MaxConn)
|
||||
**out = **in
|
||||
}
|
||||
if in.Buffering != nil {
|
||||
in, out := &in.Buffering, &out.Buffering
|
||||
*out = new(Buffering)
|
||||
**out = **in
|
||||
}
|
||||
if in.CircuitBreaker != nil {
|
||||
in, out := &in.CircuitBreaker, &out.CircuitBreaker
|
||||
*out = new(CircuitBreaker)
|
||||
**out = **in
|
||||
}
|
||||
if in.Compress != nil {
|
||||
in, out := &in.Compress, &out.Compress
|
||||
*out = new(Compress)
|
||||
**out = **in
|
||||
}
|
||||
if in.PassTLSClientCert != nil {
|
||||
in, out := &in.PassTLSClientCert, &out.PassTLSClientCert
|
||||
*out = new(PassTLSClientCert)
|
||||
(*in).DeepCopyInto(*out)
|
||||
}
|
||||
if in.Retry != nil {
|
||||
in, out := &in.Retry, &out.Retry
|
||||
*out = new(Retry)
|
||||
**out = **in
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Middleware.
|
||||
func (in *Middleware) DeepCopy() *Middleware {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(Middleware)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *PassTLSClientCert) DeepCopyInto(out *PassTLSClientCert) {
|
||||
*out = *in
|
||||
if in.Info != nil {
|
||||
in, out := &in.Info, &out.Info
|
||||
*out = new(TLSClientCertificateInfo)
|
||||
(*in).DeepCopyInto(*out)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new PassTLSClientCert.
|
||||
func (in *PassTLSClientCert) DeepCopy() *PassTLSClientCert {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(PassTLSClientCert)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *Rate) DeepCopyInto(out *Rate) {
|
||||
*out = *in
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Rate.
|
||||
func (in *Rate) DeepCopy() *Rate {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(Rate)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *RateLimit) DeepCopyInto(out *RateLimit) {
|
||||
*out = *in
|
||||
if in.RateSet != nil {
|
||||
in, out := &in.RateSet, &out.RateSet
|
||||
*out = make(map[string]*Rate, len(*in))
|
||||
for key, val := range *in {
|
||||
var outVal *Rate
|
||||
if val == nil {
|
||||
(*out)[key] = nil
|
||||
} else {
|
||||
in, out := &val, &outVal
|
||||
*out = new(Rate)
|
||||
**out = **in
|
||||
}
|
||||
(*out)[key] = outVal
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RateLimit.
|
||||
func (in *RateLimit) DeepCopy() *RateLimit {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(RateLimit)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *RedirectRegex) DeepCopyInto(out *RedirectRegex) {
|
||||
*out = *in
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RedirectRegex.
|
||||
func (in *RedirectRegex) DeepCopy() *RedirectRegex {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(RedirectRegex)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *RedirectScheme) DeepCopyInto(out *RedirectScheme) {
|
||||
*out = *in
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RedirectScheme.
|
||||
func (in *RedirectScheme) DeepCopy() *RedirectScheme {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(RedirectScheme)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *ReplacePath) DeepCopyInto(out *ReplacePath) {
|
||||
*out = *in
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ReplacePath.
|
||||
func (in *ReplacePath) DeepCopy() *ReplacePath {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(ReplacePath)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *ReplacePathRegex) DeepCopyInto(out *ReplacePathRegex) {
|
||||
*out = *in
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ReplacePathRegex.
|
||||
func (in *ReplacePathRegex) DeepCopy() *ReplacePathRegex {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(ReplacePathRegex)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *Retry) DeepCopyInto(out *Retry) {
|
||||
*out = *in
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Retry.
|
||||
func (in *Retry) DeepCopy() *Retry {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(Retry)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *StripPrefix) DeepCopyInto(out *StripPrefix) {
|
||||
*out = *in
|
||||
if in.Prefixes != nil {
|
||||
in, out := &in.Prefixes, &out.Prefixes
|
||||
*out = make([]string, len(*in))
|
||||
copy(*out, *in)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new StripPrefix.
|
||||
func (in *StripPrefix) DeepCopy() *StripPrefix {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(StripPrefix)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *StripPrefixRegex) DeepCopyInto(out *StripPrefixRegex) {
|
||||
*out = *in
|
||||
if in.Regex != nil {
|
||||
in, out := &in.Regex, &out.Regex
|
||||
*out = make([]string, len(*in))
|
||||
copy(*out, *in)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new StripPrefixRegex.
|
||||
func (in *StripPrefixRegex) DeepCopy() *StripPrefixRegex {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(StripPrefixRegex)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *TLSCLientCertificateDNInfo) DeepCopyInto(out *TLSCLientCertificateDNInfo) {
|
||||
*out = *in
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new TLSCLientCertificateDNInfo.
|
||||
func (in *TLSCLientCertificateDNInfo) DeepCopy() *TLSCLientCertificateDNInfo {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(TLSCLientCertificateDNInfo)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *TLSClientCertificateInfo) DeepCopyInto(out *TLSClientCertificateInfo) {
|
||||
*out = *in
|
||||
if in.Subject != nil {
|
||||
in, out := &in.Subject, &out.Subject
|
||||
*out = new(TLSCLientCertificateDNInfo)
|
||||
**out = **in
|
||||
}
|
||||
if in.Issuer != nil {
|
||||
in, out := &in.Issuer, &out.Issuer
|
||||
*out = new(TLSCLientCertificateDNInfo)
|
||||
**out = **in
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new TLSClientCertificateInfo.
|
||||
func (in *TLSClientCertificateInfo) DeepCopy() *TLSClientCertificateInfo {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(TLSClientCertificateInfo)
|
||||
in.DeepCopyInto(out)
|
||||
return out
|
||||
}
|
||||
|
||||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in Users) DeepCopyInto(out *Users) {
|
||||
{
|
||||
in := &in
|
||||
*out = make(Users, len(*in))
|
||||
copy(*out, *in)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Users.
|
||||
func (in Users) DeepCopy() Users {
|
||||
if in == nil {
|
||||
return nil
|
||||
}
|
||||
out := new(Users)
|
||||
in.DeepCopyInto(out)
|
||||
return *out
|
||||
}
|
476
pkg/h2c/h2c.go
Normal file
476
pkg/h2c/h2c.go
Normal file
|
@ -0,0 +1,476 @@
|
|||
// Copyright 2018 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package h2c implements the h2c part of HTTP/2.
|
||||
//
|
||||
// The h2c protocol is the non-TLS secured version of HTTP/2 which is not
|
||||
// available from net/http.
|
||||
package h2c
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
|
||||
"golang.org/x/net/http/httpguts"
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/hpack"
|
||||
)
|
||||
|
||||
var (
|
||||
http2VerboseLogs bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
e := os.Getenv("GODEBUG")
|
||||
if strings.Contains(e, "http2debug=1") || strings.Contains(e, "http2debug=2") {
|
||||
http2VerboseLogs = true
|
||||
}
|
||||
}
|
||||
|
||||
// Server implements net.Handler and enables h2c. Users who want h2c just need
|
||||
// to provide an http.Server.
|
||||
type Server struct {
|
||||
*http.Server
|
||||
}
|
||||
|
||||
// Serve Put a middleware around the original handler to handle h2c
|
||||
func (s Server) Serve(l net.Listener) error {
|
||||
originalHandler := s.Server.Handler
|
||||
if originalHandler == nil {
|
||||
originalHandler = http.DefaultServeMux
|
||||
}
|
||||
s.Server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "PRI" && r.URL.Path == "*" && r.Proto == "HTTP/2.0" {
|
||||
if http2VerboseLogs {
|
||||
log.Debugf("Attempting h2c with prior knowledge.")
|
||||
}
|
||||
conn, err := initH2CWithPriorKnowledge(w)
|
||||
if err != nil {
|
||||
if http2VerboseLogs {
|
||||
log.Debugf("Error h2c with prior knowledge: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
h2cSrv := &http2.Server{}
|
||||
h2cSrv.ServeConn(conn, &http2.ServeConnOpts{Handler: originalHandler})
|
||||
return
|
||||
}
|
||||
if conn, err := h2cUpgrade(w, r); err == nil {
|
||||
defer conn.Close()
|
||||
h2cSrv := &http2.Server{}
|
||||
h2cSrv.ServeConn(conn, &http2.ServeConnOpts{Handler: originalHandler})
|
||||
return
|
||||
}
|
||||
originalHandler.ServeHTTP(w, r)
|
||||
})
|
||||
return s.Server.Serve(l)
|
||||
}
|
||||
|
||||
// initH2CWithPriorKnowledge implements creating a h2c connection with prior
|
||||
// knowledge (Section 3.4) and creates a net.Conn suitable for http2.ServeConn.
|
||||
// All we have to do is look for the client preface that is suppose to be part
|
||||
// of the body, and reforward the client preface on the net.Conn this function
|
||||
// creates.
|
||||
func initH2CWithPriorKnowledge(w http.ResponseWriter) (net.Conn, error) {
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, errors.New("hijack not supported")
|
||||
}
|
||||
conn, rw, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hijack failed: %v", err)
|
||||
}
|
||||
|
||||
expectedBody := "SM\r\n\r\n"
|
||||
|
||||
buf := make([]byte, len(expectedBody))
|
||||
n, err := io.ReadFull(rw, buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fail to read body: %v", err)
|
||||
}
|
||||
|
||||
if bytes.Equal(buf[0:n], []byte(expectedBody)) {
|
||||
c := &rwConn{
|
||||
Conn: conn,
|
||||
Reader: io.MultiReader(bytes.NewBuffer([]byte(http2.ClientPreface)), rw),
|
||||
BufWriter: rw.Writer,
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
conn.Close()
|
||||
if http2VerboseLogs {
|
||||
log.Infof(
|
||||
"Missing the request body portion of the client preface. Wanted: %v Got: %v",
|
||||
[]byte(expectedBody),
|
||||
buf[0:n],
|
||||
)
|
||||
}
|
||||
return nil, errors.New("invalid client preface")
|
||||
}
|
||||
|
||||
// drainClientPreface reads a single instance of the HTTP/2 client preface from
|
||||
// the supplied reader.
|
||||
func drainClientPreface(r io.Reader) error {
|
||||
var buf bytes.Buffer
|
||||
prefaceLen := int64(len([]byte(http2.ClientPreface)))
|
||||
n, err := io.CopyN(&buf, r, prefaceLen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if n != prefaceLen || buf.String() != http2.ClientPreface {
|
||||
return fmt.Errorf("client never sent: %s", http2.ClientPreface)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// h2cUpgrade establishes a h2c connection using the HTTP/1 upgrade (Section 3.2).
|
||||
func h2cUpgrade(w http.ResponseWriter, r *http.Request) (net.Conn, error) {
|
||||
if !isH2CUpgrade(r.Header) {
|
||||
return nil, errors.New("non-conforming h2c headers")
|
||||
}
|
||||
|
||||
// Initial bytes we put into conn to fool http2 server
|
||||
initBytes, _, err := convertH1ReqToH2(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hijacker, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, errors.New("hijack not supported")
|
||||
}
|
||||
conn, rw, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hijack failed: %v", err)
|
||||
}
|
||||
|
||||
rw.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n" +
|
||||
"Connection: Upgrade\r\n" +
|
||||
"Upgrade: h2c\r\n\r\n"))
|
||||
rw.Flush()
|
||||
|
||||
// A conforming client will now send an H2 client preface which need to drain
|
||||
// since we already sent this.
|
||||
if err := drainClientPreface(rw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := &rwConn{
|
||||
Conn: conn,
|
||||
Reader: io.MultiReader(initBytes, rw),
|
||||
BufWriter: newSettingsAckSwallowWriter(rw.Writer),
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// convert the data contained in the HTTP/1 upgrade request into the HTTP/2
|
||||
// version in byte form.
|
||||
func convertH1ReqToH2(r *http.Request) (*bytes.Buffer, []http2.Setting, error) {
|
||||
h2Bytes := bytes.NewBuffer([]byte((http2.ClientPreface)))
|
||||
framer := http2.NewFramer(h2Bytes, nil)
|
||||
settings, err := getH2Settings(r.Header)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := framer.WriteSettings(settings...); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
headerBytes, err := getH2HeaderBytes(r, getMaxHeaderTableSize(settings))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
maxFrameSize := int(getMaxFrameSize(settings))
|
||||
needOneHeader := len(headerBytes) < maxFrameSize
|
||||
err = framer.WriteHeaders(http2.HeadersFrameParam{
|
||||
StreamID: 1,
|
||||
BlockFragment: headerBytes,
|
||||
EndHeaders: needOneHeader,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for i := maxFrameSize; i < len(headerBytes); i += maxFrameSize {
|
||||
if len(headerBytes)-i > maxFrameSize {
|
||||
if err := framer.WriteContinuation(1,
|
||||
false, // endHeaders
|
||||
headerBytes[i:maxFrameSize]); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
} else {
|
||||
if err := framer.WriteContinuation(1,
|
||||
true, // endHeaders
|
||||
headerBytes[i:]); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return h2Bytes, settings, nil
|
||||
}
|
||||
|
||||
// getMaxFrameSize returns the SETTINGS_MAX_FRAME_SIZE. If not present default
|
||||
// value is 16384 as specified by RFC 7540 Section 6.5.2.
|
||||
func getMaxFrameSize(settings []http2.Setting) uint32 {
|
||||
for _, setting := range settings {
|
||||
if setting.ID == http2.SettingMaxFrameSize {
|
||||
return setting.Val
|
||||
}
|
||||
}
|
||||
return 16384
|
||||
}
|
||||
|
||||
// getMaxHeaderTableSize returns the SETTINGS_HEADER_TABLE_SIZE. If not present
|
||||
// default value is 4096 as specified by RFC 7540 Section 6.5.2.
|
||||
func getMaxHeaderTableSize(settings []http2.Setting) uint32 {
|
||||
for _, setting := range settings {
|
||||
if setting.ID == http2.SettingHeaderTableSize {
|
||||
return setting.Val
|
||||
}
|
||||
}
|
||||
return 4096
|
||||
}
|
||||
|
||||
// bufWriter is a Writer interface that also has a Flush method.
|
||||
type bufWriter interface {
|
||||
io.Writer
|
||||
Flush() error
|
||||
}
|
||||
|
||||
// rwConn implements net.Conn but overrides Read and Write so that reads and
|
||||
// writes are forwarded to the provided io.Reader and bufWriter.
|
||||
type rwConn struct {
|
||||
net.Conn
|
||||
io.Reader
|
||||
BufWriter bufWriter
|
||||
}
|
||||
|
||||
// Read forwards reads to the underlying Reader.
|
||||
func (c *rwConn) Read(p []byte) (int, error) {
|
||||
return c.Reader.Read(p)
|
||||
}
|
||||
|
||||
// Write forwards writes to the underlying bufWriter and immediately flushes.
|
||||
func (c *rwConn) Write(p []byte) (int, error) {
|
||||
n, err := c.BufWriter.Write(p)
|
||||
if err := c.BufWriter.Flush(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// settingsAckSwallowWriter is a writer that normally forwards bytes to it's
|
||||
// underlying Writer, but swallows the first SettingsAck frame that it sees.
|
||||
type settingsAckSwallowWriter struct {
|
||||
Writer *bufio.Writer
|
||||
buf []byte
|
||||
didSwallow bool
|
||||
}
|
||||
|
||||
// newSettingsAckSwallowWriter returns a new settingsAckSwallowWriter.
|
||||
func newSettingsAckSwallowWriter(w *bufio.Writer) *settingsAckSwallowWriter {
|
||||
return &settingsAckSwallowWriter{
|
||||
Writer: w,
|
||||
buf: make([]byte, 0),
|
||||
didSwallow: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements io.Writer interface. Normally forwards bytes to w.Writer,
|
||||
// except for the first Settings ACK frame that it sees.
|
||||
func (w *settingsAckSwallowWriter) Write(p []byte) (int, error) {
|
||||
if !w.didSwallow {
|
||||
w.buf = append(w.buf, p...)
|
||||
// Process all the frames we have collected into w.buf
|
||||
for {
|
||||
// Append until we get full frame header which is 9 bytes
|
||||
if len(w.buf) < 9 {
|
||||
break
|
||||
}
|
||||
// Check if we have collected a whole frame.
|
||||
fh, err := http2.ReadFrameHeader(bytes.NewBuffer(w.buf))
|
||||
if err != nil {
|
||||
// Corrupted frame, fail current Write
|
||||
return 0, err
|
||||
}
|
||||
fSize := fh.Length + 9
|
||||
if uint32(len(w.buf)) < fSize {
|
||||
// Have not collected whole frame. Stop processing buf, and withhold on
|
||||
// forward bytes to w.Writer until we get the full frame.
|
||||
break
|
||||
}
|
||||
|
||||
// We have now collected a whole frame.
|
||||
if fh.Type == http2.FrameSettings && fh.Flags.Has(http2.FlagSettingsAck) {
|
||||
// If Settings ACK frame, do not forward to underlying writer, remove
|
||||
// bytes from w.buf, and record that we have swallowed Settings Ack
|
||||
// frame.
|
||||
w.didSwallow = true
|
||||
w.buf = w.buf[fSize:]
|
||||
continue
|
||||
}
|
||||
|
||||
// Not settings ack frame. Forward bytes to w.Writer.
|
||||
if _, err := w.Writer.Write(w.buf[:fSize]); err != nil {
|
||||
// Couldn't forward bytes. Fail current Write.
|
||||
return 0, err
|
||||
}
|
||||
w.buf = w.buf[fSize:]
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
return w.Writer.Write(p)
|
||||
}
|
||||
|
||||
// Flush calls w.Writer.Flush.
|
||||
func (w *settingsAckSwallowWriter) Flush() error {
|
||||
return w.Writer.Flush()
|
||||
}
|
||||
|
||||
// isH2CUpgrade returns true if the header properly request an upgrade to h2c
|
||||
// as specified by Section 3.2.
|
||||
func isH2CUpgrade(h http.Header) bool {
|
||||
return httpguts.HeaderValuesContainsToken(h[textproto.CanonicalMIMEHeaderKey("Upgrade")], "h2c") &&
|
||||
httpguts.HeaderValuesContainsToken(h[textproto.CanonicalMIMEHeaderKey("Connection")], "HTTP2-Settings")
|
||||
}
|
||||
|
||||
// getH2Settings returns the []http2.Setting that are encoded in the
|
||||
// HTTP2-Settings header.
|
||||
func getH2Settings(h http.Header) ([]http2.Setting, error) {
|
||||
vals, ok := h[textproto.CanonicalMIMEHeaderKey("HTTP2-Settings")]
|
||||
if !ok {
|
||||
return nil, errors.New("missing HTTP2-Settings header")
|
||||
}
|
||||
if len(vals) != 1 {
|
||||
return nil, fmt.Errorf("expected 1 HTTP2-Settings. Got: %v", vals)
|
||||
}
|
||||
settings, err := decodeSettings(vals[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid HTTP2-Settings: %q", vals[0])
|
||||
}
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
// decodeSettings decodes the base64url header value of the HTTP2-Settings
|
||||
// header. RFC 7540 Section 3.2.1.
|
||||
func decodeSettings(headerVal string) ([]http2.Setting, error) {
|
||||
b, err := base64.RawURLEncoding.DecodeString(headerVal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(b)%6 != 0 {
|
||||
return nil, err
|
||||
}
|
||||
settings := make([]http2.Setting, 0)
|
||||
for i := 0; i < len(b)/6; i++ {
|
||||
settings = append(settings, http2.Setting{
|
||||
ID: http2.SettingID(binary.BigEndian.Uint16(b[i*6 : i*6+2])),
|
||||
Val: binary.BigEndian.Uint32(b[i*6+2 : i*6+6]),
|
||||
})
|
||||
}
|
||||
|
||||
return settings, nil
|
||||
}
|
||||
|
||||
// getH2HeaderBytes return the headers in r a []bytes encoded by HPACK.
|
||||
func getH2HeaderBytes(r *http.Request, maxHeaderTableSize uint32) ([]byte, error) {
|
||||
headerBytes := bytes.NewBuffer(nil)
|
||||
hpackEnc := hpack.NewEncoder(headerBytes)
|
||||
hpackEnc.SetMaxDynamicTableSize(maxHeaderTableSize)
|
||||
|
||||
// Section 8.1.2.3
|
||||
err := hpackEnc.WriteField(hpack.HeaderField{
|
||||
Name: ":method",
|
||||
Value: r.Method,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = hpackEnc.WriteField(hpack.HeaderField{
|
||||
Name: ":scheme",
|
||||
Value: "http",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = hpackEnc.WriteField(hpack.HeaderField{
|
||||
Name: ":authority",
|
||||
Value: r.Host,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
path := r.URL.Path
|
||||
if r.URL.RawQuery != "" {
|
||||
path = strings.Join([]string{path, r.URL.RawQuery}, "?")
|
||||
}
|
||||
err = hpackEnc.WriteField(hpack.HeaderField{
|
||||
Name: ":path",
|
||||
Value: path,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO Implement Section 8.3
|
||||
|
||||
for header, values := range r.Header {
|
||||
// Skip non h2 headers
|
||||
if isNonH2Header(header) {
|
||||
continue
|
||||
}
|
||||
for _, v := range values {
|
||||
err := hpackEnc.WriteField(hpack.HeaderField{
|
||||
Name: strings.ToLower(header),
|
||||
Value: v,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return headerBytes.Bytes(), nil
|
||||
}
|
||||
|
||||
// Connection specific headers listed in RFC 7540 Section 8.1.2.2 that are not
|
||||
// suppose to be transferred to HTTP/2. The Http2-Settings header is skipped
|
||||
// since already use to create the HTTP/2 SETTINGS frame.
|
||||
var nonH2Headers = []string{
|
||||
"Connection",
|
||||
"Keep-Alive",
|
||||
"Proxy-Connection",
|
||||
"Transfer-Encoding",
|
||||
"Upgrade",
|
||||
"Http2-Settings",
|
||||
}
|
||||
|
||||
// isNonH2Header returns true if header should not be transferred to HTTP/2.
|
||||
func isNonH2Header(header string) bool {
|
||||
for _, nonH2h := range nonH2Headers {
|
||||
if header == nonH2h {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
223
pkg/healthcheck/healthcheck.go
Normal file
223
pkg/healthcheck/healthcheck.go
Normal file
|
@ -0,0 +1,223 @@
|
|||
package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
"github.com/containous/traefik/pkg/safe"
|
||||
"github.com/go-kit/kit/metrics"
|
||||
"github.com/vulcand/oxy/roundrobin"
|
||||
)
|
||||
|
||||
var singleton *HealthCheck
|
||||
var once sync.Once
|
||||
|
||||
// BalancerHandler includes functionality for load-balancing management.
|
||||
type BalancerHandler interface {
|
||||
ServeHTTP(w http.ResponseWriter, req *http.Request)
|
||||
Servers() []*url.URL
|
||||
RemoveServer(u *url.URL) error
|
||||
UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error
|
||||
}
|
||||
|
||||
// metricsRegistry is a local interface in the health check package, exposing only the required metrics
|
||||
// necessary for the health check package. This makes it easier for the tests.
|
||||
type metricsRegistry interface {
|
||||
BackendServerUpGauge() metrics.Gauge
|
||||
}
|
||||
|
||||
// Options are the public health check options.
|
||||
type Options struct {
|
||||
Headers map[string]string
|
||||
Hostname string
|
||||
Scheme string
|
||||
Path string
|
||||
Port int
|
||||
Transport http.RoundTripper
|
||||
Interval time.Duration
|
||||
Timeout time.Duration
|
||||
LB BalancerHandler
|
||||
}
|
||||
|
||||
func (opt Options) String() string {
|
||||
return fmt.Sprintf("[Hostname: %s Headers: %v Path: %s Port: %d Interval: %s Timeout: %s]", opt.Hostname, opt.Headers, opt.Path, opt.Port, opt.Interval, opt.Timeout)
|
||||
}
|
||||
|
||||
// BackendConfig HealthCheck configuration for a backend
|
||||
type BackendConfig struct {
|
||||
Options
|
||||
name string
|
||||
disabledURLs []*url.URL
|
||||
}
|
||||
|
||||
func (b *BackendConfig) newRequest(serverURL *url.URL) (*http.Request, error) {
|
||||
u, err := serverURL.Parse(b.Path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(b.Scheme) > 0 {
|
||||
u.Scheme = b.Scheme
|
||||
}
|
||||
|
||||
if b.Port != 0 {
|
||||
u.Host = net.JoinHostPort(u.Hostname(), strconv.Itoa(b.Port))
|
||||
}
|
||||
|
||||
return http.NewRequest(http.MethodGet, u.String(), http.NoBody)
|
||||
}
|
||||
|
||||
// this function adds additional http headers and hostname to http.request
|
||||
func (b *BackendConfig) addHeadersAndHost(req *http.Request) *http.Request {
|
||||
if b.Options.Hostname != "" {
|
||||
req.Host = b.Options.Hostname
|
||||
}
|
||||
|
||||
for k, v := range b.Options.Headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// HealthCheck struct
|
||||
type HealthCheck struct {
|
||||
Backends map[string]*BackendConfig
|
||||
metrics metricsRegistry
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// SetBackendsConfiguration set backends configuration
|
||||
func (hc *HealthCheck) SetBackendsConfiguration(parentCtx context.Context, backends map[string]*BackendConfig) {
|
||||
hc.Backends = backends
|
||||
if hc.cancel != nil {
|
||||
hc.cancel()
|
||||
}
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
hc.cancel = cancel
|
||||
|
||||
for _, backend := range backends {
|
||||
currentBackend := backend
|
||||
safe.Go(func() {
|
||||
hc.execute(ctx, currentBackend)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (hc *HealthCheck) execute(ctx context.Context, backend *BackendConfig) {
|
||||
log.Debugf("Initial health check for backend: %q", backend.name)
|
||||
hc.checkBackend(backend)
|
||||
ticker := time.NewTicker(backend.Interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debugf("Stopping current health check goroutines of backend: %s", backend.name)
|
||||
return
|
||||
case <-ticker.C:
|
||||
log.Debugf("Refreshing health check for backend: %s", backend.name)
|
||||
hc.checkBackend(backend)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hc *HealthCheck) checkBackend(backend *BackendConfig) {
|
||||
enabledURLs := backend.LB.Servers()
|
||||
var newDisabledURLs []*url.URL
|
||||
// FIXME re enable metrics
|
||||
for _, disableURL := range backend.disabledURLs {
|
||||
// FIXME serverUpMetricValue := float64(0)
|
||||
if err := checkHealth(disableURL, backend); err == nil {
|
||||
log.Warnf("Health check up: Returning to server list. Backend: %q URL: %q", backend.name, disableURL.String())
|
||||
if err = backend.LB.UpsertServer(disableURL, roundrobin.Weight(1)); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
// FIXME serverUpMetricValue = 1
|
||||
} else {
|
||||
log.Warnf("Health check still failing. Backend: %q URL: %q Reason: %s", backend.name, disableURL.String(), err)
|
||||
newDisabledURLs = append(newDisabledURLs, disableURL)
|
||||
}
|
||||
// FIXME labelValues := []string{"backend", backend.name, "url", disableURL.String()}
|
||||
// FIXME hc.metrics.BackendServerUpGauge().With(labelValues...).Set(serverUpMetricValue)
|
||||
}
|
||||
backend.disabledURLs = newDisabledURLs
|
||||
|
||||
// FIXME re enable metrics
|
||||
for _, enableURL := range enabledURLs {
|
||||
// FIXME serverUpMetricValue := float64(1)
|
||||
if err := checkHealth(enableURL, backend); err != nil {
|
||||
log.Warnf("Health check failed: Remove from server list. Backend: %q URL: %q Reason: %s", backend.name, enableURL.String(), err)
|
||||
if err := backend.LB.RemoveServer(enableURL); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
backend.disabledURLs = append(backend.disabledURLs, enableURL)
|
||||
// FIXME serverUpMetricValue = 0
|
||||
}
|
||||
// FIXME labelValues := []string{"backend", backend.name, "url", enableURL.String()}
|
||||
// FIXME hc.metrics.BackendServerUpGauge().With(labelValues...).Set(serverUpMetricValue)
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME re add metrics
|
||||
//func GetHealthCheck(metrics metricsRegistry) *HealthCheck {
|
||||
|
||||
// GetHealthCheck returns the health check which is guaranteed to be a singleton.
|
||||
func GetHealthCheck() *HealthCheck {
|
||||
once.Do(func() {
|
||||
singleton = newHealthCheck()
|
||||
//singleton = newHealthCheck(metrics)
|
||||
})
|
||||
return singleton
|
||||
}
|
||||
|
||||
// FIXME re add metrics
|
||||
//func newHealthCheck(metrics metricsRegistry) *HealthCheck {
|
||||
func newHealthCheck() *HealthCheck {
|
||||
return &HealthCheck{
|
||||
Backends: make(map[string]*BackendConfig),
|
||||
//metrics: metrics,
|
||||
}
|
||||
}
|
||||
|
||||
// NewBackendConfig Instantiate a new BackendConfig
|
||||
func NewBackendConfig(options Options, backendName string) *BackendConfig {
|
||||
return &BackendConfig{
|
||||
Options: options,
|
||||
name: backendName,
|
||||
}
|
||||
}
|
||||
|
||||
// checkHealth returns a nil error in case it was successful and otherwise
|
||||
// a non-nil error with a meaningful description why the health check failed.
|
||||
func checkHealth(serverURL *url.URL, backend *BackendConfig) error {
|
||||
req, err := backend.newRequest(serverURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create HTTP request: %s", err)
|
||||
}
|
||||
|
||||
req = backend.addHeadersAndHost(req)
|
||||
|
||||
client := http.Client{
|
||||
Timeout: backend.Options.Timeout,
|
||||
Transport: backend.Options.Transport,
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("HTTP request failed: %s", err)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest {
|
||||
return fmt.Errorf("received error status code: %v", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
429
pkg/healthcheck/healthcheck_test.go
Normal file
429
pkg/healthcheck/healthcheck_test.go
Normal file
|
@ -0,0 +1,429 @@
|
|||
package healthcheck
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/pkg/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vulcand/oxy/roundrobin"
|
||||
)
|
||||
|
||||
const healthCheckInterval = 200 * time.Millisecond
|
||||
const healthCheckTimeout = 100 * time.Millisecond
|
||||
|
||||
type testHandler struct {
|
||||
done func()
|
||||
healthSequence []int
|
||||
}
|
||||
|
||||
func TestSetBackendsConfiguration(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
startHealthy bool
|
||||
healthSequence []int
|
||||
expectedNumRemovedServers int
|
||||
expectedNumUpsertedServers int
|
||||
expectedGaugeValue float64
|
||||
}{
|
||||
{
|
||||
desc: "healthy server staying healthy",
|
||||
startHealthy: true,
|
||||
healthSequence: []int{http.StatusOK},
|
||||
expectedNumRemovedServers: 0,
|
||||
expectedNumUpsertedServers: 0,
|
||||
expectedGaugeValue: 1,
|
||||
},
|
||||
{
|
||||
desc: "healthy server staying healthy (StatusNoContent)",
|
||||
startHealthy: true,
|
||||
healthSequence: []int{http.StatusNoContent},
|
||||
expectedNumRemovedServers: 0,
|
||||
expectedNumUpsertedServers: 0,
|
||||
expectedGaugeValue: 1,
|
||||
},
|
||||
{
|
||||
desc: "healthy server staying healthy (StatusPermanentRedirect)",
|
||||
startHealthy: true,
|
||||
healthSequence: []int{http.StatusPermanentRedirect},
|
||||
expectedNumRemovedServers: 0,
|
||||
expectedNumUpsertedServers: 0,
|
||||
expectedGaugeValue: 1,
|
||||
},
|
||||
{
|
||||
desc: "healthy server becoming sick",
|
||||
startHealthy: true,
|
||||
healthSequence: []int{http.StatusServiceUnavailable},
|
||||
expectedNumRemovedServers: 1,
|
||||
expectedNumUpsertedServers: 0,
|
||||
expectedGaugeValue: 0,
|
||||
},
|
||||
{
|
||||
desc: "sick server becoming healthy",
|
||||
startHealthy: false,
|
||||
healthSequence: []int{http.StatusOK},
|
||||
expectedNumRemovedServers: 0,
|
||||
expectedNumUpsertedServers: 1,
|
||||
expectedGaugeValue: 1,
|
||||
},
|
||||
{
|
||||
desc: "sick server staying sick",
|
||||
startHealthy: false,
|
||||
healthSequence: []int{http.StatusServiceUnavailable},
|
||||
expectedNumRemovedServers: 0,
|
||||
expectedNumUpsertedServers: 0,
|
||||
expectedGaugeValue: 0,
|
||||
},
|
||||
{
|
||||
desc: "healthy server toggling to sick and back to healthy",
|
||||
startHealthy: true,
|
||||
healthSequence: []int{http.StatusServiceUnavailable, http.StatusOK},
|
||||
expectedNumRemovedServers: 1,
|
||||
expectedNumUpsertedServers: 1,
|
||||
expectedGaugeValue: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// The context is passed to the health check and canonically canceled by
|
||||
// the test server once all expected requests have been received.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
ts := newTestServer(cancel, test.healthSequence)
|
||||
defer ts.Close()
|
||||
|
||||
lb := &testLoadBalancer{RWMutex: &sync.RWMutex{}}
|
||||
backend := NewBackendConfig(Options{
|
||||
Path: "/path",
|
||||
Interval: healthCheckInterval,
|
||||
Timeout: healthCheckTimeout,
|
||||
LB: lb,
|
||||
}, "backendName")
|
||||
|
||||
serverURL := testhelpers.MustParseURL(ts.URL)
|
||||
if test.startHealthy {
|
||||
lb.servers = append(lb.servers, serverURL)
|
||||
} else {
|
||||
backend.disabledURLs = append(backend.disabledURLs, serverURL)
|
||||
}
|
||||
|
||||
collectingMetrics := testhelpers.NewCollectingHealthCheckMetrics()
|
||||
check := HealthCheck{
|
||||
Backends: make(map[string]*BackendConfig),
|
||||
metrics: collectingMetrics,
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
check.execute(ctx, backend)
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
// Make test timeout dependent on number of expected requests, health
|
||||
// check interval, and a safety margin.
|
||||
timeout := time.Duration(len(test.healthSequence)*int(healthCheckInterval) + 500)
|
||||
select {
|
||||
case <-time.After(timeout):
|
||||
t.Fatal("test did not complete in time")
|
||||
case <-ctx.Done():
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
lb.Lock()
|
||||
defer lb.Unlock()
|
||||
|
||||
assert.Equal(t, test.expectedNumRemovedServers, lb.numRemovedServers, "removed servers")
|
||||
assert.Equal(t, test.expectedNumUpsertedServers, lb.numUpsertedServers, "upserted servers")
|
||||
// FIXME re add metrics
|
||||
//assert.Equal(t, test.expectedGaugeValue, collectingMetrics.Gauge.GaugeValue, "ServerUp Gauge")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRequest(t *testing.T) {
|
||||
type expected struct {
|
||||
err bool
|
||||
value string
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
serverURL string
|
||||
options Options
|
||||
expected expected
|
||||
}{
|
||||
{
|
||||
desc: "no port override",
|
||||
serverURL: "http://backend1:80",
|
||||
options: Options{
|
||||
Path: "/test",
|
||||
Port: 0,
|
||||
},
|
||||
expected: expected{
|
||||
err: false,
|
||||
value: "http://backend1:80/test",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "port override",
|
||||
serverURL: "http://backend2:80",
|
||||
options: Options{
|
||||
Path: "/test",
|
||||
Port: 8080,
|
||||
},
|
||||
expected: expected{
|
||||
err: false,
|
||||
value: "http://backend2:8080/test",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "no port override with no port in server URL",
|
||||
serverURL: "http://backend1",
|
||||
options: Options{
|
||||
Path: "/health",
|
||||
Port: 0,
|
||||
},
|
||||
expected: expected{
|
||||
err: false,
|
||||
value: "http://backend1/health",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "port override with no port in server URL",
|
||||
serverURL: "http://backend2",
|
||||
options: Options{
|
||||
Path: "/health",
|
||||
Port: 8080,
|
||||
},
|
||||
expected: expected{
|
||||
err: false,
|
||||
value: "http://backend2:8080/health",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "scheme override",
|
||||
serverURL: "https://backend1:80",
|
||||
options: Options{
|
||||
Scheme: "http",
|
||||
Path: "/test",
|
||||
Port: 0,
|
||||
},
|
||||
expected: expected{
|
||||
err: false,
|
||||
value: "http://backend1:80/test",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "path with param",
|
||||
serverURL: "http://backend1:80",
|
||||
options: Options{
|
||||
Path: "/health?powpow=do",
|
||||
Port: 0,
|
||||
},
|
||||
expected: expected{
|
||||
err: false,
|
||||
value: "http://backend1:80/health?powpow=do",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "path with params",
|
||||
serverURL: "http://backend1:80",
|
||||
options: Options{
|
||||
Path: "/health?powpow=do&do=powpow",
|
||||
Port: 0,
|
||||
},
|
||||
expected: expected{
|
||||
err: false,
|
||||
value: "http://backend1:80/health?powpow=do&do=powpow",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "path with invalid path",
|
||||
serverURL: "http://backend1:80",
|
||||
options: Options{
|
||||
Path: ":",
|
||||
Port: 0,
|
||||
},
|
||||
expected: expected{
|
||||
err: true,
|
||||
value: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend := NewBackendConfig(test.options, "backendName")
|
||||
|
||||
u := testhelpers.MustParseURL(test.serverURL)
|
||||
|
||||
req, err := backend.newRequest(u)
|
||||
|
||||
if test.expected.err {
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, nil)
|
||||
} else {
|
||||
require.NoError(t, err, "failed to create new backend request")
|
||||
require.NotNil(t, req)
|
||||
assert.Equal(t, test.expected.value, req.URL.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddHeadersAndHost(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
serverURL string
|
||||
options Options
|
||||
expectedHostname string
|
||||
expectedHeader string
|
||||
}{
|
||||
{
|
||||
desc: "override hostname",
|
||||
serverURL: "http://backend1:80",
|
||||
options: Options{
|
||||
Hostname: "myhost",
|
||||
Path: "/",
|
||||
},
|
||||
expectedHostname: "myhost",
|
||||
expectedHeader: "",
|
||||
},
|
||||
{
|
||||
desc: "not override hostname",
|
||||
serverURL: "http://backend1:80",
|
||||
options: Options{
|
||||
Hostname: "",
|
||||
Path: "/",
|
||||
},
|
||||
expectedHostname: "backend1:80",
|
||||
expectedHeader: "",
|
||||
},
|
||||
{
|
||||
desc: "custom header",
|
||||
serverURL: "http://backend1:80",
|
||||
options: Options{
|
||||
Headers: map[string]string{"Custom-Header": "foo"},
|
||||
Hostname: "",
|
||||
Path: "/",
|
||||
},
|
||||
expectedHostname: "backend1:80",
|
||||
expectedHeader: "foo",
|
||||
},
|
||||
{
|
||||
desc: "custom header with hostname override",
|
||||
serverURL: "http://backend1:80",
|
||||
options: Options{
|
||||
Headers: map[string]string{"Custom-Header": "foo"},
|
||||
Hostname: "myhost",
|
||||
Path: "/",
|
||||
},
|
||||
expectedHostname: "myhost",
|
||||
expectedHeader: "foo",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
backend := NewBackendConfig(test.options, "backendName")
|
||||
|
||||
u, err := url.Parse(test.serverURL)
|
||||
require.NoError(t, err)
|
||||
|
||||
req, err := backend.newRequest(u)
|
||||
require.NoError(t, err, "failed to create new backend request")
|
||||
|
||||
req = backend.addHeadersAndHost(req)
|
||||
|
||||
assert.Equal(t, "http://backend1:80/", req.URL.String())
|
||||
assert.Equal(t, test.expectedHostname, req.Host)
|
||||
assert.Equal(t, test.expectedHeader, req.Header.Get("Custom-Header"))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testLoadBalancer struct {
|
||||
// RWMutex needed due to parallel test execution: Both the system-under-test
|
||||
// and the test assertions reference the counters.
|
||||
*sync.RWMutex
|
||||
numRemovedServers int
|
||||
numUpsertedServers int
|
||||
servers []*url.URL
|
||||
}
|
||||
|
||||
func (lb *testLoadBalancer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
// noop
|
||||
}
|
||||
|
||||
func (lb *testLoadBalancer) RemoveServer(u *url.URL) error {
|
||||
lb.Lock()
|
||||
defer lb.Unlock()
|
||||
lb.numRemovedServers++
|
||||
lb.removeServer(u)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lb *testLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error {
|
||||
lb.Lock()
|
||||
defer lb.Unlock()
|
||||
lb.numUpsertedServers++
|
||||
lb.servers = append(lb.servers, u)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lb *testLoadBalancer) Servers() []*url.URL {
|
||||
return lb.servers
|
||||
}
|
||||
|
||||
func (lb *testLoadBalancer) removeServer(u *url.URL) {
|
||||
var i int
|
||||
var serverURL *url.URL
|
||||
for i, serverURL = range lb.servers {
|
||||
if *serverURL == *u {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
lb.servers = append(lb.servers[:i], lb.servers[i+1:]...)
|
||||
}
|
||||
|
||||
func newTestServer(done func(), healthSequence []int) *httptest.Server {
|
||||
handler := &testHandler{
|
||||
done: done,
|
||||
healthSequence: healthSequence,
|
||||
}
|
||||
return httptest.NewServer(handler)
|
||||
}
|
||||
|
||||
// ServeHTTP returns HTTP response codes following a status sequences.
|
||||
// It calls the given 'done' function once all request health indicators have been depleted.
|
||||
func (th *testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if len(th.healthSequence) == 0 {
|
||||
panic("received unexpected request")
|
||||
}
|
||||
|
||||
w.WriteHeader(th.healthSequence[0])
|
||||
|
||||
th.healthSequence = th.healthSequence[1:]
|
||||
if len(th.healthSequence) == 0 {
|
||||
th.done()
|
||||
}
|
||||
}
|
123
pkg/hostresolver/hostresolver.go
Normal file
123
pkg/hostresolver/hostresolver.go
Normal file
|
@ -0,0 +1,123 @@
|
|||
package hostresolver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
type cnameResolv struct {
|
||||
TTL time.Duration
|
||||
Record string
|
||||
}
|
||||
|
||||
type byTTL []*cnameResolv
|
||||
|
||||
func (a byTTL) Len() int { return len(a) }
|
||||
func (a byTTL) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a byTTL) Less(i, j int) bool { return a[i].TTL > a[j].TTL }
|
||||
|
||||
// Resolver used for host resolver
|
||||
type Resolver struct {
|
||||
CnameFlattening bool
|
||||
ResolvConfig string
|
||||
ResolvDepth int
|
||||
cache *cache.Cache
|
||||
}
|
||||
|
||||
// CNAMEFlatten check if CNAME record exists, flatten if possible
|
||||
func (hr *Resolver) CNAMEFlatten(host string) (string, string) {
|
||||
if hr.cache == nil {
|
||||
hr.cache = cache.New(30*time.Minute, 5*time.Minute)
|
||||
}
|
||||
|
||||
result := []string{host}
|
||||
request := host
|
||||
|
||||
value, found := hr.cache.Get(host)
|
||||
if found {
|
||||
result = strings.Split(value.(string), ",")
|
||||
} else {
|
||||
var cacheDuration = 0 * time.Second
|
||||
|
||||
for depth := 0; depth < hr.ResolvDepth; depth++ {
|
||||
resolv, err := cnameResolve(request, hr.ResolvConfig)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
break
|
||||
}
|
||||
if resolv == nil {
|
||||
break
|
||||
}
|
||||
|
||||
result = append(result, resolv.Record)
|
||||
if depth == 0 {
|
||||
cacheDuration = resolv.TTL
|
||||
}
|
||||
request = resolv.Record
|
||||
}
|
||||
|
||||
if err := hr.cache.Add(host, strings.Join(result, ","), cacheDuration); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
return result[0], result[len(result)-1]
|
||||
}
|
||||
|
||||
// cnameResolve resolves CNAME if exists, and return with the highest TTL
|
||||
func cnameResolve(host string, resolvPath string) (*cnameResolv, error) {
|
||||
config, err := dns.ClientConfigFromFile(resolvPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid resolver configuration file: %s", resolvPath)
|
||||
}
|
||||
|
||||
client := &dns.Client{Timeout: 30 * time.Second}
|
||||
|
||||
m := &dns.Msg{}
|
||||
m.SetQuestion(dns.Fqdn(host), dns.TypeCNAME)
|
||||
|
||||
var result []*cnameResolv
|
||||
for _, server := range config.Servers {
|
||||
tempRecord, err := getRecord(client, m, server, config.Port)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to resolve host %s: %v", host, err)
|
||||
continue
|
||||
}
|
||||
result = append(result, tempRecord)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
sort.Sort(byTTL(result))
|
||||
return result[0], nil
|
||||
}
|
||||
|
||||
func getRecord(client *dns.Client, msg *dns.Msg, server string, port string) (*cnameResolv, error) {
|
||||
resp, _, err := client.Exchange(msg, net.JoinHostPort(server, port))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("exchange error for server %s: %v", server, err)
|
||||
}
|
||||
|
||||
if resp == nil || len(resp.Answer) == 0 {
|
||||
return nil, fmt.Errorf("empty answer for server %s", server)
|
||||
}
|
||||
|
||||
rr, ok := resp.Answer[0].(*dns.CNAME)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid response type for server %s", server)
|
||||
}
|
||||
|
||||
return &cnameResolv{
|
||||
TTL: time.Duration(rr.Hdr.Ttl) * time.Second,
|
||||
Record: strings.TrimSuffix(rr.Target, "."),
|
||||
}, nil
|
||||
}
|
61
pkg/hostresolver/hostresolver_test.go
Normal file
61
pkg/hostresolver/hostresolver_test.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package hostresolver
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCNAMEFlatten(t *testing.T) {
|
||||
testCase := []struct {
|
||||
desc string
|
||||
resolvFile string
|
||||
domain string
|
||||
expectedDomain string
|
||||
isCNAME bool
|
||||
}{
|
||||
{
|
||||
desc: "host request is CNAME record",
|
||||
resolvFile: "/etc/resolv.conf",
|
||||
domain: "www.github.com",
|
||||
expectedDomain: "github.com",
|
||||
isCNAME: true,
|
||||
},
|
||||
{
|
||||
desc: "resolve file not found",
|
||||
resolvFile: "/etc/resolv.oops",
|
||||
domain: "www.github.com",
|
||||
expectedDomain: "www.github.com",
|
||||
isCNAME: false,
|
||||
},
|
||||
{
|
||||
desc: "host request is not CNAME record",
|
||||
resolvFile: "/etc/resolv.conf",
|
||||
domain: "github.com",
|
||||
expectedDomain: "github.com",
|
||||
isCNAME: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCase {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
hostResolver := &Resolver{
|
||||
ResolvConfig: test.resolvFile,
|
||||
ResolvDepth: 5,
|
||||
}
|
||||
|
||||
reqH, flatH := hostResolver.CNAMEFlatten(test.domain)
|
||||
assert.Equal(t, test.domain, reqH)
|
||||
assert.Equal(t, test.expectedDomain, flatH)
|
||||
|
||||
if test.isCNAME {
|
||||
assert.NotEqual(t, test.expectedDomain, reqH)
|
||||
} else {
|
||||
assert.Equal(t, test.expectedDomain, reqH)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
99
pkg/ip/checker.go
Normal file
99
pkg/ip/checker.go
Normal file
|
@ -0,0 +1,99 @@
|
|||
package ip
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Checker allows to check that addresses are in a trusted IPs
|
||||
type Checker struct {
|
||||
authorizedIPs []*net.IP
|
||||
authorizedIPsNet []*net.IPNet
|
||||
}
|
||||
|
||||
// NewChecker builds a new Checker given a list of CIDR-Strings to trusted IPs
|
||||
func NewChecker(trustedIPs []string) (*Checker, error) {
|
||||
if len(trustedIPs) == 0 {
|
||||
return nil, errors.New("no trusted IPs provided")
|
||||
}
|
||||
|
||||
checker := &Checker{}
|
||||
|
||||
for _, ipMask := range trustedIPs {
|
||||
if ipAddr := net.ParseIP(ipMask); ipAddr != nil {
|
||||
checker.authorizedIPs = append(checker.authorizedIPs, &ipAddr)
|
||||
} else {
|
||||
_, ipAddr, err := net.ParseCIDR(ipMask)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing CIDR trusted IPs %s: %v", ipAddr, err)
|
||||
}
|
||||
checker.authorizedIPsNet = append(checker.authorizedIPsNet, ipAddr)
|
||||
}
|
||||
}
|
||||
|
||||
return checker, nil
|
||||
}
|
||||
|
||||
// IsAuthorized checks if provided request is authorized by the trusted IPs
|
||||
func (ip *Checker) IsAuthorized(addr string) error {
|
||||
var invalidMatches []string
|
||||
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr
|
||||
}
|
||||
|
||||
ok, err := ip.Contains(host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !ok {
|
||||
invalidMatches = append(invalidMatches, addr)
|
||||
return fmt.Errorf("%q matched none of the trusted IPs", strings.Join(invalidMatches, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Contains checks if provided address is in the trusted IPs
|
||||
func (ip *Checker) Contains(addr string) (bool, error) {
|
||||
if len(addr) == 0 {
|
||||
return false, errors.New("empty IP address")
|
||||
}
|
||||
|
||||
ipAddr, err := parseIP(addr)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("unable to parse address: %s: %s", addr, err)
|
||||
}
|
||||
|
||||
return ip.ContainsIP(ipAddr), nil
|
||||
}
|
||||
|
||||
// ContainsIP checks if provided address is in the trusted IPs
|
||||
func (ip *Checker) ContainsIP(addr net.IP) bool {
|
||||
for _, authorizedIP := range ip.authorizedIPs {
|
||||
if authorizedIP.Equal(addr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for _, authorizedNet := range ip.authorizedIPsNet {
|
||||
if authorizedNet.Contains(addr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func parseIP(addr string) (net.IP, error) {
|
||||
userIP := net.ParseIP(addr)
|
||||
if userIP == nil {
|
||||
return nil, fmt.Errorf("can't parse IP from address %s", addr)
|
||||
}
|
||||
|
||||
return userIP, nil
|
||||
}
|
326
pkg/ip/checker_test.go
Normal file
326
pkg/ip/checker_test.go
Normal file
|
@ -0,0 +1,326 @@
|
|||
package ip
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsAuthorized(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
whiteList []string
|
||||
remoteAddr string
|
||||
authorized bool
|
||||
}{
|
||||
{
|
||||
desc: "remoteAddr not in range",
|
||||
whiteList: []string{"1.2.3.4/24"},
|
||||
remoteAddr: "10.2.3.1:123",
|
||||
authorized: false,
|
||||
},
|
||||
{
|
||||
desc: "remoteAddr in range",
|
||||
whiteList: []string{"1.2.3.4/24"},
|
||||
remoteAddr: "1.2.3.1:123",
|
||||
authorized: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ipChecker, err := NewChecker(test.whiteList)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ipChecker.IsAuthorized(test.remoteAddr)
|
||||
if test.authorized {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
trustedIPs []string
|
||||
expectedAuthorizedIPs []*net.IPNet
|
||||
errMessage string
|
||||
}{
|
||||
{
|
||||
desc: "nil trusted IPs",
|
||||
trustedIPs: nil,
|
||||
expectedAuthorizedIPs: nil,
|
||||
errMessage: "no trusted IPs provided",
|
||||
}, {
|
||||
desc: "empty trusted IPs",
|
||||
trustedIPs: []string{},
|
||||
expectedAuthorizedIPs: nil,
|
||||
errMessage: "no trusted IPs provided",
|
||||
}, {
|
||||
desc: "trusted IPs containing empty string",
|
||||
trustedIPs: []string{
|
||||
"1.2.3.4/24",
|
||||
"",
|
||||
"fe80::/16",
|
||||
},
|
||||
expectedAuthorizedIPs: nil,
|
||||
errMessage: "parsing CIDR trusted IPs <nil>: invalid CIDR address: ",
|
||||
}, {
|
||||
desc: "trusted IPs containing only an empty string",
|
||||
trustedIPs: []string{
|
||||
"",
|
||||
},
|
||||
expectedAuthorizedIPs: nil,
|
||||
errMessage: "parsing CIDR trusted IPs <nil>: invalid CIDR address: ",
|
||||
}, {
|
||||
desc: "trusted IPs containing an invalid string",
|
||||
trustedIPs: []string{
|
||||
"foo",
|
||||
},
|
||||
expectedAuthorizedIPs: nil,
|
||||
errMessage: "parsing CIDR trusted IPs <nil>: invalid CIDR address: foo",
|
||||
}, {
|
||||
desc: "IPv4 & IPv6 trusted IPs",
|
||||
trustedIPs: []string{
|
||||
"1.2.3.4/24",
|
||||
"fe80::/16",
|
||||
},
|
||||
expectedAuthorizedIPs: []*net.IPNet{
|
||||
{IP: net.IPv4(1, 2, 3, 0).To4(), Mask: net.IPv4Mask(255, 255, 255, 0)},
|
||||
{IP: net.ParseIP("fe80::"), Mask: net.IPMask(net.ParseIP("ffff::"))},
|
||||
},
|
||||
errMessage: "",
|
||||
}, {
|
||||
desc: "IPv4 only",
|
||||
trustedIPs: []string{
|
||||
"127.0.0.1/8",
|
||||
},
|
||||
expectedAuthorizedIPs: []*net.IPNet{
|
||||
{IP: net.IPv4(127, 0, 0, 0).To4(), Mask: net.IPv4Mask(255, 0, 0, 0)},
|
||||
},
|
||||
errMessage: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ipChecker, err := NewChecker(test.trustedIPs)
|
||||
if test.errMessage != "" {
|
||||
require.EqualError(t, err, test.errMessage)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
for index, actual := range ipChecker.authorizedIPsNet {
|
||||
expected := test.expectedAuthorizedIPs[index]
|
||||
assert.Equal(t, expected.IP, actual.IP)
|
||||
assert.Equal(t, expected.Mask.String(), actual.Mask.String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsIsAllowed(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
trustedIPs []string
|
||||
passIPs []string
|
||||
rejectIPs []string
|
||||
}{
|
||||
{
|
||||
desc: "IPv4",
|
||||
trustedIPs: []string{"1.2.3.4/24"},
|
||||
passIPs: []string{
|
||||
"1.2.3.1",
|
||||
"1.2.3.32",
|
||||
"1.2.3.156",
|
||||
"1.2.3.255",
|
||||
},
|
||||
rejectIPs: []string{
|
||||
"1.2.16.1",
|
||||
"1.2.32.1",
|
||||
"127.0.0.1",
|
||||
"8.8.8.8",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "IPv4 single IP",
|
||||
trustedIPs: []string{"8.8.8.8"},
|
||||
passIPs: []string{"8.8.8.8"},
|
||||
rejectIPs: []string{
|
||||
"8.8.8.7",
|
||||
"8.8.8.9",
|
||||
"8.8.8.0",
|
||||
"8.8.8.255",
|
||||
"4.4.4.4",
|
||||
"127.0.0.1",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "IPv4 Net single IP",
|
||||
trustedIPs: []string{"8.8.8.8/32"},
|
||||
passIPs: []string{"8.8.8.8"},
|
||||
rejectIPs: []string{
|
||||
"8.8.8.7",
|
||||
"8.8.8.9",
|
||||
"8.8.8.0",
|
||||
"8.8.8.255",
|
||||
"4.4.4.4",
|
||||
"127.0.0.1",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple IPv4",
|
||||
trustedIPs: []string{"1.2.3.4/24", "8.8.8.8/8"},
|
||||
passIPs: []string{
|
||||
"1.2.3.1",
|
||||
"1.2.3.32",
|
||||
"1.2.3.156",
|
||||
"1.2.3.255",
|
||||
"8.8.4.4",
|
||||
"8.0.0.1",
|
||||
"8.32.42.128",
|
||||
"8.255.255.255",
|
||||
},
|
||||
rejectIPs: []string{
|
||||
"1.2.16.1",
|
||||
"1.2.32.1",
|
||||
"127.0.0.1",
|
||||
"4.4.4.4",
|
||||
"4.8.8.8",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "IPv6",
|
||||
trustedIPs: []string{"2a03:4000:6:d080::/64"},
|
||||
passIPs: []string{
|
||||
"2a03:4000:6:d080::",
|
||||
"2a03:4000:6:d080::1",
|
||||
"2a03:4000:6:d080:dead:beef:ffff:ffff",
|
||||
"2a03:4000:6:d080::42",
|
||||
},
|
||||
rejectIPs: []string{
|
||||
"2a03:4000:7:d080::",
|
||||
"2a03:4000:7:d080::1",
|
||||
"fe80::",
|
||||
"4242::1",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "IPv6 single IP",
|
||||
trustedIPs: []string{"2a03:4000:6:d080::42/128"},
|
||||
passIPs: []string{"2a03:4000:6:d080::42"},
|
||||
rejectIPs: []string{
|
||||
"2a03:4000:6:d080::1",
|
||||
"2a03:4000:6:d080:dead:beef:ffff:ffff",
|
||||
"2a03:4000:6:d080::43",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple IPv6",
|
||||
trustedIPs: []string{"2a03:4000:6:d080::/64", "fe80::/16"},
|
||||
passIPs: []string{
|
||||
"2a03:4000:6:d080::",
|
||||
"2a03:4000:6:d080::1",
|
||||
"2a03:4000:6:d080:dead:beef:ffff:ffff",
|
||||
"2a03:4000:6:d080::42",
|
||||
"fe80::1",
|
||||
"fe80:aa00:00bb:4232:ff00:eeee:00ff:1111",
|
||||
"fe80::fe80",
|
||||
},
|
||||
rejectIPs: []string{
|
||||
"2a03:4000:7:d080::",
|
||||
"2a03:4000:7:d080::1",
|
||||
"4242::1",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple IPv6 & IPv4",
|
||||
trustedIPs: []string{"2a03:4000:6:d080::/64", "fe80::/16", "1.2.3.4/24", "8.8.8.8/8"},
|
||||
passIPs: []string{
|
||||
"2a03:4000:6:d080::",
|
||||
"2a03:4000:6:d080::1",
|
||||
"2a03:4000:6:d080:dead:beef:ffff:ffff",
|
||||
"2a03:4000:6:d080::42",
|
||||
"fe80::1",
|
||||
"fe80:aa00:00bb:4232:ff00:eeee:00ff:1111",
|
||||
"fe80::fe80",
|
||||
"1.2.3.1",
|
||||
"1.2.3.32",
|
||||
"1.2.3.156",
|
||||
"1.2.3.255",
|
||||
"8.8.4.4",
|
||||
"8.0.0.1",
|
||||
"8.32.42.128",
|
||||
"8.255.255.255",
|
||||
},
|
||||
rejectIPs: []string{
|
||||
"2a03:4000:7:d080::",
|
||||
"2a03:4000:7:d080::1",
|
||||
"4242::1",
|
||||
"1.2.16.1",
|
||||
"1.2.32.1",
|
||||
"127.0.0.1",
|
||||
"4.4.4.4",
|
||||
"4.8.8.8",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "broken IP-addresses",
|
||||
trustedIPs: []string{"127.0.0.1/32"},
|
||||
passIPs: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ipChecker, err := NewChecker(test.trustedIPs)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, ipChecker)
|
||||
|
||||
for _, testIP := range test.passIPs {
|
||||
allowed, err := ipChecker.Contains(testIP)
|
||||
require.NoError(t, err)
|
||||
assert.Truef(t, allowed, "%s should have passed.", testIP)
|
||||
}
|
||||
|
||||
for _, testIP := range test.rejectIPs {
|
||||
allowed, err := ipChecker.Contains(testIP)
|
||||
require.NoError(t, err)
|
||||
assert.Falsef(t, allowed, "%s should not have passed.", testIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsBrokenIPs(t *testing.T) {
|
||||
brokenIPs := []string{
|
||||
"foo",
|
||||
"10.0.0.350",
|
||||
"fe:::80",
|
||||
"",
|
||||
"\\&$§&/(",
|
||||
}
|
||||
|
||||
ipChecker, err := NewChecker([]string{"1.2.3.4/24"})
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, testIP := range brokenIPs {
|
||||
_, err := ipChecker.Contains(testIP)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
63
pkg/ip/strategy.go
Normal file
63
pkg/ip/strategy.go
Normal file
|
@ -0,0 +1,63 @@
|
|||
package ip
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
xForwardedFor = "X-Forwarded-For"
|
||||
)
|
||||
|
||||
// Strategy a strategy for IP selection
|
||||
type Strategy interface {
|
||||
GetIP(req *http.Request) string
|
||||
}
|
||||
|
||||
// RemoteAddrStrategy a strategy that always return the remote address
|
||||
type RemoteAddrStrategy struct{}
|
||||
|
||||
// GetIP return the selected IP
|
||||
func (s *RemoteAddrStrategy) GetIP(req *http.Request) string {
|
||||
return req.RemoteAddr
|
||||
}
|
||||
|
||||
// DepthStrategy a strategy based on the depth inside the X-Forwarded-For from right to left
|
||||
type DepthStrategy struct {
|
||||
Depth int
|
||||
}
|
||||
|
||||
// GetIP return the selected IP
|
||||
func (s *DepthStrategy) GetIP(req *http.Request) string {
|
||||
xff := req.Header.Get(xForwardedFor)
|
||||
xffs := strings.Split(xff, ",")
|
||||
|
||||
if len(xffs) < s.Depth {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(xffs[len(xffs)-s.Depth])
|
||||
}
|
||||
|
||||
// CheckerStrategy a strategy based on an IP Checker
|
||||
// allows to check that addresses are in a trusted IPs
|
||||
type CheckerStrategy struct {
|
||||
Checker *Checker
|
||||
}
|
||||
|
||||
// GetIP return the selected IP
|
||||
func (s *CheckerStrategy) GetIP(req *http.Request) string {
|
||||
if s.Checker == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
xff := req.Header.Get(xForwardedFor)
|
||||
xffs := strings.Split(xff, ",")
|
||||
|
||||
for i := len(xffs) - 1; i >= 0; i-- {
|
||||
xffTrimmed := strings.TrimSpace(xffs[i])
|
||||
if contain, _ := s.Checker.Contains(xffTrimmed); !contain {
|
||||
return xffTrimmed
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
125
pkg/ip/strategy_test.go
Normal file
125
pkg/ip/strategy_test.go
Normal file
|
@ -0,0 +1,125 @@
|
|||
package ip
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRemoteAddrStrategy_GetIP(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "Use RemoteAddr",
|
||||
expected: "192.0.2.1:1234",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
strategy := RemoteAddrStrategy{}
|
||||
req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
||||
actual := strategy.GetIP(req)
|
||||
assert.Equal(t, test.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDepthStrategy_GetIP(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
depth int
|
||||
xForwardedFor string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "Use depth",
|
||||
depth: 3,
|
||||
xForwardedFor: "10.0.0.4,10.0.0.3,10.0.0.2,10.0.0.1",
|
||||
expected: "10.0.0.3",
|
||||
},
|
||||
{
|
||||
desc: "Use non existing depth in XForwardedFor",
|
||||
depth: 2,
|
||||
xForwardedFor: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
desc: "Use depth that match the first IP in XForwardedFor",
|
||||
depth: 2,
|
||||
xForwardedFor: "10.0.0.2,10.0.0.1",
|
||||
expected: "10.0.0.2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
strategy := DepthStrategy{Depth: test.depth}
|
||||
req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
||||
req.Header.Set(xForwardedFor, test.xForwardedFor)
|
||||
actual := strategy.GetIP(req)
|
||||
assert.Equal(t, test.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExcludedIPsStrategy_GetIP(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
excludedIPs []string
|
||||
xForwardedFor string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "Use excluded all IPs",
|
||||
excludedIPs: []string{"10.0.0.4", "10.0.0.3", "10.0.0.2", "10.0.0.1"},
|
||||
xForwardedFor: "10.0.0.4,10.0.0.3,10.0.0.2,10.0.0.1",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
desc: "Use excluded IPs",
|
||||
excludedIPs: []string{"10.0.0.2", "10.0.0.1"},
|
||||
xForwardedFor: "10.0.0.4,10.0.0.3,10.0.0.2,10.0.0.1",
|
||||
expected: "10.0.0.3",
|
||||
},
|
||||
{
|
||||
desc: "Use excluded IPs CIDR",
|
||||
excludedIPs: []string{"10.0.0.1/24"},
|
||||
xForwardedFor: "127.0.0.1,10.0.0.4,10.0.0.3,10.0.0.2,10.0.0.1",
|
||||
expected: "127.0.0.1",
|
||||
},
|
||||
{
|
||||
desc: "Use excluded all IPs CIDR",
|
||||
excludedIPs: []string{"10.0.0.1/24"},
|
||||
xForwardedFor: "10.0.0.4,10.0.0.3,10.0.0.2,10.0.0.1",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checker, err := NewChecker(test.excludedIPs)
|
||||
require.NoError(t, err)
|
||||
|
||||
strategy := CheckerStrategy{Checker: checker}
|
||||
req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
||||
req.Header.Set(xForwardedFor, test.xForwardedFor)
|
||||
actual := strategy.GetIP(req)
|
||||
assert.Equal(t, test.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
40
pkg/job/job.go
Normal file
40
pkg/job/job.go
Normal file
|
@ -0,0 +1,40 @@
|
|||
package job
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff"
|
||||
)
|
||||
|
||||
var (
|
||||
_ backoff.BackOff = (*BackOff)(nil)
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMinJobInterval = 30 * time.Second
|
||||
)
|
||||
|
||||
// BackOff is an exponential backoff implementation for long running jobs.
|
||||
// In long running jobs, an operation() that fails after a long Duration should not increments the backoff period.
|
||||
// If operation() takes more than MinJobInterval, Reset() is called in NextBackOff().
|
||||
type BackOff struct {
|
||||
*backoff.ExponentialBackOff
|
||||
MinJobInterval time.Duration
|
||||
}
|
||||
|
||||
// NewBackOff creates an instance of BackOff using default values.
|
||||
func NewBackOff(backOff *backoff.ExponentialBackOff) *BackOff {
|
||||
backOff.MaxElapsedTime = 0
|
||||
return &BackOff{
|
||||
ExponentialBackOff: backOff,
|
||||
MinJobInterval: defaultMinJobInterval,
|
||||
}
|
||||
}
|
||||
|
||||
// NextBackOff calculates the next backoff interval.
|
||||
func (b *BackOff) NextBackOff() time.Duration {
|
||||
if b.GetElapsedTime() >= b.MinJobInterval {
|
||||
b.Reset()
|
||||
}
|
||||
return b.ExponentialBackOff.NextBackOff()
|
||||
}
|
45
pkg/job/job_test.go
Normal file
45
pkg/job/job_test.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package job
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff"
|
||||
)
|
||||
|
||||
func TestJobBackOff(t *testing.T) {
|
||||
var (
|
||||
testInitialInterval = 500 * time.Millisecond
|
||||
testRandomizationFactor = 0.1
|
||||
testMultiplier = 2.0
|
||||
testMaxInterval = 5 * time.Second
|
||||
testMinJobInterval = 1 * time.Second
|
||||
)
|
||||
|
||||
exp := NewBackOff(backoff.NewExponentialBackOff())
|
||||
exp.InitialInterval = testInitialInterval
|
||||
exp.RandomizationFactor = testRandomizationFactor
|
||||
exp.Multiplier = testMultiplier
|
||||
exp.MaxInterval = testMaxInterval
|
||||
exp.MinJobInterval = testMinJobInterval
|
||||
exp.Reset()
|
||||
|
||||
var expectedResults = []time.Duration{500, 500, 500, 1000, 2000, 4000, 5000, 5000, 500, 1000, 2000, 4000, 5000, 5000}
|
||||
for i, d := range expectedResults {
|
||||
expectedResults[i] = d * time.Millisecond
|
||||
}
|
||||
|
||||
for i, expected := range expectedResults {
|
||||
// Assert that the next backoff falls in the expected range.
|
||||
var minInterval = expected - time.Duration(testRandomizationFactor*float64(expected))
|
||||
var maxInterval = expected + time.Duration(testRandomizationFactor*float64(expected))
|
||||
if i < 3 || i == 8 {
|
||||
time.Sleep(2 * time.Second)
|
||||
}
|
||||
var actualInterval = exp.NextBackOff()
|
||||
if !(minInterval <= actualInterval && actualInterval <= maxInterval) {
|
||||
t.Error("error")
|
||||
}
|
||||
// assertEquals(t, expected, exp.currentInterval)
|
||||
}
|
||||
}
|
139
pkg/log/deprecated.go
Normal file
139
pkg/log/deprecated.go
Normal file
|
@ -0,0 +1,139 @@
|
|||
package log
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"runtime"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Debug logs a message at level Debug on the standard logger.
|
||||
// Deprecated
|
||||
func Debug(args ...interface{}) {
|
||||
mainLogger.Debug(args...)
|
||||
}
|
||||
|
||||
// Debugf logs a message at level Debug on the standard logger.
|
||||
// Deprecated
|
||||
func Debugf(format string, args ...interface{}) {
|
||||
mainLogger.Debugf(format, args...)
|
||||
}
|
||||
|
||||
// Info logs a message at level Info on the standard logger.
|
||||
// Deprecated
|
||||
func Info(args ...interface{}) {
|
||||
mainLogger.Info(args...)
|
||||
}
|
||||
|
||||
// Infof logs a message at level Info on the standard logger.
|
||||
// Deprecated
|
||||
func Infof(format string, args ...interface{}) {
|
||||
mainLogger.Infof(format, args...)
|
||||
}
|
||||
|
||||
// Warn logs a message at level Warn on the standard logger.
|
||||
// Deprecated
|
||||
func Warn(args ...interface{}) {
|
||||
mainLogger.Warn(args...)
|
||||
}
|
||||
|
||||
// Warnf logs a message at level Warn on the standard logger.
|
||||
// Deprecated
|
||||
func Warnf(format string, args ...interface{}) {
|
||||
mainLogger.Warnf(format, args...)
|
||||
}
|
||||
|
||||
// Error logs a message at level Error on the standard logger.
|
||||
// Deprecated
|
||||
func Error(args ...interface{}) {
|
||||
mainLogger.Error(args...)
|
||||
}
|
||||
|
||||
// Errorf logs a message at level Error on the standard logger.
|
||||
// Deprecated
|
||||
func Errorf(format string, args ...interface{}) {
|
||||
mainLogger.Errorf(format, args...)
|
||||
}
|
||||
|
||||
// Panic logs a message at level Panic on the standard logger.
|
||||
// Deprecated
|
||||
func Panic(args ...interface{}) {
|
||||
mainLogger.Panic(args...)
|
||||
}
|
||||
|
||||
// Panicf logs a message at level Panic on the standard logger.
|
||||
// Deprecated
|
||||
func Panicf(format string, args ...interface{}) {
|
||||
mainLogger.Panicf(format, args...)
|
||||
}
|
||||
|
||||
// Fatal logs a message at level Fatal on the standard logger.
|
||||
// Deprecated
|
||||
func Fatal(args ...interface{}) {
|
||||
mainLogger.Fatal(args...)
|
||||
}
|
||||
|
||||
// Fatalf logs a message at level Fatal on the standard logger.
|
||||
// Deprecated
|
||||
func Fatalf(format string, args ...interface{}) {
|
||||
mainLogger.Fatalf(format, args...)
|
||||
}
|
||||
|
||||
// AddHook adds a hook to the standard logger hooks.
|
||||
func AddHook(hook logrus.Hook) {
|
||||
logrus.AddHook(hook)
|
||||
}
|
||||
|
||||
// CustomWriterLevel logs writer for a specific level. (with a custom scanner buffer size.)
|
||||
// adapted from github.com/Sirupsen/logrus/writer.go
|
||||
func CustomWriterLevel(level logrus.Level, maxScanTokenSize int) *io.PipeWriter {
|
||||
reader, writer := io.Pipe()
|
||||
|
||||
var printFunc func(args ...interface{})
|
||||
|
||||
switch level {
|
||||
case logrus.DebugLevel:
|
||||
printFunc = Debug
|
||||
case logrus.InfoLevel:
|
||||
printFunc = Info
|
||||
case logrus.WarnLevel:
|
||||
printFunc = Warn
|
||||
case logrus.ErrorLevel:
|
||||
printFunc = Error
|
||||
case logrus.FatalLevel:
|
||||
printFunc = Fatal
|
||||
case logrus.PanicLevel:
|
||||
printFunc = Panic
|
||||
default:
|
||||
printFunc = mainLogger.Print
|
||||
}
|
||||
|
||||
go writerScanner(reader, maxScanTokenSize, printFunc)
|
||||
runtime.SetFinalizer(writer, writerFinalizer)
|
||||
|
||||
return writer
|
||||
}
|
||||
|
||||
// extract from github.com/Sirupsen/logrus/writer.go
|
||||
// Hack the buffer size
|
||||
func writerScanner(reader io.ReadCloser, scanTokenSize int, printFunc func(args ...interface{})) {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
|
||||
if scanTokenSize > bufio.MaxScanTokenSize {
|
||||
buf := make([]byte, bufio.MaxScanTokenSize)
|
||||
scanner.Buffer(buf, scanTokenSize)
|
||||
}
|
||||
|
||||
for scanner.Scan() {
|
||||
printFunc(scanner.Text())
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
Errorf("Error while reading from Writer: %s", err)
|
||||
}
|
||||
reader.Close()
|
||||
}
|
||||
|
||||
func writerFinalizer(writer *io.PipeWriter) {
|
||||
writer.Close()
|
||||
}
|
15
pkg/log/fields.go
Normal file
15
pkg/log/fields.go
Normal file
|
@ -0,0 +1,15 @@
|
|||
package log
|
||||
|
||||
// Log entry name
|
||||
const (
|
||||
EntryPointName = "entryPointName"
|
||||
RouterName = "routerName"
|
||||
Rule = "rule"
|
||||
MiddlewareName = "middlewareName"
|
||||
MiddlewareType = "middlewareType"
|
||||
ProviderName = "providerName"
|
||||
ServiceName = "serviceName"
|
||||
MetricsProviderName = "metricsProviderName"
|
||||
TracingProviderName = "tracingProviderName"
|
||||
ServerName = "serverName"
|
||||
)
|
145
pkg/log/log.go
Normal file
145
pkg/log/log.go
Normal file
|
@ -0,0 +1,145 @@
|
|||
package log
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
loggerKey contextKey = iota
|
||||
)
|
||||
|
||||
// Logger the Traefik logger
|
||||
type Logger interface {
|
||||
logrus.FieldLogger
|
||||
WriterLevel(logrus.Level) *io.PipeWriter
|
||||
}
|
||||
|
||||
var (
|
||||
mainLogger Logger
|
||||
logFilePath string
|
||||
logFile *os.File
|
||||
)
|
||||
|
||||
func init() {
|
||||
mainLogger = logrus.StandardLogger()
|
||||
logrus.SetOutput(os.Stdout)
|
||||
}
|
||||
|
||||
// SetLogger sets the logger.
|
||||
func SetLogger(l Logger) {
|
||||
mainLogger = l
|
||||
}
|
||||
|
||||
// SetOutput sets the standard logger output.
|
||||
func SetOutput(out io.Writer) {
|
||||
logrus.SetOutput(out)
|
||||
}
|
||||
|
||||
// SetFormatter sets the standard logger formatter.
|
||||
func SetFormatter(formatter logrus.Formatter) {
|
||||
logrus.SetFormatter(formatter)
|
||||
}
|
||||
|
||||
// SetLevel sets the standard logger level.
|
||||
func SetLevel(level logrus.Level) {
|
||||
logrus.SetLevel(level)
|
||||
}
|
||||
|
||||
// GetLevel returns the standard logger level.
|
||||
func GetLevel() logrus.Level {
|
||||
return logrus.GetLevel()
|
||||
}
|
||||
|
||||
// Str adds a string field
|
||||
func Str(key, value string) func(logrus.Fields) {
|
||||
return func(fields logrus.Fields) {
|
||||
fields[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
// With Adds fields
|
||||
func With(ctx context.Context, opts ...func(logrus.Fields)) context.Context {
|
||||
logger := FromContext(ctx)
|
||||
|
||||
fields := make(logrus.Fields)
|
||||
for _, opt := range opts {
|
||||
opt(fields)
|
||||
}
|
||||
logger = logger.WithFields(fields)
|
||||
|
||||
return context.WithValue(ctx, loggerKey, logger)
|
||||
}
|
||||
|
||||
// FromContext Gets the logger from context
|
||||
func FromContext(ctx context.Context) Logger {
|
||||
if ctx == nil {
|
||||
panic("nil context")
|
||||
}
|
||||
|
||||
logger, ok := ctx.Value(loggerKey).(Logger)
|
||||
if !ok {
|
||||
logger = mainLogger
|
||||
}
|
||||
|
||||
return logger
|
||||
}
|
||||
|
||||
// WithoutContext Gets the main logger
|
||||
func WithoutContext() Logger {
|
||||
return mainLogger
|
||||
}
|
||||
|
||||
// OpenFile opens the log file using the specified path
|
||||
func OpenFile(path string) error {
|
||||
logFilePath = path
|
||||
|
||||
var err error
|
||||
logFile, err = os.OpenFile(logFilePath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
SetOutput(logFile)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseFile closes the log and sets the Output to stdout
|
||||
func CloseFile() error {
|
||||
logrus.SetOutput(os.Stdout)
|
||||
|
||||
if logFile != nil {
|
||||
return logFile.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RotateFile closes and reopens the log file to allow for rotation
|
||||
// by an external source. If the log isn't backed by a file then
|
||||
// it does nothing.
|
||||
func RotateFile() error {
|
||||
logger := FromContext(context.Background())
|
||||
|
||||
if logFile == nil && logFilePath == "" {
|
||||
logger.Debug("Traefik log is not writing to a file, ignoring rotate request")
|
||||
return nil
|
||||
}
|
||||
|
||||
if logFile != nil {
|
||||
defer func(f *os.File) {
|
||||
_ = f.Close()
|
||||
}(logFile)
|
||||
}
|
||||
|
||||
if err := OpenFile(logFilePath); err != nil {
|
||||
return fmt.Errorf("error opening log file: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
58
pkg/log/log_test.go
Normal file
58
pkg/log/log_test.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package log
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestLog(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
fields map[string]string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "Log with one field",
|
||||
fields: map[string]string{
|
||||
"foo": "bar",
|
||||
},
|
||||
expected: ` level=error msg="message test" foo=bar$`,
|
||||
},
|
||||
{
|
||||
desc: "Log with two fields",
|
||||
fields: map[string]string{
|
||||
"foo": "bar",
|
||||
"oof": "rab",
|
||||
},
|
||||
expected: ` level=error msg="message test" foo=bar oof=rab$`,
|
||||
},
|
||||
{
|
||||
desc: "Log without field",
|
||||
fields: map[string]string{},
|
||||
expected: ` level=error msg="message test"$`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
|
||||
var buffer bytes.Buffer
|
||||
SetOutput(&buffer)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
for key, value := range test.fields {
|
||||
ctx = With(ctx, Str(key, value))
|
||||
}
|
||||
|
||||
FromContext(ctx).Error("message test")
|
||||
|
||||
assert.Regexp(t, test.expected, strings.TrimSpace(buffer.String()))
|
||||
})
|
||||
}
|
||||
}
|
88
pkg/metrics/datadog.go
Normal file
88
pkg/metrics/datadog.go
Normal file
|
@ -0,0 +1,88 @@
|
|||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
"github.com/containous/traefik/pkg/safe"
|
||||
"github.com/containous/traefik/pkg/types"
|
||||
kitlog "github.com/go-kit/kit/log"
|
||||
"github.com/go-kit/kit/metrics/dogstatsd"
|
||||
)
|
||||
|
||||
var datadogClient = dogstatsd.New("traefik.", kitlog.LoggerFunc(func(keyvals ...interface{}) error {
|
||||
log.WithoutContext().WithField(log.MetricsProviderName, "datadog").Info(keyvals)
|
||||
return nil
|
||||
}))
|
||||
|
||||
var datadogTicker *time.Ticker
|
||||
|
||||
// Metric names consistent with https://github.com/DataDog/integrations-extras/pull/64
|
||||
const (
|
||||
ddMetricsBackendReqsName = "backend.request.total"
|
||||
ddMetricsBackendLatencyName = "backend.request.duration"
|
||||
ddRetriesTotalName = "backend.retries.total"
|
||||
ddConfigReloadsName = "config.reload.total"
|
||||
ddConfigReloadsFailureTagName = "failure"
|
||||
ddLastConfigReloadSuccessName = "config.reload.lastSuccessTimestamp"
|
||||
ddLastConfigReloadFailureName = "config.reload.lastFailureTimestamp"
|
||||
ddEntrypointReqsName = "entrypoint.request.total"
|
||||
ddEntrypointReqDurationName = "entrypoint.request.duration"
|
||||
ddEntrypointOpenConnsName = "entrypoint.connections.open"
|
||||
ddOpenConnsName = "backend.connections.open"
|
||||
ddServerUpName = "backend.server.up"
|
||||
)
|
||||
|
||||
// RegisterDatadog registers the metrics pusher if this didn't happen yet and creates a datadog Registry instance.
|
||||
func RegisterDatadog(ctx context.Context, config *types.Datadog) Registry {
|
||||
if datadogTicker == nil {
|
||||
datadogTicker = initDatadogClient(ctx, config)
|
||||
}
|
||||
|
||||
registry := &standardRegistry{
|
||||
enabled: true,
|
||||
configReloadsCounter: datadogClient.NewCounter(ddConfigReloadsName, 1.0),
|
||||
configReloadsFailureCounter: datadogClient.NewCounter(ddConfigReloadsName, 1.0).With(ddConfigReloadsFailureTagName, "true"),
|
||||
lastConfigReloadSuccessGauge: datadogClient.NewGauge(ddLastConfigReloadSuccessName),
|
||||
lastConfigReloadFailureGauge: datadogClient.NewGauge(ddLastConfigReloadFailureName),
|
||||
entrypointReqsCounter: datadogClient.NewCounter(ddEntrypointReqsName, 1.0),
|
||||
entrypointReqDurationHistogram: datadogClient.NewHistogram(ddEntrypointReqDurationName, 1.0),
|
||||
entrypointOpenConnsGauge: datadogClient.NewGauge(ddEntrypointOpenConnsName),
|
||||
backendReqsCounter: datadogClient.NewCounter(ddMetricsBackendReqsName, 1.0),
|
||||
backendReqDurationHistogram: datadogClient.NewHistogram(ddMetricsBackendLatencyName, 1.0),
|
||||
backendRetriesCounter: datadogClient.NewCounter(ddRetriesTotalName, 1.0),
|
||||
backendOpenConnsGauge: datadogClient.NewGauge(ddOpenConnsName),
|
||||
backendServerUpGauge: datadogClient.NewGauge(ddServerUpName),
|
||||
}
|
||||
|
||||
return registry
|
||||
}
|
||||
|
||||
func initDatadogClient(ctx context.Context, config *types.Datadog) *time.Ticker {
|
||||
address := config.Address
|
||||
if len(address) == 0 {
|
||||
address = "localhost:8125"
|
||||
}
|
||||
pushInterval, err := time.ParseDuration(config.PushInterval)
|
||||
if err != nil {
|
||||
log.FromContext(ctx).Warnf("Unable to parse %s from config.PushInterval: using 10s as the default value", config.PushInterval)
|
||||
pushInterval = 10 * time.Second
|
||||
}
|
||||
|
||||
report := time.NewTicker(pushInterval)
|
||||
|
||||
safe.Go(func() {
|
||||
datadogClient.SendLoop(report.C, "udp", address)
|
||||
})
|
||||
|
||||
return report
|
||||
}
|
||||
|
||||
// StopDatadog stops internal datadogTicker which controls the pushing of metrics to DD Agent and resets it to `nil`.
|
||||
func StopDatadog() {
|
||||
if datadogTicker != nil {
|
||||
datadogTicker.Stop()
|
||||
}
|
||||
datadogTicker = nil
|
||||
}
|
53
pkg/metrics/datadog_test.go
Normal file
53
pkg/metrics/datadog_test.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/pkg/types"
|
||||
"github.com/stvp/go-udp-testing"
|
||||
)
|
||||
|
||||
func TestDatadog(t *testing.T) {
|
||||
udp.SetAddr(":18125")
|
||||
// This is needed to make sure that UDP Listener listens for data a bit longer, otherwise it will quit after a millisecond
|
||||
udp.Timeout = 5 * time.Second
|
||||
|
||||
datadogRegistry := RegisterDatadog(context.Background(), &types.Datadog{Address: ":18125", PushInterval: "1s"})
|
||||
defer StopDatadog()
|
||||
|
||||
if !datadogRegistry.IsEnabled() {
|
||||
t.Errorf("DatadogRegistry should return true for IsEnabled()")
|
||||
}
|
||||
|
||||
expected := []string{
|
||||
// We are only validating counts, as it is nearly impossible to validate latency, since it varies every run
|
||||
"traefik.backend.request.total:1.000000|c|#service:test,code:404,method:GET\n",
|
||||
"traefik.backend.request.total:1.000000|c|#service:test,code:200,method:GET\n",
|
||||
"traefik.backend.retries.total:2.000000|c|#service:test\n",
|
||||
"traefik.backend.request.duration:10000.000000|h|#service:test,code:200\n",
|
||||
"traefik.config.reload.total:1.000000|c\n",
|
||||
"traefik.config.reload.total:1.000000|c|#failure:true\n",
|
||||
"traefik.entrypoint.request.total:1.000000|c|#entrypoint:test\n",
|
||||
"traefik.entrypoint.request.duration:10000.000000|h|#entrypoint:test\n",
|
||||
"traefik.entrypoint.connections.open:1.000000|g|#entrypoint:test\n",
|
||||
"traefik.backend.server.up:1.000000|g|#backend:test,url:http://127.0.0.1,one:two\n",
|
||||
}
|
||||
|
||||
udp.ShouldReceiveAll(t, expected, func() {
|
||||
datadogRegistry.BackendReqsCounter().With("service", "test", "code", strconv.Itoa(http.StatusOK), "method", http.MethodGet).Add(1)
|
||||
datadogRegistry.BackendReqsCounter().With("service", "test", "code", strconv.Itoa(http.StatusNotFound), "method", http.MethodGet).Add(1)
|
||||
datadogRegistry.BackendReqDurationHistogram().With("service", "test", "code", strconv.Itoa(http.StatusOK)).Observe(10000)
|
||||
datadogRegistry.BackendRetriesCounter().With("service", "test").Add(1)
|
||||
datadogRegistry.BackendRetriesCounter().With("service", "test").Add(1)
|
||||
datadogRegistry.ConfigReloadsCounter().Add(1)
|
||||
datadogRegistry.ConfigReloadsFailureCounter().Add(1)
|
||||
datadogRegistry.EntrypointReqsCounter().With("entrypoint", "test").Add(1)
|
||||
datadogRegistry.EntrypointReqDurationHistogram().With("entrypoint", "test").Observe(10000)
|
||||
datadogRegistry.EntrypointOpenConnsGauge().With("entrypoint", "test").Set(1)
|
||||
datadogRegistry.BackendServerUpGauge().With("backend", "test", "url", "http://127.0.0.1", "one", "two").Set(1)
|
||||
})
|
||||
}
|
213
pkg/metrics/influxdb.go
Normal file
213
pkg/metrics/influxdb.go
Normal file
|
@ -0,0 +1,213 @@
|
|||
package metrics
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
"github.com/containous/traefik/pkg/safe"
|
||||
"github.com/containous/traefik/pkg/types"
|
||||
kitlog "github.com/go-kit/kit/log"
|
||||
"github.com/go-kit/kit/metrics/influx"
|
||||
influxdb "github.com/influxdata/influxdb/client/v2"
|
||||
)
|
||||
|
||||
var influxDBClient *influx.Influx
|
||||
|
||||
type influxDBWriter struct {
|
||||
buf bytes.Buffer
|
||||
config *types.InfluxDB
|
||||
}
|
||||
|
||||
var influxDBTicker *time.Ticker
|
||||
|
||||
const (
|
||||
influxDBMetricsBackendReqsName = "traefik.backend.requests.total"
|
||||
influxDBMetricsBackendLatencyName = "traefik.backend.request.duration"
|
||||
influxDBRetriesTotalName = "traefik.backend.retries.total"
|
||||
influxDBConfigReloadsName = "traefik.config.reload.total"
|
||||
influxDBConfigReloadsFailureName = influxDBConfigReloadsName + ".failure"
|
||||
influxDBLastConfigReloadSuccessName = "traefik.config.reload.lastSuccessTimestamp"
|
||||
influxDBLastConfigReloadFailureName = "traefik.config.reload.lastFailureTimestamp"
|
||||
influxDBEntrypointReqsName = "traefik.entrypoint.requests.total"
|
||||
influxDBEntrypointReqDurationName = "traefik.entrypoint.request.duration"
|
||||
influxDBEntrypointOpenConnsName = "traefik.entrypoint.connections.open"
|
||||
influxDBOpenConnsName = "traefik.backend.connections.open"
|
||||
influxDBServerUpName = "traefik.backend.server.up"
|
||||
)
|
||||
|
||||
const (
|
||||
protocolHTTP = "http"
|
||||
protocolUDP = "udp"
|
||||
)
|
||||
|
||||
// RegisterInfluxDB registers the metrics pusher if this didn't happen yet and creates a InfluxDB Registry instance.
|
||||
func RegisterInfluxDB(ctx context.Context, config *types.InfluxDB) Registry {
|
||||
if influxDBClient == nil {
|
||||
influxDBClient = initInfluxDBClient(ctx, config)
|
||||
}
|
||||
if influxDBTicker == nil {
|
||||
influxDBTicker = initInfluxDBTicker(ctx, config)
|
||||
}
|
||||
|
||||
return &standardRegistry{
|
||||
enabled: true,
|
||||
configReloadsCounter: influxDBClient.NewCounter(influxDBConfigReloadsName),
|
||||
configReloadsFailureCounter: influxDBClient.NewCounter(influxDBConfigReloadsFailureName),
|
||||
lastConfigReloadSuccessGauge: influxDBClient.NewGauge(influxDBLastConfigReloadSuccessName),
|
||||
lastConfigReloadFailureGauge: influxDBClient.NewGauge(influxDBLastConfigReloadFailureName),
|
||||
entrypointReqsCounter: influxDBClient.NewCounter(influxDBEntrypointReqsName),
|
||||
entrypointReqDurationHistogram: influxDBClient.NewHistogram(influxDBEntrypointReqDurationName),
|
||||
entrypointOpenConnsGauge: influxDBClient.NewGauge(influxDBEntrypointOpenConnsName),
|
||||
backendReqsCounter: influxDBClient.NewCounter(influxDBMetricsBackendReqsName),
|
||||
backendReqDurationHistogram: influxDBClient.NewHistogram(influxDBMetricsBackendLatencyName),
|
||||
backendRetriesCounter: influxDBClient.NewCounter(influxDBRetriesTotalName),
|
||||
backendOpenConnsGauge: influxDBClient.NewGauge(influxDBOpenConnsName),
|
||||
backendServerUpGauge: influxDBClient.NewGauge(influxDBServerUpName),
|
||||
}
|
||||
}
|
||||
|
||||
// initInfluxDBTicker creates a influxDBClient
|
||||
func initInfluxDBClient(ctx context.Context, config *types.InfluxDB) *influx.Influx {
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
// TODO deprecated: move this switch into configuration.SetEffectiveConfiguration when web provider will be removed.
|
||||
switch config.Protocol {
|
||||
case protocolUDP:
|
||||
if len(config.Database) > 0 || len(config.RetentionPolicy) > 0 {
|
||||
logger.Warn("Database and RetentionPolicy options have no effect with UDP.")
|
||||
config.Database = ""
|
||||
config.RetentionPolicy = ""
|
||||
}
|
||||
case protocolHTTP:
|
||||
if u, err := url.Parse(config.Address); err == nil {
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
logger.Warnf("InfluxDB address %s should specify a scheme (http or https): falling back on HTTP.", config.Address)
|
||||
config.Address = "http://" + config.Address
|
||||
}
|
||||
} else {
|
||||
logger.Errorf("Unable to parse the InfluxDB address %v: falling back on UDP.", err)
|
||||
config.Protocol = protocolUDP
|
||||
config.Database = ""
|
||||
config.RetentionPolicy = ""
|
||||
}
|
||||
default:
|
||||
logger.Warnf("Unsupported protocol %s: falling back on UDP.", config.Protocol)
|
||||
config.Protocol = protocolUDP
|
||||
config.Database = ""
|
||||
config.RetentionPolicy = ""
|
||||
}
|
||||
|
||||
return influx.New(
|
||||
map[string]string{},
|
||||
influxdb.BatchPointsConfig{
|
||||
Database: config.Database,
|
||||
RetentionPolicy: config.RetentionPolicy,
|
||||
},
|
||||
kitlog.LoggerFunc(func(keyvals ...interface{}) error {
|
||||
log.WithoutContext().WithField(log.MetricsProviderName, "influxdb").Info(keyvals)
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
// initInfluxDBTicker initializes metrics pusher
|
||||
func initInfluxDBTicker(ctx context.Context, config *types.InfluxDB) *time.Ticker {
|
||||
pushInterval, err := time.ParseDuration(config.PushInterval)
|
||||
if err != nil {
|
||||
log.FromContext(ctx).Warnf("Unable to parse %s from config.PushInterval: using 10s as the default value", config.PushInterval)
|
||||
pushInterval = 10 * time.Second
|
||||
}
|
||||
|
||||
report := time.NewTicker(pushInterval)
|
||||
|
||||
safe.Go(func() {
|
||||
var buf bytes.Buffer
|
||||
influxDBClient.WriteLoop(report.C, &influxDBWriter{buf: buf, config: config})
|
||||
})
|
||||
|
||||
return report
|
||||
}
|
||||
|
||||
// StopInfluxDB stops internal influxDBTicker which controls the pushing of metrics to InfluxDB Agent and resets it to `nil`
|
||||
func StopInfluxDB() {
|
||||
if influxDBTicker != nil {
|
||||
influxDBTicker.Stop()
|
||||
}
|
||||
influxDBTicker = nil
|
||||
}
|
||||
|
||||
// Write creates a http or udp client and attempts to write BatchPoints.
|
||||
// If a "database not found" error is encountered, a CREATE DATABASE
|
||||
// query is attempted when using protocol http.
|
||||
func (w *influxDBWriter) Write(bp influxdb.BatchPoints) error {
|
||||
c, err := w.initWriteClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer c.Close()
|
||||
|
||||
if writeErr := c.Write(bp); writeErr != nil {
|
||||
ctx := log.With(context.Background(), log.Str(log.MetricsProviderName, "influxdb"))
|
||||
log.FromContext(ctx).Errorf("Error while writing to InfluxDB: %s", writeErr.Error())
|
||||
|
||||
if handleErr := w.handleWriteError(ctx, c, writeErr); handleErr != nil {
|
||||
return handleErr
|
||||
}
|
||||
// Retry write after successful handling of writeErr
|
||||
return c.Write(bp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *influxDBWriter) initWriteClient() (influxdb.Client, error) {
|
||||
if w.config.Protocol == "http" {
|
||||
return influxdb.NewHTTPClient(influxdb.HTTPConfig{
|
||||
Addr: w.config.Address,
|
||||
Username: w.config.Username,
|
||||
Password: w.config.Password,
|
||||
})
|
||||
}
|
||||
|
||||
return influxdb.NewUDPClient(influxdb.UDPConfig{
|
||||
Addr: w.config.Address,
|
||||
})
|
||||
}
|
||||
|
||||
func (w *influxDBWriter) handleWriteError(ctx context.Context, c influxdb.Client, writeErr error) error {
|
||||
if w.config.Protocol != protocolHTTP {
|
||||
return writeErr
|
||||
}
|
||||
|
||||
match, matchErr := regexp.MatchString("database not found", writeErr.Error())
|
||||
|
||||
if matchErr != nil || !match {
|
||||
return writeErr
|
||||
}
|
||||
|
||||
qStr := fmt.Sprintf("CREATE DATABASE \"%s\"", w.config.Database)
|
||||
if w.config.RetentionPolicy != "" {
|
||||
qStr = fmt.Sprintf("%s WITH NAME \"%s\"", qStr, w.config.RetentionPolicy)
|
||||
}
|
||||
|
||||
logger := log.FromContext(ctx)
|
||||
|
||||
logger.Debugf("InfluxDB database not found: attempting to create one with %s", qStr)
|
||||
|
||||
q := influxdb.NewQuery(qStr, "", "")
|
||||
response, queryErr := c.Query(q)
|
||||
if queryErr == nil && response.Error() != nil {
|
||||
queryErr = response.Error()
|
||||
}
|
||||
if queryErr != nil {
|
||||
logger.Errorf("Error while creating the InfluxDB database %s", queryErr)
|
||||
return queryErr
|
||||
}
|
||||
|
||||
logger.Debugf("Successfully created the InfluxDB database %s", w.config.Database)
|
||||
return nil
|
||||
}
|
135
pkg/metrics/influxdb_test.go
Normal file
135
pkg/metrics/influxdb_test.go
Normal file
|
@ -0,0 +1,135 @@
|
|||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/pkg/types"
|
||||
"github.com/stvp/go-udp-testing"
|
||||
)
|
||||
|
||||
func TestInfluxDB(t *testing.T) {
|
||||
udp.SetAddr(":8089")
|
||||
// This is needed to make sure that UDP Listener listens for data a bit longer, otherwise it will quit after a millisecond
|
||||
udp.Timeout = 5 * time.Second
|
||||
|
||||
influxDBRegistry := RegisterInfluxDB(context.Background(), &types.InfluxDB{Address: ":8089", PushInterval: "1s"})
|
||||
defer StopInfluxDB()
|
||||
|
||||
if !influxDBRegistry.IsEnabled() {
|
||||
t.Fatalf("InfluxDB registry must be enabled")
|
||||
}
|
||||
|
||||
expectedBackend := []string{
|
||||
`(traefik\.backend\.requests\.total,backend=test,code=200,method=GET count=1) [\d]{19}`,
|
||||
`(traefik\.backend\.requests\.total,backend=test,code=404,method=GET count=1) [\d]{19}`,
|
||||
`(traefik\.backend\.request\.duration,backend=test,code=200 p50=10000,p90=10000,p95=10000,p99=10000) [\d]{19}`,
|
||||
`(traefik\.backend\.retries\.total(?:,code=[\d]{3},method=GET)?,backend=test count=2) [\d]{19}`,
|
||||
`(traefik\.config\.reload\.total(?:[a-z=0-9A-Z,]+)? count=1) [\d]{19}`,
|
||||
`(traefik\.config\.reload\.total\.failure(?:[a-z=0-9A-Z,]+)? count=1) [\d]{19}`,
|
||||
`(traefik\.backend\.server\.up,backend=test(?:[a-z=0-9A-Z,]+)?,url=http://127.0.0.1 value=1) [\d]{19}`,
|
||||
}
|
||||
|
||||
msgBackend := udp.ReceiveString(t, func() {
|
||||
influxDBRegistry.BackendReqsCounter().With("backend", "test", "code", strconv.Itoa(http.StatusOK), "method", http.MethodGet).Add(1)
|
||||
influxDBRegistry.BackendReqsCounter().With("backend", "test", "code", strconv.Itoa(http.StatusNotFound), "method", http.MethodGet).Add(1)
|
||||
influxDBRegistry.BackendRetriesCounter().With("backend", "test").Add(1)
|
||||
influxDBRegistry.BackendRetriesCounter().With("backend", "test").Add(1)
|
||||
influxDBRegistry.BackendReqDurationHistogram().With("backend", "test", "code", strconv.Itoa(http.StatusOK)).Observe(10000)
|
||||
influxDBRegistry.ConfigReloadsCounter().Add(1)
|
||||
influxDBRegistry.ConfigReloadsFailureCounter().Add(1)
|
||||
influxDBRegistry.BackendServerUpGauge().With("backend", "test", "url", "http://127.0.0.1").Set(1)
|
||||
})
|
||||
|
||||
assertMessage(t, msgBackend, expectedBackend)
|
||||
|
||||
expectedEntrypoint := []string{
|
||||
`(traefik\.entrypoint\.requests\.total,entrypoint=test(?:[a-z=0-9A-Z,:/.]+)? count=1) [\d]{19}`,
|
||||
`(traefik\.entrypoint\.request\.duration(?:,code=[\d]{3})?,entrypoint=test(?:[a-z=0-9A-Z,:/.]+)? p50=10000,p90=10000,p95=10000,p99=10000) [\d]{19}`,
|
||||
`(traefik\.entrypoint\.connections\.open,entrypoint=test value=1) [\d]{19}`,
|
||||
}
|
||||
|
||||
msgEntrypoint := udp.ReceiveString(t, func() {
|
||||
influxDBRegistry.EntrypointReqsCounter().With("entrypoint", "test").Add(1)
|
||||
influxDBRegistry.EntrypointReqDurationHistogram().With("entrypoint", "test").Observe(10000)
|
||||
influxDBRegistry.EntrypointOpenConnsGauge().With("entrypoint", "test").Set(1)
|
||||
|
||||
})
|
||||
|
||||
assertMessage(t, msgEntrypoint, expectedEntrypoint)
|
||||
}
|
||||
|
||||
func TestInfluxDBHTTP(t *testing.T) {
|
||||
c := make(chan *string)
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "can't read body "+err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
bodyStr := string(body)
|
||||
c <- &bodyStr
|
||||
fmt.Fprintln(w, "ok")
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
influxDBRegistry := RegisterInfluxDB(context.Background(), &types.InfluxDB{Address: ts.URL, Protocol: "http", PushInterval: "1s", Database: "test", RetentionPolicy: "autogen"})
|
||||
defer StopInfluxDB()
|
||||
|
||||
if !influxDBRegistry.IsEnabled() {
|
||||
t.Fatalf("InfluxDB registry must be enabled")
|
||||
}
|
||||
|
||||
expectedBackend := []string{
|
||||
`(traefik\.backend\.requests\.total,backend=test,code=200,method=GET count=1) [\d]{19}`,
|
||||
`(traefik\.backend\.requests\.total,backend=test,code=404,method=GET count=1) [\d]{19}`,
|
||||
`(traefik\.backend\.request\.duration,backend=test,code=200 p50=10000,p90=10000,p95=10000,p99=10000) [\d]{19}`,
|
||||
`(traefik\.backend\.retries\.total(?:,code=[\d]{3},method=GET)?,backend=test count=2) [\d]{19}`,
|
||||
`(traefik\.config\.reload\.total(?:[a-z=0-9A-Z,]+)? count=1) [\d]{19}`,
|
||||
`(traefik\.config\.reload\.total\.failure(?:[a-z=0-9A-Z,]+)? count=1) [\d]{19}`,
|
||||
`(traefik\.backend\.server\.up,backend=test(?:[a-z=0-9A-Z,]+)?,url=http://127.0.0.1 value=1) [\d]{19}`,
|
||||
}
|
||||
|
||||
influxDBRegistry.BackendReqsCounter().With("backend", "test", "code", strconv.Itoa(http.StatusOK), "method", http.MethodGet).Add(1)
|
||||
influxDBRegistry.BackendReqsCounter().With("backend", "test", "code", strconv.Itoa(http.StatusNotFound), "method", http.MethodGet).Add(1)
|
||||
influxDBRegistry.BackendRetriesCounter().With("backend", "test").Add(1)
|
||||
influxDBRegistry.BackendRetriesCounter().With("backend", "test").Add(1)
|
||||
influxDBRegistry.BackendReqDurationHistogram().With("backend", "test", "code", strconv.Itoa(http.StatusOK)).Observe(10000)
|
||||
influxDBRegistry.ConfigReloadsCounter().Add(1)
|
||||
influxDBRegistry.ConfigReloadsFailureCounter().Add(1)
|
||||
influxDBRegistry.BackendServerUpGauge().With("backend", "test", "url", "http://127.0.0.1").Set(1)
|
||||
msgBackend := <-c
|
||||
|
||||
assertMessage(t, *msgBackend, expectedBackend)
|
||||
|
||||
expectedEntrypoint := []string{
|
||||
`(traefik\.entrypoint\.requests\.total,entrypoint=test(?:[a-z=0-9A-Z,:/.]+)? count=1) [\d]{19}`,
|
||||
`(traefik\.entrypoint\.request\.duration(?:,code=[\d]{3})?,entrypoint=test(?:[a-z=0-9A-Z,:/.]+)? p50=10000,p90=10000,p95=10000,p99=10000) [\d]{19}`,
|
||||
`(traefik\.entrypoint\.connections\.open,entrypoint=test value=1) [\d]{19}`,
|
||||
}
|
||||
|
||||
influxDBRegistry.EntrypointReqsCounter().With("entrypoint", "test").Add(1)
|
||||
influxDBRegistry.EntrypointReqDurationHistogram().With("entrypoint", "test").Observe(10000)
|
||||
influxDBRegistry.EntrypointOpenConnsGauge().With("entrypoint", "test").Set(1)
|
||||
msgEntrypoint := <-c
|
||||
|
||||
assertMessage(t, *msgEntrypoint, expectedEntrypoint)
|
||||
}
|
||||
|
||||
func assertMessage(t *testing.T, msg string, patterns []string) {
|
||||
t.Helper()
|
||||
for _, pattern := range patterns {
|
||||
re := regexp.MustCompile(pattern)
|
||||
match := re.FindStringSubmatch(msg)
|
||||
if len(match) != 2 {
|
||||
t.Errorf("Got %q %v, want %q", msg, match, pattern)
|
||||
}
|
||||
}
|
||||
}
|
177
pkg/metrics/metrics.go
Normal file
177
pkg/metrics/metrics.go
Normal file
|
@ -0,0 +1,177 @@
|
|||
package metrics
|
||||
|
||||
import (
|
||||
"github.com/go-kit/kit/metrics"
|
||||
"github.com/go-kit/kit/metrics/multi"
|
||||
)
|
||||
|
||||
// Registry has to implemented by any system that wants to monitor and expose metrics.
|
||||
type Registry interface {
|
||||
// IsEnabled shows whether metrics instrumentation is enabled.
|
||||
IsEnabled() bool
|
||||
|
||||
// server metrics
|
||||
ConfigReloadsCounter() metrics.Counter
|
||||
ConfigReloadsFailureCounter() metrics.Counter
|
||||
LastConfigReloadSuccessGauge() metrics.Gauge
|
||||
LastConfigReloadFailureGauge() metrics.Gauge
|
||||
|
||||
// entry point metrics
|
||||
EntrypointReqsCounter() metrics.Counter
|
||||
EntrypointReqDurationHistogram() metrics.Histogram
|
||||
EntrypointOpenConnsGauge() metrics.Gauge
|
||||
|
||||
// backend metrics
|
||||
BackendReqsCounter() metrics.Counter
|
||||
BackendReqDurationHistogram() metrics.Histogram
|
||||
BackendOpenConnsGauge() metrics.Gauge
|
||||
BackendRetriesCounter() metrics.Counter
|
||||
BackendServerUpGauge() metrics.Gauge
|
||||
}
|
||||
|
||||
// NewVoidRegistry is a noop implementation of metrics.Registry.
|
||||
// It is used to avoid nil checking in components that do metric collections.
|
||||
func NewVoidRegistry() Registry {
|
||||
return NewMultiRegistry([]Registry{})
|
||||
}
|
||||
|
||||
// NewMultiRegistry is an implementation of metrics.Registry that wraps multiple registries.
|
||||
// It handles the case when a registry hasn't registered some metric and returns nil.
|
||||
// This allows for feature imparity between the different metric implementations.
|
||||
func NewMultiRegistry(registries []Registry) Registry {
|
||||
var configReloadsCounter []metrics.Counter
|
||||
var configReloadsFailureCounter []metrics.Counter
|
||||
var lastConfigReloadSuccessGauge []metrics.Gauge
|
||||
var lastConfigReloadFailureGauge []metrics.Gauge
|
||||
var entrypointReqsCounter []metrics.Counter
|
||||
var entrypointReqDurationHistogram []metrics.Histogram
|
||||
var entrypointOpenConnsGauge []metrics.Gauge
|
||||
var backendReqsCounter []metrics.Counter
|
||||
var backendReqDurationHistogram []metrics.Histogram
|
||||
var backendOpenConnsGauge []metrics.Gauge
|
||||
var backendRetriesCounter []metrics.Counter
|
||||
var backendServerUpGauge []metrics.Gauge
|
||||
|
||||
for _, r := range registries {
|
||||
if r.ConfigReloadsCounter() != nil {
|
||||
configReloadsCounter = append(configReloadsCounter, r.ConfigReloadsCounter())
|
||||
}
|
||||
if r.ConfigReloadsFailureCounter() != nil {
|
||||
configReloadsFailureCounter = append(configReloadsFailureCounter, r.ConfigReloadsFailureCounter())
|
||||
}
|
||||
if r.LastConfigReloadSuccessGauge() != nil {
|
||||
lastConfigReloadSuccessGauge = append(lastConfigReloadSuccessGauge, r.LastConfigReloadSuccessGauge())
|
||||
}
|
||||
if r.LastConfigReloadFailureGauge() != nil {
|
||||
lastConfigReloadFailureGauge = append(lastConfigReloadFailureGauge, r.LastConfigReloadFailureGauge())
|
||||
}
|
||||
if r.EntrypointReqsCounter() != nil {
|
||||
entrypointReqsCounter = append(entrypointReqsCounter, r.EntrypointReqsCounter())
|
||||
}
|
||||
if r.EntrypointReqDurationHistogram() != nil {
|
||||
entrypointReqDurationHistogram = append(entrypointReqDurationHistogram, r.EntrypointReqDurationHistogram())
|
||||
}
|
||||
if r.EntrypointOpenConnsGauge() != nil {
|
||||
entrypointOpenConnsGauge = append(entrypointOpenConnsGauge, r.EntrypointOpenConnsGauge())
|
||||
}
|
||||
if r.BackendReqsCounter() != nil {
|
||||
backendReqsCounter = append(backendReqsCounter, r.BackendReqsCounter())
|
||||
}
|
||||
if r.BackendReqDurationHistogram() != nil {
|
||||
backendReqDurationHistogram = append(backendReqDurationHistogram, r.BackendReqDurationHistogram())
|
||||
}
|
||||
if r.BackendOpenConnsGauge() != nil {
|
||||
backendOpenConnsGauge = append(backendOpenConnsGauge, r.BackendOpenConnsGauge())
|
||||
}
|
||||
if r.BackendRetriesCounter() != nil {
|
||||
backendRetriesCounter = append(backendRetriesCounter, r.BackendRetriesCounter())
|
||||
}
|
||||
if r.BackendServerUpGauge() != nil {
|
||||
backendServerUpGauge = append(backendServerUpGauge, r.BackendServerUpGauge())
|
||||
}
|
||||
}
|
||||
|
||||
return &standardRegistry{
|
||||
enabled: len(registries) > 0,
|
||||
configReloadsCounter: multi.NewCounter(configReloadsCounter...),
|
||||
configReloadsFailureCounter: multi.NewCounter(configReloadsFailureCounter...),
|
||||
lastConfigReloadSuccessGauge: multi.NewGauge(lastConfigReloadSuccessGauge...),
|
||||
lastConfigReloadFailureGauge: multi.NewGauge(lastConfigReloadFailureGauge...),
|
||||
entrypointReqsCounter: multi.NewCounter(entrypointReqsCounter...),
|
||||
entrypointReqDurationHistogram: multi.NewHistogram(entrypointReqDurationHistogram...),
|
||||
entrypointOpenConnsGauge: multi.NewGauge(entrypointOpenConnsGauge...),
|
||||
backendReqsCounter: multi.NewCounter(backendReqsCounter...),
|
||||
backendReqDurationHistogram: multi.NewHistogram(backendReqDurationHistogram...),
|
||||
backendOpenConnsGauge: multi.NewGauge(backendOpenConnsGauge...),
|
||||
backendRetriesCounter: multi.NewCounter(backendRetriesCounter...),
|
||||
backendServerUpGauge: multi.NewGauge(backendServerUpGauge...),
|
||||
}
|
||||
}
|
||||
|
||||
type standardRegistry struct {
|
||||
enabled bool
|
||||
configReloadsCounter metrics.Counter
|
||||
configReloadsFailureCounter metrics.Counter
|
||||
lastConfigReloadSuccessGauge metrics.Gauge
|
||||
lastConfigReloadFailureGauge metrics.Gauge
|
||||
entrypointReqsCounter metrics.Counter
|
||||
entrypointReqDurationHistogram metrics.Histogram
|
||||
entrypointOpenConnsGauge metrics.Gauge
|
||||
backendReqsCounter metrics.Counter
|
||||
backendReqDurationHistogram metrics.Histogram
|
||||
backendOpenConnsGauge metrics.Gauge
|
||||
backendRetriesCounter metrics.Counter
|
||||
backendServerUpGauge metrics.Gauge
|
||||
}
|
||||
|
||||
func (r *standardRegistry) IsEnabled() bool {
|
||||
return r.enabled
|
||||
}
|
||||
|
||||
func (r *standardRegistry) ConfigReloadsCounter() metrics.Counter {
|
||||
return r.configReloadsCounter
|
||||
}
|
||||
|
||||
func (r *standardRegistry) ConfigReloadsFailureCounter() metrics.Counter {
|
||||
return r.configReloadsFailureCounter
|
||||
}
|
||||
|
||||
func (r *standardRegistry) LastConfigReloadSuccessGauge() metrics.Gauge {
|
||||
return r.lastConfigReloadSuccessGauge
|
||||
}
|
||||
|
||||
func (r *standardRegistry) LastConfigReloadFailureGauge() metrics.Gauge {
|
||||
return r.lastConfigReloadFailureGauge
|
||||
}
|
||||
|
||||
func (r *standardRegistry) EntrypointReqsCounter() metrics.Counter {
|
||||
return r.entrypointReqsCounter
|
||||
}
|
||||
|
||||
func (r *standardRegistry) EntrypointReqDurationHistogram() metrics.Histogram {
|
||||
return r.entrypointReqDurationHistogram
|
||||
}
|
||||
|
||||
func (r *standardRegistry) EntrypointOpenConnsGauge() metrics.Gauge {
|
||||
return r.entrypointOpenConnsGauge
|
||||
}
|
||||
|
||||
func (r *standardRegistry) BackendReqsCounter() metrics.Counter {
|
||||
return r.backendReqsCounter
|
||||
}
|
||||
|
||||
func (r *standardRegistry) BackendReqDurationHistogram() metrics.Histogram {
|
||||
return r.backendReqDurationHistogram
|
||||
}
|
||||
|
||||
func (r *standardRegistry) BackendOpenConnsGauge() metrics.Gauge {
|
||||
return r.backendOpenConnsGauge
|
||||
}
|
||||
|
||||
func (r *standardRegistry) BackendRetriesCounter() metrics.Counter {
|
||||
return r.backendRetriesCounter
|
||||
}
|
||||
|
||||
func (r *standardRegistry) BackendServerUpGauge() metrics.Gauge {
|
||||
return r.backendServerUpGauge
|
||||
}
|
76
pkg/metrics/metrics_test.go
Normal file
76
pkg/metrics/metrics_test.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
package metrics
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/go-kit/kit/metrics"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewMultiRegistry(t *testing.T) {
|
||||
registries := []Registry{newCollectingRetryMetrics(), newCollectingRetryMetrics()}
|
||||
registry := NewMultiRegistry(registries)
|
||||
|
||||
registry.BackendReqsCounter().With("key", "requests").Add(1)
|
||||
registry.BackendReqDurationHistogram().With("key", "durations").Observe(2)
|
||||
registry.BackendRetriesCounter().With("key", "retries").Add(3)
|
||||
|
||||
for _, collectingRegistry := range registries {
|
||||
cReqsCounter := collectingRegistry.BackendReqsCounter().(*counterMock)
|
||||
cReqDurationHistogram := collectingRegistry.BackendReqDurationHistogram().(*histogramMock)
|
||||
cRetriesCounter := collectingRegistry.BackendRetriesCounter().(*counterMock)
|
||||
|
||||
wantCounterValue := float64(1)
|
||||
if cReqsCounter.counterValue != wantCounterValue {
|
||||
t.Errorf("Got value %f for ReqsCounter, want %f", cReqsCounter.counterValue, wantCounterValue)
|
||||
}
|
||||
wantHistogramValue := float64(2)
|
||||
if cReqDurationHistogram.lastHistogramValue != wantHistogramValue {
|
||||
t.Errorf("Got last observation %f for ReqDurationHistogram, want %f", cReqDurationHistogram.lastHistogramValue, wantHistogramValue)
|
||||
}
|
||||
wantCounterValue = float64(3)
|
||||
if cRetriesCounter.counterValue != wantCounterValue {
|
||||
t.Errorf("Got value %f for RetriesCounter, want %f", cRetriesCounter.counterValue, wantCounterValue)
|
||||
}
|
||||
|
||||
assert.Equal(t, []string{"key", "requests"}, cReqsCounter.lastLabelValues)
|
||||
assert.Equal(t, []string{"key", "durations"}, cReqDurationHistogram.lastLabelValues)
|
||||
assert.Equal(t, []string{"key", "retries"}, cRetriesCounter.lastLabelValues)
|
||||
}
|
||||
}
|
||||
|
||||
func newCollectingRetryMetrics() Registry {
|
||||
return &standardRegistry{
|
||||
backendReqsCounter: &counterMock{},
|
||||
backendReqDurationHistogram: &histogramMock{},
|
||||
backendRetriesCounter: &counterMock{},
|
||||
}
|
||||
}
|
||||
|
||||
type counterMock struct {
|
||||
counterValue float64
|
||||
lastLabelValues []string
|
||||
}
|
||||
|
||||
func (c *counterMock) With(labelValues ...string) metrics.Counter {
|
||||
c.lastLabelValues = labelValues
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *counterMock) Add(delta float64) {
|
||||
c.counterValue += delta
|
||||
}
|
||||
|
||||
type histogramMock struct {
|
||||
lastHistogramValue float64
|
||||
lastLabelValues []string
|
||||
}
|
||||
|
||||
func (c *histogramMock) With(labelValues ...string) metrics.Histogram {
|
||||
c.lastLabelValues = labelValues
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *histogramMock) Observe(value float64) {
|
||||
c.lastHistogramValue = value
|
||||
}
|
512
pkg/metrics/prometheus.go
Normal file
512
pkg/metrics/prometheus.go
Normal file
|
@ -0,0 +1,512 @@
|
|||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
"github.com/containous/traefik/pkg/safe"
|
||||
"github.com/containous/traefik/pkg/types"
|
||||
"github.com/go-kit/kit/metrics"
|
||||
stdprometheus "github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
const (
|
||||
// MetricNamePrefix prefix of all metric names
|
||||
MetricNamePrefix = "traefik_"
|
||||
|
||||
// server meta information
|
||||
metricConfigPrefix = MetricNamePrefix + "config_"
|
||||
configReloadsTotalName = metricConfigPrefix + "reloads_total"
|
||||
configReloadsFailuresTotalName = metricConfigPrefix + "reloads_failure_total"
|
||||
configLastReloadSuccessName = metricConfigPrefix + "last_reload_success"
|
||||
configLastReloadFailureName = metricConfigPrefix + "last_reload_failure"
|
||||
|
||||
// entrypoint
|
||||
metricEntryPointPrefix = MetricNamePrefix + "entrypoint_"
|
||||
entrypointReqsTotalName = metricEntryPointPrefix + "requests_total"
|
||||
entrypointReqDurationName = metricEntryPointPrefix + "request_duration_seconds"
|
||||
entrypointOpenConnsName = metricEntryPointPrefix + "open_connections"
|
||||
|
||||
// backend level.
|
||||
|
||||
// MetricBackendPrefix prefix of all backend metric names
|
||||
MetricBackendPrefix = MetricNamePrefix + "backend_"
|
||||
backendReqsTotalName = MetricBackendPrefix + "requests_total"
|
||||
backendReqDurationName = MetricBackendPrefix + "request_duration_seconds"
|
||||
backendOpenConnsName = MetricBackendPrefix + "open_connections"
|
||||
backendRetriesTotalName = MetricBackendPrefix + "retries_total"
|
||||
backendServerUpName = MetricBackendPrefix + "server_up"
|
||||
)
|
||||
|
||||
// promState holds all metric state internally and acts as the only Collector we register for Prometheus.
|
||||
//
|
||||
// This enables control to remove metrics that belong to outdated configuration.
|
||||
// As an example why this is required, consider Traefik learns about a new service.
|
||||
// It populates the 'traefik_server_backend_up' metric for it with a value of 1 (alive).
|
||||
// When the backend is undeployed now the metric is still there in the client library
|
||||
// and will be returned on the metrics endpoint until Traefik would be restarted.
|
||||
//
|
||||
// To solve this problem promState keeps track of Traefik's dynamic configuration.
|
||||
// Metrics that "belong" to a dynamic configuration part like backends or entrypoints
|
||||
// are removed after they were scraped at least once when the corresponding object
|
||||
// doesn't exist anymore.
|
||||
var promState = newPrometheusState()
|
||||
|
||||
// PrometheusHandler exposes Prometheus routes.
|
||||
type PrometheusHandler struct{}
|
||||
|
||||
// Append adds Prometheus routes on a router.
|
||||
func (h PrometheusHandler) Append(router *mux.Router) {
|
||||
router.Methods(http.MethodGet).Path("/metrics").Handler(promhttp.Handler())
|
||||
}
|
||||
|
||||
// RegisterPrometheus registers all Prometheus metrics.
|
||||
// It must be called only once and failing to register the metrics will lead to a panic.
|
||||
func RegisterPrometheus(ctx context.Context, config *types.Prometheus) Registry {
|
||||
standardRegistry := initStandardRegistry(config)
|
||||
|
||||
if !registerPromState(ctx) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return standardRegistry
|
||||
}
|
||||
|
||||
func initStandardRegistry(config *types.Prometheus) Registry {
|
||||
buckets := []float64{0.1, 0.3, 1.2, 5.0}
|
||||
if config.Buckets != nil {
|
||||
buckets = config.Buckets
|
||||
}
|
||||
|
||||
safe.Go(func() {
|
||||
promState.ListenValueUpdates()
|
||||
})
|
||||
|
||||
configReloads := newCounterFrom(promState.collectors, stdprometheus.CounterOpts{
|
||||
Name: configReloadsTotalName,
|
||||
Help: "Config reloads",
|
||||
}, []string{})
|
||||
configReloadsFailures := newCounterFrom(promState.collectors, stdprometheus.CounterOpts{
|
||||
Name: configReloadsFailuresTotalName,
|
||||
Help: "Config failure reloads",
|
||||
}, []string{})
|
||||
lastConfigReloadSuccess := newGaugeFrom(promState.collectors, stdprometheus.GaugeOpts{
|
||||
Name: configLastReloadSuccessName,
|
||||
Help: "Last config reload success",
|
||||
}, []string{})
|
||||
lastConfigReloadFailure := newGaugeFrom(promState.collectors, stdprometheus.GaugeOpts{
|
||||
Name: configLastReloadFailureName,
|
||||
Help: "Last config reload failure",
|
||||
}, []string{})
|
||||
|
||||
entrypointReqs := newCounterFrom(promState.collectors, stdprometheus.CounterOpts{
|
||||
Name: entrypointReqsTotalName,
|
||||
Help: "How many HTTP requests processed on an entrypoint, partitioned by status code, protocol, and method.",
|
||||
}, []string{"code", "method", "protocol", "entrypoint"})
|
||||
entrypointReqDurations := newHistogramFrom(promState.collectors, stdprometheus.HistogramOpts{
|
||||
Name: entrypointReqDurationName,
|
||||
Help: "How long it took to process the request on an entrypoint, partitioned by status code, protocol, and method.",
|
||||
Buckets: buckets,
|
||||
}, []string{"code", "method", "protocol", "entrypoint"})
|
||||
entrypointOpenConns := newGaugeFrom(promState.collectors, stdprometheus.GaugeOpts{
|
||||
Name: entrypointOpenConnsName,
|
||||
Help: "How many open connections exist on an entrypoint, partitioned by method and protocol.",
|
||||
}, []string{"method", "protocol", "entrypoint"})
|
||||
|
||||
backendReqs := newCounterFrom(promState.collectors, stdprometheus.CounterOpts{
|
||||
Name: backendReqsTotalName,
|
||||
Help: "How many HTTP requests processed on a backend, partitioned by status code, protocol, and method.",
|
||||
}, []string{"code", "method", "protocol", "backend"})
|
||||
backendReqDurations := newHistogramFrom(promState.collectors, stdprometheus.HistogramOpts{
|
||||
Name: backendReqDurationName,
|
||||
Help: "How long it took to process the request on a backend, partitioned by status code, protocol, and method.",
|
||||
Buckets: buckets,
|
||||
}, []string{"code", "method", "protocol", "backend"})
|
||||
backendOpenConns := newGaugeFrom(promState.collectors, stdprometheus.GaugeOpts{
|
||||
Name: backendOpenConnsName,
|
||||
Help: "How many open connections exist on a backend, partitioned by method and protocol.",
|
||||
}, []string{"method", "protocol", "backend"})
|
||||
backendRetries := newCounterFrom(promState.collectors, stdprometheus.CounterOpts{
|
||||
Name: backendRetriesTotalName,
|
||||
Help: "How many request retries happened on a backend.",
|
||||
}, []string{"backend"})
|
||||
backendServerUp := newGaugeFrom(promState.collectors, stdprometheus.GaugeOpts{
|
||||
Name: backendServerUpName,
|
||||
Help: "Backend server is up, described by gauge value of 0 or 1.",
|
||||
}, []string{"backend", "url"})
|
||||
|
||||
promState.describers = []func(chan<- *stdprometheus.Desc){
|
||||
configReloads.cv.Describe,
|
||||
configReloadsFailures.cv.Describe,
|
||||
lastConfigReloadSuccess.gv.Describe,
|
||||
lastConfigReloadFailure.gv.Describe,
|
||||
entrypointReqs.cv.Describe,
|
||||
entrypointReqDurations.hv.Describe,
|
||||
entrypointOpenConns.gv.Describe,
|
||||
backendReqs.cv.Describe,
|
||||
backendReqDurations.hv.Describe,
|
||||
backendOpenConns.gv.Describe,
|
||||
backendRetries.cv.Describe,
|
||||
backendServerUp.gv.Describe,
|
||||
}
|
||||
|
||||
return &standardRegistry{
|
||||
enabled: true,
|
||||
configReloadsCounter: configReloads,
|
||||
configReloadsFailureCounter: configReloadsFailures,
|
||||
lastConfigReloadSuccessGauge: lastConfigReloadSuccess,
|
||||
lastConfigReloadFailureGauge: lastConfigReloadFailure,
|
||||
entrypointReqsCounter: entrypointReqs,
|
||||
entrypointReqDurationHistogram: entrypointReqDurations,
|
||||
entrypointOpenConnsGauge: entrypointOpenConns,
|
||||
backendReqsCounter: backendReqs,
|
||||
backendReqDurationHistogram: backendReqDurations,
|
||||
backendOpenConnsGauge: backendOpenConns,
|
||||
backendRetriesCounter: backendRetries,
|
||||
backendServerUpGauge: backendServerUp,
|
||||
}
|
||||
}
|
||||
|
||||
func registerPromState(ctx context.Context) bool {
|
||||
if err := stdprometheus.Register(promState); err != nil {
|
||||
logger := log.FromContext(ctx)
|
||||
if _, ok := err.(stdprometheus.AlreadyRegisteredError); !ok {
|
||||
logger.Errorf("Unable to register Traefik to Prometheus: %v", err)
|
||||
return false
|
||||
}
|
||||
logger.Debug("Prometheus collector already registered.")
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// OnConfigurationUpdate receives the current configuration from Traefik.
|
||||
// It then converts the configuration to the optimized package internal format
|
||||
// and sets it to the promState.
|
||||
func OnConfigurationUpdate(configurations config.Configurations) {
|
||||
dynamicConfig := newDynamicConfig()
|
||||
|
||||
// FIXME metrics
|
||||
// for _, config := range configurations {
|
||||
// for _, frontend := range config.Frontends {
|
||||
// for _, entrypointName := range frontend.EntryPoints {
|
||||
// dynamicConfig.entrypoints[entrypointName] = true
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// for backendName, backend := range config.Backends {
|
||||
// dynamicConfig.backends[backendName] = make(map[string]bool)
|
||||
// for _, server := range backend.Servers {
|
||||
// dynamicConfig.backends[backendName][server.URL] = true
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
promState.SetDynamicConfig(dynamicConfig)
|
||||
}
|
||||
|
||||
func newPrometheusState() *prometheusState {
|
||||
return &prometheusState{
|
||||
collectors: make(chan *collector),
|
||||
dynamicConfig: newDynamicConfig(),
|
||||
state: make(map[string]*collector),
|
||||
}
|
||||
}
|
||||
|
||||
type prometheusState struct {
|
||||
collectors chan *collector
|
||||
describers []func(ch chan<- *stdprometheus.Desc)
|
||||
|
||||
mtx sync.Mutex
|
||||
dynamicConfig *dynamicConfig
|
||||
state map[string]*collector
|
||||
}
|
||||
|
||||
func (ps *prometheusState) SetDynamicConfig(dynamicConfig *dynamicConfig) {
|
||||
ps.mtx.Lock()
|
||||
defer ps.mtx.Unlock()
|
||||
ps.dynamicConfig = dynamicConfig
|
||||
}
|
||||
|
||||
func (ps *prometheusState) ListenValueUpdates() {
|
||||
for collector := range ps.collectors {
|
||||
ps.mtx.Lock()
|
||||
ps.state[collector.id] = collector
|
||||
ps.mtx.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Describe implements prometheus.Collector and simply calls
|
||||
// the registered describer functions.
|
||||
func (ps *prometheusState) Describe(ch chan<- *stdprometheus.Desc) {
|
||||
for _, desc := range ps.describers {
|
||||
desc(ch)
|
||||
}
|
||||
}
|
||||
|
||||
// Collect implements prometheus.Collector. It calls the Collect
|
||||
// method of all metrics it received on the collectors channel.
|
||||
// It's also responsible to remove metrics that belong to an outdated configuration.
|
||||
// The removal happens only after their Collect method was called to ensure that
|
||||
// also those metrics will be exported on the current scrape.
|
||||
func (ps *prometheusState) Collect(ch chan<- stdprometheus.Metric) {
|
||||
ps.mtx.Lock()
|
||||
defer ps.mtx.Unlock()
|
||||
|
||||
var outdatedKeys []string
|
||||
for key, cs := range ps.state {
|
||||
cs.collector.Collect(ch)
|
||||
|
||||
if ps.isOutdated(cs) {
|
||||
outdatedKeys = append(outdatedKeys, key)
|
||||
}
|
||||
}
|
||||
|
||||
for _, key := range outdatedKeys {
|
||||
ps.state[key].delete()
|
||||
delete(ps.state, key)
|
||||
}
|
||||
}
|
||||
|
||||
// isOutdated checks whether the passed collector has labels that mark
|
||||
// it as belonging to an outdated configuration of Traefik.
|
||||
func (ps *prometheusState) isOutdated(collector *collector) bool {
|
||||
labels := collector.labels
|
||||
|
||||
if entrypointName, ok := labels["entrypoint"]; ok && !ps.dynamicConfig.hasEntrypoint(entrypointName) {
|
||||
return true
|
||||
}
|
||||
|
||||
if backendName, ok := labels["backend"]; ok {
|
||||
if !ps.dynamicConfig.hasBackend(backendName) {
|
||||
return true
|
||||
}
|
||||
if url, ok := labels["url"]; ok && !ps.dynamicConfig.hasServerURL(backendName, url) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func newDynamicConfig() *dynamicConfig {
|
||||
return &dynamicConfig{
|
||||
entrypoints: make(map[string]bool),
|
||||
backends: make(map[string]map[string]bool),
|
||||
}
|
||||
}
|
||||
|
||||
// dynamicConfig holds the current configuration for entrypoints, backends,
|
||||
// and server URLs in an optimized way to check for existence. This provides
|
||||
// a performant way to check whether the collected metrics belong to the
|
||||
// current configuration or to an outdated one.
|
||||
type dynamicConfig struct {
|
||||
entrypoints map[string]bool
|
||||
backends map[string]map[string]bool
|
||||
}
|
||||
|
||||
func (d *dynamicConfig) hasEntrypoint(entrypointName string) bool {
|
||||
_, ok := d.entrypoints[entrypointName]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (d *dynamicConfig) hasBackend(backendName string) bool {
|
||||
_, ok := d.backends[backendName]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (d *dynamicConfig) hasServerURL(backendName, serverURL string) bool {
|
||||
if backend, hasBackend := d.backends[backendName]; hasBackend {
|
||||
_, ok := backend[serverURL]
|
||||
return ok
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func newCollector(metricName string, labels stdprometheus.Labels, c stdprometheus.Collector, delete func()) *collector {
|
||||
return &collector{
|
||||
id: buildMetricID(metricName, labels),
|
||||
labels: labels,
|
||||
collector: c,
|
||||
delete: delete,
|
||||
}
|
||||
}
|
||||
|
||||
// collector wraps a Collector object from the Prometheus client library.
|
||||
// It adds information on how many generations this metric should be present
|
||||
// in the /metrics output, relatived to the time it was last tracked.
|
||||
type collector struct {
|
||||
id string
|
||||
labels stdprometheus.Labels
|
||||
collector stdprometheus.Collector
|
||||
delete func()
|
||||
}
|
||||
|
||||
func buildMetricID(metricName string, labels stdprometheus.Labels) string {
|
||||
var labelNamesValues []string
|
||||
for name, value := range labels {
|
||||
labelNamesValues = append(labelNamesValues, name, value)
|
||||
}
|
||||
sort.Strings(labelNamesValues)
|
||||
return metricName + ":" + strings.Join(labelNamesValues, "|")
|
||||
}
|
||||
|
||||
func newCounterFrom(collectors chan<- *collector, opts stdprometheus.CounterOpts, labelNames []string) *counter {
|
||||
cv := stdprometheus.NewCounterVec(opts, labelNames)
|
||||
c := &counter{
|
||||
name: opts.Name,
|
||||
cv: cv,
|
||||
collectors: collectors,
|
||||
}
|
||||
if len(labelNames) == 0 {
|
||||
c.Add(0)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
type counter struct {
|
||||
name string
|
||||
cv *stdprometheus.CounterVec
|
||||
labelNamesValues labelNamesValues
|
||||
collectors chan<- *collector
|
||||
}
|
||||
|
||||
func (c *counter) With(labelValues ...string) metrics.Counter {
|
||||
return &counter{
|
||||
name: c.name,
|
||||
cv: c.cv,
|
||||
labelNamesValues: c.labelNamesValues.With(labelValues...),
|
||||
collectors: c.collectors,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *counter) Add(delta float64) {
|
||||
labels := c.labelNamesValues.ToLabels()
|
||||
collector := c.cv.With(labels)
|
||||
collector.Add(delta)
|
||||
c.collectors <- newCollector(c.name, labels, collector, func() {
|
||||
c.cv.Delete(labels)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *counter) Describe(ch chan<- *stdprometheus.Desc) {
|
||||
c.cv.Describe(ch)
|
||||
}
|
||||
|
||||
func newGaugeFrom(collectors chan<- *collector, opts stdprometheus.GaugeOpts, labelNames []string) *gauge {
|
||||
gv := stdprometheus.NewGaugeVec(opts, labelNames)
|
||||
g := &gauge{
|
||||
name: opts.Name,
|
||||
gv: gv,
|
||||
collectors: collectors,
|
||||
}
|
||||
if len(labelNames) == 0 {
|
||||
g.Set(0)
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
type gauge struct {
|
||||
name string
|
||||
gv *stdprometheus.GaugeVec
|
||||
labelNamesValues labelNamesValues
|
||||
collectors chan<- *collector
|
||||
}
|
||||
|
||||
func (g *gauge) With(labelValues ...string) metrics.Gauge {
|
||||
return &gauge{
|
||||
name: g.name,
|
||||
gv: g.gv,
|
||||
labelNamesValues: g.labelNamesValues.With(labelValues...),
|
||||
collectors: g.collectors,
|
||||
}
|
||||
}
|
||||
|
||||
func (g *gauge) Add(delta float64) {
|
||||
labels := g.labelNamesValues.ToLabels()
|
||||
collector := g.gv.With(labels)
|
||||
collector.Add(delta)
|
||||
g.collectors <- newCollector(g.name, labels, collector, func() {
|
||||
g.gv.Delete(labels)
|
||||
})
|
||||
}
|
||||
|
||||
func (g *gauge) Set(value float64) {
|
||||
labels := g.labelNamesValues.ToLabels()
|
||||
collector := g.gv.With(labels)
|
||||
collector.Set(value)
|
||||
g.collectors <- newCollector(g.name, labels, collector, func() {
|
||||
g.gv.Delete(labels)
|
||||
})
|
||||
}
|
||||
|
||||
func (g *gauge) Describe(ch chan<- *stdprometheus.Desc) {
|
||||
g.gv.Describe(ch)
|
||||
}
|
||||
|
||||
func newHistogramFrom(collectors chan<- *collector, opts stdprometheus.HistogramOpts, labelNames []string) *histogram {
|
||||
hv := stdprometheus.NewHistogramVec(opts, labelNames)
|
||||
return &histogram{
|
||||
name: opts.Name,
|
||||
hv: hv,
|
||||
collectors: collectors,
|
||||
}
|
||||
}
|
||||
|
||||
type histogram struct {
|
||||
name string
|
||||
hv *stdprometheus.HistogramVec
|
||||
labelNamesValues labelNamesValues
|
||||
collectors chan<- *collector
|
||||
}
|
||||
|
||||
func (h *histogram) With(labelValues ...string) metrics.Histogram {
|
||||
return &histogram{
|
||||
name: h.name,
|
||||
hv: h.hv,
|
||||
labelNamesValues: h.labelNamesValues.With(labelValues...),
|
||||
collectors: h.collectors,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *histogram) Observe(value float64) {
|
||||
labels := h.labelNamesValues.ToLabels()
|
||||
collector := h.hv.With(labels)
|
||||
collector.Observe(value)
|
||||
h.collectors <- newCollector(h.name, labels, collector, func() {
|
||||
h.hv.Delete(labels)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *histogram) Describe(ch chan<- *stdprometheus.Desc) {
|
||||
h.hv.Describe(ch)
|
||||
}
|
||||
|
||||
// labelNamesValues is a type alias that provides validation on its With method.
|
||||
// Metrics may include it as a member to help them satisfy With semantics and
|
||||
// save some code duplication.
|
||||
type labelNamesValues []string
|
||||
|
||||
// With validates the input, and returns a new aggregate labelNamesValues.
|
||||
func (lvs labelNamesValues) With(labelValues ...string) labelNamesValues {
|
||||
if len(labelValues)%2 != 0 {
|
||||
labelValues = append(labelValues, "unknown")
|
||||
}
|
||||
return append(lvs, labelValues...)
|
||||
}
|
||||
|
||||
// ToLabels is a convenience method to convert a labelNamesValues
|
||||
// to the native prometheus.Labels.
|
||||
func (lvs labelNamesValues) ToLabels() stdprometheus.Labels {
|
||||
labels := stdprometheus.Labels{}
|
||||
for i := 0; i < len(lvs); i += 2 {
|
||||
labels[lvs[i]] = lvs[i+1]
|
||||
}
|
||||
return labels
|
||||
}
|
504
pkg/metrics/prometheus_test.go
Normal file
504
pkg/metrics/prometheus_test.go
Normal file
|
@ -0,0 +1,504 @@
|
|||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
th "github.com/containous/traefik/pkg/testhelpers"
|
||||
"github.com/containous/traefik/pkg/types"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRegisterPromState(t *testing.T) {
|
||||
// Reset state of global promState.
|
||||
defer promState.reset()
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
prometheusSlice []*types.Prometheus
|
||||
initPromState bool
|
||||
unregisterPromState bool
|
||||
expectedNbRegistries int
|
||||
}{
|
||||
{
|
||||
desc: "Register once",
|
||||
prometheusSlice: []*types.Prometheus{{}},
|
||||
expectedNbRegistries: 1,
|
||||
initPromState: true,
|
||||
},
|
||||
{
|
||||
desc: "Register once with no promState init",
|
||||
prometheusSlice: []*types.Prometheus{{}},
|
||||
expectedNbRegistries: 0,
|
||||
},
|
||||
{
|
||||
desc: "Register twice",
|
||||
prometheusSlice: []*types.Prometheus{{}, {}},
|
||||
expectedNbRegistries: 2,
|
||||
initPromState: true,
|
||||
},
|
||||
{
|
||||
desc: "Register twice with no promstate init",
|
||||
prometheusSlice: []*types.Prometheus{{}, {}},
|
||||
expectedNbRegistries: 0,
|
||||
},
|
||||
{
|
||||
desc: "Register twice with unregister",
|
||||
prometheusSlice: []*types.Prometheus{{}, {}},
|
||||
unregisterPromState: true,
|
||||
expectedNbRegistries: 2,
|
||||
initPromState: true,
|
||||
},
|
||||
{
|
||||
desc: "Register twice with unregister but no promstate init",
|
||||
prometheusSlice: []*types.Prometheus{{}, {}},
|
||||
unregisterPromState: true,
|
||||
expectedNbRegistries: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
actualNbRegistries := 0
|
||||
for _, prom := range test.prometheusSlice {
|
||||
if test.initPromState {
|
||||
initStandardRegistry(prom)
|
||||
}
|
||||
|
||||
if registerPromState(context.Background()) {
|
||||
actualNbRegistries++
|
||||
}
|
||||
|
||||
if test.unregisterPromState {
|
||||
prometheus.Unregister(promState)
|
||||
}
|
||||
|
||||
promState.reset()
|
||||
}
|
||||
|
||||
prometheus.Unregister(promState)
|
||||
|
||||
assert.Equal(t, test.expectedNbRegistries, actualNbRegistries)
|
||||
}
|
||||
}
|
||||
|
||||
// reset is a utility method for unit testing. It should be called after each
|
||||
// test run that changes promState internally in order to avoid dependencies
|
||||
// between unit tests.
|
||||
func (ps *prometheusState) reset() {
|
||||
ps.collectors = make(chan *collector)
|
||||
ps.describers = []func(ch chan<- *prometheus.Desc){}
|
||||
ps.dynamicConfig = newDynamicConfig()
|
||||
ps.state = make(map[string]*collector)
|
||||
}
|
||||
|
||||
func TestPrometheus(t *testing.T) {
|
||||
// Reset state of global promState.
|
||||
defer promState.reset()
|
||||
|
||||
prometheusRegistry := RegisterPrometheus(context.Background(), &types.Prometheus{})
|
||||
defer prometheus.Unregister(promState)
|
||||
|
||||
if !prometheusRegistry.IsEnabled() {
|
||||
t.Errorf("PrometheusRegistry should return true for IsEnabled()")
|
||||
}
|
||||
|
||||
prometheusRegistry.ConfigReloadsCounter().Add(1)
|
||||
prometheusRegistry.ConfigReloadsFailureCounter().Add(1)
|
||||
prometheusRegistry.LastConfigReloadSuccessGauge().Set(float64(time.Now().Unix()))
|
||||
prometheusRegistry.LastConfigReloadFailureGauge().Set(float64(time.Now().Unix()))
|
||||
|
||||
prometheusRegistry.
|
||||
EntrypointReqsCounter().
|
||||
With("code", strconv.Itoa(http.StatusOK), "method", http.MethodGet, "protocol", "http", "entrypoint", "http").
|
||||
Add(1)
|
||||
prometheusRegistry.
|
||||
EntrypointReqDurationHistogram().
|
||||
With("code", strconv.Itoa(http.StatusOK), "method", http.MethodGet, "protocol", "http", "entrypoint", "http").
|
||||
Observe(1)
|
||||
prometheusRegistry.
|
||||
EntrypointOpenConnsGauge().
|
||||
With("method", http.MethodGet, "protocol", "http", "entrypoint", "http").
|
||||
Set(1)
|
||||
|
||||
prometheusRegistry.
|
||||
BackendReqsCounter().
|
||||
With("backend", "backend1", "code", strconv.Itoa(http.StatusOK), "method", http.MethodGet, "protocol", "http").
|
||||
Add(1)
|
||||
prometheusRegistry.
|
||||
BackendReqDurationHistogram().
|
||||
With("backend", "backend1", "code", strconv.Itoa(http.StatusOK), "method", http.MethodGet, "protocol", "http").
|
||||
Observe(10000)
|
||||
prometheusRegistry.
|
||||
BackendOpenConnsGauge().
|
||||
With("backend", "backend1", "method", http.MethodGet, "protocol", "http").
|
||||
Set(1)
|
||||
prometheusRegistry.
|
||||
BackendRetriesCounter().
|
||||
With("backend", "backend1").
|
||||
Add(1)
|
||||
prometheusRegistry.
|
||||
BackendServerUpGauge().
|
||||
With("backend", "backend1", "url", "http://127.0.0.10:80").
|
||||
Set(1)
|
||||
|
||||
delayForTrackingCompletion()
|
||||
|
||||
metricsFamilies := mustScrape()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
labels map[string]string
|
||||
assert func(*dto.MetricFamily)
|
||||
}{
|
||||
{
|
||||
name: configReloadsTotalName,
|
||||
assert: buildCounterAssert(t, configReloadsTotalName, 1),
|
||||
},
|
||||
{
|
||||
name: configReloadsFailuresTotalName,
|
||||
assert: buildCounterAssert(t, configReloadsFailuresTotalName, 1),
|
||||
},
|
||||
{
|
||||
name: configLastReloadSuccessName,
|
||||
assert: buildTimestampAssert(t, configLastReloadSuccessName),
|
||||
},
|
||||
{
|
||||
name: configLastReloadFailureName,
|
||||
assert: buildTimestampAssert(t, configLastReloadFailureName),
|
||||
},
|
||||
{
|
||||
name: entrypointReqsTotalName,
|
||||
labels: map[string]string{
|
||||
"code": "200",
|
||||
"method": http.MethodGet,
|
||||
"protocol": "http",
|
||||
"entrypoint": "http",
|
||||
},
|
||||
assert: buildCounterAssert(t, entrypointReqsTotalName, 1),
|
||||
},
|
||||
{
|
||||
name: entrypointReqDurationName,
|
||||
labels: map[string]string{
|
||||
"code": "200",
|
||||
"method": http.MethodGet,
|
||||
"protocol": "http",
|
||||
"entrypoint": "http",
|
||||
},
|
||||
assert: buildHistogramAssert(t, entrypointReqDurationName, 1),
|
||||
},
|
||||
{
|
||||
name: entrypointOpenConnsName,
|
||||
labels: map[string]string{
|
||||
"method": http.MethodGet,
|
||||
"protocol": "http",
|
||||
"entrypoint": "http",
|
||||
},
|
||||
assert: buildGaugeAssert(t, entrypointOpenConnsName, 1),
|
||||
},
|
||||
{
|
||||
name: backendReqsTotalName,
|
||||
labels: map[string]string{
|
||||
"code": "200",
|
||||
"method": http.MethodGet,
|
||||
"protocol": "http",
|
||||
"backend": "backend1",
|
||||
},
|
||||
assert: buildCounterAssert(t, backendReqsTotalName, 1),
|
||||
},
|
||||
{
|
||||
name: backendReqDurationName,
|
||||
labels: map[string]string{
|
||||
"code": "200",
|
||||
"method": http.MethodGet,
|
||||
"protocol": "http",
|
||||
"backend": "backend1",
|
||||
},
|
||||
assert: buildHistogramAssert(t, backendReqDurationName, 1),
|
||||
},
|
||||
{
|
||||
name: backendOpenConnsName,
|
||||
labels: map[string]string{
|
||||
"method": http.MethodGet,
|
||||
"protocol": "http",
|
||||
"backend": "backend1",
|
||||
},
|
||||
assert: buildGaugeAssert(t, backendOpenConnsName, 1),
|
||||
},
|
||||
{
|
||||
name: backendRetriesTotalName,
|
||||
labels: map[string]string{
|
||||
"backend": "backend1",
|
||||
},
|
||||
assert: buildGreaterThanCounterAssert(t, backendRetriesTotalName, 1),
|
||||
},
|
||||
{
|
||||
name: backendServerUpName,
|
||||
labels: map[string]string{
|
||||
"backend": "backend1",
|
||||
"url": "http://127.0.0.10:80",
|
||||
},
|
||||
assert: buildGaugeAssert(t, backendServerUpName, 1),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
family := findMetricFamily(test.name, metricsFamilies)
|
||||
if family == nil {
|
||||
t.Errorf("gathered metrics do not contain %q", test.name)
|
||||
continue
|
||||
}
|
||||
for _, label := range family.Metric[0].Label {
|
||||
val, ok := test.labels[*label.Name]
|
||||
if !ok {
|
||||
t.Errorf("%q metric contains unexpected label %q", test.name, *label.Name)
|
||||
} else if val != *label.Value {
|
||||
t.Errorf("label %q in metric %q has wrong value %q, expected %q", *label.Name, test.name, *label.Value, val)
|
||||
}
|
||||
}
|
||||
test.assert(family)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrometheusMetricRemoval(t *testing.T) {
|
||||
// FIXME metrics
|
||||
t.Skip("waiting for metrics")
|
||||
|
||||
// Reset state of global promState.
|
||||
defer promState.reset()
|
||||
|
||||
prometheusRegistry := RegisterPrometheus(context.Background(), &types.Prometheus{})
|
||||
defer prometheus.Unregister(promState)
|
||||
|
||||
configurations := make(config.Configurations)
|
||||
configurations["providerName"] = &config.Configuration{
|
||||
HTTP: th.BuildConfiguration(
|
||||
th.WithRouters(
|
||||
th.WithRouter("foo",
|
||||
th.WithServiceName("bar")),
|
||||
),
|
||||
th.WithLoadBalancerServices(th.WithService("bar",
|
||||
th.WithLBMethod("wrr"),
|
||||
th.WithServers(th.WithServer("http://localhost:9000"))),
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
OnConfigurationUpdate(configurations)
|
||||
|
||||
// Register some metrics manually that are not part of the active configuration.
|
||||
// Those metrics should be part of the /metrics output on the first scrape but
|
||||
// should be removed after that scrape.
|
||||
prometheusRegistry.
|
||||
EntrypointReqsCounter().
|
||||
With("entrypoint", "entrypoint2", "code", strconv.Itoa(http.StatusOK), "method", http.MethodGet, "protocol", "http").
|
||||
Add(1)
|
||||
prometheusRegistry.
|
||||
BackendReqsCounter().
|
||||
With("backend", "backend2", "code", strconv.Itoa(http.StatusOK), "method", http.MethodGet, "protocol", "http").
|
||||
Add(1)
|
||||
prometheusRegistry.
|
||||
BackendServerUpGauge().
|
||||
With("backend", "backend1", "url", "http://localhost:9999").
|
||||
Set(1)
|
||||
|
||||
delayForTrackingCompletion()
|
||||
|
||||
assertMetricsExist(t, mustScrape(), entrypointReqsTotalName, backendReqsTotalName, backendServerUpName)
|
||||
assertMetricsAbsent(t, mustScrape(), entrypointReqsTotalName, backendReqsTotalName, backendServerUpName)
|
||||
|
||||
// To verify that metrics belonging to active configurations are not removed
|
||||
// here the counter examples.
|
||||
prometheusRegistry.
|
||||
EntrypointReqsCounter().
|
||||
With("entrypoint", "entrypoint1", "code", strconv.Itoa(http.StatusOK), "method", http.MethodGet, "protocol", "http").
|
||||
Add(1)
|
||||
|
||||
delayForTrackingCompletion()
|
||||
|
||||
assertMetricsExist(t, mustScrape(), entrypointReqsTotalName)
|
||||
assertMetricsExist(t, mustScrape(), entrypointReqsTotalName)
|
||||
}
|
||||
|
||||
func TestPrometheusRemovedMetricsReset(t *testing.T) {
|
||||
// Reset state of global promState.
|
||||
defer promState.reset()
|
||||
|
||||
prometheusRegistry := RegisterPrometheus(context.Background(), &types.Prometheus{})
|
||||
defer prometheus.Unregister(promState)
|
||||
|
||||
labelNamesValues := []string{
|
||||
"backend", "backend",
|
||||
"code", strconv.Itoa(http.StatusOK),
|
||||
"method", http.MethodGet,
|
||||
"protocol", "http",
|
||||
}
|
||||
prometheusRegistry.
|
||||
BackendReqsCounter().
|
||||
With(labelNamesValues...).
|
||||
Add(3)
|
||||
|
||||
delayForTrackingCompletion()
|
||||
|
||||
metricsFamilies := mustScrape()
|
||||
assertCounterValue(t, 3, findMetricFamily(backendReqsTotalName, metricsFamilies), labelNamesValues...)
|
||||
|
||||
// There is no dynamic configuration and so this metric will be deleted
|
||||
// after the first scrape.
|
||||
assertMetricsAbsent(t, mustScrape(), backendReqsTotalName)
|
||||
|
||||
prometheusRegistry.
|
||||
BackendReqsCounter().
|
||||
With(labelNamesValues...).
|
||||
Add(1)
|
||||
|
||||
delayForTrackingCompletion()
|
||||
|
||||
metricsFamilies = mustScrape()
|
||||
assertCounterValue(t, 1, findMetricFamily(backendReqsTotalName, metricsFamilies), labelNamesValues...)
|
||||
}
|
||||
|
||||
// Tracking and gathering the metrics happens concurrently.
|
||||
// In practice this is no problem, because in case a tracked metric would miss
|
||||
// the current scrape, it would just be there in the next one.
|
||||
// That we can test reliably the tracking of all metrics here, we sleep
|
||||
// for a short amount of time, to make sure the metric will be present
|
||||
// in the next scrape.
|
||||
func delayForTrackingCompletion() {
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
}
|
||||
|
||||
func mustScrape() []*dto.MetricFamily {
|
||||
families, err := prometheus.DefaultGatherer.Gather()
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("could not gather metrics families: %s", err))
|
||||
}
|
||||
return families
|
||||
}
|
||||
|
||||
func assertMetricsExist(t *testing.T, families []*dto.MetricFamily, metricNames ...string) {
|
||||
t.Helper()
|
||||
|
||||
for _, metricName := range metricNames {
|
||||
if findMetricFamily(metricName, families) == nil {
|
||||
t.Errorf("gathered metrics should contain %q", metricName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assertMetricsAbsent(t *testing.T, families []*dto.MetricFamily, metricNames ...string) {
|
||||
t.Helper()
|
||||
|
||||
for _, metricName := range metricNames {
|
||||
if findMetricFamily(metricName, families) != nil {
|
||||
t.Errorf("gathered metrics should not contain %q", metricName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func findMetricFamily(name string, families []*dto.MetricFamily) *dto.MetricFamily {
|
||||
for _, family := range families {
|
||||
if family.GetName() == name {
|
||||
return family
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func findMetricByLabelNamesValues(family *dto.MetricFamily, labelNamesValues ...string) *dto.Metric {
|
||||
if family == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, metric := range family.Metric {
|
||||
if hasMetricAllLabelPairs(metric, labelNamesValues...) {
|
||||
return metric
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasMetricAllLabelPairs(metric *dto.Metric, labelNamesValues ...string) bool {
|
||||
for i := 0; i < len(labelNamesValues); i += 2 {
|
||||
name, val := labelNamesValues[i], labelNamesValues[i+1]
|
||||
if !hasMetricLabelPair(metric, name, val) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func hasMetricLabelPair(metric *dto.Metric, labelName, labelValue string) bool {
|
||||
for _, labelPair := range metric.Label {
|
||||
if labelPair.GetName() == labelName && labelPair.GetValue() == labelValue {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func assertCounterValue(t *testing.T, want float64, family *dto.MetricFamily, labelNamesValues ...string) {
|
||||
t.Helper()
|
||||
|
||||
metric := findMetricByLabelNamesValues(family, labelNamesValues...)
|
||||
|
||||
if metric == nil {
|
||||
t.Error("metric must not be nil")
|
||||
return
|
||||
}
|
||||
if metric.Counter == nil {
|
||||
t.Errorf("metric %s must be a counter", family.GetName())
|
||||
return
|
||||
}
|
||||
|
||||
if cv := metric.Counter.GetValue(); cv != want {
|
||||
t.Errorf("metric %s has value %v, want %v", family.GetName(), cv, want)
|
||||
}
|
||||
}
|
||||
|
||||
func buildCounterAssert(t *testing.T, metricName string, expectedValue int) func(family *dto.MetricFamily) {
|
||||
return func(family *dto.MetricFamily) {
|
||||
if cv := int(family.Metric[0].Counter.GetValue()); cv != expectedValue {
|
||||
t.Errorf("metric %s has value %d, want %d", metricName, cv, expectedValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildGreaterThanCounterAssert(t *testing.T, metricName string, expectedMinValue int) func(family *dto.MetricFamily) {
|
||||
return func(family *dto.MetricFamily) {
|
||||
if cv := int(family.Metric[0].Counter.GetValue()); cv < expectedMinValue {
|
||||
t.Errorf("metric %s has value %d, want at least %d", metricName, cv, expectedMinValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildHistogramAssert(t *testing.T, metricName string, expectedSampleCount int) func(family *dto.MetricFamily) {
|
||||
return func(family *dto.MetricFamily) {
|
||||
if sc := int(family.Metric[0].Histogram.GetSampleCount()); sc != expectedSampleCount {
|
||||
t.Errorf("metric %s has sample count value %d, want %d", metricName, sc, expectedSampleCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildGaugeAssert(t *testing.T, metricName string, expectedValue int) func(family *dto.MetricFamily) {
|
||||
return func(family *dto.MetricFamily) {
|
||||
if gv := int(family.Metric[0].Gauge.GetValue()); gv != expectedValue {
|
||||
t.Errorf("metric %s has value %d, want %d", metricName, gv, expectedValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildTimestampAssert(t *testing.T, metricName string) func(family *dto.MetricFamily) {
|
||||
return func(family *dto.MetricFamily) {
|
||||
if ts := time.Unix(int64(family.Metric[0].Gauge.GetValue()), 0); time.Since(ts) > time.Minute {
|
||||
t.Errorf("metric %s has wrong timestamp %v", metricName, ts)
|
||||
}
|
||||
}
|
||||
}
|
86
pkg/metrics/statsd.go
Normal file
86
pkg/metrics/statsd.go
Normal file
|
@ -0,0 +1,86 @@
|
|||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
"github.com/containous/traefik/pkg/safe"
|
||||
"github.com/containous/traefik/pkg/types"
|
||||
kitlog "github.com/go-kit/kit/log"
|
||||
"github.com/go-kit/kit/metrics/statsd"
|
||||
)
|
||||
|
||||
var statsdClient = statsd.New("traefik.", kitlog.LoggerFunc(func(keyvals ...interface{}) error {
|
||||
log.WithoutContext().WithField(log.MetricsProviderName, "statsd").Info(keyvals)
|
||||
return nil
|
||||
}))
|
||||
|
||||
var statsdTicker *time.Ticker
|
||||
|
||||
const (
|
||||
statsdMetricsBackendReqsName = "backend.request.total"
|
||||
statsdMetricsBackendLatencyName = "backend.request.duration"
|
||||
statsdRetriesTotalName = "backend.retries.total"
|
||||
statsdConfigReloadsName = "config.reload.total"
|
||||
statsdConfigReloadsFailureName = statsdConfigReloadsName + ".failure"
|
||||
statsdLastConfigReloadSuccessName = "config.reload.lastSuccessTimestamp"
|
||||
statsdLastConfigReloadFailureName = "config.reload.lastFailureTimestamp"
|
||||
statsdEntrypointReqsName = "entrypoint.request.total"
|
||||
statsdEntrypointReqDurationName = "entrypoint.request.duration"
|
||||
statsdEntrypointOpenConnsName = "entrypoint.connections.open"
|
||||
statsdOpenConnsName = "backend.connections.open"
|
||||
statsdServerUpName = "backend.server.up"
|
||||
)
|
||||
|
||||
// RegisterStatsd registers the metrics pusher if this didn't happen yet and creates a statsd Registry instance.
|
||||
func RegisterStatsd(ctx context.Context, config *types.Statsd) Registry {
|
||||
if statsdTicker == nil {
|
||||
statsdTicker = initStatsdTicker(ctx, config)
|
||||
}
|
||||
|
||||
return &standardRegistry{
|
||||
enabled: true,
|
||||
configReloadsCounter: statsdClient.NewCounter(statsdConfigReloadsName, 1.0),
|
||||
configReloadsFailureCounter: statsdClient.NewCounter(statsdConfigReloadsFailureName, 1.0),
|
||||
lastConfigReloadSuccessGauge: statsdClient.NewGauge(statsdLastConfigReloadSuccessName),
|
||||
lastConfigReloadFailureGauge: statsdClient.NewGauge(statsdLastConfigReloadFailureName),
|
||||
entrypointReqsCounter: statsdClient.NewCounter(statsdEntrypointReqsName, 1.0),
|
||||
entrypointReqDurationHistogram: statsdClient.NewTiming(statsdEntrypointReqDurationName, 1.0),
|
||||
entrypointOpenConnsGauge: statsdClient.NewGauge(statsdEntrypointOpenConnsName),
|
||||
backendReqsCounter: statsdClient.NewCounter(statsdMetricsBackendReqsName, 1.0),
|
||||
backendReqDurationHistogram: statsdClient.NewTiming(statsdMetricsBackendLatencyName, 1.0),
|
||||
backendRetriesCounter: statsdClient.NewCounter(statsdRetriesTotalName, 1.0),
|
||||
backendOpenConnsGauge: statsdClient.NewGauge(statsdOpenConnsName),
|
||||
backendServerUpGauge: statsdClient.NewGauge(statsdServerUpName),
|
||||
}
|
||||
}
|
||||
|
||||
// initStatsdTicker initializes metrics pusher and creates a statsdClient if not created already
|
||||
func initStatsdTicker(ctx context.Context, config *types.Statsd) *time.Ticker {
|
||||
address := config.Address
|
||||
if len(address) == 0 {
|
||||
address = "localhost:8125"
|
||||
}
|
||||
pushInterval, err := time.ParseDuration(config.PushInterval)
|
||||
if err != nil {
|
||||
log.FromContext(ctx).Warnf("Unable to parse %s from config.PushInterval: using 10s as the default value", config.PushInterval)
|
||||
pushInterval = 10 * time.Second
|
||||
}
|
||||
|
||||
report := time.NewTicker(pushInterval)
|
||||
|
||||
safe.Go(func() {
|
||||
statsdClient.SendLoop(report.C, "udp", address)
|
||||
})
|
||||
|
||||
return report
|
||||
}
|
||||
|
||||
// StopStatsd stops internal statsdTicker which controls the pushing of metrics to StatsD Agent and resets it to `nil`
|
||||
func StopStatsd() {
|
||||
if statsdTicker != nil {
|
||||
statsdTicker.Stop()
|
||||
}
|
||||
statsdTicker = nil
|
||||
}
|
51
pkg/metrics/statsd_test.go
Normal file
51
pkg/metrics/statsd_test.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/pkg/types"
|
||||
"github.com/stvp/go-udp-testing"
|
||||
)
|
||||
|
||||
func TestStatsD(t *testing.T) {
|
||||
udp.SetAddr(":18125")
|
||||
// This is needed to make sure that UDP Listener listens for data a bit longer, otherwise it will quit after a millisecond
|
||||
udp.Timeout = 5 * time.Second
|
||||
|
||||
statsdRegistry := RegisterStatsd(context.Background(), &types.Statsd{Address: ":18125", PushInterval: "1s"})
|
||||
defer StopStatsd()
|
||||
|
||||
if !statsdRegistry.IsEnabled() {
|
||||
t.Errorf("Statsd registry should return true for IsEnabled()")
|
||||
}
|
||||
|
||||
expected := []string{
|
||||
// We are only validating counts, as it is nearly impossible to validate latency, since it varies every run
|
||||
"traefik.backend.request.total:2.000000|c\n",
|
||||
"traefik.backend.retries.total:2.000000|c\n",
|
||||
"traefik.backend.request.duration:10000.000000|ms",
|
||||
"traefik.config.reload.total:1.000000|c\n",
|
||||
"traefik.config.reload.total:1.000000|c\n",
|
||||
"traefik.entrypoint.request.total:1.000000|c\n",
|
||||
"traefik.entrypoint.request.duration:10000.000000|ms",
|
||||
"traefik.entrypoint.connections.open:1.000000|g\n",
|
||||
"traefik.backend.server.up:1.000000|g\n",
|
||||
}
|
||||
|
||||
udp.ShouldReceiveAll(t, expected, func() {
|
||||
statsdRegistry.BackendReqsCounter().With("service", "test", "code", string(http.StatusOK), "method", http.MethodGet).Add(1)
|
||||
statsdRegistry.BackendReqsCounter().With("service", "test", "code", string(http.StatusNotFound), "method", http.MethodGet).Add(1)
|
||||
statsdRegistry.BackendRetriesCounter().With("service", "test").Add(1)
|
||||
statsdRegistry.BackendRetriesCounter().With("service", "test").Add(1)
|
||||
statsdRegistry.BackendReqDurationHistogram().With("service", "test", "code", string(http.StatusOK)).Observe(10000)
|
||||
statsdRegistry.ConfigReloadsCounter().Add(1)
|
||||
statsdRegistry.ConfigReloadsFailureCounter().Add(1)
|
||||
statsdRegistry.EntrypointReqsCounter().With("entrypoint", "test").Add(1)
|
||||
statsdRegistry.EntrypointReqDurationHistogram().With("entrypoint", "test").Observe(10000)
|
||||
statsdRegistry.EntrypointOpenConnsGauge().With("entrypoint", "test").Set(1)
|
||||
statsdRegistry.BackendServerUpGauge().With("backend:test", "url", "http://127.0.0.1").Set(1)
|
||||
})
|
||||
}
|
18
pkg/middlewares/accesslog/capture_request_reader.go
Normal file
18
pkg/middlewares/accesslog/capture_request_reader.go
Normal file
|
@ -0,0 +1,18 @@
|
|||
package accesslog
|
||||
|
||||
import "io"
|
||||
|
||||
type captureRequestReader struct {
|
||||
source io.ReadCloser
|
||||
count int64
|
||||
}
|
||||
|
||||
func (r *captureRequestReader) Read(p []byte) (int, error) {
|
||||
n, err := r.source.Read(p)
|
||||
r.count += int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *captureRequestReader) Close() error {
|
||||
return r.source.Close()
|
||||
}
|
68
pkg/middlewares/accesslog/capture_response_writer.go
Normal file
68
pkg/middlewares/accesslog/capture_response_writer.go
Normal file
|
@ -0,0 +1,68 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/old/middlewares"
|
||||
)
|
||||
|
||||
var (
|
||||
_ middlewares.Stateful = &captureResponseWriter{}
|
||||
)
|
||||
|
||||
// captureResponseWriter is a wrapper of type http.ResponseWriter
|
||||
// that tracks request status and size
|
||||
type captureResponseWriter struct {
|
||||
rw http.ResponseWriter
|
||||
status int
|
||||
size int64
|
||||
}
|
||||
|
||||
func (crw *captureResponseWriter) Header() http.Header {
|
||||
return crw.rw.Header()
|
||||
}
|
||||
|
||||
func (crw *captureResponseWriter) Write(b []byte) (int, error) {
|
||||
if crw.status == 0 {
|
||||
crw.status = http.StatusOK
|
||||
}
|
||||
size, err := crw.rw.Write(b)
|
||||
crw.size += int64(size)
|
||||
return size, err
|
||||
}
|
||||
|
||||
func (crw *captureResponseWriter) WriteHeader(s int) {
|
||||
crw.rw.WriteHeader(s)
|
||||
crw.status = s
|
||||
}
|
||||
|
||||
func (crw *captureResponseWriter) Flush() {
|
||||
if f, ok := crw.rw.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (crw *captureResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if h, ok := crw.rw.(http.Hijacker); ok {
|
||||
return h.Hijack()
|
||||
}
|
||||
return nil, nil, fmt.Errorf("not a hijacker: %T", crw.rw)
|
||||
}
|
||||
|
||||
func (crw *captureResponseWriter) CloseNotify() <-chan bool {
|
||||
if c, ok := crw.rw.(http.CloseNotifier); ok {
|
||||
return c.CloseNotify()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (crw *captureResponseWriter) Status() int {
|
||||
return crw.status
|
||||
}
|
||||
|
||||
func (crw *captureResponseWriter) Size() int64 {
|
||||
return crw.size
|
||||
}
|
65
pkg/middlewares/accesslog/field_middleware.go
Normal file
65
pkg/middlewares/accesslog/field_middleware.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
// FieldApply function hook to add data in accesslog
|
||||
type FieldApply func(rw http.ResponseWriter, r *http.Request, next http.Handler, data *LogData)
|
||||
|
||||
// FieldHandler sends a new field to the logger.
|
||||
type FieldHandler struct {
|
||||
next http.Handler
|
||||
name string
|
||||
value string
|
||||
applyFn FieldApply
|
||||
}
|
||||
|
||||
// NewFieldHandler creates a Field handler.
|
||||
func NewFieldHandler(next http.Handler, name string, value string, applyFn FieldApply) http.Handler {
|
||||
return &FieldHandler{next: next, name: name, value: value, applyFn: applyFn}
|
||||
}
|
||||
|
||||
func (f *FieldHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
table := GetLogData(req)
|
||||
if table == nil {
|
||||
f.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
table.Core[f.name] = f.value
|
||||
|
||||
if f.applyFn != nil {
|
||||
f.applyFn(rw, req, f.next, table)
|
||||
} else {
|
||||
f.next.ServeHTTP(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
// AddServiceFields add service fields
|
||||
func AddServiceFields(rw http.ResponseWriter, req *http.Request, next http.Handler, data *LogData) {
|
||||
data.Core[ServiceURL] = req.URL // note that this is *not* the original incoming URL
|
||||
data.Core[ServiceAddr] = req.URL.Host
|
||||
|
||||
next.ServeHTTP(rw, req)
|
||||
|
||||
}
|
||||
|
||||
// AddOriginFields add origin fields
|
||||
func AddOriginFields(rw http.ResponseWriter, req *http.Request, next http.Handler, data *LogData) {
|
||||
crw := &captureResponseWriter{rw: rw}
|
||||
start := time.Now().UTC()
|
||||
|
||||
next.ServeHTTP(crw, req)
|
||||
|
||||
// use UTC to handle switchover of daylight saving correctly
|
||||
data.Core[OriginDuration] = time.Now().UTC().Sub(start)
|
||||
data.Core[OriginStatus] = crw.Status()
|
||||
// make copy of headers so we can ensure there is no subsequent mutation during response processing
|
||||
data.OriginResponse = make(http.Header)
|
||||
utils.CopyHeaders(data.OriginResponse, crw.Header())
|
||||
data.Core[OriginContentSize] = crw.Size()
|
||||
}
|
122
pkg/middlewares/accesslog/logdata.go
Normal file
122
pkg/middlewares/accesslog/logdata.go
Normal file
|
@ -0,0 +1,122 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
// StartUTC is the map key used for the time at which request processing started.
|
||||
StartUTC = "StartUTC"
|
||||
// StartLocal is the map key used for the local time at which request processing started.
|
||||
StartLocal = "StartLocal"
|
||||
// Duration is the map key used for the total time taken by processing the response, including the origin server's time but
|
||||
// not the log writing time.
|
||||
Duration = "Duration"
|
||||
|
||||
// RouterName is the map key used for the name of the Traefik router.
|
||||
RouterName = "RouterName"
|
||||
// ServiceName is the map key used for the name of the Traefik backend.
|
||||
ServiceName = "ServiceName"
|
||||
// ServiceURL is the map key used for the URL of the Traefik backend.
|
||||
ServiceURL = "ServiceURL"
|
||||
// ServiceAddr is the map key used for the IP:port of the Traefik backend (extracted from BackendURL)
|
||||
ServiceAddr = "ServiceAddr"
|
||||
|
||||
// ClientAddr is the map key used for the remote address in its original form (usually IP:port).
|
||||
ClientAddr = "ClientAddr"
|
||||
// ClientHost is the map key used for the remote IP address from which the client request was received.
|
||||
ClientHost = "ClientHost"
|
||||
// ClientPort is the map key used for the remote TCP port from which the client request was received.
|
||||
ClientPort = "ClientPort"
|
||||
// ClientUsername is the map key used for the username provided in the URL, if present.
|
||||
ClientUsername = "ClientUsername"
|
||||
// RequestAddr is the map key used for the HTTP Host header (usually IP:port). This is treated as not a header by the Go API.
|
||||
RequestAddr = "RequestAddr"
|
||||
// RequestHost is the map key used for the HTTP Host server name (not including port).
|
||||
RequestHost = "RequestHost"
|
||||
// RequestPort is the map key used for the TCP port from the HTTP Host.
|
||||
RequestPort = "RequestPort"
|
||||
// RequestMethod is the map key used for the HTTP method.
|
||||
RequestMethod = "RequestMethod"
|
||||
// RequestPath is the map key used for the HTTP request URI, not including the scheme, host or port.
|
||||
RequestPath = "RequestPath"
|
||||
// RequestProtocol is the map key used for the version of HTTP requested.
|
||||
RequestProtocol = "RequestProtocol"
|
||||
// RequestContentSize is the map key used for the number of bytes in the request entity (a.k.a. body) sent by the client.
|
||||
RequestContentSize = "RequestContentSize"
|
||||
// RequestRefererHeader is the Referer header in the request
|
||||
RequestRefererHeader = "request_Referer"
|
||||
// RequestUserAgentHeader is the User-Agent header in the request
|
||||
RequestUserAgentHeader = "request_User-Agent"
|
||||
// OriginDuration is the map key used for the time taken by the origin server ('upstream') to return its response.
|
||||
OriginDuration = "OriginDuration"
|
||||
// OriginContentSize is the map key used for the content length specified by the origin server, or 0 if unspecified.
|
||||
OriginContentSize = "OriginContentSize"
|
||||
// OriginStatus is the map key used for the HTTP status code returned by the origin server.
|
||||
// If the request was handled by this Traefik instance (e.g. with a redirect), then this value will be absent.
|
||||
OriginStatus = "OriginStatus"
|
||||
// DownstreamStatus is the map key used for the HTTP status code returned to the client.
|
||||
DownstreamStatus = "DownstreamStatus"
|
||||
// DownstreamContentSize is the map key used for the number of bytes in the response entity returned to the client.
|
||||
// This is in addition to the "Content-Length" header, which may be present in the origin response.
|
||||
DownstreamContentSize = "DownstreamContentSize"
|
||||
// RequestCount is the map key used for the number of requests received since the Traefik instance started.
|
||||
RequestCount = "RequestCount"
|
||||
// GzipRatio is the map key used for the response body compression ratio achieved.
|
||||
GzipRatio = "GzipRatio"
|
||||
// Overhead is the map key used for the processing time overhead caused by Traefik.
|
||||
Overhead = "Overhead"
|
||||
// RetryAttempts is the map key used for the amount of attempts the request was retried.
|
||||
RetryAttempts = "RetryAttempts"
|
||||
)
|
||||
|
||||
// These are written out in the default case when no config is provided to specify keys of interest.
|
||||
var defaultCoreKeys = [...]string{
|
||||
StartUTC,
|
||||
Duration,
|
||||
RouterName,
|
||||
ServiceName,
|
||||
ServiceURL,
|
||||
ClientHost,
|
||||
ClientPort,
|
||||
ClientUsername,
|
||||
RequestHost,
|
||||
RequestPort,
|
||||
RequestMethod,
|
||||
RequestPath,
|
||||
RequestProtocol,
|
||||
RequestContentSize,
|
||||
OriginDuration,
|
||||
OriginContentSize,
|
||||
OriginStatus,
|
||||
DownstreamStatus,
|
||||
DownstreamContentSize,
|
||||
RequestCount,
|
||||
}
|
||||
|
||||
// This contains the set of all keys, i.e. all the default keys plus all non-default keys.
|
||||
var allCoreKeys = make(map[string]struct{})
|
||||
|
||||
func init() {
|
||||
for _, k := range defaultCoreKeys {
|
||||
allCoreKeys[k] = struct{}{}
|
||||
}
|
||||
allCoreKeys[ServiceAddr] = struct{}{}
|
||||
allCoreKeys[ClientAddr] = struct{}{}
|
||||
allCoreKeys[RequestAddr] = struct{}{}
|
||||
allCoreKeys[GzipRatio] = struct{}{}
|
||||
allCoreKeys[StartLocal] = struct{}{}
|
||||
allCoreKeys[Overhead] = struct{}{}
|
||||
allCoreKeys[RetryAttempts] = struct{}{}
|
||||
}
|
||||
|
||||
// CoreLogData holds the fields computed from the request/response.
|
||||
type CoreLogData map[string]interface{}
|
||||
|
||||
// LogData is the data captured by the middleware so that it can be logged.
|
||||
type LogData struct {
|
||||
Core CoreLogData
|
||||
Request http.Header
|
||||
OriginResponse http.Header
|
||||
DownstreamResponse http.Header
|
||||
}
|
344
pkg/middlewares/accesslog/logger.go
Normal file
344
pkg/middlewares/accesslog/logger.go
Normal file
|
@ -0,0 +1,344 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/containous/alice"
|
||||
"github.com/containous/flaeg/parse"
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
"github.com/containous/traefik/pkg/types"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type key string
|
||||
|
||||
const (
|
||||
// DataTableKey is the key within the request context used to store the Log Data Table.
|
||||
DataTableKey key = "LogDataTable"
|
||||
|
||||
// CommonFormat is the common logging format (CLF).
|
||||
CommonFormat string = "common"
|
||||
|
||||
// JSONFormat is the JSON logging format.
|
||||
JSONFormat string = "json"
|
||||
)
|
||||
|
||||
type handlerParams struct {
|
||||
logDataTable *LogData
|
||||
crr *captureRequestReader
|
||||
crw *captureResponseWriter
|
||||
}
|
||||
|
||||
// Handler will write each request and its response to the access log.
|
||||
type Handler struct {
|
||||
config *types.AccessLog
|
||||
logger *logrus.Logger
|
||||
file *os.File
|
||||
mu sync.Mutex
|
||||
httpCodeRanges types.HTTPCodeRanges
|
||||
logHandlerChan chan handlerParams
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// WrapHandler Wraps access log handler into an Alice Constructor.
|
||||
func WrapHandler(handler *Handler) alice.Constructor {
|
||||
return func(next http.Handler) (http.Handler, error) {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
handler.ServeHTTP(rw, req, next.ServeHTTP)
|
||||
}), nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewHandler creates a new Handler.
|
||||
func NewHandler(config *types.AccessLog) (*Handler, error) {
|
||||
file := os.Stdout
|
||||
if len(config.FilePath) > 0 {
|
||||
f, err := openAccessLogFile(config.FilePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening access log file: %s", err)
|
||||
}
|
||||
file = f
|
||||
}
|
||||
logHandlerChan := make(chan handlerParams, config.BufferingSize)
|
||||
|
||||
var formatter logrus.Formatter
|
||||
|
||||
switch config.Format {
|
||||
case CommonFormat:
|
||||
formatter = new(CommonLogFormatter)
|
||||
case JSONFormat:
|
||||
formatter = new(logrus.JSONFormatter)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported access log format: %s", config.Format)
|
||||
}
|
||||
|
||||
logger := &logrus.Logger{
|
||||
Out: file,
|
||||
Formatter: formatter,
|
||||
Hooks: make(logrus.LevelHooks),
|
||||
Level: logrus.InfoLevel,
|
||||
}
|
||||
|
||||
logHandler := &Handler{
|
||||
config: config,
|
||||
logger: logger,
|
||||
file: file,
|
||||
logHandlerChan: logHandlerChan,
|
||||
}
|
||||
|
||||
if config.Filters != nil {
|
||||
if httpCodeRanges, err := types.NewHTTPCodeRanges(config.Filters.StatusCodes); err != nil {
|
||||
log.WithoutContext().Errorf("Failed to create new HTTP code ranges: %s", err)
|
||||
} else {
|
||||
logHandler.httpCodeRanges = httpCodeRanges
|
||||
}
|
||||
}
|
||||
|
||||
if config.BufferingSize > 0 {
|
||||
logHandler.wg.Add(1)
|
||||
go func() {
|
||||
defer logHandler.wg.Done()
|
||||
for handlerParams := range logHandler.logHandlerChan {
|
||||
logHandler.logTheRoundTrip(handlerParams.logDataTable, handlerParams.crr, handlerParams.crw)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return logHandler, nil
|
||||
}
|
||||
|
||||
func openAccessLogFile(filePath string) (*os.File, error) {
|
||||
dir := filepath.Dir(filePath)
|
||||
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create log path %s: %s", dir, err)
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0664)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening file %s: %s", filePath, err)
|
||||
}
|
||||
|
||||
return file, nil
|
||||
}
|
||||
|
||||
// GetLogData gets the request context object that contains logging data.
|
||||
// This creates data as the request passes through the middleware chain.
|
||||
func GetLogData(req *http.Request) *LogData {
|
||||
if ld, ok := req.Context().Value(DataTableKey).(*LogData); ok {
|
||||
return ld
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
|
||||
now := time.Now().UTC()
|
||||
|
||||
core := CoreLogData{
|
||||
StartUTC: now,
|
||||
StartLocal: now.Local(),
|
||||
}
|
||||
|
||||
logDataTable := &LogData{Core: core, Request: req.Header}
|
||||
|
||||
reqWithDataTable := req.WithContext(context.WithValue(req.Context(), DataTableKey, logDataTable))
|
||||
|
||||
var crr *captureRequestReader
|
||||
if req.Body != nil {
|
||||
crr = &captureRequestReader{source: req.Body, count: 0}
|
||||
reqWithDataTable.Body = crr
|
||||
}
|
||||
|
||||
core[RequestCount] = nextRequestCount()
|
||||
if req.Host != "" {
|
||||
core[RequestAddr] = req.Host
|
||||
core[RequestHost], core[RequestPort] = silentSplitHostPort(req.Host)
|
||||
}
|
||||
// copy the URL without the scheme, hostname etc
|
||||
urlCopy := &url.URL{
|
||||
Path: req.URL.Path,
|
||||
RawPath: req.URL.RawPath,
|
||||
RawQuery: req.URL.RawQuery,
|
||||
ForceQuery: req.URL.ForceQuery,
|
||||
Fragment: req.URL.Fragment,
|
||||
}
|
||||
urlCopyString := urlCopy.String()
|
||||
core[RequestMethod] = req.Method
|
||||
core[RequestPath] = urlCopyString
|
||||
core[RequestProtocol] = req.Proto
|
||||
|
||||
core[ClientAddr] = req.RemoteAddr
|
||||
core[ClientHost], core[ClientPort] = silentSplitHostPort(req.RemoteAddr)
|
||||
|
||||
if forwardedFor := req.Header.Get("X-Forwarded-For"); forwardedFor != "" {
|
||||
core[ClientHost] = forwardedFor
|
||||
}
|
||||
|
||||
crw := &captureResponseWriter{rw: rw}
|
||||
|
||||
next.ServeHTTP(crw, reqWithDataTable)
|
||||
|
||||
if _, ok := core[ClientUsername]; !ok {
|
||||
core[ClientUsername] = usernameIfPresent(reqWithDataTable.URL)
|
||||
}
|
||||
|
||||
logDataTable.DownstreamResponse = crw.Header()
|
||||
|
||||
if h.config.BufferingSize > 0 {
|
||||
h.logHandlerChan <- handlerParams{
|
||||
logDataTable: logDataTable,
|
||||
crr: crr,
|
||||
crw: crw,
|
||||
}
|
||||
} else {
|
||||
h.logTheRoundTrip(logDataTable, crr, crw)
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the Logger (i.e. the file, drain logHandlerChan, etc).
|
||||
func (h *Handler) Close() error {
|
||||
close(h.logHandlerChan)
|
||||
h.wg.Wait()
|
||||
return h.file.Close()
|
||||
}
|
||||
|
||||
// Rotate closes and reopens the log file to allow for rotation by an external source.
|
||||
func (h *Handler) Rotate() error {
|
||||
var err error
|
||||
|
||||
if h.file != nil {
|
||||
defer func(f *os.File) {
|
||||
f.Close()
|
||||
}(h.file)
|
||||
}
|
||||
|
||||
h.file, err = os.OpenFile(h.config.FilePath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0664)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.logger.Out = h.file
|
||||
return nil
|
||||
}
|
||||
|
||||
func silentSplitHostPort(value string) (host string, port string) {
|
||||
host, port, err := net.SplitHostPort(value)
|
||||
if err != nil {
|
||||
return value, "-"
|
||||
}
|
||||
return host, port
|
||||
}
|
||||
|
||||
func usernameIfPresent(theURL *url.URL) string {
|
||||
if theURL.User != nil {
|
||||
if name := theURL.User.Username(); name != "" {
|
||||
return name
|
||||
}
|
||||
}
|
||||
return "-"
|
||||
}
|
||||
|
||||
// Logging handler to log frontend name, backend name, and elapsed time.
|
||||
func (h *Handler) logTheRoundTrip(logDataTable *LogData, crr *captureRequestReader, crw *captureResponseWriter) {
|
||||
core := logDataTable.Core
|
||||
|
||||
retryAttempts, ok := core[RetryAttempts].(int)
|
||||
if !ok {
|
||||
retryAttempts = 0
|
||||
}
|
||||
core[RetryAttempts] = retryAttempts
|
||||
|
||||
if crr != nil {
|
||||
core[RequestContentSize] = crr.count
|
||||
}
|
||||
|
||||
core[DownstreamStatus] = crw.Status()
|
||||
|
||||
// n.b. take care to perform time arithmetic using UTC to avoid errors at DST boundaries.
|
||||
totalDuration := time.Now().UTC().Sub(core[StartUTC].(time.Time))
|
||||
core[Duration] = totalDuration
|
||||
|
||||
if h.keepAccessLog(crw.Status(), retryAttempts, totalDuration) {
|
||||
core[DownstreamContentSize] = crw.Size()
|
||||
if original, ok := core[OriginContentSize]; ok {
|
||||
o64 := original.(int64)
|
||||
if crw.Size() != o64 && crw.Size() != 0 {
|
||||
core[GzipRatio] = float64(o64) / float64(crw.Size())
|
||||
}
|
||||
}
|
||||
|
||||
core[Overhead] = totalDuration
|
||||
if origin, ok := core[OriginDuration]; ok {
|
||||
core[Overhead] = totalDuration - origin.(time.Duration)
|
||||
}
|
||||
|
||||
fields := logrus.Fields{}
|
||||
|
||||
for k, v := range logDataTable.Core {
|
||||
if h.config.Fields.Keep(k) {
|
||||
fields[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
h.redactHeaders(logDataTable.Request, fields, "request_")
|
||||
h.redactHeaders(logDataTable.OriginResponse, fields, "origin_")
|
||||
h.redactHeaders(logDataTable.DownstreamResponse, fields, "downstream_")
|
||||
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.logger.WithFields(fields).Println()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) redactHeaders(headers http.Header, fields logrus.Fields, prefix string) {
|
||||
for k := range headers {
|
||||
v := h.config.Fields.KeepHeader(k)
|
||||
if v == types.AccessLogKeep {
|
||||
fields[prefix+k] = headers.Get(k)
|
||||
} else if v == types.AccessLogRedact {
|
||||
fields[prefix+k] = "REDACTED"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) keepAccessLog(statusCode, retryAttempts int, duration time.Duration) bool {
|
||||
if h.config.Filters == nil {
|
||||
// no filters were specified
|
||||
return true
|
||||
}
|
||||
|
||||
if len(h.httpCodeRanges) == 0 && !h.config.Filters.RetryAttempts && h.config.Filters.MinDuration == 0 {
|
||||
// empty filters were specified, e.g. by passing --accessLog.filters only (without other filter options)
|
||||
return true
|
||||
}
|
||||
|
||||
if h.httpCodeRanges.Contains(statusCode) {
|
||||
return true
|
||||
}
|
||||
|
||||
if h.config.Filters.RetryAttempts && retryAttempts > 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if h.config.Filters.MinDuration > 0 && (parse.Duration(duration) > h.config.Filters.MinDuration) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
var requestCounter uint64 // Request ID
|
||||
|
||||
func nextRequestCount() uint64 {
|
||||
return atomic.AddUint64(&requestCounter, 1)
|
||||
}
|
83
pkg/middlewares/accesslog/logger_formatters.go
Normal file
83
pkg/middlewares/accesslog/logger_formatters.go
Normal file
|
@ -0,0 +1,83 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// default format for time presentation.
|
||||
const (
|
||||
commonLogTimeFormat = "02/Jan/2006:15:04:05 -0700"
|
||||
defaultValue = "-"
|
||||
)
|
||||
|
||||
// CommonLogFormatter provides formatting in the Traefik common log format.
|
||||
type CommonLogFormatter struct{}
|
||||
|
||||
// Format formats the log entry in the Traefik common log format.
|
||||
func (f *CommonLogFormatter) Format(entry *logrus.Entry) ([]byte, error) {
|
||||
b := &bytes.Buffer{}
|
||||
|
||||
var timestamp = defaultValue
|
||||
if v, ok := entry.Data[StartUTC]; ok {
|
||||
timestamp = v.(time.Time).Format(commonLogTimeFormat)
|
||||
}
|
||||
|
||||
var elapsedMillis int64
|
||||
if v, ok := entry.Data[Duration]; ok {
|
||||
elapsedMillis = v.(time.Duration).Nanoseconds() / 1000000
|
||||
}
|
||||
|
||||
_, err := fmt.Fprintf(b, "%s - %s [%s] \"%s %s %s\" %v %v %s %s %v %s %s %dms\n",
|
||||
toLog(entry.Data, ClientHost, defaultValue, false),
|
||||
toLog(entry.Data, ClientUsername, defaultValue, false),
|
||||
timestamp,
|
||||
toLog(entry.Data, RequestMethod, defaultValue, false),
|
||||
toLog(entry.Data, RequestPath, defaultValue, false),
|
||||
toLog(entry.Data, RequestProtocol, defaultValue, false),
|
||||
toLog(entry.Data, OriginStatus, defaultValue, true),
|
||||
toLog(entry.Data, OriginContentSize, defaultValue, true),
|
||||
toLog(entry.Data, "request_Referer", `"-"`, true),
|
||||
toLog(entry.Data, "request_User-Agent", `"-"`, true),
|
||||
toLog(entry.Data, RequestCount, defaultValue, true),
|
||||
toLog(entry.Data, RouterName, defaultValue, true),
|
||||
toLog(entry.Data, ServiceURL, defaultValue, true),
|
||||
elapsedMillis)
|
||||
|
||||
return b.Bytes(), err
|
||||
}
|
||||
|
||||
func toLog(fields logrus.Fields, key string, defaultValue string, quoted bool) interface{} {
|
||||
if v, ok := fields[key]; ok {
|
||||
if v == nil {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
switch s := v.(type) {
|
||||
case string:
|
||||
return toLogEntry(s, defaultValue, quoted)
|
||||
|
||||
case fmt.Stringer:
|
||||
return toLogEntry(s.String(), defaultValue, quoted)
|
||||
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
|
||||
}
|
||||
|
||||
func toLogEntry(s string, defaultValue string, quote bool) string {
|
||||
if len(s) == 0 {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
if quote {
|
||||
return `"` + s + `"`
|
||||
}
|
||||
return s
|
||||
}
|
140
pkg/middlewares/accesslog/logger_formatters_test.go
Normal file
140
pkg/middlewares/accesslog/logger_formatters_test.go
Normal file
|
@ -0,0 +1,140 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCommonLogFormatter_Format(t *testing.T) {
|
||||
clf := CommonLogFormatter{}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
data map[string]interface{}
|
||||
expectedLog string
|
||||
}{
|
||||
{
|
||||
name: "OriginStatus & OriginContentSize are nil",
|
||||
data: map[string]interface{}{
|
||||
StartUTC: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC),
|
||||
Duration: 123 * time.Second,
|
||||
ClientHost: "10.0.0.1",
|
||||
ClientUsername: "Client",
|
||||
RequestMethod: http.MethodGet,
|
||||
RequestPath: "/foo",
|
||||
RequestProtocol: "http",
|
||||
OriginStatus: nil,
|
||||
OriginContentSize: nil,
|
||||
RequestRefererHeader: "",
|
||||
RequestUserAgentHeader: "",
|
||||
RequestCount: 0,
|
||||
RouterName: "",
|
||||
ServiceURL: "",
|
||||
},
|
||||
expectedLog: `10.0.0.1 - Client [10/Nov/2009:23:00:00 +0000] "GET /foo http" - - "-" "-" 0 - - 123000ms
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "all data",
|
||||
data: map[string]interface{}{
|
||||
StartUTC: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC),
|
||||
Duration: 123 * time.Second,
|
||||
ClientHost: "10.0.0.1",
|
||||
ClientUsername: "Client",
|
||||
RequestMethod: http.MethodGet,
|
||||
RequestPath: "/foo",
|
||||
RequestProtocol: "http",
|
||||
OriginStatus: 123,
|
||||
OriginContentSize: 132,
|
||||
RequestRefererHeader: "referer",
|
||||
RequestUserAgentHeader: "agent",
|
||||
RequestCount: nil,
|
||||
RouterName: "foo",
|
||||
ServiceURL: "http://10.0.0.2/toto",
|
||||
},
|
||||
expectedLog: `10.0.0.1 - Client [10/Nov/2009:23:00:00 +0000] "GET /foo http" 123 132 "referer" "agent" - "foo" "http://10.0.0.2/toto" 123000ms
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
entry := &logrus.Entry{Data: test.data}
|
||||
|
||||
raw, err := clf.Format(entry)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.expectedLog, string(raw))
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func Test_toLog(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
fields logrus.Fields
|
||||
fieldName string
|
||||
defaultValue string
|
||||
quoted bool
|
||||
expectedLog interface{}
|
||||
}{
|
||||
{
|
||||
desc: "Should return int 1",
|
||||
fields: logrus.Fields{
|
||||
"Powpow": 1,
|
||||
},
|
||||
fieldName: "Powpow",
|
||||
defaultValue: defaultValue,
|
||||
quoted: false,
|
||||
expectedLog: 1,
|
||||
},
|
||||
{
|
||||
desc: "Should return string foo",
|
||||
fields: logrus.Fields{
|
||||
"Powpow": "foo",
|
||||
},
|
||||
fieldName: "Powpow",
|
||||
defaultValue: defaultValue,
|
||||
quoted: true,
|
||||
expectedLog: `"foo"`,
|
||||
},
|
||||
{
|
||||
desc: "Should return defaultValue if fieldName does not exist",
|
||||
fields: logrus.Fields{
|
||||
"Powpow": "foo",
|
||||
},
|
||||
fieldName: "",
|
||||
defaultValue: defaultValue,
|
||||
quoted: false,
|
||||
expectedLog: "-",
|
||||
},
|
||||
{
|
||||
desc: "Should return defaultValue if fields is nil",
|
||||
fields: nil,
|
||||
fieldName: "",
|
||||
defaultValue: defaultValue,
|
||||
quoted: false,
|
||||
expectedLog: "-",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
lg := toLog(test.fields, test.fieldName, defaultValue, test.quoted)
|
||||
|
||||
assert.Equal(t, test.expectedLog, lg)
|
||||
})
|
||||
}
|
||||
}
|
649
pkg/middlewares/accesslog/logger_test.go
Normal file
649
pkg/middlewares/accesslog/logger_test.go
Normal file
|
@ -0,0 +1,649 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/containous/flaeg/parse"
|
||||
"github.com/containous/traefik/pkg/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
logFileNameSuffix = "/traefik/logger/test.log"
|
||||
testContent = "Hello, World"
|
||||
testServiceName = "http://127.0.0.1/testService"
|
||||
testRouterName = "testRouter"
|
||||
testStatus = 123
|
||||
testContentSize int64 = 12
|
||||
testHostname = "TestHost"
|
||||
testUsername = "TestUser"
|
||||
testPath = "testpath"
|
||||
testPort = 8181
|
||||
testProto = "HTTP/0.0"
|
||||
testMethod = http.MethodPost
|
||||
testReferer = "testReferer"
|
||||
testUserAgent = "testUserAgent"
|
||||
testRetryAttempts = 2
|
||||
testStart = time.Now()
|
||||
)
|
||||
|
||||
func TestLogRotation(t *testing.T) {
|
||||
tempDir, err := ioutil.TempDir("", "traefik_")
|
||||
if err != nil {
|
||||
t.Fatalf("Error setting up temporary directory: %s", err)
|
||||
}
|
||||
|
||||
fileName := tempDir + "traefik.log"
|
||||
rotatedFileName := fileName + ".rotated"
|
||||
|
||||
config := &types.AccessLog{FilePath: fileName, Format: CommonFormat}
|
||||
logHandler, err := NewHandler(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating new log handler: %s", err)
|
||||
}
|
||||
defer logHandler.Close()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
next := func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
iterations := 20
|
||||
halfDone := make(chan bool)
|
||||
writeDone := make(chan bool)
|
||||
go func() {
|
||||
for i := 0; i < iterations; i++ {
|
||||
logHandler.ServeHTTP(recorder, req, next)
|
||||
if i == iterations/2 {
|
||||
halfDone <- true
|
||||
}
|
||||
}
|
||||
writeDone <- true
|
||||
}()
|
||||
|
||||
<-halfDone
|
||||
err = os.Rename(fileName, rotatedFileName)
|
||||
if err != nil {
|
||||
t.Fatalf("Error renaming file: %s", err)
|
||||
}
|
||||
|
||||
err = logHandler.Rotate()
|
||||
if err != nil {
|
||||
t.Fatalf("Error rotating file: %s", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-writeDone:
|
||||
gotLineCount := lineCount(t, fileName) + lineCount(t, rotatedFileName)
|
||||
if iterations != gotLineCount {
|
||||
t.Errorf("Wanted %d written log lines, got %d", iterations, gotLineCount)
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatalf("test timed out")
|
||||
}
|
||||
|
||||
close(halfDone)
|
||||
close(writeDone)
|
||||
}
|
||||
|
||||
func lineCount(t *testing.T, fileName string) int {
|
||||
t.Helper()
|
||||
fileContents, err := ioutil.ReadFile(fileName)
|
||||
if err != nil {
|
||||
t.Fatalf("Error reading from file %s: %s", fileName, err)
|
||||
}
|
||||
|
||||
count := 0
|
||||
for _, line := range strings.Split(string(fileContents), "\n") {
|
||||
if strings.TrimSpace(line) == "" {
|
||||
continue
|
||||
}
|
||||
count++
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
func TestLoggerCLF(t *testing.T) {
|
||||
tmpDir := createTempDir(t, CommonFormat)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
logFilePath := filepath.Join(tmpDir, logFileNameSuffix)
|
||||
config := &types.AccessLog{FilePath: logFilePath, Format: CommonFormat}
|
||||
doLogging(t, config)
|
||||
|
||||
logData, err := ioutil.ReadFile(logFilePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedLog := ` TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testRouter" "http://127.0.0.1/testService" 1ms`
|
||||
assertValidLogData(t, expectedLog, logData)
|
||||
}
|
||||
|
||||
func TestAsyncLoggerCLF(t *testing.T) {
|
||||
tmpDir := createTempDir(t, CommonFormat)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
logFilePath := filepath.Join(tmpDir, logFileNameSuffix)
|
||||
config := &types.AccessLog{FilePath: logFilePath, Format: CommonFormat, BufferingSize: 1024}
|
||||
doLogging(t, config)
|
||||
|
||||
logData, err := ioutil.ReadFile(logFilePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedLog := ` TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testRouter" "http://127.0.0.1/testService" 1ms`
|
||||
assertValidLogData(t, expectedLog, logData)
|
||||
}
|
||||
|
||||
func assertString(exp string) func(t *testing.T, actual interface{}) {
|
||||
return func(t *testing.T, actual interface{}) {
|
||||
t.Helper()
|
||||
|
||||
assert.Equal(t, exp, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func assertNotEmpty() func(t *testing.T, actual interface{}) {
|
||||
return func(t *testing.T, actual interface{}) {
|
||||
t.Helper()
|
||||
|
||||
assert.NotEqual(t, "", actual)
|
||||
}
|
||||
}
|
||||
|
||||
func assertFloat64(exp float64) func(t *testing.T, actual interface{}) {
|
||||
return func(t *testing.T, actual interface{}) {
|
||||
t.Helper()
|
||||
|
||||
assert.Equal(t, exp, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func assertFloat64NotZero() func(t *testing.T, actual interface{}) {
|
||||
return func(t *testing.T, actual interface{}) {
|
||||
t.Helper()
|
||||
|
||||
assert.NotZero(t, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerJSON(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
config *types.AccessLog
|
||||
expected map[string]func(t *testing.T, value interface{})
|
||||
}{
|
||||
{
|
||||
desc: "default config",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: JSONFormat,
|
||||
},
|
||||
expected: map[string]func(t *testing.T, value interface{}){
|
||||
RequestHost: assertString(testHostname),
|
||||
RequestAddr: assertString(testHostname),
|
||||
RequestMethod: assertString(testMethod),
|
||||
RequestPath: assertString(testPath),
|
||||
RequestProtocol: assertString(testProto),
|
||||
RequestPort: assertString("-"),
|
||||
DownstreamStatus: assertFloat64(float64(testStatus)),
|
||||
DownstreamContentSize: assertFloat64(float64(len(testContent))),
|
||||
OriginContentSize: assertFloat64(float64(len(testContent))),
|
||||
OriginStatus: assertFloat64(float64(testStatus)),
|
||||
RequestRefererHeader: assertString(testReferer),
|
||||
RequestUserAgentHeader: assertString(testUserAgent),
|
||||
RouterName: assertString(testRouterName),
|
||||
ServiceURL: assertString(testServiceName),
|
||||
ClientUsername: assertString(testUsername),
|
||||
ClientHost: assertString(testHostname),
|
||||
ClientPort: assertString(fmt.Sprintf("%d", testPort)),
|
||||
ClientAddr: assertString(fmt.Sprintf("%s:%d", testHostname, testPort)),
|
||||
"level": assertString("info"),
|
||||
"msg": assertString(""),
|
||||
"downstream_Content-Type": assertString("text/plain; charset=utf-8"),
|
||||
RequestCount: assertFloat64NotZero(),
|
||||
Duration: assertFloat64NotZero(),
|
||||
Overhead: assertFloat64NotZero(),
|
||||
RetryAttempts: assertFloat64(float64(testRetryAttempts)),
|
||||
"time": assertNotEmpty(),
|
||||
"StartLocal": assertNotEmpty(),
|
||||
"StartUTC": assertNotEmpty(),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "default config drop all fields",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: JSONFormat,
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "drop",
|
||||
},
|
||||
},
|
||||
expected: map[string]func(t *testing.T, value interface{}){
|
||||
"level": assertString("info"),
|
||||
"msg": assertString(""),
|
||||
"time": assertNotEmpty(),
|
||||
"downstream_Content-Type": assertString("text/plain; charset=utf-8"),
|
||||
RequestRefererHeader: assertString(testReferer),
|
||||
RequestUserAgentHeader: assertString(testUserAgent),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "default config drop all fields and headers",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: JSONFormat,
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "drop",
|
||||
Headers: &types.FieldHeaders{
|
||||
DefaultMode: "drop",
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]func(t *testing.T, value interface{}){
|
||||
"level": assertString("info"),
|
||||
"msg": assertString(""),
|
||||
"time": assertNotEmpty(),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "default config drop all fields and redact headers",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: JSONFormat,
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "drop",
|
||||
Headers: &types.FieldHeaders{
|
||||
DefaultMode: "redact",
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]func(t *testing.T, value interface{}){
|
||||
"level": assertString("info"),
|
||||
"msg": assertString(""),
|
||||
"time": assertNotEmpty(),
|
||||
"downstream_Content-Type": assertString("REDACTED"),
|
||||
RequestRefererHeader: assertString("REDACTED"),
|
||||
RequestUserAgentHeader: assertString("REDACTED"),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "default config drop all fields and headers but kept someone",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: JSONFormat,
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "drop",
|
||||
Names: types.FieldNames{
|
||||
RequestHost: "keep",
|
||||
},
|
||||
Headers: &types.FieldHeaders{
|
||||
DefaultMode: "drop",
|
||||
Names: types.FieldHeaderNames{
|
||||
"Referer": "keep",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]func(t *testing.T, value interface{}){
|
||||
RequestHost: assertString(testHostname),
|
||||
"level": assertString("info"),
|
||||
"msg": assertString(""),
|
||||
"time": assertNotEmpty(),
|
||||
RequestRefererHeader: assertString(testReferer),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tmpDir := createTempDir(t, JSONFormat)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
logFilePath := filepath.Join(tmpDir, logFileNameSuffix)
|
||||
|
||||
test.config.FilePath = logFilePath
|
||||
doLogging(t, test.config)
|
||||
|
||||
logData, err := ioutil.ReadFile(logFilePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
jsonData := make(map[string]interface{})
|
||||
err = json.Unmarshal(logData, &jsonData)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, len(test.expected), len(jsonData))
|
||||
|
||||
for field, assertion := range test.expected {
|
||||
assertion(t, jsonData[field])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewLogHandlerOutputStdout(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
config *types.AccessLog
|
||||
expectedLog string
|
||||
}{
|
||||
{
|
||||
desc: "default config",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
},
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
|
||||
},
|
||||
{
|
||||
desc: "default config with empty filters",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Filters: &types.AccessLogFilters{},
|
||||
},
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
|
||||
},
|
||||
{
|
||||
desc: "Status code filter not matching",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Filters: &types.AccessLogFilters{
|
||||
StatusCodes: []string{"200"},
|
||||
},
|
||||
},
|
||||
expectedLog: ``,
|
||||
},
|
||||
{
|
||||
desc: "Status code filter matching",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Filters: &types.AccessLogFilters{
|
||||
StatusCodes: []string{"123"},
|
||||
},
|
||||
},
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
|
||||
},
|
||||
{
|
||||
desc: "Duration filter not matching",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Filters: &types.AccessLogFilters{
|
||||
MinDuration: parse.Duration(1 * time.Hour),
|
||||
},
|
||||
},
|
||||
expectedLog: ``,
|
||||
},
|
||||
{
|
||||
desc: "Duration filter matching",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Filters: &types.AccessLogFilters{
|
||||
MinDuration: parse.Duration(1 * time.Millisecond),
|
||||
},
|
||||
},
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
|
||||
},
|
||||
{
|
||||
desc: "Retry attempts filter matching",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Filters: &types.AccessLogFilters{
|
||||
RetryAttempts: true,
|
||||
},
|
||||
},
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
|
||||
},
|
||||
{
|
||||
desc: "Default mode keep",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "keep",
|
||||
},
|
||||
},
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
|
||||
},
|
||||
{
|
||||
desc: "Default mode keep with override",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "keep",
|
||||
Names: types.FieldNames{
|
||||
ClientHost: "drop",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLog: `- - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testRouter" "http://127.0.0.1/testService" 1ms`,
|
||||
},
|
||||
{
|
||||
desc: "Default mode drop",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "drop",
|
||||
},
|
||||
},
|
||||
expectedLog: `- - - [-] "- - -" - - "testReferer" "testUserAgent" - - - 0ms`,
|
||||
},
|
||||
{
|
||||
desc: "Default mode drop with override",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "drop",
|
||||
Names: types.FieldNames{
|
||||
ClientHost: "drop",
|
||||
ClientUsername: "keep",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLog: `- - TestUser [-] "- - -" - - "testReferer" "testUserAgent" - - - 0ms`,
|
||||
},
|
||||
{
|
||||
desc: "Default mode drop with header dropped",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "drop",
|
||||
Names: types.FieldNames{
|
||||
ClientHost: "drop",
|
||||
ClientUsername: "keep",
|
||||
},
|
||||
Headers: &types.FieldHeaders{
|
||||
DefaultMode: "drop",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLog: `- - TestUser [-] "- - -" - - "-" "-" - - - 0ms`,
|
||||
},
|
||||
{
|
||||
desc: "Default mode drop with header redacted",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "drop",
|
||||
Names: types.FieldNames{
|
||||
ClientHost: "drop",
|
||||
ClientUsername: "keep",
|
||||
},
|
||||
Headers: &types.FieldHeaders{
|
||||
DefaultMode: "redact",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLog: `- - TestUser [-] "- - -" - - "REDACTED" "REDACTED" - - - 0ms`,
|
||||
},
|
||||
{
|
||||
desc: "Default mode drop with header redacted",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "drop",
|
||||
Names: types.FieldNames{
|
||||
ClientHost: "drop",
|
||||
ClientUsername: "keep",
|
||||
},
|
||||
Headers: &types.FieldHeaders{
|
||||
DefaultMode: "keep",
|
||||
Names: types.FieldHeaderNames{
|
||||
"Referer": "redact",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLog: `- - TestUser [-] "- - -" - - "REDACTED" "testUserAgent" - - - 0ms`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
|
||||
// NOTE: It is not possible to run these cases in parallel because we capture Stdout
|
||||
|
||||
file, restoreStdout := captureStdout(t)
|
||||
defer restoreStdout()
|
||||
|
||||
doLogging(t, test.config)
|
||||
|
||||
written, err := ioutil.ReadFile(file.Name())
|
||||
require.NoError(t, err, "unable to read captured stdout from file")
|
||||
assertValidLogData(t, test.expectedLog, written)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func assertValidLogData(t *testing.T, expected string, logData []byte) {
|
||||
|
||||
if len(expected) == 0 {
|
||||
assert.Zero(t, len(logData))
|
||||
t.Log(string(logData))
|
||||
return
|
||||
}
|
||||
|
||||
result, err := ParseAccessLog(string(logData))
|
||||
require.NoError(t, err)
|
||||
|
||||
resultExpected, err := ParseAccessLog(expected)
|
||||
require.NoError(t, err)
|
||||
|
||||
formatErrMessage := fmt.Sprintf(`
|
||||
Expected: %s
|
||||
Actual: %s`, expected, string(logData))
|
||||
|
||||
require.Equal(t, len(resultExpected), len(result), formatErrMessage)
|
||||
assert.Equal(t, resultExpected[ClientHost], result[ClientHost], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[ClientUsername], result[ClientUsername], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[RequestMethod], result[RequestMethod], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[RequestPath], result[RequestPath], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[RequestProtocol], result[RequestProtocol], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[OriginStatus], result[OriginStatus], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[OriginContentSize], result[OriginContentSize], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[RequestRefererHeader], result[RequestRefererHeader], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[RequestUserAgentHeader], result[RequestUserAgentHeader], formatErrMessage)
|
||||
assert.Regexp(t, regexp.MustCompile("[0-9]*"), result[RequestCount], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[RouterName], result[RouterName], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[ServiceURL], result[ServiceURL], formatErrMessage)
|
||||
assert.Regexp(t, regexp.MustCompile("[0-9]*ms"), result[Duration], formatErrMessage)
|
||||
}
|
||||
|
||||
func captureStdout(t *testing.T) (out *os.File, restoreStdout func()) {
|
||||
file, err := ioutil.TempFile("", "testlogger")
|
||||
require.NoError(t, err, "failed to create temp file")
|
||||
|
||||
original := os.Stdout
|
||||
os.Stdout = file
|
||||
|
||||
restoreStdout = func() {
|
||||
os.Stdout = original
|
||||
}
|
||||
|
||||
return file, restoreStdout
|
||||
}
|
||||
|
||||
func createTempDir(t *testing.T, prefix string) string {
|
||||
tmpDir, err := ioutil.TempDir("", prefix)
|
||||
require.NoError(t, err, "failed to create temp dir")
|
||||
|
||||
return tmpDir
|
||||
}
|
||||
|
||||
func doLogging(t *testing.T, config *types.AccessLog) {
|
||||
logger, err := NewHandler(config)
|
||||
require.NoError(t, err)
|
||||
defer logger.Close()
|
||||
|
||||
if config.FilePath != "" {
|
||||
_, err = os.Stat(config.FilePath)
|
||||
require.NoError(t, err, fmt.Sprintf("logger should create %s", config.FilePath))
|
||||
}
|
||||
|
||||
req := &http.Request{
|
||||
Header: map[string][]string{
|
||||
"User-Agent": {testUserAgent},
|
||||
"Referer": {testReferer},
|
||||
},
|
||||
Proto: testProto,
|
||||
Host: testHostname,
|
||||
Method: testMethod,
|
||||
RemoteAddr: fmt.Sprintf("%s:%d", testHostname, testPort),
|
||||
URL: &url.URL{
|
||||
User: url.UserPassword(testUsername, ""),
|
||||
Path: testPath,
|
||||
},
|
||||
}
|
||||
|
||||
logger.ServeHTTP(httptest.NewRecorder(), req, logWriterTestHandlerFunc)
|
||||
}
|
||||
|
||||
func logWriterTestHandlerFunc(rw http.ResponseWriter, r *http.Request) {
|
||||
if _, err := rw.Write([]byte(testContent)); err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
logData := GetLogData(r)
|
||||
if logData != nil {
|
||||
logData.Core[RouterName] = testRouterName
|
||||
logData.Core[ServiceURL] = testServiceName
|
||||
logData.Core[OriginStatus] = testStatus
|
||||
logData.Core[OriginContentSize] = testContentSize
|
||||
logData.Core[RetryAttempts] = testRetryAttempts
|
||||
logData.Core[StartUTC] = testStart.UTC()
|
||||
logData.Core[StartLocal] = testStart.Local()
|
||||
} else {
|
||||
http.Error(rw, "LogData is nil", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
rw.WriteHeader(testStatus)
|
||||
}
|
54
pkg/middlewares/accesslog/parser.go
Normal file
54
pkg/middlewares/accesslog/parser.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
// ParseAccessLog parse line of access log and return a map with each fields
|
||||
func ParseAccessLog(data string) (map[string]string, error) {
|
||||
var buffer bytes.Buffer
|
||||
buffer.WriteString(`(\S+)`) // 1 - ClientHost
|
||||
buffer.WriteString(`\s-\s`) // - - Spaces
|
||||
buffer.WriteString(`(\S+)\s`) // 2 - ClientUsername
|
||||
buffer.WriteString(`\[([^]]+)\]\s`) // 3 - StartUTC
|
||||
buffer.WriteString(`"(\S*)\s?`) // 4 - RequestMethod
|
||||
buffer.WriteString(`((?:[^"]*(?:\\")?)*)\s`) // 5 - RequestPath
|
||||
buffer.WriteString(`([^"]*)"\s`) // 6 - RequestProtocol
|
||||
buffer.WriteString(`(\S+)\s`) // 7 - OriginStatus
|
||||
buffer.WriteString(`(\S+)\s`) // 8 - OriginContentSize
|
||||
buffer.WriteString(`("?\S+"?)\s`) // 9 - Referrer
|
||||
buffer.WriteString(`("\S+")\s`) // 10 - User-Agent
|
||||
buffer.WriteString(`(\S+)\s`) // 11 - RequestCount
|
||||
buffer.WriteString(`("[^"]*"|-)\s`) // 12 - FrontendName
|
||||
buffer.WriteString(`("[^"]*"|-)\s`) // 13 - BackendURL
|
||||
buffer.WriteString(`(\S+)`) // 14 - Duration
|
||||
|
||||
regex, err := regexp.Compile(buffer.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
submatch := regex.FindStringSubmatch(data)
|
||||
result := make(map[string]string)
|
||||
|
||||
// Need to be > 13 to match CLF format
|
||||
if len(submatch) > 13 {
|
||||
result[ClientHost] = submatch[1]
|
||||
result[ClientUsername] = submatch[2]
|
||||
result[StartUTC] = submatch[3]
|
||||
result[RequestMethod] = submatch[4]
|
||||
result[RequestPath] = submatch[5]
|
||||
result[RequestProtocol] = submatch[6]
|
||||
result[OriginStatus] = submatch[7]
|
||||
result[OriginContentSize] = submatch[8]
|
||||
result[RequestRefererHeader] = submatch[9]
|
||||
result[RequestUserAgentHeader] = submatch[10]
|
||||
result[RequestCount] = submatch[11]
|
||||
result[RouterName] = submatch[12]
|
||||
result[ServiceURL] = submatch[13]
|
||||
result[Duration] = submatch[14]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
75
pkg/middlewares/accesslog/parser_test.go
Normal file
75
pkg/middlewares/accesslog/parser_test.go
Normal file
|
@ -0,0 +1,75 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseAccessLog(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
value string
|
||||
expected map[string]string
|
||||
}{
|
||||
{
|
||||
desc: "full log",
|
||||
value: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testRouter" "http://127.0.0.1/testService" 1ms`,
|
||||
expected: map[string]string{
|
||||
ClientHost: "TestHost",
|
||||
ClientUsername: "TestUser",
|
||||
StartUTC: "13/Apr/2016:07:14:19 -0700",
|
||||
RequestMethod: "POST",
|
||||
RequestPath: "testpath",
|
||||
RequestProtocol: "HTTP/0.0",
|
||||
OriginStatus: "123",
|
||||
OriginContentSize: "12",
|
||||
RequestRefererHeader: `"testReferer"`,
|
||||
RequestUserAgentHeader: `"testUserAgent"`,
|
||||
RequestCount: "1",
|
||||
RouterName: `"testRouter"`,
|
||||
ServiceURL: `"http://127.0.0.1/testService"`,
|
||||
Duration: "1ms",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "log with space",
|
||||
value: `127.0.0.1 - - [09/Mar/2018:10:51:32 +0000] "GET / HTTP/1.1" 401 17 "-" "Go-http-client/1.1" 1 "testRouter with space" - 0ms`,
|
||||
expected: map[string]string{
|
||||
ClientHost: "127.0.0.1",
|
||||
ClientUsername: "-",
|
||||
StartUTC: "09/Mar/2018:10:51:32 +0000",
|
||||
RequestMethod: "GET",
|
||||
RequestPath: "/",
|
||||
RequestProtocol: "HTTP/1.1",
|
||||
OriginStatus: "401",
|
||||
OriginContentSize: "17",
|
||||
RequestRefererHeader: `"-"`,
|
||||
RequestUserAgentHeader: `"Go-http-client/1.1"`,
|
||||
RequestCount: "1",
|
||||
RouterName: `"testRouter with space"`,
|
||||
ServiceURL: `-`,
|
||||
Duration: "0ms",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "bad log",
|
||||
value: `bad`,
|
||||
expected: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result, err := ParseAccessLog(test.value)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(test.expected), len(result))
|
||||
for key, value := range test.expected {
|
||||
assert.Equal(t, value, result[key])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
21
pkg/middlewares/accesslog/save_retries.go
Normal file
21
pkg/middlewares/accesslog/save_retries.go
Normal file
|
@ -0,0 +1,21 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// SaveRetries is an implementation of RetryListener that stores RetryAttempts in the LogDataTable.
|
||||
type SaveRetries struct{}
|
||||
|
||||
// Retried implements the RetryListener interface and will be called for each retry that happens.
|
||||
func (s *SaveRetries) Retried(req *http.Request, attempt int) {
|
||||
// it is the request attempt x, but the retry attempt is x-1
|
||||
if attempt > 0 {
|
||||
attempt--
|
||||
}
|
||||
|
||||
table := GetLogData(req)
|
||||
if table != nil {
|
||||
table.Core[RetryAttempts] = attempt
|
||||
}
|
||||
}
|
48
pkg/middlewares/accesslog/save_retries_test.go
Normal file
48
pkg/middlewares/accesslog/save_retries_test.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSaveRetries(t *testing.T) {
|
||||
tests := []struct {
|
||||
requestAttempt int
|
||||
wantRetryAttemptsInLog int
|
||||
}{
|
||||
{
|
||||
requestAttempt: 0,
|
||||
wantRetryAttemptsInLog: 0,
|
||||
},
|
||||
{
|
||||
requestAttempt: 1,
|
||||
wantRetryAttemptsInLog: 0,
|
||||
},
|
||||
{
|
||||
requestAttempt: 3,
|
||||
wantRetryAttemptsInLog: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
|
||||
t.Run(fmt.Sprintf("%d retries", test.requestAttempt), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
saveRetries := &SaveRetries{}
|
||||
|
||||
logDataTable := &LogData{Core: make(CoreLogData)}
|
||||
req := httptest.NewRequest(http.MethodGet, "/some/path", nil)
|
||||
reqWithDataTable := req.WithContext(context.WithValue(req.Context(), DataTableKey, logDataTable))
|
||||
|
||||
saveRetries.Retried(reqWithDataTable, test.requestAttempt)
|
||||
|
||||
if logDataTable.Core[RetryAttempts] != test.wantRetryAttemptsInLog {
|
||||
t.Errorf("got %v in logDataTable, want %v", logDataTable.Core[RetryAttempts], test.wantRetryAttemptsInLog)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
62
pkg/middlewares/addprefix/add_prefix.go
Normal file
62
pkg/middlewares/addprefix/add_prefix.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package addprefix
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/middlewares"
|
||||
"github.com/containous/traefik/pkg/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "AddPrefix"
|
||||
)
|
||||
|
||||
// AddPrefix is a middleware used to add prefix to an URL request.
|
||||
type addPrefix struct {
|
||||
next http.Handler
|
||||
prefix string
|
||||
name string
|
||||
}
|
||||
|
||||
// New creates a new handler.
|
||||
func New(ctx context.Context, next http.Handler, config config.AddPrefix, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
var result *addPrefix
|
||||
|
||||
if len(config.Prefix) > 0 {
|
||||
result = &addPrefix{
|
||||
prefix: config.Prefix,
|
||||
next: next,
|
||||
name: name,
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("prefix cannot be empty")
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (ap *addPrefix) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return ap.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (ap *addPrefix) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
logger := middlewares.GetLogger(req.Context(), ap.name, typeName)
|
||||
|
||||
oldURLPath := req.URL.Path
|
||||
req.URL.Path = ap.prefix + req.URL.Path
|
||||
logger.Debugf("URL.Path is now %s (was %s).", req.URL.Path, oldURLPath)
|
||||
|
||||
if req.URL.RawPath != "" {
|
||||
oldURLRawPath := req.URL.RawPath
|
||||
req.URL.RawPath = ap.prefix + req.URL.RawPath
|
||||
logger.Debugf("URL.RawPath is now %s (was %s).", req.URL.RawPath, oldURLRawPath)
|
||||
}
|
||||
req.RequestURI = req.URL.RequestURI()
|
||||
|
||||
ap.next.ServeHTTP(rw, req)
|
||||
}
|
104
pkg/middlewares/addprefix/add_prefix_test.go
Normal file
104
pkg/middlewares/addprefix/add_prefix_test.go
Normal file
|
@ -0,0 +1,104 @@
|
|||
package addprefix
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/testhelpers"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewAddPrefix(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
prefix config.AddPrefix
|
||||
expectsError bool
|
||||
}{
|
||||
{
|
||||
desc: "Works with a non empty prefix",
|
||||
prefix: config.AddPrefix{Prefix: "/a"},
|
||||
},
|
||||
{
|
||||
desc: "Fails if prefix is empty",
|
||||
prefix: config.AddPrefix{Prefix: ""},
|
||||
expectsError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
_, err := New(context.Background(), next, test.prefix, "foo-add-prefix")
|
||||
if test.expectsError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddPrefix(t *testing.T) {
|
||||
logrus.SetLevel(logrus.DebugLevel)
|
||||
testCases := []struct {
|
||||
desc string
|
||||
prefix config.AddPrefix
|
||||
path string
|
||||
expectedPath string
|
||||
expectedRawPath string
|
||||
}{
|
||||
{
|
||||
desc: "Works with a regular path",
|
||||
prefix: config.AddPrefix{Prefix: "/a"},
|
||||
path: "/b",
|
||||
expectedPath: "/a/b",
|
||||
},
|
||||
{
|
||||
desc: "Works with a raw path",
|
||||
prefix: config.AddPrefix{Prefix: "/a"},
|
||||
path: "/b%2Fc",
|
||||
expectedPath: "/a/b/c",
|
||||
expectedRawPath: "/a/b%2Fc",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var actualPath, actualRawPath, requestURI string
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
actualPath = r.URL.Path
|
||||
actualRawPath = r.URL.RawPath
|
||||
requestURI = r.RequestURI
|
||||
})
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
|
||||
|
||||
handler, err := New(context.Background(), next, test.prefix, "foo-add-prefix")
|
||||
require.NoError(t, err)
|
||||
|
||||
handler.ServeHTTP(nil, req)
|
||||
|
||||
assert.Equal(t, test.expectedPath, actualPath)
|
||||
assert.Equal(t, test.expectedRawPath, actualRawPath)
|
||||
|
||||
expectedURI := test.expectedPath
|
||||
if test.expectedRawPath != "" {
|
||||
// go HTTP uses the raw path when existent in the RequestURI
|
||||
expectedURI = test.expectedRawPath
|
||||
}
|
||||
assert.Equal(t, expectedURI, requestURI)
|
||||
})
|
||||
}
|
||||
}
|
65
pkg/middlewares/auth/auth.go
Normal file
65
pkg/middlewares/auth/auth.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// UserParser Parses a string and return a userName/userHash. An error if the format of the string is incorrect.
|
||||
type UserParser func(user string) (string, string, error)
|
||||
|
||||
const (
|
||||
defaultRealm = "traefik"
|
||||
authorizationHeader = "Authorization"
|
||||
)
|
||||
|
||||
func getUsers(fileName string, appendUsers []string, parser UserParser) (map[string]string, error) {
|
||||
users, err := loadUsers(fileName, appendUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userMap := make(map[string]string)
|
||||
for _, user := range users {
|
||||
userName, userHash, err := parser(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userMap[userName] = userHash
|
||||
}
|
||||
|
||||
return userMap, nil
|
||||
}
|
||||
|
||||
func loadUsers(fileName string, appendUsers []string) ([]string, error) {
|
||||
var users []string
|
||||
var err error
|
||||
|
||||
if fileName != "" {
|
||||
users, err = getLinesFromFile(fileName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return append(users, appendUsers...), nil
|
||||
}
|
||||
|
||||
func getLinesFromFile(filename string) ([]string, error) {
|
||||
dat, err := ioutil.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Trim lines and filter out blanks
|
||||
rawLines := strings.Split(string(dat), "\n")
|
||||
var filteredLines []string
|
||||
for _, rawLine := range rawLines {
|
||||
line := strings.TrimSpace(rawLine)
|
||||
if line != "" && !strings.HasPrefix(line, "#") {
|
||||
filteredLines = append(filteredLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredLines, nil
|
||||
}
|
102
pkg/middlewares/auth/basic_auth.go
Normal file
102
pkg/middlewares/auth/basic_auth.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
goauth "github.com/abbot/go-http-auth"
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/middlewares"
|
||||
"github.com/containous/traefik/pkg/middlewares/accesslog"
|
||||
"github.com/containous/traefik/pkg/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
)
|
||||
|
||||
const (
|
||||
basicTypeName = "BasicAuth"
|
||||
)
|
||||
|
||||
type basicAuth struct {
|
||||
next http.Handler
|
||||
auth *goauth.BasicAuth
|
||||
users map[string]string
|
||||
headerField string
|
||||
removeHeader bool
|
||||
name string
|
||||
}
|
||||
|
||||
// NewBasic creates a basicAuth middleware.
|
||||
func NewBasic(ctx context.Context, next http.Handler, authConfig config.BasicAuth, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, basicTypeName).Debug("Creating middleware")
|
||||
users, err := getUsers(authConfig.UsersFile, authConfig.Users, basicUserParser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ba := &basicAuth{
|
||||
next: next,
|
||||
users: users,
|
||||
headerField: authConfig.HeaderField,
|
||||
removeHeader: authConfig.RemoveHeader,
|
||||
name: name,
|
||||
}
|
||||
|
||||
realm := defaultRealm
|
||||
if len(authConfig.Realm) > 0 {
|
||||
realm = authConfig.Realm
|
||||
}
|
||||
ba.auth = goauth.NewBasicAuthenticator(realm, ba.secretBasic)
|
||||
|
||||
return ba, nil
|
||||
}
|
||||
|
||||
func (b *basicAuth) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return b.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (b *basicAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
logger := middlewares.GetLogger(req.Context(), b.name, basicTypeName)
|
||||
|
||||
if username := b.auth.CheckAuth(req); username == "" {
|
||||
logger.Debug("Authentication failed")
|
||||
tracing.SetErrorWithEvent(req, "Authentication failed")
|
||||
b.auth.RequireAuth(rw, req)
|
||||
} else {
|
||||
logger.Debug("Authentication succeeded")
|
||||
req.URL.User = url.User(username)
|
||||
|
||||
logData := accesslog.GetLogData(req)
|
||||
if logData != nil {
|
||||
logData.Core[accesslog.ClientUsername] = username
|
||||
}
|
||||
|
||||
if b.headerField != "" {
|
||||
req.Header[b.headerField] = []string{username}
|
||||
}
|
||||
|
||||
if b.removeHeader {
|
||||
logger.Debug("Removing authorization header")
|
||||
req.Header.Del(authorizationHeader)
|
||||
}
|
||||
b.next.ServeHTTP(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *basicAuth) secretBasic(user, realm string) string {
|
||||
if secret, ok := b.users[user]; ok {
|
||||
return secret
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func basicUserParser(user string) (string, string, error) {
|
||||
split := strings.Split(user, ":")
|
||||
if len(split) != 2 {
|
||||
return "", "", fmt.Errorf("error parsing BasicUser: %v", user)
|
||||
}
|
||||
return split[0], split[1], nil
|
||||
}
|
284
pkg/middlewares/auth/basic_auth_test.go
Normal file
284
pkg/middlewares/auth/basic_auth_test.go
Normal file
|
@ -0,0 +1,284 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBasicAuthFail(t *testing.T) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
auth := config.BasicAuth{
|
||||
Users: []string{"test"},
|
||||
}
|
||||
_, err := NewBasic(context.Background(), next, auth, "authName")
|
||||
require.Error(t, err)
|
||||
|
||||
auth2 := config.BasicAuth{
|
||||
Users: []string{"test:test"},
|
||||
}
|
||||
authMiddleware, err := NewBasic(context.Background(), next, auth2, "authTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(authMiddleware)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, res.StatusCode, "they should be equal")
|
||||
}
|
||||
|
||||
func TestBasicAuthSuccess(t *testing.T) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
auth := config.BasicAuth{
|
||||
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
|
||||
}
|
||||
authMiddleware, err := NewBasic(context.Background(), next, auth, "authName")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(authMiddleware)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, "traefik\n", string(body), "they should be equal")
|
||||
}
|
||||
|
||||
func TestBasicAuthUserHeader(t *testing.T) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "test", r.Header["X-Webauth-User"][0], "auth user should be set")
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
auth := config.BasicAuth{
|
||||
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
|
||||
HeaderField: "X-Webauth-User",
|
||||
}
|
||||
middleware, err := NewBasic(context.Background(), next, auth, "authName")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(middleware)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
|
||||
assert.Equal(t, "traefik\n", string(body))
|
||||
}
|
||||
|
||||
func TestBasicAuthHeaderRemoved(t *testing.T) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Empty(t, r.Header.Get(authorizationHeader))
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
auth := config.BasicAuth{
|
||||
RemoveHeader: true,
|
||||
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
|
||||
}
|
||||
middleware, err := NewBasic(context.Background(), next, auth, "authName")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(middleware)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
err = res.Body.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "traefik\n", string(body))
|
||||
}
|
||||
|
||||
func TestBasicAuthHeaderPresent(t *testing.T) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.NotEmpty(t, r.Header.Get(authorizationHeader))
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
auth := config.BasicAuth{
|
||||
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
|
||||
}
|
||||
middleware, err := NewBasic(context.Background(), next, auth, "authName")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(middleware)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
err = res.Body.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "traefik\n", string(body))
|
||||
}
|
||||
|
||||
func TestBasicAuthUsersFromFile(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
userFileContent string
|
||||
expectedUsers map[string]string
|
||||
givenUsers []string
|
||||
realm string
|
||||
}{
|
||||
{
|
||||
desc: "Finds the users in the file",
|
||||
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
|
||||
givenUsers: []string{},
|
||||
expectedUsers: map[string]string{"test": "test", "test2": "test2"},
|
||||
},
|
||||
{
|
||||
desc: "Merges given users with users from the file",
|
||||
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\n",
|
||||
givenUsers: []string{"test2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0", "test3:$apr1$3rJbDP0q$RfzJiorTk78jQ1EcKqWso0"},
|
||||
expectedUsers: map[string]string{"test": "test", "test2": "test2", "test3": "test3"},
|
||||
},
|
||||
{
|
||||
desc: "Given users have priority over users in the file",
|
||||
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
|
||||
givenUsers: []string{"test2:$apr1$mK.GtItK$ncnLYvNLek0weXdxo68690"},
|
||||
expectedUsers: map[string]string{"test": "test", "test2": "overridden"},
|
||||
},
|
||||
{
|
||||
desc: "Should authenticate the correct user based on the realm",
|
||||
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
|
||||
givenUsers: []string{"test2:$apr1$mK.GtItK$ncnLYvNLek0weXdxo68690"},
|
||||
expectedUsers: map[string]string{"test": "test", "test2": "overridden"},
|
||||
realm: "traefik",
|
||||
},
|
||||
{
|
||||
desc: "Should skip comments",
|
||||
userFileContent: "#Comment\ntest:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
|
||||
givenUsers: []string{},
|
||||
expectedUsers: map[string]string{"test": "test", "test2": "test2"},
|
||||
realm: "traefiker",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
if test.desc != "Should skip comments" {
|
||||
continue
|
||||
}
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Creates the temporary configuration file with the users
|
||||
usersFile, err := ioutil.TempFile("", "auth-users")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(usersFile.Name())
|
||||
|
||||
_, err = usersFile.Write([]byte(test.userFileContent))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Creates the configuration for our Authenticator
|
||||
authenticatorConfiguration := config.BasicAuth{
|
||||
Users: test.givenUsers,
|
||||
UsersFile: usersFile.Name(),
|
||||
Realm: test.realm,
|
||||
}
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
authenticator, err := NewBasic(context.Background(), next, authenticatorConfiguration, "authName")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(authenticator)
|
||||
defer ts.Close()
|
||||
|
||||
for userName, userPwd := range test.expectedUsers {
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth(userName, userPwd)
|
||||
|
||||
var res *http.Response
|
||||
res, err = http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, http.StatusOK, res.StatusCode, "Cannot authenticate user "+userName)
|
||||
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
err = res.Body.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "traefik\n", string(body))
|
||||
}
|
||||
|
||||
// Checks that user foo doesn't work
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("foo", "foo")
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
if len(test.realm) > 0 {
|
||||
require.Equal(t, `Basic realm="`+test.realm+`"`, res.Header.Get("WWW-Authenticate"))
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
err = res.Body.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotContains(t, "traefik", string(body))
|
||||
})
|
||||
}
|
||||
}
|
102
pkg/middlewares/auth/digest_auth.go
Normal file
102
pkg/middlewares/auth/digest_auth.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
goauth "github.com/abbot/go-http-auth"
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/middlewares"
|
||||
"github.com/containous/traefik/pkg/middlewares/accesslog"
|
||||
"github.com/containous/traefik/pkg/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
)
|
||||
|
||||
const (
|
||||
digestTypeName = "digestAuth"
|
||||
)
|
||||
|
||||
type digestAuth struct {
|
||||
next http.Handler
|
||||
auth *goauth.DigestAuth
|
||||
users map[string]string
|
||||
headerField string
|
||||
removeHeader bool
|
||||
name string
|
||||
}
|
||||
|
||||
// NewDigest creates a digest auth middleware.
|
||||
func NewDigest(ctx context.Context, next http.Handler, authConfig config.DigestAuth, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, digestTypeName).Debug("Creating middleware")
|
||||
users, err := getUsers(authConfig.UsersFile, authConfig.Users, digestUserParser)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
da := &digestAuth{
|
||||
next: next,
|
||||
users: users,
|
||||
headerField: authConfig.HeaderField,
|
||||
removeHeader: authConfig.RemoveHeader,
|
||||
name: name,
|
||||
}
|
||||
|
||||
realm := defaultRealm
|
||||
if len(authConfig.Realm) > 0 {
|
||||
realm = authConfig.Realm
|
||||
}
|
||||
da.auth = goauth.NewDigestAuthenticator(realm, da.secretDigest)
|
||||
|
||||
return da, nil
|
||||
}
|
||||
|
||||
func (d *digestAuth) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return d.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (d *digestAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
logger := middlewares.GetLogger(req.Context(), d.name, digestTypeName)
|
||||
|
||||
if username, _ := d.auth.CheckAuth(req); username == "" {
|
||||
logger.Debug("Digest authentication failed")
|
||||
tracing.SetErrorWithEvent(req, "Digest authentication failed")
|
||||
d.auth.RequireAuth(rw, req)
|
||||
} else {
|
||||
logger.Debug("Digest authentication succeeded")
|
||||
req.URL.User = url.User(username)
|
||||
|
||||
logData := accesslog.GetLogData(req)
|
||||
if logData != nil {
|
||||
logData.Core[accesslog.ClientUsername] = username
|
||||
}
|
||||
|
||||
if d.headerField != "" {
|
||||
req.Header[d.headerField] = []string{username}
|
||||
}
|
||||
|
||||
if d.removeHeader {
|
||||
logger.Debug("Removing the Authorization header")
|
||||
req.Header.Del(authorizationHeader)
|
||||
}
|
||||
d.next.ServeHTTP(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *digestAuth) secretDigest(user, realm string) string {
|
||||
if secret, ok := d.users[user+":"+realm]; ok {
|
||||
return secret
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func digestUserParser(user string) (string, string, error) {
|
||||
split := strings.Split(user, ":")
|
||||
if len(split) != 3 {
|
||||
return "", "", fmt.Errorf("error parsing DigestUser: %v", user)
|
||||
}
|
||||
return split[0] + ":" + split[1], split[2], nil
|
||||
}
|
141
pkg/middlewares/auth/digest_auth_request_test.go
Normal file
141
pkg/middlewares/auth/digest_auth_request_test.go
Normal file
|
@ -0,0 +1,141 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
algorithm = "algorithm"
|
||||
authorization = "Authorization"
|
||||
nonce = "nonce"
|
||||
opaque = "opaque"
|
||||
qop = "qop"
|
||||
realm = "realm"
|
||||
wwwAuthenticate = "Www-Authenticate"
|
||||
)
|
||||
|
||||
// DigestRequest is a client for digest authentication requests
|
||||
type digestRequest struct {
|
||||
client *http.Client
|
||||
username, password string
|
||||
nonceCount nonceCount
|
||||
}
|
||||
|
||||
type nonceCount int
|
||||
|
||||
func (nc nonceCount) String() string {
|
||||
return fmt.Sprintf("%08x", int(nc))
|
||||
}
|
||||
|
||||
var wanted = []string{algorithm, nonce, opaque, qop, realm}
|
||||
|
||||
// New makes a DigestRequest instance
|
||||
func newDigestRequest(username, password string, client *http.Client) *digestRequest {
|
||||
return &digestRequest{
|
||||
client: client,
|
||||
username: username,
|
||||
password: password,
|
||||
}
|
||||
}
|
||||
|
||||
// Do does requests as http.Do does
|
||||
func (r *digestRequest) Do(req *http.Request) (*http.Response, error) {
|
||||
parts, err := r.makeParts(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if parts != nil {
|
||||
req.Header.Set(authorization, r.makeAuthorization(req, parts))
|
||||
}
|
||||
|
||||
return r.client.Do(req)
|
||||
}
|
||||
|
||||
func (r *digestRequest) makeParts(req *http.Request) (map[string]string, error) {
|
||||
authReq, err := http.NewRequest(req.Method, req.URL.String(), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := r.client.Do(authReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if len(resp.Header[wwwAuthenticate]) == 0 {
|
||||
return nil, fmt.Errorf("headers do not have %s", wwwAuthenticate)
|
||||
}
|
||||
|
||||
headers := strings.Split(resp.Header[wwwAuthenticate][0], ",")
|
||||
parts := make(map[string]string, len(wanted))
|
||||
for _, r := range headers {
|
||||
for _, w := range wanted {
|
||||
if strings.Contains(r, w) {
|
||||
parts[w] = strings.Split(r, `"`)[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(parts) != len(wanted) {
|
||||
return nil, fmt.Errorf("header is invalid: %+v", parts)
|
||||
}
|
||||
|
||||
return parts, nil
|
||||
}
|
||||
|
||||
func getMD5(texts []string) string {
|
||||
h := md5.New()
|
||||
_, _ = io.WriteString(h, strings.Join(texts, ":"))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
func (r *digestRequest) getNonceCount() string {
|
||||
r.nonceCount++
|
||||
return r.nonceCount.String()
|
||||
}
|
||||
|
||||
func (r *digestRequest) makeAuthorization(req *http.Request, parts map[string]string) string {
|
||||
ha1 := getMD5([]string{r.username, parts[realm], r.password})
|
||||
ha2 := getMD5([]string{req.Method, req.URL.String()})
|
||||
cnonce := generateRandom(16)
|
||||
nc := r.getNonceCount()
|
||||
response := getMD5([]string{
|
||||
ha1,
|
||||
parts[nonce],
|
||||
nc,
|
||||
cnonce,
|
||||
parts[qop],
|
||||
ha2,
|
||||
})
|
||||
return fmt.Sprintf(
|
||||
`Digest username="%s", realm="%s", nonce="%s", uri="%s", algorithm=%s, qop=%s, nc=%s, cnonce="%s", response="%s", opaque="%s"`,
|
||||
r.username,
|
||||
parts[realm],
|
||||
parts[nonce],
|
||||
req.URL.String(),
|
||||
parts[algorithm],
|
||||
parts[qop],
|
||||
nc,
|
||||
cnonce,
|
||||
response,
|
||||
parts[opaque],
|
||||
)
|
||||
}
|
||||
|
||||
// GenerateRandom generates random string
|
||||
func generateRandom(n int) string {
|
||||
b := make([]byte, 8)
|
||||
_, _ = io.ReadFull(rand.Reader, b)
|
||||
return fmt.Sprintf("%x", b)[:n]
|
||||
}
|
156
pkg/middlewares/auth/digest_auth_test.go
Normal file
156
pkg/middlewares/auth/digest_auth_test.go
Normal file
|
@ -0,0 +1,156 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDigestAuthError(t *testing.T) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
auth := config.DigestAuth{
|
||||
Users: []string{"test"},
|
||||
}
|
||||
_, err := NewDigest(context.Background(), next, auth, "authName")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestDigestAuthFail(t *testing.T) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
auth := config.DigestAuth{
|
||||
Users: []string{"test:traefik:a2688e031edb4be6a3797f3882655c05"},
|
||||
}
|
||||
authMiddleware, err := NewDigest(context.Background(), next, auth, "authName")
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, authMiddleware, "this should not be nil")
|
||||
|
||||
ts := httptest.NewServer(authMiddleware)
|
||||
defer ts.Close()
|
||||
|
||||
client := http.DefaultClient
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
|
||||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
}
|
||||
|
||||
func TestDigestAuthUsersFromFile(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
userFileContent string
|
||||
expectedUsers map[string]string
|
||||
givenUsers []string
|
||||
realm string
|
||||
}{
|
||||
{
|
||||
desc: "Finds the users in the file",
|
||||
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\ntest2:traefik:518845800f9e2bfb1f1f740ec24f074e\n",
|
||||
givenUsers: []string{},
|
||||
expectedUsers: map[string]string{"test": "test", "test2": "test2"},
|
||||
},
|
||||
{
|
||||
desc: "Merges given users with users from the file",
|
||||
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\n",
|
||||
givenUsers: []string{"test2:traefik:518845800f9e2bfb1f1f740ec24f074e", "test3:traefik:c8e9f57ce58ecb4424407f665a91646c"},
|
||||
expectedUsers: map[string]string{"test": "test", "test2": "test2", "test3": "test3"},
|
||||
},
|
||||
{
|
||||
desc: "Given users have priority over users in the file",
|
||||
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\ntest2:traefik:518845800f9e2bfb1f1f740ec24f074e\n",
|
||||
givenUsers: []string{"test2:traefik:8de60a1c52da68ccf41f0c0ffb7c51a0"},
|
||||
expectedUsers: map[string]string{"test": "test", "test2": "overridden"},
|
||||
},
|
||||
{
|
||||
desc: "Should authenticate the correct user based on the realm",
|
||||
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\ntest:traefiker:a3d334dff2645b914918de78bec50bf4\n",
|
||||
givenUsers: []string{},
|
||||
expectedUsers: map[string]string{"test": "test2"},
|
||||
realm: "traefiker",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Creates the temporary configuration file with the users
|
||||
usersFile, err := ioutil.TempFile("", "auth-users")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(usersFile.Name())
|
||||
|
||||
_, err = usersFile.Write([]byte(test.userFileContent))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Creates the configuration for our Authenticator
|
||||
authenticatorConfiguration := config.DigestAuth{
|
||||
Users: test.givenUsers,
|
||||
UsersFile: usersFile.Name(),
|
||||
Realm: test.realm,
|
||||
}
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
authenticator, err := NewDigest(context.Background(), next, authenticatorConfiguration, "authName")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(authenticator)
|
||||
defer ts.Close()
|
||||
|
||||
for userName, userPwd := range test.expectedUsers {
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
digestRequest := newDigestRequest(userName, userPwd, http.DefaultClient)
|
||||
|
||||
var res *http.Response
|
||||
res, err = digestRequest.Do(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, res.StatusCode, "Cannot authenticate user "+userName)
|
||||
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
err = res.Body.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, "traefik\n", string(body))
|
||||
}
|
||||
|
||||
// Checks that user foo doesn't work
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
digestRequest := newDigestRequest("foo", "foo", http.DefaultClient)
|
||||
|
||||
var res *http.Response
|
||||
res, err = digestRequest.Do(req)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
|
||||
var body []byte
|
||||
body, err = ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
err = res.Body.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotContains(t, "traefik", string(body))
|
||||
})
|
||||
}
|
||||
}
|
213
pkg/middlewares/auth/forward.go
Normal file
213
pkg/middlewares/auth/forward.go
Normal file
|
@ -0,0 +1,213 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/middlewares"
|
||||
"github.com/containous/traefik/pkg/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/vulcand/oxy/forward"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
xForwardedURI = "X-Forwarded-Uri"
|
||||
xForwardedMethod = "X-Forwarded-Method"
|
||||
forwardedTypeName = "ForwardedAuthType"
|
||||
)
|
||||
|
||||
type forwardAuth struct {
|
||||
address string
|
||||
authResponseHeaders []string
|
||||
next http.Handler
|
||||
name string
|
||||
tlsConfig *tls.Config
|
||||
trustForwardHeader bool
|
||||
}
|
||||
|
||||
// NewForward creates a forward auth middleware.
|
||||
func NewForward(ctx context.Context, next http.Handler, config config.ForwardAuth, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, forwardedTypeName).Debug("Creating middleware")
|
||||
|
||||
fa := &forwardAuth{
|
||||
address: config.Address,
|
||||
authResponseHeaders: config.AuthResponseHeaders,
|
||||
next: next,
|
||||
name: name,
|
||||
trustForwardHeader: config.TrustForwardHeader,
|
||||
}
|
||||
|
||||
if config.TLS != nil {
|
||||
tlsConfig, err := config.TLS.CreateTLSConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fa.tlsConfig = tlsConfig
|
||||
}
|
||||
|
||||
return fa, nil
|
||||
}
|
||||
|
||||
func (fa *forwardAuth) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return fa.name, ext.SpanKindRPCClientEnum
|
||||
}
|
||||
|
||||
func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
logger := middlewares.GetLogger(req.Context(), fa.name, forwardedTypeName)
|
||||
|
||||
// Ensure our request client does not follow redirects
|
||||
httpClient := http.Client{
|
||||
CheckRedirect: func(r *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
|
||||
if fa.tlsConfig != nil {
|
||||
httpClient.Transport = &http.Transport{
|
||||
TLSClientConfig: fa.tlsConfig,
|
||||
}
|
||||
}
|
||||
|
||||
forwardReq, err := http.NewRequest(http.MethodGet, fa.address, nil)
|
||||
tracing.LogRequest(tracing.GetSpan(req), forwardReq)
|
||||
if err != nil {
|
||||
logMessage := fmt.Sprintf("Error calling %s. Cause %s", fa.address, err)
|
||||
logger.Debug(logMessage)
|
||||
tracing.SetErrorWithEvent(req, logMessage)
|
||||
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeHeader(req, forwardReq, fa.trustForwardHeader)
|
||||
|
||||
tracing.InjectRequestHeaders(forwardReq)
|
||||
|
||||
forwardResponse, forwardErr := httpClient.Do(forwardReq)
|
||||
if forwardErr != nil {
|
||||
logMessage := fmt.Sprintf("Error calling %s. Cause: %s", fa.address, forwardErr)
|
||||
logger.Debug(logMessage)
|
||||
tracing.SetErrorWithEvent(req, logMessage)
|
||||
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
body, readError := ioutil.ReadAll(forwardResponse.Body)
|
||||
if readError != nil {
|
||||
logMessage := fmt.Sprintf("Error reading body %s. Cause: %s", fa.address, readError)
|
||||
logger.Debug(logMessage)
|
||||
tracing.SetErrorWithEvent(req, logMessage)
|
||||
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer forwardResponse.Body.Close()
|
||||
|
||||
// Pass the forward response's body and selected headers if it
|
||||
// didn't return a response within the range of [200, 300).
|
||||
if forwardResponse.StatusCode < http.StatusOK || forwardResponse.StatusCode >= http.StatusMultipleChoices {
|
||||
logger.Debugf("Remote error %s. StatusCode: %d", fa.address, forwardResponse.StatusCode)
|
||||
|
||||
utils.CopyHeaders(rw.Header(), forwardResponse.Header)
|
||||
utils.RemoveHeaders(rw.Header(), forward.HopHeaders...)
|
||||
|
||||
// Grab the location header, if any.
|
||||
redirectURL, err := forwardResponse.Location()
|
||||
|
||||
if err != nil {
|
||||
if err != http.ErrNoLocation {
|
||||
logMessage := fmt.Sprintf("Error reading response location header %s. Cause: %s", fa.address, err)
|
||||
logger.Debug(logMessage)
|
||||
tracing.SetErrorWithEvent(req, logMessage)
|
||||
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
} else if redirectURL.String() != "" {
|
||||
// Set the location in our response if one was sent back.
|
||||
rw.Header().Set("Location", redirectURL.String())
|
||||
}
|
||||
|
||||
tracing.LogResponseCode(tracing.GetSpan(req), forwardResponse.StatusCode)
|
||||
rw.WriteHeader(forwardResponse.StatusCode)
|
||||
|
||||
if _, err = rw.Write(body); err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, headerName := range fa.authResponseHeaders {
|
||||
req.Header.Set(headerName, forwardResponse.Header.Get(headerName))
|
||||
}
|
||||
|
||||
req.RequestURI = req.URL.RequestURI()
|
||||
fa.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
func writeHeader(req *http.Request, forwardReq *http.Request, trustForwardHeader bool) {
|
||||
utils.CopyHeaders(forwardReq.Header, req.Header)
|
||||
utils.RemoveHeaders(forwardReq.Header, forward.HopHeaders...)
|
||||
|
||||
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||
if trustForwardHeader {
|
||||
if prior, ok := req.Header[forward.XForwardedFor]; ok {
|
||||
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||||
}
|
||||
}
|
||||
forwardReq.Header.Set(forward.XForwardedFor, clientIP)
|
||||
}
|
||||
|
||||
xMethod := req.Header.Get(xForwardedMethod)
|
||||
switch {
|
||||
case xMethod != "" && trustForwardHeader:
|
||||
forwardReq.Header.Set(xForwardedMethod, xMethod)
|
||||
case req.Method != "":
|
||||
forwardReq.Header.Set(xForwardedMethod, req.Method)
|
||||
default:
|
||||
forwardReq.Header.Del(xForwardedMethod)
|
||||
}
|
||||
|
||||
xfp := req.Header.Get(forward.XForwardedProto)
|
||||
switch {
|
||||
case xfp != "" && trustForwardHeader:
|
||||
forwardReq.Header.Set(forward.XForwardedProto, xfp)
|
||||
case req.TLS != nil:
|
||||
forwardReq.Header.Set(forward.XForwardedProto, "https")
|
||||
default:
|
||||
forwardReq.Header.Set(forward.XForwardedProto, "http")
|
||||
}
|
||||
|
||||
if xfp := req.Header.Get(forward.XForwardedPort); xfp != "" && trustForwardHeader {
|
||||
forwardReq.Header.Set(forward.XForwardedPort, xfp)
|
||||
}
|
||||
|
||||
xfh := req.Header.Get(forward.XForwardedHost)
|
||||
switch {
|
||||
case xfh != "" && trustForwardHeader:
|
||||
forwardReq.Header.Set(forward.XForwardedHost, xfh)
|
||||
case req.Host != "":
|
||||
forwardReq.Header.Set(forward.XForwardedHost, req.Host)
|
||||
default:
|
||||
forwardReq.Header.Del(forward.XForwardedHost)
|
||||
}
|
||||
|
||||
xfURI := req.Header.Get(xForwardedURI)
|
||||
switch {
|
||||
case xfURI != "" && trustForwardHeader:
|
||||
forwardReq.Header.Set(xForwardedURI, xfURI)
|
||||
case req.URL.RequestURI() != "":
|
||||
forwardReq.Header.Set(xForwardedURI, req.URL.RequestURI())
|
||||
default:
|
||||
forwardReq.Header.Del(xForwardedURI)
|
||||
}
|
||||
}
|
393
pkg/middlewares/auth/forward_test.go
Normal file
393
pkg/middlewares/auth/forward_test.go
Normal file
|
@ -0,0 +1,393 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vulcand/oxy/forward"
|
||||
)
|
||||
|
||||
func TestForwardAuthFail(t *testing.T) {
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
middleware, err := NewForward(context.Background(), next, config.ForwardAuth{
|
||||
Address: server.URL,
|
||||
}, "authTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(middleware)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusForbidden, res.StatusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
err = res.Body.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "Forbidden\n", string(body))
|
||||
}
|
||||
|
||||
func TestForwardAuthSuccess(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Auth-User", "user@example.com")
|
||||
w.Header().Set("X-Auth-Secret", "secret")
|
||||
fmt.Fprintln(w, "Success")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "user@example.com", r.Header.Get("X-Auth-User"))
|
||||
assert.Empty(t, r.Header.Get("X-Auth-Secret"))
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
auth := config.ForwardAuth{
|
||||
Address: server.URL,
|
||||
AuthResponseHeaders: []string{"X-Auth-User"},
|
||||
}
|
||||
middleware, err := NewForward(context.Background(), next, auth, "authTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(middleware)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
err = res.Body.Close()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "traefik\n", string(body))
|
||||
}
|
||||
|
||||
func TestForwardAuthRedirect(t *testing.T) {
|
||||
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "http://example.com/redirect-test", http.StatusFound)
|
||||
}))
|
||||
defer authTs.Close()
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
auth := config.ForwardAuth{
|
||||
Address: authTs.URL,
|
||||
}
|
||||
authMiddleware, err := NewForward(context.Background(), next, auth, "authTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(authMiddleware)
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{
|
||||
CheckRedirect: func(r *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
|
||||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusFound, res.StatusCode)
|
||||
|
||||
location, err := res.Location()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "http://example.com/redirect-test", location.String())
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
err = res.Body.Close()
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, string(body))
|
||||
}
|
||||
|
||||
func TestForwardAuthRemoveHopByHopHeaders(t *testing.T) {
|
||||
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
headers := w.Header()
|
||||
for _, header := range forward.HopHeaders {
|
||||
if header == forward.TransferEncoding {
|
||||
headers.Add(header, "identity")
|
||||
} else {
|
||||
headers.Add(header, "test")
|
||||
}
|
||||
}
|
||||
|
||||
http.Redirect(w, r, "http://example.com/redirect-test", http.StatusFound)
|
||||
}))
|
||||
defer authTs.Close()
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
auth := config.ForwardAuth{
|
||||
Address: authTs.URL,
|
||||
}
|
||||
authMiddleware, err := NewForward(context.Background(), next, auth, "authTest")
|
||||
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
|
||||
ts := httptest.NewServer(authMiddleware)
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{
|
||||
CheckRedirect: func(r *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
res, err := client.Do(req)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, http.StatusFound, res.StatusCode, "they should be equal")
|
||||
|
||||
for _, header := range forward.HopHeaders {
|
||||
assert.Equal(t, "", res.Header.Get(header), "hop-by-hop header '%s' mustn't be set", header)
|
||||
}
|
||||
|
||||
location, err := res.Location()
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, "http://example.com/redirect-test", location.String(), "they should be equal")
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.NotEmpty(t, string(body), "there should be something in the body")
|
||||
}
|
||||
|
||||
func TestForwardAuthFailResponseHeaders(t *testing.T) {
|
||||
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cookie := &http.Cookie{Name: "example", Value: "testing", Path: "/"}
|
||||
http.SetCookie(w, cookie)
|
||||
w.Header().Add("X-Foo", "bar")
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
}))
|
||||
defer authTs.Close()
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
auth := config.ForwardAuth{
|
||||
Address: authTs.URL,
|
||||
}
|
||||
authMiddleware, err := NewForward(context.Background(), next, auth, "authTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(authMiddleware)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusForbidden, res.StatusCode)
|
||||
|
||||
require.Len(t, res.Cookies(), 1)
|
||||
for _, cookie := range res.Cookies() {
|
||||
assert.Equal(t, "testing", cookie.Value)
|
||||
}
|
||||
|
||||
expectedHeaders := http.Header{
|
||||
"Content-Length": []string{"10"},
|
||||
"Content-Type": []string{"text/plain; charset=utf-8"},
|
||||
"X-Foo": []string{"bar"},
|
||||
"Set-Cookie": []string{"example=testing; Path=/"},
|
||||
"X-Content-Type-Options": []string{"nosniff"},
|
||||
}
|
||||
|
||||
assert.Len(t, res.Header, 6)
|
||||
for key, value := range expectedHeaders {
|
||||
assert.Equal(t, value, res.Header[key])
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
err = res.Body.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "Forbidden\n", string(body))
|
||||
}
|
||||
|
||||
func Test_writeHeader(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
trustForwardHeader bool
|
||||
emptyHost bool
|
||||
expectedHeaders map[string]string
|
||||
checkForUnexpectedHeaders bool
|
||||
}{
|
||||
{
|
||||
name: "trust Forward Header",
|
||||
headers: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "fii.bir",
|
||||
},
|
||||
trustForwardHeader: true,
|
||||
expectedHeaders: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "fii.bir",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not trust Forward Header",
|
||||
headers: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "fii.bir",
|
||||
},
|
||||
trustForwardHeader: false,
|
||||
expectedHeaders: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "foo.bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "trust Forward Header with empty Host",
|
||||
headers: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "fii.bir",
|
||||
},
|
||||
trustForwardHeader: true,
|
||||
emptyHost: true,
|
||||
expectedHeaders: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "fii.bir",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not trust Forward Header with empty Host",
|
||||
headers: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "fii.bir",
|
||||
},
|
||||
trustForwardHeader: false,
|
||||
emptyHost: true,
|
||||
expectedHeaders: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "trust Forward Header with forwarded URI",
|
||||
headers: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "fii.bir",
|
||||
"X-Forwarded-Uri": "/forward?q=1",
|
||||
},
|
||||
trustForwardHeader: true,
|
||||
expectedHeaders: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "fii.bir",
|
||||
"X-Forwarded-Uri": "/forward?q=1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not trust Forward Header with forward requested URI",
|
||||
headers: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "fii.bir",
|
||||
"X-Forwarded-Uri": "/forward?q=1",
|
||||
},
|
||||
trustForwardHeader: false,
|
||||
expectedHeaders: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "foo.bar",
|
||||
"X-Forwarded-Uri": "/path?q=1",
|
||||
},
|
||||
}, {
|
||||
name: "trust Forward Header with forwarded request Method",
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Method": "OPTIONS",
|
||||
},
|
||||
trustForwardHeader: true,
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-Method": "OPTIONS",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not trust Forward Header with forward request Method",
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Method": "OPTIONS",
|
||||
},
|
||||
trustForwardHeader: false,
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-Method": "GET",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "remove hop-by-hop headers",
|
||||
headers: map[string]string{
|
||||
forward.Connection: "Connection",
|
||||
forward.KeepAlive: "KeepAlive",
|
||||
forward.ProxyAuthenticate: "ProxyAuthenticate",
|
||||
forward.ProxyAuthorization: "ProxyAuthorization",
|
||||
forward.Te: "Te",
|
||||
forward.Trailers: "Trailers",
|
||||
forward.TransferEncoding: "TransferEncoding",
|
||||
forward.Upgrade: "Upgrade",
|
||||
"X-CustomHeader": "CustomHeader",
|
||||
},
|
||||
trustForwardHeader: false,
|
||||
expectedHeaders: map[string]string{
|
||||
"X-CustomHeader": "CustomHeader",
|
||||
"X-Forwarded-Proto": "http",
|
||||
"X-Forwarded-Host": "foo.bar",
|
||||
"X-Forwarded-Uri": "/path?q=1",
|
||||
"X-Forwarded-Method": "GET",
|
||||
},
|
||||
checkForUnexpectedHeaders: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/path?q=1", nil)
|
||||
for key, value := range test.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
if test.emptyHost {
|
||||
req.Host = ""
|
||||
}
|
||||
|
||||
forwardReq := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/path?q=1", nil)
|
||||
|
||||
writeHeader(req, forwardReq, test.trustForwardHeader)
|
||||
|
||||
actualHeaders := forwardReq.Header
|
||||
expectedHeaders := test.expectedHeaders
|
||||
for key, value := range expectedHeaders {
|
||||
assert.Equal(t, value, actualHeaders.Get(key))
|
||||
actualHeaders.Del(key)
|
||||
}
|
||||
if test.checkForUnexpectedHeaders {
|
||||
for key := range actualHeaders {
|
||||
assert.Fail(t, "Unexpected header found", key)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
54
pkg/middlewares/buffering/buffering.go
Normal file
54
pkg/middlewares/buffering/buffering.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package buffering
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/middlewares"
|
||||
"github.com/containous/traefik/pkg/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
oxybuffer "github.com/vulcand/oxy/buffer"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "Buffer"
|
||||
)
|
||||
|
||||
type buffer struct {
|
||||
name string
|
||||
buffer *oxybuffer.Buffer
|
||||
}
|
||||
|
||||
// New creates a buffering middleware.
|
||||
func New(ctx context.Context, next http.Handler, config config.Buffering, name string) (http.Handler, error) {
|
||||
logger := middlewares.GetLogger(ctx, name, typeName)
|
||||
logger.Debug("Creating middleware")
|
||||
logger.Debug("Setting up buffering: request limits: %d (mem), %d (max), response limits: %d (mem), %d (max) with retry: '%s'",
|
||||
config.MemRequestBodyBytes, config.MaxRequestBodyBytes, config.MemResponseBodyBytes, config.MaxResponseBodyBytes, config.RetryExpression)
|
||||
|
||||
oxyBuffer, err := oxybuffer.New(
|
||||
next,
|
||||
oxybuffer.MemRequestBodyBytes(config.MemRequestBodyBytes),
|
||||
oxybuffer.MaxRequestBodyBytes(config.MaxRequestBodyBytes),
|
||||
oxybuffer.MemResponseBodyBytes(config.MemResponseBodyBytes),
|
||||
oxybuffer.MaxResponseBodyBytes(config.MaxResponseBodyBytes),
|
||||
oxybuffer.CondSetter(len(config.RetryExpression) > 0, oxybuffer.Retry(config.RetryExpression)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &buffer{
|
||||
name: name,
|
||||
buffer: oxyBuffer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *buffer) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return b.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (b *buffer) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
b.buffer.ServeHTTP(rw, req)
|
||||
}
|
26
pkg/middlewares/chain/chain.go
Normal file
26
pkg/middlewares/chain/chain.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package chain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/alice"
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/middlewares"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "Chain"
|
||||
)
|
||||
|
||||
type chainBuilder interface {
|
||||
BuildChain(ctx context.Context, middlewares []string) *alice.Chain
|
||||
}
|
||||
|
||||
// New creates a chain middleware
|
||||
func New(ctx context.Context, next http.Handler, config config.Chain, builder chainBuilder, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
|
||||
middlewareChain := builder.BuildChain(ctx, config.Middlewares)
|
||||
return middlewareChain.Then(next)
|
||||
}
|
61
pkg/middlewares/circuitbreaker/circuit_breaker.go
Normal file
61
pkg/middlewares/circuitbreaker/circuit_breaker.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package circuitbreaker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
"github.com/containous/traefik/pkg/middlewares"
|
||||
"github.com/containous/traefik/pkg/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/vulcand/oxy/cbreaker"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "CircuitBreaker"
|
||||
)
|
||||
|
||||
type circuitBreaker struct {
|
||||
circuitBreaker *cbreaker.CircuitBreaker
|
||||
name string
|
||||
}
|
||||
|
||||
// New creates a new circuit breaker middleware.
|
||||
func New(ctx context.Context, next http.Handler, confCircuitBreaker config.CircuitBreaker, name string) (http.Handler, error) {
|
||||
expression := confCircuitBreaker.Expression
|
||||
|
||||
logger := middlewares.GetLogger(ctx, name, typeName)
|
||||
logger.Debug("Creating middleware")
|
||||
logger.Debug("Setting up with expression: %s", expression)
|
||||
|
||||
oxyCircuitBreaker, err := cbreaker.New(next, expression, createCircuitBreakerOptions(expression))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &circuitBreaker{
|
||||
circuitBreaker: oxyCircuitBreaker,
|
||||
name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewCircuitBreakerOptions returns a new CircuitBreakerOption
|
||||
func createCircuitBreakerOptions(expression string) cbreaker.CircuitBreakerOption {
|
||||
return cbreaker.Fallback(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
tracing.SetErrorWithEvent(req, "blocked by circuit-breaker (%q)", expression)
|
||||
rw.WriteHeader(http.StatusServiceUnavailable)
|
||||
|
||||
if _, err := rw.Write([]byte(http.StatusText(http.StatusServiceUnavailable))); err != nil {
|
||||
log.FromContext(req.Context()).Error(err)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func (c *circuitBreaker) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return c.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (c *circuitBreaker) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
middlewares.GetLogger(req.Context(), c.name, typeName).Debug("Entering middleware")
|
||||
c.circuitBreaker.ServeHTTP(rw, req)
|
||||
}
|
58
pkg/middlewares/compress/compress.go
Normal file
58
pkg/middlewares/compress/compress.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package compress
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/containous/traefik/pkg/middlewares"
|
||||
"github.com/containous/traefik/pkg/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "Compress"
|
||||
)
|
||||
|
||||
// Compress is a middleware that allows to compress the response.
|
||||
type compress struct {
|
||||
next http.Handler
|
||||
name string
|
||||
}
|
||||
|
||||
// New creates a new compress middleware.
|
||||
func New(ctx context.Context, next http.Handler, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
|
||||
return &compress{
|
||||
next: next,
|
||||
name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *compress) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
contentType := req.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/grpc") {
|
||||
c.next.ServeHTTP(rw, req)
|
||||
} else {
|
||||
gzipHandler(c.next, middlewares.GetLogger(req.Context(), c.name, typeName)).ServeHTTP(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *compress) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return c.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func gzipHandler(h http.Handler, logger logrus.FieldLogger) http.Handler {
|
||||
wrapper, err := gziphandler.GzipHandlerWithOpts(
|
||||
gziphandler.CompressionLevel(gzip.DefaultCompression),
|
||||
gziphandler.MinSize(gziphandler.DefaultMinSize))
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
|
||||
return wrapper(h)
|
||||
}
|
260
pkg/middlewares/compress/compress_test.go
Normal file
260
pkg/middlewares/compress/compress_test.go
Normal file
|
@ -0,0 +1,260 @@
|
|||
package compress
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/containous/traefik/pkg/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
acceptEncodingHeader = "Accept-Encoding"
|
||||
contentEncodingHeader = "Content-Encoding"
|
||||
contentTypeHeader = "Content-Type"
|
||||
varyHeader = "Vary"
|
||||
gzipValue = "gzip"
|
||||
)
|
||||
|
||||
func TestShouldCompressWhenNoContentEncodingHeader(t *testing.T) {
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
|
||||
req.Header.Add(acceptEncodingHeader, gzipValue)
|
||||
|
||||
baseBody := generateBytes(gziphandler.DefaultMinSize)
|
||||
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
_, err := rw.Write(baseBody)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
handler := &compress{next: next}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
assert.Equal(t, gzipValue, rw.Header().Get(contentEncodingHeader))
|
||||
assert.Equal(t, acceptEncodingHeader, rw.Header().Get(varyHeader))
|
||||
|
||||
if assert.ObjectsAreEqualValues(rw.Body.Bytes(), baseBody) {
|
||||
assert.Fail(t, "expected a compressed body", "got %v", rw.Body.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldNotCompressWhenContentEncodingHeader(t *testing.T) {
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
|
||||
req.Header.Add(acceptEncodingHeader, gzipValue)
|
||||
|
||||
fakeCompressedBody := generateBytes(gziphandler.DefaultMinSize)
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Add(contentEncodingHeader, gzipValue)
|
||||
rw.Header().Add(varyHeader, acceptEncodingHeader)
|
||||
_, err := rw.Write(fakeCompressedBody)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
handler := &compress{next: next}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
assert.Equal(t, gzipValue, rw.Header().Get(contentEncodingHeader))
|
||||
assert.Equal(t, acceptEncodingHeader, rw.Header().Get(varyHeader))
|
||||
|
||||
assert.EqualValues(t, rw.Body.Bytes(), fakeCompressedBody)
|
||||
}
|
||||
|
||||
func TestShouldNotCompressWhenNoAcceptEncodingHeader(t *testing.T) {
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
|
||||
|
||||
fakeBody := generateBytes(gziphandler.DefaultMinSize)
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
_, err := rw.Write(fakeBody)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
handler := &compress{next: next}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
assert.Empty(t, rw.Header().Get(contentEncodingHeader))
|
||||
assert.EqualValues(t, rw.Body.Bytes(), fakeBody)
|
||||
}
|
||||
|
||||
func TestShouldNotCompressWhenGRPC(t *testing.T) {
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
|
||||
req.Header.Add(acceptEncodingHeader, gzipValue)
|
||||
req.Header.Add(contentTypeHeader, "application/grpc")
|
||||
|
||||
baseBody := generateBytes(gziphandler.DefaultMinSize)
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
_, err := rw.Write(baseBody)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
handler := &compress{next: next}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rw, req)
|
||||
|
||||
assert.Empty(t, rw.Header().Get(acceptEncodingHeader))
|
||||
assert.Empty(t, rw.Header().Get(contentEncodingHeader))
|
||||
assert.EqualValues(t, rw.Body.Bytes(), baseBody)
|
||||
}
|
||||
|
||||
func TestIntegrationShouldNotCompress(t *testing.T) {
|
||||
fakeCompressedBody := generateBytes(100000)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
handler http.Handler
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "when content already compressed",
|
||||
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Add(contentEncodingHeader, gzipValue)
|
||||
rw.Header().Add(varyHeader, acceptEncodingHeader)
|
||||
_, err := rw.Write(fakeCompressedBody)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}),
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "when content already compressed and status code Created",
|
||||
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Add(contentEncodingHeader, gzipValue)
|
||||
rw.Header().Add(varyHeader, acceptEncodingHeader)
|
||||
rw.WriteHeader(http.StatusCreated)
|
||||
_, err := rw.Write(fakeCompressedBody)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}),
|
||||
expectedStatusCode: http.StatusCreated,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
compress := &compress{next: test.handler}
|
||||
ts := httptest.NewServer(compress)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.Header.Add(acceptEncodingHeader, gzipValue)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.expectedStatusCode, resp.StatusCode)
|
||||
|
||||
assert.Equal(t, gzipValue, resp.Header.Get(contentEncodingHeader))
|
||||
assert.Equal(t, acceptEncodingHeader, resp.Header.Get(varyHeader))
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
assert.EqualValues(t, fakeCompressedBody, body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldWriteHeaderWhenFlush(t *testing.T) {
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Add(contentEncodingHeader, gzipValue)
|
||||
rw.Header().Add(varyHeader, acceptEncodingHeader)
|
||||
rw.WriteHeader(http.StatusUnauthorized)
|
||||
rw.(http.Flusher).Flush()
|
||||
_, err := rw.Write([]byte("short"))
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
handler := &compress{next: next}
|
||||
ts := httptest.NewServer(handler)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.Header.Add(acceptEncodingHeader, gzipValue)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
|
||||
assert.Equal(t, gzipValue, resp.Header.Get(contentEncodingHeader))
|
||||
assert.Equal(t, acceptEncodingHeader, resp.Header.Get(varyHeader))
|
||||
}
|
||||
|
||||
func TestIntegrationShouldCompress(t *testing.T) {
|
||||
fakeBody := generateBytes(100000)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
handler http.Handler
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "when AcceptEncoding header is present",
|
||||
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
_, err := rw.Write(fakeBody)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}),
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "when AcceptEncoding header is present and status code Created",
|
||||
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusCreated)
|
||||
_, err := rw.Write(fakeBody)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}),
|
||||
expectedStatusCode: http.StatusCreated,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
compress := &compress{next: test.handler}
|
||||
ts := httptest.NewServer(compress)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.Header.Add(acceptEncodingHeader, gzipValue)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.expectedStatusCode, resp.StatusCode)
|
||||
|
||||
assert.Equal(t, gzipValue, resp.Header.Get(contentEncodingHeader))
|
||||
assert.Equal(t, acceptEncodingHeader, resp.Header.Get(varyHeader))
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
if assert.ObjectsAreEqualValues(body, fakeBody) {
|
||||
assert.Fail(t, "expected a compressed body", "got %v", body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateBytes(len int) []byte {
|
||||
var value []byte
|
||||
for i := 0; i < len; i++ {
|
||||
value = append(value, 0x61+byte(i))
|
||||
}
|
||||
return value
|
||||
}
|
248
pkg/middlewares/customerrors/custom_errors.go
Normal file
248
pkg/middlewares/customerrors/custom_errors.go
Normal file
|
@ -0,0 +1,248 @@
|
|||
package customerrors
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/middlewares"
|
||||
"github.com/containous/traefik/pkg/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
// Compile time validation that the response recorder implements http interfaces correctly.
|
||||
var _ middlewares.Stateful = &responseRecorderWithCloseNotify{}
|
||||
|
||||
const (
|
||||
typeName = "customError"
|
||||
backendURL = "http://0.0.0.0"
|
||||
)
|
||||
|
||||
type serviceBuilder interface {
|
||||
BuildHTTP(ctx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error)
|
||||
}
|
||||
|
||||
// customErrors is a middleware that provides the custom error pages..
|
||||
type customErrors struct {
|
||||
name string
|
||||
next http.Handler
|
||||
backendHandler http.Handler
|
||||
httpCodeRanges types.HTTPCodeRanges
|
||||
backendQuery string
|
||||
}
|
||||
|
||||
// New creates a new custom error pages middleware.
|
||||
func New(ctx context.Context, next http.Handler, config config.ErrorPage, serviceBuilder serviceBuilder, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
|
||||
httpCodeRanges, err := types.NewHTTPCodeRanges(config.Status)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
backend, err := serviceBuilder.BuildHTTP(ctx, config.Service, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &customErrors{
|
||||
name: name,
|
||||
next: next,
|
||||
backendHandler: backend,
|
||||
httpCodeRanges: httpCodeRanges,
|
||||
backendQuery: config.Query,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *customErrors) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return c.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (c *customErrors) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
logger := middlewares.GetLogger(req.Context(), c.name, typeName)
|
||||
|
||||
if c.backendHandler == nil {
|
||||
logger.Error("Error pages: no backend handler.")
|
||||
tracing.SetErrorWithEvent(req, "Error pages: no backend handler.")
|
||||
c.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
recorder := newResponseRecorder(rw, middlewares.GetLogger(context.Background(), "test", typeName))
|
||||
c.next.ServeHTTP(recorder, req)
|
||||
|
||||
// check the recorder code against the configured http status code ranges
|
||||
for _, block := range c.httpCodeRanges {
|
||||
if recorder.GetCode() >= block[0] && recorder.GetCode() <= block[1] {
|
||||
logger.Errorf("Caught HTTP Status Code %d, returning error page", recorder.GetCode())
|
||||
|
||||
var query string
|
||||
if len(c.backendQuery) > 0 {
|
||||
query = "/" + strings.TrimPrefix(c.backendQuery, "/")
|
||||
query = strings.Replace(query, "{status}", strconv.Itoa(recorder.GetCode()), -1)
|
||||
}
|
||||
|
||||
pageReq, err := newRequest(backendURL + query)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
rw.WriteHeader(recorder.GetCode())
|
||||
_, err = fmt.Fprint(rw, http.StatusText(recorder.GetCode()))
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
recorderErrorPage := newResponseRecorder(rw, middlewares.GetLogger(context.Background(), "test", typeName))
|
||||
utils.CopyHeaders(pageReq.Header, req.Header)
|
||||
|
||||
c.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context()))
|
||||
|
||||
utils.CopyHeaders(rw.Header(), recorderErrorPage.Header())
|
||||
rw.WriteHeader(recorder.GetCode())
|
||||
|
||||
if _, err = rw.Write(recorderErrorPage.GetBody().Bytes()); err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// did not catch a configured status code so proceed with the request
|
||||
utils.CopyHeaders(rw.Header(), recorder.Header())
|
||||
rw.WriteHeader(recorder.GetCode())
|
||||
_, err := rw.Write(recorder.GetBody().Bytes())
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func newRequest(baseURL string) (*http.Request, error) {
|
||||
u, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error pages: error when parse URL: %v", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, u.String(), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error pages: error when create query: %v", err)
|
||||
}
|
||||
|
||||
req.RequestURI = u.RequestURI()
|
||||
return req, nil
|
||||
}
|
||||
|
||||
type responseRecorder interface {
|
||||
http.ResponseWriter
|
||||
http.Flusher
|
||||
GetCode() int
|
||||
GetBody() *bytes.Buffer
|
||||
IsStreamingResponseStarted() bool
|
||||
}
|
||||
|
||||
// newResponseRecorder returns an initialized responseRecorder.
|
||||
func newResponseRecorder(rw http.ResponseWriter, logger logrus.FieldLogger) responseRecorder {
|
||||
recorder := &responseRecorderWithoutCloseNotify{
|
||||
HeaderMap: make(http.Header),
|
||||
Body: new(bytes.Buffer),
|
||||
Code: http.StatusOK,
|
||||
responseWriter: rw,
|
||||
logger: logger,
|
||||
}
|
||||
if _, ok := rw.(http.CloseNotifier); ok {
|
||||
return &responseRecorderWithCloseNotify{recorder}
|
||||
}
|
||||
return recorder
|
||||
}
|
||||
|
||||
// responseRecorderWithoutCloseNotify is an implementation of http.ResponseWriter that
|
||||
// records its mutations for later inspection.
|
||||
type responseRecorderWithoutCloseNotify struct {
|
||||
Code int // the HTTP response code from WriteHeader
|
||||
HeaderMap http.Header // the HTTP response headers
|
||||
Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
|
||||
|
||||
responseWriter http.ResponseWriter
|
||||
err error
|
||||
streamingResponseStarted bool
|
||||
logger logrus.FieldLogger
|
||||
}
|
||||
|
||||
type responseRecorderWithCloseNotify struct {
|
||||
*responseRecorderWithoutCloseNotify
|
||||
}
|
||||
|
||||
// CloseNotify returns a channel that receives at most a
|
||||
// single value (true) when the client connection has gone away.
|
||||
func (r *responseRecorderWithCloseNotify) CloseNotify() <-chan bool {
|
||||
return r.responseWriter.(http.CloseNotifier).CloseNotify()
|
||||
}
|
||||
|
||||
// Header returns the response headers.
|
||||
func (r *responseRecorderWithoutCloseNotify) Header() http.Header {
|
||||
if r.HeaderMap == nil {
|
||||
r.HeaderMap = make(http.Header)
|
||||
}
|
||||
|
||||
return r.HeaderMap
|
||||
}
|
||||
|
||||
func (r *responseRecorderWithoutCloseNotify) GetCode() int {
|
||||
return r.Code
|
||||
}
|
||||
|
||||
func (r *responseRecorderWithoutCloseNotify) GetBody() *bytes.Buffer {
|
||||
return r.Body
|
||||
}
|
||||
|
||||
func (r *responseRecorderWithoutCloseNotify) IsStreamingResponseStarted() bool {
|
||||
return r.streamingResponseStarted
|
||||
}
|
||||
|
||||
// Write always succeeds and writes to rw.Body, if not nil.
|
||||
func (r *responseRecorderWithoutCloseNotify) Write(buf []byte) (int, error) {
|
||||
if r.err != nil {
|
||||
return 0, r.err
|
||||
}
|
||||
return r.Body.Write(buf)
|
||||
}
|
||||
|
||||
// WriteHeader sets rw.Code.
|
||||
func (r *responseRecorderWithoutCloseNotify) WriteHeader(code int) {
|
||||
r.Code = code
|
||||
}
|
||||
|
||||
// Hijack hijacks the connection
|
||||
func (r *responseRecorderWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return r.responseWriter.(http.Hijacker).Hijack()
|
||||
}
|
||||
|
||||
// Flush sends any buffered data to the client.
|
||||
func (r *responseRecorderWithoutCloseNotify) Flush() {
|
||||
if !r.streamingResponseStarted {
|
||||
utils.CopyHeaders(r.responseWriter.Header(), r.Header())
|
||||
r.responseWriter.WriteHeader(r.Code)
|
||||
r.streamingResponseStarted = true
|
||||
}
|
||||
|
||||
_, err := r.responseWriter.Write(r.Body.Bytes())
|
||||
if err != nil {
|
||||
r.logger.Errorf("Error writing response in responseRecorder: %v", err)
|
||||
r.err = err
|
||||
}
|
||||
r.Body.Reset()
|
||||
|
||||
if flusher, ok := r.responseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
176
pkg/middlewares/customerrors/custom_errors_test.go
Normal file
176
pkg/middlewares/customerrors/custom_errors_test.go
Normal file
|
@ -0,0 +1,176 @@
|
|||
package customerrors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/middlewares"
|
||||
"github.com/containous/traefik/pkg/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestHandler(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
errorPage *config.ErrorPage
|
||||
backendCode int
|
||||
backendErrorHandler http.HandlerFunc
|
||||
validate func(t *testing.T, recorder *httptest.ResponseRecorder)
|
||||
}{
|
||||
{
|
||||
desc: "no error",
|
||||
errorPage: &config.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusOK,
|
||||
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "My error page.")
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusOK, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusOK))
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "in the range",
|
||||
errorPage: &config.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusInternalServerError,
|
||||
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "My error page.")
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusInternalServerError, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "My error page.")
|
||||
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "not in the range",
|
||||
errorPage: &config.ErrorPage{Service: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusBadGateway,
|
||||
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "My error page.")
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusBadGateway, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusBadGateway))
|
||||
assert.NotContains(t, recorder.Body.String(), "Test Server", "Should return the oops page since we have not configured the 502 code")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "query replacement",
|
||||
errorPage: &config.ErrorPage{Service: "error", Query: "/{status}", Status: []string{"503-503"}},
|
||||
backendCode: http.StatusServiceUnavailable,
|
||||
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.RequestURI == "/503" {
|
||||
fmt.Fprintln(w, "My 503 page.")
|
||||
} else {
|
||||
fmt.Fprintln(w, "Failed")
|
||||
}
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "My 503 page.")
|
||||
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Single code",
|
||||
errorPage: &config.ErrorPage{Service: "error", Query: "/{status}", Status: []string{"503"}},
|
||||
backendCode: http.StatusServiceUnavailable,
|
||||
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.RequestURI == "/503" {
|
||||
fmt.Fprintln(w, "My 503 page.")
|
||||
} else {
|
||||
fmt.Fprintln(w, "Failed")
|
||||
}
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "My 503 page.")
|
||||
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
serviceBuilderMock := &mockServiceBuilder{handler: test.backendErrorHandler}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(test.backendCode)
|
||||
fmt.Fprintln(w, http.StatusText(test.backendCode))
|
||||
})
|
||||
errorPageHandler, err := New(context.Background(), handler, *test.errorPage, serviceBuilderMock, "test")
|
||||
require.NoError(t, err)
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost/test", nil)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
errorPageHandler.ServeHTTP(recorder, req)
|
||||
|
||||
test.validate(t, recorder)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockServiceBuilder struct {
|
||||
handler http.Handler
|
||||
}
|
||||
|
||||
func (m *mockServiceBuilder) BuildHTTP(_ context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) {
|
||||
return m.handler, nil
|
||||
}
|
||||
|
||||
func TestNewResponseRecorder(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
rw http.ResponseWriter
|
||||
expected http.ResponseWriter
|
||||
}{
|
||||
{
|
||||
desc: "Without Close Notify",
|
||||
rw: httptest.NewRecorder(),
|
||||
expected: &responseRecorderWithoutCloseNotify{},
|
||||
},
|
||||
{
|
||||
desc: "With Close Notify",
|
||||
rw: &mockRWCloseNotify{},
|
||||
expected: &responseRecorderWithCloseNotify{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := newResponseRecorder(test.rw, middlewares.GetLogger(context.Background(), "test", typeName))
|
||||
assert.IsType(t, rec, test.expected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockRWCloseNotify struct{}
|
||||
|
||||
func (m *mockRWCloseNotify) CloseNotify() <-chan bool {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockRWCloseNotify) Header() http.Header {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockRWCloseNotify) Write([]byte) (int, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockRWCloseNotify) WriteHeader(int) {
|
||||
panic("implement me")
|
||||
}
|
33
pkg/middlewares/emptybackendhandler/empty_backend_handler.go
Normal file
33
pkg/middlewares/emptybackendhandler/empty_backend_handler.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package emptybackendhandler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/pkg/healthcheck"
|
||||
)
|
||||
|
||||
// EmptyBackend is a middleware that checks whether the current Backend
|
||||
// has at least one active Server in respect to the healthchecks and if this
|
||||
// is not the case, it will stop the middleware chain and respond with 503.
|
||||
type emptyBackend struct {
|
||||
next healthcheck.BalancerHandler
|
||||
}
|
||||
|
||||
// New creates a new EmptyBackend middleware.
|
||||
func New(lb healthcheck.BalancerHandler) http.Handler {
|
||||
return &emptyBackend{next: lb}
|
||||
}
|
||||
|
||||
// ServeHTTP responds with 503 when there is no active Server and otherwise
|
||||
// invokes the next handler in the middleware chain.
|
||||
func (e *emptyBackend) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if len(e.next.Servers()) == 0 {
|
||||
rw.WriteHeader(http.StatusServiceUnavailable)
|
||||
_, err := rw.Write([]byte(http.StatusText(http.StatusServiceUnavailable)))
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
} else {
|
||||
e.next.ServeHTTP(rw, req)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
package emptybackendhandler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/pkg/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/vulcand/oxy/roundrobin"
|
||||
)
|
||||
|
||||
func TestEmptyBackendHandler(t *testing.T) {
|
||||
testCases := []struct {
|
||||
amountServer int
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
amountServer: 0,
|
||||
expectedStatusCode: http.StatusServiceUnavailable,
|
||||
},
|
||||
{
|
||||
amountServer: 1,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(fmt.Sprintf("amount servers %d", test.amountServer), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := New(&healthCheckLoadBalancer{amountServer: test.amountServer})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, test.expectedStatusCode, recorder.Result().StatusCode)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type healthCheckLoadBalancer struct {
|
||||
amountServer int
|
||||
}
|
||||
|
||||
func (lb *healthCheckLoadBalancer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (lb *healthCheckLoadBalancer) Servers() []*url.URL {
|
||||
servers := make([]*url.URL, lb.amountServer)
|
||||
for i := 0; i < lb.amountServer; i++ {
|
||||
servers = append(servers, testhelpers.MustParseURL("http://localhost"))
|
||||
}
|
||||
return servers
|
||||
}
|
||||
|
||||
func (lb *healthCheckLoadBalancer) RemoveServer(u *url.URL) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lb *healthCheckLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lb *healthCheckLoadBalancer) ServerWeight(u *url.URL) (int, bool) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (lb *healthCheckLoadBalancer) NextServer() (*url.URL, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (lb *healthCheckLoadBalancer) Next() http.Handler {
|
||||
return nil
|
||||
}
|
51
pkg/middlewares/forwardedheaders/forwarded_header.go
Normal file
51
pkg/middlewares/forwardedheaders/forwarded_header.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
package forwardedheaders
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/pkg/ip"
|
||||
"github.com/vulcand/oxy/forward"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
// XForwarded filter for XForwarded headers.
|
||||
type XForwarded struct {
|
||||
insecure bool
|
||||
trustedIps []string
|
||||
ipChecker *ip.Checker
|
||||
next http.Handler
|
||||
}
|
||||
|
||||
// NewXForwarded creates a new XForwarded.
|
||||
func NewXForwarded(insecure bool, trustedIps []string, next http.Handler) (*XForwarded, error) {
|
||||
var ipChecker *ip.Checker
|
||||
if len(trustedIps) > 0 {
|
||||
var err error
|
||||
ipChecker, err = ip.NewChecker(trustedIps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &XForwarded{
|
||||
insecure: insecure,
|
||||
trustedIps: trustedIps,
|
||||
ipChecker: ipChecker,
|
||||
next: next,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (x *XForwarded) isTrustedIP(ip string) bool {
|
||||
if x.ipChecker == nil {
|
||||
return false
|
||||
}
|
||||
return x.ipChecker.IsAuthorized(ip) == nil
|
||||
}
|
||||
|
||||
func (x *XForwarded) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if !x.insecure && !x.isTrustedIP(r.RemoteAddr) {
|
||||
utils.RemoveHeaders(r.Header, forward.XHeaders...)
|
||||
}
|
||||
|
||||
x.next.ServeHTTP(w, r)
|
||||
}
|
128
pkg/middlewares/forwardedheaders/forwarded_header_test.go
Normal file
128
pkg/middlewares/forwardedheaders/forwarded_header_test.go
Normal file
|
@ -0,0 +1,128 @@
|
|||
package forwardedheaders
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServeHTTP(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
insecure bool
|
||||
trustedIps []string
|
||||
incomingHeaders map[string]string
|
||||
remoteAddr string
|
||||
expectedHeaders map[string]string
|
||||
}{
|
||||
{
|
||||
desc: "all Empty",
|
||||
insecure: true,
|
||||
trustedIps: nil,
|
||||
remoteAddr: "",
|
||||
incomingHeaders: map[string]string{},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-for": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "insecure true with incoming X-Forwarded-For",
|
||||
insecure: true,
|
||||
trustedIps: nil,
|
||||
remoteAddr: "",
|
||||
incomingHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "insecure false with incoming X-Forwarded-For",
|
||||
insecure: false,
|
||||
trustedIps: nil,
|
||||
remoteAddr: "",
|
||||
incomingHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-for": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "insecure false with incoming X-Forwarded-For and valid Trusted Ips",
|
||||
insecure: false,
|
||||
trustedIps: []string{"10.0.1.100"},
|
||||
remoteAddr: "10.0.1.100:80",
|
||||
incomingHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "insecure false with incoming X-Forwarded-For and invalid Trusted Ips",
|
||||
insecure: false,
|
||||
trustedIps: []string{"10.0.1.100"},
|
||||
remoteAddr: "10.0.1.101:80",
|
||||
incomingHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-for": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "insecure false with incoming X-Forwarded-For and valid Trusted Ips CIDR",
|
||||
insecure: false,
|
||||
trustedIps: []string{"1.2.3.4/24"},
|
||||
remoteAddr: "1.2.3.156:80",
|
||||
incomingHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "insecure false with incoming X-Forwarded-For and invalid Trusted Ips CIDR",
|
||||
insecure: false,
|
||||
trustedIps: []string{"1.2.3.4/24"},
|
||||
remoteAddr: "10.0.1.101:80",
|
||||
incomingHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-for": "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.RemoteAddr = test.remoteAddr
|
||||
|
||||
for k, v := range test.incomingHeaders {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
m, err := NewXForwarded(test.insecure, test.trustedIps, http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {}))
|
||||
require.NoError(t, err)
|
||||
|
||||
m.ServeHTTP(nil, req)
|
||||
|
||||
for k, v := range test.expectedHeaders {
|
||||
assert.Equal(t, v, req.Header.Get(k))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
35
pkg/middlewares/handler_switcher.go
Normal file
35
pkg/middlewares/handler_switcher.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/pkg/safe"
|
||||
)
|
||||
|
||||
// HTTPHandlerSwitcher allows hot switching of http.ServeMux
|
||||
type HTTPHandlerSwitcher struct {
|
||||
handler *safe.Safe
|
||||
}
|
||||
|
||||
// NewHandlerSwitcher builds a new instance of HTTPHandlerSwitcher
|
||||
func NewHandlerSwitcher(newHandler http.Handler) (hs *HTTPHandlerSwitcher) {
|
||||
return &HTTPHandlerSwitcher{
|
||||
handler: safe.New(newHandler),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HTTPHandlerSwitcher) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
handlerBackup := h.handler.Get().(http.Handler)
|
||||
handlerBackup.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
// GetHandler returns the current http.ServeMux
|
||||
func (h *HTTPHandlerSwitcher) GetHandler() (newHandler http.Handler) {
|
||||
handler := h.handler.Get().(http.Handler)
|
||||
return handler
|
||||
}
|
||||
|
||||
// UpdateHandler safely updates the current http.ServeMux with a new one
|
||||
func (h *HTTPHandlerSwitcher) UpdateHandler(newHandler http.Handler) {
|
||||
h.handler.Set(newHandler)
|
||||
}
|
134
pkg/middlewares/headers/headers.go
Normal file
134
pkg/middlewares/headers/headers.go
Normal file
|
@ -0,0 +1,134 @@
|
|||
// Package headers Middleware based on https://github.com/unrolled/secure.
|
||||
package headers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/middlewares"
|
||||
"github.com/containous/traefik/pkg/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/unrolled/secure"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "Headers"
|
||||
)
|
||||
|
||||
type headers struct {
|
||||
name string
|
||||
handler http.Handler
|
||||
}
|
||||
|
||||
// New creates a Headers middleware.
|
||||
func New(ctx context.Context, next http.Handler, config config.Headers, name string) (http.Handler, error) {
|
||||
// HeaderMiddleware -> SecureMiddleWare -> next
|
||||
logger := middlewares.GetLogger(ctx, name, typeName)
|
||||
logger.Debug("Creating middleware")
|
||||
|
||||
if !config.HasSecureHeadersDefined() && !config.HasCustomHeadersDefined() {
|
||||
return nil, errors.New("headers configuration not valid")
|
||||
}
|
||||
|
||||
var handler http.Handler
|
||||
nextHandler := next
|
||||
|
||||
if config.HasSecureHeadersDefined() {
|
||||
logger.Debug("Setting up secureHeaders from %v", config)
|
||||
handler = newSecure(next, config)
|
||||
nextHandler = handler
|
||||
}
|
||||
|
||||
if config.HasCustomHeadersDefined() {
|
||||
logger.Debug("Setting up customHeaders from %v", config)
|
||||
handler = newHeader(nextHandler, config)
|
||||
}
|
||||
|
||||
return &headers{
|
||||
handler: handler,
|
||||
name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *headers) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return h.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (h *headers) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
h.handler.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
type secureHeader struct {
|
||||
next http.Handler
|
||||
secure *secure.Secure
|
||||
}
|
||||
|
||||
// newSecure constructs a new secure instance with supplied options.
|
||||
func newSecure(next http.Handler, headers config.Headers) *secureHeader {
|
||||
opt := secure.Options{
|
||||
BrowserXssFilter: headers.BrowserXSSFilter,
|
||||
ContentTypeNosniff: headers.ContentTypeNosniff,
|
||||
ForceSTSHeader: headers.ForceSTSHeader,
|
||||
FrameDeny: headers.FrameDeny,
|
||||
IsDevelopment: headers.IsDevelopment,
|
||||
SSLRedirect: headers.SSLRedirect,
|
||||
SSLForceHost: headers.SSLForceHost,
|
||||
SSLTemporaryRedirect: headers.SSLTemporaryRedirect,
|
||||
STSIncludeSubdomains: headers.STSIncludeSubdomains,
|
||||
STSPreload: headers.STSPreload,
|
||||
ContentSecurityPolicy: headers.ContentSecurityPolicy,
|
||||
CustomBrowserXssValue: headers.CustomBrowserXSSValue,
|
||||
CustomFrameOptionsValue: headers.CustomFrameOptionsValue,
|
||||
PublicKey: headers.PublicKey,
|
||||
ReferrerPolicy: headers.ReferrerPolicy,
|
||||
SSLHost: headers.SSLHost,
|
||||
AllowedHosts: headers.AllowedHosts,
|
||||
HostsProxyHeaders: headers.HostsProxyHeaders,
|
||||
SSLProxyHeaders: headers.SSLProxyHeaders,
|
||||
STSSeconds: headers.STSSeconds,
|
||||
}
|
||||
|
||||
return &secureHeader{
|
||||
next: next,
|
||||
secure: secure.New(opt),
|
||||
}
|
||||
}
|
||||
|
||||
func (s secureHeader) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
s.secure.HandlerFuncWithNextForRequestOnly(rw, req, s.next.ServeHTTP)
|
||||
}
|
||||
|
||||
// Header is a middleware that helps setup a few basic security features. A single headerOptions struct can be
|
||||
// provided to configure which features should be enabled, and the ability to override a few of the default values.
|
||||
type header struct {
|
||||
next http.Handler
|
||||
// If Custom request headers are set, these will be added to the request
|
||||
customRequestHeaders map[string]string
|
||||
}
|
||||
|
||||
// NewHeader constructs a new header instance from supplied frontend header struct.
|
||||
func newHeader(next http.Handler, headers config.Headers) *header {
|
||||
return &header{
|
||||
next: next,
|
||||
customRequestHeaders: headers.CustomRequestHeaders,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *header) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
s.modifyRequestHeaders(req)
|
||||
s.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
// modifyRequestHeaders set or delete request headers.
|
||||
func (s *header) modifyRequestHeaders(req *http.Request) {
|
||||
// Loop through Custom request headers
|
||||
for header, value := range s.customRequestHeaders {
|
||||
if value == "" {
|
||||
req.Header.Del(header)
|
||||
} else {
|
||||
req.Header.Set(header, value)
|
||||
}
|
||||
}
|
||||
}
|
190
pkg/middlewares/headers/headers_test.go
Normal file
190
pkg/middlewares/headers/headers_test.go
Normal file
|
@ -0,0 +1,190 @@
|
|||
package headers
|
||||
|
||||
// Middleware tests based on https://github.com/unrolled/secure
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCustomRequestHeader(t *testing.T) {
|
||||
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
header := newHeader(emptyHandler, config.Headers{
|
||||
CustomRequestHeaders: map[string]string{
|
||||
"X-Custom-Request-Header": "test_request",
|
||||
},
|
||||
})
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
|
||||
|
||||
header.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code)
|
||||
assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"))
|
||||
}
|
||||
|
||||
func TestCustomRequestHeaderEmptyValue(t *testing.T) {
|
||||
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
header := newHeader(emptyHandler, config.Headers{
|
||||
CustomRequestHeaders: map[string]string{
|
||||
"X-Custom-Request-Header": "test_request",
|
||||
},
|
||||
})
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
|
||||
|
||||
header.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code)
|
||||
assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"))
|
||||
|
||||
header = newHeader(emptyHandler, config.Headers{
|
||||
CustomRequestHeaders: map[string]string{
|
||||
"X-Custom-Request-Header": "",
|
||||
},
|
||||
})
|
||||
|
||||
header.ServeHTTP(res, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code)
|
||||
assert.Equal(t, "", req.Header.Get("X-Custom-Request-Header"))
|
||||
}
|
||||
|
||||
func TestSecureHeader(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
fromHost string
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
desc: "Should accept the request when given a host that is in the list",
|
||||
fromHost: "foo.com",
|
||||
expected: http.StatusOK,
|
||||
},
|
||||
{
|
||||
desc: "Should refuse the request when no host is given",
|
||||
fromHost: "",
|
||||
expected: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
desc: "Should refuse the request when no matching host is given",
|
||||
fromHost: "boo.com",
|
||||
expected: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
emptyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
header, err := New(context.Background(), emptyHandler, config.Headers{
|
||||
AllowedHosts: []string{"foo.com", "bar.com"},
|
||||
}, "foo")
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
|
||||
req.Host = test.fromHost
|
||||
header.ServeHTTP(res, req)
|
||||
assert.Equal(t, test.expected, res.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSLForceHost(t *testing.T) {
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
_, _ = rw.Write([]byte("OK"))
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
host string
|
||||
secureMiddleware *secureHeader
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
desc: "http should return a 301",
|
||||
host: "http://powpow.example.com",
|
||||
secureMiddleware: newSecure(next, config.Headers{
|
||||
SSLRedirect: true,
|
||||
SSLForceHost: true,
|
||||
SSLHost: "powpow.example.com",
|
||||
}),
|
||||
expected: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
desc: "http sub domain should return a 301",
|
||||
host: "http://www.powpow.example.com",
|
||||
secureMiddleware: newSecure(next, config.Headers{
|
||||
SSLRedirect: true,
|
||||
SSLForceHost: true,
|
||||
SSLHost: "powpow.example.com",
|
||||
}),
|
||||
expected: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
desc: "https should return a 200",
|
||||
host: "https://powpow.example.com",
|
||||
secureMiddleware: newSecure(next, config.Headers{
|
||||
SSLRedirect: true,
|
||||
SSLForceHost: true,
|
||||
SSLHost: "powpow.example.com",
|
||||
}),
|
||||
expected: http.StatusOK,
|
||||
},
|
||||
{
|
||||
desc: "https sub domain should return a 301",
|
||||
host: "https://www.powpow.example.com",
|
||||
secureMiddleware: newSecure(next, config.Headers{
|
||||
SSLRedirect: true,
|
||||
SSLForceHost: true,
|
||||
SSLHost: "powpow.example.com",
|
||||
}),
|
||||
expected: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
desc: "http without force host and sub domain should return a 301",
|
||||
host: "http://www.powpow.example.com",
|
||||
secureMiddleware: newSecure(next, config.Headers{
|
||||
SSLRedirect: true,
|
||||
SSLForceHost: false,
|
||||
SSLHost: "powpow.example.com",
|
||||
}),
|
||||
expected: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
desc: "https without force host and sub domain should return a 301",
|
||||
host: "https://www.powpow.example.com",
|
||||
secureMiddleware: newSecure(next, config.Headers{
|
||||
SSLRedirect: true,
|
||||
SSLForceHost: false,
|
||||
SSLHost: "powpow.example.com",
|
||||
}),
|
||||
expected: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, test.host, nil)
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
test.secureMiddleware.ServeHTTP(rw, req)
|
||||
|
||||
assert.Equal(t, test.expected, rw.Result().StatusCode)
|
||||
})
|
||||
}
|
||||
}
|
85
pkg/middlewares/ipwhitelist/ip_whitelist.go
Normal file
85
pkg/middlewares/ipwhitelist/ip_whitelist.go
Normal file
|
@ -0,0 +1,85 @@
|
|||
package ipwhitelist
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/ip"
|
||||
"github.com/containous/traefik/pkg/middlewares"
|
||||
"github.com/containous/traefik/pkg/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "IPWhiteLister"
|
||||
)
|
||||
|
||||
// ipWhiteLister is a middleware that provides Checks of the Requesting IP against a set of Whitelists
|
||||
type ipWhiteLister struct {
|
||||
next http.Handler
|
||||
whiteLister *ip.Checker
|
||||
strategy ip.Strategy
|
||||
name string
|
||||
}
|
||||
|
||||
// New builds a new IPWhiteLister given a list of CIDR-Strings to whitelist
|
||||
func New(ctx context.Context, next http.Handler, config config.IPWhiteList, name string) (http.Handler, error) {
|
||||
logger := middlewares.GetLogger(ctx, name, typeName)
|
||||
logger.Debug("Creating middleware")
|
||||
|
||||
if len(config.SourceRange) == 0 {
|
||||
return nil, errors.New("sourceRange is empty, IPWhiteLister not created")
|
||||
}
|
||||
|
||||
checker, err := ip.NewChecker(config.SourceRange)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot parse CIDR whitelist %s: %v", config.SourceRange, err)
|
||||
}
|
||||
|
||||
strategy, err := config.IPStrategy.Get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Debugf("Setting up IPWhiteLister with sourceRange: %s", config.SourceRange)
|
||||
return &ipWhiteLister{
|
||||
strategy: strategy,
|
||||
whiteLister: checker,
|
||||
next: next,
|
||||
name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (wl *ipWhiteLister) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return wl.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (wl *ipWhiteLister) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
logger := middlewares.GetLogger(req.Context(), wl.name, typeName)
|
||||
|
||||
err := wl.whiteLister.IsAuthorized(wl.strategy.GetIP(req))
|
||||
if err != nil {
|
||||
logMessage := fmt.Sprintf("rejecting request %+v: %v", req, err)
|
||||
logger.Debug(logMessage)
|
||||
tracing.SetErrorWithEvent(req, logMessage)
|
||||
reject(logger, rw)
|
||||
return
|
||||
}
|
||||
logger.Debugf("Accept %s: %+v", wl.strategy.GetIP(req), req)
|
||||
|
||||
wl.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
func reject(logger logrus.FieldLogger, rw http.ResponseWriter) {
|
||||
statusCode := http.StatusForbidden
|
||||
|
||||
rw.WriteHeader(statusCode)
|
||||
_, err := rw.Write([]byte(http.StatusText(statusCode)))
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
}
|
100
pkg/middlewares/ipwhitelist/ip_whitelist_test.go
Normal file
100
pkg/middlewares/ipwhitelist/ip_whitelist_test.go
Normal file
|
@ -0,0 +1,100 @@
|
|||
package ipwhitelist
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewIPWhiteLister(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
whiteList config.IPWhiteList
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
desc: "invalid IP",
|
||||
whiteList: config.IPWhiteList{
|
||||
SourceRange: []string{"foo"},
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
desc: "valid IP",
|
||||
whiteList: config.IPWhiteList{
|
||||
SourceRange: []string{"10.10.10.10"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
whiteLister, err := New(context.Background(), next, test.whiteList, "traefikTest")
|
||||
|
||||
if test.expectedError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, whiteLister)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPWhiteLister_ServeHTTP(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
whiteList config.IPWhiteList
|
||||
remoteAddr string
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
desc: "authorized with remote address",
|
||||
whiteList: config.IPWhiteList{
|
||||
SourceRange: []string{"20.20.20.20"},
|
||||
},
|
||||
remoteAddr: "20.20.20.20:1234",
|
||||
expected: 200,
|
||||
},
|
||||
{
|
||||
desc: "non authorized with remote address",
|
||||
whiteList: config.IPWhiteList{
|
||||
SourceRange: []string{"20.20.20.20"},
|
||||
},
|
||||
remoteAddr: "20.20.20.21:1234",
|
||||
expected: 403,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
whiteLister, err := New(context.Background(), next, test.whiteList, "traefikTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://10.10.10.10", nil)
|
||||
|
||||
if len(test.remoteAddr) > 0 {
|
||||
req.RemoteAddr = test.remoteAddr
|
||||
}
|
||||
|
||||
whiteLister.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, test.expected, recorder.Code)
|
||||
})
|
||||
}
|
||||
}
|
48
pkg/middlewares/maxconnection/max_connection.go
Normal file
48
pkg/middlewares/maxconnection/max_connection.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package maxconnection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/middlewares"
|
||||
"github.com/containous/traefik/pkg/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/vulcand/oxy/connlimit"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "MaxConnection"
|
||||
)
|
||||
|
||||
type maxConnection struct {
|
||||
handler http.Handler
|
||||
name string
|
||||
}
|
||||
|
||||
// New creates a max connection middleware.
|
||||
func New(ctx context.Context, next http.Handler, maxConns config.MaxConn, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
|
||||
extractFunc, err := utils.NewExtractor(maxConns.ExtractorFunc)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating connection limit: %v", err)
|
||||
}
|
||||
|
||||
handler, err := connlimit.New(next, extractFunc, maxConns.Amount)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating connection limit: %v", err)
|
||||
}
|
||||
|
||||
return &maxConnection{handler: handler, name: name}, nil
|
||||
}
|
||||
|
||||
func (mc *maxConnection) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return mc.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (mc *maxConnection) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
mc.handler.ServeHTTP(rw, req)
|
||||
}
|
13
pkg/middlewares/middleware.go
Normal file
13
pkg/middlewares/middleware.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// GetLogger creates a logger configured with the middleware fields.
|
||||
func GetLogger(ctx context.Context, middleware string, middlewareType string) logrus.FieldLogger {
|
||||
return log.FromContext(ctx).WithField(log.MiddlewareName, middleware).WithField(log.MiddlewareType, middlewareType)
|
||||
}
|
294
pkg/middlewares/passtlsclientcert/pass_tls_client_cert.go
Normal file
294
pkg/middlewares/passtlsclientcert/pass_tls_client_cert.go
Normal file
|
@ -0,0 +1,294 @@
|
|||
package passtlsclientcert
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
"github.com/containous/traefik/pkg/middlewares"
|
||||
"github.com/containous/traefik/pkg/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
xForwardedTLSClientCert = "X-Forwarded-Tls-Client-Cert"
|
||||
xForwardedTLSClientCertInfo = "X-Forwarded-Tls-Client-Cert-Info"
|
||||
typeName = "PassClientTLSCert"
|
||||
)
|
||||
|
||||
var attributeTypeNames = map[string]string{
|
||||
"0.9.2342.19200300.100.1.25": "DC", // Domain component OID - RFC 2247
|
||||
}
|
||||
|
||||
// DistinguishedNameOptions is a struct for specifying the configuration for the distinguished name info.
|
||||
type DistinguishedNameOptions struct {
|
||||
CommonName bool
|
||||
CountryName bool
|
||||
DomainComponent bool
|
||||
LocalityName bool
|
||||
OrganizationName bool
|
||||
SerialNumber bool
|
||||
StateOrProvinceName bool
|
||||
}
|
||||
|
||||
func newDistinguishedNameOptions(info *config.TLSCLientCertificateDNInfo) *DistinguishedNameOptions {
|
||||
if info == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &DistinguishedNameOptions{
|
||||
CommonName: info.CommonName,
|
||||
CountryName: info.Country,
|
||||
DomainComponent: info.DomainComponent,
|
||||
LocalityName: info.Locality,
|
||||
OrganizationName: info.Organization,
|
||||
SerialNumber: info.SerialNumber,
|
||||
StateOrProvinceName: info.Province,
|
||||
}
|
||||
}
|
||||
|
||||
// passTLSClientCert is a middleware that helps setup a few tls info features.
|
||||
type passTLSClientCert struct {
|
||||
next http.Handler
|
||||
name string
|
||||
pem bool // pass the sanitized pem to the backend in a specific header
|
||||
info *tlsClientCertificateInfo // pass selected information from the client certificate
|
||||
}
|
||||
|
||||
// New constructs a new PassTLSClientCert instance from supplied frontend header struct.
|
||||
func New(ctx context.Context, next http.Handler, config config.PassTLSClientCert, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
|
||||
return &passTLSClientCert{
|
||||
next: next,
|
||||
name: name,
|
||||
pem: config.PEM,
|
||||
info: newTLSClientInfo(config.Info),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// tlsClientCertificateInfo is a struct for specifying the configuration for the passTLSClientCert middleware.
|
||||
type tlsClientCertificateInfo struct {
|
||||
notAfter bool
|
||||
notBefore bool
|
||||
sans bool
|
||||
subject *DistinguishedNameOptions
|
||||
issuer *DistinguishedNameOptions
|
||||
}
|
||||
|
||||
func newTLSClientInfo(info *config.TLSClientCertificateInfo) *tlsClientCertificateInfo {
|
||||
if info == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &tlsClientCertificateInfo{
|
||||
issuer: newDistinguishedNameOptions(info.Issuer),
|
||||
notAfter: info.NotAfter,
|
||||
notBefore: info.NotBefore,
|
||||
subject: newDistinguishedNameOptions(info.Subject),
|
||||
sans: info.Sans,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *passTLSClientCert) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return p.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (p *passTLSClientCert) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
logger := middlewares.GetLogger(req.Context(), p.name, typeName)
|
||||
p.modifyRequestHeaders(logger, req)
|
||||
p.next.ServeHTTP(rw, req)
|
||||
}
|
||||
func getDNInfo(prefix string, options *DistinguishedNameOptions, cs *pkix.Name) string {
|
||||
if options == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
content := &strings.Builder{}
|
||||
|
||||
// Manage non standard attributes
|
||||
for _, name := range cs.Names {
|
||||
// Domain Component - RFC 2247
|
||||
if options.DomainComponent && attributeTypeNames[name.Type.String()] == "DC" {
|
||||
content.WriteString(fmt.Sprintf("DC=%s,", name.Value))
|
||||
}
|
||||
}
|
||||
|
||||
if options.CountryName {
|
||||
writeParts(content, cs.Country, "C")
|
||||
}
|
||||
|
||||
if options.StateOrProvinceName {
|
||||
writeParts(content, cs.Province, "ST")
|
||||
}
|
||||
|
||||
if options.LocalityName {
|
||||
writeParts(content, cs.Locality, "L")
|
||||
}
|
||||
|
||||
if options.OrganizationName {
|
||||
writeParts(content, cs.Organization, "O")
|
||||
}
|
||||
|
||||
if options.SerialNumber {
|
||||
writePart(content, cs.SerialNumber, "SN")
|
||||
}
|
||||
|
||||
if options.CommonName {
|
||||
writePart(content, cs.CommonName, "CN")
|
||||
}
|
||||
|
||||
if content.Len() > 0 {
|
||||
return prefix + `="` + strings.TrimSuffix(content.String(), ",") + `"`
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func writeParts(content io.StringWriter, entries []string, prefix string) {
|
||||
for _, entry := range entries {
|
||||
writePart(content, entry, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
func writePart(content io.StringWriter, entry string, prefix string) {
|
||||
if len(entry) > 0 {
|
||||
_, err := content.WriteString(fmt.Sprintf("%s=%s,", prefix, entry))
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getXForwardedTLSClientCertInfo Build a string with the wanted client certificates information
|
||||
// like Subject="C=%s,ST=%s,L=%s,O=%s,CN=%s",NB=%d,NA=%d,SAN=%s;
|
||||
func (p *passTLSClientCert) getXForwardedTLSClientCertInfo(certs []*x509.Certificate) string {
|
||||
var headerValues []string
|
||||
|
||||
for _, peerCert := range certs {
|
||||
var values []string
|
||||
var sans string
|
||||
var nb string
|
||||
var na string
|
||||
|
||||
if p.info != nil {
|
||||
subject := getDNInfo("Subject", p.info.subject, &peerCert.Subject)
|
||||
if len(subject) > 0 {
|
||||
values = append(values, subject)
|
||||
}
|
||||
|
||||
issuer := getDNInfo("Issuer", p.info.issuer, &peerCert.Issuer)
|
||||
if len(issuer) > 0 {
|
||||
values = append(values, issuer)
|
||||
}
|
||||
}
|
||||
|
||||
ci := p.info
|
||||
if ci != nil {
|
||||
if ci.notBefore {
|
||||
nb = fmt.Sprintf("NB=%d", uint64(peerCert.NotBefore.Unix()))
|
||||
values = append(values, nb)
|
||||
}
|
||||
if ci.notAfter {
|
||||
na = fmt.Sprintf("NA=%d", uint64(peerCert.NotAfter.Unix()))
|
||||
values = append(values, na)
|
||||
}
|
||||
|
||||
if ci.sans {
|
||||
sans = fmt.Sprintf("SAN=%s", strings.Join(getSANs(peerCert), ","))
|
||||
values = append(values, sans)
|
||||
}
|
||||
}
|
||||
|
||||
value := strings.Join(values, ",")
|
||||
headerValues = append(headerValues, value)
|
||||
}
|
||||
|
||||
return strings.Join(headerValues, ";")
|
||||
}
|
||||
|
||||
// modifyRequestHeaders set the wanted headers with the certificates information.
|
||||
func (p *passTLSClientCert) modifyRequestHeaders(logger logrus.FieldLogger, r *http.Request) {
|
||||
if p.pem {
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
r.Header.Set(xForwardedTLSClientCert, getXForwardedTLSClientCert(logger, r.TLS.PeerCertificates))
|
||||
} else {
|
||||
logger.Warn("Try to extract certificate on a request without TLS")
|
||||
}
|
||||
}
|
||||
|
||||
if p.info != nil {
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
headerContent := p.getXForwardedTLSClientCertInfo(r.TLS.PeerCertificates)
|
||||
r.Header.Set(xForwardedTLSClientCertInfo, url.QueryEscape(headerContent))
|
||||
} else {
|
||||
logger.Warn("Try to extract certificate on a request without TLS")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sanitize As we pass the raw certificates, remove the useless data and make it http request compliant.
|
||||
func sanitize(cert []byte) string {
|
||||
s := string(cert)
|
||||
r := strings.NewReplacer("-----BEGIN CERTIFICATE-----", "",
|
||||
"-----END CERTIFICATE-----", "",
|
||||
"\n", "")
|
||||
cleaned := r.Replace(s)
|
||||
|
||||
return url.QueryEscape(cleaned)
|
||||
}
|
||||
|
||||
// extractCertificate extract the certificate from the request.
|
||||
func extractCertificate(logger logrus.FieldLogger, cert *x509.Certificate) string {
|
||||
b := pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
|
||||
certPEM := pem.EncodeToMemory(&b)
|
||||
if certPEM == nil {
|
||||
logger.Error("Cannot extract the certificate content")
|
||||
return ""
|
||||
}
|
||||
return sanitize(certPEM)
|
||||
}
|
||||
|
||||
// getXForwardedTLSClientCert Build a string with the client certificates.
|
||||
func getXForwardedTLSClientCert(logger logrus.FieldLogger, certs []*x509.Certificate) string {
|
||||
var headerValues []string
|
||||
|
||||
for _, peerCert := range certs {
|
||||
headerValues = append(headerValues, extractCertificate(logger, peerCert))
|
||||
}
|
||||
|
||||
return strings.Join(headerValues, ",")
|
||||
}
|
||||
|
||||
// getSANs get the Subject Alternate Name values.
|
||||
func getSANs(cert *x509.Certificate) []string {
|
||||
var sans []string
|
||||
if cert == nil {
|
||||
return sans
|
||||
}
|
||||
|
||||
sans = append(sans, cert.DNSNames...)
|
||||
sans = append(sans, cert.EmailAddresses...)
|
||||
|
||||
var ips []string
|
||||
for _, ip := range cert.IPAddresses {
|
||||
ips = append(ips, ip.String())
|
||||
}
|
||||
sans = append(sans, ips...)
|
||||
|
||||
var uris []string
|
||||
for _, uri := range cert.URIs {
|
||||
uris = append(uris, uri.String())
|
||||
}
|
||||
|
||||
return append(sans, uris...)
|
||||
}
|
663
pkg/middlewares/passtlsclientcert/pass_tls_client_cert_test.go
Normal file
663
pkg/middlewares/passtlsclientcert/pass_tls_client_cert_test.go
Normal file
|
@ -0,0 +1,663 @@
|
|||
package passtlsclientcert
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/testhelpers"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
signingCA = `Certificate:
|
||||
Data:
|
||||
Version: 3 (0x2)
|
||||
Serial Number: 2 (0x2)
|
||||
Signature Algorithm: sha1WithRSAEncryption
|
||||
Issuer: DC=org, DC=cheese, O=Cheese, O=Cheese 2, OU=Cheese Section, OU=Cheese Section 2, CN=Simple Root CA, CN=Simple Root CA 2, C=FR, C=US, L=TOULOUSE, L=LYON, ST=Root State, ST=Root State 2/emailAddress=root@signing.com/emailAddress=root2@signing.com
|
||||
Validity
|
||||
Not Before: Dec 6 11:10:09 2018 GMT
|
||||
Not After : Dec 5 11:10:09 2028 GMT
|
||||
Subject: DC=org, DC=cheese, O=Cheese, O=Cheese 2, OU=Simple Signing Section, OU=Simple Signing Section 2, CN=Simple Signing CA, CN=Simple Signing CA 2, C=FR, C=US, L=TOULOUSE, L=LYON, ST=Signing State, ST=Signing State 2/emailAddress=simple@signing.com/emailAddress=simple2@signing.com
|
||||
Subject Public Key Info:
|
||||
Public Key Algorithm: rsaEncryption
|
||||
RSA Public-Key: (2048 bit)
|
||||
Modulus:
|
||||
00:c3:9d:9f:61:15:57:3f:78:cc:e7:5d:20:e2:3e:
|
||||
2e:79:4a:c3:3a:0c:26:40:18:db:87:08:85:c2:f7:
|
||||
af:87:13:1a:ff:67:8a:b7:2b:58:a7:cc:89:dd:77:
|
||||
ff:5e:27:65:11:80:82:8f:af:a0:af:25:86:ec:a2:
|
||||
4f:20:0e:14:15:16:12:d7:74:5a:c3:99:bd:3b:81:
|
||||
c8:63:6f:fc:90:14:86:d2:39:ee:87:b2:ff:6d:a5:
|
||||
69:da:ab:5a:3a:97:cd:23:37:6a:4b:ba:63:cd:a1:
|
||||
a9:e6:79:aa:37:b8:d1:90:c9:24:b5:e8:70:fc:15:
|
||||
ad:39:97:28:73:47:66:f6:22:79:5a:b0:03:83:8a:
|
||||
f1:ca:ae:8b:50:1e:c8:fa:0d:9f:76:2e:00:c2:0e:
|
||||
75:bc:47:5a:b6:d8:05:ed:5a:bc:6d:50:50:36:6b:
|
||||
ab:ab:69:f6:9b:1b:6c:7e:a8:9f:b2:33:3a:3c:8c:
|
||||
6d:5e:83:ce:17:82:9e:10:51:a6:39:ec:98:4e:50:
|
||||
b7:b1:aa:8b:ac:bb:a1:60:1b:ea:31:3b:b8:0a:ea:
|
||||
63:41:79:b5:ec:ee:19:e9:85:8e:f3:6d:93:80:da:
|
||||
98:58:a2:40:93:a5:53:eb:1d:24:b6:66:07:ec:58:
|
||||
10:63:e7:fa:6e:18:60:74:76:15:39:3c:f4:95:95:
|
||||
7e:df
|
||||
Exponent: 65537 (0x10001)
|
||||
X509v3 extensions:
|
||||
X509v3 Key Usage: critical
|
||||
Certificate Sign, CRL Sign
|
||||
X509v3 Basic Constraints: critical
|
||||
CA:TRUE, pathlen:0
|
||||
X509v3 Subject Key Identifier:
|
||||
1E:52:A2:E8:54:D5:37:EB:D5:A8:1D:E4:C2:04:1D:37:E2:F7:70:03
|
||||
X509v3 Authority Key Identifier:
|
||||
keyid:36:70:35:AA:F0:F6:93:B2:86:5D:32:73:F9:41:5A:3F:3B:C8:BC:8B
|
||||
|
||||
Signature Algorithm: sha1WithRSAEncryption
|
||||
76:f3:16:21:27:6d:a2:2e:e8:18:49:aa:54:1e:f8:3b:07:fa:
|
||||
65:50:d8:1f:a2:cf:64:6c:15:e0:0f:c8:46:b2:d7:b8:0e:cd:
|
||||
05:3b:06:fb:dd:c6:2f:01:ae:bd:69:d3:bb:55:47:a9:f6:e5:
|
||||
ba:be:4b:45:fb:2e:3c:33:e0:57:d4:3e:8e:3e:11:f2:0a:f1:
|
||||
7d:06:ab:04:2e:a5:76:20:c2:db:a4:68:5a:39:00:62:2a:1d:
|
||||
c2:12:b1:90:66:8c:36:a8:fd:83:d1:1b:da:23:a7:1d:5b:e6:
|
||||
9b:40:c4:78:25:c7:b7:6b:75:35:cf:bb:37:4a:4f:fc:7e:32:
|
||||
1f:8c:cf:12:d2:c9:c8:99:d9:4a:55:0a:1e:ac:de:b4:cb:7c:
|
||||
bf:c4:fb:60:2c:a8:f7:e7:63:5c:b0:1c:62:af:01:3c:fe:4d:
|
||||
3c:0b:18:37:4c:25:fc:d0:b2:f6:b2:f1:c3:f4:0f:53:d6:1e:
|
||||
b5:fa:bc:d8:ad:dd:1c:f5:45:9f:af:fe:0a:01:79:92:9a:d8:
|
||||
71:db:37:f3:1e:bd:fb:c7:1e:0a:0f:97:2a:61:f3:7b:19:93:
|
||||
9c:a6:8a:69:cd:b0:f5:91:02:a5:1b:10:f4:80:5d:42:af:4e:
|
||||
82:12:30:3e:d3:a7:11:14:ce:50:91:04:80:d7:2a:03:ef:71:
|
||||
10:b8:db:a5
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIFzTCCBLWgAwIBAgIBAjANBgkqhkiG9w0BAQUFADCCAWQxEzARBgoJkiaJk/Is
|
||||
ZAEZFgNvcmcxFjAUBgoJkiaJk/IsZAEZFgZjaGVlc2UxDzANBgNVBAoMBkNoZWVz
|
||||
ZTERMA8GA1UECgwIQ2hlZXNlIDIxFzAVBgNVBAsMDkNoZWVzZSBTZWN0aW9uMRkw
|
||||
FwYDVQQLDBBDaGVlc2UgU2VjdGlvbiAyMRcwFQYDVQQDDA5TaW1wbGUgUm9vdCBD
|
||||
QTEZMBcGA1UEAwwQU2ltcGxlIFJvb3QgQ0EgMjELMAkGA1UEBhMCRlIxCzAJBgNV
|
||||
BAYTAlVTMREwDwYDVQQHDAhUT1VMT1VTRTENMAsGA1UEBwwETFlPTjETMBEGA1UE
|
||||
CAwKUm9vdCBTdGF0ZTEVMBMGA1UECAwMUm9vdCBTdGF0ZSAyMR8wHQYJKoZIhvcN
|
||||
AQkBFhByb290QHNpZ25pbmcuY29tMSAwHgYJKoZIhvcNAQkBFhFyb290MkBzaWdu
|
||||
aW5nLmNvbTAeFw0xODEyMDYxMTEwMDlaFw0yODEyMDUxMTEwMDlaMIIBhDETMBEG
|
||||
CgmSJomT8ixkARkWA29yZzEWMBQGCgmSJomT8ixkARkWBmNoZWVzZTEPMA0GA1UE
|
||||
CgwGQ2hlZXNlMREwDwYDVQQKDAhDaGVlc2UgMjEfMB0GA1UECwwWU2ltcGxlIFNp
|
||||
Z25pbmcgU2VjdGlvbjEhMB8GA1UECwwYU2ltcGxlIFNpZ25pbmcgU2VjdGlvbiAy
|
||||
MRowGAYDVQQDDBFTaW1wbGUgU2lnbmluZyBDQTEcMBoGA1UEAwwTU2ltcGxlIFNp
|
||||
Z25pbmcgQ0EgMjELMAkGA1UEBhMCRlIxCzAJBgNVBAYTAlVTMREwDwYDVQQHDAhU
|
||||
T1VMT1VTRTENMAsGA1UEBwwETFlPTjEWMBQGA1UECAwNU2lnbmluZyBTdGF0ZTEY
|
||||
MBYGA1UECAwPU2lnbmluZyBTdGF0ZSAyMSEwHwYJKoZIhvcNAQkBFhJzaW1wbGVA
|
||||
c2lnbmluZy5jb20xIjAgBgkqhkiG9w0BCQEWE3NpbXBsZTJAc2lnbmluZy5jb20w
|
||||
ggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDDnZ9hFVc/eMznXSDiPi55
|
||||
SsM6DCZAGNuHCIXC96+HExr/Z4q3K1inzIndd/9eJ2URgIKPr6CvJYbsok8gDhQV
|
||||
FhLXdFrDmb07gchjb/yQFIbSOe6Hsv9tpWnaq1o6l80jN2pLumPNoanmeao3uNGQ
|
||||
ySS16HD8Fa05lyhzR2b2InlasAODivHKrotQHsj6DZ92LgDCDnW8R1q22AXtWrxt
|
||||
UFA2a6urafabG2x+qJ+yMzo8jG1eg84Xgp4QUaY57JhOULexqousu6FgG+oxO7gK
|
||||
6mNBebXs7hnphY7zbZOA2phYokCTpVPrHSS2ZgfsWBBj5/puGGB0dhU5PPSVlX7f
|
||||
AgMBAAGjZjBkMA4GA1UdDwEB/wQEAwIBBjASBgNVHRMBAf8ECDAGAQH/AgEAMB0G
|
||||
A1UdDgQWBBQeUqLoVNU369WoHeTCBB034vdwAzAfBgNVHSMEGDAWgBQ2cDWq8PaT
|
||||
soZdMnP5QVo/O8i8izANBgkqhkiG9w0BAQUFAAOCAQEAdvMWISdtoi7oGEmqVB74
|
||||
Owf6ZVDYH6LPZGwV4A/IRrLXuA7NBTsG+93GLwGuvWnTu1VHqfblur5LRfsuPDPg
|
||||
V9Q+jj4R8grxfQarBC6ldiDC26RoWjkAYiodwhKxkGaMNqj9g9Eb2iOnHVvmm0DE
|
||||
eCXHt2t1Nc+7N0pP/H4yH4zPEtLJyJnZSlUKHqzetMt8v8T7YCyo9+djXLAcYq8B
|
||||
PP5NPAsYN0wl/NCy9rLxw/QPU9Yetfq82K3dHPVFn6/+CgF5kprYcds38x69+8ce
|
||||
Cg+XKmHzexmTnKaKac2w9ZECpRsQ9IBdQq9OghIwPtOnERTOUJEEgNcqA+9xELjb
|
||||
pQ==
|
||||
-----END CERTIFICATE-----
|
||||
`
|
||||
minimalCheeseCrt = `-----BEGIN CERTIFICATE-----
|
||||
MIIEQDCCAygCFFRY0OBk/L5Se0IZRj3CMljawL2UMA0GCSqGSIb3DQEBCwUAMIIB
|
||||
hDETMBEGCgmSJomT8ixkARkWA29yZzEWMBQGCgmSJomT8ixkARkWBmNoZWVzZTEP
|
||||
MA0GA1UECgwGQ2hlZXNlMREwDwYDVQQKDAhDaGVlc2UgMjEfMB0GA1UECwwWU2lt
|
||||
cGxlIFNpZ25pbmcgU2VjdGlvbjEhMB8GA1UECwwYU2ltcGxlIFNpZ25pbmcgU2Vj
|
||||
dGlvbiAyMRowGAYDVQQDDBFTaW1wbGUgU2lnbmluZyBDQTEcMBoGA1UEAwwTU2lt
|
||||
cGxlIFNpZ25pbmcgQ0EgMjELMAkGA1UEBhMCRlIxCzAJBgNVBAYTAlVTMREwDwYD
|
||||
VQQHDAhUT1VMT1VTRTENMAsGA1UEBwwETFlPTjEWMBQGA1UECAwNU2lnbmluZyBT
|
||||
dGF0ZTEYMBYGA1UECAwPU2lnbmluZyBTdGF0ZSAyMSEwHwYJKoZIhvcNAQkBFhJz
|
||||
aW1wbGVAc2lnbmluZy5jb20xIjAgBgkqhkiG9w0BCQEWE3NpbXBsZTJAc2lnbmlu
|
||||
Zy5jb20wHhcNMTgxMjA2MTExMDM2WhcNMjEwOTI1MTExMDM2WjAzMQswCQYDVQQG
|
||||
EwJGUjETMBEGA1UECAwKU29tZS1TdGF0ZTEPMA0GA1UECgwGQ2hlZXNlMIIBIjAN
|
||||
BgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAskX/bUtwFo1gF2BTPNaNcTUMaRFu
|
||||
FMZozK8IgLjccZ4kZ0R9oFO6Yp8Zl/IvPaf7tE26PI7XP7eHriUdhnQzX7iioDd0
|
||||
RZa68waIhAGc+xPzRFrP3b3yj3S2a9Rve3c0K+SCV+EtKAwsxMqQDhoo9PcBfo5B
|
||||
RHfht07uD5MncUcGirwN+/pxHV5xzAGPcc7On0/5L7bq/G+63nhu78zw9XyuLaHC
|
||||
PM5VbOUvpyIESJHbMMzTdFGL8ob9VKO+Kr1kVGdEA9i8FLGl3xz/GBKuW/JD0xyW
|
||||
DrU29mri5vYWHmkuv7ZWHGXnpXjTtPHwveE9/0/ArnmpMyR9JtqFr1oEvQIDAQAB
|
||||
MA0GCSqGSIb3DQEBCwUAA4IBAQBHta+NWXI08UHeOkGzOTGRiWXsOH2dqdX6gTe9
|
||||
xF1AIjyoQ0gvpoGVvlnChSzmlUj+vnx/nOYGIt1poE3hZA3ZHZD/awsvGyp3GwWD
|
||||
IfXrEViSCIyF+8tNNKYyUcEO3xdAsAUGgfUwwF/mZ6MBV5+A/ZEEILlTq8zFt9dV
|
||||
vdKzIt7fZYxYBBHFSarl1x8pDgWXlf3hAufevGJXip9xGYmznF0T5cq1RbWJ4be3
|
||||
/9K7yuWhuBYC3sbTbCneHBa91M82za+PIISc1ygCYtWSBoZKSAqLk0rkZpHaekDP
|
||||
WqeUSNGYV//RunTeuRDAf5OxehERb1srzBXhRZ3cZdzXbgR/
|
||||
-----END CERTIFICATE-----
|
||||
`
|
||||
|
||||
completeCheeseCrt = `Certificate:
|
||||
Data:
|
||||
Version: 3 (0x2)
|
||||
Serial Number: 1 (0x1)
|
||||
Signature Algorithm: sha1WithRSAEncryption
|
||||
Issuer: DC=org, DC=cheese, O=Cheese, O=Cheese 2, OU=Simple Signing Section, OU=Simple Signing Section 2, CN=Simple Signing CA, CN=Simple Signing CA 2, C=FR, C=US, L=TOULOUSE, L=LYON, ST=Signing State, ST=Signing State 2/emailAddress=simple@signing.com/emailAddress=simple2@signing.com
|
||||
Validity
|
||||
Not Before: Dec 6 11:10:16 2018 GMT
|
||||
Not After : Dec 5 11:10:16 2020 GMT
|
||||
Subject: DC=org, DC=cheese, O=Cheese, O=Cheese 2, OU=Simple Signing Section, OU=Simple Signing Section 2, CN=*.cheese.org, CN=*.cheese.com, C=FR, C=US, L=TOULOUSE, L=LYON, ST=Cheese org state, ST=Cheese com state/emailAddress=cert@cheese.org/emailAddress=cert@scheese.com
|
||||
Subject Public Key Info:
|
||||
Public Key Algorithm: rsaEncryption
|
||||
RSA Public-Key: (2048 bit)
|
||||
Modulus:
|
||||
00:de:77:fa:8d:03:70:30:39:dd:51:1b:cc:60:db:
|
||||
a9:5a:13:b1:af:fe:2c:c6:38:9b:88:0a:0f:8e:d9:
|
||||
1b:a1:1d:af:0d:66:e4:13:5b:bc:5d:36:92:d7:5e:
|
||||
d0:fa:88:29:d3:78:e1:81:de:98:b2:a9:22:3f:bf:
|
||||
8a:af:12:92:63:d4:a9:c3:f2:e4:7e:d2:dc:a2:c5:
|
||||
39:1c:7a:eb:d7:12:70:63:2e:41:47:e0:f0:08:e8:
|
||||
dc:be:09:01:ec:28:09:af:35:d7:79:9c:50:35:d1:
|
||||
6b:e5:87:7b:34:f6:d2:31:65:1d:18:42:69:6c:04:
|
||||
11:83:fe:44:ae:90:92:2d:0b:75:39:57:62:e6:17:
|
||||
2f:47:2b:c7:53:dd:10:2d:c9:e3:06:13:d2:b9:ba:
|
||||
63:2e:3c:7d:83:6b:d6:89:c9:cc:9d:4d:bf:9f:e8:
|
||||
a3:7b:da:c8:99:2b:ba:66:d6:8e:f8:41:41:a0:c9:
|
||||
d0:5e:c8:11:a4:55:4a:93:83:87:63:04:63:41:9c:
|
||||
fb:68:04:67:c2:71:2f:f2:65:1d:02:5d:15:db:2c:
|
||||
d9:04:69:85:c2:7d:0d:ea:3b:ac:85:f8:d4:8f:0f:
|
||||
c5:70:b2:45:e1:ec:b2:54:0b:e9:f7:82:b4:9b:1b:
|
||||
2d:b9:25:d4:ab:ca:8f:5b:44:3e:15:dd:b8:7f:b7:
|
||||
ee:f9
|
||||
Exponent: 65537 (0x10001)
|
||||
X509v3 extensions:
|
||||
X509v3 Key Usage: critical
|
||||
Digital Signature, Key Encipherment
|
||||
X509v3 Basic Constraints:
|
||||
CA:FALSE
|
||||
X509v3 Extended Key Usage:
|
||||
TLS Web Server Authentication, TLS Web Client Authentication
|
||||
X509v3 Subject Key Identifier:
|
||||
94:BA:73:78:A2:87:FB:58:28:28:CF:98:3B:C2:45:70:16:6E:29:2F
|
||||
X509v3 Authority Key Identifier:
|
||||
keyid:1E:52:A2:E8:54:D5:37:EB:D5:A8:1D:E4:C2:04:1D:37:E2:F7:70:03
|
||||
|
||||
X509v3 Subject Alternative Name:
|
||||
DNS:*.cheese.org, DNS:*.cheese.net, DNS:*.cheese.com, IP Address:10.0.1.0, IP Address:10.0.1.2, email:test@cheese.org, email:test@cheese.net
|
||||
Signature Algorithm: sha1WithRSAEncryption
|
||||
76:6b:05:b0:0e:34:11:b1:83:99:91:dc:ae:1b:e2:08:15:8b:
|
||||
16:b2:9b:27:1c:02:ac:b5:df:1b:d0:d0:75:a4:2b:2c:5c:65:
|
||||
ed:99:ab:f7:cd:fe:38:3f:c3:9a:22:31:1b:ac:8c:1c:c2:f9:
|
||||
5d:d4:75:7a:2e:72:c7:85:a9:04:af:9f:2a:cc:d3:96:75:f0:
|
||||
8e:c7:c6:76:48:ac:45:a4:b9:02:1e:2f:c0:15:c4:07:08:92:
|
||||
cb:27:50:67:a1:c8:05:c5:3a:b3:a6:48:be:eb:d5:59:ab:a2:
|
||||
1b:95:30:71:13:5b:0a:9a:73:3b:60:cc:10:d0:6a:c7:e5:d7:
|
||||
8b:2f:f9:2e:98:f2:ff:81:14:24:09:e3:4b:55:57:09:1a:22:
|
||||
74:f1:f6:40:13:31:43:89:71:0a:96:1a:05:82:1f:83:3a:87:
|
||||
9b:17:25:ef:5a:55:f2:2d:cd:0d:4d:e4:81:58:b6:e3:8d:09:
|
||||
62:9a:0c:bd:e4:e5:5c:f0:95:da:cb:c7:34:2c:34:5f:6d:fc:
|
||||
60:7b:12:5b:86:fd:df:21:89:3b:48:08:30:bf:67:ff:8c:e6:
|
||||
9b:53:cc:87:36:47:70:40:3b:d9:90:2a:d2:d2:82:c6:9c:f5:
|
||||
d1:d8:e0:e6:fd:aa:2f:95:7e:39:ac:fc:4e:d4:ce:65:b3:ec:
|
||||
c6:98:8a:31
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIGWjCCBUKgAwIBAgIBATANBgkqhkiG9w0BAQUFADCCAYQxEzARBgoJkiaJk/Is
|
||||
ZAEZFgNvcmcxFjAUBgoJkiaJk/IsZAEZFgZjaGVlc2UxDzANBgNVBAoMBkNoZWVz
|
||||
ZTERMA8GA1UECgwIQ2hlZXNlIDIxHzAdBgNVBAsMFlNpbXBsZSBTaWduaW5nIFNl
|
||||
Y3Rpb24xITAfBgNVBAsMGFNpbXBsZSBTaWduaW5nIFNlY3Rpb24gMjEaMBgGA1UE
|
||||
AwwRU2ltcGxlIFNpZ25pbmcgQ0ExHDAaBgNVBAMME1NpbXBsZSBTaWduaW5nIENB
|
||||
IDIxCzAJBgNVBAYTAkZSMQswCQYDVQQGEwJVUzERMA8GA1UEBwwIVE9VTE9VU0Ux
|
||||
DTALBgNVBAcMBExZT04xFjAUBgNVBAgMDVNpZ25pbmcgU3RhdGUxGDAWBgNVBAgM
|
||||
D1NpZ25pbmcgU3RhdGUgMjEhMB8GCSqGSIb3DQEJARYSc2ltcGxlQHNpZ25pbmcu
|
||||
Y29tMSIwIAYJKoZIhvcNAQkBFhNzaW1wbGUyQHNpZ25pbmcuY29tMB4XDTE4MTIw
|
||||
NjExMTAxNloXDTIwMTIwNTExMTAxNlowggF2MRMwEQYKCZImiZPyLGQBGRYDb3Jn
|
||||
MRYwFAYKCZImiZPyLGQBGRYGY2hlZXNlMQ8wDQYDVQQKDAZDaGVlc2UxETAPBgNV
|
||||
BAoMCENoZWVzZSAyMR8wHQYDVQQLDBZTaW1wbGUgU2lnbmluZyBTZWN0aW9uMSEw
|
||||
HwYDVQQLDBhTaW1wbGUgU2lnbmluZyBTZWN0aW9uIDIxFTATBgNVBAMMDCouY2hl
|
||||
ZXNlLm9yZzEVMBMGA1UEAwwMKi5jaGVlc2UuY29tMQswCQYDVQQGEwJGUjELMAkG
|
||||
A1UEBhMCVVMxETAPBgNVBAcMCFRPVUxPVVNFMQ0wCwYDVQQHDARMWU9OMRkwFwYD
|
||||
VQQIDBBDaGVlc2Ugb3JnIHN0YXRlMRkwFwYDVQQIDBBDaGVlc2UgY29tIHN0YXRl
|
||||
MR4wHAYJKoZIhvcNAQkBFg9jZXJ0QGNoZWVzZS5vcmcxHzAdBgkqhkiG9w0BCQEW
|
||||
EGNlcnRAc2NoZWVzZS5jb20wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIB
|
||||
AQDed/qNA3AwOd1RG8xg26laE7Gv/izGOJuICg+O2RuhHa8NZuQTW7xdNpLXXtD6
|
||||
iCnTeOGB3piyqSI/v4qvEpJj1KnD8uR+0tyixTkceuvXEnBjLkFH4PAI6Ny+CQHs
|
||||
KAmvNdd5nFA10Wvlh3s09tIxZR0YQmlsBBGD/kSukJItC3U5V2LmFy9HK8dT3RAt
|
||||
yeMGE9K5umMuPH2Da9aJycydTb+f6KN72siZK7pm1o74QUGgydBeyBGkVUqTg4dj
|
||||
BGNBnPtoBGfCcS/yZR0CXRXbLNkEaYXCfQ3qO6yF+NSPD8VwskXh7LJUC+n3grSb
|
||||
Gy25JdSryo9bRD4V3bh/t+75AgMBAAGjgeAwgd0wDgYDVR0PAQH/BAQDAgWgMAkG
|
||||
A1UdEwQCMAAwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMB0GA1UdDgQW
|
||||
BBSUunN4oof7WCgoz5g7wkVwFm4pLzAfBgNVHSMEGDAWgBQeUqLoVNU369WoHeTC
|
||||
BB034vdwAzBhBgNVHREEWjBYggwqLmNoZWVzZS5vcmeCDCouY2hlZXNlLm5ldIIM
|
||||
Ki5jaGVlc2UuY29thwQKAAEAhwQKAAECgQ90ZXN0QGNoZWVzZS5vcmeBD3Rlc3RA
|
||||
Y2hlZXNlLm5ldDANBgkqhkiG9w0BAQUFAAOCAQEAdmsFsA40EbGDmZHcrhviCBWL
|
||||
FrKbJxwCrLXfG9DQdaQrLFxl7Zmr983+OD/DmiIxG6yMHML5XdR1ei5yx4WpBK+f
|
||||
KszTlnXwjsfGdkisRaS5Ah4vwBXEBwiSyydQZ6HIBcU6s6ZIvuvVWauiG5UwcRNb
|
||||
CppzO2DMENBqx+XXiy/5Lpjy/4EUJAnjS1VXCRoidPH2QBMxQ4lxCpYaBYIfgzqH
|
||||
mxcl71pV8i3NDU3kgVi2440JYpoMveTlXPCV2svHNCw0X238YHsSW4b93yGJO0gI
|
||||
ML9n/4zmm1PMhzZHcEA72ZAq0tKCxpz10djg5v2qL5V+Oaz8TtTOZbPsxpiKMQ==
|
||||
-----END CERTIFICATE-----
|
||||
`
|
||||
|
||||
minimalCert = `-----BEGIN CERTIFICATE-----
|
||||
MIIDGTCCAgECCQCqLd75YLi2kDANBgkqhkiG9w0BAQsFADBYMQswCQYDVQQGEwJG
|
||||
UjETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UEBwwIVG91bG91c2UxITAfBgNV
|
||||
BAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0xODA3MTgwODI4MTZaFw0x
|
||||
ODA4MTcwODI4MTZaMEUxCzAJBgNVBAYTAkZSMRMwEQYDVQQIDApTb21lLVN0YXRl
|
||||
MSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3
|
||||
DQEBAQUAA4IBDwAwggEKAoIBAQC/+frDMMTLQyXG34F68BPhQq0kzK4LIq9Y0/gl
|
||||
FjySZNn1C0QDWA1ubVCAcA6yY204I9cxcQDPNrhC7JlS5QA8Y5rhIBrqQlzZizAi
|
||||
Rj3NTrRjtGUtOScnHuJaWjLy03DWD+aMwb7q718xt5SEABmmUvLwQK+EjW2MeDwj
|
||||
y8/UEIpvrRDmdhGaqv7IFpIDkcIF7FowJ/hwDvx3PMc+z/JWK0ovzpvgbx69AVbw
|
||||
ZxCimeha65rOqVi+lEetD26le+WnOdYsdJ2IkmpPNTXGdfb15xuAc+gFXfMCh7Iw
|
||||
3Ynl6dZtZM/Ok2kiA7/OsmVnRKkWrtBfGYkI9HcNGb3zrk6nAgMBAAEwDQYJKoZI
|
||||
hvcNAQELBQADggEBAC/R+Yvhh1VUhcbK49olWsk/JKqfS3VIDQYZg1Eo+JCPbwgS
|
||||
I1BSYVfMcGzuJTX6ua3m/AHzGF3Tap4GhF4tX12jeIx4R4utnjj7/YKkTvuEM2f4
|
||||
xT56YqI7zalGScIB0iMeyNz1QcimRl+M/49au8ow9hNX8C2tcA2cwd/9OIj/6T8q
|
||||
SBRHc6ojvbqZSJCO0jziGDT1L3D+EDgTjED4nd77v/NRdP+egb0q3P0s4dnQ/5AV
|
||||
aQlQADUn61j3ScbGJ4NSeZFFvsl38jeRi/MEzp0bGgNBcPj6JHi7qbbauZcZfQ05
|
||||
jECvgAY7Nfd9mZ1KtyNaW31is+kag7NsvjxU/kM=
|
||||
-----END CERTIFICATE-----`
|
||||
)
|
||||
|
||||
func getCleanCertContents(certContents []string) string {
|
||||
var re = regexp.MustCompile("-----BEGIN CERTIFICATE-----(?s)(.*)")
|
||||
|
||||
var cleanedCertContent []string
|
||||
for _, certContent := range certContents {
|
||||
cert := re.FindString(certContent)
|
||||
cleanedCertContent = append(cleanedCertContent, sanitize([]byte(cert)))
|
||||
}
|
||||
|
||||
return strings.Join(cleanedCertContent, ",")
|
||||
}
|
||||
|
||||
func getCertificate(certContent string) *x509.Certificate {
|
||||
roots := x509.NewCertPool()
|
||||
ok := roots.AppendCertsFromPEM([]byte(signingCA))
|
||||
if !ok {
|
||||
panic("failed to parse root certificate")
|
||||
}
|
||||
|
||||
block, _ := pem.Decode([]byte(certContent))
|
||||
if block == nil {
|
||||
panic("failed to parse certificate PEM")
|
||||
}
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
panic("failed to parse certificate: " + err.Error())
|
||||
}
|
||||
|
||||
return cert
|
||||
}
|
||||
|
||||
func buildTLSWith(certContents []string) *tls.ConnectionState {
|
||||
var peerCertificates []*x509.Certificate
|
||||
|
||||
for _, certContent := range certContents {
|
||||
peerCertificates = append(peerCertificates, getCertificate(certContent))
|
||||
}
|
||||
|
||||
return &tls.ConnectionState{PeerCertificates: peerCertificates}
|
||||
}
|
||||
|
||||
var next = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := w.Write([]byte("bar"))
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
func getExpectedSanitized(s string) string {
|
||||
return url.QueryEscape(strings.Replace(s, "\n", "", -1))
|
||||
}
|
||||
|
||||
func TestSanitize(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
toSanitize []byte
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "Empty",
|
||||
},
|
||||
{
|
||||
desc: "With a minimal cert",
|
||||
toSanitize: []byte(minimalCheeseCrt),
|
||||
expected: getExpectedSanitized(`MIIEQDCCAygCFFRY0OBk/L5Se0IZRj3CMljawL2UMA0GCSqGSIb3DQEBCwUAMIIB
|
||||
hDETMBEGCgmSJomT8ixkARkWA29yZzEWMBQGCgmSJomT8ixkARkWBmNoZWVzZTEP
|
||||
MA0GA1UECgwGQ2hlZXNlMREwDwYDVQQKDAhDaGVlc2UgMjEfMB0GA1UECwwWU2lt
|
||||
cGxlIFNpZ25pbmcgU2VjdGlvbjEhMB8GA1UECwwYU2ltcGxlIFNpZ25pbmcgU2Vj
|
||||
dGlvbiAyMRowGAYDVQQDDBFTaW1wbGUgU2lnbmluZyBDQTEcMBoGA1UEAwwTU2lt
|
||||
cGxlIFNpZ25pbmcgQ0EgMjELMAkGA1UEBhMCRlIxCzAJBgNVBAYTAlVTMREwDwYD
|
||||
VQQHDAhUT1VMT1VTRTENMAsGA1UEBwwETFlPTjEWMBQGA1UECAwNU2lnbmluZyBT
|
||||
dGF0ZTEYMBYGA1UECAwPU2lnbmluZyBTdGF0ZSAyMSEwHwYJKoZIhvcNAQkBFhJz
|
||||
aW1wbGVAc2lnbmluZy5jb20xIjAgBgkqhkiG9w0BCQEWE3NpbXBsZTJAc2lnbmlu
|
||||
Zy5jb20wHhcNMTgxMjA2MTExMDM2WhcNMjEwOTI1MTExMDM2WjAzMQswCQYDVQQG
|
||||
EwJGUjETMBEGA1UECAwKU29tZS1TdGF0ZTEPMA0GA1UECgwGQ2hlZXNlMIIBIjAN
|
||||
BgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAskX/bUtwFo1gF2BTPNaNcTUMaRFu
|
||||
FMZozK8IgLjccZ4kZ0R9oFO6Yp8Zl/IvPaf7tE26PI7XP7eHriUdhnQzX7iioDd0
|
||||
RZa68waIhAGc+xPzRFrP3b3yj3S2a9Rve3c0K+SCV+EtKAwsxMqQDhoo9PcBfo5B
|
||||
RHfht07uD5MncUcGirwN+/pxHV5xzAGPcc7On0/5L7bq/G+63nhu78zw9XyuLaHC
|
||||
PM5VbOUvpyIESJHbMMzTdFGL8ob9VKO+Kr1kVGdEA9i8FLGl3xz/GBKuW/JD0xyW
|
||||
DrU29mri5vYWHmkuv7ZWHGXnpXjTtPHwveE9/0/ArnmpMyR9JtqFr1oEvQIDAQAB
|
||||
MA0GCSqGSIb3DQEBCwUAA4IBAQBHta+NWXI08UHeOkGzOTGRiWXsOH2dqdX6gTe9
|
||||
xF1AIjyoQ0gvpoGVvlnChSzmlUj+vnx/nOYGIt1poE3hZA3ZHZD/awsvGyp3GwWD
|
||||
IfXrEViSCIyF+8tNNKYyUcEO3xdAsAUGgfUwwF/mZ6MBV5+A/ZEEILlTq8zFt9dV
|
||||
vdKzIt7fZYxYBBHFSarl1x8pDgWXlf3hAufevGJXip9xGYmznF0T5cq1RbWJ4be3
|
||||
/9K7yuWhuBYC3sbTbCneHBa91M82za+PIISc1ygCYtWSBoZKSAqLk0rkZpHaekDP
|
||||
WqeUSNGYV//RunTeuRDAf5OxehERb1srzBXhRZ3cZdzXbgR/`),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Equal(t, test.expected, sanitize(test.toSanitize), "The sanitized certificates should be equal")
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestTLSClientHeadersWithPEM(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
certContents []string // set the request TLS attribute if defined
|
||||
config config.PassTLSClientCert
|
||||
expectedHeader string
|
||||
}{
|
||||
{
|
||||
desc: "No TLS, no option",
|
||||
},
|
||||
{
|
||||
desc: "TLS, no option",
|
||||
certContents: []string{minimalCheeseCrt},
|
||||
},
|
||||
{
|
||||
desc: "No TLS, with pem option true",
|
||||
config: config.PassTLSClientCert{PEM: true},
|
||||
},
|
||||
{
|
||||
desc: "TLS with simple certificate, with pem option true",
|
||||
certContents: []string{minimalCheeseCrt},
|
||||
config: config.PassTLSClientCert{PEM: true},
|
||||
expectedHeader: getCleanCertContents([]string{minimalCert}),
|
||||
},
|
||||
{
|
||||
desc: "TLS with complete certificate, with pem option true",
|
||||
certContents: []string{minimalCheeseCrt},
|
||||
config: config.PassTLSClientCert{PEM: true},
|
||||
expectedHeader: getCleanCertContents([]string{minimalCheeseCrt}),
|
||||
},
|
||||
{
|
||||
desc: "TLS with two certificate, with pem option true",
|
||||
certContents: []string{minimalCert, minimalCheeseCrt},
|
||||
config: config.PassTLSClientCert{PEM: true},
|
||||
expectedHeader: getCleanCertContents([]string{minimalCert, minimalCheeseCrt}),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
tlsClientHeaders, err := New(context.Background(), next, test.config, "foo")
|
||||
require.NoError(t, err)
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://example.com/foo", nil)
|
||||
|
||||
if test.certContents != nil && len(test.certContents) > 0 {
|
||||
req.TLS = buildTLSWith(test.certContents)
|
||||
}
|
||||
|
||||
tlsClientHeaders.ServeHTTP(res, req)
|
||||
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Equal(t, http.StatusOK, res.Code, "Http Status should be OK")
|
||||
require.Equal(t, "bar", res.Body.String(), "Should be the expected body")
|
||||
|
||||
if test.expectedHeader != "" {
|
||||
require.Equal(t, getCleanCertContents(test.certContents), req.Header.Get(xForwardedTLSClientCert), "The request header should contain the cleaned certificate")
|
||||
} else {
|
||||
require.Empty(t, req.Header.Get(xForwardedTLSClientCert))
|
||||
}
|
||||
require.Empty(t, res.Header().Get(xForwardedTLSClientCert), "The response header should be always empty")
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestGetSans(t *testing.T) {
|
||||
urlFoo, err := url.Parse("my.foo.com")
|
||||
require.NoError(t, err)
|
||||
urlBar, err := url.Parse("my.bar.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
cert *x509.Certificate // set the request TLS attribute if defined
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
desc: "With nil",
|
||||
},
|
||||
{
|
||||
desc: "Certificate without Sans",
|
||||
cert: &x509.Certificate{},
|
||||
},
|
||||
{
|
||||
desc: "Certificate with all Sans",
|
||||
cert: &x509.Certificate{
|
||||
DNSNames: []string{"foo", "bar"},
|
||||
EmailAddresses: []string{"test@test.com", "test2@test.com"},
|
||||
IPAddresses: []net.IP{net.IPv4(10, 0, 0, 1), net.IPv4(10, 0, 0, 2)},
|
||||
URIs: []*url.URL{urlFoo, urlBar},
|
||||
},
|
||||
expected: []string{"foo", "bar", "test@test.com", "test2@test.com", "10.0.0.1", "10.0.0.2", urlFoo.String(), urlBar.String()},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
sans := getSANs(test.cert)
|
||||
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if len(test.expected) > 0 {
|
||||
for i, expected := range test.expected {
|
||||
require.Equal(t, expected, sans[i])
|
||||
}
|
||||
} else {
|
||||
require.Empty(t, sans)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestTLSClientHeadersWithCertInfo(t *testing.T) {
|
||||
minimalCheeseCertAllInfo := `Subject="C=FR,ST=Some-State,O=Cheese",Issuer="DC=org,DC=cheese,C=FR,C=US,ST=Signing State,ST=Signing State 2,L=TOULOUSE,L=LYON,O=Cheese,O=Cheese 2,CN=Simple Signing CA 2",NB=1544094636,NA=1632568236,SAN=`
|
||||
completeCertAllInfo := `Subject="DC=org,DC=cheese,C=FR,C=US,ST=Cheese org state,ST=Cheese com state,L=TOULOUSE,L=LYON,O=Cheese,O=Cheese 2,CN=*.cheese.com",Issuer="DC=org,DC=cheese,C=FR,C=US,ST=Signing State,ST=Signing State 2,L=TOULOUSE,L=LYON,O=Cheese,O=Cheese 2,CN=Simple Signing CA 2",NB=1544094616,NA=1607166616,SAN=*.cheese.org,*.cheese.net,*.cheese.com,test@cheese.org,test@cheese.net,10.0.1.0,10.0.1.2`
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
certContents []string // set the request TLS attribute if defined
|
||||
config config.PassTLSClientCert
|
||||
expectedHeader string
|
||||
}{
|
||||
{
|
||||
desc: "No TLS, no option",
|
||||
},
|
||||
{
|
||||
desc: "TLS, no option",
|
||||
certContents: []string{minimalCert},
|
||||
},
|
||||
{
|
||||
desc: "No TLS, with subject info",
|
||||
config: config.PassTLSClientCert{
|
||||
Info: &config.TLSClientCertificateInfo{
|
||||
Subject: &config.TLSCLientCertificateDNInfo{
|
||||
CommonName: true,
|
||||
Organization: true,
|
||||
Locality: true,
|
||||
Province: true,
|
||||
Country: true,
|
||||
SerialNumber: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "No TLS, with pem option false with empty subject info",
|
||||
config: config.PassTLSClientCert{
|
||||
PEM: false,
|
||||
Info: &config.TLSClientCertificateInfo{
|
||||
Subject: &config.TLSCLientCertificateDNInfo{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "TLS with simple certificate, with all info",
|
||||
certContents: []string{minimalCheeseCrt},
|
||||
config: config.PassTLSClientCert{
|
||||
Info: &config.TLSClientCertificateInfo{
|
||||
NotAfter: true,
|
||||
NotBefore: true,
|
||||
Sans: true,
|
||||
Subject: &config.TLSCLientCertificateDNInfo{
|
||||
CommonName: true,
|
||||
Country: true,
|
||||
DomainComponent: true,
|
||||
Locality: true,
|
||||
Organization: true,
|
||||
Province: true,
|
||||
SerialNumber: true,
|
||||
},
|
||||
Issuer: &config.TLSCLientCertificateDNInfo{
|
||||
CommonName: true,
|
||||
Country: true,
|
||||
DomainComponent: true,
|
||||
Locality: true,
|
||||
Organization: true,
|
||||
Province: true,
|
||||
SerialNumber: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedHeader: url.QueryEscape(minimalCheeseCertAllInfo),
|
||||
},
|
||||
{
|
||||
desc: "TLS with simple certificate, with some info",
|
||||
certContents: []string{minimalCheeseCrt},
|
||||
config: config.PassTLSClientCert{
|
||||
Info: &config.TLSClientCertificateInfo{
|
||||
NotAfter: true,
|
||||
Sans: true,
|
||||
Subject: &config.TLSCLientCertificateDNInfo{
|
||||
Organization: true,
|
||||
},
|
||||
Issuer: &config.TLSCLientCertificateDNInfo{
|
||||
Country: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedHeader: url.QueryEscape(`Subject="O=Cheese",Issuer="C=FR,C=US",NA=1632568236,SAN=`),
|
||||
},
|
||||
{
|
||||
desc: "TLS with complete certificate, with all info",
|
||||
certContents: []string{completeCheeseCrt},
|
||||
config: config.PassTLSClientCert{
|
||||
Info: &config.TLSClientCertificateInfo{
|
||||
NotAfter: true,
|
||||
NotBefore: true,
|
||||
Sans: true,
|
||||
Subject: &config.TLSCLientCertificateDNInfo{
|
||||
Country: true,
|
||||
Province: true,
|
||||
Locality: true,
|
||||
Organization: true,
|
||||
CommonName: true,
|
||||
SerialNumber: true,
|
||||
DomainComponent: true,
|
||||
},
|
||||
Issuer: &config.TLSCLientCertificateDNInfo{
|
||||
Country: true,
|
||||
Province: true,
|
||||
Locality: true,
|
||||
Organization: true,
|
||||
CommonName: true,
|
||||
SerialNumber: true,
|
||||
DomainComponent: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedHeader: url.QueryEscape(completeCertAllInfo),
|
||||
},
|
||||
{
|
||||
desc: "TLS with 2 certificates, with all info",
|
||||
certContents: []string{minimalCheeseCrt, completeCheeseCrt},
|
||||
config: config.PassTLSClientCert{
|
||||
Info: &config.TLSClientCertificateInfo{
|
||||
NotAfter: true,
|
||||
NotBefore: true,
|
||||
Sans: true,
|
||||
Subject: &config.TLSCLientCertificateDNInfo{
|
||||
Country: true,
|
||||
Province: true,
|
||||
Locality: true,
|
||||
Organization: true,
|
||||
CommonName: true,
|
||||
SerialNumber: true,
|
||||
DomainComponent: true,
|
||||
},
|
||||
Issuer: &config.TLSCLientCertificateDNInfo{
|
||||
Country: true,
|
||||
Province: true,
|
||||
Locality: true,
|
||||
Organization: true,
|
||||
CommonName: true,
|
||||
SerialNumber: true,
|
||||
DomainComponent: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedHeader: url.QueryEscape(strings.Join([]string{minimalCheeseCertAllInfo, completeCertAllInfo}, ";")),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
tlsClientHeaders, err := New(context.Background(), next, test.config, "foo")
|
||||
require.NoError(t, err)
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://example.com/foo", nil)
|
||||
|
||||
if test.certContents != nil && len(test.certContents) > 0 {
|
||||
req.TLS = buildTLSWith(test.certContents)
|
||||
}
|
||||
|
||||
tlsClientHeaders.ServeHTTP(res, req)
|
||||
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Equal(t, http.StatusOK, res.Code, "Http Status should be OK")
|
||||
require.Equal(t, "bar", res.Body.String(), "Should be the expected body")
|
||||
|
||||
if test.expectedHeader != "" {
|
||||
require.Equal(t, test.expectedHeader, req.Header.Get(xForwardedTLSClientCertInfo), "The request header should contain the cleaned certificate")
|
||||
} else {
|
||||
require.Empty(t, req.Header.Get(xForwardedTLSClientCertInfo))
|
||||
}
|
||||
require.Empty(t, res.Header().Get(xForwardedTLSClientCertInfo), "The response header should be always empty")
|
||||
})
|
||||
}
|
||||
|
||||
}
|
54
pkg/middlewares/ratelimiter/rate_limiter.go
Normal file
54
pkg/middlewares/ratelimiter/rate_limiter.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package ratelimiter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/middlewares"
|
||||
"github.com/containous/traefik/pkg/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/vulcand/oxy/ratelimit"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "RateLimiterType"
|
||||
)
|
||||
|
||||
type rateLimiter struct {
|
||||
handler http.Handler
|
||||
name string
|
||||
}
|
||||
|
||||
// New creates rate limiter middleware.
|
||||
func New(ctx context.Context, next http.Handler, config config.RateLimit, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
|
||||
extractFunc, err := utils.NewExtractor(config.ExtractorFunc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rateSet := ratelimit.NewRateSet()
|
||||
for _, rate := range config.RateSet {
|
||||
if err = rateSet.Add(time.Duration(rate.Period), rate.Average, rate.Burst); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
rl, err := ratelimit.New(next, extractFunc, rateSet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &rateLimiter{handler: rl, name: name}, nil
|
||||
}
|
||||
|
||||
func (r *rateLimiter) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return r.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (r *rateLimiter) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
r.handler.ServeHTTP(rw, req)
|
||||
}
|
40
pkg/middlewares/recovery/recovery.go
Normal file
40
pkg/middlewares/recovery/recovery.go
Normal file
|
@ -0,0 +1,40 @@
|
|||
package recovery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/pkg/middlewares"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "Recovery"
|
||||
)
|
||||
|
||||
type recovery struct {
|
||||
next http.Handler
|
||||
name string
|
||||
}
|
||||
|
||||
// New creates recovery middleware.
|
||||
func New(ctx context.Context, next http.Handler, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
|
||||
return &recovery{
|
||||
next: next,
|
||||
name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (re *recovery) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
defer recoverFunc(middlewares.GetLogger(req.Context(), re.name, typeName), rw)
|
||||
re.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
func recoverFunc(logger logrus.FieldLogger, rw http.ResponseWriter) {
|
||||
if err := recover(); err != nil {
|
||||
logger.Errorf("Recovered from panic in http handler: %+v", err)
|
||||
http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
27
pkg/middlewares/recovery/recovery_test.go
Normal file
27
pkg/middlewares/recovery/recovery_test.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
package recovery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRecoverHandler(t *testing.T) {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
panic("I love panicing!")
|
||||
}
|
||||
recovery, err := New(context.Background(), http.HandlerFunc(fn), "foo-recovery")
|
||||
require.NoError(t, err)
|
||||
|
||||
server := httptest.NewServer(recovery)
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
|
||||
}
|
158
pkg/middlewares/redirect/redirect.go
Normal file
158
pkg/middlewares/redirect/redirect.go
Normal file
|
@ -0,0 +1,158 @@
|
|||
package redirect
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"html/template"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/pkg/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
type redirect struct {
|
||||
next http.Handler
|
||||
regex *regexp.Regexp
|
||||
replacement string
|
||||
permanent bool
|
||||
errHandler utils.ErrorHandler
|
||||
name string
|
||||
}
|
||||
|
||||
// New creates a Redirect middleware.
|
||||
func newRedirect(_ context.Context, next http.Handler, regex string, replacement string, permanent bool, name string) (http.Handler, error) {
|
||||
re, err := regexp.Compile(regex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &redirect{
|
||||
regex: re,
|
||||
replacement: replacement,
|
||||
permanent: permanent,
|
||||
errHandler: utils.DefaultHandler,
|
||||
next: next,
|
||||
name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *redirect) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return r.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (r *redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
oldURL := rawURL(req)
|
||||
|
||||
// If the Regexp doesn't match, skip to the next handler
|
||||
if !r.regex.MatchString(oldURL) {
|
||||
r.next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
// apply a rewrite regexp to the URL
|
||||
newURL := r.regex.ReplaceAllString(oldURL, r.replacement)
|
||||
|
||||
// replace any variables that may be in there
|
||||
rewrittenURL := &bytes.Buffer{}
|
||||
if err := applyString(newURL, rewrittenURL, req); err != nil {
|
||||
r.errHandler.ServeHTTP(rw, req, err)
|
||||
return
|
||||
}
|
||||
|
||||
// parse the rewritten URL and replace request URL with it
|
||||
parsedURL, err := url.Parse(rewrittenURL.String())
|
||||
if err != nil {
|
||||
r.errHandler.ServeHTTP(rw, req, err)
|
||||
return
|
||||
}
|
||||
|
||||
if newURL != oldURL {
|
||||
handler := &moveHandler{location: parsedURL, permanent: r.permanent}
|
||||
handler.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
req.URL = parsedURL
|
||||
|
||||
// make sure the request URI corresponds the rewritten URL
|
||||
req.RequestURI = req.URL.RequestURI()
|
||||
r.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
type moveHandler struct {
|
||||
location *url.URL
|
||||
permanent bool
|
||||
}
|
||||
|
||||
func (m *moveHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("Location", m.location.String())
|
||||
|
||||
status := http.StatusFound
|
||||
if req.Method != http.MethodGet {
|
||||
status = http.StatusTemporaryRedirect
|
||||
}
|
||||
|
||||
if m.permanent {
|
||||
status = http.StatusMovedPermanently
|
||||
if req.Method != http.MethodGet {
|
||||
status = http.StatusPermanentRedirect
|
||||
}
|
||||
}
|
||||
rw.WriteHeader(status)
|
||||
_, err := rw.Write([]byte(http.StatusText(status)))
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func rawURL(req *http.Request) string {
|
||||
scheme := "http"
|
||||
host := req.Host
|
||||
port := ""
|
||||
uri := req.RequestURI
|
||||
|
||||
schemeRegex := `^(https?):\/\/([\w\._-]+)(:\d+)?(.*)$`
|
||||
re, _ := regexp.Compile(schemeRegex)
|
||||
if re.Match([]byte(req.RequestURI)) {
|
||||
match := re.FindStringSubmatch(req.RequestURI)
|
||||
scheme = match[1]
|
||||
|
||||
if len(match[2]) > 0 {
|
||||
host = match[2]
|
||||
}
|
||||
|
||||
if len(match[3]) > 0 {
|
||||
port = match[3]
|
||||
}
|
||||
|
||||
uri = match[4]
|
||||
}
|
||||
|
||||
if req.TLS != nil || isXForwardedHTTPS(req) {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
return strings.Join([]string{scheme, "://", host, port, uri}, "")
|
||||
}
|
||||
|
||||
func isXForwardedHTTPS(request *http.Request) bool {
|
||||
xForwardedProto := request.Header.Get("X-Forwarded-Proto")
|
||||
|
||||
return len(xForwardedProto) > 0 && xForwardedProto == "https"
|
||||
}
|
||||
|
||||
func applyString(in string, out io.Writer, req *http.Request) error {
|
||||
t, err := template.New("t").Parse(in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data := struct{ Request *http.Request }{Request: req}
|
||||
|
||||
return t.Execute(out, data)
|
||||
}
|
22
pkg/middlewares/redirect/redirect_regex.go
Normal file
22
pkg/middlewares/redirect/redirect_regex.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
package redirect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/middlewares"
|
||||
)
|
||||
|
||||
const (
|
||||
typeRegexName = "RedirectRegex"
|
||||
)
|
||||
|
||||
// NewRedirectRegex creates a redirect middleware.
|
||||
func NewRedirectRegex(ctx context.Context, next http.Handler, conf config.RedirectRegex, name string) (http.Handler, error) {
|
||||
logger := middlewares.GetLogger(ctx, name, typeRegexName)
|
||||
logger.Debug("Creating middleware")
|
||||
logger.Debugf("Setting up redirection from %s to %s", conf.Regex, conf.Replacement)
|
||||
|
||||
return newRedirect(ctx, next, conf.Regex, conf.Replacement, conf.Permanent, name)
|
||||
}
|
198
pkg/middlewares/redirect/redirect_regex_test.go
Normal file
198
pkg/middlewares/redirect/redirect_regex_test.go
Normal file
|
@ -0,0 +1,198 @@
|
|||
package redirect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRedirectRegexHandler(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
config config.RedirectRegex
|
||||
method string
|
||||
url string
|
||||
secured bool
|
||||
expectedURL string
|
||||
expectedStatus int
|
||||
errorExpected bool
|
||||
}{
|
||||
{
|
||||
desc: "simple redirection",
|
||||
config: config.RedirectRegex{
|
||||
Regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
|
||||
Replacement: "https://${1}bar$2:443$4",
|
||||
},
|
||||
url: "http://foo.com:80",
|
||||
expectedURL: "https://foobar.com:443",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "use request header",
|
||||
config: config.RedirectRegex{
|
||||
Regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
|
||||
Replacement: `https://${1}{{ .Request.Header.Get "X-Foo" }}$2:443$4`,
|
||||
},
|
||||
url: "http://foo.com:80",
|
||||
expectedURL: "https://foobar.com:443",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "URL doesn't match regex",
|
||||
config: config.RedirectRegex{
|
||||
Regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
|
||||
Replacement: "https://${1}bar$2:443$4",
|
||||
},
|
||||
url: "http://bar.com:80",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
desc: "invalid rewritten URL",
|
||||
config: config.RedirectRegex{
|
||||
Regex: `^(.*)$`,
|
||||
Replacement: "http://192.168.0.%31/",
|
||||
},
|
||||
url: "http://foo.com:80",
|
||||
expectedStatus: http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
desc: "invalid regex",
|
||||
config: config.RedirectRegex{
|
||||
Regex: `^(.*`,
|
||||
Replacement: "$1",
|
||||
},
|
||||
url: "http://foo.com:80",
|
||||
errorExpected: true,
|
||||
},
|
||||
{
|
||||
desc: "HTTP to HTTPS permanent",
|
||||
config: config.RedirectRegex{
|
||||
Regex: `^http://`,
|
||||
Replacement: "https://$1",
|
||||
Permanent: true,
|
||||
},
|
||||
url: "http://foo",
|
||||
expectedURL: "https://foo",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
desc: "HTTPS to HTTP permanent",
|
||||
config: config.RedirectRegex{
|
||||
Regex: `https://foo`,
|
||||
Replacement: "http://foo",
|
||||
Permanent: true,
|
||||
},
|
||||
secured: true,
|
||||
url: "https://foo",
|
||||
expectedURL: "http://foo",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
desc: "HTTP to HTTPS",
|
||||
config: config.RedirectRegex{
|
||||
Regex: `http://foo:80`,
|
||||
Replacement: "https://foo:443",
|
||||
},
|
||||
url: "http://foo:80",
|
||||
expectedURL: "https://foo:443",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTPS to HTTP",
|
||||
config: config.RedirectRegex{
|
||||
Regex: `https://foo:443`,
|
||||
Replacement: "http://foo:80",
|
||||
},
|
||||
secured: true,
|
||||
url: "https://foo:443",
|
||||
expectedURL: "http://foo:80",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTP to HTTP",
|
||||
config: config.RedirectRegex{
|
||||
Regex: `http://foo:80`,
|
||||
Replacement: "http://foo:88",
|
||||
},
|
||||
url: "http://foo:80",
|
||||
expectedURL: "http://foo:88",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTP to HTTP POST",
|
||||
config: config.RedirectRegex{
|
||||
Regex: `^http://`,
|
||||
Replacement: "https://$1",
|
||||
},
|
||||
url: "http://foo",
|
||||
method: http.MethodPost,
|
||||
expectedURL: "https://foo",
|
||||
expectedStatus: http.StatusTemporaryRedirect,
|
||||
},
|
||||
{
|
||||
desc: "HTTP to HTTP POST permanent",
|
||||
config: config.RedirectRegex{
|
||||
Regex: `^http://`,
|
||||
Replacement: "https://$1",
|
||||
Permanent: true,
|
||||
},
|
||||
url: "http://foo",
|
||||
method: http.MethodPost,
|
||||
expectedURL: "https://foo",
|
||||
expectedStatus: http.StatusPermanentRedirect,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
handler, err := NewRedirectRegex(context.Background(), next, test.config, "traefikTest")
|
||||
|
||||
if test.errorExpected {
|
||||
require.Error(t, err)
|
||||
require.Nil(t, handler)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, handler)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
method := http.MethodGet
|
||||
if test.method != "" {
|
||||
method = test.method
|
||||
}
|
||||
r := testhelpers.MustNewRequest(method, test.url, nil)
|
||||
if test.secured {
|
||||
r.TLS = &tls.ConnectionState{}
|
||||
}
|
||||
r.Header.Set("X-Foo", "bar")
|
||||
handler.ServeHTTP(recorder, r)
|
||||
|
||||
assert.Equal(t, test.expectedStatus, recorder.Code)
|
||||
if test.expectedStatus == http.StatusMovedPermanently ||
|
||||
test.expectedStatus == http.StatusFound ||
|
||||
test.expectedStatus == http.StatusTemporaryRedirect ||
|
||||
test.expectedStatus == http.StatusPermanentRedirect {
|
||||
|
||||
location, err := recorder.Result().Location()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.expectedURL, location.String())
|
||||
} else {
|
||||
location, err := recorder.Result().Location()
|
||||
require.Errorf(t, err, "Location %v", location)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
34
pkg/middlewares/redirect/redirect_scheme.go
Normal file
34
pkg/middlewares/redirect/redirect_scheme.go
Normal file
|
@ -0,0 +1,34 @@
|
|||
package redirect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/pkg/middlewares"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
)
|
||||
|
||||
const (
|
||||
typeSchemeName = "RedirectScheme"
|
||||
schemeRedirectRegex = `^(https?:\/\/)?([\w\._-]+)(:\d+)?(.*)$`
|
||||
)
|
||||
|
||||
// NewRedirectScheme creates a new RedirectScheme middleware.
|
||||
func NewRedirectScheme(ctx context.Context, next http.Handler, conf config.RedirectScheme, name string) (http.Handler, error) {
|
||||
logger := middlewares.GetLogger(ctx, name, typeSchemeName)
|
||||
logger.Debug("Creating middleware")
|
||||
logger.Debugf("Setting up redirection to %s %s", conf.Scheme, conf.Port)
|
||||
|
||||
if len(conf.Scheme) == 0 {
|
||||
return nil, errors.New("you must provide a target scheme")
|
||||
}
|
||||
|
||||
port := ""
|
||||
if len(conf.Port) > 0 && !(conf.Scheme == "http" && conf.Port == "80" || conf.Scheme == "https" && conf.Port == "443") {
|
||||
port = ":" + conf.Port
|
||||
}
|
||||
|
||||
return newRedirect(ctx, next, schemeRedirectRegex, conf.Scheme+"://${2}"+port+"${4}", conf.Permanent, name)
|
||||
}
|
250
pkg/middlewares/redirect/redirect_scheme_test.go
Normal file
250
pkg/middlewares/redirect/redirect_scheme_test.go
Normal file
|
@ -0,0 +1,250 @@
|
|||
package redirect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRedirectSchemeHandler(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
config config.RedirectScheme
|
||||
method string
|
||||
url string
|
||||
secured bool
|
||||
expectedURL string
|
||||
expectedStatus int
|
||||
errorExpected bool
|
||||
}{
|
||||
{
|
||||
desc: "Without scheme",
|
||||
config: config.RedirectScheme{},
|
||||
url: "http://foo",
|
||||
errorExpected: true,
|
||||
},
|
||||
{
|
||||
desc: "HTTP to HTTPS",
|
||||
config: config.RedirectScheme{
|
||||
Scheme: "https",
|
||||
},
|
||||
url: "http://foo",
|
||||
expectedURL: "https://foo",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTP with port to HTTPS without port",
|
||||
config: config.RedirectScheme{
|
||||
Scheme: "https",
|
||||
},
|
||||
url: "http://foo:8080",
|
||||
expectedURL: "https://foo",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTP without port to HTTPS with port",
|
||||
config: config.RedirectScheme{
|
||||
Scheme: "https",
|
||||
Port: "8443",
|
||||
},
|
||||
url: "http://foo",
|
||||
expectedURL: "https://foo:8443",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTP with port to HTTPS with port",
|
||||
config: config.RedirectScheme{
|
||||
Scheme: "https",
|
||||
Port: "8443",
|
||||
},
|
||||
url: "http://foo:8000",
|
||||
expectedURL: "https://foo:8443",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTPS with port to HTTPS with port",
|
||||
config: config.RedirectScheme{
|
||||
Scheme: "https",
|
||||
Port: "8443",
|
||||
},
|
||||
url: "https://foo:8000",
|
||||
expectedURL: "https://foo:8443",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTPS with port to HTTPS without port",
|
||||
config: config.RedirectScheme{
|
||||
Scheme: "https",
|
||||
},
|
||||
url: "https://foo:8000",
|
||||
expectedURL: "https://foo",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "redirection to HTTPS without port from an URL already in https",
|
||||
config: config.RedirectScheme{
|
||||
Scheme: "https",
|
||||
},
|
||||
url: "https://foo:8000/theother",
|
||||
expectedURL: "https://foo/theother",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTP to HTTPS permanent",
|
||||
config: config.RedirectScheme{
|
||||
Scheme: "https",
|
||||
Port: "8443",
|
||||
Permanent: true,
|
||||
},
|
||||
url: "http://foo",
|
||||
expectedURL: "https://foo:8443",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
desc: "to HTTP 80",
|
||||
config: config.RedirectScheme{
|
||||
Scheme: "http",
|
||||
Port: "80",
|
||||
},
|
||||
url: "http://foo:80",
|
||||
expectedURL: "http://foo",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTP to wss",
|
||||
config: config.RedirectScheme{
|
||||
Scheme: "wss",
|
||||
Port: "9443",
|
||||
},
|
||||
url: "http://foo",
|
||||
expectedURL: "wss://foo:9443",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTP to wss without port",
|
||||
config: config.RedirectScheme{
|
||||
Scheme: "wss",
|
||||
},
|
||||
url: "http://foo",
|
||||
expectedURL: "wss://foo",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTP with port to wss without port",
|
||||
config: config.RedirectScheme{
|
||||
Scheme: "wss",
|
||||
},
|
||||
url: "http://foo:5678",
|
||||
expectedURL: "wss://foo",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTP to HTTPS without port",
|
||||
config: config.RedirectScheme{
|
||||
Scheme: "https",
|
||||
},
|
||||
url: "http://foo:443",
|
||||
expectedURL: "https://foo",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTP port redirection",
|
||||
config: config.RedirectScheme{
|
||||
Scheme: "http",
|
||||
Port: "8181",
|
||||
},
|
||||
url: "http://foo:8080",
|
||||
expectedURL: "http://foo:8181",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTPS with port 80 to HTTPS without port",
|
||||
config: config.RedirectScheme{
|
||||
Scheme: "https",
|
||||
},
|
||||
url: "https://foo:80",
|
||||
expectedURL: "https://foo",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
handler, err := NewRedirectScheme(context.Background(), next, test.config, "traefikTest")
|
||||
|
||||
if test.errorExpected {
|
||||
require.Error(t, err)
|
||||
require.Nil(t, handler)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, handler)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
method := http.MethodGet
|
||||
if test.method != "" {
|
||||
method = test.method
|
||||
}
|
||||
r := httptest.NewRequest(method, test.url, nil)
|
||||
|
||||
if test.secured {
|
||||
r.TLS = &tls.ConnectionState{}
|
||||
}
|
||||
r.Header.Set("X-Foo", "bar")
|
||||
handler.ServeHTTP(recorder, r)
|
||||
|
||||
assert.Equal(t, test.expectedStatus, recorder.Code)
|
||||
if test.expectedStatus == http.StatusMovedPermanently ||
|
||||
test.expectedStatus == http.StatusFound ||
|
||||
test.expectedStatus == http.StatusTemporaryRedirect ||
|
||||
test.expectedStatus == http.StatusPermanentRedirect {
|
||||
|
||||
location, err := recorder.Result().Location()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.expectedURL, location.String())
|
||||
} else {
|
||||
location, err := recorder.Result().Location()
|
||||
require.Errorf(t, err, "Location %v", location)
|
||||
}
|
||||
|
||||
schemeRegex := `^(https?):\/\/([\w\._-]+)(:\d+)?(.*)$`
|
||||
re, _ := regexp.Compile(schemeRegex)
|
||||
|
||||
if re.Match([]byte(test.url)) {
|
||||
match := re.FindStringSubmatch(test.url)
|
||||
r.RequestURI = match[4]
|
||||
|
||||
handler.ServeHTTP(recorder, r)
|
||||
|
||||
assert.Equal(t, test.expectedStatus, recorder.Code)
|
||||
if test.expectedStatus == http.StatusMovedPermanently ||
|
||||
test.expectedStatus == http.StatusFound ||
|
||||
test.expectedStatus == http.StatusTemporaryRedirect ||
|
||||
test.expectedStatus == http.StatusPermanentRedirect {
|
||||
|
||||
location, err := recorder.Result().Location()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.expectedURL, location.String())
|
||||
} else {
|
||||
location, err := recorder.Result().Location()
|
||||
require.Errorf(t, err, "Location %v", location)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
46
pkg/middlewares/replacepath/replace_path.go
Normal file
46
pkg/middlewares/replacepath/replace_path.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package replacepath
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/middlewares"
|
||||
"github.com/containous/traefik/pkg/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
)
|
||||
|
||||
const (
|
||||
// ReplacedPathHeader is the default header to set the old path to.
|
||||
ReplacedPathHeader = "X-Replaced-Path"
|
||||
typeName = "ReplacePath"
|
||||
)
|
||||
|
||||
// ReplacePath is a middleware used to replace the path of a URL request.
|
||||
type replacePath struct {
|
||||
next http.Handler
|
||||
path string
|
||||
name string
|
||||
}
|
||||
|
||||
// New creates a new replace path middleware.
|
||||
func New(ctx context.Context, next http.Handler, config config.ReplacePath, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
|
||||
return &replacePath{
|
||||
next: next,
|
||||
path: config.Path,
|
||||
name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *replacePath) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return r.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (r *replacePath) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
req.Header.Add(ReplacedPathHeader, req.URL.Path)
|
||||
req.URL.Path = r.path
|
||||
req.RequestURI = req.URL.RequestURI()
|
||||
r.next.ServeHTTP(rw, req)
|
||||
}
|
46
pkg/middlewares/replacepath/replace_path_test.go
Normal file
46
pkg/middlewares/replacepath/replace_path_test.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package replacepath
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestReplacePath(t *testing.T) {
|
||||
var replacementConfig = config.ReplacePath{
|
||||
Path: "/replacement-path",
|
||||
}
|
||||
|
||||
paths := []string{
|
||||
"/example",
|
||||
"/some/really/long/path",
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
|
||||
var expectedPath, actualHeader, requestURI string
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
expectedPath = r.URL.Path
|
||||
actualHeader = r.Header.Get(ReplacedPathHeader)
|
||||
requestURI = r.RequestURI
|
||||
})
|
||||
|
||||
handler, err := New(context.Background(), next, replacementConfig, "foo-replace-path")
|
||||
require.NoError(t, err)
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+path, nil)
|
||||
|
||||
handler.ServeHTTP(nil, req)
|
||||
|
||||
assert.Equal(t, expectedPath, replacementConfig.Path, "Unexpected path.")
|
||||
assert.Equal(t, path, actualHeader, "Unexpected '%s' header.", ReplacedPathHeader)
|
||||
assert.Equal(t, expectedPath, requestURI, "Unexpected request URI.")
|
||||
})
|
||||
}
|
||||
}
|
57
pkg/middlewares/replacepathregex/replace_path_regex.go
Normal file
57
pkg/middlewares/replacepathregex/replace_path_regex.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package replacepathregex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/middlewares"
|
||||
"github.com/containous/traefik/pkg/middlewares/replacepath"
|
||||
"github.com/containous/traefik/pkg/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
)
|
||||
|
||||
const (
|
||||
typeName = "ReplacePathRegex"
|
||||
)
|
||||
|
||||
// ReplacePathRegex is a middleware used to replace the path of a URL request with a regular expression.
|
||||
type replacePathRegex struct {
|
||||
next http.Handler
|
||||
regexp *regexp.Regexp
|
||||
replacement string
|
||||
name string
|
||||
}
|
||||
|
||||
// New creates a new replace path regex middleware.
|
||||
func New(ctx context.Context, next http.Handler, config config.ReplacePathRegex, name string) (http.Handler, error) {
|
||||
middlewares.GetLogger(ctx, name, typeName).Debug("Creating middleware")
|
||||
|
||||
exp, err := regexp.Compile(strings.TrimSpace(config.Regex))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error compiling regular expression %s: %s", config.Regex, err)
|
||||
}
|
||||
|
||||
return &replacePathRegex{
|
||||
regexp: exp,
|
||||
replacement: strings.TrimSpace(config.Replacement),
|
||||
next: next,
|
||||
name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (rp *replacePathRegex) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return rp.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (rp *replacePathRegex) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if rp.regexp != nil && len(rp.replacement) > 0 && rp.regexp.MatchString(req.URL.Path) {
|
||||
req.Header.Add(replacepath.ReplacedPathHeader, req.URL.Path)
|
||||
req.URL.Path = rp.regexp.ReplaceAllString(req.URL.Path, rp.replacement)
|
||||
req.RequestURI = req.URL.RequestURI()
|
||||
}
|
||||
rp.next.ServeHTTP(rw, req)
|
||||
}
|
104
pkg/middlewares/replacepathregex/replace_path_regex_test.go
Normal file
104
pkg/middlewares/replacepathregex/replace_path_regex_test.go
Normal file
|
@ -0,0 +1,104 @@
|
|||
package replacepathregex
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/middlewares/replacepath"
|
||||
"github.com/containous/traefik/pkg/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestReplacePathRegex(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
path string
|
||||
config config.ReplacePathRegex
|
||||
expectedPath string
|
||||
expectedHeader string
|
||||
expectsError bool
|
||||
}{
|
||||
{
|
||||
desc: "simple regex",
|
||||
path: "/whoami/and/whoami",
|
||||
config: config.ReplacePathRegex{
|
||||
Replacement: "/who-am-i/$1",
|
||||
Regex: `^/whoami/(.*)`,
|
||||
},
|
||||
expectedPath: "/who-am-i/and/whoami",
|
||||
expectedHeader: "/whoami/and/whoami",
|
||||
},
|
||||
{
|
||||
desc: "simple replace (no regex)",
|
||||
path: "/whoami/and/whoami",
|
||||
config: config.ReplacePathRegex{
|
||||
Replacement: "/who-am-i",
|
||||
Regex: `/whoami`,
|
||||
},
|
||||
expectedPath: "/who-am-i/and/who-am-i",
|
||||
expectedHeader: "/whoami/and/whoami",
|
||||
},
|
||||
{
|
||||
desc: "no match",
|
||||
path: "/whoami/and/whoami",
|
||||
config: config.ReplacePathRegex{
|
||||
Replacement: "/whoami",
|
||||
Regex: `/no-match`,
|
||||
},
|
||||
expectedPath: "/whoami/and/whoami",
|
||||
},
|
||||
{
|
||||
desc: "multiple replacement",
|
||||
path: "/downloads/src/source.go",
|
||||
config: config.ReplacePathRegex{
|
||||
Replacement: "/downloads/$1-$2",
|
||||
Regex: `^(?i)/downloads/([^/]+)/([^/]+)$`,
|
||||
},
|
||||
expectedPath: "/downloads/src-source.go",
|
||||
expectedHeader: "/downloads/src/source.go",
|
||||
},
|
||||
{
|
||||
desc: "invalid regular expression",
|
||||
path: "/invalid/regexp/test",
|
||||
config: config.ReplacePathRegex{
|
||||
Replacement: "/valid/regexp/$1",
|
||||
Regex: `^(?err)/invalid/regexp/([^/]+)$`,
|
||||
},
|
||||
expectedPath: "/invalid/regexp/test",
|
||||
expectsError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
|
||||
var actualPath, actualHeader, requestURI string
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
actualPath = r.URL.Path
|
||||
actualHeader = r.Header.Get(replacepath.ReplacedPathHeader)
|
||||
requestURI = r.RequestURI
|
||||
})
|
||||
|
||||
handler, err := New(context.Background(), next, test.config, "foo-replace-path-regexp")
|
||||
if test.expectsError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
|
||||
req.RequestURI = test.path
|
||||
|
||||
handler.ServeHTTP(nil, req)
|
||||
|
||||
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
|
||||
assert.Equal(t, actualPath, requestURI, "Unexpected request URI.")
|
||||
if test.expectedHeader != "" {
|
||||
assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", replacepath.ReplacedPathHeader)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
124
pkg/middlewares/requestdecorator/hostresolver.go
Normal file
124
pkg/middlewares/requestdecorator/hostresolver.go
Normal file
|
@ -0,0 +1,124 @@
|
|||
package requestdecorator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/pkg/log"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
type cnameResolv struct {
|
||||
TTL time.Duration
|
||||
Record string
|
||||
}
|
||||
|
||||
type byTTL []*cnameResolv
|
||||
|
||||
func (a byTTL) Len() int { return len(a) }
|
||||
func (a byTTL) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a byTTL) Less(i, j int) bool { return a[i].TTL > a[j].TTL }
|
||||
|
||||
// Resolver used for host resolver.
|
||||
type Resolver struct {
|
||||
CnameFlattening bool
|
||||
ResolvConfig string
|
||||
ResolvDepth int
|
||||
cache *cache.Cache
|
||||
}
|
||||
|
||||
// CNAMEFlatten check if CNAME record exists, flatten if possible.
|
||||
func (hr *Resolver) CNAMEFlatten(ctx context.Context, host string) string {
|
||||
if hr.cache == nil {
|
||||
hr.cache = cache.New(30*time.Minute, 5*time.Minute)
|
||||
}
|
||||
|
||||
result := host
|
||||
request := host
|
||||
|
||||
value, found := hr.cache.Get(host)
|
||||
if found {
|
||||
return value.(string)
|
||||
}
|
||||
|
||||
logger := log.FromContext(ctx)
|
||||
var cacheDuration = 0 * time.Second
|
||||
for depth := 0; depth < hr.ResolvDepth; depth++ {
|
||||
resolv, err := cnameResolve(ctx, request, hr.ResolvConfig)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
break
|
||||
}
|
||||
if resolv == nil {
|
||||
break
|
||||
}
|
||||
|
||||
result = resolv.Record
|
||||
if depth == 0 {
|
||||
cacheDuration = resolv.TTL
|
||||
}
|
||||
request = resolv.Record
|
||||
}
|
||||
|
||||
if err := hr.cache.Add(host, result, cacheDuration); err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// cnameResolve resolves CNAME if exists, and return with the highest TTL.
|
||||
func cnameResolve(ctx context.Context, host string, resolvPath string) (*cnameResolv, error) {
|
||||
config, err := dns.ClientConfigFromFile(resolvPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid resolver configuration file: %s", resolvPath)
|
||||
}
|
||||
|
||||
client := &dns.Client{Timeout: 30 * time.Second}
|
||||
|
||||
m := &dns.Msg{}
|
||||
m.SetQuestion(dns.Fqdn(host), dns.TypeCNAME)
|
||||
|
||||
var result []*cnameResolv
|
||||
for _, server := range config.Servers {
|
||||
tempRecord, err := getRecord(client, m, server, config.Port)
|
||||
if err != nil {
|
||||
log.FromContext(ctx).Errorf("Failed to resolve host %s: %v", host, err)
|
||||
continue
|
||||
}
|
||||
result = append(result, tempRecord)
|
||||
}
|
||||
|
||||
if len(result) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
sort.Sort(byTTL(result))
|
||||
return result[0], nil
|
||||
}
|
||||
|
||||
func getRecord(client *dns.Client, msg *dns.Msg, server string, port string) (*cnameResolv, error) {
|
||||
resp, _, err := client.Exchange(msg, net.JoinHostPort(server, port))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("exchange error for server %s: %v", server, err)
|
||||
}
|
||||
|
||||
if resp == nil || len(resp.Answer) == 0 {
|
||||
return nil, fmt.Errorf("empty answer for server %s", server)
|
||||
}
|
||||
|
||||
rr, ok := resp.Answer[0].(*dns.CNAME)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid response type for server %s", server)
|
||||
}
|
||||
|
||||
return &cnameResolv{
|
||||
TTL: time.Duration(rr.Hdr.Ttl) * time.Second,
|
||||
Record: strings.TrimSuffix(rr.Target, "."),
|
||||
}, nil
|
||||
}
|
51
pkg/middlewares/requestdecorator/hostresolver_test.go
Normal file
51
pkg/middlewares/requestdecorator/hostresolver_test.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
package requestdecorator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCNAMEFlatten(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
resolvFile string
|
||||
domain string
|
||||
expectedDomain string
|
||||
}{
|
||||
{
|
||||
desc: "host request is CNAME record",
|
||||
resolvFile: "/etc/resolv.conf",
|
||||
domain: "www.github.com",
|
||||
expectedDomain: "github.com",
|
||||
},
|
||||
{
|
||||
desc: "resolve file not found",
|
||||
resolvFile: "/etc/resolv.oops",
|
||||
domain: "www.github.com",
|
||||
expectedDomain: "www.github.com",
|
||||
},
|
||||
{
|
||||
desc: "host request is not CNAME record",
|
||||
resolvFile: "/etc/resolv.conf",
|
||||
domain: "github.com",
|
||||
expectedDomain: "github.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
hostResolver := &Resolver{
|
||||
ResolvConfig: test.resolvFile,
|
||||
ResolvDepth: 5,
|
||||
}
|
||||
|
||||
flatH := hostResolver.CNAMEFlatten(context.Background(), test.domain)
|
||||
assert.Equal(t, test.expectedDomain, flatH)
|
||||
})
|
||||
}
|
||||
}
|
87
pkg/middlewares/requestdecorator/request_decorator.go
Normal file
87
pkg/middlewares/requestdecorator/request_decorator.go
Normal file
|
@ -0,0 +1,87 @@
|
|||
package requestdecorator
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/alice"
|
||||
"github.com/containous/traefik/pkg/types"
|
||||
)
|
||||
|
||||
const (
|
||||
canonicalKey key = "canonical"
|
||||
flattenKey key = "flatten"
|
||||
)
|
||||
|
||||
type key string
|
||||
|
||||
// RequestDecorator is the struct for the middleware that adds the CanonicalDomain of the request Host into a context for later use.
|
||||
type RequestDecorator struct {
|
||||
hostResolver *Resolver
|
||||
}
|
||||
|
||||
// New creates a new request host middleware.
|
||||
func New(hostResolverConfig *types.HostResolverConfig) *RequestDecorator {
|
||||
requestDecorator := &RequestDecorator{}
|
||||
if hostResolverConfig != nil {
|
||||
requestDecorator.hostResolver = &Resolver{
|
||||
CnameFlattening: hostResolverConfig.CnameFlattening,
|
||||
ResolvConfig: hostResolverConfig.ResolvConfig,
|
||||
ResolvDepth: hostResolverConfig.ResolvDepth,
|
||||
}
|
||||
}
|
||||
return requestDecorator
|
||||
}
|
||||
|
||||
func (r *RequestDecorator) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
|
||||
host := types.CanonicalDomain(parseHost(req.Host))
|
||||
reqt := req.WithContext(context.WithValue(req.Context(), canonicalKey, host))
|
||||
|
||||
if r.hostResolver != nil && r.hostResolver.CnameFlattening {
|
||||
flatHost := r.hostResolver.CNAMEFlatten(reqt.Context(), host)
|
||||
reqt = reqt.WithContext(context.WithValue(reqt.Context(), flattenKey, flatHost))
|
||||
}
|
||||
|
||||
next(rw, reqt)
|
||||
}
|
||||
|
||||
func parseHost(addr string) string {
|
||||
if !strings.Contains(addr, ":") {
|
||||
return addr
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return addr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// GetCanonizedHost retrieves the canonized host from the given context (previously stored in the request context by the middleware).
|
||||
func GetCanonizedHost(ctx context.Context) string {
|
||||
if val, ok := ctx.Value(canonicalKey).(string); ok {
|
||||
return val
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetCNAMEFlatten return the flat name if it is present in the context.
|
||||
func GetCNAMEFlatten(ctx context.Context) string {
|
||||
if val, ok := ctx.Value(flattenKey).(string); ok {
|
||||
return val
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// WrapHandler Wraps a ServeHTTP with next to an alice.Constructor.
|
||||
func WrapHandler(handler *RequestDecorator) alice.Constructor {
|
||||
return func(next http.Handler) (http.Handler, error) {
|
||||
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
handler.ServeHTTP(rw, req, next.ServeHTTP)
|
||||
}), nil
|
||||
}
|
||||
}
|
146
pkg/middlewares/requestdecorator/request_decorator_test.go
Normal file
146
pkg/middlewares/requestdecorator/request_decorator_test.go
Normal file
|
@ -0,0 +1,146 @@
|
|||
package requestdecorator
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/pkg/types"
|
||||
|
||||
"github.com/containous/traefik/pkg/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRequestHost(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
url string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "host without :",
|
||||
url: "http://host",
|
||||
expected: "host",
|
||||
},
|
||||
{
|
||||
desc: "host with : and without port",
|
||||
url: "http://host:",
|
||||
expected: "host",
|
||||
},
|
||||
{
|
||||
desc: "IP host with : and with port",
|
||||
url: "http://127.0.0.1:123",
|
||||
expected: "127.0.0.1",
|
||||
},
|
||||
{
|
||||
desc: "IP host with : and without port",
|
||||
url: "http://127.0.0.1:",
|
||||
expected: "127.0.0.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
host := GetCanonizedHost(r.Context())
|
||||
assert.Equal(t, test.expected, host)
|
||||
})
|
||||
|
||||
rh := New(nil)
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
|
||||
|
||||
rh.ServeHTTP(nil, req, next)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestFlattening(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
url string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "host with flattening",
|
||||
url: "http://www.github.com",
|
||||
expected: "github.com",
|
||||
},
|
||||
{
|
||||
desc: "host without flattening",
|
||||
url: "http://github.com",
|
||||
expected: "github.com",
|
||||
},
|
||||
{
|
||||
desc: "ip without flattening",
|
||||
url: "http://127.0.0.1",
|
||||
expected: "127.0.0.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||||
host := GetCNAMEFlatten(r.Context())
|
||||
assert.Equal(t, test.expected, host)
|
||||
})
|
||||
|
||||
rh := New(
|
||||
&types.HostResolverConfig{
|
||||
CnameFlattening: true,
|
||||
ResolvConfig: "/etc/resolv.conf",
|
||||
ResolvDepth: 5,
|
||||
},
|
||||
)
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
|
||||
|
||||
rh.ServeHTTP(nil, req, next)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestHostParseHost(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
host string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "host without :",
|
||||
host: "host",
|
||||
expected: "host",
|
||||
},
|
||||
{
|
||||
desc: "host with : and without port",
|
||||
host: "host:",
|
||||
expected: "host",
|
||||
},
|
||||
{
|
||||
desc: "IP host with : and with port",
|
||||
host: "127.0.0.1:123",
|
||||
expected: "127.0.0.1",
|
||||
},
|
||||
{
|
||||
desc: "IP host with : and without port",
|
||||
host: "127.0.0.1:",
|
||||
expected: "127.0.0.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
actual := parseHost(test.host)
|
||||
|
||||
assert.Equal(t, test.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
207
pkg/middlewares/retry/retry.go
Normal file
207
pkg/middlewares/retry/retry.go
Normal file
|
@ -0,0 +1,207 @@
|
|||
package retry
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/middlewares"
|
||||
"github.com/containous/traefik/pkg/tracing"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
)
|
||||
|
||||
// Compile time validation that the response writer implements http interfaces correctly.
|
||||
var _ middlewares.Stateful = &responseWriterWithCloseNotify{}
|
||||
|
||||
const (
|
||||
typeName = "Retry"
|
||||
)
|
||||
|
||||
// Listener is used to inform about retry attempts.
|
||||
type Listener interface {
|
||||
// Retried will be called when a retry happens, with the request attempt passed to it.
|
||||
// For the first retry this will be attempt 2.
|
||||
Retried(req *http.Request, attempt int)
|
||||
}
|
||||
|
||||
// Listeners is a convenience type to construct a list of Listener and notify
|
||||
// each of them about a retry attempt.
|
||||
type Listeners []Listener
|
||||
|
||||
// retry is a middleware that retries requests.
|
||||
type retry struct {
|
||||
attempts int
|
||||
next http.Handler
|
||||
listener Listener
|
||||
name string
|
||||
}
|
||||
|
||||
// New returns a new retry middleware.
|
||||
func New(ctx context.Context, next http.Handler, config config.Retry, listener Listener, name string) (http.Handler, error) {
|
||||
logger := middlewares.GetLogger(ctx, name, typeName)
|
||||
logger.Debug("Creating middleware")
|
||||
|
||||
if config.Attempts <= 0 {
|
||||
return nil, fmt.Errorf("incorrect (or empty) value for attempt (%d)", config.Attempts)
|
||||
}
|
||||
|
||||
return &retry{
|
||||
attempts: config.Attempts,
|
||||
next: next,
|
||||
listener: listener,
|
||||
name: name,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *retry) GetTracingInformation() (string, ext.SpanKindEnum) {
|
||||
return r.name, tracing.SpanKindNoneEnum
|
||||
}
|
||||
|
||||
func (r *retry) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// if we might make multiple attempts, swap the body for an ioutil.NopCloser
|
||||
// cf https://github.com/containous/traefik/issues/1008
|
||||
if r.attempts > 1 {
|
||||
body := req.Body
|
||||
defer body.Close()
|
||||
req.Body = ioutil.NopCloser(body)
|
||||
}
|
||||
|
||||
attempts := 1
|
||||
for {
|
||||
shouldRetry := attempts < r.attempts
|
||||
retryResponseWriter := newResponseWriter(rw, shouldRetry)
|
||||
|
||||
// Disable retries when the backend already received request data
|
||||
trace := &httptrace.ClientTrace{
|
||||
WroteHeaders: func() {
|
||||
retryResponseWriter.DisableRetries()
|
||||
},
|
||||
WroteRequest: func(httptrace.WroteRequestInfo) {
|
||||
retryResponseWriter.DisableRetries()
|
||||
},
|
||||
}
|
||||
newCtx := httptrace.WithClientTrace(req.Context(), trace)
|
||||
|
||||
r.next.ServeHTTP(retryResponseWriter, req.WithContext(newCtx))
|
||||
|
||||
if !retryResponseWriter.ShouldRetry() {
|
||||
break
|
||||
}
|
||||
|
||||
attempts++
|
||||
logger := middlewares.GetLogger(req.Context(), r.name, typeName)
|
||||
logger.Debugf("New attempt %d for request: %v", attempts, req.URL)
|
||||
r.listener.Retried(req, attempts)
|
||||
}
|
||||
}
|
||||
|
||||
// Retried exists to implement the Listener interface. It calls Retried on each of its slice entries.
|
||||
func (l Listeners) Retried(req *http.Request, attempt int) {
|
||||
for _, listener := range l {
|
||||
listener.Retried(req, attempt)
|
||||
}
|
||||
}
|
||||
|
||||
type responseWriter interface {
|
||||
http.ResponseWriter
|
||||
http.Flusher
|
||||
ShouldRetry() bool
|
||||
DisableRetries()
|
||||
}
|
||||
|
||||
func newResponseWriter(rw http.ResponseWriter, shouldRetry bool) responseWriter {
|
||||
responseWriter := &responseWriterWithoutCloseNotify{
|
||||
responseWriter: rw,
|
||||
headers: make(http.Header),
|
||||
shouldRetry: shouldRetry,
|
||||
}
|
||||
if _, ok := rw.(http.CloseNotifier); ok {
|
||||
return &responseWriterWithCloseNotify{
|
||||
responseWriterWithoutCloseNotify: responseWriter,
|
||||
}
|
||||
}
|
||||
return responseWriter
|
||||
}
|
||||
|
||||
type responseWriterWithoutCloseNotify struct {
|
||||
responseWriter http.ResponseWriter
|
||||
headers http.Header
|
||||
shouldRetry bool
|
||||
written bool
|
||||
}
|
||||
|
||||
func (r *responseWriterWithoutCloseNotify) ShouldRetry() bool {
|
||||
return r.shouldRetry
|
||||
}
|
||||
|
||||
func (r *responseWriterWithoutCloseNotify) DisableRetries() {
|
||||
r.shouldRetry = false
|
||||
}
|
||||
|
||||
func (r *responseWriterWithoutCloseNotify) Header() http.Header {
|
||||
if r.written {
|
||||
return r.responseWriter.Header()
|
||||
}
|
||||
return r.headers
|
||||
}
|
||||
|
||||
func (r *responseWriterWithoutCloseNotify) Write(buf []byte) (int, error) {
|
||||
if r.ShouldRetry() {
|
||||
return len(buf), nil
|
||||
}
|
||||
return r.responseWriter.Write(buf)
|
||||
}
|
||||
|
||||
func (r *responseWriterWithoutCloseNotify) WriteHeader(code int) {
|
||||
if r.ShouldRetry() && code == http.StatusServiceUnavailable {
|
||||
// We get a 503 HTTP Status Code when there is no backend server in the pool
|
||||
// to which the request could be sent. Also, note that r.ShouldRetry()
|
||||
// will never return true in case there was a connection established to
|
||||
// the backend server and so we can be sure that the 503 was produced
|
||||
// inside Traefik already and we don't have to retry in this cases.
|
||||
r.DisableRetries()
|
||||
}
|
||||
|
||||
if r.ShouldRetry() {
|
||||
return
|
||||
}
|
||||
|
||||
// In that case retry case is set to false which means we at least managed
|
||||
// to write headers to the backend : we are not going to perform any further retry.
|
||||
// So it is now safe to alter current response headers with headers collected during
|
||||
// the latest try before writing headers to client.
|
||||
headers := r.responseWriter.Header()
|
||||
for header, value := range r.headers {
|
||||
headers[header] = value
|
||||
}
|
||||
|
||||
r.responseWriter.WriteHeader(code)
|
||||
r.written = true
|
||||
}
|
||||
|
||||
func (r *responseWriterWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
hijacker, ok := r.responseWriter.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("%T is not a http.Hijacker", r.responseWriter)
|
||||
}
|
||||
return hijacker.Hijack()
|
||||
}
|
||||
|
||||
func (r *responseWriterWithoutCloseNotify) Flush() {
|
||||
if flusher, ok := r.responseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
type responseWriterWithCloseNotify struct {
|
||||
*responseWriterWithoutCloseNotify
|
||||
}
|
||||
|
||||
func (r *responseWriterWithCloseNotify) CloseNotify() <-chan bool {
|
||||
return r.responseWriter.(http.CloseNotifier).CloseNotify()
|
||||
}
|
312
pkg/middlewares/retry/retry_test.go
Normal file
312
pkg/middlewares/retry/retry_test.go
Normal file
|
@ -0,0 +1,312 @@
|
|||
package retry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httptrace"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/pkg/config"
|
||||
"github.com/containous/traefik/pkg/middlewares/emptybackendhandler"
|
||||
"github.com/containous/traefik/pkg/testhelpers"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vulcand/oxy/forward"
|
||||
"github.com/vulcand/oxy/roundrobin"
|
||||
)
|
||||
|
||||
func TestRetry(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
config config.Retry
|
||||
wantRetryAttempts int
|
||||
wantResponseStatus int
|
||||
amountFaultyEndpoints int
|
||||
}{
|
||||
{
|
||||
desc: "no retry on success",
|
||||
config: config.Retry{Attempts: 1},
|
||||
wantRetryAttempts: 0,
|
||||
wantResponseStatus: http.StatusOK,
|
||||
amountFaultyEndpoints: 0,
|
||||
},
|
||||
{
|
||||
desc: "no retry when max request attempts is one",
|
||||
config: config.Retry{Attempts: 1},
|
||||
wantRetryAttempts: 0,
|
||||
wantResponseStatus: http.StatusInternalServerError,
|
||||
amountFaultyEndpoints: 1,
|
||||
},
|
||||
{
|
||||
desc: "one retry when one server is faulty",
|
||||
config: config.Retry{Attempts: 2},
|
||||
wantRetryAttempts: 1,
|
||||
wantResponseStatus: http.StatusOK,
|
||||
amountFaultyEndpoints: 1,
|
||||
},
|
||||
{
|
||||
desc: "two retries when two servers are faulty",
|
||||
config: config.Retry{Attempts: 3},
|
||||
wantRetryAttempts: 2,
|
||||
wantResponseStatus: http.StatusOK,
|
||||
amountFaultyEndpoints: 2,
|
||||
},
|
||||
{
|
||||
desc: "max attempts exhausted delivers the 5xx response",
|
||||
config: config.Retry{Attempts: 3},
|
||||
wantRetryAttempts: 2,
|
||||
wantResponseStatus: http.StatusInternalServerError,
|
||||
amountFaultyEndpoints: 3,
|
||||
},
|
||||
}
|
||||
|
||||
backendServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, err := rw.Write([]byte("OK"))
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
|
||||
forwarder, err := forward.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
loadBalancer, err := roundrobin.New(forwarder)
|
||||
require.NoError(t, err)
|
||||
|
||||
basePort := 33444
|
||||
for i := 0; i < test.amountFaultyEndpoints; i++ {
|
||||
// 192.0.2.0 is a non-routable IP for testing purposes.
|
||||
// See: https://stackoverflow.com/questions/528538/non-routable-ip-address/18436928#18436928
|
||||
// We only use the port specification here because the URL is used as identifier
|
||||
// in the load balancer and using the exact same URL would not add a new server.
|
||||
err = loadBalancer.UpsertServer(testhelpers.MustParseURL("http://192.0.2.0:" + string(basePort+i)))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// add the functioning server to the end of the load balancer list
|
||||
err = loadBalancer.UpsertServer(testhelpers.MustParseURL(backendServer.URL))
|
||||
require.NoError(t, err)
|
||||
|
||||
retryListener := &countingRetryListener{}
|
||||
retry, err := New(context.Background(), loadBalancer, test.config, retryListener, "traefikTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil)
|
||||
|
||||
retry.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, test.wantResponseStatus, recorder.Code)
|
||||
assert.Equal(t, test.wantRetryAttempts, retryListener.timesCalled)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryEmptyServerList(t *testing.T) {
|
||||
forwarder, err := forward.New()
|
||||
require.NoError(t, err)
|
||||
|
||||
loadBalancer, err := roundrobin.New(forwarder)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The EmptyBackend middleware ensures that there is a 503
|
||||
// response status set when there is no backend server in the pool.
|
||||
next := emptybackendhandler.New(loadBalancer)
|
||||
|
||||
retryListener := &countingRetryListener{}
|
||||
retry, err := New(context.Background(), next, config.Retry{Attempts: 3}, retryListener, "traefikTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil)
|
||||
|
||||
retry.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code)
|
||||
assert.Equal(t, 0, retryListener.timesCalled)
|
||||
}
|
||||
|
||||
func TestRetryListeners(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
retryListeners := Listeners{&countingRetryListener{}, &countingRetryListener{}}
|
||||
|
||||
retryListeners.Retried(req, 1)
|
||||
retryListeners.Retried(req, 1)
|
||||
|
||||
for _, retryListener := range retryListeners {
|
||||
listener := retryListener.(*countingRetryListener)
|
||||
if listener.timesCalled != 2 {
|
||||
t.Errorf("retry listener was called %d time(s), want %d time(s)", listener.timesCalled, 2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleRetriesShouldNotLooseHeaders(t *testing.T) {
|
||||
attempt := 0
|
||||
expectedHeaderName := "X-Foo-Test-2"
|
||||
expectedHeaderValue := "bar"
|
||||
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
headerName := fmt.Sprintf("X-Foo-Test-%d", attempt)
|
||||
rw.Header().Add(headerName, expectedHeaderValue)
|
||||
if attempt < 2 {
|
||||
attempt++
|
||||
return
|
||||
}
|
||||
|
||||
// Request has been successfully written to backend
|
||||
trace := httptrace.ContextClientTrace(req.Context())
|
||||
trace.WroteHeaders()
|
||||
|
||||
// And we decide to answer to client
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
|
||||
retry, err := New(context.Background(), next, config.Retry{Attempts: 3}, &countingRetryListener{}, "traefikTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
retry.ServeHTTP(responseRecorder, testhelpers.MustNewRequest(http.MethodGet, "http://test", http.NoBody))
|
||||
|
||||
headerValue := responseRecorder.Header().Get(expectedHeaderName)
|
||||
|
||||
// Validate if we have the correct header
|
||||
if headerValue != expectedHeaderValue {
|
||||
t.Errorf("Expected to have %s for header %s, got %s", expectedHeaderValue, expectedHeaderName, headerValue)
|
||||
}
|
||||
|
||||
// Validate that we don't have headers from previous attempts
|
||||
for i := 0; i < attempt; i++ {
|
||||
headerName := fmt.Sprintf("X-Foo-Test-%d", i)
|
||||
headerValue = responseRecorder.Header().Get("headerName")
|
||||
if headerValue != "" {
|
||||
t.Errorf("Expected no value for header %s, got %s", headerName, headerValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// countingRetryListener is a Listener implementation to count the times the Retried fn is called.
|
||||
type countingRetryListener struct {
|
||||
timesCalled int
|
||||
}
|
||||
|
||||
func (l *countingRetryListener) Retried(req *http.Request, attempt int) {
|
||||
l.timesCalled++
|
||||
}
|
||||
|
||||
func TestRetryWithFlush(t *testing.T) {
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(200)
|
||||
_, err := rw.Write([]byte("FULL "))
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
rw.(http.Flusher).Flush()
|
||||
_, err = rw.Write([]byte("DATA"))
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
})
|
||||
|
||||
retry, err := New(context.Background(), next, config.Retry{Attempts: 1}, &countingRetryListener{}, "traefikTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
retry.ServeHTTP(responseRecorder, &http.Request{})
|
||||
|
||||
assert.Equal(t, "FULL DATA", responseRecorder.Body.String())
|
||||
}
|
||||
|
||||
func TestRetryWebsocket(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
maxRequestAttempts int
|
||||
expectedRetryAttempts int
|
||||
expectedResponseStatus int
|
||||
expectedError bool
|
||||
amountFaultyEndpoints int
|
||||
}{
|
||||
{
|
||||
desc: "Switching ok after 2 retries",
|
||||
maxRequestAttempts: 3,
|
||||
expectedRetryAttempts: 2,
|
||||
amountFaultyEndpoints: 2,
|
||||
expectedResponseStatus: http.StatusSwitchingProtocols,
|
||||
},
|
||||
{
|
||||
desc: "Switching failed",
|
||||
maxRequestAttempts: 2,
|
||||
expectedRetryAttempts: 1,
|
||||
amountFaultyEndpoints: 2,
|
||||
expectedResponseStatus: http.StatusBadGateway,
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
forwarder, err := forward.New()
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating forwarder: %v", err)
|
||||
}
|
||||
|
||||
backendServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
upgrader := websocket.Upgrader{}
|
||||
_, err := upgrader.Upgrade(rw, req, nil)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
loadBalancer, err := roundrobin.New(forwarder)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating load balancer: %v", err)
|
||||
}
|
||||
|
||||
basePort := 33444
|
||||
for i := 0; i < test.amountFaultyEndpoints; i++ {
|
||||
// 192.0.2.0 is a non-routable IP for testing purposes.
|
||||
// See: https://stackoverflow.com/questions/528538/non-routable-ip-address/18436928#18436928
|
||||
// We only use the port specification here because the URL is used as identifier
|
||||
// in the load balancer and using the exact same URL would not add a new server.
|
||||
_ = loadBalancer.UpsertServer(testhelpers.MustParseURL("http://192.0.2.0:" + string(basePort+i)))
|
||||
}
|
||||
|
||||
// add the functioning server to the end of the load balancer list
|
||||
err = loadBalancer.UpsertServer(testhelpers.MustParseURL(backendServer.URL))
|
||||
if err != nil {
|
||||
t.Fatalf("Fail to upsert server: %v", err)
|
||||
}
|
||||
|
||||
retryListener := &countingRetryListener{}
|
||||
retryH, err := New(context.Background(), loadBalancer, config.Retry{Attempts: test.maxRequestAttempts}, retryListener, "traefikTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
retryServer := httptest.NewServer(retryH)
|
||||
|
||||
url := strings.Replace(retryServer.URL, "http", "ws", 1)
|
||||
_, response, err := websocket.DefaultDialer.Dial(url, nil)
|
||||
|
||||
if !test.expectedError {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, test.expectedResponseStatus, response.StatusCode)
|
||||
assert.Equal(t, test.expectedRetryAttempts, retryListener.timesCalled)
|
||||
})
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue