Add WebSocket headers if they are present in the request
Co-authored-by: Romain <rtribotte@users.noreply.github.com>
This commit is contained in:
parent
1cfcf0d318
commit
1ccbf743cb
4 changed files with 135 additions and 63 deletions
|
@ -171,6 +171,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
if reqUpType != "" {
|
if reqUpType != "" {
|
||||||
outReq.Header.Set("Connection", "Upgrade")
|
outReq.Header.Set("Connection", "Upgrade")
|
||||||
outReq.Header.Set("Upgrade", reqUpType)
|
outReq.Header.Set("Upgrade", reqUpType)
|
||||||
|
|
||||||
if strings.EqualFold(reqUpType, "websocket") {
|
if strings.EqualFold(reqUpType, "websocket") {
|
||||||
cleanWebSocketHeaders(&outReq.Header)
|
cleanWebSocketHeaders(&outReq.Header)
|
||||||
}
|
}
|
||||||
|
@ -351,7 +352,7 @@ func isGraphic(s string) bool {
|
||||||
type fasthttpHeader interface {
|
type fasthttpHeader interface {
|
||||||
Peek(key string) []byte
|
Peek(key string) []byte
|
||||||
Set(key string, value string)
|
Set(key string, value string)
|
||||||
SetBytesV(key string, value []byte)
|
SetCanonical(key []byte, value []byte)
|
||||||
DelBytes(key []byte)
|
DelBytes(key []byte)
|
||||||
Del(key string)
|
Del(key string)
|
||||||
ConnectionUpgrade() bool
|
ConnectionUpgrade() bool
|
||||||
|
@ -382,18 +383,33 @@ func fixPragmaCacheControl(header fasthttpHeader) {
|
||||||
// Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive.
|
// Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive.
|
||||||
// https://tools.ietf.org/html/rfc6455#page-20
|
// https://tools.ietf.org/html/rfc6455#page-20
|
||||||
func cleanWebSocketHeaders(headers fasthttpHeader) {
|
func cleanWebSocketHeaders(headers fasthttpHeader) {
|
||||||
headers.SetBytesV("Sec-WebSocket-Key", headers.Peek("Sec-Websocket-Key"))
|
secWebsocketKey := headers.Peek("Sec-Websocket-Key")
|
||||||
|
if len(secWebsocketKey) > 0 {
|
||||||
|
headers.SetCanonical([]byte("Sec-WebSocket-Key"), secWebsocketKey)
|
||||||
headers.Del("Sec-Websocket-Key")
|
headers.Del("Sec-Websocket-Key")
|
||||||
|
}
|
||||||
headers.SetBytesV("Sec-WebSocket-Extensions", headers.Peek("Sec-Websocket-Extensions"))
|
|
||||||
headers.Del("Sec-Websocket-Extensions")
|
secWebsocketExtensions := headers.Peek("Sec-Websocket-Extensions")
|
||||||
|
if len(secWebsocketExtensions) > 0 {
|
||||||
headers.SetBytesV("Sec-WebSocket-Accept", headers.Peek("Sec-Websocket-Accept"))
|
headers.SetCanonical([]byte("Sec-WebSocket-Extensions"), secWebsocketExtensions)
|
||||||
headers.Del("Sec-Websocket-Accept")
|
headers.Del("Sec-Websocket-Extensions")
|
||||||
|
}
|
||||||
headers.SetBytesV("Sec-WebSocket-Protocol", headers.Peek("Sec-Websocket-Protocol"))
|
|
||||||
headers.Del("Sec-Websocket-Protocol")
|
secWebsocketAccept := headers.Peek("Sec-Websocket-Accept")
|
||||||
|
if len(secWebsocketAccept) > 0 {
|
||||||
headers.SetBytesV("Sec-WebSocket-Version", headers.Peek("Sec-Websocket-Version"))
|
headers.SetCanonical([]byte("Sec-WebSocket-Accept"), secWebsocketAccept)
|
||||||
headers.DelBytes([]byte("Sec-Websocket-Version"))
|
headers.Del("Sec-Websocket-Accept")
|
||||||
|
}
|
||||||
|
|
||||||
|
secWebsocketProtocol := headers.Peek("Sec-Websocket-Protocol")
|
||||||
|
if len(secWebsocketProtocol) > 0 {
|
||||||
|
headers.SetCanonical([]byte("Sec-WebSocket-Protocol"), secWebsocketProtocol)
|
||||||
|
headers.Del("Sec-Websocket-Protocol")
|
||||||
|
}
|
||||||
|
|
||||||
|
secWebsocketVersion := headers.Peek("Sec-Websocket-Version")
|
||||||
|
if len(secWebsocketVersion) > 0 {
|
||||||
|
headers.SetCanonical([]byte("Sec-WebSocket-Version"), secWebsocketVersion)
|
||||||
|
headers.Del("Sec-Websocket-Version")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,9 +18,12 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/traefik/traefik/v3/pkg/testhelpers"
|
"github.com/traefik/traefik/v3/pkg/testhelpers"
|
||||||
|
"github.com/valyala/fasthttp"
|
||||||
"golang.org/x/net/websocket"
|
"golang.org/x/net/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const dialTimeout = time.Second
|
||||||
|
|
||||||
func TestWebSocketUpgradeCase(t *testing.T) {
|
func TestWebSocketUpgradeCase(t *testing.T) {
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
challengeKey := r.Header.Get("Sec-Websocket-Key")
|
challengeKey := r.Header.Get("Sec-Websocket-Key")
|
||||||
|
@ -49,6 +52,31 @@ func TestWebSocketUpgradeCase(t *testing.T) {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCleanWebSocketHeaders(t *testing.T) {
|
||||||
|
// Asserts that no headers are sent if the request contain anything.
|
||||||
|
req := fasthttp.AcquireRequest()
|
||||||
|
defer fasthttp.ReleaseRequest(req)
|
||||||
|
|
||||||
|
cleanWebSocketHeaders(&req.Header)
|
||||||
|
|
||||||
|
want := "GET / HTTP/1.1\r\n\r\n"
|
||||||
|
assert.Equal(t, want, req.Header.String())
|
||||||
|
|
||||||
|
// Asserts that the Sec-WebSocket-* is enforced.
|
||||||
|
req.Reset()
|
||||||
|
|
||||||
|
req.Header.Set("Sec-Websocket-Key", "key")
|
||||||
|
req.Header.Set("Sec-Websocket-Extensions", "extensions")
|
||||||
|
req.Header.Set("Sec-Websocket-Accept", "accept")
|
||||||
|
req.Header.Set("Sec-Websocket-Protocol", "protocol")
|
||||||
|
req.Header.Set("Sec-Websocket-Version", "version")
|
||||||
|
|
||||||
|
cleanWebSocketHeaders(&req.Header)
|
||||||
|
|
||||||
|
want = "GET / HTTP/1.1\r\nSec-WebSocket-Key: key\r\nSec-WebSocket-Extensions: extensions\r\nSec-WebSocket-Accept: accept\r\nSec-WebSocket-Protocol: protocol\r\nSec-WebSocket-Version: version\r\n\r\n"
|
||||||
|
assert.Equal(t, want, req.Header.String())
|
||||||
|
}
|
||||||
|
|
||||||
func TestWebSocketTCPClose(t *testing.T) {
|
func TestWebSocketTCPClose(t *testing.T) {
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
upgrader := gorillawebsocket.Upgrader{}
|
upgrader := gorillawebsocket.Upgrader{}
|
||||||
|
@ -535,29 +563,6 @@ func TestForwardsWebsocketTraffic(t *testing.T) {
|
||||||
assert.Equal(t, "ok", resp)
|
assert.Equal(t, "ok", resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createTLSWebsocketServer() *httptest.Server {
|
|
||||||
upgrader := gorillawebsocket.Upgrader{}
|
|
||||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
for {
|
|
||||||
mt, message, err := conn.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
err = conn.WriteMessage(mt, message)
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
return srv
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWebSocketTransferTLSConfig(t *testing.T) {
|
func TestWebSocketTransferTLSConfig(t *testing.T) {
|
||||||
srv := createTLSWebsocketServer()
|
srv := createTLSWebsocketServer()
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
@ -592,7 +597,28 @@ func TestWebSocketTransferTLSConfig(t *testing.T) {
|
||||||
assert.Equal(t, "ok", resp)
|
assert.Equal(t, "ok", resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
const dialTimeout = time.Second
|
func createTLSWebsocketServer() *httptest.Server {
|
||||||
|
upgrader := gorillawebsocket.Upgrader{}
|
||||||
|
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
for {
|
||||||
|
mt, message, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
err = conn.WriteMessage(mt, message)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
return srv
|
||||||
|
}
|
||||||
|
|
||||||
type websocketRequestOpt func(w *websocketRequest)
|
type websocketRequestOpt func(w *websocketRequest)
|
||||||
|
|
||||||
|
|
|
@ -70,19 +70,17 @@ func directorBuilder(target *url.URL, passHostHeader bool, preservePath bool) fu
|
||||||
outReq.Host = outReq.URL.Host
|
outReq.Host = outReq.URL.Host
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isWebSocketUpgrade(outReq) {
|
||||||
cleanWebSocketHeaders(outReq)
|
cleanWebSocketHeaders(outReq)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// cleanWebSocketHeaders Even if the websocket RFC says that headers should be case-insensitive,
|
// cleanWebSocketHeaders Even if the websocket RFC says that headers should be case-insensitive,
|
||||||
// some servers need Sec-WebSocket-Key, Sec-WebSocket-Extensions, Sec-WebSocket-Accept,
|
// some servers need Sec-WebSocket-Key, Sec-WebSocket-Extensions, Sec-WebSocket-Accept,
|
||||||
// Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive.
|
// Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive.
|
||||||
// https://tools.ietf.org/html/rfc6455#page-20
|
// https://tools.ietf.org/html/rfc6455#page-20
|
||||||
func cleanWebSocketHeaders(req *http.Request) {
|
func cleanWebSocketHeaders(req *http.Request) {
|
||||||
if !isWebSocketUpgrade(req) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
req.Header["Sec-WebSocket-Key"] = req.Header["Sec-Websocket-Key"]
|
req.Header["Sec-WebSocket-Key"] = req.Header["Sec-Websocket-Key"]
|
||||||
delete(req.Header, "Sec-Websocket-Key")
|
delete(req.Header, "Sec-Websocket-Key")
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ package httputil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"bytes"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -18,6 +19,8 @@ import (
|
||||||
"golang.org/x/net/websocket"
|
"golang.org/x/net/websocket"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const dialTimeout = time.Second
|
||||||
|
|
||||||
func TestWebSocketTCPClose(t *testing.T) {
|
func TestWebSocketTCPClose(t *testing.T) {
|
||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
upgrader := gorillawebsocket.Upgrader{}
|
upgrader := gorillawebsocket.Upgrader{}
|
||||||
|
@ -419,28 +422,6 @@ func TestForwardsWebsocketTraffic(t *testing.T) {
|
||||||
assert.Equal(t, "ok", resp)
|
assert.Equal(t, "ok", resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createTLSWebsocketServer() *httptest.Server {
|
|
||||||
upgrader := gorillawebsocket.Upgrader{}
|
|
||||||
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
conn, err := upgrader.Upgrade(w, r, nil)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
for {
|
|
||||||
mt, message, err := conn.ReadMessage()
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
err = conn.WriteMessage(mt, message)
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
return srv
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWebSocketTransferTLSConfig(t *testing.T) {
|
func TestWebSocketTransferTLSConfig(t *testing.T) {
|
||||||
srv := createTLSWebsocketServer()
|
srv := createTLSWebsocketServer()
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
|
@ -495,7 +476,58 @@ func TestWebSocketTransferTLSConfig(t *testing.T) {
|
||||||
assert.Equal(t, "ok", resp)
|
assert.Equal(t, "ok", resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
const dialTimeout = time.Second
|
func TestCleanWebSocketHeaders(t *testing.T) {
|
||||||
|
// Asserts that no headers are sent if the request contain anything.
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
|
||||||
|
req.Header.Del("User-Agent")
|
||||||
|
|
||||||
|
cleanWebSocketHeaders(req)
|
||||||
|
|
||||||
|
b := bytes.NewBuffer(nil)
|
||||||
|
err := req.Header.Write(b)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Empty(t, b)
|
||||||
|
|
||||||
|
// Asserts that the Sec-WebSocket-* is enforced.
|
||||||
|
req.Header.Set("Sec-Websocket-Key", "key")
|
||||||
|
req.Header.Set("Sec-Websocket-Extensions", "extensions")
|
||||||
|
req.Header.Set("Sec-Websocket-Accept", "accept")
|
||||||
|
req.Header.Set("Sec-Websocket-Protocol", "protocol")
|
||||||
|
req.Header.Set("Sec-Websocket-Version", "version")
|
||||||
|
|
||||||
|
cleanWebSocketHeaders(req)
|
||||||
|
|
||||||
|
want := http.Header{
|
||||||
|
"Sec-WebSocket-Key": {"key"},
|
||||||
|
"Sec-WebSocket-Extensions": {"extensions"},
|
||||||
|
"Sec-WebSocket-Accept": {"accept"},
|
||||||
|
"Sec-WebSocket-Protocol": {"protocol"},
|
||||||
|
"Sec-WebSocket-Version": {"version"},
|
||||||
|
}
|
||||||
|
assert.Equal(t, want, req.Header)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTLSWebsocketServer() *httptest.Server {
|
||||||
|
var upgrader gorillawebsocket.Upgrader
|
||||||
|
return httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
for {
|
||||||
|
mt, message, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = conn.WriteMessage(mt, message)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
type websocketRequestOpt func(w *websocketRequest)
|
type websocketRequestOpt func(w *websocketRequest)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue