Handle shutdown of Hijacked connections

This commit is contained in:
SALLEYRON Julien 2018-07-19 17:30:06 +02:00 committed by Traefiker Bot
parent d50b6a34bc
commit c8ae97fd38
4 changed files with 128 additions and 17 deletions

View file

@ -40,6 +40,59 @@ import (
var httpServerLogger = stdlog.New(log.WriterLevel(logrus.DebugLevel), "", 0)
func newHijackConnectionTracker() *hijackConnectionTracker {
return &hijackConnectionTracker{
conns: make(map[net.Conn]struct{}),
}
}
type hijackConnectionTracker struct {
conns map[net.Conn]struct{}
lock sync.RWMutex
}
// AddHijackedConnection add a connection in the tracked connections list
func (h *hijackConnectionTracker) AddHijackedConnection(conn net.Conn) {
h.lock.Lock()
defer h.lock.Unlock()
h.conns[conn] = struct{}{}
}
// RemoveHijackedConnection remove a connection from the tracked connections list
func (h *hijackConnectionTracker) RemoveHijackedConnection(conn net.Conn) {
h.lock.Lock()
defer h.lock.Unlock()
delete(h.conns, conn)
}
// Shutdown wait for the connection closing
func (h *hijackConnectionTracker) Shutdown(ctx context.Context) error {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
for {
h.lock.RLock()
if len(h.conns) == 0 {
return nil
}
h.lock.RUnlock()
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
}
}
}
// Close close all the connections in the tracked connections list
func (h *hijackConnectionTracker) Close() {
for conn := range h.conns {
if err := conn.Close(); err != nil {
log.Errorf("Error while closing Hijacked conn: %v", err)
}
delete(h.conns, conn)
}
}
// Server is the reverse-proxy/load-balancer engine
type Server struct {
serverEntryPoints serverEntryPoints
@ -74,12 +127,41 @@ type EntryPoint struct {
type serverEntryPoints map[string]*serverEntryPoint
type serverEntryPoint struct {
httpServer *h2c.Server
listener net.Listener
httpRouter *middlewares.HandlerSwitcher
certs *traefiktls.CertificateStore
onDemandListener func(string) (*tls.Certificate, error)
tlsALPNGetter func(string) (*tls.Certificate, error)
httpServer *h2c.Server
listener net.Listener
httpRouter *middlewares.HandlerSwitcher
certs *traefiktls.CertificateStore
onDemandListener func(string) (*tls.Certificate, error)
tlsALPNGetter func(string) (*tls.Certificate, error)
hijackConnectionTracker *hijackConnectionTracker
}
func (s serverEntryPoint) Shutdown(ctx context.Context) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
if err := s.httpServer.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
log.Debugf("Wait server shutdown is over due to: %s", err)
err = s.httpServer.Close()
if err != nil {
log.Error(err)
}
}
}
}()
wg.Add(1)
go func() {
defer wg.Done()
if err := s.hijackConnectionTracker.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
log.Debugf("Wait hijack connection is over due to: %s", err)
s.hijackConnectionTracker.Close()
}
}
}()
wg.Wait()
}
// NewServer returns an initialized Server.
@ -187,13 +269,7 @@ func (s *Server) Stop() {
graceTimeOut := time.Duration(s.globalConfiguration.LifeCycle.GraceTimeOut)
ctx, cancel := context.WithTimeout(context.Background(), graceTimeOut)
log.Debugf("Waiting %s seconds before killing connections on entrypoint %s...", graceTimeOut, serverEntryPointName)
if err := serverEntryPoint.httpServer.Shutdown(ctx); err != nil {
log.Debugf("Wait is over due to: %s", err)
err = serverEntryPoint.httpServer.Close()
if err != nil {
log.Error(err)
}
}
serverEntryPoint.Shutdown(ctx)
cancel()
log.Debugf("Entrypoint %s closed", serverEntryPointName)
}(sepn, sep)
@ -447,6 +523,16 @@ func (s *Server) setupServerEntryPoint(newServerEntryPointName string, newServer
serverEntryPoint.httpServer = newSrv
serverEntryPoint.listener = listener
serverEntryPoint.hijackConnectionTracker = newHijackConnectionTracker()
serverEntryPoint.httpServer.ConnState = func(conn net.Conn, state http.ConnState) {
switch state {
case http.StateHijacked:
serverEntryPoint.hijackConnectionTracker.AddHijackedConnection(conn)
case http.StateClosed:
serverEntryPoint.hijackConnectionTracker.RemoveHijackedConnection(conn)
}
}
return serverEntryPoint
}