On client CloseWrite, do CloseWrite instead of Close for backend
Co-authored-by: Mathieu Lonjaret <mathieu.lonjaret@gmail.com>
This commit is contained in:
parent
401b3afa3b
commit
b55be9fdea
25 changed files with 393 additions and 36 deletions
|
@ -6,14 +6,23 @@ import (
|
|||
|
||||
// Handler is the TCP Handlers interface
|
||||
type Handler interface {
|
||||
ServeTCP(conn net.Conn)
|
||||
ServeTCP(conn WriteCloser)
|
||||
}
|
||||
|
||||
// The HandlerFunc type is an adapter to allow the use of
|
||||
// ordinary functions as handlers.
|
||||
type HandlerFunc func(conn net.Conn)
|
||||
type HandlerFunc func(conn WriteCloser)
|
||||
|
||||
// ServeTCP serves tcp
|
||||
func (f HandlerFunc) ServeTCP(conn net.Conn) {
|
||||
func (f HandlerFunc) ServeTCP(conn WriteCloser) {
|
||||
f(conn)
|
||||
}
|
||||
|
||||
// WriteCloser describes a net.Conn with a CloseWrite method.
|
||||
type WriteCloser interface {
|
||||
net.Conn
|
||||
// CloseWrite on a network connection, indicates that the issuer of the call
|
||||
// has terminated sending on that connection.
|
||||
// It corresponds to sending a FIN packet.
|
||||
CloseWrite() error
|
||||
}
|
||||
|
|
|
@ -3,28 +3,32 @@ package tcp
|
|||
import (
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/containous/traefik/v2/pkg/log"
|
||||
)
|
||||
|
||||
// Proxy forwards a TCP request to a TCP service
|
||||
type Proxy struct {
|
||||
target *net.TCPAddr
|
||||
target *net.TCPAddr
|
||||
terminationDelay time.Duration
|
||||
}
|
||||
|
||||
// NewProxy creates a new Proxy
|
||||
func NewProxy(address string) (*Proxy, error) {
|
||||
func NewProxy(address string, terminationDelay time.Duration) (*Proxy, error) {
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Proxy{target: tcpAddr}, nil
|
||||
return &Proxy{target: tcpAddr, terminationDelay: terminationDelay}, nil
|
||||
}
|
||||
|
||||
// ServeTCP forwards the connection to a service
|
||||
func (p *Proxy) ServeTCP(conn net.Conn) {
|
||||
func (p *Proxy) ServeTCP(conn WriteCloser) {
|
||||
log.Debugf("Handling connection from %s", conn.RemoteAddr())
|
||||
|
||||
// needed because of e.g. server.trackedConnection
|
||||
defer conn.Close()
|
||||
|
||||
connBackend, err := net.DialTCP("tcp", nil, p.target)
|
||||
|
@ -32,19 +36,35 @@ func (p *Proxy) ServeTCP(conn net.Conn) {
|
|||
log.Errorf("Error while connection to backend: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// maybe not needed, but just in case
|
||||
defer connBackend.Close()
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
go connCopy(conn, connBackend, errChan)
|
||||
go connCopy(connBackend, conn, errChan)
|
||||
errChan := make(chan error)
|
||||
go p.connCopy(conn, connBackend, errChan)
|
||||
go p.connCopy(connBackend, conn, errChan)
|
||||
|
||||
err = <-errChan
|
||||
if err != nil {
|
||||
log.Errorf("Error during connection: %v", err)
|
||||
log.WithoutContext().Errorf("Error during connection: %v", err)
|
||||
}
|
||||
|
||||
<-errChan
|
||||
}
|
||||
|
||||
func connCopy(dst, src net.Conn, errCh chan error) {
|
||||
func (p Proxy) connCopy(dst, src WriteCloser, errCh chan error) {
|
||||
_, err := io.Copy(dst, src)
|
||||
errCh <- err
|
||||
|
||||
errClose := dst.CloseWrite()
|
||||
if errClose != nil {
|
||||
log.WithoutContext().Errorf("Error while terminating connection: %v", errClose)
|
||||
}
|
||||
|
||||
if p.terminationDelay >= 0 {
|
||||
err := dst.SetReadDeadline(time.Now().Add(p.terminationDelay))
|
||||
if err != nil {
|
||||
log.WithoutContext().Errorf("Error while setting deadline: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
81
pkg/tcp/proxy_test.go
Normal file
81
pkg/tcp/proxy_test.go
Normal file
|
@ -0,0 +1,81 @@
|
|||
package tcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func fakeRedis(t *testing.T, listener net.Listener) {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
fmt.Println("Accept on server")
|
||||
require.NoError(t, err)
|
||||
for {
|
||||
withErr := false
|
||||
buf := make([]byte, 64)
|
||||
if _, err := conn.Read(buf); err != nil {
|
||||
withErr = true
|
||||
}
|
||||
|
||||
if string(buf[:4]) == "ping" {
|
||||
time.Sleep(time.Millisecond * 1)
|
||||
if _, err := conn.Write([]byte("PONG")); err != nil {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
if withErr {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseWrite(t *testing.T) {
|
||||
backendListener, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
go fakeRedis(t, backendListener)
|
||||
_, port, err := net.SplitHostPort(backendListener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
proxy, err := NewProxy(":"+port, 10*time.Millisecond)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyListener, err := net.Listen("tcp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := proxyListener.Accept()
|
||||
require.NoError(t, err)
|
||||
proxy.ServeTCP(conn.(*net.TCPConn))
|
||||
}
|
||||
}()
|
||||
|
||||
_, port, err = net.SplitHostPort(proxyListener.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
conn, err := net.Dial("tcp", ":"+port)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = conn.Write([]byte("ping\n"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.(*net.TCPConn).CloseWrite()
|
||||
require.NoError(t, err)
|
||||
|
||||
var buf []byte
|
||||
buffer := bytes.NewBuffer(buf)
|
||||
n, err := io.Copy(buffer, conn)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(4), n)
|
||||
require.Equal(t, "PONG", buffer.String())
|
||||
}
|
|
@ -25,7 +25,7 @@ type Router struct {
|
|||
}
|
||||
|
||||
// ServeTCP forwards the connection to the right TCP/HTTP handler
|
||||
func (r *Router) ServeTCP(conn net.Conn) {
|
||||
func (r *Router) ServeTCP(conn WriteCloser) {
|
||||
// FIXME -- Check if ProxyProtocol changes the first bytes of the request
|
||||
|
||||
if r.catchAllNoTLS != nil && len(r.routingTable) == 0 && r.httpsHandler == nil {
|
||||
|
@ -99,11 +99,11 @@ func (r *Router) AddCatchAllNoTLS(handler Handler) {
|
|||
}
|
||||
|
||||
// GetConn creates a connection proxy with a peeked string
|
||||
func (r *Router) GetConn(conn net.Conn, peeked string) net.Conn {
|
||||
func (r *Router) GetConn(conn WriteCloser, peeked string) WriteCloser {
|
||||
// FIXME should it really be on Router ?
|
||||
conn = &Conn{
|
||||
Peeked: []byte(peeked),
|
||||
Conn: conn,
|
||||
Peeked: []byte(peeked),
|
||||
WriteCloser: conn,
|
||||
}
|
||||
return conn
|
||||
}
|
||||
|
@ -157,7 +157,7 @@ type Conn struct {
|
|||
// It can be type asserted against *net.TCPConn or other types
|
||||
// as needed. It should not be read from directly unless
|
||||
// Peeked is nil.
|
||||
net.Conn
|
||||
WriteCloser
|
||||
}
|
||||
|
||||
// Read reads bytes from the connection (using the buffer prior to actually reading)
|
||||
|
@ -170,7 +170,7 @@ func (c *Conn) Read(p []byte) (n int, err error) {
|
|||
}
|
||||
return n, nil
|
||||
}
|
||||
return c.Conn.Read(p)
|
||||
return c.WriteCloser.Read(p)
|
||||
}
|
||||
|
||||
// clientHelloServerName returns the SNI server name inside the TLS ClientHello,
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package tcp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/containous/traefik/v2/pkg/log"
|
||||
|
@ -20,7 +19,7 @@ func NewRRLoadBalancer() *RRLoadBalancer {
|
|||
}
|
||||
|
||||
// ServeTCP forwards the connection to the right service
|
||||
func (r *RRLoadBalancer) ServeTCP(conn net.Conn) {
|
||||
func (r *RRLoadBalancer) ServeTCP(conn WriteCloser) {
|
||||
if len(r.servers) == 0 {
|
||||
log.WithoutContext().Error("no available server")
|
||||
return
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package tcp
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/containous/traefik/v2/pkg/safe"
|
||||
)
|
||||
|
||||
|
@ -12,7 +10,7 @@ type HandlerSwitcher struct {
|
|||
}
|
||||
|
||||
// ServeTCP forwards the TCP connection to the current active handler
|
||||
func (s *HandlerSwitcher) ServeTCP(conn net.Conn) {
|
||||
func (s *HandlerSwitcher) ServeTCP(conn WriteCloser) {
|
||||
handler := s.router.Get()
|
||||
h, ok := handler.(Handler)
|
||||
if ok {
|
||||
|
|
|
@ -2,7 +2,6 @@ package tcp
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
)
|
||||
|
||||
// TLSHandler handles TLS connections
|
||||
|
@ -12,6 +11,6 @@ type TLSHandler struct {
|
|||
}
|
||||
|
||||
// ServeTCP terminates the TLS connection
|
||||
func (t *TLSHandler) ServeTCP(conn net.Conn) {
|
||||
func (t *TLSHandler) ServeTCP(conn WriteCloser) {
|
||||
t.Next.ServeTCP(tls.Server(conn, t.Config))
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue