UDP support
Co-authored-by: Julien Salleyron <julien.salleyron@gmail.com>
This commit is contained in:
parent
8988c8f9af
commit
115d42e0f0
72 changed files with 4730 additions and 321 deletions
265
pkg/udp/conn.go
Normal file
265
pkg/udp/conn.go
Normal file
|
@ -0,0 +1,265 @@
|
|||
package udp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const receiveMTU = 8192
|
||||
|
||||
const closeRetryInterval = 500 * time.Millisecond
|
||||
|
||||
// connTimeout determines how long to wait on an idle session,
|
||||
// before releasing all resources related to that session.
|
||||
const connTimeout = time.Second * 3
|
||||
|
||||
var errClosedListener = errors.New("udp: listener closed")
|
||||
|
||||
// Listener augments a session-oriented Listener over a UDP PacketConn.
|
||||
type Listener struct {
|
||||
pConn *net.UDPConn
|
||||
|
||||
mu sync.RWMutex
|
||||
conns map[string]*Conn
|
||||
// accepting signifies whether the listener is still accepting new sessions.
|
||||
// It also serves as a sentinel for Shutdown to be idempotent.
|
||||
accepting bool
|
||||
|
||||
acceptCh chan *Conn // no need for a Once, already indirectly guarded by accepting.
|
||||
}
|
||||
|
||||
// Listen creates a new listener.
|
||||
func Listen(network string, laddr *net.UDPAddr) (*Listener, error) {
|
||||
conn, err := net.ListenUDP(network, laddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l := &Listener{
|
||||
pConn: conn,
|
||||
acceptCh: make(chan *Conn),
|
||||
conns: make(map[string]*Conn),
|
||||
accepting: true,
|
||||
}
|
||||
|
||||
go l.readLoop()
|
||||
|
||||
return l, nil
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next connection to the listener.
|
||||
func (l *Listener) Accept() (*Conn, error) {
|
||||
c := <-l.acceptCh
|
||||
if c == nil {
|
||||
// l.acceptCh got closed
|
||||
return nil, errClosedListener
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Addr returns the listener's network address.
|
||||
func (l *Listener) Addr() net.Addr {
|
||||
return l.pConn.LocalAddr()
|
||||
}
|
||||
|
||||
// Close closes the listener.
|
||||
// It is like Shutdown with a zero graceTimeout.
|
||||
func (l *Listener) Close() error {
|
||||
return l.Shutdown(0)
|
||||
}
|
||||
|
||||
// close should not be called more than once.
|
||||
func (l *Listener) close() error {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
err := l.pConn.Close()
|
||||
for k, v := range l.conns {
|
||||
v.close()
|
||||
delete(l.conns, k)
|
||||
}
|
||||
close(l.acceptCh)
|
||||
return err
|
||||
}
|
||||
|
||||
// Shutdown closes the listener.
|
||||
// It immediately stops accepting new sessions,
|
||||
// and it waits for all existing sessions to terminate,
|
||||
// and a maximum of graceTimeout.
|
||||
// Then it forces close any session left.
|
||||
func (l *Listener) Shutdown(graceTimeout time.Duration) error {
|
||||
l.mu.Lock()
|
||||
if !l.accepting {
|
||||
l.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
l.accepting = false
|
||||
l.mu.Unlock()
|
||||
|
||||
retryInterval := closeRetryInterval
|
||||
if retryInterval > graceTimeout {
|
||||
retryInterval = graceTimeout
|
||||
}
|
||||
start := time.Now()
|
||||
end := start.Add(graceTimeout)
|
||||
for {
|
||||
if time.Now().After(end) {
|
||||
break
|
||||
}
|
||||
|
||||
l.mu.RLock()
|
||||
if len(l.conns) == 0 {
|
||||
l.mu.RUnlock()
|
||||
break
|
||||
}
|
||||
l.mu.RUnlock()
|
||||
|
||||
time.Sleep(retryInterval)
|
||||
}
|
||||
return l.close()
|
||||
}
|
||||
|
||||
// readLoop receives all packets from all remotes.
|
||||
// If a packet comes from a remote that is already known to us (i.e. a "session"),
|
||||
// we find that session, and otherwise we create a new one.
|
||||
// We then send the data the session's readLoop.
|
||||
func (l *Listener) readLoop() {
|
||||
buf := make([]byte, receiveMTU)
|
||||
|
||||
for {
|
||||
n, raddr, err := l.pConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
conn, err := l.getConn(raddr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case conn.receiveCh <- buf[:n]:
|
||||
case <-conn.doneCh:
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getConn returns the ongoing session with raddr if it exists, or creates a new
|
||||
// one otherwise.
|
||||
func (l *Listener) getConn(raddr net.Addr) (*Conn, error) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
conn, ok := l.conns[raddr.String()]
|
||||
if ok {
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
if !l.accepting {
|
||||
return nil, errClosedListener
|
||||
}
|
||||
conn = l.newConn(raddr)
|
||||
l.conns[raddr.String()] = conn
|
||||
l.acceptCh <- conn
|
||||
go conn.readLoop()
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (l *Listener) newConn(rAddr net.Addr) *Conn {
|
||||
return &Conn{
|
||||
listener: l,
|
||||
rAddr: rAddr,
|
||||
receiveCh: make(chan []byte),
|
||||
readCh: make(chan []byte),
|
||||
sizeCh: make(chan int),
|
||||
doneCh: make(chan struct{}),
|
||||
timer: time.NewTimer(connTimeout),
|
||||
}
|
||||
}
|
||||
|
||||
// Conn represents an on-going session with a client, over UDP packets.
|
||||
type Conn struct {
|
||||
listener *Listener
|
||||
rAddr net.Addr
|
||||
|
||||
receiveCh chan []byte // to receive the data from the listener's readLoop
|
||||
readCh chan []byte // to receive the buffer into which we should Read
|
||||
sizeCh chan int // to synchronize with the end of a Read
|
||||
msgs [][]byte // to store data from listener, to be consumed by Reads
|
||||
|
||||
timer *time.Timer // for timeouts
|
||||
doneOnce sync.Once
|
||||
doneCh chan struct{}
|
||||
}
|
||||
|
||||
// readLoop waits for data to come from the listener's readLoop.
|
||||
// It then waits for a Read operation to be ready to consume said data,
|
||||
// that is to say it waits on readCh to receive the slice of bytes that the Read operation wants to read onto.
|
||||
// The Read operation receives the signal that the data has been written to the slice of bytes through the sizeCh.
|
||||
func (c *Conn) readLoop() {
|
||||
for {
|
||||
if len(c.msgs) == 0 {
|
||||
select {
|
||||
case msg := <-c.receiveCh:
|
||||
c.msgs = append(c.msgs, msg)
|
||||
case <-c.timer.C:
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case cBuf := <-c.readCh:
|
||||
msg := c.msgs[0]
|
||||
c.msgs = c.msgs[1:]
|
||||
n := copy(cBuf, msg)
|
||||
c.sizeCh <- n
|
||||
case msg := <-c.receiveCh:
|
||||
c.msgs = append(c.msgs, msg)
|
||||
case <-c.timer.C:
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Read implements io.Reader for a Conn.
|
||||
func (c *Conn) Read(p []byte) (int, error) {
|
||||
select {
|
||||
case c.readCh <- p:
|
||||
n := <-c.sizeCh
|
||||
c.timer.Reset(connTimeout)
|
||||
return n, nil
|
||||
case <-c.doneCh:
|
||||
return 0, io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements io.Writer for a Conn.
|
||||
func (c *Conn) Write(p []byte) (n int, err error) {
|
||||
l := c.listener
|
||||
if l == nil {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
c.timer.Reset(connTimeout)
|
||||
return l.pConn.WriteTo(p, c.rAddr)
|
||||
}
|
||||
|
||||
func (c *Conn) close() {
|
||||
c.doneOnce.Do(func() {
|
||||
close(c.doneCh)
|
||||
})
|
||||
}
|
||||
|
||||
// Close releases resources related to the Conn.
|
||||
func (c *Conn) Close() error {
|
||||
c.close()
|
||||
|
||||
c.listener.mu.Lock()
|
||||
defer c.listener.mu.Unlock()
|
||||
delete(c.listener.conns, c.rAddr.String())
|
||||
return nil
|
||||
}
|
270
pkg/udp/conn_test.go
Normal file
270
pkg/udp/conn_test.go
Normal file
|
@ -0,0 +1,270 @@
|
|||
package udp
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestListenNotBlocking(t *testing.T) {
|
||||
addr, err := net.ResolveUDPAddr("udp", ":0")
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
ln, err := Listen("udp", addr)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := ln.Close()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err == errClosedListener {
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
go func() {
|
||||
b := make([]byte, 2048)
|
||||
n, err := conn.Read(b)
|
||||
require.NoError(t, err)
|
||||
_, err = conn.Write(b[:n])
|
||||
require.NoError(t, err)
|
||||
|
||||
n, err = conn.Read(b)
|
||||
require.NoError(t, err)
|
||||
_, err = conn.Write(b[:n])
|
||||
require.NoError(t, err)
|
||||
|
||||
// This should not block second call
|
||||
time.Sleep(time.Second * 10)
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
udpConn, err := net.Dial("udp", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = udpConn.Write([]byte("TEST"))
|
||||
require.NoError(t, err)
|
||||
|
||||
b := make([]byte, 2048)
|
||||
n, err := udpConn.Read(b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "TEST", string(b[:n]))
|
||||
|
||||
_, err = udpConn.Write([]byte("TEST2"))
|
||||
require.NoError(t, err)
|
||||
|
||||
n, err = udpConn.Read(b)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "TEST2", string(b[:n]))
|
||||
|
||||
_, err = udpConn.Write([]byte("TEST"))
|
||||
require.NoError(t, err)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
udpConn2, err := net.Dial("udp", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = udpConn2.Write([]byte("TEST"))
|
||||
require.NoError(t, err)
|
||||
|
||||
n, err = udpConn2.Read(b)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "TEST", string(b[:n]))
|
||||
|
||||
_, err = udpConn2.Write([]byte("TEST2"))
|
||||
require.NoError(t, err)
|
||||
|
||||
n, err = udpConn2.Read(b)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "TEST2", string(b[:n]))
|
||||
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-time.Tick(time.Second):
|
||||
t.Error("Timeout")
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
|
||||
func TestTimeoutWithRead(t *testing.T) {
|
||||
testTimeout(t, true)
|
||||
}
|
||||
|
||||
func TestTimeoutWithoutRead(t *testing.T) {
|
||||
testTimeout(t, false)
|
||||
}
|
||||
|
||||
func testTimeout(t *testing.T, withRead bool) {
|
||||
addr, err := net.ResolveUDPAddr("udp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
ln, err := Listen("udp", addr)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := ln.Close()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err == errClosedListener {
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
if withRead {
|
||||
buf := make([]byte, 1024)
|
||||
_, err = conn.Read(buf)
|
||||
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
udpConn2, err := net.Dial("udp", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = udpConn2.Write([]byte("TEST"))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
assert.Equal(t, 10, len(ln.conns))
|
||||
|
||||
time.Sleep(3 * time.Second)
|
||||
assert.Equal(t, 0, len(ln.conns))
|
||||
}
|
||||
|
||||
func TestShutdown(t *testing.T) {
|
||||
addr, err := net.ResolveUDPAddr("udp", ":0")
|
||||
require.NoError(t, err)
|
||||
|
||||
l, err := Listen("udp", addr)
|
||||
require.NoError(t, err)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
conn := conn
|
||||
for {
|
||||
b := make([]byte, 1024*1024)
|
||||
n, err := conn.Read(b)
|
||||
require.NoError(t, err)
|
||||
// We control the termination,
|
||||
// otherwise we would block on the Read above,
|
||||
// until conn is closed by a timeout.
|
||||
// Which means we would get an error,
|
||||
// and even though we are in a goroutine and the current test might be over,
|
||||
// go test would still yell at us if this happens while other tests are still running.
|
||||
if string(b[:n]) == "CLOSE" {
|
||||
return
|
||||
}
|
||||
_, err = conn.Write(b[:n])
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := net.Dial("udp", l.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Start sending packets, to create a "session" with the server.
|
||||
requireEcho(t, "TEST", conn, time.Second)
|
||||
|
||||
doneChan := make(chan struct{})
|
||||
go func() {
|
||||
err := l.Shutdown(5 * time.Second)
|
||||
require.NoError(t, err)
|
||||
close(doneChan)
|
||||
}()
|
||||
|
||||
// Make sure that our session is still live even after the shutdown.
|
||||
requireEcho(t, "TEST2", conn, time.Second)
|
||||
|
||||
// And make sure that on the other hand, opening new sessions is not possible anymore.
|
||||
conn2, err := net.Dial("udp", l.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = conn2.Write([]byte("TEST"))
|
||||
// Packet is accepted, but dropped
|
||||
require.NoError(t, err)
|
||||
|
||||
// Make sure that our session is yet again still live.
|
||||
// This is specifically to make sure we don't create a regression in listener's readLoop,
|
||||
// i.e. that we only terminate the listener's readLoop goroutine by closing its pConn.
|
||||
requireEcho(t, "TEST3", conn, time.Second)
|
||||
|
||||
done := make(chan bool)
|
||||
go func() {
|
||||
defer close(done)
|
||||
b := make([]byte, 1024*1024)
|
||||
n, err := conn2.Read(b)
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, 0, n)
|
||||
}()
|
||||
|
||||
conn2.Close()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.Tick(time.Second):
|
||||
t.Fatal("Timeout")
|
||||
}
|
||||
|
||||
_, err = conn.Write([]byte("CLOSE"))
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-doneChan:
|
||||
case <-time.Tick(time.Second * 5):
|
||||
// In case we introduce a regression that would make the test wait forever.
|
||||
t.Fatal("Timeout during shutdown")
|
||||
}
|
||||
}
|
||||
|
||||
// requireEcho tests that the conn session is live and functional,
|
||||
// by writing data through it, and expecting the same data as a response when reading on it.
|
||||
// It fatals if the read blocks longer than timeout,
|
||||
// which is useful to detect regressions that would make a test wait forever.
|
||||
func requireEcho(t *testing.T, data string, conn io.ReadWriter, timeout time.Duration) {
|
||||
_, err := conn.Write([]byte(data))
|
||||
require.NoError(t, err)
|
||||
|
||||
doneChan := make(chan struct{})
|
||||
go func() {
|
||||
b := make([]byte, 1024*1024)
|
||||
n, err := conn.Read(b)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, data, string(b[:n]))
|
||||
close(doneChan)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-doneChan:
|
||||
case <-time.Tick(timeout):
|
||||
t.Fatalf("Timeout during echo for: %s", data)
|
||||
}
|
||||
}
|
14
pkg/udp/handler.go
Normal file
14
pkg/udp/handler.go
Normal file
|
@ -0,0 +1,14 @@
|
|||
package udp
|
||||
|
||||
// Handler is the UDP counterpart of the usual HTTP handler.
|
||||
type Handler interface {
|
||||
ServeUDP(conn *Conn)
|
||||
}
|
||||
|
||||
// The HandlerFunc type is an adapter to allow the use of ordinary functions as handlers.
|
||||
type HandlerFunc func(conn *Conn)
|
||||
|
||||
// ServeUDP implements the Handler interface for UDP.
|
||||
func (f HandlerFunc) ServeUDP(conn *Conn) {
|
||||
f(conn)
|
||||
}
|
56
pkg/udp/proxy.go
Normal file
56
pkg/udp/proxy.go
Normal file
|
@ -0,0 +1,56 @@
|
|||
package udp
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/containous/traefik/v2/pkg/log"
|
||||
)
|
||||
|
||||
// Proxy is a reverse-proxy implementation of the Handler interface.
|
||||
type Proxy struct {
|
||||
// TODO: maybe optimize by pre-resolving it at proxy creation time
|
||||
target string
|
||||
}
|
||||
|
||||
// NewProxy creates a new Proxy
|
||||
func NewProxy(address string) (*Proxy, error) {
|
||||
return &Proxy{target: address}, nil
|
||||
}
|
||||
|
||||
// ServeUDP implements the Handler interface.
|
||||
func (p *Proxy) ServeUDP(conn *Conn) {
|
||||
log.Debugf("Handling connection from %s", conn.rAddr)
|
||||
|
||||
// needed because of e.g. server.trackedConnection
|
||||
defer conn.Close()
|
||||
|
||||
connBackend, err := net.Dial("udp", p.target)
|
||||
if err != nil {
|
||||
log.Errorf("Error while connecting to backend: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// maybe not needed, but just in case
|
||||
defer connBackend.Close()
|
||||
|
||||
errChan := make(chan error)
|
||||
go p.connCopy(conn, connBackend, errChan)
|
||||
go p.connCopy(connBackend, conn, errChan)
|
||||
|
||||
err = <-errChan
|
||||
if err != nil {
|
||||
log.WithoutContext().Errorf("Error while serving UDP: %v", err)
|
||||
}
|
||||
|
||||
<-errChan
|
||||
}
|
||||
|
||||
func (p Proxy) connCopy(dst io.WriteCloser, src io.Reader, errCh chan error) {
|
||||
_, err := io.Copy(dst, src)
|
||||
errCh <- err
|
||||
|
||||
if err := dst.Close(); err != nil {
|
||||
log.WithoutContext().Debugf("Error while terminating connection: %v", err)
|
||||
}
|
||||
}
|
55
pkg/udp/proxy_test.go
Normal file
55
pkg/udp/proxy_test.go
Normal file
|
@ -0,0 +1,55 @@
|
|||
package udp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUDPProxy(t *testing.T) {
|
||||
backendAddr := ":8081"
|
||||
go newServer(t, ":8081", HandlerFunc(func(conn *Conn) {
|
||||
for {
|
||||
b := make([]byte, 1024*1024)
|
||||
n, err := conn.Read(b)
|
||||
require.NoError(t, err)
|
||||
_, err = conn.Write(b[:n])
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}))
|
||||
|
||||
proxy, err := NewProxy(backendAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyAddr := ":8080"
|
||||
go newServer(t, proxyAddr, proxy)
|
||||
|
||||
time.Sleep(time.Second)
|
||||
udpConn, err := net.Dial("udp", proxyAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = udpConn.Write([]byte("DATAWRITE"))
|
||||
require.NoError(t, err)
|
||||
|
||||
b := make([]byte, 1024*1024)
|
||||
n, err := udpConn.Read(b)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "DATAWRITE", string(b[:n]))
|
||||
}
|
||||
|
||||
func newServer(t *testing.T, addr string, handler Handler) {
|
||||
addrL, err := net.ResolveUDPAddr("udp", addr)
|
||||
require.NoError(t, err)
|
||||
|
||||
listener, err := Listen("udp", addrL)
|
||||
require.NoError(t, err)
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
require.NoError(t, err)
|
||||
go handler.ServeUDP(conn)
|
||||
}
|
||||
}
|
26
pkg/udp/switcher.go
Normal file
26
pkg/udp/switcher.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package udp
|
||||
|
||||
import (
|
||||
"github.com/containous/traefik/v2/pkg/safe"
|
||||
)
|
||||
|
||||
// HandlerSwitcher is a switcher implementation of the Handler interface.
|
||||
type HandlerSwitcher struct {
|
||||
handler safe.Safe
|
||||
}
|
||||
|
||||
// ServeUDP implements the Handler interface.
|
||||
func (s *HandlerSwitcher) ServeUDP(conn *Conn) {
|
||||
handler := s.handler.Get()
|
||||
h, ok := handler.(Handler)
|
||||
if ok {
|
||||
h.ServeUDP(conn)
|
||||
} else {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Switch replaces s handler with the given handler.
|
||||
func (s *HandlerSwitcher) Switch(handler Handler) {
|
||||
s.handler.Set(handler)
|
||||
}
|
122
pkg/udp/wrr_load_balancer.go
Normal file
122
pkg/udp/wrr_load_balancer.go
Normal file
|
@ -0,0 +1,122 @@
|
|||
package udp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/containous/traefik/v2/pkg/log"
|
||||
)
|
||||
|
||||
type server struct {
|
||||
Handler
|
||||
weight int
|
||||
}
|
||||
|
||||
// WRRLoadBalancer is a naive RoundRobin load balancer for UDP services
|
||||
type WRRLoadBalancer struct {
|
||||
servers []server
|
||||
lock sync.RWMutex
|
||||
currentWeight int
|
||||
index int
|
||||
}
|
||||
|
||||
// NewWRRLoadBalancer creates a new WRRLoadBalancer
|
||||
func NewWRRLoadBalancer() *WRRLoadBalancer {
|
||||
return &WRRLoadBalancer{
|
||||
index: -1,
|
||||
}
|
||||
}
|
||||
|
||||
// ServeUDP forwards the connection to the right service
|
||||
func (b *WRRLoadBalancer) ServeUDP(conn *Conn) {
|
||||
if len(b.servers) == 0 {
|
||||
log.WithoutContext().Error("no available server")
|
||||
return
|
||||
}
|
||||
|
||||
next, err := b.next()
|
||||
if err != nil {
|
||||
log.WithoutContext().Errorf("Error during load balancing: %v", err)
|
||||
conn.Close()
|
||||
}
|
||||
next.ServeUDP(conn)
|
||||
}
|
||||
|
||||
// AddServer appends a handler to the existing list
|
||||
func (b *WRRLoadBalancer) AddServer(serverHandler Handler) {
|
||||
w := 1
|
||||
b.AddWeightedServer(serverHandler, &w)
|
||||
}
|
||||
|
||||
// AddWeightedServer appends a handler to the existing list with a weight
|
||||
func (b *WRRLoadBalancer) AddWeightedServer(serverHandler Handler, weight *int) {
|
||||
w := 1
|
||||
if weight != nil {
|
||||
w = *weight
|
||||
}
|
||||
b.servers = append(b.servers, server{Handler: serverHandler, weight: w})
|
||||
}
|
||||
|
||||
func (b *WRRLoadBalancer) maxWeight() int {
|
||||
max := -1
|
||||
for _, s := range b.servers {
|
||||
if s.weight > max {
|
||||
max = s.weight
|
||||
}
|
||||
}
|
||||
return max
|
||||
}
|
||||
|
||||
func (b *WRRLoadBalancer) weightGcd() int {
|
||||
divisor := -1
|
||||
for _, s := range b.servers {
|
||||
if divisor == -1 {
|
||||
divisor = s.weight
|
||||
} else {
|
||||
divisor = gcd(divisor, s.weight)
|
||||
}
|
||||
}
|
||||
return divisor
|
||||
}
|
||||
|
||||
func gcd(a, b int) int {
|
||||
for b != 0 {
|
||||
a, b = b, a%b
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func (b *WRRLoadBalancer) next() (Handler, error) {
|
||||
b.lock.Lock()
|
||||
defer b.lock.Unlock()
|
||||
|
||||
if len(b.servers) == 0 {
|
||||
return nil, fmt.Errorf("no servers in the pool")
|
||||
}
|
||||
|
||||
// The algorithm below may look messy,
|
||||
// but is actually very simple it calculates the GCD and subtracts it on every iteration,
|
||||
// what interleaves servers and allows us not to build an iterator every time we readjust weights.
|
||||
|
||||
// GCD across all enabled servers
|
||||
gcd := b.weightGcd()
|
||||
// Maximum weight across all enabled servers
|
||||
max := b.maxWeight()
|
||||
|
||||
for {
|
||||
b.index = (b.index + 1) % len(b.servers)
|
||||
if b.index == 0 {
|
||||
b.currentWeight -= gcd
|
||||
if b.currentWeight <= 0 {
|
||||
b.currentWeight = max
|
||||
if b.currentWeight == 0 {
|
||||
return nil, fmt.Errorf("all servers have 0 weight")
|
||||
}
|
||||
}
|
||||
}
|
||||
srv := b.servers[b.index]
|
||||
if srv.weight >= b.currentWeight {
|
||||
return srv, nil
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue