Rework servers load-balancer to use the WRR

Co-authored-by: Kevin Pollet <pollet.kevin@gmail.com>
This commit is contained in:
Julien Salleyron 2022-11-16 11:38:07 +01:00 committed by GitHub
parent 67d9c8da0b
commit fadee5e87b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
70 changed files with 2085 additions and 2211 deletions

View file

@ -24,7 +24,7 @@ type middlewareBuilder interface {
type serviceManager interface {
BuildHTTP(rootCtx context.Context, serviceName string) (http.Handler, error)
LaunchHealthCheck()
LaunchHealthCheck(ctx context.Context)
}
// Manager A route/router manager.

View file

@ -7,10 +7,12 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/containous/alice"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
ptypes "github.com/traefik/paerser/types"
"github.com/traefik/traefik/v2/pkg/config/dynamic"
"github.com/traefik/traefik/v2/pkg/config/runtime"
"github.com/traefik/traefik/v2/pkg/metrics"
@ -478,7 +480,7 @@ func TestRuntimeConfiguration(t *testing.T) {
},
},
HealthCheck: &dynamic.ServerHealthCheck{
Interval: "500ms",
Interval: ptypes.Duration(500 * time.Millisecond),
Path: "/health",
},
},

View file

@ -31,6 +31,8 @@ type RouterFactory struct {
chainBuilder *middleware.ChainBuilder
tlsManager *tls.Manager
cancelPrevState func()
}
// NewRouterFactory creates a new RouterFactory.
@ -65,7 +67,12 @@ func NewRouterFactory(staticConfiguration static.Configuration, managerFactory *
// CreateRouters creates new TCPRouters and UDPRouters.
func (f *RouterFactory) CreateRouters(rtConf *runtime.Configuration) (map[string]*tcprouter.Router, map[string]udptypes.Handler) {
ctx := context.Background()
if f.cancelPrevState != nil {
f.cancelPrevState()
}
var ctx context.Context
ctx, f.cancelPrevState = context.WithCancel(context.Background())
// HTTP
serviceManager := f.managerFactory.Build(rtConf)
@ -77,7 +84,7 @@ func (f *RouterFactory) CreateRouters(rtConf *runtime.Configuration) (map[string
handlersNonTLS := routerManager.BuildHandlers(ctx, f.entryPointsTCP, false)
handlersTLS := routerManager.BuildHandlers(ctx, f.entryPointsTCP, true)
serviceManager.LaunchHealthCheck()
serviceManager.LaunchHealthCheck(ctx)
// TCP
svcTCPManager := tcp.NewManager(rtConf)

View file

@ -10,7 +10,7 @@ import (
type serviceManager interface {
BuildHTTP(rootCtx context.Context, serviceName string) (http.Handler, error)
LaunchHealthCheck()
LaunchHealthCheck(ctx context.Context)
}
// InternalHandlers is the internal HTTP handlers builder.

View file

@ -4,7 +4,6 @@ import (
"container/heap"
"context"
"errors"
"fmt"
"net/http"
"sync"
@ -39,7 +38,7 @@ type Balancer struct {
curDeadline float64
// status is a record of which child services of the Balancer are healthy, keyed
// by name of child service. A service is initially added to the map when it is
// created via AddService, and it is later removed or added to the map as needed,
// created via Add, and it is later removed or added to the map as needed,
// through the SetStatus method.
status map[string]struct{}
// updaters is the list of hooks that are run (to update the Balancer
@ -48,10 +47,10 @@ type Balancer struct {
}
// New creates a new load balancer.
func New(sticky *dynamic.Sticky, hc *dynamic.HealthCheck) *Balancer {
func New(sticky *dynamic.Sticky, wantHealthCheck bool) *Balancer {
balancer := &Balancer{
status: make(map[string]struct{}),
wantsHealthCheck: hc != nil,
wantsHealthCheck: wantHealthCheck,
}
if sticky != nil && sticky.Cookie != nil {
balancer.stickyCookie = &stickyCookie{
@ -150,10 +149,7 @@ func (b *Balancer) nextServer() (*namedHandler, error) {
b.mutex.Lock()
defer b.mutex.Unlock()
if len(b.handlers) == 0 {
return nil, fmt.Errorf("no servers in the pool")
}
if len(b.status) == 0 {
if len(b.handlers) == 0 || len(b.status) == 0 {
return nil, errNoAvailableServer
}
@ -223,9 +219,9 @@ func (b *Balancer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
server.ServeHTTP(w, req)
}
// AddService adds a handler.
// Add adds a handler.
// A handler with a non-positive weight is ignored.
func (b *Balancer) AddService(name string, handler http.Handler, weight *int) {
func (b *Balancer) Add(name string, handler http.Handler, weight *int) {
w := 1
if weight != nil {
w = *weight

View file

@ -10,31 +10,15 @@ import (
"github.com/traefik/traefik/v2/pkg/config/dynamic"
)
func Int(v int) *int { return &v }
type responseRecorder struct {
*httptest.ResponseRecorder
save map[string]int
sequence []string
status []int
}
func (r *responseRecorder) WriteHeader(statusCode int) {
r.save[r.Header().Get("server")]++
r.sequence = append(r.sequence, r.Header().Get("server"))
r.status = append(r.status, statusCode)
r.ResponseRecorder.WriteHeader(statusCode)
}
func TestBalancer(t *testing.T) {
balancer := New(nil, nil)
balancer := New(nil, false)
balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "first")
rw.WriteHeader(http.StatusOK)
}), Int(3))
balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "second")
rw.WriteHeader(http.StatusOK)
}), Int(1))
@ -49,23 +33,23 @@ func TestBalancer(t *testing.T) {
}
func TestBalancerNoService(t *testing.T) {
balancer := New(nil, nil)
balancer := New(nil, false)
recorder := httptest.NewRecorder()
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
assert.Equal(t, http.StatusInternalServerError, recorder.Result().StatusCode)
assert.Equal(t, http.StatusServiceUnavailable, recorder.Result().StatusCode)
}
func TestBalancerOneServerZeroWeight(t *testing.T) {
balancer := New(nil, nil)
balancer := New(nil, false)
balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "first")
rw.WriteHeader(http.StatusOK)
}), Int(1))
balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0))
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0))
recorder := &responseRecorder{ResponseRecorder: httptest.NewRecorder(), save: map[string]int{}}
for i := 0; i < 3; i++ {
@ -80,13 +64,13 @@ type key string
const serviceName key = "serviceName"
func TestBalancerNoServiceUp(t *testing.T) {
balancer := New(nil, nil)
balancer := New(nil, false)
balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusInternalServerError)
}), Int(1))
balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusInternalServerError)
}), Int(1))
@ -100,14 +84,14 @@ func TestBalancerNoServiceUp(t *testing.T) {
}
func TestBalancerOneServerDown(t *testing.T) {
balancer := New(nil, nil)
balancer := New(nil, false)
balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "first")
rw.WriteHeader(http.StatusOK)
}), Int(1))
balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusInternalServerError)
}), Int(1))
balancer.SetStatus(context.WithValue(context.Background(), serviceName, "parent"), "second", false)
@ -121,14 +105,14 @@ func TestBalancerOneServerDown(t *testing.T) {
}
func TestBalancerDownThenUp(t *testing.T) {
balancer := New(nil, nil)
balancer := New(nil, false)
balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "first")
rw.WriteHeader(http.StatusOK)
}), Int(1))
balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "second")
rw.WriteHeader(http.StatusOK)
}), Int(1))
@ -150,35 +134,35 @@ func TestBalancerDownThenUp(t *testing.T) {
}
func TestBalancerPropagate(t *testing.T) {
balancer1 := New(nil, &dynamic.HealthCheck{})
balancer1 := New(nil, true)
balancer1.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
balancer1.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "first")
rw.WriteHeader(http.StatusOK)
}), Int(1))
balancer1.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
balancer1.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "second")
rw.WriteHeader(http.StatusOK)
}), Int(1))
balancer2 := New(nil, &dynamic.HealthCheck{})
balancer2.AddService("third", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
balancer2 := New(nil, true)
balancer2.Add("third", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "third")
rw.WriteHeader(http.StatusOK)
}), Int(1))
balancer2.AddService("fourth", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
balancer2.Add("fourth", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "fourth")
rw.WriteHeader(http.StatusOK)
}), Int(1))
topBalancer := New(nil, &dynamic.HealthCheck{})
topBalancer.AddService("balancer1", balancer1, Int(1))
topBalancer := New(nil, true)
topBalancer.Add("balancer1", balancer1, Int(1))
_ = balancer1.RegisterStatusUpdater(func(up bool) {
topBalancer.SetStatus(context.WithValue(context.Background(), serviceName, "top"), "balancer1", up)
// TODO(mpl): if test gets flaky, add channel or something here to signal that
// propagation is done, and wait on it before sending request.
})
topBalancer.AddService("balancer2", balancer2, Int(1))
topBalancer.Add("balancer2", balancer2, Int(1))
_ = balancer2.RegisterStatusUpdater(func(up bool) {
topBalancer.SetStatus(context.WithValue(context.Background(), serviceName, "top"), "balancer2", up)
})
@ -223,28 +207,28 @@ func TestBalancerPropagate(t *testing.T) {
}
func TestBalancerAllServersZeroWeight(t *testing.T) {
balancer := New(nil, nil)
balancer := New(nil, false)
balancer.AddService("test", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0))
balancer.AddService("test2", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0))
balancer.Add("test", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0))
balancer.Add("test2", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), Int(0))
recorder := httptest.NewRecorder()
balancer.ServeHTTP(recorder, httptest.NewRequest(http.MethodGet, "/", nil))
assert.Equal(t, http.StatusInternalServerError, recorder.Result().StatusCode)
assert.Equal(t, http.StatusServiceUnavailable, recorder.Result().StatusCode)
}
func TestSticky(t *testing.T) {
balancer := New(&dynamic.Sticky{
Cookie: &dynamic.Cookie{Name: "test"},
}, nil)
}, false)
balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "first")
rw.WriteHeader(http.StatusOK)
}), Int(1))
balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "second")
rw.WriteHeader(http.StatusOK)
}), Int(2))
@ -268,14 +252,14 @@ func TestSticky(t *testing.T) {
// TestBalancerBias makes sure that the WRR algorithm spreads elements evenly right from the start,
// and that it does not "over-favor" the high-weighted ones with a biased start-up regime.
func TestBalancerBias(t *testing.T) {
balancer := New(nil, nil)
balancer := New(nil, false)
balancer.AddService("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
balancer.Add("first", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "A")
rw.WriteHeader(http.StatusOK)
}), Int(11))
balancer.AddService("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
balancer.Add("second", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("server", "B")
rw.WriteHeader(http.StatusOK)
}), Int(3))
@ -290,3 +274,19 @@ func TestBalancerBias(t *testing.T) {
assert.Equal(t, wantSequence, recorder.sequence)
}
func Int(v int) *int { return &v }
type responseRecorder struct {
*httptest.ResponseRecorder
save map[string]int
sequence []string
status []int
}
func (r *responseRecorder) WriteHeader(statusCode int) {
r.save[r.Header().Get("server")]++
r.sequence = append(r.sequence, r.Header().Get("server"))
r.status = append(r.status, statusCode)
r.ResponseRecorder.WriteHeader(statusCode)
}

View file

@ -3,7 +3,6 @@ package service
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
@ -12,8 +11,6 @@ import (
"strings"
"time"
ptypes "github.com/traefik/paerser/types"
"github.com/traefik/traefik/v2/pkg/config/dynamic"
"github.com/traefik/traefik/v2/pkg/log"
"golang.org/x/net/http/httpguts"
)
@ -24,103 +21,107 @@ const StatusClientClosedRequest = 499
// StatusClientClosedRequestText non-standard HTTP status for client disconnection.
const StatusClientClosedRequestText = "Client Closed Request"
func buildProxy(passHostHeader *bool, responseForwarding *dynamic.ResponseForwarding, roundTripper http.RoundTripper, bufferPool httputil.BufferPool) (http.Handler, error) {
var flushInterval ptypes.Duration
if responseForwarding != nil {
err := flushInterval.Set(responseForwarding.FlushInterval)
if err != nil {
return nil, fmt.Errorf("error creating flush interval: %w", err)
}
}
if flushInterval == 0 {
flushInterval = ptypes.Duration(100 * time.Millisecond)
}
proxy := &httputil.ReverseProxy{
Director: func(outReq *http.Request) {
u := outReq.URL
if outReq.RequestURI != "" {
parsedURL, err := url.ParseRequestURI(outReq.RequestURI)
if err == nil {
u = parsedURL
}
}
outReq.URL.Path = u.Path
outReq.URL.RawPath = u.RawPath
outReq.URL.RawQuery = strings.ReplaceAll(u.RawQuery, ";", "&")
outReq.RequestURI = "" // Outgoing request should not have RequestURI
outReq.Proto = "HTTP/1.1"
outReq.ProtoMajor = 1
outReq.ProtoMinor = 1
if _, ok := outReq.Header["User-Agent"]; !ok {
outReq.Header.Set("User-Agent", "")
}
// Do not pass client Host header unless optsetter PassHostHeader is set.
if passHostHeader != nil && !*passHostHeader {
outReq.Host = outReq.URL.Host
}
// Even if the websocket RFC says that headers should be case-insensitive,
// some servers need Sec-WebSocket-Key, Sec-WebSocket-Extensions, Sec-WebSocket-Accept,
// Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive.
// https://tools.ietf.org/html/rfc6455#page-20
if isWebSocketUpgrade(outReq) {
outReq.Header["Sec-WebSocket-Key"] = outReq.Header["Sec-Websocket-Key"]
outReq.Header["Sec-WebSocket-Extensions"] = outReq.Header["Sec-Websocket-Extensions"]
outReq.Header["Sec-WebSocket-Accept"] = outReq.Header["Sec-Websocket-Accept"]
outReq.Header["Sec-WebSocket-Protocol"] = outReq.Header["Sec-Websocket-Protocol"]
outReq.Header["Sec-WebSocket-Version"] = outReq.Header["Sec-Websocket-Version"]
delete(outReq.Header, "Sec-Websocket-Key")
delete(outReq.Header, "Sec-Websocket-Extensions")
delete(outReq.Header, "Sec-Websocket-Accept")
delete(outReq.Header, "Sec-Websocket-Protocol")
delete(outReq.Header, "Sec-Websocket-Version")
}
},
func buildSingleHostProxy(target *url.URL, passHostHeader bool, flushInterval time.Duration, roundTripper http.RoundTripper, bufferPool httputil.BufferPool) http.Handler {
return &httputil.ReverseProxy{
Director: directorBuilder(target, passHostHeader),
Transport: roundTripper,
FlushInterval: time.Duration(flushInterval),
FlushInterval: flushInterval,
BufferPool: bufferPool,
ErrorHandler: func(w http.ResponseWriter, request *http.Request, err error) {
statusCode := http.StatusInternalServerError
ErrorHandler: errorHandler,
}
}
switch {
case errors.Is(err, io.EOF):
statusCode = http.StatusBadGateway
case errors.Is(err, context.Canceled):
statusCode = StatusClientClosedRequest
default:
var netErr net.Error
if errors.As(err, &netErr) {
if netErr.Timeout() {
statusCode = http.StatusGatewayTimeout
} else {
statusCode = http.StatusBadGateway
}
}
}
func directorBuilder(target *url.URL, passHostHeader bool) func(req *http.Request) {
return func(outReq *http.Request) {
outReq.URL.Scheme = target.Scheme
outReq.URL.Host = target.Host
log.Debugf("'%d %s' caused by: %v", statusCode, statusText(statusCode), err)
w.WriteHeader(statusCode)
_, werr := w.Write([]byte(statusText(statusCode)))
if werr != nil {
log.Debugf("Error while writing status code", werr)
u := outReq.URL
if outReq.RequestURI != "" {
parsedURL, err := url.ParseRequestURI(outReq.RequestURI)
if err == nil {
u = parsedURL
}
},
}
outReq.URL.Path = u.Path
outReq.URL.RawPath = u.RawPath
outReq.URL.RawQuery = strings.ReplaceAll(u.RawQuery, ";", "&")
outReq.RequestURI = "" // Outgoing request should not have RequestURI
outReq.Proto = "HTTP/1.1"
outReq.ProtoMajor = 1
outReq.ProtoMinor = 1
if _, ok := outReq.Header["User-Agent"]; !ok {
outReq.Header.Set("User-Agent", "")
}
// Do not pass client Host header unless PassHostHeader is set.
if !passHostHeader {
outReq.Host = outReq.URL.Host
}
cleanWebSocketHeaders(outReq)
}
}
// cleanWebSocketHeaders Even if the websocket RFC says that headers should be case-insensitive,
// some servers need Sec-WebSocket-Key, Sec-WebSocket-Extensions, Sec-WebSocket-Accept,
// Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive.
// https://tools.ietf.org/html/rfc6455#page-20
func cleanWebSocketHeaders(req *http.Request) {
if !isWebSocketUpgrade(req) {
return
}
return proxy, nil
req.Header["Sec-WebSocket-Key"] = req.Header["Sec-Websocket-Key"]
delete(req.Header, "Sec-Websocket-Key")
req.Header["Sec-WebSocket-Extensions"] = req.Header["Sec-Websocket-Extensions"]
delete(req.Header, "Sec-Websocket-Extensions")
req.Header["Sec-WebSocket-Accept"] = req.Header["Sec-Websocket-Accept"]
delete(req.Header, "Sec-Websocket-Accept")
req.Header["Sec-WebSocket-Protocol"] = req.Header["Sec-Websocket-Protocol"]
delete(req.Header, "Sec-Websocket-Protocol")
req.Header["Sec-WebSocket-Version"] = req.Header["Sec-Websocket-Version"]
delete(req.Header, "Sec-Websocket-Version")
}
func isWebSocketUpgrade(req *http.Request) bool {
if !httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") {
return false
return httpguts.HeaderValuesContainsToken(req.Header["Connection"], "Upgrade") &&
strings.EqualFold(req.Header.Get("Upgrade"), "websocket")
}
func errorHandler(w http.ResponseWriter, req *http.Request, err error) {
statusCode := http.StatusInternalServerError
switch {
case errors.Is(err, io.EOF):
statusCode = http.StatusBadGateway
case errors.Is(err, context.Canceled):
statusCode = StatusClientClosedRequest
default:
var netErr net.Error
if errors.As(err, &netErr) {
if netErr.Timeout() {
statusCode = http.StatusGatewayTimeout
} else {
statusCode = http.StatusBadGateway
}
}
}
return strings.EqualFold(req.Header.Get("Upgrade"), "websocket")
logger := log.FromContext(req.Context())
logger.Debugf("'%d %s' caused by: %v", statusCode, statusText(statusCode), err)
w.WriteHeader(statusCode)
if _, werr := w.Write([]byte(statusText(statusCode))); werr != nil {
logger.Debugf("Error while writing status code", werr)
}
}
func statusText(statusCode int) string {

View file

@ -28,7 +28,7 @@ func BenchmarkProxy(b *testing.B) {
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
pool := newBufferPool()
handler, _ := buildProxy(Bool(false), nil, &staticTransport{res}, pool)
handler := buildSingleHostProxy(req.URL, false, 0, &staticTransport{res}, pool)
b.ReportAllocs()
for i := 0; i < b.N; i++ {

View file

@ -18,12 +18,7 @@ import (
"golang.org/x/net/websocket"
)
func Bool(v bool) *bool { return &v }
func TestWebSocketTCPClose(t *testing.T) {
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
require.NoError(t, err)
errChan := make(chan error, 1)
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -42,7 +37,7 @@ func TestWebSocketTCPClose(t *testing.T) {
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
proxyAddr := proxy.Listener.Addr().String()
_, conn, err := newWebsocketRequest(
@ -61,10 +56,6 @@ func TestWebSocketTCPClose(t *testing.T) {
}
func TestWebSocketPingPong(t *testing.T) {
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
require.NoError(t, err)
upgrader := gorillawebsocket.Upgrader{
HandshakeTimeout: 10 * time.Second,
CheckOrigin: func(*http.Request) bool {
@ -86,17 +77,10 @@ func TestWebSocketPingPong(t *testing.T) {
_, _, _ = ws.ReadMessage()
})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
srv := httptest.NewServer(mux)
defer srv.Close()
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL = parseURI(t, srv.URL)
f.ServeHTTP(w, req)
}))
defer proxy.Close()
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
serverAddr := proxy.Listener.Addr().String()
headers := http.Header{}
@ -127,9 +111,6 @@ func TestWebSocketPingPong(t *testing.T) {
}
func TestWebSocketEcho(t *testing.T) {
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
require.NoError(t, err)
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
msg := make([]byte, 4)
@ -145,17 +126,10 @@ func TestWebSocketEcho(t *testing.T) {
require.NoError(t, err)
}))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
srv := httptest.NewServer(mux)
defer srv.Close()
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL = parseURI(t, srv.URL)
f.ServeHTTP(w, req)
}))
defer proxy.Close()
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
serverAddr := proxy.Listener.Addr().String()
headers := http.Header{}
@ -193,10 +167,6 @@ func TestWebSocketPassHost(t *testing.T) {
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
f, err := buildProxy(Bool(test.passHost), nil, http.DefaultTransport, nil)
require.NoError(t, err)
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
req := conn.Request()
@ -208,7 +178,7 @@ func TestWebSocketPassHost(t *testing.T) {
}
msg := make([]byte, 4)
_, err = conn.Read(msg)
_, err := conn.Read(msg)
require.NoError(t, err)
fmt.Println(string(msg))
@ -219,16 +189,10 @@ func TestWebSocketPassHost(t *testing.T) {
require.NoError(t, err)
}))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
srv := httptest.NewServer(mux)
defer srv.Close()
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL = parseURI(t, srv.URL)
f.ServeHTTP(w, req)
}))
defer proxy.Close()
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
serverAddr := proxy.Listener.Addr().String()
@ -252,9 +216,6 @@ func TestWebSocketPassHost(t *testing.T) {
}
func TestWebSocketServerWithoutCheckOrigin(t *testing.T) {
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
require.NoError(t, err)
upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool {
return true
}}
@ -277,7 +238,7 @@ func TestWebSocketServerWithoutCheckOrigin(t *testing.T) {
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
@ -293,9 +254,6 @@ func TestWebSocketServerWithoutCheckOrigin(t *testing.T) {
}
func TestWebSocketRequestWithOrigin(t *testing.T) {
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
require.NoError(t, err)
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
@ -316,11 +274,11 @@ func TestWebSocketRequestWithOrigin(t *testing.T) {
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
_, err = newWebsocketRequest(
_, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("echo"),
@ -339,9 +297,6 @@ func TestWebSocketRequestWithOrigin(t *testing.T) {
}
func TestWebSocketRequestWithQueryParams(t *testing.T) {
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
require.NoError(t, err)
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
@ -363,7 +318,7 @@ func TestWebSocketRequestWithQueryParams(t *testing.T) {
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
@ -379,18 +334,14 @@ func TestWebSocketRequestWithQueryParams(t *testing.T) {
}
func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
require.NoError(t, err)
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
conn.Close()
}))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
srv := httptest.NewServer(mux)
defer srv.Close()
f := buildSingleHostProxy(parseURI(t, srv.URL), true, 0, http.DefaultTransport, nil)
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL = parseURI(t, srv.URL)
w.Header().Set("HEADER-KEY", "HEADER-VALUE")
@ -403,6 +354,7 @@ func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
headers := http.Header{}
webSocketURL := "ws://" + serverAddr + "/ws"
headers.Add("Origin", webSocketURL)
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
require.NoError(t, err, "Error during Dial with response: %+v", err, resp)
defer conn.Close()
@ -411,9 +363,6 @@ func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
}
func TestWebSocketRequestWithEncodedChar(t *testing.T) {
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
require.NoError(t, err)
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
@ -435,7 +384,7 @@ func TestWebSocketRequestWithEncodedChar(t *testing.T) {
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
@ -451,18 +400,14 @@ func TestWebSocketRequestWithEncodedChar(t *testing.T) {
}
func TestWebSocketUpgradeFailed(t *testing.T) {
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
require.NoError(t, err)
mux := http.NewServeMux()
mux.HandleFunc("/ws", func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(http.StatusBadRequest)
})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
srv := httptest.NewServer(mux)
defer srv.Close()
f := buildSingleHostProxy(parseURI(t, srv.URL), true, 0, http.DefaultTransport, nil)
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path // keep the original path
@ -501,9 +446,6 @@ func TestWebSocketUpgradeFailed(t *testing.T) {
}
func TestForwardsWebsocketTraffic(t *testing.T) {
f, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
require.NoError(t, err)
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
_, err := conn.Write([]byte("ok"))
@ -512,12 +454,10 @@ func TestForwardsWebsocketTraffic(t *testing.T) {
err = conn.Close()
require.NoError(t, err)
}))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
srv := httptest.NewServer(mux)
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
proxy := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
@ -557,15 +497,12 @@ func TestWebSocketTransferTLSConfig(t *testing.T) {
srv := createTLSWebsocketServer()
defer srv.Close()
forwarderWithoutTLSConfig, err := buildProxy(Bool(true), nil, http.DefaultTransport, nil)
require.NoError(t, err)
proxyWithoutTLSConfig := createProxyWithForwarder(t, forwarderWithoutTLSConfig, srv.URL)
proxyWithoutTLSConfig := createProxyWithForwarder(t, srv.URL, http.DefaultTransport)
defer proxyWithoutTLSConfig.Close()
proxyAddr := proxyWithoutTLSConfig.Listener.Addr().String()
_, err = newWebsocketRequest(
_, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
@ -576,10 +513,8 @@ func TestWebSocketTransferTLSConfig(t *testing.T) {
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
forwarderWithTLSConfig, err := buildProxy(Bool(true), nil, transport, nil)
require.NoError(t, err)
proxyWithTLSConfig := createProxyWithForwarder(t, forwarderWithTLSConfig, srv.URL)
proxyWithTLSConfig := createProxyWithForwarder(t, srv.URL, transport)
defer proxyWithTLSConfig.Close()
proxyAddr = proxyWithTLSConfig.Listener.Addr().String()
@ -597,10 +532,7 @@ func TestWebSocketTransferTLSConfig(t *testing.T) {
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
defaultTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
forwarderWithTLSConfigFromDefaultTransport, err := buildProxy(Bool(true), nil, defaultTransport, nil)
require.NoError(t, err)
proxyWithTLSConfigFromDefaultTransport := createProxyWithForwarder(t, forwarderWithTLSConfigFromDefaultTransport, srv.URL)
proxyWithTLSConfigFromDefaultTransport := createProxyWithForwarder(t, srv.URL, defaultTransport)
defer proxyWithTLSConfig.Close()
proxyAddr = proxyWithTLSConfigFromDefaultTransport.Listener.Addr().String()
@ -705,15 +637,19 @@ func parseURI(t *testing.T, uri string) *url.URL {
return out
}
func createProxyWithForwarder(t *testing.T, proxy http.Handler, url string) *httptest.Server {
func createProxyWithForwarder(t *testing.T, uri string, transport http.RoundTripper) *httptest.Server {
t.Helper()
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
u := parseURI(t, uri)
proxy := buildSingleHostProxy(u, true, 0, transport, nil)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path // keep the original path
// Set new backend URL
req.URL = parseURI(t, url)
req.URL = u
req.URL.Path = path
proxy.ServeHTTP(w, req)
}))
t.Cleanup(srv.Close)
return srv
}

View file

@ -4,36 +4,28 @@ import (
"context"
"errors"
"fmt"
"hash/fnv"
"math/rand"
"net/http"
"net/http/httputil"
"net/url"
"reflect"
"strings"
"time"
"github.com/containous/alice"
"github.com/traefik/traefik/v2/pkg/config/dynamic"
"github.com/traefik/traefik/v2/pkg/config/runtime"
"github.com/traefik/traefik/v2/pkg/healthcheck"
"github.com/traefik/traefik/v2/pkg/log"
"github.com/traefik/traefik/v2/pkg/metrics"
"github.com/traefik/traefik/v2/pkg/middlewares/accesslog"
"github.com/traefik/traefik/v2/pkg/middlewares/emptybackendhandler"
metricsMiddle "github.com/traefik/traefik/v2/pkg/middlewares/metrics"
"github.com/traefik/traefik/v2/pkg/middlewares/pipelining"
"github.com/traefik/traefik/v2/pkg/safe"
"github.com/traefik/traefik/v2/pkg/server/cookie"
"github.com/traefik/traefik/v2/pkg/server/provider"
"github.com/traefik/traefik/v2/pkg/server/service/loadbalancer/failover"
"github.com/traefik/traefik/v2/pkg/server/service/loadbalancer/mirror"
"github.com/traefik/traefik/v2/pkg/server/service/loadbalancer/wrr"
"github.com/vulcand/oxy/roundrobin"
"github.com/vulcand/oxy/roundrobin/stickycookie"
)
const (
defaultHealthCheckInterval = 30 * time.Second
defaultHealthCheckTimeout = 5 * time.Second
)
const defaultMaxBodySize int64 = -1
@ -43,6 +35,19 @@ type RoundTripperGetter interface {
Get(name string) (http.RoundTripper, error)
}
// Manager The service manager.
type Manager struct {
routinePool *safe.Pool
metricsRegistry metrics.Registry
bufferPool httputil.BufferPool
roundTripperManager RoundTripperGetter
services map[string]http.Handler
configs map[string]*runtime.ServiceInfo
healthCheckers map[string]*healthcheck.ServiceHealthChecker
rand *rand.Rand // For the initial shuffling of load-balancers.
}
// NewManager creates a new Manager.
func NewManager(configs map[string]*runtime.ServiceInfo, metricsRegistry metrics.Registry, routinePool *safe.Pool, roundTripperManager RoundTripperGetter) *Manager {
return &Manager{
@ -50,27 +55,13 @@ func NewManager(configs map[string]*runtime.ServiceInfo, metricsRegistry metrics
metricsRegistry: metricsRegistry,
bufferPool: newBufferPool(),
roundTripperManager: roundTripperManager,
balancers: make(map[string]healthcheck.Balancers),
services: make(map[string]http.Handler),
configs: configs,
healthCheckers: make(map[string]*healthcheck.ServiceHealthChecker),
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
// Manager The service manager.
type Manager struct {
routinePool *safe.Pool
metricsRegistry metrics.Registry
bufferPool httputil.BufferPool
roundTripperManager RoundTripperGetter
// balancers is the map of all Balancers, keyed by service name.
// There is one Balancer per service handler, and there is one service handler per reference to a service
// (e.g. if 2 routers refer to the same service name, 2 service handlers are created),
// which is why there is not just one Balancer per service name.
balancers map[string]healthcheck.Balancers
configs map[string]*runtime.ServiceInfo
rand *rand.Rand // For the initial shuffling of load-balancers.
}
// BuildHTTP Creates a http.Handler for a service configuration.
func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string) (http.Handler, error) {
ctx := log.With(rootCtx, log.Str(log.ServiceName, serviceName))
@ -78,11 +69,20 @@ func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string) (http.H
serviceName = provider.GetQualifiedName(ctx, serviceName)
ctx = provider.AddInContext(ctx, serviceName)
handler, ok := m.services[serviceName]
if ok {
return handler, nil
}
conf, ok := m.configs[serviceName]
if !ok {
return nil, fmt.Errorf("the service %q does not exist", serviceName)
}
if conf.Status == runtime.StatusDisabled {
return nil, errors.New(strings.Join(conf.Err, ", "))
}
value := reflect.ValueOf(*conf.Service)
var count int
for i := 0; i < value.NumField(); i++ {
@ -101,7 +101,7 @@ func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string) (http.H
switch {
case conf.LoadBalancer != nil:
var err error
lb, err = m.getLoadBalancerServiceHandler(ctx, serviceName, conf.LoadBalancer)
lb, err = m.getLoadBalancerServiceHandler(ctx, serviceName, conf)
if err != nil {
conf.AddError(err, true)
return nil, err
@ -133,6 +133,8 @@ func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string) (http.H
return nil, sErr
}
m.services[serviceName] = lb
return lb, nil
}
@ -214,14 +216,14 @@ func (m *Manager) getWRRServiceHandler(ctx context.Context, serviceName string,
config.Sticky.Cookie.Name = cookie.GetName(config.Sticky.Cookie.Name, serviceName)
}
balancer := wrr.New(config.Sticky, config.HealthCheck)
balancer := wrr.New(config.Sticky, config.HealthCheck != nil)
for _, service := range shuffle(config.Services, m.rand) {
serviceHandler, err := m.BuildHTTP(ctx, service.Name)
if err != nil {
return nil, err
}
balancer.AddService(service.Name, serviceHandler, service.Weight)
balancer.Add(service.Name, serviceHandler, service.Weight)
if config.HealthCheck == nil {
continue
@ -245,216 +247,91 @@ func (m *Manager) getWRRServiceHandler(ctx context.Context, serviceName string,
return balancer, nil
}
func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName string, service *dynamic.ServersLoadBalancer) (http.Handler, error) {
if service.PassHostHeader == nil {
defaultPassHostHeader := true
service.PassHostHeader = &defaultPassHostHeader
func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName string, info *runtime.ServiceInfo) (http.Handler, error) {
service := info.LoadBalancer
logger := log.FromContext(ctx)
logger.Debug("Creating load-balancer")
// TODO: should we keep this config value as Go is now handling stream response correctly?
flushInterval := dynamic.DefaultFlushInterval
if service.ResponseForwarding != nil {
flushInterval = service.ResponseForwarding.FlushInterval
}
if len(service.ServersTransport) > 0 {
service.ServersTransport = provider.GetQualifiedName(ctx, service.ServersTransport)
}
if service.Sticky != nil && service.Sticky.Cookie != nil {
service.Sticky.Cookie.Name = cookie.GetName(service.Sticky.Cookie.Name, serviceName)
}
// We make sure that the PassHostHeader value is defined to avoid panics.
passHostHeader := dynamic.DefaultPassHostHeader
if service.PassHostHeader != nil {
passHostHeader = *service.PassHostHeader
}
roundTripper, err := m.roundTripperManager.Get(service.ServersTransport)
if err != nil {
return nil, err
}
fwd, err := buildProxy(service.PassHostHeader, service.ResponseForwarding, roundTripper, m.bufferPool)
if err != nil {
return nil, err
lb := wrr.New(service.Sticky, service.HealthCheck != nil)
healthCheckTargets := make(map[string]*url.URL)
for _, server := range shuffle(service.Servers, m.rand) {
hasher := fnv.New64a()
_, _ = hasher.Write([]byte(server.URL)) // this will never return an error.
proxyName := fmt.Sprintf("%x", hasher.Sum(nil))
target, err := url.Parse(server.URL)
if err != nil {
return nil, fmt.Errorf("error parsing server URL %s: %w", server.URL, err)
}
logger.WithField(log.ServerName, proxyName).Debugf("Creating server %s", target)
proxy := buildSingleHostProxy(target, passHostHeader, time.Duration(flushInterval), roundTripper, m.bufferPool)
proxy = accesslog.NewFieldHandler(proxy, accesslog.ServiceURL, target.String(), nil)
proxy = accesslog.NewFieldHandler(proxy, accesslog.ServiceAddr, target.Host, nil)
proxy = accesslog.NewFieldHandler(proxy, accesslog.ServiceName, serviceName, nil)
if m.metricsRegistry != nil && m.metricsRegistry.IsSvcEnabled() {
proxy = metricsMiddle.NewServiceMiddleware(ctx, proxy, m.metricsRegistry, serviceName)
}
lb.Add(proxyName, proxy, nil)
// servers are considered UP by default.
info.UpdateServerStatus(target.String(), runtime.StatusUp)
healthCheckTargets[proxyName] = target
}
alHandler := func(next http.Handler) (http.Handler, error) {
return accesslog.NewFieldHandler(next, accesslog.ServiceName, serviceName, accesslog.AddServiceFields), nil
}
chain := alice.New()
if m.metricsRegistry != nil && m.metricsRegistry.IsSvcEnabled() {
chain = chain.Append(metricsMiddle.WrapServiceHandler(ctx, m.metricsRegistry, serviceName))
if service.HealthCheck != nil {
m.healthCheckers[serviceName] = healthcheck.NewServiceHealthChecker(
ctx,
m.metricsRegistry,
service.HealthCheck,
lb,
info,
roundTripper,
healthCheckTargets,
)
}
handler, err := chain.Append(alHandler).Then(pipelining.New(ctx, fwd, "pipelining"))
if err != nil {
return nil, err
}
balancer, err := m.getLoadBalancer(ctx, serviceName, service, handler)
if err != nil {
return nil, err
}
// TODO rename and checks
m.balancers[serviceName] = append(m.balancers[serviceName], balancer)
// Empty (backend with no servers)
return emptybackendhandler.New(balancer), nil
return lb, nil
}
// LaunchHealthCheck launches the health checks.
func (m *Manager) LaunchHealthCheck() {
backendConfigs := make(map[string]*healthcheck.BackendConfig)
for serviceName, balancers := range m.balancers {
ctx := log.With(context.Background(), log.Str(log.ServiceName, serviceName))
service := m.configs[serviceName].LoadBalancer
// Health Check
hcOpts := buildHealthCheckOptions(ctx, balancers, serviceName, service.HealthCheck)
if hcOpts == nil {
continue
}
hcOpts.Transport, _ = m.roundTripperManager.Get(service.ServersTransport)
log.FromContext(ctx).Debugf("Setting up healthcheck for service %s with %s", serviceName, *hcOpts)
backendConfigs[serviceName] = healthcheck.NewBackendConfig(*hcOpts, serviceName)
}
healthcheck.GetHealthCheck(m.metricsRegistry).SetBackendsConfiguration(context.Background(), backendConfigs)
}
func buildHealthCheckOptions(ctx context.Context, lb healthcheck.Balancer, backend string, hc *dynamic.ServerHealthCheck) *healthcheck.Options {
if hc == nil {
return nil
}
logger := log.FromContext(ctx)
if hc.Path == "" {
logger.Errorf("Ignoring heath check configuration for '%s': no path provided", backend)
return nil
}
interval := defaultHealthCheckInterval
if hc.Interval != "" {
intervalOverride, err := time.ParseDuration(hc.Interval)
switch {
case err != nil:
logger.Errorf("Illegal health check interval for '%s': %s", backend, err)
case intervalOverride <= 0:
logger.Errorf("Health check interval smaller than zero for service '%s'", backend)
default:
interval = intervalOverride
}
}
timeout := defaultHealthCheckTimeout
if hc.Timeout != "" {
timeoutOverride, err := time.ParseDuration(hc.Timeout)
switch {
case err != nil:
logger.Errorf("Illegal health check timeout for backend '%s': %s", backend, err)
case timeoutOverride <= 0:
logger.Errorf("Health check timeout smaller than zero for backend '%s', backend", backend)
default:
timeout = timeoutOverride
}
}
if timeout >= interval {
logger.Warnf("Health check timeout for backend '%s' should be lower than the health check interval. Interval set to timeout + 1 second (%s).", backend, interval)
}
followRedirects := true
if hc.FollowRedirects != nil {
followRedirects = *hc.FollowRedirects
}
mode := healthcheck.HTTPMode
switch hc.Mode {
case "":
mode = healthcheck.HTTPMode
case healthcheck.GRPCMode, healthcheck.HTTPMode:
mode = hc.Mode
default:
logger.Errorf("Illegal health check mode for backend '%s'", backend)
}
return &healthcheck.Options{
Scheme: hc.Scheme,
Mode: mode,
Path: hc.Path,
Method: hc.Method,
Port: hc.Port,
Interval: interval,
Timeout: timeout,
LB: lb,
Hostname: hc.Hostname,
Headers: hc.Headers,
FollowRedirects: followRedirects,
}
}
func (m *Manager) getLoadBalancer(ctx context.Context, serviceName string, service *dynamic.ServersLoadBalancer, fwd http.Handler) (healthcheck.BalancerStatusHandler, error) {
logger := log.FromContext(ctx)
logger.Debug("Creating load-balancer")
var options []roundrobin.LBOption
var cookieName string
if service.Sticky != nil && service.Sticky.Cookie != nil {
cookieName = cookie.GetName(service.Sticky.Cookie.Name, serviceName)
opts := roundrobin.CookieOptions{
HTTPOnly: service.Sticky.Cookie.HTTPOnly,
Secure: service.Sticky.Cookie.Secure,
SameSite: convertSameSite(service.Sticky.Cookie.SameSite),
}
// Sticky Cookie Value
cv, err := stickycookie.NewFallbackValue(&stickycookie.RawValue{}, &stickycookie.HashValue{})
if err != nil {
return nil, err
}
options = append(options, roundrobin.EnableStickySession(roundrobin.NewStickySessionWithOptions(cookieName, opts).SetCookieValue(cv)))
logger.Debugf("Sticky session cookie name: %v", cookieName)
}
lb, err := roundrobin.New(fwd, options...)
if err != nil {
return nil, err
}
lbsu := healthcheck.NewLBStatusUpdater(lb, m.configs[serviceName], service.HealthCheck)
if err := m.upsertServers(ctx, lbsu, service.Servers); err != nil {
return nil, fmt.Errorf("error configuring load balancer for service %s: %w", serviceName, err)
}
return lbsu, nil
}
func (m *Manager) upsertServers(ctx context.Context, lb healthcheck.BalancerHandler, servers []dynamic.Server) error {
logger := log.FromContext(ctx)
for name, srv := range shuffle(servers, m.rand) {
u, err := url.Parse(srv.URL)
if err != nil {
return fmt.Errorf("error parsing server URL %s: %w", srv.URL, err)
}
logger.WithField(log.ServerName, name).Debugf("Creating server %d %s", name, u)
if err := lb.UpsertServer(u, roundrobin.Weight(1)); err != nil {
return fmt.Errorf("error adding server %s to load balancer: %w", srv.URL, err)
}
// TODO Handle Metrics
}
return nil
}
func convertSameSite(sameSite string) http.SameSite {
switch sameSite {
case "none":
return http.SameSiteNoneMode
case "lax":
return http.SameSiteLaxMode
case "strict":
return http.SameSiteStrictMode
default:
return 0
func (m *Manager) LaunchHealthCheck(ctx context.Context) {
for serviceName, hc := range m.healthCheckers {
ctx = log.With(ctx, log.Str(log.ServiceName, serviceName))
go hc.Launch(ctx)
}
}

View file

@ -15,14 +15,10 @@ import (
"github.com/traefik/traefik/v2/pkg/testhelpers"
)
type MockForwarder struct{}
func (MockForwarder) ServeHTTP(http.ResponseWriter, *http.Request) {
panic("implement me")
}
func TestGetLoadBalancer(t *testing.T) {
sm := Manager{}
sm := Manager{
roundTripperManager: newRtMock(),
}
testCases := []struct {
desc string
@ -67,7 +63,8 @@ func TestGetLoadBalancer(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
handler, err := sm.getLoadBalancer(context.Background(), test.serviceName, test.service, test.fwd)
serviceInfo := &runtime.ServiceInfo{Service: &dynamic.Service{LoadBalancer: test.service}}
handler, err := sm.getLoadBalancerServiceHandler(context.Background(), test.serviceName, serviceInfo)
if test.expectError {
require.Error(t, err)
assert.Nil(t, handler)
@ -129,6 +126,7 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) {
desc: "Load balances between the two servers",
serviceName: "test",
service: &dynamic.ServersLoadBalancer{
PassHostHeader: Bool(true),
Servers: []dynamic.Server{
{
URL: server1.URL,
@ -258,40 +256,13 @@ func TestGetLoadBalancerServiceHandler(t *testing.T) {
},
},
},
{
desc: "Cookie value is backward compatible",
serviceName: "test",
service: &dynamic.ServersLoadBalancer{
Sticky: &dynamic.Sticky{
Cookie: &dynamic.Cookie{},
},
Servers: []dynamic.Server{
{
URL: server1.URL,
},
{
URL: server2.URL,
},
},
},
cookieRawValue: "_6f743=" + server1.URL,
expected: []ExpectedResult{
{
StatusCode: http.StatusOK,
XFrom: "first",
},
{
StatusCode: http.StatusOK,
XFrom: "first",
},
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
handler, err := sm.getLoadBalancerServiceHandler(context.Background(), test.serviceName, test.service)
serviceInfo := &runtime.ServiceInfo{Service: &dynamic.Service{LoadBalancer: test.service}}
handler, err := sm.getLoadBalancerServiceHandler(context.Background(), test.serviceName, serviceInfo)
assert.NoError(t, err)
assert.NotNil(t, handler)
@ -419,3 +390,21 @@ func TestMultipleTypeOnBuildHTTP(t *testing.T) {
_, err := manager.BuildHTTP(context.Background(), "test@file")
assert.Error(t, err, "cannot create service: multi-types service not supported, consider declaring two different pieces of service instead")
}
func Bool(v bool) *bool { return &v }
type MockForwarder struct{}
func (MockForwarder) ServeHTTP(http.ResponseWriter, *http.Request) {
panic("not available")
}
type rtMock struct{}
func newRtMock() RoundTripperGetter {
return &rtMock{}
}
func (r *rtMock) Get(_ string) (http.RoundTripper, error) {
return http.DefaultTransport, nil
}