1
0
Fork 0

feat: add in flight connection middleware

This commit is contained in:
Tom Moulard 2021-11-29 17:12:06 +01:00 committed by GitHub
parent 95fabeae73
commit 93de7cf0c0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 326 additions and 4 deletions

View file

@ -0,0 +1,92 @@
package tcpinflightconn
import (
"context"
"fmt"
"net"
"sync"
"github.com/traefik/traefik/v2/pkg/config/dynamic"
"github.com/traefik/traefik/v2/pkg/log"
"github.com/traefik/traefik/v2/pkg/middlewares"
"github.com/traefik/traefik/v2/pkg/tcp"
)
const typeName = "InFlightConnTCP"
type inFlightConn struct {
name string
next tcp.Handler
maxConnections int64
mu sync.Mutex
connections map[string]int64 // current number of connections by remote IP.
}
// New creates a max connections middleware.
// The connections are identified and grouped by remote IP.
func New(ctx context.Context, next tcp.Handler, config dynamic.TCPInFlightConn, name string) (tcp.Handler, error) {
logger := log.FromContext(middlewares.GetLoggerCtx(ctx, name, typeName))
logger.Debug("Creating middleware")
return &inFlightConn{
name: name,
next: next,
connections: make(map[string]int64),
maxConnections: config.Amount,
}, nil
}
// ServeTCP serves the given TCP connection.
func (i *inFlightConn) ServeTCP(conn tcp.WriteCloser) {
ctx := middlewares.GetLoggerCtx(context.Background(), i.name, typeName)
logger := log.FromContext(ctx)
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err != nil {
logger.Errorf("Cannot parse IP from remote addr: %v", err)
conn.Close()
return
}
if err = i.increment(ip); err != nil {
logger.Errorf("Connection rejected: %v", err)
conn.Close()
return
}
defer i.decrement(ip)
i.next.ServeTCP(conn)
}
// increment increases the counter for the number of connections tracked for the
// given IP.
// It returns an error if the counter would go above the max allowed number of
// connections.
func (i *inFlightConn) increment(ip string) error {
i.mu.Lock()
defer i.mu.Unlock()
if i.connections[ip] >= i.maxConnections {
return fmt.Errorf("max number of connections reached for %s", ip)
}
i.connections[ip]++
return nil
}
// decrement decreases the counter for the number of connections tracked for the
// given IP.
// It ensures that the counter does not go below zero.
func (i *inFlightConn) decrement(ip string) {
i.mu.Lock()
defer i.mu.Unlock()
if i.connections[ip] <= 0 {
return
}
i.connections[ip]--
}

View file

@ -0,0 +1,95 @@
package tcpinflightconn
import (
"context"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/traefik/traefik/v2/pkg/config/dynamic"
"github.com/traefik/traefik/v2/pkg/tcp"
)
func TestInFlightConn_ServeTCP(t *testing.T) {
proceedCh := make(chan struct{})
waitCh := make(chan struct{})
finishCh := make(chan struct{})
next := tcp.HandlerFunc(func(conn tcp.WriteCloser) {
proceedCh <- struct{}{}
if fc, ok := conn.(fakeConn); !ok || !fc.wait {
return
}
<-waitCh
finishCh <- struct{}{}
})
middleware, err := New(context.Background(), next, dynamic.TCPInFlightConn{Amount: 1}, "foo")
require.NoError(t, err)
// The first connection should succeed and wait.
go middleware.ServeTCP(fakeConn{addr: "127.0.0.1:9000", wait: true})
requireMessage(t, proceedCh)
closeCh := make(chan struct{})
// The second connection from the same remote address should be closed as the maximum number of connections is exceeded.
go middleware.ServeTCP(fakeConn{addr: "127.0.0.1:9000", closeCh: closeCh})
requireMessage(t, closeCh)
// The connection from another remote address should succeed.
go middleware.ServeTCP(fakeConn{addr: "127.0.0.2:9000"})
requireMessage(t, proceedCh)
// Once the first connection is closed, next connection with the same remote address should succeed.
close(waitCh)
requireMessage(t, finishCh)
go middleware.ServeTCP(fakeConn{addr: "127.0.0.1:9000"})
requireMessage(t, proceedCh)
}
func requireMessage(t *testing.T, c chan struct{}) {
t.Helper()
select {
case <-c:
case <-time.After(time.Second):
t.Fatal("Timeout waiting for message")
}
}
type fakeConn struct {
net.Conn
addr string
wait bool
closeCh chan struct{}
}
func (c fakeConn) RemoteAddr() net.Addr {
return fakeAddr{addr: c.addr}
}
func (c fakeConn) Close() error {
close(c.closeCh)
return nil
}
func (c fakeConn) CloseWrite() error {
panic("implement me")
}
type fakeAddr struct {
addr string
}
func (a fakeAddr) Network() string {
return "tcp"
}
func (a fakeAddr) String() string {
return a.addr
}