Dynamic Configuration Refactoring
This commit is contained in:
parent
d3ae88f108
commit
a09dfa3ce1
452 changed files with 21023 additions and 9419 deletions
39
old/api/dashboard.go
Normal file
39
old/api/dashboard.go
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/elazarl/go-bindata-assetfs"
|
||||
)
|
||||
|
||||
// DashboardHandler expose dashboard routes
|
||||
type DashboardHandler struct {
|
||||
Assets *assetfs.AssetFS
|
||||
}
|
||||
|
||||
// AddRoutes add dashboard routes on a router
|
||||
func (g DashboardHandler) AddRoutes(router *mux.Router) {
|
||||
if g.Assets == nil {
|
||||
log.Error("No assets for dashboard")
|
||||
return
|
||||
}
|
||||
|
||||
// Expose dashboard
|
||||
router.Methods(http.MethodGet).
|
||||
Path("/").
|
||||
HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
|
||||
http.Redirect(response, request, request.Header.Get("X-Forwarded-Prefix")+"/dashboard/", 302)
|
||||
})
|
||||
|
||||
router.Methods(http.MethodGet).
|
||||
Path("/dashboard/status").
|
||||
HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
|
||||
http.Redirect(response, request, "/dashboard/", 302)
|
||||
})
|
||||
|
||||
router.Methods(http.MethodGet).
|
||||
PathPrefix("/dashboard/").
|
||||
Handler(http.StripPrefix("/dashboard/", http.FileServer(g.Assets)))
|
||||
}
|
||||
48
old/api/debug.go
Normal file
48
old/api/debug.go
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"expvar"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/pprof"
|
||||
"runtime"
|
||||
|
||||
"github.com/containous/mux"
|
||||
)
|
||||
|
||||
func init() {
|
||||
expvar.Publish("Goroutines", expvar.Func(goroutines))
|
||||
}
|
||||
|
||||
func goroutines() interface{} {
|
||||
return runtime.NumGoroutine()
|
||||
}
|
||||
|
||||
// DebugHandler expose debug routes
|
||||
type DebugHandler struct{}
|
||||
|
||||
// AddRoutes add debug routes on a router
|
||||
func (g DebugHandler) AddRoutes(router *mux.Router) {
|
||||
router.Methods(http.MethodGet).Path("/debug/vars").
|
||||
HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
fmt.Fprint(w, "{\n")
|
||||
first := true
|
||||
expvar.Do(func(kv expvar.KeyValue) {
|
||||
if !first {
|
||||
fmt.Fprint(w, ",\n")
|
||||
}
|
||||
first = false
|
||||
fmt.Fprintf(w, "%q: %s", kv.Key, kv.Value)
|
||||
})
|
||||
fmt.Fprint(w, "\n}\n")
|
||||
})
|
||||
|
||||
runtime.SetBlockProfileRate(1)
|
||||
runtime.SetMutexProfileFraction(5)
|
||||
router.Methods(http.MethodGet).PathPrefix("/debug/pprof/cmdline").HandlerFunc(pprof.Cmdline)
|
||||
router.Methods(http.MethodGet).PathPrefix("/debug/pprof/profile").HandlerFunc(pprof.Profile)
|
||||
router.Methods(http.MethodGet).PathPrefix("/debug/pprof/symbol").HandlerFunc(pprof.Symbol)
|
||||
router.Methods(http.MethodGet).PathPrefix("/debug/pprof/trace").HandlerFunc(pprof.Trace)
|
||||
router.Methods(http.MethodGet).PathPrefix("/debug/pprof/").HandlerFunc(pprof.Index)
|
||||
}
|
||||
252
old/api/handler.go
Normal file
252
old/api/handler.go
Normal file
|
|
@ -0,0 +1,252 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/containous/traefik/old/middlewares"
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/containous/traefik/safe"
|
||||
"github.com/containous/traefik/version"
|
||||
"github.com/elazarl/go-bindata-assetfs"
|
||||
thoas_stats "github.com/thoas/stats"
|
||||
"github.com/unrolled/render"
|
||||
)
|
||||
|
||||
// Handler expose api routes
|
||||
type Handler struct {
|
||||
EntryPoint string `description:"EntryPoint" export:"true"`
|
||||
Dashboard bool `description:"Activate dashboard" export:"true"`
|
||||
Debug bool `export:"true"`
|
||||
CurrentConfigurations *safe.Safe
|
||||
Statistics *types.Statistics `description:"Enable more detailed statistics" export:"true"`
|
||||
Stats *thoas_stats.Stats `json:"-"`
|
||||
StatsRecorder *middlewares.StatsRecorder `json:"-"`
|
||||
DashboardAssets *assetfs.AssetFS `json:"-"`
|
||||
}
|
||||
|
||||
var (
|
||||
templatesRenderer = render.New(render.Options{
|
||||
Directory: "nowhere",
|
||||
})
|
||||
)
|
||||
|
||||
// AddRoutes add api routes on a router
|
||||
func (p Handler) AddRoutes(router *mux.Router) {
|
||||
if p.Debug {
|
||||
DebugHandler{}.AddRoutes(router)
|
||||
}
|
||||
|
||||
router.Methods(http.MethodGet).Path("/api").HandlerFunc(p.getConfigHandler)
|
||||
router.Methods(http.MethodGet).Path("/api/providers").HandlerFunc(p.getConfigHandler)
|
||||
router.Methods(http.MethodGet).Path("/api/providers/{provider}").HandlerFunc(p.getProviderHandler)
|
||||
router.Methods(http.MethodGet).Path("/api/providers/{provider}/backends").HandlerFunc(p.getBackendsHandler)
|
||||
router.Methods(http.MethodGet).Path("/api/providers/{provider}/backends/{backend}").HandlerFunc(p.getBackendHandler)
|
||||
router.Methods(http.MethodGet).Path("/api/providers/{provider}/backends/{backend}/servers").HandlerFunc(p.getServersHandler)
|
||||
router.Methods(http.MethodGet).Path("/api/providers/{provider}/backends/{backend}/servers/{server}").HandlerFunc(p.getServerHandler)
|
||||
router.Methods(http.MethodGet).Path("/api/providers/{provider}/frontends").HandlerFunc(p.getFrontendsHandler)
|
||||
router.Methods(http.MethodGet).Path("/api/providers/{provider}/frontends/{frontend}").HandlerFunc(p.getFrontendHandler)
|
||||
router.Methods(http.MethodGet).Path("/api/providers/{provider}/frontends/{frontend}/routes").HandlerFunc(p.getRoutesHandler)
|
||||
router.Methods(http.MethodGet).Path("/api/providers/{provider}/frontends/{frontend}/routes/{route}").HandlerFunc(p.getRouteHandler)
|
||||
|
||||
// health route
|
||||
router.Methods(http.MethodGet).Path("/health").HandlerFunc(p.getHealthHandler)
|
||||
|
||||
version.Handler{}.Append(router)
|
||||
|
||||
if p.Dashboard {
|
||||
DashboardHandler{Assets: p.DashboardAssets}.AddRoutes(router)
|
||||
}
|
||||
}
|
||||
|
||||
func getProviderIDFromVars(vars map[string]string) string {
|
||||
providerID := vars["provider"]
|
||||
// TODO: Deprecated
|
||||
if providerID == "rest" {
|
||||
providerID = "web"
|
||||
}
|
||||
return providerID
|
||||
}
|
||||
|
||||
func (p Handler) getConfigHandler(response http.ResponseWriter, request *http.Request) {
|
||||
currentConfigurations := p.CurrentConfigurations.Get().(types.Configurations)
|
||||
err := templatesRenderer.JSON(response, http.StatusOK, currentConfigurations)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p Handler) getProviderHandler(response http.ResponseWriter, request *http.Request) {
|
||||
providerID := getProviderIDFromVars(mux.Vars(request))
|
||||
|
||||
currentConfigurations := p.CurrentConfigurations.Get().(types.Configurations)
|
||||
if provider, ok := currentConfigurations[providerID]; ok {
|
||||
err := templatesRenderer.JSON(response, http.StatusOK, provider)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
} else {
|
||||
http.NotFound(response, request)
|
||||
}
|
||||
}
|
||||
|
||||
func (p Handler) getBackendsHandler(response http.ResponseWriter, request *http.Request) {
|
||||
providerID := getProviderIDFromVars(mux.Vars(request))
|
||||
|
||||
currentConfigurations := p.CurrentConfigurations.Get().(types.Configurations)
|
||||
if provider, ok := currentConfigurations[providerID]; ok {
|
||||
err := templatesRenderer.JSON(response, http.StatusOK, provider.Backends)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
} else {
|
||||
http.NotFound(response, request)
|
||||
}
|
||||
}
|
||||
|
||||
func (p Handler) getBackendHandler(response http.ResponseWriter, request *http.Request) {
|
||||
vars := mux.Vars(request)
|
||||
providerID := getProviderIDFromVars(vars)
|
||||
backendID := vars["backend"]
|
||||
|
||||
currentConfigurations := p.CurrentConfigurations.Get().(types.Configurations)
|
||||
if provider, ok := currentConfigurations[providerID]; ok {
|
||||
if backend, ok := provider.Backends[backendID]; ok {
|
||||
err := templatesRenderer.JSON(response, http.StatusOK, backend)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
http.NotFound(response, request)
|
||||
}
|
||||
|
||||
func (p Handler) getServersHandler(response http.ResponseWriter, request *http.Request) {
|
||||
vars := mux.Vars(request)
|
||||
providerID := getProviderIDFromVars(vars)
|
||||
backendID := vars["backend"]
|
||||
|
||||
currentConfigurations := p.CurrentConfigurations.Get().(types.Configurations)
|
||||
if provider, ok := currentConfigurations[providerID]; ok {
|
||||
if backend, ok := provider.Backends[backendID]; ok {
|
||||
err := templatesRenderer.JSON(response, http.StatusOK, backend.Servers)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
http.NotFound(response, request)
|
||||
}
|
||||
|
||||
func (p Handler) getServerHandler(response http.ResponseWriter, request *http.Request) {
|
||||
vars := mux.Vars(request)
|
||||
providerID := getProviderIDFromVars(vars)
|
||||
backendID := vars["backend"]
|
||||
serverID := vars["server"]
|
||||
|
||||
currentConfigurations := p.CurrentConfigurations.Get().(types.Configurations)
|
||||
if provider, ok := currentConfigurations[providerID]; ok {
|
||||
if backend, ok := provider.Backends[backendID]; ok {
|
||||
if server, ok := backend.Servers[serverID]; ok {
|
||||
err := templatesRenderer.JSON(response, http.StatusOK, server)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
http.NotFound(response, request)
|
||||
}
|
||||
|
||||
func (p Handler) getFrontendsHandler(response http.ResponseWriter, request *http.Request) {
|
||||
providerID := getProviderIDFromVars(mux.Vars(request))
|
||||
|
||||
currentConfigurations := p.CurrentConfigurations.Get().(types.Configurations)
|
||||
if provider, ok := currentConfigurations[providerID]; ok {
|
||||
err := templatesRenderer.JSON(response, http.StatusOK, provider.Frontends)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
} else {
|
||||
http.NotFound(response, request)
|
||||
}
|
||||
}
|
||||
|
||||
func (p Handler) getFrontendHandler(response http.ResponseWriter, request *http.Request) {
|
||||
vars := mux.Vars(request)
|
||||
providerID := getProviderIDFromVars(vars)
|
||||
frontendID := vars["frontend"]
|
||||
|
||||
currentConfigurations := p.CurrentConfigurations.Get().(types.Configurations)
|
||||
if provider, ok := currentConfigurations[providerID]; ok {
|
||||
if frontend, ok := provider.Frontends[frontendID]; ok {
|
||||
err := templatesRenderer.JSON(response, http.StatusOK, frontend)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
http.NotFound(response, request)
|
||||
}
|
||||
|
||||
func (p Handler) getRoutesHandler(response http.ResponseWriter, request *http.Request) {
|
||||
vars := mux.Vars(request)
|
||||
providerID := getProviderIDFromVars(vars)
|
||||
frontendID := vars["frontend"]
|
||||
|
||||
currentConfigurations := p.CurrentConfigurations.Get().(types.Configurations)
|
||||
if provider, ok := currentConfigurations[providerID]; ok {
|
||||
if frontend, ok := provider.Frontends[frontendID]; ok {
|
||||
err := templatesRenderer.JSON(response, http.StatusOK, frontend.Routes)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
http.NotFound(response, request)
|
||||
}
|
||||
|
||||
func (p Handler) getRouteHandler(response http.ResponseWriter, request *http.Request) {
|
||||
vars := mux.Vars(request)
|
||||
providerID := getProviderIDFromVars(vars)
|
||||
frontendID := vars["frontend"]
|
||||
routeID := vars["route"]
|
||||
|
||||
currentConfigurations := p.CurrentConfigurations.Get().(types.Configurations)
|
||||
if provider, ok := currentConfigurations[providerID]; ok {
|
||||
if frontend, ok := provider.Frontends[frontendID]; ok {
|
||||
if route, ok := frontend.Routes[routeID]; ok {
|
||||
err := templatesRenderer.JSON(response, http.StatusOK, route)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
http.NotFound(response, request)
|
||||
}
|
||||
|
||||
// healthResponse combines data returned by thoas/stats with statistics (if
|
||||
// they are enabled).
|
||||
type healthResponse struct {
|
||||
*thoas_stats.Data
|
||||
*middlewares.Stats
|
||||
}
|
||||
|
||||
func (p *Handler) getHealthHandler(response http.ResponseWriter, request *http.Request) {
|
||||
health := &healthResponse{Data: p.Stats.Data()}
|
||||
if p.StatsRecorder != nil {
|
||||
health.Stats = p.StatsRecorder.Data()
|
||||
}
|
||||
err := templatesRenderer.JSON(response, http.StatusOK, health)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
445
old/configuration/configuration.go
Normal file
445
old/configuration/configuration.go
Normal file
|
|
@ -0,0 +1,445 @@
|
|||
package configuration
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/containous/flaeg/parse"
|
||||
"github.com/containous/traefik/acme"
|
||||
"github.com/containous/traefik/old/api"
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/containous/traefik/old/middlewares/tracing"
|
||||
"github.com/containous/traefik/old/middlewares/tracing/datadog"
|
||||
"github.com/containous/traefik/old/middlewares/tracing/jaeger"
|
||||
"github.com/containous/traefik/old/middlewares/tracing/zipkin"
|
||||
"github.com/containous/traefik/old/ping"
|
||||
"github.com/containous/traefik/old/provider/boltdb"
|
||||
"github.com/containous/traefik/old/provider/consul"
|
||||
"github.com/containous/traefik/old/provider/consulcatalog"
|
||||
"github.com/containous/traefik/old/provider/docker"
|
||||
"github.com/containous/traefik/old/provider/dynamodb"
|
||||
"github.com/containous/traefik/old/provider/ecs"
|
||||
"github.com/containous/traefik/old/provider/etcd"
|
||||
"github.com/containous/traefik/old/provider/eureka"
|
||||
"github.com/containous/traefik/old/provider/file"
|
||||
"github.com/containous/traefik/old/provider/kubernetes"
|
||||
"github.com/containous/traefik/old/provider/marathon"
|
||||
"github.com/containous/traefik/old/provider/mesos"
|
||||
"github.com/containous/traefik/old/provider/rancher"
|
||||
"github.com/containous/traefik/old/provider/rest"
|
||||
"github.com/containous/traefik/old/provider/zk"
|
||||
"github.com/containous/traefik/old/types"
|
||||
acmeprovider "github.com/containous/traefik/provider/acme"
|
||||
"github.com/containous/traefik/tls"
|
||||
newtypes "github.com/containous/traefik/types"
|
||||
"github.com/pkg/errors"
|
||||
lego "github.com/xenolf/lego/acme"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultInternalEntryPointName the name of the default internal entry point
|
||||
DefaultInternalEntryPointName = "traefik"
|
||||
|
||||
// DefaultHealthCheckInterval is the default health check interval.
|
||||
DefaultHealthCheckInterval = 30 * time.Second
|
||||
|
||||
// DefaultHealthCheckTimeout is the default health check request timeout.
|
||||
DefaultHealthCheckTimeout = 5 * time.Second
|
||||
|
||||
// DefaultDialTimeout when connecting to a backend server.
|
||||
DefaultDialTimeout = 30 * time.Second
|
||||
|
||||
// DefaultIdleTimeout before closing an idle connection.
|
||||
DefaultIdleTimeout = 180 * time.Second
|
||||
|
||||
// DefaultGraceTimeout controls how long Traefik serves pending requests
|
||||
// prior to shutting down.
|
||||
DefaultGraceTimeout = 10 * time.Second
|
||||
|
||||
// DefaultAcmeCAServer is the default ACME API endpoint
|
||||
DefaultAcmeCAServer = "https://acme-v02.api.letsencrypt.org/directory"
|
||||
)
|
||||
|
||||
// GlobalConfiguration holds global configuration (with providers, etc.).
|
||||
// It's populated from the traefik configuration file passed as an argument to the binary.
|
||||
type GlobalConfiguration struct {
|
||||
LifeCycle *LifeCycle `description:"Timeouts influencing the server life cycle" export:"true"`
|
||||
Debug bool `short:"d" description:"Enable debug mode" export:"true"`
|
||||
CheckNewVersion bool `description:"Periodically check if a new version has been released" export:"true"`
|
||||
SendAnonymousUsage bool `description:"send periodically anonymous usage statistics" export:"true"`
|
||||
AccessLog *types.AccessLog `description:"Access log settings" export:"true"`
|
||||
TraefikLog *types.TraefikLog `description:"Traefik log settings" export:"true"`
|
||||
Tracing *tracing.Tracing `description:"OpenTracing configuration" export:"true"`
|
||||
LogLevel string `short:"l" description:"Log level" export:"true"`
|
||||
EntryPoints EntryPoints `description:"Entrypoints definition using format: --entryPoints='Name:http Address::8000 Redirect.EntryPoint:https' --entryPoints='Name:https Address::4442 TLS:tests/traefik.crt,tests/traefik.key;prod/traefik.crt,prod/traefik.key'" export:"true"`
|
||||
Cluster *types.Cluster
|
||||
Constraints types.Constraints `description:"Filter services by constraint, matching with service tags" export:"true"`
|
||||
ACME *acme.ACME `description:"Enable ACME (Let's Encrypt): automatic SSL" export:"true"`
|
||||
DefaultEntryPoints DefaultEntryPoints `description:"Entrypoints to be used by frontends that do not specify any entrypoint" export:"true"`
|
||||
ProvidersThrottleDuration parse.Duration `description:"Backends throttle duration: minimum duration between 2 events from providers before applying a new configuration. It avoids unnecessary reloads if multiples events are sent in a short amount of time." export:"true"`
|
||||
MaxIdleConnsPerHost int `description:"If non-zero, controls the maximum idle (keep-alive) to keep per-host. If zero, DefaultMaxIdleConnsPerHost is used" export:"true"`
|
||||
InsecureSkipVerify bool `description:"Disable SSL certificate verification" export:"true"`
|
||||
RootCAs tls.FilesOrContents `description:"Add cert file for self-signed certificate"`
|
||||
Retry *Retry `description:"Enable retry sending request if network error" export:"true"`
|
||||
HealthCheck *HealthCheckConfig `description:"Health check parameters" export:"true"`
|
||||
RespondingTimeouts *RespondingTimeouts `description:"Timeouts for incoming requests to the Traefik instance" export:"true"`
|
||||
ForwardingTimeouts *ForwardingTimeouts `description:"Timeouts for requests forwarded to the backend servers" export:"true"`
|
||||
KeepTrailingSlash bool `description:"Do not remove trailing slash." export:"true"` // Deprecated
|
||||
Docker *docker.Provider `description:"Enable Docker backend with default settings" export:"true"`
|
||||
File *file.Provider `description:"Enable File backend with default settings" export:"true"`
|
||||
Marathon *marathon.Provider `description:"Enable Marathon backend with default settings" export:"true"`
|
||||
Consul *consul.Provider `description:"Enable Consul backend with default settings" export:"true"`
|
||||
ConsulCatalog *consulcatalog.Provider `description:"Enable Consul catalog backend with default settings" export:"true"`
|
||||
Etcd *etcd.Provider `description:"Enable Etcd backend with default settings" export:"true"`
|
||||
Zookeeper *zk.Provider `description:"Enable Zookeeper backend with default settings" export:"true"`
|
||||
Boltdb *boltdb.Provider `description:"Enable Boltdb backend with default settings" export:"true"`
|
||||
Kubernetes *kubernetes.Provider `description:"Enable Kubernetes backend with default settings" export:"true"`
|
||||
Mesos *mesos.Provider `description:"Enable Mesos backend with default settings" export:"true"`
|
||||
Eureka *eureka.Provider `description:"Enable Eureka backend with default settings" export:"true"`
|
||||
ECS *ecs.Provider `description:"Enable ECS backend with default settings" export:"true"`
|
||||
Rancher *rancher.Provider `description:"Enable Rancher backend with default settings" export:"true"`
|
||||
DynamoDB *dynamodb.Provider `description:"Enable DynamoDB backend with default settings" export:"true"`
|
||||
Rest *rest.Provider `description:"Enable Rest backend with default settings" export:"true"`
|
||||
API *api.Handler `description:"Enable api/dashboard" export:"true"`
|
||||
Metrics *types.Metrics `description:"Enable a metrics exporter" export:"true"`
|
||||
Ping *ping.Handler `description:"Enable ping" export:"true"`
|
||||
HostResolver *HostResolverConfig `description:"Enable CNAME Flattening" export:"true"`
|
||||
}
|
||||
|
||||
// SetEffectiveConfiguration adds missing configuration parameters derived from existing ones.
|
||||
// It also takes care of maintaining backwards compatibility.
|
||||
func (gc *GlobalConfiguration) SetEffectiveConfiguration(configFile string) {
|
||||
if len(gc.EntryPoints) == 0 {
|
||||
gc.EntryPoints = map[string]*EntryPoint{"http": {
|
||||
Address: ":80",
|
||||
ForwardedHeaders: &ForwardedHeaders{},
|
||||
}}
|
||||
gc.DefaultEntryPoints = []string{"http"}
|
||||
}
|
||||
|
||||
if (gc.API != nil && gc.API.EntryPoint == DefaultInternalEntryPointName) ||
|
||||
(gc.Ping != nil && gc.Ping.EntryPoint == DefaultInternalEntryPointName) ||
|
||||
(gc.Metrics != nil && gc.Metrics.Prometheus != nil && gc.Metrics.Prometheus.EntryPoint == DefaultInternalEntryPointName) ||
|
||||
(gc.Rest != nil && gc.Rest.EntryPoint == DefaultInternalEntryPointName) {
|
||||
if _, ok := gc.EntryPoints[DefaultInternalEntryPointName]; !ok {
|
||||
gc.EntryPoints[DefaultInternalEntryPointName] = &EntryPoint{Address: ":8080"}
|
||||
}
|
||||
}
|
||||
|
||||
for entryPointName := range gc.EntryPoints {
|
||||
entryPoint := gc.EntryPoints[entryPointName]
|
||||
// ForwardedHeaders must be remove in the next breaking version
|
||||
if entryPoint.ForwardedHeaders == nil {
|
||||
entryPoint.ForwardedHeaders = &ForwardedHeaders{}
|
||||
}
|
||||
|
||||
if entryPoint.TLS != nil && entryPoint.TLS.DefaultCertificate == nil && len(entryPoint.TLS.Certificates) > 0 {
|
||||
log.Infof("No tls.defaultCertificate given for %s: using the first item in tls.certificates as a fallback.", entryPointName)
|
||||
entryPoint.TLS.DefaultCertificate = &entryPoint.TLS.Certificates[0]
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure LifeCycle isn't nil to spare nil checks elsewhere.
|
||||
if gc.LifeCycle == nil {
|
||||
gc.LifeCycle = &LifeCycle{}
|
||||
}
|
||||
|
||||
if gc.Rancher != nil {
|
||||
// Ensure backwards compatibility for now
|
||||
if len(gc.Rancher.AccessKey) > 0 ||
|
||||
len(gc.Rancher.Endpoint) > 0 ||
|
||||
len(gc.Rancher.SecretKey) > 0 {
|
||||
|
||||
if gc.Rancher.API == nil {
|
||||
gc.Rancher.API = &rancher.APIConfiguration{
|
||||
AccessKey: gc.Rancher.AccessKey,
|
||||
SecretKey: gc.Rancher.SecretKey,
|
||||
Endpoint: gc.Rancher.Endpoint,
|
||||
}
|
||||
}
|
||||
log.Warn("Deprecated configuration found: rancher.[accesskey|secretkey|endpoint]. " +
|
||||
"Please use rancher.api.[accesskey|secretkey|endpoint] instead.")
|
||||
}
|
||||
|
||||
if gc.Rancher.Metadata != nil && len(gc.Rancher.Metadata.Prefix) == 0 {
|
||||
gc.Rancher.Metadata.Prefix = "latest"
|
||||
}
|
||||
}
|
||||
|
||||
if gc.API != nil {
|
||||
gc.API.Debug = gc.Debug
|
||||
}
|
||||
|
||||
if gc.File != nil {
|
||||
gc.File.TraefikFile = configFile
|
||||
}
|
||||
|
||||
gc.initACMEProvider()
|
||||
gc.initTracing()
|
||||
}
|
||||
|
||||
func (gc *GlobalConfiguration) initTracing() {
|
||||
if gc.Tracing != nil {
|
||||
switch gc.Tracing.Backend {
|
||||
case jaeger.Name:
|
||||
if gc.Tracing.Jaeger == nil {
|
||||
gc.Tracing.Jaeger = &jaeger.Config{
|
||||
SamplingServerURL: "http://localhost:5778/sampling",
|
||||
SamplingType: "const",
|
||||
SamplingParam: 1.0,
|
||||
LocalAgentHostPort: "127.0.0.1:6831",
|
||||
Propagation: "jaeger",
|
||||
Gen128Bit: false,
|
||||
}
|
||||
}
|
||||
if gc.Tracing.Zipkin != nil {
|
||||
log.Warn("Zipkin configuration will be ignored")
|
||||
gc.Tracing.Zipkin = nil
|
||||
}
|
||||
if gc.Tracing.DataDog != nil {
|
||||
log.Warn("DataDog configuration will be ignored")
|
||||
gc.Tracing.DataDog = nil
|
||||
}
|
||||
case zipkin.Name:
|
||||
if gc.Tracing.Zipkin == nil {
|
||||
gc.Tracing.Zipkin = &zipkin.Config{
|
||||
HTTPEndpoint: "http://localhost:9411/api/v1/spans",
|
||||
SameSpan: false,
|
||||
ID128Bit: true,
|
||||
Debug: false,
|
||||
SampleRate: 1.0,
|
||||
}
|
||||
}
|
||||
if gc.Tracing.Jaeger != nil {
|
||||
log.Warn("Jaeger configuration will be ignored")
|
||||
gc.Tracing.Jaeger = nil
|
||||
}
|
||||
if gc.Tracing.DataDog != nil {
|
||||
log.Warn("DataDog configuration will be ignored")
|
||||
gc.Tracing.DataDog = nil
|
||||
}
|
||||
case datadog.Name:
|
||||
if gc.Tracing.DataDog == nil {
|
||||
gc.Tracing.DataDog = &datadog.Config{
|
||||
LocalAgentHostPort: "localhost:8126",
|
||||
GlobalTag: "",
|
||||
Debug: false,
|
||||
}
|
||||
}
|
||||
if gc.Tracing.Zipkin != nil {
|
||||
log.Warn("Zipkin configuration will be ignored")
|
||||
gc.Tracing.Zipkin = nil
|
||||
}
|
||||
if gc.Tracing.Jaeger != nil {
|
||||
log.Warn("Jaeger configuration will be ignored")
|
||||
gc.Tracing.Jaeger = nil
|
||||
}
|
||||
default:
|
||||
log.Warnf("Unknown tracer %q", gc.Tracing.Backend)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (gc *GlobalConfiguration) initACMEProvider() {
|
||||
if gc.ACME != nil {
|
||||
gc.ACME.CAServer = getSafeACMECAServer(gc.ACME.CAServer)
|
||||
|
||||
if gc.ACME.DNSChallenge != nil && gc.ACME.HTTPChallenge != nil {
|
||||
log.Warn("Unable to use DNS challenge and HTTP challenge at the same time. Fallback to DNS challenge.")
|
||||
gc.ACME.HTTPChallenge = nil
|
||||
}
|
||||
|
||||
if gc.ACME.DNSChallenge != nil && gc.ACME.TLSChallenge != nil {
|
||||
log.Warn("Unable to use DNS challenge and TLS challenge at the same time. Fallback to DNS challenge.")
|
||||
gc.ACME.TLSChallenge = nil
|
||||
}
|
||||
|
||||
if gc.ACME.HTTPChallenge != nil && gc.ACME.TLSChallenge != nil {
|
||||
log.Warn("Unable to use HTTP challenge and TLS challenge at the same time. Fallback to TLS challenge.")
|
||||
gc.ACME.HTTPChallenge = nil
|
||||
}
|
||||
|
||||
if gc.ACME.OnDemand {
|
||||
log.Warn("ACME.OnDemand is deprecated")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// InitACMEProvider create an acme provider from the ACME part of globalConfiguration
|
||||
func (gc *GlobalConfiguration) InitACMEProvider() (*acmeprovider.Provider, error) {
|
||||
if gc.ACME != nil {
|
||||
if len(gc.ACME.Storage) == 0 {
|
||||
// Delete the ACME configuration to avoid starting ACME in cluster mode
|
||||
gc.ACME = nil
|
||||
return nil, errors.New("unable to initialize ACME provider with no storage location for the certificates")
|
||||
}
|
||||
// TODO: Remove when Provider ACME will replace totally ACME
|
||||
// If provider file, use Provider ACME instead of ACME
|
||||
if gc.Cluster == nil {
|
||||
provider := &acmeprovider.Provider{}
|
||||
provider.Configuration = convertACMEChallenge(gc.ACME)
|
||||
|
||||
store := acmeprovider.NewLocalStore(provider.Storage)
|
||||
provider.Store = store
|
||||
acme.ConvertToNewFormat(provider.Storage)
|
||||
gc.ACME = nil
|
||||
return provider, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func getSafeACMECAServer(caServerSrc string) string {
|
||||
if len(caServerSrc) == 0 {
|
||||
return DefaultAcmeCAServer
|
||||
}
|
||||
|
||||
if strings.HasPrefix(caServerSrc, "https://acme-v01.api.letsencrypt.org") {
|
||||
caServer := strings.Replace(caServerSrc, "v01", "v02", 1)
|
||||
log.Warnf("The CA server %[1]q refers to a v01 endpoint of the ACME API, please change to %[2]q. Fallback to %[2]q.", caServerSrc, caServer)
|
||||
return caServer
|
||||
}
|
||||
|
||||
if strings.HasPrefix(caServerSrc, "https://acme-staging.api.letsencrypt.org") {
|
||||
caServer := strings.Replace(caServerSrc, "https://acme-staging.api.letsencrypt.org", "https://acme-staging-v02.api.letsencrypt.org", 1)
|
||||
log.Warnf("The CA server %[1]q refers to a v01 endpoint of the ACME API, please change to %[2]q. Fallback to %[2]q.", caServerSrc, caServer)
|
||||
return caServer
|
||||
}
|
||||
|
||||
return caServerSrc
|
||||
}
|
||||
|
||||
// ValidateConfiguration validate that configuration is coherent
|
||||
func (gc *GlobalConfiguration) ValidateConfiguration() {
|
||||
if gc.ACME != nil {
|
||||
if _, ok := gc.EntryPoints[gc.ACME.EntryPoint]; !ok {
|
||||
log.Fatalf("Unknown entrypoint %q for ACME configuration", gc.ACME.EntryPoint)
|
||||
} else {
|
||||
if gc.EntryPoints[gc.ACME.EntryPoint].TLS == nil {
|
||||
log.Fatalf("Entrypoint %q has no TLS configuration for ACME configuration", gc.ACME.EntryPoint)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultEntryPoints holds default entry points
|
||||
type DefaultEntryPoints []string
|
||||
|
||||
// String is the method to format the flag's value, part of the flag.Value interface.
|
||||
// The String method's output will be used in diagnostics.
|
||||
func (dep *DefaultEntryPoints) String() string {
|
||||
return strings.Join(*dep, ",")
|
||||
}
|
||||
|
||||
// Set is the method to set the flag value, part of the flag.Value interface.
|
||||
// Set's argument is a string to be parsed to set the flag.
|
||||
// It's a comma-separated list, so we split it.
|
||||
func (dep *DefaultEntryPoints) Set(value string) error {
|
||||
entrypoints := strings.Split(value, ",")
|
||||
if len(entrypoints) == 0 {
|
||||
return fmt.Errorf("bad DefaultEntryPoints format: %s", value)
|
||||
}
|
||||
for _, entrypoint := range entrypoints {
|
||||
*dep = append(*dep, entrypoint)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get return the EntryPoints map
|
||||
func (dep *DefaultEntryPoints) Get() interface{} {
|
||||
return *dep
|
||||
}
|
||||
|
||||
// SetValue sets the EntryPoints map with val
|
||||
func (dep *DefaultEntryPoints) SetValue(val interface{}) {
|
||||
*dep = val.(DefaultEntryPoints)
|
||||
}
|
||||
|
||||
// Type is type of the struct
|
||||
func (dep *DefaultEntryPoints) Type() string {
|
||||
return "defaultentrypoints"
|
||||
}
|
||||
|
||||
// Retry contains request retry config
|
||||
type Retry struct {
|
||||
Attempts int `description:"Number of attempts" export:"true"`
|
||||
}
|
||||
|
||||
// HealthCheckConfig contains health check configuration parameters.
|
||||
type HealthCheckConfig struct {
|
||||
Interval parse.Duration `description:"Default periodicity of enabled health checks" export:"true"`
|
||||
Timeout parse.Duration `description:"Default request timeout of enabled health checks" export:"true"`
|
||||
}
|
||||
|
||||
// RespondingTimeouts contains timeout configurations for incoming requests to the Traefik instance.
|
||||
type RespondingTimeouts struct {
|
||||
ReadTimeout parse.Duration `description:"ReadTimeout is the maximum duration for reading the entire request, including the body. If zero, no timeout is set" export:"true"`
|
||||
WriteTimeout parse.Duration `description:"WriteTimeout is the maximum duration before timing out writes of the response. If zero, no timeout is set" export:"true"`
|
||||
IdleTimeout parse.Duration `description:"IdleTimeout is the maximum amount duration an idle (keep-alive) connection will remain idle before closing itself. Defaults to 180 seconds. If zero, no timeout is set" export:"true"`
|
||||
}
|
||||
|
||||
// ForwardingTimeouts contains timeout configurations for forwarding requests to the backend servers.
|
||||
type ForwardingTimeouts struct {
|
||||
DialTimeout parse.Duration `description:"The amount of time to wait until a connection to a backend server can be established. Defaults to 30 seconds. If zero, no timeout exists" export:"true"`
|
||||
ResponseHeaderTimeout parse.Duration `description:"The amount of time to wait for a server's response headers after fully writing the request (including its body, if any). If zero, no timeout exists" export:"true"`
|
||||
}
|
||||
|
||||
// LifeCycle contains configurations relevant to the lifecycle (such as the
|
||||
// shutdown phase) of Traefik.
|
||||
type LifeCycle struct {
|
||||
RequestAcceptGraceTimeout parse.Duration `description:"Duration to keep accepting requests before Traefik initiates the graceful shutdown procedure"`
|
||||
GraceTimeOut parse.Duration `description:"Duration to give active requests a chance to finish before Traefik stops"`
|
||||
}
|
||||
|
||||
// HostResolverConfig contain configuration for CNAME Flattening
|
||||
type HostResolverConfig struct {
|
||||
CnameFlattening bool `description:"A flag to enable/disable CNAME flattening" export:"true"`
|
||||
ResolvConfig string `description:"resolv.conf used for DNS resolving" export:"true"`
|
||||
ResolvDepth int `description:"The maximal depth of DNS recursive resolving" export:"true"`
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
func convertACMEChallenge(oldACMEChallenge *acme.ACME) *acmeprovider.Configuration {
|
||||
conf := &acmeprovider.Configuration{
|
||||
KeyType: oldACMEChallenge.KeyType,
|
||||
OnHostRule: oldACMEChallenge.OnHostRule,
|
||||
OnDemand: oldACMEChallenge.OnDemand,
|
||||
Email: oldACMEChallenge.Email,
|
||||
Storage: oldACMEChallenge.Storage,
|
||||
ACMELogging: oldACMEChallenge.ACMELogging,
|
||||
CAServer: oldACMEChallenge.CAServer,
|
||||
EntryPoint: oldACMEChallenge.EntryPoint,
|
||||
}
|
||||
|
||||
for _, domain := range oldACMEChallenge.Domains {
|
||||
if domain.Main != lego.UnFqdn(domain.Main) {
|
||||
log.Warnf("FQDN detected, please remove the trailing dot: %s", domain.Main)
|
||||
}
|
||||
for _, san := range domain.SANs {
|
||||
if san != lego.UnFqdn(san) {
|
||||
log.Warnf("FQDN detected, please remove the trailing dot: %s", san)
|
||||
}
|
||||
}
|
||||
conf.Domains = append(conf.Domains, newtypes.Domain(domain))
|
||||
}
|
||||
if oldACMEChallenge.HTTPChallenge != nil {
|
||||
conf.HTTPChallenge = &acmeprovider.HTTPChallenge{
|
||||
EntryPoint: oldACMEChallenge.HTTPChallenge.EntryPoint,
|
||||
}
|
||||
}
|
||||
|
||||
if oldACMEChallenge.DNSChallenge != nil {
|
||||
conf.DNSChallenge = &acmeprovider.DNSChallenge{
|
||||
Provider: oldACMEChallenge.DNSChallenge.Provider,
|
||||
DelayBeforeCheck: oldACMEChallenge.DNSChallenge.DelayBeforeCheck,
|
||||
}
|
||||
}
|
||||
|
||||
if oldACMEChallenge.TLSChallenge != nil {
|
||||
conf.TLSChallenge = &acmeprovider.TLSChallenge{}
|
||||
}
|
||||
|
||||
return conf
|
||||
}
|
||||
224
old/configuration/configuration_test.go
Normal file
224
old/configuration/configuration_test.go
Normal file
|
|
@ -0,0 +1,224 @@
|
|||
package configuration
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/acme"
|
||||
"github.com/containous/traefik/old/middlewares/tracing"
|
||||
"github.com/containous/traefik/old/middlewares/tracing/jaeger"
|
||||
"github.com/containous/traefik/old/middlewares/tracing/zipkin"
|
||||
"github.com/containous/traefik/old/provider"
|
||||
acmeprovider "github.com/containous/traefik/old/provider/acme"
|
||||
"github.com/containous/traefik/old/provider/file"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const defaultConfigFile = "traefik.toml"
|
||||
|
||||
func TestSetEffectiveConfigurationFileProviderFilename(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
fileProvider *file.Provider
|
||||
wantFileProviderFilename string
|
||||
wantFileProviderTraefikFile string
|
||||
}{
|
||||
{
|
||||
desc: "no filename for file provider given",
|
||||
fileProvider: &file.Provider{},
|
||||
wantFileProviderFilename: "",
|
||||
wantFileProviderTraefikFile: defaultConfigFile,
|
||||
},
|
||||
{
|
||||
desc: "filename for file provider given",
|
||||
fileProvider: &file.Provider{BaseProvider: provider.BaseProvider{Filename: "other.toml"}},
|
||||
wantFileProviderFilename: "other.toml",
|
||||
wantFileProviderTraefikFile: defaultConfigFile,
|
||||
},
|
||||
{
|
||||
desc: "directory for file provider given",
|
||||
fileProvider: &file.Provider{Directory: "/"},
|
||||
wantFileProviderFilename: "",
|
||||
wantFileProviderTraefikFile: defaultConfigFile,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gc := &GlobalConfiguration{
|
||||
File: test.fileProvider,
|
||||
}
|
||||
|
||||
gc.SetEffectiveConfiguration(defaultConfigFile)
|
||||
|
||||
assert.Equal(t, test.wantFileProviderFilename, gc.File.Filename)
|
||||
assert.Equal(t, test.wantFileProviderTraefikFile, gc.File.TraefikFile)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetEffectiveConfigurationTracing(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
tracing *tracing.Tracing
|
||||
expected *tracing.Tracing
|
||||
}{
|
||||
{
|
||||
desc: "no tracing configuration",
|
||||
tracing: &tracing.Tracing{},
|
||||
expected: &tracing.Tracing{},
|
||||
},
|
||||
{
|
||||
desc: "tracing bad backend name",
|
||||
tracing: &tracing.Tracing{
|
||||
Backend: "powpow",
|
||||
},
|
||||
expected: &tracing.Tracing{
|
||||
Backend: "powpow",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tracing jaeger backend name",
|
||||
tracing: &tracing.Tracing{
|
||||
Backend: "jaeger",
|
||||
Zipkin: &zipkin.Config{
|
||||
HTTPEndpoint: "http://localhost:9411/api/v1/spans",
|
||||
SameSpan: false,
|
||||
ID128Bit: true,
|
||||
Debug: false,
|
||||
},
|
||||
},
|
||||
expected: &tracing.Tracing{
|
||||
Backend: "jaeger",
|
||||
Jaeger: &jaeger.Config{
|
||||
SamplingServerURL: "http://localhost:5778/sampling",
|
||||
SamplingType: "const",
|
||||
SamplingParam: 1.0,
|
||||
LocalAgentHostPort: "127.0.0.1:6831",
|
||||
Propagation: "jaeger",
|
||||
Gen128Bit: false,
|
||||
},
|
||||
Zipkin: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tracing zipkin backend name",
|
||||
tracing: &tracing.Tracing{
|
||||
Backend: "zipkin",
|
||||
Jaeger: &jaeger.Config{
|
||||
SamplingServerURL: "http://localhost:5778/sampling",
|
||||
SamplingType: "const",
|
||||
SamplingParam: 1.0,
|
||||
LocalAgentHostPort: "127.0.0.1:6831",
|
||||
},
|
||||
},
|
||||
expected: &tracing.Tracing{
|
||||
Backend: "zipkin",
|
||||
Jaeger: nil,
|
||||
Zipkin: &zipkin.Config{
|
||||
HTTPEndpoint: "http://localhost:9411/api/v1/spans",
|
||||
SameSpan: false,
|
||||
ID128Bit: true,
|
||||
Debug: false,
|
||||
SampleRate: 1.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tracing zipkin backend name value override",
|
||||
tracing: &tracing.Tracing{
|
||||
Backend: "zipkin",
|
||||
Jaeger: &jaeger.Config{
|
||||
SamplingServerURL: "http://localhost:5778/sampling",
|
||||
SamplingType: "const",
|
||||
SamplingParam: 1.0,
|
||||
LocalAgentHostPort: "127.0.0.1:6831",
|
||||
},
|
||||
Zipkin: &zipkin.Config{
|
||||
HTTPEndpoint: "http://powpow:9411/api/v1/spans",
|
||||
SameSpan: true,
|
||||
ID128Bit: true,
|
||||
Debug: true,
|
||||
SampleRate: 0.02,
|
||||
},
|
||||
},
|
||||
expected: &tracing.Tracing{
|
||||
Backend: "zipkin",
|
||||
Jaeger: nil,
|
||||
Zipkin: &zipkin.Config{
|
||||
HTTPEndpoint: "http://powpow:9411/api/v1/spans",
|
||||
SameSpan: true,
|
||||
ID128Bit: true,
|
||||
Debug: true,
|
||||
SampleRate: 0.02,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gc := &GlobalConfiguration{
|
||||
Tracing: test.tracing,
|
||||
}
|
||||
|
||||
gc.SetEffectiveConfiguration(defaultConfigFile)
|
||||
|
||||
assert.Equal(t, test.expected, gc.Tracing)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitACMEProvider(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
acmeConfiguration *acme.ACME
|
||||
expectedConfiguration *acmeprovider.Provider
|
||||
noError bool
|
||||
}{
|
||||
{
|
||||
desc: "No ACME configuration",
|
||||
acmeConfiguration: nil,
|
||||
expectedConfiguration: nil,
|
||||
noError: true,
|
||||
},
|
||||
{
|
||||
desc: "ACME configuration with storage",
|
||||
acmeConfiguration: &acme.ACME{Storage: "foo/acme.json"},
|
||||
expectedConfiguration: &acmeprovider.Provider{Configuration: &acmeprovider.Configuration{Storage: "foo/acme.json"}},
|
||||
noError: true,
|
||||
},
|
||||
{
|
||||
desc: "ACME configuration with no storage",
|
||||
acmeConfiguration: &acme.ACME{},
|
||||
expectedConfiguration: nil,
|
||||
noError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gc := &GlobalConfiguration{
|
||||
ACME: test.acmeConfiguration,
|
||||
}
|
||||
|
||||
configuration, err := gc.InitACMEProvider()
|
||||
|
||||
assert.True(t, (err == nil) == test.noError)
|
||||
|
||||
if test.expectedConfiguration == nil {
|
||||
assert.Nil(t, configuration)
|
||||
} else {
|
||||
assert.Equal(t, test.expectedConfiguration.Storage, configuration.Storage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
322
old/configuration/entrypoints.go
Normal file
322
old/configuration/entrypoints.go
Normal file
|
|
@ -0,0 +1,322 @@
|
|||
package configuration
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/containous/traefik/tls"
|
||||
)
|
||||
|
||||
// EntryPoint holds an entry point configuration of the reverse proxy (ip, port, TLS...)
|
||||
type EntryPoint struct {
|
||||
Address string
|
||||
TLS *tls.TLS `export:"true"`
|
||||
Redirect *types.Redirect `export:"true"`
|
||||
Auth *types.Auth `export:"true"`
|
||||
WhiteList *types.WhiteList `export:"true"`
|
||||
Compress *Compress `export:"true"`
|
||||
ProxyProtocol *ProxyProtocol `export:"true"`
|
||||
ForwardedHeaders *ForwardedHeaders `export:"true"`
|
||||
ClientIPStrategy *types.IPStrategy `export:"true"`
|
||||
}
|
||||
|
||||
// Compress contains compress configuration
|
||||
type Compress struct{}
|
||||
|
||||
// ProxyProtocol contains Proxy-Protocol configuration
|
||||
type ProxyProtocol struct {
|
||||
Insecure bool `export:"true"`
|
||||
TrustedIPs []string
|
||||
}
|
||||
|
||||
// ForwardedHeaders Trust client forwarding headers
|
||||
type ForwardedHeaders struct {
|
||||
Insecure bool `export:"true"`
|
||||
TrustedIPs []string
|
||||
}
|
||||
|
||||
// EntryPoints holds entry points configuration of the reverse proxy (ip, port, TLS...)
|
||||
type EntryPoints map[string]*EntryPoint
|
||||
|
||||
// String is the method to format the flag's value, part of the flag.Value interface.
|
||||
// The String method's output will be used in diagnostics.
|
||||
func (ep EntryPoints) String() string {
|
||||
return fmt.Sprintf("%+v", map[string]*EntryPoint(ep))
|
||||
}
|
||||
|
||||
// Get return the EntryPoints map
|
||||
func (ep *EntryPoints) Get() interface{} {
|
||||
return *ep
|
||||
}
|
||||
|
||||
// SetValue sets the EntryPoints map with val
|
||||
func (ep *EntryPoints) SetValue(val interface{}) {
|
||||
*ep = val.(EntryPoints)
|
||||
}
|
||||
|
||||
// Type is type of the struct
|
||||
func (ep *EntryPoints) Type() string {
|
||||
return "entrypoints"
|
||||
}
|
||||
|
||||
// Set is the method to set the flag value, part of the flag.Value interface.
|
||||
// Set's argument is a string to be parsed to set the flag.
|
||||
// It's a comma-separated list, so we split it.
|
||||
func (ep *EntryPoints) Set(value string) error {
|
||||
result := parseEntryPointsConfiguration(value)
|
||||
|
||||
var compress *Compress
|
||||
if len(result["compress"]) > 0 {
|
||||
compress = &Compress{}
|
||||
}
|
||||
|
||||
configTLS, err := makeEntryPointTLS(result)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
(*ep)[result["name"]] = &EntryPoint{
|
||||
Address: result["address"],
|
||||
TLS: configTLS,
|
||||
Auth: makeEntryPointAuth(result),
|
||||
Redirect: makeEntryPointRedirect(result),
|
||||
Compress: compress,
|
||||
WhiteList: makeWhiteList(result),
|
||||
ProxyProtocol: makeEntryPointProxyProtocol(result),
|
||||
ForwardedHeaders: makeEntryPointForwardedHeaders(result),
|
||||
ClientIPStrategy: makeIPStrategy("clientipstrategy", result),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeWhiteList(result map[string]string) *types.WhiteList {
|
||||
if rawRange, ok := result["whitelist_sourcerange"]; ok {
|
||||
return &types.WhiteList{
|
||||
SourceRange: strings.Split(rawRange, ","),
|
||||
IPStrategy: makeIPStrategy("whitelist_ipstrategy", result),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func makeIPStrategy(prefix string, result map[string]string) *types.IPStrategy {
|
||||
depth := toInt(result, prefix+"_depth")
|
||||
excludedIPs := result[prefix+"_excludedips"]
|
||||
|
||||
if depth == 0 && len(excludedIPs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &types.IPStrategy{
|
||||
Depth: depth,
|
||||
ExcludedIPs: strings.Split(excludedIPs, ","),
|
||||
}
|
||||
}
|
||||
|
||||
func makeEntryPointAuth(result map[string]string) *types.Auth {
|
||||
var basic *types.Basic
|
||||
if v, ok := result["auth_basic_users"]; ok {
|
||||
basic = &types.Basic{
|
||||
Realm: result["auth_basic_realm"],
|
||||
Users: strings.Split(v, ","),
|
||||
RemoveHeader: toBool(result, "auth_basic_removeheader"),
|
||||
}
|
||||
}
|
||||
|
||||
var digest *types.Digest
|
||||
if v, ok := result["auth_digest_users"]; ok {
|
||||
digest = &types.Digest{
|
||||
Users: strings.Split(v, ","),
|
||||
RemoveHeader: toBool(result, "auth_digest_removeheader"),
|
||||
}
|
||||
}
|
||||
|
||||
var forward *types.Forward
|
||||
if address, ok := result["auth_forward_address"]; ok {
|
||||
var clientTLS *types.ClientTLS
|
||||
|
||||
cert := result["auth_forward_tls_cert"]
|
||||
key := result["auth_forward_tls_key"]
|
||||
insecureSkipVerify := toBool(result, "auth_forward_tls_insecureskipverify")
|
||||
|
||||
if len(cert) > 0 && len(key) > 0 || insecureSkipVerify {
|
||||
clientTLS = &types.ClientTLS{
|
||||
CA: result["auth_forward_tls_ca"],
|
||||
CAOptional: toBool(result, "auth_forward_tls_caoptional"),
|
||||
Cert: cert,
|
||||
Key: key,
|
||||
InsecureSkipVerify: insecureSkipVerify,
|
||||
}
|
||||
}
|
||||
|
||||
var authResponseHeaders []string
|
||||
if v, ok := result["auth_forward_authresponseheaders"]; ok {
|
||||
authResponseHeaders = strings.Split(v, ",")
|
||||
}
|
||||
|
||||
forward = &types.Forward{
|
||||
Address: address,
|
||||
TLS: clientTLS,
|
||||
TrustForwardHeader: toBool(result, "auth_forward_trustforwardheader"),
|
||||
AuthResponseHeaders: authResponseHeaders,
|
||||
}
|
||||
}
|
||||
|
||||
var auth *types.Auth
|
||||
if basic != nil || digest != nil || forward != nil {
|
||||
auth = &types.Auth{
|
||||
Basic: basic,
|
||||
Digest: digest,
|
||||
Forward: forward,
|
||||
HeaderField: result["auth_headerfield"],
|
||||
}
|
||||
}
|
||||
|
||||
return auth
|
||||
}
|
||||
|
||||
func makeEntryPointProxyProtocol(result map[string]string) *ProxyProtocol {
|
||||
var proxyProtocol *ProxyProtocol
|
||||
|
||||
ppTrustedIPs := result["proxyprotocol_trustedips"]
|
||||
if len(result["proxyprotocol_insecure"]) > 0 || len(ppTrustedIPs) > 0 {
|
||||
proxyProtocol = &ProxyProtocol{
|
||||
Insecure: toBool(result, "proxyprotocol_insecure"),
|
||||
}
|
||||
if len(ppTrustedIPs) > 0 {
|
||||
proxyProtocol.TrustedIPs = strings.Split(ppTrustedIPs, ",")
|
||||
}
|
||||
}
|
||||
|
||||
if proxyProtocol != nil && proxyProtocol.Insecure {
|
||||
log.Warn("ProxyProtocol.insecure:true is dangerous. Please use 'ProxyProtocol.TrustedIPs:IPs' and remove 'ProxyProtocol.insecure:true'")
|
||||
}
|
||||
|
||||
return proxyProtocol
|
||||
}
|
||||
|
||||
func makeEntryPointForwardedHeaders(result map[string]string) *ForwardedHeaders {
|
||||
forwardedHeaders := &ForwardedHeaders{}
|
||||
if _, ok := result["forwardedheaders_insecure"]; ok {
|
||||
forwardedHeaders.Insecure = toBool(result, "forwardedheaders_insecure")
|
||||
}
|
||||
|
||||
fhTrustedIPs := result["forwardedheaders_trustedips"]
|
||||
if len(fhTrustedIPs) > 0 {
|
||||
// TODO must be removed in the next breaking version.
|
||||
forwardedHeaders.Insecure = toBool(result, "forwardedheaders_insecure")
|
||||
forwardedHeaders.TrustedIPs = strings.Split(fhTrustedIPs, ",")
|
||||
}
|
||||
|
||||
return forwardedHeaders
|
||||
}
|
||||
|
||||
func makeEntryPointRedirect(result map[string]string) *types.Redirect {
|
||||
var redirect *types.Redirect
|
||||
|
||||
if len(result["redirect_entrypoint"]) > 0 || len(result["redirect_regex"]) > 0 || len(result["redirect_replacement"]) > 0 {
|
||||
redirect = &types.Redirect{
|
||||
EntryPoint: result["redirect_entrypoint"],
|
||||
Regex: result["redirect_regex"],
|
||||
Replacement: result["redirect_replacement"],
|
||||
Permanent: toBool(result, "redirect_permanent"),
|
||||
}
|
||||
}
|
||||
|
||||
return redirect
|
||||
}
|
||||
|
||||
func makeEntryPointTLS(result map[string]string) (*tls.TLS, error) {
|
||||
var configTLS *tls.TLS
|
||||
|
||||
if len(result["tls"]) > 0 {
|
||||
certs := tls.Certificates{}
|
||||
if err := certs.Set(result["tls"]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
configTLS = &tls.TLS{
|
||||
Certificates: certs,
|
||||
}
|
||||
} else if len(result["tls_acme"]) > 0 {
|
||||
configTLS = &tls.TLS{
|
||||
Certificates: tls.Certificates{},
|
||||
}
|
||||
}
|
||||
|
||||
if configTLS != nil {
|
||||
if len(result["ca"]) > 0 {
|
||||
files := tls.FilesOrContents{}
|
||||
files.Set(result["ca"])
|
||||
optional := toBool(result, "ca_optional")
|
||||
configTLS.ClientCA = tls.ClientCA{
|
||||
Files: files,
|
||||
Optional: optional,
|
||||
}
|
||||
}
|
||||
|
||||
if len(result["tls_minversion"]) > 0 {
|
||||
configTLS.MinVersion = result["tls_minversion"]
|
||||
}
|
||||
|
||||
if len(result["tls_ciphersuites"]) > 0 {
|
||||
configTLS.CipherSuites = strings.Split(result["tls_ciphersuites"], ",")
|
||||
}
|
||||
|
||||
if len(result["tls_snistrict"]) > 0 {
|
||||
configTLS.SniStrict = toBool(result, "tls_snistrict")
|
||||
}
|
||||
|
||||
if len(result["tls_defaultcertificate_cert"]) > 0 && len(result["tls_defaultcertificate_key"]) > 0 {
|
||||
configTLS.DefaultCertificate = &tls.Certificate{
|
||||
CertFile: tls.FileOrContent(result["tls_defaultcertificate_cert"]),
|
||||
KeyFile: tls.FileOrContent(result["tls_defaultcertificate_key"]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return configTLS, nil
|
||||
}
|
||||
|
||||
func parseEntryPointsConfiguration(raw string) map[string]string {
|
||||
sections := strings.Fields(raw)
|
||||
|
||||
config := make(map[string]string)
|
||||
for _, part := range sections {
|
||||
field := strings.SplitN(part, ":", 2)
|
||||
name := strings.ToLower(strings.Replace(field[0], ".", "_", -1))
|
||||
if len(field) > 1 {
|
||||
config[name] = field[1]
|
||||
} else {
|
||||
if strings.EqualFold(name, "TLS") {
|
||||
config["tls_acme"] = "TLS"
|
||||
} else {
|
||||
config[name] = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
return config
|
||||
}
|
||||
|
||||
func toBool(conf map[string]string, key string) bool {
|
||||
if val, ok := conf[key]; ok {
|
||||
return strings.EqualFold(val, "true") ||
|
||||
strings.EqualFold(val, "enable") ||
|
||||
strings.EqualFold(val, "on")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func toInt(conf map[string]string, key string) int {
|
||||
if val, ok := conf[key]; ok {
|
||||
intVal, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return intVal
|
||||
}
|
||||
return 0
|
||||
}
|
||||
515
old/configuration/entrypoints_test.go
Normal file
515
old/configuration/entrypoints_test.go
Normal file
|
|
@ -0,0 +1,515 @@
|
|||
package configuration
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/containous/traefik/tls"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_parseEntryPointsConfiguration(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
value string
|
||||
expectedResult map[string]string
|
||||
}{
|
||||
{
|
||||
name: "all parameters",
|
||||
value: "Name:foo " +
|
||||
"Address::8000 " +
|
||||
"TLS:goo,gii " +
|
||||
"TLS " +
|
||||
"TLS.MinVersion:VersionTLS11 " +
|
||||
"TLS.CipherSuites:TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA384,TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA " +
|
||||
"CA:car " +
|
||||
"CA.Optional:true " +
|
||||
"Redirect.EntryPoint:https " +
|
||||
"Redirect.Regex:http://localhost/(.*) " +
|
||||
"Redirect.Replacement:http://mydomain/$1 " +
|
||||
"Redirect.Permanent:true " +
|
||||
"Compress:true " +
|
||||
"ProxyProtocol.TrustedIPs:192.168.0.1 " +
|
||||
"ForwardedHeaders.TrustedIPs:10.0.0.3/24,20.0.0.3/24 " +
|
||||
"Auth.Basic.Realm:myRealm " +
|
||||
"Auth.Basic.Users:test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/,test2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0 " +
|
||||
"Auth.Basic.RemoveHeader:true " +
|
||||
"Auth.Digest.Users:test:traefik:a2688e031edb4be6a3797f3882655c05,test2:traefik:518845800f9e2bfb1f1f740ec24f074e " +
|
||||
"Auth.Digest.RemoveHeader:true " +
|
||||
"Auth.HeaderField:X-WebAuth-User " +
|
||||
"Auth.Forward.Address:https://authserver.com/auth " +
|
||||
"Auth.Forward.AuthResponseHeaders:X-Auth,X-Test,X-Secret " +
|
||||
"Auth.Forward.TrustForwardHeader:true " +
|
||||
"Auth.Forward.TLS.CA:path/to/local.crt " +
|
||||
"Auth.Forward.TLS.CAOptional:true " +
|
||||
"Auth.Forward.TLS.Cert:path/to/foo.cert " +
|
||||
"Auth.Forward.TLS.Key:path/to/foo.key " +
|
||||
"Auth.Forward.TLS.InsecureSkipVerify:true " +
|
||||
"WhiteList.SourceRange:10.42.0.0/16,152.89.1.33/32,afed:be44::/16 " +
|
||||
"WhiteList.IPStrategy.depth:3 " +
|
||||
"WhiteList.IPStrategy.ExcludedIPs:10.0.0.3/24,20.0.0.3/24 " +
|
||||
"ClientIPStrategy.depth:3 " +
|
||||
"ClientIPStrategy.ExcludedIPs:10.0.0.3/24,20.0.0.3/24 ",
|
||||
expectedResult: map[string]string{
|
||||
"address": ":8000",
|
||||
"auth_basic_realm": "myRealm",
|
||||
"auth_basic_users": "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/,test2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0",
|
||||
"auth_basic_removeheader": "true",
|
||||
"auth_digest_users": "test:traefik:a2688e031edb4be6a3797f3882655c05,test2:traefik:518845800f9e2bfb1f1f740ec24f074e",
|
||||
"auth_digest_removeheader": "true",
|
||||
"auth_forward_address": "https://authserver.com/auth",
|
||||
"auth_forward_authresponseheaders": "X-Auth,X-Test,X-Secret",
|
||||
"auth_forward_tls_ca": "path/to/local.crt",
|
||||
"auth_forward_tls_caoptional": "true",
|
||||
"auth_forward_tls_cert": "path/to/foo.cert",
|
||||
"auth_forward_tls_insecureskipverify": "true",
|
||||
"auth_forward_tls_key": "path/to/foo.key",
|
||||
"auth_forward_trustforwardheader": "true",
|
||||
"auth_headerfield": "X-WebAuth-User",
|
||||
"ca": "car",
|
||||
"ca_optional": "true",
|
||||
"compress": "true",
|
||||
"forwardedheaders_trustedips": "10.0.0.3/24,20.0.0.3/24",
|
||||
"name": "foo",
|
||||
"proxyprotocol_trustedips": "192.168.0.1",
|
||||
"redirect_entrypoint": "https",
|
||||
"redirect_permanent": "true",
|
||||
"redirect_regex": "http://localhost/(.*)",
|
||||
"redirect_replacement": "http://mydomain/$1",
|
||||
"tls": "goo,gii",
|
||||
"tls_acme": "TLS",
|
||||
"tls_ciphersuites": "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA384,TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA",
|
||||
"tls_minversion": "VersionTLS11",
|
||||
"whitelist_sourcerange": "10.42.0.0/16,152.89.1.33/32,afed:be44::/16",
|
||||
"whitelist_ipstrategy_depth": "3",
|
||||
"whitelist_ipstrategy_excludedips": "10.0.0.3/24,20.0.0.3/24",
|
||||
"clientipstrategy_depth": "3",
|
||||
"clientipstrategy_excludedips": "10.0.0.3/24,20.0.0.3/24",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "compress on",
|
||||
value: "name:foo Compress:on",
|
||||
expectedResult: map[string]string{
|
||||
"name": "foo",
|
||||
"compress": "on",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "TLS",
|
||||
value: "Name:foo TLS:goo TLS",
|
||||
expectedResult: map[string]string{
|
||||
"name": "foo",
|
||||
"tls": "goo",
|
||||
"tls_acme": "TLS",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conf := parseEntryPointsConfiguration(test.value)
|
||||
|
||||
assert.Len(t, conf, len(test.expectedResult))
|
||||
assert.Equal(t, test.expectedResult, conf)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_toBool(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
value string
|
||||
key string
|
||||
expectedBool bool
|
||||
}{
|
||||
{
|
||||
name: "on",
|
||||
value: "on",
|
||||
key: "foo",
|
||||
expectedBool: true,
|
||||
},
|
||||
{
|
||||
name: "true",
|
||||
value: "true",
|
||||
key: "foo",
|
||||
expectedBool: true,
|
||||
},
|
||||
{
|
||||
name: "enable",
|
||||
value: "enable",
|
||||
key: "foo",
|
||||
expectedBool: true,
|
||||
},
|
||||
{
|
||||
name: "arbitrary string",
|
||||
value: "bar",
|
||||
key: "foo",
|
||||
expectedBool: false,
|
||||
},
|
||||
{
|
||||
name: "no existing entry",
|
||||
value: "bar",
|
||||
key: "fii",
|
||||
expectedBool: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conf := map[string]string{
|
||||
"foo": test.value,
|
||||
}
|
||||
|
||||
result := toBool(conf, test.key)
|
||||
|
||||
assert.Equal(t, test.expectedBool, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEntryPoints_Set(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
expression string
|
||||
expectedEntryPointName string
|
||||
expectedEntryPoint *EntryPoint
|
||||
}{
|
||||
{
|
||||
name: "all parameters camelcase",
|
||||
expression: "Name:foo " +
|
||||
"Address::8000 " +
|
||||
"TLS:goo,gii;foo,fii " +
|
||||
"TLS " +
|
||||
"TLS.MinVersion:VersionTLS11 " +
|
||||
"TLS.CipherSuites:TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA384,TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA " +
|
||||
"CA:car " +
|
||||
"CA.Optional:true " +
|
||||
"Redirect.EntryPoint:https " +
|
||||
"Redirect.Regex:http://localhost/(.*) " +
|
||||
"Redirect.Replacement:http://mydomain/$1 " +
|
||||
"Redirect.Permanent:true " +
|
||||
"Compress:true " +
|
||||
"ProxyProtocol.TrustedIPs:192.168.0.1 " +
|
||||
"ForwardedHeaders.TrustedIPs:10.0.0.3/24,20.0.0.3/24 " +
|
||||
"Auth.Basic.Realm:myRealm " +
|
||||
"Auth.Basic.Users:test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/,test2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0 " +
|
||||
"Auth.Basic.RemoveHeader:true " +
|
||||
"Auth.Digest.Users:test:traefik:a2688e031edb4be6a3797f3882655c05,test2:traefik:518845800f9e2bfb1f1f740ec24f074e " +
|
||||
"Auth.Digest.RemoveHeader:true " +
|
||||
"Auth.HeaderField:X-WebAuth-User " +
|
||||
"Auth.Forward.Address:https://authserver.com/auth " +
|
||||
"Auth.Forward.AuthResponseHeaders:X-Auth,X-Test,X-Secret " +
|
||||
"Auth.Forward.TrustForwardHeader:true " +
|
||||
"Auth.Forward.TLS.CA:path/to/local.crt " +
|
||||
"Auth.Forward.TLS.CAOptional:true " +
|
||||
"Auth.Forward.TLS.Cert:path/to/foo.cert " +
|
||||
"Auth.Forward.TLS.Key:path/to/foo.key " +
|
||||
"Auth.Forward.TLS.InsecureSkipVerify:true " +
|
||||
"WhiteList.SourceRange:10.42.0.0/16,152.89.1.33/32,afed:be44::/16 " +
|
||||
"WhiteList.IPStrategy.depth:3 " +
|
||||
"WhiteList.IPStrategy.ExcludedIPs:10.0.0.3/24,20.0.0.3/24 " +
|
||||
"ClientIPStrategy.depth:3 " +
|
||||
"ClientIPStrategy.ExcludedIPs:10.0.0.3/24,20.0.0.3/24 ",
|
||||
expectedEntryPointName: "foo",
|
||||
expectedEntryPoint: &EntryPoint{
|
||||
Address: ":8000",
|
||||
TLS: &tls.TLS{
|
||||
MinVersion: "VersionTLS11",
|
||||
CipherSuites: []string{"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA384", "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA"},
|
||||
Certificates: tls.Certificates{
|
||||
{
|
||||
CertFile: tls.FileOrContent("goo"),
|
||||
KeyFile: tls.FileOrContent("gii"),
|
||||
},
|
||||
{
|
||||
CertFile: tls.FileOrContent("foo"),
|
||||
KeyFile: tls.FileOrContent("fii"),
|
||||
},
|
||||
},
|
||||
ClientCA: tls.ClientCA{
|
||||
Files: tls.FilesOrContents{"car"},
|
||||
Optional: true,
|
||||
},
|
||||
},
|
||||
Redirect: &types.Redirect{
|
||||
EntryPoint: "https",
|
||||
Regex: "http://localhost/(.*)",
|
||||
Replacement: "http://mydomain/$1",
|
||||
Permanent: true,
|
||||
},
|
||||
Auth: &types.Auth{
|
||||
Basic: &types.Basic{
|
||||
Realm: "myRealm",
|
||||
RemoveHeader: true,
|
||||
Users: types.Users{
|
||||
"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/",
|
||||
"test2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0",
|
||||
},
|
||||
},
|
||||
Digest: &types.Digest{
|
||||
RemoveHeader: true,
|
||||
Users: types.Users{
|
||||
"test:traefik:a2688e031edb4be6a3797f3882655c05",
|
||||
"test2:traefik:518845800f9e2bfb1f1f740ec24f074e",
|
||||
},
|
||||
},
|
||||
Forward: &types.Forward{
|
||||
Address: "https://authserver.com/auth",
|
||||
AuthResponseHeaders: []string{"X-Auth", "X-Test", "X-Secret"},
|
||||
TLS: &types.ClientTLS{
|
||||
CA: "path/to/local.crt",
|
||||
CAOptional: true,
|
||||
Cert: "path/to/foo.cert",
|
||||
Key: "path/to/foo.key",
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
TrustForwardHeader: true,
|
||||
},
|
||||
HeaderField: "X-WebAuth-User",
|
||||
},
|
||||
WhiteList: &types.WhiteList{
|
||||
SourceRange: []string{
|
||||
"10.42.0.0/16",
|
||||
"152.89.1.33/32",
|
||||
"afed:be44::/16",
|
||||
},
|
||||
IPStrategy: &types.IPStrategy{
|
||||
Depth: 3,
|
||||
ExcludedIPs: []string{
|
||||
"10.0.0.3/24",
|
||||
"20.0.0.3/24",
|
||||
},
|
||||
},
|
||||
},
|
||||
Compress: &Compress{},
|
||||
ProxyProtocol: &ProxyProtocol{
|
||||
Insecure: false,
|
||||
TrustedIPs: []string{"192.168.0.1"},
|
||||
},
|
||||
ForwardedHeaders: &ForwardedHeaders{
|
||||
Insecure: false,
|
||||
TrustedIPs: []string{
|
||||
"10.0.0.3/24",
|
||||
"20.0.0.3/24",
|
||||
},
|
||||
},
|
||||
ClientIPStrategy: &types.IPStrategy{
|
||||
Depth: 3,
|
||||
ExcludedIPs: []string{
|
||||
"10.0.0.3/24",
|
||||
"20.0.0.3/24",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all parameters lowercase",
|
||||
expression: "Name:foo " +
|
||||
"address::8000 " +
|
||||
"tls:goo,gii;foo,fii " +
|
||||
"tls " +
|
||||
"tls.minversion:VersionTLS11 " +
|
||||
"tls.ciphersuites:TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA384,TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA " +
|
||||
"ca:car " +
|
||||
"ca.Optional:true " +
|
||||
"redirect.entryPoint:https " +
|
||||
"redirect.regex:http://localhost/(.*) " +
|
||||
"redirect.replacement:http://mydomain/$1 " +
|
||||
"redirect.permanent:true " +
|
||||
"compress:true " +
|
||||
"whiteList.sourceRange:10.42.0.0/16,152.89.1.33/32,afed:be44::/16 " +
|
||||
"proxyProtocol.TrustedIPs:192.168.0.1 " +
|
||||
"forwardedHeaders.TrustedIPs:10.0.0.3/24,20.0.0.3/24 " +
|
||||
"auth.basic.realm:myRealm " +
|
||||
"auth.basic.users:test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/,test2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0 " +
|
||||
"auth.digest.users:test:traefik:a2688e031edb4be6a3797f3882655c05,test2:traefik:518845800f9e2bfb1f1f740ec24f074e " +
|
||||
"auth.headerField:X-WebAuth-User " +
|
||||
"auth.forward.address:https://authserver.com/auth " +
|
||||
"auth.forward.authResponseHeaders:X-Auth,X-Test,X-Secret " +
|
||||
"auth.forward.trustForwardHeader:true " +
|
||||
"auth.forward.tls.ca:path/to/local.crt " +
|
||||
"auth.forward.tls.caOptional:true " +
|
||||
"auth.forward.tls.cert:path/to/foo.cert " +
|
||||
"auth.forward.tls.key:path/to/foo.key " +
|
||||
"auth.forward.tls.insecureSkipVerify:true ",
|
||||
expectedEntryPointName: "foo",
|
||||
expectedEntryPoint: &EntryPoint{
|
||||
Address: ":8000",
|
||||
TLS: &tls.TLS{
|
||||
MinVersion: "VersionTLS11",
|
||||
CipherSuites: []string{"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA384", "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305", "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA"},
|
||||
Certificates: tls.Certificates{
|
||||
{
|
||||
CertFile: tls.FileOrContent("goo"),
|
||||
KeyFile: tls.FileOrContent("gii"),
|
||||
},
|
||||
{
|
||||
CertFile: tls.FileOrContent("foo"),
|
||||
KeyFile: tls.FileOrContent("fii"),
|
||||
},
|
||||
},
|
||||
ClientCA: tls.ClientCA{
|
||||
Files: tls.FilesOrContents{"car"},
|
||||
Optional: true,
|
||||
},
|
||||
},
|
||||
Redirect: &types.Redirect{
|
||||
EntryPoint: "https",
|
||||
Regex: "http://localhost/(.*)",
|
||||
Replacement: "http://mydomain/$1",
|
||||
Permanent: true,
|
||||
},
|
||||
Auth: &types.Auth{
|
||||
Basic: &types.Basic{
|
||||
Realm: "myRealm",
|
||||
Users: types.Users{
|
||||
"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/",
|
||||
"test2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0",
|
||||
},
|
||||
},
|
||||
Digest: &types.Digest{
|
||||
Users: types.Users{
|
||||
"test:traefik:a2688e031edb4be6a3797f3882655c05",
|
||||
"test2:traefik:518845800f9e2bfb1f1f740ec24f074e",
|
||||
},
|
||||
},
|
||||
Forward: &types.Forward{
|
||||
Address: "https://authserver.com/auth",
|
||||
AuthResponseHeaders: []string{"X-Auth", "X-Test", "X-Secret"},
|
||||
TLS: &types.ClientTLS{
|
||||
CA: "path/to/local.crt",
|
||||
CAOptional: true,
|
||||
Cert: "path/to/foo.cert",
|
||||
Key: "path/to/foo.key",
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
TrustForwardHeader: true,
|
||||
},
|
||||
HeaderField: "X-WebAuth-User",
|
||||
},
|
||||
WhiteList: &types.WhiteList{
|
||||
SourceRange: []string{
|
||||
"10.42.0.0/16",
|
||||
"152.89.1.33/32",
|
||||
"afed:be44::/16",
|
||||
},
|
||||
},
|
||||
Compress: &Compress{},
|
||||
ProxyProtocol: &ProxyProtocol{
|
||||
Insecure: false,
|
||||
TrustedIPs: []string{"192.168.0.1"},
|
||||
},
|
||||
ForwardedHeaders: &ForwardedHeaders{
|
||||
Insecure: false,
|
||||
TrustedIPs: []string{
|
||||
"10.0.0.3/24",
|
||||
"20.0.0.3/24",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "default",
|
||||
expression: "Name:foo",
|
||||
expectedEntryPointName: "foo",
|
||||
expectedEntryPoint: &EntryPoint{
|
||||
ForwardedHeaders: &ForwardedHeaders{Insecure: false},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ForwardedHeaders insecure true",
|
||||
expression: "Name:foo ForwardedHeaders.insecure:true",
|
||||
expectedEntryPointName: "foo",
|
||||
expectedEntryPoint: &EntryPoint{
|
||||
ForwardedHeaders: &ForwardedHeaders{Insecure: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ForwardedHeaders insecure false",
|
||||
expression: "Name:foo ForwardedHeaders.insecure:false",
|
||||
expectedEntryPointName: "foo",
|
||||
expectedEntryPoint: &EntryPoint{
|
||||
ForwardedHeaders: &ForwardedHeaders{Insecure: false},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ForwardedHeaders TrustedIPs",
|
||||
expression: "Name:foo ForwardedHeaders.TrustedIPs:10.0.0.3/24,20.0.0.3/24",
|
||||
expectedEntryPointName: "foo",
|
||||
expectedEntryPoint: &EntryPoint{
|
||||
ForwardedHeaders: &ForwardedHeaders{
|
||||
TrustedIPs: []string{"10.0.0.3/24", "20.0.0.3/24"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ProxyProtocol insecure true",
|
||||
expression: "Name:foo ProxyProtocol.insecure:true",
|
||||
expectedEntryPointName: "foo",
|
||||
expectedEntryPoint: &EntryPoint{
|
||||
ForwardedHeaders: &ForwardedHeaders{},
|
||||
ProxyProtocol: &ProxyProtocol{Insecure: true},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ProxyProtocol insecure false",
|
||||
expression: "Name:foo ProxyProtocol.insecure:false",
|
||||
expectedEntryPointName: "foo",
|
||||
expectedEntryPoint: &EntryPoint{
|
||||
ForwardedHeaders: &ForwardedHeaders{},
|
||||
ProxyProtocol: &ProxyProtocol{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ProxyProtocol TrustedIPs",
|
||||
expression: "Name:foo ProxyProtocol.TrustedIPs:10.0.0.3/24,20.0.0.3/24",
|
||||
expectedEntryPointName: "foo",
|
||||
expectedEntryPoint: &EntryPoint{
|
||||
ForwardedHeaders: &ForwardedHeaders{},
|
||||
ProxyProtocol: &ProxyProtocol{
|
||||
TrustedIPs: []string{"10.0.0.3/24", "20.0.0.3/24"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "compress on",
|
||||
expression: "Name:foo Compress:on",
|
||||
expectedEntryPointName: "foo",
|
||||
expectedEntryPoint: &EntryPoint{
|
||||
Compress: &Compress{},
|
||||
ForwardedHeaders: &ForwardedHeaders{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "compress true",
|
||||
expression: "Name:foo Compress:true",
|
||||
expectedEntryPointName: "foo",
|
||||
expectedEntryPoint: &EntryPoint{
|
||||
Compress: &Compress{},
|
||||
ForwardedHeaders: &ForwardedHeaders{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
eps := EntryPoints{}
|
||||
err := eps.Set(test.expression)
|
||||
require.NoError(t, err)
|
||||
|
||||
ep := eps[test.expectedEntryPointName]
|
||||
assert.EqualValues(t, test.expectedEntryPoint, ep)
|
||||
})
|
||||
}
|
||||
}
|
||||
111
old/configuration/provider_aggregator.go
Normal file
111
old/configuration/provider_aggregator.go
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
package configuration
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/containous/traefik/old/provider"
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/containous/traefik/safe"
|
||||
)
|
||||
|
||||
// ProviderAggregator aggregate providers
|
||||
type ProviderAggregator struct {
|
||||
providers []provider.Provider
|
||||
constraints types.Constraints
|
||||
}
|
||||
|
||||
// NewProviderAggregator return an aggregate of all the providers configured in GlobalConfiguration
|
||||
func NewProviderAggregator(gc *GlobalConfiguration) ProviderAggregator {
|
||||
provider := ProviderAggregator{
|
||||
constraints: gc.Constraints,
|
||||
}
|
||||
if gc.Docker != nil {
|
||||
provider.quietAddProvider(gc.Docker)
|
||||
}
|
||||
if gc.Marathon != nil {
|
||||
provider.quietAddProvider(gc.Marathon)
|
||||
}
|
||||
if gc.File != nil {
|
||||
provider.quietAddProvider(gc.File)
|
||||
}
|
||||
if gc.Rest != nil {
|
||||
provider.quietAddProvider(gc.Rest)
|
||||
}
|
||||
if gc.Consul != nil {
|
||||
provider.quietAddProvider(gc.Consul)
|
||||
}
|
||||
if gc.ConsulCatalog != nil {
|
||||
provider.quietAddProvider(gc.ConsulCatalog)
|
||||
}
|
||||
if gc.Etcd != nil {
|
||||
provider.quietAddProvider(gc.Etcd)
|
||||
}
|
||||
if gc.Zookeeper != nil {
|
||||
provider.quietAddProvider(gc.Zookeeper)
|
||||
}
|
||||
if gc.Boltdb != nil {
|
||||
provider.quietAddProvider(gc.Boltdb)
|
||||
}
|
||||
if gc.Kubernetes != nil {
|
||||
provider.quietAddProvider(gc.Kubernetes)
|
||||
}
|
||||
if gc.Mesos != nil {
|
||||
provider.quietAddProvider(gc.Mesos)
|
||||
}
|
||||
if gc.Eureka != nil {
|
||||
provider.quietAddProvider(gc.Eureka)
|
||||
}
|
||||
if gc.ECS != nil {
|
||||
provider.quietAddProvider(gc.ECS)
|
||||
}
|
||||
if gc.Rancher != nil {
|
||||
provider.quietAddProvider(gc.Rancher)
|
||||
}
|
||||
if gc.DynamoDB != nil {
|
||||
provider.quietAddProvider(gc.DynamoDB)
|
||||
}
|
||||
|
||||
return provider
|
||||
}
|
||||
|
||||
func (p *ProviderAggregator) quietAddProvider(provider provider.Provider) {
|
||||
err := p.AddProvider(provider)
|
||||
if err != nil {
|
||||
log.Errorf("Error initializing provider %T: %v", provider, err)
|
||||
}
|
||||
}
|
||||
|
||||
// AddProvider add a provider in the providers map
|
||||
func (p *ProviderAggregator) AddProvider(provider provider.Provider) error {
|
||||
err := provider.Init(p.constraints)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.providers = append(p.providers, provider)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Init the provider
|
||||
func (p ProviderAggregator) Init(_ types.Constraints) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Provide call the provide method of every providers
|
||||
func (p ProviderAggregator) Provide(configurationChan chan<- types.ConfigMessage, pool *safe.Pool) error {
|
||||
for _, p := range p.providers {
|
||||
jsonConf, err := json.Marshal(p)
|
||||
if err != nil {
|
||||
log.Debugf("Unable to marshal provider conf %T with error: %v", p, err)
|
||||
}
|
||||
log.Infof("Starting provider %T %s", p, jsonConf)
|
||||
currentProvider := p
|
||||
safe.Go(func() {
|
||||
err := currentProvider.Provide(configurationChan, pool)
|
||||
if err != nil {
|
||||
log.Errorf("Error starting provider %T: %v", p, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
116
old/configuration/router/internal_router.go
Normal file
116
old/configuration/router/internal_router.go
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/old/configuration"
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/containous/traefik/old/middlewares"
|
||||
mauth "github.com/containous/traefik/old/middlewares/auth"
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
// NewInternalRouterAggregator Create a new internalRouterAggregator
|
||||
func NewInternalRouterAggregator(globalConfiguration configuration.GlobalConfiguration, entryPointName string) *InternalRouterAggregator {
|
||||
var serverMiddlewares []negroni.Handler
|
||||
|
||||
if globalConfiguration.EntryPoints[entryPointName].WhiteList != nil {
|
||||
ipStrategy := globalConfiguration.EntryPoints[entryPointName].ClientIPStrategy
|
||||
if globalConfiguration.EntryPoints[entryPointName].WhiteList.IPStrategy != nil {
|
||||
ipStrategy = globalConfiguration.EntryPoints[entryPointName].WhiteList.IPStrategy
|
||||
}
|
||||
|
||||
strategy, err := ipStrategy.Get()
|
||||
if err != nil {
|
||||
log.Fatalf("Error creating whitelist middleware: %s", err)
|
||||
}
|
||||
|
||||
ipWhitelistMiddleware, err := middlewares.NewIPWhiteLister(globalConfiguration.EntryPoints[entryPointName].WhiteList.SourceRange, strategy)
|
||||
if err != nil {
|
||||
log.Fatalf("Error creating whitelist middleware: %s", err)
|
||||
}
|
||||
if ipWhitelistMiddleware != nil {
|
||||
serverMiddlewares = append(serverMiddlewares, ipWhitelistMiddleware)
|
||||
}
|
||||
}
|
||||
|
||||
if globalConfiguration.EntryPoints[entryPointName].Auth != nil {
|
||||
authMiddleware, err := mauth.NewAuthenticator(globalConfiguration.EntryPoints[entryPointName].Auth, nil)
|
||||
if err != nil {
|
||||
log.Fatalf("Error creating authenticator middleware: %s", err)
|
||||
}
|
||||
serverMiddlewares = append(serverMiddlewares, authMiddleware)
|
||||
}
|
||||
|
||||
router := InternalRouterAggregator{}
|
||||
routerWithMiddleware := InternalRouterAggregator{}
|
||||
|
||||
if globalConfiguration.Metrics != nil && globalConfiguration.Metrics.Prometheus != nil && globalConfiguration.Metrics.Prometheus.EntryPoint == entryPointName {
|
||||
// routerWithMiddleware.AddRouter(metrics.PrometheusHandler{})
|
||||
}
|
||||
|
||||
if globalConfiguration.Rest != nil && globalConfiguration.Rest.EntryPoint == entryPointName {
|
||||
routerWithMiddleware.AddRouter(globalConfiguration.Rest)
|
||||
}
|
||||
|
||||
if globalConfiguration.API != nil && globalConfiguration.API.EntryPoint == entryPointName {
|
||||
routerWithMiddleware.AddRouter(globalConfiguration.API)
|
||||
}
|
||||
|
||||
if globalConfiguration.Ping != nil && globalConfiguration.Ping.EntryPoint == entryPointName {
|
||||
router.AddRouter(globalConfiguration.Ping)
|
||||
}
|
||||
|
||||
if globalConfiguration.ACME != nil && globalConfiguration.ACME.HTTPChallenge != nil && globalConfiguration.ACME.HTTPChallenge.EntryPoint == entryPointName {
|
||||
router.AddRouter(globalConfiguration.ACME)
|
||||
}
|
||||
|
||||
router.AddRouter(&WithMiddleware{router: &routerWithMiddleware, routerMiddlewares: serverMiddlewares})
|
||||
|
||||
return &router
|
||||
}
|
||||
|
||||
// WithMiddleware router with internal middleware
|
||||
type WithMiddleware struct {
|
||||
router types.InternalRouter
|
||||
routerMiddlewares []negroni.Handler
|
||||
}
|
||||
|
||||
// AddRoutes Add routes to the router
|
||||
func (wm *WithMiddleware) AddRoutes(systemRouter *mux.Router) {
|
||||
realRouter := systemRouter.PathPrefix("/").Subrouter()
|
||||
|
||||
wm.router.AddRoutes(realRouter)
|
||||
|
||||
if len(wm.routerMiddlewares) > 0 {
|
||||
if err := realRouter.Walk(wrapRoute(wm.routerMiddlewares)); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// InternalRouterAggregator InternalRouter that aggregate other internalRouter
|
||||
type InternalRouterAggregator struct {
|
||||
internalRouters []types.InternalRouter
|
||||
}
|
||||
|
||||
// AddRouter add a router in the aggregator
|
||||
func (r *InternalRouterAggregator) AddRouter(router types.InternalRouter) {
|
||||
r.internalRouters = append(r.internalRouters, router)
|
||||
}
|
||||
|
||||
// AddRoutes Add routes to the router
|
||||
func (r *InternalRouterAggregator) AddRoutes(systemRouter *mux.Router) {
|
||||
for _, router := range r.internalRouters {
|
||||
router.AddRoutes(systemRouter)
|
||||
}
|
||||
}
|
||||
|
||||
// wrapRoute with middlewares
|
||||
func wrapRoute(middlewares []negroni.Handler) func(*mux.Route, *mux.Router, []*mux.Route) error {
|
||||
return func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
|
||||
middles := append(middlewares, negroni.Wrap(route.GetHandler()))
|
||||
route.Handler(negroni.New(middles...))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
151
old/configuration/router/internal_router_test.go
Normal file
151
old/configuration/router/internal_router_test.go
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/acme"
|
||||
"github.com/containous/traefik/old/api"
|
||||
"github.com/containous/traefik/old/configuration"
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/containous/traefik/old/ping"
|
||||
acmeprovider "github.com/containous/traefik/old/provider/acme"
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/containous/traefik/safe"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
func TestNewInternalRouterAggregatorWithAuth(t *testing.T) {
|
||||
currentConfiguration := &safe.Safe{}
|
||||
currentConfiguration.Set(types.Configurations{})
|
||||
|
||||
globalConfiguration := configuration.GlobalConfiguration{
|
||||
API: &api.Handler{
|
||||
EntryPoint: "traefik",
|
||||
CurrentConfigurations: currentConfiguration,
|
||||
},
|
||||
Ping: &ping.Handler{
|
||||
EntryPoint: "traefik",
|
||||
},
|
||||
ACME: &acme.ACME{
|
||||
HTTPChallenge: &acmeprovider.HTTPChallenge{
|
||||
EntryPoint: "traefik",
|
||||
},
|
||||
},
|
||||
EntryPoints: configuration.EntryPoints{
|
||||
"traefik": &configuration.EntryPoint{
|
||||
Auth: &types.Auth{
|
||||
Basic: &types.Basic{
|
||||
Users: types.Users{"test:test"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
testedURL string
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
desc: "Wrong url",
|
||||
testedURL: "/wrong",
|
||||
expectedStatusCode: 502,
|
||||
},
|
||||
{
|
||||
desc: "Ping without auth",
|
||||
testedURL: "/ping",
|
||||
expectedStatusCode: 200,
|
||||
},
|
||||
{
|
||||
desc: "acme without auth",
|
||||
testedURL: "/.well-known/acme-challenge/token",
|
||||
expectedStatusCode: 404,
|
||||
},
|
||||
{
|
||||
desc: "api with auth",
|
||||
testedURL: "/api",
|
||||
expectedStatusCode: 401,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
router := NewInternalRouterAggregator(globalConfiguration, "traefik")
|
||||
|
||||
internalMuxRouter := mux.NewRouter()
|
||||
router.AddRoutes(internalMuxRouter)
|
||||
internalMuxRouter.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadGateway)
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodGet, test.testedURL, nil)
|
||||
internalMuxRouter.ServeHTTP(recorder, request)
|
||||
|
||||
assert.Equal(t, test.expectedStatusCode, recorder.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type MockInternalRouterFunc func(systemRouter *mux.Router)
|
||||
|
||||
func (m MockInternalRouterFunc) AddRoutes(systemRouter *mux.Router) {
|
||||
m(systemRouter)
|
||||
}
|
||||
|
||||
func TestWithMiddleware(t *testing.T) {
|
||||
router := WithMiddleware{
|
||||
router: MockInternalRouterFunc(func(systemRouter *mux.Router) {
|
||||
systemRouter.Handle("/test", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if _, err := w.Write([]byte("router")); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}))
|
||||
}),
|
||||
routerMiddlewares: []negroni.Handler{
|
||||
negroni.HandlerFunc(func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
if _, err := rw.Write([]byte("before middleware1|")); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
next.ServeHTTP(rw, r)
|
||||
|
||||
if _, err := rw.Write([]byte("|after middleware1")); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
}),
|
||||
negroni.HandlerFunc(func(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
if _, err := rw.Write([]byte("before middleware2|")); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
next.ServeHTTP(rw, r)
|
||||
|
||||
if _, err := rw.Write([]byte("|after middleware2")); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}),
|
||||
},
|
||||
}
|
||||
|
||||
internalMuxRouter := mux.NewRouter()
|
||||
router.AddRoutes(internalMuxRouter)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
request := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
internalMuxRouter.ServeHTTP(recorder, request)
|
||||
|
||||
obtained := recorder.Body.String()
|
||||
|
||||
assert.Equal(t, "before middleware1|before middleware2|router|after middleware2|after middleware1", obtained)
|
||||
}
|
||||
315
old/log/logger.go
Normal file
315
old/log/logger.go
Normal file
|
|
@ -0,0 +1,315 @@
|
|||
package log
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Logger allows overriding the logrus logger behavior
|
||||
type Logger interface {
|
||||
logrus.FieldLogger
|
||||
WriterLevel(logrus.Level) *io.PipeWriter
|
||||
}
|
||||
|
||||
var (
|
||||
logger Logger
|
||||
logFilePath string
|
||||
logFile *os.File
|
||||
)
|
||||
|
||||
func init() {
|
||||
logger = logrus.StandardLogger().WithFields(logrus.Fields{})
|
||||
logrus.SetOutput(os.Stdout)
|
||||
}
|
||||
|
||||
// Context sets the Context of the logger
|
||||
func Context(context interface{}) *logrus.Entry {
|
||||
return logger.WithField("context", context)
|
||||
}
|
||||
|
||||
// SetOutput sets the standard logger output.
|
||||
func SetOutput(out io.Writer) {
|
||||
logrus.SetOutput(out)
|
||||
}
|
||||
|
||||
// SetFormatter sets the standard logger formatter.
|
||||
func SetFormatter(formatter logrus.Formatter) {
|
||||
logrus.SetFormatter(formatter)
|
||||
}
|
||||
|
||||
// SetLevel sets the standard logger level.
|
||||
func SetLevel(level logrus.Level) {
|
||||
logrus.SetLevel(level)
|
||||
}
|
||||
|
||||
// SetLogger sets the logger.
|
||||
func SetLogger(l Logger) {
|
||||
logger = l
|
||||
}
|
||||
|
||||
// GetLevel returns the standard logger level.
|
||||
func GetLevel() logrus.Level {
|
||||
return logrus.GetLevel()
|
||||
}
|
||||
|
||||
// AddHook adds a hook to the standard logger hooks.
|
||||
func AddHook(hook logrus.Hook) {
|
||||
logrus.AddHook(hook)
|
||||
}
|
||||
|
||||
// WithError creates an entry from the standard logger and adds an error to it, using the value defined in ErrorKey as key.
|
||||
func WithError(err error) *logrus.Entry {
|
||||
return logger.WithError(err)
|
||||
}
|
||||
|
||||
// WithField creates an entry from the standard logger and adds a field to
|
||||
// it. If you want multiple fields, use `WithFields`.
|
||||
//
|
||||
// Note that it doesn't log until you call Debug, Print, Info, Warn, Fatal
|
||||
// or Panic on the Entry it returns.
|
||||
func WithField(key string, value interface{}) *logrus.Entry {
|
||||
return logger.WithField(key, value)
|
||||
}
|
||||
|
||||
// WithFields creates an entry from the standard logger and adds multiple
|
||||
// fields to it. This is simply a helper for `WithField`, invoking it
|
||||
// once for each field.
|
||||
//
|
||||
// Note that it doesn't log until you call Debug, Print, Info, Warn, Fatal
|
||||
// or Panic on the Entry it returns.
|
||||
func WithFields(fields logrus.Fields) *logrus.Entry {
|
||||
return logger.WithFields(fields)
|
||||
}
|
||||
|
||||
// Debug logs a message at level Debug on the standard logger.
|
||||
func Debug(args ...interface{}) {
|
||||
logger.Debug(args...)
|
||||
}
|
||||
|
||||
// Print logs a message at level Info on the standard logger.
|
||||
func Print(args ...interface{}) {
|
||||
logger.Print(args...)
|
||||
}
|
||||
|
||||
// Info logs a message at level Info on the standard logger.
|
||||
func Info(args ...interface{}) {
|
||||
logger.Info(args...)
|
||||
}
|
||||
|
||||
// Warn logs a message at level Warn on the standard logger.
|
||||
func Warn(args ...interface{}) {
|
||||
logger.Warn(args...)
|
||||
}
|
||||
|
||||
// Warning logs a message at level Warn on the standard logger.
|
||||
func Warning(args ...interface{}) {
|
||||
logger.Warning(args...)
|
||||
}
|
||||
|
||||
// Error logs a message at level Error on the standard logger.
|
||||
func Error(args ...interface{}) {
|
||||
logger.Error(args...)
|
||||
}
|
||||
|
||||
// Panic logs a message at level Panic on the standard logger.
|
||||
func Panic(args ...interface{}) {
|
||||
logger.Panic(args...)
|
||||
}
|
||||
|
||||
// Fatal logs a message at level Fatal on the standard logger.
|
||||
func Fatal(args ...interface{}) {
|
||||
logger.Fatal(args...)
|
||||
}
|
||||
|
||||
// Debugf logs a message at level Debug on the standard logger.
|
||||
func Debugf(format string, args ...interface{}) {
|
||||
logger.Debugf(format, args...)
|
||||
}
|
||||
|
||||
// Printf logs a message at level Info on the standard logger.
|
||||
func Printf(format string, args ...interface{}) {
|
||||
logger.Printf(format, args...)
|
||||
}
|
||||
|
||||
// Infof logs a message at level Info on the standard logger.
|
||||
func Infof(format string, args ...interface{}) {
|
||||
logger.Infof(format, args...)
|
||||
}
|
||||
|
||||
// Warnf logs a message at level Warn on the standard logger.
|
||||
func Warnf(format string, args ...interface{}) {
|
||||
logger.Warnf(format, args...)
|
||||
}
|
||||
|
||||
// Warningf logs a message at level Warn on the standard logger.
|
||||
func Warningf(format string, args ...interface{}) {
|
||||
logger.Warningf(format, args...)
|
||||
}
|
||||
|
||||
// Errorf logs a message at level Error on the standard logger.
|
||||
func Errorf(format string, args ...interface{}) {
|
||||
logger.Errorf(format, args...)
|
||||
}
|
||||
|
||||
// Panicf logs a message at level Panic on the standard logger.
|
||||
func Panicf(format string, args ...interface{}) {
|
||||
logger.Panicf(format, args...)
|
||||
}
|
||||
|
||||
// Fatalf logs a message at level Fatal on the standard logger.
|
||||
func Fatalf(format string, args ...interface{}) {
|
||||
logger.Fatalf(format, args...)
|
||||
}
|
||||
|
||||
// Debugln logs a message at level Debug on the standard logger.
|
||||
func Debugln(args ...interface{}) {
|
||||
logger.Debugln(args...)
|
||||
}
|
||||
|
||||
// Println logs a message at level Info on the standard logger.
|
||||
func Println(args ...interface{}) {
|
||||
logger.Println(args...)
|
||||
}
|
||||
|
||||
// Infoln logs a message at level Info on the standard logger.
|
||||
func Infoln(args ...interface{}) {
|
||||
logger.Infoln(args...)
|
||||
}
|
||||
|
||||
// Warnln logs a message at level Warn on the standard logger.
|
||||
func Warnln(args ...interface{}) {
|
||||
logger.Warnln(args...)
|
||||
}
|
||||
|
||||
// Warningln logs a message at level Warn on the standard logger.
|
||||
func Warningln(args ...interface{}) {
|
||||
logger.Warningln(args...)
|
||||
}
|
||||
|
||||
// Errorln logs a message at level Error on the standard logger.
|
||||
func Errorln(args ...interface{}) {
|
||||
logger.Errorln(args...)
|
||||
}
|
||||
|
||||
// Panicln logs a message at level Panic on the standard logger.
|
||||
func Panicln(args ...interface{}) {
|
||||
logger.Panicln(args...)
|
||||
}
|
||||
|
||||
// Fatalln logs a message at level Fatal on the standard logger.
|
||||
func Fatalln(args ...interface{}) {
|
||||
logger.Fatalln(args...)
|
||||
}
|
||||
|
||||
// OpenFile opens the log file using the specified path
|
||||
func OpenFile(path string) error {
|
||||
logFilePath = path
|
||||
var err error
|
||||
logFile, err = os.OpenFile(logFilePath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666)
|
||||
|
||||
if err == nil {
|
||||
SetOutput(logFile)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// CloseFile closes the log and sets the Output to stdout
|
||||
func CloseFile() error {
|
||||
logrus.SetOutput(os.Stdout)
|
||||
|
||||
if logFile != nil {
|
||||
return logFile.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RotateFile closes and reopens the log file to allow for rotation
|
||||
// by an external source. If the log isn't backed by a file then
|
||||
// it does nothing.
|
||||
func RotateFile() error {
|
||||
if logFile == nil && logFilePath == "" {
|
||||
Debug("Traefik log is not writing to a file, ignoring rotate request")
|
||||
return nil
|
||||
}
|
||||
|
||||
if logFile != nil {
|
||||
defer func(f *os.File) {
|
||||
f.Close()
|
||||
}(logFile)
|
||||
}
|
||||
|
||||
if err := OpenFile(logFilePath); err != nil {
|
||||
return fmt.Errorf("error opening log file: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Writer logs writer (Level Info)
|
||||
func Writer() *io.PipeWriter {
|
||||
return WriterLevel(logrus.InfoLevel)
|
||||
}
|
||||
|
||||
// WriterLevel logs writer for a specific level.
|
||||
func WriterLevel(level logrus.Level) *io.PipeWriter {
|
||||
return logger.WriterLevel(level)
|
||||
}
|
||||
|
||||
// CustomWriterLevel logs writer for a specific level. (with a custom scanner buffer size.)
|
||||
// adapted from github.com/Sirupsen/logrus/writer.go
|
||||
func CustomWriterLevel(level logrus.Level, maxScanTokenSize int) *io.PipeWriter {
|
||||
reader, writer := io.Pipe()
|
||||
|
||||
var printFunc func(args ...interface{})
|
||||
|
||||
switch level {
|
||||
case logrus.DebugLevel:
|
||||
printFunc = Debug
|
||||
case logrus.InfoLevel:
|
||||
printFunc = Info
|
||||
case logrus.WarnLevel:
|
||||
printFunc = Warn
|
||||
case logrus.ErrorLevel:
|
||||
printFunc = Error
|
||||
case logrus.FatalLevel:
|
||||
printFunc = Fatal
|
||||
case logrus.PanicLevel:
|
||||
printFunc = Panic
|
||||
default:
|
||||
printFunc = Print
|
||||
}
|
||||
|
||||
go writerScanner(reader, maxScanTokenSize, printFunc)
|
||||
runtime.SetFinalizer(writer, writerFinalizer)
|
||||
|
||||
return writer
|
||||
}
|
||||
|
||||
// extract from github.com/Sirupsen/logrus/writer.go
|
||||
// Hack the buffer size
|
||||
func writerScanner(reader io.ReadCloser, scanTokenSize int, printFunc func(args ...interface{})) {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
|
||||
if scanTokenSize > bufio.MaxScanTokenSize {
|
||||
buf := make([]byte, bufio.MaxScanTokenSize)
|
||||
scanner.Buffer(buf, scanTokenSize)
|
||||
}
|
||||
|
||||
for scanner.Scan() {
|
||||
printFunc(scanner.Text())
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
Errorf("Error while reading from Writer: %s", err)
|
||||
}
|
||||
reader.Close()
|
||||
}
|
||||
|
||||
func writerFinalizer(writer *io.PipeWriter) {
|
||||
writer.Close()
|
||||
}
|
||||
79
old/log/logger_test.go
Normal file
79
old/log/logger_test.go
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
package log
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLogRotation(t *testing.T) {
|
||||
tempDir, err := ioutil.TempDir("", "traefik_")
|
||||
if err != nil {
|
||||
t.Fatalf("Error setting up temporary directory: %s", err)
|
||||
}
|
||||
|
||||
fileName := tempDir + "traefik.log"
|
||||
if err := OpenFile(fileName); err != nil {
|
||||
t.Fatalf("Error opening temporary file %s: %s", fileName, err)
|
||||
}
|
||||
defer CloseFile()
|
||||
|
||||
rotatedFileName := fileName + ".rotated"
|
||||
|
||||
iterations := 20
|
||||
halfDone := make(chan bool)
|
||||
writeDone := make(chan bool)
|
||||
go func() {
|
||||
for i := 0; i < iterations; i++ {
|
||||
Println("Test log line")
|
||||
if i == iterations/2 {
|
||||
halfDone <- true
|
||||
}
|
||||
}
|
||||
writeDone <- true
|
||||
}()
|
||||
|
||||
<-halfDone
|
||||
err = os.Rename(fileName, rotatedFileName)
|
||||
if err != nil {
|
||||
t.Fatalf("Error renaming file: %s", err)
|
||||
}
|
||||
|
||||
err = RotateFile()
|
||||
if err != nil {
|
||||
t.Fatalf("Error rotating file: %s", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-writeDone:
|
||||
gotLineCount := lineCount(t, fileName) + lineCount(t, rotatedFileName)
|
||||
if iterations != gotLineCount {
|
||||
t.Errorf("Wanted %d written log lines, got %d", iterations, gotLineCount)
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatalf("test timed out")
|
||||
}
|
||||
|
||||
close(halfDone)
|
||||
close(writeDone)
|
||||
}
|
||||
|
||||
func lineCount(t *testing.T, fileName string) int {
|
||||
t.Helper()
|
||||
fileContents, err := ioutil.ReadFile(fileName)
|
||||
if err != nil {
|
||||
t.Fatalf("Error reading from file %s: %s", fileName, err)
|
||||
}
|
||||
|
||||
count := 0
|
||||
for _, line := range strings.Split(string(fileContents), "\n") {
|
||||
if strings.TrimSpace(line) == "" {
|
||||
continue
|
||||
}
|
||||
count++
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
18
old/middlewares/accesslog/capture_request_reader.go
Normal file
18
old/middlewares/accesslog/capture_request_reader.go
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
package accesslog
|
||||
|
||||
import "io"
|
||||
|
||||
type captureRequestReader struct {
|
||||
source io.ReadCloser
|
||||
count int64
|
||||
}
|
||||
|
||||
func (r *captureRequestReader) Read(p []byte) (int, error) {
|
||||
n, err := r.source.Read(p)
|
||||
r.count += int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (r *captureRequestReader) Close() error {
|
||||
return r.source.Close()
|
||||
}
|
||||
68
old/middlewares/accesslog/capture_response_writer.go
Normal file
68
old/middlewares/accesslog/capture_response_writer.go
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/old/middlewares"
|
||||
)
|
||||
|
||||
var (
|
||||
_ middlewares.Stateful = &captureResponseWriter{}
|
||||
)
|
||||
|
||||
// captureResponseWriter is a wrapper of type http.ResponseWriter
|
||||
// that tracks request status and size
|
||||
type captureResponseWriter struct {
|
||||
rw http.ResponseWriter
|
||||
status int
|
||||
size int64
|
||||
}
|
||||
|
||||
func (crw *captureResponseWriter) Header() http.Header {
|
||||
return crw.rw.Header()
|
||||
}
|
||||
|
||||
func (crw *captureResponseWriter) Write(b []byte) (int, error) {
|
||||
if crw.status == 0 {
|
||||
crw.status = http.StatusOK
|
||||
}
|
||||
size, err := crw.rw.Write(b)
|
||||
crw.size += int64(size)
|
||||
return size, err
|
||||
}
|
||||
|
||||
func (crw *captureResponseWriter) WriteHeader(s int) {
|
||||
crw.rw.WriteHeader(s)
|
||||
crw.status = s
|
||||
}
|
||||
|
||||
func (crw *captureResponseWriter) Flush() {
|
||||
if f, ok := crw.rw.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (crw *captureResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if h, ok := crw.rw.(http.Hijacker); ok {
|
||||
return h.Hijack()
|
||||
}
|
||||
return nil, nil, fmt.Errorf("not a hijacker: %T", crw.rw)
|
||||
}
|
||||
|
||||
func (crw *captureResponseWriter) CloseNotify() <-chan bool {
|
||||
if c, ok := crw.rw.(http.CloseNotifier); ok {
|
||||
return c.CloseNotify()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (crw *captureResponseWriter) Status() int {
|
||||
return crw.status
|
||||
}
|
||||
|
||||
func (crw *captureResponseWriter) Size() int64 {
|
||||
return crw.size
|
||||
}
|
||||
120
old/middlewares/accesslog/logdata.go
Normal file
120
old/middlewares/accesslog/logdata.go
Normal file
|
|
@ -0,0 +1,120 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
// StartUTC is the map key used for the time at which request processing started.
|
||||
StartUTC = "StartUTC"
|
||||
// StartLocal is the map key used for the local time at which request processing started.
|
||||
StartLocal = "StartLocal"
|
||||
// Duration is the map key used for the total time taken by processing the response, including the origin server's time but
|
||||
// not the log writing time.
|
||||
Duration = "Duration"
|
||||
// FrontendName is the map key used for the name of the Traefik frontend.
|
||||
FrontendName = "FrontendName"
|
||||
// BackendName is the map key used for the name of the Traefik backend.
|
||||
BackendName = "BackendName"
|
||||
// BackendURL is the map key used for the URL of the Traefik backend.
|
||||
BackendURL = "BackendURL"
|
||||
// BackendAddr is the map key used for the IP:port of the Traefik backend (extracted from BackendURL)
|
||||
BackendAddr = "BackendAddr"
|
||||
// ClientAddr is the map key used for the remote address in its original form (usually IP:port).
|
||||
ClientAddr = "ClientAddr"
|
||||
// ClientHost is the map key used for the remote IP address from which the client request was received.
|
||||
ClientHost = "ClientHost"
|
||||
// ClientPort is the map key used for the remote TCP port from which the client request was received.
|
||||
ClientPort = "ClientPort"
|
||||
// ClientUsername is the map key used for the username provided in the URL, if present.
|
||||
ClientUsername = "ClientUsername"
|
||||
// RequestAddr is the map key used for the HTTP Host header (usually IP:port). This is treated as not a header by the Go API.
|
||||
RequestAddr = "RequestAddr"
|
||||
// RequestHost is the map key used for the HTTP Host server name (not including port).
|
||||
RequestHost = "RequestHost"
|
||||
// RequestPort is the map key used for the TCP port from the HTTP Host.
|
||||
RequestPort = "RequestPort"
|
||||
// RequestMethod is the map key used for the HTTP method.
|
||||
RequestMethod = "RequestMethod"
|
||||
// RequestPath is the map key used for the HTTP request URI, not including the scheme, host or port.
|
||||
RequestPath = "RequestPath"
|
||||
// RequestProtocol is the map key used for the version of HTTP requested.
|
||||
RequestProtocol = "RequestProtocol"
|
||||
// RequestContentSize is the map key used for the number of bytes in the request entity (a.k.a. body) sent by the client.
|
||||
RequestContentSize = "RequestContentSize"
|
||||
// RequestRefererHeader is the Referer header in the request
|
||||
RequestRefererHeader = "request_Referer"
|
||||
// RequestUserAgentHeader is the User-Agent header in the request
|
||||
RequestUserAgentHeader = "request_User-Agent"
|
||||
// OriginDuration is the map key used for the time taken by the origin server ('upstream') to return its response.
|
||||
OriginDuration = "OriginDuration"
|
||||
// OriginContentSize is the map key used for the content length specified by the origin server, or 0 if unspecified.
|
||||
OriginContentSize = "OriginContentSize"
|
||||
// OriginStatus is the map key used for the HTTP status code returned by the origin server.
|
||||
// If the request was handled by this Traefik instance (e.g. with a redirect), then this value will be absent.
|
||||
OriginStatus = "OriginStatus"
|
||||
// DownstreamStatus is the map key used for the HTTP status code returned to the client.
|
||||
DownstreamStatus = "DownstreamStatus"
|
||||
// DownstreamContentSize is the map key used for the number of bytes in the response entity returned to the client.
|
||||
// This is in addition to the "Content-Length" header, which may be present in the origin response.
|
||||
DownstreamContentSize = "DownstreamContentSize"
|
||||
// RequestCount is the map key used for the number of requests received since the Traefik instance started.
|
||||
RequestCount = "RequestCount"
|
||||
// GzipRatio is the map key used for the response body compression ratio achieved.
|
||||
GzipRatio = "GzipRatio"
|
||||
// Overhead is the map key used for the processing time overhead caused by Traefik.
|
||||
Overhead = "Overhead"
|
||||
// RetryAttempts is the map key used for the amount of attempts the request was retried.
|
||||
RetryAttempts = "RetryAttempts"
|
||||
)
|
||||
|
||||
// These are written out in the default case when no config is provided to specify keys of interest.
|
||||
var defaultCoreKeys = [...]string{
|
||||
StartUTC,
|
||||
Duration,
|
||||
FrontendName,
|
||||
BackendName,
|
||||
BackendURL,
|
||||
ClientHost,
|
||||
ClientPort,
|
||||
ClientUsername,
|
||||
RequestHost,
|
||||
RequestPort,
|
||||
RequestMethod,
|
||||
RequestPath,
|
||||
RequestProtocol,
|
||||
RequestContentSize,
|
||||
OriginDuration,
|
||||
OriginContentSize,
|
||||
OriginStatus,
|
||||
DownstreamStatus,
|
||||
DownstreamContentSize,
|
||||
RequestCount,
|
||||
}
|
||||
|
||||
// This contains the set of all keys, i.e. all the default keys plus all non-default keys.
|
||||
var allCoreKeys = make(map[string]struct{})
|
||||
|
||||
func init() {
|
||||
for _, k := range defaultCoreKeys {
|
||||
allCoreKeys[k] = struct{}{}
|
||||
}
|
||||
allCoreKeys[BackendAddr] = struct{}{}
|
||||
allCoreKeys[ClientAddr] = struct{}{}
|
||||
allCoreKeys[RequestAddr] = struct{}{}
|
||||
allCoreKeys[GzipRatio] = struct{}{}
|
||||
allCoreKeys[StartLocal] = struct{}{}
|
||||
allCoreKeys[Overhead] = struct{}{}
|
||||
allCoreKeys[RetryAttempts] = struct{}{}
|
||||
}
|
||||
|
||||
// CoreLogData holds the fields computed from the request/response.
|
||||
type CoreLogData map[string]interface{}
|
||||
|
||||
// LogData is the data captured by the middleware so that it can be logged.
|
||||
type LogData struct {
|
||||
Core CoreLogData
|
||||
Request http.Header
|
||||
OriginResponse http.Header
|
||||
DownstreamResponse http.Header
|
||||
}
|
||||
334
old/middlewares/accesslog/logger.go
Normal file
334
old/middlewares/accesslog/logger.go
Normal file
|
|
@ -0,0 +1,334 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/containous/flaeg/parse"
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type key string
|
||||
|
||||
const (
|
||||
// DataTableKey is the key within the request context used to
|
||||
// store the Log Data Table
|
||||
DataTableKey key = "LogDataTable"
|
||||
|
||||
// CommonFormat is the common logging format (CLF)
|
||||
CommonFormat string = "common"
|
||||
|
||||
// JSONFormat is the JSON logging format
|
||||
JSONFormat string = "json"
|
||||
)
|
||||
|
||||
type logHandlerParams struct {
|
||||
logDataTable *LogData
|
||||
crr *captureRequestReader
|
||||
crw *captureResponseWriter
|
||||
}
|
||||
|
||||
// LogHandler will write each request and its response to the access log.
|
||||
type LogHandler struct {
|
||||
config *types.AccessLog
|
||||
logger *logrus.Logger
|
||||
file *os.File
|
||||
mu sync.Mutex
|
||||
httpCodeRanges types.HTTPCodeRanges
|
||||
logHandlerChan chan logHandlerParams
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewLogHandler creates a new LogHandler
|
||||
func NewLogHandler(config *types.AccessLog) (*LogHandler, error) {
|
||||
file := os.Stdout
|
||||
if len(config.FilePath) > 0 {
|
||||
f, err := openAccessLogFile(config.FilePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening access log file: %s", err)
|
||||
}
|
||||
file = f
|
||||
}
|
||||
logHandlerChan := make(chan logHandlerParams, config.BufferingSize)
|
||||
|
||||
var formatter logrus.Formatter
|
||||
|
||||
switch config.Format {
|
||||
case CommonFormat:
|
||||
formatter = new(CommonLogFormatter)
|
||||
case JSONFormat:
|
||||
formatter = new(logrus.JSONFormatter)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported access log format: %s", config.Format)
|
||||
}
|
||||
|
||||
logger := &logrus.Logger{
|
||||
Out: file,
|
||||
Formatter: formatter,
|
||||
Hooks: make(logrus.LevelHooks),
|
||||
Level: logrus.InfoLevel,
|
||||
}
|
||||
|
||||
logHandler := &LogHandler{
|
||||
config: config,
|
||||
logger: logger,
|
||||
file: file,
|
||||
logHandlerChan: logHandlerChan,
|
||||
}
|
||||
|
||||
if config.Filters != nil {
|
||||
if httpCodeRanges, err := types.NewHTTPCodeRanges(config.Filters.StatusCodes); err != nil {
|
||||
log.Errorf("Failed to create new HTTP code ranges: %s", err)
|
||||
} else {
|
||||
logHandler.httpCodeRanges = httpCodeRanges
|
||||
}
|
||||
}
|
||||
|
||||
if config.BufferingSize > 0 {
|
||||
logHandler.wg.Add(1)
|
||||
go func() {
|
||||
defer logHandler.wg.Done()
|
||||
for handlerParams := range logHandler.logHandlerChan {
|
||||
logHandler.logTheRoundTrip(handlerParams.logDataTable, handlerParams.crr, handlerParams.crw)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return logHandler, nil
|
||||
}
|
||||
|
||||
func openAccessLogFile(filePath string) (*os.File, error) {
|
||||
dir := filepath.Dir(filePath)
|
||||
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create log path %s: %s", dir, err)
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0664)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening file %s: %s", filePath, err)
|
||||
}
|
||||
|
||||
return file, nil
|
||||
}
|
||||
|
||||
// GetLogDataTable gets the request context object that contains logging data.
|
||||
// This creates data as the request passes through the middleware chain.
|
||||
func GetLogDataTable(req *http.Request) *LogData {
|
||||
if ld, ok := req.Context().Value(DataTableKey).(*LogData); ok {
|
||||
return ld
|
||||
}
|
||||
log.Errorf("%s is nil", DataTableKey)
|
||||
return &LogData{Core: make(CoreLogData)}
|
||||
}
|
||||
|
||||
func (l *LogHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
|
||||
now := time.Now().UTC()
|
||||
|
||||
core := CoreLogData{
|
||||
StartUTC: now,
|
||||
StartLocal: now.Local(),
|
||||
}
|
||||
|
||||
logDataTable := &LogData{Core: core, Request: req.Header}
|
||||
|
||||
reqWithDataTable := req.WithContext(context.WithValue(req.Context(), DataTableKey, logDataTable))
|
||||
|
||||
var crr *captureRequestReader
|
||||
if req.Body != nil {
|
||||
crr = &captureRequestReader{source: req.Body, count: 0}
|
||||
reqWithDataTable.Body = crr
|
||||
}
|
||||
|
||||
core[RequestCount] = nextRequestCount()
|
||||
if req.Host != "" {
|
||||
core[RequestAddr] = req.Host
|
||||
core[RequestHost], core[RequestPort] = silentSplitHostPort(req.Host)
|
||||
}
|
||||
// copy the URL without the scheme, hostname etc
|
||||
urlCopy := &url.URL{
|
||||
Path: req.URL.Path,
|
||||
RawPath: req.URL.RawPath,
|
||||
RawQuery: req.URL.RawQuery,
|
||||
ForceQuery: req.URL.ForceQuery,
|
||||
Fragment: req.URL.Fragment,
|
||||
}
|
||||
urlCopyString := urlCopy.String()
|
||||
core[RequestMethod] = req.Method
|
||||
core[RequestPath] = urlCopyString
|
||||
core[RequestProtocol] = req.Proto
|
||||
|
||||
core[ClientAddr] = req.RemoteAddr
|
||||
core[ClientHost], core[ClientPort] = silentSplitHostPort(req.RemoteAddr)
|
||||
|
||||
if forwardedFor := req.Header.Get("X-Forwarded-For"); forwardedFor != "" {
|
||||
core[ClientHost] = forwardedFor
|
||||
}
|
||||
|
||||
crw := &captureResponseWriter{rw: rw}
|
||||
|
||||
next.ServeHTTP(crw, reqWithDataTable)
|
||||
|
||||
core[ClientUsername] = formatUsernameForLog(core[ClientUsername])
|
||||
|
||||
logDataTable.DownstreamResponse = crw.Header()
|
||||
|
||||
if l.config.BufferingSize > 0 {
|
||||
l.logHandlerChan <- logHandlerParams{
|
||||
logDataTable: logDataTable,
|
||||
crr: crr,
|
||||
crw: crw,
|
||||
}
|
||||
} else {
|
||||
l.logTheRoundTrip(logDataTable, crr, crw)
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the Logger (i.e. the file, drain logHandlerChan, etc).
|
||||
func (l *LogHandler) Close() error {
|
||||
close(l.logHandlerChan)
|
||||
l.wg.Wait()
|
||||
return l.file.Close()
|
||||
}
|
||||
|
||||
// Rotate closes and reopens the log file to allow for rotation
|
||||
// by an external source.
|
||||
func (l *LogHandler) Rotate() error {
|
||||
var err error
|
||||
|
||||
if l.file != nil {
|
||||
defer func(f *os.File) {
|
||||
f.Close()
|
||||
}(l.file)
|
||||
}
|
||||
|
||||
l.file, err = os.OpenFile(l.config.FilePath, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0664)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.logger.Out = l.file
|
||||
return nil
|
||||
}
|
||||
|
||||
func silentSplitHostPort(value string) (host string, port string) {
|
||||
host, port, err := net.SplitHostPort(value)
|
||||
if err != nil {
|
||||
return value, "-"
|
||||
}
|
||||
return host, port
|
||||
}
|
||||
|
||||
func formatUsernameForLog(usernameField interface{}) string {
|
||||
username, ok := usernameField.(string)
|
||||
if ok && len(username) != 0 {
|
||||
return username
|
||||
}
|
||||
return "-"
|
||||
}
|
||||
|
||||
// Logging handler to log frontend name, backend name, and elapsed time
|
||||
func (l *LogHandler) logTheRoundTrip(logDataTable *LogData, crr *captureRequestReader, crw *captureResponseWriter) {
|
||||
core := logDataTable.Core
|
||||
|
||||
retryAttempts, ok := core[RetryAttempts].(int)
|
||||
if !ok {
|
||||
retryAttempts = 0
|
||||
}
|
||||
core[RetryAttempts] = retryAttempts
|
||||
|
||||
if crr != nil {
|
||||
core[RequestContentSize] = crr.count
|
||||
}
|
||||
|
||||
core[DownstreamStatus] = crw.Status()
|
||||
|
||||
// n.b. take care to perform time arithmetic using UTC to avoid errors at DST boundaries
|
||||
totalDuration := time.Now().UTC().Sub(core[StartUTC].(time.Time))
|
||||
core[Duration] = totalDuration
|
||||
|
||||
if l.keepAccessLog(crw.Status(), retryAttempts, totalDuration) {
|
||||
core[DownstreamContentSize] = crw.Size()
|
||||
if original, ok := core[OriginContentSize]; ok {
|
||||
o64 := original.(int64)
|
||||
if o64 != crw.Size() && 0 != crw.Size() {
|
||||
core[GzipRatio] = float64(o64) / float64(crw.Size())
|
||||
}
|
||||
}
|
||||
|
||||
core[Overhead] = totalDuration
|
||||
if origin, ok := core[OriginDuration]; ok {
|
||||
core[Overhead] = totalDuration - origin.(time.Duration)
|
||||
}
|
||||
|
||||
fields := logrus.Fields{}
|
||||
|
||||
for k, v := range logDataTable.Core {
|
||||
if l.config.Fields.Keep(k) {
|
||||
fields[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
l.redactHeaders(logDataTable.Request, fields, "request_")
|
||||
l.redactHeaders(logDataTable.OriginResponse, fields, "origin_")
|
||||
l.redactHeaders(logDataTable.DownstreamResponse, fields, "downstream_")
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
l.logger.WithFields(fields).Println()
|
||||
}
|
||||
}
|
||||
|
||||
func (l *LogHandler) redactHeaders(headers http.Header, fields logrus.Fields, prefix string) {
|
||||
for k := range headers {
|
||||
v := l.config.Fields.KeepHeader(k)
|
||||
if v == types.AccessLogKeep {
|
||||
fields[prefix+k] = headers.Get(k)
|
||||
} else if v == types.AccessLogRedact {
|
||||
fields[prefix+k] = "REDACTED"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *LogHandler) keepAccessLog(statusCode, retryAttempts int, duration time.Duration) bool {
|
||||
if l.config.Filters == nil {
|
||||
// no filters were specified
|
||||
return true
|
||||
}
|
||||
|
||||
if len(l.httpCodeRanges) == 0 && !l.config.Filters.RetryAttempts && l.config.Filters.MinDuration == 0 {
|
||||
// empty filters were specified, e.g. by passing --accessLog.filters only (without other filter options)
|
||||
return true
|
||||
}
|
||||
|
||||
if l.httpCodeRanges.Contains(statusCode) {
|
||||
return true
|
||||
}
|
||||
|
||||
if l.config.Filters.RetryAttempts && retryAttempts > 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if l.config.Filters.MinDuration > 0 && (parse.Duration(duration) > l.config.Filters.MinDuration) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
var requestCounter uint64 // Request ID
|
||||
|
||||
func nextRequestCount() uint64 {
|
||||
return atomic.AddUint64(&requestCounter, 1)
|
||||
}
|
||||
82
old/middlewares/accesslog/logger_formatters.go
Normal file
82
old/middlewares/accesslog/logger_formatters.go
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// default format for time presentation
|
||||
const (
|
||||
commonLogTimeFormat = "02/Jan/2006:15:04:05 -0700"
|
||||
defaultValue = "-"
|
||||
)
|
||||
|
||||
// CommonLogFormatter provides formatting in the Traefik common log format
|
||||
type CommonLogFormatter struct{}
|
||||
|
||||
// Format formats the log entry in the Traefik common log format
|
||||
func (f *CommonLogFormatter) Format(entry *logrus.Entry) ([]byte, error) {
|
||||
b := &bytes.Buffer{}
|
||||
|
||||
var timestamp = defaultValue
|
||||
if v, ok := entry.Data[StartUTC]; ok {
|
||||
timestamp = v.(time.Time).Format(commonLogTimeFormat)
|
||||
}
|
||||
|
||||
var elapsedMillis int64
|
||||
if v, ok := entry.Data[Duration]; ok {
|
||||
elapsedMillis = v.(time.Duration).Nanoseconds() / 1000000
|
||||
}
|
||||
|
||||
_, err := fmt.Fprintf(b, "%s - %s [%s] \"%s %s %s\" %v %v %s %s %v %s %s %dms\n",
|
||||
toLog(entry.Data, ClientHost, defaultValue, false),
|
||||
toLog(entry.Data, ClientUsername, defaultValue, false),
|
||||
timestamp,
|
||||
toLog(entry.Data, RequestMethod, defaultValue, false),
|
||||
toLog(entry.Data, RequestPath, defaultValue, false),
|
||||
toLog(entry.Data, RequestProtocol, defaultValue, false),
|
||||
toLog(entry.Data, OriginStatus, defaultValue, true),
|
||||
toLog(entry.Data, OriginContentSize, defaultValue, true),
|
||||
toLog(entry.Data, "request_Referer", `"-"`, true),
|
||||
toLog(entry.Data, "request_User-Agent", `"-"`, true),
|
||||
toLog(entry.Data, RequestCount, defaultValue, true),
|
||||
toLog(entry.Data, FrontendName, defaultValue, true),
|
||||
toLog(entry.Data, BackendURL, defaultValue, true),
|
||||
elapsedMillis)
|
||||
|
||||
return b.Bytes(), err
|
||||
}
|
||||
|
||||
func toLog(fields logrus.Fields, key string, defaultValue string, quoted bool) interface{} {
|
||||
if v, ok := fields[key]; ok {
|
||||
if v == nil {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
switch s := v.(type) {
|
||||
case string:
|
||||
return toLogEntry(s, defaultValue, quoted)
|
||||
|
||||
case fmt.Stringer:
|
||||
return toLogEntry(s.String(), defaultValue, quoted)
|
||||
|
||||
default:
|
||||
return v
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
|
||||
}
|
||||
func toLogEntry(s string, defaultValue string, quote bool) string {
|
||||
if len(s) == 0 {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
if quote {
|
||||
return `"` + s + `"`
|
||||
}
|
||||
return s
|
||||
}
|
||||
140
old/middlewares/accesslog/logger_formatters_test.go
Normal file
140
old/middlewares/accesslog/logger_formatters_test.go
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCommonLogFormatter_Format(t *testing.T) {
|
||||
clf := CommonLogFormatter{}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
data map[string]interface{}
|
||||
expectedLog string
|
||||
}{
|
||||
{
|
||||
name: "OriginStatus & OriginContentSize are nil",
|
||||
data: map[string]interface{}{
|
||||
StartUTC: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC),
|
||||
Duration: 123 * time.Second,
|
||||
ClientHost: "10.0.0.1",
|
||||
ClientUsername: "Client",
|
||||
RequestMethod: http.MethodGet,
|
||||
RequestPath: "/foo",
|
||||
RequestProtocol: "http",
|
||||
OriginStatus: nil,
|
||||
OriginContentSize: nil,
|
||||
RequestRefererHeader: "",
|
||||
RequestUserAgentHeader: "",
|
||||
RequestCount: 0,
|
||||
FrontendName: "",
|
||||
BackendURL: "",
|
||||
},
|
||||
expectedLog: `10.0.0.1 - Client [10/Nov/2009:23:00:00 +0000] "GET /foo http" - - "-" "-" 0 - - 123000ms
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "all data",
|
||||
data: map[string]interface{}{
|
||||
StartUTC: time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC),
|
||||
Duration: 123 * time.Second,
|
||||
ClientHost: "10.0.0.1",
|
||||
ClientUsername: "Client",
|
||||
RequestMethod: http.MethodGet,
|
||||
RequestPath: "/foo",
|
||||
RequestProtocol: "http",
|
||||
OriginStatus: 123,
|
||||
OriginContentSize: 132,
|
||||
RequestRefererHeader: "referer",
|
||||
RequestUserAgentHeader: "agent",
|
||||
RequestCount: nil,
|
||||
FrontendName: "foo",
|
||||
BackendURL: "http://10.0.0.2/toto",
|
||||
},
|
||||
expectedLog: `10.0.0.1 - Client [10/Nov/2009:23:00:00 +0000] "GET /foo http" 123 132 "referer" "agent" - "foo" "http://10.0.0.2/toto" 123000ms
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
entry := &logrus.Entry{Data: test.data}
|
||||
|
||||
raw, err := clf.Format(entry)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.expectedLog, string(raw))
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func Test_toLog(t *testing.T) {
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
fields logrus.Fields
|
||||
fieldName string
|
||||
defaultValue string
|
||||
quoted bool
|
||||
expectedLog interface{}
|
||||
}{
|
||||
{
|
||||
desc: "Should return int 1",
|
||||
fields: logrus.Fields{
|
||||
"Powpow": 1,
|
||||
},
|
||||
fieldName: "Powpow",
|
||||
defaultValue: defaultValue,
|
||||
quoted: false,
|
||||
expectedLog: 1,
|
||||
},
|
||||
{
|
||||
desc: "Should return string foo",
|
||||
fields: logrus.Fields{
|
||||
"Powpow": "foo",
|
||||
},
|
||||
fieldName: "Powpow",
|
||||
defaultValue: defaultValue,
|
||||
quoted: true,
|
||||
expectedLog: `"foo"`,
|
||||
},
|
||||
{
|
||||
desc: "Should return defaultValue if fieldName does not exist",
|
||||
fields: logrus.Fields{
|
||||
"Powpow": "foo",
|
||||
},
|
||||
fieldName: "",
|
||||
defaultValue: defaultValue,
|
||||
quoted: false,
|
||||
expectedLog: "-",
|
||||
},
|
||||
{
|
||||
desc: "Should return defaultValue if fields is nil",
|
||||
fields: nil,
|
||||
fieldName: "",
|
||||
defaultValue: defaultValue,
|
||||
quoted: false,
|
||||
expectedLog: "-",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
lg := toLog(test.fields, test.fieldName, defaultValue, test.quoted)
|
||||
|
||||
assert.Equal(t, test.expectedLog, lg)
|
||||
})
|
||||
}
|
||||
}
|
||||
644
old/middlewares/accesslog/logger_test.go
Normal file
644
old/middlewares/accesslog/logger_test.go
Normal file
|
|
@ -0,0 +1,644 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/containous/flaeg/parse"
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
logFileNameSuffix = "/traefik/logger/test.log"
|
||||
testContent = "Hello, World"
|
||||
testBackendName = "http://127.0.0.1/testBackend"
|
||||
testFrontendName = "testFrontend"
|
||||
testStatus = 123
|
||||
testContentSize int64 = 12
|
||||
testHostname = "TestHost"
|
||||
testUsername = "TestUser"
|
||||
testPath = "testpath"
|
||||
testPort = 8181
|
||||
testProto = "HTTP/0.0"
|
||||
testMethod = http.MethodPost
|
||||
testReferer = "testReferer"
|
||||
testUserAgent = "testUserAgent"
|
||||
testRetryAttempts = 2
|
||||
testStart = time.Now()
|
||||
)
|
||||
|
||||
func TestLogRotation(t *testing.T) {
|
||||
tempDir, err := ioutil.TempDir("", "traefik_")
|
||||
if err != nil {
|
||||
t.Fatalf("Error setting up temporary directory: %s", err)
|
||||
}
|
||||
|
||||
fileName := tempDir + "traefik.log"
|
||||
rotatedFileName := fileName + ".rotated"
|
||||
|
||||
config := &types.AccessLog{FilePath: fileName, Format: CommonFormat}
|
||||
logHandler, err := NewLogHandler(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating new log handler: %s", err)
|
||||
}
|
||||
defer logHandler.Close()
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
next := func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
iterations := 20
|
||||
halfDone := make(chan bool)
|
||||
writeDone := make(chan bool)
|
||||
go func() {
|
||||
for i := 0; i < iterations; i++ {
|
||||
logHandler.ServeHTTP(recorder, req, next)
|
||||
if i == iterations/2 {
|
||||
halfDone <- true
|
||||
}
|
||||
}
|
||||
writeDone <- true
|
||||
}()
|
||||
|
||||
<-halfDone
|
||||
err = os.Rename(fileName, rotatedFileName)
|
||||
if err != nil {
|
||||
t.Fatalf("Error renaming file: %s", err)
|
||||
}
|
||||
|
||||
err = logHandler.Rotate()
|
||||
if err != nil {
|
||||
t.Fatalf("Error rotating file: %s", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-writeDone:
|
||||
gotLineCount := lineCount(t, fileName) + lineCount(t, rotatedFileName)
|
||||
if iterations != gotLineCount {
|
||||
t.Errorf("Wanted %d written log lines, got %d", iterations, gotLineCount)
|
||||
}
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatalf("test timed out")
|
||||
}
|
||||
|
||||
close(halfDone)
|
||||
close(writeDone)
|
||||
}
|
||||
|
||||
func lineCount(t *testing.T, fileName string) int {
|
||||
t.Helper()
|
||||
fileContents, err := ioutil.ReadFile(fileName)
|
||||
if err != nil {
|
||||
t.Fatalf("Error reading from file %s: %s", fileName, err)
|
||||
}
|
||||
|
||||
count := 0
|
||||
for _, line := range strings.Split(string(fileContents), "\n") {
|
||||
if strings.TrimSpace(line) == "" {
|
||||
continue
|
||||
}
|
||||
count++
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
func TestLoggerCLF(t *testing.T) {
|
||||
tmpDir := createTempDir(t, CommonFormat)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
logFilePath := filepath.Join(tmpDir, logFileNameSuffix)
|
||||
config := &types.AccessLog{FilePath: logFilePath, Format: CommonFormat}
|
||||
doLogging(t, config)
|
||||
|
||||
logData, err := ioutil.ReadFile(logFilePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedLog := ` TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testFrontend" "http://127.0.0.1/testBackend" 1ms`
|
||||
assertValidLogData(t, expectedLog, logData)
|
||||
}
|
||||
|
||||
func TestAsyncLoggerCLF(t *testing.T) {
|
||||
tmpDir := createTempDir(t, CommonFormat)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
logFilePath := filepath.Join(tmpDir, logFileNameSuffix)
|
||||
config := &types.AccessLog{FilePath: logFilePath, Format: CommonFormat, BufferingSize: 1024}
|
||||
doLogging(t, config)
|
||||
|
||||
logData, err := ioutil.ReadFile(logFilePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedLog := ` TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testFrontend" "http://127.0.0.1/testBackend" 1ms`
|
||||
assertValidLogData(t, expectedLog, logData)
|
||||
}
|
||||
|
||||
func assertString(exp string) func(t *testing.T, actual interface{}) {
|
||||
return func(t *testing.T, actual interface{}) {
|
||||
t.Helper()
|
||||
|
||||
assert.Equal(t, exp, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func assertNotEqual(exp string) func(t *testing.T, actual interface{}) {
|
||||
return func(t *testing.T, actual interface{}) {
|
||||
t.Helper()
|
||||
|
||||
assert.NotEqual(t, exp, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func assertFloat64(exp float64) func(t *testing.T, actual interface{}) {
|
||||
return func(t *testing.T, actual interface{}) {
|
||||
t.Helper()
|
||||
|
||||
assert.Equal(t, exp, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func assertFloat64NotZero() func(t *testing.T, actual interface{}) {
|
||||
return func(t *testing.T, actual interface{}) {
|
||||
t.Helper()
|
||||
|
||||
assert.NotZero(t, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoggerJSON(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
config *types.AccessLog
|
||||
expected map[string]func(t *testing.T, value interface{})
|
||||
}{
|
||||
{
|
||||
desc: "default config",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: JSONFormat,
|
||||
},
|
||||
expected: map[string]func(t *testing.T, value interface{}){
|
||||
RequestHost: assertString(testHostname),
|
||||
RequestAddr: assertString(testHostname),
|
||||
RequestMethod: assertString(testMethod),
|
||||
RequestPath: assertString(testPath),
|
||||
RequestProtocol: assertString(testProto),
|
||||
RequestPort: assertString("-"),
|
||||
DownstreamStatus: assertFloat64(float64(testStatus)),
|
||||
DownstreamContentSize: assertFloat64(float64(len(testContent))),
|
||||
OriginContentSize: assertFloat64(float64(len(testContent))),
|
||||
OriginStatus: assertFloat64(float64(testStatus)),
|
||||
RequestRefererHeader: assertString(testReferer),
|
||||
RequestUserAgentHeader: assertString(testUserAgent),
|
||||
FrontendName: assertString(testFrontendName),
|
||||
BackendURL: assertString(testBackendName),
|
||||
ClientUsername: assertString(testUsername),
|
||||
ClientHost: assertString(testHostname),
|
||||
ClientPort: assertString(fmt.Sprintf("%d", testPort)),
|
||||
ClientAddr: assertString(fmt.Sprintf("%s:%d", testHostname, testPort)),
|
||||
"level": assertString("info"),
|
||||
"msg": assertString(""),
|
||||
"downstream_Content-Type": assertString("text/plain; charset=utf-8"),
|
||||
RequestCount: assertFloat64NotZero(),
|
||||
Duration: assertFloat64NotZero(),
|
||||
Overhead: assertFloat64NotZero(),
|
||||
RetryAttempts: assertFloat64(float64(testRetryAttempts)),
|
||||
"time": assertNotEqual(""),
|
||||
"StartLocal": assertNotEqual(""),
|
||||
"StartUTC": assertNotEqual(""),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "default config drop all fields",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: JSONFormat,
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "drop",
|
||||
},
|
||||
},
|
||||
expected: map[string]func(t *testing.T, value interface{}){
|
||||
"level": assertString("info"),
|
||||
"msg": assertString(""),
|
||||
"time": assertNotEqual(""),
|
||||
"downstream_Content-Type": assertString("text/plain; charset=utf-8"),
|
||||
RequestRefererHeader: assertString(testReferer),
|
||||
RequestUserAgentHeader: assertString(testUserAgent),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "default config drop all fields and headers",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: JSONFormat,
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "drop",
|
||||
Headers: &types.FieldHeaders{
|
||||
DefaultMode: "drop",
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]func(t *testing.T, value interface{}){
|
||||
"level": assertString("info"),
|
||||
"msg": assertString(""),
|
||||
"time": assertNotEqual(""),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "default config drop all fields and redact headers",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: JSONFormat,
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "drop",
|
||||
Headers: &types.FieldHeaders{
|
||||
DefaultMode: "redact",
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]func(t *testing.T, value interface{}){
|
||||
"level": assertString("info"),
|
||||
"msg": assertString(""),
|
||||
"time": assertNotEqual(""),
|
||||
"downstream_Content-Type": assertString("REDACTED"),
|
||||
RequestRefererHeader: assertString("REDACTED"),
|
||||
RequestUserAgentHeader: assertString("REDACTED"),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "default config drop all fields and headers but kept someone",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: JSONFormat,
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "drop",
|
||||
Names: types.FieldNames{
|
||||
RequestHost: "keep",
|
||||
},
|
||||
Headers: &types.FieldHeaders{
|
||||
DefaultMode: "drop",
|
||||
Names: types.FieldHeaderNames{
|
||||
"Referer": "keep",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string]func(t *testing.T, value interface{}){
|
||||
RequestHost: assertString(testHostname),
|
||||
"level": assertString("info"),
|
||||
"msg": assertString(""),
|
||||
"time": assertNotEqual(""),
|
||||
RequestRefererHeader: assertString(testReferer),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tmpDir := createTempDir(t, JSONFormat)
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
logFilePath := filepath.Join(tmpDir, logFileNameSuffix)
|
||||
|
||||
test.config.FilePath = logFilePath
|
||||
doLogging(t, test.config)
|
||||
|
||||
logData, err := ioutil.ReadFile(logFilePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
jsonData := make(map[string]interface{})
|
||||
err = json.Unmarshal(logData, &jsonData)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, len(test.expected), len(jsonData))
|
||||
|
||||
for field, assertion := range test.expected {
|
||||
assertion(t, jsonData[field])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewLogHandlerOutputStdout(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
config *types.AccessLog
|
||||
expectedLog string
|
||||
}{
|
||||
{
|
||||
desc: "default config",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
},
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
|
||||
},
|
||||
{
|
||||
desc: "default config with empty filters",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Filters: &types.AccessLogFilters{},
|
||||
},
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
|
||||
},
|
||||
{
|
||||
desc: "Status code filter not matching",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Filters: &types.AccessLogFilters{
|
||||
StatusCodes: []string{"200"},
|
||||
},
|
||||
},
|
||||
expectedLog: ``,
|
||||
},
|
||||
{
|
||||
desc: "Status code filter matching",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Filters: &types.AccessLogFilters{
|
||||
StatusCodes: []string{"123"},
|
||||
},
|
||||
},
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
|
||||
},
|
||||
{
|
||||
desc: "Duration filter not matching",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Filters: &types.AccessLogFilters{
|
||||
MinDuration: parse.Duration(1 * time.Hour),
|
||||
},
|
||||
},
|
||||
expectedLog: ``,
|
||||
},
|
||||
{
|
||||
desc: "Duration filter matching",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Filters: &types.AccessLogFilters{
|
||||
MinDuration: parse.Duration(1 * time.Millisecond),
|
||||
},
|
||||
},
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
|
||||
},
|
||||
{
|
||||
desc: "Retry attempts filter matching",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Filters: &types.AccessLogFilters{
|
||||
RetryAttempts: true,
|
||||
},
|
||||
},
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
|
||||
},
|
||||
{
|
||||
desc: "Default mode keep",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "keep",
|
||||
},
|
||||
},
|
||||
expectedLog: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
|
||||
},
|
||||
{
|
||||
desc: "Default mode keep with override",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "keep",
|
||||
Names: types.FieldNames{
|
||||
ClientHost: "drop",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLog: `- - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 23 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
|
||||
},
|
||||
{
|
||||
desc: "Default mode drop",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "drop",
|
||||
},
|
||||
},
|
||||
expectedLog: `- - - [-] "- - -" - - "testReferer" "testUserAgent" - - - 0ms`,
|
||||
},
|
||||
{
|
||||
desc: "Default mode drop with override",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "drop",
|
||||
Names: types.FieldNames{
|
||||
ClientHost: "drop",
|
||||
ClientUsername: "keep",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLog: `- - TestUser [-] "- - -" - - "testReferer" "testUserAgent" - - - 0ms`,
|
||||
},
|
||||
{
|
||||
desc: "Default mode drop with header dropped",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "drop",
|
||||
Names: types.FieldNames{
|
||||
ClientHost: "drop",
|
||||
ClientUsername: "keep",
|
||||
},
|
||||
Headers: &types.FieldHeaders{
|
||||
DefaultMode: "drop",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLog: `- - TestUser [-] "- - -" - - "-" "-" - - - 0ms`,
|
||||
},
|
||||
{
|
||||
desc: "Default mode drop with header redacted",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "drop",
|
||||
Names: types.FieldNames{
|
||||
ClientHost: "drop",
|
||||
ClientUsername: "keep",
|
||||
},
|
||||
Headers: &types.FieldHeaders{
|
||||
DefaultMode: "redact",
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLog: `- - TestUser [-] "- - -" - - "REDACTED" "REDACTED" - - - 0ms`,
|
||||
},
|
||||
{
|
||||
desc: "Default mode drop with header redacted",
|
||||
config: &types.AccessLog{
|
||||
FilePath: "",
|
||||
Format: CommonFormat,
|
||||
Fields: &types.AccessLogFields{
|
||||
DefaultMode: "drop",
|
||||
Names: types.FieldNames{
|
||||
ClientHost: "drop",
|
||||
ClientUsername: "keep",
|
||||
},
|
||||
Headers: &types.FieldHeaders{
|
||||
DefaultMode: "keep",
|
||||
Names: types.FieldHeaderNames{
|
||||
"Referer": "redact",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedLog: `- - TestUser [-] "- - -" - - "REDACTED" "testUserAgent" - - - 0ms`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
|
||||
// NOTE: It is not possible to run these cases in parallel because we capture Stdout
|
||||
|
||||
file, restoreStdout := captureStdout(t)
|
||||
defer restoreStdout()
|
||||
|
||||
doLogging(t, test.config)
|
||||
|
||||
written, err := ioutil.ReadFile(file.Name())
|
||||
require.NoError(t, err, "unable to read captured stdout from file")
|
||||
assertValidLogData(t, test.expectedLog, written)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func assertValidLogData(t *testing.T, expected string, logData []byte) {
|
||||
|
||||
if len(expected) == 0 {
|
||||
assert.Zero(t, len(logData))
|
||||
t.Log(string(logData))
|
||||
return
|
||||
}
|
||||
|
||||
result, err := ParseAccessLog(string(logData))
|
||||
require.NoError(t, err)
|
||||
|
||||
resultExpected, err := ParseAccessLog(expected)
|
||||
require.NoError(t, err)
|
||||
|
||||
formatErrMessage := fmt.Sprintf(`
|
||||
Expected: %s
|
||||
Actual: %s`, expected, string(logData))
|
||||
|
||||
require.Equal(t, len(resultExpected), len(result), formatErrMessage)
|
||||
assert.Equal(t, resultExpected[ClientHost], result[ClientHost], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[ClientUsername], result[ClientUsername], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[RequestMethod], result[RequestMethod], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[RequestPath], result[RequestPath], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[RequestProtocol], result[RequestProtocol], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[OriginStatus], result[OriginStatus], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[OriginContentSize], result[OriginContentSize], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[RequestRefererHeader], result[RequestRefererHeader], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[RequestUserAgentHeader], result[RequestUserAgentHeader], formatErrMessage)
|
||||
assert.Regexp(t, regexp.MustCompile("[0-9]*"), result[RequestCount], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[FrontendName], result[FrontendName], formatErrMessage)
|
||||
assert.Equal(t, resultExpected[BackendURL], result[BackendURL], formatErrMessage)
|
||||
assert.Regexp(t, regexp.MustCompile("[0-9]*ms"), result[Duration], formatErrMessage)
|
||||
}
|
||||
|
||||
func captureStdout(t *testing.T) (out *os.File, restoreStdout func()) {
|
||||
file, err := ioutil.TempFile("", "testlogger")
|
||||
require.NoError(t, err, "failed to create temp file")
|
||||
|
||||
original := os.Stdout
|
||||
os.Stdout = file
|
||||
|
||||
restoreStdout = func() {
|
||||
os.Stdout = original
|
||||
}
|
||||
|
||||
return file, restoreStdout
|
||||
}
|
||||
|
||||
func createTempDir(t *testing.T, prefix string) string {
|
||||
tmpDir, err := ioutil.TempDir("", prefix)
|
||||
require.NoError(t, err, "failed to create temp dir")
|
||||
|
||||
return tmpDir
|
||||
}
|
||||
|
||||
func doLogging(t *testing.T, config *types.AccessLog) {
|
||||
logger, err := NewLogHandler(config)
|
||||
require.NoError(t, err)
|
||||
defer logger.Close()
|
||||
|
||||
if config.FilePath != "" {
|
||||
_, err = os.Stat(config.FilePath)
|
||||
require.NoError(t, err, fmt.Sprintf("logger should create %s", config.FilePath))
|
||||
}
|
||||
|
||||
req := &http.Request{
|
||||
Header: map[string][]string{
|
||||
"User-Agent": {testUserAgent},
|
||||
"Referer": {testReferer},
|
||||
},
|
||||
Proto: testProto,
|
||||
Host: testHostname,
|
||||
Method: testMethod,
|
||||
RemoteAddr: fmt.Sprintf("%s:%d", testHostname, testPort),
|
||||
URL: &url.URL{
|
||||
Path: testPath,
|
||||
},
|
||||
}
|
||||
|
||||
logger.ServeHTTP(httptest.NewRecorder(), req, logWriterTestHandlerFunc)
|
||||
}
|
||||
|
||||
func logWriterTestHandlerFunc(rw http.ResponseWriter, r *http.Request) {
|
||||
if _, err := rw.Write([]byte(testContent)); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
rw.WriteHeader(testStatus)
|
||||
|
||||
logDataTable := GetLogDataTable(r)
|
||||
logDataTable.Core[FrontendName] = testFrontendName
|
||||
logDataTable.Core[BackendURL] = testBackendName
|
||||
logDataTable.Core[OriginStatus] = testStatus
|
||||
logDataTable.Core[OriginContentSize] = testContentSize
|
||||
logDataTable.Core[RetryAttempts] = testRetryAttempts
|
||||
logDataTable.Core[StartUTC] = testStart.UTC()
|
||||
logDataTable.Core[StartLocal] = testStart.Local()
|
||||
logDataTable.Core[ClientUsername] = testUsername
|
||||
}
|
||||
54
old/middlewares/accesslog/parser.go
Normal file
54
old/middlewares/accesslog/parser.go
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
// ParseAccessLog parse line of access log and return a map with each fields
|
||||
func ParseAccessLog(data string) (map[string]string, error) {
|
||||
var buffer bytes.Buffer
|
||||
buffer.WriteString(`(\S+)`) // 1 - ClientHost
|
||||
buffer.WriteString(`\s-\s`) // - - Spaces
|
||||
buffer.WriteString(`(\S+)\s`) // 2 - ClientUsername
|
||||
buffer.WriteString(`\[([^]]+)\]\s`) // 3 - StartUTC
|
||||
buffer.WriteString(`"(\S*)\s?`) // 4 - RequestMethod
|
||||
buffer.WriteString(`((?:[^"]*(?:\\")?)*)\s`) // 5 - RequestPath
|
||||
buffer.WriteString(`([^"]*)"\s`) // 6 - RequestProtocol
|
||||
buffer.WriteString(`(\S+)\s`) // 7 - OriginStatus
|
||||
buffer.WriteString(`(\S+)\s`) // 8 - OriginContentSize
|
||||
buffer.WriteString(`("?\S+"?)\s`) // 9 - Referrer
|
||||
buffer.WriteString(`("\S+")\s`) // 10 - User-Agent
|
||||
buffer.WriteString(`(\S+)\s`) // 11 - RequestCount
|
||||
buffer.WriteString(`("[^"]*"|-)\s`) // 12 - FrontendName
|
||||
buffer.WriteString(`("[^"]*"|-)\s`) // 13 - BackendURL
|
||||
buffer.WriteString(`(\S+)`) // 14 - Duration
|
||||
|
||||
regex, err := regexp.Compile(buffer.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
submatch := regex.FindStringSubmatch(data)
|
||||
result := make(map[string]string)
|
||||
|
||||
// Need to be > 13 to match CLF format
|
||||
if len(submatch) > 13 {
|
||||
result[ClientHost] = submatch[1]
|
||||
result[ClientUsername] = submatch[2]
|
||||
result[StartUTC] = submatch[3]
|
||||
result[RequestMethod] = submatch[4]
|
||||
result[RequestPath] = submatch[5]
|
||||
result[RequestProtocol] = submatch[6]
|
||||
result[OriginStatus] = submatch[7]
|
||||
result[OriginContentSize] = submatch[8]
|
||||
result[RequestRefererHeader] = submatch[9]
|
||||
result[RequestUserAgentHeader] = submatch[10]
|
||||
result[RequestCount] = submatch[11]
|
||||
result[FrontendName] = submatch[12]
|
||||
result[BackendURL] = submatch[13]
|
||||
result[Duration] = submatch[14]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
75
old/middlewares/accesslog/parser_test.go
Normal file
75
old/middlewares/accesslog/parser_test.go
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestParseAccessLog(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
value string
|
||||
expected map[string]string
|
||||
}{
|
||||
{
|
||||
desc: "full log",
|
||||
value: `TestHost - TestUser [13/Apr/2016:07:14:19 -0700] "POST testpath HTTP/0.0" 123 12 "testReferer" "testUserAgent" 1 "testFrontend" "http://127.0.0.1/testBackend" 1ms`,
|
||||
expected: map[string]string{
|
||||
ClientHost: "TestHost",
|
||||
ClientUsername: "TestUser",
|
||||
StartUTC: "13/Apr/2016:07:14:19 -0700",
|
||||
RequestMethod: "POST",
|
||||
RequestPath: "testpath",
|
||||
RequestProtocol: "HTTP/0.0",
|
||||
OriginStatus: "123",
|
||||
OriginContentSize: "12",
|
||||
RequestRefererHeader: `"testReferer"`,
|
||||
RequestUserAgentHeader: `"testUserAgent"`,
|
||||
RequestCount: "1",
|
||||
FrontendName: `"testFrontend"`,
|
||||
BackendURL: `"http://127.0.0.1/testBackend"`,
|
||||
Duration: "1ms",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "log with space",
|
||||
value: `127.0.0.1 - - [09/Mar/2018:10:51:32 +0000] "GET / HTTP/1.1" 401 17 "-" "Go-http-client/1.1" 1 "testFrontend with space" - 0ms`,
|
||||
expected: map[string]string{
|
||||
ClientHost: "127.0.0.1",
|
||||
ClientUsername: "-",
|
||||
StartUTC: "09/Mar/2018:10:51:32 +0000",
|
||||
RequestMethod: "GET",
|
||||
RequestPath: "/",
|
||||
RequestProtocol: "HTTP/1.1",
|
||||
OriginStatus: "401",
|
||||
OriginContentSize: "17",
|
||||
RequestRefererHeader: `"-"`,
|
||||
RequestUserAgentHeader: `"Go-http-client/1.1"`,
|
||||
RequestCount: "1",
|
||||
FrontendName: `"testFrontend with space"`,
|
||||
BackendURL: `-`,
|
||||
Duration: "0ms",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "bad log",
|
||||
value: `bad`,
|
||||
expected: map[string]string{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result, err := ParseAccessLog(test.value)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(test.expected), len(result))
|
||||
for key, value := range test.expected {
|
||||
assert.Equal(t, value, result[key])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
64
old/middlewares/accesslog/save_backend.go
Normal file
64
old/middlewares/accesslog/save_backend.go
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/urfave/negroni"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
// SaveBackend sends the backend name to the logger.
|
||||
// These are always used with a corresponding SaveFrontend handler.
|
||||
type SaveBackend struct {
|
||||
next http.Handler
|
||||
backendName string
|
||||
}
|
||||
|
||||
// NewSaveBackend creates a SaveBackend handler.
|
||||
func NewSaveBackend(next http.Handler, backendName string) http.Handler {
|
||||
return &SaveBackend{next, backendName}
|
||||
}
|
||||
|
||||
func (sb *SaveBackend) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
serveSaveBackend(rw, r, sb.backendName, func(crw *captureResponseWriter) {
|
||||
sb.next.ServeHTTP(crw, r)
|
||||
})
|
||||
}
|
||||
|
||||
// SaveNegroniBackend sends the backend name to the logger.
|
||||
type SaveNegroniBackend struct {
|
||||
next negroni.Handler
|
||||
backendName string
|
||||
}
|
||||
|
||||
// NewSaveNegroniBackend creates a SaveBackend handler.
|
||||
func NewSaveNegroniBackend(next negroni.Handler, backendName string) negroni.Handler {
|
||||
return &SaveNegroniBackend{next, backendName}
|
||||
}
|
||||
|
||||
func (sb *SaveNegroniBackend) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
serveSaveBackend(rw, r, sb.backendName, func(crw *captureResponseWriter) {
|
||||
sb.next.ServeHTTP(crw, r, next)
|
||||
})
|
||||
}
|
||||
|
||||
func serveSaveBackend(rw http.ResponseWriter, r *http.Request, backendName string, apply func(*captureResponseWriter)) {
|
||||
table := GetLogDataTable(r)
|
||||
table.Core[BackendName] = backendName
|
||||
table.Core[BackendURL] = r.URL // note that this is *not* the original incoming URL
|
||||
table.Core[BackendAddr] = r.URL.Host
|
||||
|
||||
crw := &captureResponseWriter{rw: rw}
|
||||
start := time.Now().UTC()
|
||||
|
||||
apply(crw)
|
||||
|
||||
// use UTC to handle switchover of daylight saving correctly
|
||||
table.Core[OriginDuration] = time.Now().UTC().Sub(start)
|
||||
table.Core[OriginStatus] = crw.Status()
|
||||
// make copy of headers so we can ensure there is no subsequent mutation during response processing
|
||||
table.OriginResponse = make(http.Header)
|
||||
utils.CopyHeaders(table.OriginResponse, crw.Header())
|
||||
table.Core[OriginContentSize] = crw.Size()
|
||||
}
|
||||
51
old/middlewares/accesslog/save_frontend.go
Normal file
51
old/middlewares/accesslog/save_frontend.go
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
// SaveFrontend sends the frontend name to the logger.
|
||||
// These are sometimes used with a corresponding SaveBackend handler, but not always.
|
||||
// For example, redirected requests don't reach a backend.
|
||||
type SaveFrontend struct {
|
||||
next http.Handler
|
||||
frontendName string
|
||||
}
|
||||
|
||||
// NewSaveFrontend creates a SaveFrontend handler.
|
||||
func NewSaveFrontend(next http.Handler, frontendName string) http.Handler {
|
||||
return &SaveFrontend{next, frontendName}
|
||||
}
|
||||
|
||||
func (sf *SaveFrontend) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
serveSaveFrontend(r, sf.frontendName, func() {
|
||||
sf.next.ServeHTTP(rw, r)
|
||||
})
|
||||
}
|
||||
|
||||
// SaveNegroniFrontend sends the frontend name to the logger.
|
||||
type SaveNegroniFrontend struct {
|
||||
next negroni.Handler
|
||||
frontendName string
|
||||
}
|
||||
|
||||
// NewSaveNegroniFrontend creates a SaveNegroniFrontend handler.
|
||||
func NewSaveNegroniFrontend(next negroni.Handler, frontendName string) negroni.Handler {
|
||||
return &SaveNegroniFrontend{next, frontendName}
|
||||
}
|
||||
|
||||
func (sf *SaveNegroniFrontend) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
serveSaveFrontend(r, sf.frontendName, func() {
|
||||
sf.next.ServeHTTP(rw, r, next)
|
||||
})
|
||||
}
|
||||
|
||||
func serveSaveFrontend(r *http.Request, frontendName string, apply func()) {
|
||||
table := GetLogDataTable(r)
|
||||
table.Core[FrontendName] = strings.TrimPrefix(frontendName, "frontend-")
|
||||
|
||||
apply()
|
||||
}
|
||||
19
old/middlewares/accesslog/save_retries.go
Normal file
19
old/middlewares/accesslog/save_retries.go
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// SaveRetries is an implementation of RetryListener that stores RetryAttempts in the LogDataTable.
|
||||
type SaveRetries struct{}
|
||||
|
||||
// Retried implements the RetryListener interface and will be called for each retry that happens.
|
||||
func (s *SaveRetries) Retried(req *http.Request, attempt int) {
|
||||
// it is the request attempt x, but the retry attempt is x-1
|
||||
if attempt > 0 {
|
||||
attempt--
|
||||
}
|
||||
|
||||
table := GetLogDataTable(req)
|
||||
table.Core[RetryAttempts] = attempt
|
||||
}
|
||||
48
old/middlewares/accesslog/save_retries_test.go
Normal file
48
old/middlewares/accesslog/save_retries_test.go
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSaveRetries(t *testing.T) {
|
||||
tests := []struct {
|
||||
requestAttempt int
|
||||
wantRetryAttemptsInLog int
|
||||
}{
|
||||
{
|
||||
requestAttempt: 0,
|
||||
wantRetryAttemptsInLog: 0,
|
||||
},
|
||||
{
|
||||
requestAttempt: 1,
|
||||
wantRetryAttemptsInLog: 0,
|
||||
},
|
||||
{
|
||||
requestAttempt: 3,
|
||||
wantRetryAttemptsInLog: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
|
||||
t.Run(fmt.Sprintf("%d retries", test.requestAttempt), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
saveRetries := &SaveRetries{}
|
||||
|
||||
logDataTable := &LogData{Core: make(CoreLogData)}
|
||||
req := httptest.NewRequest(http.MethodGet, "/some/path", nil)
|
||||
reqWithDataTable := req.WithContext(context.WithValue(req.Context(), DataTableKey, logDataTable))
|
||||
|
||||
saveRetries.Retried(reqWithDataTable, test.requestAttempt)
|
||||
|
||||
if logDataTable.Core[RetryAttempts] != test.wantRetryAttemptsInLog {
|
||||
t.Errorf("got %v in logDataTable, want %v", logDataTable.Core[RetryAttempts], test.wantRetryAttemptsInLog)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
60
old/middlewares/accesslog/save_username.go
Normal file
60
old/middlewares/accesslog/save_username.go
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
package accesslog
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
const (
|
||||
clientUsernameKey key = "ClientUsername"
|
||||
)
|
||||
|
||||
// SaveUsername sends the Username name to the access logger.
|
||||
type SaveUsername struct {
|
||||
next http.Handler
|
||||
}
|
||||
|
||||
// NewSaveUsername creates a SaveUsername handler.
|
||||
func NewSaveUsername(next http.Handler) http.Handler {
|
||||
return &SaveUsername{next}
|
||||
}
|
||||
|
||||
func (sf *SaveUsername) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
serveSaveUsername(r, func() {
|
||||
sf.next.ServeHTTP(rw, r)
|
||||
})
|
||||
}
|
||||
|
||||
// SaveNegroniUsername adds the Username to the access logger data table.
|
||||
type SaveNegroniUsername struct {
|
||||
next negroni.Handler
|
||||
}
|
||||
|
||||
// NewSaveNegroniUsername creates a SaveNegroniUsername handler.
|
||||
func NewSaveNegroniUsername(next negroni.Handler) negroni.Handler {
|
||||
return &SaveNegroniUsername{next}
|
||||
}
|
||||
|
||||
func (sf *SaveNegroniUsername) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
serveSaveUsername(r, func() {
|
||||
sf.next.ServeHTTP(rw, r, next)
|
||||
})
|
||||
}
|
||||
|
||||
func serveSaveUsername(r *http.Request, apply func()) {
|
||||
table := GetLogDataTable(r)
|
||||
|
||||
username, ok := r.Context().Value(clientUsernameKey).(string)
|
||||
if ok {
|
||||
table.Core[ClientUsername] = username
|
||||
}
|
||||
|
||||
apply()
|
||||
}
|
||||
|
||||
// WithUserName adds a username to a requests' context
|
||||
func WithUserName(req *http.Request, username string) *http.Request {
|
||||
return req.WithContext(context.WithValue(req.Context(), clientUsernameKey, username))
|
||||
}
|
||||
35
old/middlewares/addPrefix.go
Normal file
35
old/middlewares/addPrefix.go
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// AddPrefix is a middleware used to add prefix to an URL request
|
||||
type AddPrefix struct {
|
||||
Handler http.Handler
|
||||
Prefix string
|
||||
}
|
||||
|
||||
type key string
|
||||
|
||||
const (
|
||||
// AddPrefixKey is the key within the request context used to
|
||||
// store the added prefix
|
||||
AddPrefixKey key = "AddPrefix"
|
||||
)
|
||||
|
||||
func (s *AddPrefix) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
r.URL.Path = s.Prefix + r.URL.Path
|
||||
if r.URL.RawPath != "" {
|
||||
r.URL.RawPath = s.Prefix + r.URL.RawPath
|
||||
}
|
||||
r.RequestURI = r.URL.RequestURI()
|
||||
r = r.WithContext(context.WithValue(r.Context(), AddPrefixKey, s.Prefix))
|
||||
s.Handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// SetHandler sets handler
|
||||
func (s *AddPrefix) SetHandler(Handler http.Handler) {
|
||||
s.Handler = Handler
|
||||
}
|
||||
66
old/middlewares/addPrefix_test.go
Normal file
66
old/middlewares/addPrefix_test.go
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAddPrefix(t *testing.T) {
|
||||
logrus.SetLevel(logrus.DebugLevel)
|
||||
tests := []struct {
|
||||
desc string
|
||||
prefix string
|
||||
path string
|
||||
expectedPath string
|
||||
expectedRawPath string
|
||||
}{
|
||||
{
|
||||
desc: "regular path",
|
||||
prefix: "/a",
|
||||
path: "/b",
|
||||
expectedPath: "/a/b",
|
||||
},
|
||||
{
|
||||
desc: "raw path is supported",
|
||||
prefix: "/a",
|
||||
path: "/b%2Fc",
|
||||
expectedPath: "/a/b/c",
|
||||
expectedRawPath: "/a/b%2Fc",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var actualPath, actualRawPath, requestURI string
|
||||
handler := &AddPrefix{
|
||||
Prefix: test.prefix,
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
actualPath = r.URL.Path
|
||||
actualRawPath = r.URL.RawPath
|
||||
requestURI = r.RequestURI
|
||||
}),
|
||||
}
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
|
||||
|
||||
handler.ServeHTTP(nil, req)
|
||||
|
||||
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
|
||||
assert.Equal(t, test.expectedRawPath, actualRawPath, "Unexpected raw path.")
|
||||
|
||||
expectedURI := test.expectedPath
|
||||
if test.expectedRawPath != "" {
|
||||
// go HTTP uses the raw path when existent in the RequestURI
|
||||
expectedURI = test.expectedRawPath
|
||||
}
|
||||
assert.Equal(t, expectedURI, requestURI, "Unexpected request URI.")
|
||||
})
|
||||
}
|
||||
}
|
||||
167
old/middlewares/auth/authenticator.go
Normal file
167
old/middlewares/auth/authenticator.go
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
goauth "github.com/abbot/go-http-auth"
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/containous/traefik/old/middlewares/accesslog"
|
||||
"github.com/containous/traefik/old/middlewares/tracing"
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
// Authenticator is a middleware that provides HTTP basic and digest authentication
|
||||
type Authenticator struct {
|
||||
handler negroni.Handler
|
||||
users map[string]string
|
||||
}
|
||||
|
||||
type tracingAuthenticator struct {
|
||||
name string
|
||||
handler negroni.Handler
|
||||
clientSpanKind bool
|
||||
}
|
||||
|
||||
const (
|
||||
authorizationHeader = "Authorization"
|
||||
)
|
||||
|
||||
// NewAuthenticator builds a new Authenticator given a config
|
||||
func NewAuthenticator(authConfig *types.Auth, tracingMiddleware *tracing.Tracing) (*Authenticator, error) {
|
||||
if authConfig == nil {
|
||||
return nil, fmt.Errorf("error creating Authenticator: auth is nil")
|
||||
}
|
||||
|
||||
var err error
|
||||
authenticator := &Authenticator{}
|
||||
tracingAuth := tracingAuthenticator{}
|
||||
|
||||
if authConfig.Basic != nil {
|
||||
authenticator.users, err = parserBasicUsers(authConfig.Basic)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
realm := "traefik"
|
||||
if authConfig.Basic.Realm != "" {
|
||||
realm = authConfig.Basic.Realm
|
||||
}
|
||||
basicAuth := goauth.NewBasicAuthenticator(realm, authenticator.secretBasic)
|
||||
tracingAuth.handler = createAuthBasicHandler(basicAuth, authConfig)
|
||||
tracingAuth.name = "Auth Basic"
|
||||
tracingAuth.clientSpanKind = false
|
||||
} else if authConfig.Digest != nil {
|
||||
authenticator.users, err = parserDigestUsers(authConfig.Digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
digestAuth := goauth.NewDigestAuthenticator("traefik", authenticator.secretDigest)
|
||||
tracingAuth.handler = createAuthDigestHandler(digestAuth, authConfig)
|
||||
tracingAuth.name = "Auth Digest"
|
||||
tracingAuth.clientSpanKind = false
|
||||
} else if authConfig.Forward != nil {
|
||||
tracingAuth.handler = createAuthForwardHandler(authConfig)
|
||||
tracingAuth.name = "Auth Forward"
|
||||
tracingAuth.clientSpanKind = true
|
||||
}
|
||||
|
||||
if tracingMiddleware != nil {
|
||||
authenticator.handler = tracingMiddleware.NewNegroniHandlerWrapper(tracingAuth.name, tracingAuth.handler, tracingAuth.clientSpanKind)
|
||||
} else {
|
||||
authenticator.handler = tracingAuth.handler
|
||||
}
|
||||
return authenticator, nil
|
||||
}
|
||||
|
||||
func createAuthForwardHandler(authConfig *types.Auth) negroni.HandlerFunc {
|
||||
return negroni.HandlerFunc(func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
Forward(authConfig.Forward, w, r, next)
|
||||
})
|
||||
}
|
||||
|
||||
func createAuthDigestHandler(digestAuth *goauth.DigestAuth, authConfig *types.Auth) negroni.HandlerFunc {
|
||||
return negroni.HandlerFunc(func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
if username, _ := digestAuth.CheckAuth(r); username == "" {
|
||||
log.Debugf("Digest auth failed")
|
||||
digestAuth.RequireAuth(w, r)
|
||||
} else {
|
||||
log.Debugf("Digest auth succeeded")
|
||||
|
||||
// set username in request context
|
||||
r = accesslog.WithUserName(r, username)
|
||||
|
||||
if authConfig.HeaderField != "" {
|
||||
r.Header[authConfig.HeaderField] = []string{username}
|
||||
}
|
||||
if authConfig.Digest.RemoveHeader {
|
||||
log.Debugf("Remove the Authorization header from the Digest auth")
|
||||
r.Header.Del(authorizationHeader)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func createAuthBasicHandler(basicAuth *goauth.BasicAuth, authConfig *types.Auth) negroni.HandlerFunc {
|
||||
return negroni.HandlerFunc(func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
if username := basicAuth.CheckAuth(r); username == "" {
|
||||
log.Debugf("Basic auth failed")
|
||||
basicAuth.RequireAuth(w, r)
|
||||
} else {
|
||||
log.Debugf("Basic auth succeeded")
|
||||
|
||||
// set username in request context
|
||||
r = accesslog.WithUserName(r, username)
|
||||
|
||||
if authConfig.HeaderField != "" {
|
||||
r.Header[authConfig.HeaderField] = []string{username}
|
||||
}
|
||||
if authConfig.Basic.RemoveHeader {
|
||||
log.Debugf("Remove the Authorization header from the Basic auth")
|
||||
r.Header.Del(authorizationHeader)
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func getLinesFromFile(filename string) ([]string, error) {
|
||||
dat, err := ioutil.ReadFile(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Trim lines and filter out blanks
|
||||
rawLines := strings.Split(string(dat), "\n")
|
||||
var filteredLines []string
|
||||
for _, rawLine := range rawLines {
|
||||
line := strings.TrimSpace(rawLine)
|
||||
if line != "" {
|
||||
filteredLines = append(filteredLines, line)
|
||||
}
|
||||
}
|
||||
return filteredLines, nil
|
||||
}
|
||||
|
||||
func (a *Authenticator) secretBasic(user, realm string) string {
|
||||
if secret, ok := a.users[user]; ok {
|
||||
return secret
|
||||
}
|
||||
log.Debugf("User not found: %s", user)
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *Authenticator) secretDigest(user, realm string) string {
|
||||
if secret, ok := a.users[user+":"+realm]; ok {
|
||||
return secret
|
||||
}
|
||||
log.Debugf("User not found: %s:%s", user, realm)
|
||||
return ""
|
||||
}
|
||||
|
||||
func (a *Authenticator) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
a.handler.ServeHTTP(rw, r, next)
|
||||
}
|
||||
297
old/middlewares/auth/authenticator_test.go
Normal file
297
old/middlewares/auth/authenticator_test.go
Normal file
|
|
@ -0,0 +1,297 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/old/middlewares/tracing"
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
func TestAuthUsersFromFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
authType string
|
||||
usersStr string
|
||||
userKeys []string
|
||||
parserFunc func(fileName string) (map[string]string, error)
|
||||
}{
|
||||
{
|
||||
authType: "basic",
|
||||
usersStr: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
|
||||
userKeys: []string{"test", "test2"},
|
||||
parserFunc: func(fileName string) (map[string]string, error) {
|
||||
basic := &types.Basic{
|
||||
UsersFile: fileName,
|
||||
}
|
||||
return parserBasicUsers(basic)
|
||||
},
|
||||
},
|
||||
{
|
||||
authType: "digest",
|
||||
usersStr: "test:traefik:a2688e031edb4be6a3797f3882655c05 \ntest2:traefik:518845800f9e2bfb1f1f740ec24f074e\n",
|
||||
userKeys: []string{"test:traefik", "test2:traefik"},
|
||||
parserFunc: func(fileName string) (map[string]string, error) {
|
||||
digest := &types.Digest{
|
||||
UsersFile: fileName,
|
||||
}
|
||||
return parserDigestUsers(digest)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.authType, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
usersFile, err := ioutil.TempFile("", "auth-users")
|
||||
require.NoError(t, err)
|
||||
defer os.Remove(usersFile.Name())
|
||||
|
||||
_, err = usersFile.Write([]byte(test.usersStr))
|
||||
require.NoError(t, err)
|
||||
|
||||
users, err := test.parserFunc(usersFile.Name())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, len(users), "they should be equal")
|
||||
|
||||
_, ok := users[test.userKeys[0]]
|
||||
assert.True(t, ok, "user test should be found")
|
||||
_, ok = users[test.userKeys[1]]
|
||||
assert.True(t, ok, "user test2 should be found")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicAuthFail(t *testing.T) {
|
||||
_, err := NewAuthenticator(&types.Auth{
|
||||
Basic: &types.Basic{
|
||||
Users: []string{"test"},
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
assert.Contains(t, err.Error(), "error parsing Authenticator user", "should contains")
|
||||
|
||||
authMiddleware, err := NewAuthenticator(&types.Auth{
|
||||
Basic: &types.Basic{
|
||||
Users: []string{"test:test"},
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(authMiddleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{}
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusUnauthorized, res.StatusCode, "they should be equal")
|
||||
}
|
||||
|
||||
func TestBasicAuthSuccess(t *testing.T) {
|
||||
authMiddleware, err := NewAuthenticator(&types.Auth{
|
||||
Basic: &types.Basic{
|
||||
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(authMiddleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{}
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "traefik\n", string(body), "they should be equal")
|
||||
}
|
||||
|
||||
func TestBasicRealm(t *testing.T) {
|
||||
authMiddlewareDefaultRealm, errdefault := NewAuthenticator(&types.Auth{
|
||||
Basic: &types.Basic{
|
||||
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
require.NoError(t, errdefault)
|
||||
|
||||
authMiddlewareCustomRealm, errcustom := NewAuthenticator(&types.Auth{
|
||||
Basic: &types.Basic{
|
||||
Realm: "foobar",
|
||||
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
require.NoError(t, errcustom)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
n := negroni.New(authMiddlewareDefaultRealm)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{}
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Basic realm=\"traefik\"", res.Header.Get("Www-Authenticate"), "they should be equal")
|
||||
|
||||
n = negroni.New(authMiddlewareCustomRealm)
|
||||
n.UseHandler(handler)
|
||||
ts = httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
client = &http.Client{}
|
||||
req = testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
res, err = client.Do(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "Basic realm=\"foobar\"", res.Header.Get("Www-Authenticate"), "they should be equal")
|
||||
}
|
||||
|
||||
func TestDigestAuthFail(t *testing.T) {
|
||||
_, err := NewAuthenticator(&types.Auth{
|
||||
Digest: &types.Digest{
|
||||
Users: []string{"test"},
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
assert.Contains(t, err.Error(), "error parsing Authenticator user", "should contains")
|
||||
|
||||
authMiddleware, err := NewAuthenticator(&types.Auth{
|
||||
Digest: &types.Digest{
|
||||
Users: []string{"test:traefik:test"},
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, authMiddleware, "this should not be nil")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(authMiddleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{}
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusUnauthorized, res.StatusCode, "they should be equal")
|
||||
}
|
||||
|
||||
func TestBasicAuthUserHeader(t *testing.T) {
|
||||
middleware, err := NewAuthenticator(&types.Auth{
|
||||
Basic: &types.Basic{
|
||||
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
|
||||
},
|
||||
HeaderField: "X-Webauth-User",
|
||||
}, &tracing.Tracing{})
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "test", r.Header["X-Webauth-User"][0], "auth user should be set")
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(middleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{}
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "traefik\n", string(body), "they should be equal")
|
||||
}
|
||||
|
||||
func TestBasicAuthHeaderRemoved(t *testing.T) {
|
||||
middleware, err := NewAuthenticator(&types.Auth{
|
||||
Basic: &types.Basic{
|
||||
RemoveHeader: true,
|
||||
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Empty(t, r.Header.Get(authorizationHeader))
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(middleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{}
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "traefik\n", string(body), "they should be equal")
|
||||
}
|
||||
|
||||
func TestBasicAuthHeaderPresent(t *testing.T) {
|
||||
middleware, err := NewAuthenticator(&types.Auth{
|
||||
Basic: &types.Basic{
|
||||
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.NotEmpty(t, r.Header.Get(authorizationHeader))
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(middleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{}
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.SetBasicAuth("test", "test")
|
||||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "traefik\n", string(body), "they should be equal")
|
||||
}
|
||||
157
old/middlewares/auth/forward.go
Normal file
157
old/middlewares/auth/forward.go
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/containous/traefik/old/middlewares/tracing"
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/vulcand/oxy/forward"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
xForwardedURI = "X-Forwarded-Uri"
|
||||
xForwardedMethod = "X-Forwarded-Method"
|
||||
)
|
||||
|
||||
// Forward the authentication to a external server
|
||||
func Forward(config *types.Forward, w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
// Ensure our request client does not follow redirects
|
||||
httpClient := http.Client{
|
||||
CheckRedirect: func(r *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
|
||||
if config.TLS != nil {
|
||||
tlsConfig, err := config.TLS.CreateTLSConfig()
|
||||
if err != nil {
|
||||
tracing.SetErrorAndDebugLog(r, "Unable to configure TLS to call %s. Cause %s", config.Address, err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
httpClient.Transport = &http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
}
|
||||
}
|
||||
|
||||
forwardReq, err := http.NewRequest(http.MethodGet, config.Address, http.NoBody)
|
||||
tracing.LogRequest(tracing.GetSpan(r), forwardReq)
|
||||
if err != nil {
|
||||
tracing.SetErrorAndDebugLog(r, "Error calling %s. Cause %s", config.Address, err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeHeader(r, forwardReq, config.TrustForwardHeader)
|
||||
|
||||
tracing.InjectRequestHeaders(forwardReq)
|
||||
|
||||
forwardResponse, forwardErr := httpClient.Do(forwardReq)
|
||||
if forwardErr != nil {
|
||||
tracing.SetErrorAndDebugLog(r, "Error calling %s. Cause: %s", config.Address, forwardErr)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
body, readError := ioutil.ReadAll(forwardResponse.Body)
|
||||
if readError != nil {
|
||||
tracing.SetErrorAndDebugLog(r, "Error reading body %s. Cause: %s", config.Address, readError)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
defer forwardResponse.Body.Close()
|
||||
|
||||
// Pass the forward response's body and selected headers if it
|
||||
// didn't return a response within the range of [200, 300).
|
||||
if forwardResponse.StatusCode < http.StatusOK || forwardResponse.StatusCode >= http.StatusMultipleChoices {
|
||||
log.Debugf("Remote error %s. StatusCode: %d", config.Address, forwardResponse.StatusCode)
|
||||
|
||||
utils.CopyHeaders(w.Header(), forwardResponse.Header)
|
||||
utils.RemoveHeaders(w.Header(), forward.HopHeaders...)
|
||||
|
||||
// Grab the location header, if any.
|
||||
redirectURL, err := forwardResponse.Location()
|
||||
|
||||
if err != nil {
|
||||
if err != http.ErrNoLocation {
|
||||
tracing.SetErrorAndDebugLog(r, "Error reading response location header %s. Cause: %s", config.Address, err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
} else if redirectURL.String() != "" {
|
||||
// Set the location in our response if one was sent back.
|
||||
w.Header().Set("Location", redirectURL.String())
|
||||
}
|
||||
|
||||
tracing.LogResponseCode(tracing.GetSpan(r), forwardResponse.StatusCode)
|
||||
w.WriteHeader(forwardResponse.StatusCode)
|
||||
|
||||
if _, err = w.Write(body); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, headerName := range config.AuthResponseHeaders {
|
||||
r.Header.Set(headerName, forwardResponse.Header.Get(headerName))
|
||||
}
|
||||
|
||||
r.RequestURI = r.URL.RequestURI()
|
||||
next(w, r)
|
||||
}
|
||||
|
||||
func writeHeader(req *http.Request, forwardReq *http.Request, trustForwardHeader bool) {
|
||||
utils.CopyHeaders(forwardReq.Header, req.Header)
|
||||
utils.RemoveHeaders(forwardReq.Header, forward.HopHeaders...)
|
||||
|
||||
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||
if trustForwardHeader {
|
||||
if prior, ok := req.Header[forward.XForwardedFor]; ok {
|
||||
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||||
}
|
||||
}
|
||||
forwardReq.Header.Set(forward.XForwardedFor, clientIP)
|
||||
}
|
||||
|
||||
if xMethod := req.Header.Get(xForwardedMethod); xMethod != "" && trustForwardHeader {
|
||||
forwardReq.Header.Set(xForwardedMethod, xMethod)
|
||||
} else if req.Method != "" {
|
||||
forwardReq.Header.Set(xForwardedMethod, req.Method)
|
||||
} else {
|
||||
forwardReq.Header.Del(xForwardedMethod)
|
||||
}
|
||||
|
||||
if xfp := req.Header.Get(forward.XForwardedProto); xfp != "" && trustForwardHeader {
|
||||
forwardReq.Header.Set(forward.XForwardedProto, xfp)
|
||||
} else if req.TLS != nil {
|
||||
forwardReq.Header.Set(forward.XForwardedProto, "https")
|
||||
} else {
|
||||
forwardReq.Header.Set(forward.XForwardedProto, "http")
|
||||
}
|
||||
|
||||
if xfp := req.Header.Get(forward.XForwardedPort); xfp != "" && trustForwardHeader {
|
||||
forwardReq.Header.Set(forward.XForwardedPort, xfp)
|
||||
}
|
||||
|
||||
if xfh := req.Header.Get(forward.XForwardedHost); xfh != "" && trustForwardHeader {
|
||||
forwardReq.Header.Set(forward.XForwardedHost, xfh)
|
||||
} else if req.Host != "" {
|
||||
forwardReq.Header.Set(forward.XForwardedHost, req.Host)
|
||||
} else {
|
||||
forwardReq.Header.Del(forward.XForwardedHost)
|
||||
}
|
||||
|
||||
if xfURI := req.Header.Get(xForwardedURI); xfURI != "" && trustForwardHeader {
|
||||
forwardReq.Header.Set(xForwardedURI, xfURI)
|
||||
} else if req.URL.RequestURI() != "" {
|
||||
forwardReq.Header.Set(xForwardedURI, req.URL.RequestURI())
|
||||
} else {
|
||||
forwardReq.Header.Del(xForwardedURI)
|
||||
}
|
||||
}
|
||||
392
old/middlewares/auth/forward_test.go
Normal file
392
old/middlewares/auth/forward_test.go
Normal file
|
|
@ -0,0 +1,392 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/old/middlewares/tracing"
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/urfave/negroni"
|
||||
"github.com/vulcand/oxy/forward"
|
||||
)
|
||||
|
||||
func TestForwardAuthFail(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
middleware, err := NewAuthenticator(&types.Auth{
|
||||
Forward: &types.Forward{
|
||||
Address: server.URL,
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(middleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, http.StatusForbidden, res.StatusCode, "they should be equal")
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, "Forbidden\n", string(body), "they should be equal")
|
||||
}
|
||||
|
||||
func TestForwardAuthSuccess(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Auth-User", "user@example.com")
|
||||
w.Header().Set("X-Auth-Secret", "secret")
|
||||
fmt.Fprintln(w, "Success")
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
middleware, err := NewAuthenticator(&types.Auth{
|
||||
Forward: &types.Forward{
|
||||
Address: server.URL,
|
||||
AuthResponseHeaders: []string{"X-Auth-User"},
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "user@example.com", r.Header.Get("X-Auth-User"))
|
||||
assert.Empty(t, r.Header.Get("X-Auth-Secret"))
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(middleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, "traefik\n", string(body), "they should be equal")
|
||||
}
|
||||
|
||||
func TestForwardAuthRedirect(t *testing.T) {
|
||||
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "http://example.com/redirect-test", http.StatusFound)
|
||||
}))
|
||||
defer authTs.Close()
|
||||
|
||||
authMiddleware, err := NewAuthenticator(&types.Auth{
|
||||
Forward: &types.Forward{
|
||||
Address: authTs.URL,
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(authMiddleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{
|
||||
CheckRedirect: func(r *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
res, err := client.Do(req)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, http.StatusFound, res.StatusCode, "they should be equal")
|
||||
|
||||
location, err := res.Location()
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, "http://example.com/redirect-test", location.String(), "they should be equal")
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.NotEmpty(t, string(body), "there should be something in the body")
|
||||
}
|
||||
|
||||
func TestForwardAuthRemoveHopByHopHeaders(t *testing.T) {
|
||||
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
headers := w.Header()
|
||||
for _, header := range forward.HopHeaders {
|
||||
if header == forward.TransferEncoding {
|
||||
headers.Add(header, "identity")
|
||||
} else {
|
||||
headers.Add(header, "test")
|
||||
}
|
||||
}
|
||||
|
||||
http.Redirect(w, r, "http://example.com/redirect-test", http.StatusFound)
|
||||
}))
|
||||
defer authTs.Close()
|
||||
|
||||
authMiddleware, err := NewAuthenticator(&types.Auth{
|
||||
Forward: &types.Forward{
|
||||
Address: authTs.URL,
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(authMiddleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
client := &http.Client{
|
||||
CheckRedirect: func(r *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
res, err := client.Do(req)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, http.StatusFound, res.StatusCode, "they should be equal")
|
||||
|
||||
for _, header := range forward.HopHeaders {
|
||||
assert.Equal(t, "", res.Header.Get(header), "hop-by-hop header '%s' mustn't be set", header)
|
||||
}
|
||||
|
||||
location, err := res.Location()
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, "http://example.com/redirect-test", location.String(), "they should be equal")
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.NotEmpty(t, string(body), "there should be something in the body")
|
||||
}
|
||||
|
||||
func TestForwardAuthFailResponseHeaders(t *testing.T) {
|
||||
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cookie := &http.Cookie{Name: "example", Value: "testing", Path: "/"}
|
||||
http.SetCookie(w, cookie)
|
||||
w.Header().Add("X-Foo", "bar")
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
}))
|
||||
defer authTs.Close()
|
||||
|
||||
authMiddleware, err := NewAuthenticator(&types.Auth{
|
||||
Forward: &types.Forward{
|
||||
Address: authTs.URL,
|
||||
},
|
||||
}, &tracing.Tracing{})
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
n := negroni.New(authMiddleware)
|
||||
n.UseHandler(handler)
|
||||
ts := httptest.NewServer(n)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
client := &http.Client{}
|
||||
res, err := client.Do(req)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, http.StatusForbidden, res.StatusCode, "they should be equal")
|
||||
|
||||
require.Len(t, res.Cookies(), 1)
|
||||
for _, cookie := range res.Cookies() {
|
||||
assert.Equal(t, "testing", cookie.Value, "they should be equal")
|
||||
}
|
||||
|
||||
expectedHeaders := http.Header{
|
||||
"Content-Length": []string{"10"},
|
||||
"Content-Type": []string{"text/plain; charset=utf-8"},
|
||||
"X-Foo": []string{"bar"},
|
||||
"Set-Cookie": []string{"example=testing; Path=/"},
|
||||
"X-Content-Type-Options": []string{"nosniff"},
|
||||
}
|
||||
|
||||
assert.Len(t, res.Header, 6)
|
||||
for key, value := range expectedHeaders {
|
||||
assert.Equal(t, value, res.Header[key])
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err, "there should be no error")
|
||||
assert.Equal(t, "Forbidden\n", string(body), "they should be equal")
|
||||
}
|
||||
|
||||
func Test_writeHeader(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
trustForwardHeader bool
|
||||
emptyHost bool
|
||||
expectedHeaders map[string]string
|
||||
checkForUnexpectedHeaders bool
|
||||
}{
|
||||
{
|
||||
name: "trust Forward Header",
|
||||
headers: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "fii.bir",
|
||||
},
|
||||
trustForwardHeader: true,
|
||||
expectedHeaders: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "fii.bir",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not trust Forward Header",
|
||||
headers: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "fii.bir",
|
||||
},
|
||||
trustForwardHeader: false,
|
||||
expectedHeaders: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "foo.bar",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "trust Forward Header with empty Host",
|
||||
headers: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "fii.bir",
|
||||
},
|
||||
trustForwardHeader: true,
|
||||
emptyHost: true,
|
||||
expectedHeaders: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "fii.bir",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not trust Forward Header with empty Host",
|
||||
headers: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "fii.bir",
|
||||
},
|
||||
trustForwardHeader: false,
|
||||
emptyHost: true,
|
||||
expectedHeaders: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "trust Forward Header with forwarded URI",
|
||||
headers: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "fii.bir",
|
||||
"X-Forwarded-Uri": "/forward?q=1",
|
||||
},
|
||||
trustForwardHeader: true,
|
||||
expectedHeaders: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "fii.bir",
|
||||
"X-Forwarded-Uri": "/forward?q=1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not trust Forward Header with forward requested URI",
|
||||
headers: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "fii.bir",
|
||||
"X-Forwarded-Uri": "/forward?q=1",
|
||||
},
|
||||
trustForwardHeader: false,
|
||||
expectedHeaders: map[string]string{
|
||||
"Accept": "application/json",
|
||||
"X-Forwarded-Host": "foo.bar",
|
||||
"X-Forwarded-Uri": "/path?q=1",
|
||||
},
|
||||
}, {
|
||||
name: "trust Forward Header with forwarded request Method",
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Method": "OPTIONS",
|
||||
},
|
||||
trustForwardHeader: true,
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-Method": "OPTIONS",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not trust Forward Header with forward request Method",
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-Method": "OPTIONS",
|
||||
},
|
||||
trustForwardHeader: false,
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-Method": "GET",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "remove hop-by-hop headers",
|
||||
headers: map[string]string{
|
||||
forward.Connection: "Connection",
|
||||
forward.KeepAlive: "KeepAlive",
|
||||
forward.ProxyAuthenticate: "ProxyAuthenticate",
|
||||
forward.ProxyAuthorization: "ProxyAuthorization",
|
||||
forward.Te: "Te",
|
||||
forward.Trailers: "Trailers",
|
||||
forward.TransferEncoding: "TransferEncoding",
|
||||
forward.Upgrade: "Upgrade",
|
||||
"X-CustomHeader": "CustomHeader",
|
||||
},
|
||||
trustForwardHeader: false,
|
||||
expectedHeaders: map[string]string{
|
||||
"X-CustomHeader": "CustomHeader",
|
||||
"X-Forwarded-Proto": "http",
|
||||
"X-Forwarded-Host": "foo.bar",
|
||||
"X-Forwarded-Uri": "/path?q=1",
|
||||
"X-Forwarded-Method": "GET",
|
||||
},
|
||||
checkForUnexpectedHeaders: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/path?q=1", nil)
|
||||
for key, value := range test.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
if test.emptyHost {
|
||||
req.Host = ""
|
||||
}
|
||||
|
||||
forwardReq := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/path?q=1", nil)
|
||||
|
||||
writeHeader(req, forwardReq, test.trustForwardHeader)
|
||||
|
||||
actualHeaders := forwardReq.Header
|
||||
expectedHeaders := test.expectedHeaders
|
||||
for key, value := range expectedHeaders {
|
||||
assert.Equal(t, value, actualHeaders.Get(key))
|
||||
actualHeaders.Del(key)
|
||||
}
|
||||
if test.checkForUnexpectedHeaders {
|
||||
for key := range actualHeaders {
|
||||
assert.Fail(t, "Unexpected header found", key)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
48
old/middlewares/auth/parser.go
Normal file
48
old/middlewares/auth/parser.go
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/old/types"
|
||||
)
|
||||
|
||||
func parserBasicUsers(basic *types.Basic) (map[string]string, error) {
|
||||
var userStrs []string
|
||||
if basic.UsersFile != "" {
|
||||
var err error
|
||||
if userStrs, err = getLinesFromFile(basic.UsersFile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
userStrs = append(basic.Users, userStrs...)
|
||||
userMap := make(map[string]string)
|
||||
for _, user := range userStrs {
|
||||
split := strings.Split(user, ":")
|
||||
if len(split) != 2 {
|
||||
return nil, fmt.Errorf("error parsing Authenticator user: %v", user)
|
||||
}
|
||||
userMap[split[0]] = split[1]
|
||||
}
|
||||
return userMap, nil
|
||||
}
|
||||
|
||||
func parserDigestUsers(digest *types.Digest) (map[string]string, error) {
|
||||
var userStrs []string
|
||||
if digest.UsersFile != "" {
|
||||
var err error
|
||||
if userStrs, err = getLinesFromFile(digest.UsersFile); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
userStrs = append(digest.Users, userStrs...)
|
||||
userMap := make(map[string]string)
|
||||
for _, user := range userStrs {
|
||||
split := strings.Split(user, ":")
|
||||
if len(split) != 3 {
|
||||
return nil, fmt.Errorf("error parsing Authenticator user: %v", user)
|
||||
}
|
||||
userMap[split[0]+":"+split[1]] = split[2]
|
||||
}
|
||||
return userMap, nil
|
||||
}
|
||||
40
old/middlewares/cbreaker.go
Normal file
40
old/middlewares/cbreaker.go
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/containous/traefik/old/middlewares/tracing"
|
||||
"github.com/vulcand/oxy/cbreaker"
|
||||
)
|
||||
|
||||
// CircuitBreaker holds the oxy circuit breaker.
|
||||
type CircuitBreaker struct {
|
||||
circuitBreaker *cbreaker.CircuitBreaker
|
||||
}
|
||||
|
||||
// NewCircuitBreaker returns a new CircuitBreaker.
|
||||
func NewCircuitBreaker(next http.Handler, expression string, options ...cbreaker.CircuitBreakerOption) (*CircuitBreaker, error) {
|
||||
circuitBreaker, err := cbreaker.New(next, expression, options...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &CircuitBreaker{circuitBreaker}, nil
|
||||
}
|
||||
|
||||
// NewCircuitBreakerOptions returns a new CircuitBreakerOption
|
||||
func NewCircuitBreakerOptions(expression string) cbreaker.CircuitBreakerOption {
|
||||
return cbreaker.Fallback(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tracing.LogEventf(r, "blocked by circuit-breaker (%q)", expression)
|
||||
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
|
||||
if _, err := w.Write([]byte(http.StatusText(http.StatusServiceUnavailable))); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func (cb *CircuitBreaker) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
cb.circuitBreaker.ServeHTTP(rw, r)
|
||||
}
|
||||
33
old/middlewares/compress.go
Normal file
33
old/middlewares/compress.go
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/containous/traefik/old/log"
|
||||
)
|
||||
|
||||
// Compress is a middleware that allows to compress the response
|
||||
type Compress struct{}
|
||||
|
||||
// ServeHTTP is a function used by Negroni
|
||||
func (c *Compress) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
if strings.HasPrefix(contentType, "application/grpc") {
|
||||
next.ServeHTTP(rw, r)
|
||||
} else {
|
||||
gzipHandler(next).ServeHTTP(rw, r)
|
||||
}
|
||||
}
|
||||
|
||||
func gzipHandler(h http.Handler) http.Handler {
|
||||
wrapper, err := gziphandler.GzipHandlerWithOpts(
|
||||
gziphandler.CompressionLevel(gzip.DefaultCompression),
|
||||
gziphandler.MinSize(gziphandler.DefaultMinSize))
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
return wrapper(h)
|
||||
}
|
||||
248
old/middlewares/compress_test.go
Normal file
248
old/middlewares/compress_test.go
Normal file
|
|
@ -0,0 +1,248 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
const (
|
||||
acceptEncodingHeader = "Accept-Encoding"
|
||||
contentEncodingHeader = "Content-Encoding"
|
||||
contentTypeHeader = "Content-Type"
|
||||
varyHeader = "Vary"
|
||||
gzipValue = "gzip"
|
||||
)
|
||||
|
||||
func TestShouldCompressWhenNoContentEncodingHeader(t *testing.T) {
|
||||
handler := &Compress{}
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
|
||||
req.Header.Add(acceptEncodingHeader, gzipValue)
|
||||
|
||||
baseBody := generateBytes(gziphandler.DefaultMinSize)
|
||||
next := func(rw http.ResponseWriter, r *http.Request) {
|
||||
_, err := rw.Write(baseBody)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rw, req, next)
|
||||
|
||||
assert.Equal(t, gzipValue, rw.Header().Get(contentEncodingHeader))
|
||||
assert.Equal(t, acceptEncodingHeader, rw.Header().Get(varyHeader))
|
||||
|
||||
if assert.ObjectsAreEqualValues(rw.Body.Bytes(), baseBody) {
|
||||
assert.Fail(t, "expected a compressed body", "got %v", rw.Body.Bytes())
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldNotCompressWhenContentEncodingHeader(t *testing.T) {
|
||||
handler := &Compress{}
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
|
||||
req.Header.Add(acceptEncodingHeader, gzipValue)
|
||||
|
||||
fakeCompressedBody := generateBytes(gziphandler.DefaultMinSize)
|
||||
next := func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Add(contentEncodingHeader, gzipValue)
|
||||
rw.Header().Add(varyHeader, acceptEncodingHeader)
|
||||
rw.Write(fakeCompressedBody)
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rw, req, next)
|
||||
|
||||
assert.Equal(t, gzipValue, rw.Header().Get(contentEncodingHeader))
|
||||
assert.Equal(t, acceptEncodingHeader, rw.Header().Get(varyHeader))
|
||||
|
||||
assert.EqualValues(t, rw.Body.Bytes(), fakeCompressedBody)
|
||||
}
|
||||
|
||||
func TestShouldNotCompressWhenNoAcceptEncodingHeader(t *testing.T) {
|
||||
handler := &Compress{}
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
|
||||
|
||||
fakeBody := generateBytes(gziphandler.DefaultMinSize)
|
||||
next := func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Write(fakeBody)
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rw, req, next)
|
||||
|
||||
assert.Empty(t, rw.Header().Get(contentEncodingHeader))
|
||||
assert.EqualValues(t, rw.Body.Bytes(), fakeBody)
|
||||
}
|
||||
|
||||
func TestShouldNotCompressWhenGRPC(t *testing.T) {
|
||||
handler := &Compress{}
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
|
||||
req.Header.Add(acceptEncodingHeader, gzipValue)
|
||||
req.Header.Add(contentTypeHeader, "application/grpc")
|
||||
|
||||
baseBody := generateBytes(gziphandler.DefaultMinSize)
|
||||
next := func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Write(baseBody)
|
||||
}
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rw, req, next)
|
||||
|
||||
assert.Empty(t, rw.Header().Get(acceptEncodingHeader))
|
||||
assert.Empty(t, rw.Header().Get(contentEncodingHeader))
|
||||
assert.EqualValues(t, rw.Body.Bytes(), baseBody)
|
||||
}
|
||||
|
||||
func TestIntegrationShouldNotCompress(t *testing.T) {
|
||||
fakeCompressedBody := generateBytes(100000)
|
||||
comp := &Compress{}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
handler func(rw http.ResponseWriter, r *http.Request)
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "when content already compressed",
|
||||
handler: func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Add(contentEncodingHeader, gzipValue)
|
||||
rw.Header().Add(varyHeader, acceptEncodingHeader)
|
||||
rw.Write(fakeCompressedBody)
|
||||
},
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "when content already compressed and status code Created",
|
||||
handler: func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Add(contentEncodingHeader, gzipValue)
|
||||
rw.Header().Add(varyHeader, acceptEncodingHeader)
|
||||
rw.WriteHeader(http.StatusCreated)
|
||||
rw.Write(fakeCompressedBody)
|
||||
},
|
||||
expectedStatusCode: http.StatusCreated,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
negro := negroni.New(comp)
|
||||
negro.UseHandlerFunc(test.handler)
|
||||
ts := httptest.NewServer(negro)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.Header.Add(acceptEncodingHeader, gzipValue)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.expectedStatusCode, resp.StatusCode)
|
||||
|
||||
assert.Equal(t, gzipValue, resp.Header.Get(contentEncodingHeader))
|
||||
assert.Equal(t, acceptEncodingHeader, resp.Header.Get(varyHeader))
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
assert.EqualValues(t, fakeCompressedBody, body)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldWriteHeaderWhenFlush(t *testing.T) {
|
||||
comp := &Compress{}
|
||||
negro := negroni.New(comp)
|
||||
negro.UseHandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Header().Add(contentEncodingHeader, gzipValue)
|
||||
rw.Header().Add(varyHeader, acceptEncodingHeader)
|
||||
rw.WriteHeader(http.StatusUnauthorized)
|
||||
rw.(http.Flusher).Flush()
|
||||
rw.Write([]byte("short"))
|
||||
})
|
||||
ts := httptest.NewServer(negro)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.Header.Add(acceptEncodingHeader, gzipValue)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
|
||||
assert.Equal(t, gzipValue, resp.Header.Get(contentEncodingHeader))
|
||||
assert.Equal(t, acceptEncodingHeader, resp.Header.Get(varyHeader))
|
||||
}
|
||||
|
||||
func TestIntegrationShouldCompress(t *testing.T) {
|
||||
fakeBody := generateBytes(100000)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
handler func(rw http.ResponseWriter, r *http.Request)
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "when AcceptEncoding header is present",
|
||||
handler: func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.Write(fakeBody)
|
||||
},
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "when AcceptEncoding header is present and status code Created",
|
||||
handler: func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(http.StatusCreated)
|
||||
rw.Write(fakeBody)
|
||||
},
|
||||
expectedStatusCode: http.StatusCreated,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
comp := &Compress{}
|
||||
|
||||
negro := negroni.New(comp)
|
||||
negro.UseHandlerFunc(test.handler)
|
||||
ts := httptest.NewServer(negro)
|
||||
defer ts.Close()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.Header.Add(acceptEncodingHeader, gzipValue)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.expectedStatusCode, resp.StatusCode)
|
||||
|
||||
assert.Equal(t, gzipValue, resp.Header.Get(contentEncodingHeader))
|
||||
assert.Equal(t, acceptEncodingHeader, resp.Header.Get(varyHeader))
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
if assert.ObjectsAreEqualValues(body, fakeBody) {
|
||||
assert.Fail(t, "expected a compressed body", "got %v", body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateBytes(len int) []byte {
|
||||
var value []byte
|
||||
for i := 0; i < len; i++ {
|
||||
value = append(value, 0x61+byte(i))
|
||||
}
|
||||
return value
|
||||
}
|
||||
30
old/middlewares/empty_backend_handler.go
Normal file
30
old/middlewares/empty_backend_handler.go
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/healthcheck"
|
||||
)
|
||||
|
||||
// EmptyBackendHandler is a middlware that checks whether the current Backend
|
||||
// has at least one active Server in respect to the healthchecks and if this
|
||||
// is not the case, it will stop the middleware chain and respond with 503.
|
||||
type EmptyBackendHandler struct {
|
||||
next healthcheck.BalancerHandler
|
||||
}
|
||||
|
||||
// NewEmptyBackendHandler creates a new EmptyBackendHandler instance.
|
||||
func NewEmptyBackendHandler(lb healthcheck.BalancerHandler) *EmptyBackendHandler {
|
||||
return &EmptyBackendHandler{next: lb}
|
||||
}
|
||||
|
||||
// ServeHTTP responds with 503 when there is no active Server and otherwise
|
||||
// invokes the next handler in the middleware chain.
|
||||
func (h *EmptyBackendHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
if len(h.next.Servers()) == 0 {
|
||||
rw.WriteHeader(http.StatusServiceUnavailable)
|
||||
rw.Write([]byte(http.StatusText(http.StatusServiceUnavailable)))
|
||||
} else {
|
||||
h.next.ServeHTTP(rw, r)
|
||||
}
|
||||
}
|
||||
83
old/middlewares/empty_backend_handler_test.go
Normal file
83
old/middlewares/empty_backend_handler_test.go
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/vulcand/oxy/roundrobin"
|
||||
)
|
||||
|
||||
func TestEmptyBackendHandler(t *testing.T) {
|
||||
tests := []struct {
|
||||
amountServer int
|
||||
wantStatusCode int
|
||||
}{
|
||||
{
|
||||
amountServer: 0,
|
||||
wantStatusCode: http.StatusServiceUnavailable,
|
||||
},
|
||||
{
|
||||
amountServer: 1,
|
||||
wantStatusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
|
||||
t.Run(fmt.Sprintf("amount servers %d", test.amountServer), func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewEmptyBackendHandler(&healthCheckLoadBalancer{test.amountServer})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
|
||||
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Result().StatusCode != test.wantStatusCode {
|
||||
t.Errorf("Received status code %d, wanted %d", recorder.Result().StatusCode, test.wantStatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type healthCheckLoadBalancer struct {
|
||||
amountServer int
|
||||
}
|
||||
|
||||
func (lb *healthCheckLoadBalancer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (lb *healthCheckLoadBalancer) Servers() []*url.URL {
|
||||
servers := make([]*url.URL, lb.amountServer)
|
||||
for i := 0; i < lb.amountServer; i++ {
|
||||
servers = append(servers, testhelpers.MustParseURL("http://localhost"))
|
||||
}
|
||||
return servers
|
||||
}
|
||||
|
||||
func (lb *healthCheckLoadBalancer) RemoveServer(u *url.URL) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lb *healthCheckLoadBalancer) UpsertServer(u *url.URL, options ...roundrobin.ServerOption) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (lb *healthCheckLoadBalancer) ServerWeight(u *url.URL) (int, bool) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (lb *healthCheckLoadBalancer) NextServer() (*url.URL, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (lb *healthCheckLoadBalancer) Next() http.Handler {
|
||||
return nil
|
||||
}
|
||||
236
old/middlewares/errorpages/error_pages.go
Normal file
236
old/middlewares/errorpages/error_pages.go
Normal file
|
|
@ -0,0 +1,236 @@
|
|||
package errorpages
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/containous/traefik/old/middlewares"
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/vulcand/oxy/forward"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
// Compile time validation that the response recorder implements http interfaces correctly.
|
||||
var _ middlewares.Stateful = &responseRecorderWithCloseNotify{}
|
||||
|
||||
// Handler is a middleware that provides the custom error pages
|
||||
type Handler struct {
|
||||
BackendName string
|
||||
backendHandler http.Handler
|
||||
httpCodeRanges types.HTTPCodeRanges
|
||||
backendURL string
|
||||
backendQuery string
|
||||
FallbackURL string // Deprecated
|
||||
}
|
||||
|
||||
// NewHandler initializes the utils.ErrorHandler for the custom error pages
|
||||
func NewHandler(errorPage *types.ErrorPage, backendName string) (*Handler, error) {
|
||||
if len(backendName) == 0 {
|
||||
return nil, errors.New("error pages: backend name is mandatory ")
|
||||
}
|
||||
|
||||
httpCodeRanges, err := types.NewHTTPCodeRanges(errorPage.Status)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Handler{
|
||||
BackendName: backendName,
|
||||
httpCodeRanges: httpCodeRanges,
|
||||
backendQuery: errorPage.Query,
|
||||
backendURL: "http://0.0.0.0",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// PostLoad adds backend handler if available
|
||||
func (h *Handler) PostLoad(backendHandler http.Handler) error {
|
||||
if backendHandler == nil {
|
||||
fwd, err := forward.New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.backendHandler = fwd
|
||||
h.backendURL = h.FallbackURL
|
||||
} else {
|
||||
h.backendHandler = backendHandler
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) ServeHTTP(w http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
|
||||
if h.backendHandler == nil {
|
||||
log.Error("Error pages: no backend handler.")
|
||||
next.ServeHTTP(w, req)
|
||||
return
|
||||
}
|
||||
|
||||
recorder := newResponseRecorder(w)
|
||||
next.ServeHTTP(recorder, req)
|
||||
|
||||
// check the recorder code against the configured http status code ranges
|
||||
for _, block := range h.httpCodeRanges {
|
||||
if recorder.GetCode() >= block[0] && recorder.GetCode() <= block[1] {
|
||||
log.Errorf("Caught HTTP Status Code %d, returning error page", recorder.GetCode())
|
||||
|
||||
var query string
|
||||
if len(h.backendQuery) > 0 {
|
||||
query = "/" + strings.TrimPrefix(h.backendQuery, "/")
|
||||
query = strings.Replace(query, "{status}", strconv.Itoa(recorder.GetCode()), -1)
|
||||
}
|
||||
|
||||
pageReq, err := newRequest(h.backendURL + query)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
w.WriteHeader(recorder.GetCode())
|
||||
fmt.Fprint(w, http.StatusText(recorder.GetCode()))
|
||||
return
|
||||
}
|
||||
|
||||
recorderErrorPage := newResponseRecorder(w)
|
||||
utils.CopyHeaders(pageReq.Header, req.Header)
|
||||
|
||||
h.backendHandler.ServeHTTP(recorderErrorPage, pageReq.WithContext(req.Context()))
|
||||
|
||||
utils.CopyHeaders(w.Header(), recorderErrorPage.Header())
|
||||
w.WriteHeader(recorder.GetCode())
|
||||
|
||||
if _, err = w.Write(recorderErrorPage.GetBody().Bytes()); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// did not catch a configured status code so proceed with the request
|
||||
utils.CopyHeaders(w.Header(), recorder.Header())
|
||||
w.WriteHeader(recorder.GetCode())
|
||||
w.Write(recorder.GetBody().Bytes())
|
||||
}
|
||||
|
||||
func newRequest(baseURL string) (*http.Request, error) {
|
||||
u, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error pages: error when parse URL: %v", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, u.String(), http.NoBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error pages: error when create query: %v", err)
|
||||
}
|
||||
|
||||
req.RequestURI = u.RequestURI()
|
||||
return req, nil
|
||||
}
|
||||
|
||||
type responseRecorder interface {
|
||||
http.ResponseWriter
|
||||
http.Flusher
|
||||
GetCode() int
|
||||
GetBody() *bytes.Buffer
|
||||
IsStreamingResponseStarted() bool
|
||||
}
|
||||
|
||||
// newResponseRecorder returns an initialized responseRecorder.
|
||||
func newResponseRecorder(rw http.ResponseWriter) responseRecorder {
|
||||
recorder := &responseRecorderWithoutCloseNotify{
|
||||
HeaderMap: make(http.Header),
|
||||
Body: new(bytes.Buffer),
|
||||
Code: http.StatusOK,
|
||||
responseWriter: rw,
|
||||
}
|
||||
if _, ok := rw.(http.CloseNotifier); ok {
|
||||
return &responseRecorderWithCloseNotify{recorder}
|
||||
}
|
||||
return recorder
|
||||
}
|
||||
|
||||
// responseRecorderWithoutCloseNotify is an implementation of http.ResponseWriter that
|
||||
// records its mutations for later inspection.
|
||||
type responseRecorderWithoutCloseNotify struct {
|
||||
Code int // the HTTP response code from WriteHeader
|
||||
HeaderMap http.Header // the HTTP response headers
|
||||
Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
|
||||
|
||||
responseWriter http.ResponseWriter
|
||||
err error
|
||||
streamingResponseStarted bool
|
||||
}
|
||||
|
||||
type responseRecorderWithCloseNotify struct {
|
||||
*responseRecorderWithoutCloseNotify
|
||||
}
|
||||
|
||||
// CloseNotify returns a channel that receives at most a
|
||||
// single value (true) when the client connection has gone away.
|
||||
func (r *responseRecorderWithCloseNotify) CloseNotify() <-chan bool {
|
||||
return r.responseWriter.(http.CloseNotifier).CloseNotify()
|
||||
}
|
||||
|
||||
// Header returns the response headers.
|
||||
func (r *responseRecorderWithoutCloseNotify) Header() http.Header {
|
||||
if r.HeaderMap == nil {
|
||||
r.HeaderMap = make(http.Header)
|
||||
}
|
||||
|
||||
return r.HeaderMap
|
||||
}
|
||||
|
||||
func (r *responseRecorderWithoutCloseNotify) GetCode() int {
|
||||
return r.Code
|
||||
}
|
||||
|
||||
func (r *responseRecorderWithoutCloseNotify) GetBody() *bytes.Buffer {
|
||||
return r.Body
|
||||
}
|
||||
|
||||
func (r *responseRecorderWithoutCloseNotify) IsStreamingResponseStarted() bool {
|
||||
return r.streamingResponseStarted
|
||||
}
|
||||
|
||||
// Write always succeeds and writes to rw.Body, if not nil.
|
||||
func (r *responseRecorderWithoutCloseNotify) Write(buf []byte) (int, error) {
|
||||
if r.err != nil {
|
||||
return 0, r.err
|
||||
}
|
||||
return r.Body.Write(buf)
|
||||
}
|
||||
|
||||
// WriteHeader sets rw.Code.
|
||||
func (r *responseRecorderWithoutCloseNotify) WriteHeader(code int) {
|
||||
r.Code = code
|
||||
}
|
||||
|
||||
// Hijack hijacks the connection
|
||||
func (r *responseRecorderWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return r.responseWriter.(http.Hijacker).Hijack()
|
||||
}
|
||||
|
||||
// Flush sends any buffered data to the client.
|
||||
func (r *responseRecorderWithoutCloseNotify) Flush() {
|
||||
if !r.streamingResponseStarted {
|
||||
utils.CopyHeaders(r.responseWriter.Header(), r.Header())
|
||||
r.responseWriter.WriteHeader(r.Code)
|
||||
r.streamingResponseStarted = true
|
||||
}
|
||||
|
||||
_, err := r.responseWriter.Write(r.Body.Bytes())
|
||||
if err != nil {
|
||||
log.Errorf("Error writing response in responseRecorder: %v", err)
|
||||
r.err = err
|
||||
}
|
||||
r.Body.Reset()
|
||||
|
||||
if flusher, ok := r.responseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
384
old/middlewares/errorpages/error_pages_test.go
Normal file
384
old/middlewares/errorpages/error_pages_test.go
Normal file
|
|
@ -0,0 +1,384 @@
|
|||
package errorpages
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
func TestHandler(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
errorPage *types.ErrorPage
|
||||
backendCode int
|
||||
backendErrorHandler http.HandlerFunc
|
||||
validate func(t *testing.T, recorder *httptest.ResponseRecorder)
|
||||
}{
|
||||
{
|
||||
desc: "no error",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusOK,
|
||||
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "My error page.")
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusOK, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusOK))
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "in the range",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusInternalServerError,
|
||||
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "My error page.")
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusInternalServerError, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "My error page.")
|
||||
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "not in the range",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusBadGateway,
|
||||
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "My error page.")
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusBadGateway, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusBadGateway))
|
||||
assert.NotContains(t, recorder.Body.String(), "Test Server", "Should return the oops page since we have not configured the 502 code")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "query replacement",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503-503"}},
|
||||
backendCode: http.StatusServiceUnavailable,
|
||||
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.RequestURI == "/503" {
|
||||
fmt.Fprintln(w, "My 503 page.")
|
||||
} else {
|
||||
fmt.Fprintln(w, "Failed")
|
||||
}
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "My 503 page.")
|
||||
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Single code",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503"}},
|
||||
backendCode: http.StatusServiceUnavailable,
|
||||
backendErrorHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.RequestURI == "/503" {
|
||||
fmt.Fprintln(w, "My 503 page.")
|
||||
} else {
|
||||
fmt.Fprintln(w, "Failed")
|
||||
}
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "My 503 page.")
|
||||
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errorPageHandler, err := NewHandler(test.errorPage, "test")
|
||||
require.NoError(t, err)
|
||||
|
||||
errorPageHandler.backendHandler = test.backendErrorHandler
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(test.backendCode)
|
||||
fmt.Fprintln(w, http.StatusText(test.backendCode))
|
||||
})
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost/test", nil)
|
||||
|
||||
n := negroni.New()
|
||||
n.Use(errorPageHandler)
|
||||
n.UseHandler(handler)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
n.ServeHTTP(recorder, req)
|
||||
|
||||
test.validate(t, recorder)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerOldWay(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
errorPage *types.ErrorPage
|
||||
backendCode int
|
||||
errorPageForwarder http.HandlerFunc
|
||||
validate func(t *testing.T, recorder *httptest.ResponseRecorder)
|
||||
}{
|
||||
{
|
||||
desc: "no error",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusOK,
|
||||
errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "My error page.")
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusOK, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "OK")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "in the range",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusInternalServerError,
|
||||
errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "My error page.")
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusInternalServerError, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), "My error page.")
|
||||
assert.NotContains(t, recorder.Body.String(), http.StatusText(http.StatusInternalServerError), "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "not in the range",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusBadGateway,
|
||||
errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "My error page.")
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusBadGateway, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusBadGateway))
|
||||
assert.NotContains(t, recorder.Body.String(), "My error page.", "Should return the oops page since we have not configured the 502 code")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "query replacement",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503-503"}},
|
||||
backendCode: http.StatusServiceUnavailable,
|
||||
errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.RequestURI() == "/"+strconv.Itoa(503) {
|
||||
fmt.Fprintln(w, "My 503 page.")
|
||||
} else {
|
||||
fmt.Fprintln(w, "Failed")
|
||||
}
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "My 503 page.")
|
||||
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Single code",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503"}},
|
||||
backendCode: http.StatusServiceUnavailable,
|
||||
errorPageForwarder: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.RequestURI() == "/"+strconv.Itoa(503) {
|
||||
fmt.Fprintln(w, "My 503 page.")
|
||||
} else {
|
||||
fmt.Fprintln(w, "Failed")
|
||||
}
|
||||
}),
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "My 503 page.")
|
||||
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost/test", nil)
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
errorPageHandler, err := NewHandler(test.errorPage, "test")
|
||||
require.NoError(t, err)
|
||||
errorPageHandler.FallbackURL = "http://localhost"
|
||||
|
||||
err = errorPageHandler.PostLoad(test.errorPageForwarder)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(test.backendCode)
|
||||
fmt.Fprintln(w, http.StatusText(test.backendCode))
|
||||
})
|
||||
|
||||
n := negroni.New()
|
||||
n.Use(errorPageHandler)
|
||||
n.UseHandler(handler)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
n.ServeHTTP(recorder, req)
|
||||
|
||||
test.validate(t, recorder)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerOldWayIntegration(t *testing.T) {
|
||||
errorPagesServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.RequestURI() == "/503" {
|
||||
fmt.Fprintln(w, "My 503 page.")
|
||||
} else {
|
||||
fmt.Fprintln(w, "Test Server")
|
||||
}
|
||||
}))
|
||||
defer errorPagesServer.Close()
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
errorPage *types.ErrorPage
|
||||
backendCode int
|
||||
validate func(t *testing.T, recorder *httptest.ResponseRecorder)
|
||||
}{
|
||||
{
|
||||
desc: "no error",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusOK,
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusOK, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "OK")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "in the range",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusInternalServerError,
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusInternalServerError, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), "Test Server")
|
||||
assert.NotContains(t, recorder.Body.String(), http.StatusText(http.StatusInternalServerError), "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "not in the range",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/test", Status: []string{"500-501", "503-599"}},
|
||||
backendCode: http.StatusBadGateway,
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusBadGateway, recorder.Code)
|
||||
assert.Contains(t, recorder.Body.String(), http.StatusText(http.StatusBadGateway))
|
||||
assert.NotContains(t, recorder.Body.String(), "Test Server", "Should return the oops page since we have not configured the 502 code")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "query replacement",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503-503"}},
|
||||
backendCode: http.StatusServiceUnavailable,
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "My 503 page.")
|
||||
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Single code",
|
||||
errorPage: &types.ErrorPage{Backend: "error", Query: "/{status}", Status: []string{"503"}},
|
||||
backendCode: http.StatusServiceUnavailable,
|
||||
validate: func(t *testing.T, recorder *httptest.ResponseRecorder) {
|
||||
assert.Equal(t, http.StatusServiceUnavailable, recorder.Code, "HTTP status")
|
||||
assert.Contains(t, recorder.Body.String(), "My 503 page.")
|
||||
assert.NotContains(t, recorder.Body.String(), "oops", "Should not return the oops page")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, errorPagesServer.URL+"/test", nil)
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
|
||||
errorPageHandler, err := NewHandler(test.errorPage, "test")
|
||||
require.NoError(t, err)
|
||||
errorPageHandler.FallbackURL = errorPagesServer.URL
|
||||
|
||||
err = errorPageHandler.PostLoad(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(test.backendCode)
|
||||
fmt.Fprintln(w, http.StatusText(test.backendCode))
|
||||
})
|
||||
|
||||
n := negroni.New()
|
||||
n.Use(errorPageHandler)
|
||||
n.UseHandler(handler)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
n.ServeHTTP(recorder, req)
|
||||
|
||||
test.validate(t, recorder)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewResponseRecorder(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
rw http.ResponseWriter
|
||||
expected http.ResponseWriter
|
||||
}{
|
||||
{
|
||||
desc: "Without Close Notify",
|
||||
rw: httptest.NewRecorder(),
|
||||
expected: &responseRecorderWithoutCloseNotify{},
|
||||
},
|
||||
{
|
||||
desc: "With Close Notify",
|
||||
rw: &mockRWCloseNotify{},
|
||||
expected: &responseRecorderWithCloseNotify{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rec := newResponseRecorder(test.rw)
|
||||
|
||||
assert.IsType(t, rec, test.expected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type mockRWCloseNotify struct{}
|
||||
|
||||
func (m *mockRWCloseNotify) CloseNotify() <-chan bool {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockRWCloseNotify) Header() http.Header {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockRWCloseNotify) Write([]byte) (int, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (m *mockRWCloseNotify) WriteHeader(int) {
|
||||
panic("implement me")
|
||||
}
|
||||
52
old/middlewares/forwardedheaders/forwarded_header.go
Normal file
52
old/middlewares/forwardedheaders/forwarded_header.go
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
package forwardedheaders
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/ip"
|
||||
"github.com/vulcand/oxy/forward"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
// XForwarded filter for XForwarded headers
|
||||
type XForwarded struct {
|
||||
insecure bool
|
||||
trustedIps []string
|
||||
ipChecker *ip.Checker
|
||||
}
|
||||
|
||||
// NewXforwarded creates a new XForwarded
|
||||
func NewXforwarded(insecure bool, trustedIps []string) (*XForwarded, error) {
|
||||
var ipChecker *ip.Checker
|
||||
if len(trustedIps) > 0 {
|
||||
var err error
|
||||
ipChecker, err = ip.NewChecker(trustedIps)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &XForwarded{
|
||||
insecure: insecure,
|
||||
trustedIps: trustedIps,
|
||||
ipChecker: ipChecker,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (x *XForwarded) isTrustedIP(ip string) bool {
|
||||
if x.ipChecker == nil {
|
||||
return false
|
||||
}
|
||||
return x.ipChecker.IsAuthorized(ip) == nil
|
||||
}
|
||||
|
||||
func (x *XForwarded) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
if !x.insecure && !x.isTrustedIP(r.RemoteAddr) {
|
||||
utils.RemoveHeaders(r.Header, forward.XHeaders...)
|
||||
}
|
||||
|
||||
// If there is a next, call it.
|
||||
if next != nil {
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
128
old/middlewares/forwardedheaders/forwarded_header_test.go
Normal file
128
old/middlewares/forwardedheaders/forwarded_header_test.go
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
package forwardedheaders
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestServeHTTP(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
insecure bool
|
||||
trustedIps []string
|
||||
incomingHeaders map[string]string
|
||||
remoteAddr string
|
||||
expectedHeaders map[string]string
|
||||
}{
|
||||
{
|
||||
desc: "all Empty",
|
||||
insecure: true,
|
||||
trustedIps: nil,
|
||||
remoteAddr: "",
|
||||
incomingHeaders: map[string]string{},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-for": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "insecure true with incoming X-Forwarded-For",
|
||||
insecure: true,
|
||||
trustedIps: nil,
|
||||
remoteAddr: "",
|
||||
incomingHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "insecure false with incoming X-Forwarded-For",
|
||||
insecure: false,
|
||||
trustedIps: nil,
|
||||
remoteAddr: "",
|
||||
incomingHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-for": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "insecure false with incoming X-Forwarded-For and valid Trusted Ips",
|
||||
insecure: false,
|
||||
trustedIps: []string{"10.0.1.100"},
|
||||
remoteAddr: "10.0.1.100:80",
|
||||
incomingHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "insecure false with incoming X-Forwarded-For and invalid Trusted Ips",
|
||||
insecure: false,
|
||||
trustedIps: []string{"10.0.1.100"},
|
||||
remoteAddr: "10.0.1.101:80",
|
||||
incomingHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-for": "",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "insecure false with incoming X-Forwarded-For and valid Trusted Ips CIDR",
|
||||
insecure: false,
|
||||
trustedIps: []string{"1.2.3.4/24"},
|
||||
remoteAddr: "1.2.3.156:80",
|
||||
incomingHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "insecure false with incoming X-Forwarded-For and invalid Trusted Ips CIDR",
|
||||
insecure: false,
|
||||
trustedIps: []string{"1.2.3.4/24"},
|
||||
remoteAddr: "10.0.1.101:80",
|
||||
incomingHeaders: map[string]string{
|
||||
"X-Forwarded-for": "10.0.1.0, 10.0.1.12",
|
||||
},
|
||||
expectedHeaders: map[string]string{
|
||||
"X-Forwarded-for": "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.RemoteAddr = test.remoteAddr
|
||||
|
||||
for k, v := range test.incomingHeaders {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
m, err := NewXforwarded(test.insecure, test.trustedIps)
|
||||
require.NoError(t, err)
|
||||
|
||||
m.ServeHTTP(nil, req, nil)
|
||||
|
||||
for k, v := range test.expectedHeaders {
|
||||
assert.Equal(t, v, req.Header.Get(k))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
36
old/middlewares/handlerSwitcher.go
Normal file
36
old/middlewares/handlerSwitcher.go
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/safe"
|
||||
)
|
||||
|
||||
// HandlerSwitcher allows hot switching of http.ServeMux
|
||||
type HandlerSwitcher struct {
|
||||
handler *safe.Safe
|
||||
}
|
||||
|
||||
// NewHandlerSwitcher builds a new instance of HandlerSwitcher
|
||||
func NewHandlerSwitcher(newHandler *mux.Router) (hs *HandlerSwitcher) {
|
||||
return &HandlerSwitcher{
|
||||
handler: safe.New(newHandler),
|
||||
}
|
||||
}
|
||||
|
||||
func (hs *HandlerSwitcher) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
handlerBackup := hs.handler.Get().(*mux.Router)
|
||||
handlerBackup.ServeHTTP(rw, r)
|
||||
}
|
||||
|
||||
// GetHandler returns the current http.ServeMux
|
||||
func (hs *HandlerSwitcher) GetHandler() (newHandler *mux.Router) {
|
||||
handler := hs.handler.Get().(*mux.Router)
|
||||
return handler
|
||||
}
|
||||
|
||||
// UpdateHandler safely updates the current http.ServeMux with a new one
|
||||
func (hs *HandlerSwitcher) UpdateHandler(newHandler *mux.Router) {
|
||||
hs.handler.Set(newHandler)
|
||||
}
|
||||
71
old/middlewares/headers.go
Normal file
71
old/middlewares/headers.go
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
package middlewares
|
||||
|
||||
// Middleware based on https://github.com/unrolled/secure
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/old/types"
|
||||
)
|
||||
|
||||
// HeaderOptions is a struct for specifying configuration options for the headers middleware.
|
||||
type HeaderOptions struct {
|
||||
// If Custom request headers are set, these will be added to the request
|
||||
CustomRequestHeaders map[string]string
|
||||
// If Custom response headers are set, these will be added to the ResponseWriter
|
||||
CustomResponseHeaders map[string]string
|
||||
}
|
||||
|
||||
// HeaderStruct is a middleware that helps setup a few basic security features. A single headerOptions struct can be
|
||||
// provided to configure which features should be enabled, and the ability to override a few of the default values.
|
||||
type HeaderStruct struct {
|
||||
// Customize headers with a headerOptions struct.
|
||||
opt HeaderOptions
|
||||
}
|
||||
|
||||
// NewHeaderFromStruct constructs a new header instance from supplied frontend header struct.
|
||||
func NewHeaderFromStruct(headers *types.Headers) *HeaderStruct {
|
||||
if headers == nil || !headers.HasCustomHeadersDefined() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &HeaderStruct{
|
||||
opt: HeaderOptions{
|
||||
CustomRequestHeaders: headers.CustomRequestHeaders,
|
||||
CustomResponseHeaders: headers.CustomResponseHeaders,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *HeaderStruct) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
s.ModifyRequestHeaders(r)
|
||||
// If there is a next, call it.
|
||||
if next != nil {
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// ModifyRequestHeaders set or delete request headers
|
||||
func (s *HeaderStruct) ModifyRequestHeaders(r *http.Request) {
|
||||
// Loop through Custom request headers
|
||||
for header, value := range s.opt.CustomRequestHeaders {
|
||||
if value == "" {
|
||||
r.Header.Del(header)
|
||||
} else {
|
||||
r.Header.Set(header, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ModifyResponseHeaders set or delete response headers
|
||||
func (s *HeaderStruct) ModifyResponseHeaders(res *http.Response) error {
|
||||
// Loop through Custom response headers
|
||||
for header, value := range s.opt.CustomResponseHeaders {
|
||||
if value == "" {
|
||||
res.Header.Del(header)
|
||||
} else {
|
||||
res.Header.Set(header, value)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
118
old/middlewares/headers_test.go
Normal file
118
old/middlewares/headers_test.go
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
package middlewares
|
||||
|
||||
// Middleware tests based on https://github.com/unrolled/secure
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
var myHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("bar"))
|
||||
})
|
||||
|
||||
// newHeader constructs a new header instance with supplied options.
|
||||
func newHeader(options ...HeaderOptions) *HeaderStruct {
|
||||
var opt HeaderOptions
|
||||
if len(options) == 0 {
|
||||
opt = HeaderOptions{}
|
||||
} else {
|
||||
opt = options[0]
|
||||
}
|
||||
|
||||
return &HeaderStruct{opt: opt}
|
||||
}
|
||||
|
||||
func TestNoConfig(t *testing.T) {
|
||||
header := newHeader()
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://example.com/foo", nil)
|
||||
|
||||
header.ServeHTTP(res, req, myHandler)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
|
||||
assert.Equal(t, "bar", res.Body.String(), "Body not the expected")
|
||||
}
|
||||
|
||||
func TestModifyResponseHeaders(t *testing.T) {
|
||||
header := newHeader(HeaderOptions{
|
||||
CustomResponseHeaders: map[string]string{
|
||||
"X-Custom-Response-Header": "test_response",
|
||||
},
|
||||
})
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
res.HeaderMap.Add("X-Custom-Response-Header", "test_response")
|
||||
|
||||
err := header.ModifyResponseHeaders(res.Result())
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
|
||||
assert.Equal(t, "test_response", res.Header().Get("X-Custom-Response-Header"), "Did not get expected header")
|
||||
|
||||
res = httptest.NewRecorder()
|
||||
res.HeaderMap.Add("X-Custom-Response-Header", "")
|
||||
|
||||
err = header.ModifyResponseHeaders(res.Result())
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
|
||||
assert.Equal(t, "", res.Header().Get("X-Custom-Response-Header"), "Did not get expected header")
|
||||
|
||||
res = httptest.NewRecorder()
|
||||
res.HeaderMap.Add("X-Custom-Response-Header", "test_override")
|
||||
|
||||
err = header.ModifyResponseHeaders(res.Result())
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
|
||||
assert.Equal(t, "test_override", res.Header().Get("X-Custom-Response-Header"), "Did not get expected header")
|
||||
}
|
||||
|
||||
func TestCustomRequestHeader(t *testing.T) {
|
||||
header := newHeader(HeaderOptions{
|
||||
CustomRequestHeaders: map[string]string{
|
||||
"X-Custom-Request-Header": "test_request",
|
||||
},
|
||||
})
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
|
||||
|
||||
header.ServeHTTP(res, req, nil)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
|
||||
assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"), "Did not get expected header")
|
||||
}
|
||||
|
||||
func TestCustomRequestHeaderEmptyValue(t *testing.T) {
|
||||
header := newHeader(HeaderOptions{
|
||||
CustomRequestHeaders: map[string]string{
|
||||
"X-Custom-Request-Header": "test_request",
|
||||
},
|
||||
})
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "/foo", nil)
|
||||
|
||||
header.ServeHTTP(res, req, nil)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
|
||||
assert.Equal(t, "test_request", req.Header.Get("X-Custom-Request-Header"), "Did not get expected header")
|
||||
|
||||
header = newHeader(HeaderOptions{
|
||||
CustomRequestHeaders: map[string]string{
|
||||
"X-Custom-Request-Header": "",
|
||||
},
|
||||
})
|
||||
|
||||
header.ServeHTTP(res, req, nil)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.Code, "Status not OK")
|
||||
assert.Equal(t, "", req.Header.Get("X-Custom-Request-Header"), "This header is not expected")
|
||||
}
|
||||
67
old/middlewares/ip_whitelister.go
Normal file
67
old/middlewares/ip_whitelister.go
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/ip"
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/containous/traefik/old/middlewares/tracing"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
// IPWhiteLister is a middleware that provides Checks of the Requesting IP against a set of Whitelists
|
||||
type IPWhiteLister struct {
|
||||
handler negroni.Handler
|
||||
whiteLister *ip.Checker
|
||||
strategy ip.Strategy
|
||||
}
|
||||
|
||||
// NewIPWhiteLister builds a new IPWhiteLister given a list of CIDR-Strings to whitelist
|
||||
func NewIPWhiteLister(whiteList []string, strategy ip.Strategy) (*IPWhiteLister, error) {
|
||||
if len(whiteList) == 0 {
|
||||
return nil, errors.New("no white list provided")
|
||||
}
|
||||
|
||||
checker, err := ip.NewChecker(whiteList)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing CIDR whitelist %s: %v", whiteList, err)
|
||||
}
|
||||
|
||||
whiteLister := IPWhiteLister{
|
||||
strategy: strategy,
|
||||
whiteLister: checker,
|
||||
}
|
||||
|
||||
whiteLister.handler = negroni.HandlerFunc(whiteLister.handle)
|
||||
log.Debugf("configured IP white list: %s", whiteList)
|
||||
|
||||
return &whiteLister, nil
|
||||
}
|
||||
|
||||
func (wl *IPWhiteLister) handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
err := wl.whiteLister.IsAuthorized(wl.strategy.GetIP(r))
|
||||
if err != nil {
|
||||
tracing.SetErrorAndDebugLog(r, "request %+v - rejecting: %v", r, err)
|
||||
reject(w)
|
||||
return
|
||||
}
|
||||
log.Debugf("Accept %s: %+v", wl.strategy.GetIP(r), r)
|
||||
tracing.SetErrorAndDebugLog(r, "request %+v matched white list %v - passing", r, wl.whiteLister)
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func (wl *IPWhiteLister) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
wl.handler.ServeHTTP(rw, r, next)
|
||||
}
|
||||
|
||||
func reject(w http.ResponseWriter) {
|
||||
statusCode := http.StatusForbidden
|
||||
|
||||
w.WriteHeader(statusCode)
|
||||
_, err := w.Write([]byte(http.StatusText(statusCode)))
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
92
old/middlewares/ip_whitelister_test.go
Normal file
92
old/middlewares/ip_whitelister_test.go
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/ip"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewIPWhiteLister(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
whiteList []string
|
||||
expectedError string
|
||||
}{
|
||||
{
|
||||
desc: "invalid IP",
|
||||
whiteList: []string{"foo"},
|
||||
expectedError: "parsing CIDR whitelist [foo]: parsing CIDR trusted IPs <nil>: invalid CIDR address: foo",
|
||||
},
|
||||
{
|
||||
desc: "valid IP",
|
||||
whiteList: []string{"10.10.10.10"},
|
||||
expectedError: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
whiteLister, err := NewIPWhiteLister(test.whiteList, &ip.RemoteAddrStrategy{})
|
||||
|
||||
if len(test.expectedError) > 0 {
|
||||
assert.EqualError(t, err, test.expectedError)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, whiteLister)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPWhiteLister_ServeHTTP(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
whiteList []string
|
||||
remoteAddr string
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
desc: "authorized with remote address",
|
||||
whiteList: []string{"20.20.20.20"},
|
||||
remoteAddr: "20.20.20.20:1234",
|
||||
expected: 200,
|
||||
},
|
||||
{
|
||||
desc: "non authorized with remote address",
|
||||
whiteList: []string{"20.20.20.20"},
|
||||
remoteAddr: "20.20.20.21:1234",
|
||||
expected: 403,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
whiteLister, err := NewIPWhiteLister(test.whiteList, &ip.RemoteAddrStrategy{})
|
||||
require.NoError(t, err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://10.10.10.10", nil)
|
||||
|
||||
if len(test.remoteAddr) > 0 {
|
||||
req.RemoteAddr = test.remoteAddr
|
||||
}
|
||||
|
||||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
|
||||
whiteLister.ServeHTTP(recorder, req, next)
|
||||
|
||||
assert.Equal(t, test.expected, recorder.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
62
old/middlewares/pipelining/pipelining.go
Normal file
62
old/middlewares/pipelining/pipelining.go
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
package pipelining
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Pipelining returns a middleware
|
||||
type Pipelining struct {
|
||||
next http.Handler
|
||||
}
|
||||
|
||||
// NewPipelining returns a new Pipelining instance
|
||||
func NewPipelining(next http.Handler) *Pipelining {
|
||||
return &Pipelining{
|
||||
next: next,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Pipelining) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
// https://github.com/golang/go/blob/3d59583836630cf13ec4bfbed977d27b1b7adbdc/src/net/http/server.go#L201-L218
|
||||
if r.Method == http.MethodPut || r.Method == http.MethodPost {
|
||||
p.next.ServeHTTP(rw, r)
|
||||
} else {
|
||||
p.next.ServeHTTP(&writerWithoutCloseNotify{rw}, r)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// writerWithoutCloseNotify helps to disable closeNotify
|
||||
type writerWithoutCloseNotify struct {
|
||||
W http.ResponseWriter
|
||||
}
|
||||
|
||||
// Header returns the response headers.
|
||||
func (w *writerWithoutCloseNotify) Header() http.Header {
|
||||
return w.W.Header()
|
||||
}
|
||||
|
||||
// Write writes the data to the connection as part of an HTTP reply.
|
||||
func (w *writerWithoutCloseNotify) Write(buf []byte) (int, error) {
|
||||
return w.W.Write(buf)
|
||||
}
|
||||
|
||||
// WriteHeader sends an HTTP response header with the provided
|
||||
// status code.
|
||||
func (w *writerWithoutCloseNotify) WriteHeader(code int) {
|
||||
w.W.WriteHeader(code)
|
||||
}
|
||||
|
||||
// Flush sends any buffered data to the client.
|
||||
func (w *writerWithoutCloseNotify) Flush() {
|
||||
if f, ok := w.W.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Hijack hijacks the connection.
|
||||
func (w *writerWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return w.W.(http.Hijacker).Hijack()
|
||||
}
|
||||
69
old/middlewares/pipelining/pipelining_test.go
Normal file
69
old/middlewares/pipelining/pipelining_test.go
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
package pipelining
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type recorderWithCloseNotify struct {
|
||||
*httptest.ResponseRecorder
|
||||
}
|
||||
|
||||
func (r *recorderWithCloseNotify) CloseNotify() <-chan bool {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func TestNewPipelining(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
HTTPMethod string
|
||||
implementCloseNotifier bool
|
||||
}{
|
||||
{
|
||||
desc: "should not implement CloseNotifier with GET method",
|
||||
HTTPMethod: http.MethodGet,
|
||||
implementCloseNotifier: false,
|
||||
},
|
||||
{
|
||||
desc: "should implement CloseNotifier with PUT method",
|
||||
HTTPMethod: http.MethodPut,
|
||||
implementCloseNotifier: true,
|
||||
},
|
||||
{
|
||||
desc: "should implement CloseNotifier with POST method",
|
||||
HTTPMethod: http.MethodPost,
|
||||
implementCloseNotifier: true,
|
||||
},
|
||||
{
|
||||
desc: "should not implement CloseNotifier with GET method",
|
||||
HTTPMethod: http.MethodHead,
|
||||
implementCloseNotifier: false,
|
||||
},
|
||||
{
|
||||
desc: "should not implement CloseNotifier with PROPFIND method",
|
||||
HTTPMethod: "PROPFIND",
|
||||
implementCloseNotifier: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, ok := w.(http.CloseNotifier)
|
||||
assert.Equal(t, test.implementCloseNotifier, ok)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler := NewPipelining(nextHandler)
|
||||
|
||||
req := httptest.NewRequest(test.HTTPMethod, "http://localhost", nil)
|
||||
|
||||
handler.ServeHTTP(&recorderWithCloseNotify{httptest.NewRecorder()}, req)
|
||||
})
|
||||
}
|
||||
}
|
||||
51
old/middlewares/recover.go
Normal file
51
old/middlewares/recover.go
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"runtime"
|
||||
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
// RecoverHandler recovers from a panic in http handlers
|
||||
func RecoverHandler(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
defer recoverFunc(w, r)
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
|
||||
// NegroniRecoverHandler recovers from a panic in negroni handlers
|
||||
func NegroniRecoverHandler() negroni.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
defer recoverFunc(w, r)
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
return negroni.HandlerFunc(fn)
|
||||
}
|
||||
|
||||
func recoverFunc(w http.ResponseWriter, r *http.Request) {
|
||||
if err := recover(); err != nil {
|
||||
if !shouldLogPanic(err) {
|
||||
log.Debugf("Request has been aborted [%s - %s]: %v", r.RemoteAddr, r.URL, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Errorf("Recovered from panic in HTTP handler [%s - %s]: %+v", r.RemoteAddr, r.URL, err)
|
||||
|
||||
const size = 64 << 10
|
||||
buf := make([]byte, size)
|
||||
buf = buf[:runtime.Stack(buf, false)]
|
||||
log.Errorf("Stack: %s", buf)
|
||||
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/golang/go/blob/a0d6420d8be2ae7164797051ec74fa2a2df466a1/src/net/http/server.go#L1761-L1775
|
||||
// https://github.com/golang/go/blob/c33153f7b416c03983324b3e8f869ce1116d84bc/src/net/http/httputil/reverseproxy.go#L284
|
||||
func shouldLogPanic(panicValue interface{}) bool {
|
||||
return panicValue != nil && panicValue != http.ErrAbortHandler
|
||||
}
|
||||
45
old/middlewares/recover_test.go
Normal file
45
old/middlewares/recover_test.go
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
func TestRecoverHandler(t *testing.T) {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
panic("I love panicing!")
|
||||
}
|
||||
recoverHandler := RecoverHandler(http.HandlerFunc(fn))
|
||||
server := httptest.NewServer(recoverHandler)
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusInternalServerError {
|
||||
t.Fatalf("Received non-%d response: %d\n", http.StatusInternalServerError, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNegroniRecoverHandler(t *testing.T) {
|
||||
n := negroni.New()
|
||||
n.Use(NegroniRecoverHandler())
|
||||
panicHandler := func(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
panic("I love panicing!")
|
||||
}
|
||||
n.UseFunc(negroni.HandlerFunc(panicHandler))
|
||||
server := httptest.NewServer(n)
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusInternalServerError {
|
||||
t.Fatalf("Received non-%d response: %d\n", http.StatusInternalServerError, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
163
old/middlewares/redirect/redirect.go
Normal file
163
old/middlewares/redirect/redirect.go
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
package redirect
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/containous/traefik/old/configuration"
|
||||
"github.com/containous/traefik/old/middlewares"
|
||||
"github.com/urfave/negroni"
|
||||
"github.com/vulcand/oxy/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultRedirectRegex = `^(?:https?:\/\/)?([\w\._-]+)(?::\d+)?(.*)$`
|
||||
)
|
||||
|
||||
// NewEntryPointHandler create a new redirection handler base on entry point
|
||||
func NewEntryPointHandler(dstEntryPoint *configuration.EntryPoint, permanent bool) (negroni.Handler, error) {
|
||||
exp := regexp.MustCompile(`(:\d+)`)
|
||||
match := exp.FindStringSubmatch(dstEntryPoint.Address)
|
||||
if len(match) == 0 {
|
||||
return nil, fmt.Errorf("bad Address format %q", dstEntryPoint.Address)
|
||||
}
|
||||
|
||||
protocol := "http"
|
||||
if dstEntryPoint.TLS != nil {
|
||||
protocol = "https"
|
||||
}
|
||||
|
||||
replacement := protocol + "://${1}" + match[0] + "${2}"
|
||||
|
||||
return NewRegexHandler(defaultRedirectRegex, replacement, permanent)
|
||||
}
|
||||
|
||||
// NewRegexHandler create a new redirection handler base on regex
|
||||
func NewRegexHandler(exp string, replacement string, permanent bool) (negroni.Handler, error) {
|
||||
re, err := regexp.Compile(exp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &handler{
|
||||
regexp: re,
|
||||
replacement: replacement,
|
||||
permanent: permanent,
|
||||
errHandler: utils.DefaultHandler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type handler struct {
|
||||
regexp *regexp.Regexp
|
||||
replacement string
|
||||
permanent bool
|
||||
errHandler utils.ErrorHandler
|
||||
}
|
||||
|
||||
func (h *handler) ServeHTTP(rw http.ResponseWriter, req *http.Request, next http.HandlerFunc) {
|
||||
oldURL := rawURL(req)
|
||||
|
||||
// only continue if the Regexp param matches the URL
|
||||
if !h.regexp.MatchString(oldURL) {
|
||||
next.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
// apply a rewrite regexp to the URL
|
||||
newURL := h.regexp.ReplaceAllString(oldURL, h.replacement)
|
||||
|
||||
// replace any variables that may be in there
|
||||
rewrittenURL := &bytes.Buffer{}
|
||||
if err := applyString(newURL, rewrittenURL, req); err != nil {
|
||||
h.errHandler.ServeHTTP(rw, req, err)
|
||||
return
|
||||
}
|
||||
|
||||
// parse the rewritten URL and replace request URL with it
|
||||
parsedURL, err := url.Parse(rewrittenURL.String())
|
||||
if err != nil {
|
||||
h.errHandler.ServeHTTP(rw, req, err)
|
||||
return
|
||||
}
|
||||
|
||||
if stripPrefix, stripPrefixOk := req.Context().Value(middlewares.StripPrefixKey).(string); stripPrefixOk {
|
||||
if len(stripPrefix) > 0 {
|
||||
parsedURL.Path = stripPrefix
|
||||
}
|
||||
}
|
||||
|
||||
if addPrefix, addPrefixOk := req.Context().Value(middlewares.AddPrefixKey).(string); addPrefixOk {
|
||||
if len(addPrefix) > 0 {
|
||||
parsedURL.Path = strings.Replace(parsedURL.Path, addPrefix, "", 1)
|
||||
}
|
||||
}
|
||||
|
||||
if replacePath, replacePathOk := req.Context().Value(middlewares.ReplacePathKey).(string); replacePathOk {
|
||||
if len(replacePath) > 0 {
|
||||
parsedURL.Path = replacePath
|
||||
}
|
||||
}
|
||||
|
||||
if newURL != oldURL {
|
||||
handler := &moveHandler{location: parsedURL, permanent: h.permanent}
|
||||
handler.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
req.URL = parsedURL
|
||||
|
||||
// make sure the request URI corresponds the rewritten URL
|
||||
req.RequestURI = req.URL.RequestURI()
|
||||
next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
type moveHandler struct {
|
||||
location *url.URL
|
||||
permanent bool
|
||||
}
|
||||
|
||||
func (m *moveHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.Header().Set("Location", m.location.String())
|
||||
status := http.StatusFound
|
||||
if m.permanent {
|
||||
status = http.StatusMovedPermanently
|
||||
}
|
||||
rw.WriteHeader(status)
|
||||
rw.Write([]byte(http.StatusText(status)))
|
||||
}
|
||||
|
||||
func rawURL(request *http.Request) string {
|
||||
scheme := "http"
|
||||
if request.TLS != nil || isXForwardedHTTPS(request) {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
return strings.Join([]string{scheme, "://", request.Host, request.RequestURI}, "")
|
||||
}
|
||||
|
||||
func isXForwardedHTTPS(request *http.Request) bool {
|
||||
xForwardedProto := request.Header.Get("X-Forwarded-Proto")
|
||||
|
||||
return len(xForwardedProto) > 0 && xForwardedProto == "https"
|
||||
}
|
||||
|
||||
func applyString(in string, out io.Writer, request *http.Request) error {
|
||||
t, err := template.New("t").Parse(in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data := struct {
|
||||
Request *http.Request
|
||||
}{
|
||||
Request: request,
|
||||
}
|
||||
|
||||
return t.Execute(out, data)
|
||||
}
|
||||
182
old/middlewares/redirect/redirect_test.go
Normal file
182
old/middlewares/redirect/redirect_test.go
Normal file
|
|
@ -0,0 +1,182 @@
|
|||
package redirect
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/old/configuration"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/containous/traefik/tls"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewEntryPointHandler(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
entryPoint *configuration.EntryPoint
|
||||
permanent bool
|
||||
url string
|
||||
expectedURL string
|
||||
expectedStatus int
|
||||
errorExpected bool
|
||||
}{
|
||||
{
|
||||
desc: "HTTP to HTTPS",
|
||||
entryPoint: &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}},
|
||||
url: "http://foo:80",
|
||||
expectedURL: "https://foo:443",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTPS to HTTP",
|
||||
entryPoint: &configuration.EntryPoint{Address: ":80"},
|
||||
url: "https://foo:443",
|
||||
expectedURL: "http://foo:80",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTP to HTTP",
|
||||
entryPoint: &configuration.EntryPoint{Address: ":88"},
|
||||
url: "http://foo:80",
|
||||
expectedURL: "http://foo:88",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "HTTP to HTTPS permanent",
|
||||
entryPoint: &configuration.EntryPoint{Address: ":443", TLS: &tls.TLS{}},
|
||||
permanent: true,
|
||||
url: "http://foo:80",
|
||||
expectedURL: "https://foo:443",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
desc: "HTTPS to HTTP permanent",
|
||||
entryPoint: &configuration.EntryPoint{Address: ":80"},
|
||||
permanent: true,
|
||||
url: "https://foo:443",
|
||||
expectedURL: "http://foo:80",
|
||||
expectedStatus: http.StatusMovedPermanently,
|
||||
},
|
||||
{
|
||||
desc: "invalid address",
|
||||
entryPoint: &configuration.EntryPoint{Address: ":foo", TLS: &tls.TLS{}},
|
||||
url: "http://foo:80",
|
||||
errorExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := NewEntryPointHandler(test.entryPoint, test.permanent)
|
||||
|
||||
if test.errorExpected {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
r := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
|
||||
handler.ServeHTTP(recorder, r, nil)
|
||||
|
||||
location, err := recorder.Result().Location()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.expectedURL, location.String())
|
||||
assert.Equal(t, test.expectedStatus, recorder.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRegexHandler(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
regex string
|
||||
replacement string
|
||||
permanent bool
|
||||
url string
|
||||
expectedURL string
|
||||
expectedStatus int
|
||||
errorExpected bool
|
||||
}{
|
||||
{
|
||||
desc: "simple redirection",
|
||||
regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
|
||||
replacement: "https://${1}bar$2:443$4",
|
||||
url: "http://foo.com:80",
|
||||
expectedURL: "https://foobar.com:443",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "use request header",
|
||||
regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
|
||||
replacement: `https://${1}{{ .Request.Header.Get "X-Foo" }}$2:443$4`,
|
||||
url: "http://foo.com:80",
|
||||
expectedURL: "https://foobar.com:443",
|
||||
expectedStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
desc: "URL doesn't match regex",
|
||||
regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
|
||||
replacement: "https://${1}bar$2:443$4",
|
||||
url: "http://bar.com:80",
|
||||
expectedStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
desc: "invalid rewritten URL",
|
||||
regex: `^(.*)$`,
|
||||
replacement: "http://192.168.0.%31/",
|
||||
url: "http://foo.com:80",
|
||||
expectedStatus: http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
desc: "invalid regex",
|
||||
regex: `^(.*`,
|
||||
replacement: "$1",
|
||||
url: "http://foo.com:80",
|
||||
errorExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler, err := NewRegexHandler(test.regex, test.replacement, test.permanent)
|
||||
|
||||
if test.errorExpected {
|
||||
require.Nil(t, handler)
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NotNil(t, handler)
|
||||
require.NoError(t, err)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
r := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
|
||||
r.Header.Set("X-Foo", "bar")
|
||||
next := func(rw http.ResponseWriter, req *http.Request) {}
|
||||
handler.ServeHTTP(recorder, r, next)
|
||||
|
||||
if test.expectedStatus == http.StatusMovedPermanently || test.expectedStatus == http.StatusFound {
|
||||
assert.Equal(t, test.expectedStatus, recorder.Code)
|
||||
|
||||
location, err := recorder.Result().Location()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.expectedURL, location.String())
|
||||
} else {
|
||||
assert.Equal(t, test.expectedStatus, recorder.Code)
|
||||
|
||||
location, err := recorder.Result().Location()
|
||||
require.Errorf(t, err, "Location %v", location)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
28
old/middlewares/replace_path.go
Normal file
28
old/middlewares/replace_path.go
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
const (
|
||||
// ReplacePathKey is the key within the request context used to
|
||||
// store the replaced path
|
||||
ReplacePathKey key = "ReplacePath"
|
||||
// ReplacedPathHeader is the default header to set the old path to
|
||||
ReplacedPathHeader = "X-Replaced-Path"
|
||||
)
|
||||
|
||||
// ReplacePath is a middleware used to replace the path of a URL request
|
||||
type ReplacePath struct {
|
||||
Handler http.Handler
|
||||
Path string
|
||||
}
|
||||
|
||||
func (s *ReplacePath) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
r = r.WithContext(context.WithValue(r.Context(), ReplacePathKey, r.URL.Path))
|
||||
r.Header.Add(ReplacedPathHeader, r.URL.Path)
|
||||
r.URL.Path = s.Path
|
||||
r.RequestURI = r.URL.RequestURI()
|
||||
s.Handler.ServeHTTP(w, r)
|
||||
}
|
||||
40
old/middlewares/replace_path_regex.go
Normal file
40
old/middlewares/replace_path_regex.go
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/old/log"
|
||||
)
|
||||
|
||||
// ReplacePathRegex is a middleware used to replace the path of a URL request with a regular expression
|
||||
type ReplacePathRegex struct {
|
||||
Handler http.Handler
|
||||
Regexp *regexp.Regexp
|
||||
Replacement string
|
||||
}
|
||||
|
||||
// NewReplacePathRegexHandler returns a new ReplacePathRegex
|
||||
func NewReplacePathRegexHandler(regex string, replacement string, handler http.Handler) http.Handler {
|
||||
exp, err := regexp.Compile(strings.TrimSpace(regex))
|
||||
if err != nil {
|
||||
log.Errorf("Error compiling regular expression %s: %s", regex, err)
|
||||
}
|
||||
return &ReplacePathRegex{
|
||||
Regexp: exp,
|
||||
Replacement: strings.TrimSpace(replacement),
|
||||
Handler: handler,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ReplacePathRegex) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if s.Regexp != nil && len(s.Replacement) > 0 && s.Regexp.MatchString(r.URL.Path) {
|
||||
r = r.WithContext(context.WithValue(r.Context(), ReplacePathKey, r.URL.Path))
|
||||
r.Header.Add(ReplacedPathHeader, r.URL.Path)
|
||||
r.URL.Path = s.Regexp.ReplaceAllString(r.URL.Path, s.Replacement)
|
||||
r.RequestURI = r.URL.RequestURI()
|
||||
}
|
||||
s.Handler.ServeHTTP(w, r)
|
||||
}
|
||||
80
old/middlewares/replace_path_regex_test.go
Normal file
80
old/middlewares/replace_path_regex_test.go
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestReplacePathRegex(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
path string
|
||||
replacement string
|
||||
regex string
|
||||
expectedPath string
|
||||
expectedHeader string
|
||||
}{
|
||||
{
|
||||
desc: "simple regex",
|
||||
path: "/whoami/and/whoami",
|
||||
replacement: "/who-am-i/$1",
|
||||
regex: `^/whoami/(.*)`,
|
||||
expectedPath: "/who-am-i/and/whoami",
|
||||
expectedHeader: "/whoami/and/whoami",
|
||||
},
|
||||
{
|
||||
desc: "simple replace (no regex)",
|
||||
path: "/whoami/and/whoami",
|
||||
replacement: "/who-am-i",
|
||||
regex: `/whoami`,
|
||||
expectedPath: "/who-am-i/and/who-am-i",
|
||||
expectedHeader: "/whoami/and/whoami",
|
||||
},
|
||||
{
|
||||
desc: "multiple replacement",
|
||||
path: "/downloads/src/source.go",
|
||||
replacement: "/downloads/$1-$2",
|
||||
regex: `^(?i)/downloads/([^/]+)/([^/]+)$`,
|
||||
expectedPath: "/downloads/src-source.go",
|
||||
expectedHeader: "/downloads/src/source.go",
|
||||
},
|
||||
{
|
||||
desc: "invalid regular expression",
|
||||
path: "/invalid/regexp/test",
|
||||
replacement: "/valid/regexp/$1",
|
||||
regex: `^(?err)/invalid/regexp/([^/]+)$`,
|
||||
expectedPath: "/invalid/regexp/test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var actualPath, actualHeader, requestURI string
|
||||
handler := NewReplacePathRegexHandler(
|
||||
test.regex,
|
||||
test.replacement,
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
actualPath = r.URL.Path
|
||||
actualHeader = r.Header.Get(ReplacedPathHeader)
|
||||
requestURI = r.RequestURI
|
||||
}),
|
||||
)
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
|
||||
|
||||
handler.ServeHTTP(nil, req)
|
||||
|
||||
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
|
||||
assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", ReplacedPathHeader)
|
||||
if test.expectedHeader != "" {
|
||||
assert.Equal(t, actualPath, requestURI, "Unexpected request URI.")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
41
old/middlewares/replace_path_test.go
Normal file
41
old/middlewares/replace_path_test.go
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestReplacePath(t *testing.T) {
|
||||
const replacementPath = "/replacement-path"
|
||||
|
||||
paths := []string{
|
||||
"/example",
|
||||
"/some/really/long/path",
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
|
||||
var expectedPath, actualHeader, requestURI string
|
||||
handler := &ReplacePath{
|
||||
Path: replacementPath,
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
expectedPath = r.URL.Path
|
||||
actualHeader = r.Header.Get(ReplacedPathHeader)
|
||||
requestURI = r.RequestURI
|
||||
}),
|
||||
}
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+path, nil)
|
||||
|
||||
handler.ServeHTTP(nil, req)
|
||||
|
||||
assert.Equal(t, expectedPath, replacementPath, "Unexpected path.")
|
||||
assert.Equal(t, path, actualHeader, "Unexpected '%s' header.", ReplacedPathHeader)
|
||||
assert.Equal(t, expectedPath, requestURI, "Unexpected request URI.")
|
||||
})
|
||||
}
|
||||
}
|
||||
45
old/middlewares/request_host.go
Normal file
45
old/middlewares/request_host.go
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/containous/traefik/old/types"
|
||||
)
|
||||
|
||||
var requestHostKey struct{}
|
||||
|
||||
// RequestHost is the struct for the middleware that adds the CanonicalDomain of the request Host into a context for later use.
|
||||
type RequestHost struct{}
|
||||
|
||||
func (rh *RequestHost) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
if next != nil {
|
||||
host := types.CanonicalDomain(parseHost(r.Host))
|
||||
next.ServeHTTP(rw, r.WithContext(context.WithValue(r.Context(), requestHostKey, host)))
|
||||
}
|
||||
}
|
||||
|
||||
func parseHost(addr string) string {
|
||||
if !strings.Contains(addr, ":") {
|
||||
return addr
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return addr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
// GetCanonizedHost plucks the canonized host key from the request of a context that was put through the middleware
|
||||
func GetCanonizedHost(ctx context.Context) string {
|
||||
if val, ok := ctx.Value(requestHostKey).(string); ok {
|
||||
return val
|
||||
}
|
||||
|
||||
log.Warn("RequestHost is missing in the middleware chain")
|
||||
return ""
|
||||
}
|
||||
94
old/middlewares/request_host_test.go
Normal file
94
old/middlewares/request_host_test.go
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRequestHost(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
url string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "host without :",
|
||||
url: "http://host",
|
||||
expected: "host",
|
||||
},
|
||||
{
|
||||
desc: "host with : and without port",
|
||||
url: "http://host:",
|
||||
expected: "host",
|
||||
},
|
||||
{
|
||||
desc: "IP host with : and with port",
|
||||
url: "http://127.0.0.1:123",
|
||||
expected: "127.0.0.1",
|
||||
},
|
||||
{
|
||||
desc: "IP host with : and without port",
|
||||
url: "http://127.0.0.1:",
|
||||
expected: "127.0.0.1",
|
||||
},
|
||||
}
|
||||
|
||||
rh := &RequestHost{}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, test.url, nil)
|
||||
|
||||
rh.ServeHTTP(nil, req, func(_ http.ResponseWriter, r *http.Request) {
|
||||
host := GetCanonizedHost(r.Context())
|
||||
assert.Equal(t, test.expected, host)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestHostParseHost(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
host string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "host without :",
|
||||
host: "host",
|
||||
expected: "host",
|
||||
},
|
||||
{
|
||||
desc: "host with : and without port",
|
||||
host: "host:",
|
||||
expected: "host",
|
||||
},
|
||||
{
|
||||
desc: "IP host with : and with port",
|
||||
host: "127.0.0.1:123",
|
||||
expected: "127.0.0.1",
|
||||
},
|
||||
{
|
||||
desc: "IP host with : and without port",
|
||||
host: "127.0.0.1:",
|
||||
expected: "127.0.0.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
actual := parseHost(test.host)
|
||||
|
||||
assert.Equal(t, test.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
173
old/middlewares/retry.go
Normal file
173
old/middlewares/retry.go
Normal file
|
|
@ -0,0 +1,173 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
|
||||
"github.com/containous/traefik/old/log"
|
||||
)
|
||||
|
||||
// Compile time validation that the response writer implements http interfaces correctly.
|
||||
var _ Stateful = &retryResponseWriterWithCloseNotify{}
|
||||
|
||||
// Retry is a middleware that retries requests
|
||||
type Retry struct {
|
||||
attempts int
|
||||
next http.Handler
|
||||
listener RetryListener
|
||||
}
|
||||
|
||||
// NewRetry returns a new Retry instance
|
||||
func NewRetry(attempts int, next http.Handler, listener RetryListener) *Retry {
|
||||
return &Retry{
|
||||
attempts: attempts,
|
||||
next: next,
|
||||
listener: listener,
|
||||
}
|
||||
}
|
||||
|
||||
func (retry *Retry) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
// if we might make multiple attempts, swap the body for an ioutil.NopCloser
|
||||
// cf https://github.com/containous/traefik/issues/1008
|
||||
if retry.attempts > 1 {
|
||||
body := r.Body
|
||||
if body == nil {
|
||||
body = http.NoBody
|
||||
}
|
||||
defer body.Close()
|
||||
r.Body = ioutil.NopCloser(body)
|
||||
}
|
||||
|
||||
attempts := 1
|
||||
for {
|
||||
attemptsExhausted := attempts >= retry.attempts
|
||||
|
||||
shouldRetry := !attemptsExhausted
|
||||
retryResponseWriter := newRetryResponseWriter(rw, shouldRetry)
|
||||
|
||||
// Disable retries when the backend already received request data
|
||||
trace := &httptrace.ClientTrace{
|
||||
WroteHeaders: func() {
|
||||
retryResponseWriter.DisableRetries()
|
||||
},
|
||||
WroteRequest: func(httptrace.WroteRequestInfo) {
|
||||
retryResponseWriter.DisableRetries()
|
||||
},
|
||||
}
|
||||
newCtx := httptrace.WithClientTrace(r.Context(), trace)
|
||||
|
||||
retry.next.ServeHTTP(retryResponseWriter, r.WithContext(newCtx))
|
||||
if !retryResponseWriter.ShouldRetry() {
|
||||
break
|
||||
}
|
||||
|
||||
attempts++
|
||||
log.Debugf("New attempt %d for request: %v", attempts, r.URL)
|
||||
retry.listener.Retried(r, attempts)
|
||||
}
|
||||
}
|
||||
|
||||
// RetryListener is used to inform about retry attempts.
|
||||
type RetryListener interface {
|
||||
// Retried will be called when a retry happens, with the request attempt passed to it.
|
||||
// For the first retry this will be attempt 2.
|
||||
Retried(req *http.Request, attempt int)
|
||||
}
|
||||
|
||||
// RetryListeners is a convenience type to construct a list of RetryListener and notify
|
||||
// each of them about a retry attempt.
|
||||
type RetryListeners []RetryListener
|
||||
|
||||
// Retried exists to implement the RetryListener interface. It calls Retried on each of its slice entries.
|
||||
func (l RetryListeners) Retried(req *http.Request, attempt int) {
|
||||
for _, retryListener := range l {
|
||||
retryListener.Retried(req, attempt)
|
||||
}
|
||||
}
|
||||
|
||||
type retryResponseWriter interface {
|
||||
http.ResponseWriter
|
||||
http.Flusher
|
||||
ShouldRetry() bool
|
||||
DisableRetries()
|
||||
}
|
||||
|
||||
func newRetryResponseWriter(rw http.ResponseWriter, shouldRetry bool) retryResponseWriter {
|
||||
responseWriter := &retryResponseWriterWithoutCloseNotify{
|
||||
responseWriter: rw,
|
||||
shouldRetry: shouldRetry,
|
||||
}
|
||||
if _, ok := rw.(http.CloseNotifier); ok {
|
||||
return &retryResponseWriterWithCloseNotify{responseWriter}
|
||||
}
|
||||
return responseWriter
|
||||
}
|
||||
|
||||
type retryResponseWriterWithoutCloseNotify struct {
|
||||
responseWriter http.ResponseWriter
|
||||
shouldRetry bool
|
||||
}
|
||||
|
||||
func (rr *retryResponseWriterWithoutCloseNotify) ShouldRetry() bool {
|
||||
return rr.shouldRetry
|
||||
}
|
||||
|
||||
func (rr *retryResponseWriterWithoutCloseNotify) DisableRetries() {
|
||||
rr.shouldRetry = false
|
||||
}
|
||||
|
||||
func (rr *retryResponseWriterWithoutCloseNotify) Header() http.Header {
|
||||
if rr.ShouldRetry() {
|
||||
return make(http.Header)
|
||||
}
|
||||
return rr.responseWriter.Header()
|
||||
}
|
||||
|
||||
func (rr *retryResponseWriterWithoutCloseNotify) Write(buf []byte) (int, error) {
|
||||
if rr.ShouldRetry() {
|
||||
return len(buf), nil
|
||||
}
|
||||
return rr.responseWriter.Write(buf)
|
||||
}
|
||||
|
||||
func (rr *retryResponseWriterWithoutCloseNotify) WriteHeader(code int) {
|
||||
if rr.ShouldRetry() && code == http.StatusServiceUnavailable {
|
||||
// We get a 503 HTTP Status Code when there is no backend server in the pool
|
||||
// to which the request could be sent. Also, note that rr.ShouldRetry()
|
||||
// will never return true in case there was a connection established to
|
||||
// the backend server and so we can be sure that the 503 was produced
|
||||
// inside Traefik already and we don't have to retry in this cases.
|
||||
rr.DisableRetries()
|
||||
}
|
||||
|
||||
if rr.ShouldRetry() {
|
||||
return
|
||||
}
|
||||
rr.responseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (rr *retryResponseWriterWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
hijacker, ok := rr.responseWriter.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("%T is not a http.Hijacker", rr.responseWriter)
|
||||
}
|
||||
return hijacker.Hijack()
|
||||
}
|
||||
|
||||
func (rr *retryResponseWriterWithoutCloseNotify) Flush() {
|
||||
if flusher, ok := rr.responseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
type retryResponseWriterWithCloseNotify struct {
|
||||
*retryResponseWriterWithoutCloseNotify
|
||||
}
|
||||
|
||||
func (rr *retryResponseWriterWithCloseNotify) CloseNotify() <-chan bool {
|
||||
return rr.responseWriter.(http.CloseNotifier).CloseNotify()
|
||||
}
|
||||
260
old/middlewares/retry_test.go
Normal file
260
old/middlewares/retry_test.go
Normal file
|
|
@ -0,0 +1,260 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/vulcand/oxy/forward"
|
||||
"github.com/vulcand/oxy/roundrobin"
|
||||
)
|
||||
|
||||
func TestRetry(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
maxRequestAttempts int
|
||||
wantRetryAttempts int
|
||||
wantResponseStatus int
|
||||
amountFaultyEndpoints int
|
||||
}{
|
||||
{
|
||||
desc: "no retry on success",
|
||||
maxRequestAttempts: 1,
|
||||
wantRetryAttempts: 0,
|
||||
wantResponseStatus: http.StatusOK,
|
||||
amountFaultyEndpoints: 0,
|
||||
},
|
||||
{
|
||||
desc: "no retry when max request attempts is one",
|
||||
maxRequestAttempts: 1,
|
||||
wantRetryAttempts: 0,
|
||||
wantResponseStatus: http.StatusInternalServerError,
|
||||
amountFaultyEndpoints: 1,
|
||||
},
|
||||
{
|
||||
desc: "one retry when one server is faulty",
|
||||
maxRequestAttempts: 2,
|
||||
wantRetryAttempts: 1,
|
||||
wantResponseStatus: http.StatusOK,
|
||||
amountFaultyEndpoints: 1,
|
||||
},
|
||||
{
|
||||
desc: "two retries when two servers are faulty",
|
||||
maxRequestAttempts: 3,
|
||||
wantRetryAttempts: 2,
|
||||
wantResponseStatus: http.StatusOK,
|
||||
amountFaultyEndpoints: 2,
|
||||
},
|
||||
{
|
||||
desc: "max attempts exhausted delivers the 5xx response",
|
||||
maxRequestAttempts: 3,
|
||||
wantRetryAttempts: 2,
|
||||
wantResponseStatus: http.StatusInternalServerError,
|
||||
amountFaultyEndpoints: 3,
|
||||
},
|
||||
}
|
||||
|
||||
backendServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
rw.Write([]byte("OK"))
|
||||
}))
|
||||
|
||||
forwarder, err := forward.New()
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating forwarder: %s", err)
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
loadBalancer, err := roundrobin.New(forwarder)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating load balancer: %s", err)
|
||||
}
|
||||
|
||||
basePort := 33444
|
||||
for i := 0; i < test.amountFaultyEndpoints; i++ {
|
||||
// 192.0.2.0 is a non-routable IP for testing purposes.
|
||||
// See: https://stackoverflow.com/questions/528538/non-routable-ip-address/18436928#18436928
|
||||
// We only use the port specification here because the URL is used as identifier
|
||||
// in the load balancer and using the exact same URL would not add a new server.
|
||||
err = loadBalancer.UpsertServer(testhelpers.MustParseURL("http://192.0.2.0:" + string(basePort+i)))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// add the functioning server to the end of the load balancer list
|
||||
err = loadBalancer.UpsertServer(testhelpers.MustParseURL(backendServer.URL))
|
||||
assert.NoError(t, err)
|
||||
|
||||
retryListener := &countingRetryListener{}
|
||||
retry := NewRetry(test.maxRequestAttempts, loadBalancer, retryListener)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil)
|
||||
|
||||
retry.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, test.wantResponseStatus, recorder.Code)
|
||||
assert.Equal(t, test.wantRetryAttempts, retryListener.timesCalled)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryWebsocket(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
maxRequestAttempts int
|
||||
expectedRetryAttempts int
|
||||
expectedResponseStatus int
|
||||
expectedError bool
|
||||
amountFaultyEndpoints int
|
||||
}{
|
||||
{
|
||||
desc: "Switching ok after 2 retries",
|
||||
maxRequestAttempts: 3,
|
||||
expectedRetryAttempts: 2,
|
||||
amountFaultyEndpoints: 2,
|
||||
expectedResponseStatus: http.StatusSwitchingProtocols,
|
||||
},
|
||||
{
|
||||
desc: "Switching failed",
|
||||
maxRequestAttempts: 2,
|
||||
expectedRetryAttempts: 1,
|
||||
amountFaultyEndpoints: 2,
|
||||
expectedResponseStatus: http.StatusBadGateway,
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
forwarder, err := forward.New()
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating forwarder: %s", err)
|
||||
}
|
||||
|
||||
backendServer := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
upgrader := websocket.Upgrader{}
|
||||
upgrader.Upgrade(rw, req, nil)
|
||||
}))
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
loadBalancer, err := roundrobin.New(forwarder)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating load balancer: %s", err)
|
||||
}
|
||||
|
||||
basePort := 33444
|
||||
for i := 0; i < test.amountFaultyEndpoints; i++ {
|
||||
// 192.0.2.0 is a non-routable IP for testing purposes.
|
||||
// See: https://stackoverflow.com/questions/528538/non-routable-ip-address/18436928#18436928
|
||||
// We only use the port specification here because the URL is used as identifier
|
||||
// in the load balancer and using the exact same URL would not add a new server.
|
||||
loadBalancer.UpsertServer(testhelpers.MustParseURL("http://192.0.2.0:" + string(basePort+i)))
|
||||
}
|
||||
|
||||
// add the functioning server to the end of the load balancer list
|
||||
loadBalancer.UpsertServer(testhelpers.MustParseURL(backendServer.URL))
|
||||
|
||||
retryListener := &countingRetryListener{}
|
||||
retry := NewRetry(test.maxRequestAttempts, loadBalancer, retryListener)
|
||||
|
||||
retryServer := httptest.NewServer(retry)
|
||||
|
||||
url := strings.Replace(retryServer.URL, "http", "ws", 1)
|
||||
_, response, err := websocket.DefaultDialer.Dial(url, nil)
|
||||
|
||||
if !test.expectedError {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, test.expectedResponseStatus, response.StatusCode)
|
||||
assert.Equal(t, test.expectedRetryAttempts, retryListener.timesCalled)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryEmptyServerList(t *testing.T) {
|
||||
forwarder, err := forward.New()
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating forwarder: %s", err)
|
||||
}
|
||||
|
||||
loadBalancer, err := roundrobin.New(forwarder)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating load balancer: %s", err)
|
||||
}
|
||||
|
||||
// The EmptyBackendHandler middleware ensures that there is a 503
|
||||
// response status set when there is no backend server in the pool.
|
||||
next := NewEmptyBackendHandler(loadBalancer)
|
||||
|
||||
retryListener := &countingRetryListener{}
|
||||
retry := NewRetry(3, next, retryListener)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "http://localhost:3000/ok", nil)
|
||||
|
||||
retry.ServeHTTP(recorder, req)
|
||||
|
||||
const wantResponseStatus = http.StatusServiceUnavailable
|
||||
if wantResponseStatus != recorder.Code {
|
||||
t.Errorf("got status code %d, want %d", recorder.Code, wantResponseStatus)
|
||||
}
|
||||
const wantRetryAttempts = 0
|
||||
if wantRetryAttempts != retryListener.timesCalled {
|
||||
t.Errorf("retry listener called %d time(s), want %d time(s)", retryListener.timesCalled, wantRetryAttempts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryListeners(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
retryListeners := RetryListeners{&countingRetryListener{}, &countingRetryListener{}}
|
||||
|
||||
retryListeners.Retried(req, 1)
|
||||
retryListeners.Retried(req, 1)
|
||||
|
||||
for _, retryListener := range retryListeners {
|
||||
listener := retryListener.(*countingRetryListener)
|
||||
if listener.timesCalled != 2 {
|
||||
t.Errorf("retry listener was called %d time(s), want %d time(s)", listener.timesCalled, 2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// countingRetryListener is a RetryListener implementation to count the times the Retried fn is called.
|
||||
type countingRetryListener struct {
|
||||
timesCalled int
|
||||
}
|
||||
|
||||
func (l *countingRetryListener) Retried(req *http.Request, attempt int) {
|
||||
l.timesCalled++
|
||||
}
|
||||
|
||||
func TestRetryWithFlush(t *testing.T) {
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(200)
|
||||
rw.Write([]byte("FULL "))
|
||||
rw.(http.Flusher).Flush()
|
||||
rw.Write([]byte("DATA"))
|
||||
})
|
||||
|
||||
retry := NewRetry(1, next, &countingRetryListener{})
|
||||
responseRecorder := httptest.NewRecorder()
|
||||
|
||||
retry.ServeHTTP(responseRecorder, &http.Request{})
|
||||
|
||||
if responseRecorder.Body.String() != "FULL DATA" {
|
||||
t.Errorf("Wrong body %q want %q", responseRecorder.Body.String(), "FULL DATA")
|
||||
}
|
||||
}
|
||||
28
old/middlewares/routes.go
Normal file
28
old/middlewares/routes.go
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/mux"
|
||||
)
|
||||
|
||||
// Routes holds the gorilla mux routes (for the API & co).
|
||||
type Routes struct {
|
||||
router *mux.Router
|
||||
}
|
||||
|
||||
// NewRoutes return a Routes based on the given router.
|
||||
func NewRoutes(router *mux.Router) *Routes {
|
||||
return &Routes{router}
|
||||
}
|
||||
|
||||
func (router *Routes) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
routeMatch := mux.RouteMatch{}
|
||||
if router.router.Match(r, &routeMatch) {
|
||||
rt, _ := json.Marshal(routeMatch.Handler)
|
||||
log.Println("Request match route ", rt)
|
||||
}
|
||||
next(rw, r)
|
||||
}
|
||||
36
old/middlewares/secure.go
Normal file
36
old/middlewares/secure.go
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/unrolled/secure"
|
||||
)
|
||||
|
||||
// NewSecure constructs a new Secure instance with supplied options.
|
||||
func NewSecure(headers *types.Headers) *secure.Secure {
|
||||
if headers == nil || !headers.HasSecureHeadersDefined() {
|
||||
return nil
|
||||
}
|
||||
|
||||
opt := secure.Options{
|
||||
AllowedHosts: headers.AllowedHosts,
|
||||
HostsProxyHeaders: headers.HostsProxyHeaders,
|
||||
SSLRedirect: headers.SSLRedirect,
|
||||
SSLTemporaryRedirect: headers.SSLTemporaryRedirect,
|
||||
SSLHost: headers.SSLHost,
|
||||
SSLProxyHeaders: headers.SSLProxyHeaders,
|
||||
STSSeconds: headers.STSSeconds,
|
||||
STSIncludeSubdomains: headers.STSIncludeSubdomains,
|
||||
STSPreload: headers.STSPreload,
|
||||
ForceSTSHeader: headers.ForceSTSHeader,
|
||||
FrameDeny: headers.FrameDeny,
|
||||
CustomFrameOptionsValue: headers.CustomFrameOptionsValue,
|
||||
ContentTypeNosniff: headers.ContentTypeNosniff,
|
||||
BrowserXssFilter: headers.BrowserXSSFilter,
|
||||
CustomBrowserXssValue: headers.CustomBrowserXSSValue,
|
||||
ContentSecurityPolicy: headers.ContentSecurityPolicy,
|
||||
PublicKey: headers.PublicKey,
|
||||
ReferrerPolicy: headers.ReferrerPolicy,
|
||||
IsDevelopment: headers.IsDevelopment,
|
||||
}
|
||||
return secure.New(opt)
|
||||
}
|
||||
12
old/middlewares/stateful.go
Normal file
12
old/middlewares/stateful.go
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
package middlewares
|
||||
|
||||
import "net/http"
|
||||
|
||||
// Stateful interface groups all http interfaces that must be
|
||||
// implemented by a stateful middleware (ie: recorders)
|
||||
type Stateful interface {
|
||||
http.ResponseWriter
|
||||
http.Hijacker
|
||||
http.Flusher
|
||||
http.CloseNotifier
|
||||
}
|
||||
115
old/middlewares/stats.go
Normal file
115
old/middlewares/stats.go
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
_ Stateful = &responseRecorder{}
|
||||
)
|
||||
|
||||
// StatsRecorder is an optional middleware that records more details statistics
|
||||
// about requests and how they are processed. This currently consists of recent
|
||||
// requests that have caused errors (4xx and 5xx status codes), making it easy
|
||||
// to pinpoint problems.
|
||||
type StatsRecorder struct {
|
||||
mutex sync.RWMutex
|
||||
numRecentErrors int
|
||||
recentErrors []*statsError
|
||||
}
|
||||
|
||||
// NewStatsRecorder returns a new StatsRecorder
|
||||
func NewStatsRecorder(numRecentErrors int) *StatsRecorder {
|
||||
return &StatsRecorder{
|
||||
numRecentErrors: numRecentErrors,
|
||||
}
|
||||
}
|
||||
|
||||
// Stats includes all of the stats gathered by the recorder.
|
||||
type Stats struct {
|
||||
RecentErrors []*statsError `json:"recent_errors"`
|
||||
}
|
||||
|
||||
// statsError represents an error that has occurred during request processing.
|
||||
type statsError struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
Status string `json:"status"`
|
||||
Method string `json:"method"`
|
||||
Host string `json:"host"`
|
||||
Path string `json:"path"`
|
||||
Time time.Time `json:"time"`
|
||||
}
|
||||
|
||||
// responseRecorder captures information from the response and preserves it for
|
||||
// later analysis.
|
||||
type responseRecorder struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
// WriteHeader captures the status code for later retrieval.
|
||||
func (r *responseRecorder) WriteHeader(status int) {
|
||||
r.ResponseWriter.WriteHeader(status)
|
||||
r.statusCode = status
|
||||
}
|
||||
|
||||
// Hijack hijacks the connection
|
||||
func (r *responseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return r.ResponseWriter.(http.Hijacker).Hijack()
|
||||
}
|
||||
|
||||
// CloseNotify returns a channel that receives at most a
|
||||
// single value (true) when the client connection has gone
|
||||
// away.
|
||||
func (r *responseRecorder) CloseNotify() <-chan bool {
|
||||
return r.ResponseWriter.(http.CloseNotifier).CloseNotify()
|
||||
}
|
||||
|
||||
// Flush sends any buffered data to the client.
|
||||
func (r *responseRecorder) Flush() {
|
||||
r.ResponseWriter.(http.Flusher).Flush()
|
||||
}
|
||||
|
||||
// ServeHTTP silently extracts information from the request and response as it
|
||||
// is processed. If the response is 4xx or 5xx, add it to the list of 10 most
|
||||
// recent errors.
|
||||
func (s *StatsRecorder) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
recorder := &responseRecorder{w, http.StatusOK}
|
||||
next(recorder, r)
|
||||
if recorder.statusCode >= http.StatusBadRequest {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
s.recentErrors = append([]*statsError{
|
||||
{
|
||||
StatusCode: recorder.statusCode,
|
||||
Status: http.StatusText(recorder.statusCode),
|
||||
Method: r.Method,
|
||||
Host: r.Host,
|
||||
Path: r.URL.Path,
|
||||
Time: time.Now(),
|
||||
},
|
||||
}, s.recentErrors...)
|
||||
// Limit the size of the list to numRecentErrors
|
||||
if len(s.recentErrors) > s.numRecentErrors {
|
||||
s.recentErrors = s.recentErrors[:s.numRecentErrors]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Data returns a copy of the statistics that have been gathered.
|
||||
func (s *StatsRecorder) Data() *Stats {
|
||||
s.mutex.RLock()
|
||||
defer s.mutex.RUnlock()
|
||||
|
||||
// We can't return the slice directly or a race condition might develop
|
||||
recentErrors := make([]*statsError, len(s.recentErrors))
|
||||
copy(recentErrors, s.recentErrors)
|
||||
|
||||
return &Stats{
|
||||
RecentErrors: recentErrors,
|
||||
}
|
||||
}
|
||||
56
old/middlewares/stripPrefix.go
Normal file
56
old/middlewares/stripPrefix.go
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// StripPrefixKey is the key within the request context used to
|
||||
// store the stripped prefix
|
||||
StripPrefixKey key = "StripPrefix"
|
||||
// ForwardedPrefixHeader is the default header to set prefix
|
||||
ForwardedPrefixHeader = "X-Forwarded-Prefix"
|
||||
)
|
||||
|
||||
// StripPrefix is a middleware used to strip prefix from an URL request
|
||||
type StripPrefix struct {
|
||||
Handler http.Handler
|
||||
Prefixes []string
|
||||
}
|
||||
|
||||
func (s *StripPrefix) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
for _, prefix := range s.Prefixes {
|
||||
if strings.HasPrefix(r.URL.Path, prefix) {
|
||||
rawReqPath := r.URL.Path
|
||||
r.URL.Path = stripPrefix(r.URL.Path, prefix)
|
||||
if r.URL.RawPath != "" {
|
||||
r.URL.RawPath = stripPrefix(r.URL.RawPath, prefix)
|
||||
}
|
||||
s.serveRequest(w, r, strings.TrimSpace(prefix), rawReqPath)
|
||||
return
|
||||
}
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
|
||||
func (s *StripPrefix) serveRequest(w http.ResponseWriter, r *http.Request, prefix string, rawReqPath string) {
|
||||
r = r.WithContext(context.WithValue(r.Context(), StripPrefixKey, rawReqPath))
|
||||
r.Header.Add(ForwardedPrefixHeader, prefix)
|
||||
r.RequestURI = r.URL.RequestURI()
|
||||
s.Handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// SetHandler sets handler
|
||||
func (s *StripPrefix) SetHandler(Handler http.Handler) {
|
||||
s.Handler = Handler
|
||||
}
|
||||
|
||||
func stripPrefix(s, prefix string) string {
|
||||
return ensureLeadingSlash(strings.TrimPrefix(s, prefix))
|
||||
}
|
||||
|
||||
func ensureLeadingSlash(str string) string {
|
||||
return "/" + strings.TrimPrefix(str, "/")
|
||||
}
|
||||
59
old/middlewares/stripPrefixRegex.go
Normal file
59
old/middlewares/stripPrefixRegex.go
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/old/log"
|
||||
)
|
||||
|
||||
// StripPrefixRegex is a middleware used to strip prefix from an URL request
|
||||
type StripPrefixRegex struct {
|
||||
Handler http.Handler
|
||||
router *mux.Router
|
||||
}
|
||||
|
||||
// NewStripPrefixRegex builds a new StripPrefixRegex given a handler and prefixes
|
||||
func NewStripPrefixRegex(handler http.Handler, prefixes []string) *StripPrefixRegex {
|
||||
stripPrefix := StripPrefixRegex{Handler: handler, router: mux.NewRouter()}
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
stripPrefix.router.PathPrefix(prefix)
|
||||
}
|
||||
|
||||
return &stripPrefix
|
||||
}
|
||||
|
||||
func (s *StripPrefixRegex) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
var match mux.RouteMatch
|
||||
if s.router.Match(r, &match) {
|
||||
params := make([]string, 0, len(match.Vars)*2)
|
||||
for key, val := range match.Vars {
|
||||
params = append(params, key)
|
||||
params = append(params, val)
|
||||
}
|
||||
|
||||
prefix, err := match.Route.URL(params...)
|
||||
if err != nil || len(prefix.Path) > len(r.URL.Path) {
|
||||
log.Error("Error in stripPrefix middleware", err)
|
||||
return
|
||||
}
|
||||
rawReqPath := r.URL.Path
|
||||
r.URL.Path = r.URL.Path[len(prefix.Path):]
|
||||
if r.URL.RawPath != "" {
|
||||
r.URL.RawPath = r.URL.RawPath[len(prefix.Path):]
|
||||
}
|
||||
r = r.WithContext(context.WithValue(r.Context(), StripPrefixKey, rawReqPath))
|
||||
r.Header.Add(ForwardedPrefixHeader, prefix.Path)
|
||||
r.RequestURI = ensureLeadingSlash(r.URL.RequestURI())
|
||||
s.Handler.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
|
||||
// SetHandler sets handler
|
||||
func (s *StripPrefixRegex) SetHandler(Handler http.Handler) {
|
||||
s.Handler = Handler
|
||||
}
|
||||
97
old/middlewares/stripPrefixRegex_test.go
Normal file
97
old/middlewares/stripPrefixRegex_test.go
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStripPrefixRegex(t *testing.T) {
|
||||
testPrefixRegex := []string{"/a/api/", "/b/{regex}/", "/c/{category}/{id:[0-9]+}/"}
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
expectedStatusCode int
|
||||
expectedPath string
|
||||
expectedRawPath string
|
||||
expectedHeader string
|
||||
}{
|
||||
{
|
||||
path: "/a/test",
|
||||
expectedStatusCode: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
path: "/a/api/test",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "test",
|
||||
expectedHeader: "/a/api/",
|
||||
},
|
||||
{
|
||||
path: "/b/api/",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedHeader: "/b/api/",
|
||||
},
|
||||
{
|
||||
path: "/b/api/test1",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "test1",
|
||||
expectedHeader: "/b/api/",
|
||||
},
|
||||
{
|
||||
path: "/b/api2/test2",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "test2",
|
||||
expectedHeader: "/b/api2/",
|
||||
},
|
||||
{
|
||||
path: "/c/api/123/",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedHeader: "/c/api/123/",
|
||||
},
|
||||
{
|
||||
path: "/c/api/123/test3",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "test3",
|
||||
expectedHeader: "/c/api/123/",
|
||||
},
|
||||
{
|
||||
path: "/c/api/abc/test4",
|
||||
expectedStatusCode: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
path: "/a/api/a%2Fb",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "a/b",
|
||||
expectedRawPath: "a%2Fb",
|
||||
expectedHeader: "/a/api/",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.path, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var actualPath, actualRawPath, actualHeader string
|
||||
handlerPath := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
actualPath = r.URL.Path
|
||||
actualRawPath = r.URL.RawPath
|
||||
actualHeader = r.Header.Get(ForwardedPrefixHeader)
|
||||
})
|
||||
handler := NewStripPrefixRegex(handlerPath, testPrefixRegex)
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
|
||||
resp := &httptest.ResponseRecorder{Code: http.StatusOK}
|
||||
|
||||
handler.ServeHTTP(resp, req)
|
||||
|
||||
assert.Equal(t, test.expectedStatusCode, resp.Code, "Unexpected status code.")
|
||||
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
|
||||
assert.Equal(t, test.expectedRawPath, actualRawPath, "Unexpected raw path.")
|
||||
assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", ForwardedPrefixHeader)
|
||||
})
|
||||
}
|
||||
}
|
||||
143
old/middlewares/stripPrefix_test.go
Normal file
143
old/middlewares/stripPrefix_test.go
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStripPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
desc string
|
||||
prefixes []string
|
||||
path string
|
||||
expectedStatusCode int
|
||||
expectedPath string
|
||||
expectedRawPath string
|
||||
expectedHeader string
|
||||
}{
|
||||
{
|
||||
desc: "no prefixes configured",
|
||||
prefixes: []string{},
|
||||
path: "/noprefixes",
|
||||
expectedStatusCode: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
desc: "wildcard (.*) requests",
|
||||
prefixes: []string{"/"},
|
||||
path: "/",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "/",
|
||||
expectedHeader: "/",
|
||||
},
|
||||
{
|
||||
desc: "prefix and path matching",
|
||||
prefixes: []string{"/stat"},
|
||||
path: "/stat",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "/",
|
||||
expectedHeader: "/stat",
|
||||
},
|
||||
{
|
||||
desc: "path prefix on exactly matching path",
|
||||
prefixes: []string{"/stat/"},
|
||||
path: "/stat/",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "/",
|
||||
expectedHeader: "/stat/",
|
||||
},
|
||||
{
|
||||
desc: "path prefix on matching longer path",
|
||||
prefixes: []string{"/stat/"},
|
||||
path: "/stat/us",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "/us",
|
||||
expectedHeader: "/stat/",
|
||||
},
|
||||
{
|
||||
desc: "path prefix on mismatching path",
|
||||
prefixes: []string{"/stat/"},
|
||||
path: "/status",
|
||||
expectedStatusCode: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
desc: "general prefix on matching path",
|
||||
prefixes: []string{"/stat"},
|
||||
path: "/stat/",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "/",
|
||||
expectedHeader: "/stat",
|
||||
},
|
||||
{
|
||||
desc: "earlier prefix matching",
|
||||
prefixes: []string{"/stat", "/stat/us"},
|
||||
path: "/stat/us",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "/us",
|
||||
expectedHeader: "/stat",
|
||||
},
|
||||
{
|
||||
desc: "later prefix matching",
|
||||
prefixes: []string{"/mismatch", "/stat"},
|
||||
path: "/stat",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "/",
|
||||
expectedHeader: "/stat",
|
||||
},
|
||||
{
|
||||
desc: "prefix matching within slash boundaries",
|
||||
prefixes: []string{"/stat"},
|
||||
path: "/status",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "/us",
|
||||
expectedHeader: "/stat",
|
||||
},
|
||||
{
|
||||
desc: "raw path is also stripped",
|
||||
prefixes: []string{"/stat"},
|
||||
path: "/stat/a%2Fb",
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedPath: "/a/b",
|
||||
expectedRawPath: "/a%2Fb",
|
||||
expectedHeader: "/stat",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var actualPath, actualRawPath, actualHeader, requestURI string
|
||||
handler := &StripPrefix{
|
||||
Prefixes: test.prefixes,
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
actualPath = r.URL.Path
|
||||
actualRawPath = r.URL.RawPath
|
||||
actualHeader = r.Header.Get(ForwardedPrefixHeader)
|
||||
requestURI = r.RequestURI
|
||||
}),
|
||||
}
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost"+test.path, nil)
|
||||
resp := &httptest.ResponseRecorder{Code: http.StatusOK}
|
||||
|
||||
handler.ServeHTTP(resp, req)
|
||||
|
||||
assert.Equal(t, test.expectedStatusCode, resp.Code, "Unexpected status code.")
|
||||
assert.Equal(t, test.expectedPath, actualPath, "Unexpected path.")
|
||||
assert.Equal(t, test.expectedRawPath, actualRawPath, "Unexpected raw path.")
|
||||
assert.Equal(t, test.expectedHeader, actualHeader, "Unexpected '%s' header.", ForwardedPrefixHeader)
|
||||
|
||||
expectedURI := test.expectedPath
|
||||
if test.expectedRawPath != "" {
|
||||
// go HTTP uses the raw path when existent in the RequestURI
|
||||
expectedURI = test.expectedRawPath
|
||||
}
|
||||
assert.Equal(t, expectedURI, requestURI, "Unexpected request URI.")
|
||||
})
|
||||
}
|
||||
}
|
||||
251
old/middlewares/tlsClientHeaders.go
Normal file
251
old/middlewares/tlsClientHeaders.go
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/containous/traefik/old/types"
|
||||
)
|
||||
|
||||
const xForwardedTLSClientCert = "X-Forwarded-Tls-Client-Cert"
|
||||
const xForwardedTLSClientCertInfos = "X-Forwarded-Tls-Client-Cert-Infos"
|
||||
|
||||
// TLSClientCertificateInfos is a struct for specifying the configuration for the tlsClientHeaders middleware.
|
||||
type TLSClientCertificateInfos struct {
|
||||
NotAfter bool
|
||||
NotBefore bool
|
||||
Subject *TLSCLientCertificateSubjectInfos
|
||||
Sans bool
|
||||
}
|
||||
|
||||
// TLSCLientCertificateSubjectInfos contains the configuration for the certificate subject infos.
|
||||
type TLSCLientCertificateSubjectInfos struct {
|
||||
Country bool
|
||||
Province bool
|
||||
Locality bool
|
||||
Organization bool
|
||||
CommonName bool
|
||||
SerialNumber bool
|
||||
}
|
||||
|
||||
// TLSClientHeaders is a middleware that helps setup a few tls infos features.
|
||||
type TLSClientHeaders struct {
|
||||
PEM bool // pass the sanitized pem to the backend in a specific header
|
||||
Infos *TLSClientCertificateInfos // pass selected informations from the client certificate
|
||||
}
|
||||
|
||||
func newTLSCLientCertificateSubjectInfos(infos *types.TLSCLientCertificateSubjectInfos) *TLSCLientCertificateSubjectInfos {
|
||||
if infos == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &TLSCLientCertificateSubjectInfos{
|
||||
SerialNumber: infos.SerialNumber,
|
||||
CommonName: infos.CommonName,
|
||||
Country: infos.Country,
|
||||
Locality: infos.Locality,
|
||||
Organization: infos.Organization,
|
||||
Province: infos.Province,
|
||||
}
|
||||
}
|
||||
|
||||
func newTLSClientInfos(infos *types.TLSClientCertificateInfos) *TLSClientCertificateInfos {
|
||||
if infos == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &TLSClientCertificateInfos{
|
||||
NotBefore: infos.NotBefore,
|
||||
NotAfter: infos.NotAfter,
|
||||
Sans: infos.Sans,
|
||||
Subject: newTLSCLientCertificateSubjectInfos(infos.Subject),
|
||||
}
|
||||
}
|
||||
|
||||
// NewTLSClientHeaders constructs a new TLSClientHeaders instance from supplied frontend header struct.
|
||||
func NewTLSClientHeaders(frontend *types.Frontend) *TLSClientHeaders {
|
||||
if frontend == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var pem bool
|
||||
var infos *TLSClientCertificateInfos
|
||||
|
||||
if frontend.PassTLSClientCert != nil {
|
||||
conf := frontend.PassTLSClientCert
|
||||
pem = conf.PEM
|
||||
infos = newTLSClientInfos(conf.Infos)
|
||||
}
|
||||
|
||||
return &TLSClientHeaders{
|
||||
PEM: pem,
|
||||
Infos: infos,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TLSClientHeaders) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
s.ModifyRequestHeaders(r)
|
||||
// If there is a next, call it.
|
||||
if next != nil {
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// sanitize As we pass the raw certificates, remove the useless data and make it http request compliant
|
||||
func sanitize(cert []byte) string {
|
||||
s := string(cert)
|
||||
r := strings.NewReplacer("-----BEGIN CERTIFICATE-----", "",
|
||||
"-----END CERTIFICATE-----", "",
|
||||
"\n", "")
|
||||
cleaned := r.Replace(s)
|
||||
|
||||
return url.QueryEscape(cleaned)
|
||||
}
|
||||
|
||||
// extractCertificate extract the certificate from the request
|
||||
func extractCertificate(cert *x509.Certificate) string {
|
||||
b := pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
|
||||
certPEM := pem.EncodeToMemory(&b)
|
||||
if certPEM == nil {
|
||||
log.Error("Cannot extract the certificate content")
|
||||
return ""
|
||||
}
|
||||
return sanitize(certPEM)
|
||||
}
|
||||
|
||||
// getXForwardedTLSClientCert Build a string with the client certificates
|
||||
func getXForwardedTLSClientCert(certs []*x509.Certificate) string {
|
||||
var headerValues []string
|
||||
|
||||
for _, peerCert := range certs {
|
||||
headerValues = append(headerValues, extractCertificate(peerCert))
|
||||
}
|
||||
|
||||
return strings.Join(headerValues, ",")
|
||||
}
|
||||
|
||||
// getSANs get the Subject Alternate Name values
|
||||
func getSANs(cert *x509.Certificate) []string {
|
||||
var sans []string
|
||||
if cert == nil {
|
||||
return sans
|
||||
}
|
||||
|
||||
sans = append(cert.DNSNames, cert.EmailAddresses...)
|
||||
|
||||
var ips []string
|
||||
for _, ip := range cert.IPAddresses {
|
||||
ips = append(ips, ip.String())
|
||||
}
|
||||
sans = append(sans, ips...)
|
||||
|
||||
var uris []string
|
||||
for _, uri := range cert.URIs {
|
||||
uris = append(uris, uri.String())
|
||||
}
|
||||
|
||||
return append(sans, uris...)
|
||||
}
|
||||
|
||||
// getSubjectInfos extract the requested informations from the certificate subject
|
||||
func (s *TLSClientHeaders) getSubjectInfos(cs *pkix.Name) string {
|
||||
var subject string
|
||||
|
||||
if s.Infos != nil && s.Infos.Subject != nil {
|
||||
options := s.Infos.Subject
|
||||
|
||||
var content []string
|
||||
|
||||
if options.Country && len(cs.Country) > 0 {
|
||||
content = append(content, fmt.Sprintf("C=%s", cs.Country[0]))
|
||||
}
|
||||
|
||||
if options.Province && len(cs.Province) > 0 {
|
||||
content = append(content, fmt.Sprintf("ST=%s", cs.Province[0]))
|
||||
}
|
||||
|
||||
if options.Locality && len(cs.Locality) > 0 {
|
||||
content = append(content, fmt.Sprintf("L=%s", cs.Locality[0]))
|
||||
}
|
||||
|
||||
if options.Organization && len(cs.Organization) > 0 {
|
||||
content = append(content, fmt.Sprintf("O=%s", cs.Organization[0]))
|
||||
}
|
||||
|
||||
if options.CommonName && len(cs.CommonName) > 0 {
|
||||
content = append(content, fmt.Sprintf("CN=%s", cs.CommonName))
|
||||
}
|
||||
|
||||
if len(content) > 0 {
|
||||
subject = `Subject="` + strings.Join(content, ",") + `"`
|
||||
}
|
||||
}
|
||||
|
||||
return subject
|
||||
}
|
||||
|
||||
// getXForwardedTLSClientCertInfos Build a string with the wanted client certificates informations
|
||||
// like Subject="C=%s,ST=%s,L=%s,O=%s,CN=%s",NB=%d,NA=%d,SAN=%s;
|
||||
func (s *TLSClientHeaders) getXForwardedTLSClientCertInfos(certs []*x509.Certificate) string {
|
||||
var headerValues []string
|
||||
|
||||
for _, peerCert := range certs {
|
||||
var values []string
|
||||
var sans string
|
||||
var nb string
|
||||
var na string
|
||||
|
||||
subject := s.getSubjectInfos(&peerCert.Subject)
|
||||
if len(subject) > 0 {
|
||||
values = append(values, subject)
|
||||
}
|
||||
|
||||
ci := s.Infos
|
||||
if ci != nil {
|
||||
if ci.NotBefore {
|
||||
nb = fmt.Sprintf("NB=%d", uint64(peerCert.NotBefore.Unix()))
|
||||
values = append(values, nb)
|
||||
}
|
||||
if ci.NotAfter {
|
||||
na = fmt.Sprintf("NA=%d", uint64(peerCert.NotAfter.Unix()))
|
||||
values = append(values, na)
|
||||
}
|
||||
|
||||
if ci.Sans {
|
||||
sans = fmt.Sprintf("SAN=%s", strings.Join(getSANs(peerCert), ","))
|
||||
values = append(values, sans)
|
||||
}
|
||||
}
|
||||
|
||||
value := strings.Join(values, ",")
|
||||
headerValues = append(headerValues, value)
|
||||
}
|
||||
|
||||
return strings.Join(headerValues, ";")
|
||||
}
|
||||
|
||||
// ModifyRequestHeaders set the wanted headers with the certificates informations
|
||||
func (s *TLSClientHeaders) ModifyRequestHeaders(r *http.Request) {
|
||||
if s.PEM {
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
r.Header.Set(xForwardedTLSClientCert, getXForwardedTLSClientCert(r.TLS.PeerCertificates))
|
||||
} else {
|
||||
log.Warn("Try to extract certificate on a request without TLS")
|
||||
}
|
||||
}
|
||||
|
||||
if s.Infos != nil {
|
||||
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
|
||||
headerContent := s.getXForwardedTLSClientCertInfos(r.TLS.PeerCertificates)
|
||||
r.Header.Set(xForwardedTLSClientCertInfos, url.QueryEscape(headerContent))
|
||||
} else {
|
||||
log.Warn("Try to extract certificate on a request without TLS")
|
||||
}
|
||||
}
|
||||
}
|
||||
799
old/middlewares/tlsClientHeaders_test.go
Normal file
799
old/middlewares/tlsClientHeaders_test.go
Normal file
|
|
@ -0,0 +1,799 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
rootCrt = `-----BEGIN CERTIFICATE-----
|
||||
MIIDhjCCAm6gAwIBAgIJAIKZlW9a3VrYMA0GCSqGSIb3DQEBCwUAMFgxCzAJBgNV
|
||||
BAYTAkZSMRMwEQYDVQQIDApTb21lLVN0YXRlMREwDwYDVQQHDAhUb3Vsb3VzZTEh
|
||||
MB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMB4XDTE4MDcxNzIwMzQz
|
||||
OFoXDTE4MDgxNjIwMzQzOFowWDELMAkGA1UEBhMCRlIxEzARBgNVBAgMClNvbWUt
|
||||
U3RhdGUxETAPBgNVBAcMCFRvdWxvdXNlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRn
|
||||
aXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC1P8GJ
|
||||
H9LkIxIIqK9MyUpushnjmjwccpSMB3OecISKYLy62QDIcAw6NzGcSe8hMwciMJr+
|
||||
CdCjJlohybnaRI9hrJ3GPnI++UT/MMthf2IIcjmJxmD4k9L1fgs1V6zSTlo0+o0x
|
||||
0gkAGlWvRkgA+3nt555ee84XQZuneKKeRRIlSA1ygycewFobZ/pGYijIEko+gYkV
|
||||
sF3LnRGxNl673w+EQsvI7+z29T1nzjmM/xE7WlvnsrVd1/N61jAohLota0YTufwd
|
||||
ioJZNryzuPejHBCiQRGMbJ7uEEZLiSCN6QiZEfqhS3AulykjgFXQQHn4zoVljSBR
|
||||
UyLV0prIn5Scbks/AgMBAAGjUzBRMB0GA1UdDgQWBBTroRRnSgtkV+8dumtcftb/
|
||||
lwIkATAfBgNVHSMEGDAWgBTroRRnSgtkV+8dumtcftb/lwIkATAPBgNVHRMBAf8E
|
||||
BTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQAJ67U5cLa0ZFa/7zQQT4ldkY6YOEgR
|
||||
0LNoTu51hc+ozaXSvF8YIBzkEpEnbGS3x4xodrwEBZjK2LFhNu/33gkCAuhmedgk
|
||||
KwZrQM6lqRFGHGVOlkVz+QrJ2EsKYaO4SCUIwVjijXRLA7A30G5C/CIh66PsMgBY
|
||||
6QHXVPEWm/v1d1Q/DfFfFzSOa1n1rIUw03qVJsxqSwfwYcegOF8YvS/eH4HUr2gF
|
||||
cEujh6CCnylf35ExHa45atr3+xxbOVdNjobISkYADtbhAAn4KjLS4v8W6445vxxj
|
||||
G5EIZLjOHyWg1sGaHaaAPkVpZQg8EKm21c4hrEEMfel60AMSSzad/a/V
|
||||
-----END CERTIFICATE-----`
|
||||
|
||||
minimalCert = `-----BEGIN CERTIFICATE-----
|
||||
MIIDGTCCAgECCQCqLd75YLi2kDANBgkqhkiG9w0BAQsFADBYMQswCQYDVQQGEwJG
|
||||
UjETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UEBwwIVG91bG91c2UxITAfBgNV
|
||||
BAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0xODA3MTgwODI4MTZaFw0x
|
||||
ODA4MTcwODI4MTZaMEUxCzAJBgNVBAYTAkZSMRMwEQYDVQQIDApTb21lLVN0YXRl
|
||||
MSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3
|
||||
DQEBAQUAA4IBDwAwggEKAoIBAQC/+frDMMTLQyXG34F68BPhQq0kzK4LIq9Y0/gl
|
||||
FjySZNn1C0QDWA1ubVCAcA6yY204I9cxcQDPNrhC7JlS5QA8Y5rhIBrqQlzZizAi
|
||||
Rj3NTrRjtGUtOScnHuJaWjLy03DWD+aMwb7q718xt5SEABmmUvLwQK+EjW2MeDwj
|
||||
y8/UEIpvrRDmdhGaqv7IFpIDkcIF7FowJ/hwDvx3PMc+z/JWK0ovzpvgbx69AVbw
|
||||
ZxCimeha65rOqVi+lEetD26le+WnOdYsdJ2IkmpPNTXGdfb15xuAc+gFXfMCh7Iw
|
||||
3Ynl6dZtZM/Ok2kiA7/OsmVnRKkWrtBfGYkI9HcNGb3zrk6nAgMBAAEwDQYJKoZI
|
||||
hvcNAQELBQADggEBAC/R+Yvhh1VUhcbK49olWsk/JKqfS3VIDQYZg1Eo+JCPbwgS
|
||||
I1BSYVfMcGzuJTX6ua3m/AHzGF3Tap4GhF4tX12jeIx4R4utnjj7/YKkTvuEM2f4
|
||||
xT56YqI7zalGScIB0iMeyNz1QcimRl+M/49au8ow9hNX8C2tcA2cwd/9OIj/6T8q
|
||||
SBRHc6ojvbqZSJCO0jziGDT1L3D+EDgTjED4nd77v/NRdP+egb0q3P0s4dnQ/5AV
|
||||
aQlQADUn61j3ScbGJ4NSeZFFvsl38jeRi/MEzp0bGgNBcPj6JHi7qbbauZcZfQ05
|
||||
jECvgAY7Nfd9mZ1KtyNaW31is+kag7NsvjxU/kM=
|
||||
-----END CERTIFICATE-----`
|
||||
|
||||
completeCert = `Certificate:
|
||||
Data:
|
||||
Version: 3 (0x2)
|
||||
Serial Number: 3 (0x3)
|
||||
Signature Algorithm: sha1WithRSAEncryption
|
||||
Issuer: C=FR, ST=Some-State, L=Toulouse, O=Internet Widgits Pty Ltd
|
||||
Validity
|
||||
Not Before: Jul 18 08:00:16 2018 GMT
|
||||
Not After : Jul 18 08:00:16 2019 GMT
|
||||
Subject: C=FR, ST=SomeState, L=Toulouse, O=Cheese, CN=*.cheese.org
|
||||
Subject Public Key Info:
|
||||
Public Key Algorithm: rsaEncryption
|
||||
Public-Key: (2048 bit)
|
||||
Modulus:
|
||||
00:a6:1f:96:7c:c1:cc:b8:1c:b5:91:5d:b8:bf:70:
|
||||
bc:f7:b8:04:4f:2a:42:de:ea:c5:c3:19:0b:03:04:
|
||||
ec:ef:a1:24:25:de:ad:05:e7:26:ea:89:6c:59:60:
|
||||
10:18:0c:73:f1:bf:d3:cc:7b:ed:6b:9c:ea:1d:88:
|
||||
e2:ee:14:81:d7:07:ee:87:95:3d:36:df:9c:38:b7:
|
||||
7b:1e:2b:51:9c:4a:1f:d0:cc:5b:af:5d:6c:5c:35:
|
||||
49:32:e4:01:5b:f9:8c:71:cf:62:48:5a:ea:b7:31:
|
||||
58:e2:c6:d0:5b:1c:50:b5:5c:6d:5a:6f:da:41:5e:
|
||||
d5:4c:6e:1a:21:f3:40:f9:9e:52:76:50:25:3e:03:
|
||||
9b:87:19:48:5b:47:87:d3:67:c6:25:69:77:29:8e:
|
||||
56:97:45:d9:6f:64:a8:4e:ad:35:75:2e:fc:6a:2e:
|
||||
47:87:76:fc:4e:3e:44:e9:16:b2:c7:f0:23:98:13:
|
||||
a2:df:15:23:cb:0c:3d:fd:48:5e:c7:2c:86:70:63:
|
||||
8b:c6:c8:89:17:52:d5:a7:8e:cb:4e:11:9d:69:8e:
|
||||
8e:59:cc:7e:a3:bd:a1:11:88:d7:cf:7b:8c:19:46:
|
||||
9c:1b:7a:c9:39:81:4c:58:08:1f:c7:ce:b0:0e:79:
|
||||
64:d3:11:72:65:e6:dd:bd:00:7f:22:30:46:9b:66:
|
||||
9c:b9
|
||||
Exponent: 65537 (0x10001)
|
||||
X509v3 extensions:
|
||||
X509v3 Basic Constraints:
|
||||
CA:FALSE
|
||||
X509v3 Subject Alternative Name:
|
||||
DNS:*.cheese.org, DNS:*.cheese.net, DNS:cheese.in, IP Address:10.0.1.0, IP Address:10.0.1.2, email:test@cheese.org, email:test@cheese.net
|
||||
X509v3 Subject Key Identifier:
|
||||
AB:6B:89:25:11:FC:5E:7B:D4:B0:F7:D4:B6:D9:EB:D0:30:93:E5:58
|
||||
Signature Algorithm: sha1WithRSAEncryption
|
||||
ad:87:84:a0:88:a3:4c:d9:0a:c0:14:e4:2d:9a:1d:bb:57:b7:
|
||||
12:ef:3a:fb:8b:b2:ce:32:b8:04:e6:59:c8:4f:14:6a:b5:12:
|
||||
46:e9:c9:0a:11:64:ea:a1:86:20:96:0e:a7:40:e3:aa:e5:98:
|
||||
91:36:89:77:b6:b9:73:7e:1a:58:19:ae:d1:14:83:1e:c1:5f:
|
||||
a5:a0:32:bb:52:68:b4:8d:a3:1d:b3:08:d7:45:6e:3b:87:64:
|
||||
7e:ef:46:e6:6f:d5:79:d7:1d:57:68:67:d8:18:39:61:5b:8b:
|
||||
1a:7f:88:da:0a:51:9b:3d:6c:5d:b1:cf:b7:e9:1e:06:65:8e:
|
||||
96:d3:61:96:f8:a2:61:f9:40:5e:fa:bc:76:b9:64:0e:6f:90:
|
||||
37:de:ac:6d:7f:36:84:35:19:88:8c:26:af:3e:c3:6a:1a:03:
|
||||
ed:d7:90:89:ed:18:4c:9e:94:1f:d8:ae:6c:61:36:17:72:f9:
|
||||
bb:de:0a:56:9a:79:b4:7d:4a:9d:cb:4a:7d:71:9f:38:e7:8d:
|
||||
f0:87:24:21:0a:24:1f:82:9a:6b:67:ce:7d:af:cb:91:6b:8a:
|
||||
de:e6:d8:6f:a1:37:b9:2d:d0:cb:e8:4e:f4:43:af:ad:90:13:
|
||||
7d:61:7a:ce:86:48:fc:00:8c:37:fb:e0:31:6b:e2:18:ad:fd:
|
||||
1e:df:08:db
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDvTCCAqWgAwIBAgIBAzANBgkqhkiG9w0BAQUFADBYMQswCQYDVQQGEwJGUjET
|
||||
MBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UEBwwIVG91bG91c2UxITAfBgNVBAoM
|
||||
GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0xODA3MTgwODAwMTZaFw0xOTA3
|
||||
MTgwODAwMTZaMFwxCzAJBgNVBAYTAkZSMRIwEAYDVQQIDAlTb21lU3RhdGUxETAP
|
||||
BgNVBAcMCFRvdWxvdXNlMQ8wDQYDVQQKDAZDaGVlc2UxFTATBgNVBAMMDCouY2hl
|
||||
ZXNlLm9yZzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKYflnzBzLgc
|
||||
tZFduL9wvPe4BE8qQt7qxcMZCwME7O+hJCXerQXnJuqJbFlgEBgMc/G/08x77Wuc
|
||||
6h2I4u4UgdcH7oeVPTbfnDi3ex4rUZxKH9DMW69dbFw1STLkAVv5jHHPYkha6rcx
|
||||
WOLG0FscULVcbVpv2kFe1UxuGiHzQPmeUnZQJT4Dm4cZSFtHh9NnxiVpdymOVpdF
|
||||
2W9kqE6tNXUu/GouR4d2/E4+ROkWssfwI5gTot8VI8sMPf1IXscshnBji8bIiRdS
|
||||
1aeOy04RnWmOjlnMfqO9oRGI1897jBlGnBt6yTmBTFgIH8fOsA55ZNMRcmXm3b0A
|
||||
fyIwRptmnLkCAwEAAaOBjTCBijAJBgNVHRMEAjAAMF4GA1UdEQRXMFWCDCouY2hl
|
||||
ZXNlLm9yZ4IMKi5jaGVlc2UubmV0ggljaGVlc2UuaW6HBAoAAQCHBAoAAQKBD3Rl
|
||||
c3RAY2hlZXNlLm9yZ4EPdGVzdEBjaGVlc2UubmV0MB0GA1UdDgQWBBSra4klEfxe
|
||||
e9Sw99S22evQMJPlWDANBgkqhkiG9w0BAQUFAAOCAQEArYeEoIijTNkKwBTkLZod
|
||||
u1e3Eu86+4uyzjK4BOZZyE8UarUSRunJChFk6qGGIJYOp0DjquWYkTaJd7a5c34a
|
||||
WBmu0RSDHsFfpaAyu1JotI2jHbMI10VuO4dkfu9G5m/VedcdV2hn2Bg5YVuLGn+I
|
||||
2gpRmz1sXbHPt+keBmWOltNhlviiYflAXvq8drlkDm+QN96sbX82hDUZiIwmrz7D
|
||||
ahoD7deQie0YTJ6UH9iubGE2F3L5u94KVpp5tH1KnctKfXGfOOeN8IckIQokH4Ka
|
||||
a2fOfa/LkWuK3ubYb6E3uS3Qy+hO9EOvrZATfWF6zoZI/ACMN/vgMWviGK39Ht8I
|
||||
2w==
|
||||
-----END CERTIFICATE-----
|
||||
`
|
||||
)
|
||||
|
||||
func getCleanCertContents(certContents []string) string {
|
||||
var re = regexp.MustCompile("-----BEGIN CERTIFICATE-----(?s)(.*)")
|
||||
|
||||
var cleanedCertContent []string
|
||||
for _, certContent := range certContents {
|
||||
cert := re.FindString(string(certContent))
|
||||
cleanedCertContent = append(cleanedCertContent, sanitize([]byte(cert)))
|
||||
}
|
||||
|
||||
return strings.Join(cleanedCertContent, ",")
|
||||
}
|
||||
|
||||
func getCertificate(certContent string) *x509.Certificate {
|
||||
roots := x509.NewCertPool()
|
||||
ok := roots.AppendCertsFromPEM([]byte(rootCrt))
|
||||
if !ok {
|
||||
panic("failed to parse root certificate")
|
||||
}
|
||||
|
||||
block, _ := pem.Decode([]byte(certContent))
|
||||
if block == nil {
|
||||
panic("failed to parse certificate PEM")
|
||||
}
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
panic("failed to parse certificate: " + err.Error())
|
||||
}
|
||||
|
||||
return cert
|
||||
}
|
||||
|
||||
func buildTLSWith(certContents []string) *tls.ConnectionState {
|
||||
var peerCertificates []*x509.Certificate
|
||||
|
||||
for _, certContent := range certContents {
|
||||
peerCertificates = append(peerCertificates, getCertificate(certContent))
|
||||
}
|
||||
|
||||
return &tls.ConnectionState{PeerCertificates: peerCertificates}
|
||||
}
|
||||
|
||||
var myPassTLSClientCustomHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("bar"))
|
||||
})
|
||||
|
||||
func getExpectedSanitized(s string) string {
|
||||
return url.QueryEscape(strings.Replace(s, "\n", "", -1))
|
||||
}
|
||||
|
||||
func TestSanitize(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
toSanitize []byte
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "Empty",
|
||||
},
|
||||
{
|
||||
desc: "With a minimal cert",
|
||||
toSanitize: []byte(minimalCert),
|
||||
expected: getExpectedSanitized(`MIIDGTCCAgECCQCqLd75YLi2kDANBgkqhkiG9w0BAQsFADBYMQswCQYDVQQGEwJG
|
||||
UjETMBEGA1UECAwKU29tZS1TdGF0ZTERMA8GA1UEBwwIVG91bG91c2UxITAfBgNV
|
||||
BAoMGEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0xODA3MTgwODI4MTZaFw0x
|
||||
ODA4MTcwODI4MTZaMEUxCzAJBgNVBAYTAkZSMRMwEQYDVQQIDApTb21lLVN0YXRl
|
||||
MSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggEiMA0GCSqGSIb3
|
||||
DQEBAQUAA4IBDwAwggEKAoIBAQC/+frDMMTLQyXG34F68BPhQq0kzK4LIq9Y0/gl
|
||||
FjySZNn1C0QDWA1ubVCAcA6yY204I9cxcQDPNrhC7JlS5QA8Y5rhIBrqQlzZizAi
|
||||
Rj3NTrRjtGUtOScnHuJaWjLy03DWD+aMwb7q718xt5SEABmmUvLwQK+EjW2MeDwj
|
||||
y8/UEIpvrRDmdhGaqv7IFpIDkcIF7FowJ/hwDvx3PMc+z/JWK0ovzpvgbx69AVbw
|
||||
ZxCimeha65rOqVi+lEetD26le+WnOdYsdJ2IkmpPNTXGdfb15xuAc+gFXfMCh7Iw
|
||||
3Ynl6dZtZM/Ok2kiA7/OsmVnRKkWrtBfGYkI9HcNGb3zrk6nAgMBAAEwDQYJKoZI
|
||||
hvcNAQELBQADggEBAC/R+Yvhh1VUhcbK49olWsk/JKqfS3VIDQYZg1Eo+JCPbwgS
|
||||
I1BSYVfMcGzuJTX6ua3m/AHzGF3Tap4GhF4tX12jeIx4R4utnjj7/YKkTvuEM2f4
|
||||
xT56YqI7zalGScIB0iMeyNz1QcimRl+M/49au8ow9hNX8C2tcA2cwd/9OIj/6T8q
|
||||
SBRHc6ojvbqZSJCO0jziGDT1L3D+EDgTjED4nd77v/NRdP+egb0q3P0s4dnQ/5AV
|
||||
aQlQADUn61j3ScbGJ4NSeZFFvsl38jeRi/MEzp0bGgNBcPj6JHi7qbbauZcZfQ05
|
||||
jECvgAY7Nfd9mZ1KtyNaW31is+kag7NsvjxU/kM=`),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Equal(t, test.expected, sanitize(test.toSanitize), "The sanitized certificates should be equal")
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestTlsClientheadersWithPEM(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
certContents []string // set the request TLS attribute if defined
|
||||
tlsClientCertHeaders *types.TLSClientHeaders
|
||||
expectedHeader string
|
||||
}{
|
||||
{
|
||||
desc: "No TLS, no option",
|
||||
},
|
||||
{
|
||||
desc: "TLS, no option",
|
||||
certContents: []string{minimalCert},
|
||||
},
|
||||
{
|
||||
desc: "No TLS, with pem option true",
|
||||
tlsClientCertHeaders: &types.TLSClientHeaders{PEM: true},
|
||||
},
|
||||
{
|
||||
desc: "TLS with simple certificate, with pem option true",
|
||||
certContents: []string{minimalCert},
|
||||
tlsClientCertHeaders: &types.TLSClientHeaders{PEM: true},
|
||||
expectedHeader: getCleanCertContents([]string{minimalCert}),
|
||||
},
|
||||
{
|
||||
desc: "TLS with complete certificate, with pem option true",
|
||||
certContents: []string{completeCert},
|
||||
tlsClientCertHeaders: &types.TLSClientHeaders{PEM: true},
|
||||
expectedHeader: getCleanCertContents([]string{completeCert}),
|
||||
},
|
||||
{
|
||||
desc: "TLS with two certificate, with pem option true",
|
||||
certContents: []string{minimalCert, completeCert},
|
||||
tlsClientCertHeaders: &types.TLSClientHeaders{PEM: true},
|
||||
expectedHeader: getCleanCertContents([]string{minimalCert, completeCert}),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
tlsClientHeaders := NewTLSClientHeaders(&types.Frontend{PassTLSClientCert: test.tlsClientCertHeaders})
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://example.com/foo", nil)
|
||||
|
||||
if test.certContents != nil && len(test.certContents) > 0 {
|
||||
req.TLS = buildTLSWith(test.certContents)
|
||||
}
|
||||
|
||||
tlsClientHeaders.ServeHTTP(res, req, myPassTLSClientCustomHandler)
|
||||
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Equal(t, http.StatusOK, res.Code, "Http Status should be OK")
|
||||
require.Equal(t, "bar", res.Body.String(), "Should be the expected body")
|
||||
|
||||
if test.expectedHeader != "" {
|
||||
require.Equal(t, getCleanCertContents(test.certContents), req.Header.Get(xForwardedTLSClientCert), "The request header should contain the cleaned certificate")
|
||||
} else {
|
||||
require.Empty(t, req.Header.Get(xForwardedTLSClientCert))
|
||||
}
|
||||
require.Empty(t, res.Header().Get(xForwardedTLSClientCert), "The response header should be always empty")
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestGetSans(t *testing.T) {
|
||||
urlFoo, err := url.Parse("my.foo.com")
|
||||
require.NoError(t, err)
|
||||
urlBar, err := url.Parse("my.bar.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
cert *x509.Certificate // set the request TLS attribute if defined
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
desc: "With nil",
|
||||
},
|
||||
{
|
||||
desc: "Certificate without Sans",
|
||||
cert: &x509.Certificate{},
|
||||
},
|
||||
{
|
||||
desc: "Certificate with all Sans",
|
||||
cert: &x509.Certificate{
|
||||
DNSNames: []string{"foo", "bar"},
|
||||
EmailAddresses: []string{"test@test.com", "test2@test.com"},
|
||||
IPAddresses: []net.IP{net.IPv4(10, 0, 0, 1), net.IPv4(10, 0, 0, 2)},
|
||||
URIs: []*url.URL{urlFoo, urlBar},
|
||||
},
|
||||
expected: []string{"foo", "bar", "test@test.com", "test2@test.com", "10.0.0.1", "10.0.0.2", urlFoo.String(), urlBar.String()},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
sans := getSANs(test.cert)
|
||||
test := test
|
||||
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if len(test.expected) > 0 {
|
||||
for i, expected := range test.expected {
|
||||
require.Equal(t, expected, sans[i])
|
||||
}
|
||||
} else {
|
||||
require.Empty(t, sans)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestTlsClientheadersWithCertInfos(t *testing.T) {
|
||||
minimalCertAllInfos := `Subject="C=FR,ST=Some-State,O=Internet Widgits Pty Ltd",NB=1531902496,NA=1534494496,SAN=`
|
||||
completeCertAllInfos := `Subject="C=FR,ST=SomeState,L=Toulouse,O=Cheese,CN=*.cheese.org",NB=1531900816,NA=1563436816,SAN=*.cheese.org,*.cheese.net,cheese.in,test@cheese.org,test@cheese.net,10.0.1.0,10.0.1.2`
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
certContents []string // set the request TLS attribute if defined
|
||||
tlsClientCertHeaders *types.TLSClientHeaders
|
||||
expectedHeader string
|
||||
}{
|
||||
{
|
||||
desc: "No TLS, no option",
|
||||
},
|
||||
{
|
||||
desc: "TLS, no option",
|
||||
certContents: []string{minimalCert},
|
||||
},
|
||||
{
|
||||
desc: "No TLS, with pem option true",
|
||||
tlsClientCertHeaders: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
CommonName: true,
|
||||
Organization: true,
|
||||
Locality: true,
|
||||
Province: true,
|
||||
Country: true,
|
||||
SerialNumber: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "No TLS, with pem option true with no flag",
|
||||
tlsClientCertHeaders: &types.TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "TLS with simple certificate, with all infos",
|
||||
certContents: []string{minimalCert},
|
||||
tlsClientCertHeaders: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
NotAfter: true,
|
||||
NotBefore: true,
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
CommonName: true,
|
||||
Organization: true,
|
||||
Locality: true,
|
||||
Province: true,
|
||||
Country: true,
|
||||
SerialNumber: true,
|
||||
},
|
||||
Sans: true,
|
||||
},
|
||||
},
|
||||
expectedHeader: url.QueryEscape(minimalCertAllInfos),
|
||||
},
|
||||
{
|
||||
desc: "TLS with simple certificate, with some infos",
|
||||
certContents: []string{minimalCert},
|
||||
tlsClientCertHeaders: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
NotAfter: true,
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
Organization: true,
|
||||
},
|
||||
Sans: true,
|
||||
},
|
||||
},
|
||||
expectedHeader: url.QueryEscape(`Subject="O=Internet Widgits Pty Ltd",NA=1534494496,SAN=`),
|
||||
},
|
||||
{
|
||||
desc: "TLS with complete certificate, with all infos",
|
||||
certContents: []string{completeCert},
|
||||
tlsClientCertHeaders: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
NotAfter: true,
|
||||
NotBefore: true,
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
CommonName: true,
|
||||
Organization: true,
|
||||
Locality: true,
|
||||
Province: true,
|
||||
Country: true,
|
||||
SerialNumber: true,
|
||||
},
|
||||
Sans: true,
|
||||
},
|
||||
},
|
||||
expectedHeader: url.QueryEscape(completeCertAllInfos),
|
||||
},
|
||||
{
|
||||
desc: "TLS with 2 certificates, with all infos",
|
||||
certContents: []string{minimalCert, completeCert},
|
||||
tlsClientCertHeaders: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
NotAfter: true,
|
||||
NotBefore: true,
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
CommonName: true,
|
||||
Organization: true,
|
||||
Locality: true,
|
||||
Province: true,
|
||||
Country: true,
|
||||
SerialNumber: true,
|
||||
},
|
||||
Sans: true,
|
||||
},
|
||||
},
|
||||
expectedHeader: url.QueryEscape(strings.Join([]string{minimalCertAllInfos, completeCertAllInfos}, ";")),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
tlsClientHeaders := NewTLSClientHeaders(&types.Frontend{PassTLSClientCert: test.tlsClientCertHeaders})
|
||||
|
||||
res := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://example.com/foo", nil)
|
||||
|
||||
if test.certContents != nil && len(test.certContents) > 0 {
|
||||
req.TLS = buildTLSWith(test.certContents)
|
||||
}
|
||||
|
||||
tlsClientHeaders.ServeHTTP(res, req, myPassTLSClientCustomHandler)
|
||||
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Equal(t, http.StatusOK, res.Code, "Http Status should be OK")
|
||||
require.Equal(t, "bar", res.Body.String(), "Should be the expected body")
|
||||
|
||||
if test.expectedHeader != "" {
|
||||
require.Equal(t, test.expectedHeader, req.Header.Get(xForwardedTLSClientCertInfos), "The request header should contain the cleaned certificate")
|
||||
} else {
|
||||
require.Empty(t, req.Header.Get(xForwardedTLSClientCertInfos))
|
||||
}
|
||||
require.Empty(t, res.Header().Get(xForwardedTLSClientCertInfos), "The response header should be always empty")
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestNewTLSClientHeadersFromStruct(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
frontend *types.Frontend
|
||||
expected *TLSClientHeaders
|
||||
}{
|
||||
{
|
||||
desc: "Without frontend",
|
||||
},
|
||||
{
|
||||
desc: "frontend without the option",
|
||||
frontend: &types.Frontend{},
|
||||
expected: &TLSClientHeaders{},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the pem set false",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
PEM: false,
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{PEM: false},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the pem set true",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
PEM: true,
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{PEM: true},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos with no flag",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
NotAfter: false,
|
||||
NotBefore: false,
|
||||
Sans: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos basic",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
NotAfter: true,
|
||||
NotBefore: true,
|
||||
Sans: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
NotBefore: true,
|
||||
NotAfter: true,
|
||||
Sans: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos NotAfter",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
NotAfter: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
NotAfter: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos NotBefore",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
NotBefore: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
NotBefore: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos Sans",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
Sans: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
Sans: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos Subject Organization",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
Organization: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
Subject: &TLSCLientCertificateSubjectInfos{
|
||||
Organization: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos Subject Country",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
Country: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
Subject: &TLSCLientCertificateSubjectInfos{
|
||||
Country: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos Subject SerialNumber",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
SerialNumber: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
Subject: &TLSCLientCertificateSubjectInfos{
|
||||
SerialNumber: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos Subject Province",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
Province: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
Subject: &TLSCLientCertificateSubjectInfos{
|
||||
Province: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos Subject Locality",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
Locality: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
Subject: &TLSCLientCertificateSubjectInfos{
|
||||
Locality: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos Subject CommonName",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
CommonName: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
Subject: &TLSCLientCertificateSubjectInfos{
|
||||
CommonName: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos NotBefore",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
Sans: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
Sans: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "frontend with the Infos all",
|
||||
frontend: &types.Frontend{
|
||||
PassTLSClientCert: &types.TLSClientHeaders{
|
||||
Infos: &types.TLSClientCertificateInfos{
|
||||
NotAfter: true,
|
||||
NotBefore: true,
|
||||
Subject: &types.TLSCLientCertificateSubjectInfos{
|
||||
CommonName: true,
|
||||
Country: true,
|
||||
Locality: true,
|
||||
Organization: true,
|
||||
Province: true,
|
||||
SerialNumber: true,
|
||||
},
|
||||
Sans: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &TLSClientHeaders{
|
||||
PEM: false,
|
||||
Infos: &TLSClientCertificateInfos{
|
||||
NotBefore: true,
|
||||
NotAfter: true,
|
||||
Sans: true,
|
||||
Subject: &TLSCLientCertificateSubjectInfos{
|
||||
Province: true,
|
||||
Organization: true,
|
||||
Locality: true,
|
||||
Country: true,
|
||||
CommonName: true,
|
||||
SerialNumber: true,
|
||||
},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Equal(t, test.expected, NewTLSClientHeaders(test.frontend))
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
25
old/middlewares/tracing/carrier.go
Normal file
25
old/middlewares/tracing/carrier.go
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
package tracing
|
||||
|
||||
import "net/http"
|
||||
|
||||
// HTTPHeadersCarrier custom implementation to fix duplicated headers
|
||||
// It has been fixed in https://github.com/opentracing/opentracing-go/pull/191
|
||||
type HTTPHeadersCarrier http.Header
|
||||
|
||||
// Set conforms to the TextMapWriter interface.
|
||||
func (c HTTPHeadersCarrier) Set(key, val string) {
|
||||
h := http.Header(c)
|
||||
h.Set(key, val)
|
||||
}
|
||||
|
||||
// ForeachKey conforms to the TextMapReader interface.
|
||||
func (c HTTPHeadersCarrier) ForeachKey(handler func(key, val string) error) error {
|
||||
for k, vals := range c {
|
||||
for _, v := range vals {
|
||||
if err := handler(k, v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
45
old/middlewares/tracing/datadog/datadog.go
Normal file
45
old/middlewares/tracing/datadog/datadog.go
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
package datadog
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
ddtracer "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/opentracer"
|
||||
datadog "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/tracer"
|
||||
)
|
||||
|
||||
// Name sets the name of this tracer
|
||||
const Name = "datadog"
|
||||
|
||||
// Config provides configuration settings for a datadog tracer
|
||||
type Config struct {
|
||||
LocalAgentHostPort string `description:"Set datadog-agent's host:port that the reporter will used. Defaults to localhost:8126" export:"false"`
|
||||
GlobalTag string `description:"Key:Value tag to be set on all the spans." export:"true"`
|
||||
Debug bool `description:"Enable DataDog debug." export:"true"`
|
||||
}
|
||||
|
||||
// Setup sets up the tracer
|
||||
func (c *Config) Setup(serviceName string) (opentracing.Tracer, io.Closer, error) {
|
||||
tag := strings.SplitN(c.GlobalTag, ":", 2)
|
||||
|
||||
value := ""
|
||||
if len(tag) == 2 {
|
||||
value = tag[1]
|
||||
}
|
||||
|
||||
tracer := ddtracer.New(
|
||||
datadog.WithAgentAddr(c.LocalAgentHostPort),
|
||||
datadog.WithServiceName(serviceName),
|
||||
datadog.WithGlobalTag(tag[0], value),
|
||||
datadog.WithDebugMode(c.Debug),
|
||||
)
|
||||
|
||||
// Without this, child spans are getting the NOOP tracer
|
||||
opentracing.SetGlobalTracer(tracer)
|
||||
|
||||
log.Debug("DataDog tracer configured")
|
||||
|
||||
return tracer, nil, nil
|
||||
}
|
||||
57
old/middlewares/tracing/entrypoint.go
Normal file
57
old/middlewares/tracing/entrypoint.go
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
package tracing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
type entryPointMiddleware struct {
|
||||
entryPoint string
|
||||
*Tracing
|
||||
}
|
||||
|
||||
// NewEntryPoint creates a new middleware that the incoming request
|
||||
func (t *Tracing) NewEntryPoint(name string) negroni.Handler {
|
||||
log.Debug("Added entrypoint tracing middleware")
|
||||
return &entryPointMiddleware{Tracing: t, entryPoint: name}
|
||||
}
|
||||
|
||||
func (e *entryPointMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
opNameFunc := generateEntryPointSpanName
|
||||
|
||||
ctx, _ := e.Extract(opentracing.HTTPHeaders, HTTPHeadersCarrier(r.Header))
|
||||
span := e.StartSpan(opNameFunc(r, e.entryPoint, e.SpanNameLimit), ext.RPCServerOption(ctx))
|
||||
ext.Component.Set(span, e.ServiceName)
|
||||
LogRequest(span, r)
|
||||
ext.SpanKindRPCServer.Set(span)
|
||||
|
||||
r = r.WithContext(opentracing.ContextWithSpan(r.Context(), span))
|
||||
|
||||
recorder := newStatusCodeRecoder(w, 200)
|
||||
next(recorder, r)
|
||||
|
||||
LogResponseCode(span, recorder.Status())
|
||||
span.Finish()
|
||||
}
|
||||
|
||||
// generateEntryPointSpanName will return a Span name of an appropriate lenth based on the 'spanLimit' argument. If needed, it will be truncated, but will not be less than 24 characters.
|
||||
func generateEntryPointSpanName(r *http.Request, entryPoint string, spanLimit int) string {
|
||||
name := fmt.Sprintf("Entrypoint %s %s", entryPoint, r.Host)
|
||||
|
||||
if spanLimit > 0 && len(name) > spanLimit {
|
||||
if spanLimit < EntryPointMaxLengthNumber {
|
||||
log.Warnf("SpanNameLimit is set to be less than required static number of characters, defaulting to %d + 3", EntryPointMaxLengthNumber)
|
||||
spanLimit = EntryPointMaxLengthNumber + 3
|
||||
}
|
||||
hash := computeHash(name)
|
||||
limit := (spanLimit - EntryPointMaxLengthNumber) / 2
|
||||
name = fmt.Sprintf("Entrypoint %s %s %s", truncateString(entryPoint, limit), truncateString(r.Host, limit), hash)
|
||||
}
|
||||
|
||||
return name
|
||||
}
|
||||
70
old/middlewares/tracing/entrypoint_test.go
Normal file
70
old/middlewares/tracing/entrypoint_test.go
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
package tracing
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestEntryPointMiddlewareServeHTTP(t *testing.T) {
|
||||
expectedTags := map[string]interface{}{
|
||||
"span.kind": ext.SpanKindRPCServerEnum,
|
||||
"http.method": "GET",
|
||||
"component": "",
|
||||
"http.url": "http://www.test.com",
|
||||
"http.host": "www.test.com",
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
entryPoint string
|
||||
tracing *Tracing
|
||||
expectedTags map[string]interface{}
|
||||
expectedName string
|
||||
}{
|
||||
{
|
||||
desc: "no truncation test",
|
||||
entryPoint: "test",
|
||||
tracing: &Tracing{
|
||||
SpanNameLimit: 0,
|
||||
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
|
||||
},
|
||||
expectedTags: expectedTags,
|
||||
expectedName: "Entrypoint test www.test.com",
|
||||
}, {
|
||||
desc: "basic test",
|
||||
entryPoint: "test",
|
||||
tracing: &Tracing{
|
||||
SpanNameLimit: 25,
|
||||
tracer: &MockTracer{Span: &MockSpan{Tags: make(map[string]interface{})}},
|
||||
},
|
||||
expectedTags: expectedTags,
|
||||
expectedName: "Entrypoint te... ww... 39b97e58",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
e := &entryPointMiddleware{
|
||||
entryPoint: test.entryPoint,
|
||||
Tracing: test.tracing,
|
||||
}
|
||||
|
||||
next := func(http.ResponseWriter, *http.Request) {
|
||||
span := test.tracing.tracer.(*MockTracer).Span
|
||||
|
||||
actual := span.Tags
|
||||
assert.Equal(t, test.expectedTags, actual)
|
||||
assert.Equal(t, test.expectedName, span.OpName)
|
||||
}
|
||||
|
||||
e.ServeHTTP(httptest.NewRecorder(), httptest.NewRequest(http.MethodGet, "http://www.test.com", nil), next)
|
||||
})
|
||||
}
|
||||
}
|
||||
63
old/middlewares/tracing/forwarder.go
Normal file
63
old/middlewares/tracing/forwarder.go
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
package tracing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
type forwarderMiddleware struct {
|
||||
frontend string
|
||||
backend string
|
||||
opName string
|
||||
*Tracing
|
||||
}
|
||||
|
||||
// NewForwarderMiddleware creates a new forwarder middleware that traces the outgoing request
|
||||
func (t *Tracing) NewForwarderMiddleware(frontend, backend string) negroni.Handler {
|
||||
log.Debugf("Added outgoing tracing middleware %s", frontend)
|
||||
return &forwarderMiddleware{
|
||||
Tracing: t,
|
||||
frontend: frontend,
|
||||
backend: backend,
|
||||
opName: generateForwardSpanName(frontend, backend, t.SpanNameLimit),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *forwarderMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
span, r, finish := StartSpan(r, f.opName, true)
|
||||
defer finish()
|
||||
span.SetTag("frontend.name", f.frontend)
|
||||
span.SetTag("backend.name", f.backend)
|
||||
ext.HTTPMethod.Set(span, r.Method)
|
||||
ext.HTTPUrl.Set(span, fmt.Sprintf("%s%s", r.URL.String(), r.RequestURI))
|
||||
span.SetTag("http.host", r.Host)
|
||||
|
||||
InjectRequestHeaders(r)
|
||||
|
||||
recorder := newStatusCodeRecoder(w, 200)
|
||||
|
||||
next(recorder, r)
|
||||
|
||||
LogResponseCode(span, recorder.Status())
|
||||
}
|
||||
|
||||
// generateForwardSpanName will return a Span name of an appropriate lenth based on the 'spanLimit' argument. If needed, it will be truncated, but will not be less than 21 characters
|
||||
func generateForwardSpanName(frontend, backend string, spanLimit int) string {
|
||||
name := fmt.Sprintf("forward %s/%s", frontend, backend)
|
||||
|
||||
if spanLimit > 0 && len(name) > spanLimit {
|
||||
if spanLimit < ForwardMaxLengthNumber {
|
||||
log.Warnf("SpanNameLimit is set to be less than required static number of characters, defaulting to %d + 3", ForwardMaxLengthNumber)
|
||||
spanLimit = ForwardMaxLengthNumber + 3
|
||||
}
|
||||
hash := computeHash(name)
|
||||
limit := (spanLimit - ForwardMaxLengthNumber) / 2
|
||||
name = fmt.Sprintf("forward %s/%s/%s", truncateString(frontend, limit), truncateString(backend, limit), hash)
|
||||
}
|
||||
|
||||
return name
|
||||
}
|
||||
93
old/middlewares/tracing/forwarder_test.go
Normal file
93
old/middlewares/tracing/forwarder_test.go
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
package tracing
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTracingNewForwarderMiddleware(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
tracer *Tracing
|
||||
frontend string
|
||||
backend string
|
||||
expected *forwarderMiddleware
|
||||
}{
|
||||
{
|
||||
desc: "Simple Forward Tracer without truncation and hashing",
|
||||
tracer: &Tracing{
|
||||
SpanNameLimit: 101,
|
||||
},
|
||||
frontend: "some-service.domain.tld",
|
||||
backend: "some-service.domain.tld",
|
||||
expected: &forwarderMiddleware{
|
||||
Tracing: &Tracing{
|
||||
SpanNameLimit: 101,
|
||||
},
|
||||
frontend: "some-service.domain.tld",
|
||||
backend: "some-service.domain.tld",
|
||||
opName: "forward some-service.domain.tld/some-service.domain.tld",
|
||||
},
|
||||
}, {
|
||||
desc: "Simple Forward Tracer with truncation and hashing",
|
||||
tracer: &Tracing{
|
||||
SpanNameLimit: 101,
|
||||
},
|
||||
frontend: "some-service-100.slug.namespace.environment.domain.tld",
|
||||
backend: "some-service-100.slug.namespace.environment.domain.tld",
|
||||
expected: &forwarderMiddleware{
|
||||
Tracing: &Tracing{
|
||||
SpanNameLimit: 101,
|
||||
},
|
||||
frontend: "some-service-100.slug.namespace.environment.domain.tld",
|
||||
backend: "some-service-100.slug.namespace.environment.domain.tld",
|
||||
opName: "forward some-service-100.slug.namespace.enviro.../some-service-100.slug.namespace.enviro.../bc4a0d48",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Exactly 101 chars",
|
||||
tracer: &Tracing{
|
||||
SpanNameLimit: 101,
|
||||
},
|
||||
frontend: "some-service1.namespace.environment.domain.tld",
|
||||
backend: "some-service1.namespace.environment.domain.tld",
|
||||
expected: &forwarderMiddleware{
|
||||
Tracing: &Tracing{
|
||||
SpanNameLimit: 101,
|
||||
},
|
||||
frontend: "some-service1.namespace.environment.domain.tld",
|
||||
backend: "some-service1.namespace.environment.domain.tld",
|
||||
opName: "forward some-service1.namespace.environment.domain.tld/some-service1.namespace.environment.domain.tld",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "More than 101 chars",
|
||||
tracer: &Tracing{
|
||||
SpanNameLimit: 101,
|
||||
},
|
||||
frontend: "some-service1.frontend.namespace.environment.domain.tld",
|
||||
backend: "some-service1.backend.namespace.environment.domain.tld",
|
||||
expected: &forwarderMiddleware{
|
||||
Tracing: &Tracing{
|
||||
SpanNameLimit: 101,
|
||||
},
|
||||
frontend: "some-service1.frontend.namespace.environment.domain.tld",
|
||||
backend: "some-service1.backend.namespace.environment.domain.tld",
|
||||
opName: "forward some-service1.frontend.namespace.envir.../some-service1.backend.namespace.enviro.../fa49dd23",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
actual := test.tracer.NewForwarderMiddleware(test.frontend, test.backend)
|
||||
|
||||
assert.Equal(t, test.expected, actual)
|
||||
assert.True(t, len(test.expected.opName) <= test.tracer.SpanNameLimit)
|
||||
})
|
||||
}
|
||||
}
|
||||
73
old/middlewares/tracing/jaeger/jaeger.go
Normal file
73
old/middlewares/tracing/jaeger/jaeger.go
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
package jaeger
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
jaegercfg "github.com/uber/jaeger-client-go/config"
|
||||
"github.com/uber/jaeger-client-go/zipkin"
|
||||
jaegermet "github.com/uber/jaeger-lib/metrics"
|
||||
)
|
||||
|
||||
// Name sets the name of this tracer
|
||||
const Name = "jaeger"
|
||||
|
||||
// Config provides configuration settings for a jaeger tracer
|
||||
type Config struct {
|
||||
SamplingServerURL string `description:"set the sampling server url." export:"false"`
|
||||
SamplingType string `description:"set the sampling type." export:"true"`
|
||||
SamplingParam float64 `description:"set the sampling parameter." export:"true"`
|
||||
LocalAgentHostPort string `description:"set jaeger-agent's host:port that the reporter will used." export:"false"`
|
||||
Gen128Bit bool `description:"generate 128 bit span IDs." export:"true"`
|
||||
Propagation string `description:"which propgation format to use (jaeger/b3)." export:"true"`
|
||||
}
|
||||
|
||||
// Setup sets up the tracer
|
||||
func (c *Config) Setup(componentName string) (opentracing.Tracer, io.Closer, error) {
|
||||
jcfg := jaegercfg.Configuration{
|
||||
Sampler: &jaegercfg.SamplerConfig{
|
||||
SamplingServerURL: c.SamplingServerURL,
|
||||
Type: c.SamplingType,
|
||||
Param: c.SamplingParam,
|
||||
},
|
||||
Reporter: &jaegercfg.ReporterConfig{
|
||||
LogSpans: true,
|
||||
LocalAgentHostPort: c.LocalAgentHostPort,
|
||||
},
|
||||
}
|
||||
|
||||
jMetricsFactory := jaegermet.NullFactory
|
||||
|
||||
opts := []jaegercfg.Option{
|
||||
jaegercfg.Logger(&jaegerLogger{}),
|
||||
jaegercfg.Metrics(jMetricsFactory),
|
||||
jaegercfg.Gen128Bit(c.Gen128Bit),
|
||||
}
|
||||
|
||||
switch c.Propagation {
|
||||
case "b3":
|
||||
p := zipkin.NewZipkinB3HTTPHeaderPropagator()
|
||||
opts = append(opts,
|
||||
jaegercfg.Injector(opentracing.HTTPHeaders, p),
|
||||
jaegercfg.Extractor(opentracing.HTTPHeaders, p),
|
||||
)
|
||||
case "jaeger", "":
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unknown propagation format: %s", c.Propagation)
|
||||
}
|
||||
|
||||
// Initialize tracer with a logger and a metrics factory
|
||||
closer, err := jcfg.InitGlobalTracer(
|
||||
componentName,
|
||||
opts...,
|
||||
)
|
||||
if err != nil {
|
||||
log.Warnf("Could not initialize jaeger tracer: %s", err.Error())
|
||||
return nil, nil, err
|
||||
}
|
||||
log.Debug("Jaeger tracer configured")
|
||||
|
||||
return opentracing.GlobalTracer(), closer, nil
|
||||
}
|
||||
15
old/middlewares/tracing/jaeger/logger.go
Normal file
15
old/middlewares/tracing/jaeger/logger.go
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
package jaeger
|
||||
|
||||
import "github.com/containous/traefik/old/log"
|
||||
|
||||
// jaegerLogger is an implementation of the Logger interface that delegates to traefik log
|
||||
type jaegerLogger struct{}
|
||||
|
||||
func (l *jaegerLogger) Error(msg string) {
|
||||
log.Errorf("Tracing jaeger error: %s", msg)
|
||||
}
|
||||
|
||||
// Infof logs a message at debug priority
|
||||
func (l *jaegerLogger) Infof(msg string, args ...interface{}) {
|
||||
log.Debugf(msg, args...)
|
||||
}
|
||||
57
old/middlewares/tracing/status_code.go
Normal file
57
old/middlewares/tracing/status_code.go
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
package tracing
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type statusCodeRecoder interface {
|
||||
http.ResponseWriter
|
||||
Status() int
|
||||
}
|
||||
|
||||
type statusCodeWithoutCloseNotify struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
// WriteHeader captures the status code for later retrieval.
|
||||
func (s *statusCodeWithoutCloseNotify) WriteHeader(status int) {
|
||||
s.status = status
|
||||
s.ResponseWriter.WriteHeader(status)
|
||||
}
|
||||
|
||||
// Status get response status
|
||||
func (s *statusCodeWithoutCloseNotify) Status() int {
|
||||
return s.status
|
||||
}
|
||||
|
||||
// Hijack hijacks the connection
|
||||
func (s *statusCodeWithoutCloseNotify) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return s.ResponseWriter.(http.Hijacker).Hijack()
|
||||
}
|
||||
|
||||
// Flush sends any buffered data to the client.
|
||||
func (s *statusCodeWithoutCloseNotify) Flush() {
|
||||
if flusher, ok := s.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
type statusCodeWithCloseNotify struct {
|
||||
*statusCodeWithoutCloseNotify
|
||||
}
|
||||
|
||||
func (s *statusCodeWithCloseNotify) CloseNotify() <-chan bool {
|
||||
return s.ResponseWriter.(http.CloseNotifier).CloseNotify()
|
||||
}
|
||||
|
||||
// newStatusCodeRecoder returns an initialized statusCodeRecoder.
|
||||
func newStatusCodeRecoder(rw http.ResponseWriter, status int) statusCodeRecoder {
|
||||
recorder := &statusCodeWithoutCloseNotify{rw, status}
|
||||
if _, ok := rw.(http.CloseNotifier); ok {
|
||||
return &statusCodeWithCloseNotify{recorder}
|
||||
}
|
||||
return recorder
|
||||
}
|
||||
197
old/middlewares/tracing/tracing.go
Normal file
197
old/middlewares/tracing/tracing.go
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
package tracing
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/containous/traefik/old/middlewares/tracing/datadog"
|
||||
"github.com/containous/traefik/old/middlewares/tracing/jaeger"
|
||||
"github.com/containous/traefik/old/middlewares/tracing/zipkin"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
)
|
||||
|
||||
// ForwardMaxLengthNumber defines the number of static characters in the Forwarding Span Trace name : 8 chars for 'forward ' + 8 chars for hash + 2 chars for '_'.
|
||||
const ForwardMaxLengthNumber = 18
|
||||
|
||||
// EntryPointMaxLengthNumber defines the number of static characters in the Entrypoint Span Trace name : 11 chars for 'Entrypoint ' + 8 chars for hash + 2 chars for '_'.
|
||||
const EntryPointMaxLengthNumber = 21
|
||||
|
||||
// TraceNameHashLength defines the number of characters to use from the head of the generated hash.
|
||||
const TraceNameHashLength = 8
|
||||
|
||||
// Tracing middleware
|
||||
type Tracing struct {
|
||||
Backend string `description:"Selects the tracking backend ('jaeger','zipkin', 'datadog')." export:"true"`
|
||||
ServiceName string `description:"Set the name for this service" export:"true"`
|
||||
SpanNameLimit int `description:"Set the maximum character limit for Span names (default 0 = no limit)" export:"true"`
|
||||
Jaeger *jaeger.Config `description:"Settings for jaeger"`
|
||||
Zipkin *zipkin.Config `description:"Settings for zipkin"`
|
||||
DataDog *datadog.Config `description:"Settings for DataDog"`
|
||||
|
||||
tracer opentracing.Tracer
|
||||
closer io.Closer
|
||||
}
|
||||
|
||||
// StartSpan delegates to opentracing.Tracer
|
||||
func (t *Tracing) StartSpan(operationName string, opts ...opentracing.StartSpanOption) opentracing.Span {
|
||||
return t.tracer.StartSpan(operationName, opts...)
|
||||
}
|
||||
|
||||
// Inject delegates to opentracing.Tracer
|
||||
func (t *Tracing) Inject(sm opentracing.SpanContext, format interface{}, carrier interface{}) error {
|
||||
return t.tracer.Inject(sm, format, carrier)
|
||||
}
|
||||
|
||||
// Extract delegates to opentracing.Tracer
|
||||
func (t *Tracing) Extract(format interface{}, carrier interface{}) (opentracing.SpanContext, error) {
|
||||
return t.tracer.Extract(format, carrier)
|
||||
}
|
||||
|
||||
// Backend describes things we can use to setup tracing
|
||||
type Backend interface {
|
||||
Setup(serviceName string) (opentracing.Tracer, io.Closer, error)
|
||||
}
|
||||
|
||||
// Setup Tracing middleware
|
||||
func (t *Tracing) Setup() {
|
||||
var err error
|
||||
|
||||
switch t.Backend {
|
||||
case jaeger.Name:
|
||||
t.tracer, t.closer, err = t.Jaeger.Setup(t.ServiceName)
|
||||
case zipkin.Name:
|
||||
t.tracer, t.closer, err = t.Zipkin.Setup(t.ServiceName)
|
||||
case datadog.Name:
|
||||
t.tracer, t.closer, err = t.DataDog.Setup(t.ServiceName)
|
||||
default:
|
||||
log.Warnf("Unknown tracer %q", t.Backend)
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Warnf("Could not initialize %s tracing: %v", t.Backend, err)
|
||||
}
|
||||
}
|
||||
|
||||
// IsEnabled determines if tracing was successfully activated
|
||||
func (t *Tracing) IsEnabled() bool {
|
||||
if t == nil || t.tracer == nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Close tracer
|
||||
func (t *Tracing) Close() {
|
||||
if t.closer != nil {
|
||||
err := t.closer.Close()
|
||||
if err != nil {
|
||||
log.Warn(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LogRequest used to create span tags from the request
|
||||
func LogRequest(span opentracing.Span, r *http.Request) {
|
||||
if span != nil && r != nil {
|
||||
ext.HTTPMethod.Set(span, r.Method)
|
||||
ext.HTTPUrl.Set(span, r.URL.String())
|
||||
span.SetTag("http.host", r.Host)
|
||||
}
|
||||
}
|
||||
|
||||
// LogResponseCode used to log response code in span
|
||||
func LogResponseCode(span opentracing.Span, code int) {
|
||||
if span != nil {
|
||||
ext.HTTPStatusCode.Set(span, uint16(code))
|
||||
if code >= 400 {
|
||||
ext.Error.Set(span, true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetSpan used to retrieve span from request context
|
||||
func GetSpan(r *http.Request) opentracing.Span {
|
||||
return opentracing.SpanFromContext(r.Context())
|
||||
}
|
||||
|
||||
// InjectRequestHeaders used to inject OpenTracing headers into the request
|
||||
func InjectRequestHeaders(r *http.Request) {
|
||||
if span := GetSpan(r); span != nil {
|
||||
err := opentracing.GlobalTracer().Inject(
|
||||
span.Context(),
|
||||
opentracing.HTTPHeaders,
|
||||
HTTPHeadersCarrier(r.Header))
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LogEventf logs an event to the span in the request context.
|
||||
func LogEventf(r *http.Request, format string, args ...interface{}) {
|
||||
if span := GetSpan(r); span != nil {
|
||||
span.LogKV("event", fmt.Sprintf(format, args...))
|
||||
}
|
||||
}
|
||||
|
||||
// StartSpan starts a new span from the one in the request context
|
||||
func StartSpan(r *http.Request, operationName string, spanKinClient bool, opts ...opentracing.StartSpanOption) (opentracing.Span, *http.Request, func()) {
|
||||
span, ctx := opentracing.StartSpanFromContext(r.Context(), operationName, opts...)
|
||||
if spanKinClient {
|
||||
ext.SpanKindRPCClient.Set(span)
|
||||
}
|
||||
r = r.WithContext(ctx)
|
||||
return span, r, func() {
|
||||
span.Finish()
|
||||
}
|
||||
}
|
||||
|
||||
// SetError flags the span associated with this request as in error
|
||||
func SetError(r *http.Request) {
|
||||
if span := GetSpan(r); span != nil {
|
||||
ext.Error.Set(span, true)
|
||||
}
|
||||
}
|
||||
|
||||
// SetErrorAndDebugLog flags the span associated with this request as in error and create a debug log.
|
||||
func SetErrorAndDebugLog(r *http.Request, format string, args ...interface{}) {
|
||||
SetError(r)
|
||||
log.Debugf(format, args...)
|
||||
LogEventf(r, format, args...)
|
||||
}
|
||||
|
||||
// SetErrorAndWarnLog flags the span associated with this request as in error and create a debug log.
|
||||
func SetErrorAndWarnLog(r *http.Request, format string, args ...interface{}) {
|
||||
SetError(r)
|
||||
log.Warnf(format, args...)
|
||||
LogEventf(r, format, args...)
|
||||
}
|
||||
|
||||
// truncateString reduces the length of the 'str' argument to 'num' - 3 and adds a '...' suffix to the tail.
|
||||
func truncateString(str string, num int) string {
|
||||
text := str
|
||||
if len(str) > num {
|
||||
if num > 3 {
|
||||
num -= 3
|
||||
}
|
||||
text = str[0:num] + "..."
|
||||
}
|
||||
return text
|
||||
}
|
||||
|
||||
// computeHash returns the first TraceNameHashLength character of the sha256 hash for 'name' argument.
|
||||
func computeHash(name string) string {
|
||||
data := []byte(name)
|
||||
hash := sha256.New()
|
||||
if _, err := hash.Write(data); err != nil {
|
||||
// Impossible case
|
||||
log.Errorf("Fail to create Span name hash for %s: %v", name, err)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%x", hash.Sum(nil))[:TraceNameHashLength]
|
||||
}
|
||||
133
old/middlewares/tracing/tracing_test.go
Normal file
133
old/middlewares/tracing/tracing_test.go
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
package tracing
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"github.com/opentracing/opentracing-go/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type MockTracer struct {
|
||||
Span *MockSpan
|
||||
}
|
||||
|
||||
type MockSpan struct {
|
||||
OpName string
|
||||
Tags map[string]interface{}
|
||||
}
|
||||
|
||||
type MockSpanContext struct {
|
||||
}
|
||||
|
||||
// MockSpanContext:
|
||||
func (n MockSpanContext) ForeachBaggageItem(handler func(k, v string) bool) {}
|
||||
|
||||
// MockSpan:
|
||||
func (n MockSpan) Context() opentracing.SpanContext { return MockSpanContext{} }
|
||||
func (n MockSpan) SetBaggageItem(key, val string) opentracing.Span {
|
||||
return MockSpan{Tags: make(map[string]interface{})}
|
||||
}
|
||||
func (n MockSpan) BaggageItem(key string) string { return "" }
|
||||
func (n MockSpan) SetTag(key string, value interface{}) opentracing.Span {
|
||||
n.Tags[key] = value
|
||||
return n
|
||||
}
|
||||
func (n MockSpan) LogFields(fields ...log.Field) {}
|
||||
func (n MockSpan) LogKV(keyVals ...interface{}) {}
|
||||
func (n MockSpan) Finish() {}
|
||||
func (n MockSpan) FinishWithOptions(opts opentracing.FinishOptions) {}
|
||||
func (n MockSpan) SetOperationName(operationName string) opentracing.Span { return n }
|
||||
func (n MockSpan) Tracer() opentracing.Tracer { return MockTracer{} }
|
||||
func (n MockSpan) LogEvent(event string) {}
|
||||
func (n MockSpan) LogEventWithPayload(event string, payload interface{}) {}
|
||||
func (n MockSpan) Log(data opentracing.LogData) {}
|
||||
func (n MockSpan) Reset() {
|
||||
n.Tags = make(map[string]interface{})
|
||||
}
|
||||
|
||||
// StartSpan belongs to the Tracer interface.
|
||||
func (n MockTracer) StartSpan(operationName string, opts ...opentracing.StartSpanOption) opentracing.Span {
|
||||
n.Span.OpName = operationName
|
||||
return n.Span
|
||||
}
|
||||
|
||||
// Inject belongs to the Tracer interface.
|
||||
func (n MockTracer) Inject(sp opentracing.SpanContext, format interface{}, carrier interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Extract belongs to the Tracer interface.
|
||||
func (n MockTracer) Extract(format interface{}, carrier interface{}) (opentracing.SpanContext, error) {
|
||||
return nil, opentracing.ErrSpanContextNotFound
|
||||
}
|
||||
|
||||
func TestTruncateString(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
text string
|
||||
limit int
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "short text less than limit 10",
|
||||
text: "short",
|
||||
limit: 10,
|
||||
expected: "short",
|
||||
},
|
||||
{
|
||||
desc: "basic truncate with limit 10",
|
||||
text: "some very long pice of text",
|
||||
limit: 10,
|
||||
expected: "some ve...",
|
||||
},
|
||||
{
|
||||
desc: "truncate long FQDN to 39 chars",
|
||||
text: "some-service-100.slug.namespace.environment.domain.tld",
|
||||
limit: 39,
|
||||
expected: "some-service-100.slug.namespace.envi...",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
actual := truncateString(test.text, test.limit)
|
||||
|
||||
assert.Equal(t, test.expected, actual)
|
||||
assert.True(t, len(actual) <= test.limit)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeHash(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
text string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "hashing",
|
||||
text: "some very long pice of text",
|
||||
expected: "0258ea1c",
|
||||
},
|
||||
{
|
||||
desc: "short text less than limit 10",
|
||||
text: "short",
|
||||
expected: "f9b0078b",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
actual := computeHash(test.text)
|
||||
|
||||
assert.Equal(t, test.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
66
old/middlewares/tracing/wrapper.go
Normal file
66
old/middlewares/tracing/wrapper.go
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
package tracing
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/urfave/negroni"
|
||||
)
|
||||
|
||||
// NewNegroniHandlerWrapper return a negroni.Handler struct
|
||||
func (t *Tracing) NewNegroniHandlerWrapper(name string, handler negroni.Handler, clientSpanKind bool) negroni.Handler {
|
||||
if t.IsEnabled() && handler != nil {
|
||||
return &NegroniHandlerWrapper{
|
||||
name: name,
|
||||
next: handler,
|
||||
clientSpanKind: clientSpanKind,
|
||||
}
|
||||
}
|
||||
return handler
|
||||
}
|
||||
|
||||
// NewHTTPHandlerWrapper return a http.Handler struct
|
||||
func (t *Tracing) NewHTTPHandlerWrapper(name string, handler http.Handler, clientSpanKind bool) http.Handler {
|
||||
if t.IsEnabled() && handler != nil {
|
||||
return &HTTPHandlerWrapper{
|
||||
name: name,
|
||||
handler: handler,
|
||||
clientSpanKind: clientSpanKind,
|
||||
}
|
||||
}
|
||||
return handler
|
||||
}
|
||||
|
||||
// NegroniHandlerWrapper is used to wrap negroni handler middleware
|
||||
type NegroniHandlerWrapper struct {
|
||||
name string
|
||||
next negroni.Handler
|
||||
clientSpanKind bool
|
||||
}
|
||||
|
||||
func (t *NegroniHandlerWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
|
||||
var finish func()
|
||||
_, r, finish = StartSpan(r, t.name, t.clientSpanKind)
|
||||
defer finish()
|
||||
|
||||
if t.next != nil {
|
||||
t.next.ServeHTTP(rw, r, next)
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPHandlerWrapper is used to wrap http handler middleware
|
||||
type HTTPHandlerWrapper struct {
|
||||
name string
|
||||
handler http.Handler
|
||||
clientSpanKind bool
|
||||
}
|
||||
|
||||
func (t *HTTPHandlerWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
var finish func()
|
||||
_, r, finish = StartSpan(r, t.name, t.clientSpanKind)
|
||||
defer finish()
|
||||
|
||||
if t.handler != nil {
|
||||
t.handler.ServeHTTP(rw, r)
|
||||
}
|
||||
|
||||
}
|
||||
49
old/middlewares/tracing/zipkin/zipkin.go
Normal file
49
old/middlewares/tracing/zipkin/zipkin.go
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
package zipkin
|
||||
|
||||
import (
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
zipkin "github.com/openzipkin/zipkin-go-opentracing"
|
||||
)
|
||||
|
||||
// Name sets the name of this tracer
|
||||
const Name = "zipkin"
|
||||
|
||||
// Config provides configuration settings for a zipkin tracer
|
||||
type Config struct {
|
||||
HTTPEndpoint string `description:"HTTP Endpoint to report traces to." export:"false"`
|
||||
SameSpan bool `description:"Use Zipkin SameSpan RPC style traces." export:"true"`
|
||||
ID128Bit bool `description:"Use Zipkin 128 bit root span IDs." export:"true"`
|
||||
Debug bool `description:"Enable Zipkin debug." export:"true"`
|
||||
SampleRate float64 `description:"The rate between 0.0 and 1.0 of requests to trace." export:"true"`
|
||||
}
|
||||
|
||||
// Setup sets up the tracer
|
||||
func (c *Config) Setup(serviceName string) (opentracing.Tracer, io.Closer, error) {
|
||||
collector, err := zipkin.NewHTTPCollector(c.HTTPEndpoint)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
recorder := zipkin.NewRecorder(collector, c.Debug, "0.0.0.0:0", serviceName)
|
||||
tracer, err := zipkin.NewTracer(
|
||||
recorder,
|
||||
zipkin.ClientServerSameSpan(c.SameSpan),
|
||||
zipkin.TraceID128Bit(c.ID128Bit),
|
||||
zipkin.DebugMode(c.Debug),
|
||||
zipkin.WithSampler(zipkin.NewBoundarySampler(c.SampleRate, time.Now().Unix())),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Without this, child spans are getting the NOOP tracer
|
||||
opentracing.SetGlobalTracer(tracer)
|
||||
|
||||
log.Debug("Zipkin tracer configured")
|
||||
|
||||
return tracer, collector, nil
|
||||
}
|
||||
36
old/ping/ping.go
Normal file
36
old/ping/ping.go
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
package ping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/containous/mux"
|
||||
)
|
||||
|
||||
// Handler expose ping routes
|
||||
type Handler struct {
|
||||
EntryPoint string `description:"Ping entryPoint" export:"true"`
|
||||
terminating bool
|
||||
}
|
||||
|
||||
// WithContext causes the ping endpoint to serve non 200 responses.
|
||||
func (h *Handler) WithContext(ctx context.Context) {
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
h.terminating = true
|
||||
}()
|
||||
}
|
||||
|
||||
// AddRoutes add ping routes on a router
|
||||
func (h *Handler) AddRoutes(router *mux.Router) {
|
||||
router.Methods(http.MethodGet, http.MethodHead).Path("/ping").
|
||||
HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
|
||||
statusCode := http.StatusOK
|
||||
if h.terminating {
|
||||
statusCode = http.StatusServiceUnavailable
|
||||
}
|
||||
response.WriteHeader(statusCode)
|
||||
fmt.Fprint(response, http.StatusText(statusCode))
|
||||
})
|
||||
}
|
||||
83
old/provider/acme/account.go
Normal file
83
old/provider/acme/account.go
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/xenolf/lego/acme"
|
||||
)
|
||||
|
||||
// Account is used to store lets encrypt registration info
|
||||
type Account struct {
|
||||
Email string
|
||||
Registration *acme.RegistrationResource
|
||||
PrivateKey []byte
|
||||
KeyType acme.KeyType
|
||||
}
|
||||
|
||||
const (
|
||||
// RegistrationURLPathV1Regexp is a regexp which match ACME registration URL in the V1 format
|
||||
RegistrationURLPathV1Regexp = `^.*/acme/reg/\d+$`
|
||||
)
|
||||
|
||||
// NewAccount creates an account
|
||||
func NewAccount(email string, keyTypeValue string) (*Account, error) {
|
||||
keyType := GetKeyType(keyTypeValue)
|
||||
|
||||
// Create a user. New accounts need an email and private key to start
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Account{
|
||||
Email: email,
|
||||
PrivateKey: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||
KeyType: keyType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetEmail returns email
|
||||
func (a *Account) GetEmail() string {
|
||||
return a.Email
|
||||
}
|
||||
|
||||
// GetRegistration returns lets encrypt registration resource
|
||||
func (a *Account) GetRegistration() *acme.RegistrationResource {
|
||||
return a.Registration
|
||||
}
|
||||
|
||||
// GetPrivateKey returns private key
|
||||
func (a *Account) GetPrivateKey() crypto.PrivateKey {
|
||||
if privateKey, err := x509.ParsePKCS1PrivateKey(a.PrivateKey); err == nil {
|
||||
return privateKey
|
||||
}
|
||||
|
||||
log.Errorf("Cannot unmarshal private key %+v", a.PrivateKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetKeyType used to determine which algo to used
|
||||
func GetKeyType(value string) acme.KeyType {
|
||||
switch value {
|
||||
case "EC256":
|
||||
return acme.EC256
|
||||
case "EC384":
|
||||
return acme.EC384
|
||||
case "RSA2048":
|
||||
return acme.RSA2048
|
||||
case "RSA4096":
|
||||
return acme.RSA4096
|
||||
case "RSA8192":
|
||||
return acme.RSA8192
|
||||
case "":
|
||||
log.Infof("The key type is empty. Use default key type %v.", acme.RSA4096)
|
||||
return acme.RSA4096
|
||||
default:
|
||||
log.Infof("Unable to determine key type value %q. Use default key type %v.", value, acme.RSA4096)
|
||||
return acme.RSA4096
|
||||
}
|
||||
}
|
||||
86
old/provider/acme/challenge_http.go
Normal file
86
old/provider/acme/challenge_http.go
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/cenk/backoff"
|
||||
"github.com/containous/mux"
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/containous/traefik/safe"
|
||||
"github.com/xenolf/lego/acme"
|
||||
)
|
||||
|
||||
var _ acme.ChallengeProviderTimeout = (*challengeHTTP)(nil)
|
||||
|
||||
type challengeHTTP struct {
|
||||
Store Store
|
||||
}
|
||||
|
||||
// Present presents a challenge to obtain new ACME certificate
|
||||
func (c *challengeHTTP) Present(domain, token, keyAuth string) error {
|
||||
return c.Store.SetHTTPChallengeToken(token, domain, []byte(keyAuth))
|
||||
}
|
||||
|
||||
// CleanUp cleans the challenges when certificate is obtained
|
||||
func (c *challengeHTTP) CleanUp(domain, token, keyAuth string) error {
|
||||
return c.Store.RemoveHTTPChallengeToken(token, domain)
|
||||
}
|
||||
|
||||
// Timeout calculates the maximum of time allowed to resolved an ACME challenge
|
||||
func (c *challengeHTTP) Timeout() (timeout, interval time.Duration) {
|
||||
return 60 * time.Second, 5 * time.Second
|
||||
}
|
||||
|
||||
func getTokenValue(token, domain string, store Store) []byte {
|
||||
log.Debugf("Looking for an existing ACME challenge for token %v...", token)
|
||||
var result []byte
|
||||
|
||||
operation := func() error {
|
||||
var err error
|
||||
result, err = store.GetHTTPChallengeToken(token, domain)
|
||||
return err
|
||||
}
|
||||
|
||||
notify := func(err error, time time.Duration) {
|
||||
log.Errorf("Error getting challenge for token retrying in %s", time)
|
||||
}
|
||||
|
||||
ebo := backoff.NewExponentialBackOff()
|
||||
ebo.MaxElapsedTime = 60 * time.Second
|
||||
err := backoff.RetryNotify(safe.OperationWithRecover(operation), ebo, notify)
|
||||
if err != nil {
|
||||
log.Errorf("Error getting challenge for token: %v", err)
|
||||
return []byte{}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// AddRoutes add routes on internal router
|
||||
func (p *Provider) AddRoutes(router *mux.Router) {
|
||||
router.Methods(http.MethodGet).
|
||||
Path(acme.HTTP01ChallengePath("{token}")).
|
||||
Handler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
vars := mux.Vars(req)
|
||||
if token, ok := vars["token"]; ok {
|
||||
domain, _, err := net.SplitHostPort(req.Host)
|
||||
if err != nil {
|
||||
log.Debugf("Unable to split host and port: %v. Fallback to request host.", err)
|
||||
domain = req.Host
|
||||
}
|
||||
|
||||
tokenValue := getTokenValue(token, domain, p.Store)
|
||||
if len(tokenValue) > 0 {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
_, err = rw.Write(tokenValue)
|
||||
if err != nil {
|
||||
log.Errorf("Unable to write token : %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
rw.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
}
|
||||
52
old/provider/acme/challenge_tls.go
Normal file
52
old/provider/acme/challenge_tls.go
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/xenolf/lego/acme"
|
||||
)
|
||||
|
||||
var _ acme.ChallengeProvider = (*challengeTLSALPN)(nil)
|
||||
|
||||
type challengeTLSALPN struct {
|
||||
Store Store
|
||||
}
|
||||
|
||||
func (c *challengeTLSALPN) Present(domain, token, keyAuth string) error {
|
||||
log.Debugf("TLS Challenge Present temp certificate for %s", domain)
|
||||
|
||||
certPEMBlock, keyPEMBlock, err := acme.TLSALPNChallengeBlocks(domain, keyAuth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cert := &Certificate{Certificate: certPEMBlock, Key: keyPEMBlock, Domain: types.Domain{Main: "TEMP-" + domain}}
|
||||
return c.Store.AddTLSChallenge(domain, cert)
|
||||
}
|
||||
|
||||
func (c *challengeTLSALPN) CleanUp(domain, token, keyAuth string) error {
|
||||
log.Debugf("TLS Challenge CleanUp temp certificate for %s", domain)
|
||||
|
||||
return c.Store.RemoveTLSChallenge(domain)
|
||||
}
|
||||
|
||||
// GetTLSALPNCertificate Get the temp certificate for ACME TLS-ALPN-O1 challenge.
|
||||
func (p *Provider) GetTLSALPNCertificate(domain string) (*tls.Certificate, error) {
|
||||
cert, err := p.Store.GetTLSChallenge(domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cert == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
certificate, err := tls.X509KeyPair(cert.Certificate, cert.Key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &certificate, nil
|
||||
}
|
||||
251
old/provider/acme/local_store.go
Normal file
251
old/provider/acme/local_store.go
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"regexp"
|
||||
"sync"
|
||||
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/containous/traefik/safe"
|
||||
)
|
||||
|
||||
var _ Store = (*LocalStore)(nil)
|
||||
|
||||
// LocalStore Store implementation for local file
|
||||
type LocalStore struct {
|
||||
filename string
|
||||
storedData *StoredData
|
||||
SaveDataChan chan *StoredData `json:"-"`
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
// NewLocalStore initializes a new LocalStore with a file name
|
||||
func NewLocalStore(filename string) *LocalStore {
|
||||
store := &LocalStore{filename: filename, SaveDataChan: make(chan *StoredData)}
|
||||
store.listenSaveAction()
|
||||
return store
|
||||
}
|
||||
|
||||
func (s *LocalStore) get() (*StoredData, error) {
|
||||
if s.storedData == nil {
|
||||
s.storedData = &StoredData{
|
||||
HTTPChallenges: make(map[string]map[string][]byte),
|
||||
TLSChallenges: make(map[string]*Certificate),
|
||||
}
|
||||
|
||||
hasData, err := CheckFile(s.filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if hasData {
|
||||
f, err := os.Open(s.filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
file, err := ioutil.ReadAll(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(file) > 0 {
|
||||
if err := json.Unmarshal(file, s.storedData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Check if ACME Account is in ACME V1 format
|
||||
if s.storedData.Account != nil && s.storedData.Account.Registration != nil {
|
||||
isOldRegistration, err := regexp.MatchString(RegistrationURLPathV1Regexp, s.storedData.Account.Registration.URI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if isOldRegistration {
|
||||
log.Debug("Reset ACME account.")
|
||||
s.storedData.Account = nil
|
||||
s.SaveDataChan <- s.storedData
|
||||
}
|
||||
}
|
||||
|
||||
// Delete all certificates with no value
|
||||
var certificates []*Certificate
|
||||
for _, certificate := range s.storedData.Certificates {
|
||||
if len(certificate.Certificate) == 0 || len(certificate.Key) == 0 {
|
||||
log.Debugf("Delete certificate %v for domains %v which have no value.", certificate, certificate.Domain.ToStrArray())
|
||||
continue
|
||||
}
|
||||
certificates = append(certificates, certificate)
|
||||
}
|
||||
|
||||
if len(certificates) < len(s.storedData.Certificates) {
|
||||
s.storedData.Certificates = certificates
|
||||
s.SaveDataChan <- s.storedData
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return s.storedData, nil
|
||||
}
|
||||
|
||||
// listenSaveAction listens to a chan to store ACME data in json format into LocalStore.filename
|
||||
func (s *LocalStore) listenSaveAction() {
|
||||
safe.Go(func() {
|
||||
for object := range s.SaveDataChan {
|
||||
data, err := json.MarshalIndent(object, "", " ")
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(s.filename, data, 0600)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// GetAccount returns ACME Account
|
||||
func (s *LocalStore) GetAccount() (*Account, error) {
|
||||
storedData, err := s.get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return storedData.Account, nil
|
||||
}
|
||||
|
||||
// SaveAccount stores ACME Account
|
||||
func (s *LocalStore) SaveAccount(account *Account) error {
|
||||
storedData, err := s.get()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
storedData.Account = account
|
||||
s.SaveDataChan <- storedData
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCertificates returns ACME Certificates list
|
||||
func (s *LocalStore) GetCertificates() ([]*Certificate, error) {
|
||||
storedData, err := s.get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return storedData.Certificates, nil
|
||||
}
|
||||
|
||||
// SaveCertificates stores ACME Certificates list
|
||||
func (s *LocalStore) SaveCertificates(certificates []*Certificate) error {
|
||||
storedData, err := s.get()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
storedData.Certificates = certificates
|
||||
s.SaveDataChan <- storedData
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetHTTPChallengeToken Get the http challenge token from the store
|
||||
func (s *LocalStore) GetHTTPChallengeToken(token, domain string) ([]byte, error) {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
|
||||
if s.storedData.HTTPChallenges == nil {
|
||||
s.storedData.HTTPChallenges = map[string]map[string][]byte{}
|
||||
}
|
||||
|
||||
if _, ok := s.storedData.HTTPChallenges[token]; !ok {
|
||||
return nil, fmt.Errorf("cannot find challenge for token %v", token)
|
||||
}
|
||||
|
||||
result, ok := s.storedData.HTTPChallenges[token][domain]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("cannot find challenge for token %v", token)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SetHTTPChallengeToken Set the http challenge token in the store
|
||||
func (s *LocalStore) SetHTTPChallengeToken(token, domain string, keyAuth []byte) error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.storedData.HTTPChallenges == nil {
|
||||
s.storedData.HTTPChallenges = map[string]map[string][]byte{}
|
||||
}
|
||||
|
||||
if _, ok := s.storedData.HTTPChallenges[token]; !ok {
|
||||
s.storedData.HTTPChallenges[token] = map[string][]byte{}
|
||||
}
|
||||
|
||||
s.storedData.HTTPChallenges[token][domain] = keyAuth
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveHTTPChallengeToken Remove the http challenge token in the store
|
||||
func (s *LocalStore) RemoveHTTPChallengeToken(token, domain string) error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.storedData.HTTPChallenges == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, ok := s.storedData.HTTPChallenges[token]; ok {
|
||||
if _, domainOk := s.storedData.HTTPChallenges[token][domain]; domainOk {
|
||||
delete(s.storedData.HTTPChallenges[token], domain)
|
||||
}
|
||||
if len(s.storedData.HTTPChallenges[token]) == 0 {
|
||||
delete(s.storedData.HTTPChallenges, token)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddTLSChallenge Add a certificate to the ACME TLS-ALPN-01 certificates storage
|
||||
func (s *LocalStore) AddTLSChallenge(domain string, cert *Certificate) error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.storedData.TLSChallenges == nil {
|
||||
s.storedData.TLSChallenges = make(map[string]*Certificate)
|
||||
}
|
||||
|
||||
s.storedData.TLSChallenges[domain] = cert
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTLSChallenge Get a certificate from the ACME TLS-ALPN-01 certificates storage
|
||||
func (s *LocalStore) GetTLSChallenge(domain string) (*Certificate, error) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.storedData.TLSChallenges == nil {
|
||||
s.storedData.TLSChallenges = make(map[string]*Certificate)
|
||||
}
|
||||
|
||||
return s.storedData.TLSChallenges[domain], nil
|
||||
}
|
||||
|
||||
// RemoveTLSChallenge Remove a certificate from the ACME TLS-ALPN-01 certificates storage
|
||||
func (s *LocalStore) RemoveTLSChallenge(domain string) error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
if s.storedData.TLSChallenges == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
delete(s.storedData.TLSChallenges, domain)
|
||||
return nil
|
||||
}
|
||||
35
old/provider/acme/local_store_unix.go
Normal file
35
old/provider/acme/local_store_unix.go
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
// +build !windows
|
||||
|
||||
package acme
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// CheckFile checks file permissions and content size
|
||||
func CheckFile(name string) (bool, error) {
|
||||
f, err := os.Open(name)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
f, err = os.Create(name)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return false, f.Chmod(0600)
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
fi, err := f.Stat()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if fi.Mode().Perm()&0077 != 0 {
|
||||
return false, fmt.Errorf("permissions %o for %s are too open, please use 600", fi.Mode().Perm(), name)
|
||||
}
|
||||
|
||||
return fi.Size() > 0, nil
|
||||
}
|
||||
27
old/provider/acme/local_store_windows.go
Normal file
27
old/provider/acme/local_store_windows.go
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
package acme
|
||||
|
||||
import "os"
|
||||
|
||||
// CheckFile checks file content size
|
||||
// Do not check file permissions on Windows right now
|
||||
func CheckFile(name string) (bool, error) {
|
||||
f, err := os.Open(name)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
f, err = os.Create(name)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return false, f.Chmod(0600)
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
fi, err := f.Stat()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return fi.Size() > 0, nil
|
||||
}
|
||||
826
old/provider/acme/provider.go
Normal file
826
old/provider/acme/provider.go
Normal file
|
|
@ -0,0 +1,826 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
fmtlog "log"
|
||||
"net"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenk/backoff"
|
||||
"github.com/containous/flaeg/parse"
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/containous/traefik/rules"
|
||||
"github.com/containous/traefik/safe"
|
||||
traefiktls "github.com/containous/traefik/tls"
|
||||
"github.com/containous/traefik/version"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/xenolf/lego/acme"
|
||||
legolog "github.com/xenolf/lego/log"
|
||||
"github.com/xenolf/lego/providers/dns"
|
||||
)
|
||||
|
||||
var (
|
||||
// OSCPMustStaple enables OSCP stapling as from https://github.com/xenolf/lego/issues/270
|
||||
OSCPMustStaple = false
|
||||
)
|
||||
|
||||
// Configuration holds ACME configuration provided by users
|
||||
type Configuration struct {
|
||||
Email string `description:"Email address used for registration"`
|
||||
ACMELogging bool `description:"Enable debug logging of ACME actions."`
|
||||
CAServer string `description:"CA server to use."`
|
||||
Storage string `description:"Storage to use."`
|
||||
EntryPoint string `description:"EntryPoint to use."`
|
||||
KeyType string `description:"KeyType used for generating certificate private key. Allow value 'EC256', 'EC384', 'RSA2048', 'RSA4096', 'RSA8192'. Default to 'RSA4096'"`
|
||||
OnHostRule bool `description:"Enable certificate generation on frontends Host rules."`
|
||||
OnDemand bool `description:"Enable on demand certificate generation. This will request a certificate from Let's Encrypt during the first TLS handshake for a hostname that does not yet have a certificate."` // Deprecated
|
||||
DNSChallenge *DNSChallenge `description:"Activate DNS-01 Challenge"`
|
||||
HTTPChallenge *HTTPChallenge `description:"Activate HTTP-01 Challenge"`
|
||||
TLSChallenge *TLSChallenge `description:"Activate TLS-ALPN-01 Challenge"`
|
||||
Domains []types.Domain `description:"CN and SANs (alternative domains) to each main domain using format: --acme.domains='main.com,san1.com,san2.com' --acme.domains='*.main.net'. No SANs for wildcards domain. Wildcard domains only accepted with DNSChallenge"`
|
||||
}
|
||||
|
||||
// Provider holds configurations of the provider.
|
||||
type Provider struct {
|
||||
*Configuration
|
||||
Store Store
|
||||
certificates []*Certificate
|
||||
account *Account
|
||||
client *acme.Client
|
||||
certsChan chan *Certificate
|
||||
configurationChan chan<- types.ConfigMessage
|
||||
certificateStore *traefiktls.CertificateStore
|
||||
clientMutex sync.Mutex
|
||||
configFromListenerChan chan types.Configuration
|
||||
pool *safe.Pool
|
||||
resolvingDomains map[string]struct{}
|
||||
resolvingDomainsMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// Certificate is a struct which contains all data needed from an ACME certificate
|
||||
type Certificate struct {
|
||||
Domain types.Domain
|
||||
Certificate []byte
|
||||
Key []byte
|
||||
}
|
||||
|
||||
// DNSChallenge contains DNS challenge Configuration
|
||||
type DNSChallenge struct {
|
||||
Provider string `description:"Use a DNS-01 based challenge provider rather than HTTPS."`
|
||||
DelayBeforeCheck parse.Duration `description:"Assume DNS propagates after a delay in seconds rather than finding and querying nameservers."`
|
||||
Resolvers types.DNSResolvers `description:"Use following DNS servers to resolve the FQDN authority."`
|
||||
DisablePropagationCheck bool `description:"Disable the DNS propagation checks before notifying ACME that the DNS challenge is ready. [not recommended]"`
|
||||
preCheckTimeout time.Duration
|
||||
preCheckInterval time.Duration
|
||||
}
|
||||
|
||||
// HTTPChallenge contains HTTP challenge Configuration
|
||||
type HTTPChallenge struct {
|
||||
EntryPoint string `description:"HTTP challenge EntryPoint"`
|
||||
}
|
||||
|
||||
// TLSChallenge contains TLS challenge Configuration
|
||||
type TLSChallenge struct{}
|
||||
|
||||
// SetConfigListenerChan initializes the configFromListenerChan
|
||||
func (p *Provider) SetConfigListenerChan(configFromListenerChan chan types.Configuration) {
|
||||
p.configFromListenerChan = configFromListenerChan
|
||||
}
|
||||
|
||||
// SetCertificateStore allow to initialize certificate store
|
||||
func (p *Provider) SetCertificateStore(certificateStore *traefiktls.CertificateStore) {
|
||||
p.certificateStore = certificateStore
|
||||
}
|
||||
|
||||
// ListenConfiguration sets a new Configuration into the configFromListenerChan
|
||||
func (p *Provider) ListenConfiguration(config types.Configuration) {
|
||||
p.configFromListenerChan <- config
|
||||
}
|
||||
|
||||
// ListenRequest resolves new certificates for a domain from an incoming request and return a valid Certificate to serve (onDemand option)
|
||||
func (p *Provider) ListenRequest(domain string) (*tls.Certificate, error) {
|
||||
acmeCert, err := p.resolveCertificate(types.Domain{Main: domain}, false)
|
||||
if acmeCert == nil || err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
certificate, err := tls.X509KeyPair(acmeCert.Certificate, acmeCert.PrivateKey)
|
||||
|
||||
return &certificate, err
|
||||
}
|
||||
|
||||
// Init for compatibility reason the BaseProvider implements an empty Init
|
||||
func (p *Provider) Init(_ types.Constraints) error {
|
||||
acme.UserAgent = fmt.Sprintf("containous-traefik/%s", version.Version)
|
||||
if p.ACMELogging {
|
||||
legolog.Logger = fmtlog.New(log.WriterLevel(logrus.InfoLevel), "legolog: ", 0)
|
||||
} else {
|
||||
legolog.Logger = fmtlog.New(ioutil.Discard, "", 0)
|
||||
}
|
||||
|
||||
if p.Store == nil {
|
||||
return errors.New("no store found for the ACME provider")
|
||||
}
|
||||
|
||||
var err error
|
||||
p.account, err = p.Store.GetAccount()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get ACME account : %v", err)
|
||||
}
|
||||
|
||||
// Reset Account if caServer changed, thus registration URI can be updated
|
||||
if p.account != nil && p.account.Registration != nil && !isAccountMatchingCaServer(p.account.Registration.URI, p.CAServer) {
|
||||
log.Info("Account URI does not match the current CAServer. The account will be reset")
|
||||
p.account = nil
|
||||
}
|
||||
|
||||
p.certificates, err = p.Store.GetCertificates()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get ACME certificates : %v", err)
|
||||
}
|
||||
|
||||
// Init the currently resolved domain map
|
||||
p.resolvingDomains = make(map[string]struct{})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isAccountMatchingCaServer(accountURI string, serverURI string) bool {
|
||||
aru, err := url.Parse(accountURI)
|
||||
if err != nil {
|
||||
log.Infof("Unable to parse account.Registration URL : %v", err)
|
||||
return false
|
||||
}
|
||||
cau, err := url.Parse(serverURI)
|
||||
if err != nil {
|
||||
log.Infof("Unable to parse CAServer URL : %v", err)
|
||||
return false
|
||||
}
|
||||
return cau.Hostname() == aru.Hostname()
|
||||
}
|
||||
|
||||
// Provide allows the file provider to provide configurations to traefik
|
||||
// using the given Configuration channel.
|
||||
func (p *Provider) Provide(configurationChan chan<- types.ConfigMessage, pool *safe.Pool) error {
|
||||
p.pool = pool
|
||||
|
||||
p.watchCertificate()
|
||||
p.watchNewDomains()
|
||||
|
||||
p.configurationChan = configurationChan
|
||||
p.refreshCertificates()
|
||||
|
||||
p.deleteUnnecessaryDomains()
|
||||
for i := 0; i < len(p.Domains); i++ {
|
||||
domain := p.Domains[i]
|
||||
safe.Go(func() {
|
||||
if _, err := p.resolveCertificate(domain, true); err != nil {
|
||||
log.Errorf("Unable to obtain ACME certificate for domains %q : %v", strings.Join(domain.ToStrArray(), ","), err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
p.renewCertificates()
|
||||
|
||||
ticker := time.NewTicker(24 * time.Hour)
|
||||
pool.Go(func(stop chan bool) {
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
p.renewCertificates()
|
||||
case <-stop:
|
||||
ticker.Stop()
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) getClient() (*acme.Client, error) {
|
||||
p.clientMutex.Lock()
|
||||
defer p.clientMutex.Unlock()
|
||||
|
||||
if p.client != nil {
|
||||
return p.client, nil
|
||||
}
|
||||
|
||||
account, err := p.initAccount()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Debug("Building ACME client...")
|
||||
|
||||
caServer := "https://acme-v02.api.letsencrypt.org/directory"
|
||||
if len(p.CAServer) > 0 {
|
||||
caServer = p.CAServer
|
||||
}
|
||||
log.Debug(caServer)
|
||||
|
||||
client, err := acme.NewClient(caServer, account, account.KeyType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// New users will need to register; be sure to save it
|
||||
if account.GetRegistration() == nil {
|
||||
log.Info("Register...")
|
||||
|
||||
reg, err := client.Register(true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account.Registration = reg
|
||||
}
|
||||
|
||||
// Save the account once before all the certificates generation/storing
|
||||
// No certificate can be generated if account is not initialized
|
||||
err = p.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.DNSChallenge != nil && len(p.DNSChallenge.Provider) > 0 {
|
||||
log.Debugf("Using DNS Challenge provider: %s", p.DNSChallenge.Provider)
|
||||
|
||||
SetRecursiveNameServers(p.DNSChallenge.Resolvers)
|
||||
SetPropagationCheck(p.DNSChallenge.DisablePropagationCheck)
|
||||
|
||||
err = dnsOverrideDelay(p.DNSChallenge.DelayBeforeCheck)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var provider acme.ChallengeProvider
|
||||
provider, err = dns.NewDNSChallengeProviderByName(p.DNSChallenge.Provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client.ExcludeChallenges([]acme.Challenge{acme.HTTP01, acme.TLSALPN01})
|
||||
|
||||
err = client.SetChallengeProvider(acme.DNS01, provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Same default values than LEGO
|
||||
p.DNSChallenge.preCheckTimeout = 60 * time.Second
|
||||
p.DNSChallenge.preCheckInterval = 2 * time.Second
|
||||
|
||||
// Set the precheck timeout into the DNSChallenge provider
|
||||
if challengeProviderTimeout, ok := provider.(acme.ChallengeProviderTimeout); ok {
|
||||
p.DNSChallenge.preCheckTimeout, p.DNSChallenge.preCheckInterval = challengeProviderTimeout.Timeout()
|
||||
}
|
||||
|
||||
} else if p.HTTPChallenge != nil && len(p.HTTPChallenge.EntryPoint) > 0 {
|
||||
log.Debug("Using HTTP Challenge provider.")
|
||||
|
||||
client.ExcludeChallenges([]acme.Challenge{acme.DNS01, acme.TLSALPN01})
|
||||
|
||||
err = client.SetChallengeProvider(acme.HTTP01, &challengeHTTP{Store: p.Store})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if p.TLSChallenge != nil {
|
||||
log.Debug("Using TLS Challenge provider.")
|
||||
|
||||
client.ExcludeChallenges([]acme.Challenge{acme.HTTP01, acme.DNS01})
|
||||
|
||||
err = client.SetChallengeProvider(acme.TLSALPN01, &challengeTLSALPN{Store: p.Store})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("ACME challenge not specified, please select TLS or HTTP or DNS Challenge")
|
||||
}
|
||||
|
||||
p.client = client
|
||||
return p.client, nil
|
||||
}
|
||||
|
||||
func (p *Provider) initAccount() (*Account, error) {
|
||||
if p.account == nil || len(p.account.Email) == 0 {
|
||||
var err error
|
||||
p.account, err = NewAccount(p.Email, p.KeyType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Set the KeyType if not already defined in the account
|
||||
if len(p.account.KeyType) == 0 {
|
||||
p.account.KeyType = GetKeyType(p.KeyType)
|
||||
}
|
||||
|
||||
return p.account, nil
|
||||
}
|
||||
|
||||
func contains(entryPoints []string, acmeEntryPoint string) bool {
|
||||
for _, entryPoint := range entryPoints {
|
||||
if entryPoint == acmeEntryPoint {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Provider) watchNewDomains() {
|
||||
p.pool.Go(func(stop chan bool) {
|
||||
for {
|
||||
select {
|
||||
case config := <-p.configFromListenerChan:
|
||||
for _, frontend := range config.Frontends {
|
||||
if !contains(frontend.EntryPoints, p.EntryPoint) {
|
||||
continue
|
||||
}
|
||||
for _, route := range frontend.Routes {
|
||||
domainRules := rules.Rules{}
|
||||
domains, err := domainRules.ParseDomains(route.Rule)
|
||||
if err != nil {
|
||||
log.Errorf("Error parsing domains in provider ACME: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(domains) == 0 {
|
||||
log.Debugf("No domain parsed in rule %q in provider ACME", route.Rule)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("Try to challenge certificate for domain %v founded in Host rule", domains)
|
||||
|
||||
var domain types.Domain
|
||||
if len(domains) > 0 {
|
||||
domain = types.Domain{Main: domains[0]}
|
||||
if len(domains) > 1 {
|
||||
domain.SANs = domains[1:]
|
||||
}
|
||||
|
||||
safe.Go(func() {
|
||||
if _, err := p.resolveCertificate(domain, false); err != nil {
|
||||
log.Errorf("Unable to obtain ACME certificate for domains %q detected thanks to rule %q : %v", strings.Join(domains, ","), route.Rule, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Provider) resolveCertificate(domain types.Domain, domainFromConfigurationFile bool) (*acme.CertificateResource, error) {
|
||||
domains, err := p.getValidDomains(domain, domainFromConfigurationFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check provided certificates
|
||||
uncheckedDomains := p.getUncheckedDomains(domains, !domainFromConfigurationFile)
|
||||
if len(uncheckedDomains) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
p.addResolvingDomains(uncheckedDomains)
|
||||
defer p.removeResolvingDomains(uncheckedDomains)
|
||||
|
||||
log.Debugf("Loading ACME certificates %+v...", uncheckedDomains)
|
||||
|
||||
client, err := p.getClient()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot get ACME client %v", err)
|
||||
}
|
||||
|
||||
var certificate *acme.CertificateResource
|
||||
bundle := true
|
||||
if p.useCertificateWithRetry(uncheckedDomains) {
|
||||
certificate, err = obtainCertificateWithRetry(domains, client, p.DNSChallenge.preCheckTimeout, p.DNSChallenge.preCheckInterval, bundle)
|
||||
} else {
|
||||
certificate, err = client.ObtainCertificate(domains, bundle, nil, OSCPMustStaple)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to generate a certificate for the domains %v: %v", uncheckedDomains, err)
|
||||
}
|
||||
if certificate == nil {
|
||||
return nil, fmt.Errorf("domains %v do not generate a certificate", uncheckedDomains)
|
||||
}
|
||||
if len(certificate.Certificate) == 0 || len(certificate.PrivateKey) == 0 {
|
||||
return nil, fmt.Errorf("domains %v generate certificate with no value: %v", uncheckedDomains, certificate)
|
||||
}
|
||||
|
||||
log.Debugf("Certificates obtained for domains %+v", uncheckedDomains)
|
||||
|
||||
if len(uncheckedDomains) > 1 {
|
||||
domain = types.Domain{Main: uncheckedDomains[0], SANs: uncheckedDomains[1:]}
|
||||
} else {
|
||||
domain = types.Domain{Main: uncheckedDomains[0]}
|
||||
}
|
||||
p.addCertificateForDomain(domain, certificate.Certificate, certificate.PrivateKey)
|
||||
|
||||
return certificate, nil
|
||||
}
|
||||
|
||||
func (p *Provider) removeResolvingDomains(resolvingDomains []string) {
|
||||
p.resolvingDomainsMutex.Lock()
|
||||
defer p.resolvingDomainsMutex.Unlock()
|
||||
|
||||
for _, domain := range resolvingDomains {
|
||||
delete(p.resolvingDomains, domain)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Provider) addResolvingDomains(resolvingDomains []string) {
|
||||
p.resolvingDomainsMutex.Lock()
|
||||
defer p.resolvingDomainsMutex.Unlock()
|
||||
|
||||
for _, domain := range resolvingDomains {
|
||||
p.resolvingDomains[domain] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Provider) useCertificateWithRetry(domains []string) bool {
|
||||
// Check if we can use the retry mechanism only if we use the DNS Challenge and if is there are at least 2 domains to check
|
||||
if p.DNSChallenge != nil && len(domains) > 1 {
|
||||
rootDomain := ""
|
||||
for _, searchWildcardDomain := range domains {
|
||||
// Search a wildcard domain if not already found
|
||||
if len(rootDomain) == 0 && strings.HasPrefix(searchWildcardDomain, "*.") {
|
||||
rootDomain = strings.TrimPrefix(searchWildcardDomain, "*.")
|
||||
if len(rootDomain) > 0 {
|
||||
// Look for a root domain which matches the wildcard domain
|
||||
for _, searchRootDomain := range domains {
|
||||
if rootDomain == searchRootDomain {
|
||||
// If the domains list contains a wildcard domain and its root domain, we can use the retry mechanism to obtain the certificate
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
// There is only one wildcard domain in the slice, if its root domain has not been found, the retry mechanism does not have to be used
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func obtainCertificateWithRetry(domains []string, client *acme.Client, timeout, interval time.Duration, bundle bool) (*acme.CertificateResource, error) {
|
||||
var certificate *acme.CertificateResource
|
||||
var err error
|
||||
|
||||
operation := func() error {
|
||||
certificate, err = client.ObtainCertificate(domains, bundle, nil, OSCPMustStaple)
|
||||
return err
|
||||
}
|
||||
|
||||
notify := func(err error, time time.Duration) {
|
||||
log.Errorf("Error obtaining certificate retrying in %s", time)
|
||||
}
|
||||
|
||||
// Define a retry backOff to let LEGO tries twice to obtain a certificate for both wildcard and root domain
|
||||
ebo := backoff.NewExponentialBackOff()
|
||||
ebo.MaxElapsedTime = 2 * timeout
|
||||
ebo.MaxInterval = interval
|
||||
rbo := backoff.WithMaxRetries(ebo, 2)
|
||||
|
||||
err = backoff.RetryNotify(safe.OperationWithRecover(operation), rbo, notify)
|
||||
if err != nil {
|
||||
log.Errorf("Error obtaining certificate: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return certificate, nil
|
||||
}
|
||||
|
||||
func dnsOverrideDelay(delay parse.Duration) error {
|
||||
if delay == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if delay > 0 {
|
||||
log.Debugf("Delaying %d rather than validating DNS propagation now.", delay)
|
||||
|
||||
acme.PreCheckDNS = func(_, _ string) (bool, error) {
|
||||
time.Sleep(time.Duration(delay))
|
||||
return true, nil
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("delayBeforeCheck: %d cannot be less than 0", delay)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) addCertificateForDomain(domain types.Domain, certificate []byte, key []byte) {
|
||||
p.certsChan <- &Certificate{Certificate: certificate, Key: key, Domain: domain}
|
||||
}
|
||||
|
||||
// deleteUnnecessaryDomains deletes from the configuration :
|
||||
// - Duplicated domains
|
||||
// - Domains which are checked by wildcard domain
|
||||
func (p *Provider) deleteUnnecessaryDomains() {
|
||||
var newDomains []types.Domain
|
||||
|
||||
for idxDomainToCheck, domainToCheck := range p.Domains {
|
||||
keepDomain := true
|
||||
|
||||
for idxDomain, domain := range p.Domains {
|
||||
if idxDomainToCheck == idxDomain {
|
||||
continue
|
||||
}
|
||||
|
||||
if reflect.DeepEqual(domain, domainToCheck) {
|
||||
if idxDomainToCheck > idxDomain {
|
||||
log.Warnf("The domain %v is duplicated in the configuration but will be process by ACME provider only once.", domainToCheck)
|
||||
keepDomain = false
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Check if CN or SANS to check already exists
|
||||
// or can not be checked by a wildcard
|
||||
var newDomainsToCheck []string
|
||||
for _, domainProcessed := range domainToCheck.ToStrArray() {
|
||||
if idxDomain < idxDomainToCheck && isDomainAlreadyChecked(domainProcessed, domain.ToStrArray()) {
|
||||
// The domain is duplicated in a CN
|
||||
log.Warnf("Domain %q is duplicated in the configuration or validated by the domain %v. It will be processed once.", domainProcessed, domain)
|
||||
continue
|
||||
} else if domain.Main != domainProcessed && strings.HasPrefix(domain.Main, "*") && isDomainAlreadyChecked(domainProcessed, []string{domain.Main}) {
|
||||
// Check if a wildcard can validate the domain
|
||||
log.Warnf("Domain %q will not be processed by ACME provider because it is validated by the wildcard %q", domainProcessed, domain.Main)
|
||||
continue
|
||||
}
|
||||
newDomainsToCheck = append(newDomainsToCheck, domainProcessed)
|
||||
}
|
||||
|
||||
// Delete the domain if both Main and SANs can be validated by the wildcard domain
|
||||
// otherwise keep the unchecked values
|
||||
if newDomainsToCheck == nil {
|
||||
keepDomain = false
|
||||
break
|
||||
}
|
||||
domainToCheck.Set(newDomainsToCheck)
|
||||
}
|
||||
|
||||
if keepDomain {
|
||||
newDomains = append(newDomains, domainToCheck)
|
||||
}
|
||||
}
|
||||
|
||||
p.Domains = newDomains
|
||||
}
|
||||
|
||||
func (p *Provider) watchCertificate() {
|
||||
p.certsChan = make(chan *Certificate)
|
||||
p.pool.Go(func(stop chan bool) {
|
||||
for {
|
||||
select {
|
||||
case cert := <-p.certsChan:
|
||||
certUpdated := false
|
||||
for _, domainsCertificate := range p.certificates {
|
||||
if reflect.DeepEqual(cert.Domain, domainsCertificate.Domain) {
|
||||
domainsCertificate.Certificate = cert.Certificate
|
||||
domainsCertificate.Key = cert.Key
|
||||
certUpdated = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !certUpdated {
|
||||
p.certificates = append(p.certificates, cert)
|
||||
}
|
||||
|
||||
err := p.saveCertificates()
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Provider) saveCertificates() error {
|
||||
err := p.Store.SaveCertificates(p.certificates)
|
||||
|
||||
p.refreshCertificates()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *Provider) refreshCertificates() {
|
||||
config := types.ConfigMessage{
|
||||
ProviderName: "ACME",
|
||||
Configuration: &types.Configuration{
|
||||
Backends: map[string]*types.Backend{},
|
||||
Frontends: map[string]*types.Frontend{},
|
||||
TLS: []*traefiktls.Configuration{},
|
||||
},
|
||||
}
|
||||
|
||||
for _, cert := range p.certificates {
|
||||
certificate := &traefiktls.Certificate{CertFile: traefiktls.FileOrContent(cert.Certificate), KeyFile: traefiktls.FileOrContent(cert.Key)}
|
||||
config.Configuration.TLS = append(config.Configuration.TLS, &traefiktls.Configuration{Certificate: certificate, EntryPoints: []string{p.EntryPoint}})
|
||||
}
|
||||
p.configurationChan <- config
|
||||
}
|
||||
|
||||
func (p *Provider) renewCertificates() {
|
||||
log.Info("Testing certificate renew...")
|
||||
for _, certificate := range p.certificates {
|
||||
crt, err := getX509Certificate(certificate)
|
||||
// If there's an error, we assume the cert is broken, and needs update
|
||||
// <= 30 days left, renew certificate
|
||||
if err != nil || crt == nil || crt.NotAfter.Before(time.Now().Add(24*30*time.Hour)) {
|
||||
client, err := p.getClient()
|
||||
if err != nil {
|
||||
log.Infof("Error renewing certificate from LE : %+v, %v", certificate.Domain, err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Infof("Renewing certificate from LE : %+v", certificate.Domain)
|
||||
|
||||
renewedCert, err := client.RenewCertificate(acme.CertificateResource{
|
||||
Domain: certificate.Domain.Main,
|
||||
PrivateKey: certificate.Key,
|
||||
Certificate: certificate.Certificate,
|
||||
}, true, OSCPMustStaple)
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("Error renewing certificate from LE: %v, %v", certificate.Domain, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(renewedCert.Certificate) == 0 || len(renewedCert.PrivateKey) == 0 {
|
||||
log.Errorf("domains %v renew certificate with no value: %v", certificate.Domain.ToStrArray(), certificate)
|
||||
continue
|
||||
}
|
||||
|
||||
p.addCertificateForDomain(certificate.Domain, renewedCert.Certificate, renewedCert.PrivateKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get provided certificate which check a domains list (Main and SANs)
|
||||
// from static and dynamic provided certificates
|
||||
func (p *Provider) getUncheckedDomains(domainsToCheck []string, checkConfigurationDomains bool) []string {
|
||||
p.resolvingDomainsMutex.RLock()
|
||||
defer p.resolvingDomainsMutex.RUnlock()
|
||||
|
||||
log.Debugf("Looking for provided certificate(s) to validate %q...", domainsToCheck)
|
||||
|
||||
allDomains := p.certificateStore.GetAllDomains()
|
||||
|
||||
// Get ACME certificates
|
||||
for _, certificate := range p.certificates {
|
||||
allDomains = append(allDomains, strings.Join(certificate.Domain.ToStrArray(), ","))
|
||||
}
|
||||
|
||||
// Get currently resolved domains
|
||||
for domain := range p.resolvingDomains {
|
||||
allDomains = append(allDomains, domain)
|
||||
}
|
||||
|
||||
// Get Configuration Domains
|
||||
if checkConfigurationDomains {
|
||||
for i := 0; i < len(p.Domains); i++ {
|
||||
allDomains = append(allDomains, strings.Join(p.Domains[i].ToStrArray(), ","))
|
||||
}
|
||||
}
|
||||
|
||||
return searchUncheckedDomains(domainsToCheck, allDomains)
|
||||
}
|
||||
|
||||
func searchUncheckedDomains(domainsToCheck []string, existentDomains []string) []string {
|
||||
var uncheckedDomains []string
|
||||
for _, domainToCheck := range domainsToCheck {
|
||||
if !isDomainAlreadyChecked(domainToCheck, existentDomains) {
|
||||
uncheckedDomains = append(uncheckedDomains, domainToCheck)
|
||||
}
|
||||
}
|
||||
|
||||
if len(uncheckedDomains) == 0 {
|
||||
log.Debugf("No ACME certificate generation required for domains %q.", domainsToCheck)
|
||||
} else {
|
||||
log.Debugf("Domains %q need ACME certificates generation for domains %q.", domainsToCheck, strings.Join(uncheckedDomains, ","))
|
||||
}
|
||||
return uncheckedDomains
|
||||
}
|
||||
|
||||
func getX509Certificate(certificate *Certificate) (*x509.Certificate, error) {
|
||||
tlsCert, err := tls.X509KeyPair(certificate.Certificate, certificate.Key)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to load TLS keypair from ACME certificate for domain %q (SAN : %q), certificate will be renewed : %v", certificate.Domain.Main, strings.Join(certificate.Domain.SANs, ","), err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
crt := tlsCert.Leaf
|
||||
if crt == nil {
|
||||
crt, err = x509.ParseCertificate(tlsCert.Certificate[0])
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse TLS keypair from ACME certificate for domain %q (SAN : %q), certificate will be renewed : %v", certificate.Domain.Main, strings.Join(certificate.Domain.SANs, ","), err)
|
||||
}
|
||||
}
|
||||
|
||||
return crt, err
|
||||
}
|
||||
|
||||
// getValidDomains checks if given domain is allowed to generate a ACME certificate and return it
|
||||
func (p *Provider) getValidDomains(domain types.Domain, wildcardAllowed bool) ([]string, error) {
|
||||
domains := domain.ToStrArray()
|
||||
if len(domains) == 0 {
|
||||
return nil, errors.New("unable to generate a certificate in ACME provider when no domain is given")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(domain.Main, "*") {
|
||||
if !wildcardAllowed {
|
||||
return nil, fmt.Errorf("unable to generate a wildcard certificate in ACME provider for domain %q from a 'Host' rule", strings.Join(domains, ","))
|
||||
}
|
||||
|
||||
if p.DNSChallenge == nil {
|
||||
return nil, fmt.Errorf("unable to generate a wildcard certificate in ACME provider for domain %q : ACME needs a DNSChallenge", strings.Join(domains, ","))
|
||||
}
|
||||
|
||||
if strings.HasPrefix(domain.Main, "*.*") {
|
||||
return nil, fmt.Errorf("unable to generate a wildcard certificate in ACME provider for domain %q : ACME does not allow '*.*' wildcard domain", strings.Join(domains, ","))
|
||||
}
|
||||
}
|
||||
|
||||
for _, san := range domain.SANs {
|
||||
if strings.HasPrefix(san, "*") {
|
||||
return nil, fmt.Errorf("unable to generate a certificate in ACME provider for domains %q: SAN %q can not be a wildcard domain", strings.Join(domains, ","), san)
|
||||
}
|
||||
}
|
||||
|
||||
var cleanDomains []string
|
||||
for _, domain := range domains {
|
||||
canonicalDomain := types.CanonicalDomain(domain)
|
||||
cleanDomain := acme.UnFqdn(canonicalDomain)
|
||||
if canonicalDomain != cleanDomain {
|
||||
log.Warnf("FQDN detected, please remove the trailing dot: %s", canonicalDomain)
|
||||
}
|
||||
cleanDomains = append(cleanDomains, cleanDomain)
|
||||
}
|
||||
|
||||
return cleanDomains, nil
|
||||
}
|
||||
|
||||
func isDomainAlreadyChecked(domainToCheck string, existentDomains []string) bool {
|
||||
for _, certDomains := range existentDomains {
|
||||
for _, certDomain := range strings.Split(certDomains, ",") {
|
||||
if types.MatchDomain(domainToCheck, certDomain) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SetPropagationCheck to disable the Lego PreCheck.
|
||||
func SetPropagationCheck(disable bool) {
|
||||
if disable {
|
||||
acme.PreCheckDNS = func(_, _ string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetRecursiveNameServers to provide a custom DNS resolver.
|
||||
func SetRecursiveNameServers(dnsResolvers []string) {
|
||||
resolvers := normaliseDNSResolvers(dnsResolvers)
|
||||
if len(resolvers) > 0 {
|
||||
acme.RecursiveNameservers = resolvers
|
||||
log.Infof("Validating FQDN authority with DNS using %+v", resolvers)
|
||||
}
|
||||
}
|
||||
|
||||
// ensure all servers have a port number
|
||||
func normaliseDNSResolvers(dnsResolvers []string) []string {
|
||||
var normalisedResolvers []string
|
||||
for _, server := range dnsResolvers {
|
||||
srv := strings.TrimSpace(server)
|
||||
if len(srv) > 0 {
|
||||
if host, port, err := net.SplitHostPort(srv); err != nil {
|
||||
normalisedResolvers = append(normalisedResolvers, net.JoinHostPort(srv, "53"))
|
||||
} else {
|
||||
normalisedResolvers = append(normalisedResolvers, net.JoinHostPort(host, port))
|
||||
}
|
||||
}
|
||||
}
|
||||
return normalisedResolvers
|
||||
}
|
||||
684
old/provider/acme/provider_test.go
Normal file
684
old/provider/acme/provider_test.go
Normal file
|
|
@ -0,0 +1,684 @@
|
|||
package acme
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/containous/traefik/safe"
|
||||
traefiktls "github.com/containous/traefik/tls"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/xenolf/lego/acme"
|
||||
)
|
||||
|
||||
func TestGetUncheckedCertificates(t *testing.T) {
|
||||
wildcardMap := make(map[string]*tls.Certificate)
|
||||
wildcardMap["*.traefik.wtf"] = &tls.Certificate{}
|
||||
|
||||
wildcardSafe := &safe.Safe{}
|
||||
wildcardSafe.Set(wildcardMap)
|
||||
|
||||
domainMap := make(map[string]*tls.Certificate)
|
||||
domainMap["traefik.wtf"] = &tls.Certificate{}
|
||||
|
||||
domainSafe := &safe.Safe{}
|
||||
domainSafe.Set(domainMap)
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
dynamicCerts *safe.Safe
|
||||
staticCerts *safe.Safe
|
||||
resolvingDomains map[string]struct{}
|
||||
acmeCertificates []*Certificate
|
||||
domains []string
|
||||
expectedDomains []string
|
||||
}{
|
||||
{
|
||||
desc: "wildcard to generate",
|
||||
domains: []string{"*.traefik.wtf"},
|
||||
expectedDomains: []string{"*.traefik.wtf"},
|
||||
},
|
||||
{
|
||||
desc: "wildcard already exists in dynamic certificates",
|
||||
domains: []string{"*.traefik.wtf"},
|
||||
dynamicCerts: wildcardSafe,
|
||||
expectedDomains: nil,
|
||||
},
|
||||
{
|
||||
desc: "wildcard already exists in static certificates",
|
||||
domains: []string{"*.traefik.wtf"},
|
||||
staticCerts: wildcardSafe,
|
||||
expectedDomains: nil,
|
||||
},
|
||||
{
|
||||
desc: "wildcard already exists in ACME certificates",
|
||||
domains: []string{"*.traefik.wtf"},
|
||||
acmeCertificates: []*Certificate{
|
||||
{
|
||||
Domain: types.Domain{Main: "*.traefik.wtf"},
|
||||
},
|
||||
},
|
||||
expectedDomains: nil,
|
||||
},
|
||||
{
|
||||
desc: "domain CN and SANs to generate",
|
||||
domains: []string{"traefik.wtf", "foo.traefik.wtf"},
|
||||
expectedDomains: []string{"traefik.wtf", "foo.traefik.wtf"},
|
||||
},
|
||||
{
|
||||
desc: "domain CN already exists in dynamic certificates and SANs to generate",
|
||||
domains: []string{"traefik.wtf", "foo.traefik.wtf"},
|
||||
dynamicCerts: domainSafe,
|
||||
expectedDomains: []string{"foo.traefik.wtf"},
|
||||
},
|
||||
{
|
||||
desc: "domain CN already exists in static certificates and SANs to generate",
|
||||
domains: []string{"traefik.wtf", "foo.traefik.wtf"},
|
||||
staticCerts: domainSafe,
|
||||
expectedDomains: []string{"foo.traefik.wtf"},
|
||||
},
|
||||
{
|
||||
desc: "domain CN already exists in ACME certificates and SANs to generate",
|
||||
domains: []string{"traefik.wtf", "foo.traefik.wtf"},
|
||||
acmeCertificates: []*Certificate{
|
||||
{
|
||||
Domain: types.Domain{Main: "traefik.wtf"},
|
||||
},
|
||||
},
|
||||
expectedDomains: []string{"foo.traefik.wtf"},
|
||||
},
|
||||
{
|
||||
desc: "domain already exists in dynamic certificates",
|
||||
domains: []string{"traefik.wtf"},
|
||||
dynamicCerts: domainSafe,
|
||||
expectedDomains: nil,
|
||||
},
|
||||
{
|
||||
desc: "domain already exists in static certificates",
|
||||
domains: []string{"traefik.wtf"},
|
||||
staticCerts: domainSafe,
|
||||
expectedDomains: nil,
|
||||
},
|
||||
{
|
||||
desc: "domain already exists in ACME certificates",
|
||||
domains: []string{"traefik.wtf"},
|
||||
acmeCertificates: []*Certificate{
|
||||
{
|
||||
Domain: types.Domain{Main: "traefik.wtf"},
|
||||
},
|
||||
},
|
||||
expectedDomains: nil,
|
||||
},
|
||||
{
|
||||
desc: "domain matched by wildcard in dynamic certificates",
|
||||
domains: []string{"who.traefik.wtf", "foo.traefik.wtf"},
|
||||
dynamicCerts: wildcardSafe,
|
||||
expectedDomains: nil,
|
||||
},
|
||||
{
|
||||
desc: "domain matched by wildcard in static certificates",
|
||||
domains: []string{"who.traefik.wtf", "foo.traefik.wtf"},
|
||||
staticCerts: wildcardSafe,
|
||||
expectedDomains: nil,
|
||||
},
|
||||
{
|
||||
desc: "domain matched by wildcard in ACME certificates",
|
||||
domains: []string{"who.traefik.wtf", "foo.traefik.wtf"},
|
||||
acmeCertificates: []*Certificate{
|
||||
{
|
||||
Domain: types.Domain{Main: "*.traefik.wtf"},
|
||||
},
|
||||
},
|
||||
expectedDomains: nil,
|
||||
},
|
||||
{
|
||||
desc: "root domain with wildcard in ACME certificates",
|
||||
domains: []string{"traefik.wtf", "foo.traefik.wtf"},
|
||||
acmeCertificates: []*Certificate{
|
||||
{
|
||||
Domain: types.Domain{Main: "*.traefik.wtf"},
|
||||
},
|
||||
},
|
||||
expectedDomains: []string{"traefik.wtf"},
|
||||
},
|
||||
{
|
||||
desc: "all domains already managed by ACME",
|
||||
domains: []string{"traefik.wtf", "foo.traefik.wtf"},
|
||||
resolvingDomains: map[string]struct{}{
|
||||
"traefik.wtf": {},
|
||||
"foo.traefik.wtf": {},
|
||||
},
|
||||
expectedDomains: []string{},
|
||||
},
|
||||
{
|
||||
desc: "one domain already managed by ACME",
|
||||
domains: []string{"traefik.wtf", "foo.traefik.wtf"},
|
||||
resolvingDomains: map[string]struct{}{
|
||||
"traefik.wtf": {},
|
||||
},
|
||||
expectedDomains: []string{"foo.traefik.wtf"},
|
||||
},
|
||||
{
|
||||
desc: "wildcard domain already managed by ACME checks the domains",
|
||||
domains: []string{"bar.traefik.wtf", "foo.traefik.wtf"},
|
||||
resolvingDomains: map[string]struct{}{
|
||||
"*.traefik.wtf": {},
|
||||
},
|
||||
expectedDomains: []string{},
|
||||
},
|
||||
{
|
||||
desc: "wildcard domain already managed by ACME checks domains and another domain checks one other domain, one domain still unchecked",
|
||||
domains: []string{"traefik.wtf", "bar.traefik.wtf", "foo.traefik.wtf", "acme.wtf"},
|
||||
resolvingDomains: map[string]struct{}{
|
||||
"*.traefik.wtf": {},
|
||||
"traefik.wtf": {},
|
||||
},
|
||||
expectedDomains: []string{"acme.wtf"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if test.resolvingDomains == nil {
|
||||
test.resolvingDomains = make(map[string]struct{})
|
||||
}
|
||||
|
||||
acmeProvider := Provider{
|
||||
certificateStore: &traefiktls.CertificateStore{
|
||||
DynamicCerts: test.dynamicCerts,
|
||||
StaticCerts: test.staticCerts,
|
||||
},
|
||||
certificates: test.acmeCertificates,
|
||||
resolvingDomains: test.resolvingDomains,
|
||||
}
|
||||
|
||||
domains := acmeProvider.getUncheckedDomains(test.domains, false)
|
||||
assert.Equal(t, len(test.expectedDomains), len(domains), "Unexpected domains.")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetValidDomain(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
domains types.Domain
|
||||
wildcardAllowed bool
|
||||
dnsChallenge *DNSChallenge
|
||||
expectedErr string
|
||||
expectedDomains []string
|
||||
}{
|
||||
{
|
||||
desc: "valid wildcard",
|
||||
domains: types.Domain{Main: "*.traefik.wtf"},
|
||||
dnsChallenge: &DNSChallenge{},
|
||||
wildcardAllowed: true,
|
||||
expectedErr: "",
|
||||
expectedDomains: []string{"*.traefik.wtf"},
|
||||
},
|
||||
{
|
||||
desc: "no wildcard",
|
||||
domains: types.Domain{Main: "traefik.wtf", SANs: []string{"foo.traefik.wtf"}},
|
||||
dnsChallenge: &DNSChallenge{},
|
||||
expectedErr: "",
|
||||
wildcardAllowed: true,
|
||||
expectedDomains: []string{"traefik.wtf", "foo.traefik.wtf"},
|
||||
},
|
||||
{
|
||||
desc: "unauthorized wildcard",
|
||||
domains: types.Domain{Main: "*.traefik.wtf"},
|
||||
dnsChallenge: &DNSChallenge{},
|
||||
wildcardAllowed: false,
|
||||
expectedErr: "unable to generate a wildcard certificate in ACME provider for domain \"*.traefik.wtf\" from a 'Host' rule",
|
||||
expectedDomains: nil,
|
||||
},
|
||||
{
|
||||
desc: "no domain",
|
||||
domains: types.Domain{},
|
||||
dnsChallenge: nil,
|
||||
wildcardAllowed: true,
|
||||
expectedErr: "unable to generate a certificate in ACME provider when no domain is given",
|
||||
expectedDomains: nil,
|
||||
},
|
||||
{
|
||||
desc: "no DNSChallenge",
|
||||
domains: types.Domain{Main: "*.traefik.wtf", SANs: []string{"foo.traefik.wtf"}},
|
||||
dnsChallenge: nil,
|
||||
wildcardAllowed: true,
|
||||
expectedErr: "unable to generate a wildcard certificate in ACME provider for domain \"*.traefik.wtf,foo.traefik.wtf\" : ACME needs a DNSChallenge",
|
||||
expectedDomains: nil,
|
||||
},
|
||||
{
|
||||
desc: "unauthorized wildcard with SAN",
|
||||
domains: types.Domain{Main: "*.*.traefik.wtf", SANs: []string{"foo.traefik.wtf"}},
|
||||
dnsChallenge: &DNSChallenge{},
|
||||
wildcardAllowed: true,
|
||||
expectedErr: "unable to generate a wildcard certificate in ACME provider for domain \"*.*.traefik.wtf,foo.traefik.wtf\" : ACME does not allow '*.*' wildcard domain",
|
||||
expectedDomains: nil,
|
||||
},
|
||||
{
|
||||
desc: "wildcard and SANs",
|
||||
domains: types.Domain{Main: "*.traefik.wtf", SANs: []string{"traefik.wtf"}},
|
||||
dnsChallenge: &DNSChallenge{},
|
||||
wildcardAllowed: true,
|
||||
expectedErr: "",
|
||||
expectedDomains: []string{"*.traefik.wtf", "traefik.wtf"},
|
||||
},
|
||||
{
|
||||
desc: "unexpected SANs",
|
||||
domains: types.Domain{Main: "*.traefik.wtf", SANs: []string{"*.acme.wtf"}},
|
||||
dnsChallenge: &DNSChallenge{},
|
||||
wildcardAllowed: true,
|
||||
expectedErr: "unable to generate a certificate in ACME provider for domains \"*.traefik.wtf,*.acme.wtf\": SAN \"*.acme.wtf\" can not be a wildcard domain",
|
||||
expectedDomains: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
acmeProvider := Provider{Configuration: &Configuration{DNSChallenge: test.dnsChallenge}}
|
||||
|
||||
domains, err := acmeProvider.getValidDomains(test.domains, test.wildcardAllowed)
|
||||
|
||||
if len(test.expectedErr) > 0 {
|
||||
assert.EqualError(t, err, test.expectedErr, "Unexpected error.")
|
||||
} else {
|
||||
assert.Equal(t, len(test.expectedDomains), len(domains), "Unexpected domains.")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteUnnecessaryDomains(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
domains []types.Domain
|
||||
expectedDomains []types.Domain
|
||||
}{
|
||||
{
|
||||
desc: "no domain to delete",
|
||||
domains: []types.Domain{
|
||||
{
|
||||
Main: "acme.wtf",
|
||||
SANs: []string{"traefik.acme.wtf", "foo.bar"},
|
||||
},
|
||||
{
|
||||
Main: "*.foo.acme.wtf",
|
||||
},
|
||||
{
|
||||
Main: "acme02.wtf",
|
||||
SANs: []string{"traefik.acme02.wtf", "bar.foo"},
|
||||
},
|
||||
},
|
||||
expectedDomains: []types.Domain{
|
||||
{
|
||||
Main: "acme.wtf",
|
||||
SANs: []string{"traefik.acme.wtf", "foo.bar"},
|
||||
},
|
||||
{
|
||||
Main: "*.foo.acme.wtf",
|
||||
SANs: []string{},
|
||||
},
|
||||
{
|
||||
Main: "acme02.wtf",
|
||||
SANs: []string{"traefik.acme02.wtf", "bar.foo"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "wildcard and root domain",
|
||||
domains: []types.Domain{
|
||||
{
|
||||
Main: "acme.wtf",
|
||||
},
|
||||
{
|
||||
Main: "*.acme.wtf",
|
||||
SANs: []string{"acme.wtf"},
|
||||
},
|
||||
},
|
||||
expectedDomains: []types.Domain{
|
||||
{
|
||||
Main: "acme.wtf",
|
||||
SANs: []string{},
|
||||
},
|
||||
{
|
||||
Main: "*.acme.wtf",
|
||||
SANs: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "2 equals domains",
|
||||
domains: []types.Domain{
|
||||
{
|
||||
Main: "acme.wtf",
|
||||
SANs: []string{"traefik.acme.wtf", "foo.bar"},
|
||||
},
|
||||
{
|
||||
Main: "acme.wtf",
|
||||
SANs: []string{"traefik.acme.wtf", "foo.bar"},
|
||||
},
|
||||
},
|
||||
expectedDomains: []types.Domain{
|
||||
{
|
||||
Main: "acme.wtf",
|
||||
SANs: []string{"traefik.acme.wtf", "foo.bar"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "2 domains with same values",
|
||||
domains: []types.Domain{
|
||||
{
|
||||
Main: "acme.wtf",
|
||||
SANs: []string{"traefik.acme.wtf"},
|
||||
},
|
||||
{
|
||||
Main: "acme.wtf",
|
||||
SANs: []string{"traefik.acme.wtf", "foo.bar"},
|
||||
},
|
||||
},
|
||||
expectedDomains: []types.Domain{
|
||||
{
|
||||
Main: "acme.wtf",
|
||||
SANs: []string{"traefik.acme.wtf"},
|
||||
},
|
||||
{
|
||||
Main: "foo.bar",
|
||||
SANs: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "domain totally checked by wildcard",
|
||||
domains: []types.Domain{
|
||||
{
|
||||
Main: "who.acme.wtf",
|
||||
SANs: []string{"traefik.acme.wtf", "bar.acme.wtf"},
|
||||
},
|
||||
{
|
||||
Main: "*.acme.wtf",
|
||||
},
|
||||
},
|
||||
expectedDomains: []types.Domain{
|
||||
{
|
||||
Main: "*.acme.wtf",
|
||||
SANs: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "duplicated wildcard",
|
||||
domains: []types.Domain{
|
||||
{
|
||||
Main: "*.acme.wtf",
|
||||
SANs: []string{"acme.wtf"},
|
||||
},
|
||||
{
|
||||
Main: "*.acme.wtf",
|
||||
},
|
||||
},
|
||||
expectedDomains: []types.Domain{
|
||||
{
|
||||
Main: "*.acme.wtf",
|
||||
SANs: []string{"acme.wtf"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "domain partially checked by wildcard",
|
||||
domains: []types.Domain{
|
||||
{
|
||||
Main: "traefik.acme.wtf",
|
||||
SANs: []string{"acme.wtf", "foo.bar"},
|
||||
},
|
||||
{
|
||||
Main: "*.acme.wtf",
|
||||
},
|
||||
{
|
||||
Main: "who.acme.wtf",
|
||||
SANs: []string{"traefik.acme.wtf", "bar.acme.wtf"},
|
||||
},
|
||||
},
|
||||
expectedDomains: []types.Domain{
|
||||
{
|
||||
Main: "acme.wtf",
|
||||
SANs: []string{"foo.bar"},
|
||||
},
|
||||
{
|
||||
Main: "*.acme.wtf",
|
||||
SANs: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
acmeProvider := Provider{Configuration: &Configuration{Domains: test.domains}}
|
||||
|
||||
acmeProvider.deleteUnnecessaryDomains()
|
||||
assert.Equal(t, test.expectedDomains, acmeProvider.Domains, "unexpected domain")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsAccountMatchingCaServer(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
accountURI string
|
||||
serverURI string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
desc: "acme staging with matching account",
|
||||
accountURI: "https://acme-staging-v02.api.letsencrypt.org/acme/acct/1234567",
|
||||
serverURI: "https://acme-staging-v02.api.letsencrypt.org/acme/directory",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
desc: "acme production with matching account",
|
||||
accountURI: "https://acme-v02.api.letsencrypt.org/acme/acct/1234567",
|
||||
serverURI: "https://acme-v02.api.letsencrypt.org/acme/directory",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
desc: "http only acme with matching account",
|
||||
accountURI: "http://acme.api.letsencrypt.org/acme/acct/1234567",
|
||||
serverURI: "http://acme.api.letsencrypt.org/acme/directory",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
desc: "different subdomains for account and server",
|
||||
accountURI: "https://test1.example.org/acme/acct/1234567",
|
||||
serverURI: "https://test2.example.org/acme/directory",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
desc: "different domains for account and server",
|
||||
accountURI: "https://test.example1.org/acme/acct/1234567",
|
||||
serverURI: "https://test.example2.org/acme/directory",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
desc: "different tld for account and server",
|
||||
accountURI: "https://test.example.com/acme/acct/1234567",
|
||||
serverURI: "https://test.example.org/acme/directory",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
desc: "malformed account url",
|
||||
accountURI: "//|\\/test.example.com/acme/acct/1234567",
|
||||
serverURI: "https://test.example.com/acme/directory",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
desc: "malformed server url",
|
||||
accountURI: "https://test.example.com/acme/acct/1234567",
|
||||
serverURI: "//|\\/test.example.com/acme/directory",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
desc: "malformed server and account url",
|
||||
accountURI: "//|\\/test.example.com/acme/acct/1234567",
|
||||
serverURI: "//|\\/test.example.com/acme/directory",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
result := isAccountMatchingCaServer(test.accountURI, test.serverURI)
|
||||
|
||||
assert.Equal(t, test.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUseBackOffToObtainCertificate(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
domains []string
|
||||
dnsChallenge *DNSChallenge
|
||||
expectedResponse bool
|
||||
}{
|
||||
{
|
||||
desc: "only one single domain",
|
||||
domains: []string{"acme.wtf"},
|
||||
dnsChallenge: &DNSChallenge{},
|
||||
expectedResponse: false,
|
||||
},
|
||||
{
|
||||
desc: "only one wildcard domain",
|
||||
domains: []string{"*.acme.wtf"},
|
||||
dnsChallenge: &DNSChallenge{},
|
||||
expectedResponse: false,
|
||||
},
|
||||
{
|
||||
desc: "wildcard domain with no root domain",
|
||||
domains: []string{"*.acme.wtf", "foo.acme.wtf", "bar.acme.wtf", "foo.bar"},
|
||||
dnsChallenge: &DNSChallenge{},
|
||||
expectedResponse: false,
|
||||
},
|
||||
{
|
||||
desc: "wildcard and root domain",
|
||||
domains: []string{"*.acme.wtf", "foo.acme.wtf", "bar.acme.wtf", "acme.wtf"},
|
||||
dnsChallenge: &DNSChallenge{},
|
||||
expectedResponse: true,
|
||||
},
|
||||
{
|
||||
desc: "wildcard and root domain but no DNS challenge",
|
||||
domains: []string{"*.acme.wtf", "acme.wtf"},
|
||||
dnsChallenge: nil,
|
||||
expectedResponse: false,
|
||||
},
|
||||
{
|
||||
desc: "two wildcard domains (must never happen)",
|
||||
domains: []string{"*.acme.wtf", "*.bar.foo"},
|
||||
dnsChallenge: nil,
|
||||
expectedResponse: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
acmeProvider := Provider{Configuration: &Configuration{DNSChallenge: test.dnsChallenge}}
|
||||
|
||||
actualResponse := acmeProvider.useCertificateWithRetry(test.domains)
|
||||
assert.Equal(t, test.expectedResponse, actualResponse, "unexpected response to use backOff")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitAccount(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
account *Account
|
||||
email string
|
||||
keyType string
|
||||
expectedAccount *Account
|
||||
}{
|
||||
{
|
||||
desc: "Existing account with all information",
|
||||
account: &Account{
|
||||
Email: "foo@foo.net",
|
||||
KeyType: acme.EC256,
|
||||
},
|
||||
expectedAccount: &Account{
|
||||
Email: "foo@foo.net",
|
||||
KeyType: acme.EC256,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Account nil",
|
||||
email: "foo@foo.net",
|
||||
keyType: "EC256",
|
||||
expectedAccount: &Account{
|
||||
Email: "foo@foo.net",
|
||||
KeyType: acme.EC256,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Existing account with no email",
|
||||
account: &Account{
|
||||
KeyType: acme.RSA4096,
|
||||
},
|
||||
email: "foo@foo.net",
|
||||
keyType: "EC256",
|
||||
expectedAccount: &Account{
|
||||
Email: "foo@foo.net",
|
||||
KeyType: acme.EC256,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Existing account with no key type",
|
||||
account: &Account{
|
||||
Email: "foo@foo.net",
|
||||
},
|
||||
email: "bar@foo.net",
|
||||
keyType: "EC256",
|
||||
expectedAccount: &Account{
|
||||
Email: "foo@foo.net",
|
||||
KeyType: acme.EC256,
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Existing account and provider with no key type",
|
||||
account: &Account{
|
||||
Email: "foo@foo.net",
|
||||
},
|
||||
email: "bar@foo.net",
|
||||
expectedAccount: &Account{
|
||||
Email: "foo@foo.net",
|
||||
KeyType: acme.RSA4096,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
acmeProvider := Provider{account: test.account, Configuration: &Configuration{Email: test.email, KeyType: test.keyType}}
|
||||
|
||||
actualAccount, err := acmeProvider.initAccount()
|
||||
assert.Nil(t, err, "Init account in error")
|
||||
assert.Equal(t, test.expectedAccount.Email, actualAccount.Email, "unexpected email account")
|
||||
assert.Equal(t, test.expectedAccount.KeyType, actualAccount.KeyType, "unexpected keyType account")
|
||||
})
|
||||
}
|
||||
}
|
||||
25
old/provider/acme/store.go
Normal file
25
old/provider/acme/store.go
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
package acme
|
||||
|
||||
// StoredData represents the data managed by the Store
|
||||
type StoredData struct {
|
||||
Account *Account
|
||||
Certificates []*Certificate
|
||||
HTTPChallenges map[string]map[string][]byte
|
||||
TLSChallenges map[string]*Certificate
|
||||
}
|
||||
|
||||
// Store is a generic interface to represents a storage
|
||||
type Store interface {
|
||||
GetAccount() (*Account, error)
|
||||
SaveAccount(*Account) error
|
||||
GetCertificates() ([]*Certificate, error)
|
||||
SaveCertificates([]*Certificate) error
|
||||
|
||||
GetHTTPChallengeToken(token, domain string) ([]byte, error)
|
||||
SetHTTPChallengeToken(token, domain string, keyAuth []byte) error
|
||||
RemoveHTTPChallengeToken(token, domain string) error
|
||||
|
||||
AddTLSChallenge(domain string, cert *Certificate) error
|
||||
GetTLSChallenge(domain string) (*Certificate, error)
|
||||
RemoveTLSChallenge(domain string) error
|
||||
}
|
||||
48
old/provider/boltdb/boltdb.go
Normal file
48
old/provider/boltdb/boltdb.go
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
package boltdb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/abronan/valkeyrie/store"
|
||||
"github.com/abronan/valkeyrie/store/boltdb"
|
||||
"github.com/containous/traefik/old/provider"
|
||||
"github.com/containous/traefik/old/provider/kv"
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/containous/traefik/safe"
|
||||
)
|
||||
|
||||
var _ provider.Provider = (*Provider)(nil)
|
||||
|
||||
// Provider holds configurations of the provider.
|
||||
type Provider struct {
|
||||
kv.Provider `mapstructure:",squash" export:"true"`
|
||||
}
|
||||
|
||||
// Init the provider
|
||||
func (p *Provider) Init(constraints types.Constraints) error {
|
||||
err := p.Provider.Init(constraints)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store, err := p.CreateStore()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to Connect to KV store: %v", err)
|
||||
}
|
||||
|
||||
p.SetKVClient(store)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Provide allows the boltdb provider to Provide configurations to traefik
|
||||
// using the given configuration channel.
|
||||
func (p *Provider) Provide(configurationChan chan<- types.ConfigMessage, pool *safe.Pool) error {
|
||||
return p.Provider.Provide(configurationChan, pool)
|
||||
}
|
||||
|
||||
// CreateStore creates the KV store
|
||||
func (p *Provider) CreateStore() (store.Store, error) {
|
||||
p.SetStoreType(store.BOLTDB)
|
||||
boltdb.Register()
|
||||
return p.Provider.CreateStore()
|
||||
}
|
||||
48
old/provider/consul/consul.go
Normal file
48
old/provider/consul/consul.go
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
package consul
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/abronan/valkeyrie/store"
|
||||
"github.com/abronan/valkeyrie/store/consul"
|
||||
"github.com/containous/traefik/old/provider"
|
||||
"github.com/containous/traefik/old/provider/kv"
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/containous/traefik/safe"
|
||||
)
|
||||
|
||||
var _ provider.Provider = (*Provider)(nil)
|
||||
|
||||
// Provider holds configurations of the p.
|
||||
type Provider struct {
|
||||
kv.Provider `mapstructure:",squash" export:"true"`
|
||||
}
|
||||
|
||||
// Init the provider
|
||||
func (p *Provider) Init(constraints types.Constraints) error {
|
||||
err := p.Provider.Init(constraints)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store, err := p.CreateStore()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to Connect to KV store: %v", err)
|
||||
}
|
||||
|
||||
p.SetKVClient(store)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Provide allows the consul provider to provide configurations to traefik
|
||||
// using the given configuration channel.
|
||||
func (p *Provider) Provide(configurationChan chan<- types.ConfigMessage, pool *safe.Pool) error {
|
||||
return p.Provider.Provide(configurationChan, pool)
|
||||
}
|
||||
|
||||
// CreateStore creates the KV store
|
||||
func (p *Provider) CreateStore() (store.Store, error) {
|
||||
p.SetStoreType(store.CONSUL)
|
||||
consul.Register()
|
||||
return p.Provider.CreateStore()
|
||||
}
|
||||
227
old/provider/consulcatalog/config.go
Normal file
227
old/provider/consulcatalog/config.go
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
package consulcatalog
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/containous/traefik/old/provider"
|
||||
"github.com/containous/traefik/old/provider/label"
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/hashicorp/consul/api"
|
||||
)
|
||||
|
||||
func (p *Provider) buildConfiguration(catalog []catalogUpdate) *types.Configuration {
|
||||
var funcMap = template.FuncMap{
|
||||
"getAttribute": p.getAttribute,
|
||||
"getTag": getTag,
|
||||
"hasTag": hasTag,
|
||||
|
||||
// Backend functions
|
||||
"getNodeBackendName": getNodeBackendName,
|
||||
"getServiceBackendName": getServiceBackendName,
|
||||
"getBackendAddress": getBackendAddress,
|
||||
"getServerName": getServerName,
|
||||
"getCircuitBreaker": label.GetCircuitBreaker,
|
||||
"getLoadBalancer": label.GetLoadBalancer,
|
||||
"getMaxConn": label.GetMaxConn,
|
||||
"getHealthCheck": label.GetHealthCheck,
|
||||
"getBuffering": label.GetBuffering,
|
||||
"getResponseForwarding": label.GetResponseForwarding,
|
||||
"getServer": p.getServer,
|
||||
|
||||
// Frontend functions
|
||||
"getFrontendRule": p.getFrontendRule,
|
||||
"getBasicAuth": label.GetFuncSliceString(label.TraefikFrontendAuthBasic), // Deprecated
|
||||
"getAuth": label.GetAuth,
|
||||
"getFrontEndEntryPoints": label.GetFuncSliceString(label.TraefikFrontendEntryPoints),
|
||||
"getPriority": label.GetFuncInt(label.TraefikFrontendPriority, label.DefaultFrontendPriority),
|
||||
"getPassHostHeader": label.GetFuncBool(label.TraefikFrontendPassHostHeader, label.DefaultPassHostHeader),
|
||||
"getPassTLSCert": label.GetFuncBool(label.TraefikFrontendPassTLSCert, label.DefaultPassTLSCert),
|
||||
"getPassTLSClientCert": label.GetTLSClientCert,
|
||||
"getWhiteList": label.GetWhiteList,
|
||||
"getRedirect": label.GetRedirect,
|
||||
"getErrorPages": label.GetErrorPages,
|
||||
"getRateLimit": label.GetRateLimit,
|
||||
"getHeaders": label.GetHeaders,
|
||||
}
|
||||
|
||||
var allNodes []*api.ServiceEntry
|
||||
var services []*serviceUpdate
|
||||
for _, info := range catalog {
|
||||
if len(info.Nodes) > 0 {
|
||||
services = append(services, p.generateFrontends(info.Service)...)
|
||||
allNodes = append(allNodes, info.Nodes...)
|
||||
}
|
||||
}
|
||||
// Ensure a stable ordering of nodes so that identical configurations may be detected
|
||||
sort.Sort(nodeSorter(allNodes))
|
||||
|
||||
templateObjects := struct {
|
||||
Services []*serviceUpdate
|
||||
Nodes []*api.ServiceEntry
|
||||
}{
|
||||
Services: services,
|
||||
Nodes: allNodes,
|
||||
}
|
||||
|
||||
configuration, err := p.GetConfiguration("templates/consul_catalog.tmpl", funcMap, templateObjects)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to create config")
|
||||
}
|
||||
|
||||
return configuration
|
||||
}
|
||||
|
||||
// Specific functions
|
||||
|
||||
func (p *Provider) getFrontendRule(service serviceUpdate) string {
|
||||
customFrontendRule := label.GetStringValue(service.TraefikLabels, label.TraefikFrontendRule, "")
|
||||
if customFrontendRule == "" {
|
||||
customFrontendRule = p.FrontEndRule
|
||||
}
|
||||
|
||||
tmpl := p.frontEndRuleTemplate
|
||||
tmpl, err := tmpl.Parse(customFrontendRule)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to parse Consul Catalog custom frontend rule: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
templateObjects := struct {
|
||||
ServiceName string
|
||||
Domain string
|
||||
Attributes []string
|
||||
}{
|
||||
ServiceName: service.ServiceName,
|
||||
Domain: p.Domain,
|
||||
Attributes: service.Attributes,
|
||||
}
|
||||
|
||||
var buffer bytes.Buffer
|
||||
err = tmpl.Execute(&buffer, templateObjects)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to execute Consul Catalog custom frontend rule template: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.TrimSuffix(buffer.String(), ".")
|
||||
}
|
||||
|
||||
func (p *Provider) getServer(node *api.ServiceEntry) types.Server {
|
||||
scheme := p.getAttribute(label.SuffixProtocol, node.Service.Tags, label.DefaultProtocol)
|
||||
address := getBackendAddress(node)
|
||||
|
||||
return types.Server{
|
||||
URL: fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(address, strconv.Itoa(node.Service.Port))),
|
||||
Weight: p.getWeight(node.Service.Tags),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Provider) setupFrontEndRuleTemplate() {
|
||||
var FuncMap = template.FuncMap{
|
||||
"getAttribute": p.getAttribute,
|
||||
"getTag": getTag,
|
||||
"hasTag": hasTag,
|
||||
}
|
||||
p.frontEndRuleTemplate = template.New("consul catalog frontend rule").Funcs(FuncMap)
|
||||
}
|
||||
|
||||
// Specific functions
|
||||
|
||||
func getServiceBackendName(service *serviceUpdate) string {
|
||||
if service.ParentServiceName != "" {
|
||||
return strings.ToLower(service.ParentServiceName)
|
||||
}
|
||||
return strings.ToLower(service.ServiceName)
|
||||
}
|
||||
|
||||
func getNodeBackendName(node *api.ServiceEntry) string {
|
||||
return strings.ToLower(node.Service.Service)
|
||||
}
|
||||
|
||||
func getBackendAddress(node *api.ServiceEntry) string {
|
||||
if node.Service.Address != "" {
|
||||
return node.Service.Address
|
||||
}
|
||||
return node.Node.Address
|
||||
}
|
||||
|
||||
func getServerName(node *api.ServiceEntry, index int) string {
|
||||
serviceName := node.Service.Service + node.Service.Address + strconv.Itoa(node.Service.Port)
|
||||
// TODO sort tags ?
|
||||
serviceName += strings.Join(node.Service.Tags, "")
|
||||
|
||||
hash := sha1.New()
|
||||
_, err := hash.Write([]byte(serviceName))
|
||||
if err != nil {
|
||||
// Impossible case
|
||||
log.Error(err)
|
||||
} else {
|
||||
serviceName = base64.URLEncoding.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
|
||||
// unique int at the end
|
||||
return provider.Normalize(node.Service.Service + "-" + strconv.Itoa(index) + "-" + serviceName)
|
||||
}
|
||||
|
||||
func (p *Provider) getWeight(tags []string) int {
|
||||
labels := tagsToNeutralLabels(tags, p.Prefix)
|
||||
return label.GetIntValue(labels, p.getPrefixedName(label.SuffixWeight), label.DefaultWeight)
|
||||
}
|
||||
|
||||
// Base functions
|
||||
|
||||
func (p *Provider) getAttribute(name string, tags []string, defaultValue string) string {
|
||||
return getTag(p.getPrefixedName(name), tags, defaultValue)
|
||||
}
|
||||
|
||||
func (p *Provider) getPrefixedName(name string) string {
|
||||
if len(p.Prefix) > 0 && len(name) > 0 {
|
||||
return p.Prefix + "." + name
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func hasTag(name string, tags []string) bool {
|
||||
lowerName := strings.ToLower(name)
|
||||
|
||||
for _, tag := range tags {
|
||||
lowerTag := strings.ToLower(tag)
|
||||
|
||||
// Given the nature of Consul tags, which could be either singular markers, or key=value pairs
|
||||
if strings.HasPrefix(lowerTag, lowerName+"=") || lowerTag == lowerName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getTag(name string, tags []string, defaultValue string) string {
|
||||
lowerName := strings.ToLower(name)
|
||||
|
||||
for _, tag := range tags {
|
||||
lowerTag := strings.ToLower(tag)
|
||||
|
||||
// Given the nature of Consul tags, which could be either singular markers, or key=value pairs
|
||||
if strings.HasPrefix(lowerTag, lowerName+"=") || lowerTag == lowerName {
|
||||
// In case, where a tag might be a key=value, try to split it by the first '='
|
||||
kv := strings.SplitN(tag, "=", 2)
|
||||
|
||||
// If the returned result is a key=value pair, return the 'value' component
|
||||
if len(kv) == 2 {
|
||||
return kv[1]
|
||||
}
|
||||
// If the returned result is a singular marker, return the 'key' component
|
||||
return kv[0]
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
1231
old/provider/consulcatalog/config_test.go
Normal file
1231
old/provider/consulcatalog/config_test.go
Normal file
File diff suppressed because it is too large
Load diff
617
old/provider/consulcatalog/consul_catalog.go
Normal file
617
old/provider/consulcatalog/consul_catalog.go
Normal file
|
|
@ -0,0 +1,617 @@
|
|||
package consulcatalog
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/BurntSushi/ty/fun"
|
||||
"github.com/cenk/backoff"
|
||||
"github.com/containous/traefik/job"
|
||||
"github.com/containous/traefik/old/log"
|
||||
"github.com/containous/traefik/old/provider"
|
||||
"github.com/containous/traefik/old/provider/label"
|
||||
"github.com/containous/traefik/old/types"
|
||||
"github.com/containous/traefik/safe"
|
||||
"github.com/hashicorp/consul/api"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultWatchWaitTime is the duration to wait when polling consul
|
||||
DefaultWatchWaitTime = 15 * time.Second
|
||||
)
|
||||
|
||||
var _ provider.Provider = (*Provider)(nil)
|
||||
|
||||
// Provider holds configurations of the Consul catalog provider.
|
||||
type Provider struct {
|
||||
provider.BaseProvider `mapstructure:",squash" export:"true"`
|
||||
Endpoint string `description:"Consul server endpoint"`
|
||||
Domain string `description:"Default domain used"`
|
||||
Stale bool `description:"Use stale consistency for catalog reads" export:"true"`
|
||||
ExposedByDefault bool `description:"Expose Consul services by default" export:"true"`
|
||||
Prefix string `description:"Prefix used for Consul catalog tags" export:"true"`
|
||||
FrontEndRule string `description:"Frontend rule used for Consul services" export:"true"`
|
||||
TLS *types.ClientTLS `description:"Enable TLS support" export:"true"`
|
||||
client *api.Client
|
||||
frontEndRuleTemplate *template.Template
|
||||
}
|
||||
|
||||
// Service represent a Consul service.
|
||||
type Service struct {
|
||||
Name string
|
||||
Tags []string
|
||||
Nodes []string
|
||||
Addresses []string
|
||||
Ports []int
|
||||
}
|
||||
|
||||
type serviceUpdate struct {
|
||||
ServiceName string
|
||||
ParentServiceName string
|
||||
Attributes []string
|
||||
TraefikLabels map[string]string
|
||||
}
|
||||
|
||||
type frontendSegment struct {
|
||||
Name string
|
||||
Labels map[string]string
|
||||
}
|
||||
|
||||
type catalogUpdate struct {
|
||||
Service *serviceUpdate
|
||||
Nodes []*api.ServiceEntry
|
||||
}
|
||||
|
||||
type nodeSorter []*api.ServiceEntry
|
||||
|
||||
func (a nodeSorter) Len() int {
|
||||
return len(a)
|
||||
}
|
||||
|
||||
func (a nodeSorter) Swap(i int, j int) {
|
||||
a[i], a[j] = a[j], a[i]
|
||||
}
|
||||
|
||||
func (a nodeSorter) Less(i int, j int) bool {
|
||||
lEntry := a[i]
|
||||
rEntry := a[j]
|
||||
|
||||
ls := strings.ToLower(lEntry.Service.Service)
|
||||
lr := strings.ToLower(rEntry.Service.Service)
|
||||
|
||||
if ls != lr {
|
||||
return ls < lr
|
||||
}
|
||||
if lEntry.Service.Address != rEntry.Service.Address {
|
||||
return lEntry.Service.Address < rEntry.Service.Address
|
||||
}
|
||||
if lEntry.Node.Address != rEntry.Node.Address {
|
||||
return lEntry.Node.Address < rEntry.Node.Address
|
||||
}
|
||||
return lEntry.Service.Port < rEntry.Service.Port
|
||||
}
|
||||
|
||||
// Init the provider
|
||||
func (p *Provider) Init(constraints types.Constraints) error {
|
||||
err := p.BaseProvider.Init(constraints)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := p.createClient()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.client = client
|
||||
p.setupFrontEndRuleTemplate()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Provide allows the consul catalog provider to provide configurations to traefik
|
||||
// using the given configuration channel.
|
||||
func (p *Provider) Provide(configurationChan chan<- types.ConfigMessage, pool *safe.Pool) error {
|
||||
pool.Go(func(stop chan bool) {
|
||||
notify := func(err error, time time.Duration) {
|
||||
log.Errorf("Consul connection error %+v, retrying in %s", err, time)
|
||||
}
|
||||
operation := func() error {
|
||||
return p.watch(configurationChan, stop)
|
||||
}
|
||||
errRetry := backoff.RetryNotify(safe.OperationWithRecover(operation), job.NewBackOff(backoff.NewExponentialBackOff()), notify)
|
||||
if errRetry != nil {
|
||||
log.Errorf("Cannot connect to consul server %+v", errRetry)
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Provider) createClient() (*api.Client, error) {
|
||||
config := api.DefaultConfig()
|
||||
config.Address = p.Endpoint
|
||||
if p.TLS != nil {
|
||||
tlsConfig, err := p.TLS.CreateTLSConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config.Scheme = "https"
|
||||
config.Transport.TLSClientConfig = tlsConfig
|
||||
}
|
||||
|
||||
client, err := api.NewClient(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (p *Provider) watch(configurationChan chan<- types.ConfigMessage, stop chan bool) error {
|
||||
stopCh := make(chan struct{})
|
||||
watchCh := make(chan map[string][]string)
|
||||
errorCh := make(chan error)
|
||||
|
||||
var errorOnce sync.Once
|
||||
notifyError := func(err error) {
|
||||
errorOnce.Do(func() {
|
||||
errorCh <- err
|
||||
})
|
||||
}
|
||||
|
||||
p.watchHealthState(stopCh, watchCh, notifyError)
|
||||
p.watchCatalogServices(stopCh, watchCh, notifyError)
|
||||
|
||||
defer close(stopCh)
|
||||
defer close(watchCh)
|
||||
|
||||
safe.Go(func() {
|
||||
for index := range watchCh {
|
||||
log.Debug("List of services changed")
|
||||
nodes, err := p.getNodes(index)
|
||||
if err != nil {
|
||||
notifyError(err)
|
||||
}
|
||||
configuration := p.buildConfiguration(nodes)
|
||||
configurationChan <- types.ConfigMessage{
|
||||
ProviderName: "consul_catalog",
|
||||
Configuration: configuration,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return nil
|
||||
case err := <-errorCh:
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Provider) watchCatalogServices(stopCh <-chan struct{}, watchCh chan<- map[string][]string, notifyError func(error)) {
|
||||
catalog := p.client.Catalog()
|
||||
|
||||
safe.Go(func() {
|
||||
// variable to hold previous state
|
||||
var flashback map[string]Service
|
||||
|
||||
options := &api.QueryOptions{WaitTime: DefaultWatchWaitTime, AllowStale: p.Stale}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
data, meta, err := catalog.Services(options)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to list services: %v", err)
|
||||
notifyError(err)
|
||||
return
|
||||
}
|
||||
|
||||
if options.WaitIndex == meta.LastIndex {
|
||||
continue
|
||||
}
|
||||
|
||||
options.WaitIndex = meta.LastIndex
|
||||
|
||||
if data != nil {
|
||||
current := make(map[string]Service)
|
||||
for key, value := range data {
|
||||
nodes, _, err := catalog.Service(key, "", &api.QueryOptions{AllowStale: p.Stale})
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get detail of service %s: %v", key, err)
|
||||
notifyError(err)
|
||||
return
|
||||
}
|
||||
|
||||
nodesID := getServiceIds(nodes)
|
||||
ports := getServicePorts(nodes)
|
||||
addresses := getServiceAddresses(nodes)
|
||||
|
||||
if service, ok := current[key]; ok {
|
||||
service.Tags = value
|
||||
service.Nodes = nodesID
|
||||
service.Ports = ports
|
||||
} else {
|
||||
service := Service{
|
||||
Name: key,
|
||||
Tags: value,
|
||||
Nodes: nodesID,
|
||||
Addresses: addresses,
|
||||
Ports: ports,
|
||||
}
|
||||
current[key] = service
|
||||
}
|
||||
}
|
||||
|
||||
// A critical note is that the return of a blocking request is no guarantee of a change.
|
||||
// It is possible that there was an idempotent write that does not affect the result of the query.
|
||||
// Thus it is required to do extra check for changes...
|
||||
if hasChanged(current, flashback) {
|
||||
watchCh <- data
|
||||
flashback = current
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Provider) watchHealthState(stopCh <-chan struct{}, watchCh chan<- map[string][]string, notifyError func(error)) {
|
||||
health := p.client.Health()
|
||||
catalog := p.client.Catalog()
|
||||
|
||||
safe.Go(func() {
|
||||
// variable to hold previous state
|
||||
var flashback map[string][]string
|
||||
var flashbackMaintenance []string
|
||||
|
||||
options := &api.QueryOptions{WaitTime: DefaultWatchWaitTime, AllowStale: p.Stale}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Listening to changes that leads to `passing` state or degrades from it.
|
||||
healthyState, meta, err := health.State("any", options)
|
||||
if err != nil {
|
||||
log.WithError(err).Error("Failed to retrieve health checks")
|
||||
notifyError(err)
|
||||
return
|
||||
}
|
||||
|
||||
var current = make(map[string][]string)
|
||||
var currentFailing = make(map[string]*api.HealthCheck)
|
||||
var maintenance []string
|
||||
if healthyState != nil {
|
||||
for _, healthy := range healthyState {
|
||||
key := fmt.Sprintf("%s-%s", healthy.Node, healthy.ServiceID)
|
||||
_, failing := currentFailing[key]
|
||||
if healthy.Status == "passing" && !failing {
|
||||
current[key] = append(current[key], healthy.Node)
|
||||
} else if strings.HasPrefix(healthy.CheckID, "_service_maintenance") || strings.HasPrefix(healthy.CheckID, "_node_maintenance") {
|
||||
maintenance = append(maintenance, healthy.CheckID)
|
||||
} else {
|
||||
currentFailing[key] = healthy
|
||||
if _, ok := current[key]; ok {
|
||||
delete(current, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If LastIndex didn't change then it means `Get` returned
|
||||
// because of the WaitTime and the key didn't changed.
|
||||
if options.WaitIndex == meta.LastIndex {
|
||||
continue
|
||||
}
|
||||
|
||||
options.WaitIndex = meta.LastIndex
|
||||
|
||||
// The response should be unified with watchCatalogServices
|
||||
data, _, err := catalog.Services(&api.QueryOptions{AllowStale: p.Stale})
|
||||
if err != nil {
|
||||
log.Errorf("Failed to list services: %v", err)
|
||||
notifyError(err)
|
||||
return
|
||||
}
|
||||
|
||||
if data != nil {
|
||||
// A critical note is that the return of a blocking request is no guarantee of a change.
|
||||
// It is possible that there was an idempotent write that does not affect the result of the query.
|
||||
// Thus it is required to do extra check for changes...
|
||||
addedKeys, removedKeys, changedKeys := getChangedHealth(current, flashback)
|
||||
|
||||
if len(addedKeys) > 0 || len(removedKeys) > 0 || len(changedKeys) > 0 {
|
||||
log.WithField("DiscoveredServices", addedKeys).
|
||||
WithField("MissingServices", removedKeys).
|
||||
WithField("ChangedServices", changedKeys).
|
||||
Debug("Health State change detected.")
|
||||
|
||||
watchCh <- data
|
||||
flashback = current
|
||||
flashbackMaintenance = maintenance
|
||||
} else {
|
||||
addedKeysMaintenance, removedMaintenance := getChangedStringKeys(maintenance, flashbackMaintenance)
|
||||
|
||||
if len(addedKeysMaintenance) > 0 || len(removedMaintenance) > 0 {
|
||||
log.WithField("MaintenanceMode", maintenance).Debug("Maintenance change detected.")
|
||||
watchCh <- data
|
||||
flashback = current
|
||||
flashbackMaintenance = maintenance
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Provider) getNodes(index map[string][]string) ([]catalogUpdate, error) {
|
||||
visited := make(map[string]bool)
|
||||
|
||||
var nodes []catalogUpdate
|
||||
for service := range index {
|
||||
name := strings.ToLower(service)
|
||||
if !strings.Contains(name, " ") && !visited[name] {
|
||||
visited[name] = true
|
||||
log.WithField("service", name).Debug("Fetching service")
|
||||
healthy, err := p.healthyNodes(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// healthy.Nodes can be empty if constraints do not match, without throwing error
|
||||
if healthy.Service != nil && len(healthy.Nodes) > 0 {
|
||||
nodes = append(nodes, healthy)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nodes, nil
|
||||
}
|
||||
|
||||
func hasChanged(current map[string]Service, previous map[string]Service) bool {
|
||||
if len(current) != len(previous) {
|
||||
return true
|
||||
}
|
||||
addedServiceKeys, removedServiceKeys := getChangedServiceKeys(current, previous)
|
||||
return len(removedServiceKeys) > 0 || len(addedServiceKeys) > 0 || hasServiceChanged(current, previous)
|
||||
}
|
||||
|
||||
func getChangedServiceKeys(current map[string]Service, previous map[string]Service) ([]string, []string) {
|
||||
currKeySet := fun.Set(fun.Keys(current).([]string)).(map[string]bool)
|
||||
prevKeySet := fun.Set(fun.Keys(previous).([]string)).(map[string]bool)
|
||||
|
||||
addedKeys := fun.Difference(currKeySet, prevKeySet).(map[string]bool)
|
||||
removedKeys := fun.Difference(prevKeySet, currKeySet).(map[string]bool)
|
||||
|
||||
return fun.Keys(addedKeys).([]string), fun.Keys(removedKeys).([]string)
|
||||
}
|
||||
|
||||
func hasServiceChanged(current map[string]Service, previous map[string]Service) bool {
|
||||
for key, value := range current {
|
||||
if prevValue, ok := previous[key]; ok {
|
||||
addedNodesKeys, removedNodesKeys := getChangedStringKeys(value.Nodes, prevValue.Nodes)
|
||||
if len(addedNodesKeys) > 0 || len(removedNodesKeys) > 0 {
|
||||
return true
|
||||
}
|
||||
addedTagsKeys, removedTagsKeys := getChangedStringKeys(value.Tags, prevValue.Tags)
|
||||
if len(addedTagsKeys) > 0 || len(removedTagsKeys) > 0 {
|
||||
return true
|
||||
}
|
||||
addedAddressesKeys, removedAddressesKeys := getChangedStringKeys(value.Addresses, prevValue.Addresses)
|
||||
if len(addedAddressesKeys) > 0 || len(removedAddressesKeys) > 0 {
|
||||
return true
|
||||
}
|
||||
addedPortsKeys, removedPortsKeys := getChangedIntKeys(value.Ports, prevValue.Ports)
|
||||
if len(addedPortsKeys) > 0 || len(removedPortsKeys) > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func getChangedStringKeys(currState []string, prevState []string) ([]string, []string) {
|
||||
currKeySet := fun.Set(currState).(map[string]bool)
|
||||
prevKeySet := fun.Set(prevState).(map[string]bool)
|
||||
|
||||
addedKeys := fun.Difference(currKeySet, prevKeySet).(map[string]bool)
|
||||
removedKeys := fun.Difference(prevKeySet, currKeySet).(map[string]bool)
|
||||
|
||||
return fun.Keys(addedKeys).([]string), fun.Keys(removedKeys).([]string)
|
||||
}
|
||||
|
||||
func getChangedHealth(current map[string][]string, previous map[string][]string) ([]string, []string, []string) {
|
||||
currKeySet := fun.Set(fun.Keys(current).([]string)).(map[string]bool)
|
||||
prevKeySet := fun.Set(fun.Keys(previous).([]string)).(map[string]bool)
|
||||
|
||||
addedKeys := fun.Difference(currKeySet, prevKeySet).(map[string]bool)
|
||||
removedKeys := fun.Difference(prevKeySet, currKeySet).(map[string]bool)
|
||||
|
||||
var changedKeys []string
|
||||
|
||||
for key, value := range current {
|
||||
if prevValue, ok := previous[key]; ok {
|
||||
addedNodesKeys, removedNodesKeys := getChangedStringKeys(value, prevValue)
|
||||
if len(addedNodesKeys) > 0 || len(removedNodesKeys) > 0 {
|
||||
changedKeys = append(changedKeys, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fun.Keys(addedKeys).([]string), fun.Keys(removedKeys).([]string), changedKeys
|
||||
}
|
||||
|
||||
func getChangedIntKeys(currState []int, prevState []int) ([]int, []int) {
|
||||
currKeySet := fun.Set(currState).(map[int]bool)
|
||||
prevKeySet := fun.Set(prevState).(map[int]bool)
|
||||
|
||||
addedKeys := fun.Difference(currKeySet, prevKeySet).(map[int]bool)
|
||||
removedKeys := fun.Difference(prevKeySet, currKeySet).(map[int]bool)
|
||||
|
||||
return fun.Keys(addedKeys).([]int), fun.Keys(removedKeys).([]int)
|
||||
}
|
||||
|
||||
func getServiceIds(services []*api.CatalogService) []string {
|
||||
var serviceIds []string
|
||||
for _, service := range services {
|
||||
serviceIds = append(serviceIds, service.ID)
|
||||
}
|
||||
return serviceIds
|
||||
}
|
||||
|
||||
func getServicePorts(services []*api.CatalogService) []int {
|
||||
var servicePorts []int
|
||||
for _, service := range services {
|
||||
servicePorts = append(servicePorts, service.ServicePort)
|
||||
}
|
||||
return servicePorts
|
||||
}
|
||||
|
||||
func getServiceAddresses(services []*api.CatalogService) []string {
|
||||
var serviceAddresses []string
|
||||
for _, service := range services {
|
||||
serviceAddresses = append(serviceAddresses, service.ServiceAddress)
|
||||
}
|
||||
return serviceAddresses
|
||||
}
|
||||
|
||||
func (p *Provider) healthyNodes(service string) (catalogUpdate, error) {
|
||||
health := p.client.Health()
|
||||
data, _, err := health.Service(service, "", true, &api.QueryOptions{AllowStale: p.Stale})
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("Failed to fetch details of %s", service)
|
||||
return catalogUpdate{}, err
|
||||
}
|
||||
|
||||
nodes := fun.Filter(func(node *api.ServiceEntry) bool {
|
||||
return p.nodeFilter(service, node)
|
||||
}, data).([]*api.ServiceEntry)
|
||||
|
||||
// Merge tags of nodes matching constraints, in a single slice.
|
||||
tags := fun.Foldl(func(node *api.ServiceEntry, set []string) []string {
|
||||
return fun.Keys(fun.Union(
|
||||
fun.Set(set),
|
||||
fun.Set(node.Service.Tags),
|
||||
).(map[string]bool)).([]string)
|
||||
}, []string{}, nodes).([]string)
|
||||
|
||||
labels := tagsToNeutralLabels(tags, p.Prefix)
|
||||
|
||||
return catalogUpdate{
|
||||
Service: &serviceUpdate{
|
||||
ServiceName: service,
|
||||
Attributes: tags,
|
||||
TraefikLabels: labels,
|
||||
},
|
||||
Nodes: nodes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *Provider) nodeFilter(service string, node *api.ServiceEntry) bool {
|
||||
// Filter disabled application.
|
||||
if !p.isServiceEnabled(node) {
|
||||
log.Debugf("Filtering disabled Consul service %s", service)
|
||||
return false
|
||||
}
|
||||
|
||||
// Filter by constraints.
|
||||
constraintTags := p.getConstraintTags(node.Service.Tags)
|
||||
ok, failingConstraint := p.MatchConstraints(constraintTags)
|
||||
if !ok && failingConstraint != nil {
|
||||
log.Debugf("Service %v pruned by '%v' constraint", service, failingConstraint.String())
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Provider) isServiceEnabled(node *api.ServiceEntry) bool {
|
||||
rawValue := getTag(p.getPrefixedName(label.SuffixEnable), node.Service.Tags, "")
|
||||
|
||||
if len(rawValue) == 0 {
|
||||
return p.ExposedByDefault
|
||||
}
|
||||
|
||||
value, err := strconv.ParseBool(rawValue)
|
||||
if err != nil {
|
||||
log.Errorf("Invalid value for %s: %s", label.SuffixEnable, rawValue)
|
||||
return p.ExposedByDefault
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (p *Provider) getConstraintTags(tags []string) []string {
|
||||
var values []string
|
||||
|
||||
prefix := p.getPrefixedName("tags=")
|
||||
for _, tag := range tags {
|
||||
// We look for a Consul tag named 'traefik.tags' (unless different 'prefix' is configured)
|
||||
if strings.HasPrefix(strings.ToLower(tag), prefix) {
|
||||
// If 'traefik.tags=' tag is found, take the tag value and split by ',' adding the result to the list to be returned
|
||||
splitedTags := label.SplitAndTrimString(tag[len(prefix):], ",")
|
||||
values = append(values, splitedTags...)
|
||||
}
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
|
||||
func (p *Provider) generateFrontends(service *serviceUpdate) []*serviceUpdate {
|
||||
frontends := make([]*serviceUpdate, 0)
|
||||
// to support <prefix>.frontend.xxx
|
||||
frontends = append(frontends, &serviceUpdate{
|
||||
ServiceName: service.ServiceName,
|
||||
ParentServiceName: service.ServiceName,
|
||||
Attributes: service.Attributes,
|
||||
TraefikLabels: service.TraefikLabels,
|
||||
})
|
||||
|
||||
// loop over children of <prefix>.frontends.*
|
||||
for _, frontend := range getSegments(p.Prefix+".frontends", p.Prefix, service.TraefikLabels) {
|
||||
frontends = append(frontends, &serviceUpdate{
|
||||
ServiceName: service.ServiceName + "-" + frontend.Name,
|
||||
ParentServiceName: service.ServiceName,
|
||||
Attributes: service.Attributes,
|
||||
TraefikLabels: frontend.Labels,
|
||||
})
|
||||
}
|
||||
|
||||
return frontends
|
||||
}
|
||||
func getSegments(path string, prefix string, tree map[string]string) []*frontendSegment {
|
||||
segments := make([]*frontendSegment, 0)
|
||||
// find segment names
|
||||
segmentNames := make(map[string]bool)
|
||||
for key := range tree {
|
||||
if strings.HasPrefix(key, path+".") {
|
||||
segmentNames[strings.SplitN(strings.TrimPrefix(key, path+"."), ".", 2)[0]] = true
|
||||
}
|
||||
}
|
||||
|
||||
// get labels for each segment found
|
||||
for segment := range segmentNames {
|
||||
labels := make(map[string]string)
|
||||
for key, value := range tree {
|
||||
if strings.HasPrefix(key, path+"."+segment) {
|
||||
labels[prefix+".frontend"+strings.TrimPrefix(key, path+"."+segment)] = value
|
||||
}
|
||||
}
|
||||
segments = append(segments, &frontendSegment{
|
||||
Name: segment,
|
||||
Labels: labels,
|
||||
})
|
||||
}
|
||||
|
||||
return segments
|
||||
}
|
||||
862
old/provider/consulcatalog/consul_catalog_test.go
Normal file
862
old/provider/consulcatalog/consul_catalog_test.go
Normal file
|
|
@ -0,0 +1,862 @@
|
|||
package consulcatalog
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/BurntSushi/ty/fun"
|
||||
"github.com/hashicorp/consul/api"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNodeSorter(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
nodes []*api.ServiceEntry
|
||||
expected []*api.ServiceEntry
|
||||
}{
|
||||
{
|
||||
desc: "Should sort nothing",
|
||||
nodes: []*api.ServiceEntry{},
|
||||
expected: []*api.ServiceEntry{},
|
||||
},
|
||||
{
|
||||
desc: "Should sort by node address",
|
||||
nodes: []*api.ServiceEntry{
|
||||
{
|
||||
Service: &api.AgentService{
|
||||
Service: "foo",
|
||||
Address: "127.0.0.1",
|
||||
Port: 80,
|
||||
},
|
||||
Node: &api.Node{
|
||||
Node: "localhost",
|
||||
Address: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: []*api.ServiceEntry{
|
||||
{
|
||||
Service: &api.AgentService{
|
||||
Service: "foo",
|
||||
Address: "127.0.0.1",
|
||||
Port: 80,
|
||||
},
|
||||
Node: &api.Node{
|
||||
Node: "localhost",
|
||||
Address: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Should sort by service name",
|
||||
nodes: []*api.ServiceEntry{
|
||||
{
|
||||
Service: &api.AgentService{
|
||||
Service: "foo",
|
||||
Address: "127.0.0.2",
|
||||
Port: 80,
|
||||
},
|
||||
Node: &api.Node{
|
||||
Node: "localhost",
|
||||
Address: "127.0.0.2",
|
||||
},
|
||||
},
|
||||
{
|
||||
Service: &api.AgentService{
|
||||
Service: "bar",
|
||||
Address: "127.0.0.2",
|
||||
Port: 81,
|
||||
},
|
||||
Node: &api.Node{
|
||||
Node: "localhost",
|
||||
Address: "127.0.0.2",
|
||||
},
|
||||
},
|
||||
{
|
||||
Service: &api.AgentService{
|
||||
Service: "foo",
|
||||
Address: "127.0.0.1",
|
||||
Port: 80,
|
||||
},
|
||||
Node: &api.Node{
|
||||
Node: "localhost",
|
||||
Address: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
{
|
||||
Service: &api.AgentService{
|
||||
Service: "bar",
|
||||
Address: "127.0.0.2",
|
||||
Port: 80,
|
||||
},
|
||||
Node: &api.Node{
|
||||
Node: "localhost",
|
||||
Address: "127.0.0.2",
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: []*api.ServiceEntry{
|
||||
{
|
||||
Service: &api.AgentService{
|
||||
Service: "bar",
|
||||
Address: "127.0.0.2",
|
||||
Port: 80,
|
||||
},
|
||||
Node: &api.Node{
|
||||
Node: "localhost",
|
||||
Address: "127.0.0.2",
|
||||
},
|
||||
},
|
||||
{
|
||||
Service: &api.AgentService{
|
||||
Service: "bar",
|
||||
Address: "127.0.0.2",
|
||||
Port: 81,
|
||||
},
|
||||
Node: &api.Node{
|
||||
Node: "localhost",
|
||||
Address: "127.0.0.2",
|
||||
},
|
||||
},
|
||||
{
|
||||
Service: &api.AgentService{
|
||||
Service: "foo",
|
||||
Address: "127.0.0.1",
|
||||
Port: 80,
|
||||
},
|
||||
Node: &api.Node{
|
||||
Node: "localhost",
|
||||
Address: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
{
|
||||
Service: &api.AgentService{
|
||||
Service: "foo",
|
||||
Address: "127.0.0.2",
|
||||
Port: 80,
|
||||
},
|
||||
Node: &api.Node{
|
||||
Node: "localhost",
|
||||
Address: "127.0.0.2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Should sort by node address",
|
||||
nodes: []*api.ServiceEntry{
|
||||
{
|
||||
Service: &api.AgentService{
|
||||
Service: "foo",
|
||||
Address: "",
|
||||
Port: 80,
|
||||
},
|
||||
Node: &api.Node{
|
||||
Node: "localhost",
|
||||
Address: "127.0.0.2",
|
||||
},
|
||||
},
|
||||
{
|
||||
Service: &api.AgentService{
|
||||
Service: "foo",
|
||||
Address: "",
|
||||
Port: 80,
|
||||
},
|
||||
Node: &api.Node{
|
||||
Node: "localhost",
|
||||
Address: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: []*api.ServiceEntry{
|
||||
{
|
||||
Service: &api.AgentService{
|
||||
Service: "foo",
|
||||
Address: "",
|
||||
Port: 80,
|
||||
},
|
||||
Node: &api.Node{
|
||||
Node: "localhost",
|
||||
Address: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
{
|
||||
Service: &api.AgentService{
|
||||
Service: "foo",
|
||||
Address: "",
|
||||
Port: 80,
|
||||
},
|
||||
Node: &api.Node{
|
||||
Node: "localhost",
|
||||
Address: "127.0.0.2",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sort.Sort(nodeSorter(test.nodes))
|
||||
actual := test.nodes
|
||||
assert.Equal(t, test.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetChangedKeys(t *testing.T) {
|
||||
type Input struct {
|
||||
currState map[string]Service
|
||||
prevState map[string]Service
|
||||
}
|
||||
|
||||
type Output struct {
|
||||
addedKeys []string
|
||||
removedKeys []string
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
input Input
|
||||
output Output
|
||||
}{
|
||||
{
|
||||
desc: "Should add 0 services and removed 0",
|
||||
input: Input{
|
||||
currState: map[string]Service{
|
||||
"foo-service": {Name: "v1"},
|
||||
"bar-service": {Name: "v1"},
|
||||
"baz-service": {Name: "v1"},
|
||||
"qux-service": {Name: "v1"},
|
||||
"quux-service": {Name: "v1"},
|
||||
"quuz-service": {Name: "v1"},
|
||||
"corge-service": {Name: "v1"},
|
||||
"grault-service": {Name: "v1"},
|
||||
"garply-service": {Name: "v1"},
|
||||
"waldo-service": {Name: "v1"},
|
||||
"fred-service": {Name: "v1"},
|
||||
"plugh-service": {Name: "v1"},
|
||||
"xyzzy-service": {Name: "v1"},
|
||||
"thud-service": {Name: "v1"},
|
||||
},
|
||||
prevState: map[string]Service{
|
||||
"foo-service": {Name: "v1"},
|
||||
"bar-service": {Name: "v1"},
|
||||
"baz-service": {Name: "v1"},
|
||||
"qux-service": {Name: "v1"},
|
||||
"quux-service": {Name: "v1"},
|
||||
"quuz-service": {Name: "v1"},
|
||||
"corge-service": {Name: "v1"},
|
||||
"grault-service": {Name: "v1"},
|
||||
"garply-service": {Name: "v1"},
|
||||
"waldo-service": {Name: "v1"},
|
||||
"fred-service": {Name: "v1"},
|
||||
"plugh-service": {Name: "v1"},
|
||||
"xyzzy-service": {Name: "v1"},
|
||||
"thud-service": {Name: "v1"},
|
||||
},
|
||||
},
|
||||
output: Output{
|
||||
addedKeys: []string{},
|
||||
removedKeys: []string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Should add 3 services and removed 0",
|
||||
input: Input{
|
||||
currState: map[string]Service{
|
||||
"foo-service": {Name: "v1"},
|
||||
"bar-service": {Name: "v1"},
|
||||
"baz-service": {Name: "v1"},
|
||||
"qux-service": {Name: "v1"},
|
||||
"quux-service": {Name: "v1"},
|
||||
"quuz-service": {Name: "v1"},
|
||||
"corge-service": {Name: "v1"},
|
||||
"grault-service": {Name: "v1"},
|
||||
"garply-service": {Name: "v1"},
|
||||
"waldo-service": {Name: "v1"},
|
||||
"fred-service": {Name: "v1"},
|
||||
"plugh-service": {Name: "v1"},
|
||||
"xyzzy-service": {Name: "v1"},
|
||||
"thud-service": {Name: "v1"},
|
||||
},
|
||||
prevState: map[string]Service{
|
||||
"foo-service": {Name: "v1"},
|
||||
"bar-service": {Name: "v1"},
|
||||
"baz-service": {Name: "v1"},
|
||||
"corge-service": {Name: "v1"},
|
||||
"grault-service": {Name: "v1"},
|
||||
"garply-service": {Name: "v1"},
|
||||
"waldo-service": {Name: "v1"},
|
||||
"fred-service": {Name: "v1"},
|
||||
"plugh-service": {Name: "v1"},
|
||||
"xyzzy-service": {Name: "v1"},
|
||||
"thud-service": {Name: "v1"},
|
||||
},
|
||||
},
|
||||
output: Output{
|
||||
addedKeys: []string{"qux-service", "quux-service", "quuz-service"},
|
||||
removedKeys: []string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "Should add 2 services and removed 2",
|
||||
input: Input{
|
||||
currState: map[string]Service{
|
||||
"foo-service": {Name: "v1"},
|
||||
"qux-service": {Name: "v1"},
|
||||
"quux-service": {Name: "v1"},
|
||||
"quuz-service": {Name: "v1"},
|
||||
"corge-service": {Name: "v1"},
|
||||
"grault-service": {Name: "v1"},
|
||||
"garply-service": {Name: "v1"},
|
||||
"waldo-service": {Name: "v1"},
|
||||
"fred-service": {Name: "v1"},
|
||||
"plugh-service": {Name: "v1"},
|
||||
"xyzzy-service": {Name: "v1"},
|
||||
"thud-service": {Name: "v1"},
|
||||
},
|
||||
prevState: map[string]Service{
|
||||
"foo-service": {Name: "v1"},
|
||||
"bar-service": {Name: "v1"},
|
||||
"baz-service": {Name: "v1"},
|
||||
"qux-service": {Name: "v1"},
|
||||
"quux-service": {Name: "v1"},
|
||||
"quuz-service": {Name: "v1"},
|
||||
"corge-service": {Name: "v1"},
|
||||
"waldo-service": {Name: "v1"},
|
||||
"fred-service": {Name: "v1"},
|
||||
"plugh-service": {Name: "v1"},
|
||||
"xyzzy-service": {Name: "v1"},
|
||||
"thud-service": {Name: "v1"},
|
||||
},
|
||||
},
|
||||
output: Output{
|
||||
addedKeys: []string{"grault-service", "garply-service"},
|
||||
removedKeys: []string{"bar-service", "baz-service"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
addedKeys, removedKeys := getChangedServiceKeys(test.input.currState, test.input.prevState)
|
||||
assert.Equal(t, fun.Set(test.output.addedKeys), fun.Set(addedKeys), "Added keys comparison results: got %q, want %q", addedKeys, test.output.addedKeys)
|
||||
assert.Equal(t, fun.Set(test.output.removedKeys), fun.Set(removedKeys), "Removed keys comparison results: got %q, want %q", removedKeys, test.output.removedKeys)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterEnabled(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
exposedByDefault bool
|
||||
node *api.ServiceEntry
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
desc: "exposed",
|
||||
exposedByDefault: true,
|
||||
node: &api.ServiceEntry{
|
||||
Service: &api.AgentService{
|
||||
Service: "api",
|
||||
Address: "10.0.0.1",
|
||||
Port: 80,
|
||||
Tags: []string{""},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
desc: "exposed and tolerated by valid label value",
|
||||
exposedByDefault: true,
|
||||
node: &api.ServiceEntry{
|
||||
Service: &api.AgentService{
|
||||
Service: "api",
|
||||
Address: "10.0.0.1",
|
||||
Port: 80,
|
||||
Tags: []string{"", "traefik.enable=true"},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
desc: "exposed and tolerated by invalid label value",
|
||||
exposedByDefault: true,
|
||||
node: &api.ServiceEntry{
|
||||
Service: &api.AgentService{
|
||||
Service: "api",
|
||||
Address: "10.0.0.1",
|
||||
Port: 80,
|
||||
Tags: []string{"", "traefik.enable=bad"},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
desc: "exposed but overridden by label",
|
||||
exposedByDefault: true,
|
||||
node: &api.ServiceEntry{
|
||||
Service: &api.AgentService{
|
||||
Service: "api",
|
||||
Address: "10.0.0.1",
|
||||
Port: 80,
|
||||
Tags: []string{"", "traefik.enable=false"},
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
desc: "non-exposed",
|
||||
exposedByDefault: false,
|
||||
node: &api.ServiceEntry{
|
||||
Service: &api.AgentService{
|
||||
Service: "api",
|
||||
Address: "10.0.0.1",
|
||||
Port: 80,
|
||||
Tags: []string{""},
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
desc: "non-exposed but overridden by label",
|
||||
exposedByDefault: false,
|
||||
node: &api.ServiceEntry{
|
||||
Service: &api.AgentService{
|
||||
Service: "api",
|
||||
Address: "10.0.0.1",
|
||||
Port: 80,
|
||||
Tags: []string{"", "traefik.enable=true"},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
provider := &Provider{
|
||||
Domain: "localhost",
|
||||
Prefix: "traefik",
|
||||
ExposedByDefault: test.exposedByDefault,
|
||||
}
|
||||
actual := provider.nodeFilter("test", test.node)
|
||||
assert.Equal(t, test.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetChangedStringKeys(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
current []string
|
||||
previous []string
|
||||
expectedAdded []string
|
||||
expectedRemoved []string
|
||||
}{
|
||||
{
|
||||
desc: "1 element added, 0 removed",
|
||||
current: []string{"chou"},
|
||||
previous: []string{},
|
||||
expectedAdded: []string{"chou"},
|
||||
expectedRemoved: []string{},
|
||||
}, {
|
||||
desc: "0 element added, 0 removed",
|
||||
current: []string{"chou"},
|
||||
previous: []string{"chou"},
|
||||
expectedAdded: []string{},
|
||||
expectedRemoved: []string{},
|
||||
},
|
||||
{
|
||||
desc: "0 element added, 1 removed",
|
||||
current: []string{},
|
||||
previous: []string{"chou"},
|
||||
expectedAdded: []string{},
|
||||
expectedRemoved: []string{"chou"},
|
||||
},
|
||||
{
|
||||
desc: "1 element added, 1 removed",
|
||||
current: []string{"carotte"},
|
||||
previous: []string{"chou"},
|
||||
expectedAdded: []string{"carotte"},
|
||||
expectedRemoved: []string{"chou"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
actualAdded, actualRemoved := getChangedStringKeys(test.current, test.previous)
|
||||
assert.Equal(t, test.expectedAdded, actualAdded)
|
||||
assert.Equal(t, test.expectedRemoved, actualRemoved)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasServiceChanged(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
current map[string]Service
|
||||
previous map[string]Service
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
desc: "Change detected due to change of nodes",
|
||||
current: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{},
|
||||
},
|
||||
},
|
||||
previous: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node2"},
|
||||
Tags: []string{},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
desc: "No change missing current service",
|
||||
current: make(map[string]Service),
|
||||
previous: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{},
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
desc: "No change on nodes",
|
||||
current: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{},
|
||||
},
|
||||
},
|
||||
previous: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{},
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
desc: "No change on nodes and tags",
|
||||
current: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{"foo=bar"},
|
||||
},
|
||||
},
|
||||
previous: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{"foo=bar"},
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
desc: "Change detected on tags",
|
||||
current: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{"foo=bar"},
|
||||
},
|
||||
},
|
||||
previous: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{"foo"},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
desc: "Change detected on ports",
|
||||
current: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{"foo=bar"},
|
||||
Ports: []int{80},
|
||||
},
|
||||
},
|
||||
previous: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{"foo"},
|
||||
Ports: []int{81},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
desc: "Change detected on ports",
|
||||
current: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{"foo=bar"},
|
||||
Ports: []int{80},
|
||||
},
|
||||
},
|
||||
previous: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{"foo"},
|
||||
Ports: []int{81, 82},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
desc: "Change detected on addresses",
|
||||
current: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Addresses: []string{"127.0.0.1"},
|
||||
},
|
||||
},
|
||||
previous: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Addresses: []string{"127.0.0.2"},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
desc: "No Change detected",
|
||||
current: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{"foo"},
|
||||
Ports: []int{80},
|
||||
},
|
||||
},
|
||||
previous: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{"foo"},
|
||||
Ports: []int{80},
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
actual := hasServiceChanged(test.current, test.previous)
|
||||
assert.Equal(t, test.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasChanged(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
current map[string]Service
|
||||
previous map[string]Service
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
desc: "Change detected due to change new service",
|
||||
current: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{},
|
||||
},
|
||||
},
|
||||
previous: make(map[string]Service),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
desc: "Change detected due to change service removed",
|
||||
current: make(map[string]Service),
|
||||
previous: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
desc: "Change detected due to change of nodes",
|
||||
current: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{},
|
||||
},
|
||||
},
|
||||
previous: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node2"},
|
||||
Tags: []string{},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
desc: "No change on nodes",
|
||||
current: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{},
|
||||
},
|
||||
},
|
||||
previous: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{},
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
desc: "No change on nodes and tags",
|
||||
current: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{"foo=bar"},
|
||||
},
|
||||
},
|
||||
previous: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{"foo=bar"},
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
desc: "Change detected on tags",
|
||||
current: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{"foo=bar"},
|
||||
},
|
||||
},
|
||||
previous: map[string]Service{
|
||||
"foo-service": {
|
||||
Name: "foo",
|
||||
Nodes: []string{"node1"},
|
||||
Tags: []string{"foo"},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
actual := hasChanged(test.current, test.previous)
|
||||
assert.Equal(t, test.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetConstraintTags(t *testing.T) {
|
||||
provider := &Provider{
|
||||
Domain: "localhost",
|
||||
Prefix: "traefik",
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
tags []string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
desc: "nil tags",
|
||||
},
|
||||
{
|
||||
desc: "invalid tag",
|
||||
tags: []string{"tags=foobar"},
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
desc: "wrong tag",
|
||||
tags: []string{"traefik_tags=foobar"},
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
desc: "empty value",
|
||||
tags: []string{"traefik.tags="},
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
desc: "simple tag",
|
||||
tags: []string{"traefik.tags=foobar "},
|
||||
expected: []string{"foobar"},
|
||||
},
|
||||
{
|
||||
desc: "multiple values tag",
|
||||
tags: []string{"traefik.tags=foobar, fiibir"},
|
||||
expected: []string{"foobar", "fiibir"},
|
||||
},
|
||||
{
|
||||
desc: "multiple tags",
|
||||
tags: []string{"traefik.tags=foobar", "traefik.tags=foobor"},
|
||||
expected: []string{"foobar", "foobor"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
constraints := provider.getConstraintTags(test.tags)
|
||||
assert.EqualValues(t, test.expected, constraints)
|
||||
})
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue