Add support for UDP routing in systemd socket activation

This commit is contained in:
tsiid 2025-01-21 11:38:09 +03:00 committed by GitHub
parent 95dd17e020
commit 261e4395f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 158 additions and 54 deletions

View file

@ -48,15 +48,8 @@ const (
var (
clientConnectionStates = map[string]*connState{}
clientConnectionStatesMu = sync.RWMutex{}
socketActivationListeners map[string]net.Listener
)
func init() {
// Populates pre-defined socketActivationListeners by socket activation.
populateSocketActivationListeners()
}
type connState struct {
State string
KeepAliveState string
@ -204,7 +197,7 @@ func NewTCPEntryPoint(ctx context.Context, name string, config *static.EntryPoin
return nil, fmt.Errorf("error preparing https server: %w", err)
}
h3Server, err := newHTTP3Server(ctx, config, httpsServer)
h3Server, err := newHTTP3Server(ctx, name, config, httpsServer)
if err != nil {
return nil, fmt.Errorf("error preparing http3 server: %w", err)
}
@ -476,13 +469,14 @@ func buildListener(ctx context.Context, name string, config *static.EntryPoint)
var err error
// if we have predefined listener from socket activation
if ln, ok := socketActivationListeners[name]; ok {
listener = ln
} else {
if len(socketActivationListeners) > 0 {
log.Warn().Str("name", name).Msg("Unable to find socket activation listener for entryPoint")
if socketActivation.isEnabled() {
listener, err = socketActivation.getListener(name)
if err != nil {
log.Ctx(ctx).Warn().Err(err).Str("name", name).Msg("Unable to use socket activation for entrypoint")
}
}
if listener == nil {
listenConfig := newListenConfig(config)
listener, err = listenConfig.Listen(ctx, "tcp", config.GetAddress())
if err != nil {

View file

@ -25,19 +25,32 @@ type http3server struct {
getter func(info *tls.ClientHelloInfo) (*tls.Config, error)
}
func newHTTP3Server(ctx context.Context, configuration *static.EntryPoint, httpsServer *httpServer) (*http3server, error) {
if configuration.HTTP3 == nil {
func newHTTP3Server(ctx context.Context, name string, config *static.EntryPoint, httpsServer *httpServer) (*http3server, error) {
var conn net.PacketConn
var err error
if config.HTTP3 == nil {
return nil, nil
}
if configuration.HTTP3.AdvertisedPort < 0 {
if config.HTTP3.AdvertisedPort < 0 {
return nil, errors.New("advertised port must be greater than or equal to zero")
}
listenConfig := newListenConfig(configuration)
conn, err := listenConfig.ListenPacket(ctx, "udp", configuration.GetAddress())
if err != nil {
return nil, fmt.Errorf("starting listener: %w", err)
// if we have predefined connections from socket activation
if socketActivation.isEnabled() {
conn, err = socketActivation.getConn(name)
if err != nil {
log.Ctx(ctx).Warn().Err(err).Str("name", name).Msg("Unable to use socket activation for entrypoint")
}
}
if conn == nil {
listenConfig := newListenConfig(config)
conn, err = listenConfig.ListenPacket(ctx, "udp", config.GetAddress())
if err != nil {
return nil, fmt.Errorf("starting listener: %w", err)
}
}
h3 := &http3server{
@ -48,8 +61,8 @@ func newHTTP3Server(ctx context.Context, configuration *static.EntryPoint, https
}
h3.Server = &http3.Server{
Addr: configuration.GetAddress(),
Port: configuration.HTTP3.AdvertisedPort,
Addr: config.GetAddress(),
Port: config.HTTP3.AdvertisedPort,
Handler: httpsServer.Server.(*http.Server).Handler,
TLSConfig: &tls.Config{GetConfigForClient: h3.getGetConfigForClient},
QUICConfig: &quic.Config{

View file

@ -16,9 +16,9 @@ import (
type UDPEntryPoints map[string]*UDPEntryPoint
// NewUDPEntryPoints returns all the UDP entry points, keyed by name.
func NewUDPEntryPoints(cfg static.EntryPoints) (UDPEntryPoints, error) {
func NewUDPEntryPoints(config static.EntryPoints) (UDPEntryPoints, error) {
entryPoints := make(UDPEntryPoints)
for entryPointName, entryPoint := range cfg {
for entryPointName, entryPoint := range config {
protocol, err := entryPoint.GetProtocol()
if err != nil {
return nil, fmt.Errorf("error while building entryPoint %s: %w", entryPointName, err)
@ -28,7 +28,7 @@ func NewUDPEntryPoints(cfg static.EntryPoints) (UDPEntryPoints, error) {
continue
}
ep, err := NewUDPEntryPoint(entryPoint)
ep, err := NewUDPEntryPoint(entryPoint, entryPointName)
if err != nil {
return nil, fmt.Errorf("error while building entryPoint %s: %w", entryPointName, err)
}
@ -85,14 +85,33 @@ type UDPEntryPoint struct {
}
// NewUDPEntryPoint returns a UDP entry point.
func NewUDPEntryPoint(cfg *static.EntryPoint) (*UDPEntryPoint, error) {
listenConfig := newListenConfig(cfg)
listener, err := udp.Listen(listenConfig, "udp", cfg.GetAddress(), time.Duration(cfg.UDP.Timeout))
if err != nil {
return nil, err
func NewUDPEntryPoint(config *static.EntryPoint, name string) (*UDPEntryPoint, error) {
var listener *udp.Listener
var err error
timeout := time.Duration(config.UDP.Timeout)
// if we have predefined connections from socket activation
if socketActivation.isEnabled() {
if conn, err := socketActivation.getConn(name); err == nil {
listener, err = udp.ListenPacketConn(conn, timeout)
if err != nil {
log.Warn().Err(err).Str("name", name).Msg("Unable to create socket activation listener")
}
} else {
log.Warn().Err(err).Str("name", name).Msg("Unable to use socket activation for entrypoint")
}
}
return &UDPEntryPoint{listener: listener, switcher: &udp.HandlerSwitcher{}, transportConfiguration: cfg.Transport}, nil
if listener == nil {
listenConfig := newListenConfig(config)
listener, err = udp.Listen(listenConfig, "udp", config.GetAddress(), timeout)
if err != nil {
return nil, fmt.Errorf("error creating listener: %w", err)
}
}
return &UDPEntryPoint{listener: listener, switcher: &udp.HandlerSwitcher{}, transportConfiguration: config.Transport}, nil
}
// Start commences the listening for ep.

View file

@ -24,7 +24,7 @@ func TestShutdownUDPConn(t *testing.T) {
}
ep.SetDefaults()
entryPoint, err := NewUDPEntryPoint(&ep)
entryPoint, err := NewUDPEntryPoint(&ep, "")
require.NoError(t, err)
go entryPoint.Start(context.Background())

View file

@ -0,0 +1,41 @@
package server
import (
"errors"
"net"
)
type SocketActivation struct {
enabled bool
listeners map[string]net.Listener
conns map[string]net.PacketConn
}
func (s *SocketActivation) isEnabled() bool {
return s.enabled
}
func (s *SocketActivation) getListener(name string) (net.Listener, error) {
listener, ok := s.listeners[name]
if !ok {
return nil, errors.New("unable to find socket activation TCP listener for entryPoint")
}
return listener, nil
}
func (s *SocketActivation) getConn(name string) (net.PacketConn, error) {
conn, ok := s.conns[name]
if !ok {
return nil, errors.New("unable to find socket activation UDP listener for entryPoint")
}
return conn, nil
}
var socketActivation *SocketActivation
func init() {
// Populates pre-defined TCP and UDP listeners provided by systemd socket activation.
socketActivation = populateSocketActivationListeners()
}

View file

@ -9,16 +9,36 @@ import (
"github.com/rs/zerolog/log"
)
func populateSocketActivationListeners() {
listenersWithName, _ := activation.ListenersWithNames()
func populateSocketActivationListeners() *SocketActivation {
// We use Files api due to activation not providing method for get PacketConn with names
files := activation.Files(true)
sa := &SocketActivation{enabled: false}
sa.listeners = make(map[string]net.Listener)
sa.conns = make(map[string]net.PacketConn)
socketActivationListeners = make(map[string]net.Listener)
for name, lns := range listenersWithName {
if len(lns) != 1 {
log.Error().Str("listenersName", name).Msg("Socket activation listeners must have one and only one listener per name")
continue
if len(files) > 0 {
sa.enabled = true
for _, f := range files {
if lc, err := net.FileListener(f); err == nil {
_, ok := sa.listeners[f.Name()]
if ok {
log.Error().Str("listenersName", f.Name()).Msg("Socket activation TCP listeners must have one and only one listener per name")
} else {
sa.listeners[f.Name()] = lc
}
f.Close()
} else if pc, err := net.FilePacketConn(f); err == nil {
_, ok := sa.conns[f.Name()]
if ok {
log.Error().Str("listenersName", f.Name()).Msg("Socket activation UDP listeners must have one and only one listener per name")
} else {
sa.conns[f.Name()] = pc
}
f.Close()
}
}
socketActivationListeners[name] = lns[0]
}
return sa
}

View file

@ -2,4 +2,6 @@
package server
func populateSocketActivationListeners() {}
func populateSocketActivationListeners() *SocketActivation {
return &SocketActivation{enabled: false}
}