On client CloseWrite, do CloseWrite instead of Close for backend

Co-authored-by: Mathieu Lonjaret <mathieu.lonjaret@gmail.com>
This commit is contained in:
Julien Salleyron 2019-09-13 17:46:04 +02:00 committed by Traefiker Bot
parent 401b3afa3b
commit b55be9fdea
25 changed files with 393 additions and 36 deletions

View file

@ -6,14 +6,23 @@ import (
// Handler is the TCP Handlers interface
type Handler interface {
ServeTCP(conn net.Conn)
ServeTCP(conn WriteCloser)
}
// The HandlerFunc type is an adapter to allow the use of
// ordinary functions as handlers.
type HandlerFunc func(conn net.Conn)
type HandlerFunc func(conn WriteCloser)
// ServeTCP serves tcp
func (f HandlerFunc) ServeTCP(conn net.Conn) {
func (f HandlerFunc) ServeTCP(conn WriteCloser) {
f(conn)
}
// WriteCloser describes a net.Conn with a CloseWrite method.
type WriteCloser interface {
net.Conn
// CloseWrite on a network connection, indicates that the issuer of the call
// has terminated sending on that connection.
// It corresponds to sending a FIN packet.
CloseWrite() error
}

View file

@ -3,28 +3,32 @@ package tcp
import (
"io"
"net"
"time"
"github.com/containous/traefik/v2/pkg/log"
)
// Proxy forwards a TCP request to a TCP service
type Proxy struct {
target *net.TCPAddr
target *net.TCPAddr
terminationDelay time.Duration
}
// NewProxy creates a new Proxy
func NewProxy(address string) (*Proxy, error) {
func NewProxy(address string, terminationDelay time.Duration) (*Proxy, error) {
tcpAddr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
return nil, err
}
return &Proxy{target: tcpAddr}, nil
return &Proxy{target: tcpAddr, terminationDelay: terminationDelay}, nil
}
// ServeTCP forwards the connection to a service
func (p *Proxy) ServeTCP(conn net.Conn) {
func (p *Proxy) ServeTCP(conn WriteCloser) {
log.Debugf("Handling connection from %s", conn.RemoteAddr())
// needed because of e.g. server.trackedConnection
defer conn.Close()
connBackend, err := net.DialTCP("tcp", nil, p.target)
@ -32,19 +36,35 @@ func (p *Proxy) ServeTCP(conn net.Conn) {
log.Errorf("Error while connection to backend: %v", err)
return
}
// maybe not needed, but just in case
defer connBackend.Close()
errChan := make(chan error, 1)
go connCopy(conn, connBackend, errChan)
go connCopy(connBackend, conn, errChan)
errChan := make(chan error)
go p.connCopy(conn, connBackend, errChan)
go p.connCopy(connBackend, conn, errChan)
err = <-errChan
if err != nil {
log.Errorf("Error during connection: %v", err)
log.WithoutContext().Errorf("Error during connection: %v", err)
}
<-errChan
}
func connCopy(dst, src net.Conn, errCh chan error) {
func (p Proxy) connCopy(dst, src WriteCloser, errCh chan error) {
_, err := io.Copy(dst, src)
errCh <- err
errClose := dst.CloseWrite()
if errClose != nil {
log.WithoutContext().Errorf("Error while terminating connection: %v", errClose)
}
if p.terminationDelay >= 0 {
err := dst.SetReadDeadline(time.Now().Add(p.terminationDelay))
if err != nil {
log.WithoutContext().Errorf("Error while setting deadline: %v", err)
}
}
}

81
pkg/tcp/proxy_test.go Normal file
View file

@ -0,0 +1,81 @@
package tcp
import (
"bytes"
"fmt"
"io"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func fakeRedis(t *testing.T, listener net.Listener) {
for {
conn, err := listener.Accept()
fmt.Println("Accept on server")
require.NoError(t, err)
for {
withErr := false
buf := make([]byte, 64)
if _, err := conn.Read(buf); err != nil {
withErr = true
}
if string(buf[:4]) == "ping" {
time.Sleep(time.Millisecond * 1)
if _, err := conn.Write([]byte("PONG")); err != nil {
conn.Close()
return
}
}
if withErr {
conn.Close()
return
}
}
}
}
func TestCloseWrite(t *testing.T) {
backendListener, err := net.Listen("tcp", ":0")
require.NoError(t, err)
go fakeRedis(t, backendListener)
_, port, err := net.SplitHostPort(backendListener.Addr().String())
require.NoError(t, err)
proxy, err := NewProxy(":"+port, 10*time.Millisecond)
require.NoError(t, err)
proxyListener, err := net.Listen("tcp", ":0")
require.NoError(t, err)
go func() {
for {
conn, err := proxyListener.Accept()
require.NoError(t, err)
proxy.ServeTCP(conn.(*net.TCPConn))
}
}()
_, port, err = net.SplitHostPort(proxyListener.Addr().String())
require.NoError(t, err)
conn, err := net.Dial("tcp", ":"+port)
require.NoError(t, err)
_, err = conn.Write([]byte("ping\n"))
require.NoError(t, err)
err = conn.(*net.TCPConn).CloseWrite()
require.NoError(t, err)
var buf []byte
buffer := bytes.NewBuffer(buf)
n, err := io.Copy(buffer, conn)
require.NoError(t, err)
require.Equal(t, int64(4), n)
require.Equal(t, "PONG", buffer.String())
}

View file

@ -25,7 +25,7 @@ type Router struct {
}
// ServeTCP forwards the connection to the right TCP/HTTP handler
func (r *Router) ServeTCP(conn net.Conn) {
func (r *Router) ServeTCP(conn WriteCloser) {
// FIXME -- Check if ProxyProtocol changes the first bytes of the request
if r.catchAllNoTLS != nil && len(r.routingTable) == 0 && r.httpsHandler == nil {
@ -99,11 +99,11 @@ func (r *Router) AddCatchAllNoTLS(handler Handler) {
}
// GetConn creates a connection proxy with a peeked string
func (r *Router) GetConn(conn net.Conn, peeked string) net.Conn {
func (r *Router) GetConn(conn WriteCloser, peeked string) WriteCloser {
// FIXME should it really be on Router ?
conn = &Conn{
Peeked: []byte(peeked),
Conn: conn,
Peeked: []byte(peeked),
WriteCloser: conn,
}
return conn
}
@ -157,7 +157,7 @@ type Conn struct {
// It can be type asserted against *net.TCPConn or other types
// as needed. It should not be read from directly unless
// Peeked is nil.
net.Conn
WriteCloser
}
// Read reads bytes from the connection (using the buffer prior to actually reading)
@ -170,7 +170,7 @@ func (c *Conn) Read(p []byte) (n int, err error) {
}
return n, nil
}
return c.Conn.Read(p)
return c.WriteCloser.Read(p)
}
// clientHelloServerName returns the SNI server name inside the TLS ClientHello,

View file

@ -1,7 +1,6 @@
package tcp
import (
"net"
"sync"
"github.com/containous/traefik/v2/pkg/log"
@ -20,7 +19,7 @@ func NewRRLoadBalancer() *RRLoadBalancer {
}
// ServeTCP forwards the connection to the right service
func (r *RRLoadBalancer) ServeTCP(conn net.Conn) {
func (r *RRLoadBalancer) ServeTCP(conn WriteCloser) {
if len(r.servers) == 0 {
log.WithoutContext().Error("no available server")
return

View file

@ -1,8 +1,6 @@
package tcp
import (
"net"
"github.com/containous/traefik/v2/pkg/safe"
)
@ -12,7 +10,7 @@ type HandlerSwitcher struct {
}
// ServeTCP forwards the TCP connection to the current active handler
func (s *HandlerSwitcher) ServeTCP(conn net.Conn) {
func (s *HandlerSwitcher) ServeTCP(conn WriteCloser) {
handler := s.router.Get()
h, ok := handler.(Handler)
if ok {

View file

@ -2,7 +2,6 @@ package tcp
import (
"crypto/tls"
"net"
)
// TLSHandler handles TLS connections
@ -12,6 +11,6 @@ type TLSHandler struct {
}
// ServeTCP terminates the TLS connection
func (t *TLSHandler) ServeTCP(conn net.Conn) {
func (t *TLSHandler) ServeTCP(conn WriteCloser) {
t.Next.ServeTCP(tls.Server(conn, t.Config))
}