Support mirroring request body
Co-authored-by: Mathieu Lonjaret <mathieu.lonjaret@gmail.com> Co-authored-by: Julien Salleyron <julien.salleyron@gmail.com>
This commit is contained in:
parent
09c07f45ee
commit
cf7f0f878a
20 changed files with 454 additions and 44 deletions
|
@ -1,7 +1,9 @@
|
|||
package mirror
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
|
@ -11,13 +13,15 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const defaultMaxBodySize int64 = -1
|
||||
|
||||
func TestMirroringOn100(t *testing.T) {
|
||||
var countMirror1, countMirror2 int32
|
||||
handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
pool := safe.NewPool(context.Background())
|
||||
mirror := New(handler, pool)
|
||||
mirror := New(handler, pool, defaultMaxBodySize)
|
||||
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
atomic.AddInt32(&countMirror1, 1)
|
||||
}), 10)
|
||||
|
@ -46,7 +50,7 @@ func TestMirroringOn10(t *testing.T) {
|
|||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
pool := safe.NewPool(context.Background())
|
||||
mirror := New(handler, pool)
|
||||
mirror := New(handler, pool, defaultMaxBodySize)
|
||||
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
atomic.AddInt32(&countMirror1, 1)
|
||||
}), 10)
|
||||
|
@ -70,7 +74,7 @@ func TestMirroringOn10(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestInvalidPercent(t *testing.T) {
|
||||
mirror := New(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), safe.NewPool(context.Background()))
|
||||
mirror := New(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {}), safe.NewPool(context.Background()), defaultMaxBodySize)
|
||||
err := mirror.AddMirror(nil, -1)
|
||||
assert.Error(t, err)
|
||||
|
||||
|
@ -89,7 +93,7 @@ func TestHijack(t *testing.T) {
|
|||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
pool := safe.NewPool(context.Background())
|
||||
mirror := New(handler, pool)
|
||||
mirror := New(handler, pool, defaultMaxBodySize)
|
||||
|
||||
var mirrorRequest bool
|
||||
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
|
@ -113,7 +117,7 @@ func TestFlush(t *testing.T) {
|
|||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
pool := safe.NewPool(context.Background())
|
||||
mirror := New(handler, pool)
|
||||
mirror := New(handler, pool, defaultMaxBodySize)
|
||||
|
||||
var mirrorRequest bool
|
||||
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
|
@ -131,3 +135,121 @@ func TestFlush(t *testing.T) {
|
|||
pool.Stop()
|
||||
assert.Equal(t, true, mirrorRequest)
|
||||
}
|
||||
|
||||
func TestMirroringWithBody(t *testing.T) {
|
||||
const numMirrors = 10
|
||||
|
||||
var (
|
||||
countMirror int32
|
||||
body = []byte(`body`)
|
||||
)
|
||||
|
||||
pool := safe.NewPool(context.Background())
|
||||
|
||||
handler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
assert.NotNil(t, r.Body)
|
||||
bb, err := ioutil.ReadAll(r.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, body, bb)
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
mirror := New(handler, pool, defaultMaxBodySize)
|
||||
|
||||
for i := 0; i < numMirrors; i++ {
|
||||
err := mirror.AddMirror(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
assert.NotNil(t, r.Body)
|
||||
bb, err := ioutil.ReadAll(r.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, body, bb)
|
||||
atomic.AddInt32(&countMirror, 1)
|
||||
}), 100)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/", bytes.NewBuffer(body))
|
||||
|
||||
mirror.ServeHTTP(httptest.NewRecorder(), req)
|
||||
|
||||
pool.Stop()
|
||||
|
||||
val := atomic.LoadInt32(&countMirror)
|
||||
assert.Equal(t, numMirrors, int(val))
|
||||
}
|
||||
|
||||
func TestCloneRequest(t *testing.T) {
|
||||
t.Run("http request body is nil", func(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "/", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctx := req.Context()
|
||||
rr, _, err := newReusableRequest(req, defaultMaxBodySize)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// first call
|
||||
cloned := rr.clone(ctx)
|
||||
assert.Equal(t, cloned, req)
|
||||
assert.Nil(t, cloned.Body)
|
||||
|
||||
// second call
|
||||
cloned = rr.clone(ctx)
|
||||
assert.Equal(t, cloned, req)
|
||||
assert.Nil(t, cloned.Body)
|
||||
})
|
||||
|
||||
t.Run("http request body is not nil", func(t *testing.T) {
|
||||
bb := []byte(`¯\_(ツ)_/¯`)
|
||||
contentLength := len(bb)
|
||||
|
||||
buf := bytes.NewBuffer(bb)
|
||||
req, err := http.NewRequest(http.MethodPost, "/", buf)
|
||||
assert.NoError(t, err)
|
||||
|
||||
ctx := req.Context()
|
||||
req.ContentLength = int64(contentLength)
|
||||
|
||||
rr, _, err := newReusableRequest(req, defaultMaxBodySize)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// first call
|
||||
cloned := rr.clone(ctx)
|
||||
body, err := ioutil.ReadAll(cloned.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, bb, body)
|
||||
|
||||
// second call
|
||||
cloned = rr.clone(ctx)
|
||||
body, err = ioutil.ReadAll(cloned.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, bb, body)
|
||||
})
|
||||
|
||||
t.Run("failed case", func(t *testing.T) {
|
||||
bb := []byte(`1234567890`)
|
||||
buf := bytes.NewBuffer(bb)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, "/", buf)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, expectedBytes, err := newReusableRequest(req, 2)
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, bb[:3], expectedBytes)
|
||||
})
|
||||
|
||||
t.Run("valid case with maxBodySize", func(t *testing.T) {
|
||||
bb := []byte(`1234567890`)
|
||||
buf := bytes.NewBuffer(bb)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, "/", buf)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, expectedBytes, err := newReusableRequest(req, 20)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, expectedBytes)
|
||||
})
|
||||
|
||||
t.Run("no request given", func(t *testing.T) {
|
||||
_, _, err := newReusableRequest(nil, defaultMaxBodySize)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue