1
0
Fork 0

Support mirroring request body

Co-authored-by: Mathieu Lonjaret <mathieu.lonjaret@gmail.com>
Co-authored-by: Julien Salleyron <julien.salleyron@gmail.com>
This commit is contained in:
Dmytro Tananayskiy 2020-03-05 18:03:08 +01:00 committed by Traefiker Bot
parent 09c07f45ee
commit cf7f0f878a
20 changed files with 454 additions and 44 deletions

View file

@ -1,7 +1,9 @@
package mirror
import (
"bytes"
"context"
"io/ioutil"
"net/http"
"net/http/httptest"
"sync/atomic"
@ -11,13 +13,15 @@ import (
"github.com/stretchr/testify/assert"
)
const defaultMaxBodySize int64 = -1
func TestMirroringOn100(t *testing.T) {
var countMirror1, countMirror2 int32
handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusOK)
})
pool := safe.NewPool(context.Background())
mirror := New(handler, pool)
mirror := New(handler, pool, defaultMaxBodySize)
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
atomic.AddInt32(&countMirror1, 1)
}), 10)
@ -46,7 +50,7 @@ func TestMirroringOn10(t *testing.T) {
rw.WriteHeader(http.StatusOK)
})
pool := safe.NewPool(context.Background())
mirror := New(handler, pool)
mirror := New(handler, pool, defaultMaxBodySize)
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
atomic.AddInt32(&countMirror1, 1)
}), 10)
@ -70,7 +74,7 @@ func TestMirroringOn10(t *testing.T) {
}
func TestInvalidPercent(t *testing.T) {
mirror := New(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), safe.NewPool(context.Background()))
mirror := New(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), safe.NewPool(context.Background()), defaultMaxBodySize)
err := mirror.AddMirror(nil, -1)
assert.Error(t, err)
@ -89,7 +93,7 @@ func TestHijack(t *testing.T) {
rw.WriteHeader(http.StatusOK)
})
pool := safe.NewPool(context.Background())
mirror := New(handler, pool)
mirror := New(handler, pool, defaultMaxBodySize)
var mirrorRequest bool
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
@ -113,7 +117,7 @@ func TestFlush(t *testing.T) {
rw.WriteHeader(http.StatusOK)
})
pool := safe.NewPool(context.Background())
mirror := New(handler, pool)
mirror := New(handler, pool, defaultMaxBodySize)
var mirrorRequest bool
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
@ -131,3 +135,121 @@ func TestFlush(t *testing.T) {
pool.Stop()
assert.Equal(t, true, mirrorRequest)
}
func TestMirroringWithBody(t *testing.T) {
const numMirrors = 10
var (
countMirror int32
body = []byte(`body`)
)
pool := safe.NewPool(context.Background())
handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
assert.NotNil(t, r.Body)
bb, err := ioutil.ReadAll(r.Body)
assert.NoError(t, err)
assert.Equal(t, body, bb)
rw.WriteHeader(http.StatusOK)
})
mirror := New(handler, pool, defaultMaxBodySize)
for i := 0; i < numMirrors; i++ {
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
assert.NotNil(t, r.Body)
bb, err := ioutil.ReadAll(r.Body)
assert.NoError(t, err)
assert.Equal(t, body, bb)
atomic.AddInt32(&countMirror, 1)
}), 100)
assert.NoError(t, err)
}
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(body))
mirror.ServeHTTP(httptest.NewRecorder(), req)
pool.Stop()
val := atomic.LoadInt32(&countMirror)
assert.Equal(t, numMirrors, int(val))
}
func TestCloneRequest(t *testing.T) {
t.Run("http request body is nil", func(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "/", nil)
assert.NoError(t, err)
ctx := req.Context()
rr, _, err := newReusableRequest(req, defaultMaxBodySize)
assert.NoError(t, err)
// first call
cloned := rr.clone(ctx)
assert.Equal(t, cloned, req)
assert.Nil(t, cloned.Body)
// second call
cloned = rr.clone(ctx)
assert.Equal(t, cloned, req)
assert.Nil(t, cloned.Body)
})
t.Run("http request body is not nil", func(t *testing.T) {
bb := []byte(`¯\_(ツ)_/¯`)
contentLength := len(bb)
buf := bytes.NewBuffer(bb)
req, err := http.NewRequest(http.MethodPost, "/", buf)
assert.NoError(t, err)
ctx := req.Context()
req.ContentLength = int64(contentLength)
rr, _, err := newReusableRequest(req, defaultMaxBodySize)
assert.NoError(t, err)
// first call
cloned := rr.clone(ctx)
body, err := ioutil.ReadAll(cloned.Body)
assert.NoError(t, err)
assert.Equal(t, bb, body)
// second call
cloned = rr.clone(ctx)
body, err = ioutil.ReadAll(cloned.Body)
assert.NoError(t, err)
assert.Equal(t, bb, body)
})
t.Run("failed case", func(t *testing.T) {
bb := []byte(`1234567890`)
buf := bytes.NewBuffer(bb)
req, err := http.NewRequest(http.MethodPost, "/", buf)
assert.NoError(t, err)
_, expectedBytes, err := newReusableRequest(req, 2)
assert.Error(t, err)
assert.Equal(t, bb[:3], expectedBytes)
})
t.Run("valid case with maxBodySize", func(t *testing.T) {
bb := []byte(`1234567890`)
buf := bytes.NewBuffer(bb)
req, err := http.NewRequest(http.MethodPost, "/", buf)
assert.NoError(t, err)
_, expectedBytes, err := newReusableRequest(req, 20)
assert.NoError(t, err)
assert.Nil(t, expectedBytes)
})
t.Run("no request given", func(t *testing.T) {
_, _, err := newReusableRequest(nil, defaultMaxBodySize)
assert.Error(t, err)
})
}