1
0
Fork 0

Dynamic Configuration Refactoring

This commit is contained in:
Ludovic Fernandez 2018-11-14 10:18:03 +01:00 committed by Traefiker Bot
parent d3ae88f108
commit a09dfa3ce1
452 changed files with 21023 additions and 9419 deletions

39
old/api/dashboard.go Normal file
View file

@ -0,0 +1,39 @@
package api
import (
"net/http"
"github.com/containous/mux"
"github.com/containous/traefik/old/log"
"github.com/elazarl/go-bindata-assetfs"
)
// DashboardHandler expose dashboard routes
type DashboardHandler struct {
Assets *assetfs.AssetFS
}
// AddRoutes add dashboard routes on a router
func (g DashboardHandler) AddRoutes(router *mux.Router) {
if g.Assets == nil {
log.Error("No assets for dashboard")
return
}
// Expose dashboard
router.Methods(http.MethodGet).
Path("/").
HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
http.Redirect(response, request, request.Header.Get("X-Forwarded-Prefix")+"/dashboard/", 302)
})
router.Methods(http.MethodGet).
Path("/dashboard/status").
HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
http.Redirect(response, request, "/dashboard/", 302)
})
router.Methods(http.MethodGet).
PathPrefix("/dashboard/").
Handler(http.StripPrefix("/dashboard/", http.FileServer(g.Assets)))
}

48
old/api/debug.go Normal file
View file

@ -0,0 +1,48 @@
package api
import (
"expvar"
"fmt"
"net/http"
"net/http/pprof"
"runtime"
"github.com/containous/mux"
)
func init() {
expvar.Publish("Goroutines", expvar.Func(goroutines))
}
func goroutines() interface{} {
return runtime.NumGoroutine()
}
// DebugHandler expose debug routes
type DebugHandler struct{}
// AddRoutes add debug routes on a router
func (g DebugHandler) AddRoutes(router *mux.Router) {
router.Methods(http.MethodGet).Path("/debug/vars").
HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
fmt.Fprint(w, "{\n")
first := true
expvar.Do(func(kv expvar.KeyValue) {
if !first {
fmt.Fprint(w, ",\n")
}
first = false
fmt.Fprintf(w, "%q: %s", kv.Key, kv.Value)
})
fmt.Fprint(w, "\n}\n")
})
runtime.SetBlockProfileRate(1)
runtime.SetMutexProfileFraction(5)
router.Methods(http.MethodGet).PathPrefix("/debug/pprof/cmdline").HandlerFunc(pprof.Cmdline)
router.Methods(http.MethodGet).PathPrefix("/debug/pprof/profile").HandlerFunc(pprof.Profile)
router.Methods(http.MethodGet).PathPrefix("/debug/pprof/symbol").HandlerFunc(pprof.Symbol)
router.Methods(http.MethodGet).PathPrefix("/debug/pprof/trace").HandlerFunc(pprof.Trace)
router.Methods(http.MethodGet).PathPrefix("/debug/pprof/").HandlerFunc(pprof.Index)
}

252
old/api/handler.go Normal file
View file

@ -0,0 +1,252 @@
package api
import (
"net/http"
"github.com/containous/mux"
"github.com/containous/traefik/old/log"
"github.com/containous/traefik/old/middlewares"
"github.com/containous/traefik/old/types"
"github.com/containous/traefik/safe"
"github.com/containous/traefik/version"
"github.com/elazarl/go-bindata-assetfs"
thoas_stats "github.com/thoas/stats"
"github.com/unrolled/render"
)
// Handler expose api routes
type Handler struct {
EntryPoint string `description:"EntryPoint" export:"true"`
Dashboard bool `description:"Activate dashboard" export:"true"`
Debug bool `export:"true"`
CurrentConfigurations *safe.Safe
Statistics *types.Statistics `description:"Enable more detailed statistics" export:"true"`
Stats *thoas_stats.Stats `json:"-"`
StatsRecorder *middlewares.StatsRecorder `json:"-"`
DashboardAssets *assetfs.AssetFS `json:"-"`
}
var (
templatesRenderer = render.New(render.Options{
Directory: "nowhere",
})
)
// AddRoutes add api routes on a router
func (p Handler) AddRoutes(router *mux.Router) {
if p.Debug {
DebugHandler{}.AddRoutes(router)
}
router.Methods(http.MethodGet).Path("/api").HandlerFunc(p.getConfigHandler)
router.Methods(http.MethodGet).Path("/api/providers").HandlerFunc(p.getConfigHandler)
router.Methods(http.MethodGet).Path("/api/providers/{provider}").HandlerFunc(p.getProviderHandler)
router.Methods(http.MethodGet).Path("/api/providers/{provider}/backends").HandlerFunc(p.getBackendsHandler)
router.Methods(http.MethodGet).Path("/api/providers/{provider}/backends/{backend}").HandlerFunc(p.getBackendHandler)
router.Methods(http.MethodGet).Path("/api/providers/{provider}/backends/{backend}/servers").HandlerFunc(p.getServersHandler)
router.Methods(http.MethodGet).Path("/api/providers/{provider}/backends/{backend}/servers/{server}").HandlerFunc(p.getServerHandler)
router.Methods(http.MethodGet).Path("/api/providers/{provider}/frontends").HandlerFunc(p.getFrontendsHandler)
router.Methods(http.MethodGet).Path("/api/providers/{provider}/frontends/{frontend}").HandlerFunc(p.getFrontendHandler)
router.Methods(http.MethodGet).Path("/api/providers/{provider}/frontends/{frontend}/routes").HandlerFunc(p.getRoutesHandler)
router.Methods(http.MethodGet).Path("/api/providers/{provider}/frontends/{frontend}/routes/{route}").HandlerFunc(p.getRouteHandler)
// health route
router.Methods(http.MethodGet).Path("/health").HandlerFunc(p.getHealthHandler)
version.Handler{}.Append(router)
if p.Dashboard {
DashboardHandler{Assets: p.DashboardAssets}.AddRoutes(router)
}
}
func getProviderIDFromVars(vars map[string]string) string {
providerID := vars["provider"]
// TODO: Deprecated
if providerID == "rest" {
providerID = "web"
}
return providerID
}
func (p Handler) getConfigHandler(response http.ResponseWriter, request *http.Request) {
currentConfigurations := p.CurrentConfigurations.Get().(types.Configurations)
err := templatesRenderer.JSON(response, http.StatusOK, currentConfigurations)
if err != nil {
log.Error(err)
}
}
func (p Handler) getProviderHandler(response http.ResponseWriter, request *http.Request) {
providerID := getProviderIDFromVars(mux.Vars(request))
currentConfigurations := p.CurrentConfigurations.Get().(types.Configurations)
if provider, ok := currentConfigurations[providerID]; ok {
err := templatesRenderer.JSON(response, http.StatusOK, provider)
if err != nil {
log.Error(err)
}
} else {
http.NotFound(response, request)
}
}
func (p Handler) getBackendsHandler(response http.ResponseWriter, request *http.Request) {
providerID := getProviderIDFromVars(mux.Vars(request))
currentConfigurations := p.CurrentConfigurations.Get().(types.Configurations)
if provider, ok := currentConfigurations[providerID]; ok {
err := templatesRenderer.JSON(response, http.StatusOK, provider.Backends)
if err != nil {
log.Error(err)
}
} else {
http.NotFound(response, request)
}
}
func (p Handler) getBackendHandler(response http.ResponseWriter, request *http.Request) {
vars := mux.Vars(request)
providerID := getProviderIDFromVars(vars)
backendID := vars["backend"]
currentConfigurations := p.CurrentConfigurations.Get().(types.Configurations)
if provider, ok := currentConfigurations[providerID]; ok {
if backend, ok := provider.Backends[backendID]; ok {
err := templatesRenderer.JSON(response, http.StatusOK, backend)
if err != nil {
log.Error(err)
}
return
}
}
http.NotFound(response, request)
}
func (p Handler) getServersHandler(response http.ResponseWriter, request *http.Request) {
vars := mux.Vars(request)
providerID := getProviderIDFromVars(vars)
backendID := vars["backend"]
currentConfigurations := p.CurrentConfigurations.Get().(types.Configurations)
if provider, ok := currentConfigurations[providerID]; ok {
if backend, ok := provider.Backends[backendID]; ok {
err := templatesRenderer.JSON(response, http.StatusOK, backend.Servers)
if err != nil {
log.Error(err)
}
return
}
}
http.NotFound(response, request)
}
func (p Handler) getServerHandler(response http.ResponseWriter, request *http.Request) {
vars := mux.Vars(request)
providerID := getProviderIDFromVars(vars)
backendID := vars["backend"]
serverID := vars["server"]
currentConfigurations := p.CurrentConfigurations.Get().(types.Configurations)
if provider, ok := currentConfigurations[providerID]; ok {
if backend, ok := provider.Backends[backendID]; ok {
if server, ok := backend.Servers[serverID]; ok {
err := templatesRenderer.JSON(response, http.StatusOK, server)
if err != nil {
log.Error(err)
}
return
}
}
}
http.NotFound(response, request)
}
func (p Handler) getFrontendsHandler(response http.ResponseWriter, request *http.Request) {
providerID := getProviderIDFromVars(mux.Vars(request))
currentConfigurations := p.CurrentConfigurations.Get().(types.Configurations)
if provider, ok := currentConfigurations[providerID]; ok {
err := templatesRenderer.JSON(response, http.StatusOK, provider.Frontends)
if err != nil {
log.Error(err)
}
} else {
http.NotFound(response, request)
}
}
func (p Handler) getFrontendHandler(response http.ResponseWriter, request *http.Request) {
vars := mux.Vars(request)
providerID := getProviderIDFromVars(vars)
frontendID := vars["frontend"]
currentConfigurations := p.CurrentConfigurations.Get().(types.Configurations)
if provider, ok := currentConfigurations[providerID]; ok {
if frontend, ok := provider.Frontends[frontendID]; ok {
err := templatesRenderer.JSON(response, http.StatusOK, frontend)
if err != nil {
log.Error(err)
}
return
}
}
http.NotFound(response, request)
}
func (p Handler) getRoutesHandler(response http.ResponseWriter, request *http.Request) {
vars := mux.Vars(request)
providerID := getProviderIDFromVars(vars)
frontendID := vars["frontend"]
currentConfigurations := p.CurrentConfigurations.Get().(types.Configurations)
if provider, ok := currentConfigurations[providerID]; ok {
if frontend, ok := provider.Frontends[frontendID]; ok {
err := templatesRenderer.JSON(response, http.StatusOK, frontend.Routes)
if err != nil {
log.Error(err)
}
return
}
}
http.NotFound(response, request)
}
func (p Handler) getRouteHandler(response http.ResponseWriter, request *http.Request) {
vars := mux.Vars(request)
providerID := getProviderIDFromVars(vars)
frontendID := vars["frontend"]
routeID := vars["route"]
currentConfigurations := p.CurrentConfigurations.Get().(types.Configurations)
if provider, ok := currentConfigurations[providerID]; ok {
if frontend, ok := provider.Frontends[frontendID]; ok {
if route, ok := frontend.Routes[routeID]; ok {
err := templatesRenderer.JSON(response, http.StatusOK, route)
if err != nil {
log.Error(err)
}
return
}
}
}
http.NotFound(response, request)
}
// healthResponse combines data returned by thoas/stats with statistics (if
// they are enabled).
type healthResponse struct {
*thoas_stats.Data
*middlewares.Stats
}
func (p *Handler) getHealthHandler(response http.ResponseWriter, request *http.Request) {
health := &healthResponse{Data: p.Stats.Data()}
if p.StatsRecorder != nil {
health.Stats = p.StatsRecorder.Data()
}
err := templatesRenderer.JSON(response, http.StatusOK, health)
if err != nil {
log.Error(err)
}
}

View file

@ -0,0 +1,445 @@
package configuration
import (
"fmt"
"strings"
"time"
"github.com/containous/flaeg/parse"
"github.com/containous/traefik/acme"
"github.com/containous/traefik/old/api"
"github.com/containous/traefik/old/log"
"github.com/containous/traefik/old/middlewares/tracing"
"github.com/containous/traefik/old/middlewares/tracing/datadog"
"github.com/containous/traefik/old/middlewares/tracing/jaeger"
"github.com/containous/traefik/old/middlewares/tracing/zipkin"
"github.com/containous/traefik/old/ping"
"github.com/containous/traefik/old/provider/boltdb"
"github.com/containous/traefik/old/provider/consul"
"github.com/containous/traefik/old/provider/consulcatalog"
"github.com/containous/traefik/old/provider/docker"
"github.com/containous/traefik/old/provider/dynamodb"
"github.com/containous/traefik/old/provider/ecs"
"github.com/containous/traefik/old/provider/etcd"
"github.com/containous/traefik/old/provider/eureka"
"github.com/containous/traefik/old/provider/file"
"github.com/containous/traefik/old/provider/kubernetes"
"github.com/containous/traefik/old/provider/marathon"
"github.com/containous/traefik/old/provider/mesos"
"github.com/containous/traefik/old/provider/rancher"
"github.com/containous/traefik/old/provider/rest"
"github.com/containous/traefik/old/provider/zk"
"github.com/containous/traefik/old/types"
acmeprovider "github.com/containous/traefik/provider/acme"
"github.com/containous/traefik/tls"
newtypes "github.com/containous/traefik/types"
"github.com/pkg/errors"
lego "github.com/xenolf/lego/acme"
)
const (
// DefaultInternalEntryPointName the name of the default internal entry point
DefaultInternalEntryPointName = "traefik"
// DefaultHealthCheckInterval is the default health check interval.
DefaultHealthCheckInterval = 30 * time.Second
// DefaultHealthCheckTimeout is the default health check request timeout.
DefaultHealthCheckTimeout = 5 * time.Second
// DefaultDialTimeout when connecting to a backend server.
DefaultDialTimeout = 30 * time.Second
// DefaultIdleTimeout before closing an idle connection.
DefaultIdleTimeout = 180 * time.Second
// DefaultGraceTimeout controls how long Traefik serves pending requests
// prior to shutting down.
DefaultGraceTimeout = 10 * time.Second
// DefaultAcmeCAServer is the default ACME API endpoint
DefaultAcmeCAServer = "https://acme-v02.api.letsencrypt.org/directory"
)
// GlobalConfiguration holds global configuration (with providers, etc.).
// It's populated from the traefik configuration file passed as an argument to the binary.
type GlobalConfiguration struct {
LifeCycle *LifeCycle `description:"Timeouts influencing the server life cycle" export:"true"`
Debug bool `short:"d" description:"Enable debug mode" export:"true"`
CheckNewVersion bool `description:"Periodically check if a new version has been released" export:"true"`
SendAnonymousUsage bool `description:"send periodically anonymous usage statistics" export:"true"`
AccessLog *types.AccessLog `description:"Access log settings" export:"true"`
TraefikLog *types.TraefikLog `description:"Traefik log settings" export:"true"`
Tracing *tracing.Tracing `description:"OpenTracing configuration" export:"true"`
LogLevel string `short:"l" description:"Log level" export:"true"`
EntryPoints EntryPoints `description:"Entrypoints definition using format: --entryPoints='Name:http Address::8000 Redirect.EntryPoint:https' --entryPoints='Name:https Address::4442 TLS:tests/traefik.crt,tests/traefik.key;prod/traefik.crt,prod/traefik.key'" export:"true"`
Cluster *types.Cluster
Constraints types.Constraints `description:"Filter services by constraint, matching with service tags" export:"true"`
ACME *acme.ACME `description:"Enable ACME (Let's Encrypt): automatic SSL" export:"true"`
DefaultEntryPoints DefaultEntryPoints `description:"Entrypoints to be used by frontends that do not specify any entrypoint" export:"true"`
ProvidersThrottleDuration parse.Duration `description:"Backends throttle duration: minimum duration between 2 events from providers before applying a new configuration. It avoids unnecessary reloads if multiples events are sent in a short amount of time." export:"true"`
MaxIdleConnsPerHost int `description:"If non-zero, controls the maximum idle (keep-alive) to keep per-host. If zero, DefaultMaxIdleConnsPerHost is used" export:"true"`
InsecureSkipVerify bool `description:"Disable SSL certificate verification" export:"true"`
RootCAs tls.FilesOrContents `description:"Add cert file for self-signed certificate"`
Retry *Retry `description:"Enable retry sending request if network error" export:"true"`
HealthCheck *HealthCheckConfig `description:"Health check parameters" export:"true"`
RespondingTimeouts *RespondingTimeouts `description:"Timeouts for incoming requests to the Traefik instance" export:"true"`
ForwardingTimeouts *ForwardingTimeouts `description:"Timeouts for requests forwarded to the backend servers" export:"true"`
KeepTrailingSlash bool `description:"Do not remove trailing slash." export:"true"` // Deprecated
Docker *docker.Provider `description:"Enable Docker backend with default settings" export:"true"`
File *file.Provider `description:"Enable File backend with default settings" export:"true"`
Marathon *marathon.Provider `description:"Enable Marathon backend with default settings" export:"true"`
Consul *consul.Provider `description:"Enable Consul backend with default settings" export:"true"`
ConsulCatalog *consulcatalog.Provider `description:"Enable Consul catalog backend with default settings" export:"true"`
Etcd *etcd.Provider `description:"Enable Etcd backend with default settings" export:"true"`
Zookeeper *zk.Provider `description:"Enable Zookeeper backend with default settings" export:"true"`
Boltdb *boltdb.Provider `description:"Enable Boltdb backend with default settings" export:"true"`
Kubernetes *kubernetes.Provider `description:"Enable Kubernetes backend with default settings" export:"true"`
Mesos *mesos.Provider `description:"Enable Mesos backend with default settings" export:"true"`
Eureka *eureka.Provider `description:"Enable Eureka backend with default settings" export:"true"`
ECS *ecs.Provider `description:"Enable ECS backend with default settings" export:"true"`
Rancher *rancher.Provider `description:"Enable Rancher backend with default settings" export:"true"`
DynamoDB *dynamodb.Provider `description:"Enable DynamoDB backend with default settings" export:"true"`
Rest *rest.Provider `description:"Enable Rest backend with default settings" export:"true"`
API *api.Handler `description:"Enable api/dashboard" export:"true"`
Metrics *types.Metrics `description:"Enable a metrics exporter" export:"true"`
Ping *ping.Handler `description:"Enable ping" export:"true"`
HostResolver *HostResolverConfig `description:"Enable CNAME Flattening" export:"true"`
}
// SetEffectiveConfiguration adds missing configuration parameters derived from existing ones.
// It also takes care of maintaining backwards compatibility.
func (gc *GlobalConfiguration) SetEffectiveConfiguration(configFile string) {
if len(gc.EntryPoints) == 0 {
gc.EntryPoints = map[string]*EntryPoint{"http": {
Address: ":80",
ForwardedHeaders: &ForwardedHeaders{},
}}
gc.DefaultEntryPoints = []string{"http"}
}
if (gc.API != nil && gc.API.EntryPoint == DefaultInternalEntryPointName) ||
(gc.Ping != nil && gc.Ping.EntryPoint == DefaultInternalEntryPointName) ||
(gc.Metrics != nil && gc.Metrics.Prometheus != nil && gc.Metrics.Prometheus.EntryPoint == DefaultInternalEntryPointName) ||
(gc.Rest != nil && gc.Rest.EntryPoint == DefaultInternalEntryPointName) {
if _, ok := gc.EntryPoints[DefaultInternalEntryPointName]; !ok {
gc.EntryPoints[DefaultInternalEntryPointName] = &EntryPoint{Address: ":8080"}
}
}
for entryPointName := range gc.EntryPoints {
entryPoint := gc.EntryPoints[entryPointName]
// ForwardedHeaders must be remove in the next breaking version
if entryPoint.ForwardedHeaders == nil {
entryPoint.ForwardedHeaders = &ForwardedHeaders{}
}
if entryPoint.TLS != nil && entryPoint.TLS.DefaultCertificate == nil && len(entryPoint.TLS.Certificates) > 0 {
log.Infof("No tls.defaultCertificate given for %s: using the first item in tls.certificates as a fallback.", entryPointName)
entryPoint.TLS.DefaultCertificate = &entryPoint.TLS.Certificates[0]
}
}
// Make sure LifeCycle isn't nil to spare nil checks elsewhere.
if gc.LifeCycle == nil {
gc.LifeCycle = &LifeCycle{}
}
if gc.Rancher != nil {
// Ensure backwards compatibility for now
if len(gc.Rancher.AccessKey) > 0 ||
len(gc.Rancher.Endpoint) > 0 ||
len(gc.Rancher.SecretKey) > 0 {
if gc.Rancher.API == nil {
gc.Rancher.API = &rancher.APIConfiguration{
AccessKey: gc.Rancher.AccessKey,
SecretKey: gc.Rancher.SecretKey,
Endpoint: gc.Rancher.Endpoint,
}
}
log.Warn("Deprecated configuration found: rancher.[accesskey|secretkey|endpoint]. " +
"Please use rancher.api.[accesskey|secretkey|endpoint] instead.")
}
if gc.Rancher.Metadata != nil && len(gc.Rancher.Metadata.Prefix) == 0 {
gc.Rancher.Metadata.Prefix = "latest"
}
}
if gc.API != nil {
gc.API.Debug = gc.Debug
}
if gc.File != nil {
gc.File.TraefikFile = configFile
}
gc.initACMEProvider()
gc.initTracing()
}
func (gc *GlobalConfiguration) initTracing() {
if gc.Tracing != nil {
switch gc.Tracing.Backend {
case jaeger.Name:
if gc.Tracing.Jaeger == nil {
gc.Tracing.Jaeger = &jaeger.Config{
SamplingServerURL: "http://localhost:5778/sampling",
SamplingType: "const",
SamplingParam: 1.0,
LocalAgentHostPort: "127.0.0.1:6831",
Propagation: "jaeger",
Gen128Bit: false,
}
}
if gc.Tracing.Zipkin != nil {
log.Warn("Zipkin configuration will be ignored")
gc.Tracing.Zipkin = nil
}
if gc.Tracing.DataDog != nil {
log.Warn("DataDog configuration will be ignored")
gc.Tracing.DataDog = nil
}
case zipkin.Name:
if gc.Tracing.Zipkin == nil {
gc.Tracing.Zipkin = &zipkin.Config{
HTTPEndpoint: "http://localhost:9411/api/v1/spans",
SameSpan: false,
ID128Bit: true,
Debug: false,
SampleRate: 1.0,
}
}
if gc.Tracing.Jaeger != nil {
log.Warn("Jaeger configuration will be ignored")
gc.Tracing.Jaeger = nil
}
if gc.Tracing.DataDog != nil {
log.Warn("DataDog configuration will be ignored")
gc.Tracing.DataDog = nil
}
case datadog.Name:
if gc.Tracing.DataDog == nil {
gc.Tracing.DataDog = &datadog.Config{
LocalAgentHostPort: "localhost:8126",
GlobalTag: "",
Debug: false,
}
}
if gc.Tracing.Zipkin != nil {
log.Warn("Zipkin configuration will be ignored")
gc.Tracing.Zipkin = nil
}
if gc.Tracing.Jaeger != nil {
log.Warn("Jaeger configuration will be ignored")
gc.Tracing.Jaeger = nil
}
default:
log.Warnf("Unknown tracer %q", gc.Tracing.Backend)
return
}
}
}
func (gc *GlobalConfiguration) initACMEProvider() {
if gc.ACME != nil {
gc.ACME.CAServer = getSafeACMECAServer(gc.ACME.CAServer)
if gc.ACME.DNSChallenge != nil && gc.ACME.HTTPChallenge != nil {
log.Warn("Unable to use DNS challenge and HTTP challenge at the same time. Fallback to DNS challenge.")
gc.ACME.HTTPChallenge = nil
}
if gc.ACME.DNSChallenge != nil && gc.ACME.TLSChallenge != nil {
log.Warn("Unable to use DNS challenge and TLS challenge at the same time. Fallback to DNS challenge.")
gc.ACME.TLSChallenge = nil
}
if gc.ACME.HTTPChallenge != nil && gc.ACME.TLSChallenge != nil {
log.Warn("Unable to use HTTP challenge and TLS challenge at the same time. Fallback to TLS challenge.")
gc.ACME.HTTPChallenge = nil
}
if gc.ACME.OnDemand {
log.Warn("ACME.OnDemand is deprecated")
}
}
}
// InitACMEProvider create an acme provider from the ACME part of globalConfiguration
func (gc *GlobalConfiguration) InitACMEProvider() (*acmeprovider.Provider, error) {
if gc.ACME != nil {
if len(gc.ACME.Storage) == 0 {
// Delete the ACME configuration to avoid starting ACME in cluster mode
gc.ACME = nil
return nil, errors.New("unable to initialize ACME provider with no storage location for the certificates")
}
// TODO: Remove when Provider ACME will replace totally ACME
// If provider file, use Provider ACME instead of ACME
if gc.Cluster == nil {
provider := &acmeprovider.Provider{}
provider.Configuration = convertACMEChallenge(gc.ACME)
store := acmeprovider.NewLocalStore(provider.Storage)
provider.Store = store
acme.ConvertToNewFormat(provider.Storage)
gc.ACME = nil
return provider, nil
}
}
return nil, nil
}
func getSafeACMECAServer(caServerSrc string) string {
if len(caServerSrc) == 0 {
return DefaultAcmeCAServer
}
if strings.HasPrefix(caServerSrc, "https://acme-v01.api.letsencrypt.org") {
caServer := strings.Replace(caServerSrc, "v01", "v02", 1)
log.Warnf("The CA server %[1]q refers to a v01 endpoint of the ACME API, please change to %[2]q. Fallback to %[2]q.", caServerSrc, caServer)
return caServer
}
if strings.HasPrefix(caServerSrc, "https://acme-staging.api.letsencrypt.org") {
caServer := strings.Replace(caServerSrc, "https://acme-staging.api.letsencrypt.org", "https://acme-staging-v02.api.letsencrypt.org", 1)
log.Warnf("The CA server %[1]q refers to a v01 endpoint of the ACME API, please change to %[2]q. Fallback to %[2]q.", caServerSrc, caServer)
return caServer
}
return caServerSrc
}
// ValidateConfiguration validate that configuration is coherent
func (gc *GlobalConfiguration) ValidateConfiguration() {
if gc.ACME != nil {
if _, ok := gc.EntryPoints[gc.ACME.EntryPoint]; !ok {
log.Fatalf("Unknown entrypoint %q for ACME configuration", gc.ACME.EntryPoint)
} else {
if gc.EntryPoints[gc.ACME.EntryPoint].TLS == nil {
log.Fatalf("Entrypoint %q has no TLS configuration for ACME configuration", gc.ACME.EntryPoint)
}
}
}
}
// DefaultEntryPoints holds default entry points
type DefaultEntryPoints []string
// String is the method to format the flag's value, part of the flag.Value interface.
// The String method's output will be used in diagnostics.
func (dep *DefaultEntryPoints) String() string {
return strings.Join(*dep, ",")
}
// Set is the method to set the flag value, part of the flag.Value interface.
// Set's argument is a string to be parsed to set the flag.
// It's a comma-separated list, so we split it.
func (dep *DefaultEntryPoints) Set(value string) error {
entrypoints := strings.Split(value, ",")
if len(entrypoints) == 0 {
return fmt.Errorf("bad DefaultEntryPoints format: %s", value)
}
for _, entrypoint := range entrypoints {
*dep = append(*dep, entrypoint)
}
return nil
}
// Get return the EntryPoints map
func (dep *DefaultEntryPoints) Get() interface{} {
return *dep
}
// SetValue sets the EntryPoints map with val
func (dep *DefaultEntryPoints) SetValue(val interface{}) {
*dep = val.(DefaultEntryPoints)
}
// Type is type of the struct
func (dep *DefaultEntryPoints) Type() string {
return "defaultentrypoints"
}
// Retry contains request retry config
type Retry struct {
Attempts int `description:"Number of attempts" export:"true"`
}
// HealthCheckConfig contains health check configuration parameters.
type HealthCheckConfig struct {
Interval parse.Duration `description:"Default periodicity of enabled health checks" export:"true"`
Timeout parse.Duration `description:"Default request timeout of enabled health checks" export:"true"`
}
// RespondingTimeouts contains timeout configurations for incoming requests to the Traefik instance.
type RespondingTimeouts struct {
ReadTimeout parse.Duration `description:"ReadTimeout is the maximum duration for reading the entire request, including the body. If zero, no timeout is set" export:"true"`
WriteTimeout parse.Duration `description:"WriteTimeout is the maximum duration before timing out writes of the response. If zero, no timeout is set" export:"true"`
IdleTimeout parse.Duration `description:"IdleTimeout is the maximum amount duration an idle (keep-alive) connection will remain idle before closing itself. Defaults to 180 seconds. If zero, no timeout is set" export:"true"`
}
// ForwardingTimeouts contains timeout configurations for forwarding requests to the backend servers.
type ForwardingTimeouts struct {
DialTimeout parse.Duration `description:"The amount of time to wait until a connection to a backend server can be established. Defaults to 30 seconds. If zero, no timeout exists" export:"true"`
ResponseHeaderTimeout parse.Duration `description:"The amount of time to wait for a server's response headers after fully writing the request (including its body, if any). If zero, no timeout exists" export:"true"`
}
// LifeCycle contains configurations relevant to the lifecycle (such as the
// shutdown phase) of Traefik.
type LifeCycle struct {
RequestAcceptGraceTimeout parse.Duration `description:"Duration to keep accepting requests before Traefik initiates the graceful shutdown procedure"`
GraceTimeOut parse.Duration `description:"Duration to give active requests a chance to finish before Traefik stops"`
}
// HostResolverConfig contain configuration for CNAME Flattening
type HostResolverConfig struct {
CnameFlattening bool `description:"A flag to enable/disable CNAME flattening" export:"true"`
ResolvConfig string `description:"resolv.conf used for DNS resolving" export:"true"`
ResolvDepth int `description:"The maximal depth of DNS recursive resolving" export:"true"`
}
// Deprecated
func convertACMEChallenge(oldACMEChallenge *acme.ACME) *acmeprovider.Configuration {
conf := &acmeprovider.Configuration{
KeyType: oldACMEChallenge.KeyType,
OnHostRule: oldACMEChallenge.OnHostRule,
OnDemand: oldACMEChallenge.OnDemand,
Email: oldACMEChallenge.Email,
Storage: oldACMEChallenge.Storage,
ACMELogging: oldACMEChallenge.ACMELogging,
CAServer: oldACMEChallenge.CAServer,
EntryPoint: oldACMEChallenge.EntryPoint,
}
for _, domain := range oldACMEChallenge.Domains {
if domain.Main != lego.UnFqdn(domain.Main) {
log.Warnf("FQDN detected, please remove the trailing dot: %s", domain.Main)
}
for _, san := range domain.SANs {
if san != lego.UnFqdn(san) {
log.Warnf("FQDN detected, please remove the trailing dot: %s", san)
}
}
conf.Domains = append(conf.Domains, newtypes.Domain(domain))
}
if oldACMEChallenge.HTTPChallenge != nil {
conf.HTTPChallenge = &acmeprovider.HTTPChallenge{
EntryPoint: oldACMEChallenge.HTTPChallenge.EntryPoint,
}
}
if oldACMEChallenge.DNSChallenge != nil {
conf.DNSChallenge = &acmeprovider.DNSChallenge{
Provider: oldACMEChallenge.DNSChallenge.Provider,
DelayBeforeCheck: oldACMEChallenge.DNSChallenge.DelayBeforeCheck,
}
}
if oldACMEChallenge.TLSChallenge != nil {
conf.TLSChallenge = &acmeprovider.TLSChallenge{}
}
return conf
}

View file

@ -0,0 +1,224 @@
package configuration
import (
"testing"
"github.com/containous/traefik/acme"
"github.com/containous/traefik/old/middlewares/tracing"
"github.com/containous/traefik/old/middlewares/tracing/jaeger"
"github.com/containous/traefik/old/middlewares/tracing/zipkin"
"github.com/containous/traefik/old/provider"
acmeprovider "github.com/containous/traefik/old/provider/acme"
"github.com/containous/traefik/old/provider/file"
"github.com/stretchr/testify/assert"
)
const defaultConfigFile = "traefik.toml"
func TestSetEffectiveConfigurationFileProviderFilename(t *testing.T) {
testCases := []struct {
desc string
fileProvider *file.Provider
wantFileProviderFilename string
wantFileProviderTraefikFile string
}{
{
desc: "no filename for file provider given",
fileProvider: &file.Provider{},
wantFileProviderFilename: "",
wantFileProviderTraefikFile: defaultConfigFile,
},
{
desc: "filename for file provider given",
fileProvider: &file.Provider{BaseProvider: provider.BaseProvider{Filename: "other.toml"}},
wantFileProviderFilename: "other.toml",
wantFileProviderTraefikFile: defaultConfigFile,
},
{
desc: "directory for file provider given",
fileProvider: &file.Provider{Directory: "/"},
wantFileProviderFilename: "",
wantFileProviderTraefikFile: defaultConfigFile,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
gc := &GlobalConfiguration{
File: test.fileProvider,
}
gc.SetEffectiveConfiguration(defaultConfigFile)
assert.Equal(t, test.wantFileProviderFilename, gc.File.Filename)
assert.Equal(t, test.wantFileProviderTraefikFile, gc.File.TraefikFile)
})
}
}
func TestSetEffectiveConfigurationTracing(t *testing.T) {
testCases := []struct {
desc string
tracing *tracing.Tracing
expected *tracing.Tracing
}{
{
desc: "no tracing configuration",
tracing: &tracing.Tracing{},
expected: &tracing.Tracing{},
},
{
desc: "tracing bad backend name",
tracing: &tracing.Tracing{
Backend: "powpow",
},
expected: &tracing.Tracing{
Backend: "powpow",
},
},
{
desc: "tracing jaeger backend name",
tracing: &tracing.Tracing{
Backend: "jaeger",
Zipkin: &zipkin.Config{
HTTPEndpoint: "http://localhost:9411/api/v1/spans",
SameSpan: false,
ID128Bit: true,
Debug: false,
},
},
expected: &tracing.Tracing{
Backend: "jaeger",
Jaeger: &jaeger.Config{
SamplingServerURL: "http://localhost:5778/sampling",
SamplingType: "const",
SamplingParam: 1.0,
LocalAgentHostPort: "127.0.0.1:6831",
Propagation: "jaeger",
Gen128Bit: false,
},
Zipkin: nil,
},
},
{
desc: "tracing zipkin backend name",
tracing: &tracing.Tracing{
Backend: "zipkin",
Jaeger: &jaeger.Config{
SamplingServerURL: "http://localhost:5778/sampling",
SamplingType: "const",
SamplingParam: 1.0,
LocalAgentHostPort: "127.0.0.1:6831",
},
},
expected: &tracing.Tracing{
Backend: "zipkin",
Jaeger: nil,
Zipkin: &zipkin.Config{
HTTPEndpoint: "http://localhost:9411/api/v1/spans",
SameSpan: false,
ID128Bit: true,
Debug: false,
SampleRate: 1.0,
},
},
},
{
desc: "tracing zipkin backend name value override",
tracing: &tracing.Tracing{
Backend: "zipkin",
Jaeger: &jaeger.Config{
SamplingServerURL: "http://localhost:5778/sampling",
SamplingType: "const",
SamplingParam: 1.0,
LocalAgentHostPort: "127.0.0.1:6831",
},
Zipkin: &zipkin.Config{
HTTPEndpoint: "http://powpow:9411/api/v1/spans",
SameSpan: true,
ID128Bit: true,
Debug: true,
SampleRate: 0.02,
},
},
expected: &tracing.Tracing{
Backend: "zipkin",
Jaeger: nil,
Zipkin: &zipkin.Config{
HTTPEndpoint: "http://powpow:9411/api/v1/spans",
SameSpan: true,
ID128Bit: true,
Debug: true,
SampleRate: 0.02,
},
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
gc := &GlobalConfiguration{
Tracing: test.tracing,
}
gc.SetEffectiveConfiguration(defaultConfigFile)
assert.Equal(t, test.expected, gc.Tracing)
})
}
}
func TestInitACMEProvider(t *testing.T) {
testCases := []struct {
desc string
acmeConfiguration *acme.ACME
expectedConfiguration *acmeprovider.Provider
noError bool
}{
{
desc: "No ACME configuration",
acmeConfiguration: nil,
expectedConfiguration: nil,
noError: true,
},
{
desc: "ACME configuration with storage",
acmeConfiguration: &acme.ACME{Storage: "foo/acme.json"},
expectedConfiguration: &acmeprovider.Provider{Configuration: &acmeprovider.Configuration{Storage: "foo/acme.json"}},
noError: true,
},
{
desc: "ACME configuration with no storage",
acmeConfiguration: &acme.ACME{},
expectedConfiguration: nil,
noError: false,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
gc := &GlobalConfiguration{
ACME: test.acmeConfiguration,
}
configuration, err := gc.InitACMEProvider()
assert.True(t, (err == nil) == test.noError)
if test.expectedConfiguration == nil {
assert.Nil(t, configuration)
} else {
assert.Equal(t, test.expectedConfiguration.Storage, configuration.Storage)
}
})
}
}

View file

@ -0,0 +1,322 @@
package configuration
import (
"fmt"
"strconv"
"strings"
"github.com/containous/traefik/old/log"
"github.com/containous/traefik/old/types"
"github.com/containous/traefik/tls"
)
// EntryPoint holds an entry point configuration of the reverse proxy (ip, port, TLS...)
type EntryPoint struct {
Address string
TLS *tls.TLS `export:"true"`
Redirect *types.Redirect `export:"true"`
Auth *types.Auth `export:"true"`
WhiteList *types.WhiteList `export:"true"`
Compress *Compress `export:"true"`
ProxyProtocol *ProxyProtocol `export:"true"`
ForwardedHeaders *ForwardedHeaders `export:"true"`
ClientIPStrategy *types.IPStrategy `export:"true"`
}
// Compress contains compress configuration
type Compress struct{}
// ProxyProtocol contains Proxy-Protocol configuration
type ProxyProtocol struct {
Insecure bool `export:"true"`
TrustedIPs []string
}
// ForwardedHeaders Trust client forwarding headers
type ForwardedHeaders struct {
Insecure bool `export:"true"`
TrustedIPs []string
}
// EntryPoints holds entry points configuration of the reverse proxy (ip, port, TLS...)
type EntryPoints map[string]*EntryPoint
// String is the method to format the flag's value, part of the flag.Value interface.
// The String method's output will be used in diagnostics.
func (ep EntryPoints) String() string {
return fmt.Sprintf("%+v", map[string]*EntryPoint(ep))
}
// Get return the EntryPoints map
func (ep *EntryPoints) Get() interface{} {
return *ep
}
// SetValue sets the EntryPoints map with val
func (ep *EntryPoints) SetValue(val interface{}) {
*ep = val.(EntryPoints)
}
// Type is type of the struct
func (ep *EntryPoints) Type() string {
return "entrypoints"
}
// Set is the method to set the flag value, part of the flag.Value interface.
// Set's argument is a string to be parsed to set the flag.
// It's a comma-separated list, so we split it.
func (ep *EntryPoints) Set(value string) error {
result := parseEntryPointsConfiguration(value)
var compress *Compress
if len(result["compress"]) > 0 {
compress = &Compress{}
}
configTLS, err := makeEntryPointTLS(result)
if err != nil {
return err
}
(*ep)[result["name"]] = &EntryPoint{
Address: result["address"],
TLS: configTLS,
Auth: makeEntryPointAuth(result),
Redirect: makeEntryPointRedirect(result),
Compress: compress,
WhiteList: makeWhiteList(result),
ProxyProtocol: makeEntryPointProxyProtocol(result),
ForwardedHeaders: makeEntryPointForwardedHeaders(result),
ClientIPStrategy: makeIPStrategy("clientipstrategy", result),
}
return nil
}
func makeWhiteList(result map[string]string) *types.WhiteList {
if rawRange, ok := result["whitelist_sourcerange"]; ok {
return &types.WhiteList{
SourceRange: strings.Split(rawRange, ","),
IPStrategy: makeIPStrategy("whitelist_ipstrategy", result),
}
}
return nil
}
func makeIPStrategy(prefix string, result map[string]string) *types.IPStrategy {
depth := toInt(result, prefix+"_depth")
excludedIPs := result[prefix+"_excludedips"]
if depth == 0 && len(excludedIPs) == 0 {
return nil
}
return &types.IPStrategy{
Depth: depth,
ExcludedIPs: strings.Split(excludedIPs, ","),
}
}
func makeEntryPointAuth(result map[string]string) *types.Auth {
var basic *types.Basic
if v, ok := result["auth_basic_users"]; ok {
basic = &types.Basic{
Realm: result["auth_basic_realm"],
Users: strings.Split(v, ","),
RemoveHeader: toBool(result, "auth_basic_removeheader"),
}
}
var digest *types.Digest
if v, ok := result["auth_digest_users"]; ok {
digest = &types.Digest{
Users: strings.Split(v, ","),
RemoveHeader: toBool(result, "auth_digest_removeheader"),
}
}
var forward *types.Forward
if address, ok := result["auth_forward_address"]; ok {
var clientTLS *types.ClientTLS
cert := result["auth_forward_tls_cert"]
key := result["auth_forward_tls_key"]
insecureSkipVerify := toBool(result, "auth_forward_tls_insecureskipverify")
if len(cert) > 0 && len(key) > 0 || insecureSkipVerify {
clientTLS = &types.ClientTLS{
CA: result["auth_forward_tls_ca"],
CAOptional: toBool(result, "auth_forward_tls_caoptional"),
Cert: cert,
Key: key,
InsecureSkipVerify: insecureSkipVerify,
}
}
var authResponseHeaders []string
if v, ok := result["auth_forward_authresponseheaders"]; ok {
authResponseHeaders = strings.Split(v, ",")
}
forward = &types.Forward{
Address: address,
TLS: clientTLS,
TrustForwardHeader: toBool(result, "auth_forward_trustforwardheader"),
AuthResponseHeaders: authResponseHeaders,
}
}
var auth *types.Auth
if basic != nil || digest != nil || forward != nil {
auth = &types.Auth{
Basic: basic,
Digest: digest,
Forward: forward,
HeaderField: result["auth_headerfield"],
}
}
return auth
}
func makeEntryPointProxyProtocol(result map[string]string) *ProxyProtocol {
var proxyProtocol *ProxyProtocol
ppTrustedIPs := result["proxyprotocol_trustedips"]
if len(result["proxyprotocol_insecure"]) > 0 || len(ppTrustedIPs) > 0 {
proxyProtocol = &ProxyProtocol{
Insecure: toBool(result, "proxyprotocol_insecure"),
}
if len(ppTrustedIPs) > 0 {
proxyProtocol.TrustedIPs = strings.Split(ppTrustedIPs, ",")
}
}
if proxyProtocol != nil && proxyProtocol.Insecure {
log.Warn("ProxyProtocol.insecure:true is dangerous. Please use 'ProxyProtocol.TrustedIPs:IPs' and remove 'ProxyProtocol.insecure:true'")
}
return proxyProtocol
}
func makeEntryPointForwardedHeaders(result map[string]string) *ForwardedHeaders {
forwardedHeaders := &ForwardedHeaders{}
if _, ok := result["forwardedheaders_insecure"]; ok {
forwardedHeaders.Insecure = toBool(result, "forwardedheaders_insecure")
}
fhTrustedIPs := result["forwardedheaders_trustedips"]
if len(fhTrustedIPs) > 0 {
// TODO must be removed in the next breaking version.
forwardedHeaders.Insecure = toBool(result, "forwardedheaders_insecure")
forwardedHeaders.TrustedIPs = strings.Split(fhTrustedIPs, ",")
}
return forwardedHeaders
}
func makeEntryPointRedirect(result map[string]string) *types.Redirect {
var redirect *types.Redirect
if len(result["redirect_entrypoint"]) > 0 || len(result["redirect_regex"]) > 0 || len(result["redirect_replacement"]) > 0 {
redirect = &types.Redirect{
EntryPoint: result["redirect_entrypoint"],
Regex: result["redirect_regex"],
Replacement: result["redirect_replacement"],
Permanent: toBool(result, "redirect_permanent"),
}
}
return redirect
}
func makeEntryPointTLS(result map[string]string) (*tls.TLS, error) {
var configTLS *tls.TLS
if len(result["tls"]) > 0 {
certs := tls.Certificates{}
if err := certs.Set(result["tls"]); err != nil {
return nil, err
}
configTLS = &tls.TLS{
Certificates: certs,
}
} else if len(result["tls_acme"]) > 0 {
configTLS = &tls.TLS{
Certificates: tls.Certificates{},
}
}
if configTLS != nil {
if len(result["ca"]) > 0 {
files := tls.FilesOrContents{}
files.Set(result["ca"])
optional := toBool(result, "ca_optional")
configTLS.ClientCA = tls.ClientCA{
Files: files,
Optional: optional,
}
}
if len(result["tls_minversion"]) > 0 {
configTLS.MinVersion = result["tls_minversion"]
}
if len(result["tls_ciphersuites"]) > 0 {
configTLS.CipherSuites = strings.Split(result["tls_ciphersuites"], ",")
}
if len(result["tls_snistrict"]) > 0 {
configTLS.SniStrict = toBool(result, "tls_snistrict")
}
if len(result["tls_defaultcertificate_cert"]) > 0 && len(result["tls_defaultcertificate_key"]) > 0 {
configTLS.DefaultCertificate = &tls.Certificate{
CertFile: tls.FileOrContent(result["tls_defaultcertificate_cert"]),
KeyFile: tls.FileOrContent(result["tls_defaultcertificate_key"]),
}
}
}
return configTLS, nil
}
func parseEntryPointsConfiguration(raw string) map[string]string {
sections := strings.Fields(raw)
config := make(map[string]string)
for _, part := range sections {
field := strings.SplitN(part, ":", 2)
name := strings.ToLower(strings.Replace(field[0], ".", "_", -1))
if len(field) > 1 {
config[name] = field[1]
} else {
if strings.EqualFold(name, "TLS") {
config["tls_acme"] = "TLS"
} else {
config[name] = ""
}
}
}
return config
}
func toBool(conf map[string]string, key string) bool {
if val, ok := conf[key]; ok {
return strings.EqualFold(val, "true") ||
strings.EqualFold(val, "enable") ||
strings.EqualFold(val, "on")
}
return false
}
func toInt(conf map[string]string, key string) int {
if val, ok := conf[key]; ok {
intVal, err := strconv.Atoi(val)
if err != nil {
return 0
}
return intVal
}
return 0
}

View file

@ -0,0 +1,515 @@
package configuration
import (
"testing"
"github.com/containous/traefik/old/types"
"github.com/containous/traefik/tls"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_parseEntryPointsConfiguration(t *testing.T) {
testCases := []struct {
name string
value string
expectedResult map[string]string
}{
{
name: "all parameters",
value: "Name:foo " +
"Address::8000 " +
"TLS:goo,gii " +
"TLS " +
"TLS.MinVersion:VersionTLS11 " +
"TLS.CipherSuites:TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA384,TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA " +
"CA:car " +
"CA.Optional:true " +
"Redirect.EntryPoint:https " +
"Redirect.Regex:http://localhost/(.*) " +
"Redirect.Replacement:http://mydomain/$1 " +
"Redirect.Permanent:true " +
"Compress:true " +
"ProxyProtocol.TrustedIPs:192.168.0.1 " +
"ForwardedHeaders.TrustedIPs:10.0.0.3/24,20.0.0.3/24 " +
"Auth.Basic.Realm:myRealm " +
"Auth.Basic.Users:test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/,test2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0 " +
"Auth.Basic.RemoveHeader:true " +
"Auth.Digest.Users:test:traefik:a2688e031edb4be6a3797f3882655c05,test2:traefik:518845800f9e2bfb1f1f740ec24f074e " +
"Auth.Digest.RemoveHeader:true " +
"Auth.HeaderField:X-WebAuth-User " +
"Auth.Forward.Address:https://authserver.com/auth " +
"Auth.Forward.AuthResponseHeaders:X-Auth,X-Test,X-Secret " +
"Auth.Forward.TrustForwardHeader:true " +
"Auth.Forward.TLS.CA:path/to/local.crt " +
"Auth.Forward.TLS.CAOptional:true " +
"Auth.Forward.TLS.Cert:path/to/foo.cert " +
"Auth.Forward.TLS.Key:path/to/foo.key " +
"Auth.Forward.TLS.InsecureSkipVerify:true " +
"WhiteList.SourceRange:10.42.0.0/16,152.89.1.33/32,afed:be44::/16 " +
"WhiteList.IPStrategy.depth:3 " +
"WhiteList.IPStrategy.ExcludedIPs:10.0.0.3/24,20.0.0.3/24 " +
"ClientIPStrategy.depth:3 " +
"ClientIPStrategy.ExcludedIPs:10.0.0.3/24,20.0.0.3/24 ",
expectedResult: map[string]string{
"address": ":8000",
"auth_basic_realm": "myRealm",
"auth_basic_users": "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/,test2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0",
"auth_basic_removeheader": "true",
"auth_digest_users": "test:traefik:a2688e031edb4be6a3797f3882655c05,test2:traefik:518845800f9e2bfb1f1f740ec24f074e",
"auth_digest_removeheader": "true",
"auth_forward_address": "https://authserver.com/auth",
"auth_forward_authresponseheaders": "X-Auth,X-Test,X-Secret",
"auth_forward_tls_ca": "path/to/local.crt",
"auth_forward_tls_caoptional": "true",
"auth_forward_tls_cert": "path/to/foo.cert",
"auth_forward_tls_insecureskipverify": "true",
"auth_forward_tls_key": "path/to/foo.key",
"auth_forward_trustforwardheader": "true",
"auth_headerfield": "X-WebAuth-User",
"ca": "car",
"ca_optional": "true",
"compress": "true",
"forwardedheaders_trustedips": "10.0.0.3/24,20.0.0.3/24",
"name": "foo",
"proxyprotocol_trustedips": "192.168.0.1",
"redirect_entrypoint": "https",
"redirect_permanent": "true",
"redirect_regex": "http://localhost/(.*)",
"redirect_replacement": "http://mydomain/$1",
"tls": "goo,gii",
"tls_acme": "TLS",
"tls_ciphersuites": "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA384,TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA",
"tls_minversion": "VersionTLS11",
"whitelist_sourcerange": "10.42.0.0/16,152.89.1.33/32,afed:be44::/16",
"whitelist_ipstrategy_depth": "3",
"whitelist_ipstrategy_excludedips": "10.0.0.3/24,20.0.0.3/24",
"clientipstrategy_depth": "3",
"clientipstrategy_excludedips": "10.0.0.3/24,20.0.0.3/24",
},
},
{
name: "compress on",
value: "name:foo Compress:on",
expectedResult: map[string]string{
"name": "foo",
"compress": "on",
},
},
{
name: "TLS",
value: "Name:foo TLS:goo TLS",
expectedResult: map[string]string{
"name": "foo",
"tls": "goo",
"tls_acme": "TLS",
},
},
}
for _, test := range testCases {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
conf := parseEntryPointsConfiguration(test.value)
assert.Len(t, conf, len(test.expectedResult))
assert.Equal(t, test.expectedResult, conf)
})
}
}
func Test_toBool(t *testing.T) {
testCases := []struct {
name string
value string
key string
expectedBool bool
}{
{
name: "on",
value: "on",
key: "foo",
expectedBool: true,
},
{
name: "true",
value: "true",
key: "foo",
expectedBool: true,
},
{
name: "enable",
value: "enable",
key: "foo",
expectedBool: true,
},
{
name: "arbitrary string",
value: "bar",
key: "foo",
expectedBool: false,
},
{
name: "no existing entry",
value: "bar",
key: "fii",
expectedBool: false,
},
}
for _, test := range testCases {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
conf := map[string]string{
"foo": test.value,
}
result := toBool(conf, test.key)
assert.Equal(t, test.expectedBool, result)
})
}
}
func TestEntryPoints_Set(t *testing.T) {
testCases := []struct {
name string
expression string
expectedEntryPointName string
expectedEntryPoint *EntryPoint
}{
{
name: "all parameters camelcase",
expression: "Name:foo " +
"Address::8000 " +
"TLS:goo,gii;foo,fii " +
"TLS " +
"TLS.MinVersion:VersionTLS11 " +
"TLS.CipherSuites:TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA384,TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA " +
"CA:car " +
"CA.Optional:true " +
"Redirect.EntryPoint:https " +
"Redirect.Regex:http://localhost/(.*) " +
"Redirect.Replacement:http://mydomain/$1 " +
"Redirect.Permanent:true " +
"Compress:true " +
"ProxyProtocol.TrustedIPs:192.168.0.1 " +
"ForwardedHeaders.TrustedIPs:10.0.0.3/24,20.0.0.3/24 " +
"Auth.Basic.Realm:myRealm " +
"Auth.Basic.Users:test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/,test2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0 " +
"Auth.Basic.RemoveHeader:true " +
"Auth.Digest.Users:test:traefik:a2688e031edb4be6a3797f3882655c05,test2:traefik:518845800f9e2bfb1f1f740ec24f074e " +
"Auth.Digest.RemoveHeader:true " +
"Auth.HeaderField:X-WebAuth-User " +
"Auth.Forward.Address:https://authserver.com/auth " +
"Auth.Forward.AuthResponseHeaders:X-Auth,X-Test,X-Secret " +
"Auth.Forward.TrustForwardHeader:true " +
"Auth.Forward.TLS.CA:path/to/local.crt " +
"Auth.Forward.TLS.CAOptional:true " +
"Auth.Forward.TLS.Cert:path/to/foo.cert " +
"Auth.Forward.TLS.Key:path/to/foo.key " +
"Auth.Forward.TLS.InsecureSkipVerify:true " +
"WhiteList.SourceRange:10.42.0.0/16,152.89.1.33/32,afed:be44::/16 " +
"WhiteList.IPStrategy.depth:3 " +
"WhiteList.IPStrategy.ExcludedIPs:10.0.0.3/24,20.0.0.3/24 " +
"ClientIPStrategy.depth:3 " +
"ClientIPStrategy.ExcludedIPs:10.0.0.3/24,20.0.0.3/24 ",
expectedEntryPointName: "foo",
expectedEntryPoint: &EntryPoint{
Address: ":8000",
TLS: &tls.TLS{
MinVersion: "VersionTLS11",
CipherSuites: []string{"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA384", "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA"},
Certificates: tls.Certificates{
{
CertFile: tls.FileOrContent("goo"),
KeyFile: tls.FileOrContent("gii"),
},
{
CertFile: tls.FileOrContent("foo"),
KeyFile: tls.FileOrContent("fii"),
},
},
ClientCA: tls.ClientCA{
Files: tls.FilesOrContents{"car"},
Optional: true,
},
},
Redirect: &types.Redirect{
EntryPoint: "https",
Regex: "http://localhost/(.*)",
Replacement: "http://mydomain/$1",
Permanent: true,
},
Auth: &types.Auth{
Basic: &types.Basic{
Realm: "myRealm",
RemoveHeader: true,
Users: types.Users{
"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/",
"test2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0",
},
},
Digest: &types.Digest{
RemoveHeader: true,
Users: types.Users{
"test:traefik:a2688e031edb4be6a3797f3882655c05",
"test2:traefik:518845800f9e2bfb1f1f740ec24f074e",
},
},
Forward: &types.Forward{
Address: "https://authserver.com/auth",
AuthResponseHeaders: []string{"X-Auth", "X-Test", "X-Secret"},
TLS: &types.ClientTLS{
CA: "path/to/local.crt",
CAOptional: true,
Cert: "path/to/foo.cert",
Key: "path/to/foo.key",
InsecureSkipVerify: true,
},
TrustForwardHeader: true,
},
HeaderField: "X-WebAuth-User",
},
WhiteList: &types.WhiteList{
SourceRange: []string{
"10.42.0.0/16",
"152.89.1.33/32",
"afed:be44::/16",
},
IPStrategy: &types.IPStrategy{
Depth: 3,
ExcludedIPs: []string{
"10.0.0.3/24",
"20.0.0.3/24",
},
},
},
Compress: &Compress{},
ProxyProtocol: &ProxyProtocol{
Insecure: false,
TrustedIPs: []string{"192.168.0.1"},
},
ForwardedHeaders: &ForwardedHeaders{
Insecure: false,
TrustedIPs: []string{
"10.0.0.3/24",
"20.0.0.3/24",
},
},
ClientIPStrategy: &types.IPStrategy{
Depth: 3,
ExcludedIPs: []string{
"10.0.0.3/24",
"20.0.0.3/24",
},
},
},
},
{
name: "all parameters lowercase",
expression: "Name:foo " +
"address::8000 " +
"tls:goo,gii;foo,fii " +
"tls " +
"tls.minversion:VersionTLS11 " +
"tls.ciphersuites:TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA384,TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA " +
"ca:car " +
"ca.Optional:true " +
"redirect.entryPoint:https " +
"redirect.regex:http://localhost/(.*) " +
"redirect.replacement:http://mydomain/$1 " +
"redirect.permanent:true " +
"compress:true " +
"whiteList.sourceRange:10.42.0.0/16,152.89.1.33/32,afed:be44::/16 " +
"proxyProtocol.TrustedIPs:192.168.0.1 " +
"forwardedHeaders.TrustedIPs:10.0.0.3/24,20.0.0.3/24 " +
"auth.basic.realm:myRealm " +
"auth.basic.users:test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/,test2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0 " +
"auth.digest.users:test:traefik:a2688e031edb4be6a3797f3882655c05,test2:traefik:518845800f9e2bfb1f1f740ec24f074e " +
"auth.headerField:X-WebAuth-User " +
"auth.forward.address:https://authserver.com/auth " +
"auth.forward.authResponseHeaders:X-Auth,X-Test,X-Secret " +
"auth.forward.trustForwardHeader:true " +
"auth.forward.tls.ca:path/to/local.crt " +
"auth.forward.tls.caOptional:true " +
"auth.forward.tls.cert:path/to/foo.cert " +
"auth.forward.tls.key:path/to/foo.key " +
"auth.forward.tls.insecureSkipVerify:true ",
expectedEntryPointName: "foo",
expectedEntryPoint: &EntryPoint{
Address: ":8000",
TLS: &tls.TLS{
MinVersion: "VersionTLS11",
CipherSuites: []string{"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA384", "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA"},
Certificates: tls.Certificates{
{
CertFile: tls.FileOrContent("goo"),
KeyFile: tls.FileOrContent("gii"),
},
{
CertFile: tls.FileOrContent("foo"),
KeyFile: tls.FileOrContent("fii"),
},
},
ClientCA: tls.ClientCA{
Files: tls.FilesOrContents{"car"},
Optional: true,
},
},
Redirect: &types.Redirect{
EntryPoint: "https",
Regex: "http://localhost/(.*)",
Replacement: "http://mydomain/$1",
Permanent: true,
},
Auth: &types.Auth{
Basic: &types.Basic{
Realm: "myRealm",
Users: types.Users{
"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/",
"test2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0",
},
},
Digest: &types.Digest{
Users: types.Users{
"test:traefik:a2688e031edb4be6a3797f3882655c05",
"test2:traefik:518845800f9e2bfb1f1f740ec24f074e",
},
},
Forward: &types.Forward{
Address: "https://authserver.com/auth",
AuthResponseHeaders: []string{"X-Auth", "X-Test", "X-Secret"},
TLS: &types.ClientTLS{
CA: "path/to/local.crt",
CAOptional: true,
Cert: "path/to/foo.cert",
Key: "path/to/foo.key",
InsecureSkipVerify: true,
},
TrustForwardHeader: true,
},
HeaderField: "X-WebAuth-User",
},
WhiteList: &types.WhiteList{
SourceRange: []string{
"10.42.0.0/16",
"152.89.1.33/32",
"afed:be44::/16",
},
},
Compress: &Compress{},
ProxyProtocol: &ProxyProtocol{
Insecure: false,
TrustedIPs: []string{"192.168.0.1"},
},
ForwardedHeaders: &ForwardedHeaders{
Insecure: false,
TrustedIPs: []string{
"10.0.0.3/24",
"20.0.0.3/24",
},
},
},
},
{
name: "default",
expression: "Name:foo",
expectedEntryPointName: "foo",
expectedEntryPoint: &EntryPoint{
ForwardedHeaders: &ForwardedHeaders{Insecure: false},
},
},
{
name: "ForwardedHeaders insecure true",
expression: "Name:foo ForwardedHeaders.insecure:true",
expectedEntryPointName: "foo",
expectedEntryPoint: &EntryPoint{
ForwardedHeaders: &ForwardedHeaders{Insecure: true},
},
},
{
name: "ForwardedHeaders insecure false",
expression: "Name:foo ForwardedHeaders.insecure:false",
expectedEntryPointName: "foo",
expectedEntryPoint: &EntryPoint{
ForwardedHeaders: &ForwardedHeaders{Insecure: false},
},
},
{
name: "ForwardedHeaders TrustedIPs",
expression: "Name:foo ForwardedHeaders.TrustedIPs:10.0.0.3/24,20.0.0.3/24",
expectedEntryPointName: "foo",
expectedEntryPoint: &EntryPoint{
ForwardedHeaders: &ForwardedHeaders{
TrustedIPs: []string{"10.0.0.3/24", "20.0.0.3/24"},
},
},
},
{
name: "ProxyProtocol insecure true",
expression: "Name:foo ProxyProtocol.insecure:true",
expectedEntryPointName: "foo",
expectedEntryPoint: &EntryPoint{
ForwardedHeaders: &ForwardedHeaders{},
ProxyProtocol: &ProxyProtocol{Insecure: true},
},
},
{
name: "ProxyProtocol insecure false",
expression: "Name:foo ProxyProtocol.insecure:false",
expectedEntryPointName: "foo",
expectedEntryPoint: &EntryPoint{
ForwardedHeaders: &ForwardedHeaders{},
ProxyProtocol: &ProxyProtocol{},
},
},
{
name: "ProxyProtocol TrustedIPs",
expression: "Name:foo ProxyProtocol.TrustedIPs:10.0.0.3/24,20.0.0.3/24",
expectedEntryPointName: "foo",
expectedEntryPoint: &EntryPoint{
ForwardedHeaders: &ForwardedHeaders{},
ProxyProtocol: &ProxyProtocol{
TrustedIPs: []string{"10.0.0.3/24", "20.0.0.3/24"},
},
},
},
{
name: "compress on",
expression: "Name:foo Compress:on",
expectedEntryPointName: "foo",
expectedEntryPoint: &EntryPoint{
Compress: &Compress{},
ForwardedHeaders: &ForwardedHeaders{},
},
},
{
name: "compress true",
expression: "Name:foo Compress:true",
expectedEntryPointName: "foo",
expectedEntryPoint: &EntryPoint{
Compress: &Compress{},
ForwardedHeaders: &ForwardedHeaders{},
},
},
}
for _, test := range testCases {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
eps := EntryPoints{}
err := eps.Set(test.expression)
require.NoError(t, err)
ep := eps[test.expectedEntryPointName]
assert.EqualValues(t, test.expectedEntryPoint, ep)
})
}
}

View file

@ -0,0 +1,111 @@
package configuration
import (
"encoding/json"
"github.com/containous/traefik/old/log"
"github.com/containous/traefik/old/provider"
"github.com/containous/traefik/old/types"
"github.com/containous/traefik/safe"
)
// ProviderAggregator aggregate providers
type ProviderAggregator struct {
providers []provider.Provider
constraints types.Constraints
}
// NewProviderAggregator return an aggregate of all the providers configured in GlobalConfiguration
func NewProviderAggregator(gc *GlobalConfiguration) ProviderAggregator {
provider := ProviderAggregator{
constraints: gc.Constraints,
}
if gc.Docker != nil {
provider.quietAddProvider(gc.Docker)
}
if gc.Marathon != nil {
provider.quietAddProvider(gc.Marathon)
}
if gc.File != nil {
provider.quietAddProvider(gc.File)
}
if gc.Rest != nil {
provider.quietAddProvider(gc.Rest)
}
if gc.Consul != nil {
provider.quietAddProvider(gc.Consul)
}
if gc.ConsulCatalog != nil {
provider.quietAddProvider(gc.ConsulCatalog)
}
if gc.Etcd != nil {
provider.quietAddProvider(gc.Etcd)
}
if gc.Zookeeper != nil {
provider.quietAddProvider(gc.Zookeeper)
}
if gc.Boltdb != nil {
provider.quietAddProvider(gc.Boltdb)
}
if gc.Kubernetes != nil {
provider.quietAddProvider(gc.Kubernetes)
}
if gc.Mesos != nil {
provider.quietAddProvider(gc.Mesos)
}
if gc.Eureka != nil {
provider.quietAddProvider(gc.Eureka)
}
if gc.ECS != nil {
provider.quietAddProvider(gc.ECS)
}
if gc.Rancher != nil {
provider.quietAddProvider(gc.Rancher)
}
if gc.DynamoDB != nil {
provider.quietAddProvider(gc.DynamoDB)
}
return provider
}
func (p *ProviderAggregator) quietAddProvider(provider provider.Provider) {
err := p.AddProvider(provider)
if err != nil {
log.Errorf("Error initializing provider %T: %v", provider, err)
}
}
// AddProvider add a provider in the providers map
func (p *ProviderAggregator) AddProvider(provider provider.Provider) error {
err := provider.Init(p.constraints)
if err != nil {
return err
}
p.providers = append(p.providers, provider)
return nil
}
// Init the provider
func (p ProviderAggregator) Init(_ types.Constraints) error {
return nil
}
// Provide call the provide method of every providers
func (p ProviderAggregator) Provide(configurationChan chan<- types.ConfigMessage, pool *safe.Pool) error {
for _, p := range p.providers {
jsonConf, err := json.Marshal(p)
if err != nil {
log.Debugf("Unable to marshal provider conf %T with error: %v", p, err)
}
log.Infof("Starting provider %T %s", p, jsonConf)
currentProvider := p
safe.Go(func() {
err := currentProvider.Provide(configurationChan, pool)
if err != nil {
log.Errorf("Error starting provider %T: %v", p, err)
}
})
}
return nil
}

View file

@ -0,0 +1,116 @@
package router
import (
"github.com/containous/mux"
"github.com/containous/traefik/old/configuration"
"github.com/containous/traefik/old/log"
"github.com/containous/traefik/old/middlewares"
mauth "github.com/containous/traefik/old/middlewares/auth"
"github.com/containous/traefik/old/types"
"github.com/urfave/negroni"
)
// NewInternalRouterAggregator Create a new internalRouterAggregator
func NewInternalRouterAggregator(globalConfiguration configuration.GlobalConfiguration, entryPointName string) *InternalRouterAggregator {
var serverMiddlewares []negroni.Handler
if globalConfiguration.EntryPoints[entryPointName].WhiteList != nil {
ipStrategy := globalConfiguration.EntryPoints[entryPointName].ClientIPStrategy
if globalConfiguration.EntryPoints[entryPointName].WhiteList.IPStrategy != nil {
ipStrategy = globalConfiguration.EntryPoints[entryPointName].WhiteList.IPStrategy
}
strategy, err := ipStrategy.Get()
if err != nil {
log.Fatalf("Error creating whitelist middleware: %s", err)
}
ipWhitelistMiddleware, err := middlewares.NewIPWhiteLister(globalConfiguration.EntryPoints[entryPointName].WhiteList.SourceRange, strategy)
if err != nil {
log.Fatalf("Error creating whitelist middleware: %s", err)
}
if ipWhitelistMiddleware != nil {
serverMiddlewares = append(serverMiddlewares, ipWhitelistMiddleware)
}
}
if globalConfiguration.EntryPoints[entryPointName].Auth != nil {
authMiddleware, err := mauth.NewAuthenticator(globalConfiguration.EntryPoints[entryPointName].Auth, nil)
if err != nil {
log.Fatalf("Error creating authenticator middleware: %s", err)
}
serverMiddlewares = append(serverMiddlewares, authMiddleware)
}
router := InternalRouterAggregator{}
routerWithMiddleware := InternalRouterAggregator{}
if globalConfiguration.Metrics != nil && globalConfiguration.Metrics.Prometheus != nil && globalConfiguration.Metrics.Prometheus.EntryPoint == entryPointName {
// routerWithMiddleware.AddRouter(metrics.PrometheusHandler{})
}
if globalConfiguration.Rest != nil && globalConfiguration.Rest.EntryPoint == entryPointName {
routerWithMiddleware.AddRouter(globalConfiguration.Rest)
}
if globalConfiguration.API != nil && globalConfiguration.API.EntryPoint == entryPointName {
routerWithMiddleware.AddRouter(globalConfiguration.API)
}
if globalConfiguration.Ping != nil && globalConfiguration.Ping.EntryPoint == entryPointName {
router.AddRouter(globalConfiguration.Ping)
}
if globalConfiguration.ACME != nil && globalConfiguration.ACME.HTTPChallenge != nil && globalConfiguration.ACME.HTTPChallenge.EntryPoint == entryPointName {
router.AddRouter(globalConfiguration.ACME)
}
router.AddRouter(&WithMiddleware{router: &routerWithMiddleware, routerMiddlewares: serverMiddlewares})
return &router
}
// WithMiddleware router with internal middleware
type WithMiddleware struct {
router types.InternalRouter
routerMiddlewares []negroni.Handler
}
// AddRoutes Add routes to the router
func (wm *WithMiddleware) AddRoutes(systemRouter *mux.Router) {
realRouter := systemRouter.PathPrefix("/").Subrouter()
wm.router.AddRoutes(realRouter)
if len(wm.routerMiddlewares) > 0 {
if err := realRouter.Walk(wrapRoute(wm.routerMiddlewares)); err != nil {
log.Error(err)
}
}
}
// InternalRouterAggregator InternalRouter that aggregate other internalRouter
type InternalRouterAggregator struct {
internalRouters []types.InternalRouter
}
// AddRouter add a router in the aggregator
func (r *InternalRouterAggregator) AddRouter(router types.InternalRouter) {
r.internalRouters = append(r.internalRouters, router)
}
// AddRoutes Add routes to the router
func (r *InternalRouterAggregator) AddRoutes(systemRouter *mux.Router) {
for _, router := range r.internalRouters {
router.AddRoutes(systemRouter)
}
}
// wrapRoute with middlewares
func wrapRoute(middlewares []negroni.Handler) func(*mux.Route, *mux.Router, []*mux.Route) error {
return func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
middles := append(middlewares, negroni.Wrap(route.GetHandler()))
route.Handler(negroni.New(middles...))
return nil
}
}

View file

@ -0,0 +1,151 @@
package router
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/mux"
"github.com/containous/traefik/acme"
"github.com/containous/traefik/old/api"
"github.com/containous/traefik/old/configuration"
"github.com/containous/traefik/old/log"
"github.com/containous/traefik/old/ping"
acmeprovider "github.com/containous/traefik/old/provider/acme"
"github.com/containous/traefik/old/types"
"github.com/containous/traefik/safe"
"github.com/stretchr/testify/assert"
"github.com/urfave/negroni"
)
func TestNewInternalRouterAggregatorWithAuth(t *testing.T) {
currentConfiguration := &safe.Safe{}
currentConfiguration.Set(types.Configurations{})
globalConfiguration := configuration.GlobalConfiguration{
API: &api.Handler{
EntryPoint: "traefik",
CurrentConfigurations: currentConfiguration,
},
Ping: &ping.Handler{
EntryPoint: "traefik",
},
ACME: &acme.ACME{
HTTPChallenge: &acmeprovider.HTTPChallenge{
EntryPoint: "traefik",
},
},
EntryPoints: configuration.EntryPoints{
"traefik": &configuration.EntryPoint{
Auth: &types.Auth{
Basic: &types.Basic{
Users: types.Users{"test:test"},
},
},
},
},
}
testCases := []struct {
desc string
testedURL string
expectedStatusCode int
}{
{
desc: "Wrong url",
testedURL: "/wrong",
expectedStatusCode: 502,
},
{
desc: "Ping without auth",
testedURL: "/ping",
expectedStatusCode: 200,
},
{
desc: "acme without auth",
testedURL: "/.well-known/acme-challenge/token",
expectedStatusCode: 404,
},
{
desc: "api with auth",
testedURL: "/api",
expectedStatusCode: 401,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
router := NewInternalRouterAggregator(globalConfiguration, "traefik")
internalMuxRouter := mux.NewRouter()
router.AddRoutes(internalMuxRouter)
internalMuxRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadGateway)
})
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, test.testedURL, nil)
internalMuxRouter.ServeHTTP(recorder, request)
assert.Equal(t, test.expectedStatusCode, recorder.Code)
})
}
}
type MockInternalRouterFunc func(systemRouter *mux.Router)
func (m MockInternalRouterFunc) AddRoutes(systemRouter *mux.Router) {
m(systemRouter)
}
func TestWithMiddleware(t *testing.T) {
router := WithMiddleware{
router: MockInternalRouterFunc(func(systemRouter *mux.Router) {
systemRouter.Handle("/test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if _, err := w.Write([]byte("router")); err != nil {
log.Error(err)
}
}))
}),
routerMiddlewares: []negroni.Handler{
negroni.HandlerFunc(func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
if _, err := rw.Write([]byte("before middleware1|")); err != nil {
log.Error(err)
}
next.ServeHTTP(rw, r)
if _, err := rw.Write([]byte("|after middleware1")); err != nil {
log.Error(err)
}
}),
negroni.HandlerFunc(func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
if _, err := rw.Write([]byte("before middleware2|")); err != nil {
log.Error(err)
}
next.ServeHTTP(rw, r)
if _, err := rw.Write([]byte("|after middleware2")); err != nil {
log.Error(err)
}
}),
},
}
internalMuxRouter := mux.NewRouter()
router.AddRoutes(internalMuxRouter)
recorder := httptest.NewRecorder()
request := httptest.NewRequest(http.MethodGet, "/test", nil)
internalMuxRouter.ServeHTTP(recorder, request)
obtained := recorder.Body.String()
assert.Equal(t, "before middleware1|before middleware2|router|after middleware2|after middleware1", obtained)
}

315
old/log/logger.go Normal file
View file

@ -0,0 +1,315 @@
package log
import (
"bufio"
"fmt"
"io"
"os"
"runtime"
"github.com/sirupsen/logrus"
)
// Logger allows overriding the logrus logger behavior
type Logger interface {
logrus.FieldLogger
WriterLevel(logrus.Level) *io.PipeWriter
}
var (
logger Logger
logFilePath string
logFile *os.File
)
func init() {
logger = logrus.StandardLogger().WithFields(logrus.Fields{})
logrus.SetOutput(os.Stdout)
}
// Context sets the Context of the logger
func Context(context interface{}) *logrus.Entry {
return logger.WithField("context", context)
}
// SetOutput sets the standard logger output.
func SetOutput(out io.Writer) {
logrus.SetOutput(out)
}
// SetFormatter sets the standard logger formatter.
func SetFormatter(formatter logrus.Formatter) {
logrus.SetFormatter(formatter)
}
// SetLevel sets the standard logger level.
func SetLevel(level logrus.Level) {
logrus.SetLevel(level)
}
// SetLogger sets the logger.
func SetLogger(l Logger) {
logger = l
}
// GetLevel returns the standard logger level.
func GetLevel() logrus.Level {
return logrus.GetLevel()
}
// AddHook adds a hook to the standard logger hooks.
func AddHook(hook logrus.Hook) {
logrus.AddHook(hook)
}
// WithError creates an entry from the standard logger and adds an error to it, using the value defined in ErrorKey as key.
func WithError(err error) *logrus.Entry {
return logger.WithError(err)
}
// WithField creates an entry from the standard logger and adds a field to
// it. If you want multiple fields, use `WithFields`.
//
// Note that it doesn't log until you call Debug, Print, Info, Warn, Fatal
// or Panic on the Entry it returns.
func WithField(key string, value interface{}) *logrus.Entry {
return logger.WithField(key, value)
}
// WithFields creates an entry from the standard logger and adds multiple
// fields to it. This is simply a helper for `WithField`, invoking it
// once for each field.
//
// Note that it doesn't log until you call Debug, Print, Info, Warn, Fatal
// or Panic on the Entry it returns.
func WithFields(fields logrus.Fields) *logrus.Entry {
return logger.WithFields(fields)
}
// Debug logs a message at level Debug on the standard logger.
func Debug(args ...interface{}) {
logger.Debug(args...)
}
// Print logs a message at level Info on the standard logger.
func Print(args ...interface{}) {
logger.Print(args...)
}
// Info logs a message at level Info on the standard logger.
func Info(args ...interface{}) {
logger.Info(args...)
}
// Warn logs a message at level Warn on the standard logger.
func Warn(args ...interface{}) {
logger.Warn(args...)
}
// Warning logs a message at level Warn on the standard logger.
func Warning(args ...interface{}) {
logger.Warning(args...)
}
// Error logs a message at level Error on the standard logger.
func Error(args ...interface{}) {
logger.Error(args...)
}
// Panic logs a message at level Panic on the standard logger.
func Panic(args ...interface{}) {
logger.Panic(args...)
}
// Fatal logs a message at level Fatal on the standard logger.
func Fatal(args ...interface{}) {
logger.Fatal(args...)
}
// Debugf logs a message at level Debug on the standard logger.
func Debugf(format string, args ...interface{}) {
logger.Debugf(format, args...)
}
// Printf logs a message at level Info on the standard logger.
func Printf(format string, args ...interface{}) {
logger.Printf(format, args...)
}
// Infof logs a message at level Info on the standard logger.
func Infof(format string, args ...interface{}) {
logger.Infof(format, args...)
}
// Warnf logs a message at level Warn on the standard logger.
func Warnf(format string, args ...interface{}) {
logger.Warnf(format, args...)
}
// Warningf logs a message at level Warn on the standard logger.
func Warningf(format string, args ...interface{}) {
logger.Warningf(format, args...)
}
// Errorf logs a message at level Error on the standard logger.
func Errorf(format string, args ...interface{}) {
logger.Errorf(format, args...)
}
// Panicf logs a message at level Panic on the standard logger.
func Panicf(format string, args ...interface{}) {
logger.Panicf(format, args...)
}
// Fatalf logs a message at level Fatal on the standard logger.
func Fatalf(format string, args ...interface{}) {
logger.Fatalf(format, args...)
}
// Debugln logs a message at level Debug on the standard logger.
func Debugln(args ...interface{}) {
logger.Debugln(args...)
}
// Println logs a message at level Info on the standard logger.
func Println(args ...interface{}) {
logger.Println(args...)
}
// Infoln logs a message at level Info on the standard logger.
func Infoln(args ...interface{}) {
logger.Infoln(args...)
}
// Warnln logs a message at level Warn on the standard logger.
func Warnln(args ...interface{}) {
logger.Warnln(args...)
}
// Warningln logs a message at level Warn on the standard logger.
func Warningln(args ...interface{}) {
logger.Warningln(args...)
}
// Errorln logs a message at level Error on the standard logger.
func Errorln(args ...interface{}) {
logger.Errorln(args...)
}
// Panicln logs a message at level Panic on the standard logger.
func Panicln(args ...interface{}) {
logger.Panicln(args...)
}
// Fatalln logs a message at level Fatal on the standard logger.
func Fatalln(args ...interface{}) {
logger.Fatalln(args...)
}
// OpenFile opens the log file using the specified path
func OpenFile(path string) error {
logFilePath = path
var err error
logFile, err = os.OpenFile(logFilePath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
if err == nil {
SetOutput(logFile)
}
return err
}
// CloseFile closes the log and sets the Output to stdout
func CloseFile() error {
logrus.SetOutput(os.Stdout)
if logFile != nil {
return logFile.Close()
}
return nil
}
// RotateFile closes and reopens the log file to allow for rotation
// by an external source. If the log isn't backed by a file then
// it does nothing.
func RotateFile() error {
if logFile == nil && logFilePath == "" {
Debug("Traefik log is not writing to a file, ignoring rotate request")
return nil
}
if logFile != nil {
defer func(f *os.File) {
f.Close()
}(logFile)
}
if err := OpenFile(logFilePath); err != nil {
return fmt.Errorf("error opening log file: %s", err)
}
return nil
}
// Writer logs writer (Level Info)
func Writer() *io.PipeWriter {
return WriterLevel(logrus.InfoLevel)
}
// WriterLevel logs writer for a specific level.
func WriterLevel(level logrus.Level) *io.PipeWriter {
return logger.WriterLevel(level)
}
// CustomWriterLevel logs writer for a specific level. (with a custom scanner buffer size.)
// adapted from github.com/Sirupsen/logrus/writer.go
func CustomWriterLevel(level logrus.Level, maxScanTokenSize int) *io.PipeWriter {
reader, writer := io.Pipe()
var printFunc func(args ...interface{})
switch level {
case logrus.DebugLevel:
printFunc = Debug
case logrus.InfoLevel:
printFunc = Info
case logrus.WarnLevel:
printFunc = Warn
case logrus.ErrorLevel:
printFunc = Error
case logrus.FatalLevel:
printFunc = Fatal
case logrus.PanicLevel:
printFunc = Panic
default:
printFunc = Print
}
go writerScanner(reader, maxScanTokenSize, printFunc)
runtime.SetFinalizer(writer, writerFinalizer)
return writer
}
// extract from github.com/Sirupsen/logrus/writer.go
// Hack the buffer size
func writerScanner(reader io.ReadCloser, scanTokenSize int, printFunc func(args ...interface{})) {
scanner := bufio.NewScanner(reader)
if scanTokenSize > bufio.MaxScanTokenSize {
buf := make([]byte, bufio.MaxScanTokenSize)
scanner.Buffer(buf, scanTokenSize)
}
for scanner.Scan() {
printFunc(scanner.Text())
}
if err := scanner.Err(); err != nil {
Errorf("Error while reading from Writer: %s", err)
}
reader.Close()
}
func writerFinalizer(writer *io.PipeWriter) {
writer.Close()
}

79
old/log/logger_test.go Normal file
View file

@ -0,0 +1,79 @@
package log
import (
"io/ioutil"
"os"
"strings"
"testing"
"time"
)
func TestLogRotation(t *testing.T) {
tempDir, err := ioutil.TempDir("", "traefik_")
if err != nil {
t.Fatalf("Error setting up temporary directory: %s", err)
}
fileName := tempDir + "traefik.log"
if err := OpenFile(fileName); err != nil {
t.Fatalf("Error opening temporary file %s: %s", fileName, err)
}
defer CloseFile()
rotatedFileName := fileName + ".rotated"
iterations := 20
halfDone := make(chan bool)
writeDone := make(chan bool)
go func() {
for i := 0; i < iterations; i++ {
Println("Test log line")
if i == iterations/2 {
halfDone <- true
}
}
writeDone <- true
}()
<-halfDone
err = os.Rename(fileName, rotatedFileName)
if err != nil {
t.Fatalf("Error renaming file: %s", err)
}
err = RotateFile()
if err != nil {
t.Fatalf("Error rotating file: %s", err)
}
select {
case <-writeDone:
gotLineCount := lineCount(t, fileName) + lineCount(t, rotatedFileName)
if iterations != gotLineCount {
t.Errorf("Wanted %d written log lines, got %d", iterations, gotLineCount)
}
case <-time.After(500 * time.Millisecond):
t.Fatalf("test timed out")
}
close(halfDone)
close(writeDone)
}
func lineCount(t *testing.T, fileName string) int {
t.Helper()
fileContents, err := ioutil.ReadFile(fileName)
if err != nil {
t.Fatalf("Error reading from file %s: %s", fileName, err)
}
count := 0
for _, line := range strings.Split(string(fileContents), "\n") {
if strings.TrimSpace(line) == "" {
continue
}
count++
}
return count
}

View file

@ -0,0 +1,18 @@
package accesslog
import "io"
type captureRequestReader struct {
source io.ReadCloser
count int64
}
func (r *captureRequestReader) Read(p []byte) (int, error) {
n, err := r.source.Read(p)
r.count += int64(n)
return n, err
}
func (r *captureRequestReader) Close() error {
return r.source.Close()
}

View file

@ -0,0 +1,68 @@
package accesslog
import (
"bufio"
"fmt"
"net"
"net/http"
"github.com/containous/traefik/old/middlewares"
)
var (
_ middlewares.Stateful = &captureResponseWriter{}
)
// captureResponseWriter is a wrapper of type http.ResponseWriter
// that tracks request status and size
type captureResponseWriter struct {
rw http.ResponseWriter
status int
size int64
}
func (crw *captureResponseWriter) Header() http.Header {
return crw.rw.Header()
}
func (crw *captureResponseWriter) Write(b []byte) (int, error) {
if crw.status == 0 {
crw.status = http.StatusOK
}
size, err := crw.rw.Write(b)
crw.size += int64(size)
return size, err
}
func (crw *captureResponseWriter) WriteHeader(s int) {
crw.rw.WriteHeader(s)
crw.status = s
}
func (crw *captureResponseWriter) Flush() {
if f, ok := crw.rw.(http.Flusher); ok {
f.Flush()
}
}
func (crw *captureResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if h, ok := crw.rw.(http.Hijacker); ok {
return h.Hijack()
}
return nil, nil, fmt.Errorf("not a hijacker: %T", crw.rw)
}
func (crw *captureResponseWriter) CloseNotify() <-chan bool {
if c, ok := crw.rw.(http.CloseNotifier); ok {
return c.CloseNotify()
}
return nil
}
func (crw *captureResponseWriter) Status() int {
return crw.status
}
func (crw *captureResponseWriter) Size() int64 {
return crw.size
}

View file

@ -0,0 +1,120 @@
package accesslog
import (
"net/http"
)
const (
// StartUTC is the map key used for the time at which request processing started.
StartUTC = "StartUTC"
// StartLocal is the map key used for the local time at which request processing started.
StartLocal = "StartLocal"
// Duration is the map key used for the total time taken by processing the response, including the origin server's time but
// not the log writing time.
Duration = "Duration"
// FrontendName is the map key used for the name of the Traefik frontend.
FrontendName = "FrontendName"
// BackendName is the map key used for the name of the Traefik backend.
BackendName = "BackendName"
// BackendURL is the map key used for the URL of the Traefik backend.
BackendURL = "BackendURL"
// BackendAddr is the map key used for the IP:port of the Traefik backend (extracted from BackendURL)
BackendAddr = "BackendAddr"
// ClientAddr is the map key used for the remote address in its original form (usually IP:port).
ClientAddr = "ClientAddr"
// ClientHost is the map key used for the remote IP address from which the client request was received.
ClientHost = "ClientHost"
// ClientPort is the map key used for the remote TCP port from which the client request was received.
ClientPort = "ClientPort"
// ClientUsername is the map key used for the username provided in the URL, if present.
ClientUsername = "ClientUsername"
// RequestAddr is the map key used for the HTTP Host header (usually IP:port). This is treated as not a header by the Go API.
RequestAddr = "RequestAddr"
// RequestHost is the map key used for the HTTP Host server name (not including port).
RequestHost = "RequestHost"
// RequestPort is the map key used for the TCP port from the HTTP Host.
RequestPort = "RequestPort"
// RequestMethod is the map key used for the HTTP method.
RequestMethod = "RequestMethod"
// RequestPath is the map key used for the HTTP request URI, not including the scheme, host or port.
RequestPath = "RequestPath"
// RequestProtocol is the map key used for the version of HTTP requested.
RequestProtocol = "RequestProtocol"
// RequestContentSize is the map key used for the number of bytes in the request entity (a.k.a. body) sent by the client.
RequestContentSize = "RequestContentSize"
// RequestRefererHeader is the Referer header in the request
RequestRefererHeader = "request_Referer"
// RequestUserAgentHeader is the User-Agent header in the request
RequestUserAgentHeader = "request_User-Agent"
// OriginDuration is the map key used for the time taken by the origin server ('upstream') to return its response.
OriginDuration = "OriginDuration"
// OriginContentSize is the map key used for the content length specified by the origin server, or 0 if unspecified.
OriginContentSize = "OriginContentSize"
// OriginStatus is the map key used for the HTTP status code returned by the origin server.
// If the request was handled by this Traefik instance (e.g. with a redirect), then this value will be absent.
OriginStatus = "OriginStatus"
// DownstreamStatus is the map key used for the HTTP status code returned to the client.
DownstreamStatus = "DownstreamStatus"
// DownstreamContentSize is the map key used for the number of bytes in the response entity returned to the client.
// This is in addition to the "Content-Length" header, which may be present in the origin response.
DownstreamContentSize = "DownstreamContentSize"
// RequestCount is the map key used for the number of requests received since the Traefik instance started.
RequestCount = "RequestCount"
// GzipRatio is the map key used for the response body compression ratio achieved.
GzipRatio = "GzipRatio"
// Overhead is the map key used for the processing time overhead caused by Traefik.
Overhead = "Overhead"
// RetryAttempts is the map key used for the amount of attempts the request was retried.
RetryAttempts = "RetryAttempts"
)
// These are written out in the default case when no config is provided to specify keys of interest.
var defaultCoreKeys = [...]string{
StartUTC,
Duration,
FrontendName,
BackendName,
BackendURL,
ClientHost,
ClientPort,
ClientUsername,
RequestHost,
RequestPort,
RequestMethod,
RequestPath,
RequestProtocol,
RequestContentSize,
OriginDuration,
OriginContentSize,
OriginStatus,
DownstreamStatus,
DownstreamContentSize,
RequestCount,
}
// This contains the set of all keys, i.e. all the default keys plus all non-default keys.
var allCoreKeys = make(map[string]struct{})
func init() {
for _, k := range defaultCoreKeys {
allCoreKeys[k] = struct{}{}
}
allCoreKeys[BackendAddr] = struct{}{}
allCoreKeys[ClientAddr] = struct{}{}
allCoreKeys[RequestAddr] = struct{}{}
allCoreKeys[GzipRatio] = struct{}{}
allCoreKeys[StartLocal] = struct{}{}
allCoreKeys[Overhead] = struct{}{}
allCoreKeys[RetryAttempts] = struct{}{}
}
// CoreLogData holds the fields computed from the request/response.
type CoreLogData map[string]interface{}
// LogData is the data captured by the middleware so that it can be logged.
type LogData struct {
Core CoreLogData
Request http.Header
OriginResponse http.Header
DownstreamResponse http.Header
}

View file

@ -0,0 +1,334 @@
package accesslog
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"sync"
"sync/atomic"
"time"
"github.com/containous/flaeg/parse"
"github.com/containous/traefik/old/log"
"github.com/containous/traefik/old/types"
"github.com/sirupsen/logrus"
)
type key string
const (
// DataTableKey is the key within the request context used to
// store the Log Data Table
DataTableKey key = "LogDataTable"
// CommonFormat is the common logging format (CLF)
CommonFormat string = "common"
// JSONFormat is the JSON logging format
JSONFormat string = "json"
)
type logHandlerParams struct {
logDataTable *LogData
crr *captureRequestReader
crw *captureResponseWriter
}
// LogHandler will write each request and its response to the access log.
type LogHandler struct {
config *types.AccessLog
logger *logrus.Logger
file *os.File
mu sync.Mutex
httpCodeRanges types.HTTPCodeRanges
logHandlerChan chan logHandlerParams
wg sync.WaitGroup
}
// NewLogHandler creates a new LogHandler
func NewLogHandler(config *types.AccessLog) (*LogHandler, error) {
file := os.Stdout
if len(config.FilePath) > 0 {
f, err := openAccessLogFile(config.FilePath)
if err != nil {
return nil, fmt.Errorf("error opening access log file: %s", err)
}
file = f
}
logHandlerChan := make(chan logHandlerParams, config.BufferingSize)
var formatter logrus.Formatter
switch config.Format {
case CommonFormat:
formatter = new(CommonLogFormatter)
case JSONFormat:
formatter = new(logrus.JSONFormatter)
default:
return nil, fmt.Errorf("unsupported access log format: %s", config.Format)
}
logger := &logrus.Logger{
Out: file,
Formatter: formatter,
Hooks: make(logrus.LevelHooks),
Level: logrus.InfoLevel,
}
logHandler := &LogHandler{
config: config,
logger: logger,
file: file,
logHandlerChan: logHandlerChan,
}
if config.Filters != nil {
if httpCodeRanges, err := types.NewHTTPCodeRanges(config.Filters.StatusCodes); err != nil {
log.Errorf("Failed to create new HTTP code ranges: %s", err)
} else {
logHandler.httpCodeRanges = httpCodeRanges
}
}
if config.BufferingSize > 0 {
logHandler.wg.Add(1)
go func() {
defer logHandler.wg.Done()
for handlerParams := range logHandler.logHandlerChan {
logHandler.logTheRoundTrip(handlerParams.logDataTable, handlerParams.crr, handlerParams.crw)
}
}()
}
return logHandler, nil
}
func openAccessLogFile(filePath string) (*os.File, error) {
dir := filepath.Dir(filePath)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create log path %s: %s", dir, err)
}
file, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0664)
if err != nil {
return nil, fmt.Errorf("error opening file %s: %s", filePath, err)
}
return file, nil
}
// GetLogDataTable gets the request context object that contains logging data.
// This creates data as the request passes through the middleware chain.
func GetLogDataTable(req *http.Request) *LogData {
if ld, ok := req.Context().Value(DataTableKey).(*LogData); ok {
return ld
}
log.Errorf("%s is nil", DataTableKey)
return &LogData{Core: make(CoreLogData)}
}
func (l *LogHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
now := time.Now().UTC()
core := CoreLogData{
StartUTC: now,
StartLocal: now.Local(),
}
logDataTable := &LogData{Core: core, Request: req.Header}
reqWithDataTable := req.WithContext(context.WithValue(req.Context(), DataTableKey, logDataTable))
var crr *captureRequestReader
if req.Body != nil {
crr = &captureRequestReader{source: req.Body, count: 0}
reqWithDataTable.Body = crr
}
core[RequestCount] = nextRequestCount()
if req.Host != "" {
core[RequestAddr] = req.Host
core[RequestHost], core[RequestPort] = silentSplitHostPort(req.Host)
}
// copy the URL without the scheme, hostname etc
urlCopy := &url.URL{
Path: req.URL.Path,
RawPath: req.URL.RawPath,
RawQuery: req.URL.RawQuery,
ForceQuery: req.URL.ForceQuery,
Fragment: req.URL.Fragment,
}
urlCopyString := urlCopy.String()
core[RequestMethod] = req.Method
core[RequestPath] = urlCopyString
core[RequestProtocol] = req.Proto
core[ClientAddr] = req.RemoteAddr
core[ClientHost], core[ClientPort] = silentSplitHostPort(req.RemoteAddr)
if forwardedFor := req.Header.Get("X-Forwarded-For"); forwardedFor != "" {
core[ClientHost] = forwardedFor
}
crw := &captureResponseWriter{rw: rw}
next.ServeHTTP(crw, reqWithDataTable)
core[ClientUsername] = formatUsernameForLog(core[ClientUsername])
logDataTable.DownstreamResponse = crw.Header()
if l.config.BufferingSize > 0 {
l.logHandlerChan <- logHandlerParams{
logDataTable: logDataTable,
crr: crr,
crw: crw,
}
} else {
l.logTheRoundTrip(logDataTable, crr, crw)
}
}
// Close closes the Logger (i.e. the file, drain logHandlerChan, etc).
func (l *LogHandler) Close() error {
close(l.logHandlerChan)
l.wg.Wait()
return l.file.Close()
}
// Rotate closes and reopens the log file to allow for rotation
// by an external source.
func (l *LogHandler) Rotate() error {
var err error
if l.file != nil {
defer func(f *os.File) {
f.Close()
}(l.file)
}
l.file, err = os.OpenFile(l.config.FilePath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0664)
if err != nil {
return err
}
l.mu.Lock()
defer l.mu.Unlock()
l.logger.Out = l.file
return nil
}
func silentSplitHostPort(value string) (host string, port string) {
host, port, err := net.SplitHostPort(value)
if err != nil {
return value, "-"
}
return host, port
}
func formatUsernameForLog(usernameField interface{}) string {
username, ok := usernameField.(string)
if ok && len(username) != 0 {
return username
}
return "-"
}
// Logging handler to log frontend name, backend name, and elapsed time
func (l *LogHandler) logTheRoundTrip(logDataTable *LogData, crr *captureRequestReader, crw *captureResponseWriter) {
core := logDataTable.Core
retryAttempts, ok := core[RetryAttempts].(int)
if !ok {
retryAttempts = 0
}
core[RetryAttempts] = retryAttempts
if crr != nil {
core[RequestContentSize] = crr.count
}
core[DownstreamStatus] = crw.Status()
// n.b. take care to perform time arithmetic using UTC to avoid errors at DST boundaries
totalDuration := time.Now().UTC().Sub(core[StartUTC].(time.Time))
core[Duration] = totalDuration
if l.keepAccessLog(crw.Status(), retryAttempts, totalDuration) {
core[DownstreamContentSize] = crw.Size()
if original, ok := core[OriginContentSize]; ok {
o64 := original.(int64)
if o64 != crw.Size() && 0 != crw.Size() {
core[GzipRatio] = float64(o64) / float64(crw.Size())
}
}
core[Overhead] = totalDuration
if origin, ok := core[OriginDuration]; ok {
core[Overhead] = totalDuration - origin.(time.Duration)
}
fields := logrus.Fields{}
for k, v := range logDataTable.Core {
if l.config.Fields.Keep(k) {
fields[k] = v
}
}
l.redactHeaders(logDataTable.Request, fields, "request_")
l.redactHeaders(logDataTable.OriginResponse, fields, "origin_")
l.redactHeaders(logDataTable.DownstreamResponse, fields, "downstream_")
l.mu.Lock()
defer l.mu.Unlock()
l.logger.WithFields(fields).Println()
}
}
func (l *LogHandler) redactHeaders(headers http.Header, fields logrus.Fields, prefix string) {
for k := range headers {
v := l.config.Fields.KeepHeader(k)
if v == types.AccessLogKeep {
fields[prefix+k] = headers.Get(k)
} else if v == types.AccessLogRedact {
fields[prefix+k] = "REDACTED"
}
}
}
func (l *LogHandler) keepAccessLog(statusCode, retryAttempts int, duration time.Duration) bool {
if l.config.Filters == nil {
// no filters were specified
return true
}
if len(l.httpCodeRanges) == 0 && !l.config.Filters.RetryAttempts && l.config.Filters.MinDuration == 0 {
// empty filters were specified, e.g. by passing --accessLog.filters only (without other filter options)
return true
}
if l.httpCodeRanges.Contains(statusCode) {
return true
}
if l.config.Filters.RetryAttempts && retryAttempts > 0 {
return true
}
if l.config.Filters.MinDuration > 0 && (parse.Duration(duration) > l.config.Filters.MinDuration) {
return true
}
return false
}
var requestCounter uint64 // Request ID
func nextRequestCount() uint64 {
return atomic.AddUint64(&requestCounter, 1)
}

View file

@ -0,0 +1,82 @@
package accesslog
import (
"bytes"
"fmt"
"time"
"github.com/sirupsen/logrus"
)
// default format for time presentation
const (
commonLogTimeFormat = "02/Jan/2006:15:04:05 -0700"
defaultValue = "-"
)
// CommonLogFormatter provides formatting in the Traefik common log format
type CommonLogFormatter struct{}
// Format formats the log entry in the Traefik common log format
func (f *CommonLogFormatter) Format(entry *logrus.Entry) ([]byte, error) {
b := &bytes.Buffer{}
var timestamp = defaultValue
if v, ok := entry.Data[StartUTC]; ok {
timestamp = v.(time.Time).Format(commonLogTimeFormat)
}
var elapsedMillis int64
if v, ok := entry.Data[Duration]; ok {
elapsedMillis = v.(time.Duration).Nanoseconds() / 1000000
}
_, err := fmt.Fprintf(b, "%s - %s [%s] \"%s %s %s\" %v %v %s %s %v %s %s %dms\n",
toLog(entry.Data, ClientHost, defaultValue, false),
toLog(entry.Data, ClientUsername, defaultValue, false),
timestamp,
toLog(entry.Data, RequestMethod, defaultValue, false),
toLog(entry.Data, RequestPath, defaultValue, false),
toLog(entry.Data, RequestProtocol, defaultValue, false),
toLog(entry.Data, OriginStatus, defaultValue, true),
toLog(entry.Data, OriginContentSize, defaultValue, true),
toLog(entry.Data, "request_Referer", `"-"`, true),
toLog(entry.Data, "request_User-Agent", `"-"`, true),
toLog(entry.Data, RequestCount, defaultValue, true),
toLog(entry.Data, FrontendName, defaultValue, true),
toLog(entry.Data, BackendURL, defaultValue, true),
elapsedMillis)
return b.Bytes(), err
}
func toLog(fields logrus.Fields, key string, defaultValue string, quoted bool) interface{} {
if v, ok := fields[key]; ok {
if v == nil {
return defaultValue
}
switch s := v.(type) {
case string:
return toLogEntry(s, defaultValue, quoted)
case fmt.Stringer:
return toLogEntry(s.String(), defaultValue, quoted)
default:
return v
}
}
return defaultValue
}
func toLogEntry(s string, defaultValue string, quote bool) string {
if len(s) == 0 {
return defaultValue
}
if quote {
return `"` + s + `"`
}
return s
}

View file

@ -0,0 +1,140 @@
package accesslog
import (
"net/http"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)
func TestCommonLogFormatter_Format(t *testing.T) {
clf := CommonLogFormatter{}
testCases := []struct {
name string
data map[string]interface{}
expectedLog string
}{
{
name: "OriginStatus & OriginContentSize are nil",
data: map[string]interface{}{
StartUTC: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC),
Duration: 123 * time.Second,
ClientHost: "10.0.0.1",
ClientUsername: "Client",
RequestMethod: http.MethodGet,
RequestPath: "/foo",
RequestProtocol: "http",
OriginStatus: nil,
OriginContentSize: nil,
RequestRefererHeader: "",
RequestUserAgentHeader: "",
RequestCount: 0,
FrontendName: "",
BackendURL: "",
},
expectedLog: `10.0.0.1 - Client [10/Nov/2009:23:00:00 +0000] "GET /foo http" - - "-" "-" 0 - - 123000ms
`,
},
{
name: "all data",
data: map[string]interface{}{
StartUTC: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC),
Duration: 123 * time.Second,
ClientHost: "10.0.0.1",
ClientUsername: "Client",
RequestMethod: http.MethodGet,
RequestPath: "/foo",
RequestProtocol: "http",
OriginStatus: 123,
OriginContentSize: 132,
RequestRefererHeader: "referer",
RequestUserAgentHeader: "agent",
RequestCount: nil,
FrontendName: "foo",
BackendURL: "http://10.0.0.2/toto",
},
expectedLog: `10.0.0.1 - Client [10/Nov/2009:23:00:00 +0000] "GET /foo http" 123 132 "referer" "agent" - "foo" "http://10.0.0.2/toto" 123000ms
`,
},
}
for _, test := range testCases {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
entry := &logrus.Entry{Data: test.data}
raw, err := clf.Format(entry)
assert.NoError(t, err)
assert.Equal(t, test.expectedLog, string(raw))
})
}
}
func Test_toLog(t *testing.T) {
testCases := []struct {
desc string
fields logrus.Fields
fieldName string
defaultValue string
quoted bool
expectedLog interface{}
}{
{
desc: "Should return int 1",
fields: logrus.Fields{
"Powpow": 1,
},
fieldName: "Powpow",
defaultValue: defaultValue,
quoted: false,
expectedLog: 1,
},
{
desc: "Should return string foo",
fields: logrus.Fields{
"Powpow": "foo",
},
fieldName: "Powpow",
defaultValue: defaultValue,
quoted: true,
expectedLog: `"foo"`,
},
{
desc: "Should return defaultValue if fieldName does not exist",
fields: logrus.Fields{
"Powpow": "foo",
},
fieldName: "",
defaultValue: defaultValue,
quoted: false,
expectedLog: "-",
},
{
desc: "Should return defaultValue if fields is nil",
fields: nil,
fieldName: "",
defaultValue: defaultValue,
quoted: false,
expectedLog: "-",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
lg := toLog(test.fields, test.fieldName, defaultValue, test.quoted)
assert.Equal(t, test.expectedLog, lg)
})
}
}

View file

@ -0,0 +1,644 @@
package accesslog
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
"testing"
"time"
"github.com/containous/flaeg/parse"
"github.com/containous/traefik/old/log"
"github.com/containous/traefik/old/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
logFileNameSuffix = "/traefik/logger/test.log"
testContent = "Hello, World"
testBackendName = "http://127.0.0.1/testBackend"
testFrontendName = "testFrontend"
testStatus = 123
testContentSize int64 = 12
testHostname = "TestHost"
testUsername = "TestUser"
testPath = "testpath"
testPort = 8181
testProto = "HTTP/0.0"
testMethod = http.MethodPost
testReferer = "testReferer"
testUserAgent = "testUserAgent"
testRetryAttempts = 2
testStart = time.Now()
)
func TestLogRotation(t *testing.T) {
tempDir, err := ioutil.TempDir("", "traefik_")
if err != nil {
t.Fatalf("Error setting up temporary directory: %s", err)
}
fileName := tempDir + "traefik.log"
rotatedFileName := fileName + ".rotated"
config := &types.AccessLog{FilePath: fileName, Format: CommonFormat}
logHandler, err := NewLogHandler(config)
if err != nil {
t.Fatalf("Error creating new log handler: %s", err)
}
defer logHandler.Close()
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
next := func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
}
iterations := 20
halfDone := make(chan bool)
writeDone := make(chan bool)
go func() {
for i := 0; i < iterations; i++ {
logHandler.ServeHTTP(recorder, req, next)
if i == iterations/2 {
halfDone <- true
}
}
writeDone <- true
}()
<-halfDone
err = os.Rename(fileName, rotatedFileName)
if err != nil {
t.Fatalf("Error renaming file: %s", err)
}
err = logHandler.Rotate()
if err != nil {
t.Fatalf("Error rotating file: %s", err)
}
select {
case <-writeDone:
gotLineCount := lineCount(t, fileName) + lineCount(t, rotatedFileName)
if iterations != gotLineCount {
t.Errorf("Wanted %d written log lines, got %d", iterations, gotLineCount)
}
case <-time.After(500 * time.Millisecond):
t.Fatalf("test timed out")
}
close(halfDone)
close(writeDone)
}
func lineCount(t *testing.T, fileName string) int {
t.Helper()
fileContents, err := ioutil.ReadFile(fileName)
if err != nil {
t.Fatalf("Error reading from file %s: %s", fileName, err)
}
count := 0
for _, line := range strings.Split(string(fileContents), "\n") {
if strings.TrimSpace(line) == "" {
continue
}
count++
}
return count
}
func TestLoggerCLF(t *testing.T) {
tmpDir := createTempDir(t, CommonFormat)
defer os.RemoveAll(tmpDir)
logFilePath := filepath.Join(tmpDir, logFileNameSuffix)
config := &types.AccessLog{FilePath: logFilePath, Format: CommonFormat}
doLogging(t, config)
logData, err := ioutil.ReadFile(logFilePath)
require.NoError(t, err)
expectedLog := ` TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testFrontend" "http://127.0.0.1/testBackend" 1ms`
assertValidLogData(t, expectedLog, logData)
}
func TestAsyncLoggerCLF(t *testing.T) {
tmpDir := createTempDir(t, CommonFormat)
defer os.RemoveAll(tmpDir)
logFilePath := filepath.Join(tmpDir, logFileNameSuffix)
config := &types.AccessLog{FilePath: logFilePath, Format: CommonFormat, BufferingSize: 1024}
doLogging(t, config)
logData, err := ioutil.ReadFile(logFilePath)
require.NoError(t, err)
expectedLog := ` TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testFrontend" "http://127.0.0.1/testBackend" 1ms`
assertValidLogData(t, expectedLog, logData)
}
func assertString(exp string) func(t *testing.T, actual interface{}) {
return func(t *testing.T, actual interface{}) {
t.Helper()
assert.Equal(t, exp, actual)
}
}
func assertNotEqual(exp string) func(t *testing.T, actual interface{}) {
return func(t *testing.T, actual interface{}) {
t.Helper()
assert.NotEqual(t, exp, actual)
}
}
func assertFloat64(exp float64) func(t *testing.T, actual interface{}) {
return func(t *testing.T, actual interface{}) {
t.Helper()
assert.Equal(t, exp, actual)
}
}
func assertFloat64NotZero() func(t *testing.T, actual interface{}) {
return func(t *testing.T, actual interface{}) {
t.Helper()
assert.NotZero(t, actual)
}
}
func TestLoggerJSON(t *testing.T) {
testCases := []struct {
desc string
config *types.AccessLog
expected map[string]func(t *testing.T, value interface{})
}{
{
desc: "default config",
config: &types.AccessLog{
FilePath: "",
Format: JSONFormat,
},
expected: map[string]func(t *testing.T, value interface{}){
RequestHost: assertString(testHostname),
RequestAddr: assertString(testHostname),
RequestMethod: assertString(testMethod),
RequestPath: assertString(testPath),
RequestProtocol: assertString(testProto),
RequestPort: assertString("-"),
DownstreamStatus: assertFloat64(float64(testStatus)),
DownstreamContentSize: assertFloat64(float64(len(testContent))),
OriginContentSize: assertFloat64(float64(len(testContent))),
OriginStatus: assertFloat64(float64(testStatus)),
RequestRefererHeader: assertString(testReferer),
RequestUserAgentHeader: assertString(testUserAgent),
FrontendName: assertString(testFrontendName),
BackendURL: assertString(testBackendName),
ClientUsername: assertString(testUsername),
ClientHost: assertString(testHostname),
ClientPort: assertString(fmt.Sprintf("%d", testPort)),
ClientAddr: assertString(fmt.Sprintf("%s:%d", testHostname, testPort)),
"level": assertString("info"),
"msg": assertString(""),
"downstream_Content-Type": assertString("text/plain; charset=utf-8"),
RequestCount: assertFloat64NotZero(),
Duration: assertFloat64NotZero(),
Overhead: assertFloat64NotZero(),
RetryAttempts: assertFloat64(float64(testRetryAttempts)),
"time": assertNotEqual(""),
"StartLocal": assertNotEqual(""),
"StartUTC": assertNotEqual(""),
},
},
{
desc: "default config drop all fields",
config: &types.AccessLog{
FilePath: "",
Format: JSONFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
},
},
expected: map[string]func(t *testing.T, value interface{}){
"level": assertString("info"),
"msg": assertString(""),
"time": assertNotEqual(""),
"downstream_Content-Type": assertString("text/plain; charset=utf-8"),
RequestRefererHeader: assertString(testReferer),
RequestUserAgentHeader: assertString(testUserAgent),
},
},
{
desc: "default config drop all fields and headers",
config: &types.AccessLog{
FilePath: "",
Format: JSONFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Headers: &types.FieldHeaders{
DefaultMode: "drop",
},
},
},
expected: map[string]func(t *testing.T, value interface{}){
"level": assertString("info"),
"msg": assertString(""),
"time": assertNotEqual(""),
},
},
{
desc: "default config drop all fields and redact headers",
config: &types.AccessLog{
FilePath: "",
Format: JSONFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Headers: &types.FieldHeaders{
DefaultMode: "redact",
},
},
},
expected: map[string]func(t *testing.T, value interface{}){
"level": assertString("info"),
"msg": assertString(""),
"time": assertNotEqual(""),
"downstream_Content-Type": assertString("REDACTED"),
RequestRefererHeader: assertString("REDACTED"),
RequestUserAgentHeader: assertString("REDACTED"),
},
},
{
desc: "default config drop all fields and headers but kept someone",
config: &types.AccessLog{
FilePath: "",
Format: JSONFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Names: types.FieldNames{
RequestHost: "keep",
},
Headers: &types.FieldHeaders{
DefaultMode: "drop",
Names: types.FieldHeaderNames{
"Referer": "keep",
},
},
},
},
expected: map[string]func(t *testing.T, value interface{}){
RequestHost: assertString(testHostname),
"level": assertString("info"),
"msg": assertString(""),
"time": assertNotEqual(""),
RequestRefererHeader: assertString(testReferer),
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
tmpDir := createTempDir(t, JSONFormat)
defer os.RemoveAll(tmpDir)
logFilePath := filepath.Join(tmpDir, logFileNameSuffix)
test.config.FilePath = logFilePath
doLogging(t, test.config)
logData, err := ioutil.ReadFile(logFilePath)
require.NoError(t, err)
jsonData := make(map[string]interface{})
err = json.Unmarshal(logData, &jsonData)
require.NoError(t, err)
assert.Equal(t, len(test.expected), len(jsonData))
for field, assertion := range test.expected {
assertion(t, jsonData[field])
}
})
}
}
func TestNewLogHandlerOutputStdout(t *testing.T) {
testCases := []struct {
desc string
config *types.AccessLog
expectedLog string
}{
{
desc: "default config",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
},
{
desc: "default config with empty filters",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Filters: &types.AccessLogFilters{},
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
},
{
desc: "Status code filter not matching",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Filters: &types.AccessLogFilters{
StatusCodes: []string{"200"},
},
},
expectedLog: ``,
},
{
desc: "Status code filter matching",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Filters: &types.AccessLogFilters{
StatusCodes: []string{"123"},
},
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
},
{
desc: "Duration filter not matching",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Filters: &types.AccessLogFilters{
MinDuration: parse.Duration(1 * time.Hour),
},
},
expectedLog: ``,
},
{
desc: "Duration filter matching",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Filters: &types.AccessLogFilters{
MinDuration: parse.Duration(1 * time.Millisecond),
},
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
},
{
desc: "Retry attempts filter matching",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Filters: &types.AccessLogFilters{
RetryAttempts: true,
},
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
},
{
desc: "Default mode keep",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Fields: &types.AccessLogFields{
DefaultMode: "keep",
},
},
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
},
{
desc: "Default mode keep with override",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Fields: &types.AccessLogFields{
DefaultMode: "keep",
Names: types.FieldNames{
ClientHost: "drop",
},
},
},
expectedLog: `- - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
},
{
desc: "Default mode drop",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
},
},
expectedLog: `- - - [-] "- - -" - - "testReferer" "testUserAgent" - - - 0ms`,
},
{
desc: "Default mode drop with override",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Names: types.FieldNames{
ClientHost: "drop",
ClientUsername: "keep",
},
},
},
expectedLog: `- - TestUser [-] "- - -" - - "testReferer" "testUserAgent" - - - 0ms`,
},
{
desc: "Default mode drop with header dropped",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Names: types.FieldNames{
ClientHost: "drop",
ClientUsername: "keep",
},
Headers: &types.FieldHeaders{
DefaultMode: "drop",
},
},
},
expectedLog: `- - TestUser [-] "- - -" - - "-" "-" - - - 0ms`,
},
{
desc: "Default mode drop with header redacted",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Names: types.FieldNames{
ClientHost: "drop",
ClientUsername: "keep",
},
Headers: &types.FieldHeaders{
DefaultMode: "redact",
},
},
},
expectedLog: `- - TestUser [-] "- - -" - - "REDACTED" "REDACTED" - - - 0ms`,
},
{
desc: "Default mode drop with header redacted",
config: &types.AccessLog{
FilePath: "",
Format: CommonFormat,
Fields: &types.AccessLogFields{
DefaultMode: "drop",
Names: types.FieldNames{
ClientHost: "drop",
ClientUsername: "keep",
},
Headers: &types.FieldHeaders{
DefaultMode: "keep",
Names: types.FieldHeaderNames{
"Referer": "redact",
},
},
},
},
expectedLog: `- - TestUser [-] "- - -" - - "REDACTED" "testUserAgent" - - - 0ms`,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
// NOTE: It is not possible to run these cases in parallel because we capture Stdout
file, restoreStdout := captureStdout(t)
defer restoreStdout()
doLogging(t, test.config)
written, err := ioutil.ReadFile(file.Name())
require.NoError(t, err, "unable to read captured stdout from file")
assertValidLogData(t, test.expectedLog, written)
})
}
}
func assertValidLogData(t *testing.T, expected string, logData []byte) {
if len(expected) == 0 {
assert.Zero(t, len(logData))
t.Log(string(logData))
return
}
result, err := ParseAccessLog(string(logData))
require.NoError(t, err)
resultExpected, err := ParseAccessLog(expected)
require.NoError(t, err)
formatErrMessage := fmt.Sprintf(`
Expected: %s
Actual: %s`, expected, string(logData))
require.Equal(t, len(resultExpected), len(result), formatErrMessage)
assert.Equal(t, resultExpected[ClientHost], result[ClientHost], formatErrMessage)
assert.Equal(t, resultExpected[ClientUsername], result[ClientUsername], formatErrMessage)
assert.Equal(t, resultExpected[RequestMethod], result[RequestMethod], formatErrMessage)
assert.Equal(t, resultExpected[RequestPath], result[RequestPath], formatErrMessage)
assert.Equal(t, resultExpected[RequestProtocol], result[RequestProtocol], formatErrMessage)
assert.Equal(t, resultExpected[OriginStatus], result[OriginStatus], formatErrMessage)
assert.Equal(t, resultExpected[OriginContentSize], result[OriginContentSize], formatErrMessage)
assert.Equal(t, resultExpected[RequestRefererHeader], result[RequestRefererHeader], formatErrMessage)
assert.Equal(t, resultExpected[RequestUserAgentHeader], result[RequestUserAgentHeader], formatErrMessage)
assert.Regexp(t, regexp.MustCompile("[0-9]*"), result[RequestCount], formatErrMessage)
assert.Equal(t, resultExpected[FrontendName], result[FrontendName], formatErrMessage)
assert.Equal(t, resultExpected[BackendURL], result[BackendURL], formatErrMessage)
assert.Regexp(t, regexp.MustCompile("[0-9]*ms"), result[Duration], formatErrMessage)
}
func captureStdout(t *testing.T) (out *os.File, restoreStdout func()) {
file, err := ioutil.TempFile("", "testlogger")
require.NoError(t, err, "failed to create temp file")
original := os.Stdout
os.Stdout = file
restoreStdout = func() {
os.Stdout = original
}
return file, restoreStdout
}
func createTempDir(t *testing.T, prefix string) string {
tmpDir, err := ioutil.TempDir("", prefix)
require.NoError(t, err, "failed to create temp dir")
return tmpDir
}
func doLogging(t *testing.T, config *types.AccessLog) {
logger, err := NewLogHandler(config)
require.NoError(t, err)
defer logger.Close()
if config.FilePath != "" {
_, err = os.Stat(config.FilePath)
require.NoError(t, err, fmt.Sprintf("logger should create %s", config.FilePath))
}
req := &http.Request{
Header: map[string][]string{
"User-Agent": {testUserAgent},
"Referer": {testReferer},
},
Proto: testProto,
Host: testHostname,
Method: testMethod,
RemoteAddr: fmt.Sprintf("%s:%d", testHostname, testPort),
URL: &url.URL{
Path: testPath,
},
}
logger.ServeHTTP(httptest.NewRecorder(), req, logWriterTestHandlerFunc)
}
func logWriterTestHandlerFunc(rw http.ResponseWriter, r *http.Request) {
if _, err := rw.Write([]byte(testContent)); err != nil {
log.Error(err)
}
rw.WriteHeader(testStatus)
logDataTable := GetLogDataTable(r)
logDataTable.Core[FrontendName] = testFrontendName
logDataTable.Core[BackendURL] = testBackendName
logDataTable.Core[OriginStatus] = testStatus
logDataTable.Core[OriginContentSize] = testContentSize
logDataTable.Core[RetryAttempts] = testRetryAttempts
logDataTable.Core[StartUTC] = testStart.UTC()
logDataTable.Core[StartLocal] = testStart.Local()
logDataTable.Core[ClientUsername] = testUsername
}

View file

@ -0,0 +1,54 @@
package accesslog
import (
"bytes"
"regexp"
)
// ParseAccessLog parse line of access log and return a map with each fields
func ParseAccessLog(data string) (map[string]string, error) {
var buffer bytes.Buffer
buffer.WriteString(`(\S+)`) // 1 - ClientHost
buffer.WriteString(`\s-\s`) // - - Spaces
buffer.WriteString(`(\S+)\s`) // 2 - ClientUsername
buffer.WriteString(`\[([^]]+)\]\s`) // 3 - StartUTC
buffer.WriteString(`"(\S*)\s?`) // 4 - RequestMethod
buffer.WriteString(`((?:[^"]*(?:\\")?)*)\s`) // 5 - RequestPath
buffer.WriteString(`([^"]*)"\s`) // 6 - RequestProtocol
buffer.WriteString(`(\S+)\s`) // 7 - OriginStatus
buffer.WriteString(`(\S+)\s`) // 8 - OriginContentSize
buffer.WriteString(`("?\S+"?)\s`) // 9 - Referrer
buffer.WriteString(`("\S+")\s`) // 10 - User-Agent
buffer.WriteString(`(\S+)\s`) // 11 - RequestCount
buffer.WriteString(`("[^"]*"|-)\s`) // 12 - FrontendName
buffer.WriteString(`("[^"]*"|-)\s`) // 13 - BackendURL
buffer.WriteString(`(\S+)`) // 14 - Duration
regex, err := regexp.Compile(buffer.String())
if err != nil {
return nil, err
}
submatch := regex.FindStringSubmatch(data)
result := make(map[string]string)
// Need to be > 13 to match CLF format
if len(submatch) > 13 {
result[ClientHost] = submatch[1]
result[ClientUsername] = submatch[2]
result[StartUTC] = submatch[3]
result[RequestMethod] = submatch[4]
result[RequestPath] = submatch[5]
result[RequestProtocol] = submatch[6]
result[OriginStatus] = submatch[7]
result[OriginContentSize] = submatch[8]
result[RequestRefererHeader] = submatch[9]
result[RequestUserAgentHeader] = submatch[10]
result[RequestCount] = submatch[11]
result[FrontendName] = submatch[12]
result[BackendURL] = submatch[13]
result[Duration] = submatch[14]
}
return result, nil
}

View file

@ -0,0 +1,75 @@
package accesslog
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestParseAccessLog(t *testing.T) {
testCases := []struct {
desc string
value string
expected map[string]string
}{
{
desc: "full log",
value: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
expected: map[string]string{
ClientHost: "TestHost",
ClientUsername: "TestUser",
StartUTC: "13/Apr/2016:07:14:19 -0700",
RequestMethod: "POST",
RequestPath: "testpath",
RequestProtocol: "HTTP/0.0",
OriginStatus: "123",
OriginContentSize: "12",
RequestRefererHeader: `"testReferer"`,
RequestUserAgentHeader: `"testUserAgent"`,
RequestCount: "1",
FrontendName: `"testFrontend"`,
BackendURL: `"http://127.0.0.1/testBackend"`,
Duration: "1ms",
},
},
{
desc: "log with space",
value: `127.0.0.1 - - [09/Mar/2018:10:51:32 +0000] "GET / HTTP/1.1" 401 17 "-" "Go-http-client/1.1" 1 "testFrontend with space" - 0ms`,
expected: map[string]string{
ClientHost: "127.0.0.1",
ClientUsername: "-",
StartUTC: "09/Mar/2018:10:51:32 +0000",
RequestMethod: "GET",
RequestPath: "/",
RequestProtocol: "HTTP/1.1",
OriginStatus: "401",
OriginContentSize: "17",
RequestRefererHeader: `"-"`,
RequestUserAgentHeader: `"Go-http-client/1.1"`,
RequestCount: "1",
FrontendName: `"testFrontend with space"`,
BackendURL: `-`,
Duration: "0ms",
},
},
{
desc: "bad log",
value: `bad`,
expected: map[string]string{},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
result, err := ParseAccessLog(test.value)
assert.NoError(t, err)
assert.Equal(t, len(test.expected), len(result))
for key, value := range test.expected {
assert.Equal(t, value, result[key])
}
})
}
}

View file

@ -0,0 +1,64 @@
package accesslog
import (
"net/http"
"time"
"github.com/urfave/negroni"
"github.com/vulcand/oxy/utils"
)
// SaveBackend sends the backend name to the logger.
// These are always used with a corresponding SaveFrontend handler.
type SaveBackend struct {
next http.Handler
backendName string
}
// NewSaveBackend creates a SaveBackend handler.
func NewSaveBackend(next http.Handler, backendName string) http.Handler {
return &SaveBackend{next, backendName}
}
func (sb *SaveBackend) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
serveSaveBackend(rw, r, sb.backendName, func(crw *captureResponseWriter) {
sb.next.ServeHTTP(crw, r)
})
}
// SaveNegroniBackend sends the backend name to the logger.
type SaveNegroniBackend struct {
next negroni.Handler
backendName string
}
// NewSaveNegroniBackend creates a SaveBackend handler.
func NewSaveNegroniBackend(next negroni.Handler, backendName string) negroni.Handler {
return &SaveNegroniBackend{next, backendName}
}
func (sb *SaveNegroniBackend) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
serveSaveBackend(rw, r, sb.backendName, func(crw *captureResponseWriter) {
sb.next.ServeHTTP(crw, r, next)
})
}
func serveSaveBackend(rw http.ResponseWriter, r *http.Request, backendName string, apply func(*captureResponseWriter)) {
table := GetLogDataTable(r)
table.Core[BackendName] = backendName
table.Core[BackendURL] = r.URL // note that this is *not* the original incoming URL
table.Core[BackendAddr] = r.URL.Host
crw := &captureResponseWriter{rw: rw}
start := time.Now().UTC()
apply(crw)
// use UTC to handle switchover of daylight saving correctly
table.Core[OriginDuration] = time.Now().UTC().Sub(start)
table.Core[OriginStatus] = crw.Status()
// make copy of headers so we can ensure there is no subsequent mutation during response processing
table.OriginResponse = make(http.Header)
utils.CopyHeaders(table.OriginResponse, crw.Header())
table.Core[OriginContentSize] = crw.Size()
}

View file

@ -0,0 +1,51 @@
package accesslog
import (
"net/http"
"strings"
"github.com/urfave/negroni"
)
// SaveFrontend sends the frontend name to the logger.
// These are sometimes used with a corresponding SaveBackend handler, but not always.
// For example, redirected requests don't reach a backend.
type SaveFrontend struct {
next http.Handler
frontendName string
}
// NewSaveFrontend creates a SaveFrontend handler.
func NewSaveFrontend(next http.Handler, frontendName string) http.Handler {
return &SaveFrontend{next, frontendName}
}
func (sf *SaveFrontend) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
serveSaveFrontend(r, sf.frontendName, func() {
sf.next.ServeHTTP(rw, r)
})
}
// SaveNegroniFrontend sends the frontend name to the logger.
type SaveNegroniFrontend struct {
next negroni.Handler
frontendName string
}
// NewSaveNegroniFrontend creates a SaveNegroniFrontend handler.
func NewSaveNegroniFrontend(next negroni.Handler, frontendName string) negroni.Handler {
return &SaveNegroniFrontend{next, frontendName}
}
func (sf *SaveNegroniFrontend) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
serveSaveFrontend(r, sf.frontendName, func() {
sf.next.ServeHTTP(rw, r, next)
})
}
func serveSaveFrontend(r *http.Request, frontendName string, apply func()) {
table := GetLogDataTable(r)
table.Core[FrontendName] = strings.TrimPrefix(frontendName, "frontend-")
apply()
}

View file

@ -0,0 +1,19 @@
package accesslog
import (
"net/http"
)
// SaveRetries is an implementation of RetryListener that stores RetryAttempts in the LogDataTable.
type SaveRetries struct{}
// Retried implements the RetryListener interface and will be called for each retry that happens.
func (s *SaveRetries) Retried(req *http.Request, attempt int) {
// it is the request attempt x, but the retry attempt is x-1
if attempt > 0 {
attempt--
}
table := GetLogDataTable(req)
table.Core[RetryAttempts] = attempt
}

View file

@ -0,0 +1,48 @@
package accesslog
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
)
func TestSaveRetries(t *testing.T) {
tests := []struct {
requestAttempt int
wantRetryAttemptsInLog int
}{
{
requestAttempt: 0,
wantRetryAttemptsInLog: 0,
},
{
requestAttempt: 1,
wantRetryAttemptsInLog: 0,
},
{
requestAttempt: 3,
wantRetryAttemptsInLog: 2,
},
}
for _, test := range tests {
test := test
t.Run(fmt.Sprintf("%d retries", test.requestAttempt), func(t *testing.T) {
t.Parallel()
saveRetries := &SaveRetries{}
logDataTable := &LogData{Core: make(CoreLogData)}
req := httptest.NewRequest(http.MethodGet, "/some/path", nil)
reqWithDataTable := req.WithContext(context.WithValue(req.Context(), DataTableKey, logDataTable))
saveRetries.Retried(reqWithDataTable, test.requestAttempt)
if logDataTable.Core[RetryAttempts] != test.wantRetryAttemptsInLog {
t.Errorf("got %v in logDataTable, want %v", logDataTable.Core[RetryAttempts], test.wantRetryAttemptsInLog)
}
})
}
}

View file

@ -0,0 +1,60 @@
package accesslog
import (
"context"
"net/http"
"github.com/urfave/negroni"
)
const (
clientUsernameKey key = "ClientUsername"
)
// SaveUsername sends the Username name to the access logger.
type SaveUsername struct {
next http.Handler
}
// NewSaveUsername creates a SaveUsername handler.
func NewSaveUsername(next http.Handler) http.Handler {
return &SaveUsername{next}
}
func (sf *SaveUsername) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
serveSaveUsername(r, func() {
sf.next.ServeHTTP(rw, r)
})
}
// SaveNegroniUsername adds the Username to the access logger data table.
type SaveNegroniUsername struct {
next negroni.Handler
}
// NewSaveNegroniUsername creates a SaveNegroniUsername handler.
func NewSaveNegroniUsername(next negroni.Handler) negroni.Handler {
return &SaveNegroniUsername{next}
}
func (sf *SaveNegroniUsername) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
serveSaveUsername(r, func() {
sf.next.ServeHTTP(rw, r, next)
})
}
func serveSaveUsername(r *http.Request, apply func()) {
table := GetLogDataTable(r)
username, ok := r.Context().Value(clientUsernameKey).(string)
if ok {
table.Core[ClientUsername] = username
}
apply()
}
// WithUserName adds a username to a requests' context
func WithUserName(req *http.Request, username string) *http.Request {
return req.WithContext(context.WithValue(req.Context(), clientUsernameKey, username))
}

View file

@ -0,0 +1,35 @@
package middlewares
import (
"context"
"net/http"
)
// AddPrefix is a middleware used to add prefix to an URL request
type AddPrefix struct {
Handler http.Handler
Prefix string
}
type key string
const (
// AddPrefixKey is the key within the request context used to
// store the added prefix
AddPrefixKey key = "AddPrefix"
)
func (s *AddPrefix) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r.URL.Path = s.Prefix + r.URL.Path
if r.URL.RawPath != "" {
r.URL.RawPath = s.Prefix + r.URL.RawPath
}
r.RequestURI = r.URL.RequestURI()
r = r.WithContext(context.WithValue(r.Context(), AddPrefixKey, s.Prefix))
s.Handler.ServeHTTP(w, r)
}
// SetHandler sets handler
func (s *AddPrefix) SetHandler(Handler http.Handler) {
s.Handler = Handler
}

View file

@ -0,0 +1,66 @@
package middlewares
import (
"net/http"
"testing"
"github.com/containous/traefik/testhelpers"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)
func TestAddPrefix(t *testing.T) {
logrus.SetLevel(logrus.DebugLevel)
tests := []struct {
desc string
prefix string
path string
expectedPath string
expectedRawPath string
}{
{
desc: "regular path",
prefix: "/a",
path: "/b",
expectedPath: "/a/b",
},
{
desc: "raw path is supported",
prefix: "/a",
path: "/b%2Fc",
expectedPath: "/a/b/c",
expectedRawPath: "/a/b%2Fc",
},
}
for _, test := range tests {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
var actualPath, actualRawPath, requestURI string
handler := &AddPrefix{
Prefix: test.prefix,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actualPath = r.URL.Path
actualRawPath = r.URL.RawPath
requestURI = r.RequestURI
}),
}
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
handler.ServeHTTP(nil, req)
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
assert.Equal(t, test.expectedRawPath, actualRawPath, "Unexpected raw path.")
expectedURI := test.expectedPath
if test.expectedRawPath != "" {
// go HTTP uses the raw path when existent in the RequestURI
expectedURI = test.expectedRawPath
}
assert.Equal(t, expectedURI, requestURI, "Unexpected request URI.")
})
}
}

View file

@ -0,0 +1,167 @@
package auth
import (
"fmt"
"io/ioutil"
"net/http"
"strings"
goauth "github.com/abbot/go-http-auth"
"github.com/containous/traefik/old/log"
"github.com/containous/traefik/old/middlewares/accesslog"
"github.com/containous/traefik/old/middlewares/tracing"
"github.com/containous/traefik/old/types"
"github.com/urfave/negroni"
)
// Authenticator is a middleware that provides HTTP basic and digest authentication
type Authenticator struct {
handler negroni.Handler
users map[string]string
}
type tracingAuthenticator struct {
name string
handler negroni.Handler
clientSpanKind bool
}
const (
authorizationHeader = "Authorization"
)
// NewAuthenticator builds a new Authenticator given a config
func NewAuthenticator(authConfig *types.Auth, tracingMiddleware *tracing.Tracing) (*Authenticator, error) {
if authConfig == nil {
return nil, fmt.Errorf("error creating Authenticator: auth is nil")
}
var err error
authenticator := &Authenticator{}
tracingAuth := tracingAuthenticator{}
if authConfig.Basic != nil {
authenticator.users, err = parserBasicUsers(authConfig.Basic)
if err != nil {
return nil, err
}
realm := "traefik"
if authConfig.Basic.Realm != "" {
realm = authConfig.Basic.Realm
}
basicAuth := goauth.NewBasicAuthenticator(realm, authenticator.secretBasic)
tracingAuth.handler = createAuthBasicHandler(basicAuth, authConfig)
tracingAuth.name = "Auth Basic"
tracingAuth.clientSpanKind = false
} else if authConfig.Digest != nil {
authenticator.users, err = parserDigestUsers(authConfig.Digest)
if err != nil {
return nil, err
}
digestAuth := goauth.NewDigestAuthenticator("traefik", authenticator.secretDigest)
tracingAuth.handler = createAuthDigestHandler(digestAuth, authConfig)
tracingAuth.name = "Auth Digest"
tracingAuth.clientSpanKind = false
} else if authConfig.Forward != nil {
tracingAuth.handler = createAuthForwardHandler(authConfig)
tracingAuth.name = "Auth Forward"
tracingAuth.clientSpanKind = true
}
if tracingMiddleware != nil {
authenticator.handler = tracingMiddleware.NewNegroniHandlerWrapper(tracingAuth.name, tracingAuth.handler, tracingAuth.clientSpanKind)
} else {
authenticator.handler = tracingAuth.handler
}
return authenticator, nil
}
func createAuthForwardHandler(authConfig *types.Auth) negroni.HandlerFunc {
return negroni.HandlerFunc(func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
Forward(authConfig.Forward, w, r, next)
})
}
func createAuthDigestHandler(digestAuth *goauth.DigestAuth, authConfig *types.Auth) negroni.HandlerFunc {
return negroni.HandlerFunc(func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
if username, _ := digestAuth.CheckAuth(r); username == "" {
log.Debugf("Digest auth failed")
digestAuth.RequireAuth(w, r)
} else {
log.Debugf("Digest auth succeeded")
// set username in request context
r = accesslog.WithUserName(r, username)
if authConfig.HeaderField != "" {
r.Header[authConfig.HeaderField] = []string{username}
}
if authConfig.Digest.RemoveHeader {
log.Debugf("Remove the Authorization header from the Digest auth")
r.Header.Del(authorizationHeader)
}
next.ServeHTTP(w, r)
}
})
}
func createAuthBasicHandler(basicAuth *goauth.BasicAuth, authConfig *types.Auth) negroni.HandlerFunc {
return negroni.HandlerFunc(func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
if username := basicAuth.CheckAuth(r); username == "" {
log.Debugf("Basic auth failed")
basicAuth.RequireAuth(w, r)
} else {
log.Debugf("Basic auth succeeded")
// set username in request context
r = accesslog.WithUserName(r, username)
if authConfig.HeaderField != "" {
r.Header[authConfig.HeaderField] = []string{username}
}
if authConfig.Basic.RemoveHeader {
log.Debugf("Remove the Authorization header from the Basic auth")
r.Header.Del(authorizationHeader)
}
next.ServeHTTP(w, r)
}
})
}
func getLinesFromFile(filename string) ([]string, error) {
dat, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
// Trim lines and filter out blanks
rawLines := strings.Split(string(dat), "\n")
var filteredLines []string
for _, rawLine := range rawLines {
line := strings.TrimSpace(rawLine)
if line != "" {
filteredLines = append(filteredLines, line)
}
}
return filteredLines, nil
}
func (a *Authenticator) secretBasic(user, realm string) string {
if secret, ok := a.users[user]; ok {
return secret
}
log.Debugf("User not found: %s", user)
return ""
}
func (a *Authenticator) secretDigest(user, realm string) string {
if secret, ok := a.users[user+":"+realm]; ok {
return secret
}
log.Debugf("User not found: %s:%s", user, realm)
return ""
}
func (a *Authenticator) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
a.handler.ServeHTTP(rw, r, next)
}

View file

@ -0,0 +1,297 @@
package auth
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/containous/traefik/old/middlewares/tracing"
"github.com/containous/traefik/old/types"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/negroni"
)
func TestAuthUsersFromFile(t *testing.T) {
tests := []struct {
authType string
usersStr string
userKeys []string
parserFunc func(fileName string) (map[string]string, error)
}{
{
authType: "basic",
usersStr: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
userKeys: []string{"test", "test2"},
parserFunc: func(fileName string) (map[string]string, error) {
basic := &types.Basic{
UsersFile: fileName,
}
return parserBasicUsers(basic)
},
},
{
authType: "digest",
usersStr: "test:traefik:a2688e031edb4be6a3797f3882655c05 \ntest2:traefik:518845800f9e2bfb1f1f740ec24f074e\n",
userKeys: []string{"test:traefik", "test2:traefik"},
parserFunc: func(fileName string) (map[string]string, error) {
digest := &types.Digest{
UsersFile: fileName,
}
return parserDigestUsers(digest)
},
},
}
for _, test := range tests {
test := test
t.Run(test.authType, func(t *testing.T) {
t.Parallel()
usersFile, err := ioutil.TempFile("", "auth-users")
require.NoError(t, err)
defer os.Remove(usersFile.Name())
_, err = usersFile.Write([]byte(test.usersStr))
require.NoError(t, err)
users, err := test.parserFunc(usersFile.Name())
require.NoError(t, err)
assert.Equal(t, 2, len(users), "they should be equal")
_, ok := users[test.userKeys[0]]
assert.True(t, ok, "user test should be found")
_, ok = users[test.userKeys[1]]
assert.True(t, ok, "user test2 should be found")
})
}
}
func TestBasicAuthFail(t *testing.T) {
_, err := NewAuthenticator(&types.Auth{
Basic: &types.Basic{
Users: []string{"test"},
},
}, &tracing.Tracing{})
assert.Contains(t, err.Error(), "error parsing Authenticator user", "should contains")
authMiddleware, err := NewAuthenticator(&types.Auth{
Basic: &types.Basic{
Users: []string{"test:test"},
},
}, &tracing.Tracing{})
require.NoError(t, err)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
n := negroni.New(authMiddleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
defer ts.Close()
client := &http.Client{}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, res.StatusCode, "they should be equal")
}
func TestBasicAuthSuccess(t *testing.T) {
authMiddleware, err := NewAuthenticator(&types.Auth{
Basic: &types.Basic{
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
},
}, &tracing.Tracing{})
require.NoError(t, err)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
n := negroni.New(authMiddleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
defer ts.Close()
client := &http.Client{}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, "traefik\n", string(body), "they should be equal")
}
func TestBasicRealm(t *testing.T) {
authMiddlewareDefaultRealm, errdefault := NewAuthenticator(&types.Auth{
Basic: &types.Basic{
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
},
}, &tracing.Tracing{})
require.NoError(t, errdefault)
authMiddlewareCustomRealm, errcustom := NewAuthenticator(&types.Auth{
Basic: &types.Basic{
Realm: "foobar",
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
},
}, &tracing.Tracing{})
require.NoError(t, errcustom)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
n := negroni.New(authMiddlewareDefaultRealm)
n.UseHandler(handler)
ts := httptest.NewServer(n)
defer ts.Close()
client := &http.Client{}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, "Basic realm=\"traefik\"", res.Header.Get("Www-Authenticate"), "they should be equal")
n = negroni.New(authMiddlewareCustomRealm)
n.UseHandler(handler)
ts = httptest.NewServer(n)
defer ts.Close()
client = &http.Client{}
req = testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err = client.Do(req)
require.NoError(t, err)
assert.Equal(t, "Basic realm=\"foobar\"", res.Header.Get("Www-Authenticate"), "they should be equal")
}
func TestDigestAuthFail(t *testing.T) {
_, err := NewAuthenticator(&types.Auth{
Digest: &types.Digest{
Users: []string{"test"},
},
}, &tracing.Tracing{})
assert.Contains(t, err.Error(), "error parsing Authenticator user", "should contains")
authMiddleware, err := NewAuthenticator(&types.Auth{
Digest: &types.Digest{
Users: []string{"test:traefik:test"},
},
}, &tracing.Tracing{})
require.NoError(t, err)
assert.NotNil(t, authMiddleware, "this should not be nil")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
n := negroni.New(authMiddleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
defer ts.Close()
client := &http.Client{}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, res.StatusCode, "they should be equal")
}
func TestBasicAuthUserHeader(t *testing.T) {
middleware, err := NewAuthenticator(&types.Auth{
Basic: &types.Basic{
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
},
HeaderField: "X-Webauth-User",
}, &tracing.Tracing{})
require.NoError(t, err)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "test", r.Header["X-Webauth-User"][0], "auth user should be set")
fmt.Fprintln(w, "traefik")
})
n := negroni.New(middleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
defer ts.Close()
client := &http.Client{}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, "traefik\n", string(body), "they should be equal")
}
func TestBasicAuthHeaderRemoved(t *testing.T) {
middleware, err := NewAuthenticator(&types.Auth{
Basic: &types.Basic{
RemoveHeader: true,
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
},
}, &tracing.Tracing{})
require.NoError(t, err)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get(authorizationHeader))
fmt.Fprintln(w, "traefik")
})
n := negroni.New(middleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
defer ts.Close()
client := &http.Client{}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, "traefik\n", string(body), "they should be equal")
}
func TestBasicAuthHeaderPresent(t *testing.T) {
middleware, err := NewAuthenticator(&types.Auth{
Basic: &types.Basic{
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
},
}, &tracing.Tracing{})
require.NoError(t, err)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.NotEmpty(t, r.Header.Get(authorizationHeader))
fmt.Fprintln(w, "traefik")
})
n := negroni.New(middleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
defer ts.Close()
client := &http.Client{}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, "traefik\n", string(body), "they should be equal")
}

View file

@ -0,0 +1,157 @@
package auth
import (
"io/ioutil"
"net"
"net/http"
"strings"
"github.com/containous/traefik/old/log"
"github.com/containous/traefik/old/middlewares/tracing"
"github.com/containous/traefik/old/types"
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/utils"
)
const (
xForwardedURI = "X-Forwarded-Uri"
xForwardedMethod = "X-Forwarded-Method"
)
// Forward the authentication to a external server
func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
// Ensure our request client does not follow redirects
httpClient := http.Client{
CheckRedirect: func(r *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
if config.TLS != nil {
tlsConfig, err := config.TLS.CreateTLSConfig()
if err != nil {
tracing.SetErrorAndDebugLog(r, "Unable to configure TLS to call %s. Cause %s", config.Address, err)
w.WriteHeader(http.StatusInternalServerError)
return
}
httpClient.Transport = &http.Transport{
TLSClientConfig: tlsConfig,
}
}
forwardReq, err := http.NewRequest(http.MethodGet, config.Address, http.NoBody)
tracing.LogRequest(tracing.GetSpan(r), forwardReq)
if err != nil {
tracing.SetErrorAndDebugLog(r, "Error calling %s. Cause %s", config.Address, err)
w.WriteHeader(http.StatusInternalServerError)
return
}
writeHeader(r, forwardReq, config.TrustForwardHeader)
tracing.InjectRequestHeaders(forwardReq)
forwardResponse, forwardErr := httpClient.Do(forwardReq)
if forwardErr != nil {
tracing.SetErrorAndDebugLog(r, "Error calling %s. Cause: %s", config.Address, forwardErr)
w.WriteHeader(http.StatusInternalServerError)
return
}
body, readError := ioutil.ReadAll(forwardResponse.Body)
if readError != nil {
tracing.SetErrorAndDebugLog(r, "Error reading body %s. Cause: %s", config.Address, readError)
w.WriteHeader(http.StatusInternalServerError)
return
}
defer forwardResponse.Body.Close()
// Pass the forward response's body and selected headers if it
// didn't return a response within the range of [200, 300).
if forwardResponse.StatusCode < http.StatusOK || forwardResponse.StatusCode >= http.StatusMultipleChoices {
log.Debugf("Remote error %s. StatusCode: %d", config.Address, forwardResponse.StatusCode)
utils.CopyHeaders(w.Header(), forwardResponse.Header)
utils.RemoveHeaders(w.Header(), forward.HopHeaders...)
// Grab the location header, if any.
redirectURL, err := forwardResponse.Location()
if err != nil {
if err != http.ErrNoLocation {
tracing.SetErrorAndDebugLog(r, "Error reading response location header %s. Cause: %s", config.Address, err)
w.WriteHeader(http.StatusInternalServerError)
return
}
} else if redirectURL.String() != "" {
// Set the location in our response if one was sent back.
w.Header().Set("Location", redirectURL.String())
}
tracing.LogResponseCode(tracing.GetSpan(r), forwardResponse.StatusCode)
w.WriteHeader(forwardResponse.StatusCode)
if _, err = w.Write(body); err != nil {
log.Error(err)
}
return
}
for _, headerName := range config.AuthResponseHeaders {
r.Header.Set(headerName, forwardResponse.Header.Get(headerName))
}
r.RequestURI = r.URL.RequestURI()
next(w, r)
}
func writeHeader(req *http.Request, forwardReq *http.Request, trustForwardHeader bool) {
utils.CopyHeaders(forwardReq.Header, req.Header)
utils.RemoveHeaders(forwardReq.Header, forward.HopHeaders...)
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if trustForwardHeader {
if prior, ok := req.Header[forward.XForwardedFor]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
}
forwardReq.Header.Set(forward.XForwardedFor, clientIP)
}
if xMethod := req.Header.Get(xForwardedMethod); xMethod != "" && trustForwardHeader {
forwardReq.Header.Set(xForwardedMethod, xMethod)
} else if req.Method != "" {
forwardReq.Header.Set(xForwardedMethod, req.Method)
} else {
forwardReq.Header.Del(xForwardedMethod)
}
if xfp := req.Header.Get(forward.XForwardedProto); xfp != "" && trustForwardHeader {
forwardReq.Header.Set(forward.XForwardedProto, xfp)
} else if req.TLS != nil {
forwardReq.Header.Set(forward.XForwardedProto, "https")
} else {
forwardReq.Header.Set(forward.XForwardedProto, "http")
}
if xfp := req.Header.Get(forward.XForwardedPort); xfp != "" && trustForwardHeader {
forwardReq.Header.Set(forward.XForwardedPort, xfp)
}
if xfh := req.Header.Get(forward.XForwardedHost); xfh != "" && trustForwardHeader {
forwardReq.Header.Set(forward.XForwardedHost, xfh)
} else if req.Host != "" {
forwardReq.Header.Set(forward.XForwardedHost, req.Host)
} else {
forwardReq.Header.Del(forward.XForwardedHost)
}
if xfURI := req.Header.Get(xForwardedURI); xfURI != "" && trustForwardHeader {
forwardReq.Header.Set(xForwardedURI, xfURI)
} else if req.URL.RequestURI() != "" {
forwardReq.Header.Set(xForwardedURI, req.URL.RequestURI())
} else {
forwardReq.Header.Del(xForwardedURI)
}
}

View file

@ -0,0 +1,392 @@
package auth
import (
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/old/middlewares/tracing"
"github.com/containous/traefik/old/types"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/negroni"
"github.com/vulcand/oxy/forward"
)
func TestForwardAuthFail(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Forbidden", http.StatusForbidden)
}))
defer server.Close()
middleware, err := NewAuthenticator(&types.Auth{
Forward: &types.Forward{
Address: server.URL,
},
}, &tracing.Tracing{})
assert.NoError(t, err, "there should be no error")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
n := negroni.New(middleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := http.DefaultClient.Do(req)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, http.StatusForbidden, res.StatusCode, "they should be equal")
body, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, "Forbidden\n", string(body), "they should be equal")
}
func TestForwardAuthSuccess(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Auth-User", "user@example.com")
w.Header().Set("X-Auth-Secret", "secret")
fmt.Fprintln(w, "Success")
}))
defer server.Close()
middleware, err := NewAuthenticator(&types.Auth{
Forward: &types.Forward{
Address: server.URL,
AuthResponseHeaders: []string{"X-Auth-User"},
},
}, &tracing.Tracing{})
assert.NoError(t, err, "there should be no error")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "user@example.com", r.Header.Get("X-Auth-User"))
assert.Empty(t, r.Header.Get("X-Auth-Secret"))
fmt.Fprintln(w, "traefik")
})
n := negroni.New(middleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := http.DefaultClient.Do(req)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
body, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, "traefik\n", string(body), "they should be equal")
}
func TestForwardAuthRedirect(t *testing.T) {
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "http://example.com/redirect-test", http.StatusFound)
}))
defer authTs.Close()
authMiddleware, err := NewAuthenticator(&types.Auth{
Forward: &types.Forward{
Address: authTs.URL,
},
}, &tracing.Tracing{})
assert.NoError(t, err, "there should be no error")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
n := negroni.New(authMiddleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
defer ts.Close()
client := &http.Client{
CheckRedirect: func(r *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := client.Do(req)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, http.StatusFound, res.StatusCode, "they should be equal")
location, err := res.Location()
assert.NoError(t, err, "there should be no error")
assert.Equal(t, "http://example.com/redirect-test", location.String(), "they should be equal")
body, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err, "there should be no error")
assert.NotEmpty(t, string(body), "there should be something in the body")
}
func TestForwardAuthRemoveHopByHopHeaders(t *testing.T) {
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
headers := w.Header()
for _, header := range forward.HopHeaders {
if header == forward.TransferEncoding {
headers.Add(header, "identity")
} else {
headers.Add(header, "test")
}
}
http.Redirect(w, r, "http://example.com/redirect-test", http.StatusFound)
}))
defer authTs.Close()
authMiddleware, err := NewAuthenticator(&types.Auth{
Forward: &types.Forward{
Address: authTs.URL,
},
}, &tracing.Tracing{})
assert.NoError(t, err, "there should be no error")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
n := negroni.New(authMiddleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
defer ts.Close()
client := &http.Client{
CheckRedirect: func(r *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := client.Do(req)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, http.StatusFound, res.StatusCode, "they should be equal")
for _, header := range forward.HopHeaders {
assert.Equal(t, "", res.Header.Get(header), "hop-by-hop header '%s' mustn't be set", header)
}
location, err := res.Location()
assert.NoError(t, err, "there should be no error")
assert.Equal(t, "http://example.com/redirect-test", location.String(), "they should be equal")
body, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err, "there should be no error")
assert.NotEmpty(t, string(body), "there should be something in the body")
}
func TestForwardAuthFailResponseHeaders(t *testing.T) {
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie := &http.Cookie{Name: "example", Value: "testing", Path: "/"}
http.SetCookie(w, cookie)
w.Header().Add("X-Foo", "bar")
http.Error(w, "Forbidden", http.StatusForbidden)
}))
defer authTs.Close()
authMiddleware, err := NewAuthenticator(&types.Auth{
Forward: &types.Forward{
Address: authTs.URL,
},
}, &tracing.Tracing{})
assert.NoError(t, err, "there should be no error")
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
n := negroni.New(authMiddleware)
n.UseHandler(handler)
ts := httptest.NewServer(n)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
client := &http.Client{}
res, err := client.Do(req)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, http.StatusForbidden, res.StatusCode, "they should be equal")
require.Len(t, res.Cookies(), 1)
for _, cookie := range res.Cookies() {
assert.Equal(t, "testing", cookie.Value, "they should be equal")
}
expectedHeaders := http.Header{
"Content-Length": []string{"10"},
"Content-Type": []string{"text/plain; charset=utf-8"},
"X-Foo": []string{"bar"},
"Set-Cookie": []string{"example=testing; Path=/"},
"X-Content-Type-Options": []string{"nosniff"},
}
assert.Len(t, res.Header, 6)
for key, value := range expectedHeaders {
assert.Equal(t, value, res.Header[key])
}
body, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, "Forbidden\n", string(body), "they should be equal")
}
func Test_writeHeader(t *testing.T) {
testCases := []struct {
name string
headers map[string]string
trustForwardHeader bool
emptyHost bool
expectedHeaders map[string]string
checkForUnexpectedHeaders bool
}{
{
name: "trust Forward Header",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
trustForwardHeader: true,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
},
{
name: "not trust Forward Header",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
trustForwardHeader: false,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "foo.bar",
},
},
{
name: "trust Forward Header with empty Host",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
trustForwardHeader: true,
emptyHost: true,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
},
{
name: "not trust Forward Header with empty Host",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
trustForwardHeader: false,
emptyHost: true,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "",
},
},
{
name: "trust Forward Header with forwarded URI",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
"X-Forwarded-Uri": "/forward?q=1",
},
trustForwardHeader: true,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
"X-Forwarded-Uri": "/forward?q=1",
},
},
{
name: "not trust Forward Header with forward requested URI",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
"X-Forwarded-Uri": "/forward?q=1",
},
trustForwardHeader: false,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "foo.bar",
"X-Forwarded-Uri": "/path?q=1",
},
}, {
name: "trust Forward Header with forwarded request Method",
headers: map[string]string{
"X-Forwarded-Method": "OPTIONS",
},
trustForwardHeader: true,
expectedHeaders: map[string]string{
"X-Forwarded-Method": "OPTIONS",
},
},
{
name: "not trust Forward Header with forward request Method",
headers: map[string]string{
"X-Forwarded-Method": "OPTIONS",
},
trustForwardHeader: false,
expectedHeaders: map[string]string{
"X-Forwarded-Method": "GET",
},
},
{
name: "remove hop-by-hop headers",
headers: map[string]string{
forward.Connection: "Connection",
forward.KeepAlive: "KeepAlive",
forward.ProxyAuthenticate: "ProxyAuthenticate",
forward.ProxyAuthorization: "ProxyAuthorization",
forward.Te: "Te",
forward.Trailers: "Trailers",
forward.TransferEncoding: "TransferEncoding",
forward.Upgrade: "Upgrade",
"X-CustomHeader": "CustomHeader",
},
trustForwardHeader: false,
expectedHeaders: map[string]string{
"X-CustomHeader": "CustomHeader",
"X-Forwarded-Proto": "http",
"X-Forwarded-Host": "foo.bar",
"X-Forwarded-Uri": "/path?q=1",
"X-Forwarded-Method": "GET",
},
checkForUnexpectedHeaders: true,
},
}
for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/path?q=1", nil)
for key, value := range test.headers {
req.Header.Set(key, value)
}
if test.emptyHost {
req.Host = ""
}
forwardReq := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/path?q=1", nil)
writeHeader(req, forwardReq, test.trustForwardHeader)
actualHeaders := forwardReq.Header
expectedHeaders := test.expectedHeaders
for key, value := range expectedHeaders {
assert.Equal(t, value, actualHeaders.Get(key))
actualHeaders.Del(key)
}
if test.checkForUnexpectedHeaders {
for key := range actualHeaders {
assert.Fail(t, "Unexpected header found", key)
}
}
})
}
}

View file

@ -0,0 +1,48 @@
package auth
import (
"fmt"
"strings"
"github.com/containous/traefik/old/types"
)
func parserBasicUsers(basic *types.Basic) (map[string]string, error) {
var userStrs []string
if basic.UsersFile != "" {
var err error
if userStrs, err = getLinesFromFile(basic.UsersFile); err != nil {
return nil, err
}
}
userStrs = append(basic.Users, userStrs...)
userMap := make(map[string]string)
for _, user := range userStrs {
split := strings.Split(user, ":")
if len(split) != 2 {
return nil, fmt.Errorf("error parsing Authenticator user: %v", user)
}
userMap[split[0]] = split[1]
}
return userMap, nil
}
func parserDigestUsers(digest *types.Digest) (map[string]string, error) {
var userStrs []string
if digest.UsersFile != "" {
var err error
if userStrs, err = getLinesFromFile(digest.UsersFile); err != nil {
return nil, err
}
}
userStrs = append(digest.Users, userStrs...)
userMap := make(map[string]string)
for _, user := range userStrs {
split := strings.Split(user, ":")
if len(split) != 3 {
return nil, fmt.Errorf("error parsing Authenticator user: %v", user)
}
userMap[split[0]+":"+split[1]] = split[2]
}
return userMap, nil
}

View file

@ -0,0 +1,40 @@
package middlewares
import (
"net/http"
"github.com/containous/traefik/old/log"
"github.com/containous/traefik/old/middlewares/tracing"
"github.com/vulcand/oxy/cbreaker"
)
// CircuitBreaker holds the oxy circuit breaker.
type CircuitBreaker struct {
circuitBreaker *cbreaker.CircuitBreaker
}
// NewCircuitBreaker returns a new CircuitBreaker.
func NewCircuitBreaker(next http.Handler, expression string, options ...cbreaker.CircuitBreakerOption) (*CircuitBreaker, error) {
circuitBreaker, err := cbreaker.New(next, expression, options...)
if err != nil {
return nil, err
}
return &CircuitBreaker{circuitBreaker}, nil
}
// NewCircuitBreakerOptions returns a new CircuitBreakerOption
func NewCircuitBreakerOptions(expression string) cbreaker.CircuitBreakerOption {
return cbreaker.Fallback(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tracing.LogEventf(r, "blocked by circuit-breaker (%q)", expression)
w.WriteHeader(http.StatusServiceUnavailable)
if _, err := w.Write([]byte(http.StatusText(http.StatusServiceUnavailable))); err != nil {
log.Error(err)
}
}))
}
func (cb *CircuitBreaker) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
cb.circuitBreaker.ServeHTTP(rw, r)
}

View file

@ -0,0 +1,33 @@
package middlewares
import (
"compress/gzip"
"net/http"
"strings"
"github.com/NYTimes/gziphandler"
"github.com/containous/traefik/old/log"
)
// Compress is a middleware that allows to compress the response
type Compress struct{}
// ServeHTTP is a function used by Negroni
func (c *Compress) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
contentType := r.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/grpc") {
next.ServeHTTP(rw, r)
} else {
gzipHandler(next).ServeHTTP(rw, r)
}
}
func gzipHandler(h http.Handler) http.Handler {
wrapper, err := gziphandler.GzipHandlerWithOpts(
gziphandler.CompressionLevel(gzip.DefaultCompression),
gziphandler.MinSize(gziphandler.DefaultMinSize))
if err != nil {
log.Error(err)
}
return wrapper(h)
}

View file

@ -0,0 +1,248 @@
package middlewares
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"github.com/NYTimes/gziphandler"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/negroni"
)
const (
acceptEncodingHeader = "Accept-Encoding"
contentEncodingHeader = "Content-Encoding"
contentTypeHeader = "Content-Type"
varyHeader = "Vary"
gzipValue = "gzip"
)
func TestShouldCompressWhenNoContentEncodingHeader(t *testing.T) {
handler := &Compress{}
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
baseBody := generateBytes(gziphandler.DefaultMinSize)
next := func(rw http.ResponseWriter, r *http.Request) {
_, err := rw.Write(baseBody)
assert.NoError(t, err)
}
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req, next)
assert.Equal(t, gzipValue, rw.Header().Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, rw.Header().Get(varyHeader))
if assert.ObjectsAreEqualValues(rw.Body.Bytes(), baseBody) {
assert.Fail(t, "expected a compressed body", "got %v", rw.Body.Bytes())
}
}
func TestShouldNotCompressWhenContentEncodingHeader(t *testing.T) {
handler := &Compress{}
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
fakeCompressedBody := generateBytes(gziphandler.DefaultMinSize)
next := func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(contentEncodingHeader, gzipValue)
rw.Header().Add(varyHeader, acceptEncodingHeader)
rw.Write(fakeCompressedBody)
}
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req, next)
assert.Equal(t, gzipValue, rw.Header().Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, rw.Header().Get(varyHeader))
assert.EqualValues(t, rw.Body.Bytes(), fakeCompressedBody)
}
func TestShouldNotCompressWhenNoAcceptEncodingHeader(t *testing.T) {
handler := &Compress{}
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
fakeBody := generateBytes(gziphandler.DefaultMinSize)
next := func(rw http.ResponseWriter, r *http.Request) {
rw.Write(fakeBody)
}
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req, next)
assert.Empty(t, rw.Header().Get(contentEncodingHeader))
assert.EqualValues(t, rw.Body.Bytes(), fakeBody)
}
func TestShouldNotCompressWhenGRPC(t *testing.T) {
handler := &Compress{}
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
req.Header.Add(contentTypeHeader, "application/grpc")
baseBody := generateBytes(gziphandler.DefaultMinSize)
next := func(rw http.ResponseWriter, r *http.Request) {
rw.Write(baseBody)
}
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req, next)
assert.Empty(t, rw.Header().Get(acceptEncodingHeader))
assert.Empty(t, rw.Header().Get(contentEncodingHeader))
assert.EqualValues(t, rw.Body.Bytes(), baseBody)
}
func TestIntegrationShouldNotCompress(t *testing.T) {
fakeCompressedBody := generateBytes(100000)
comp := &Compress{}
testCases := []struct {
name string
handler func(rw http.ResponseWriter, r *http.Request)
expectedStatusCode int
}{
{
name: "when content already compressed",
handler: func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(contentEncodingHeader, gzipValue)
rw.Header().Add(varyHeader, acceptEncodingHeader)
rw.Write(fakeCompressedBody)
},
expectedStatusCode: http.StatusOK,
},
{
name: "when content already compressed and status code Created",
handler: func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(contentEncodingHeader, gzipValue)
rw.Header().Add(varyHeader, acceptEncodingHeader)
rw.WriteHeader(http.StatusCreated)
rw.Write(fakeCompressedBody)
},
expectedStatusCode: http.StatusCreated,
},
}
for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {
negro := negroni.New(comp)
negro.UseHandlerFunc(test.handler)
ts := httptest.NewServer(negro)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, test.expectedStatusCode, resp.StatusCode)
assert.Equal(t, gzipValue, resp.Header.Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, resp.Header.Get(varyHeader))
body, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
assert.EqualValues(t, fakeCompressedBody, body)
})
}
}
func TestShouldWriteHeaderWhenFlush(t *testing.T) {
comp := &Compress{}
negro := negroni.New(comp)
negro.UseHandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.Header().Add(contentEncodingHeader, gzipValue)
rw.Header().Add(varyHeader, acceptEncodingHeader)
rw.WriteHeader(http.StatusUnauthorized)
rw.(http.Flusher).Flush()
rw.Write([]byte("short"))
})
ts := httptest.NewServer(negro)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.Equal(t, gzipValue, resp.Header.Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, resp.Header.Get(varyHeader))
}
func TestIntegrationShouldCompress(t *testing.T) {
fakeBody := generateBytes(100000)
testCases := []struct {
name string
handler func(rw http.ResponseWriter, r *http.Request)
expectedStatusCode int
}{
{
name: "when AcceptEncoding header is present",
handler: func(rw http.ResponseWriter, r *http.Request) {
rw.Write(fakeBody)
},
expectedStatusCode: http.StatusOK,
},
{
name: "when AcceptEncoding header is present and status code Created",
handler: func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusCreated)
rw.Write(fakeBody)
},
expectedStatusCode: http.StatusCreated,
},
}
for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {
comp := &Compress{}
negro := negroni.New(comp)
negro.UseHandlerFunc(test.handler)
ts := httptest.NewServer(negro)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.Header.Add(acceptEncodingHeader, gzipValue)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, test.expectedStatusCode, resp.StatusCode)
assert.Equal(t, gzipValue, resp.Header.Get(contentEncodingHeader))
assert.Equal(t, acceptEncodingHeader, resp.Header.Get(varyHeader))
body, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
if assert.ObjectsAreEqualValues(body, fakeBody) {
assert.Fail(t, "expected a compressed body", "got %v", body)
}
})
}
}
func generateBytes(len int) []byte {
var value []byte
for i := 0; i < len; i++ {
value = append(value, 0x61+byte(i))
}
return value
}

View file

@ -0,0 +1,30 @@
package middlewares
import (
"net/http"
"github.com/containous/traefik/healthcheck"
)
// EmptyBackendHandler is a middlware that checks whether the current Backend
// has at least one active Server in respect to the healthchecks and if this
// is not the case, it will stop the middleware chain and respond with 503.
type EmptyBackendHandler struct {
next healthcheck.BalancerHandler
}
// NewEmptyBackendHandler creates a new EmptyBackendHandler instance.
func NewEmptyBackendHandler(lb healthcheck.BalancerHandler) *EmptyBackendHandler {
return &EmptyBackendHandler{next: lb}
}
// ServeHTTP responds with 503 when there is no active Server and otherwise
// invokes the next handler in the middleware chain.
func (h *EmptyBackendHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
if len(h.next.Servers()) == 0 {
rw.WriteHeader(http.StatusServiceUnavailable)
rw.Write([]byte(http.StatusText(http.StatusServiceUnavailable)))
} else {
h.next.ServeHTTP(rw, r)
}
}

View file

@ -0,0 +1,83 @@
package middlewares
import (
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/containous/traefik/testhelpers"
"github.com/vulcand/oxy/roundrobin"
)
func TestEmptyBackendHandler(t *testing.T) {
tests := []struct {
amountServer int
wantStatusCode int
}{
{
amountServer: 0,
wantStatusCode: http.StatusServiceUnavailable,
},
{
amountServer: 1,
wantStatusCode: http.StatusOK,
},
}
for _, test := range tests {
test := test
t.Run(fmt.Sprintf("amount servers %d", test.amountServer), func(t *testing.T) {
t.Parallel()
handler := NewEmptyBackendHandler(&healthCheckLoadBalancer{test.amountServer})
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
handler.ServeHTTP(recorder, req)
if recorder.Result().StatusCode != test.wantStatusCode {
t.Errorf("Received status code %d, wanted %d", recorder.Result().StatusCode, test.wantStatusCode)
}
})
}
}
type healthCheckLoadBalancer struct {
amountServer int
}
func (lb *healthCheckLoadBalancer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}
func (lb *healthCheckLoadBalancer) Servers() []*url.URL {
servers := make([]*url.URL, lb.amountServer)
for i := 0; i < lb.amountServer; i++ {
servers = append(servers, testhelpers.MustParseURL("http://localhost"))
}
return servers
}
func (lb *healthCheckLoadBalancer) RemoveServer(u *url.URL) error {
return nil
}
func (lb *healthCheckLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error {
return nil
}
func (lb *healthCheckLoadBalancer) ServerWeight(u *url.URL) (int, bool) {
return 0, false
}
func (lb *healthCheckLoadBalancer) NextServer() (*url.URL, error) {
return nil, nil
}
func (lb *healthCheckLoadBalancer) Next() http.Handler {
return nil
}

View file

@ -0,0 +1,236 @@
package errorpages
import (
"bufio"
"bytes"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"github.com/containous/traefik/old/log"
"github.com/containous/traefik/old/middlewares"
"github.com/containous/traefik/old/types"
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/utils"
)
// Compile time validation that the response recorder implements http interfaces correctly.
var _ middlewares.Stateful = &responseRecorderWithCloseNotify{}
// Handler is a middleware that provides the custom error pages
type Handler struct {
BackendName string
backendHandler http.Handler
httpCodeRanges types.HTTPCodeRanges
backendURL string
backendQuery string
FallbackURL string // Deprecated
}
// NewHandler initializes the utils.ErrorHandler for the custom error pages
func NewHandler(errorPage *types.ErrorPage, backendName string) (*Handler, error) {
if len(backendName) == 0 {
return nil, errors.New("error pages: backend name is mandatory ")
}
httpCodeRanges, err := types.NewHTTPCodeRanges(errorPage.Status)
if err != nil {
return nil, err
}
return &Handler{
BackendName: backendName,
httpCodeRanges: httpCodeRanges,
backendQuery: errorPage.Query,
backendURL: "http://0.0.0.0",
}, nil
}
// PostLoad adds backend handler if available
func (h *Handler) PostLoad(backendHandler http.Handler) error {
if backendHandler == nil {
fwd, err := forward.New()
if err != nil {
return err
}
h.backendHandler = fwd
h.backendURL = h.FallbackURL
} else {
h.backendHandler = backendHandler
}
return nil
}
func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
if h.backendHandler == nil {
log.Error("Error pages: no backend handler.")
next.ServeHTTP(w, req)
return
}
recorder := newResponseRecorder(w)
next.ServeHTTP(recorder, req)
// check the recorder code against the configured http status code ranges
for _, block := range h.httpCodeRanges {
if recorder.GetCode() >= block[0] && recorder.GetCode() <= block[1] {
log.Errorf("Caught HTTP Status Code %d, returning error page", recorder.GetCode())
var query string
if len(h.backendQuery) > 0 {
query = "/" + strings.TrimPrefix(h.backendQuery, "/")
query = strings.Replace(query, "{status}", strconv.Itoa(recorder.GetCode()), -1)
}
pageReq, err := newRequest(h.backendURL + query)
if err != nil {
log.Error(err)
w.WriteHeader(recorder.GetCode())
fmt.Fprint(w, http.StatusText(recorder.GetCode()))
return
}
recorderErrorPage := newResponseRecorder(w)
utils.CopyHeaders(pageReq.Header, req.Header)
h.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context()))
utils.CopyHeaders(w.Header(), recorderErrorPage.Header())
w.WriteHeader(recorder.GetCode())
if _, err = w.Write(recorderErrorPage.GetBody().Bytes()); err != nil {
log.Error(err)
}
return
}
}
// did not catch a configured status code so proceed with the request
utils.CopyHeaders(w.Header(), recorder.Header())
w.WriteHeader(recorder.GetCode())
w.Write(recorder.GetBody().Bytes())
}
func newRequest(baseURL string) (*http.Request, error) {
u, err := url.Parse(baseURL)
if err != nil {
return nil, fmt.Errorf("error pages: error when parse URL: %v", err)
}
req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody)
if err != nil {
return nil, fmt.Errorf("error pages: error when create query: %v", err)
}
req.RequestURI = u.RequestURI()
return req, nil
}
type responseRecorder interface {
http.ResponseWriter
http.Flusher
GetCode() int
GetBody() *bytes.Buffer
IsStreamingResponseStarted() bool
}
// newResponseRecorder returns an initialized responseRecorder.
func newResponseRecorder(rw http.ResponseWriter) responseRecorder {
recorder := &responseRecorderWithoutCloseNotify{
HeaderMap: make(http.Header),
Body: new(bytes.Buffer),
Code: http.StatusOK,
responseWriter: rw,
}
if _, ok := rw.(http.CloseNotifier); ok {
return &responseRecorderWithCloseNotify{recorder}
}
return recorder
}
// responseRecorderWithoutCloseNotify is an implementation of http.ResponseWriter that
// records its mutations for later inspection.
type responseRecorderWithoutCloseNotify struct {
Code int // the HTTP response code from WriteHeader
HeaderMap http.Header // the HTTP response headers
Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
responseWriter http.ResponseWriter
err error
streamingResponseStarted bool
}
type responseRecorderWithCloseNotify struct {
*responseRecorderWithoutCloseNotify
}
// CloseNotify returns a channel that receives at most a
// single value (true) when the client connection has gone away.
func (r *responseRecorderWithCloseNotify) CloseNotify() <-chan bool {
return r.responseWriter.(http.CloseNotifier).CloseNotify()
}
// Header returns the response headers.
func (r *responseRecorderWithoutCloseNotify) Header() http.Header {
if r.HeaderMap == nil {
r.HeaderMap = make(http.Header)
}
return r.HeaderMap
}
func (r *responseRecorderWithoutCloseNotify) GetCode() int {
return r.Code
}
func (r *responseRecorderWithoutCloseNotify) GetBody() *bytes.Buffer {
return r.Body
}
func (r *responseRecorderWithoutCloseNotify) IsStreamingResponseStarted() bool {
return r.streamingResponseStarted
}
// Write always succeeds and writes to rw.Body, if not nil.
func (r *responseRecorderWithoutCloseNotify) Write(buf []byte) (int, error) {
if r.err != nil {
return 0, r.err
}
return r.Body.Write(buf)
}
// WriteHeader sets rw.Code.
func (r *responseRecorderWithoutCloseNotify) WriteHeader(code int) {
r.Code = code
}
// Hijack hijacks the connection
func (r *responseRecorderWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return r.responseWriter.(http.Hijacker).Hijack()
}
// Flush sends any buffered data to the client.
func (r *responseRecorderWithoutCloseNotify) Flush() {
if !r.streamingResponseStarted {
utils.CopyHeaders(r.responseWriter.Header(), r.Header())
r.responseWriter.WriteHeader(r.Code)
r.streamingResponseStarted = true
}
_, err := r.responseWriter.Write(r.Body.Bytes())
if err != nil {
log.Errorf("Error writing response in responseRecorder: %v", err)
r.err = err
}
r.Body.Reset()
if flusher, ok := r.responseWriter.(http.Flusher); ok {
flusher.Flush()
}
}

View file

@ -0,0 +1,384 @@
package errorpages
import (
"fmt"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"github.com/containous/traefik/old/types"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/urfave/negroni"
)
func TestHandler(t *testing.T) {
testCases := []struct {
desc string
errorPage *types.ErrorPage
backendCode int
backendErrorHandler http.HandlerFunc
validate func(t *testing.T, recorder *httptest.ResponseRecorder)
}{
{
desc: "no error",
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusOK,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusOK))
},
},
{
desc: "in the range",
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusInternalServerError,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusInternalServerError, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My error page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
{
desc: "not in the range",
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusBadGateway,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusBadGateway, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusBadGateway))
assert.NotContains(t, recorder.Body.String(), "Test Server", "Should return the oops page since we have not configured the 502 code")
},
},
{
desc: "query replacement",
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503-503"}},
backendCode: http.StatusServiceUnavailable,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI == "/503" {
fmt.Fprintln(w, "My 503 page.")
} else {
fmt.Fprintln(w, "Failed")
}
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My 503 page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
{
desc: "Single code",
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503"}},
backendCode: http.StatusServiceUnavailable,
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.RequestURI == "/503" {
fmt.Fprintln(w, "My 503 page.")
} else {
fmt.Fprintln(w, "Failed")
}
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My 503 page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
errorPageHandler, err := NewHandler(test.errorPage, "test")
require.NoError(t, err)
errorPageHandler.backendHandler = test.backendErrorHandler
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(test.backendCode)
fmt.Fprintln(w, http.StatusText(test.backendCode))
})
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost/test", nil)
n := negroni.New()
n.Use(errorPageHandler)
n.UseHandler(handler)
recorder := httptest.NewRecorder()
n.ServeHTTP(recorder, req)
test.validate(t, recorder)
})
}
}
func TestHandlerOldWay(t *testing.T) {
testCases := []struct {
desc string
errorPage *types.ErrorPage
backendCode int
errorPageForwarder http.HandlerFunc
validate func(t *testing.T, recorder *httptest.ResponseRecorder)
}{
{
desc: "no error",
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusOK,
errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "OK")
},
},
{
desc: "in the range",
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusInternalServerError,
errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusInternalServerError, recorder.Code)
assert.Contains(t, recorder.Body.String(), "My error page.")
assert.NotContains(t, recorder.Body.String(), http.StatusText(http.StatusInternalServerError), "Should not return the oops page")
},
},
{
desc: "not in the range",
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusBadGateway,
errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "My error page.")
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusBadGateway, recorder.Code)
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusBadGateway))
assert.NotContains(t, recorder.Body.String(), "My error page.", "Should return the oops page since we have not configured the 502 code")
},
},
{
desc: "query replacement",
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503-503"}},
backendCode: http.StatusServiceUnavailable,
errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.RequestURI() == "/"+strconv.Itoa(503) {
fmt.Fprintln(w, "My 503 page.")
} else {
fmt.Fprintln(w, "Failed")
}
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My 503 page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
{
desc: "Single code",
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503"}},
backendCode: http.StatusServiceUnavailable,
errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.RequestURI() == "/"+strconv.Itoa(503) {
fmt.Fprintln(w, "My 503 page.")
} else {
fmt.Fprintln(w, "Failed")
}
}),
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My 503 page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
}
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost/test", nil)
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
errorPageHandler, err := NewHandler(test.errorPage, "test")
require.NoError(t, err)
errorPageHandler.FallbackURL = "http://localhost"
err = errorPageHandler.PostLoad(test.errorPageForwarder)
require.NoError(t, err)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(test.backendCode)
fmt.Fprintln(w, http.StatusText(test.backendCode))
})
n := negroni.New()
n.Use(errorPageHandler)
n.UseHandler(handler)
recorder := httptest.NewRecorder()
n.ServeHTTP(recorder, req)
test.validate(t, recorder)
})
}
}
func TestHandlerOldWayIntegration(t *testing.T) {
errorPagesServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.RequestURI() == "/503" {
fmt.Fprintln(w, "My 503 page.")
} else {
fmt.Fprintln(w, "Test Server")
}
}))
defer errorPagesServer.Close()
testCases := []struct {
desc string
errorPage *types.ErrorPage
backendCode int
validate func(t *testing.T, recorder *httptest.ResponseRecorder)
}{
{
desc: "no error",
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusOK,
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "OK")
},
},
{
desc: "in the range",
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusInternalServerError,
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusInternalServerError, recorder.Code)
assert.Contains(t, recorder.Body.String(), "Test Server")
assert.NotContains(t, recorder.Body.String(), http.StatusText(http.StatusInternalServerError), "Should not return the oops page")
},
},
{
desc: "not in the range",
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
backendCode: http.StatusBadGateway,
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusBadGateway, recorder.Code)
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusBadGateway))
assert.NotContains(t, recorder.Body.String(), "Test Server", "Should return the oops page since we have not configured the 502 code")
},
},
{
desc: "query replacement",
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503-503"}},
backendCode: http.StatusServiceUnavailable,
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My 503 page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
{
desc: "Single code",
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503"}},
backendCode: http.StatusServiceUnavailable,
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
assert.Contains(t, recorder.Body.String(), "My 503 page.")
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
},
},
}
req := testhelpers.MustNewRequest(http.MethodGet, errorPagesServer.URL+"/test", nil)
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
errorPageHandler, err := NewHandler(test.errorPage, "test")
require.NoError(t, err)
errorPageHandler.FallbackURL = errorPagesServer.URL
err = errorPageHandler.PostLoad(nil)
require.NoError(t, err)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(test.backendCode)
fmt.Fprintln(w, http.StatusText(test.backendCode))
})
n := negroni.New()
n.Use(errorPageHandler)
n.UseHandler(handler)
recorder := httptest.NewRecorder()
n.ServeHTTP(recorder, req)
test.validate(t, recorder)
})
}
}
func TestNewResponseRecorder(t *testing.T) {
testCases := []struct {
desc string
rw http.ResponseWriter
expected http.ResponseWriter
}{
{
desc: "Without Close Notify",
rw: httptest.NewRecorder(),
expected: &responseRecorderWithoutCloseNotify{},
},
{
desc: "With Close Notify",
rw: &mockRWCloseNotify{},
expected: &responseRecorderWithCloseNotify{},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
rec := newResponseRecorder(test.rw)
assert.IsType(t, rec, test.expected)
})
}
}
type mockRWCloseNotify struct{}
func (m *mockRWCloseNotify) CloseNotify() <-chan bool {
panic("implement me")
}
func (m *mockRWCloseNotify) Header() http.Header {
panic("implement me")
}
func (m *mockRWCloseNotify) Write([]byte) (int, error) {
panic("implement me")
}
func (m *mockRWCloseNotify) WriteHeader(int) {
panic("implement me")
}

View file

@ -0,0 +1,52 @@
package forwardedheaders
import (
"net/http"
"github.com/containous/traefik/ip"
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/utils"
)
// XForwarded filter for XForwarded headers
type XForwarded struct {
insecure bool
trustedIps []string
ipChecker *ip.Checker
}
// NewXforwarded creates a new XForwarded
func NewXforwarded(insecure bool, trustedIps []string) (*XForwarded, error) {
var ipChecker *ip.Checker
if len(trustedIps) > 0 {
var err error
ipChecker, err = ip.NewChecker(trustedIps)
if err != nil {
return nil, err
}
}
return &XForwarded{
insecure: insecure,
trustedIps: trustedIps,
ipChecker: ipChecker,
}, nil
}
func (x *XForwarded) isTrustedIP(ip string) bool {
if x.ipChecker == nil {
return false
}
return x.ipChecker.IsAuthorized(ip) == nil
}
func (x *XForwarded) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
if !x.insecure && !x.isTrustedIP(r.RemoteAddr) {
utils.RemoveHeaders(r.Header, forward.XHeaders...)
}
// If there is a next, call it.
if next != nil {
next(w, r)
}
}

View file

@ -0,0 +1,128 @@
package forwardedheaders
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestServeHTTP(t *testing.T) {
testCases := []struct {
desc string
insecure bool
trustedIps []string
incomingHeaders map[string]string
remoteAddr string
expectedHeaders map[string]string
}{
{
desc: "all Empty",
insecure: true,
trustedIps: nil,
remoteAddr: "",
incomingHeaders: map[string]string{},
expectedHeaders: map[string]string{
"X-Forwarded-for": "",
},
},
{
desc: "insecure true with incoming X-Forwarded-For",
insecure: true,
trustedIps: nil,
remoteAddr: "",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
},
{
desc: "insecure false with incoming X-Forwarded-For",
insecure: false,
trustedIps: nil,
remoteAddr: "",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "",
},
},
{
desc: "insecure false with incoming X-Forwarded-For and valid Trusted Ips",
insecure: false,
trustedIps: []string{"10.0.1.100"},
remoteAddr: "10.0.1.100:80",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
},
{
desc: "insecure false with incoming X-Forwarded-For and invalid Trusted Ips",
insecure: false,
trustedIps: []string{"10.0.1.100"},
remoteAddr: "10.0.1.101:80",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "",
},
},
{
desc: "insecure false with incoming X-Forwarded-For and valid Trusted Ips CIDR",
insecure: false,
trustedIps: []string{"1.2.3.4/24"},
remoteAddr: "1.2.3.156:80",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
},
{
desc: "insecure false with incoming X-Forwarded-For and invalid Trusted Ips CIDR",
insecure: false,
trustedIps: []string{"1.2.3.4/24"},
remoteAddr: "10.0.1.101:80",
incomingHeaders: map[string]string{
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
},
expectedHeaders: map[string]string{
"X-Forwarded-for": "",
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
req, err := http.NewRequest(http.MethodGet, "", nil)
require.NoError(t, err)
req.RemoteAddr = test.remoteAddr
for k, v := range test.incomingHeaders {
req.Header.Set(k, v)
}
m, err := NewXforwarded(test.insecure, test.trustedIps)
require.NoError(t, err)
m.ServeHTTP(nil, req, nil)
for k, v := range test.expectedHeaders {
assert.Equal(t, v, req.Header.Get(k))
}
})
}
}

View file

@ -0,0 +1,36 @@
package middlewares
import (
"net/http"
"github.com/containous/mux"
"github.com/containous/traefik/safe"
)
// HandlerSwitcher allows hot switching of http.ServeMux
type HandlerSwitcher struct {
handler *safe.Safe
}
// NewHandlerSwitcher builds a new instance of HandlerSwitcher
func NewHandlerSwitcher(newHandler *mux.Router) (hs *HandlerSwitcher) {
return &HandlerSwitcher{
handler: safe.New(newHandler),
}
}
func (hs *HandlerSwitcher) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
handlerBackup := hs.handler.Get().(*mux.Router)
handlerBackup.ServeHTTP(rw, r)
}
// GetHandler returns the current http.ServeMux
func (hs *HandlerSwitcher) GetHandler() (newHandler *mux.Router) {
handler := hs.handler.Get().(*mux.Router)
return handler
}
// UpdateHandler safely updates the current http.ServeMux with a new one
func (hs *HandlerSwitcher) UpdateHandler(newHandler *mux.Router) {
hs.handler.Set(newHandler)
}

View file

@ -0,0 +1,71 @@
package middlewares
// Middleware based on https://github.com/unrolled/secure
import (
"net/http"
"github.com/containous/traefik/old/types"
)
// HeaderOptions is a struct for specifying configuration options for the headers middleware.
type HeaderOptions struct {
// If Custom request headers are set, these will be added to the request
CustomRequestHeaders map[string]string
// If Custom response headers are set, these will be added to the ResponseWriter
CustomResponseHeaders map[string]string
}
// HeaderStruct is a middleware that helps setup a few basic security features. A single headerOptions struct can be
// provided to configure which features should be enabled, and the ability to override a few of the default values.
type HeaderStruct struct {
// Customize headers with a headerOptions struct.
opt HeaderOptions
}
// NewHeaderFromStruct constructs a new header instance from supplied frontend header struct.
func NewHeaderFromStruct(headers *types.Headers) *HeaderStruct {
if headers == nil || !headers.HasCustomHeadersDefined() {
return nil
}
return &HeaderStruct{
opt: HeaderOptions{
CustomRequestHeaders: headers.CustomRequestHeaders,
CustomResponseHeaders: headers.CustomResponseHeaders,
},
}
}
func (s *HeaderStruct) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
s.ModifyRequestHeaders(r)
// If there is a next, call it.
if next != nil {
next(w, r)
}
}
// ModifyRequestHeaders set or delete request headers
func (s *HeaderStruct) ModifyRequestHeaders(r *http.Request) {
// Loop through Custom request headers
for header, value := range s.opt.CustomRequestHeaders {
if value == "" {
r.Header.Del(header)
} else {
r.Header.Set(header, value)
}
}
}
// ModifyResponseHeaders set or delete response headers
func (s *HeaderStruct) ModifyResponseHeaders(res *http.Response) error {
// Loop through Custom response headers
for header, value := range s.opt.CustomResponseHeaders {
if value == "" {
res.Header.Del(header)
} else {
res.Header.Set(header, value)
}
}
return nil
}

View file

@ -0,0 +1,118 @@
package middlewares
// Middleware tests based on https://github.com/unrolled/secure
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
)
var myHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("bar"))
})
// newHeader constructs a new header instance with supplied options.
func newHeader(options ...HeaderOptions) *HeaderStruct {
var opt HeaderOptions
if len(options) == 0 {
opt = HeaderOptions{}
} else {
opt = options[0]
}
return &HeaderStruct{opt: opt}
}
func TestNoConfig(t *testing.T) {
header := newHeader()
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://example.com/foo", nil)
header.ServeHTTP(res, req, myHandler)
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
assert.Equal(t, "bar", res.Body.String(), "Body not the expected")
}
func TestModifyResponseHeaders(t *testing.T) {
header := newHeader(HeaderOptions{
CustomResponseHeaders: map[string]string{
"X-Custom-Response-Header": "test_response",
},
})
res := httptest.NewRecorder()
res.HeaderMap.Add("X-Custom-Response-Header", "test_response")
err := header.ModifyResponseHeaders(res.Result())
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
assert.Equal(t, "test_response", res.Header().Get("X-Custom-Response-Header"), "Did not get expected header")
res = httptest.NewRecorder()
res.HeaderMap.Add("X-Custom-Response-Header", "")
err = header.ModifyResponseHeaders(res.Result())
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
assert.Equal(t, "", res.Header().Get("X-Custom-Response-Header"), "Did not get expected header")
res = httptest.NewRecorder()
res.HeaderMap.Add("X-Custom-Response-Header", "test_override")
err = header.ModifyResponseHeaders(res.Result())
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
assert.Equal(t, "test_override", res.Header().Get("X-Custom-Response-Header"), "Did not get expected header")
}
func TestCustomRequestHeader(t *testing.T) {
header := newHeader(HeaderOptions{
CustomRequestHeaders: map[string]string{
"X-Custom-Request-Header": "test_request",
},
})
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
header.ServeHTTP(res, req, nil)
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"), "Did not get expected header")
}
func TestCustomRequestHeaderEmptyValue(t *testing.T) {
header := newHeader(HeaderOptions{
CustomRequestHeaders: map[string]string{
"X-Custom-Request-Header": "test_request",
},
})
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
header.ServeHTTP(res, req, nil)
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"), "Did not get expected header")
header = newHeader(HeaderOptions{
CustomRequestHeaders: map[string]string{
"X-Custom-Request-Header": "",
},
})
header.ServeHTTP(res, req, nil)
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
assert.Equal(t, "", req.Header.Get("X-Custom-Request-Header"), "This header is not expected")
}

View file

@ -0,0 +1,67 @@
package middlewares
import (
"fmt"
"net/http"
"github.com/containous/traefik/ip"
"github.com/containous/traefik/old/log"
"github.com/containous/traefik/old/middlewares/tracing"
"github.com/pkg/errors"
"github.com/urfave/negroni"
)
// IPWhiteLister is a middleware that provides Checks of the Requesting IP against a set of Whitelists
type IPWhiteLister struct {
handler negroni.Handler
whiteLister *ip.Checker
strategy ip.Strategy
}
// NewIPWhiteLister builds a new IPWhiteLister given a list of CIDR-Strings to whitelist
func NewIPWhiteLister(whiteList []string, strategy ip.Strategy) (*IPWhiteLister, error) {
if len(whiteList) == 0 {
return nil, errors.New("no white list provided")
}
checker, err := ip.NewChecker(whiteList)
if err != nil {
return nil, fmt.Errorf("parsing CIDR whitelist %s: %v", whiteList, err)
}
whiteLister := IPWhiteLister{
strategy: strategy,
whiteLister: checker,
}
whiteLister.handler = negroni.HandlerFunc(whiteLister.handle)
log.Debugf("configured IP white list: %s", whiteList)
return &whiteLister, nil
}
func (wl *IPWhiteLister) handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
err := wl.whiteLister.IsAuthorized(wl.strategy.GetIP(r))
if err != nil {
tracing.SetErrorAndDebugLog(r, "request %+v - rejecting: %v", r, err)
reject(w)
return
}
log.Debugf("Accept %s: %+v", wl.strategy.GetIP(r), r)
tracing.SetErrorAndDebugLog(r, "request %+v matched white list %v - passing", r, wl.whiteLister)
next.ServeHTTP(w, r)
}
func (wl *IPWhiteLister) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
wl.handler.ServeHTTP(rw, r, next)
}
func reject(w http.ResponseWriter) {
statusCode := http.StatusForbidden
w.WriteHeader(statusCode)
_, err := w.Write([]byte(http.StatusText(statusCode)))
if err != nil {
log.Error(err)
}
}

View file

@ -0,0 +1,92 @@
package middlewares
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/ip"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewIPWhiteLister(t *testing.T) {
testCases := []struct {
desc string
whiteList []string
expectedError string
}{
{
desc: "invalid IP",
whiteList: []string{"foo"},
expectedError: "parsing CIDR whitelist [foo]: parsing CIDR trusted IPs <nil>: invalid CIDR address: foo",
},
{
desc: "valid IP",
whiteList: []string{"10.10.10.10"},
expectedError: "",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
whiteLister, err := NewIPWhiteLister(test.whiteList, &ip.RemoteAddrStrategy{})
if len(test.expectedError) > 0 {
assert.EqualError(t, err, test.expectedError)
} else {
require.NoError(t, err)
assert.NotNil(t, whiteLister)
}
})
}
}
func TestIPWhiteLister_ServeHTTP(t *testing.T) {
testCases := []struct {
desc string
whiteList []string
remoteAddr string
expected int
}{
{
desc: "authorized with remote address",
whiteList: []string{"20.20.20.20"},
remoteAddr: "20.20.20.20:1234",
expected: 200,
},
{
desc: "non authorized with remote address",
whiteList: []string{"20.20.20.20"},
remoteAddr: "20.20.20.21:1234",
expected: 403,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
whiteLister, err := NewIPWhiteLister(test.whiteList, &ip.RemoteAddrStrategy{})
require.NoError(t, err)
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://10.10.10.10", nil)
if len(test.remoteAddr) > 0 {
req.RemoteAddr = test.remoteAddr
}
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
whiteLister.ServeHTTP(recorder, req, next)
assert.Equal(t, test.expected, recorder.Code)
})
}
}

View file

@ -0,0 +1,62 @@
package pipelining
import (
"bufio"
"net"
"net/http"
)
// Pipelining returns a middleware
type Pipelining struct {
next http.Handler
}
// NewPipelining returns a new Pipelining instance
func NewPipelining(next http.Handler) *Pipelining {
return &Pipelining{
next: next,
}
}
func (p *Pipelining) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
// https://github.com/golang/go/blob/3d59583836630cf13ec4bfbed977d27b1b7adbdc/src/net/http/server.go#L201-L218
if r.Method == http.MethodPut || r.Method == http.MethodPost {
p.next.ServeHTTP(rw, r)
} else {
p.next.ServeHTTP(&writerWithoutCloseNotify{rw}, r)
}
}
// writerWithoutCloseNotify helps to disable closeNotify
type writerWithoutCloseNotify struct {
W http.ResponseWriter
}
// Header returns the response headers.
func (w *writerWithoutCloseNotify) Header() http.Header {
return w.W.Header()
}
// Write writes the data to the connection as part of an HTTP reply.
func (w *writerWithoutCloseNotify) Write(buf []byte) (int, error) {
return w.W.Write(buf)
}
// WriteHeader sends an HTTP response header with the provided
// status code.
func (w *writerWithoutCloseNotify) WriteHeader(code int) {
w.W.WriteHeader(code)
}
// Flush sends any buffered data to the client.
func (w *writerWithoutCloseNotify) Flush() {
if f, ok := w.W.(http.Flusher); ok {
f.Flush()
}
}
// Hijack hijacks the connection.
func (w *writerWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return w.W.(http.Hijacker).Hijack()
}

View file

@ -0,0 +1,69 @@
package pipelining
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
type recorderWithCloseNotify struct {
*httptest.ResponseRecorder
}
func (r *recorderWithCloseNotify) CloseNotify() <-chan bool {
panic("implement me")
}
func TestNewPipelining(t *testing.T) {
testCases := []struct {
desc string
HTTPMethod string
implementCloseNotifier bool
}{
{
desc: "should not implement CloseNotifier with GET method",
HTTPMethod: http.MethodGet,
implementCloseNotifier: false,
},
{
desc: "should implement CloseNotifier with PUT method",
HTTPMethod: http.MethodPut,
implementCloseNotifier: true,
},
{
desc: "should implement CloseNotifier with POST method",
HTTPMethod: http.MethodPost,
implementCloseNotifier: true,
},
{
desc: "should not implement CloseNotifier with GET method",
HTTPMethod: http.MethodHead,
implementCloseNotifier: false,
},
{
desc: "should not implement CloseNotifier with PROPFIND method",
HTTPMethod: "PROPFIND",
implementCloseNotifier: false,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, ok := w.(http.CloseNotifier)
assert.Equal(t, test.implementCloseNotifier, ok)
w.WriteHeader(http.StatusOK)
})
handler := NewPipelining(nextHandler)
req := httptest.NewRequest(test.HTTPMethod, "http://localhost", nil)
handler.ServeHTTP(&recorderWithCloseNotify{httptest.NewRecorder()}, req)
})
}
}

View file

@ -0,0 +1,51 @@
package middlewares
import (
"net/http"
"runtime"
"github.com/containous/traefik/old/log"
"github.com/urfave/negroni"
)
// RecoverHandler recovers from a panic in http handlers
func RecoverHandler(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
defer recoverFunc(w, r)
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
// NegroniRecoverHandler recovers from a panic in negroni handlers
func NegroniRecoverHandler() negroni.Handler {
fn := func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
defer recoverFunc(w, r)
next.ServeHTTP(w, r)
}
return negroni.HandlerFunc(fn)
}
func recoverFunc(w http.ResponseWriter, r *http.Request) {
if err := recover(); err != nil {
if !shouldLogPanic(err) {
log.Debugf("Request has been aborted [%s - %s]: %v", r.RemoteAddr, r.URL, err)
return
}
log.Errorf("Recovered from panic in HTTP handler [%s - %s]: %+v", r.RemoteAddr, r.URL, err)
const size = 64 << 10
buf := make([]byte, size)
buf = buf[:runtime.Stack(buf, false)]
log.Errorf("Stack: %s", buf)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
}
// https://github.com/golang/go/blob/a0d6420d8be2ae7164797051ec74fa2a2df466a1/src/net/http/server.go#L1761-L1775
// https://github.com/golang/go/blob/c33153f7b416c03983324b3e8f869ce1116d84bc/src/net/http/httputil/reverseproxy.go#L284
func shouldLogPanic(panicValue interface{}) bool {
return panicValue != nil && panicValue != http.ErrAbortHandler
}

View file

@ -0,0 +1,45 @@
package middlewares
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/urfave/negroni"
)
func TestRecoverHandler(t *testing.T) {
fn := func(w http.ResponseWriter, r *http.Request) {
panic("I love panicing!")
}
recoverHandler := RecoverHandler(http.HandlerFunc(fn))
server := httptest.NewServer(recoverHandler)
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != http.StatusInternalServerError {
t.Fatalf("Received non-%d response: %d\n", http.StatusInternalServerError, resp.StatusCode)
}
}
func TestNegroniRecoverHandler(t *testing.T) {
n := negroni.New()
n.Use(NegroniRecoverHandler())
panicHandler := func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
panic("I love panicing!")
}
n.UseFunc(negroni.HandlerFunc(panicHandler))
server := httptest.NewServer(n)
defer server.Close()
resp, err := http.Get(server.URL)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != http.StatusInternalServerError {
t.Fatalf("Received non-%d response: %d\n", http.StatusInternalServerError, resp.StatusCode)
}
}

View file

@ -0,0 +1,163 @@
package redirect
import (
"bytes"
"fmt"
"io"
"net/http"
"net/url"
"regexp"
"strings"
"text/template"
"github.com/containous/traefik/old/configuration"
"github.com/containous/traefik/old/middlewares"
"github.com/urfave/negroni"
"github.com/vulcand/oxy/utils"
)
const (
defaultRedirectRegex = `^(?:https?:\/\/)?([\w\._-]+)(?::\d+)?(.*)$`
)
// NewEntryPointHandler create a new redirection handler base on entry point
func NewEntryPointHandler(dstEntryPoint *configuration.EntryPoint, permanent bool) (negroni.Handler, error) {
exp := regexp.MustCompile(`(:\d+)`)
match := exp.FindStringSubmatch(dstEntryPoint.Address)
if len(match) == 0 {
return nil, fmt.Errorf("bad Address format %q", dstEntryPoint.Address)
}
protocol := "http"
if dstEntryPoint.TLS != nil {
protocol = "https"
}
replacement := protocol + "://${1}" + match[0] + "${2}"
return NewRegexHandler(defaultRedirectRegex, replacement, permanent)
}
// NewRegexHandler create a new redirection handler base on regex
func NewRegexHandler(exp string, replacement string, permanent bool) (negroni.Handler, error) {
re, err := regexp.Compile(exp)
if err != nil {
return nil, err
}
return &handler{
regexp: re,
replacement: replacement,
permanent: permanent,
errHandler: utils.DefaultHandler,
}, nil
}
type handler struct {
regexp *regexp.Regexp
replacement string
permanent bool
errHandler utils.ErrorHandler
}
func (h *handler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
oldURL := rawURL(req)
// only continue if the Regexp param matches the URL
if !h.regexp.MatchString(oldURL) {
next.ServeHTTP(rw, req)
return
}
// apply a rewrite regexp to the URL
newURL := h.regexp.ReplaceAllString(oldURL, h.replacement)
// replace any variables that may be in there
rewrittenURL := &bytes.Buffer{}
if err := applyString(newURL, rewrittenURL, req); err != nil {
h.errHandler.ServeHTTP(rw, req, err)
return
}
// parse the rewritten URL and replace request URL with it
parsedURL, err := url.Parse(rewrittenURL.String())
if err != nil {
h.errHandler.ServeHTTP(rw, req, err)
return
}
if stripPrefix, stripPrefixOk := req.Context().Value(middlewares.StripPrefixKey).(string); stripPrefixOk {
if len(stripPrefix) > 0 {
parsedURL.Path = stripPrefix
}
}
if addPrefix, addPrefixOk := req.Context().Value(middlewares.AddPrefixKey).(string); addPrefixOk {
if len(addPrefix) > 0 {
parsedURL.Path = strings.Replace(parsedURL.Path, addPrefix, "", 1)
}
}
if replacePath, replacePathOk := req.Context().Value(middlewares.ReplacePathKey).(string); replacePathOk {
if len(replacePath) > 0 {
parsedURL.Path = replacePath
}
}
if newURL != oldURL {
handler := &moveHandler{location: parsedURL, permanent: h.permanent}
handler.ServeHTTP(rw, req)
return
}
req.URL = parsedURL
// make sure the request URI corresponds the rewritten URL
req.RequestURI = req.URL.RequestURI()
next.ServeHTTP(rw, req)
}
type moveHandler struct {
location *url.URL
permanent bool
}
func (m *moveHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("Location", m.location.String())
status := http.StatusFound
if m.permanent {
status = http.StatusMovedPermanently
}
rw.WriteHeader(status)
rw.Write([]byte(http.StatusText(status)))
}
func rawURL(request *http.Request) string {
scheme := "http"
if request.TLS != nil || isXForwardedHTTPS(request) {
scheme = "https"
}
return strings.Join([]string{scheme, "://", request.Host, request.RequestURI}, "")
}
func isXForwardedHTTPS(request *http.Request) bool {
xForwardedProto := request.Header.Get("X-Forwarded-Proto")
return len(xForwardedProto) > 0 && xForwardedProto == "https"
}
func applyString(in string, out io.Writer, request *http.Request) error {
t, err := template.New("t").Parse(in)
if err != nil {
return err
}
data := struct {
Request *http.Request
}{
Request: request,
}
return t.Execute(out, data)
}

View file

@ -0,0 +1,182 @@
package redirect
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/old/configuration"
"github.com/containous/traefik/testhelpers"
"github.com/containous/traefik/tls"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewEntryPointHandler(t *testing.T) {
testCases := []struct {
desc string
entryPoint *configuration.EntryPoint
permanent bool
url string
expectedURL string
expectedStatus int
errorExpected bool
}{
{
desc: "HTTP to HTTPS",
entryPoint: &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}},
url: "http://foo:80",
expectedURL: "https://foo:443",
expectedStatus: http.StatusFound,
},
{
desc: "HTTPS to HTTP",
entryPoint: &configuration.EntryPoint{Address: ":80"},
url: "https://foo:443",
expectedURL: "http://foo:80",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP to HTTP",
entryPoint: &configuration.EntryPoint{Address: ":88"},
url: "http://foo:80",
expectedURL: "http://foo:88",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP to HTTPS permanent",
entryPoint: &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}},
permanent: true,
url: "http://foo:80",
expectedURL: "https://foo:443",
expectedStatus: http.StatusMovedPermanently,
},
{
desc: "HTTPS to HTTP permanent",
entryPoint: &configuration.EntryPoint{Address: ":80"},
permanent: true,
url: "https://foo:443",
expectedURL: "http://foo:80",
expectedStatus: http.StatusMovedPermanently,
},
{
desc: "invalid address",
entryPoint: &configuration.EntryPoint{Address: ":foo", TLS: &tls.TLS{}},
url: "http://foo:80",
errorExpected: true,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
handler, err := NewEntryPointHandler(test.entryPoint, test.permanent)
if test.errorExpected {
require.Error(t, err)
} else {
require.NoError(t, err)
recorder := httptest.NewRecorder()
r := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
handler.ServeHTTP(recorder, r, nil)
location, err := recorder.Result().Location()
require.NoError(t, err)
assert.Equal(t, test.expectedURL, location.String())
assert.Equal(t, test.expectedStatus, recorder.Code)
}
})
}
}
func TestNewRegexHandler(t *testing.T) {
testCases := []struct {
desc string
regex string
replacement string
permanent bool
url string
expectedURL string
expectedStatus int
errorExpected bool
}{
{
desc: "simple redirection",
regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
replacement: "https://${1}bar$2:443$4",
url: "http://foo.com:80",
expectedURL: "https://foobar.com:443",
expectedStatus: http.StatusFound,
},
{
desc: "use request header",
regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
replacement: `https://${1}{{ .Request.Header.Get "X-Foo" }}$2:443$4`,
url: "http://foo.com:80",
expectedURL: "https://foobar.com:443",
expectedStatus: http.StatusFound,
},
{
desc: "URL doesn't match regex",
regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
replacement: "https://${1}bar$2:443$4",
url: "http://bar.com:80",
expectedStatus: http.StatusOK,
},
{
desc: "invalid rewritten URL",
regex: `^(.*)$`,
replacement: "http://192.168.0.%31/",
url: "http://foo.com:80",
expectedStatus: http.StatusBadGateway,
},
{
desc: "invalid regex",
regex: `^(.*`,
replacement: "$1",
url: "http://foo.com:80",
errorExpected: true,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
handler, err := NewRegexHandler(test.regex, test.replacement, test.permanent)
if test.errorExpected {
require.Nil(t, handler)
require.Error(t, err)
} else {
require.NotNil(t, handler)
require.NoError(t, err)
recorder := httptest.NewRecorder()
r := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
r.Header.Set("X-Foo", "bar")
next := func(rw http.ResponseWriter, req *http.Request) {}
handler.ServeHTTP(recorder, r, next)
if test.expectedStatus == http.StatusMovedPermanently || test.expectedStatus == http.StatusFound {
assert.Equal(t, test.expectedStatus, recorder.Code)
location, err := recorder.Result().Location()
require.NoError(t, err)
assert.Equal(t, test.expectedURL, location.String())
} else {
assert.Equal(t, test.expectedStatus, recorder.Code)
location, err := recorder.Result().Location()
require.Errorf(t, err, "Location %v", location)
}
}
})
}
}

View file

@ -0,0 +1,28 @@
package middlewares
import (
"context"
"net/http"
)
const (
// ReplacePathKey is the key within the request context used to
// store the replaced path
ReplacePathKey key = "ReplacePath"
// ReplacedPathHeader is the default header to set the old path to
ReplacedPathHeader = "X-Replaced-Path"
)
// ReplacePath is a middleware used to replace the path of a URL request
type ReplacePath struct {
Handler http.Handler
Path string
}
func (s *ReplacePath) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(context.WithValue(r.Context(), ReplacePathKey, r.URL.Path))
r.Header.Add(ReplacedPathHeader, r.URL.Path)
r.URL.Path = s.Path
r.RequestURI = r.URL.RequestURI()
s.Handler.ServeHTTP(w, r)
}

View file

@ -0,0 +1,40 @@
package middlewares
import (
"context"
"net/http"
"regexp"
"strings"
"github.com/containous/traefik/old/log"
)
// ReplacePathRegex is a middleware used to replace the path of a URL request with a regular expression
type ReplacePathRegex struct {
Handler http.Handler
Regexp *regexp.Regexp
Replacement string
}
// NewReplacePathRegexHandler returns a new ReplacePathRegex
func NewReplacePathRegexHandler(regex string, replacement string, handler http.Handler) http.Handler {
exp, err := regexp.Compile(strings.TrimSpace(regex))
if err != nil {
log.Errorf("Error compiling regular expression %s: %s", regex, err)
}
return &ReplacePathRegex{
Regexp: exp,
Replacement: strings.TrimSpace(replacement),
Handler: handler,
}
}
func (s *ReplacePathRegex) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if s.Regexp != nil && len(s.Replacement) > 0 && s.Regexp.MatchString(r.URL.Path) {
r = r.WithContext(context.WithValue(r.Context(), ReplacePathKey, r.URL.Path))
r.Header.Add(ReplacedPathHeader, r.URL.Path)
r.URL.Path = s.Regexp.ReplaceAllString(r.URL.Path, s.Replacement)
r.RequestURI = r.URL.RequestURI()
}
s.Handler.ServeHTTP(w, r)
}

View file

@ -0,0 +1,80 @@
package middlewares
import (
"net/http"
"testing"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
)
func TestReplacePathRegex(t *testing.T) {
testCases := []struct {
desc string
path string
replacement string
regex string
expectedPath string
expectedHeader string
}{
{
desc: "simple regex",
path: "/whoami/and/whoami",
replacement: "/who-am-i/$1",
regex: `^/whoami/(.*)`,
expectedPath: "/who-am-i/and/whoami",
expectedHeader: "/whoami/and/whoami",
},
{
desc: "simple replace (no regex)",
path: "/whoami/and/whoami",
replacement: "/who-am-i",
regex: `/whoami`,
expectedPath: "/who-am-i/and/who-am-i",
expectedHeader: "/whoami/and/whoami",
},
{
desc: "multiple replacement",
path: "/downloads/src/source.go",
replacement: "/downloads/$1-$2",
regex: `^(?i)/downloads/([^/]+)/([^/]+)$`,
expectedPath: "/downloads/src-source.go",
expectedHeader: "/downloads/src/source.go",
},
{
desc: "invalid regular expression",
path: "/invalid/regexp/test",
replacement: "/valid/regexp/$1",
regex: `^(?err)/invalid/regexp/([^/]+)$`,
expectedPath: "/invalid/regexp/test",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
var actualPath, actualHeader, requestURI string
handler := NewReplacePathRegexHandler(
test.regex,
test.replacement,
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actualPath = r.URL.Path
actualHeader = r.Header.Get(ReplacedPathHeader)
requestURI = r.RequestURI
}),
)
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
handler.ServeHTTP(nil, req)
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", ReplacedPathHeader)
if test.expectedHeader != "" {
assert.Equal(t, actualPath, requestURI, "Unexpected request URI.")
}
})
}
}

View file

@ -0,0 +1,41 @@
package middlewares
import (
"net/http"
"testing"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
)
func TestReplacePath(t *testing.T) {
const replacementPath = "/replacement-path"
paths := []string{
"/example",
"/some/really/long/path",
}
for _, path := range paths {
t.Run(path, func(t *testing.T) {
var expectedPath, actualHeader, requestURI string
handler := &ReplacePath{
Path: replacementPath,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
expectedPath = r.URL.Path
actualHeader = r.Header.Get(ReplacedPathHeader)
requestURI = r.RequestURI
}),
}
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+path, nil)
handler.ServeHTTP(nil, req)
assert.Equal(t, expectedPath, replacementPath, "Unexpected path.")
assert.Equal(t, path, actualHeader, "Unexpected '%s' header.", ReplacedPathHeader)
assert.Equal(t, expectedPath, requestURI, "Unexpected request URI.")
})
}
}

View file

@ -0,0 +1,45 @@
package middlewares
import (
"context"
"net"
"net/http"
"strings"
"github.com/containous/traefik/old/log"
"github.com/containous/traefik/old/types"
)
var requestHostKey struct{}
// RequestHost is the struct for the middleware that adds the CanonicalDomain of the request Host into a context for later use.
type RequestHost struct{}
func (rh *RequestHost) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
if next != nil {
host := types.CanonicalDomain(parseHost(r.Host))
next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), requestHostKey, host)))
}
}
func parseHost(addr string) string {
if !strings.Contains(addr, ":") {
return addr
}
host, _, err := net.SplitHostPort(addr)
if err != nil {
return addr
}
return host
}
// GetCanonizedHost plucks the canonized host key from the request of a context that was put through the middleware
func GetCanonizedHost(ctx context.Context) string {
if val, ok := ctx.Value(requestHostKey).(string); ok {
return val
}
log.Warn("RequestHost is missing in the middleware chain")
return ""
}

View file

@ -0,0 +1,94 @@
package middlewares
import (
"net/http"
"testing"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
)
func TestRequestHost(t *testing.T) {
testCases := []struct {
desc string
url string
expected string
}{
{
desc: "host without :",
url: "http://host",
expected: "host",
},
{
desc: "host with : and without port",
url: "http://host:",
expected: "host",
},
{
desc: "IP host with : and with port",
url: "http://127.0.0.1:123",
expected: "127.0.0.1",
},
{
desc: "IP host with : and without port",
url: "http://127.0.0.1:",
expected: "127.0.0.1",
},
}
rh := &RequestHost{}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
req := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
rh.ServeHTTP(nil, req, func(_ http.ResponseWriter, r *http.Request) {
host := GetCanonizedHost(r.Context())
assert.Equal(t, test.expected, host)
})
})
}
}
func TestRequestHostParseHost(t *testing.T) {
testCases := []struct {
desc string
host string
expected string
}{
{
desc: "host without :",
host: "host",
expected: "host",
},
{
desc: "host with : and without port",
host: "host:",
expected: "host",
},
{
desc: "IP host with : and with port",
host: "127.0.0.1:123",
expected: "127.0.0.1",
},
{
desc: "IP host with : and without port",
host: "127.0.0.1:",
expected: "127.0.0.1",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
actual := parseHost(test.host)
assert.Equal(t, test.expected, actual)
})
}
}

173
old/middlewares/retry.go Normal file
View file

@ -0,0 +1,173 @@
package middlewares
import (
"bufio"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptrace"
"github.com/containous/traefik/old/log"
)
// Compile time validation that the response writer implements http interfaces correctly.
var _ Stateful = &retryResponseWriterWithCloseNotify{}
// Retry is a middleware that retries requests
type Retry struct {
attempts int
next http.Handler
listener RetryListener
}
// NewRetry returns a new Retry instance
func NewRetry(attempts int, next http.Handler, listener RetryListener) *Retry {
return &Retry{
attempts: attempts,
next: next,
listener: listener,
}
}
func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
// if we might make multiple attempts, swap the body for an ioutil.NopCloser
// cf https://github.com/containous/traefik/issues/1008
if retry.attempts > 1 {
body := r.Body
if body == nil {
body = http.NoBody
}
defer body.Close()
r.Body = ioutil.NopCloser(body)
}
attempts := 1
for {
attemptsExhausted := attempts >= retry.attempts
shouldRetry := !attemptsExhausted
retryResponseWriter := newRetryResponseWriter(rw, shouldRetry)
// Disable retries when the backend already received request data
trace := &httptrace.ClientTrace{
WroteHeaders: func() {
retryResponseWriter.DisableRetries()
},
WroteRequest: func(httptrace.WroteRequestInfo) {
retryResponseWriter.DisableRetries()
},
}
newCtx := httptrace.WithClientTrace(r.Context(), trace)
retry.next.ServeHTTP(retryResponseWriter, r.WithContext(newCtx))
if !retryResponseWriter.ShouldRetry() {
break
}
attempts++
log.Debugf("New attempt %d for request: %v", attempts, r.URL)
retry.listener.Retried(r, attempts)
}
}
// RetryListener is used to inform about retry attempts.
type RetryListener interface {
// Retried will be called when a retry happens, with the request attempt passed to it.
// For the first retry this will be attempt 2.
Retried(req *http.Request, attempt int)
}
// RetryListeners is a convenience type to construct a list of RetryListener and notify
// each of them about a retry attempt.
type RetryListeners []RetryListener
// Retried exists to implement the RetryListener interface. It calls Retried on each of its slice entries.
func (l RetryListeners) Retried(req *http.Request, attempt int) {
for _, retryListener := range l {
retryListener.Retried(req, attempt)
}
}
type retryResponseWriter interface {
http.ResponseWriter
http.Flusher
ShouldRetry() bool
DisableRetries()
}
func newRetryResponseWriter(rw http.ResponseWriter, shouldRetry bool) retryResponseWriter {
responseWriter := &retryResponseWriterWithoutCloseNotify{
responseWriter: rw,
shouldRetry: shouldRetry,
}
if _, ok := rw.(http.CloseNotifier); ok {
return &retryResponseWriterWithCloseNotify{responseWriter}
}
return responseWriter
}
type retryResponseWriterWithoutCloseNotify struct {
responseWriter http.ResponseWriter
shouldRetry bool
}
func (rr *retryResponseWriterWithoutCloseNotify) ShouldRetry() bool {
return rr.shouldRetry
}
func (rr *retryResponseWriterWithoutCloseNotify) DisableRetries() {
rr.shouldRetry = false
}
func (rr *retryResponseWriterWithoutCloseNotify) Header() http.Header {
if rr.ShouldRetry() {
return make(http.Header)
}
return rr.responseWriter.Header()
}
func (rr *retryResponseWriterWithoutCloseNotify) Write(buf []byte) (int, error) {
if rr.ShouldRetry() {
return len(buf), nil
}
return rr.responseWriter.Write(buf)
}
func (rr *retryResponseWriterWithoutCloseNotify) WriteHeader(code int) {
if rr.ShouldRetry() && code == http.StatusServiceUnavailable {
// We get a 503 HTTP Status Code when there is no backend server in the pool
// to which the request could be sent. Also, note that rr.ShouldRetry()
// will never return true in case there was a connection established to
// the backend server and so we can be sure that the 503 was produced
// inside Traefik already and we don't have to retry in this cases.
rr.DisableRetries()
}
if rr.ShouldRetry() {
return
}
rr.responseWriter.WriteHeader(code)
}
func (rr *retryResponseWriterWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := rr.responseWriter.(http.Hijacker)
if !ok {
return nil, nil, fmt.Errorf("%T is not a http.Hijacker", rr.responseWriter)
}
return hijacker.Hijack()
}
func (rr *retryResponseWriterWithoutCloseNotify) Flush() {
if flusher, ok := rr.responseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
type retryResponseWriterWithCloseNotify struct {
*retryResponseWriterWithoutCloseNotify
}
func (rr *retryResponseWriterWithCloseNotify) CloseNotify() <-chan bool {
return rr.responseWriter.(http.CloseNotifier).CloseNotify()
}

View file

@ -0,0 +1,260 @@
package middlewares
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/containous/traefik/testhelpers"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/roundrobin"
)
func TestRetry(t *testing.T) {
testCases := []struct {
desc string
maxRequestAttempts int
wantRetryAttempts int
wantResponseStatus int
amountFaultyEndpoints int
}{
{
desc: "no retry on success",
maxRequestAttempts: 1,
wantRetryAttempts: 0,
wantResponseStatus: http.StatusOK,
amountFaultyEndpoints: 0,
},
{
desc: "no retry when max request attempts is one",
maxRequestAttempts: 1,
wantRetryAttempts: 0,
wantResponseStatus: http.StatusInternalServerError,
amountFaultyEndpoints: 1,
},
{
desc: "one retry when one server is faulty",
maxRequestAttempts: 2,
wantRetryAttempts: 1,
wantResponseStatus: http.StatusOK,
amountFaultyEndpoints: 1,
},
{
desc: "two retries when two servers are faulty",
maxRequestAttempts: 3,
wantRetryAttempts: 2,
wantResponseStatus: http.StatusOK,
amountFaultyEndpoints: 2,
},
{
desc: "max attempts exhausted delivers the 5xx response",
maxRequestAttempts: 3,
wantRetryAttempts: 2,
wantResponseStatus: http.StatusInternalServerError,
amountFaultyEndpoints: 3,
},
}
backendServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
rw.Write([]byte("OK"))
}))
forwarder, err := forward.New()
if err != nil {
t.Fatalf("Error creating forwarder: %s", err)
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
loadBalancer, err := roundrobin.New(forwarder)
if err != nil {
t.Fatalf("Error creating load balancer: %s", err)
}
basePort := 33444
for i := 0; i < test.amountFaultyEndpoints; i++ {
// 192.0.2.0 is a non-routable IP for testing purposes.
// See: https://stackoverflow.com/questions/528538/non-routable-ip-address/18436928#18436928
// We only use the port specification here because the URL is used as identifier
// in the load balancer and using the exact same URL would not add a new server.
err = loadBalancer.UpsertServer(testhelpers.MustParseURL("http://192.0.2.0:" + string(basePort+i)))
assert.NoError(t, err)
}
// add the functioning server to the end of the load balancer list
err = loadBalancer.UpsertServer(testhelpers.MustParseURL(backendServer.URL))
assert.NoError(t, err)
retryListener := &countingRetryListener{}
retry := NewRetry(test.maxRequestAttempts, loadBalancer, retryListener)
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil)
retry.ServeHTTP(recorder, req)
assert.Equal(t, test.wantResponseStatus, recorder.Code)
assert.Equal(t, test.wantRetryAttempts, retryListener.timesCalled)
})
}
}
func TestRetryWebsocket(t *testing.T) {
testCases := []struct {
desc string
maxRequestAttempts int
expectedRetryAttempts int
expectedResponseStatus int
expectedError bool
amountFaultyEndpoints int
}{
{
desc: "Switching ok after 2 retries",
maxRequestAttempts: 3,
expectedRetryAttempts: 2,
amountFaultyEndpoints: 2,
expectedResponseStatus: http.StatusSwitchingProtocols,
},
{
desc: "Switching failed",
maxRequestAttempts: 2,
expectedRetryAttempts: 1,
amountFaultyEndpoints: 2,
expectedResponseStatus: http.StatusBadGateway,
expectedError: true,
},
}
forwarder, err := forward.New()
if err != nil {
t.Fatalf("Error creating forwarder: %s", err)
}
backendServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
upgrader := websocket.Upgrader{}
upgrader.Upgrade(rw, req, nil)
}))
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
loadBalancer, err := roundrobin.New(forwarder)
if err != nil {
t.Fatalf("Error creating load balancer: %s", err)
}
basePort := 33444
for i := 0; i < test.amountFaultyEndpoints; i++ {
// 192.0.2.0 is a non-routable IP for testing purposes.
// See: https://stackoverflow.com/questions/528538/non-routable-ip-address/18436928#18436928
// We only use the port specification here because the URL is used as identifier
// in the load balancer and using the exact same URL would not add a new server.
loadBalancer.UpsertServer(testhelpers.MustParseURL("http://192.0.2.0:" + string(basePort+i)))
}
// add the functioning server to the end of the load balancer list
loadBalancer.UpsertServer(testhelpers.MustParseURL(backendServer.URL))
retryListener := &countingRetryListener{}
retry := NewRetry(test.maxRequestAttempts, loadBalancer, retryListener)
retryServer := httptest.NewServer(retry)
url := strings.Replace(retryServer.URL, "http", "ws", 1)
_, response, err := websocket.DefaultDialer.Dial(url, nil)
if !test.expectedError {
require.NoError(t, err)
}
assert.Equal(t, test.expectedResponseStatus, response.StatusCode)
assert.Equal(t, test.expectedRetryAttempts, retryListener.timesCalled)
})
}
}
func TestRetryEmptyServerList(t *testing.T) {
forwarder, err := forward.New()
if err != nil {
t.Fatalf("Error creating forwarder: %s", err)
}
loadBalancer, err := roundrobin.New(forwarder)
if err != nil {
t.Fatalf("Error creating load balancer: %s", err)
}
// The EmptyBackendHandler middleware ensures that there is a 503
// response status set when there is no backend server in the pool.
next := NewEmptyBackendHandler(loadBalancer)
retryListener := &countingRetryListener{}
retry := NewRetry(3, next, retryListener)
recorder := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil)
retry.ServeHTTP(recorder, req)
const wantResponseStatus = http.StatusServiceUnavailable
if wantResponseStatus != recorder.Code {
t.Errorf("got status code %d, want %d", recorder.Code, wantResponseStatus)
}
const wantRetryAttempts = 0
if wantRetryAttempts != retryListener.timesCalled {
t.Errorf("retry listener called %d time(s), want %d time(s)", retryListener.timesCalled, wantRetryAttempts)
}
}
func TestRetryListeners(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
retryListeners := RetryListeners{&countingRetryListener{}, &countingRetryListener{}}
retryListeners.Retried(req, 1)
retryListeners.Retried(req, 1)
for _, retryListener := range retryListeners {
listener := retryListener.(*countingRetryListener)
if listener.timesCalled != 2 {
t.Errorf("retry listener was called %d time(s), want %d time(s)", listener.timesCalled, 2)
}
}
}
// countingRetryListener is a RetryListener implementation to count the times the Retried fn is called.
type countingRetryListener struct {
timesCalled int
}
func (l *countingRetryListener) Retried(req *http.Request, attempt int) {
l.timesCalled++
}
func TestRetryWithFlush(t *testing.T) {
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(200)
rw.Write([]byte("FULL "))
rw.(http.Flusher).Flush()
rw.Write([]byte("DATA"))
})
retry := NewRetry(1, next, &countingRetryListener{})
responseRecorder := httptest.NewRecorder()
retry.ServeHTTP(responseRecorder, &http.Request{})
if responseRecorder.Body.String() != "FULL DATA" {
t.Errorf("Wrong body %q want %q", responseRecorder.Body.String(), "FULL DATA")
}
}

28
old/middlewares/routes.go Normal file
View file

@ -0,0 +1,28 @@
package middlewares
import (
"encoding/json"
"log"
"net/http"
"github.com/containous/mux"
)
// Routes holds the gorilla mux routes (for the API & co).
type Routes struct {
router *mux.Router
}
// NewRoutes return a Routes based on the given router.
func NewRoutes(router *mux.Router) *Routes {
return &Routes{router}
}
func (router *Routes) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
routeMatch := mux.RouteMatch{}
if router.router.Match(r, &routeMatch) {
rt, _ := json.Marshal(routeMatch.Handler)
log.Println("Request match route ", rt)
}
next(rw, r)
}

36
old/middlewares/secure.go Normal file
View file

@ -0,0 +1,36 @@
package middlewares
import (
"github.com/containous/traefik/old/types"
"github.com/unrolled/secure"
)
// NewSecure constructs a new Secure instance with supplied options.
func NewSecure(headers *types.Headers) *secure.Secure {
if headers == nil || !headers.HasSecureHeadersDefined() {
return nil
}
opt := secure.Options{
AllowedHosts: headers.AllowedHosts,
HostsProxyHeaders: headers.HostsProxyHeaders,
SSLRedirect: headers.SSLRedirect,
SSLTemporaryRedirect: headers.SSLTemporaryRedirect,
SSLHost: headers.SSLHost,
SSLProxyHeaders: headers.SSLProxyHeaders,
STSSeconds: headers.STSSeconds,
STSIncludeSubdomains: headers.STSIncludeSubdomains,
STSPreload: headers.STSPreload,
ForceSTSHeader: headers.ForceSTSHeader,
FrameDeny: headers.FrameDeny,
CustomFrameOptionsValue: headers.CustomFrameOptionsValue,
ContentTypeNosniff: headers.ContentTypeNosniff,
BrowserXssFilter: headers.BrowserXSSFilter,
CustomBrowserXssValue: headers.CustomBrowserXSSValue,
ContentSecurityPolicy: headers.ContentSecurityPolicy,
PublicKey: headers.PublicKey,
ReferrerPolicy: headers.ReferrerPolicy,
IsDevelopment: headers.IsDevelopment,
}
return secure.New(opt)
}

View file

@ -0,0 +1,12 @@
package middlewares
import "net/http"
// Stateful interface groups all http interfaces that must be
// implemented by a stateful middleware (ie: recorders)
type Stateful interface {
http.ResponseWriter
http.Hijacker
http.Flusher
http.CloseNotifier
}

115
old/middlewares/stats.go Normal file
View file

@ -0,0 +1,115 @@
package middlewares
import (
"bufio"
"net"
"net/http"
"sync"
"time"
)
var (
_ Stateful = &responseRecorder{}
)
// StatsRecorder is an optional middleware that records more details statistics
// about requests and how they are processed. This currently consists of recent
// requests that have caused errors (4xx and 5xx status codes), making it easy
// to pinpoint problems.
type StatsRecorder struct {
mutex sync.RWMutex
numRecentErrors int
recentErrors []*statsError
}
// NewStatsRecorder returns a new StatsRecorder
func NewStatsRecorder(numRecentErrors int) *StatsRecorder {
return &StatsRecorder{
numRecentErrors: numRecentErrors,
}
}
// Stats includes all of the stats gathered by the recorder.
type Stats struct {
RecentErrors []*statsError `json:"recent_errors"`
}
// statsError represents an error that has occurred during request processing.
type statsError struct {
StatusCode int `json:"status_code"`
Status string `json:"status"`
Method string `json:"method"`
Host string `json:"host"`
Path string `json:"path"`
Time time.Time `json:"time"`
}
// responseRecorder captures information from the response and preserves it for
// later analysis.
type responseRecorder struct {
http.ResponseWriter
statusCode int
}
// WriteHeader captures the status code for later retrieval.
func (r *responseRecorder) WriteHeader(status int) {
r.ResponseWriter.WriteHeader(status)
r.statusCode = status
}
// Hijack hijacks the connection
func (r *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return r.ResponseWriter.(http.Hijacker).Hijack()
}
// CloseNotify returns a channel that receives at most a
// single value (true) when the client connection has gone
// away.
func (r *responseRecorder) CloseNotify() <-chan bool {
return r.ResponseWriter.(http.CloseNotifier).CloseNotify()
}
// Flush sends any buffered data to the client.
func (r *responseRecorder) Flush() {
r.ResponseWriter.(http.Flusher).Flush()
}
// ServeHTTP silently extracts information from the request and response as it
// is processed. If the response is 4xx or 5xx, add it to the list of 10 most
// recent errors.
func (s *StatsRecorder) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
recorder := &responseRecorder{w, http.StatusOK}
next(recorder, r)
if recorder.statusCode >= http.StatusBadRequest {
s.mutex.Lock()
defer s.mutex.Unlock()
s.recentErrors = append([]*statsError{
{
StatusCode: recorder.statusCode,
Status: http.StatusText(recorder.statusCode),
Method: r.Method,
Host: r.Host,
Path: r.URL.Path,
Time: time.Now(),
},
}, s.recentErrors...)
// Limit the size of the list to numRecentErrors
if len(s.recentErrors) > s.numRecentErrors {
s.recentErrors = s.recentErrors[:s.numRecentErrors]
}
}
}
// Data returns a copy of the statistics that have been gathered.
func (s *StatsRecorder) Data() *Stats {
s.mutex.RLock()
defer s.mutex.RUnlock()
// We can't return the slice directly or a race condition might develop
recentErrors := make([]*statsError, len(s.recentErrors))
copy(recentErrors, s.recentErrors)
return &Stats{
RecentErrors: recentErrors,
}
}

View file

@ -0,0 +1,56 @@
package middlewares
import (
"context"
"net/http"
"strings"
)
const (
// StripPrefixKey is the key within the request context used to
// store the stripped prefix
StripPrefixKey key = "StripPrefix"
// ForwardedPrefixHeader is the default header to set prefix
ForwardedPrefixHeader = "X-Forwarded-Prefix"
)
// StripPrefix is a middleware used to strip prefix from an URL request
type StripPrefix struct {
Handler http.Handler
Prefixes []string
}
func (s *StripPrefix) ServeHTTP(w http.ResponseWriter, r *http.Request) {
for _, prefix := range s.Prefixes {
if strings.HasPrefix(r.URL.Path, prefix) {
rawReqPath := r.URL.Path
r.URL.Path = stripPrefix(r.URL.Path, prefix)
if r.URL.RawPath != "" {
r.URL.RawPath = stripPrefix(r.URL.RawPath, prefix)
}
s.serveRequest(w, r, strings.TrimSpace(prefix), rawReqPath)
return
}
}
http.NotFound(w, r)
}
func (s *StripPrefix) serveRequest(w http.ResponseWriter, r *http.Request, prefix string, rawReqPath string) {
r = r.WithContext(context.WithValue(r.Context(), StripPrefixKey, rawReqPath))
r.Header.Add(ForwardedPrefixHeader, prefix)
r.RequestURI = r.URL.RequestURI()
s.Handler.ServeHTTP(w, r)
}
// SetHandler sets handler
func (s *StripPrefix) SetHandler(Handler http.Handler) {
s.Handler = Handler
}
func stripPrefix(s, prefix string) string {
return ensureLeadingSlash(strings.TrimPrefix(s, prefix))
}
func ensureLeadingSlash(str string) string {
return "/" + strings.TrimPrefix(str, "/")
}

View file

@ -0,0 +1,59 @@
package middlewares
import (
"context"
"net/http"
"github.com/containous/mux"
"github.com/containous/traefik/old/log"
)
// StripPrefixRegex is a middleware used to strip prefix from an URL request
type StripPrefixRegex struct {
Handler http.Handler
router *mux.Router
}
// NewStripPrefixRegex builds a new StripPrefixRegex given a handler and prefixes
func NewStripPrefixRegex(handler http.Handler, prefixes []string) *StripPrefixRegex {
stripPrefix := StripPrefixRegex{Handler: handler, router: mux.NewRouter()}
for _, prefix := range prefixes {
stripPrefix.router.PathPrefix(prefix)
}
return &stripPrefix
}
func (s *StripPrefixRegex) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var match mux.RouteMatch
if s.router.Match(r, &match) {
params := make([]string, 0, len(match.Vars)*2)
for key, val := range match.Vars {
params = append(params, key)
params = append(params, val)
}
prefix, err := match.Route.URL(params...)
if err != nil || len(prefix.Path) > len(r.URL.Path) {
log.Error("Error in stripPrefix middleware", err)
return
}
rawReqPath := r.URL.Path
r.URL.Path = r.URL.Path[len(prefix.Path):]
if r.URL.RawPath != "" {
r.URL.RawPath = r.URL.RawPath[len(prefix.Path):]
}
r = r.WithContext(context.WithValue(r.Context(), StripPrefixKey, rawReqPath))
r.Header.Add(ForwardedPrefixHeader, prefix.Path)
r.RequestURI = ensureLeadingSlash(r.URL.RequestURI())
s.Handler.ServeHTTP(w, r)
return
}
http.NotFound(w, r)
}
// SetHandler sets handler
func (s *StripPrefixRegex) SetHandler(Handler http.Handler) {
s.Handler = Handler
}

View file

@ -0,0 +1,97 @@
package middlewares
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
)
func TestStripPrefixRegex(t *testing.T) {
testPrefixRegex := []string{"/a/api/", "/b/{regex}/", "/c/{category}/{id:[0-9]+}/"}
tests := []struct {
path string
expectedStatusCode int
expectedPath string
expectedRawPath string
expectedHeader string
}{
{
path: "/a/test",
expectedStatusCode: http.StatusNotFound,
},
{
path: "/a/api/test",
expectedStatusCode: http.StatusOK,
expectedPath: "test",
expectedHeader: "/a/api/",
},
{
path: "/b/api/",
expectedStatusCode: http.StatusOK,
expectedHeader: "/b/api/",
},
{
path: "/b/api/test1",
expectedStatusCode: http.StatusOK,
expectedPath: "test1",
expectedHeader: "/b/api/",
},
{
path: "/b/api2/test2",
expectedStatusCode: http.StatusOK,
expectedPath: "test2",
expectedHeader: "/b/api2/",
},
{
path: "/c/api/123/",
expectedStatusCode: http.StatusOK,
expectedHeader: "/c/api/123/",
},
{
path: "/c/api/123/test3",
expectedStatusCode: http.StatusOK,
expectedPath: "test3",
expectedHeader: "/c/api/123/",
},
{
path: "/c/api/abc/test4",
expectedStatusCode: http.StatusNotFound,
},
{
path: "/a/api/a%2Fb",
expectedStatusCode: http.StatusOK,
expectedPath: "a/b",
expectedRawPath: "a%2Fb",
expectedHeader: "/a/api/",
},
}
for _, test := range tests {
test := test
t.Run(test.path, func(t *testing.T) {
t.Parallel()
var actualPath, actualRawPath, actualHeader string
handlerPath := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actualPath = r.URL.Path
actualRawPath = r.URL.RawPath
actualHeader = r.Header.Get(ForwardedPrefixHeader)
})
handler := NewStripPrefixRegex(handlerPath, testPrefixRegex)
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
resp := &httptest.ResponseRecorder{Code: http.StatusOK}
handler.ServeHTTP(resp, req)
assert.Equal(t, test.expectedStatusCode, resp.Code, "Unexpected status code.")
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
assert.Equal(t, test.expectedRawPath, actualRawPath, "Unexpected raw path.")
assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", ForwardedPrefixHeader)
})
}
}

View file

@ -0,0 +1,143 @@
package middlewares
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/assert"
)
func TestStripPrefix(t *testing.T) {
tests := []struct {
desc string
prefixes []string
path string
expectedStatusCode int
expectedPath string
expectedRawPath string
expectedHeader string
}{
{
desc: "no prefixes configured",
prefixes: []string{},
path: "/noprefixes",
expectedStatusCode: http.StatusNotFound,
},
{
desc: "wildcard (.*) requests",
prefixes: []string{"/"},
path: "/",
expectedStatusCode: http.StatusOK,
expectedPath: "/",
expectedHeader: "/",
},
{
desc: "prefix and path matching",
prefixes: []string{"/stat"},
path: "/stat",
expectedStatusCode: http.StatusOK,
expectedPath: "/",
expectedHeader: "/stat",
},
{
desc: "path prefix on exactly matching path",
prefixes: []string{"/stat/"},
path: "/stat/",
expectedStatusCode: http.StatusOK,
expectedPath: "/",
expectedHeader: "/stat/",
},
{
desc: "path prefix on matching longer path",
prefixes: []string{"/stat/"},
path: "/stat/us",
expectedStatusCode: http.StatusOK,
expectedPath: "/us",
expectedHeader: "/stat/",
},
{
desc: "path prefix on mismatching path",
prefixes: []string{"/stat/"},
path: "/status",
expectedStatusCode: http.StatusNotFound,
},
{
desc: "general prefix on matching path",
prefixes: []string{"/stat"},
path: "/stat/",
expectedStatusCode: http.StatusOK,
expectedPath: "/",
expectedHeader: "/stat",
},
{
desc: "earlier prefix matching",
prefixes: []string{"/stat", "/stat/us"},
path: "/stat/us",
expectedStatusCode: http.StatusOK,
expectedPath: "/us",
expectedHeader: "/stat",
},
{
desc: "later prefix matching",
prefixes: []string{"/mismatch", "/stat"},
path: "/stat",
expectedStatusCode: http.StatusOK,
expectedPath: "/",
expectedHeader: "/stat",
},
{
desc: "prefix matching within slash boundaries",
prefixes: []string{"/stat"},
path: "/status",
expectedStatusCode: http.StatusOK,
expectedPath: "/us",
expectedHeader: "/stat",
},
{
desc: "raw path is also stripped",
prefixes: []string{"/stat"},
path: "/stat/a%2Fb",
expectedStatusCode: http.StatusOK,
expectedPath: "/a/b",
expectedRawPath: "/a%2Fb",
expectedHeader: "/stat",
},
}
for _, test := range tests {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
var actualPath, actualRawPath, actualHeader, requestURI string
handler := &StripPrefix{
Prefixes: test.prefixes,
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actualPath = r.URL.Path
actualRawPath = r.URL.RawPath
actualHeader = r.Header.Get(ForwardedPrefixHeader)
requestURI = r.RequestURI
}),
}
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
resp := &httptest.ResponseRecorder{Code: http.StatusOK}
handler.ServeHTTP(resp, req)
assert.Equal(t, test.expectedStatusCode, resp.Code, "Unexpected status code.")
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
assert.Equal(t, test.expectedRawPath, actualRawPath, "Unexpected raw path.")
assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", ForwardedPrefixHeader)
expectedURI := test.expectedPath
if test.expectedRawPath != "" {
// go HTTP uses the raw path when existent in the RequestURI
expectedURI = test.expectedRawPath
}
assert.Equal(t, expectedURI, requestURI, "Unexpected request URI.")
})
}
}

View file

@ -0,0 +1,251 @@
package middlewares
import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"net/http"
"net/url"
"strings"
"github.com/containous/traefik/old/log"
"github.com/containous/traefik/old/types"
)
const xForwardedTLSClientCert = "X-Forwarded-Tls-Client-Cert"
const xForwardedTLSClientCertInfos = "X-Forwarded-Tls-Client-Cert-Infos"
// TLSClientCertificateInfos is a struct for specifying the configuration for the tlsClientHeaders middleware.
type TLSClientCertificateInfos struct {
NotAfter bool
NotBefore bool
Subject *TLSCLientCertificateSubjectInfos
Sans bool
}
// TLSCLientCertificateSubjectInfos contains the configuration for the certificate subject infos.
type TLSCLientCertificateSubjectInfos struct {
Country bool
Province bool
Locality bool
Organization bool
CommonName bool
SerialNumber bool
}
// TLSClientHeaders is a middleware that helps setup a few tls infos features.
type TLSClientHeaders struct {
PEM bool // pass the sanitized pem to the backend in a specific header
Infos *TLSClientCertificateInfos // pass selected informations from the client certificate
}
func newTLSCLientCertificateSubjectInfos(infos *types.TLSCLientCertificateSubjectInfos) *TLSCLientCertificateSubjectInfos {
if infos == nil {
return nil
}
return &TLSCLientCertificateSubjectInfos{
SerialNumber: infos.SerialNumber,
CommonName: infos.CommonName,
Country: infos.Country,
Locality: infos.Locality,
Organization: infos.Organization,
Province: infos.Province,
}
}
func newTLSClientInfos(infos *types.TLSClientCertificateInfos) *TLSClientCertificateInfos {
if infos == nil {
return nil
}
return &TLSClientCertificateInfos{
NotBefore: infos.NotBefore,
NotAfter: infos.NotAfter,
Sans: infos.Sans,
Subject: newTLSCLientCertificateSubjectInfos(infos.Subject),
}
}
// NewTLSClientHeaders constructs a new TLSClientHeaders instance from supplied frontend header struct.
func NewTLSClientHeaders(frontend *types.Frontend) *TLSClientHeaders {
if frontend == nil {
return nil
}
var pem bool
var infos *TLSClientCertificateInfos
if frontend.PassTLSClientCert != nil {
conf := frontend.PassTLSClientCert
pem = conf.PEM
infos = newTLSClientInfos(conf.Infos)
}
return &TLSClientHeaders{
PEM: pem,
Infos: infos,
}
}
func (s *TLSClientHeaders) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
s.ModifyRequestHeaders(r)
// If there is a next, call it.
if next != nil {
next(w, r)
}
}
// sanitize As we pass the raw certificates, remove the useless data and make it http request compliant
func sanitize(cert []byte) string {
s := string(cert)
r := strings.NewReplacer("-----BEGIN CERTIFICATE-----", "",
"-----END CERTIFICATE-----", "",
"\n", "")
cleaned := r.Replace(s)
return url.QueryEscape(cleaned)
}
// extractCertificate extract the certificate from the request
func extractCertificate(cert *x509.Certificate) string {
b := pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
certPEM := pem.EncodeToMemory(&b)
if certPEM == nil {
log.Error("Cannot extract the certificate content")
return ""
}
return sanitize(certPEM)
}
// getXForwardedTLSClientCert Build a string with the client certificates
func getXForwardedTLSClientCert(certs []*x509.Certificate) string {
var headerValues []string
for _, peerCert := range certs {
headerValues = append(headerValues, extractCertificate(peerCert))
}
return strings.Join(headerValues, ",")
}
// getSANs get the Subject Alternate Name values
func getSANs(cert *x509.Certificate) []string {
var sans []string
if cert == nil {
return sans
}
sans = append(cert.DNSNames, cert.EmailAddresses...)
var ips []string
for _, ip := range cert.IPAddresses {
ips = append(ips, ip.String())
}
sans = append(sans, ips...)
var uris []string
for _, uri := range cert.URIs {
uris = append(uris, uri.String())
}
return append(sans, uris...)
}
// getSubjectInfos extract the requested informations from the certificate subject
func (s *TLSClientHeaders) getSubjectInfos(cs *pkix.Name) string {
var subject string
if s.Infos != nil && s.Infos.Subject != nil {
options := s.Infos.Subject
var content []string
if options.Country && len(cs.Country) > 0 {
content = append(content, fmt.Sprintf("C=%s", cs.Country[0]))
}
if options.Province && len(cs.Province) > 0 {
content = append(content, fmt.Sprintf("ST=%s", cs.Province[0]))
}
if options.Locality && len(cs.Locality) > 0 {
content = append(content, fmt.Sprintf("L=%s", cs.Locality[0]))
}
if options.Organization && len(cs.Organization) > 0 {
content = append(content, fmt.Sprintf("O=%s", cs.Organization[0]))
}
if options.CommonName && len(cs.CommonName) > 0 {
content = append(content, fmt.Sprintf("CN=%s", cs.CommonName))
}
if len(content) > 0 {
subject = `Subject="` + strings.Join(content, ",") + `"`
}
}
return subject
}
// getXForwardedTLSClientCertInfos Build a string with the wanted client certificates informations
// like Subject="C=%s,ST=%s,L=%s,O=%s,CN=%s",NB=%d,NA=%d,SAN=%s;
func (s *TLSClientHeaders) getXForwardedTLSClientCertInfos(certs []*x509.Certificate) string {
var headerValues []string
for _, peerCert := range certs {
var values []string
var sans string
var nb string
var na string
subject := s.getSubjectInfos(&peerCert.Subject)
if len(subject) > 0 {
values = append(values, subject)
}
ci := s.Infos
if ci != nil {
if ci.NotBefore {
nb = fmt.Sprintf("NB=%d", uint64(peerCert.NotBefore.Unix()))
values = append(values, nb)
}
if ci.NotAfter {
na = fmt.Sprintf("NA=%d", uint64(peerCert.NotAfter.Unix()))
values = append(values, na)
}
if ci.Sans {
sans = fmt.Sprintf("SAN=%s", strings.Join(getSANs(peerCert), ","))
values = append(values, sans)
}
}
value := strings.Join(values, ",")
headerValues = append(headerValues, value)
}
return strings.Join(headerValues, ";")
}
// ModifyRequestHeaders set the wanted headers with the certificates informations
func (s *TLSClientHeaders) ModifyRequestHeaders(r *http.Request) {
if s.PEM {
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
r.Header.Set(xForwardedTLSClientCert, getXForwardedTLSClientCert(r.TLS.PeerCertificates))
} else {
log.Warn("Try to extract certificate on a request without TLS")
}
}
if s.Infos != nil {
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
headerContent := s.getXForwardedTLSClientCertInfos(r.TLS.PeerCertificates)
r.Header.Set(xForwardedTLSClientCertInfos, url.QueryEscape(headerContent))
} else {
log.Warn("Try to extract certificate on a request without TLS")
}
}
}

View file

@ -0,0 +1,799 @@
package middlewares
import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"net"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strings"
"testing"
"github.com/containous/traefik/old/types"
"github.com/containous/traefik/testhelpers"
"github.com/stretchr/testify/require"
)
const (
rootCrt = `-----BEGIN CERTIFICATE-----
MIIDhjCCAm6gAwIBAgIJAIKZlW9a3VrYMA0GCSqGSIb3DQEBCwUAMFgxCzAJBgNV
BAYTAkZSMRMwEQYDVQQIDApTb21lLVN0YXRlMREwDwYDVQQHDAhUb3Vsb3VzZTEh
MB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMB4XDTE4MDcxNzIwMzQz
OFoXDTE4MDgxNjIwMzQzOFowWDELMAkGA1UEBhMCRlIxEzARBgNVBAgMClNvbWUt
U3RhdGUxETAPBgNVBAcMCFRvdWxvdXNlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRn
aXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC1P8GJ
H9LkIxIIqK9MyUpushnjmjwccpSMB3OecISKYLy62QDIcAw6NzGcSe8hMwciMJr+
CdCjJlohybnaRI9hrJ3GPnI++UT/MMthf2IIcjmJxmD4k9L1fgs1V6zSTlo0+o0x
0gkAGlWvRkgA+3nt555ee84XQZuneKKeRRIlSA1ygycewFobZ/pGYijIEko+gYkV
sF3LnRGxNl673w+EQsvI7+z29T1nzjmM/xE7WlvnsrVd1/N61jAohLota0YTufwd
ioJZNryzuPejHBCiQRGMbJ7uEEZLiSCN6QiZEfqhS3AulykjgFXQQHn4zoVljSBR
UyLV0prIn5Scbks/AgMBAAGjUzBRMB0GA1UdDgQWBBTroRRnSgtkV+8dumtcftb/
lwIkATAfBgNVHSMEGDAWgBTroRRnSgtkV+8dumtcftb/lwIkATAPBgNVHRMBAf8E
BTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAJ67U5cLa0ZFa/7zQQT4ldkY6YOEgR
0LNoTu51hc+ozaXSvF8YIBzkEpEnbGS3x4xodrwEBZjK2LFhNu/33gkCAuhmedgk
KwZrQM6lqRFGHGVOlkVz+QrJ2EsKYaO4SCUIwVjijXRLA7A30G5C/CIh66PsMgBY
6QHXVPEWm/v1d1Q/DfFfFzSOa1n1rIUw03qVJsxqSwfwYcegOF8YvS/eH4HUr2gF
cEujh6CCnylf35ExHa45atr3+xxbOVdNjobISkYADtbhAAn4KjLS4v8W6445vxxj
G5EIZLjOHyWg1sGaHaaAPkVpZQg8EKm21c4hrEEMfel60AMSSzad/a/V
-----END CERTIFICATE-----`
minimalCert = `-----BEGIN CERTIFICATE-----
MIIDGTCCAgECCQCqLd75YLi2kDANBgkqhkiG9w0BAQsFADBYMQswCQYDVQQGEwJG
UjETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UEBwwIVG91bG91c2UxITAfBgNV
BAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0xODA3MTgwODI4MTZaFw0x
ODA4MTcwODI4MTZaMEUxCzAJBgNVBAYTAkZSMRMwEQYDVQQIDApTb21lLVN0YXRl
MSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3
DQEBAQUAA4IBDwAwggEKAoIBAQC/+frDMMTLQyXG34F68BPhQq0kzK4LIq9Y0/gl
FjySZNn1C0QDWA1ubVCAcA6yY204I9cxcQDPNrhC7JlS5QA8Y5rhIBrqQlzZizAi
Rj3NTrRjtGUtOScnHuJaWjLy03DWD+aMwb7q718xt5SEABmmUvLwQK+EjW2MeDwj
y8/UEIpvrRDmdhGaqv7IFpIDkcIF7FowJ/hwDvx3PMc+z/JWK0ovzpvgbx69AVbw
ZxCimeha65rOqVi+lEetD26le+WnOdYsdJ2IkmpPNTXGdfb15xuAc+gFXfMCh7Iw
3Ynl6dZtZM/Ok2kiA7/OsmVnRKkWrtBfGYkI9HcNGb3zrk6nAgMBAAEwDQYJKoZI
hvcNAQELBQADggEBAC/R+Yvhh1VUhcbK49olWsk/JKqfS3VIDQYZg1Eo+JCPbwgS
I1BSYVfMcGzuJTX6ua3m/AHzGF3Tap4GhF4tX12jeIx4R4utnjj7/YKkTvuEM2f4
xT56YqI7zalGScIB0iMeyNz1QcimRl+M/49au8ow9hNX8C2tcA2cwd/9OIj/6T8q
SBRHc6ojvbqZSJCO0jziGDT1L3D+EDgTjED4nd77v/NRdP+egb0q3P0s4dnQ/5AV
aQlQADUn61j3ScbGJ4NSeZFFvsl38jeRi/MEzp0bGgNBcPj6JHi7qbbauZcZfQ05
jECvgAY7Nfd9mZ1KtyNaW31is+kag7NsvjxU/kM=
-----END CERTIFICATE-----`
completeCert = `Certificate:
Data:
Version: 3 (0x2)
Serial Number: 3 (0x3)
Signature Algorithm: sha1WithRSAEncryption
Issuer: C=FR, ST=Some-State, L=Toulouse, O=Internet Widgits Pty Ltd
Validity
Not Before: Jul 18 08:00:16 2018 GMT
Not After : Jul 18 08:00:16 2019 GMT
Subject: C=FR, ST=SomeState, L=Toulouse, O=Cheese, CN=*.cheese.org
Subject Public Key Info:
Public Key Algorithm: rsaEncryption
Public-Key: (2048 bit)
Modulus:
00:a6:1f:96:7c:c1:cc:b8:1c:b5:91:5d:b8:bf:70:
bc:f7:b8:04:4f:2a:42:de:ea:c5:c3:19:0b:03:04:
ec:ef:a1:24:25:de:ad:05:e7:26:ea:89:6c:59:60:
10:18:0c:73:f1:bf:d3:cc:7b:ed:6b:9c:ea:1d:88:
e2:ee:14:81:d7:07:ee:87:95:3d:36:df:9c:38:b7:
7b:1e:2b:51:9c:4a:1f:d0:cc:5b:af:5d:6c:5c:35:
49:32:e4:01:5b:f9:8c:71:cf:62:48:5a:ea:b7:31:
58:e2:c6:d0:5b:1c:50:b5:5c:6d:5a:6f:da:41:5e:
d5:4c:6e:1a:21:f3:40:f9:9e:52:76:50:25:3e:03:
9b:87:19:48:5b:47:87:d3:67:c6:25:69:77:29:8e:
56:97:45:d9:6f:64:a8:4e:ad:35:75:2e:fc:6a:2e:
47:87:76:fc:4e:3e:44:e9:16:b2:c7:f0:23:98:13:
a2:df:15:23:cb:0c:3d:fd:48:5e:c7:2c:86:70:63:
8b:c6:c8:89:17:52:d5:a7:8e:cb:4e:11:9d:69:8e:
8e:59:cc:7e:a3:bd:a1:11:88:d7:cf:7b:8c:19:46:
9c:1b:7a:c9:39:81:4c:58:08:1f:c7:ce:b0:0e:79:
64:d3:11:72:65:e6:dd:bd:00:7f:22:30:46:9b:66:
9c:b9
Exponent: 65537 (0x10001)
X509v3 extensions:
X509v3 Basic Constraints:
CA:FALSE
X509v3 Subject Alternative Name:
DNS:*.cheese.org, DNS:*.cheese.net, DNS:cheese.in, IP Address:10.0.1.0, IP Address:10.0.1.2, email:test@cheese.org, email:test@cheese.net
X509v3 Subject Key Identifier:
AB:6B:89:25:11:FC:5E:7B:D4:B0:F7:D4:B6:D9:EB:D0:30:93:E5:58
Signature Algorithm: sha1WithRSAEncryption
ad:87:84:a0:88:a3:4c:d9:0a:c0:14:e4:2d:9a:1d:bb:57:b7:
12:ef:3a:fb:8b:b2:ce:32:b8:04:e6:59:c8:4f:14:6a:b5:12:
46:e9:c9:0a:11:64:ea:a1:86:20:96:0e:a7:40:e3:aa:e5:98:
91:36:89:77:b6:b9:73:7e:1a:58:19:ae:d1:14:83:1e:c1:5f:
a5:a0:32:bb:52:68:b4:8d:a3:1d:b3:08:d7:45:6e:3b:87:64:
7e:ef:46:e6:6f:d5:79:d7:1d:57:68:67:d8:18:39:61:5b:8b:
1a:7f:88:da:0a:51:9b:3d:6c:5d:b1:cf:b7:e9:1e:06:65:8e:
96:d3:61:96:f8:a2:61:f9:40:5e:fa:bc:76:b9:64:0e:6f:90:
37:de:ac:6d:7f:36:84:35:19:88:8c:26:af:3e:c3:6a:1a:03:
ed:d7:90:89:ed:18:4c:9e:94:1f:d8:ae:6c:61:36:17:72:f9:
bb:de:0a:56:9a:79:b4:7d:4a:9d:cb:4a:7d:71:9f:38:e7:8d:
f0:87:24:21:0a:24:1f:82:9a:6b:67:ce:7d:af:cb:91:6b:8a:
de:e6:d8:6f:a1:37:b9:2d:d0:cb:e8:4e:f4:43:af:ad:90:13:
7d:61:7a:ce:86:48:fc:00:8c:37:fb:e0:31:6b:e2:18:ad:fd:
1e:df:08:db
-----BEGIN CERTIFICATE-----
MIIDvTCCAqWgAwIBAgIBAzANBgkqhkiG9w0BAQUFADBYMQswCQYDVQQGEwJGUjET
MBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UEBwwIVG91bG91c2UxITAfBgNVBAoM
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0xODA3MTgwODAwMTZaFw0xOTA3
MTgwODAwMTZaMFwxCzAJBgNVBAYTAkZSMRIwEAYDVQQIDAlTb21lU3RhdGUxETAP
BgNVBAcMCFRvdWxvdXNlMQ8wDQYDVQQKDAZDaGVlc2UxFTATBgNVBAMMDCouY2hl
ZXNlLm9yZzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKYflnzBzLgc
tZFduL9wvPe4BE8qQt7qxcMZCwME7O+hJCXerQXnJuqJbFlgEBgMc/G/08x77Wuc
6h2I4u4UgdcH7oeVPTbfnDi3ex4rUZxKH9DMW69dbFw1STLkAVv5jHHPYkha6rcx
WOLG0FscULVcbVpv2kFe1UxuGiHzQPmeUnZQJT4Dm4cZSFtHh9NnxiVpdymOVpdF
2W9kqE6tNXUu/GouR4d2/E4+ROkWssfwI5gTot8VI8sMPf1IXscshnBji8bIiRdS
1aeOy04RnWmOjlnMfqO9oRGI1897jBlGnBt6yTmBTFgIH8fOsA55ZNMRcmXm3b0A
fyIwRptmnLkCAwEAAaOBjTCBijAJBgNVHRMEAjAAMF4GA1UdEQRXMFWCDCouY2hl
ZXNlLm9yZ4IMKi5jaGVlc2UubmV0ggljaGVlc2UuaW6HBAoAAQCHBAoAAQKBD3Rl
c3RAY2hlZXNlLm9yZ4EPdGVzdEBjaGVlc2UubmV0MB0GA1UdDgQWBBSra4klEfxe
e9Sw99S22evQMJPlWDANBgkqhkiG9w0BAQUFAAOCAQEArYeEoIijTNkKwBTkLZod
u1e3Eu86+4uyzjK4BOZZyE8UarUSRunJChFk6qGGIJYOp0DjquWYkTaJd7a5c34a
WBmu0RSDHsFfpaAyu1JotI2jHbMI10VuO4dkfu9G5m/VedcdV2hn2Bg5YVuLGn+I
2gpRmz1sXbHPt+keBmWOltNhlviiYflAXvq8drlkDm+QN96sbX82hDUZiIwmrz7D
ahoD7deQie0YTJ6UH9iubGE2F3L5u94KVpp5tH1KnctKfXGfOOeN8IckIQokH4Ka
a2fOfa/LkWuK3ubYb6E3uS3Qy+hO9EOvrZATfWF6zoZI/ACMN/vgMWviGK39Ht8I
2w==
-----END CERTIFICATE-----
`
)
func getCleanCertContents(certContents []string) string {
var re = regexp.MustCompile("-----BEGIN CERTIFICATE-----(?s)(.*)")
var cleanedCertContent []string
for _, certContent := range certContents {
cert := re.FindString(string(certContent))
cleanedCertContent = append(cleanedCertContent, sanitize([]byte(cert)))
}
return strings.Join(cleanedCertContent, ",")
}
func getCertificate(certContent string) *x509.Certificate {
roots := x509.NewCertPool()
ok := roots.AppendCertsFromPEM([]byte(rootCrt))
if !ok {
panic("failed to parse root certificate")
}
block, _ := pem.Decode([]byte(certContent))
if block == nil {
panic("failed to parse certificate PEM")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
panic("failed to parse certificate: " + err.Error())
}
return cert
}
func buildTLSWith(certContents []string) *tls.ConnectionState {
var peerCertificates []*x509.Certificate
for _, certContent := range certContents {
peerCertificates = append(peerCertificates, getCertificate(certContent))
}
return &tls.ConnectionState{PeerCertificates: peerCertificates}
}
var myPassTLSClientCustomHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("bar"))
})
func getExpectedSanitized(s string) string {
return url.QueryEscape(strings.Replace(s, "\n", "", -1))
}
func TestSanitize(t *testing.T) {
testCases := []struct {
desc string
toSanitize []byte
expected string
}{
{
desc: "Empty",
},
{
desc: "With a minimal cert",
toSanitize: []byte(minimalCert),
expected: getExpectedSanitized(`MIIDGTCCAgECCQCqLd75YLi2kDANBgkqhkiG9w0BAQsFADBYMQswCQYDVQQGEwJG
UjETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UEBwwIVG91bG91c2UxITAfBgNV
BAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0xODA3MTgwODI4MTZaFw0x
ODA4MTcwODI4MTZaMEUxCzAJBgNVBAYTAkZSMRMwEQYDVQQIDApTb21lLVN0YXRl
MSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3
DQEBAQUAA4IBDwAwggEKAoIBAQC/+frDMMTLQyXG34F68BPhQq0kzK4LIq9Y0/gl
FjySZNn1C0QDWA1ubVCAcA6yY204I9cxcQDPNrhC7JlS5QA8Y5rhIBrqQlzZizAi
Rj3NTrRjtGUtOScnHuJaWjLy03DWD+aMwb7q718xt5SEABmmUvLwQK+EjW2MeDwj
y8/UEIpvrRDmdhGaqv7IFpIDkcIF7FowJ/hwDvx3PMc+z/JWK0ovzpvgbx69AVbw
ZxCimeha65rOqVi+lEetD26le+WnOdYsdJ2IkmpPNTXGdfb15xuAc+gFXfMCh7Iw
3Ynl6dZtZM/Ok2kiA7/OsmVnRKkWrtBfGYkI9HcNGb3zrk6nAgMBAAEwDQYJKoZI
hvcNAQELBQADggEBAC/R+Yvhh1VUhcbK49olWsk/JKqfS3VIDQYZg1Eo+JCPbwgS
I1BSYVfMcGzuJTX6ua3m/AHzGF3Tap4GhF4tX12jeIx4R4utnjj7/YKkTvuEM2f4
xT56YqI7zalGScIB0iMeyNz1QcimRl+M/49au8ow9hNX8C2tcA2cwd/9OIj/6T8q
SBRHc6ojvbqZSJCO0jziGDT1L3D+EDgTjED4nd77v/NRdP+egb0q3P0s4dnQ/5AV
aQlQADUn61j3ScbGJ4NSeZFFvsl38jeRi/MEzp0bGgNBcPj6JHi7qbbauZcZfQ05
jECvgAY7Nfd9mZ1KtyNaW31is+kag7NsvjxU/kM=`),
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
require.Equal(t, test.expected, sanitize(test.toSanitize), "The sanitized certificates should be equal")
})
}
}
func TestTlsClientheadersWithPEM(t *testing.T) {
testCases := []struct {
desc string
certContents []string // set the request TLS attribute if defined
tlsClientCertHeaders *types.TLSClientHeaders
expectedHeader string
}{
{
desc: "No TLS, no option",
},
{
desc: "TLS, no option",
certContents: []string{minimalCert},
},
{
desc: "No TLS, with pem option true",
tlsClientCertHeaders: &types.TLSClientHeaders{PEM: true},
},
{
desc: "TLS with simple certificate, with pem option true",
certContents: []string{minimalCert},
tlsClientCertHeaders: &types.TLSClientHeaders{PEM: true},
expectedHeader: getCleanCertContents([]string{minimalCert}),
},
{
desc: "TLS with complete certificate, with pem option true",
certContents: []string{completeCert},
tlsClientCertHeaders: &types.TLSClientHeaders{PEM: true},
expectedHeader: getCleanCertContents([]string{completeCert}),
},
{
desc: "TLS with two certificate, with pem option true",
certContents: []string{minimalCert, completeCert},
tlsClientCertHeaders: &types.TLSClientHeaders{PEM: true},
expectedHeader: getCleanCertContents([]string{minimalCert, completeCert}),
},
}
for _, test := range testCases {
tlsClientHeaders := NewTLSClientHeaders(&types.Frontend{PassTLSClientCert: test.tlsClientCertHeaders})
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://example.com/foo", nil)
if test.certContents != nil && len(test.certContents) > 0 {
req.TLS = buildTLSWith(test.certContents)
}
tlsClientHeaders.ServeHTTP(res, req, myPassTLSClientCustomHandler)
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
require.Equal(t, http.StatusOK, res.Code, "Http Status should be OK")
require.Equal(t, "bar", res.Body.String(), "Should be the expected body")
if test.expectedHeader != "" {
require.Equal(t, getCleanCertContents(test.certContents), req.Header.Get(xForwardedTLSClientCert), "The request header should contain the cleaned certificate")
} else {
require.Empty(t, req.Header.Get(xForwardedTLSClientCert))
}
require.Empty(t, res.Header().Get(xForwardedTLSClientCert), "The response header should be always empty")
})
}
}
func TestGetSans(t *testing.T) {
urlFoo, err := url.Parse("my.foo.com")
require.NoError(t, err)
urlBar, err := url.Parse("my.bar.com")
require.NoError(t, err)
testCases := []struct {
desc string
cert *x509.Certificate // set the request TLS attribute if defined
expected []string
}{
{
desc: "With nil",
},
{
desc: "Certificate without Sans",
cert: &x509.Certificate{},
},
{
desc: "Certificate with all Sans",
cert: &x509.Certificate{
DNSNames: []string{"foo", "bar"},
EmailAddresses: []string{"test@test.com", "test2@test.com"},
IPAddresses: []net.IP{net.IPv4(10, 0, 0, 1), net.IPv4(10, 0, 0, 2)},
URIs: []*url.URL{urlFoo, urlBar},
},
expected: []string{"foo", "bar", "test@test.com", "test2@test.com", "10.0.0.1", "10.0.0.2", urlFoo.String(), urlBar.String()},
},
}
for _, test := range testCases {
sans := getSANs(test.cert)
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
if len(test.expected) > 0 {
for i, expected := range test.expected {
require.Equal(t, expected, sans[i])
}
} else {
require.Empty(t, sans)
}
})
}
}
func TestTlsClientheadersWithCertInfos(t *testing.T) {
minimalCertAllInfos := `Subject="C=FR,ST=Some-State,O=Internet Widgits Pty Ltd",NB=1531902496,NA=1534494496,SAN=`
completeCertAllInfos := `Subject="C=FR,ST=SomeState,L=Toulouse,O=Cheese,CN=*.cheese.org",NB=1531900816,NA=1563436816,SAN=*.cheese.org,*.cheese.net,cheese.in,test@cheese.org,test@cheese.net,10.0.1.0,10.0.1.2`
testCases := []struct {
desc string
certContents []string // set the request TLS attribute if defined
tlsClientCertHeaders *types.TLSClientHeaders
expectedHeader string
}{
{
desc: "No TLS, no option",
},
{
desc: "TLS, no option",
certContents: []string{minimalCert},
},
{
desc: "No TLS, with pem option true",
tlsClientCertHeaders: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
Subject: &types.TLSCLientCertificateSubjectInfos{
CommonName: true,
Organization: true,
Locality: true,
Province: true,
Country: true,
SerialNumber: true,
},
},
},
},
{
desc: "No TLS, with pem option true with no flag",
tlsClientCertHeaders: &types.TLSClientHeaders{
PEM: false,
Infos: &types.TLSClientCertificateInfos{
Subject: &types.TLSCLientCertificateSubjectInfos{},
},
},
},
{
desc: "TLS with simple certificate, with all infos",
certContents: []string{minimalCert},
tlsClientCertHeaders: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
NotAfter: true,
NotBefore: true,
Subject: &types.TLSCLientCertificateSubjectInfos{
CommonName: true,
Organization: true,
Locality: true,
Province: true,
Country: true,
SerialNumber: true,
},
Sans: true,
},
},
expectedHeader: url.QueryEscape(minimalCertAllInfos),
},
{
desc: "TLS with simple certificate, with some infos",
certContents: []string{minimalCert},
tlsClientCertHeaders: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
NotAfter: true,
Subject: &types.TLSCLientCertificateSubjectInfos{
Organization: true,
},
Sans: true,
},
},
expectedHeader: url.QueryEscape(`Subject="O=Internet Widgits Pty Ltd",NA=1534494496,SAN=`),
},
{
desc: "TLS with complete certificate, with all infos",
certContents: []string{completeCert},
tlsClientCertHeaders: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
NotAfter: true,
NotBefore: true,
Subject: &types.TLSCLientCertificateSubjectInfos{
CommonName: true,
Organization: true,
Locality: true,
Province: true,
Country: true,
SerialNumber: true,
},
Sans: true,
},
},
expectedHeader: url.QueryEscape(completeCertAllInfos),
},
{
desc: "TLS with 2 certificates, with all infos",
certContents: []string{minimalCert, completeCert},
tlsClientCertHeaders: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
NotAfter: true,
NotBefore: true,
Subject: &types.TLSCLientCertificateSubjectInfos{
CommonName: true,
Organization: true,
Locality: true,
Province: true,
Country: true,
SerialNumber: true,
},
Sans: true,
},
},
expectedHeader: url.QueryEscape(strings.Join([]string{minimalCertAllInfos, completeCertAllInfos}, ";")),
},
}
for _, test := range testCases {
tlsClientHeaders := NewTLSClientHeaders(&types.Frontend{PassTLSClientCert: test.tlsClientCertHeaders})
res := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://example.com/foo", nil)
if test.certContents != nil && len(test.certContents) > 0 {
req.TLS = buildTLSWith(test.certContents)
}
tlsClientHeaders.ServeHTTP(res, req, myPassTLSClientCustomHandler)
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
require.Equal(t, http.StatusOK, res.Code, "Http Status should be OK")
require.Equal(t, "bar", res.Body.String(), "Should be the expected body")
if test.expectedHeader != "" {
require.Equal(t, test.expectedHeader, req.Header.Get(xForwardedTLSClientCertInfos), "The request header should contain the cleaned certificate")
} else {
require.Empty(t, req.Header.Get(xForwardedTLSClientCertInfos))
}
require.Empty(t, res.Header().Get(xForwardedTLSClientCertInfos), "The response header should be always empty")
})
}
}
func TestNewTLSClientHeadersFromStruct(t *testing.T) {
testCases := []struct {
desc string
frontend *types.Frontend
expected *TLSClientHeaders
}{
{
desc: "Without frontend",
},
{
desc: "frontend without the option",
frontend: &types.Frontend{},
expected: &TLSClientHeaders{},
},
{
desc: "frontend with the pem set false",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
PEM: false,
},
},
expected: &TLSClientHeaders{PEM: false},
},
{
desc: "frontend with the pem set true",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
PEM: true,
},
},
expected: &TLSClientHeaders{PEM: true},
},
{
desc: "frontend with the Infos with no flag",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
NotAfter: false,
NotBefore: false,
Sans: false,
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{},
},
},
{
desc: "frontend with the Infos basic",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
NotAfter: true,
NotBefore: true,
Sans: true,
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
NotBefore: true,
NotAfter: true,
Sans: true,
},
},
},
{
desc: "frontend with the Infos NotAfter",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
NotAfter: true,
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
NotAfter: true,
},
},
},
{
desc: "frontend with the Infos NotBefore",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
NotBefore: true,
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
NotBefore: true,
},
},
},
{
desc: "frontend with the Infos Sans",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
Sans: true,
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
Sans: true,
},
},
},
{
desc: "frontend with the Infos Subject Organization",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
Subject: &types.TLSCLientCertificateSubjectInfos{
Organization: true,
},
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
Subject: &TLSCLientCertificateSubjectInfos{
Organization: true,
},
},
},
},
{
desc: "frontend with the Infos Subject Country",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
Subject: &types.TLSCLientCertificateSubjectInfos{
Country: true,
},
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
Subject: &TLSCLientCertificateSubjectInfos{
Country: true,
},
},
},
},
{
desc: "frontend with the Infos Subject SerialNumber",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
Subject: &types.TLSCLientCertificateSubjectInfos{
SerialNumber: true,
},
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
Subject: &TLSCLientCertificateSubjectInfos{
SerialNumber: true,
},
},
},
},
{
desc: "frontend with the Infos Subject Province",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
Subject: &types.TLSCLientCertificateSubjectInfos{
Province: true,
},
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
Subject: &TLSCLientCertificateSubjectInfos{
Province: true,
},
},
},
},
{
desc: "frontend with the Infos Subject Locality",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
Subject: &types.TLSCLientCertificateSubjectInfos{
Locality: true,
},
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
Subject: &TLSCLientCertificateSubjectInfos{
Locality: true,
},
},
},
},
{
desc: "frontend with the Infos Subject CommonName",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
Subject: &types.TLSCLientCertificateSubjectInfos{
CommonName: true,
},
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
Subject: &TLSCLientCertificateSubjectInfos{
CommonName: true,
},
},
},
},
{
desc: "frontend with the Infos NotBefore",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
Sans: true,
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
Sans: true,
},
},
},
{
desc: "frontend with the Infos all",
frontend: &types.Frontend{
PassTLSClientCert: &types.TLSClientHeaders{
Infos: &types.TLSClientCertificateInfos{
NotAfter: true,
NotBefore: true,
Subject: &types.TLSCLientCertificateSubjectInfos{
CommonName: true,
Country: true,
Locality: true,
Organization: true,
Province: true,
SerialNumber: true,
},
Sans: true,
},
},
},
expected: &TLSClientHeaders{
PEM: false,
Infos: &TLSClientCertificateInfos{
NotBefore: true,
NotAfter: true,
Sans: true,
Subject: &TLSCLientCertificateSubjectInfos{
Province: true,
Organization: true,
Locality: true,
Country: true,
CommonName: true,
SerialNumber: true,
},
}},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
require.Equal(t, test.expected, NewTLSClientHeaders(test.frontend))
})
}
}

View file

@ -0,0 +1,25 @@
package tracing
import "net/http"
// HTTPHeadersCarrier custom implementation to fix duplicated headers
// It has been fixed in https://github.com/opentracing/opentracing-go/pull/191
type HTTPHeadersCarrier http.Header
// Set conforms to the TextMapWriter interface.
func (c HTTPHeadersCarrier) Set(key, val string) {
h := http.Header(c)
h.Set(key, val)
}
// ForeachKey conforms to the TextMapReader interface.
func (c HTTPHeadersCarrier) ForeachKey(handler func(key, val string) error) error {
for k, vals := range c {
for _, v := range vals {
if err := handler(k, v); err != nil {
return err
}
}
}
return nil
}

View file

@ -0,0 +1,45 @@
package datadog
import (
"io"
"strings"
"github.com/containous/traefik/old/log"
"github.com/opentracing/opentracing-go"
ddtracer "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/opentracer"
datadog "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
)
// Name sets the name of this tracer
const Name = "datadog"
// Config provides configuration settings for a datadog tracer
type Config struct {
LocalAgentHostPort string `description:"Set datadog-agent's host:port that the reporter will used. Defaults to localhost:8126" export:"false"`
GlobalTag string `description:"Key:Value tag to be set on all the spans." export:"true"`
Debug bool `description:"Enable DataDog debug." export:"true"`
}
// Setup sets up the tracer
func (c *Config) Setup(serviceName string) (opentracing.Tracer, io.Closer, error) {
tag := strings.SplitN(c.GlobalTag, ":", 2)
value := ""
if len(tag) == 2 {
value = tag[1]
}
tracer := ddtracer.New(
datadog.WithAgentAddr(c.LocalAgentHostPort),
datadog.WithServiceName(serviceName),
datadog.WithGlobalTag(tag[0], value),
datadog.WithDebugMode(c.Debug),
)
// Without this, child spans are getting the NOOP tracer
opentracing.SetGlobalTracer(tracer)
log.Debug("DataDog tracer configured")
return tracer, nil, nil
}

View file

@ -0,0 +1,57 @@
package tracing
import (
"fmt"
"net/http"
"github.com/containous/traefik/old/log"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
"github.com/urfave/negroni"
)
type entryPointMiddleware struct {
entryPoint string
*Tracing
}
// NewEntryPoint creates a new middleware that the incoming request
func (t *Tracing) NewEntryPoint(name string) negroni.Handler {
log.Debug("Added entrypoint tracing middleware")
return &entryPointMiddleware{Tracing: t, entryPoint: name}
}
func (e *entryPointMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
opNameFunc := generateEntryPointSpanName
ctx, _ := e.Extract(opentracing.HTTPHeaders, HTTPHeadersCarrier(r.Header))
span := e.StartSpan(opNameFunc(r, e.entryPoint, e.SpanNameLimit), ext.RPCServerOption(ctx))
ext.Component.Set(span, e.ServiceName)
LogRequest(span, r)
ext.SpanKindRPCServer.Set(span)
r = r.WithContext(opentracing.ContextWithSpan(r.Context(), span))
recorder := newStatusCodeRecoder(w, 200)
next(recorder, r)
LogResponseCode(span, recorder.Status())
span.Finish()
}
// generateEntryPointSpanName will return a Span name of an appropriate lenth based on the 'spanLimit' argument. If needed, it will be truncated, but will not be less than 24 characters.
func generateEntryPointSpanName(r *http.Request, entryPoint string, spanLimit int) string {
name := fmt.Sprintf("Entrypoint %s %s", entryPoint, r.Host)
if spanLimit > 0 && len(name) > spanLimit {
if spanLimit < EntryPointMaxLengthNumber {
log.Warnf("SpanNameLimit is set to be less than required static number of characters, defaulting to %d + 3", EntryPointMaxLengthNumber)
spanLimit = EntryPointMaxLengthNumber + 3
}
hash := computeHash(name)
limit := (spanLimit - EntryPointMaxLengthNumber) / 2
name = fmt.Sprintf("Entrypoint %s %s %s", truncateString(entryPoint, limit), truncateString(r.Host, limit), hash)
}
return name
}

View file

@ -0,0 +1,70 @@
package tracing
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/opentracing/opentracing-go/ext"
"github.com/stretchr/testify/assert"
)
func TestEntryPointMiddlewareServeHTTP(t *testing.T) {
expectedTags := map[string]interface{}{
"span.kind": ext.SpanKindRPCServerEnum,
"http.method": "GET",
"component": "",
"http.url": "http://www.test.com",
"http.host": "www.test.com",
}
testCases := []struct {
desc string
entryPoint string
tracing *Tracing
expectedTags map[string]interface{}
expectedName string
}{
{
desc: "no truncation test",
entryPoint: "test",
tracing: &Tracing{
SpanNameLimit: 0,
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
},
expectedTags: expectedTags,
expectedName: "Entrypoint test www.test.com",
}, {
desc: "basic test",
entryPoint: "test",
tracing: &Tracing{
SpanNameLimit: 25,
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
},
expectedTags: expectedTags,
expectedName: "Entrypoint te... ww... 39b97e58",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
e := &entryPointMiddleware{
entryPoint: test.entryPoint,
Tracing: test.tracing,
}
next := func(http.ResponseWriter, *http.Request) {
span := test.tracing.tracer.(*MockTracer).Span
actual := span.Tags
assert.Equal(t, test.expectedTags, actual)
assert.Equal(t, test.expectedName, span.OpName)
}
e.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "http://www.test.com", nil), next)
})
}
}

View file

@ -0,0 +1,63 @@
package tracing
import (
"fmt"
"net/http"
"github.com/containous/traefik/old/log"
"github.com/opentracing/opentracing-go/ext"
"github.com/urfave/negroni"
)
type forwarderMiddleware struct {
frontend string
backend string
opName string
*Tracing
}
// NewForwarderMiddleware creates a new forwarder middleware that traces the outgoing request
func (t *Tracing) NewForwarderMiddleware(frontend, backend string) negroni.Handler {
log.Debugf("Added outgoing tracing middleware %s", frontend)
return &forwarderMiddleware{
Tracing: t,
frontend: frontend,
backend: backend,
opName: generateForwardSpanName(frontend, backend, t.SpanNameLimit),
}
}
func (f *forwarderMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
span, r, finish := StartSpan(r, f.opName, true)
defer finish()
span.SetTag("frontend.name", f.frontend)
span.SetTag("backend.name", f.backend)
ext.HTTPMethod.Set(span, r.Method)
ext.HTTPUrl.Set(span, fmt.Sprintf("%s%s", r.URL.String(), r.RequestURI))
span.SetTag("http.host", r.Host)
InjectRequestHeaders(r)
recorder := newStatusCodeRecoder(w, 200)
next(recorder, r)
LogResponseCode(span, recorder.Status())
}
// generateForwardSpanName will return a Span name of an appropriate lenth based on the 'spanLimit' argument. If needed, it will be truncated, but will not be less than 21 characters
func generateForwardSpanName(frontend, backend string, spanLimit int) string {
name := fmt.Sprintf("forward %s/%s", frontend, backend)
if spanLimit > 0 && len(name) > spanLimit {
if spanLimit < ForwardMaxLengthNumber {
log.Warnf("SpanNameLimit is set to be less than required static number of characters, defaulting to %d + 3", ForwardMaxLengthNumber)
spanLimit = ForwardMaxLengthNumber + 3
}
hash := computeHash(name)
limit := (spanLimit - ForwardMaxLengthNumber) / 2
name = fmt.Sprintf("forward %s/%s/%s", truncateString(frontend, limit), truncateString(backend, limit), hash)
}
return name
}

View file

@ -0,0 +1,93 @@
package tracing
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestTracingNewForwarderMiddleware(t *testing.T) {
testCases := []struct {
desc string
tracer *Tracing
frontend string
backend string
expected *forwarderMiddleware
}{
{
desc: "Simple Forward Tracer without truncation and hashing",
tracer: &Tracing{
SpanNameLimit: 101,
},
frontend: "some-service.domain.tld",
backend: "some-service.domain.tld",
expected: &forwarderMiddleware{
Tracing: &Tracing{
SpanNameLimit: 101,
},
frontend: "some-service.domain.tld",
backend: "some-service.domain.tld",
opName: "forward some-service.domain.tld/some-service.domain.tld",
},
}, {
desc: "Simple Forward Tracer with truncation and hashing",
tracer: &Tracing{
SpanNameLimit: 101,
},
frontend: "some-service-100.slug.namespace.environment.domain.tld",
backend: "some-service-100.slug.namespace.environment.domain.tld",
expected: &forwarderMiddleware{
Tracing: &Tracing{
SpanNameLimit: 101,
},
frontend: "some-service-100.slug.namespace.environment.domain.tld",
backend: "some-service-100.slug.namespace.environment.domain.tld",
opName: "forward some-service-100.slug.namespace.enviro.../some-service-100.slug.namespace.enviro.../bc4a0d48",
},
},
{
desc: "Exactly 101 chars",
tracer: &Tracing{
SpanNameLimit: 101,
},
frontend: "some-service1.namespace.environment.domain.tld",
backend: "some-service1.namespace.environment.domain.tld",
expected: &forwarderMiddleware{
Tracing: &Tracing{
SpanNameLimit: 101,
},
frontend: "some-service1.namespace.environment.domain.tld",
backend: "some-service1.namespace.environment.domain.tld",
opName: "forward some-service1.namespace.environment.domain.tld/some-service1.namespace.environment.domain.tld",
},
},
{
desc: "More than 101 chars",
tracer: &Tracing{
SpanNameLimit: 101,
},
frontend: "some-service1.frontend.namespace.environment.domain.tld",
backend: "some-service1.backend.namespace.environment.domain.tld",
expected: &forwarderMiddleware{
Tracing: &Tracing{
SpanNameLimit: 101,
},
frontend: "some-service1.frontend.namespace.environment.domain.tld",
backend: "some-service1.backend.namespace.environment.domain.tld",
opName: "forward some-service1.frontend.namespace.envir.../some-service1.backend.namespace.enviro.../fa49dd23",
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
actual := test.tracer.NewForwarderMiddleware(test.frontend, test.backend)
assert.Equal(t, test.expected, actual)
assert.True(t, len(test.expected.opName) <= test.tracer.SpanNameLimit)
})
}
}

View file

@ -0,0 +1,73 @@
package jaeger
import (
"fmt"
"io"
"github.com/containous/traefik/old/log"
"github.com/opentracing/opentracing-go"
jaegercfg "github.com/uber/jaeger-client-go/config"
"github.com/uber/jaeger-client-go/zipkin"
jaegermet "github.com/uber/jaeger-lib/metrics"
)
// Name sets the name of this tracer
const Name = "jaeger"
// Config provides configuration settings for a jaeger tracer
type Config struct {
SamplingServerURL string `description:"set the sampling server url." export:"false"`
SamplingType string `description:"set the sampling type." export:"true"`
SamplingParam float64 `description:"set the sampling parameter." export:"true"`
LocalAgentHostPort string `description:"set jaeger-agent's host:port that the reporter will used." export:"false"`
Gen128Bit bool `description:"generate 128 bit span IDs." export:"true"`
Propagation string `description:"which propgation format to use (jaeger/b3)." export:"true"`
}
// Setup sets up the tracer
func (c *Config) Setup(componentName string) (opentracing.Tracer, io.Closer, error) {
jcfg := jaegercfg.Configuration{
Sampler: &jaegercfg.SamplerConfig{
SamplingServerURL: c.SamplingServerURL,
Type: c.SamplingType,
Param: c.SamplingParam,
},
Reporter: &jaegercfg.ReporterConfig{
LogSpans: true,
LocalAgentHostPort: c.LocalAgentHostPort,
},
}
jMetricsFactory := jaegermet.NullFactory
opts := []jaegercfg.Option{
jaegercfg.Logger(&jaegerLogger{}),
jaegercfg.Metrics(jMetricsFactory),
jaegercfg.Gen128Bit(c.Gen128Bit),
}
switch c.Propagation {
case "b3":
p := zipkin.NewZipkinB3HTTPHeaderPropagator()
opts = append(opts,
jaegercfg.Injector(opentracing.HTTPHeaders, p),
jaegercfg.Extractor(opentracing.HTTPHeaders, p),
)
case "jaeger", "":
default:
return nil, nil, fmt.Errorf("unknown propagation format: %s", c.Propagation)
}
// Initialize tracer with a logger and a metrics factory
closer, err := jcfg.InitGlobalTracer(
componentName,
opts...,
)
if err != nil {
log.Warnf("Could not initialize jaeger tracer: %s", err.Error())
return nil, nil, err
}
log.Debug("Jaeger tracer configured")
return opentracing.GlobalTracer(), closer, nil
}

View file

@ -0,0 +1,15 @@
package jaeger
import "github.com/containous/traefik/old/log"
// jaegerLogger is an implementation of the Logger interface that delegates to traefik log
type jaegerLogger struct{}
func (l *jaegerLogger) Error(msg string) {
log.Errorf("Tracing jaeger error: %s", msg)
}
// Infof logs a message at debug priority
func (l *jaegerLogger) Infof(msg string, args ...interface{}) {
log.Debugf(msg, args...)
}

View file

@ -0,0 +1,57 @@
package tracing
import (
"bufio"
"net"
"net/http"
)
type statusCodeRecoder interface {
http.ResponseWriter
Status() int
}
type statusCodeWithoutCloseNotify struct {
http.ResponseWriter
status int
}
// WriteHeader captures the status code for later retrieval.
func (s *statusCodeWithoutCloseNotify) WriteHeader(status int) {
s.status = status
s.ResponseWriter.WriteHeader(status)
}
// Status get response status
func (s *statusCodeWithoutCloseNotify) Status() int {
return s.status
}
// Hijack hijacks the connection
func (s *statusCodeWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return s.ResponseWriter.(http.Hijacker).Hijack()
}
// Flush sends any buffered data to the client.
func (s *statusCodeWithoutCloseNotify) Flush() {
if flusher, ok := s.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
type statusCodeWithCloseNotify struct {
*statusCodeWithoutCloseNotify
}
func (s *statusCodeWithCloseNotify) CloseNotify() <-chan bool {
return s.ResponseWriter.(http.CloseNotifier).CloseNotify()
}
// newStatusCodeRecoder returns an initialized statusCodeRecoder.
func newStatusCodeRecoder(rw http.ResponseWriter, status int) statusCodeRecoder {
recorder := &statusCodeWithoutCloseNotify{rw, status}
if _, ok := rw.(http.CloseNotifier); ok {
return &statusCodeWithCloseNotify{recorder}
}
return recorder
}

View file

@ -0,0 +1,197 @@
package tracing
import (
"crypto/sha256"
"fmt"
"io"
"net/http"
"github.com/containous/traefik/old/log"
"github.com/containous/traefik/old/middlewares/tracing/datadog"
"github.com/containous/traefik/old/middlewares/tracing/jaeger"
"github.com/containous/traefik/old/middlewares/tracing/zipkin"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
)
// ForwardMaxLengthNumber defines the number of static characters in the Forwarding Span Trace name : 8 chars for 'forward ' + 8 chars for hash + 2 chars for '_'.
const ForwardMaxLengthNumber = 18
// EntryPointMaxLengthNumber defines the number of static characters in the Entrypoint Span Trace name : 11 chars for 'Entrypoint ' + 8 chars for hash + 2 chars for '_'.
const EntryPointMaxLengthNumber = 21
// TraceNameHashLength defines the number of characters to use from the head of the generated hash.
const TraceNameHashLength = 8
// Tracing middleware
type Tracing struct {
Backend string `description:"Selects the tracking backend ('jaeger','zipkin', 'datadog')." export:"true"`
ServiceName string `description:"Set the name for this service" export:"true"`
SpanNameLimit int `description:"Set the maximum character limit for Span names (default 0 = no limit)" export:"true"`
Jaeger *jaeger.Config `description:"Settings for jaeger"`
Zipkin *zipkin.Config `description:"Settings for zipkin"`
DataDog *datadog.Config `description:"Settings for DataDog"`
tracer opentracing.Tracer
closer io.Closer
}
// StartSpan delegates to opentracing.Tracer
func (t *Tracing) StartSpan(operationName string, opts ...opentracing.StartSpanOption) opentracing.Span {
return t.tracer.StartSpan(operationName, opts...)
}
// Inject delegates to opentracing.Tracer
func (t *Tracing) Inject(sm opentracing.SpanContext, format interface{}, carrier interface{}) error {
return t.tracer.Inject(sm, format, carrier)
}
// Extract delegates to opentracing.Tracer
func (t *Tracing) Extract(format interface{}, carrier interface{}) (opentracing.SpanContext, error) {
return t.tracer.Extract(format, carrier)
}
// Backend describes things we can use to setup tracing
type Backend interface {
Setup(serviceName string) (opentracing.Tracer, io.Closer, error)
}
// Setup Tracing middleware
func (t *Tracing) Setup() {
var err error
switch t.Backend {
case jaeger.Name:
t.tracer, t.closer, err = t.Jaeger.Setup(t.ServiceName)
case zipkin.Name:
t.tracer, t.closer, err = t.Zipkin.Setup(t.ServiceName)
case datadog.Name:
t.tracer, t.closer, err = t.DataDog.Setup(t.ServiceName)
default:
log.Warnf("Unknown tracer %q", t.Backend)
return
}
if err != nil {
log.Warnf("Could not initialize %s tracing: %v", t.Backend, err)
}
}
// IsEnabled determines if tracing was successfully activated
func (t *Tracing) IsEnabled() bool {
if t == nil || t.tracer == nil {
return false
}
return true
}
// Close tracer
func (t *Tracing) Close() {
if t.closer != nil {
err := t.closer.Close()
if err != nil {
log.Warn(err)
}
}
}
// LogRequest used to create span tags from the request
func LogRequest(span opentracing.Span, r *http.Request) {
if span != nil && r != nil {
ext.HTTPMethod.Set(span, r.Method)
ext.HTTPUrl.Set(span, r.URL.String())
span.SetTag("http.host", r.Host)
}
}
// LogResponseCode used to log response code in span
func LogResponseCode(span opentracing.Span, code int) {
if span != nil {
ext.HTTPStatusCode.Set(span, uint16(code))
if code >= 400 {
ext.Error.Set(span, true)
}
}
}
// GetSpan used to retrieve span from request context
func GetSpan(r *http.Request) opentracing.Span {
return opentracing.SpanFromContext(r.Context())
}
// InjectRequestHeaders used to inject OpenTracing headers into the request
func InjectRequestHeaders(r *http.Request) {
if span := GetSpan(r); span != nil {
err := opentracing.GlobalTracer().Inject(
span.Context(),
opentracing.HTTPHeaders,
HTTPHeadersCarrier(r.Header))
if err != nil {
log.Error(err)
}
}
}
// LogEventf logs an event to the span in the request context.
func LogEventf(r *http.Request, format string, args ...interface{}) {
if span := GetSpan(r); span != nil {
span.LogKV("event", fmt.Sprintf(format, args...))
}
}
// StartSpan starts a new span from the one in the request context
func StartSpan(r *http.Request, operationName string, spanKinClient bool, opts ...opentracing.StartSpanOption) (opentracing.Span, *http.Request, func()) {
span, ctx := opentracing.StartSpanFromContext(r.Context(), operationName, opts...)
if spanKinClient {
ext.SpanKindRPCClient.Set(span)
}
r = r.WithContext(ctx)
return span, r, func() {
span.Finish()
}
}
// SetError flags the span associated with this request as in error
func SetError(r *http.Request) {
if span := GetSpan(r); span != nil {
ext.Error.Set(span, true)
}
}
// SetErrorAndDebugLog flags the span associated with this request as in error and create a debug log.
func SetErrorAndDebugLog(r *http.Request, format string, args ...interface{}) {
SetError(r)
log.Debugf(format, args...)
LogEventf(r, format, args...)
}
// SetErrorAndWarnLog flags the span associated with this request as in error and create a debug log.
func SetErrorAndWarnLog(r *http.Request, format string, args ...interface{}) {
SetError(r)
log.Warnf(format, args...)
LogEventf(r, format, args...)
}
// truncateString reduces the length of the 'str' argument to 'num' - 3 and adds a '...' suffix to the tail.
func truncateString(str string, num int) string {
text := str
if len(str) > num {
if num > 3 {
num -= 3
}
text = str[0:num] + "..."
}
return text
}
// computeHash returns the first TraceNameHashLength character of the sha256 hash for 'name' argument.
func computeHash(name string) string {
data := []byte(name)
hash := sha256.New()
if _, err := hash.Write(data); err != nil {
// Impossible case
log.Errorf("Fail to create Span name hash for %s: %v", name, err)
}
return fmt.Sprintf("%x", hash.Sum(nil))[:TraceNameHashLength]
}

View file

@ -0,0 +1,133 @@
package tracing
import (
"testing"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/log"
"github.com/stretchr/testify/assert"
)
type MockTracer struct {
Span *MockSpan
}
type MockSpan struct {
OpName string
Tags map[string]interface{}
}
type MockSpanContext struct {
}
// MockSpanContext:
func (n MockSpanContext) ForeachBaggageItem(handler func(k, v string) bool) {}
// MockSpan:
func (n MockSpan) Context() opentracing.SpanContext { return MockSpanContext{} }
func (n MockSpan) SetBaggageItem(key, val string) opentracing.Span {
return MockSpan{Tags: make(map[string]interface{})}
}
func (n MockSpan) BaggageItem(key string) string { return "" }
func (n MockSpan) SetTag(key string, value interface{}) opentracing.Span {
n.Tags[key] = value
return n
}
func (n MockSpan) LogFields(fields ...log.Field) {}
func (n MockSpan) LogKV(keyVals ...interface{}) {}
func (n MockSpan) Finish() {}
func (n MockSpan) FinishWithOptions(opts opentracing.FinishOptions) {}
func (n MockSpan) SetOperationName(operationName string) opentracing.Span { return n }
func (n MockSpan) Tracer() opentracing.Tracer { return MockTracer{} }
func (n MockSpan) LogEvent(event string) {}
func (n MockSpan) LogEventWithPayload(event string, payload interface{}) {}
func (n MockSpan) Log(data opentracing.LogData) {}
func (n MockSpan) Reset() {
n.Tags = make(map[string]interface{})
}
// StartSpan belongs to the Tracer interface.
func (n MockTracer) StartSpan(operationName string, opts ...opentracing.StartSpanOption) opentracing.Span {
n.Span.OpName = operationName
return n.Span
}
// Inject belongs to the Tracer interface.
func (n MockTracer) Inject(sp opentracing.SpanContext, format interface{}, carrier interface{}) error {
return nil
}
// Extract belongs to the Tracer interface.
func (n MockTracer) Extract(format interface{}, carrier interface{}) (opentracing.SpanContext, error) {
return nil, opentracing.ErrSpanContextNotFound
}
func TestTruncateString(t *testing.T) {
testCases := []struct {
desc string
text string
limit int
expected string
}{
{
desc: "short text less than limit 10",
text: "short",
limit: 10,
expected: "short",
},
{
desc: "basic truncate with limit 10",
text: "some very long pice of text",
limit: 10,
expected: "some ve...",
},
{
desc: "truncate long FQDN to 39 chars",
text: "some-service-100.slug.namespace.environment.domain.tld",
limit: 39,
expected: "some-service-100.slug.namespace.envi...",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
actual := truncateString(test.text, test.limit)
assert.Equal(t, test.expected, actual)
assert.True(t, len(actual) <= test.limit)
})
}
}
func TestComputeHash(t *testing.T) {
testCases := []struct {
desc string
text string
expected string
}{
{
desc: "hashing",
text: "some very long pice of text",
expected: "0258ea1c",
},
{
desc: "short text less than limit 10",
text: "short",
expected: "f9b0078b",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
actual := computeHash(test.text)
assert.Equal(t, test.expected, actual)
})
}
}

View file

@ -0,0 +1,66 @@
package tracing
import (
"net/http"
"github.com/urfave/negroni"
)
// NewNegroniHandlerWrapper return a negroni.Handler struct
func (t *Tracing) NewNegroniHandlerWrapper(name string, handler negroni.Handler, clientSpanKind bool) negroni.Handler {
if t.IsEnabled() && handler != nil {
return &NegroniHandlerWrapper{
name: name,
next: handler,
clientSpanKind: clientSpanKind,
}
}
return handler
}
// NewHTTPHandlerWrapper return a http.Handler struct
func (t *Tracing) NewHTTPHandlerWrapper(name string, handler http.Handler, clientSpanKind bool) http.Handler {
if t.IsEnabled() && handler != nil {
return &HTTPHandlerWrapper{
name: name,
handler: handler,
clientSpanKind: clientSpanKind,
}
}
return handler
}
// NegroniHandlerWrapper is used to wrap negroni handler middleware
type NegroniHandlerWrapper struct {
name string
next negroni.Handler
clientSpanKind bool
}
func (t *NegroniHandlerWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
var finish func()
_, r, finish = StartSpan(r, t.name, t.clientSpanKind)
defer finish()
if t.next != nil {
t.next.ServeHTTP(rw, r, next)
}
}
// HTTPHandlerWrapper is used to wrap http handler middleware
type HTTPHandlerWrapper struct {
name string
handler http.Handler
clientSpanKind bool
}
func (t *HTTPHandlerWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
var finish func()
_, r, finish = StartSpan(r, t.name, t.clientSpanKind)
defer finish()
if t.handler != nil {
t.handler.ServeHTTP(rw, r)
}
}

View file

@ -0,0 +1,49 @@
package zipkin
import (
"io"
"time"
"github.com/containous/traefik/old/log"
"github.com/opentracing/opentracing-go"
zipkin "github.com/openzipkin/zipkin-go-opentracing"
)
// Name sets the name of this tracer
const Name = "zipkin"
// Config provides configuration settings for a zipkin tracer
type Config struct {
HTTPEndpoint string `description:"HTTP Endpoint to report traces to." export:"false"`
SameSpan bool `description:"Use Zipkin SameSpan RPC style traces." export:"true"`
ID128Bit bool `description:"Use Zipkin 128 bit root span IDs." export:"true"`
Debug bool `description:"Enable Zipkin debug." export:"true"`
SampleRate float64 `description:"The rate between 0.0 and 1.0 of requests to trace." export:"true"`
}
// Setup sets up the tracer
func (c *Config) Setup(serviceName string) (opentracing.Tracer, io.Closer, error) {
collector, err := zipkin.NewHTTPCollector(c.HTTPEndpoint)
if err != nil {
return nil, nil, err
}
recorder := zipkin.NewRecorder(collector, c.Debug, "0.0.0.0:0", serviceName)
tracer, err := zipkin.NewTracer(
recorder,
zipkin.ClientServerSameSpan(c.SameSpan),
zipkin.TraceID128Bit(c.ID128Bit),
zipkin.DebugMode(c.Debug),
zipkin.WithSampler(zipkin.NewBoundarySampler(c.SampleRate, time.Now().Unix())),
)
if err != nil {
return nil, nil, err
}
// Without this, child spans are getting the NOOP tracer
opentracing.SetGlobalTracer(tracer)
log.Debug("Zipkin tracer configured")
return tracer, collector, nil
}

36
old/ping/ping.go Normal file
View file

@ -0,0 +1,36 @@
package ping
import (
"context"
"fmt"
"net/http"
"github.com/containous/mux"
)
// Handler expose ping routes
type Handler struct {
EntryPoint string `description:"Ping entryPoint" export:"true"`
terminating bool
}
// WithContext causes the ping endpoint to serve non 200 responses.
func (h *Handler) WithContext(ctx context.Context) {
go func() {
<-ctx.Done()
h.terminating = true
}()
}
// AddRoutes add ping routes on a router
func (h *Handler) AddRoutes(router *mux.Router) {
router.Methods(http.MethodGet, http.MethodHead).Path("/ping").
HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
statusCode := http.StatusOK
if h.terminating {
statusCode = http.StatusServiceUnavailable
}
response.WriteHeader(statusCode)
fmt.Fprint(response, http.StatusText(statusCode))
})
}

View file

@ -0,0 +1,83 @@
package acme
import (
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"github.com/containous/traefik/old/log"
"github.com/xenolf/lego/acme"
)
// Account is used to store lets encrypt registration info
type Account struct {
Email string
Registration *acme.RegistrationResource
PrivateKey []byte
KeyType acme.KeyType
}
const (
// RegistrationURLPathV1Regexp is a regexp which match ACME registration URL in the V1 format
RegistrationURLPathV1Regexp = `^.*/acme/reg/\d+$`
)
// NewAccount creates an account
func NewAccount(email string, keyTypeValue string) (*Account, error) {
keyType := GetKeyType(keyTypeValue)
// Create a user. New accounts need an email and private key to start
privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, err
}
return &Account{
Email: email,
PrivateKey: x509.MarshalPKCS1PrivateKey(privateKey),
KeyType: keyType,
}, nil
}
// GetEmail returns email
func (a *Account) GetEmail() string {
return a.Email
}
// GetRegistration returns lets encrypt registration resource
func (a *Account) GetRegistration() *acme.RegistrationResource {
return a.Registration
}
// GetPrivateKey returns private key
func (a *Account) GetPrivateKey() crypto.PrivateKey {
if privateKey, err := x509.ParsePKCS1PrivateKey(a.PrivateKey); err == nil {
return privateKey
}
log.Errorf("Cannot unmarshal private key %+v", a.PrivateKey)
return nil
}
// GetKeyType used to determine which algo to used
func GetKeyType(value string) acme.KeyType {
switch value {
case "EC256":
return acme.EC256
case "EC384":
return acme.EC384
case "RSA2048":
return acme.RSA2048
case "RSA4096":
return acme.RSA4096
case "RSA8192":
return acme.RSA8192
case "":
log.Infof("The key type is empty. Use default key type %v.", acme.RSA4096)
return acme.RSA4096
default:
log.Infof("Unable to determine key type value %q. Use default key type %v.", value, acme.RSA4096)
return acme.RSA4096
}
}

View file

@ -0,0 +1,86 @@
package acme
import (
"net"
"net/http"
"time"
"github.com/cenk/backoff"
"github.com/containous/mux"
"github.com/containous/traefik/old/log"
"github.com/containous/traefik/safe"
"github.com/xenolf/lego/acme"
)
var _ acme.ChallengeProviderTimeout = (*challengeHTTP)(nil)
type challengeHTTP struct {
Store Store
}
// Present presents a challenge to obtain new ACME certificate
func (c *challengeHTTP) Present(domain, token, keyAuth string) error {
return c.Store.SetHTTPChallengeToken(token, domain, []byte(keyAuth))
}
// CleanUp cleans the challenges when certificate is obtained
func (c *challengeHTTP) CleanUp(domain, token, keyAuth string) error {
return c.Store.RemoveHTTPChallengeToken(token, domain)
}
// Timeout calculates the maximum of time allowed to resolved an ACME challenge
func (c *challengeHTTP) Timeout() (timeout, interval time.Duration) {
return 60 * time.Second, 5 * time.Second
}
func getTokenValue(token, domain string, store Store) []byte {
log.Debugf("Looking for an existing ACME challenge for token %v...", token)
var result []byte
operation := func() error {
var err error
result, err = store.GetHTTPChallengeToken(token, domain)
return err
}
notify := func(err error, time time.Duration) {
log.Errorf("Error getting challenge for token retrying in %s", time)
}
ebo := backoff.NewExponentialBackOff()
ebo.MaxElapsedTime = 60 * time.Second
err := backoff.RetryNotify(safe.OperationWithRecover(operation), ebo, notify)
if err != nil {
log.Errorf("Error getting challenge for token: %v", err)
return []byte{}
}
return result
}
// AddRoutes add routes on internal router
func (p *Provider) AddRoutes(router *mux.Router) {
router.Methods(http.MethodGet).
Path(acme.HTTP01ChallengePath("{token}")).
Handler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
vars := mux.Vars(req)
if token, ok := vars["token"]; ok {
domain, _, err := net.SplitHostPort(req.Host)
if err != nil {
log.Debugf("Unable to split host and port: %v. Fallback to request host.", err)
domain = req.Host
}
tokenValue := getTokenValue(token, domain, p.Store)
if len(tokenValue) > 0 {
rw.WriteHeader(http.StatusOK)
_, err = rw.Write(tokenValue)
if err != nil {
log.Errorf("Unable to write token : %v", err)
}
return
}
}
rw.WriteHeader(http.StatusNotFound)
}))
}

View file

@ -0,0 +1,52 @@
package acme
import (
"crypto/tls"
"github.com/containous/traefik/old/log"
"github.com/containous/traefik/old/types"
"github.com/xenolf/lego/acme"
)
var _ acme.ChallengeProvider = (*challengeTLSALPN)(nil)
type challengeTLSALPN struct {
Store Store
}
func (c *challengeTLSALPN) Present(domain, token, keyAuth string) error {
log.Debugf("TLS Challenge Present temp certificate for %s", domain)
certPEMBlock, keyPEMBlock, err := acme.TLSALPNChallengeBlocks(domain, keyAuth)
if err != nil {
return err
}
cert := &Certificate{Certificate: certPEMBlock, Key: keyPEMBlock, Domain: types.Domain{Main: "TEMP-" + domain}}
return c.Store.AddTLSChallenge(domain, cert)
}
func (c *challengeTLSALPN) CleanUp(domain, token, keyAuth string) error {
log.Debugf("TLS Challenge CleanUp temp certificate for %s", domain)
return c.Store.RemoveTLSChallenge(domain)
}
// GetTLSALPNCertificate Get the temp certificate for ACME TLS-ALPN-O1 challenge.
func (p *Provider) GetTLSALPNCertificate(domain string) (*tls.Certificate, error) {
cert, err := p.Store.GetTLSChallenge(domain)
if err != nil {
return nil, err
}
if cert == nil {
return nil, nil
}
certificate, err := tls.X509KeyPair(cert.Certificate, cert.Key)
if err != nil {
return nil, err
}
return &certificate, nil
}

View file

@ -0,0 +1,251 @@
package acme
import (
"encoding/json"
"fmt"
"io/ioutil"
"os"
"regexp"
"sync"
"github.com/containous/traefik/old/log"
"github.com/containous/traefik/safe"
)
var _ Store = (*LocalStore)(nil)
// LocalStore Store implementation for local file
type LocalStore struct {
filename string
storedData *StoredData
SaveDataChan chan *StoredData `json:"-"`
lock sync.RWMutex
}
// NewLocalStore initializes a new LocalStore with a file name
func NewLocalStore(filename string) *LocalStore {
store := &LocalStore{filename: filename, SaveDataChan: make(chan *StoredData)}
store.listenSaveAction()
return store
}
func (s *LocalStore) get() (*StoredData, error) {
if s.storedData == nil {
s.storedData = &StoredData{
HTTPChallenges: make(map[string]map[string][]byte),
TLSChallenges: make(map[string]*Certificate),
}
hasData, err := CheckFile(s.filename)
if err != nil {
return nil, err
}
if hasData {
f, err := os.Open(s.filename)
if err != nil {
return nil, err
}
defer f.Close()
file, err := ioutil.ReadAll(f)
if err != nil {
return nil, err
}
if len(file) > 0 {
if err := json.Unmarshal(file, s.storedData); err != nil {
return nil, err
}
}
// Check if ACME Account is in ACME V1 format
if s.storedData.Account != nil && s.storedData.Account.Registration != nil {
isOldRegistration, err := regexp.MatchString(RegistrationURLPathV1Regexp, s.storedData.Account.Registration.URI)
if err != nil {
return nil, err
}
if isOldRegistration {
log.Debug("Reset ACME account.")
s.storedData.Account = nil
s.SaveDataChan <- s.storedData
}
}
// Delete all certificates with no value
var certificates []*Certificate
for _, certificate := range s.storedData.Certificates {
if len(certificate.Certificate) == 0 || len(certificate.Key) == 0 {
log.Debugf("Delete certificate %v for domains %v which have no value.", certificate, certificate.Domain.ToStrArray())
continue
}
certificates = append(certificates, certificate)
}
if len(certificates) < len(s.storedData.Certificates) {
s.storedData.Certificates = certificates
s.SaveDataChan <- s.storedData
}
}
}
return s.storedData, nil
}
// listenSaveAction listens to a chan to store ACME data in json format into LocalStore.filename
func (s *LocalStore) listenSaveAction() {
safe.Go(func() {
for object := range s.SaveDataChan {
data, err := json.MarshalIndent(object, "", " ")
if err != nil {
log.Error(err)
}
err = ioutil.WriteFile(s.filename, data, 0600)
if err != nil {
log.Error(err)
}
}
})
}
// GetAccount returns ACME Account
func (s *LocalStore) GetAccount() (*Account, error) {
storedData, err := s.get()
if err != nil {
return nil, err
}
return storedData.Account, nil
}
// SaveAccount stores ACME Account
func (s *LocalStore) SaveAccount(account *Account) error {
storedData, err := s.get()
if err != nil {
return err
}
storedData.Account = account
s.SaveDataChan <- storedData
return nil
}
// GetCertificates returns ACME Certificates list
func (s *LocalStore) GetCertificates() ([]*Certificate, error) {
storedData, err := s.get()
if err != nil {
return nil, err
}
return storedData.Certificates, nil
}
// SaveCertificates stores ACME Certificates list
func (s *LocalStore) SaveCertificates(certificates []*Certificate) error {
storedData, err := s.get()
if err != nil {
return err
}
storedData.Certificates = certificates
s.SaveDataChan <- storedData
return nil
}
// GetHTTPChallengeToken Get the http challenge token from the store
func (s *LocalStore) GetHTTPChallengeToken(token, domain string) ([]byte, error) {
s.lock.RLock()
defer s.lock.RUnlock()
if s.storedData.HTTPChallenges == nil {
s.storedData.HTTPChallenges = map[string]map[string][]byte{}
}
if _, ok := s.storedData.HTTPChallenges[token]; !ok {
return nil, fmt.Errorf("cannot find challenge for token %v", token)
}
result, ok := s.storedData.HTTPChallenges[token][domain]
if !ok {
return nil, fmt.Errorf("cannot find challenge for token %v", token)
}
return result, nil
}
// SetHTTPChallengeToken Set the http challenge token in the store
func (s *LocalStore) SetHTTPChallengeToken(token, domain string, keyAuth []byte) error {
s.lock.Lock()
defer s.lock.Unlock()
if s.storedData.HTTPChallenges == nil {
s.storedData.HTTPChallenges = map[string]map[string][]byte{}
}
if _, ok := s.storedData.HTTPChallenges[token]; !ok {
s.storedData.HTTPChallenges[token] = map[string][]byte{}
}
s.storedData.HTTPChallenges[token][domain] = keyAuth
return nil
}
// RemoveHTTPChallengeToken Remove the http challenge token in the store
func (s *LocalStore) RemoveHTTPChallengeToken(token, domain string) error {
s.lock.Lock()
defer s.lock.Unlock()
if s.storedData.HTTPChallenges == nil {
return nil
}
if _, ok := s.storedData.HTTPChallenges[token]; ok {
if _, domainOk := s.storedData.HTTPChallenges[token][domain]; domainOk {
delete(s.storedData.HTTPChallenges[token], domain)
}
if len(s.storedData.HTTPChallenges[token]) == 0 {
delete(s.storedData.HTTPChallenges, token)
}
}
return nil
}
// AddTLSChallenge Add a certificate to the ACME TLS-ALPN-01 certificates storage
func (s *LocalStore) AddTLSChallenge(domain string, cert *Certificate) error {
s.lock.Lock()
defer s.lock.Unlock()
if s.storedData.TLSChallenges == nil {
s.storedData.TLSChallenges = make(map[string]*Certificate)
}
s.storedData.TLSChallenges[domain] = cert
return nil
}
// GetTLSChallenge Get a certificate from the ACME TLS-ALPN-01 certificates storage
func (s *LocalStore) GetTLSChallenge(domain string) (*Certificate, error) {
s.lock.Lock()
defer s.lock.Unlock()
if s.storedData.TLSChallenges == nil {
s.storedData.TLSChallenges = make(map[string]*Certificate)
}
return s.storedData.TLSChallenges[domain], nil
}
// RemoveTLSChallenge Remove a certificate from the ACME TLS-ALPN-01 certificates storage
func (s *LocalStore) RemoveTLSChallenge(domain string) error {
s.lock.Lock()
defer s.lock.Unlock()
if s.storedData.TLSChallenges == nil {
return nil
}
delete(s.storedData.TLSChallenges, domain)
return nil
}

View file

@ -0,0 +1,35 @@
// +build !windows
package acme
import (
"fmt"
"os"
)
// CheckFile checks file permissions and content size
func CheckFile(name string) (bool, error) {
f, err := os.Open(name)
if err != nil {
if os.IsNotExist(err) {
f, err = os.Create(name)
if err != nil {
return false, err
}
return false, f.Chmod(0600)
}
return false, err
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
return false, err
}
if fi.Mode().Perm()&0077 != 0 {
return false, fmt.Errorf("permissions %o for %s are too open, please use 600", fi.Mode().Perm(), name)
}
return fi.Size() > 0, nil
}

View file

@ -0,0 +1,27 @@
package acme
import "os"
// CheckFile checks file content size
// Do not check file permissions on Windows right now
func CheckFile(name string) (bool, error) {
f, err := os.Open(name)
if err != nil {
if os.IsNotExist(err) {
f, err = os.Create(name)
if err != nil {
return false, err
}
return false, f.Chmod(0600)
}
return false, err
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
return false, err
}
return fi.Size() > 0, nil
}

View file

@ -0,0 +1,826 @@
package acme
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
fmtlog "log"
"net"
"net/url"
"reflect"
"strings"
"sync"
"time"
"github.com/cenk/backoff"
"github.com/containous/flaeg/parse"
"github.com/containous/traefik/old/log"
"github.com/containous/traefik/old/types"
"github.com/containous/traefik/rules"
"github.com/containous/traefik/safe"
traefiktls "github.com/containous/traefik/tls"
"github.com/containous/traefik/version"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
"github.com/xenolf/lego/acme"
legolog "github.com/xenolf/lego/log"
"github.com/xenolf/lego/providers/dns"
)
var (
// OSCPMustStaple enables OSCP stapling as from https://github.com/xenolf/lego/issues/270
OSCPMustStaple = false
)
// Configuration holds ACME configuration provided by users
type Configuration struct {
Email string `description:"Email address used for registration"`
ACMELogging bool `description:"Enable debug logging of ACME actions."`
CAServer string `description:"CA server to use."`
Storage string `description:"Storage to use."`
EntryPoint string `description:"EntryPoint to use."`
KeyType string `description:"KeyType used for generating certificate private key. Allow value 'EC256', 'EC384', 'RSA2048', 'RSA4096', 'RSA8192'. Default to 'RSA4096'"`
OnHostRule bool `description:"Enable certificate generation on frontends Host rules."`
OnDemand bool `description:"Enable on demand certificate generation. This will request a certificate from Let's Encrypt during the first TLS handshake for a hostname that does not yet have a certificate."` // Deprecated
DNSChallenge *DNSChallenge `description:"Activate DNS-01 Challenge"`
HTTPChallenge *HTTPChallenge `description:"Activate HTTP-01 Challenge"`
TLSChallenge *TLSChallenge `description:"Activate TLS-ALPN-01 Challenge"`
Domains []types.Domain `description:"CN and SANs (alternative domains) to each main domain using format: --acme.domains='main.com,san1.com,san2.com' --acme.domains='*.main.net'. No SANs for wildcards domain. Wildcard domains only accepted with DNSChallenge"`
}
// Provider holds configurations of the provider.
type Provider struct {
*Configuration
Store Store
certificates []*Certificate
account *Account
client *acme.Client
certsChan chan *Certificate
configurationChan chan<- types.ConfigMessage
certificateStore *traefiktls.CertificateStore
clientMutex sync.Mutex
configFromListenerChan chan types.Configuration
pool *safe.Pool
resolvingDomains map[string]struct{}
resolvingDomainsMutex sync.RWMutex
}
// Certificate is a struct which contains all data needed from an ACME certificate
type Certificate struct {
Domain types.Domain
Certificate []byte
Key []byte
}
// DNSChallenge contains DNS challenge Configuration
type DNSChallenge struct {
Provider string `description:"Use a DNS-01 based challenge provider rather than HTTPS."`
DelayBeforeCheck parse.Duration `description:"Assume DNS propagates after a delay in seconds rather than finding and querying nameservers."`
Resolvers types.DNSResolvers `description:"Use following DNS servers to resolve the FQDN authority."`
DisablePropagationCheck bool `description:"Disable the DNS propagation checks before notifying ACME that the DNS challenge is ready. [not recommended]"`
preCheckTimeout time.Duration
preCheckInterval time.Duration
}
// HTTPChallenge contains HTTP challenge Configuration
type HTTPChallenge struct {
EntryPoint string `description:"HTTP challenge EntryPoint"`
}
// TLSChallenge contains TLS challenge Configuration
type TLSChallenge struct{}
// SetConfigListenerChan initializes the configFromListenerChan
func (p *Provider) SetConfigListenerChan(configFromListenerChan chan types.Configuration) {
p.configFromListenerChan = configFromListenerChan
}
// SetCertificateStore allow to initialize certificate store
func (p *Provider) SetCertificateStore(certificateStore *traefiktls.CertificateStore) {
p.certificateStore = certificateStore
}
// ListenConfiguration sets a new Configuration into the configFromListenerChan
func (p *Provider) ListenConfiguration(config types.Configuration) {
p.configFromListenerChan <- config
}
// ListenRequest resolves new certificates for a domain from an incoming request and return a valid Certificate to serve (onDemand option)
func (p *Provider) ListenRequest(domain string) (*tls.Certificate, error) {
acmeCert, err := p.resolveCertificate(types.Domain{Main: domain}, false)
if acmeCert == nil || err != nil {
return nil, err
}
certificate, err := tls.X509KeyPair(acmeCert.Certificate, acmeCert.PrivateKey)
return &certificate, err
}
// Init for compatibility reason the BaseProvider implements an empty Init
func (p *Provider) Init(_ types.Constraints) error {
acme.UserAgent = fmt.Sprintf("containous-traefik/%s", version.Version)
if p.ACMELogging {
legolog.Logger = fmtlog.New(log.WriterLevel(logrus.InfoLevel), "legolog: ", 0)
} else {
legolog.Logger = fmtlog.New(ioutil.Discard, "", 0)
}
if p.Store == nil {
return errors.New("no store found for the ACME provider")
}
var err error
p.account, err = p.Store.GetAccount()
if err != nil {
return fmt.Errorf("unable to get ACME account : %v", err)
}
// Reset Account if caServer changed, thus registration URI can be updated
if p.account != nil && p.account.Registration != nil && !isAccountMatchingCaServer(p.account.Registration.URI, p.CAServer) {
log.Info("Account URI does not match the current CAServer. The account will be reset")
p.account = nil
}
p.certificates, err = p.Store.GetCertificates()
if err != nil {
return fmt.Errorf("unable to get ACME certificates : %v", err)
}
// Init the currently resolved domain map
p.resolvingDomains = make(map[string]struct{})
return nil
}
func isAccountMatchingCaServer(accountURI string, serverURI string) bool {
aru, err := url.Parse(accountURI)
if err != nil {
log.Infof("Unable to parse account.Registration URL : %v", err)
return false
}
cau, err := url.Parse(serverURI)
if err != nil {
log.Infof("Unable to parse CAServer URL : %v", err)
return false
}
return cau.Hostname() == aru.Hostname()
}
// Provide allows the file provider to provide configurations to traefik
// using the given Configuration channel.
func (p *Provider) Provide(configurationChan chan<- types.ConfigMessage, pool *safe.Pool) error {
p.pool = pool
p.watchCertificate()
p.watchNewDomains()
p.configurationChan = configurationChan
p.refreshCertificates()
p.deleteUnnecessaryDomains()
for i := 0; i < len(p.Domains); i++ {
domain := p.Domains[i]
safe.Go(func() {
if _, err := p.resolveCertificate(domain, true); err != nil {
log.Errorf("Unable to obtain ACME certificate for domains %q : %v", strings.Join(domain.ToStrArray(), ","), err)
}
})
}
p.renewCertificates()
ticker := time.NewTicker(24 * time.Hour)
pool.Go(func(stop chan bool) {
for {
select {
case <-ticker.C:
p.renewCertificates()
case <-stop:
ticker.Stop()
return
}
}
})
return nil
}
func (p *Provider) getClient() (*acme.Client, error) {
p.clientMutex.Lock()
defer p.clientMutex.Unlock()
if p.client != nil {
return p.client, nil
}
account, err := p.initAccount()
if err != nil {
return nil, err
}
log.Debug("Building ACME client...")
caServer := "https://acme-v02.api.letsencrypt.org/directory"
if len(p.CAServer) > 0 {
caServer = p.CAServer
}
log.Debug(caServer)
client, err := acme.NewClient(caServer, account, account.KeyType)
if err != nil {
return nil, err
}
// New users will need to register; be sure to save it
if account.GetRegistration() == nil {
log.Info("Register...")
reg, err := client.Register(true)
if err != nil {
return nil, err
}
account.Registration = reg
}
// Save the account once before all the certificates generation/storing
// No certificate can be generated if account is not initialized
err = p.Store.SaveAccount(account)
if err != nil {
return nil, err
}
if p.DNSChallenge != nil && len(p.DNSChallenge.Provider) > 0 {
log.Debugf("Using DNS Challenge provider: %s", p.DNSChallenge.Provider)
SetRecursiveNameServers(p.DNSChallenge.Resolvers)
SetPropagationCheck(p.DNSChallenge.DisablePropagationCheck)
err = dnsOverrideDelay(p.DNSChallenge.DelayBeforeCheck)
if err != nil {
return nil, err
}
var provider acme.ChallengeProvider
provider, err = dns.NewDNSChallengeProviderByName(p.DNSChallenge.Provider)
if err != nil {
return nil, err
}
client.ExcludeChallenges([]acme.Challenge{acme.HTTP01, acme.TLSALPN01})
err = client.SetChallengeProvider(acme.DNS01, provider)
if err != nil {
return nil, err
}
// Same default values than LEGO
p.DNSChallenge.preCheckTimeout = 60 * time.Second
p.DNSChallenge.preCheckInterval = 2 * time.Second
// Set the precheck timeout into the DNSChallenge provider
if challengeProviderTimeout, ok := provider.(acme.ChallengeProviderTimeout); ok {
p.DNSChallenge.preCheckTimeout, p.DNSChallenge.preCheckInterval = challengeProviderTimeout.Timeout()
}
} else if p.HTTPChallenge != nil && len(p.HTTPChallenge.EntryPoint) > 0 {
log.Debug("Using HTTP Challenge provider.")
client.ExcludeChallenges([]acme.Challenge{acme.DNS01, acme.TLSALPN01})
err = client.SetChallengeProvider(acme.HTTP01, &challengeHTTP{Store: p.Store})
if err != nil {
return nil, err
}
} else if p.TLSChallenge != nil {
log.Debug("Using TLS Challenge provider.")
client.ExcludeChallenges([]acme.Challenge{acme.HTTP01, acme.DNS01})
err = client.SetChallengeProvider(acme.TLSALPN01, &challengeTLSALPN{Store: p.Store})
if err != nil {
return nil, err
}
} else {
return nil, errors.New("ACME challenge not specified, please select TLS or HTTP or DNS Challenge")
}
p.client = client
return p.client, nil
}
func (p *Provider) initAccount() (*Account, error) {
if p.account == nil || len(p.account.Email) == 0 {
var err error
p.account, err = NewAccount(p.Email, p.KeyType)
if err != nil {
return nil, err
}
}
// Set the KeyType if not already defined in the account
if len(p.account.KeyType) == 0 {
p.account.KeyType = GetKeyType(p.KeyType)
}
return p.account, nil
}
func contains(entryPoints []string, acmeEntryPoint string) bool {
for _, entryPoint := range entryPoints {
if entryPoint == acmeEntryPoint {
return true
}
}
return false
}
func (p *Provider) watchNewDomains() {
p.pool.Go(func(stop chan bool) {
for {
select {
case config := <-p.configFromListenerChan:
for _, frontend := range config.Frontends {
if !contains(frontend.EntryPoints, p.EntryPoint) {
continue
}
for _, route := range frontend.Routes {
domainRules := rules.Rules{}
domains, err := domainRules.ParseDomains(route.Rule)
if err != nil {
log.Errorf("Error parsing domains in provider ACME: %v", err)
continue
}
if len(domains) == 0 {
log.Debugf("No domain parsed in rule %q in provider ACME", route.Rule)
continue
}
log.Debugf("Try to challenge certificate for domain %v founded in Host rule", domains)
var domain types.Domain
if len(domains) > 0 {
domain = types.Domain{Main: domains[0]}
if len(domains) > 1 {
domain.SANs = domains[1:]
}
safe.Go(func() {
if _, err := p.resolveCertificate(domain, false); err != nil {
log.Errorf("Unable to obtain ACME certificate for domains %q detected thanks to rule %q : %v", strings.Join(domains, ","), route.Rule, err)
}
})
}
}
}
case <-stop:
return
}
}
})
}
func (p *Provider) resolveCertificate(domain types.Domain, domainFromConfigurationFile bool) (*acme.CertificateResource, error) {
domains, err := p.getValidDomains(domain, domainFromConfigurationFile)
if err != nil {
return nil, err
}
// Check provided certificates
uncheckedDomains := p.getUncheckedDomains(domains, !domainFromConfigurationFile)
if len(uncheckedDomains) == 0 {
return nil, nil
}
p.addResolvingDomains(uncheckedDomains)
defer p.removeResolvingDomains(uncheckedDomains)
log.Debugf("Loading ACME certificates %+v...", uncheckedDomains)
client, err := p.getClient()
if err != nil {
return nil, fmt.Errorf("cannot get ACME client %v", err)
}
var certificate *acme.CertificateResource
bundle := true
if p.useCertificateWithRetry(uncheckedDomains) {
certificate, err = obtainCertificateWithRetry(domains, client, p.DNSChallenge.preCheckTimeout, p.DNSChallenge.preCheckInterval, bundle)
} else {
certificate, err = client.ObtainCertificate(domains, bundle, nil, OSCPMustStaple)
}
if err != nil {
return nil, fmt.Errorf("unable to generate a certificate for the domains %v: %v", uncheckedDomains, err)
}
if certificate == nil {
return nil, fmt.Errorf("domains %v do not generate a certificate", uncheckedDomains)
}
if len(certificate.Certificate) == 0 || len(certificate.PrivateKey) == 0 {
return nil, fmt.Errorf("domains %v generate certificate with no value: %v", uncheckedDomains, certificate)
}
log.Debugf("Certificates obtained for domains %+v", uncheckedDomains)
if len(uncheckedDomains) > 1 {
domain = types.Domain{Main: uncheckedDomains[0], SANs: uncheckedDomains[1:]}
} else {
domain = types.Domain{Main: uncheckedDomains[0]}
}
p.addCertificateForDomain(domain, certificate.Certificate, certificate.PrivateKey)
return certificate, nil
}
func (p *Provider) removeResolvingDomains(resolvingDomains []string) {
p.resolvingDomainsMutex.Lock()
defer p.resolvingDomainsMutex.Unlock()
for _, domain := range resolvingDomains {
delete(p.resolvingDomains, domain)
}
}
func (p *Provider) addResolvingDomains(resolvingDomains []string) {
p.resolvingDomainsMutex.Lock()
defer p.resolvingDomainsMutex.Unlock()
for _, domain := range resolvingDomains {
p.resolvingDomains[domain] = struct{}{}
}
}
func (p *Provider) useCertificateWithRetry(domains []string) bool {
// Check if we can use the retry mechanism only if we use the DNS Challenge and if is there are at least 2 domains to check
if p.DNSChallenge != nil && len(domains) > 1 {
rootDomain := ""
for _, searchWildcardDomain := range domains {
// Search a wildcard domain if not already found
if len(rootDomain) == 0 && strings.HasPrefix(searchWildcardDomain, "*.") {
rootDomain = strings.TrimPrefix(searchWildcardDomain, "*.")
if len(rootDomain) > 0 {
// Look for a root domain which matches the wildcard domain
for _, searchRootDomain := range domains {
if rootDomain == searchRootDomain {
// If the domains list contains a wildcard domain and its root domain, we can use the retry mechanism to obtain the certificate
return true
}
}
}
// There is only one wildcard domain in the slice, if its root domain has not been found, the retry mechanism does not have to be used
return false
}
}
}
return false
}
func obtainCertificateWithRetry(domains []string, client *acme.Client, timeout, interval time.Duration, bundle bool) (*acme.CertificateResource, error) {
var certificate *acme.CertificateResource
var err error
operation := func() error {
certificate, err = client.ObtainCertificate(domains, bundle, nil, OSCPMustStaple)
return err
}
notify := func(err error, time time.Duration) {
log.Errorf("Error obtaining certificate retrying in %s", time)
}
// Define a retry backOff to let LEGO tries twice to obtain a certificate for both wildcard and root domain
ebo := backoff.NewExponentialBackOff()
ebo.MaxElapsedTime = 2 * timeout
ebo.MaxInterval = interval
rbo := backoff.WithMaxRetries(ebo, 2)
err = backoff.RetryNotify(safe.OperationWithRecover(operation), rbo, notify)
if err != nil {
log.Errorf("Error obtaining certificate: %v", err)
return nil, err
}
return certificate, nil
}
func dnsOverrideDelay(delay parse.Duration) error {
if delay == 0 {
return nil
}
if delay > 0 {
log.Debugf("Delaying %d rather than validating DNS propagation now.", delay)
acme.PreCheckDNS = func(_, _ string) (bool, error) {
time.Sleep(time.Duration(delay))
return true, nil
}
} else {
return fmt.Errorf("delayBeforeCheck: %d cannot be less than 0", delay)
}
return nil
}
func (p *Provider) addCertificateForDomain(domain types.Domain, certificate []byte, key []byte) {
p.certsChan <- &Certificate{Certificate: certificate, Key: key, Domain: domain}
}
// deleteUnnecessaryDomains deletes from the configuration :
// - Duplicated domains
// - Domains which are checked by wildcard domain
func (p *Provider) deleteUnnecessaryDomains() {
var newDomains []types.Domain
for idxDomainToCheck, domainToCheck := range p.Domains {
keepDomain := true
for idxDomain, domain := range p.Domains {
if idxDomainToCheck == idxDomain {
continue
}
if reflect.DeepEqual(domain, domainToCheck) {
if idxDomainToCheck > idxDomain {
log.Warnf("The domain %v is duplicated in the configuration but will be process by ACME provider only once.", domainToCheck)
keepDomain = false
}
break
}
// Check if CN or SANS to check already exists
// or can not be checked by a wildcard
var newDomainsToCheck []string
for _, domainProcessed := range domainToCheck.ToStrArray() {
if idxDomain < idxDomainToCheck && isDomainAlreadyChecked(domainProcessed, domain.ToStrArray()) {
// The domain is duplicated in a CN
log.Warnf("Domain %q is duplicated in the configuration or validated by the domain %v. It will be processed once.", domainProcessed, domain)
continue
} else if domain.Main != domainProcessed && strings.HasPrefix(domain.Main, "*") && isDomainAlreadyChecked(domainProcessed, []string{domain.Main}) {
// Check if a wildcard can validate the domain
log.Warnf("Domain %q will not be processed by ACME provider because it is validated by the wildcard %q", domainProcessed, domain.Main)
continue
}
newDomainsToCheck = append(newDomainsToCheck, domainProcessed)
}
// Delete the domain if both Main and SANs can be validated by the wildcard domain
// otherwise keep the unchecked values
if newDomainsToCheck == nil {
keepDomain = false
break
}
domainToCheck.Set(newDomainsToCheck)
}
if keepDomain {
newDomains = append(newDomains, domainToCheck)
}
}
p.Domains = newDomains
}
func (p *Provider) watchCertificate() {
p.certsChan = make(chan *Certificate)
p.pool.Go(func(stop chan bool) {
for {
select {
case cert := <-p.certsChan:
certUpdated := false
for _, domainsCertificate := range p.certificates {
if reflect.DeepEqual(cert.Domain, domainsCertificate.Domain) {
domainsCertificate.Certificate = cert.Certificate
domainsCertificate.Key = cert.Key
certUpdated = true
break
}
}
if !certUpdated {
p.certificates = append(p.certificates, cert)
}
err := p.saveCertificates()
if err != nil {
log.Error(err)
}
case <-stop:
return
}
}
})
}
func (p *Provider) saveCertificates() error {
err := p.Store.SaveCertificates(p.certificates)
p.refreshCertificates()
return err
}
func (p *Provider) refreshCertificates() {
config := types.ConfigMessage{
ProviderName: "ACME",
Configuration: &types.Configuration{
Backends: map[string]*types.Backend{},
Frontends: map[string]*types.Frontend{},
TLS: []*traefiktls.Configuration{},
},
}
for _, cert := range p.certificates {
certificate := &traefiktls.Certificate{CertFile: traefiktls.FileOrContent(cert.Certificate), KeyFile: traefiktls.FileOrContent(cert.Key)}
config.Configuration.TLS = append(config.Configuration.TLS, &traefiktls.Configuration{Certificate: certificate, EntryPoints: []string{p.EntryPoint}})
}
p.configurationChan <- config
}
func (p *Provider) renewCertificates() {
log.Info("Testing certificate renew...")
for _, certificate := range p.certificates {
crt, err := getX509Certificate(certificate)
// If there's an error, we assume the cert is broken, and needs update
// <= 30 days left, renew certificate
if err != nil || crt == nil || crt.NotAfter.Before(time.Now().Add(24*30*time.Hour)) {
client, err := p.getClient()
if err != nil {
log.Infof("Error renewing certificate from LE : %+v, %v", certificate.Domain, err)
continue
}
log.Infof("Renewing certificate from LE : %+v", certificate.Domain)
renewedCert, err := client.RenewCertificate(acme.CertificateResource{
Domain: certificate.Domain.Main,
PrivateKey: certificate.Key,
Certificate: certificate.Certificate,
}, true, OSCPMustStaple)
if err != nil {
log.Errorf("Error renewing certificate from LE: %v, %v", certificate.Domain, err)
continue
}
if len(renewedCert.Certificate) == 0 || len(renewedCert.PrivateKey) == 0 {
log.Errorf("domains %v renew certificate with no value: %v", certificate.Domain.ToStrArray(), certificate)
continue
}
p.addCertificateForDomain(certificate.Domain, renewedCert.Certificate, renewedCert.PrivateKey)
}
}
}
// Get provided certificate which check a domains list (Main and SANs)
// from static and dynamic provided certificates
func (p *Provider) getUncheckedDomains(domainsToCheck []string, checkConfigurationDomains bool) []string {
p.resolvingDomainsMutex.RLock()
defer p.resolvingDomainsMutex.RUnlock()
log.Debugf("Looking for provided certificate(s) to validate %q...", domainsToCheck)
allDomains := p.certificateStore.GetAllDomains()
// Get ACME certificates
for _, certificate := range p.certificates {
allDomains = append(allDomains, strings.Join(certificate.Domain.ToStrArray(), ","))
}
// Get currently resolved domains
for domain := range p.resolvingDomains {
allDomains = append(allDomains, domain)
}
// Get Configuration Domains
if checkConfigurationDomains {
for i := 0; i < len(p.Domains); i++ {
allDomains = append(allDomains, strings.Join(p.Domains[i].ToStrArray(), ","))
}
}
return searchUncheckedDomains(domainsToCheck, allDomains)
}
func searchUncheckedDomains(domainsToCheck []string, existentDomains []string) []string {
var uncheckedDomains []string
for _, domainToCheck := range domainsToCheck {
if !isDomainAlreadyChecked(domainToCheck, existentDomains) {
uncheckedDomains = append(uncheckedDomains, domainToCheck)
}
}
if len(uncheckedDomains) == 0 {
log.Debugf("No ACME certificate generation required for domains %q.", domainsToCheck)
} else {
log.Debugf("Domains %q need ACME certificates generation for domains %q.", domainsToCheck, strings.Join(uncheckedDomains, ","))
}
return uncheckedDomains
}
func getX509Certificate(certificate *Certificate) (*x509.Certificate, error) {
tlsCert, err := tls.X509KeyPair(certificate.Certificate, certificate.Key)
if err != nil {
log.Errorf("Failed to load TLS keypair from ACME certificate for domain %q (SAN : %q), certificate will be renewed : %v", certificate.Domain.Main, strings.Join(certificate.Domain.SANs, ","), err)
return nil, err
}
crt := tlsCert.Leaf
if crt == nil {
crt, err = x509.ParseCertificate(tlsCert.Certificate[0])
if err != nil {
log.Errorf("Failed to parse TLS keypair from ACME certificate for domain %q (SAN : %q), certificate will be renewed : %v", certificate.Domain.Main, strings.Join(certificate.Domain.SANs, ","), err)
}
}
return crt, err
}
// getValidDomains checks if given domain is allowed to generate a ACME certificate and return it
func (p *Provider) getValidDomains(domain types.Domain, wildcardAllowed bool) ([]string, error) {
domains := domain.ToStrArray()
if len(domains) == 0 {
return nil, errors.New("unable to generate a certificate in ACME provider when no domain is given")
}
if strings.HasPrefix(domain.Main, "*") {
if !wildcardAllowed {
return nil, fmt.Errorf("unable to generate a wildcard certificate in ACME provider for domain %q from a 'Host' rule", strings.Join(domains, ","))
}
if p.DNSChallenge == nil {
return nil, fmt.Errorf("unable to generate a wildcard certificate in ACME provider for domain %q : ACME needs a DNSChallenge", strings.Join(domains, ","))
}
if strings.HasPrefix(domain.Main, "*.*") {
return nil, fmt.Errorf("unable to generate a wildcard certificate in ACME provider for domain %q : ACME does not allow '*.*' wildcard domain", strings.Join(domains, ","))
}
}
for _, san := range domain.SANs {
if strings.HasPrefix(san, "*") {
return nil, fmt.Errorf("unable to generate a certificate in ACME provider for domains %q: SAN %q can not be a wildcard domain", strings.Join(domains, ","), san)
}
}
var cleanDomains []string
for _, domain := range domains {
canonicalDomain := types.CanonicalDomain(domain)
cleanDomain := acme.UnFqdn(canonicalDomain)
if canonicalDomain != cleanDomain {
log.Warnf("FQDN detected, please remove the trailing dot: %s", canonicalDomain)
}
cleanDomains = append(cleanDomains, cleanDomain)
}
return cleanDomains, nil
}
func isDomainAlreadyChecked(domainToCheck string, existentDomains []string) bool {
for _, certDomains := range existentDomains {
for _, certDomain := range strings.Split(certDomains, ",") {
if types.MatchDomain(domainToCheck, certDomain) {
return true
}
}
}
return false
}
// SetPropagationCheck to disable the Lego PreCheck.
func SetPropagationCheck(disable bool) {
if disable {
acme.PreCheckDNS = func(_, _ string) (bool, error) {
return true, nil
}
}
}
// SetRecursiveNameServers to provide a custom DNS resolver.
func SetRecursiveNameServers(dnsResolvers []string) {
resolvers := normaliseDNSResolvers(dnsResolvers)
if len(resolvers) > 0 {
acme.RecursiveNameservers = resolvers
log.Infof("Validating FQDN authority with DNS using %+v", resolvers)
}
}
// ensure all servers have a port number
func normaliseDNSResolvers(dnsResolvers []string) []string {
var normalisedResolvers []string
for _, server := range dnsResolvers {
srv := strings.TrimSpace(server)
if len(srv) > 0 {
if host, port, err := net.SplitHostPort(srv); err != nil {
normalisedResolvers = append(normalisedResolvers, net.JoinHostPort(srv, "53"))
} else {
normalisedResolvers = append(normalisedResolvers, net.JoinHostPort(host, port))
}
}
}
return normalisedResolvers
}

View file

@ -0,0 +1,684 @@
package acme
import (
"crypto/tls"
"testing"
"github.com/containous/traefik/old/types"
"github.com/containous/traefik/safe"
traefiktls "github.com/containous/traefik/tls"
"github.com/stretchr/testify/assert"
"github.com/xenolf/lego/acme"
)
func TestGetUncheckedCertificates(t *testing.T) {
wildcardMap := make(map[string]*tls.Certificate)
wildcardMap["*.traefik.wtf"] = &tls.Certificate{}
wildcardSafe := &safe.Safe{}
wildcardSafe.Set(wildcardMap)
domainMap := make(map[string]*tls.Certificate)
domainMap["traefik.wtf"] = &tls.Certificate{}
domainSafe := &safe.Safe{}
domainSafe.Set(domainMap)
testCases := []struct {
desc string
dynamicCerts *safe.Safe
staticCerts *safe.Safe
resolvingDomains map[string]struct{}
acmeCertificates []*Certificate
domains []string
expectedDomains []string
}{
{
desc: "wildcard to generate",
domains: []string{"*.traefik.wtf"},
expectedDomains: []string{"*.traefik.wtf"},
},
{
desc: "wildcard already exists in dynamic certificates",
domains: []string{"*.traefik.wtf"},
dynamicCerts: wildcardSafe,
expectedDomains: nil,
},
{
desc: "wildcard already exists in static certificates",
domains: []string{"*.traefik.wtf"},
staticCerts: wildcardSafe,
expectedDomains: nil,
},
{
desc: "wildcard already exists in ACME certificates",
domains: []string{"*.traefik.wtf"},
acmeCertificates: []*Certificate{
{
Domain: types.Domain{Main: "*.traefik.wtf"},
},
},
expectedDomains: nil,
},
{
desc: "domain CN and SANs to generate",
domains: []string{"traefik.wtf", "foo.traefik.wtf"},
expectedDomains: []string{"traefik.wtf", "foo.traefik.wtf"},
},
{
desc: "domain CN already exists in dynamic certificates and SANs to generate",
domains: []string{"traefik.wtf", "foo.traefik.wtf"},
dynamicCerts: domainSafe,
expectedDomains: []string{"foo.traefik.wtf"},
},
{
desc: "domain CN already exists in static certificates and SANs to generate",
domains: []string{"traefik.wtf", "foo.traefik.wtf"},
staticCerts: domainSafe,
expectedDomains: []string{"foo.traefik.wtf"},
},
{
desc: "domain CN already exists in ACME certificates and SANs to generate",
domains: []string{"traefik.wtf", "foo.traefik.wtf"},
acmeCertificates: []*Certificate{
{
Domain: types.Domain{Main: "traefik.wtf"},
},
},
expectedDomains: []string{"foo.traefik.wtf"},
},
{
desc: "domain already exists in dynamic certificates",
domains: []string{"traefik.wtf"},
dynamicCerts: domainSafe,
expectedDomains: nil,
},
{
desc: "domain already exists in static certificates",
domains: []string{"traefik.wtf"},
staticCerts: domainSafe,
expectedDomains: nil,
},
{
desc: "domain already exists in ACME certificates",
domains: []string{"traefik.wtf"},
acmeCertificates: []*Certificate{
{
Domain: types.Domain{Main: "traefik.wtf"},
},
},
expectedDomains: nil,
},
{
desc: "domain matched by wildcard in dynamic certificates",
domains: []string{"who.traefik.wtf", "foo.traefik.wtf"},
dynamicCerts: wildcardSafe,
expectedDomains: nil,
},
{
desc: "domain matched by wildcard in static certificates",
domains: []string{"who.traefik.wtf", "foo.traefik.wtf"},
staticCerts: wildcardSafe,
expectedDomains: nil,
},
{
desc: "domain matched by wildcard in ACME certificates",
domains: []string{"who.traefik.wtf", "foo.traefik.wtf"},
acmeCertificates: []*Certificate{
{
Domain: types.Domain{Main: "*.traefik.wtf"},
},
},
expectedDomains: nil,
},
{
desc: "root domain with wildcard in ACME certificates",
domains: []string{"traefik.wtf", "foo.traefik.wtf"},
acmeCertificates: []*Certificate{
{
Domain: types.Domain{Main: "*.traefik.wtf"},
},
},
expectedDomains: []string{"traefik.wtf"},
},
{
desc: "all domains already managed by ACME",
domains: []string{"traefik.wtf", "foo.traefik.wtf"},
resolvingDomains: map[string]struct{}{
"traefik.wtf": {},
"foo.traefik.wtf": {},
},
expectedDomains: []string{},
},
{
desc: "one domain already managed by ACME",
domains: []string{"traefik.wtf", "foo.traefik.wtf"},
resolvingDomains: map[string]struct{}{
"traefik.wtf": {},
},
expectedDomains: []string{"foo.traefik.wtf"},
},
{
desc: "wildcard domain already managed by ACME checks the domains",
domains: []string{"bar.traefik.wtf", "foo.traefik.wtf"},
resolvingDomains: map[string]struct{}{
"*.traefik.wtf": {},
},
expectedDomains: []string{},
},
{
desc: "wildcard domain already managed by ACME checks domains and another domain checks one other domain, one domain still unchecked",
domains: []string{"traefik.wtf", "bar.traefik.wtf", "foo.traefik.wtf", "acme.wtf"},
resolvingDomains: map[string]struct{}{
"*.traefik.wtf": {},
"traefik.wtf": {},
},
expectedDomains: []string{"acme.wtf"},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
if test.resolvingDomains == nil {
test.resolvingDomains = make(map[string]struct{})
}
acmeProvider := Provider{
certificateStore: &traefiktls.CertificateStore{
DynamicCerts: test.dynamicCerts,
StaticCerts: test.staticCerts,
},
certificates: test.acmeCertificates,
resolvingDomains: test.resolvingDomains,
}
domains := acmeProvider.getUncheckedDomains(test.domains, false)
assert.Equal(t, len(test.expectedDomains), len(domains), "Unexpected domains.")
})
}
}
func TestGetValidDomain(t *testing.T) {
testCases := []struct {
desc string
domains types.Domain
wildcardAllowed bool
dnsChallenge *DNSChallenge
expectedErr string
expectedDomains []string
}{
{
desc: "valid wildcard",
domains: types.Domain{Main: "*.traefik.wtf"},
dnsChallenge: &DNSChallenge{},
wildcardAllowed: true,
expectedErr: "",
expectedDomains: []string{"*.traefik.wtf"},
},
{
desc: "no wildcard",
domains: types.Domain{Main: "traefik.wtf", SANs: []string{"foo.traefik.wtf"}},
dnsChallenge: &DNSChallenge{},
expectedErr: "",
wildcardAllowed: true,
expectedDomains: []string{"traefik.wtf", "foo.traefik.wtf"},
},
{
desc: "unauthorized wildcard",
domains: types.Domain{Main: "*.traefik.wtf"},
dnsChallenge: &DNSChallenge{},
wildcardAllowed: false,
expectedErr: "unable to generate a wildcard certificate in ACME provider for domain \"*.traefik.wtf\" from a 'Host' rule",
expectedDomains: nil,
},
{
desc: "no domain",
domains: types.Domain{},
dnsChallenge: nil,
wildcardAllowed: true,
expectedErr: "unable to generate a certificate in ACME provider when no domain is given",
expectedDomains: nil,
},
{
desc: "no DNSChallenge",
domains: types.Domain{Main: "*.traefik.wtf", SANs: []string{"foo.traefik.wtf"}},
dnsChallenge: nil,
wildcardAllowed: true,
expectedErr: "unable to generate a wildcard certificate in ACME provider for domain \"*.traefik.wtf,foo.traefik.wtf\" : ACME needs a DNSChallenge",
expectedDomains: nil,
},
{
desc: "unauthorized wildcard with SAN",
domains: types.Domain{Main: "*.*.traefik.wtf", SANs: []string{"foo.traefik.wtf"}},
dnsChallenge: &DNSChallenge{},
wildcardAllowed: true,
expectedErr: "unable to generate a wildcard certificate in ACME provider for domain \"*.*.traefik.wtf,foo.traefik.wtf\" : ACME does not allow '*.*' wildcard domain",
expectedDomains: nil,
},
{
desc: "wildcard and SANs",
domains: types.Domain{Main: "*.traefik.wtf", SANs: []string{"traefik.wtf"}},
dnsChallenge: &DNSChallenge{},
wildcardAllowed: true,
expectedErr: "",
expectedDomains: []string{"*.traefik.wtf", "traefik.wtf"},
},
{
desc: "unexpected SANs",
domains: types.Domain{Main: "*.traefik.wtf", SANs: []string{"*.acme.wtf"}},
dnsChallenge: &DNSChallenge{},
wildcardAllowed: true,
expectedErr: "unable to generate a certificate in ACME provider for domains \"*.traefik.wtf,*.acme.wtf\": SAN \"*.acme.wtf\" can not be a wildcard domain",
expectedDomains: nil,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
acmeProvider := Provider{Configuration: &Configuration{DNSChallenge: test.dnsChallenge}}
domains, err := acmeProvider.getValidDomains(test.domains, test.wildcardAllowed)
if len(test.expectedErr) > 0 {
assert.EqualError(t, err, test.expectedErr, "Unexpected error.")
} else {
assert.Equal(t, len(test.expectedDomains), len(domains), "Unexpected domains.")
}
})
}
}
func TestDeleteUnnecessaryDomains(t *testing.T) {
testCases := []struct {
desc string
domains []types.Domain
expectedDomains []types.Domain
}{
{
desc: "no domain to delete",
domains: []types.Domain{
{
Main: "acme.wtf",
SANs: []string{"traefik.acme.wtf", "foo.bar"},
},
{
Main: "*.foo.acme.wtf",
},
{
Main: "acme02.wtf",
SANs: []string{"traefik.acme02.wtf", "bar.foo"},
},
},
expectedDomains: []types.Domain{
{
Main: "acme.wtf",
SANs: []string{"traefik.acme.wtf", "foo.bar"},
},
{
Main: "*.foo.acme.wtf",
SANs: []string{},
},
{
Main: "acme02.wtf",
SANs: []string{"traefik.acme02.wtf", "bar.foo"},
},
},
},
{
desc: "wildcard and root domain",
domains: []types.Domain{
{
Main: "acme.wtf",
},
{
Main: "*.acme.wtf",
SANs: []string{"acme.wtf"},
},
},
expectedDomains: []types.Domain{
{
Main: "acme.wtf",
SANs: []string{},
},
{
Main: "*.acme.wtf",
SANs: []string{},
},
},
},
{
desc: "2 equals domains",
domains: []types.Domain{
{
Main: "acme.wtf",
SANs: []string{"traefik.acme.wtf", "foo.bar"},
},
{
Main: "acme.wtf",
SANs: []string{"traefik.acme.wtf", "foo.bar"},
},
},
expectedDomains: []types.Domain{
{
Main: "acme.wtf",
SANs: []string{"traefik.acme.wtf", "foo.bar"},
},
},
},
{
desc: "2 domains with same values",
domains: []types.Domain{
{
Main: "acme.wtf",
SANs: []string{"traefik.acme.wtf"},
},
{
Main: "acme.wtf",
SANs: []string{"traefik.acme.wtf", "foo.bar"},
},
},
expectedDomains: []types.Domain{
{
Main: "acme.wtf",
SANs: []string{"traefik.acme.wtf"},
},
{
Main: "foo.bar",
SANs: []string{},
},
},
},
{
desc: "domain totally checked by wildcard",
domains: []types.Domain{
{
Main: "who.acme.wtf",
SANs: []string{"traefik.acme.wtf", "bar.acme.wtf"},
},
{
Main: "*.acme.wtf",
},
},
expectedDomains: []types.Domain{
{
Main: "*.acme.wtf",
SANs: []string{},
},
},
},
{
desc: "duplicated wildcard",
domains: []types.Domain{
{
Main: "*.acme.wtf",
SANs: []string{"acme.wtf"},
},
{
Main: "*.acme.wtf",
},
},
expectedDomains: []types.Domain{
{
Main: "*.acme.wtf",
SANs: []string{"acme.wtf"},
},
},
},
{
desc: "domain partially checked by wildcard",
domains: []types.Domain{
{
Main: "traefik.acme.wtf",
SANs: []string{"acme.wtf", "foo.bar"},
},
{
Main: "*.acme.wtf",
},
{
Main: "who.acme.wtf",
SANs: []string{"traefik.acme.wtf", "bar.acme.wtf"},
},
},
expectedDomains: []types.Domain{
{
Main: "acme.wtf",
SANs: []string{"foo.bar"},
},
{
Main: "*.acme.wtf",
SANs: []string{},
},
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
acmeProvider := Provider{Configuration: &Configuration{Domains: test.domains}}
acmeProvider.deleteUnnecessaryDomains()
assert.Equal(t, test.expectedDomains, acmeProvider.Domains, "unexpected domain")
})
}
}
func TestIsAccountMatchingCaServer(t *testing.T) {
testCases := []struct {
desc string
accountURI string
serverURI string
expected bool
}{
{
desc: "acme staging with matching account",
accountURI: "https://acme-staging-v02.api.letsencrypt.org/acme/acct/1234567",
serverURI: "https://acme-staging-v02.api.letsencrypt.org/acme/directory",
expected: true,
},
{
desc: "acme production with matching account",
accountURI: "https://acme-v02.api.letsencrypt.org/acme/acct/1234567",
serverURI: "https://acme-v02.api.letsencrypt.org/acme/directory",
expected: true,
},
{
desc: "http only acme with matching account",
accountURI: "http://acme.api.letsencrypt.org/acme/acct/1234567",
serverURI: "http://acme.api.letsencrypt.org/acme/directory",
expected: true,
},
{
desc: "different subdomains for account and server",
accountURI: "https://test1.example.org/acme/acct/1234567",
serverURI: "https://test2.example.org/acme/directory",
expected: false,
},
{
desc: "different domains for account and server",
accountURI: "https://test.example1.org/acme/acct/1234567",
serverURI: "https://test.example2.org/acme/directory",
expected: false,
},
{
desc: "different tld for account and server",
accountURI: "https://test.example.com/acme/acct/1234567",
serverURI: "https://test.example.org/acme/directory",
expected: false,
},
{
desc: "malformed account url",
accountURI: "//|\\/test.example.com/acme/acct/1234567",
serverURI: "https://test.example.com/acme/directory",
expected: false,
},
{
desc: "malformed server url",
accountURI: "https://test.example.com/acme/acct/1234567",
serverURI: "//|\\/test.example.com/acme/directory",
expected: false,
},
{
desc: "malformed server and account url",
accountURI: "//|\\/test.example.com/acme/acct/1234567",
serverURI: "//|\\/test.example.com/acme/directory",
expected: false,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
result := isAccountMatchingCaServer(test.accountURI, test.serverURI)
assert.Equal(t, test.expected, result)
})
}
}
func TestUseBackOffToObtainCertificate(t *testing.T) {
testCases := []struct {
desc string
domains []string
dnsChallenge *DNSChallenge
expectedResponse bool
}{
{
desc: "only one single domain",
domains: []string{"acme.wtf"},
dnsChallenge: &DNSChallenge{},
expectedResponse: false,
},
{
desc: "only one wildcard domain",
domains: []string{"*.acme.wtf"},
dnsChallenge: &DNSChallenge{},
expectedResponse: false,
},
{
desc: "wildcard domain with no root domain",
domains: []string{"*.acme.wtf", "foo.acme.wtf", "bar.acme.wtf", "foo.bar"},
dnsChallenge: &DNSChallenge{},
expectedResponse: false,
},
{
desc: "wildcard and root domain",
domains: []string{"*.acme.wtf", "foo.acme.wtf", "bar.acme.wtf", "acme.wtf"},
dnsChallenge: &DNSChallenge{},
expectedResponse: true,
},
{
desc: "wildcard and root domain but no DNS challenge",
domains: []string{"*.acme.wtf", "acme.wtf"},
dnsChallenge: nil,
expectedResponse: false,
},
{
desc: "two wildcard domains (must never happen)",
domains: []string{"*.acme.wtf", "*.bar.foo"},
dnsChallenge: nil,
expectedResponse: false,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
acmeProvider := Provider{Configuration: &Configuration{DNSChallenge: test.dnsChallenge}}
actualResponse := acmeProvider.useCertificateWithRetry(test.domains)
assert.Equal(t, test.expectedResponse, actualResponse, "unexpected response to use backOff")
})
}
}
func TestInitAccount(t *testing.T) {
testCases := []struct {
desc string
account *Account
email string
keyType string
expectedAccount *Account
}{
{
desc: "Existing account with all information",
account: &Account{
Email: "foo@foo.net",
KeyType: acme.EC256,
},
expectedAccount: &Account{
Email: "foo@foo.net",
KeyType: acme.EC256,
},
},
{
desc: "Account nil",
email: "foo@foo.net",
keyType: "EC256",
expectedAccount: &Account{
Email: "foo@foo.net",
KeyType: acme.EC256,
},
},
{
desc: "Existing account with no email",
account: &Account{
KeyType: acme.RSA4096,
},
email: "foo@foo.net",
keyType: "EC256",
expectedAccount: &Account{
Email: "foo@foo.net",
KeyType: acme.EC256,
},
},
{
desc: "Existing account with no key type",
account: &Account{
Email: "foo@foo.net",
},
email: "bar@foo.net",
keyType: "EC256",
expectedAccount: &Account{
Email: "foo@foo.net",
KeyType: acme.EC256,
},
},
{
desc: "Existing account and provider with no key type",
account: &Account{
Email: "foo@foo.net",
},
email: "bar@foo.net",
expectedAccount: &Account{
Email: "foo@foo.net",
KeyType: acme.RSA4096,
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
acmeProvider := Provider{account: test.account, Configuration: &Configuration{Email: test.email, KeyType: test.keyType}}
actualAccount, err := acmeProvider.initAccount()
assert.Nil(t, err, "Init account in error")
assert.Equal(t, test.expectedAccount.Email, actualAccount.Email, "unexpected email account")
assert.Equal(t, test.expectedAccount.KeyType, actualAccount.KeyType, "unexpected keyType account")
})
}
}

View file

@ -0,0 +1,25 @@
package acme
// StoredData represents the data managed by the Store
type StoredData struct {
Account *Account
Certificates []*Certificate
HTTPChallenges map[string]map[string][]byte
TLSChallenges map[string]*Certificate
}
// Store is a generic interface to represents a storage
type Store interface {
GetAccount() (*Account, error)
SaveAccount(*Account) error
GetCertificates() ([]*Certificate, error)
SaveCertificates([]*Certificate) error
GetHTTPChallengeToken(token, domain string) ([]byte, error)
SetHTTPChallengeToken(token, domain string, keyAuth []byte) error
RemoveHTTPChallengeToken(token, domain string) error
AddTLSChallenge(domain string, cert *Certificate) error
GetTLSChallenge(domain string) (*Certificate, error)
RemoveTLSChallenge(domain string) error
}

View file

@ -0,0 +1,48 @@
package boltdb
import (
"fmt"
"github.com/abronan/valkeyrie/store"
"github.com/abronan/valkeyrie/store/boltdb"
"github.com/containous/traefik/old/provider"
"github.com/containous/traefik/old/provider/kv"
"github.com/containous/traefik/old/types"
"github.com/containous/traefik/safe"
)
var _ provider.Provider = (*Provider)(nil)
// Provider holds configurations of the provider.
type Provider struct {
kv.Provider `mapstructure:",squash" export:"true"`
}
// Init the provider
func (p *Provider) Init(constraints types.Constraints) error {
err := p.Provider.Init(constraints)
if err != nil {
return err
}
store, err := p.CreateStore()
if err != nil {
return fmt.Errorf("failed to Connect to KV store: %v", err)
}
p.SetKVClient(store)
return nil
}
// Provide allows the boltdb provider to Provide configurations to traefik
// using the given configuration channel.
func (p *Provider) Provide(configurationChan chan<- types.ConfigMessage, pool *safe.Pool) error {
return p.Provider.Provide(configurationChan, pool)
}
// CreateStore creates the KV store
func (p *Provider) CreateStore() (store.Store, error) {
p.SetStoreType(store.BOLTDB)
boltdb.Register()
return p.Provider.CreateStore()
}

View file

@ -0,0 +1,48 @@
package consul
import (
"fmt"
"github.com/abronan/valkeyrie/store"
"github.com/abronan/valkeyrie/store/consul"
"github.com/containous/traefik/old/provider"
"github.com/containous/traefik/old/provider/kv"
"github.com/containous/traefik/old/types"
"github.com/containous/traefik/safe"
)
var _ provider.Provider = (*Provider)(nil)
// Provider holds configurations of the p.
type Provider struct {
kv.Provider `mapstructure:",squash" export:"true"`
}
// Init the provider
func (p *Provider) Init(constraints types.Constraints) error {
err := p.Provider.Init(constraints)
if err != nil {
return err
}
store, err := p.CreateStore()
if err != nil {
return fmt.Errorf("failed to Connect to KV store: %v", err)
}
p.SetKVClient(store)
return nil
}
// Provide allows the consul provider to provide configurations to traefik
// using the given configuration channel.
func (p *Provider) Provide(configurationChan chan<- types.ConfigMessage, pool *safe.Pool) error {
return p.Provider.Provide(configurationChan, pool)
}
// CreateStore creates the KV store
func (p *Provider) CreateStore() (store.Store, error) {
p.SetStoreType(store.CONSUL)
consul.Register()
return p.Provider.CreateStore()
}

View file

@ -0,0 +1,227 @@
package consulcatalog
import (
"bytes"
"crypto/sha1"
"encoding/base64"
"fmt"
"net"
"sort"
"strconv"
"strings"
"text/template"
"github.com/containous/traefik/old/log"
"github.com/containous/traefik/old/provider"
"github.com/containous/traefik/old/provider/label"
"github.com/containous/traefik/old/types"
"github.com/hashicorp/consul/api"
)
func (p *Provider) buildConfiguration(catalog []catalogUpdate) *types.Configuration {
var funcMap = template.FuncMap{
"getAttribute": p.getAttribute,
"getTag": getTag,
"hasTag": hasTag,
// Backend functions
"getNodeBackendName": getNodeBackendName,
"getServiceBackendName": getServiceBackendName,
"getBackendAddress": getBackendAddress,
"getServerName": getServerName,
"getCircuitBreaker": label.GetCircuitBreaker,
"getLoadBalancer": label.GetLoadBalancer,
"getMaxConn": label.GetMaxConn,
"getHealthCheck": label.GetHealthCheck,
"getBuffering": label.GetBuffering,
"getResponseForwarding": label.GetResponseForwarding,
"getServer": p.getServer,
// Frontend functions
"getFrontendRule": p.getFrontendRule,
"getBasicAuth": label.GetFuncSliceString(label.TraefikFrontendAuthBasic), // Deprecated
"getAuth": label.GetAuth,
"getFrontEndEntryPoints": label.GetFuncSliceString(label.TraefikFrontendEntryPoints),
"getPriority": label.GetFuncInt(label.TraefikFrontendPriority, label.DefaultFrontendPriority),
"getPassHostHeader": label.GetFuncBool(label.TraefikFrontendPassHostHeader, label.DefaultPassHostHeader),
"getPassTLSCert": label.GetFuncBool(label.TraefikFrontendPassTLSCert, label.DefaultPassTLSCert),
"getPassTLSClientCert": label.GetTLSClientCert,
"getWhiteList": label.GetWhiteList,
"getRedirect": label.GetRedirect,
"getErrorPages": label.GetErrorPages,
"getRateLimit": label.GetRateLimit,
"getHeaders": label.GetHeaders,
}
var allNodes []*api.ServiceEntry
var services []*serviceUpdate
for _, info := range catalog {
if len(info.Nodes) > 0 {
services = append(services, p.generateFrontends(info.Service)...)
allNodes = append(allNodes, info.Nodes...)
}
}
// Ensure a stable ordering of nodes so that identical configurations may be detected
sort.Sort(nodeSorter(allNodes))
templateObjects := struct {
Services []*serviceUpdate
Nodes []*api.ServiceEntry
}{
Services: services,
Nodes: allNodes,
}
configuration, err := p.GetConfiguration("templates/consul_catalog.tmpl", funcMap, templateObjects)
if err != nil {
log.WithError(err).Error("Failed to create config")
}
return configuration
}
// Specific functions
func (p *Provider) getFrontendRule(service serviceUpdate) string {
customFrontendRule := label.GetStringValue(service.TraefikLabels, label.TraefikFrontendRule, "")
if customFrontendRule == "" {
customFrontendRule = p.FrontEndRule
}
tmpl := p.frontEndRuleTemplate
tmpl, err := tmpl.Parse(customFrontendRule)
if err != nil {
log.Errorf("Failed to parse Consul Catalog custom frontend rule: %v", err)
return ""
}
templateObjects := struct {
ServiceName string
Domain string
Attributes []string
}{
ServiceName: service.ServiceName,
Domain: p.Domain,
Attributes: service.Attributes,
}
var buffer bytes.Buffer
err = tmpl.Execute(&buffer, templateObjects)
if err != nil {
log.Errorf("Failed to execute Consul Catalog custom frontend rule template: %v", err)
return ""
}
return strings.TrimSuffix(buffer.String(), ".")
}
func (p *Provider) getServer(node *api.ServiceEntry) types.Server {
scheme := p.getAttribute(label.SuffixProtocol, node.Service.Tags, label.DefaultProtocol)
address := getBackendAddress(node)
return types.Server{
URL: fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(address, strconv.Itoa(node.Service.Port))),
Weight: p.getWeight(node.Service.Tags),
}
}
func (p *Provider) setupFrontEndRuleTemplate() {
var FuncMap = template.FuncMap{
"getAttribute": p.getAttribute,
"getTag": getTag,
"hasTag": hasTag,
}
p.frontEndRuleTemplate = template.New("consul catalog frontend rule").Funcs(FuncMap)
}
// Specific functions
func getServiceBackendName(service *serviceUpdate) string {
if service.ParentServiceName != "" {
return strings.ToLower(service.ParentServiceName)
}
return strings.ToLower(service.ServiceName)
}
func getNodeBackendName(node *api.ServiceEntry) string {
return strings.ToLower(node.Service.Service)
}
func getBackendAddress(node *api.ServiceEntry) string {
if node.Service.Address != "" {
return node.Service.Address
}
return node.Node.Address
}
func getServerName(node *api.ServiceEntry, index int) string {
serviceName := node.Service.Service + node.Service.Address + strconv.Itoa(node.Service.Port)
// TODO sort tags ?
serviceName += strings.Join(node.Service.Tags, "")
hash := sha1.New()
_, err := hash.Write([]byte(serviceName))
if err != nil {
// Impossible case
log.Error(err)
} else {
serviceName = base64.URLEncoding.EncodeToString(hash.Sum(nil))
}
// unique int at the end
return provider.Normalize(node.Service.Service + "-" + strconv.Itoa(index) + "-" + serviceName)
}
func (p *Provider) getWeight(tags []string) int {
labels := tagsToNeutralLabels(tags, p.Prefix)
return label.GetIntValue(labels, p.getPrefixedName(label.SuffixWeight), label.DefaultWeight)
}
// Base functions
func (p *Provider) getAttribute(name string, tags []string, defaultValue string) string {
return getTag(p.getPrefixedName(name), tags, defaultValue)
}
func (p *Provider) getPrefixedName(name string) string {
if len(p.Prefix) > 0 && len(name) > 0 {
return p.Prefix + "." + name
}
return name
}
func hasTag(name string, tags []string) bool {
lowerName := strings.ToLower(name)
for _, tag := range tags {
lowerTag := strings.ToLower(tag)
// Given the nature of Consul tags, which could be either singular markers, or key=value pairs
if strings.HasPrefix(lowerTag, lowerName+"=") || lowerTag == lowerName {
return true
}
}
return false
}
func getTag(name string, tags []string, defaultValue string) string {
lowerName := strings.ToLower(name)
for _, tag := range tags {
lowerTag := strings.ToLower(tag)
// Given the nature of Consul tags, which could be either singular markers, or key=value pairs
if strings.HasPrefix(lowerTag, lowerName+"=") || lowerTag == lowerName {
// In case, where a tag might be a key=value, try to split it by the first '='
kv := strings.SplitN(tag, "=", 2)
// If the returned result is a key=value pair, return the 'value' component
if len(kv) == 2 {
return kv[1]
}
// If the returned result is a singular marker, return the 'key' component
return kv[0]
}
}
return defaultValue
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,617 @@
package consulcatalog
import (
"fmt"
"strconv"
"strings"
"sync"
"text/template"
"time"
"github.com/BurntSushi/ty/fun"
"github.com/cenk/backoff"
"github.com/containous/traefik/job"
"github.com/containous/traefik/old/log"
"github.com/containous/traefik/old/provider"
"github.com/containous/traefik/old/provider/label"
"github.com/containous/traefik/old/types"
"github.com/containous/traefik/safe"
"github.com/hashicorp/consul/api"
)
const (
// DefaultWatchWaitTime is the duration to wait when polling consul
DefaultWatchWaitTime = 15 * time.Second
)
var _ provider.Provider = (*Provider)(nil)
// Provider holds configurations of the Consul catalog provider.
type Provider struct {
provider.BaseProvider `mapstructure:",squash" export:"true"`
Endpoint string `description:"Consul server endpoint"`
Domain string `description:"Default domain used"`
Stale bool `description:"Use stale consistency for catalog reads" export:"true"`
ExposedByDefault bool `description:"Expose Consul services by default" export:"true"`
Prefix string `description:"Prefix used for Consul catalog tags" export:"true"`
FrontEndRule string `description:"Frontend rule used for Consul services" export:"true"`
TLS *types.ClientTLS `description:"Enable TLS support" export:"true"`
client *api.Client
frontEndRuleTemplate *template.Template
}
// Service represent a Consul service.
type Service struct {
Name string
Tags []string
Nodes []string
Addresses []string
Ports []int
}
type serviceUpdate struct {
ServiceName string
ParentServiceName string
Attributes []string
TraefikLabels map[string]string
}
type frontendSegment struct {
Name string
Labels map[string]string
}
type catalogUpdate struct {
Service *serviceUpdate
Nodes []*api.ServiceEntry
}
type nodeSorter []*api.ServiceEntry
func (a nodeSorter) Len() int {
return len(a)
}
func (a nodeSorter) Swap(i int, j int) {
a[i], a[j] = a[j], a[i]
}
func (a nodeSorter) Less(i int, j int) bool {
lEntry := a[i]
rEntry := a[j]
ls := strings.ToLower(lEntry.Service.Service)
lr := strings.ToLower(rEntry.Service.Service)
if ls != lr {
return ls < lr
}
if lEntry.Service.Address != rEntry.Service.Address {
return lEntry.Service.Address < rEntry.Service.Address
}
if lEntry.Node.Address != rEntry.Node.Address {
return lEntry.Node.Address < rEntry.Node.Address
}
return lEntry.Service.Port < rEntry.Service.Port
}
// Init the provider
func (p *Provider) Init(constraints types.Constraints) error {
err := p.BaseProvider.Init(constraints)
if err != nil {
return err
}
client, err := p.createClient()
if err != nil {
return err
}
p.client = client
p.setupFrontEndRuleTemplate()
return nil
}
// Provide allows the consul catalog provider to provide configurations to traefik
// using the given configuration channel.
func (p *Provider) Provide(configurationChan chan<- types.ConfigMessage, pool *safe.Pool) error {
pool.Go(func(stop chan bool) {
notify := func(err error, time time.Duration) {
log.Errorf("Consul connection error %+v, retrying in %s", err, time)
}
operation := func() error {
return p.watch(configurationChan, stop)
}
errRetry := backoff.RetryNotify(safe.OperationWithRecover(operation), job.NewBackOff(backoff.NewExponentialBackOff()), notify)
if errRetry != nil {
log.Errorf("Cannot connect to consul server %+v", errRetry)
}
})
return nil
}
func (p *Provider) createClient() (*api.Client, error) {
config := api.DefaultConfig()
config.Address = p.Endpoint
if p.TLS != nil {
tlsConfig, err := p.TLS.CreateTLSConfig()
if err != nil {
return nil, err
}
config.Scheme = "https"
config.Transport.TLSClientConfig = tlsConfig
}
client, err := api.NewClient(config)
if err != nil {
return nil, err
}
return client, nil
}
func (p *Provider) watch(configurationChan chan<- types.ConfigMessage, stop chan bool) error {
stopCh := make(chan struct{})
watchCh := make(chan map[string][]string)
errorCh := make(chan error)
var errorOnce sync.Once
notifyError := func(err error) {
errorOnce.Do(func() {
errorCh <- err
})
}
p.watchHealthState(stopCh, watchCh, notifyError)
p.watchCatalogServices(stopCh, watchCh, notifyError)
defer close(stopCh)
defer close(watchCh)
safe.Go(func() {
for index := range watchCh {
log.Debug("List of services changed")
nodes, err := p.getNodes(index)
if err != nil {
notifyError(err)
}
configuration := p.buildConfiguration(nodes)
configurationChan <- types.ConfigMessage{
ProviderName: "consul_catalog",
Configuration: configuration,
}
}
})
for {
select {
case <-stop:
return nil
case err := <-errorCh:
return err
}
}
}
func (p *Provider) watchCatalogServices(stopCh <-chan struct{}, watchCh chan<- map[string][]string, notifyError func(error)) {
catalog := p.client.Catalog()
safe.Go(func() {
// variable to hold previous state
var flashback map[string]Service
options := &api.QueryOptions{WaitTime: DefaultWatchWaitTime, AllowStale: p.Stale}
for {
select {
case <-stopCh:
return
default:
}
data, meta, err := catalog.Services(options)
if err != nil {
log.Errorf("Failed to list services: %v", err)
notifyError(err)
return
}
if options.WaitIndex == meta.LastIndex {
continue
}
options.WaitIndex = meta.LastIndex
if data != nil {
current := make(map[string]Service)
for key, value := range data {
nodes, _, err := catalog.Service(key, "", &api.QueryOptions{AllowStale: p.Stale})
if err != nil {
log.Errorf("Failed to get detail of service %s: %v", key, err)
notifyError(err)
return
}
nodesID := getServiceIds(nodes)
ports := getServicePorts(nodes)
addresses := getServiceAddresses(nodes)
if service, ok := current[key]; ok {
service.Tags = value
service.Nodes = nodesID
service.Ports = ports
} else {
service := Service{
Name: key,
Tags: value,
Nodes: nodesID,
Addresses: addresses,
Ports: ports,
}
current[key] = service
}
}
// A critical note is that the return of a blocking request is no guarantee of a change.
// It is possible that there was an idempotent write that does not affect the result of the query.
// Thus it is required to do extra check for changes...
if hasChanged(current, flashback) {
watchCh <- data
flashback = current
}
}
}
})
}
func (p *Provider) watchHealthState(stopCh <-chan struct{}, watchCh chan<- map[string][]string, notifyError func(error)) {
health := p.client.Health()
catalog := p.client.Catalog()
safe.Go(func() {
// variable to hold previous state
var flashback map[string][]string
var flashbackMaintenance []string
options := &api.QueryOptions{WaitTime: DefaultWatchWaitTime, AllowStale: p.Stale}
for {
select {
case <-stopCh:
return
default:
}
// Listening to changes that leads to `passing` state or degrades from it.
healthyState, meta, err := health.State("any", options)
if err != nil {
log.WithError(err).Error("Failed to retrieve health checks")
notifyError(err)
return
}
var current = make(map[string][]string)
var currentFailing = make(map[string]*api.HealthCheck)
var maintenance []string
if healthyState != nil {
for _, healthy := range healthyState {
key := fmt.Sprintf("%s-%s", healthy.Node, healthy.ServiceID)
_, failing := currentFailing[key]
if healthy.Status == "passing" && !failing {
current[key] = append(current[key], healthy.Node)
} else if strings.HasPrefix(healthy.CheckID, "_service_maintenance") || strings.HasPrefix(healthy.CheckID, "_node_maintenance") {
maintenance = append(maintenance, healthy.CheckID)
} else {
currentFailing[key] = healthy
if _, ok := current[key]; ok {
delete(current, key)
}
}
}
}
// If LastIndex didn't change then it means `Get` returned
// because of the WaitTime and the key didn't changed.
if options.WaitIndex == meta.LastIndex {
continue
}
options.WaitIndex = meta.LastIndex
// The response should be unified with watchCatalogServices
data, _, err := catalog.Services(&api.QueryOptions{AllowStale: p.Stale})
if err != nil {
log.Errorf("Failed to list services: %v", err)
notifyError(err)
return
}
if data != nil {
// A critical note is that the return of a blocking request is no guarantee of a change.
// It is possible that there was an idempotent write that does not affect the result of the query.
// Thus it is required to do extra check for changes...
addedKeys, removedKeys, changedKeys := getChangedHealth(current, flashback)
if len(addedKeys) > 0 || len(removedKeys) > 0 || len(changedKeys) > 0 {
log.WithField("DiscoveredServices", addedKeys).
WithField("MissingServices", removedKeys).
WithField("ChangedServices", changedKeys).
Debug("Health State change detected.")
watchCh <- data
flashback = current
flashbackMaintenance = maintenance
} else {
addedKeysMaintenance, removedMaintenance := getChangedStringKeys(maintenance, flashbackMaintenance)
if len(addedKeysMaintenance) > 0 || len(removedMaintenance) > 0 {
log.WithField("MaintenanceMode", maintenance).Debug("Maintenance change detected.")
watchCh <- data
flashback = current
flashbackMaintenance = maintenance
}
}
}
}
})
}
func (p *Provider) getNodes(index map[string][]string) ([]catalogUpdate, error) {
visited := make(map[string]bool)
var nodes []catalogUpdate
for service := range index {
name := strings.ToLower(service)
if !strings.Contains(name, " ") && !visited[name] {
visited[name] = true
log.WithField("service", name).Debug("Fetching service")
healthy, err := p.healthyNodes(name)
if err != nil {
return nil, err
}
// healthy.Nodes can be empty if constraints do not match, without throwing error
if healthy.Service != nil && len(healthy.Nodes) > 0 {
nodes = append(nodes, healthy)
}
}
}
return nodes, nil
}
func hasChanged(current map[string]Service, previous map[string]Service) bool {
if len(current) != len(previous) {
return true
}
addedServiceKeys, removedServiceKeys := getChangedServiceKeys(current, previous)
return len(removedServiceKeys) > 0 || len(addedServiceKeys) > 0 || hasServiceChanged(current, previous)
}
func getChangedServiceKeys(current map[string]Service, previous map[string]Service) ([]string, []string) {
currKeySet := fun.Set(fun.Keys(current).([]string)).(map[string]bool)
prevKeySet := fun.Set(fun.Keys(previous).([]string)).(map[string]bool)
addedKeys := fun.Difference(currKeySet, prevKeySet).(map[string]bool)
removedKeys := fun.Difference(prevKeySet, currKeySet).(map[string]bool)
return fun.Keys(addedKeys).([]string), fun.Keys(removedKeys).([]string)
}
func hasServiceChanged(current map[string]Service, previous map[string]Service) bool {
for key, value := range current {
if prevValue, ok := previous[key]; ok {
addedNodesKeys, removedNodesKeys := getChangedStringKeys(value.Nodes, prevValue.Nodes)
if len(addedNodesKeys) > 0 || len(removedNodesKeys) > 0 {
return true
}
addedTagsKeys, removedTagsKeys := getChangedStringKeys(value.Tags, prevValue.Tags)
if len(addedTagsKeys) > 0 || len(removedTagsKeys) > 0 {
return true
}
addedAddressesKeys, removedAddressesKeys := getChangedStringKeys(value.Addresses, prevValue.Addresses)
if len(addedAddressesKeys) > 0 || len(removedAddressesKeys) > 0 {
return true
}
addedPortsKeys, removedPortsKeys := getChangedIntKeys(value.Ports, prevValue.Ports)
if len(addedPortsKeys) > 0 || len(removedPortsKeys) > 0 {
return true
}
}
}
return false
}
func getChangedStringKeys(currState []string, prevState []string) ([]string, []string) {
currKeySet := fun.Set(currState).(map[string]bool)
prevKeySet := fun.Set(prevState).(map[string]bool)
addedKeys := fun.Difference(currKeySet, prevKeySet).(map[string]bool)
removedKeys := fun.Difference(prevKeySet, currKeySet).(map[string]bool)
return fun.Keys(addedKeys).([]string), fun.Keys(removedKeys).([]string)
}
func getChangedHealth(current map[string][]string, previous map[string][]string) ([]string, []string, []string) {
currKeySet := fun.Set(fun.Keys(current).([]string)).(map[string]bool)
prevKeySet := fun.Set(fun.Keys(previous).([]string)).(map[string]bool)
addedKeys := fun.Difference(currKeySet, prevKeySet).(map[string]bool)
removedKeys := fun.Difference(prevKeySet, currKeySet).(map[string]bool)
var changedKeys []string
for key, value := range current {
if prevValue, ok := previous[key]; ok {
addedNodesKeys, removedNodesKeys := getChangedStringKeys(value, prevValue)
if len(addedNodesKeys) > 0 || len(removedNodesKeys) > 0 {
changedKeys = append(changedKeys, key)
}
}
}
return fun.Keys(addedKeys).([]string), fun.Keys(removedKeys).([]string), changedKeys
}
func getChangedIntKeys(currState []int, prevState []int) ([]int, []int) {
currKeySet := fun.Set(currState).(map[int]bool)
prevKeySet := fun.Set(prevState).(map[int]bool)
addedKeys := fun.Difference(currKeySet, prevKeySet).(map[int]bool)
removedKeys := fun.Difference(prevKeySet, currKeySet).(map[int]bool)
return fun.Keys(addedKeys).([]int), fun.Keys(removedKeys).([]int)
}
func getServiceIds(services []*api.CatalogService) []string {
var serviceIds []string
for _, service := range services {
serviceIds = append(serviceIds, service.ID)
}
return serviceIds
}
func getServicePorts(services []*api.CatalogService) []int {
var servicePorts []int
for _, service := range services {
servicePorts = append(servicePorts, service.ServicePort)
}
return servicePorts
}
func getServiceAddresses(services []*api.CatalogService) []string {
var serviceAddresses []string
for _, service := range services {
serviceAddresses = append(serviceAddresses, service.ServiceAddress)
}
return serviceAddresses
}
func (p *Provider) healthyNodes(service string) (catalogUpdate, error) {
health := p.client.Health()
data, _, err := health.Service(service, "", true, &api.QueryOptions{AllowStale: p.Stale})
if err != nil {
log.WithError(err).Errorf("Failed to fetch details of %s", service)
return catalogUpdate{}, err
}
nodes := fun.Filter(func(node *api.ServiceEntry) bool {
return p.nodeFilter(service, node)
}, data).([]*api.ServiceEntry)
// Merge tags of nodes matching constraints, in a single slice.
tags := fun.Foldl(func(node *api.ServiceEntry, set []string) []string {
return fun.Keys(fun.Union(
fun.Set(set),
fun.Set(node.Service.Tags),
).(map[string]bool)).([]string)
}, []string{}, nodes).([]string)
labels := tagsToNeutralLabels(tags, p.Prefix)
return catalogUpdate{
Service: &serviceUpdate{
ServiceName: service,
Attributes: tags,
TraefikLabels: labels,
},
Nodes: nodes,
}, nil
}
func (p *Provider) nodeFilter(service string, node *api.ServiceEntry) bool {
// Filter disabled application.
if !p.isServiceEnabled(node) {
log.Debugf("Filtering disabled Consul service %s", service)
return false
}
// Filter by constraints.
constraintTags := p.getConstraintTags(node.Service.Tags)
ok, failingConstraint := p.MatchConstraints(constraintTags)
if !ok && failingConstraint != nil {
log.Debugf("Service %v pruned by '%v' constraint", service, failingConstraint.String())
return false
}
return true
}
func (p *Provider) isServiceEnabled(node *api.ServiceEntry) bool {
rawValue := getTag(p.getPrefixedName(label.SuffixEnable), node.Service.Tags, "")
if len(rawValue) == 0 {
return p.ExposedByDefault
}
value, err := strconv.ParseBool(rawValue)
if err != nil {
log.Errorf("Invalid value for %s: %s", label.SuffixEnable, rawValue)
return p.ExposedByDefault
}
return value
}
func (p *Provider) getConstraintTags(tags []string) []string {
var values []string
prefix := p.getPrefixedName("tags=")
for _, tag := range tags {
// We look for a Consul tag named 'traefik.tags' (unless different 'prefix' is configured)
if strings.HasPrefix(strings.ToLower(tag), prefix) {
// If 'traefik.tags=' tag is found, take the tag value and split by ',' adding the result to the list to be returned
splitedTags := label.SplitAndTrimString(tag[len(prefix):], ",")
values = append(values, splitedTags...)
}
}
return values
}
func (p *Provider) generateFrontends(service *serviceUpdate) []*serviceUpdate {
frontends := make([]*serviceUpdate, 0)
// to support <prefix>.frontend.xxx
frontends = append(frontends, &serviceUpdate{
ServiceName: service.ServiceName,
ParentServiceName: service.ServiceName,
Attributes: service.Attributes,
TraefikLabels: service.TraefikLabels,
})
// loop over children of <prefix>.frontends.*
for _, frontend := range getSegments(p.Prefix+".frontends", p.Prefix, service.TraefikLabels) {
frontends = append(frontends, &serviceUpdate{
ServiceName: service.ServiceName + "-" + frontend.Name,
ParentServiceName: service.ServiceName,
Attributes: service.Attributes,
TraefikLabels: frontend.Labels,
})
}
return frontends
}
func getSegments(path string, prefix string, tree map[string]string) []*frontendSegment {
segments := make([]*frontendSegment, 0)
// find segment names
segmentNames := make(map[string]bool)
for key := range tree {
if strings.HasPrefix(key, path+".") {
segmentNames[strings.SplitN(strings.TrimPrefix(key, path+"."), ".", 2)[0]] = true
}
}
// get labels for each segment found
for segment := range segmentNames {
labels := make(map[string]string)
for key, value := range tree {
if strings.HasPrefix(key, path+"."+segment) {
labels[prefix+".frontend"+strings.TrimPrefix(key, path+"."+segment)] = value
}
}
segments = append(segments, &frontendSegment{
Name: segment,
Labels: labels,
})
}
return segments
}

View file

@ -0,0 +1,862 @@
package consulcatalog
import (
"sort"
"testing"
"github.com/BurntSushi/ty/fun"
"github.com/hashicorp/consul/api"
"github.com/stretchr/testify/assert"
)
func TestNodeSorter(t *testing.T) {
testCases := []struct {
desc string
nodes []*api.ServiceEntry
expected []*api.ServiceEntry
}{
{
desc: "Should sort nothing",
nodes: []*api.ServiceEntry{},
expected: []*api.ServiceEntry{},
},
{
desc: "Should sort by node address",
nodes: []*api.ServiceEntry{
{
Service: &api.AgentService{
Service: "foo",
Address: "127.0.0.1",
Port: 80,
},
Node: &api.Node{
Node: "localhost",
Address: "127.0.0.1",
},
},
},
expected: []*api.ServiceEntry{
{
Service: &api.AgentService{
Service: "foo",
Address: "127.0.0.1",
Port: 80,
},
Node: &api.Node{
Node: "localhost",
Address: "127.0.0.1",
},
},
},
},
{
desc: "Should sort by service name",
nodes: []*api.ServiceEntry{
{
Service: &api.AgentService{
Service: "foo",
Address: "127.0.0.2",
Port: 80,
},
Node: &api.Node{
Node: "localhost",
Address: "127.0.0.2",
},
},
{
Service: &api.AgentService{
Service: "bar",
Address: "127.0.0.2",
Port: 81,
},
Node: &api.Node{
Node: "localhost",
Address: "127.0.0.2",
},
},
{
Service: &api.AgentService{
Service: "foo",
Address: "127.0.0.1",
Port: 80,
},
Node: &api.Node{
Node: "localhost",
Address: "127.0.0.1",
},
},
{
Service: &api.AgentService{
Service: "bar",
Address: "127.0.0.2",
Port: 80,
},
Node: &api.Node{
Node: "localhost",
Address: "127.0.0.2",
},
},
},
expected: []*api.ServiceEntry{
{
Service: &api.AgentService{
Service: "bar",
Address: "127.0.0.2",
Port: 80,
},
Node: &api.Node{
Node: "localhost",
Address: "127.0.0.2",
},
},
{
Service: &api.AgentService{
Service: "bar",
Address: "127.0.0.2",
Port: 81,
},
Node: &api.Node{
Node: "localhost",
Address: "127.0.0.2",
},
},
{
Service: &api.AgentService{
Service: "foo",
Address: "127.0.0.1",
Port: 80,
},
Node: &api.Node{
Node: "localhost",
Address: "127.0.0.1",
},
},
{
Service: &api.AgentService{
Service: "foo",
Address: "127.0.0.2",
Port: 80,
},
Node: &api.Node{
Node: "localhost",
Address: "127.0.0.2",
},
},
},
},
{
desc: "Should sort by node address",
nodes: []*api.ServiceEntry{
{
Service: &api.AgentService{
Service: "foo",
Address: "",
Port: 80,
},
Node: &api.Node{
Node: "localhost",
Address: "127.0.0.2",
},
},
{
Service: &api.AgentService{
Service: "foo",
Address: "",
Port: 80,
},
Node: &api.Node{
Node: "localhost",
Address: "127.0.0.1",
},
},
},
expected: []*api.ServiceEntry{
{
Service: &api.AgentService{
Service: "foo",
Address: "",
Port: 80,
},
Node: &api.Node{
Node: "localhost",
Address: "127.0.0.1",
},
},
{
Service: &api.AgentService{
Service: "foo",
Address: "",
Port: 80,
},
Node: &api.Node{
Node: "localhost",
Address: "127.0.0.2",
},
},
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
sort.Sort(nodeSorter(test.nodes))
actual := test.nodes
assert.Equal(t, test.expected, actual)
})
}
}
func TestGetChangedKeys(t *testing.T) {
type Input struct {
currState map[string]Service
prevState map[string]Service
}
type Output struct {
addedKeys []string
removedKeys []string
}
testCases := []struct {
desc string
input Input
output Output
}{
{
desc: "Should add 0 services and removed 0",
input: Input{
currState: map[string]Service{
"foo-service": {Name: "v1"},
"bar-service": {Name: "v1"},
"baz-service": {Name: "v1"},
"qux-service": {Name: "v1"},
"quux-service": {Name: "v1"},
"quuz-service": {Name: "v1"},
"corge-service": {Name: "v1"},
"grault-service": {Name: "v1"},
"garply-service": {Name: "v1"},
"waldo-service": {Name: "v1"},
"fred-service": {Name: "v1"},
"plugh-service": {Name: "v1"},
"xyzzy-service": {Name: "v1"},
"thud-service": {Name: "v1"},
},
prevState: map[string]Service{
"foo-service": {Name: "v1"},
"bar-service": {Name: "v1"},
"baz-service": {Name: "v1"},
"qux-service": {Name: "v1"},
"quux-service": {Name: "v1"},
"quuz-service": {Name: "v1"},
"corge-service": {Name: "v1"},
"grault-service": {Name: "v1"},
"garply-service": {Name: "v1"},
"waldo-service": {Name: "v1"},
"fred-service": {Name: "v1"},
"plugh-service": {Name: "v1"},
"xyzzy-service": {Name: "v1"},
"thud-service": {Name: "v1"},
},
},
output: Output{
addedKeys: []string{},
removedKeys: []string{},
},
},
{
desc: "Should add 3 services and removed 0",
input: Input{
currState: map[string]Service{
"foo-service": {Name: "v1"},
"bar-service": {Name: "v1"},
"baz-service": {Name: "v1"},
"qux-service": {Name: "v1"},
"quux-service": {Name: "v1"},
"quuz-service": {Name: "v1"},
"corge-service": {Name: "v1"},
"grault-service": {Name: "v1"},
"garply-service": {Name: "v1"},
"waldo-service": {Name: "v1"},
"fred-service": {Name: "v1"},
"plugh-service": {Name: "v1"},
"xyzzy-service": {Name: "v1"},
"thud-service": {Name: "v1"},
},
prevState: map[string]Service{
"foo-service": {Name: "v1"},
"bar-service": {Name: "v1"},
"baz-service": {Name: "v1"},
"corge-service": {Name: "v1"},
"grault-service": {Name: "v1"},
"garply-service": {Name: "v1"},
"waldo-service": {Name: "v1"},
"fred-service": {Name: "v1"},
"plugh-service": {Name: "v1"},
"xyzzy-service": {Name: "v1"},
"thud-service": {Name: "v1"},
},
},
output: Output{
addedKeys: []string{"qux-service", "quux-service", "quuz-service"},
removedKeys: []string{},
},
},
{
desc: "Should add 2 services and removed 2",
input: Input{
currState: map[string]Service{
"foo-service": {Name: "v1"},
"qux-service": {Name: "v1"},
"quux-service": {Name: "v1"},
"quuz-service": {Name: "v1"},
"corge-service": {Name: "v1"},
"grault-service": {Name: "v1"},
"garply-service": {Name: "v1"},
"waldo-service": {Name: "v1"},
"fred-service": {Name: "v1"},
"plugh-service": {Name: "v1"},
"xyzzy-service": {Name: "v1"},
"thud-service": {Name: "v1"},
},
prevState: map[string]Service{
"foo-service": {Name: "v1"},
"bar-service": {Name: "v1"},
"baz-service": {Name: "v1"},
"qux-service": {Name: "v1"},
"quux-service": {Name: "v1"},
"quuz-service": {Name: "v1"},
"corge-service": {Name: "v1"},
"waldo-service": {Name: "v1"},
"fred-service": {Name: "v1"},
"plugh-service": {Name: "v1"},
"xyzzy-service": {Name: "v1"},
"thud-service": {Name: "v1"},
},
},
output: Output{
addedKeys: []string{"grault-service", "garply-service"},
removedKeys: []string{"bar-service", "baz-service"},
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
addedKeys, removedKeys := getChangedServiceKeys(test.input.currState, test.input.prevState)
assert.Equal(t, fun.Set(test.output.addedKeys), fun.Set(addedKeys), "Added keys comparison results: got %q, want %q", addedKeys, test.output.addedKeys)
assert.Equal(t, fun.Set(test.output.removedKeys), fun.Set(removedKeys), "Removed keys comparison results: got %q, want %q", removedKeys, test.output.removedKeys)
})
}
}
func TestFilterEnabled(t *testing.T) {
testCases := []struct {
desc string
exposedByDefault bool
node *api.ServiceEntry
expected bool
}{
{
desc: "exposed",
exposedByDefault: true,
node: &api.ServiceEntry{
Service: &api.AgentService{
Service: "api",
Address: "10.0.0.1",
Port: 80,
Tags: []string{""},
},
},
expected: true,
},
{
desc: "exposed and tolerated by valid label value",
exposedByDefault: true,
node: &api.ServiceEntry{
Service: &api.AgentService{
Service: "api",
Address: "10.0.0.1",
Port: 80,
Tags: []string{"", "traefik.enable=true"},
},
},
expected: true,
},
{
desc: "exposed and tolerated by invalid label value",
exposedByDefault: true,
node: &api.ServiceEntry{
Service: &api.AgentService{
Service: "api",
Address: "10.0.0.1",
Port: 80,
Tags: []string{"", "traefik.enable=bad"},
},
},
expected: true,
},
{
desc: "exposed but overridden by label",
exposedByDefault: true,
node: &api.ServiceEntry{
Service: &api.AgentService{
Service: "api",
Address: "10.0.0.1",
Port: 80,
Tags: []string{"", "traefik.enable=false"},
},
},
expected: false,
},
{
desc: "non-exposed",
exposedByDefault: false,
node: &api.ServiceEntry{
Service: &api.AgentService{
Service: "api",
Address: "10.0.0.1",
Port: 80,
Tags: []string{""},
},
},
expected: false,
},
{
desc: "non-exposed but overridden by label",
exposedByDefault: false,
node: &api.ServiceEntry{
Service: &api.AgentService{
Service: "api",
Address: "10.0.0.1",
Port: 80,
Tags: []string{"", "traefik.enable=true"},
},
},
expected: true,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
provider := &Provider{
Domain: "localhost",
Prefix: "traefik",
ExposedByDefault: test.exposedByDefault,
}
actual := provider.nodeFilter("test", test.node)
assert.Equal(t, test.expected, actual)
})
}
}
func TestGetChangedStringKeys(t *testing.T) {
testCases := []struct {
desc string
current []string
previous []string
expectedAdded []string
expectedRemoved []string
}{
{
desc: "1 element added, 0 removed",
current: []string{"chou"},
previous: []string{},
expectedAdded: []string{"chou"},
expectedRemoved: []string{},
}, {
desc: "0 element added, 0 removed",
current: []string{"chou"},
previous: []string{"chou"},
expectedAdded: []string{},
expectedRemoved: []string{},
},
{
desc: "0 element added, 1 removed",
current: []string{},
previous: []string{"chou"},
expectedAdded: []string{},
expectedRemoved: []string{"chou"},
},
{
desc: "1 element added, 1 removed",
current: []string{"carotte"},
previous: []string{"chou"},
expectedAdded: []string{"carotte"},
expectedRemoved: []string{"chou"},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
actualAdded, actualRemoved := getChangedStringKeys(test.current, test.previous)
assert.Equal(t, test.expectedAdded, actualAdded)
assert.Equal(t, test.expectedRemoved, actualRemoved)
})
}
}
func TestHasServiceChanged(t *testing.T) {
testCases := []struct {
desc string
current map[string]Service
previous map[string]Service
expected bool
}{
{
desc: "Change detected due to change of nodes",
current: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{},
},
},
previous: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node2"},
Tags: []string{},
},
},
expected: true,
},
{
desc: "No change missing current service",
current: make(map[string]Service),
previous: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{},
},
},
expected: false,
},
{
desc: "No change on nodes",
current: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{},
},
},
previous: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{},
},
},
expected: false,
},
{
desc: "No change on nodes and tags",
current: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{"foo=bar"},
},
},
previous: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{"foo=bar"},
},
},
expected: false,
},
{
desc: "Change detected on tags",
current: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{"foo=bar"},
},
},
previous: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{"foo"},
},
},
expected: true,
},
{
desc: "Change detected on ports",
current: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{"foo=bar"},
Ports: []int{80},
},
},
previous: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{"foo"},
Ports: []int{81},
},
},
expected: true,
},
{
desc: "Change detected on ports",
current: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{"foo=bar"},
Ports: []int{80},
},
},
previous: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{"foo"},
Ports: []int{81, 82},
},
},
expected: true,
},
{
desc: "Change detected on addresses",
current: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Addresses: []string{"127.0.0.1"},
},
},
previous: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Addresses: []string{"127.0.0.2"},
},
},
expected: true,
},
{
desc: "No Change detected",
current: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{"foo"},
Ports: []int{80},
},
},
previous: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{"foo"},
Ports: []int{80},
},
},
expected: false,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
actual := hasServiceChanged(test.current, test.previous)
assert.Equal(t, test.expected, actual)
})
}
}
func TestHasChanged(t *testing.T) {
testCases := []struct {
desc string
current map[string]Service
previous map[string]Service
expected bool
}{
{
desc: "Change detected due to change new service",
current: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{},
},
},
previous: make(map[string]Service),
expected: true,
},
{
desc: "Change detected due to change service removed",
current: make(map[string]Service),
previous: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{},
},
},
expected: true,
},
{
desc: "Change detected due to change of nodes",
current: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{},
},
},
previous: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node2"},
Tags: []string{},
},
},
expected: true,
},
{
desc: "No change on nodes",
current: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{},
},
},
previous: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{},
},
},
expected: false,
},
{
desc: "No change on nodes and tags",
current: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{"foo=bar"},
},
},
previous: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{"foo=bar"},
},
},
expected: false,
},
{
desc: "Change detected on tags",
current: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{"foo=bar"},
},
},
previous: map[string]Service{
"foo-service": {
Name: "foo",
Nodes: []string{"node1"},
Tags: []string{"foo"},
},
},
expected: true,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
actual := hasChanged(test.current, test.previous)
assert.Equal(t, test.expected, actual)
})
}
}
func TestGetConstraintTags(t *testing.T) {
provider := &Provider{
Domain: "localhost",
Prefix: "traefik",
}
testCases := []struct {
desc string
tags []string
expected []string
}{
{
desc: "nil tags",
},
{
desc: "invalid tag",
tags: []string{"tags=foobar"},
expected: nil,
},
{
desc: "wrong tag",
tags: []string{"traefik_tags=foobar"},
expected: nil,
},
{
desc: "empty value",
tags: []string{"traefik.tags="},
expected: nil,
},
{
desc: "simple tag",
tags: []string{"traefik.tags=foobar "},
expected: []string{"foobar"},
},
{
desc: "multiple values tag",
tags: []string{"traefik.tags=foobar, fiibir"},
expected: []string{"foobar", "fiibir"},
},
{
desc: "multiple tags",
tags: []string{"traefik.tags=foobar", "traefik.tags=foobor"},
expected: []string{"foobar", "foobor"},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
constraints := provider.getConstraintTags(test.tags)
assert.EqualValues(t, test.expected, constraints)
})
}
}

Some files were not shown because too many files have changed in this diff Show more