Add a new protocol
Co-authored-by: Gérald Croës <gerald@containo.us>
This commit is contained in:
parent
0ca2149408
commit
4a68d29ce2
231 changed files with 6895 additions and 4395 deletions
100
server/service/proxy.go
Normal file
100
server/service/proxy.go
Normal file
|
@ -0,0 +1,100 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/containous/flaeg/parse"
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/log"
|
||||
)
|
||||
|
||||
// StatusClientClosedRequest non-standard HTTP status code for client disconnection
|
||||
const StatusClientClosedRequest = 499
|
||||
|
||||
// StatusClientClosedRequestText non-standard HTTP status for client disconnection
|
||||
const StatusClientClosedRequestText = "Client Closed Request"
|
||||
|
||||
func buildProxy(passHostHeader bool, responseForwarding *config.ResponseForwarding, defaultRoundTripper http.RoundTripper, bufferPool httputil.BufferPool, responseModifier func(*http.Response) error) (http.Handler, error) {
|
||||
var flushInterval parse.Duration
|
||||
if responseForwarding != nil {
|
||||
err := flushInterval.Set(responseForwarding.FlushInterval)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating flush interval: %v", err)
|
||||
}
|
||||
}
|
||||
if flushInterval == 0 {
|
||||
flushInterval = parse.Duration(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
proxy := &httputil.ReverseProxy{
|
||||
Director: func(outReq *http.Request) {
|
||||
u := outReq.URL
|
||||
if outReq.RequestURI != "" {
|
||||
parsedURL, err := url.ParseRequestURI(outReq.RequestURI)
|
||||
if err == nil {
|
||||
u = parsedURL
|
||||
}
|
||||
}
|
||||
|
||||
outReq.URL.Path = u.Path
|
||||
outReq.URL.RawPath = u.RawPath
|
||||
outReq.URL.RawQuery = u.RawQuery
|
||||
outReq.RequestURI = "" // Outgoing request should not have RequestURI
|
||||
|
||||
outReq.Proto = "HTTP/1.1"
|
||||
outReq.ProtoMajor = 1
|
||||
outReq.ProtoMinor = 1
|
||||
|
||||
// Do not pass client Host header unless optsetter PassHostHeader is set.
|
||||
if !passHostHeader {
|
||||
outReq.Host = outReq.URL.Host
|
||||
}
|
||||
|
||||
},
|
||||
Transport: defaultRoundTripper,
|
||||
FlushInterval: time.Duration(flushInterval),
|
||||
ModifyResponse: responseModifier,
|
||||
BufferPool: bufferPool,
|
||||
ErrorHandler: func(w http.ResponseWriter, request *http.Request, err error) {
|
||||
statusCode := http.StatusInternalServerError
|
||||
|
||||
switch {
|
||||
case err == io.EOF:
|
||||
statusCode = http.StatusBadGateway
|
||||
case err == context.Canceled:
|
||||
statusCode = StatusClientClosedRequest
|
||||
default:
|
||||
if e, ok := err.(net.Error); ok {
|
||||
if e.Timeout() {
|
||||
statusCode = http.StatusGatewayTimeout
|
||||
} else {
|
||||
statusCode = http.StatusBadGateway
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("'%d %s' caused by: %v", statusCode, statusText(statusCode), err)
|
||||
w.WriteHeader(statusCode)
|
||||
_, werr := w.Write([]byte(statusText(statusCode)))
|
||||
if werr != nil {
|
||||
log.Debugf("Error while writing status code", werr)
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
func statusText(statusCode int) string {
|
||||
if statusCode == StatusClientClosedRequest {
|
||||
return StatusClientClosedRequestText
|
||||
}
|
||||
return http.StatusText(statusCode)
|
||||
}
|
37
server/service/proxy_test.go
Normal file
37
server/service/proxy_test.go
Normal file
|
@ -0,0 +1,37 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/containous/traefik/testhelpers"
|
||||
)
|
||||
|
||||
type staticTransport struct {
|
||||
res *http.Response
|
||||
}
|
||||
|
||||
func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
return t.res, nil
|
||||
}
|
||||
|
||||
func BenchmarkProxy(b *testing.B) {
|
||||
res := &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader("")),
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/", nil)
|
||||
|
||||
pool := newBufferPool()
|
||||
handler, _ := buildProxy(false, nil, &staticTransport{res}, pool, nil)
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
handler.ServeHTTP(w, req)
|
||||
}
|
||||
}
|
713
server/service/proxy_websocket_test.go
Normal file
713
server/service/proxy_websocket_test.go
Normal file
|
@ -0,0 +1,713 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
gorillawebsocket "github.com/gorilla/websocket"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
func TestWebSocketTCPClose(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
errChan := make(chan error, 1)
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
for {
|
||||
_, _, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, f, srv.URL)
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
_, conn, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
).open()
|
||||
require.NoError(t, err)
|
||||
conn.Close()
|
||||
|
||||
serverErr := <-errChan
|
||||
|
||||
wsErr, ok := serverErr.(*gorillawebsocket.CloseError)
|
||||
assert.Equal(t, true, ok)
|
||||
assert.Equal(t, 1006, wsErr.Code)
|
||||
}
|
||||
|
||||
func TestWebSocketPingPong(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
var upgrader = gorillawebsocket.Upgrader{
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
CheckOrigin: func(*http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/ws", func(writer http.ResponseWriter, request *http.Request) {
|
||||
ws, err := upgrader.Upgrade(writer, request, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
ws.SetPingHandler(func(appData string) error {
|
||||
err = ws.WriteMessage(gorillawebsocket.PongMessage, []byte(appData+"Pong"))
|
||||
require.NoError(t, err)
|
||||
return nil
|
||||
})
|
||||
|
||||
_, _, _ = ws.ReadMessage()
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
req.URL = parseURI(t, srv.URL)
|
||||
f.ServeHTTP(w, req)
|
||||
}))
|
||||
defer proxy.Close()
|
||||
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", resp)
|
||||
defer conn.Close()
|
||||
|
||||
goodErr := fmt.Errorf("signal: %s", "Good data")
|
||||
badErr := fmt.Errorf("signal: %s", "Bad data")
|
||||
conn.SetPongHandler(func(data string) error {
|
||||
if data == "PingPong" {
|
||||
return goodErr
|
||||
}
|
||||
return badErr
|
||||
})
|
||||
|
||||
err = conn.WriteControl(gorillawebsocket.PingMessage, []byte("Ping"), time.Now().Add(time.Second))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = conn.ReadMessage()
|
||||
|
||||
if err != goodErr {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketEcho(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
msg := make([]byte, 4)
|
||||
_, err := conn.Read(msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
fmt.Println(string(msg))
|
||||
|
||||
_, err = conn.Write(msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
}))
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
req.URL = parseURI(t, srv.URL)
|
||||
f.ServeHTTP(w, req)
|
||||
}))
|
||||
defer proxy.Close()
|
||||
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", resp)
|
||||
|
||||
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
|
||||
require.NoError(t, err)
|
||||
|
||||
fmt.Println(conn.ReadMessage())
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
}
|
||||
|
||||
func TestWebSocketPassHost(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
passHost bool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "PassHost false",
|
||||
passHost: false,
|
||||
},
|
||||
{
|
||||
desc: "PassHost true",
|
||||
passHost: true,
|
||||
expected: "example.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
f, err := buildProxy(test.passHost, nil, http.DefaultTransport, nil, nil)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
req := conn.Request()
|
||||
|
||||
if test.passHost {
|
||||
require.Equal(t, test.expected, req.Host)
|
||||
} else {
|
||||
require.NotEqual(t, test.expected, req.Host)
|
||||
}
|
||||
|
||||
msg := make([]byte, 4)
|
||||
_, err = conn.Read(msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
fmt.Println(string(msg))
|
||||
_, err = conn.Write(msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
}))
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
req.URL = parseURI(t, srv.URL)
|
||||
f.ServeHTTP(w, req)
|
||||
}))
|
||||
defer proxy.Close()
|
||||
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
headers.Add("Host", "example.com")
|
||||
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", resp)
|
||||
|
||||
err = conn.WriteMessage(gorillawebsocket.TextMessage, []byte("OK"))
|
||||
require.NoError(t, err)
|
||||
|
||||
fmt.Println(conn.ReadMessage())
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketServerWithoutCheckOrigin(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
}}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
for {
|
||||
mt, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = c.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, f, srv.URL)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
withOrigin("http://127.0.0.2"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithOrigin(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
for {
|
||||
mt, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = c.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, f, srv.URL)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
_, err = newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("echo"),
|
||||
withOrigin("http://127.0.0.2"),
|
||||
).send()
|
||||
require.EqualError(t, err, "bad status")
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithQueryParams(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
assert.Equal(t, "test", r.URL.Query().Get("query"))
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = conn.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, f, srv.URL)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws?query=test"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithHeadersInResponseWriter(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
conn.Close()
|
||||
}))
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
req.URL = parseURI(t, srv.URL)
|
||||
w.Header().Set("HEADER-KEY", "HEADER-VALUE")
|
||||
f.ServeHTTP(w, req)
|
||||
}))
|
||||
defer proxy.Close()
|
||||
|
||||
serverAddr := proxy.Listener.Addr().String()
|
||||
|
||||
headers := http.Header{}
|
||||
webSocketURL := "ws://" + serverAddr + "/ws"
|
||||
headers.Add("Origin", webSocketURL)
|
||||
conn, resp, err := gorillawebsocket.DefaultDialer.Dial(webSocketURL, headers)
|
||||
require.NoError(t, err, "Error during Dial with response: %+v", err, resp)
|
||||
defer conn.Close()
|
||||
|
||||
assert.Equal(t, "HEADER-VALUE", resp.Header.Get("HEADER-KEY"))
|
||||
}
|
||||
|
||||
func TestWebSocketRequestWithEncodedChar(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
assert.Equal(t, "/%3A%2F%2F", r.URL.EscapedPath())
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = conn.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, f, srv.URL)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/%3A%2F%2F"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func TestWebSocketUpgradeFailed(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/ws", func(w http.ResponseWriter, req *http.Request) {
|
||||
w.WriteHeader(400)
|
||||
})
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
path := req.URL.Path // keep the original path
|
||||
|
||||
if path == "/ws" {
|
||||
// Set new backend URL
|
||||
req.URL = parseURI(t, srv.URL)
|
||||
req.URL.Path = path
|
||||
f.ServeHTTP(w, req)
|
||||
} else {
|
||||
w.WriteHeader(200)
|
||||
}
|
||||
}))
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
conn, err := net.DialTimeout("tcp", proxyAddr, dialTimeout)
|
||||
|
||||
require.NoError(t, err)
|
||||
defer conn.Close()
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, "ws://127.0.0.1/ws", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req.Header.Add("upgrade", "websocket")
|
||||
req.Header.Add("Connection", "upgrade")
|
||||
|
||||
err = req.Write(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
// First request works with 400
|
||||
br := bufio.NewReader(conn)
|
||||
resp, err := http.ReadResponse(br, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 400, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestForwardsWebsocketTraffic(t *testing.T) {
|
||||
f, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/ws", websocket.Handler(func(conn *websocket.Conn) {
|
||||
_, err := conn.Write([]byte("ok"))
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.Close()
|
||||
require.NoError(t, err)
|
||||
}))
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
mux.ServeHTTP(w, req)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
proxy := createProxyWithForwarder(t, f, srv.URL)
|
||||
defer proxy.Close()
|
||||
|
||||
proxyAddr := proxy.Listener.Addr().String()
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("echo"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
func createTLSWebsocketServer() *httptest.Server {
|
||||
upgrader := gorillawebsocket.Upgrader{}
|
||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
for {
|
||||
mt, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = conn.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}))
|
||||
return srv
|
||||
}
|
||||
|
||||
func TestWebSocketTransferTLSConfig(t *testing.T) {
|
||||
srv := createTLSWebsocketServer()
|
||||
defer srv.Close()
|
||||
|
||||
forwarderWithoutTLSConfig, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyWithoutTLSConfig := createProxyWithForwarder(t, forwarderWithoutTLSConfig, srv.URL)
|
||||
defer proxyWithoutTLSConfig.Close()
|
||||
|
||||
proxyAddr := proxyWithoutTLSConfig.Listener.Addr().String()
|
||||
|
||||
_, err = newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.EqualError(t, err, "bad status")
|
||||
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
forwarderWithTLSConfig, err := buildProxy(true, nil, transport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyWithTLSConfig := createProxyWithForwarder(t, forwarderWithTLSConfig, srv.URL)
|
||||
defer proxyWithTLSConfig.Close()
|
||||
|
||||
proxyAddr = proxyWithTLSConfig.Listener.Addr().String()
|
||||
|
||||
resp, err := newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
|
||||
http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
|
||||
forwarderWithTLSConfigFromDefaultTransport, err := buildProxy(true, nil, http.DefaultTransport, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
proxyWithTLSConfigFromDefaultTransport := createProxyWithForwarder(t, forwarderWithTLSConfigFromDefaultTransport, srv.URL)
|
||||
defer proxyWithTLSConfig.Close()
|
||||
|
||||
proxyAddr = proxyWithTLSConfigFromDefaultTransport.Listener.Addr().String()
|
||||
|
||||
resp, err = newWebsocketRequest(
|
||||
withServer(proxyAddr),
|
||||
withPath("/ws"),
|
||||
withData("ok"),
|
||||
).send()
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
}
|
||||
|
||||
const dialTimeout = time.Second
|
||||
|
||||
type websocketRequestOpt func(w *websocketRequest)
|
||||
|
||||
func withServer(server string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.ServerAddr = server
|
||||
}
|
||||
}
|
||||
|
||||
func withPath(path string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.Path = path
|
||||
}
|
||||
}
|
||||
|
||||
func withData(data string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.Data = data
|
||||
}
|
||||
}
|
||||
|
||||
func withOrigin(origin string) websocketRequestOpt {
|
||||
return func(w *websocketRequest) {
|
||||
w.Origin = origin
|
||||
}
|
||||
}
|
||||
|
||||
func newWebsocketRequest(opts ...websocketRequestOpt) *websocketRequest {
|
||||
wsrequest := &websocketRequest{}
|
||||
for _, opt := range opts {
|
||||
opt(wsrequest)
|
||||
}
|
||||
if wsrequest.Origin == "" {
|
||||
wsrequest.Origin = "http://" + wsrequest.ServerAddr
|
||||
}
|
||||
if wsrequest.Config == nil {
|
||||
wsrequest.Config, _ = websocket.NewConfig(fmt.Sprintf("ws://%s%s", wsrequest.ServerAddr, wsrequest.Path), wsrequest.Origin)
|
||||
}
|
||||
return wsrequest
|
||||
}
|
||||
|
||||
type websocketRequest struct {
|
||||
ServerAddr string
|
||||
Path string
|
||||
Data string
|
||||
Origin string
|
||||
Config *websocket.Config
|
||||
}
|
||||
|
||||
func (w *websocketRequest) send() (string, error) {
|
||||
conn, _, err := w.open()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer conn.Close()
|
||||
if _, err := conn.Write([]byte(w.Data)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
var msg = make([]byte, 512)
|
||||
var n int
|
||||
n, err = conn.Read(msg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
received := string(msg[:n])
|
||||
return received, nil
|
||||
}
|
||||
|
||||
func (w *websocketRequest) open() (*websocket.Conn, net.Conn, error) {
|
||||
client, err := net.DialTimeout("tcp", w.ServerAddr, dialTimeout)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
conn, err := websocket.NewClient(w.Config, client)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return conn, client, err
|
||||
}
|
||||
|
||||
func parseURI(t *testing.T, uri string) *url.URL {
|
||||
out, err := url.ParseRequestURI(uri)
|
||||
require.NoError(t, err)
|
||||
return out
|
||||
}
|
||||
|
||||
func createProxyWithForwarder(t *testing.T, proxy http.Handler, url string) *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
path := req.URL.Path // keep the original path
|
||||
// Set new backend URL
|
||||
req.URL = parseURI(t, url)
|
||||
req.URL.Path = path
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
}))
|
||||
}
|
|
@ -3,14 +3,12 @@ package service
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/containous/alice"
|
||||
"github.com/containous/flaeg/parse"
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/healthcheck"
|
||||
"github.com/containous/traefik/log"
|
||||
|
@ -19,7 +17,6 @@ import (
|
|||
"github.com/containous/traefik/old/middlewares/pipelining"
|
||||
"github.com/containous/traefik/server/cookie"
|
||||
"github.com/containous/traefik/server/internal"
|
||||
"github.com/vulcand/oxy/forward"
|
||||
"github.com/vulcand/oxy/roundrobin"
|
||||
)
|
||||
|
||||
|
@ -46,8 +43,8 @@ type Manager struct {
|
|||
configs map[string]*config.Service
|
||||
}
|
||||
|
||||
// Build Creates a http.Handler for a service configuration.
|
||||
func (m *Manager) Build(rootCtx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) {
|
||||
// BuildHTTP Creates a http.Handler for a service configuration.
|
||||
func (m *Manager) BuildHTTP(rootCtx context.Context, serviceName string, responseModifier func(*http.Response) error) (http.Handler, error) {
|
||||
ctx := log.With(rootCtx, log.Str(log.ServiceName, serviceName))
|
||||
|
||||
serviceName = internal.GetQualifiedName(ctx, serviceName)
|
||||
|
@ -55,6 +52,7 @@ func (m *Manager) Build(rootCtx context.Context, serviceName string, responseMod
|
|||
|
||||
if conf, ok := m.configs[serviceName]; ok {
|
||||
// TODO Should handle multiple service types
|
||||
// FIXME Check if the service is declared multiple times with different types
|
||||
if conf.LoadBalancer != nil {
|
||||
return m.getLoadBalancerServiceHandler(ctx, serviceName, conf.LoadBalancer, responseModifier)
|
||||
}
|
||||
|
@ -69,8 +67,7 @@ func (m *Manager) getLoadBalancerServiceHandler(
|
|||
service *config.LoadBalancerService,
|
||||
responseModifier func(*http.Response) error,
|
||||
) (http.Handler, error) {
|
||||
|
||||
fwd, err := m.buildForwarder(service.PassHostHeader, service.ResponseForwarding, responseModifier)
|
||||
fwd, err := buildProxy(service.PassHostHeader, service.ResponseForwarding, m.defaultRoundTripper, m.bufferPool, responseModifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -258,32 +255,3 @@ func (m *Manager) upsertServers(ctx context.Context, lb healthcheck.BalancerHand
|
|||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) buildForwarder(passHostHeader bool, responseForwarding *config.ResponseForwarding, responseModifier func(*http.Response) error) (http.Handler, error) {
|
||||
|
||||
var flushInterval parse.Duration
|
||||
if responseForwarding != nil {
|
||||
err := flushInterval.Set(responseForwarding.FlushInterval)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating flush interval: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return forward.New(
|
||||
forward.Stream(true),
|
||||
forward.PassHostHeader(passHostHeader),
|
||||
forward.RoundTripper(m.defaultRoundTripper),
|
||||
forward.ResponseModifier(responseModifier),
|
||||
forward.BufferPool(m.bufferPool),
|
||||
forward.StreamingFlushInterval(time.Duration(flushInterval)),
|
||||
forward.WebsocketConnectionClosedHook(func(req *http.Request, conn net.Conn) {
|
||||
server := req.Context().Value(http.ServerContextKey).(*http.Server)
|
||||
if server != nil {
|
||||
connState := server.ConnState
|
||||
if connState != nil {
|
||||
connState(conn, http.StateClosed)
|
||||
}
|
||||
}
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
|
|
@ -320,7 +320,7 @@ func TestManager_Build(t *testing.T) {
|
|||
ctx = internal.AddProviderInContext(ctx, test.providerName+".foobar")
|
||||
}
|
||||
|
||||
_, err := manager.Build(ctx, test.serviceName, nil)
|
||||
_, err := manager.BuildHTTP(ctx, test.serviceName, nil)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
|
67
server/service/tcp/service.go
Normal file
67
server/service/tcp/service.go
Normal file
|
@ -0,0 +1,67 @@
|
|||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/containous/traefik/config"
|
||||
"github.com/containous/traefik/log"
|
||||
"github.com/containous/traefik/server/internal"
|
||||
"github.com/containous/traefik/tcp"
|
||||
)
|
||||
|
||||
// Manager is the TCPHandlers factory
|
||||
type Manager struct {
|
||||
configs map[string]*config.TCPService
|
||||
}
|
||||
|
||||
// NewManager creates a new manager
|
||||
func NewManager(configs map[string]*config.TCPService) *Manager {
|
||||
return &Manager{
|
||||
configs: configs,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildTCP Creates a tcp.Handler for a service configuration.
|
||||
func (m *Manager) BuildTCP(rootCtx context.Context, serviceName string) (tcp.Handler, error) {
|
||||
ctx := log.With(rootCtx, log.Str(log.ServiceName, serviceName))
|
||||
|
||||
serviceName = internal.GetQualifiedName(ctx, serviceName)
|
||||
ctx = internal.AddProviderInContext(ctx, serviceName)
|
||||
|
||||
if conf, ok := m.configs[serviceName]; ok {
|
||||
// FIXME Check if the service is declared multiple times with different types
|
||||
if conf.LoadBalancer != nil {
|
||||
loadBalancer := tcp.NewRRLoadBalancer()
|
||||
|
||||
var handler tcp.Handler
|
||||
for _, server := range conf.LoadBalancer.Servers {
|
||||
_, err := parseIP(server.Address)
|
||||
if err == nil {
|
||||
handler, _ = tcp.NewProxy(server.Address)
|
||||
loadBalancer.AddServer(handler)
|
||||
} else {
|
||||
log.FromContext(ctx).Errorf("Invalid IP address for a %s server %s: %v", serviceName, server.Address, err)
|
||||
}
|
||||
}
|
||||
return loadBalancer, nil
|
||||
}
|
||||
return nil, fmt.Errorf("the service %q doesn't have any TCP load balancer", serviceName)
|
||||
}
|
||||
return nil, fmt.Errorf("the service %q does not exits", serviceName)
|
||||
}
|
||||
|
||||
func parseIP(s string) (string, error) {
|
||||
ip, _, err := net.SplitHostPort(s)
|
||||
if err == nil {
|
||||
return ip, nil
|
||||
}
|
||||
|
||||
ipNoPort := net.ParseIP(s)
|
||||
if ipNoPort == nil {
|
||||
return "", fmt.Errorf("invalid IP Address %s", ipNoPort)
|
||||
}
|
||||
|
||||
return ipNoPort.String(), nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue