1
0
Fork 0

Merge branch 'v2.1' into master

This commit is contained in:
Fernandez Ludovic 2020-01-07 17:35:07 +01:00
commit da3d814c8b
56 changed files with 2162 additions and 558 deletions

View file

@ -52,9 +52,9 @@ func (h *httpForwarder) Accept() (net.Conn, error) {
type TCPEntryPoints map[string]*TCPEntryPoint
// NewTCPEntryPoints creates a new TCPEntryPoints.
func NewTCPEntryPoints(staticConfiguration static.Configuration) (TCPEntryPoints, error) {
func NewTCPEntryPoints(entryPointsConfig static.EntryPoints) (TCPEntryPoints, error) {
serverEntryPointsTCP := make(TCPEntryPoints)
for entryPointName, config := range staticConfiguration.EntryPoints {
for entryPointName, config := range entryPointsConfig {
ctx := log.With(context.Background(), log.Str(log.EntryPointName, entryPointName))
var err error
@ -171,6 +171,23 @@ func (e *TCPEntryPoint) StartTCP(ctx context.Context) {
}
safe.Go(func() {
// Enforce read/write deadlines at the connection level,
// because when we're peeking the first byte to determine whether we are doing TLS,
// the deadlines at the server level are not taken into account.
if e.transportConfiguration.RespondingTimeouts.ReadTimeout > 0 {
err := writeCloser.SetReadDeadline(time.Now().Add(time.Duration(e.transportConfiguration.RespondingTimeouts.ReadTimeout)))
if err != nil {
logger.Errorf("Error while setting read deadline: %v", err)
}
}
if e.transportConfiguration.RespondingTimeouts.WriteTimeout > 0 {
err = writeCloser.SetWriteDeadline(time.Now().Add(time.Duration(e.transportConfiguration.RespondingTimeouts.WriteTimeout)))
if err != nil {
logger.Errorf("Error while setting write deadline: %v", err)
}
}
e.switcher.ServeTCP(newTrackedConnection(writeCloser, e.tracker))
})
}
@ -191,48 +208,48 @@ func (e *TCPEntryPoint) Shutdown(ctx context.Context) {
logger.Debugf("Waiting %s seconds before killing connections.", graceTimeOut)
var wg sync.WaitGroup
shutdownServer := func(server stoppableServer) {
defer wg.Done()
err := server.Shutdown(ctx)
if err == nil {
return
}
if ctx.Err() == context.DeadlineExceeded {
logger.Debugf("Server failed to shutdown within deadline because: %s", err)
if err = server.Close(); err != nil {
logger.Error(err)
}
return
}
logger.Error(err)
// We expect Close to fail again because Shutdown most likely failed when trying to close a listener.
// We still call it however, to make sure that all connections get closed as well.
server.Close()
}
if e.httpServer.Server != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := e.httpServer.Server.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
logger.Debugf("Wait server shutdown is overdue to: %s", err)
err = e.httpServer.Server.Close()
if err != nil {
logger.Error(err)
}
}
}
}()
go shutdownServer(e.httpServer.Server)
}
if e.httpsServer.Server != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := e.httpsServer.Server.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
logger.Debugf("Wait server shutdown is overdue to: %s", err)
err = e.httpsServer.Server.Close()
if err != nil {
logger.Error(err)
}
}
}
}()
go shutdownServer(e.httpsServer.Server)
}
if e.tracker != nil {
wg.Add(1)
go func() {
defer wg.Done()
if err := e.tracker.Shutdown(ctx); err != nil {
if ctx.Err() == context.DeadlineExceeded {
logger.Debugf("Wait hijack connection is overdue to: %s", err)
e.tracker.Close()
}
err := e.tracker.Shutdown(ctx)
if err == nil {
return
}
if ctx.Err() == context.DeadlineExceeded {
logger.Debugf("Server failed to shutdown before deadline because: %s", err)
}
e.tracker.Close()
}()
}
@ -459,8 +476,11 @@ func createHTTPServer(ctx context.Context, ln net.Listener, configuration *stati
}
serverHTTP := &http.Server{
Handler: handler,
ErrorLog: httpServerLogger,
Handler: handler,
ErrorLog: httpServerLogger,
ReadTimeout: time.Duration(configuration.Transport.RespondingTimeouts.ReadTimeout),
WriteTimeout: time.Duration(configuration.Transport.RespondingTimeouts.WriteTimeout),
IdleTimeout: time.Duration(configuration.Transport.RespondingTimeouts.IdleTimeout),
}
listener := newHTTPForwarder(ln)

View file

@ -3,8 +3,11 @@ package server
import (
"bufio"
"context"
"errors"
"io"
"net"
"net/http"
"strings"
"testing"
"time"
@ -15,128 +18,206 @@ import (
"github.com/stretchr/testify/require"
)
func TestShutdownHTTP(t *testing.T) {
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
Address: ":0",
Transport: &static.EntryPointsTransport{
LifeCycle: &static.LifeCycle{
RequestAcceptGraceTimeout: 0,
GraceTimeOut: types.Duration(5 * time.Second),
},
},
ForwardedHeaders: &static.ForwardedHeaders{},
})
require.NoError(t, err)
go entryPoint.StartTCP(context.Background())
router := &tcp.Router{}
router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
time.Sleep(1 * time.Second)
rw.WriteHeader(http.StatusOK)
}))
entryPoint.SwitchRouter(router)
conn, err := net.Dial("tcp", entryPoint.listener.Addr().String())
require.NoError(t, err)
go entryPoint.Shutdown(context.Background())
request, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8082", nil)
require.NoError(t, err)
err = request.Write(conn)
require.NoError(t, err)
resp, err := http.ReadResponse(bufio.NewReader(conn), request)
require.NoError(t, err)
assert.Equal(t, resp.StatusCode, http.StatusOK)
}
func TestShutdownHTTPHijacked(t *testing.T) {
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
Address: ":0",
Transport: &static.EntryPointsTransport{
LifeCycle: &static.LifeCycle{
RequestAcceptGraceTimeout: 0,
GraceTimeOut: types.Duration(5 * time.Second),
},
},
ForwardedHeaders: &static.ForwardedHeaders{},
})
require.NoError(t, err)
go entryPoint.StartTCP(context.Background())
func TestShutdownHijacked(t *testing.T) {
router := &tcp.Router{}
router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
conn, _, err := rw.(http.Hijacker).Hijack()
require.NoError(t, err)
time.Sleep(1 * time.Second)
resp := http.Response{StatusCode: http.StatusOK}
err = resp.Write(conn)
require.NoError(t, err)
}))
entryPoint.SwitchRouter(router)
conn, err := net.Dial("tcp", entryPoint.listener.Addr().String())
require.NoError(t, err)
go entryPoint.Shutdown(context.Background())
request, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8082", nil)
require.NoError(t, err)
err = request.Write(conn)
require.NoError(t, err)
resp, err := http.ReadResponse(bufio.NewReader(conn), request)
require.NoError(t, err)
assert.Equal(t, resp.StatusCode, http.StatusOK)
testShutdown(t, router)
}
func TestShutdownTCPConn(t *testing.T) {
func TestShutdownHTTP(t *testing.T) {
router := &tcp.Router{}
router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
time.Sleep(time.Second)
}))
testShutdown(t, router)
}
func TestShutdownTCP(t *testing.T) {
router := &tcp.Router{}
router.AddCatchAllNoTLS(tcp.HandlerFunc(func(conn tcp.WriteCloser) {
for {
_, err := http.ReadRequest(bufio.NewReader(conn))
if err == io.EOF || (err != nil && strings.HasSuffix(err.Error(), "use of closed network connection")) {
return
}
require.NoError(t, err)
resp := http.Response{StatusCode: http.StatusOK}
err = resp.Write(conn)
require.NoError(t, err)
}
}))
testShutdown(t, router)
}
func testShutdown(t *testing.T, router *tcp.Router) {
epConfig := &static.EntryPointsTransport{}
epConfig.SetDefaults()
epConfig.LifeCycle.RequestAcceptGraceTimeout = 0
epConfig.LifeCycle.GraceTimeOut = types.Duration(5 * time.Second)
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
Address: ":0",
Transport: &static.EntryPointsTransport{
LifeCycle: &static.LifeCycle{
RequestAcceptGraceTimeout: 0,
GraceTimeOut: types.Duration(5 * time.Second),
},
},
// We explicitly use an IPV4 address because on Alpine, with an IPV6 address
// there seems to be shenanigans related to properly cleaning up file descriptors
Address: "127.0.0.1:0",
Transport: epConfig,
ForwardedHeaders: &static.ForwardedHeaders{},
})
require.NoError(t, err)
go entryPoint.StartTCP(context.Background())
conn, err := startEntrypoint(entryPoint, router)
require.NoError(t, err)
router := &tcp.Router{}
router.AddCatchAllNoTLS(tcp.HandlerFunc(func(conn tcp.WriteCloser) {
_, err := http.ReadRequest(bufio.NewReader(conn))
require.NoError(t, err)
time.Sleep(1 * time.Second)
epAddr := entryPoint.listener.Addr().String()
resp := http.Response{StatusCode: http.StatusOK}
err = resp.Write(conn)
require.NoError(t, err)
}))
request, err := http.NewRequest(http.MethodHead, "http://127.0.0.1:8082", nil)
require.NoError(t, err)
entryPoint.SwitchRouter(router)
time.Sleep(time.Millisecond * 100)
conn, err := net.Dial("tcp", entryPoint.listener.Addr().String())
// We need to do a write on the conn before the shutdown to make it "exist".
// Because the connection indeed exists as far as TCP is concerned,
// but since we only pass it along to the HTTP server after at least one byte is peaked,
// the HTTP server (and hence its shutdown) does not know about the connection until that first byte peaking.
err = request.Write(conn)
require.NoError(t, err)
go entryPoint.Shutdown(context.Background())
request, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:8082", nil)
require.NoError(t, err)
// Make sure that new connections are not permitted anymore.
// Note that this should be true not only after Shutdown has returned,
// but technically also as early as the Shutdown has closed the listener,
// i.e. during the shutdown and before the gracetime is over.
var testOk bool
for i := 0; i < 10; i++ {
loopConn, err := net.Dial("tcp", epAddr)
if err == nil {
loopConn.Close()
time.Sleep(time.Millisecond * 100)
continue
}
if !strings.HasSuffix(err.Error(), "connection refused") && !strings.HasSuffix(err.Error(), "reset by peer") {
t.Fatalf(`unexpected error: got %v, wanted "connection refused" or "reset by peer"`, err)
}
testOk = true
break
}
if !testOk {
t.Fatal("entry point never closed")
}
err = request.Write(conn)
require.NoError(t, err)
// And make sure that the connection we had opened before shutting things down is still operational
resp, err := http.ReadResponse(bufio.NewReader(conn), request)
require.NoError(t, err)
assert.Equal(t, resp.StatusCode, http.StatusOK)
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
func startEntrypoint(entryPoint *TCPEntryPoint, router *tcp.Router) (net.Conn, error) {
go entryPoint.StartTCP(context.Background())
entryPoint.SwitchRouter(router)
var conn net.Conn
var err error
var epStarted bool
for i := 0; i < 10; i++ {
conn, err = net.Dial("tcp", entryPoint.listener.Addr().String())
if err != nil {
time.Sleep(time.Millisecond * 100)
continue
}
epStarted = true
break
}
if !epStarted {
return nil, errors.New("entry point never started")
}
return conn, err
}
func TestReadTimeoutWithoutFirstByte(t *testing.T) {
epConfig := &static.EntryPointsTransport{}
epConfig.SetDefaults()
epConfig.RespondingTimeouts.ReadTimeout = types.Duration(time.Second * 2)
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
Address: ":0",
Transport: epConfig,
ForwardedHeaders: &static.ForwardedHeaders{},
})
require.NoError(t, err)
router := &tcp.Router{}
router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
}))
conn, err := startEntrypoint(entryPoint, router)
require.NoError(t, err)
errChan := make(chan error)
go func() {
b := make([]byte, 2048)
_, err := conn.Read(b)
errChan <- err
}()
select {
case err := <-errChan:
require.Equal(t, io.EOF, err)
case <-time.Tick(time.Second * 5):
t.Error("Timeout while read")
}
}
func TestReadTimeoutWithFirstByte(t *testing.T) {
epConfig := &static.EntryPointsTransport{}
epConfig.SetDefaults()
epConfig.RespondingTimeouts.ReadTimeout = types.Duration(time.Second * 2)
entryPoint, err := NewTCPEntryPoint(context.Background(), &static.EntryPoint{
Address: ":0",
Transport: epConfig,
ForwardedHeaders: &static.ForwardedHeaders{},
})
require.NoError(t, err)
router := &tcp.Router{}
router.HTTPHandler(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
}))
conn, err := startEntrypoint(entryPoint, router)
require.NoError(t, err)
_, err = conn.Write([]byte("GET /some HTTP/1.1\r\n"))
require.NoError(t, err)
errChan := make(chan error)
go func() {
b := make([]byte, 2048)
_, err := conn.Read(b)
errChan <- err
}()
select {
case err := <-errChan:
require.Equal(t, io.EOF, err)
case <-time.Tick(time.Second * 5):
t.Error("Timeout while read")
}
}

View file

@ -52,6 +52,10 @@ func buildProxy(passHostHeader *bool, responseForwarding *dynamic.ResponseForwar
outReq.ProtoMajor = 1
outReq.ProtoMinor = 1
if _, ok := outReq.Header["User-Agent"]; !ok {
outReq.Header.Set("User-Agent", "")
}
// Do not pass client Host header unless optsetter PassHostHeader is set.
if passHostHeader != nil && !*passHostHeader {
outReq.Host = outReq.URL.Host