Add KeepAliveMaxTime and KeepAliveMaxRequests features to entrypoints

This commit is contained in:
Julien Salleyron 2024-01-02 16:40:06 +01:00 committed by GitHub
parent 3dfaa3d5fa
commit 9662cdca64
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 258 additions and 7 deletions

View file

@ -0,0 +1,29 @@
package server
import (
"net/http"
"time"
ptypes "github.com/traefik/paerser/types"
"github.com/traefik/traefik/v2/pkg/log"
)
func newKeepAliveMiddleware(next http.Handler, maxRequests int, maxTime ptypes.Duration) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
state, ok := req.Context().Value(connStateKey).(*connState)
if ok {
state.HTTPRequestCount++
if maxRequests > 0 && state.HTTPRequestCount >= maxRequests {
log.WithoutContext().Debug("Close because of too many requests")
state.KeepAliveState = "Close because of too many requests"
rw.Header().Set("Connection", "close")
}
if maxTime > 0 && time.Now().After(state.Start.Add(time.Duration(maxTime))) {
log.WithoutContext().Debug("Close because of too long connection")
state.KeepAliveState = "Close because of too long connection"
rw.Header().Set("Connection", "close")
}
}
next.ServeHTTP(rw, req)
})
}

View file

@ -3,6 +3,7 @@ package server
import (
"context"
"errors"
"expvar"
"fmt"
stdlog "log"
"net"
@ -34,6 +35,25 @@ import (
var httpServerLogger = stdlog.New(log.WithoutContext().WriterLevel(logrus.DebugLevel), "", 0)
type key string
const (
connStateKey key = "connState"
debugConnectionEnv string = "DEBUG_CONNECTION"
)
var (
clientConnectionStates = map[string]*connState{}
clientConnectionStatesMu = sync.RWMutex{}
)
type connState struct {
State string
KeepAliveState string
Start time.Time
HTTPRequestCount int
}
type httpForwarder struct {
net.Listener
connChan chan net.Conn
@ -68,6 +88,12 @@ type TCPEntryPoints map[string]*TCPEntryPoint
// NewTCPEntryPoints creates a new TCPEntryPoints.
func NewTCPEntryPoints(entryPointsConfig static.EntryPoints, hostResolverConfig *types.HostResolverConfig) (TCPEntryPoints, error) {
if os.Getenv(debugConnectionEnv) != "" {
expvar.Publish("clientConnectionStates", expvar.Func(func() any {
return clientConnectionStates
}))
}
serverEntryPointsTCP := make(TCPEntryPoints)
for entryPointName, config := range entryPointsConfig {
protocol, err := config.GetProtocol()
@ -548,6 +574,11 @@ func createHTTPServer(ctx context.Context, ln net.Listener, configuration *stati
})
}
debugConnection := os.Getenv(debugConnectionEnv) != ""
if debugConnection || (configuration.Transport != nil && (configuration.Transport.KeepAliveMaxTime > 0 || configuration.Transport.KeepAliveMaxRequests > 0)) {
handler = newKeepAliveMiddleware(handler, configuration.Transport.KeepAliveMaxRequests, configuration.Transport.KeepAliveMaxTime)
}
serverHTTP := &http.Server{
Handler: handler,
ErrorLog: httpServerLogger,
@ -555,6 +586,27 @@ func createHTTPServer(ctx context.Context, ln net.Listener, configuration *stati
WriteTimeout: time.Duration(configuration.Transport.RespondingTimeouts.WriteTimeout),
IdleTimeout: time.Duration(configuration.Transport.RespondingTimeouts.IdleTimeout),
}
if debugConnection || (configuration.Transport != nil && (configuration.Transport.KeepAliveMaxTime > 0 || configuration.Transport.KeepAliveMaxRequests > 0)) {
serverHTTP.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
cState := &connState{Start: time.Now()}
if debugConnection {
clientConnectionStatesMu.Lock()
clientConnectionStates[getConnKey(c)] = cState
clientConnectionStatesMu.Unlock()
}
return context.WithValue(ctx, connStateKey, cState)
}
if debugConnection {
serverHTTP.ConnState = func(c net.Conn, state http.ConnState) {
clientConnectionStatesMu.Lock()
if clientConnectionStates[getConnKey(c)] != nil {
clientConnectionStates[getConnKey(c)].State = state.String()
}
clientConnectionStatesMu.Unlock()
}
}
}
// ConfigureServer configures HTTP/2 with the MaxConcurrentStreams option for the given server.
// Also keeping behavior the same as
@ -584,6 +636,10 @@ func createHTTPServer(ctx context.Context, ln net.Listener, configuration *stati
}, nil
}
func getConnKey(conn net.Conn) string {
return fmt.Sprintf("%s => %s", conn.RemoteAddr(), conn.LocalAddr())
}
func newTrackedConnection(conn tcp.WriteCloser, tracker *connectionTracker) *trackedConnection {
tracker.AddConnection(conn)
return &trackedConnection{

View file

@ -230,3 +230,91 @@ func TestReadTimeoutWithFirstByte(t *testing.T) {
t.Error("Timeout while read")
}
}
func TestKeepAliveMaxRequests(t *testing.T) {
epConfig := &static.EntryPointsTransport{}
epConfig.SetDefaults()
epConfig.KeepAliveMaxRequests = 3
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
Address: ":0",
Transport: epConfig,
ForwardedHeaders: &static.ForwardedHeaders{},
HTTP2: &static.HTTP2Config{},
}, nil)
require.NoError(t, err)
router := &tcprouter.Router{}
router.SetHTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
}))
conn, err := startEntrypoint(entryPoint, router)
require.NoError(t, err)
http.DefaultClient.Transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return conn, nil
},
}
resp, err := http.Get("http://" + entryPoint.listener.Addr().String())
require.NoError(t, err)
require.False(t, resp.Close)
err = resp.Body.Close()
require.NoError(t, err)
resp, err = http.Get("http://" + entryPoint.listener.Addr().String())
require.NoError(t, err)
require.False(t, resp.Close)
err = resp.Body.Close()
require.NoError(t, err)
resp, err = http.Get("http://" + entryPoint.listener.Addr().String())
require.NoError(t, err)
require.True(t, resp.Close)
err = resp.Body.Close()
require.NoError(t, err)
}
func TestKeepAliveMaxTime(t *testing.T) {
epConfig := &static.EntryPointsTransport{}
epConfig.SetDefaults()
epConfig.KeepAliveMaxTime = ptypes.Duration(time.Millisecond)
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
Address: ":0",
Transport: epConfig,
ForwardedHeaders: &static.ForwardedHeaders{},
HTTP2: &static.HTTP2Config{},
}, nil)
require.NoError(t, err)
router := &tcprouter.Router{}
router.SetHTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
}))
conn, err := startEntrypoint(entryPoint, router)
require.NoError(t, err)
http.DefaultClient.Transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return conn, nil
},
}
resp, err := http.Get("http://" + entryPoint.listener.Addr().String())
require.NoError(t, err)
require.False(t, resp.Close)
err = resp.Body.Close()
require.NoError(t, err)
time.Sleep(time.Millisecond)
resp, err = http.Get("http://" + entryPoint.listener.Addr().String())
require.NoError(t, err)
require.True(t, resp.Close)
err = resp.Body.Close()
require.NoError(t, err)
}