1
0
Fork 0

Add TCP Healthcheck

This commit is contained in:
Douglas De Toni Machado 2025-10-22 06:42:05 -03:00 committed by GitHub
parent d1ab6ed489
commit 8392503df7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 2416 additions and 307 deletions

View file

@ -33,16 +33,21 @@ type serviceInfoRepresentation struct {
ServerStatus map[string]string `json:"serverStatus,omitempty"`
}
type tcpServiceInfoRepresentation struct {
*runtime.TCPServiceInfo
ServerStatus map[string]string `json:"serverStatus,omitempty"`
}
// RunTimeRepresentation is the configuration information exposed by the API handler.
type RunTimeRepresentation struct {
Routers map[string]*runtime.RouterInfo `json:"routers,omitempty"`
Middlewares map[string]*runtime.MiddlewareInfo `json:"middlewares,omitempty"`
Services map[string]*serviceInfoRepresentation `json:"services,omitempty"`
TCPRouters map[string]*runtime.TCPRouterInfo `json:"tcpRouters,omitempty"`
TCPMiddlewares map[string]*runtime.TCPMiddlewareInfo `json:"tcpMiddlewares,omitempty"`
TCPServices map[string]*runtime.TCPServiceInfo `json:"tcpServices,omitempty"`
UDPRouters map[string]*runtime.UDPRouterInfo `json:"udpRouters,omitempty"`
UDPServices map[string]*runtime.UDPServiceInfo `json:"udpServices,omitempty"`
Routers map[string]*runtime.RouterInfo `json:"routers,omitempty"`
Middlewares map[string]*runtime.MiddlewareInfo `json:"middlewares,omitempty"`
Services map[string]*serviceInfoRepresentation `json:"services,omitempty"`
TCPRouters map[string]*runtime.TCPRouterInfo `json:"tcpRouters,omitempty"`
TCPMiddlewares map[string]*runtime.TCPMiddlewareInfo `json:"tcpMiddlewares,omitempty"`
TCPServices map[string]*tcpServiceInfoRepresentation `json:"tcpServices,omitempty"`
UDPRouters map[string]*runtime.UDPRouterInfo `json:"udpRouters,omitempty"`
UDPServices map[string]*runtime.UDPServiceInfo `json:"udpServices,omitempty"`
}
// Handler serves the configuration and status of Traefik on API endpoints.
@ -127,13 +132,21 @@ func (h Handler) getRuntimeConfiguration(rw http.ResponseWriter, request *http.R
}
}
tcpSIRepr := make(map[string]*tcpServiceInfoRepresentation, len(h.runtimeConfiguration.Services))
for k, v := range h.runtimeConfiguration.TCPServices {
tcpSIRepr[k] = &tcpServiceInfoRepresentation{
TCPServiceInfo: v,
ServerStatus: v.GetAllStatus(),
}
}
result := RunTimeRepresentation{
Routers: h.runtimeConfiguration.Routers,
Middlewares: h.runtimeConfiguration.Middlewares,
Services: siRepr,
TCPRouters: h.runtimeConfiguration.TCPRouters,
TCPMiddlewares: h.runtimeConfiguration.TCPMiddlewares,
TCPServices: h.runtimeConfiguration.TCPServices,
TCPServices: tcpSIRepr,
UDPRouters: h.runtimeConfiguration.UDPRouters,
UDPServices: h.runtimeConfiguration.UDPServices,
}

View file

@ -34,10 +34,10 @@ func newRouterRepresentation(name string, rt *runtime.RouterInfo) routerRepresen
type serviceRepresentation struct {
*runtime.ServiceInfo
ServerStatus map[string]string `json:"serverStatus,omitempty"`
Name string `json:"name,omitempty"`
Provider string `json:"provider,omitempty"`
Type string `json:"type,omitempty"`
ServerStatus map[string]string `json:"serverStatus,omitempty"`
}
func newServiceRepresentation(name string, si *runtime.ServiceInfo) serviceRepresentation {
@ -45,8 +45,8 @@ func newServiceRepresentation(name string, si *runtime.ServiceInfo) serviceRepre
ServiceInfo: si,
Name: name,
Provider: getProviderName(name),
ServerStatus: si.GetAllStatus(),
Type: strings.ToLower(extractType(si.Service)),
ServerStatus: si.GetAllStatus(),
}
}

View file

@ -29,9 +29,10 @@ func newTCPRouterRepresentation(name string, rt *runtime.TCPRouterInfo) tcpRoute
type tcpServiceRepresentation struct {
*runtime.TCPServiceInfo
Name string `json:"name,omitempty"`
Provider string `json:"provider,omitempty"`
Type string `json:"type,omitempty"`
Name string `json:"name,omitempty"`
Provider string `json:"provider,omitempty"`
Type string `json:"type,omitempty"`
ServerStatus map[string]string `json:"serverStatus,omitempty"`
}
func newTCPServiceRepresentation(name string, si *runtime.TCPServiceInfo) tcpServiceRepresentation {
@ -40,6 +41,7 @@ func newTCPServiceRepresentation(name string, si *runtime.TCPServiceInfo) tcpSer
Name: name,
Provider: getProviderName(name),
Type: strings.ToLower(extractType(si.TCPService)),
ServerStatus: si.GetAllStatus(),
}
}

View file

@ -355,45 +355,57 @@ func TestHandler_TCP(t *testing.T) {
path: "/api/tcp/services",
conf: runtime.Configuration{
TCPServices: map[string]*runtime.TCPServiceInfo{
"bar@myprovider": {
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.1:2345",
"bar@myprovider": func() *runtime.TCPServiceInfo {
si := &runtime.TCPServiceInfo{
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.1:2345",
},
},
},
},
},
UsedBy: []string{"foo@myprovider", "test@myprovider"},
Status: runtime.StatusEnabled,
},
"baz@myprovider": {
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.2:2345",
UsedBy: []string{"foo@myprovider", "test@myprovider"},
Status: runtime.StatusEnabled,
}
si.UpdateServerStatus("127.0.0.1:2345", "UP")
return si
}(),
"baz@myprovider": func() *runtime.TCPServiceInfo {
si := &runtime.TCPServiceInfo{
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.2:2345",
},
},
},
},
},
UsedBy: []string{"foo@myprovider"},
Status: runtime.StatusWarning,
},
"foz@myprovider": {
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.2:2345",
UsedBy: []string{"foo@myprovider"},
Status: runtime.StatusWarning,
}
si.UpdateServerStatus("127.0.0.2:2345", "UP")
return si
}(),
"foz@myprovider": func() *runtime.TCPServiceInfo {
si := &runtime.TCPServiceInfo{
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.3:2345",
},
},
},
},
},
UsedBy: []string{"foo@myprovider"},
Status: runtime.StatusDisabled,
},
UsedBy: []string{"foo@myprovider"},
Status: runtime.StatusDisabled,
}
si.UpdateServerStatus("127.0.0.3:2345", "UP")
return si
}(),
},
},
expected: expected{
@ -407,45 +419,57 @@ func TestHandler_TCP(t *testing.T) {
path: "/api/tcp/services?status=enabled",
conf: runtime.Configuration{
TCPServices: map[string]*runtime.TCPServiceInfo{
"bar@myprovider": {
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.1:2345",
"bar@myprovider": func() *runtime.TCPServiceInfo {
si := &runtime.TCPServiceInfo{
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.1:2345",
},
},
},
},
},
UsedBy: []string{"foo@myprovider", "test@myprovider"},
Status: runtime.StatusEnabled,
},
"baz@myprovider": {
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.2:2345",
UsedBy: []string{"foo@myprovider", "test@myprovider"},
Status: runtime.StatusEnabled,
}
si.UpdateServerStatus("127.0.0.1:2345", "UP")
return si
}(),
"baz@myprovider": func() *runtime.TCPServiceInfo {
si := &runtime.TCPServiceInfo{
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.2:2345",
},
},
},
},
},
UsedBy: []string{"foo@myprovider"},
Status: runtime.StatusWarning,
},
"foz@myprovider": {
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.2:2345",
UsedBy: []string{"foo@myprovider"},
Status: runtime.StatusWarning,
}
si.UpdateServerStatus("127.0.0.2:2345", "UP")
return si
}(),
"foz@myprovider": func() *runtime.TCPServiceInfo {
si := &runtime.TCPServiceInfo{
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.3:2345",
},
},
},
},
},
UsedBy: []string{"foo@myprovider"},
Status: runtime.StatusDisabled,
},
UsedBy: []string{"foo@myprovider"},
Status: runtime.StatusDisabled,
}
si.UpdateServerStatus("127.0.0.3:2345", "UP")
return si
}(),
},
},
expected: expected{
@ -459,45 +483,57 @@ func TestHandler_TCP(t *testing.T) {
path: "/api/tcp/services?search=baz@my",
conf: runtime.Configuration{
TCPServices: map[string]*runtime.TCPServiceInfo{
"bar@myprovider": {
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.1:2345",
"bar@myprovider": func() *runtime.TCPServiceInfo {
si := &runtime.TCPServiceInfo{
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.1:2345",
},
},
},
},
},
UsedBy: []string{"foo@myprovider", "test@myprovider"},
Status: runtime.StatusEnabled,
},
"baz@myprovider": {
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.2:2345",
UsedBy: []string{"foo@myprovider", "test@myprovider"},
Status: runtime.StatusEnabled,
}
si.UpdateServerStatus("127.0.0.1:2345", "UP")
return si
}(),
"baz@myprovider": func() *runtime.TCPServiceInfo {
si := &runtime.TCPServiceInfo{
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.2:2345",
},
},
},
},
},
UsedBy: []string{"foo@myprovider"},
Status: runtime.StatusWarning,
},
"foz@myprovider": {
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.2:2345",
UsedBy: []string{"foo@myprovider"},
Status: runtime.StatusWarning,
}
si.UpdateServerStatus("127.0.0.2:2345", "UP")
return si
}(),
"foz@myprovider": func() *runtime.TCPServiceInfo {
si := &runtime.TCPServiceInfo{
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.3:2345",
},
},
},
},
},
UsedBy: []string{"foo@myprovider"},
Status: runtime.StatusDisabled,
},
UsedBy: []string{"foo@myprovider"},
Status: runtime.StatusDisabled,
}
si.UpdateServerStatus("127.0.0.3:2345", "UP")
return si
}(),
},
},
expected: expected{
@ -511,41 +547,53 @@ func TestHandler_TCP(t *testing.T) {
path: "/api/tcp/services?page=2&per_page=1",
conf: runtime.Configuration{
TCPServices: map[string]*runtime.TCPServiceInfo{
"bar@myprovider": {
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.1:2345",
"bar@myprovider": func() *runtime.TCPServiceInfo {
si := &runtime.TCPServiceInfo{
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.1:2345",
},
},
},
},
},
UsedBy: []string{"foo@myprovider", "test@myprovider"},
},
"baz@myprovider": {
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.2:2345",
UsedBy: []string{"foo@myprovider", "test@myprovider"},
}
si.UpdateServerStatus("127.0.0.1:2345", "UP")
return si
}(),
"baz@myprovider": func() *runtime.TCPServiceInfo {
si := &runtime.TCPServiceInfo{
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.2:2345",
},
},
},
},
},
UsedBy: []string{"foo@myprovider"},
},
"test@myprovider": {
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.3:2345",
UsedBy: []string{"foo@myprovider"},
}
si.UpdateServerStatus("127.0.0.2:2345", "UP")
return si
}(),
"test@myprovider": func() *runtime.TCPServiceInfo {
si := &runtime.TCPServiceInfo{
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.3:2345",
},
},
},
},
},
},
}
si.UpdateServerStatus("127.0.0.3:2345", "UP")
return si
}(),
},
},
expected: expected{
@ -559,18 +607,22 @@ func TestHandler_TCP(t *testing.T) {
path: "/api/tcp/services/bar@myprovider",
conf: runtime.Configuration{
TCPServices: map[string]*runtime.TCPServiceInfo{
"bar@myprovider": {
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.1:2345",
"bar@myprovider": func() *runtime.TCPServiceInfo {
si := &runtime.TCPServiceInfo{
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.1:2345",
},
},
},
},
},
UsedBy: []string{"foo@myprovider", "test@myprovider"},
},
UsedBy: []string{"foo@myprovider", "test@myprovider"},
}
si.UpdateServerStatus("127.0.0.1:2345", "UP")
return si
}(),
},
},
expected: expected{
@ -583,18 +635,22 @@ func TestHandler_TCP(t *testing.T) {
path: "/api/tcp/services/" + url.PathEscape("foo / bar@myprovider"),
conf: runtime.Configuration{
TCPServices: map[string]*runtime.TCPServiceInfo{
"foo / bar@myprovider": {
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.1:2345",
"foo / bar@myprovider": func() *runtime.TCPServiceInfo {
si := &runtime.TCPServiceInfo{
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.1:2345",
},
},
},
},
},
UsedBy: []string{"foo@myprovider", "test@myprovider"},
},
UsedBy: []string{"foo@myprovider", "test@myprovider"},
}
si.UpdateServerStatus("127.0.0.1:2345", "UP")
return si
}(),
},
},
expected: expected{
@ -607,18 +663,22 @@ func TestHandler_TCP(t *testing.T) {
path: "/api/tcp/services/nono@myprovider",
conf: runtime.Configuration{
TCPServices: map[string]*runtime.TCPServiceInfo{
"bar@myprovider": {
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.1:2345",
"bar@myprovider": func() *runtime.TCPServiceInfo {
si := &runtime.TCPServiceInfo{
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "127.0.0.1:2345",
},
},
},
},
},
UsedBy: []string{"foo@myprovider", "test@myprovider"},
},
UsedBy: []string{"foo@myprovider", "test@myprovider"},
}
si.UpdateServerStatus("127.0.0.1:2345", "UP")
return si
}(),
},
},
expected: expected{

View file

@ -8,6 +8,9 @@
},
"name": "bar@myprovider",
"provider": "myprovider",
"serverStatus": {
"127.0.0.1:2345": "UP"
},
"status": "enabled",
"type": "loadbalancer",
"usedBy": [

View file

@ -8,6 +8,9 @@
},
"name": "foo / bar@myprovider",
"provider": "myprovider",
"serverStatus": {
"127.0.0.1:2345": "UP"
},
"status": "enabled",
"type": "loadbalancer",
"usedBy": [

View file

@ -9,6 +9,9 @@
},
"name": "baz@myprovider",
"provider": "myprovider",
"serverStatus": {
"127.0.0.2:2345": "UP"
},
"status": "warning",
"type": "loadbalancer",
"usedBy": [

View file

@ -9,6 +9,9 @@
},
"name": "bar@myprovider",
"provider": "myprovider",
"serverStatus": {
"127.0.0.1:2345": "UP"
},
"status": "enabled",
"type": "loadbalancer",
"usedBy": [

View file

@ -9,6 +9,9 @@
},
"name": "baz@myprovider",
"provider": "myprovider",
"serverStatus": {
"127.0.0.2:2345": "UP"
},
"status": "enabled",
"type": "loadbalancer",
"usedBy": [

View file

@ -9,6 +9,9 @@
},
"name": "bar@myprovider",
"provider": "myprovider",
"serverStatus": {
"127.0.0.1:2345": "UP"
},
"status": "enabled",
"type": "loadbalancer",
"usedBy": [
@ -26,6 +29,9 @@
},
"name": "baz@myprovider",
"provider": "myprovider",
"serverStatus": {
"127.0.0.2:2345": "UP"
},
"status": "warning",
"type": "loadbalancer",
"usedBy": [
@ -36,12 +42,15 @@
"loadBalancer": {
"servers": [
{
"address": "127.0.0.2:2345"
"address": "127.0.0.3:2345"
}
]
},
"name": "foz@myprovider",
"provider": "myprovider",
"serverStatus": {
"127.0.0.3:2345": "UP"
},
"status": "disabled",
"type": "loadbalancer",
"usedBy": [

View file

@ -39,7 +39,8 @@ type TCPService struct {
// TCPWeightedRoundRobin is a weighted round robin tcp load-balancer of services.
type TCPWeightedRoundRobin struct {
Services []TCPWRRService `json:"services,omitempty" toml:"services,omitempty" yaml:"services,omitempty" export:"true"`
Services []TCPWRRService `json:"services,omitempty" toml:"services,omitempty" yaml:"services,omitempty" export:"true"`
HealthCheck *HealthCheck `json:"healthCheck,omitempty" toml:"healthCheck,omitempty" yaml:"healthCheck,omitempty" label:"allowEmpty" file:"allowEmpty" kv:"allowEmpty" export:"true"`
}
// +k8s:deepcopy-gen=true
@ -86,7 +87,6 @@ type RouterTCPTLSConfig struct {
type TCPServersLoadBalancer struct {
Servers []TCPServer `json:"servers,omitempty" toml:"servers,omitempty" yaml:"servers,omitempty" label-slice-as-struct:"server" export:"true"`
ServersTransport string `json:"serversTransport,omitempty" toml:"serversTransport,omitempty" yaml:"serversTransport,omitempty" export:"true"`
// ProxyProtocol holds the PROXY Protocol configuration.
// Deprecated: use ServersTransport to configure ProxyProtocol instead.
ProxyProtocol *ProxyProtocol `json:"proxyProtocol,omitempty" toml:"proxyProtocol,omitempty" yaml:"proxyProtocol,omitempty" label:"allowEmpty" file:"allowEmpty" kv:"allowEmpty" export:"true"`
@ -96,7 +96,8 @@ type TCPServersLoadBalancer struct {
// connection. It is a duration in milliseconds, defaulting to 100. A negative value
// means an infinite deadline (i.e. the reading capability is never closed).
// Deprecated: use ServersTransport to configure the TerminationDelay instead.
TerminationDelay *int `json:"terminationDelay,omitempty" toml:"terminationDelay,omitempty" yaml:"terminationDelay,omitempty" export:"true"`
TerminationDelay *int `json:"terminationDelay,omitempty" toml:"terminationDelay,omitempty" yaml:"terminationDelay,omitempty" export:"true"`
HealthCheck *TCPServerHealthCheck `json:"healthCheck,omitempty" toml:"healthCheck,omitempty" yaml:"healthCheck,omitempty" label:"allowEmpty" file:"allowEmpty" kv:"allowEmpty" export:"true"`
}
// Mergeable tells if the given service is mergeable.
@ -176,3 +177,21 @@ type TLSClientConfig struct {
PeerCertURI string `description:"Defines the URI used to match against SAN URI during the peer certificate verification." json:"peerCertURI,omitempty" toml:"peerCertURI,omitempty" yaml:"peerCertURI,omitempty" export:"true"`
Spiffe *Spiffe `description:"Defines the SPIFFE TLS configuration." json:"spiffe,omitempty" toml:"spiffe,omitempty" yaml:"spiffe,omitempty" label:"allowEmpty" file:"allowEmpty" export:"true"`
}
// +k8s:deepcopy-gen=true
// TCPServerHealthCheck holds the HealthCheck configuration.
type TCPServerHealthCheck struct {
Port int `json:"port,omitempty" toml:"port,omitempty,omitzero" yaml:"port,omitempty" export:"true"`
Send string `json:"send,omitempty" toml:"send,omitempty" yaml:"send,omitempty" export:"true"`
Expect string `json:"expect,omitempty" toml:"expect,omitempty" yaml:"expect,omitempty" export:"true"`
Interval ptypes.Duration `json:"interval,omitempty" toml:"interval,omitempty" yaml:"interval,omitempty" export:"true"`
UnhealthyInterval *ptypes.Duration `json:"unhealthyInterval,omitempty" toml:"unhealthyInterval,omitempty" yaml:"unhealthyInterval,omitempty" export:"true"`
Timeout ptypes.Duration `json:"timeout,omitempty" toml:"timeout,omitempty" yaml:"timeout,omitempty" export:"true"`
}
// SetDefaults sets the default values for a TCPServerHealthCheck.
func (t *TCPServerHealthCheck) SetDefaults() {
t.Interval = DefaultHealthCheckInterval
t.Timeout = DefaultHealthCheckTimeout
}

View file

@ -2001,6 +2001,27 @@ func (in *TCPServer) DeepCopy() *TCPServer {
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *TCPServerHealthCheck) DeepCopyInto(out *TCPServerHealthCheck) {
*out = *in
if in.UnhealthyInterval != nil {
in, out := &in.UnhealthyInterval, &out.UnhealthyInterval
*out = new(paersertypes.Duration)
**out = **in
}
return
}
// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new TCPServerHealthCheck.
func (in *TCPServerHealthCheck) DeepCopy() *TCPServerHealthCheck {
if in == nil {
return nil
}
out := new(TCPServerHealthCheck)
in.DeepCopyInto(out)
return out
}
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
func (in *TCPServersLoadBalancer) DeepCopyInto(out *TCPServersLoadBalancer) {
*out = *in
@ -2019,6 +2040,11 @@ func (in *TCPServersLoadBalancer) DeepCopyInto(out *TCPServersLoadBalancer) {
*out = new(int)
**out = **in
}
if in.HealthCheck != nil {
in, out := &in.HealthCheck, &out.HealthCheck
*out = new(TCPServerHealthCheck)
(*in).DeepCopyInto(*out)
}
return
}
@ -2115,6 +2141,11 @@ func (in *TCPWeightedRoundRobin) DeepCopyInto(out *TCPWeightedRoundRobin) {
(*in)[i].DeepCopyInto(&(*out)[i])
}
}
if in.HealthCheck != nil {
in, out := &in.HealthCheck, &out.HealthCheck
*out = new(HealthCheck)
**out = **in
}
return
}

View file

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"slices"
"sync"
"github.com/rs/zerolog/log"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
@ -87,6 +88,9 @@ type TCPServiceInfo struct {
// It is the caller's responsibility to set the initial status.
Status string `json:"status,omitempty"`
UsedBy []string `json:"usedBy,omitempty"` // list of routers using that service
serverStatusMu sync.RWMutex
serverStatus map[string]string // keyed by server address
}
// AddError adds err to s.Err, if it does not already exist.
@ -110,6 +114,33 @@ func (s *TCPServiceInfo) AddError(err error, critical bool) {
}
}
// UpdateServerStatus sets the status of the server in the TCPServiceInfo.
func (s *TCPServiceInfo) UpdateServerStatus(server, status string) {
s.serverStatusMu.Lock()
defer s.serverStatusMu.Unlock()
if s.serverStatus == nil {
s.serverStatus = make(map[string]string)
}
s.serverStatus[server] = status
}
// GetAllStatus returns all the statuses of all the servers in TCPServiceInfo.
func (s *TCPServiceInfo) GetAllStatus() map[string]string {
s.serverStatusMu.RLock()
defer s.serverStatusMu.RUnlock()
if len(s.serverStatus) == 0 {
return nil
}
allStatus := make(map[string]string, len(s.serverStatus))
for k, v := range s.serverStatus {
allStatus[k] = v
}
return allStatus
}
// TCPMiddlewareInfo holds information about a currently running middleware.
type TCPMiddlewareInfo struct {
*dynamic.TCPMiddleware // dynamic configuration

View file

@ -0,0 +1,212 @@
package healthcheck
import (
"context"
"errors"
"fmt"
"net"
"strconv"
"time"
"github.com/rs/zerolog/log"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/config/runtime"
"github.com/traefik/traefik/v3/pkg/tcp"
)
// maxPayloadSize is the maximum payload size that can be sent during health checks.
const maxPayloadSize = 65535
type TCPHealthCheckTarget struct {
Address string
TLS bool
Dialer tcp.Dialer
}
type ServiceTCPHealthChecker struct {
balancer StatusSetter
info *runtime.TCPServiceInfo
config *dynamic.TCPServerHealthCheck
interval time.Duration
unhealthyInterval time.Duration
timeout time.Duration
healthyTargets chan *TCPHealthCheckTarget
unhealthyTargets chan *TCPHealthCheckTarget
serviceName string
}
func NewServiceTCPHealthChecker(ctx context.Context, config *dynamic.TCPServerHealthCheck, service StatusSetter, info *runtime.TCPServiceInfo, targets []TCPHealthCheckTarget, serviceName string) *ServiceTCPHealthChecker {
logger := log.Ctx(ctx)
interval := time.Duration(config.Interval)
if interval <= 0 {
logger.Error().Msg("Health check interval smaller than zero, default value will be used instead.")
interval = time.Duration(dynamic.DefaultHealthCheckInterval)
}
// If the unhealthyInterval option is not set, we use the interval option value,
// to check the unhealthy targets as often as the healthy ones.
var unhealthyInterval time.Duration
if config.UnhealthyInterval == nil {
unhealthyInterval = interval
} else {
unhealthyInterval = time.Duration(*config.UnhealthyInterval)
if unhealthyInterval <= 0 {
logger.Error().Msg("Health check unhealthy interval smaller than zero, default value will be used instead.")
unhealthyInterval = time.Duration(dynamic.DefaultHealthCheckInterval)
}
}
timeout := time.Duration(config.Timeout)
if timeout <= 0 {
logger.Error().Msg("Health check timeout smaller than zero, default value will be used instead.")
timeout = time.Duration(dynamic.DefaultHealthCheckTimeout)
}
if config.Send != "" && len(config.Send) > maxPayloadSize {
logger.Error().Msgf("Health check payload size exceeds maximum allowed size of %d bytes, falling back to connect only check.", maxPayloadSize)
config.Send = ""
}
if config.Expect != "" && len(config.Expect) > maxPayloadSize {
logger.Error().Msgf("Health check expected response size exceeds maximum allowed size of %d bytes, falling back to close without response.", maxPayloadSize)
config.Expect = ""
}
healthyTargets := make(chan *TCPHealthCheckTarget, len(targets))
for _, target := range targets {
healthyTargets <- &target
}
unhealthyTargets := make(chan *TCPHealthCheckTarget, len(targets))
return &ServiceTCPHealthChecker{
balancer: service,
info: info,
config: config,
interval: interval,
unhealthyInterval: unhealthyInterval,
timeout: timeout,
healthyTargets: healthyTargets,
unhealthyTargets: unhealthyTargets,
serviceName: serviceName,
}
}
func (thc *ServiceTCPHealthChecker) Launch(ctx context.Context) {
go thc.healthcheck(ctx, thc.unhealthyTargets, thc.unhealthyInterval)
thc.healthcheck(ctx, thc.healthyTargets, thc.interval)
}
func (thc *ServiceTCPHealthChecker) healthcheck(ctx context.Context, targets chan *TCPHealthCheckTarget, interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
// We collect the targets to check once for all,
// to avoid rechecking a target that has been moved during the health check.
var targetsToCheck []*TCPHealthCheckTarget
hasMoreTargets := true
for hasMoreTargets {
select {
case <-ctx.Done():
return
case target := <-targets:
targetsToCheck = append(targetsToCheck, target)
default:
hasMoreTargets = false
}
}
// Now we can check the targets.
for _, target := range targetsToCheck {
select {
case <-ctx.Done():
return
default:
}
up := true
if err := thc.executeHealthCheck(ctx, thc.config, target); err != nil {
// The context is canceled when the dynamic configuration is refreshed.
if errors.Is(err, context.Canceled) {
return
}
log.Ctx(ctx).Warn().
Str("targetAddress", target.Address).
Err(err).
Msg("Health check failed.")
up = false
}
thc.balancer.SetStatus(ctx, target.Address, up)
var statusStr string
if up {
statusStr = runtime.StatusUp
thc.healthyTargets <- target
} else {
statusStr = runtime.StatusDown
thc.unhealthyTargets <- target
}
thc.info.UpdateServerStatus(target.Address, statusStr)
// TODO: add a TCP server up metric (like for HTTP).
}
}
}
}
func (thc *ServiceTCPHealthChecker) executeHealthCheck(ctx context.Context, config *dynamic.TCPServerHealthCheck, target *TCPHealthCheckTarget) error {
addr := target.Address
if config.Port != 0 {
host, _, err := net.SplitHostPort(target.Address)
if err != nil {
return fmt.Errorf("parsing address %q: %w", target.Address, err)
}
addr = net.JoinHostPort(host, strconv.Itoa(config.Port))
}
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(time.Duration(config.Timeout)))
defer cancel()
conn, err := target.Dialer.DialContext(ctx, "tcp", addr, nil)
if err != nil {
return fmt.Errorf("connecting to %s: %w", addr, err)
}
defer conn.Close()
if err := conn.SetDeadline(time.Now().Add(thc.timeout)); err != nil {
return fmt.Errorf("setting timeout to %s: %w", thc.timeout, err)
}
if config.Send != "" {
if _, err = conn.Write([]byte(config.Send)); err != nil {
return fmt.Errorf("sending to %s: %w", addr, err)
}
}
if config.Expect != "" {
buf := make([]byte, len(config.Expect))
if _, err = conn.Read(buf); err != nil {
return fmt.Errorf("reading from %s: %w", addr, err)
}
if string(buf) != config.Expect {
return errors.New("unexpected heath check response")
}
}
return nil
}

View file

@ -0,0 +1,754 @@
package healthcheck
import (
"context"
"crypto/tls"
"crypto/x509"
"net"
"strings"
"sync"
"testing"
"time"
"github.com/rs/zerolog/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
ptypes "github.com/traefik/paerser/types"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
truntime "github.com/traefik/traefik/v3/pkg/config/runtime"
"github.com/traefik/traefik/v3/pkg/tcp"
)
var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
MIIDJzCCAg+gAwIBAgIUe3vnWg3cTbflL6kz2TyPUxmV8Y4wDQYJKoZIhvcNAQEL
BQAwFjEUMBIGA1UEAwwLZXhhbXBsZS5jb20wIBcNMjUwMzA1MjAwOTM4WhgPMjA1
NTAyMjYyMDA5MzhaMBYxFDASBgNVBAMMC2V4YW1wbGUuY29tMIIBIjANBgkqhkiG
9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4Mm4Sp6xzJvFZJWAv/KVmI1krywiuef8Fhlf
JR2M0caKixjBcNt4U8KwrzIrqL+8nilbps1QuwpQ09+6ztlbUXUL6DqR8ZC+4oCp
gOZ3yyVX2vhMigkATbQyJrX/WVjWSHD5rIUBP2BrsaYLt1qETnFP9wwQ3YEi7V4l
c4+jDrZOtJvrv+tRClt9gQJVgkr7Y30X+dx+rsh+ROaA2+/VTDX0qtoqd/4fjhcJ
OY9VLm0eU66VUMyOTNeUm6ZAXRBp/EonIM1FXOlj82S0pZQbPrvyWWqWoAjtPvLU
qRzqp/BQJqx3EHz1dP6s+xUjP999B+7jhiHoFhZ/bfVVlx8XkwIDAQABo2swaTAd
BgNVHQ4EFgQUhJiJ37LW6RODCpBPAApG1zQxFtAwHwYDVR0jBBgwFoAUhJiJ37LW
6RODCpBPAApG1zQxFtAwDwYDVR0TAQH/BAUwAwEB/zAWBgNVHREEDzANggtleGFt
cGxlLmNvbTANBgkqhkiG9w0BAQsFAAOCAQEAfnDPHllA1TFlQ6zY46tqM20d68bR
kXeGMKLoaATFPbDea5H8/GM5CU6CPD7RUuEB9CvxvaM0aOInxkgstozG7BOr8hcs
WS9fMgM0oO5yGiSOv+Qa0Rc0BFb6A1fUJRta5MI5DTdTJLoyoRX/5aocSI34T67x
ULbkJvVXw6hnx/KZ65apNobfmVQSy7DR8Fo82eB4hSoaLpXyUUTLmctGgrRCoKof
GVUJfKsDJ4Ts8WIR1np74flSoxksWSHEOYk79AZOPANYgJwPMMiiZKsKm17GBoGu
DxI0om4eX8GaSSZAtG6TOt3O3v1oCjKNsAC+u585HN0x0MFA33TUzC15NA==
-----END CERTIFICATE-----`)
var localhostKey = []byte(`-----BEGIN PRIVATE KEY-----
MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQDgybhKnrHMm8Vk
lYC/8pWYjWSvLCK55/wWGV8lHYzRxoqLGMFw23hTwrCvMiuov7yeKVumzVC7ClDT
37rO2VtRdQvoOpHxkL7igKmA5nfLJVfa+EyKCQBNtDImtf9ZWNZIcPmshQE/YGux
pgu3WoROcU/3DBDdgSLtXiVzj6MOtk60m+u/61EKW32BAlWCSvtjfRf53H6uyH5E
5oDb79VMNfSq2ip3/h+OFwk5j1UubR5TrpVQzI5M15SbpkBdEGn8SicgzUVc6WPz
ZLSllBs+u/JZapagCO0+8tSpHOqn8FAmrHcQfPV0/qz7FSM/330H7uOGIegWFn9t
9VWXHxeTAgMBAAECggEALinfGhv7Iaz/3cdCOKlGBZ1MBxmGTC2TPKqbOpEWAWLH
wwcjetznmjQKewBPrQkrYEPYGapioPbeYJS61Y4XzeO+vUOCA10ZhoSrytgJ1ANo
RoTlmxd8I3kVL5QCy8ONxjTFYaOy/OP9We9iypXhRAbLSE4HDKZfmOXTxSbDctql
Kq7uV3LX1KCfr9C6M8d79a0Rdr4p8IXp8MOg3tXq6n75vZbepRFyAujhg7o/kkTp
lgv87h89lrK97K+AjqtvCIT3X3VXfA+LYp3AoQFdOluKgyJT221MyHkTeI/7gggt
Z57lVGD71UJH/LGUJWrraJqXd9uDxZWprD/s66BIAQKBgQD8CtHUJ/VuS7gP0ebN
688zrmRtENj6Gqi+URm/Pwgr9b7wKKlf9jjhg5F/ue+BgB7/nK6N7yJ4Xx3JJ5ox
LqsRGLFa4fDBxogF/FN27obD8naOxe2wS1uTjM6LSrvdJ+HjeNEwHYhjuDjTAHj5
VVEMagZWgkE4jBiFUYefiYLsAQKBgQDkUVdW8cXaYri5xxDW86JNUzI1tUPyd6I+
AkOHV/V0y2zpwTHVLcETRpdVGpc5TH3J5vWf+5OvSz6RDTGjv7blDb8vB/kVkFmn
uXTi0dB9P+SYTsm+X3V7hOAFsyVYZ1D9IFsKUyMgxMdF+qgERjdPKx5IdLV/Jf3q
P9pQ922TkwKBgCKllhyU9Z8Y14+NKi4qeUxAb9uyUjFnUsT+vwxULNpmKL44yLfB
UCZoAKtPMwZZR2mZ70Dhm5pycNTDFeYm5Ssvesnkf0UT9oTkH9EcjvgGr5eGy9rN
MSSCWa46MsL/BYVQiWkU1jfnDiCrUvXrbX3IYWCo/TA5yfEhuQQMUiwBAoGADyzo
5TqEsBNHu/FjSSZAb2tMNw2pSoBxJDX6TxClm/G5d4AD0+uKncFfZaSy0HgpFDZp
tQx/sHML4ZBC8GNZwLe9MV8SS0Cg9Oj6v+i6Ntj8VLNH7YNix6b5TOevX8TeOTTh
WDpWZ2Ms65XRfRc9reFrzd0UAzN/QQaleCQ6AEkCgYBe4Ucows7JGbv7fNkz3nb1
kyH+hk9ecnq/evDKX7UUxKO1wwTi74IYKgcRB2uPLpHKL35gPz+LAfCphCW5rwpR
lvDhS+Pi/1KCBJxLHMv+V/WrckDRgHFnAhDaBZ+2vI/s09rKDnpjcTzV7x22kL0b
XIJCEEE8JZ4AXIZ+IcB6LA==
-----END PRIVATE KEY-----`)
func TestNewServiceTCPHealthChecker(t *testing.T) {
testCases := []struct {
desc string
config *dynamic.TCPServerHealthCheck
expectedInterval time.Duration
expectedTimeout time.Duration
}{
{
desc: "default values",
config: &dynamic.TCPServerHealthCheck{},
expectedInterval: time.Duration(dynamic.DefaultHealthCheckInterval),
expectedTimeout: time.Duration(dynamic.DefaultHealthCheckTimeout),
},
{
desc: "out of range values",
config: &dynamic.TCPServerHealthCheck{
Interval: ptypes.Duration(-time.Second),
Timeout: ptypes.Duration(-time.Second),
},
expectedInterval: time.Duration(dynamic.DefaultHealthCheckInterval),
expectedTimeout: time.Duration(dynamic.DefaultHealthCheckTimeout),
},
{
desc: "custom durations",
config: &dynamic.TCPServerHealthCheck{
Interval: ptypes.Duration(time.Second * 10),
Timeout: ptypes.Duration(time.Second * 5),
},
expectedInterval: time.Second * 10,
expectedTimeout: time.Second * 5,
},
{
desc: "interval shorter than timeout",
config: &dynamic.TCPServerHealthCheck{
Interval: ptypes.Duration(time.Second),
Timeout: ptypes.Duration(time.Second * 5),
},
expectedInterval: time.Second,
expectedTimeout: time.Second * 5,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
healthChecker := NewServiceTCPHealthChecker(t.Context(), test.config, nil, nil, nil, "")
assert.Equal(t, test.expectedInterval, healthChecker.interval)
assert.Equal(t, test.expectedTimeout, healthChecker.timeout)
})
}
}
func TestServiceTCPHealthChecker_executeHealthCheck_connection(t *testing.T) {
testCases := []struct {
desc string
address string
config *dynamic.TCPServerHealthCheck
expectedAddress string
}{
{
desc: "no port override - uses original address",
address: "127.0.0.1:8080",
config: &dynamic.TCPServerHealthCheck{Port: 0},
expectedAddress: "127.0.0.1:8080",
},
{
desc: "port override - uses overridden port",
address: "127.0.0.1:8080",
config: &dynamic.TCPServerHealthCheck{Port: 9090},
expectedAddress: "127.0.0.1:9090",
},
{
desc: "IPv6 address with port override",
address: "[::1]:8080",
config: &dynamic.TCPServerHealthCheck{Port: 9090},
expectedAddress: "[::1]:9090",
},
{
desc: "successful connection without port override",
address: "localhost:3306",
config: &dynamic.TCPServerHealthCheck{Port: 0},
expectedAddress: "localhost:3306",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
// Create a mock dialer that records the address it was asked to dial.
var gotAddress string
mockDialer := &dialerMock{
onDial: func(network, addr string) (net.Conn, error) {
gotAddress = addr
return &connMock{}, nil
},
}
targets := []TCPHealthCheckTarget{{
Address: test.address,
Dialer: mockDialer,
}}
healthChecker := NewServiceTCPHealthChecker(t.Context(), test.config, nil, nil, targets, "test")
// Execute a health check to see what address it tries to connect to.
err := healthChecker.executeHealthCheck(t.Context(), test.config, &targets[0])
require.NoError(t, err)
// Verify that the health check attempted to connect to the expected address.
assert.Equal(t, test.expectedAddress, gotAddress)
})
}
}
func TestServiceTCPHealthChecker_executeHealthCheck_payloadHandling(t *testing.T) {
testCases := []struct {
desc string
config *dynamic.TCPServerHealthCheck
mockResponse string
expectedSentData string
expectedSuccess bool
}{
{
desc: "successful send and expect",
config: &dynamic.TCPServerHealthCheck{
Send: "PING",
Expect: "PONG",
},
mockResponse: "PONG",
expectedSentData: "PING",
expectedSuccess: true,
},
{
desc: "send without expect",
config: &dynamic.TCPServerHealthCheck{
Send: "STATUS",
Expect: "",
},
expectedSentData: "STATUS",
expectedSuccess: true,
},
{
desc: "send without expect, ignores response",
config: &dynamic.TCPServerHealthCheck{
Send: "STATUS",
},
mockResponse: strings.Repeat("A", maxPayloadSize+1),
expectedSentData: "STATUS",
expectedSuccess: true,
},
{
desc: "expect without send",
config: &dynamic.TCPServerHealthCheck{
Expect: "READY",
},
mockResponse: "READY",
expectedSuccess: true,
},
{
desc: "wrong response received",
config: &dynamic.TCPServerHealthCheck{
Send: "PING",
Expect: "PONG",
},
mockResponse: "WRONG",
expectedSentData: "PING",
expectedSuccess: false,
},
{
desc: "send payload too large - gets truncated",
config: &dynamic.TCPServerHealthCheck{
Send: strings.Repeat("A", maxPayloadSize+1), // Will be truncated to empty
Expect: "OK",
},
mockResponse: "OK",
expectedSuccess: true,
},
{
desc: "expect payload too large - gets truncated",
config: &dynamic.TCPServerHealthCheck{
Send: "PING",
Expect: strings.Repeat("B", maxPayloadSize+1), // Will be truncated to empty
},
expectedSentData: "PING",
expectedSuccess: true,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
var sentData []byte
mockConn := &connMock{
writeFunc: func(data []byte) (int, error) {
sentData = append([]byte{}, data...)
return len(data), nil
},
readFunc: func(buf []byte) (int, error) {
return copy(buf, test.mockResponse), nil
},
}
mockDialer := &dialerMock{
onDial: func(network, addr string) (net.Conn, error) {
return mockConn, nil
},
}
targets := []TCPHealthCheckTarget{{
Address: "127.0.0.1:8080",
TLS: false,
Dialer: mockDialer,
}}
healthChecker := NewServiceTCPHealthChecker(t.Context(), test.config, nil, nil, targets, "test")
err := healthChecker.executeHealthCheck(t.Context(), test.config, &targets[0])
if test.expectedSuccess {
assert.NoError(t, err, "Health check should succeed")
} else {
assert.Error(t, err, "Health check should fail")
}
assert.Equal(t, test.expectedSentData, string(sentData), "Should send the expected data")
})
}
}
func TestServiceTCPHealthChecker_Launch(t *testing.T) {
testCases := []struct {
desc string
server *sequencedTCPServer
config *dynamic.TCPServerHealthCheck
expNumRemovedServers int
expNumUpsertedServers int
targetStatus string
}{
{
desc: "connection-only healthy server staying healthy",
server: newTCPServer(t,
false,
tcpMockSequence{accept: true},
tcpMockSequence{accept: true},
tcpMockSequence{accept: true},
),
config: &dynamic.TCPServerHealthCheck{
Interval: ptypes.Duration(time.Millisecond * 50),
Timeout: ptypes.Duration(time.Millisecond * 40),
},
expNumRemovedServers: 0,
expNumUpsertedServers: 3, // 3 health check sequences
targetStatus: truntime.StatusUp,
},
{
desc: "connection-only healthy server becoming unhealthy",
server: newTCPServer(t,
false,
tcpMockSequence{accept: true},
tcpMockSequence{accept: false},
),
config: &dynamic.TCPServerHealthCheck{
Interval: ptypes.Duration(time.Millisecond * 50),
Timeout: ptypes.Duration(time.Millisecond * 40),
},
expNumRemovedServers: 1,
expNumUpsertedServers: 1,
targetStatus: truntime.StatusDown,
},
{
desc: "connection-only server toggling unhealthy to healthy",
server: newTCPServer(t,
false,
tcpMockSequence{accept: false},
tcpMockSequence{accept: true},
),
config: &dynamic.TCPServerHealthCheck{
Interval: ptypes.Duration(time.Millisecond * 50),
Timeout: ptypes.Duration(time.Millisecond * 40),
},
expNumRemovedServers: 1, // 1 failure call
expNumUpsertedServers: 1, // 1 success call
targetStatus: truntime.StatusUp,
},
{
desc: "connection-only server toggling healthy to unhealthy to healthy",
server: newTCPServer(t,
false,
tcpMockSequence{accept: true},
tcpMockSequence{accept: false},
tcpMockSequence{accept: true},
),
config: &dynamic.TCPServerHealthCheck{
Interval: ptypes.Duration(time.Millisecond * 50),
Timeout: ptypes.Duration(time.Millisecond * 40),
},
expNumRemovedServers: 1,
expNumUpsertedServers: 2,
targetStatus: truntime.StatusUp,
},
{
desc: "send/expect healthy server staying healthy",
server: newTCPServer(t,
false,
tcpMockSequence{accept: true, payloadIn: "PING", payloadOut: "PONG"},
tcpMockSequence{accept: true, payloadIn: "PING", payloadOut: "PONG"},
),
config: &dynamic.TCPServerHealthCheck{
Send: "PING",
Expect: "PONG",
Interval: ptypes.Duration(time.Millisecond * 50),
Timeout: ptypes.Duration(time.Millisecond * 40),
},
expNumRemovedServers: 0,
expNumUpsertedServers: 2, // 2 successful health checks
targetStatus: truntime.StatusUp,
},
{
desc: "send/expect server with wrong response",
server: newTCPServer(t,
false,
tcpMockSequence{accept: true, payloadIn: "PING", payloadOut: "PONG"},
tcpMockSequence{accept: true, payloadIn: "PING", payloadOut: "WRONG"},
),
config: &dynamic.TCPServerHealthCheck{
Send: "PING",
Expect: "PONG",
Interval: ptypes.Duration(time.Millisecond * 50),
Timeout: ptypes.Duration(time.Millisecond * 40),
},
expNumRemovedServers: 1,
expNumUpsertedServers: 1,
targetStatus: truntime.StatusDown,
},
{
desc: "TLS healthy server staying healthy",
server: newTCPServer(t,
true,
tcpMockSequence{accept: true, payloadIn: "HELLO", payloadOut: "WORLD"},
),
config: &dynamic.TCPServerHealthCheck{
Send: "HELLO",
Expect: "WORLD",
Interval: ptypes.Duration(time.Millisecond * 500),
Timeout: ptypes.Duration(time.Millisecond * 2000), // Even longer timeout for TLS handshake
},
expNumRemovedServers: 0,
expNumUpsertedServers: 1, // 1 TLS health check sequence
targetStatus: truntime.StatusUp,
},
{
desc: "send-only healthcheck (no expect)",
server: newTCPServer(t,
false,
tcpMockSequence{accept: true, payloadIn: "STATUS"},
tcpMockSequence{accept: true, payloadIn: "STATUS"},
),
config: &dynamic.TCPServerHealthCheck{
Send: "STATUS",
Interval: ptypes.Duration(time.Millisecond * 50),
Timeout: ptypes.Duration(time.Millisecond * 40),
},
expNumRemovedServers: 0,
expNumUpsertedServers: 2,
targetStatus: truntime.StatusUp,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(log.Logger.WithContext(t.Context()))
defer cancel()
test.server.Start(t)
dialerManager := tcp.NewDialerManager(nil)
dialerManager.Update(map[string]*dynamic.TCPServersTransport{"default@internal": {
TLS: &dynamic.TLSClientConfig{
InsecureSkipVerify: true,
ServerName: "example.com",
},
}})
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{}, test.server.TLS)
require.NoError(t, err)
targets := []TCPHealthCheckTarget{
{
Address: test.server.Addr.String(),
TLS: test.server.TLS,
Dialer: dialer,
},
}
lb := &testLoadBalancer{}
serviceInfo := &truntime.TCPServiceInfo{}
service := NewServiceTCPHealthChecker(ctx, test.config, lb, serviceInfo, targets, "serviceName")
go service.Launch(ctx)
// How much time to wait for the health check to actually complete.
deadline := time.Now().Add(200 * time.Millisecond)
// TLS handshake can take much longer.
if test.server.TLS {
deadline = time.Now().Add(1000 * time.Millisecond)
}
// Wait for all health checks to complete deterministically
for range test.server.StatusSequence {
test.server.Next()
initialUpserted := lb.numUpsertedServers
initialRemoved := lb.numRemovedServers
for time.Now().Before(deadline) {
time.Sleep(5 * time.Millisecond)
if lb.numUpsertedServers > initialUpserted || lb.numRemovedServers > initialRemoved {
break
}
}
}
assert.Equal(t, test.expNumRemovedServers, lb.numRemovedServers, "removed servers")
assert.Equal(t, test.expNumUpsertedServers, lb.numUpsertedServers, "upserted servers")
assert.Equal(t, map[string]string{test.server.Addr.String(): test.targetStatus}, serviceInfo.GetAllStatus())
})
}
}
func TestServiceTCPHealthChecker_differentIntervals(t *testing.T) {
// Test that unhealthy servers are checked more frequently than healthy servers
// when UnhealthyInterval is set to a lower value than Interval
ctx, cancel := context.WithCancel(t.Context())
t.Cleanup(cancel)
// Create a healthy TCP server that always accepts connections
healthyServer := newTCPServer(t, false,
tcpMockSequence{accept: true}, tcpMockSequence{accept: true}, tcpMockSequence{accept: true},
tcpMockSequence{accept: true}, tcpMockSequence{accept: true},
)
healthyServer.Start(t)
// Create an unhealthy TCP server that always rejects connections
unhealthyServer := newTCPServer(t, false,
tcpMockSequence{accept: false}, tcpMockSequence{accept: false}, tcpMockSequence{accept: false},
tcpMockSequence{accept: false}, tcpMockSequence{accept: false}, tcpMockSequence{accept: false},
tcpMockSequence{accept: false}, tcpMockSequence{accept: false}, tcpMockSequence{accept: false},
tcpMockSequence{accept: false},
)
unhealthyServer.Start(t)
lb := &testLoadBalancer{RWMutex: &sync.RWMutex{}}
// Set normal interval to 500ms but unhealthy interval to 50ms
// This means unhealthy servers should be checked 10x more frequently
config := &dynamic.TCPServerHealthCheck{
Interval: ptypes.Duration(500 * time.Millisecond),
UnhealthyInterval: pointer(ptypes.Duration(50 * time.Millisecond)),
Timeout: ptypes.Duration(100 * time.Millisecond),
}
// Set up dialer manager
dialerManager := tcp.NewDialerManager(nil)
dialerManager.Update(map[string]*dynamic.TCPServersTransport{
"default@internal": {
DialTimeout: ptypes.Duration(100 * time.Millisecond),
DialKeepAlive: ptypes.Duration(100 * time.Millisecond),
},
})
// Get dialer for targets
dialer, err := dialerManager.Build(&dynamic.TCPServersLoadBalancer{}, false)
require.NoError(t, err)
targets := []TCPHealthCheckTarget{
{Address: healthyServer.Addr.String(), TLS: false, Dialer: dialer},
{Address: unhealthyServer.Addr.String(), TLS: false, Dialer: dialer},
}
serviceInfo := &truntime.TCPServiceInfo{}
hc := NewServiceTCPHealthChecker(ctx, config, lb, serviceInfo, targets, "test-service")
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
hc.Launch(ctx)
wg.Done()
}()
// Let it run for 2 seconds to see the different check frequencies
select {
case <-time.After(2 * time.Second):
cancel()
case <-ctx.Done():
}
wg.Wait()
lb.Lock()
defer lb.Unlock()
// The unhealthy server should be checked more frequently (50ms interval)
// compared to the healthy server (500ms interval), so we should see
// significantly more "removed" events than "upserted" events
assert.Greater(t, lb.numRemovedServers, lb.numUpsertedServers, "unhealthy servers checked more frequently")
}
type tcpMockSequence struct {
accept bool
payloadIn string
payloadOut string
}
type sequencedTCPServer struct {
Addr *net.TCPAddr
StatusSequence []tcpMockSequence
TLS bool
release chan struct{}
}
func newTCPServer(t *testing.T, tlsEnabled bool, statusSequence ...tcpMockSequence) *sequencedTCPServer {
t.Helper()
addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
require.NoError(t, err)
listener, err := net.ListenTCP("tcp", addr)
require.NoError(t, err)
tcpAddr, ok := listener.Addr().(*net.TCPAddr)
require.True(t, ok)
listener.Close()
return &sequencedTCPServer{
Addr: tcpAddr,
TLS: tlsEnabled,
StatusSequence: statusSequence,
release: make(chan struct{}),
}
}
func (s *sequencedTCPServer) Next() {
s.release <- struct{}{}
}
func (s *sequencedTCPServer) Start(t *testing.T) {
t.Helper()
go func() {
var listener net.Listener
for _, seq := range s.StatusSequence {
<-s.release
if listener != nil {
listener.Close()
}
if !seq.accept {
continue
}
lis, err := net.ListenTCP("tcp", s.Addr)
require.NoError(t, err)
listener = lis
if s.TLS {
cert, err := tls.X509KeyPair(localhostCert, localhostKey)
require.NoError(t, err)
x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
require.NoError(t, err)
certpool := x509.NewCertPool()
certpool.AddCert(x509Cert)
listener = tls.NewListener(
listener,
&tls.Config{
RootCAs: certpool,
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: true,
ServerName: "example.com",
MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS12,
ClientAuth: tls.VerifyClientCertIfGiven,
ClientCAs: certpool,
},
)
}
conn, err := listener.Accept()
require.NoError(t, err)
t.Cleanup(func() {
_ = conn.Close()
})
// For TLS connections, perform handshake first
if s.TLS {
if tlsConn, ok := conn.(*tls.Conn); ok {
if err := tlsConn.Handshake(); err != nil {
continue // Skip this sequence on handshake failure
}
}
}
if seq.payloadIn == "" {
continue
}
buf := make([]byte, len(seq.payloadIn))
n, err := conn.Read(buf)
require.NoError(t, err)
recv := strings.TrimSpace(string(buf[:n]))
switch recv {
case seq.payloadIn:
if _, err := conn.Write([]byte(seq.payloadOut)); err != nil {
t.Errorf("failed to write payload: %v", err)
}
default:
if _, err := conn.Write([]byte("FAULT\n")); err != nil {
t.Errorf("failed to write payload: %v", err)
}
}
}
defer close(s.release)
}()
}
type dialerMock struct {
onDial func(network, addr string) (net.Conn, error)
}
func (dm *dialerMock) Dial(network, addr string, _ tcp.ClientConn) (net.Conn, error) {
return dm.onDial(network, addr)
}
func (dm *dialerMock) DialContext(_ context.Context, network, addr string, _ tcp.ClientConn) (net.Conn, error) {
return dm.onDial(network, addr)
}
func (dm *dialerMock) TerminationDelay() time.Duration {
return 0
}
type connMock struct {
writeFunc func([]byte) (int, error)
readFunc func([]byte) (int, error)
}
func (cm *connMock) Read(b []byte) (n int, err error) {
if cm.readFunc != nil {
return cm.readFunc(b)
}
return 0, nil
}
func (cm *connMock) Write(b []byte) (n int, err error) {
if cm.writeFunc != nil {
return cm.writeFunc(b)
}
return len(b), nil
}
func (cm *connMock) Close() error { return nil }
func (cm *connMock) LocalAddr() net.Addr { return &net.TCPAddr{} }
func (cm *connMock) RemoteAddr() net.Addr { return &net.TCPAddr{} }
func (cm *connMock) SetDeadline(_ time.Time) error { return nil }
func (cm *connMock) SetReadDeadline(_ time.Time) error { return nil }
func (cm *connMock) SetWriteDeadline(_ time.Time) error { return nil }

View file

@ -66,6 +66,8 @@ func TestNewServiceHealthChecker_durations(t *testing.T) {
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
healthChecker := NewServiceHealthChecker(t.Context(), nil, test.config, nil, nil, http.DefaultTransport, nil, "")
assert.Equal(t, test.expInterval, healthChecker.interval)
assert.Equal(t, test.expTimeout, healthChecker.timeout)

View file

@ -124,6 +124,8 @@ func (f *RouterFactory) CreateRouters(rtConf *runtime.Configuration) (map[string
}
}
svcTCPManager.LaunchHealthCheck(ctx)
// UDP
svcUDPManager := udpsvc.NewManager(rtConf)
rtUDPManager := udprouter.NewManager(rtConf, svcUDPManager)

View file

@ -39,6 +39,7 @@ type Balancer struct {
status map[string]struct{}
// updaters is the list of hooks that are run (to update the Balancer
// parent(s)), whenever the Balancer status changes.
// No mutex is needed, as it is modified only during the configuration build.
updaters []func(bool)
// fenced is the list of terminating yet still serving child services.
fenced map[string]struct{}

View file

@ -56,6 +56,7 @@ type Balancer struct {
// updaters is the list of hooks that are run (to update the Balancer
// parent(s)), whenever the Balancer status changes.
// No mutex is needed, as it is modified only during the configuration build.
updaters []func(bool)
sticky *loadbalancer.Sticky

View file

@ -40,6 +40,7 @@ type Balancer struct {
// updaters is the list of hooks that are run (to update the Balancer
// parent(s)), whenever the Balancer status changes.
// No mutex is needed, as it is modified only during the configuration build.
updaters []func(bool)
sticky *loadbalancer.Sticky

View file

@ -266,19 +266,18 @@ func (m *Manager) getWRRServiceHandler(ctx context.Context, serviceName string,
continue
}
childName := service.Name
updater, ok := serviceHandler.(healthcheck.StatusUpdater)
if !ok {
return nil, fmt.Errorf("child service %v of %v not a healthcheck.StatusUpdater (%T)", childName, serviceName, serviceHandler)
return nil, fmt.Errorf("child service %v of %v not a healthcheck.StatusUpdater (%T)", service.Name, serviceName, serviceHandler)
}
if err := updater.RegisterStatusUpdater(func(up bool) {
balancer.SetStatus(ctx, childName, up)
balancer.SetStatus(ctx, service.Name, up)
}); err != nil {
return nil, fmt.Errorf("cannot register %v as updater for %v: %w", childName, serviceName, err)
return nil, fmt.Errorf("cannot register %v as updater for %v: %w", service.Name, serviceName, err)
}
log.Ctx(ctx).Debug().Str("parent", serviceName).Str("child", childName).
log.Ctx(ctx).Debug().Str("parent", serviceName).Str("child", service.Name).
Msg("Child service will update parent on status change")
}
@ -342,19 +341,18 @@ func (m *Manager) getHRWServiceHandler(ctx context.Context, serviceName string,
continue
}
childName := service.Name
updater, ok := serviceHandler.(healthcheck.StatusUpdater)
if !ok {
return nil, fmt.Errorf("child service %v of %v not a healthcheck.StatusUpdater (%T)", childName, serviceName, serviceHandler)
return nil, fmt.Errorf("child service %v of %v not a healthcheck.StatusUpdater (%T)", service.Name, serviceName, serviceHandler)
}
if err := updater.RegisterStatusUpdater(func(up bool) {
balancer.SetStatus(ctx, childName, up)
balancer.SetStatus(ctx, service.Name, up)
}); err != nil {
return nil, fmt.Errorf("cannot register %v as updater for %v: %w", childName, serviceName, err)
return nil, fmt.Errorf("cannot register %v as updater for %v: %w", service.Name, serviceName, err)
}
log.Ctx(ctx).Debug().Str("parent", serviceName).Str("child", childName).
log.Ctx(ctx).Debug().Str("parent", serviceName).Str("child", service.Name).
Msg("Child service will update parent on status change")
}
@ -466,7 +464,7 @@ func (m *Manager) getLoadBalancerServiceHandler(ctx context.Context, serviceName
lb.AddServer(server.URL, proxy, server)
// servers are considered UP by default.
// Servers are considered UP by default.
info.UpdateServerStatus(target.String(), runtime.StatusUp)
healthCheckTargets[server.URL] = target

View file

@ -4,12 +4,15 @@ import (
"context"
"errors"
"fmt"
"maps"
"math/rand"
"net"
"slices"
"time"
"github.com/rs/zerolog/log"
"github.com/traefik/traefik/v3/pkg/config/runtime"
"github.com/traefik/traefik/v3/pkg/healthcheck"
"github.com/traefik/traefik/v3/pkg/observability/logs"
"github.com/traefik/traefik/v3/pkg/server/provider"
"github.com/traefik/traefik/v3/pkg/tcp"
@ -17,17 +20,19 @@ import (
// Manager is the TCPHandlers factory.
type Manager struct {
dialerManager *tcp.DialerManager
configs map[string]*runtime.TCPServiceInfo
rand *rand.Rand // For the initial shuffling of load-balancers.
dialerManager *tcp.DialerManager
configs map[string]*runtime.TCPServiceInfo
rand *rand.Rand // For the initial shuffling of load-balancers.
healthCheckers map[string]*healthcheck.ServiceTCPHealthChecker
}
// NewManager creates a new manager.
func NewManager(conf *runtime.Configuration, dialerManager *tcp.DialerManager) *Manager {
return &Manager{
dialerManager: dialerManager,
configs: conf.TCPServices,
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
dialerManager: dialerManager,
healthCheckers: make(map[string]*healthcheck.ServiceTCPHealthChecker),
configs: conf.TCPServices,
rand: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}
@ -51,7 +56,7 @@ func (m *Manager) BuildTCP(rootCtx context.Context, serviceName string) (tcp.Han
switch {
case conf.LoadBalancer != nil:
loadBalancer := tcp.NewWRRLoadBalancer()
loadBalancer := tcp.NewWRRLoadBalancer(conf.LoadBalancer.HealthCheck != nil)
if conf.LoadBalancer.TerminationDelay != nil {
log.Ctx(ctx).Warn().Msgf("Service %q load balancer uses `TerminationDelay`, but this option is deprecated, please use ServersTransport configuration instead.", serviceName)
@ -65,6 +70,8 @@ func (m *Manager) BuildTCP(rootCtx context.Context, serviceName string) (tcp.Han
conf.LoadBalancer.ServersTransport = provider.GetQualifiedName(ctx, conf.LoadBalancer.ServersTransport)
}
uniqHealthCheckTargets := make(map[string]healthcheck.TCPHealthCheckTarget, len(conf.LoadBalancer.Servers))
for index, server := range shuffle(conf.LoadBalancer.Servers, m.rand) {
srvLogger := logger.With().
Int(logs.ServerIndex, index).
@ -86,14 +93,34 @@ func (m *Manager) BuildTCP(rootCtx context.Context, serviceName string) (tcp.Han
continue
}
loadBalancer.AddServer(handler)
loadBalancer.Add(server.Address, handler, nil)
// Servers are considered UP by default.
conf.UpdateServerStatus(server.Address, runtime.StatusUp)
uniqHealthCheckTargets[server.Address] = healthcheck.TCPHealthCheckTarget{
Address: server.Address,
TLS: server.TLS,
Dialer: dialer,
}
logger.Debug().Msg("Creating TCP server")
}
if conf.LoadBalancer.HealthCheck != nil {
m.healthCheckers[serviceName] = healthcheck.NewServiceTCPHealthChecker(
ctx,
conf.LoadBalancer.HealthCheck,
loadBalancer,
conf,
slices.Collect(maps.Values(uniqHealthCheckTargets)),
serviceQualifiedName)
}
return loadBalancer, nil
case conf.Weighted != nil:
loadBalancer := tcp.NewWRRLoadBalancer()
loadBalancer := tcp.NewWRRLoadBalancer(conf.Weighted.HealthCheck != nil)
for _, service := range shuffle(conf.Weighted.Services, m.rand) {
handler, err := m.BuildTCP(ctx, service.Name)
@ -102,7 +129,25 @@ func (m *Manager) BuildTCP(rootCtx context.Context, serviceName string) (tcp.Han
return nil, err
}
loadBalancer.AddWeightServer(handler, service.Weight)
loadBalancer.Add(service.Name, handler, service.Weight)
if conf.Weighted.HealthCheck == nil {
continue
}
updater, ok := handler.(healthcheck.StatusUpdater)
if !ok {
return nil, fmt.Errorf("child service %v of %v not a healthcheck.StatusUpdater (%T)", service.Name, serviceName, handler)
}
if err := updater.RegisterStatusUpdater(func(up bool) {
loadBalancer.SetStatus(ctx, service.Name, up)
}); err != nil {
return nil, fmt.Errorf("cannot register %v as updater for %v: %w", service.Name, serviceName, err)
}
log.Ctx(ctx).Debug().Str("parent", serviceName).Str("child", service.Name).
Msg("Child service will update parent on status change")
}
return loadBalancer, nil
@ -114,6 +159,14 @@ func (m *Manager) BuildTCP(rootCtx context.Context, serviceName string) (tcp.Han
}
}
// LaunchHealthCheck launches the health checks.
func (m *Manager) LaunchHealthCheck(ctx context.Context) {
for serviceName, hc := range m.healthCheckers {
logger := log.Ctx(ctx).With().Str(logs.ServiceName, serviceName).Logger()
go hc.Launch(logger.WithContext(ctx))
}
}
func shuffle[T any](values []T, r *rand.Rand) []T {
shuffled := make([]T, len(values))
copy(shuffled, values)

View file

@ -233,6 +233,49 @@ func TestManager_BuildTCP(t *testing.T) {
providerName: "provider-1",
expectedError: "no transport configuration found for \"myServersTransport@provider-1\"",
},
{
desc: "WRR with healthcheck enabled",
stConfigs: map[string]*dynamic.TCPServersTransport{"default@internal": {}},
serviceName: "serviceName",
configs: map[string]*runtime.TCPServiceInfo{
"serviceName@provider-1": {
TCPService: &dynamic.TCPService{
Weighted: &dynamic.TCPWeightedRoundRobin{
Services: []dynamic.TCPWRRService{
{Name: "foobar@provider-1", Weight: new(int)},
{Name: "foobar2@provider-1", Weight: new(int)},
},
HealthCheck: &dynamic.HealthCheck{},
},
},
},
"foobar@provider-1": {
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "192.168.0.12:80",
},
},
HealthCheck: &dynamic.TCPServerHealthCheck{},
},
},
},
"foobar2@provider-1": {
TCPService: &dynamic.TCPService{
LoadBalancer: &dynamic.TCPServersLoadBalancer{
Servers: []dynamic.TCPServer{
{
Address: "192.168.0.13:80",
},
},
HealthCheck: &dynamic.TCPServerHealthCheck{},
},
},
},
},
providerName: "provider-1",
},
}
for _, test := range testCases {

View file

@ -1,6 +1,7 @@
package tcp
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
@ -33,6 +34,7 @@ type ClientConn interface {
// Dialer is an interface to dial a network connection, with support for PROXY protocol and termination delay.
type Dialer interface {
Dial(network, addr string, clientConn ClientConn) (c net.Conn, err error)
DialContext(ctx context.Context, network, addr string, clientConn ClientConn) (c net.Conn, err error)
TerminationDelay() time.Duration
}
@ -49,7 +51,12 @@ func (d tcpDialer) TerminationDelay() time.Duration {
// Dial dials a network connection and optionally sends a PROXY protocol header.
func (d tcpDialer) Dial(network, addr string, clientConn ClientConn) (net.Conn, error) {
conn, err := d.dialer.Dial(network, addr)
return d.DialContext(context.Background(), network, addr, clientConn)
}
// DialContext dials a network connection and optionally sends a PROXY protocol header, with context.
func (d tcpDialer) DialContext(ctx context.Context, network, addr string, clientConn ClientConn) (net.Conn, error) {
conn, err := d.dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
@ -72,7 +79,12 @@ type tcpTLSDialer struct {
// Dial dials a network connection with the wrapped tcpDialer and performs a TLS handshake.
func (d tcpTLSDialer) Dial(network, addr string, clientConn ClientConn) (net.Conn, error) {
conn, err := d.tcpDialer.Dial(network, addr, clientConn)
return d.DialContext(context.Background(), network, addr, clientConn)
}
// DialContext dials a network connection with the wrapped tcpDialer and performs a TLS handshake, with context.
func (d tcpTLSDialer) DialContext(ctx context.Context, network, addr string, clientConn ClientConn) (net.Conn, error) {
conn, err := d.tcpDialer.DialContext(ctx, network, addr, clientConn)
if err != nil {
return nil, err
}

View file

@ -1,6 +1,7 @@
package tcp
import (
"context"
"errors"
"sync"
@ -11,30 +12,42 @@ var errNoServersInPool = errors.New("no servers in the pool")
type server struct {
Handler
name string
weight int
}
// WRRLoadBalancer is a naive RoundRobin load balancer for TCP services.
type WRRLoadBalancer struct {
servers []server
lock sync.Mutex
currentWeight int
index int
// serversMu is a mutex to protect the handlers slice and the status.
serversMu sync.Mutex
servers []server
// status is a record of which child services of the Balancer are healthy, keyed
// by name of child service. A service is initially added to the map when it is
// created via Add, and it is later removed or added to the map as needed,
// through the SetStatus method.
status map[string]struct{}
// updaters is the list of hooks that are run (to update the Balancer parent(s)), whenever the Balancer status changes.
// No mutex is needed, as it is modified only during the configuration build.
updaters []func(bool)
index int
currentWeight int
wantsHealthCheck bool
}
// NewWRRLoadBalancer creates a new WRRLoadBalancer.
func NewWRRLoadBalancer() *WRRLoadBalancer {
func NewWRRLoadBalancer(wantsHealthCheck bool) *WRRLoadBalancer {
return &WRRLoadBalancer{
index: -1,
status: make(map[string]struct{}),
index: -1,
wantsHealthCheck: wantsHealthCheck,
}
}
// ServeTCP forwards the connection to the right service.
func (b *WRRLoadBalancer) ServeTCP(conn WriteCloser) {
b.lock.Lock()
next, err := b.next()
b.lock.Unlock()
next, err := b.nextServer()
if err != nil {
if !errors.Is(err, errNoServersInPool) {
log.Error().Err(err).Msg("Error during load balancing")
@ -46,22 +59,103 @@ func (b *WRRLoadBalancer) ServeTCP(conn WriteCloser) {
next.ServeTCP(conn)
}
// AddServer appends a server to the existing list.
func (b *WRRLoadBalancer) AddServer(serverHandler Handler) {
w := 1
b.AddWeightServer(serverHandler, &w)
}
// AddWeightServer appends a server to the existing list with a weight.
func (b *WRRLoadBalancer) AddWeightServer(serverHandler Handler, weight *int) {
b.lock.Lock()
defer b.lock.Unlock()
// Add appends a server to the existing list with a name and weight.
func (b *WRRLoadBalancer) Add(name string, handler Handler, weight *int) {
w := 1
if weight != nil {
w = *weight
}
b.servers = append(b.servers, server{Handler: serverHandler, weight: w})
b.serversMu.Lock()
b.servers = append(b.servers, server{Handler: handler, name: name, weight: w})
b.status[name] = struct{}{}
b.serversMu.Unlock()
}
// SetStatus sets status (UP or DOWN) of a target server.
func (b *WRRLoadBalancer) SetStatus(ctx context.Context, childName string, up bool) {
b.serversMu.Lock()
defer b.serversMu.Unlock()
upBefore := len(b.status) > 0
status := "DOWN"
if up {
status = "UP"
}
log.Ctx(ctx).Debug().Msgf("Setting status of %s to %v", childName, status)
if up {
b.status[childName] = struct{}{}
} else {
delete(b.status, childName)
}
upAfter := len(b.status) > 0
status = "DOWN"
if upAfter {
status = "UP"
}
// No Status Change
if upBefore == upAfter {
// We're still with the same status, no need to propagate
log.Ctx(ctx).Debug().Msgf("Still %s, no need to propagate", status)
return
}
// Status Change
log.Ctx(ctx).Debug().Msgf("Propagating new %s status", status)
for _, fn := range b.updaters {
fn(upAfter)
}
}
func (b *WRRLoadBalancer) RegisterStatusUpdater(fn func(up bool)) error {
if !b.wantsHealthCheck {
return errors.New("healthCheck not enabled in config for this weighted service")
}
b.updaters = append(b.updaters, fn)
return nil
}
func (b *WRRLoadBalancer) nextServer() (Handler, error) {
b.serversMu.Lock()
defer b.serversMu.Unlock()
if len(b.servers) == 0 || len(b.status) == 0 {
return nil, errNoServersInPool
}
// The algo below may look messy, but is actually very simple
// it calculates the GCD and subtracts it on every iteration, what interleaves servers
// and allows us not to build an iterator every time we readjust weights.
// Maximum weight across all enabled servers.
maximum := b.maxWeight()
if maximum == 0 {
return nil, errors.New("all servers have 0 weight")
}
// GCD across all enabled servers
gcd := b.weightGcd()
for {
b.index = (b.index + 1) % len(b.servers)
if b.index == 0 {
b.currentWeight -= gcd
if b.currentWeight <= 0 {
b.currentWeight = maximum
}
}
srv := b.servers[b.index]
if _, ok := b.status[srv.name]; ok && srv.weight >= b.currentWeight {
return srv, nil
}
}
}
func (b *WRRLoadBalancer) maxWeight() int {
@ -92,36 +186,3 @@ func gcd(a, b int) int {
}
return a
}
func (b *WRRLoadBalancer) next() (Handler, error) {
if len(b.servers) == 0 {
return nil, errNoServersInPool
}
// The algo below may look messy, but is actually very simple
// it calculates the GCD and subtracts it on every iteration, what interleaves servers
// and allows us not to build an iterator every time we readjust weights
// Maximum weight across all enabled servers
maximum := b.maxWeight()
if maximum == 0 {
return nil, errors.New("all servers have 0 weight")
}
// GCD across all enabled servers
gcd := b.weightGcd()
for {
b.index = (b.index + 1) % len(b.servers)
if b.index == 0 {
b.currentWeight -= gcd
if b.currentWeight <= 0 {
b.currentWeight = maximum
}
}
srv := b.servers[b.index]
if srv.weight >= b.currentWeight {
return srv, nil
}
}
}

View file

@ -9,50 +9,7 @@ import (
"github.com/stretchr/testify/require"
)
type fakeConn struct {
writeCall map[string]int
closeCall int
}
func (f *fakeConn) Read(b []byte) (n int, err error) {
panic("implement me")
}
func (f *fakeConn) Write(b []byte) (n int, err error) {
f.writeCall[string(b)]++
return len(b), nil
}
func (f *fakeConn) Close() error {
f.closeCall++
return nil
}
func (f *fakeConn) LocalAddr() net.Addr {
panic("implement me")
}
func (f *fakeConn) RemoteAddr() net.Addr {
panic("implement me")
}
func (f *fakeConn) SetDeadline(t time.Time) error {
panic("implement me")
}
func (f *fakeConn) SetReadDeadline(t time.Time) error {
panic("implement me")
}
func (f *fakeConn) SetWriteDeadline(t time.Time) error {
panic("implement me")
}
func (f *fakeConn) CloseWrite() error {
panic("implement me")
}
func TestLoadBalancing(t *testing.T) {
func TestWRRLoadBalancer_LoadBalancing(t *testing.T) {
testCases := []struct {
desc string
serversWeight map[string]int
@ -124,9 +81,9 @@ func TestLoadBalancing(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
balancer := NewWRRLoadBalancer()
balancer := NewWRRLoadBalancer(false)
for server, weight := range test.serversWeight {
balancer.AddWeightServer(HandlerFunc(func(conn WriteCloser) {
balancer.Add(server, HandlerFunc(func(conn WriteCloser) {
_, err := conn.Write([]byte(server))
require.NoError(t, err)
}), &weight)
@ -142,3 +99,196 @@ func TestLoadBalancing(t *testing.T) {
})
}
}
func TestWRRLoadBalancer_NoServiceUp(t *testing.T) {
balancer := NewWRRLoadBalancer(false)
balancer.Add("first", HandlerFunc(func(conn WriteCloser) {
_, err := conn.Write([]byte("first"))
require.NoError(t, err)
}), pointer(1))
balancer.Add("second", HandlerFunc(func(conn WriteCloser) {
_, err := conn.Write([]byte("second"))
require.NoError(t, err)
}), pointer(1))
balancer.SetStatus(t.Context(), "first", false)
balancer.SetStatus(t.Context(), "second", false)
conn := &fakeConn{writeCall: make(map[string]int)}
balancer.ServeTCP(conn)
assert.Empty(t, conn.writeCall)
assert.Equal(t, 1, conn.closeCall)
}
func TestWRRLoadBalancer_OneServerDown(t *testing.T) {
balancer := NewWRRLoadBalancer(false)
balancer.Add("first", HandlerFunc(func(conn WriteCloser) {
_, err := conn.Write([]byte("first"))
require.NoError(t, err)
}), pointer(1))
balancer.Add("second", HandlerFunc(func(conn WriteCloser) {
_, err := conn.Write([]byte("second"))
require.NoError(t, err)
}), pointer(1))
balancer.SetStatus(t.Context(), "second", false)
conn := &fakeConn{writeCall: make(map[string]int)}
for range 3 {
balancer.ServeTCP(conn)
}
assert.Equal(t, 3, conn.writeCall["first"])
}
func TestWRRLoadBalancer_DownThenUp(t *testing.T) {
balancer := NewWRRLoadBalancer(false)
balancer.Add("first", HandlerFunc(func(conn WriteCloser) {
_, err := conn.Write([]byte("first"))
require.NoError(t, err)
}), pointer(1))
balancer.Add("second", HandlerFunc(func(conn WriteCloser) {
_, err := conn.Write([]byte("second"))
require.NoError(t, err)
}), pointer(1))
balancer.SetStatus(t.Context(), "second", false)
conn := &fakeConn{writeCall: make(map[string]int)}
for range 3 {
balancer.ServeTCP(conn)
}
assert.Equal(t, 3, conn.writeCall["first"])
balancer.SetStatus(t.Context(), "second", true)
conn = &fakeConn{writeCall: make(map[string]int)}
for range 2 {
balancer.ServeTCP(conn)
}
assert.Equal(t, 1, conn.writeCall["first"])
assert.Equal(t, 1, conn.writeCall["second"])
}
func TestWRRLoadBalancer_Propagate(t *testing.T) {
balancer1 := NewWRRLoadBalancer(true)
balancer1.Add("first", HandlerFunc(func(conn WriteCloser) {
_, err := conn.Write([]byte("first"))
require.NoError(t, err)
}), pointer(1))
balancer1.Add("second", HandlerFunc(func(conn WriteCloser) {
_, err := conn.Write([]byte("second"))
require.NoError(t, err)
}), pointer(1))
balancer2 := NewWRRLoadBalancer(true)
balancer2.Add("third", HandlerFunc(func(conn WriteCloser) {
_, err := conn.Write([]byte("third"))
require.NoError(t, err)
}), pointer(1))
balancer2.Add("fourth", HandlerFunc(func(conn WriteCloser) {
_, err := conn.Write([]byte("fourth"))
require.NoError(t, err)
}), pointer(1))
topBalancer := NewWRRLoadBalancer(true)
topBalancer.Add("balancer1", balancer1, pointer(1))
_ = balancer1.RegisterStatusUpdater(func(up bool) {
topBalancer.SetStatus(t.Context(), "balancer1", up)
})
topBalancer.Add("balancer2", balancer2, pointer(1))
_ = balancer2.RegisterStatusUpdater(func(up bool) {
topBalancer.SetStatus(t.Context(), "balancer2", up)
})
conn := &fakeConn{writeCall: make(map[string]int)}
for range 8 {
topBalancer.ServeTCP(conn)
}
assert.Equal(t, 2, conn.writeCall["first"])
assert.Equal(t, 2, conn.writeCall["second"])
assert.Equal(t, 2, conn.writeCall["third"])
assert.Equal(t, 2, conn.writeCall["fourth"])
// fourth gets downed, but balancer2 still up since third is still up.
balancer2.SetStatus(t.Context(), "fourth", false)
conn = &fakeConn{writeCall: make(map[string]int)}
for range 8 {
topBalancer.ServeTCP(conn)
}
assert.Equal(t, 2, conn.writeCall["first"])
assert.Equal(t, 2, conn.writeCall["second"])
assert.Equal(t, 4, conn.writeCall["third"])
assert.Equal(t, 0, conn.writeCall["fourth"])
// third gets downed, and the propagation triggers balancer2 to be marked as
// down as well for topBalancer.
balancer2.SetStatus(t.Context(), "third", false)
conn = &fakeConn{writeCall: make(map[string]int)}
for range 8 {
topBalancer.ServeTCP(conn)
}
assert.Equal(t, 4, conn.writeCall["first"])
assert.Equal(t, 4, conn.writeCall["second"])
assert.Equal(t, 0, conn.writeCall["third"])
assert.Equal(t, 0, conn.writeCall["fourth"])
}
func pointer[T any](v T) *T { return &v }
type fakeConn struct {
writeCall map[string]int
closeCall int
}
func (f *fakeConn) Read(b []byte) (n int, err error) {
panic("implement me")
}
func (f *fakeConn) Write(b []byte) (n int, err error) {
f.writeCall[string(b)]++
return len(b), nil
}
func (f *fakeConn) Close() error {
f.closeCall++
return nil
}
func (f *fakeConn) LocalAddr() net.Addr {
panic("implement me")
}
func (f *fakeConn) RemoteAddr() net.Addr {
panic("implement me")
}
func (f *fakeConn) SetDeadline(t time.Time) error {
panic("implement me")
}
func (f *fakeConn) SetReadDeadline(t time.Time) error {
panic("implement me")
}
func (f *fakeConn) SetWriteDeadline(t time.Time) error {
panic("implement me")
}
func (f *fakeConn) CloseWrite() error {
panic("implement me")
}