On client CloseWrite, do CloseWrite instead of Close for backend
Co-authored-by: Mathieu Lonjaret <mathieu.lonjaret@gmail.com>
This commit is contained in:
parent
401b3afa3b
commit
b55be9fdea
25 changed files with 393 additions and 36 deletions
|
@ -45,7 +45,19 @@ type RouterTCPTLSConfig struct {
|
|||
|
||||
// TCPLoadBalancerService holds the LoadBalancerService configuration.
|
||||
type TCPLoadBalancerService struct {
|
||||
Servers []TCPServer `json:"servers,omitempty" toml:"servers,omitempty" yaml:"servers,omitempty" label-slice-as-struct:"server"`
|
||||
// TerminationDelay, corresponds to the deadline that the proxy sets, after one
|
||||
// of its connected peers indicates it has closed the writing capability of its
|
||||
// connection, to close the reading capability as well, hence fully terminating the
|
||||
// connection. It is a duration in milliseconds, defaulting to 100. A negative value
|
||||
// means an infinite deadline (i.e. the reading capability is never closed).
|
||||
TerminationDelay *int `json:"terminationDelay,omitempty" toml:"terminationDelay,omitempty" yaml:"terminationDelay,omitempty"`
|
||||
Servers []TCPServer `json:"servers,omitempty" toml:"servers,omitempty" yaml:"servers,omitempty" label-slice-as-struct:"server"`
|
||||
}
|
||||
|
||||
// SetDefaults Default values for a TCPLoadBalancerService
|
||||
func (l *TCPLoadBalancerService) SetDefaults() {
|
||||
defaultTerminationDelay := 100 // in milliseconds
|
||||
l.TerminationDelay = &defaultTerminationDelay
|
||||
}
|
||||
|
||||
// Mergeable tells if the given service is mergeable.
|
||||
|
|
|
@ -1155,6 +1155,11 @@ func (in *TCPConfiguration) DeepCopy() *TCPConfiguration {
|
|||
// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
|
||||
func (in *TCPLoadBalancerService) DeepCopyInto(out *TCPLoadBalancerService) {
|
||||
*out = *in
|
||||
if in.TerminationDelay != nil {
|
||||
in, out := &in.TerminationDelay, &out.TerminationDelay
|
||||
*out = new(int)
|
||||
**out = **in
|
||||
}
|
||||
if in.Servers != nil {
|
||||
in, out := &in.Servers, &out.Servers
|
||||
*out = make([]TCPServer, len(*in))
|
||||
|
|
|
@ -170,7 +170,9 @@ func TestDecodeConfiguration(t *testing.T) {
|
|||
"traefik.tcp.routers.Router1.tls.options": "foo",
|
||||
"traefik.tcp.routers.Router1.tls.passthrough": "false",
|
||||
"traefik.tcp.services.Service0.loadbalancer.server.Port": "42",
|
||||
"traefik.tcp.services.Service0.loadbalancer.TerminationDelay": "42",
|
||||
"traefik.tcp.services.Service1.loadbalancer.server.Port": "42",
|
||||
"traefik.tcp.services.Service1.loadbalancer.TerminationDelay": "42",
|
||||
}
|
||||
|
||||
configuration, err := DecodeConfiguration(labels)
|
||||
|
@ -212,6 +214,7 @@ func TestDecodeConfiguration(t *testing.T) {
|
|||
Port: "42",
|
||||
},
|
||||
},
|
||||
TerminationDelay: func(i int) *int { return &i }(42),
|
||||
},
|
||||
},
|
||||
"Service1": {
|
||||
|
@ -221,6 +224,7 @@ func TestDecodeConfiguration(t *testing.T) {
|
|||
Port: "42",
|
||||
},
|
||||
},
|
||||
TerminationDelay: func(i int) *int { return &i }(42),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
@ -79,6 +79,7 @@ func (p *Provider) buildTCPServiceConfiguration(ctx context.Context, container d
|
|||
if len(configuration.Services) == 0 {
|
||||
configuration.Services = make(map[string]*dynamic.TCPService)
|
||||
lb := &dynamic.TCPLoadBalancerService{}
|
||||
lb.SetDefaults()
|
||||
configuration.Services[serviceName] = &dynamic.TCPService{
|
||||
LoadBalancer: lb,
|
||||
}
|
||||
|
|
|
@ -13,6 +13,8 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Int(v int) *int { return &v }
|
||||
|
||||
func TestDefaultRule(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
|
@ -2092,6 +2094,7 @@ func Test_buildConfiguration(t *testing.T) {
|
|||
Address: "127.0.0.1:80",
|
||||
},
|
||||
},
|
||||
TerminationDelay: Int(100),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -2136,6 +2139,7 @@ func Test_buildConfiguration(t *testing.T) {
|
|||
Address: "127.0.0.1:80",
|
||||
},
|
||||
},
|
||||
TerminationDelay: Int(100),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -2190,6 +2194,7 @@ func Test_buildConfiguration(t *testing.T) {
|
|||
Address: "127.0.0.1:8080",
|
||||
},
|
||||
},
|
||||
TerminationDelay: Int(100),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -2268,6 +2273,7 @@ func Test_buildConfiguration(t *testing.T) {
|
|||
Address: "127.0.0.2:8080",
|
||||
},
|
||||
},
|
||||
TerminationDelay: Int(100),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -2331,6 +2337,53 @@ func Test_buildConfiguration(t *testing.T) {
|
|||
Address: "127.0.0.1:8080",
|
||||
},
|
||||
},
|
||||
TerminationDelay: Int(100),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
HTTP: &dynamic.HTTPConfiguration{
|
||||
Routers: map[string]*dynamic.Router{},
|
||||
Middlewares: map[string]*dynamic.Middleware{},
|
||||
Services: map[string]*dynamic.Service{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tcp with label for tcp service, with termination delay",
|
||||
containers: []dockerData{
|
||||
{
|
||||
ServiceName: "Test",
|
||||
Name: "Test",
|
||||
Labels: map[string]string{
|
||||
"traefik.tcp.services.foo.loadbalancer.server.port": "8080",
|
||||
"traefik.tcp.services.foo.loadbalancer.terminationdelay": "200",
|
||||
},
|
||||
NetworkSettings: networkSettings{
|
||||
Ports: nat.PortMap{
|
||||
nat.Port("80/tcp"): []nat.PortBinding{},
|
||||
},
|
||||
Networks: map[string]*networkData{
|
||||
"bridge": {
|
||||
Name: "bridge",
|
||||
Addr: "127.0.0.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: &dynamic.Configuration{
|
||||
TCP: &dynamic.TCPConfiguration{
|
||||
Routers: map[string]*dynamic.TCPRouter{},
|
||||
Services: map[string]*dynamic.TCPService{
|
||||
"foo": {
|
||||
LoadBalancer: &dynamic.TCPLoadBalancerService{
|
||||
Servers: []dynamic.TCPServer{
|
||||
{
|
||||
Address: "127.0.0.1:8080",
|
||||
},
|
||||
},
|
||||
TerminationDelay: Int(200),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
@ -141,6 +141,7 @@ func (p *Provider) buildTCPServiceConfiguration(ctx context.Context, app maratho
|
|||
if len(conf.Services) == 0 {
|
||||
conf.Services = make(map[string]*dynamic.TCPService)
|
||||
lb := &dynamic.TCPLoadBalancerService{}
|
||||
lb.SetDefaults()
|
||||
conf.Services[appName] = &dynamic.TCPService{
|
||||
LoadBalancer: lb,
|
||||
}
|
||||
|
|
|
@ -11,6 +11,8 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Int(v int) *int { return &v }
|
||||
|
||||
func TestGetConfigurationAPIErrors(t *testing.T) {
|
||||
fakeClient := newFakeClient(true, marathon.Applications{})
|
||||
|
||||
|
@ -1240,6 +1242,7 @@ func TestBuildConfiguration(t *testing.T) {
|
|||
Address: "localhost:80",
|
||||
},
|
||||
},
|
||||
TerminationDelay: Int(100),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -1271,6 +1274,7 @@ func TestBuildConfiguration(t *testing.T) {
|
|||
Address: "localhost:80",
|
||||
},
|
||||
},
|
||||
TerminationDelay: Int(100),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -1310,6 +1314,48 @@ func TestBuildConfiguration(t *testing.T) {
|
|||
Address: "localhost:8080",
|
||||
},
|
||||
},
|
||||
TerminationDelay: Int(100),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
HTTP: &dynamic.HTTPConfiguration{
|
||||
Routers: map[string]*dynamic.Router{},
|
||||
Middlewares: map[string]*dynamic.Middleware{},
|
||||
Services: map[string]*dynamic.Service{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "one app with tcp labels with port, with termination delay",
|
||||
applications: withApplications(
|
||||
application(
|
||||
appID("/app"),
|
||||
appPorts(80, 81),
|
||||
withTasks(localhostTask(taskPorts(80, 81))),
|
||||
withLabel("traefik.tcp.routers.foo.rule", "HostSNI(`foo.bar`)"),
|
||||
withLabel("traefik.tcp.routers.foo.tls", "true"),
|
||||
withLabel("traefik.tcp.services.foo.loadbalancer.server.port", "8080"),
|
||||
withLabel("traefik.tcp.services.foo.loadbalancer.terminationdelay", "200"),
|
||||
)),
|
||||
expected: &dynamic.Configuration{
|
||||
TCP: &dynamic.TCPConfiguration{
|
||||
Routers: map[string]*dynamic.TCPRouter{
|
||||
"foo": {
|
||||
Service: "foo",
|
||||
Rule: "HostSNI(`foo.bar`)",
|
||||
TLS: &dynamic.RouterTCPTLSConfig{},
|
||||
},
|
||||
},
|
||||
Services: map[string]*dynamic.TCPService{
|
||||
"foo": {
|
||||
LoadBalancer: &dynamic.TCPLoadBalancerService{
|
||||
Servers: []dynamic.TCPServer{
|
||||
{
|
||||
Address: "localhost:8080",
|
||||
},
|
||||
},
|
||||
TerminationDelay: Int(200),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -1350,6 +1396,7 @@ func TestBuildConfiguration(t *testing.T) {
|
|||
Address: "localhost:8080",
|
||||
},
|
||||
},
|
||||
TerminationDelay: Int(100),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
@ -75,6 +75,7 @@ func (p *Provider) buildTCPServiceConfiguration(ctx context.Context, service ran
|
|||
if len(configuration.Services) == 0 {
|
||||
configuration.Services = make(map[string]*dynamic.TCPService)
|
||||
lb := &dynamic.TCPLoadBalancerService{}
|
||||
lb.SetDefaults()
|
||||
configuration.Services[serviceName] = &dynamic.TCPService{
|
||||
LoadBalancer: lb,
|
||||
}
|
||||
|
|
|
@ -9,6 +9,8 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Int(v int) *int { return &v }
|
||||
|
||||
func Test_buildConfiguration(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
|
@ -512,6 +514,7 @@ func Test_buildConfiguration(t *testing.T) {
|
|||
Address: "127.0.0.1:80",
|
||||
},
|
||||
},
|
||||
TerminationDelay: Int(100),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -548,6 +551,7 @@ func Test_buildConfiguration(t *testing.T) {
|
|||
Address: "127.0.0.1:80",
|
||||
},
|
||||
},
|
||||
TerminationDelay: Int(100),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -590,6 +594,7 @@ func Test_buildConfiguration(t *testing.T) {
|
|||
Address: "127.0.0.1:8080",
|
||||
},
|
||||
},
|
||||
TerminationDelay: Int(100),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -638,6 +643,7 @@ func Test_buildConfiguration(t *testing.T) {
|
|||
Address: "127.0.0.2:8080",
|
||||
},
|
||||
},
|
||||
TerminationDelay: Int(100),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
@ -693,6 +699,45 @@ func Test_buildConfiguration(t *testing.T) {
|
|||
Address: "127.0.0.1:8080",
|
||||
},
|
||||
},
|
||||
TerminationDelay: Int(100),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
HTTP: &dynamic.HTTPConfiguration{
|
||||
Routers: map[string]*dynamic.Router{},
|
||||
Middlewares: map[string]*dynamic.Middleware{},
|
||||
Services: map[string]*dynamic.Service{},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tcp with label for tcp service, with termination delay",
|
||||
containers: []rancherData{
|
||||
{
|
||||
Name: "Test",
|
||||
Labels: map[string]string{
|
||||
"traefik.tcp.services.foo.loadbalancer.server.port": "8080",
|
||||
"traefik.tcp.services.foo.loadbalancer.terminationdelay": "200",
|
||||
},
|
||||
Port: "80/tcp",
|
||||
Containers: []string{"127.0.0.1"},
|
||||
Health: "",
|
||||
State: "",
|
||||
},
|
||||
},
|
||||
expected: &dynamic.Configuration{
|
||||
TCP: &dynamic.TCPConfiguration{
|
||||
Routers: map[string]*dynamic.TCPRouter{},
|
||||
Services: map[string]*dynamic.TCPService{
|
||||
"foo": {
|
||||
LoadBalancer: &dynamic.TCPLoadBalancerService{
|
||||
Servers: []dynamic.TCPServer{
|
||||
{
|
||||
Address: "127.0.0.1:8080",
|
||||
},
|
||||
},
|
||||
TerminationDelay: Int(200),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
@ -37,7 +37,7 @@ func newHTTPForwarder(ln net.Listener) *httpForwarder {
|
|||
}
|
||||
|
||||
// ServeTCP uses the connection to serve it later in "Accept"
|
||||
func (h *httpForwarder) ServeTCP(conn net.Conn) {
|
||||
func (h *httpForwarder) ServeTCP(conn tcp.WriteCloser) {
|
||||
h.connChan <- conn
|
||||
}
|
||||
|
||||
|
@ -99,7 +99,36 @@ func NewTCPEntryPoint(ctx context.Context, configuration *static.EntryPoint) (*T
|
|||
}, nil
|
||||
}
|
||||
|
||||
// writeCloserWrapper wraps together a connection, and the concrete underlying
|
||||
// connection type that was found to satisfy WriteCloser.
|
||||
type writeCloserWrapper struct {
|
||||
net.Conn
|
||||
writeCloser tcp.WriteCloser
|
||||
}
|
||||
|
||||
func (c *writeCloserWrapper) CloseWrite() error {
|
||||
return c.writeCloser.CloseWrite()
|
||||
}
|
||||
|
||||
// writeCloser returns the given connection, augmented with the WriteCloser
|
||||
// implementation, if any was found within the underlying conn.
|
||||
func writeCloser(conn net.Conn) (tcp.WriteCloser, error) {
|
||||
switch typedConn := conn.(type) {
|
||||
case *proxyprotocol.Conn:
|
||||
underlying, err := writeCloser(typedConn.Conn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &writeCloserWrapper{writeCloser: underlying, Conn: typedConn}, nil
|
||||
case *net.TCPConn:
|
||||
return typedConn, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown connection type %T", typedConn)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *TCPEntryPoint) startTCP(ctx context.Context) {
|
||||
|
||||
log.FromContext(ctx).Debugf("Start TCP Server")
|
||||
|
||||
for {
|
||||
|
@ -109,8 +138,13 @@ func (e *TCPEntryPoint) startTCP(ctx context.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
writeCloser, err := writeCloser(conn)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
safe.Go(func() {
|
||||
e.switcher.ServeTCP(newTrackedConnection(conn, e.tracker))
|
||||
e.switcher.ServeTCP(newTrackedConnection(writeCloser, e.tracker))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -374,20 +408,20 @@ func createHTTPServer(ln net.Listener, configuration *static.EntryPoint, withH2c
|
|||
}, nil
|
||||
}
|
||||
|
||||
func newTrackedConnection(conn net.Conn, tracker *connectionTracker) *trackedConnection {
|
||||
func newTrackedConnection(conn tcp.WriteCloser, tracker *connectionTracker) *trackedConnection {
|
||||
tracker.AddConnection(conn)
|
||||
return &trackedConnection{
|
||||
Conn: conn,
|
||||
tracker: tracker,
|
||||
WriteCloser: conn,
|
||||
tracker: tracker,
|
||||
}
|
||||
}
|
||||
|
||||
type trackedConnection struct {
|
||||
tracker *connectionTracker
|
||||
net.Conn
|
||||
tcp.WriteCloser
|
||||
}
|
||||
|
||||
func (t *trackedConnection) Close() error {
|
||||
t.tracker.RemoveConnection(t.Conn)
|
||||
return t.Conn.Close()
|
||||
t.tracker.RemoveConnection(t.WriteCloser)
|
||||
return t.WriteCloser.Close()
|
||||
}
|
||||
|
|
|
@ -113,7 +113,7 @@ func TestShutdownTCPConn(t *testing.T) {
|
|||
go entryPoint.startTCP(context.Background())
|
||||
|
||||
router := &tcp.Router{}
|
||||
router.AddCatchAllNoTLS(tcp.HandlerFunc(func(conn net.Conn) {
|
||||
router.AddCatchAllNoTLS(tcp.HandlerFunc(func(conn tcp.WriteCloser) {
|
||||
_, err := http.ReadRequest(bufio.NewReader(conn))
|
||||
require.NoError(t, err)
|
||||
time.Sleep(1 * time.Second)
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/v2/pkg/config/runtime"
|
||||
"github.com/containous/traefik/v2/pkg/log"
|
||||
|
@ -44,13 +45,19 @@ func (m *Manager) BuildTCP(rootCtx context.Context, serviceName string) (tcp.Han
|
|||
|
||||
loadBalancer := tcp.NewRRLoadBalancer()
|
||||
|
||||
if conf.LoadBalancer.TerminationDelay == nil {
|
||||
defaultTerminationDelay := 100
|
||||
conf.LoadBalancer.TerminationDelay = &defaultTerminationDelay
|
||||
}
|
||||
duration := time.Millisecond * time.Duration(*conf.LoadBalancer.TerminationDelay)
|
||||
|
||||
for name, server := range conf.LoadBalancer.Servers {
|
||||
if _, _, err := net.SplitHostPort(server.Address); err != nil {
|
||||
logger.Errorf("In service %q: %v", serviceQualifiedName, err)
|
||||
continue
|
||||
}
|
||||
|
||||
handler, err := tcp.NewProxy(server.Address)
|
||||
handler, err := tcp.NewProxy(server.Address, duration)
|
||||
if err != nil {
|
||||
logger.Errorf("In service %q server %q: %v", serviceQualifiedName, server.Address, err)
|
||||
continue
|
||||
|
|
|
@ -6,14 +6,23 @@ import (
|
|||
|
||||
// Handler is the TCP Handlers interface
|
||||
type Handler interface {
|
||||
ServeTCP(conn net.Conn)
|
||||
ServeTCP(conn WriteCloser)
|
||||
}
|
||||
|
||||
// The HandlerFunc type is an adapter to allow the use of
|
||||
// ordinary functions as handlers.
|
||||
type HandlerFunc func(conn net.Conn)
|
||||
type HandlerFunc func(conn WriteCloser)
|
||||
|
||||
// ServeTCP serves tcp
|
||||
func (f HandlerFunc) ServeTCP(conn net.Conn) {
|
||||
func (f HandlerFunc) ServeTCP(conn WriteCloser) {
|
||||
f(conn)
|
||||
}
|
||||
|
||||
// WriteCloser describes a net.Conn with a CloseWrite method.
|
||||
type WriteCloser interface {
|
||||
net.Conn
|
||||
// CloseWrite on a network connection, indicates that the issuer of the call
|
||||
// has terminated sending on that connection.
|
||||
// It corresponds to sending a FIN packet.
|
||||
CloseWrite() error
|
||||
}
|
||||
|
|
|
@ -3,28 +3,32 @@ package tcp
|
|||
import (
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/v2/pkg/log"
|
||||
)
|
||||
|
||||
// Proxy forwards a TCP request to a TCP service
|
||||
type Proxy struct {
|
||||
target *net.TCPAddr
|
||||
target *net.TCPAddr
|
||||
terminationDelay time.Duration
|
||||
}
|
||||
|
||||
// NewProxy creates a new Proxy
|
||||
func NewProxy(address string) (*Proxy, error) {
|
||||
func NewProxy(address string, terminationDelay time.Duration) (*Proxy, error) {
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Proxy{target: tcpAddr}, nil
|
||||
return &Proxy{target: tcpAddr, terminationDelay: terminationDelay}, nil
|
||||
}
|
||||
|
||||
// ServeTCP forwards the connection to a service
|
||||
func (p *Proxy) ServeTCP(conn net.Conn) {
|
||||
func (p *Proxy) ServeTCP(conn WriteCloser) {
|
||||
log.Debugf("Handling connection from %s", conn.RemoteAddr())
|
||||
|
||||
// needed because of e.g. server.trackedConnection
|
||||
defer conn.Close()
|
||||
|
||||
connBackend, err := net.DialTCP("tcp", nil, p.target)
|
||||
|
@ -32,19 +36,35 @@ func (p *Proxy) ServeTCP(conn net.Conn) {
|
|||
log.Errorf("Error while connection to backend: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// maybe not needed, but just in case
|
||||
defer connBackend.Close()
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go connCopy(conn, connBackend, errChan)
|
||||
go connCopy(connBackend, conn, errChan)
|
||||
errChan := make(chan error)
|
||||
go p.connCopy(conn, connBackend, errChan)
|
||||
go p.connCopy(connBackend, conn, errChan)
|
||||
|
||||
err = <-errChan
|
||||
if err != nil {
|
||||
log.Errorf("Error during connection: %v", err)
|
||||
log.WithoutContext().Errorf("Error during connection: %v", err)
|
||||
}
|
||||
|
||||
<-errChan
|
||||
}
|
||||
|
||||
func connCopy(dst, src net.Conn, errCh chan error) {
|
||||
func (p Proxy) connCopy(dst, src WriteCloser, errCh chan error) {
|
||||
_, err := io.Copy(dst, src)
|
||||
errCh <- err
|
||||
|
||||
errClose := dst.CloseWrite()
|
||||
if errClose != nil {
|
||||
log.WithoutContext().Errorf("Error while terminating connection: %v", errClose)
|
||||
}
|
||||
|
||||
if p.terminationDelay >= 0 {
|
||||
err := dst.SetReadDeadline(time.Now().Add(p.terminationDelay))
|
||||
if err != nil {
|
||||
log.WithoutContext().Errorf("Error while setting deadline: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
81
pkg/tcp/proxy_test.go
Normal file
81
pkg/tcp/proxy_test.go
Normal file
|
@ -0,0 +1,81 @@
|
|||
package tcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func fakeRedis(t *testing.T, listener net.Listener) {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
fmt.Println("Accept on server")
|
||||
require.NoError(t, err)
|
||||
for {
|
||||
withErr := false
|
||||
buf := make([]byte, 64)
|
||||
if _, err := conn.Read(buf); err != nil {
|
||||
withErr = true
|
||||
}
|
||||
|
||||
if string(buf[:4]) == "ping" {
|
||||
time.Sleep(time.Millisecond * 1)
|
||||
if _, err := conn.Write([]byte("PONG")); err != nil {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
if withErr {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseWrite(t *testing.T) {
|
||||
backendListener, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
go fakeRedis(t, backendListener)
|
||||
_, port, err := net.SplitHostPort(backendListener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy, err := NewProxy(":"+port, 10*time.Millisecond)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyListener, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := proxyListener.Accept()
|
||||
require.NoError(t, err)
|
||||
proxy.ServeTCP(conn.(*net.TCPConn))
|
||||
}
|
||||
}()
|
||||
|
||||
_, port, err = net.SplitHostPort(proxyListener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := net.Dial("tcp", ":"+port)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = conn.Write([]byte("ping\n"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.(*net.TCPConn).CloseWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
var buf []byte
|
||||
buffer := bytes.NewBuffer(buf)
|
||||
n, err := io.Copy(buffer, conn)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(4), n)
|
||||
require.Equal(t, "PONG", buffer.String())
|
||||
}
|
|
@ -25,7 +25,7 @@ type Router struct {
|
|||
}
|
||||
|
||||
// ServeTCP forwards the connection to the right TCP/HTTP handler
|
||||
func (r *Router) ServeTCP(conn net.Conn) {
|
||||
func (r *Router) ServeTCP(conn WriteCloser) {
|
||||
// FIXME -- Check if ProxyProtocol changes the first bytes of the request
|
||||
|
||||
if r.catchAllNoTLS != nil && len(r.routingTable) == 0 && r.httpsHandler == nil {
|
||||
|
@ -99,11 +99,11 @@ func (r *Router) AddCatchAllNoTLS(handler Handler) {
|
|||
}
|
||||
|
||||
// GetConn creates a connection proxy with a peeked string
|
||||
func (r *Router) GetConn(conn net.Conn, peeked string) net.Conn {
|
||||
func (r *Router) GetConn(conn WriteCloser, peeked string) WriteCloser {
|
||||
// FIXME should it really be on Router ?
|
||||
conn = &Conn{
|
||||
Peeked: []byte(peeked),
|
||||
Conn: conn,
|
||||
Peeked: []byte(peeked),
|
||||
WriteCloser: conn,
|
||||
}
|
||||
return conn
|
||||
}
|
||||
|
@ -157,7 +157,7 @@ type Conn struct {
|
|||
// It can be type asserted against *net.TCPConn or other types
|
||||
// as needed. It should not be read from directly unless
|
||||
// Peeked is nil.
|
||||
net.Conn
|
||||
WriteCloser
|
||||
}
|
||||
|
||||
// Read reads bytes from the connection (using the buffer prior to actually reading)
|
||||
|
@ -170,7 +170,7 @@ func (c *Conn) Read(p []byte) (n int, err error) {
|
|||
}
|
||||
return n, nil
|
||||
}
|
||||
return c.Conn.Read(p)
|
||||
return c.WriteCloser.Read(p)
|
||||
}
|
||||
|
||||
// clientHelloServerName returns the SNI server name inside the TLS ClientHello,
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package tcp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/containous/traefik/v2/pkg/log"
|
||||
|
@ -20,7 +19,7 @@ func NewRRLoadBalancer() *RRLoadBalancer {
|
|||
}
|
||||
|
||||
// ServeTCP forwards the connection to the right service
|
||||
func (r *RRLoadBalancer) ServeTCP(conn net.Conn) {
|
||||
func (r *RRLoadBalancer) ServeTCP(conn WriteCloser) {
|
||||
if len(r.servers) == 0 {
|
||||
log.WithoutContext().Error("no available server")
|
||||
return
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package tcp
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/containous/traefik/v2/pkg/safe"
|
||||
)
|
||||
|
||||
|
@ -12,7 +10,7 @@ type HandlerSwitcher struct {
|
|||
}
|
||||
|
||||
// ServeTCP forwards the TCP connection to the current active handler
|
||||
func (s *HandlerSwitcher) ServeTCP(conn net.Conn) {
|
||||
func (s *HandlerSwitcher) ServeTCP(conn WriteCloser) {
|
||||
handler := s.router.Get()
|
||||
h, ok := handler.(Handler)
|
||||
if ok {
|
||||
|
|
|
@ -2,7 +2,6 @@ package tcp
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
)
|
||||
|
||||
// TLSHandler handles TLS connections
|
||||
|
@ -12,6 +11,6 @@ type TLSHandler struct {
|
|||
}
|
||||
|
||||
// ServeTCP terminates the TLS connection
|
||||
func (t *TLSHandler) ServeTCP(conn net.Conn) {
|
||||
func (t *TLSHandler) ServeTCP(conn WriteCloser) {
|
||||
t.Next.ServeTCP(tls.Server(conn, t.Config))
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue