1
0
Fork 0

Support mirroring request body

Co-authored-by: Mathieu Lonjaret <mathieu.lonjaret@gmail.com>
Co-authored-by: Julien Salleyron <julien.salleyron@gmail.com>
This commit is contained in:
Dmytro Tananayskiy 2020-03-05 18:03:08 +01:00 committed by Traefiker Bot
parent 09c07f45ee
commit cf7f0f878a
20 changed files with 454 additions and 44 deletions

View file

@ -2,12 +2,17 @@ package mirror
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"sync"
"github.com/containous/traefik/v2/pkg/log"
"github.com/containous/traefik/v2/pkg/middlewares/accesslog"
"github.com/containous/traefik/v2/pkg/safe"
)
@ -19,16 +24,19 @@ type Mirroring struct {
rw http.ResponseWriter
routinePool *safe.Pool
maxBodySize int64
lock sync.RWMutex
total uint64
}
// New returns a new instance of *Mirroring.
func New(handler http.Handler, pool *safe.Pool) *Mirroring {
func New(handler http.Handler, pool *safe.Pool, maxBodySize int64) *Mirroring {
return &Mirroring{
routinePool: pool,
handler: handler,
rw: blackholeResponseWriter{},
rw: blackHoleResponseWriter{},
maxBodySize: maxBodySize,
}
}
@ -47,41 +55,73 @@ type mirrorHandler struct {
count uint64
}
func (m *Mirroring) getActiveMirrors() []http.Handler {
total := m.inc()
var mirrors []http.Handler
for _, handler := range m.mirrorHandlers {
handler.lock.Lock()
if handler.count*100 < total*uint64(handler.percent) {
handler.count++
handler.lock.Unlock()
mirrors = append(mirrors, handler)
} else {
handler.lock.Unlock()
}
}
return mirrors
}
func (m *Mirroring) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
m.handler.ServeHTTP(rw, req)
mirrors := m.getActiveMirrors()
if len(mirrors) == 0 {
m.handler.ServeHTTP(rw, req)
return
}
logger := log.FromContext(req.Context())
rr, bytesRead, err := newReusableRequest(req, m.maxBodySize)
if err != nil && err != errBodyTooLarge {
http.Error(rw, http.StatusText(http.StatusInternalServerError)+
fmt.Sprintf("error creating reusable request: %v", err), http.StatusInternalServerError)
return
}
if err == errBodyTooLarge {
req.Body = ioutil.NopCloser(io.MultiReader(bytes.NewReader(bytesRead), req.Body))
m.handler.ServeHTTP(rw, req)
logger.Debugf("no mirroring, request body larger than allowed size")
return
}
m.handler.ServeHTTP(rw, rr.clone(req.Context()))
select {
case <-req.Context().Done():
// No mirroring if request has been canceled during main handler ServeHTTP
logger.Warn("no mirroring, request has been canceled during main handler ServeHTTP")
return
default:
}
m.routinePool.GoCtx(func(_ context.Context) {
total := m.inc()
for _, handler := range m.mirrorHandlers {
handler.lock.Lock()
if handler.count*100 < total*uint64(handler.percent) {
handler.count++
handler.lock.Unlock()
for _, handler := range mirrors {
// prepare request, update body from buffer
r := rr.clone(req.Context())
// In ServeHTTP, we rely on the presence of the accesslog datatable found in the
// request's context to know whether we should mutate said datatable (and
// contribute some fields to the log). In this instance, we do not want the mirrors
// mutating (i.e. changing the service name in) the logs related to the mirrored
// server. Especially since it would result in unguarded concurrent reads/writes on
// the datatable. Therefore, we reset any potential datatable key in the new
// context that we pass around.
ctx := context.WithValue(req.Context(), accesslog.DataTableKey, nil)
// In ServeHTTP, we rely on the presence of the accessLog datatable found in the request's context
// to know whether we should mutate said datatable (and contribute some fields to the log).
// In this instance, we do not want the mirrors mutating (i.e. changing the service name in)
// the logs related to the mirrored server.
// Especially since it would result in unguarded concurrent reads/writes on the datatable.
// Therefore, we reset any potential datatable key in the new context that we pass around.
ctx := context.WithValue(r.Context(), accesslog.DataTableKey, nil)
// When a request served by m.handler is successful, req.Context will be canceled,
// which would trigger a cancellation of the ongoing mirrored requests.
// Therefore, we give a new, non-cancellable context to each of the mirrored calls,
// so they can terminate by themselves.
handler.ServeHTTP(m.rw, req.WithContext(contextStopPropagation{ctx}))
} else {
handler.lock.Unlock()
}
// When a request served by m.handler is successful, req.Context will be canceled,
// which would trigger a cancellation of the ongoing mirrored requests.
// Therefore, we give a new, non-cancellable context to each of the mirrored calls,
// so they can terminate by themselves.
handler.ServeHTTP(m.rw, r.WithContext(contextStopPropagation{ctx}))
}
})
}
@ -95,23 +135,23 @@ func (m *Mirroring) AddMirror(handler http.Handler, percent int) error {
return nil
}
type blackholeResponseWriter struct{}
type blackHoleResponseWriter struct{}
func (b blackholeResponseWriter) Flush() {}
func (b blackHoleResponseWriter) Flush() {}
func (b blackholeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, errors.New("connection on blackholeResponseWriter cannot be hijacked")
func (b blackHoleResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, errors.New("connection on blackHoleResponseWriter cannot be hijacked")
}
func (b blackholeResponseWriter) Header() http.Header {
func (b blackHoleResponseWriter) Header() http.Header {
return http.Header{}
}
func (b blackholeResponseWriter) Write(bytes []byte) (int, error) {
func (b blackHoleResponseWriter) Write(bytes []byte) (int, error) {
return len(bytes), nil
}
func (b blackholeResponseWriter) WriteHeader(statusCode int) {}
func (b blackHoleResponseWriter) WriteHeader(statusCode int) {}
type contextStopPropagation struct {
context.Context
@ -120,3 +160,65 @@ type contextStopPropagation struct {
func (c contextStopPropagation) Done() <-chan struct{} {
return make(chan struct{})
}
// reusableRequest keeps in memory the body of the given request,
// so that the request can be fully cloned by each mirror.
type reusableRequest struct {
req *http.Request
body []byte
}
var errBodyTooLarge = errors.New("request body too large")
// if the returned error is errBodyTooLarge, newReusableRequest also returns the
// bytes that were already consumed from the request's body.
func newReusableRequest(req *http.Request, maxBodySize int64) (*reusableRequest, []byte, error) {
if req == nil {
return nil, nil, errors.New("nil input request")
}
if req.Body == nil {
return &reusableRequest{req: req}, nil, nil
}
// unbounded body size
if maxBodySize < 0 {
body, err := ioutil.ReadAll(req.Body)
if err != nil {
return nil, nil, err
}
return &reusableRequest{
req: req,
body: body,
}, nil, nil
}
// we purposefully try to read _more_ than maxBodySize to detect whether
// the request body is larger than what we allow for the mirrors.
body := make([]byte, maxBodySize+1)
n, err := io.ReadFull(req.Body, body)
if err != nil && err != io.ErrUnexpectedEOF {
return nil, nil, err
}
// we got an ErrUnexpectedEOF, which means there was less than maxBodySize data to read,
// which permits us sending also to all the mirrors later.
if err == io.ErrUnexpectedEOF {
return &reusableRequest{
req: req,
body: body[:n],
}, nil, nil
}
// err == nil , which means data size > maxBodySize
return nil, body[:n], errBodyTooLarge
}
func (rr reusableRequest) clone(ctx context.Context) *http.Request {
req := rr.req.Clone(ctx)
if rr.body != nil {
req.Body = ioutil.NopCloser(bytes.NewReader(rr.body))
}
return req
}

View file

@ -1,7 +1,9 @@
package mirror
import (
"bytes"
"context"
"io/ioutil"
"net/http"
"net/http/httptest"
"sync/atomic"
@ -11,13 +13,15 @@ import (
"github.com/stretchr/testify/assert"
)
const defaultMaxBodySize int64 = -1
func TestMirroringOn100(t *testing.T) {
var countMirror1, countMirror2 int32
handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
})
pool := safe.NewPool(context.Background())
mirror := New(handler, pool)
mirror := New(handler, pool, defaultMaxBodySize)
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
atomic.AddInt32(&countMirror1, 1)
}), 10)
@ -46,7 +50,7 @@ func TestMirroringOn10(t *testing.T) {
rw.WriteHeader(http.StatusOK)
})
pool := safe.NewPool(context.Background())
mirror := New(handler, pool)
mirror := New(handler, pool, defaultMaxBodySize)
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
atomic.AddInt32(&countMirror1, 1)
}), 10)
@ -70,7 +74,7 @@ func TestMirroringOn10(t *testing.T) {
}
func TestInvalidPercent(t *testing.T) {
mirror := New(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), safe.NewPool(context.Background()))
mirror := New(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), safe.NewPool(context.Background()), defaultMaxBodySize)
err := mirror.AddMirror(nil, -1)
assert.Error(t, err)
@ -89,7 +93,7 @@ func TestHijack(t *testing.T) {
rw.WriteHeader(http.StatusOK)
})
pool := safe.NewPool(context.Background())
mirror := New(handler, pool)
mirror := New(handler, pool, defaultMaxBodySize)
var mirrorRequest bool
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
@ -113,7 +117,7 @@ func TestFlush(t *testing.T) {
rw.WriteHeader(http.StatusOK)
})
pool := safe.NewPool(context.Background())
mirror := New(handler, pool)
mirror := New(handler, pool, defaultMaxBodySize)
var mirrorRequest bool
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
@ -131,3 +135,121 @@ func TestFlush(t *testing.T) {
pool.Stop()
assert.Equal(t, true, mirrorRequest)
}
func TestMirroringWithBody(t *testing.T) {
const numMirrors = 10
var (
countMirror int32
body = []byte(`body`)
)
pool := safe.NewPool(context.Background())
handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
assert.NotNil(t, r.Body)
bb, err := ioutil.ReadAll(r.Body)
assert.NoError(t, err)
assert.Equal(t, body, bb)
rw.WriteHeader(http.StatusOK)
})
mirror := New(handler, pool, defaultMaxBodySize)
for i := 0; i < numMirrors; i++ {
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
assert.NotNil(t, r.Body)
bb, err := ioutil.ReadAll(r.Body)
assert.NoError(t, err)
assert.Equal(t, body, bb)
atomic.AddInt32(&countMirror, 1)
}), 100)
assert.NoError(t, err)
}
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(body))
mirror.ServeHTTP(httptest.NewRecorder(), req)
pool.Stop()
val := atomic.LoadInt32(&countMirror)
assert.Equal(t, numMirrors, int(val))
}
func TestCloneRequest(t *testing.T) {
t.Run("http request body is nil", func(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "/", nil)
assert.NoError(t, err)
ctx := req.Context()
rr, _, err := newReusableRequest(req, defaultMaxBodySize)
assert.NoError(t, err)
// first call
cloned := rr.clone(ctx)
assert.Equal(t, cloned, req)
assert.Nil(t, cloned.Body)
// second call
cloned = rr.clone(ctx)
assert.Equal(t, cloned, req)
assert.Nil(t, cloned.Body)
})
t.Run("http request body is not nil", func(t *testing.T) {
bb := []byte(`¯\_(ツ)_/¯`)
contentLength := len(bb)
buf := bytes.NewBuffer(bb)
req, err := http.NewRequest(http.MethodPost, "/", buf)
assert.NoError(t, err)
ctx := req.Context()
req.ContentLength = int64(contentLength)
rr, _, err := newReusableRequest(req, defaultMaxBodySize)
assert.NoError(t, err)
// first call
cloned := rr.clone(ctx)
body, err := ioutil.ReadAll(cloned.Body)
assert.NoError(t, err)
assert.Equal(t, bb, body)
// second call
cloned = rr.clone(ctx)
body, err = ioutil.ReadAll(cloned.Body)
assert.NoError(t, err)
assert.Equal(t, bb, body)
})
t.Run("failed case", func(t *testing.T) {
bb := []byte(`1234567890`)
buf := bytes.NewBuffer(bb)
req, err := http.NewRequest(http.MethodPost, "/", buf)
assert.NoError(t, err)
_, expectedBytes, err := newReusableRequest(req, 2)
assert.Error(t, err)
assert.Equal(t, bb[:3], expectedBytes)
})
t.Run("valid case with maxBodySize", func(t *testing.T) {
bb := []byte(`1234567890`)
buf := bytes.NewBuffer(bb)
req, err := http.NewRequest(http.MethodPost, "/", buf)
assert.NoError(t, err)
_, expectedBytes, err := newReusableRequest(req, 20)
assert.NoError(t, err)
assert.Nil(t, expectedBytes)
})
t.Run("no request given", func(t *testing.T) {
_, _, err := newReusableRequest(nil, defaultMaxBodySize)
assert.Error(t, err)
})
}

View file

@ -33,6 +33,8 @@ const (
defaultHealthCheckTimeout = 5 * time.Second
)
const defaultMaxBodySize int64 = -1
// NewManager creates a new Manager
func NewManager(configs map[string]*runtime.ServiceInfo, defaultRoundTripper http.RoundTripper, metricsRegistry metrics.Registry, routinePool *safe.Pool) *Manager {
return &Manager{
@ -123,7 +125,11 @@ func (m *Manager) getMirrorServiceHandler(ctx context.Context, config *dynamic.M
return nil, err
}
handler := mirror.New(serviceHandler, m.routinePool)
maxBodySize := defaultMaxBodySize
if config.MaxBodySize != nil {
maxBodySize = *config.MaxBodySize
}
handler := mirror.New(serviceHandler, m.routinePool, maxBodySize)
for _, mirrorConfig := range config.Mirrors {
mirrorHandler, err := m.BuildHTTP(ctx, mirrorConfig.Name, responseModifier)
if err != nil {