1
0
Fork 0

Move code to pkg

This commit is contained in:
Ludovic Fernandez 2019-03-15 09:42:03 +01:00 committed by Traefiker Bot
parent bd4c822670
commit f1b085fa36
465 changed files with 656 additions and 680 deletions

View file

@ -0,0 +1,27 @@
package service
import "sync"
const bufferPoolSize = 32 * 1024
func newBufferPool() *bufferPool {
return &bufferPool{
pool: sync.Pool{
New: func() interface{} {
return make([]byte, bufferPoolSize)
},
},
}
}
type bufferPool struct {
pool sync.Pool
}
func (b *bufferPool) Get() []byte {
return b.pool.Get().([]byte)
}
func (b *bufferPool) Put(bytes []byte) {
b.pool.Put(bytes)
}

100
pkg/server/service/proxy.go Normal file
View file

@ -0,0 +1,100 @@
package service
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
"time"
"github.com/containous/flaeg/parse"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/log"
)
// StatusClientClosedRequest non-standard HTTP status code for client disconnection
const StatusClientClosedRequest = 499
// StatusClientClosedRequestText non-standard HTTP status for client disconnection
const StatusClientClosedRequestText = "Client Closed Request"
func buildProxy(passHostHeader bool, responseForwarding *config.ResponseForwarding, defaultRoundTripper http.RoundTripper, bufferPool httputil.BufferPool, responseModifier func(*http.Response) error) (http.Handler, error) {
var flushInterval parse.Duration
if responseForwarding != nil {
err := flushInterval.Set(responseForwarding.FlushInterval)
if err != nil {
return nil, fmt.Errorf("error creating flush interval: %v", err)
}
}
if flushInterval == 0 {
flushInterval = parse.Duration(100 * time.Millisecond)
}
proxy := &httputil.ReverseProxy{
Director: func(outReq *http.Request) {
u := outReq.URL
if outReq.RequestURI != "" {
parsedURL, err := url.ParseRequestURI(outReq.RequestURI)
if err == nil {
u = parsedURL
}
}
outReq.URL.Path = u.Path
outReq.URL.RawPath = u.RawPath
outReq.URL.RawQuery = u.RawQuery
outReq.RequestURI = "" // Outgoing request should not have RequestURI
outReq.Proto = "HTTP/1.1"
outReq.ProtoMajor = 1
outReq.ProtoMinor = 1
// Do not pass client Host header unless optsetter PassHostHeader is set.
if !passHostHeader {
outReq.Host = outReq.URL.Host
}
},
Transport: defaultRoundTripper,
FlushInterval: time.Duration(flushInterval),
ModifyResponse: responseModifier,
BufferPool: bufferPool,
ErrorHandler: func(w http.ResponseWriter, request *http.Request, err error) {
statusCode := http.StatusInternalServerError
switch {
case err == io.EOF:
statusCode = http.StatusBadGateway
case err == context.Canceled:
statusCode = StatusClientClosedRequest
default:
if e, ok := err.(net.Error); ok {
if e.Timeout() {
statusCode = http.StatusGatewayTimeout
} else {
statusCode = http.StatusBadGateway
}
}
}
log.Debugf("'%d %s' caused by: %v", statusCode, statusText(statusCode), err)
w.WriteHeader(statusCode)
_, werr := w.Write([]byte(statusText(statusCode)))
if werr != nil {
log.Debugf("Error while writing status code", werr)
}
},
}
return proxy, nil
}
func statusText(statusCode int) string {
if statusCode == StatusClientClosedRequest {
return StatusClientClosedRequestText
}
return http.StatusText(statusCode)
}

View file

@ -0,0 +1,37 @@
package service
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/containous/traefik/pkg/testhelpers"
)
type staticTransport struct {
res *http.Response
}
func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
return t.res, nil
}
func BenchmarkProxy(b *testing.B) {
res := &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("")),
}
w := httptest.NewRecorder()
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
pool := newBufferPool()
handler, _ := buildProxy(false, nil, &staticTransport{res}, pool, nil)
b.ReportAllocs()
for i := 0; i < b.N; i++ {
handler.ServeHTTP(w, req)
}
}

View file

@ -0,0 +1,713 @@
package service
import (
"bufio"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
gorillawebsocket "github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/websocket"
)
func TestWebSocketTCPClose(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
errChan := make(chan error, 1)
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
for {
_, _, err := c.ReadMessage()
if err != nil {
errChan <- err
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
proxyAddr := proxy.Listener.Addr().String()
_, conn, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
).open()
require.NoError(t, err)
conn.Close()
serverErr := <-errChan
wsErr, ok := serverErr.(*gorillawebsocket.CloseError)
assert.Equal(t, true, ok)
assert.Equal(t, 1006, wsErr.Code)
}
func TestWebSocketPingPong(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
var upgrader = gorillawebsocket.Upgrader{
HandshakeTimeout: 10 * time.Second,
CheckOrigin: func(*http.Request) bool {
return true
},
}
mux := http.NewServeMux()
mux.HandleFunc("/ws", func(writer http.ResponseWriter, request *http.Request) {
ws, err := upgrader.Upgrade(writer, request, nil)
require.NoError(t, err)
ws.SetPingHandler(func(appData string) error {
err = ws.WriteMessage(gorillawebsocket.PongMessage, []byte(appData+"Pong"))
require.NoError(t, err)
return nil
})
_, _, _ = ws.ReadMessage()
})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL = parseURI(t, srv.URL)
f.ServeHTTP(w, req)
}))
defer proxy.Close()
serverAddr := proxy.Listener.Addr().String()
headers := http.Header{}
webSocketURL := "ws://" + serverAddr + "/ws"
headers.Add("Origin", webSocketURL)
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
require.NoError(t, err, "Error during Dial with response: %+v", resp)
defer conn.Close()
goodErr := fmt.Errorf("signal: %s", "Good data")
badErr := fmt.Errorf("signal: %s", "Bad data")
conn.SetPongHandler(func(data string) error {
if data == "PingPong" {
return goodErr
}
return badErr
})
err = conn.WriteControl(gorillawebsocket.PingMessage, []byte("Ping"), time.Now().Add(time.Second))
require.NoError(t, err)
_, _, err = conn.ReadMessage()
if err != goodErr {
require.NoError(t, err)
}
}
func TestWebSocketEcho(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
msg := make([]byte, 4)
_, err := conn.Read(msg)
require.NoError(t, err)
fmt.Println(string(msg))
_, err = conn.Write(msg)
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
}))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL = parseURI(t, srv.URL)
f.ServeHTTP(w, req)
}))
defer proxy.Close()
serverAddr := proxy.Listener.Addr().String()
headers := http.Header{}
webSocketURL := "ws://" + serverAddr + "/ws"
headers.Add("Origin", webSocketURL)
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
require.NoError(t, err, "Error during Dial with response: %+v", resp)
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
require.NoError(t, err)
fmt.Println(conn.ReadMessage())
err = conn.Close()
require.NoError(t, err)
}
func TestWebSocketPassHost(t *testing.T) {
testCases := []struct {
desc string
passHost bool
expected string
}{
{
desc: "PassHost false",
passHost: false,
},
{
desc: "PassHost true",
passHost: true,
expected: "example.com",
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
f, err := buildProxy(test.passHost, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
req := conn.Request()
if test.passHost {
require.Equal(t, test.expected, req.Host)
} else {
require.NotEqual(t, test.expected, req.Host)
}
msg := make([]byte, 4)
_, err = conn.Read(msg)
require.NoError(t, err)
fmt.Println(string(msg))
_, err = conn.Write(msg)
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
}))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL = parseURI(t, srv.URL)
f.ServeHTTP(w, req)
}))
defer proxy.Close()
serverAddr := proxy.Listener.Addr().String()
headers := http.Header{}
webSocketURL := "ws://" + serverAddr + "/ws"
headers.Add("Origin", webSocketURL)
headers.Add("Host", "example.com")
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
require.NoError(t, err, "Error during Dial with response: %+v", resp)
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
require.NoError(t, err)
fmt.Println(conn.ReadMessage())
err = conn.Close()
require.NoError(t, err)
})
}
}
func TestWebSocketServerWithoutCheckOrigin(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool {
return true
}}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
for {
mt, message, err := c.ReadMessage()
if err != nil {
break
}
err = c.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
withOrigin("http://127.0.0.2"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func TestWebSocketRequestWithOrigin(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
for {
mt, message, err := c.ReadMessage()
if err != nil {
break
}
err = c.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
_, err = newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("echo"),
withOrigin("http://127.0.0.2"),
).send()
require.EqualError(t, err, "bad status")
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func TestWebSocketRequestWithQueryParams(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
assert.Equal(t, "test", r.URL.Query().Get("query"))
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws?query=test"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
conn.Close()
}))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req.URL = parseURI(t, srv.URL)
w.Header().Set("HEADER-KEY", "HEADER-VALUE")
f.ServeHTTP(w, req)
}))
defer proxy.Close()
serverAddr := proxy.Listener.Addr().String()
headers := http.Header{}
webSocketURL := "ws://" + serverAddr + "/ws"
headers.Add("Origin", webSocketURL)
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
require.NoError(t, err, "Error during Dial with response: %+v", err, resp)
defer conn.Close()
assert.Equal(t, "HEADER-VALUE", resp.Header.Get("HEADER-KEY"))
}
func TestWebSocketRequestWithEncodedChar(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
assert.Equal(t, "/%3A%2F%2F", r.URL.EscapedPath())
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/%3A%2F%2F"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func TestWebSocketUpgradeFailed(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
mux := http.NewServeMux()
mux.HandleFunc("/ws", func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(400)
})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path // keep the original path
if path == "/ws" {
// Set new backend URL
req.URL = parseURI(t, srv.URL)
req.URL.Path = path
f.ServeHTTP(w, req)
} else {
w.WriteHeader(200)
}
}))
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
conn, err := net.DialTimeout("tcp", proxyAddr, dialTimeout)
require.NoError(t, err)
defer conn.Close()
req, err := http.NewRequest(http.MethodGet, "ws://127.0.0.1/ws", nil)
require.NoError(t, err)
req.Header.Add("upgrade", "websocket")
req.Header.Add("Connection", "upgrade")
err = req.Write(conn)
require.NoError(t, err)
// First request works with 400
br := bufio.NewReader(conn)
resp, err := http.ReadResponse(br, req)
require.NoError(t, err)
assert.Equal(t, 400, resp.StatusCode)
}
func TestForwardsWebsocketTraffic(t *testing.T) {
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
mux := http.NewServeMux()
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
_, err := conn.Write([]byte("ok"))
require.NoError(t, err)
err = conn.Close()
require.NoError(t, err)
}))
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
mux.ServeHTTP(w, req)
}))
defer srv.Close()
proxy := createProxyWithForwarder(t, f, srv.URL)
defer proxy.Close()
proxyAddr := proxy.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("echo"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
func createTLSWebsocketServer() *httptest.Server {
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
return srv
}
func TestWebSocketTransferTLSConfig(t *testing.T) {
srv := createTLSWebsocketServer()
defer srv.Close()
forwarderWithoutTLSConfig, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
proxyWithoutTLSConfig := createProxyWithForwarder(t, forwarderWithoutTLSConfig, srv.URL)
defer proxyWithoutTLSConfig.Close()
proxyAddr := proxyWithoutTLSConfig.Listener.Addr().String()
_, err = newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
).send()
require.EqualError(t, err, "bad status")
transport := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
forwarderWithTLSConfig, err := buildProxy(true, nil, transport, nil, nil)
require.NoError(t, err)
proxyWithTLSConfig := createProxyWithForwarder(t, forwarderWithTLSConfig, srv.URL)
defer proxyWithTLSConfig.Close()
proxyAddr = proxyWithTLSConfig.Listener.Addr().String()
resp, err := newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
forwarderWithTLSConfigFromDefaultTransport, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
require.NoError(t, err)
proxyWithTLSConfigFromDefaultTransport := createProxyWithForwarder(t, forwarderWithTLSConfigFromDefaultTransport, srv.URL)
defer proxyWithTLSConfig.Close()
proxyAddr = proxyWithTLSConfigFromDefaultTransport.Listener.Addr().String()
resp, err = newWebsocketRequest(
withServer(proxyAddr),
withPath("/ws"),
withData("ok"),
).send()
require.NoError(t, err)
assert.Equal(t, "ok", resp)
}
const dialTimeout = time.Second
type websocketRequestOpt func(w *websocketRequest)
func withServer(server string) websocketRequestOpt {
return func(w *websocketRequest) {
w.ServerAddr = server
}
}
func withPath(path string) websocketRequestOpt {
return func(w *websocketRequest) {
w.Path = path
}
}
func withData(data string) websocketRequestOpt {
return func(w *websocketRequest) {
w.Data = data
}
}
func withOrigin(origin string) websocketRequestOpt {
return func(w *websocketRequest) {
w.Origin = origin
}
}
func newWebsocketRequest(opts ...websocketRequestOpt) *websocketRequest {
wsrequest := &websocketRequest{}
for _, opt := range opts {
opt(wsrequest)
}
if wsrequest.Origin == "" {
wsrequest.Origin = "http://" + wsrequest.ServerAddr
}
if wsrequest.Config == nil {
wsrequest.Config, _ = websocket.NewConfig(fmt.Sprintf("ws://%s%s", wsrequest.ServerAddr, wsrequest.Path), wsrequest.Origin)
}
return wsrequest
}
type websocketRequest struct {
ServerAddr string
Path string
Data string
Origin string
Config *websocket.Config
}
func (w *websocketRequest) send() (string, error) {
conn, _, err := w.open()
if err != nil {
return "", err
}
defer conn.Close()
if _, err := conn.Write([]byte(w.Data)); err != nil {
return "", err
}
var msg = make([]byte, 512)
var n int
n, err = conn.Read(msg)
if err != nil {
return "", err
}
received := string(msg[:n])
return received, nil
}
func (w *websocketRequest) open() (*websocket.Conn, net.Conn, error) {
client, err := net.DialTimeout("tcp", w.ServerAddr, dialTimeout)
if err != nil {
return nil, nil, err
}
conn, err := websocket.NewClient(w.Config, client)
if err != nil {
return nil, nil, err
}
return conn, client, err
}
func parseURI(t *testing.T, uri string) *url.URL {
out, err := url.ParseRequestURI(uri)
require.NoError(t, err)
return out
}
func createProxyWithForwarder(t *testing.T, proxy http.Handler, url string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
path := req.URL.Path // keep the original path
// Set new backend URL
req.URL = parseURI(t, url)
req.URL.Path = path
proxy.ServeHTTP(w, req)
}))
}

View file

@ -0,0 +1,257 @@
package service
import (
"context"
"fmt"
"net/http"
"net/http/httputil"
"net/url"
"time"
"github.com/containous/alice"
"github.com/containous/traefik/old/middlewares/pipelining"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/healthcheck"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/middlewares/accesslog"
"github.com/containous/traefik/pkg/middlewares/emptybackendhandler"
"github.com/containous/traefik/pkg/server/cookie"
"github.com/containous/traefik/pkg/server/internal"
"github.com/vulcand/oxy/roundrobin"
)
const (
defaultHealthCheckInterval = 30 * time.Second
defaultHealthCheckTimeout = 5 * time.Second
)
// NewManager creates a new Manager
func NewManager(configs map[string]*config.Service, defaultRoundTripper http.RoundTripper) *Manager {
return &Manager{
bufferPool: newBufferPool(),
defaultRoundTripper: defaultRoundTripper,
balancers: make(map[string][]healthcheck.BalancerHandler),
configs: configs,
}
}
// Manager The service manager
type Manager struct {
bufferPool httputil.BufferPool
defaultRoundTripper http.RoundTripper
balancers map[string][]healthcheck.BalancerHandler
configs map[string]*config.Service
}
// BuildHTTP Creates a http.Handler for a service configuration.
func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) {
ctx := log.With(rootCtx, log.Str(log.ServiceName, serviceName))
serviceName = internal.GetQualifiedName(ctx, serviceName)
ctx = internal.AddProviderInContext(ctx, serviceName)
if conf, ok := m.configs[serviceName]; ok {
// TODO Should handle multiple service types
// FIXME Check if the service is declared multiple times with different types
if conf.LoadBalancer != nil {
return m.getLoadBalancerServiceHandler(ctx, serviceName, conf.LoadBalancer, responseModifier)
}
return nil, fmt.Errorf("the service %q doesn't have any load balancer", serviceName)
}
return nil, fmt.Errorf("the service %q does not exits", serviceName)
}
func (m *Manager) getLoadBalancerServiceHandler(
ctx context.Context,
serviceName string,
service *config.LoadBalancerService,
responseModifier func(*http.Response) error,
) (http.Handler, error) {
fwd, err := buildProxy(service.PassHostHeader, service.ResponseForwarding, m.defaultRoundTripper, m.bufferPool, responseModifier)
if err != nil {
return nil, err
}
alHandler := func(next http.Handler) (http.Handler, error) {
return accesslog.NewFieldHandler(next, accesslog.ServiceName, serviceName, accesslog.AddServiceFields), nil
}
handler, err := alice.New().Append(alHandler).Then(pipelining.NewPipelining(fwd))
if err != nil {
return nil, err
}
balancer, err := m.getLoadBalancer(ctx, serviceName, service, handler)
if err != nil {
return nil, err
}
// TODO rename and checks
m.balancers[serviceName] = append(m.balancers[serviceName], balancer)
// Empty (backend with no servers)
return emptybackendhandler.New(balancer), nil
}
// LaunchHealthCheck Launches the health checks.
func (m *Manager) LaunchHealthCheck() {
backendConfigs := make(map[string]*healthcheck.BackendConfig)
for serviceName, balancers := range m.balancers {
ctx := log.With(context.Background(), log.Str(log.ServiceName, serviceName))
// FIXME aggregate
balancer := balancers[0]
// FIXME Should all the services handle healthcheck? Handle different types
service := m.configs[serviceName].LoadBalancer
// Health Check
var backendHealthCheck *healthcheck.BackendConfig
if hcOpts := buildHealthCheckOptions(ctx, balancer, serviceName, service.HealthCheck); hcOpts != nil {
log.FromContext(ctx).Debugf("Setting up healthcheck for service %s with %s", serviceName, *hcOpts)
hcOpts.Transport = m.defaultRoundTripper
backendHealthCheck = healthcheck.NewBackendConfig(*hcOpts, serviceName)
}
if backendHealthCheck != nil {
backendConfigs[serviceName] = backendHealthCheck
}
}
// FIXME metrics and context
healthcheck.GetHealthCheck().SetBackendsConfiguration(context.TODO(), backendConfigs)
}
func buildHealthCheckOptions(ctx context.Context, lb healthcheck.BalancerHandler, backend string, hc *config.HealthCheck) *healthcheck.Options {
if hc == nil || hc.Path == "" {
return nil
}
logger := log.FromContext(ctx)
interval := defaultHealthCheckInterval
if hc.Interval != "" {
intervalOverride, err := time.ParseDuration(hc.Interval)
switch {
case err != nil:
logger.Errorf("Illegal health check interval for '%s': %s", backend, err)
case intervalOverride <= 0:
logger.Errorf("Health check interval smaller than zero for service '%s'", backend)
default:
interval = intervalOverride
}
}
timeout := defaultHealthCheckTimeout
if hc.Timeout != "" {
timeoutOverride, err := time.ParseDuration(hc.Timeout)
switch {
case err != nil:
logger.Errorf("Illegal health check timeout for backend '%s': %s", backend, err)
case timeoutOverride <= 0:
logger.Errorf("Health check timeout smaller than zero for backend '%s', backend", backend)
default:
timeout = timeoutOverride
}
}
if timeout >= interval {
logger.Warnf("Health check timeout for backend '%s' should be lower than the health check interval. Interval set to timeout + 1 second (%s).", backend)
}
return &healthcheck.Options{
Scheme: hc.Scheme,
Path: hc.Path,
Port: hc.Port,
Interval: interval,
Timeout: timeout,
LB: lb,
Hostname: hc.Hostname,
Headers: hc.Headers,
}
}
func (m *Manager) getLoadBalancer(ctx context.Context, serviceName string, service *config.LoadBalancerService, fwd http.Handler) (healthcheck.BalancerHandler, error) {
logger := log.FromContext(ctx)
var stickySession *roundrobin.StickySession
var cookieName string
if stickiness := service.Stickiness; stickiness != nil {
cookieName = cookie.GetName(stickiness.CookieName, serviceName)
stickySession = roundrobin.NewStickySession(cookieName)
}
var lb healthcheck.BalancerHandler
if service.Method == "drr" {
logger.Debug("Creating drr load-balancer")
rr, err := roundrobin.New(fwd)
if err != nil {
return nil, err
}
if stickySession != nil {
logger.Debugf("Sticky session cookie name: %v", cookieName)
lb, err = roundrobin.NewRebalancer(rr, roundrobin.RebalancerStickySession(stickySession))
if err != nil {
return nil, err
}
} else {
lb, err = roundrobin.NewRebalancer(rr)
if err != nil {
return nil, err
}
}
} else {
if service.Method != "wrr" {
logger.Warnf("Invalid load-balancing method %q, fallback to 'wrr' method", service.Method)
}
logger.Debug("Creating wrr load-balancer")
if stickySession != nil {
logger.Debugf("Sticky session cookie name: %v", cookieName)
var err error
lb, err = roundrobin.New(fwd, roundrobin.EnableStickySession(stickySession))
if err != nil {
return nil, err
}
} else {
var err error
lb, err = roundrobin.New(fwd)
if err != nil {
return nil, err
}
}
}
if err := m.upsertServers(ctx, lb, service.Servers); err != nil {
return nil, fmt.Errorf("error configuring load balancer for service %s: %v", serviceName, err)
}
return lb, nil
}
func (m *Manager) upsertServers(ctx context.Context, lb healthcheck.BalancerHandler, servers []config.Server) error {
logger := log.FromContext(ctx)
for name, srv := range servers {
u, err := url.Parse(srv.URL)
if err != nil {
return fmt.Errorf("error parsing server URL %s: %v", srv.URL, err)
}
logger.WithField(log.ServerName, name).Debugf("Creating server %d at %s with weight %d", name, u, srv.Weight)
if err := lb.UpsertServer(u, roundrobin.Weight(srv.Weight)); err != nil {
return fmt.Errorf("error adding server %s to load balancer: %v", srv.URL, err)
}
// FIXME Handle Metrics
}
return nil
}

View file

@ -0,0 +1,329 @@
package service
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/server/internal"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type MockForwarder struct{}
func (MockForwarder) ServeHTTP(http.ResponseWriter, *http.Request) {
panic("implement me")
}
func TestGetLoadBalancer(t *testing.T) {
sm := Manager{}
testCases := []struct {
desc string
serviceName string
service *config.LoadBalancerService
fwd http.Handler
expectError bool
}{
{
desc: "Fails when provided an invalid URL",
serviceName: "test",
service: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: ":",
Weight: 0,
},
},
},
fwd: &MockForwarder{},
expectError: true,
},
{
desc: "Succeeds when there are no servers",
serviceName: "test",
service: &config.LoadBalancerService{},
fwd: &MockForwarder{},
expectError: false,
},
{
desc: "Succeeds when stickiness is set",
serviceName: "test",
service: &config.LoadBalancerService{
Stickiness: &config.Stickiness{},
},
fwd: &MockForwarder{},
expectError: false,
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
handler, err := sm.getLoadBalancer(context.Background(), test.serviceName, test.service, test.fwd)
if test.expectError {
require.Error(t, err)
assert.Nil(t, handler)
} else {
require.NoError(t, err)
assert.NotNil(t, handler)
}
})
}
}
func TestGetLoadBalancerServiceHandler(t *testing.T) {
sm := NewManager(nil, http.DefaultTransport)
server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-From", "first")
}))
defer server1.Close()
server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-From", "second")
}))
defer server2.Close()
serverPassHost := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-From", "passhost")
assert.Equal(t, "callme", r.Host)
}))
defer serverPassHost.Close()
serverPassHostFalse := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-From", "passhostfalse")
assert.NotEqual(t, "callme", r.Host)
}))
defer serverPassHostFalse.Close()
type ExpectedResult struct {
StatusCode int
XFrom string
}
testCases := []struct {
desc string
serviceName string
service *config.LoadBalancerService
responseModifier func(*http.Response) error
expected []ExpectedResult
}{
{
desc: "Load balances between the two servers",
serviceName: "test",
service: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: server1.URL,
Weight: 50,
},
{
URL: server2.URL,
Weight: 50,
},
},
Method: "wrr",
},
expected: []ExpectedResult{
{
StatusCode: http.StatusOK,
XFrom: "first",
},
{
StatusCode: http.StatusOK,
XFrom: "second",
},
},
},
{
desc: "StatusBadGateway when the server is not reachable",
serviceName: "test",
service: &config.LoadBalancerService{
Servers: []config.Server{
{
URL: "http://foo",
Weight: 1,
},
},
Method: "wrr",
},
expected: []ExpectedResult{
{
StatusCode: http.StatusBadGateway,
},
},
},
{
desc: "ServiceUnavailable when no servers are available",
serviceName: "test",
service: &config.LoadBalancerService{
Servers: []config.Server{},
Method: "wrr",
},
expected: []ExpectedResult{
{
StatusCode: http.StatusServiceUnavailable,
},
},
},
{
desc: "Always call the same server when stickiness is true",
serviceName: "test",
service: &config.LoadBalancerService{
Stickiness: &config.Stickiness{},
Servers: []config.Server{
{
URL: server1.URL,
Weight: 1,
},
{
URL: server2.URL,
Weight: 1,
},
},
Method: "wrr",
},
expected: []ExpectedResult{
{
StatusCode: http.StatusOK,
XFrom: "first",
},
{
StatusCode: http.StatusOK,
XFrom: "first",
},
},
},
{
desc: "PassHost passes the host instead of the IP",
serviceName: "test",
service: &config.LoadBalancerService{
Stickiness: &config.Stickiness{},
PassHostHeader: true,
Servers: []config.Server{
{
URL: serverPassHost.URL,
Weight: 1,
},
},
Method: "wrr",
},
expected: []ExpectedResult{
{
StatusCode: http.StatusOK,
XFrom: "passhost",
},
},
},
{
desc: "PassHost doesn't passe the host instead of the IP",
serviceName: "test",
service: &config.LoadBalancerService{
Stickiness: &config.Stickiness{},
Servers: []config.Server{
{
URL: serverPassHostFalse.URL,
Weight: 1,
},
},
Method: "wrr",
},
expected: []ExpectedResult{
{
StatusCode: http.StatusOK,
XFrom: "passhostfalse",
},
},
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
handler, err := sm.getLoadBalancerServiceHandler(context.Background(), test.serviceName, test.service, test.responseModifier)
assert.NoError(t, err)
assert.NotNil(t, handler)
req := testhelpers.MustNewRequest(http.MethodGet, "http://callme", nil)
for _, expected := range test.expected {
recorder := httptest.NewRecorder()
handler.ServeHTTP(recorder, req)
assert.Equal(t, expected.StatusCode, recorder.Code)
assert.Equal(t, expected.XFrom, recorder.Header().Get("X-From"))
if len(recorder.Header().Get("Set-Cookie")) > 0 {
req.Header.Set("Cookie", recorder.Header().Get("Set-Cookie"))
}
}
})
}
}
func TestManager_Build(t *testing.T) {
testCases := []struct {
desc string
serviceName string
configs map[string]*config.Service
providerName string
}{
{
desc: "Simple service name",
serviceName: "serviceName",
configs: map[string]*config.Service{
"serviceName": {
LoadBalancer: &config.LoadBalancerService{Method: "wrr"},
},
},
},
{
desc: "Service name with provider",
serviceName: "provider-1.serviceName",
configs: map[string]*config.Service{
"provider-1.serviceName": {
LoadBalancer: &config.LoadBalancerService{Method: "wrr"},
},
},
},
{
desc: "Service name with provider in context",
serviceName: "serviceName",
configs: map[string]*config.Service{
"provider-1.serviceName": {
LoadBalancer: &config.LoadBalancerService{Method: "wrr"},
},
},
providerName: "provider-1",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
manager := NewManager(test.configs, http.DefaultTransport)
ctx := context.Background()
if len(test.providerName) > 0 {
ctx = internal.AddProviderInContext(ctx, test.providerName+".foobar")
}
_, err := manager.BuildHTTP(ctx, test.serviceName, nil)
require.NoError(t, err)
})
}
}
// FIXME Add healthcheck tests

View file

@ -0,0 +1,67 @@
package tcp
import (
"context"
"fmt"
"net"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/log"
"github.com/containous/traefik/pkg/server/internal"
"github.com/containous/traefik/pkg/tcp"
)
// Manager is the TCPHandlers factory
type Manager struct {
configs map[string]*config.TCPService
}
// NewManager creates a new manager
func NewManager(configs map[string]*config.TCPService) *Manager {
return &Manager{
configs: configs,
}
}
// BuildTCP Creates a tcp.Handler for a service configuration.
func (m *Manager) BuildTCP(rootCtx context.Context, serviceName string) (tcp.Handler, error) {
ctx := log.With(rootCtx, log.Str(log.ServiceName, serviceName))
serviceName = internal.GetQualifiedName(ctx, serviceName)
ctx = internal.AddProviderInContext(ctx, serviceName)
if conf, ok := m.configs[serviceName]; ok {
// FIXME Check if the service is declared multiple times with different types
if conf.LoadBalancer != nil {
loadBalancer := tcp.NewRRLoadBalancer()
var handler tcp.Handler
for _, server := range conf.LoadBalancer.Servers {
_, err := parseIP(server.Address)
if err == nil {
handler, _ = tcp.NewProxy(server.Address)
loadBalancer.AddServer(handler)
} else {
log.FromContext(ctx).Errorf("Invalid IP address for a %s server %s: %v", serviceName, server.Address, err)
}
}
return loadBalancer, nil
}
return nil, fmt.Errorf("the service %q doesn't have any TCP load balancer", serviceName)
}
return nil, fmt.Errorf("the service %q does not exits", serviceName)
}
func parseIP(s string) (string, error) {
ip, _, err := net.SplitHostPort(s)
if err == nil {
return ip, nil
}
ipNoPort := net.ParseIP(s)
if ipNoPort == nil {
return "", fmt.Errorf("invalid IP Address %s", ipNoPort)
}
return ipNoPort.String(), nil
}