Add TCP Middlewares support
This commit is contained in:
parent
679def0151
commit
fc9f41b955
134 changed files with 5865 additions and 1852 deletions
86
pkg/tcp/chain.go
Normal file
86
pkg/tcp/chain.go
Normal file
|
@ -0,0 +1,86 @@
|
|||
package tcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Constructor A constructor for a piece of TCP middleware.
|
||||
// Some TCP middleware use this constructor out of the box,
|
||||
// so in most cases you can just pass somepackage.New.
|
||||
type Constructor func(Handler) (Handler, error)
|
||||
|
||||
// Chain is a chain for TCP handlers.
|
||||
// Chain acts as a list of tcp.Handler constructors.
|
||||
// Chain is effectively immutable:
|
||||
// once created, it will always hold
|
||||
// the same set of constructors in the same order.
|
||||
type Chain struct {
|
||||
constructors []Constructor
|
||||
}
|
||||
|
||||
// NewChain creates a new TCP chain,
|
||||
// memorizing the given list of TCP middleware constructors.
|
||||
// New serves no other function,
|
||||
// constructors are only called upon a call to Then().
|
||||
func NewChain(constructors ...Constructor) Chain {
|
||||
return Chain{constructors: constructors}
|
||||
}
|
||||
|
||||
// Then adds an handler at the end of the chain.
|
||||
func (c Chain) Then(h Handler) (Handler, error) {
|
||||
if h == nil {
|
||||
return nil, fmt.Errorf("cannot add a nil handler to the chain")
|
||||
}
|
||||
|
||||
for i := range c.constructors {
|
||||
handler, err := c.constructors[len(c.constructors)-1-i](h)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
h = handler
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// Append extends a chain, adding the specified constructors
|
||||
// as the last ones in the request flow.
|
||||
//
|
||||
// Append returns a new chain, leaving the original one untouched.
|
||||
//
|
||||
// stdChain := tcp.NewChain(m1, m2)
|
||||
// extChain := stdChain.Append(m3, m4)
|
||||
// // requests in stdChain go m1 -> m2
|
||||
// // requests in extChain go m1 -> m2 -> m3 -> m4
|
||||
func (c Chain) Append(constructors ...Constructor) Chain {
|
||||
newCons := make([]Constructor, 0, len(c.constructors)+len(constructors))
|
||||
newCons = append(newCons, c.constructors...)
|
||||
newCons = append(newCons, constructors...)
|
||||
|
||||
return Chain{newCons}
|
||||
}
|
||||
|
||||
// Extend extends a chain by adding the specified chain
|
||||
// as the last one in the request flow.
|
||||
//
|
||||
// Extend returns a new chain, leaving the original one untouched.
|
||||
//
|
||||
// stdChain := tcp.NewChain(m1, m2)
|
||||
// ext1Chain := tcp.NewChain(m3, m4)
|
||||
// ext2Chain := stdChain.Extend(ext1Chain)
|
||||
// // requests in stdChain go m1 -> m2
|
||||
// // requests in ext1Chain go m3 -> m4
|
||||
// // requests in ext2Chain go m1 -> m2 -> m3 -> m4
|
||||
//
|
||||
// Another example:
|
||||
// aHtmlAfterNosurf := tcp.NewChain(m2)
|
||||
// aHtml := tcp.NewChain(m1, func(h tcp.Handler) tcp.Handler {
|
||||
// csrf := nosurf.New(h)
|
||||
// csrf.SetFailureHandler(aHtmlAfterNosurf.ThenFunc(csrfFail))
|
||||
// return csrf
|
||||
// }).Extend(aHtmlAfterNosurf)
|
||||
// // requests to aHtml hitting nosurfs success handler go m1 -> nosurf -> m2 -> target-handler
|
||||
// // requests to aHtml hitting nosurfs failure handler go m1 -> nosurf -> m2 -> csrfFail
|
||||
func (c Chain) Extend(chain Chain) Chain {
|
||||
return c.Append(chain.constructors...)
|
||||
}
|
176
pkg/tcp/chain_test.go
Normal file
176
pkg/tcp/chain_test.go
Normal file
|
@ -0,0 +1,176 @@
|
|||
package tcp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type HandlerTCPFunc func(WriteCloser)
|
||||
|
||||
// ServeTCP calls f(conn).
|
||||
func (f HandlerTCPFunc) ServeTCP(conn WriteCloser) {
|
||||
f(conn)
|
||||
}
|
||||
|
||||
// A constructor for middleware
|
||||
// that writes its own "tag" into the Conn and does nothing else.
|
||||
// Useful in checking if a chain is behaving in the right order.
|
||||
func tagMiddleware(tag string) Constructor {
|
||||
return func(h Handler) (Handler, error) {
|
||||
return HandlerTCPFunc(func(conn WriteCloser) {
|
||||
_, err := conn.Write([]byte(tag))
|
||||
if err != nil {
|
||||
panic("Unexpected")
|
||||
}
|
||||
h.ServeTCP(conn)
|
||||
}), nil
|
||||
}
|
||||
}
|
||||
|
||||
var testApp = HandlerTCPFunc(func(conn WriteCloser) {
|
||||
_, err := conn.Write([]byte("app\n"))
|
||||
if err != nil {
|
||||
panic("unexpected")
|
||||
}
|
||||
})
|
||||
|
||||
type myWriter struct {
|
||||
data []byte
|
||||
}
|
||||
|
||||
func (mw *myWriter) Close() error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (mw *myWriter) LocalAddr() net.Addr {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (mw *myWriter) RemoteAddr() net.Addr {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (mw *myWriter) SetDeadline(t time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (mw *myWriter) SetReadDeadline(t time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (mw *myWriter) SetWriteDeadline(t time.Time) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (mw *myWriter) Read(b []byte) (n int, err error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (mw *myWriter) Write(b []byte) (n int, err error) {
|
||||
mw.data = append(mw.data, b...)
|
||||
return len(mw.data), nil
|
||||
}
|
||||
|
||||
func (mw *myWriter) CloseWrite() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestNewChain(t *testing.T) {
|
||||
c1 := func(h Handler) (Handler, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
c2 := func(h Handler) (Handler, error) {
|
||||
return h, nil
|
||||
}
|
||||
|
||||
slice := []Constructor{c1, c2}
|
||||
|
||||
chain := NewChain(slice...)
|
||||
for k := range slice {
|
||||
assert.ObjectsAreEqual(chain.constructors[k], slice[k])
|
||||
}
|
||||
}
|
||||
|
||||
func TestThenWorksWithNoMiddleware(t *testing.T) {
|
||||
handler, err := NewChain().Then(testApp)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.ObjectsAreEqual(handler, testApp)
|
||||
}
|
||||
|
||||
func TestThenTreatsNilAsError(t *testing.T) {
|
||||
handler, err := NewChain().Then(nil)
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, handler)
|
||||
}
|
||||
|
||||
func TestThenOrdersHandlersCorrectly(t *testing.T) {
|
||||
t1 := tagMiddleware("t1\n")
|
||||
t2 := tagMiddleware("t2\n")
|
||||
t3 := tagMiddleware("t3\n")
|
||||
|
||||
chained, err := NewChain(t1, t2, t3).Then(testApp)
|
||||
require.NoError(t, err)
|
||||
|
||||
conn := &myWriter{}
|
||||
chained.ServeTCP(conn)
|
||||
|
||||
assert.Equal(t, "t1\nt2\nt3\napp\n", string(conn.data))
|
||||
}
|
||||
|
||||
func TestAppendAddsHandlersCorrectly(t *testing.T) {
|
||||
chain := NewChain(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
|
||||
newChain := chain.Append(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
|
||||
|
||||
assert.Len(t, chain.constructors, 2)
|
||||
assert.Len(t, newChain.constructors, 4)
|
||||
|
||||
chained, err := newChain.Then(testApp)
|
||||
require.NoError(t, err)
|
||||
|
||||
conn := &myWriter{}
|
||||
chained.ServeTCP(conn)
|
||||
|
||||
assert.Equal(t, "t1\nt2\nt3\nt4\napp\n", string(conn.data))
|
||||
}
|
||||
|
||||
func TestAppendRespectsImmutability(t *testing.T) {
|
||||
chain := NewChain(tagMiddleware(""))
|
||||
newChain := chain.Append(tagMiddleware(""))
|
||||
|
||||
if &chain.constructors[0] == &newChain.constructors[0] {
|
||||
t.Error("Apppend does not respect immutability")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtendAddsHandlersCorrectly(t *testing.T) {
|
||||
chain1 := NewChain(tagMiddleware("t1\n"), tagMiddleware("t2\n"))
|
||||
chain2 := NewChain(tagMiddleware("t3\n"), tagMiddleware("t4\n"))
|
||||
newChain := chain1.Extend(chain2)
|
||||
|
||||
assert.Len(t, chain1.constructors, 2)
|
||||
assert.Len(t, chain2.constructors, 2)
|
||||
assert.Len(t, newChain.constructors, 4)
|
||||
|
||||
chained, err := newChain.Then(testApp)
|
||||
require.NoError(t, err)
|
||||
|
||||
conn := &myWriter{}
|
||||
chained.ServeTCP(conn)
|
||||
|
||||
assert.Equal(t, "t1\nt2\nt3\nt4\napp\n", string(conn.data))
|
||||
}
|
||||
|
||||
func TestExtendRespectsImmutability(t *testing.T) {
|
||||
chain := NewChain(tagMiddleware(""))
|
||||
newChain := chain.Extend(NewChain(tagMiddleware("")))
|
||||
|
||||
if &chain.constructors[0] == &newChain.constructors[0] {
|
||||
t.Error("Extend does not respect immutability")
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue