Add WebSocket headers if they are present in the request

Co-authored-by: Romain <rtribotte@users.noreply.github.com>
This commit is contained in:
Kevin Pollet 2025-02-17 20:20:05 +01:00 committed by GitHub
parent 1cfcf0d318
commit 1ccbf743cb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 135 additions and 63 deletions

View file

@ -171,6 +171,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if reqUpType != "" { if reqUpType != "" {
outReq.Header.Set("Connection", "Upgrade") outReq.Header.Set("Connection", "Upgrade")
outReq.Header.Set("Upgrade", reqUpType) outReq.Header.Set("Upgrade", reqUpType)
if strings.EqualFold(reqUpType, "websocket") { if strings.EqualFold(reqUpType, "websocket") {
cleanWebSocketHeaders(&outReq.Header) cleanWebSocketHeaders(&outReq.Header)
} }
@ -351,7 +352,7 @@ func isGraphic(s string) bool {
type fasthttpHeader interface { type fasthttpHeader interface {
Peek(key string) []byte Peek(key string) []byte
Set(key string, value string) Set(key string, value string)
SetBytesV(key string, value []byte) SetCanonical(key []byte, value []byte)
DelBytes(key []byte) DelBytes(key []byte)
Del(key string) Del(key string)
ConnectionUpgrade() bool ConnectionUpgrade() bool
@ -382,18 +383,33 @@ func fixPragmaCacheControl(header fasthttpHeader) {
// Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive. // Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive.
// https://tools.ietf.org/html/rfc6455#page-20 // https://tools.ietf.org/html/rfc6455#page-20
func cleanWebSocketHeaders(headers fasthttpHeader) { func cleanWebSocketHeaders(headers fasthttpHeader) {
headers.SetBytesV("Sec-WebSocket-Key", headers.Peek("Sec-Websocket-Key")) secWebsocketKey := headers.Peek("Sec-Websocket-Key")
if len(secWebsocketKey) > 0 {
headers.SetCanonical([]byte("Sec-WebSocket-Key"), secWebsocketKey)
headers.Del("Sec-Websocket-Key") headers.Del("Sec-Websocket-Key")
}
headers.SetBytesV("Sec-WebSocket-Extensions", headers.Peek("Sec-Websocket-Extensions")) secWebsocketExtensions := headers.Peek("Sec-Websocket-Extensions")
if len(secWebsocketExtensions) > 0 {
headers.SetCanonical([]byte("Sec-WebSocket-Extensions"), secWebsocketExtensions)
headers.Del("Sec-Websocket-Extensions") headers.Del("Sec-Websocket-Extensions")
}
headers.SetBytesV("Sec-WebSocket-Accept", headers.Peek("Sec-Websocket-Accept")) secWebsocketAccept := headers.Peek("Sec-Websocket-Accept")
if len(secWebsocketAccept) > 0 {
headers.SetCanonical([]byte("Sec-WebSocket-Accept"), secWebsocketAccept)
headers.Del("Sec-Websocket-Accept") headers.Del("Sec-Websocket-Accept")
}
headers.SetBytesV("Sec-WebSocket-Protocol", headers.Peek("Sec-Websocket-Protocol")) secWebsocketProtocol := headers.Peek("Sec-Websocket-Protocol")
if len(secWebsocketProtocol) > 0 {
headers.SetCanonical([]byte("Sec-WebSocket-Protocol"), secWebsocketProtocol)
headers.Del("Sec-Websocket-Protocol") headers.Del("Sec-Websocket-Protocol")
}
headers.SetBytesV("Sec-WebSocket-Version", headers.Peek("Sec-Websocket-Version")) secWebsocketVersion := headers.Peek("Sec-Websocket-Version")
headers.DelBytes([]byte("Sec-Websocket-Version")) if len(secWebsocketVersion) > 0 {
headers.SetCanonical([]byte("Sec-WebSocket-Version"), secWebsocketVersion)
headers.Del("Sec-Websocket-Version")
}
} }

View file

@ -18,9 +18,12 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/traefik/traefik/v3/pkg/testhelpers" "github.com/traefik/traefik/v3/pkg/testhelpers"
"github.com/valyala/fasthttp"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
) )
const dialTimeout = time.Second
func TestWebSocketUpgradeCase(t *testing.T) { func TestWebSocketUpgradeCase(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
challengeKey := r.Header.Get("Sec-Websocket-Key") challengeKey := r.Header.Get("Sec-Websocket-Key")
@ -49,6 +52,31 @@ func TestWebSocketUpgradeCase(t *testing.T) {
conn.Close() conn.Close()
} }
func TestCleanWebSocketHeaders(t *testing.T) {
// Asserts that no headers are sent if the request contain anything.
req := fasthttp.AcquireRequest()
defer fasthttp.ReleaseRequest(req)
cleanWebSocketHeaders(&req.Header)
want := "GET / HTTP/1.1\r\n\r\n"
assert.Equal(t, want, req.Header.String())
// Asserts that the Sec-WebSocket-* is enforced.
req.Reset()
req.Header.Set("Sec-Websocket-Key", "key")
req.Header.Set("Sec-Websocket-Extensions", "extensions")
req.Header.Set("Sec-Websocket-Accept", "accept")
req.Header.Set("Sec-Websocket-Protocol", "protocol")
req.Header.Set("Sec-Websocket-Version", "version")
cleanWebSocketHeaders(&req.Header)
want = "GET / HTTP/1.1\r\nSec-WebSocket-Key: key\r\nSec-WebSocket-Extensions: extensions\r\nSec-WebSocket-Accept: accept\r\nSec-WebSocket-Protocol: protocol\r\nSec-WebSocket-Version: version\r\n\r\n"
assert.Equal(t, want, req.Header.String())
}
func TestWebSocketTCPClose(t *testing.T) { func TestWebSocketTCPClose(t *testing.T) {
errChan := make(chan error, 1) errChan := make(chan error, 1)
upgrader := gorillawebsocket.Upgrader{} upgrader := gorillawebsocket.Upgrader{}
@ -535,29 +563,6 @@ func TestForwardsWebsocketTraffic(t *testing.T) {
assert.Equal(t, "ok", resp) assert.Equal(t, "ok", resp)
} }
func createTLSWebsocketServer() *httptest.Server {
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
return srv
}
func TestWebSocketTransferTLSConfig(t *testing.T) { func TestWebSocketTransferTLSConfig(t *testing.T) {
srv := createTLSWebsocketServer() srv := createTLSWebsocketServer()
defer srv.Close() defer srv.Close()
@ -592,7 +597,28 @@ func TestWebSocketTransferTLSConfig(t *testing.T) {
assert.Equal(t, "ok", resp) assert.Equal(t, "ok", resp)
} }
const dialTimeout = time.Second func createTLSWebsocketServer() *httptest.Server {
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
return srv
}
type websocketRequestOpt func(w *websocketRequest) type websocketRequestOpt func(w *websocketRequest)

View file

@ -70,8 +70,10 @@ func directorBuilder(target *url.URL, passHostHeader bool, preservePath bool) fu
outReq.Host = outReq.URL.Host outReq.Host = outReq.URL.Host
} }
if isWebSocketUpgrade(outReq) {
cleanWebSocketHeaders(outReq) cleanWebSocketHeaders(outReq)
} }
}
} }
// cleanWebSocketHeaders Even if the websocket RFC says that headers should be case-insensitive, // cleanWebSocketHeaders Even if the websocket RFC says that headers should be case-insensitive,
@ -79,10 +81,6 @@ func directorBuilder(target *url.URL, passHostHeader bool, preservePath bool) fu
// Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive. // Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive.
// https://tools.ietf.org/html/rfc6455#page-20 // https://tools.ietf.org/html/rfc6455#page-20
func cleanWebSocketHeaders(req *http.Request) { func cleanWebSocketHeaders(req *http.Request) {
if !isWebSocketUpgrade(req) {
return
}
req.Header["Sec-WebSocket-Key"] = req.Header["Sec-Websocket-Key"] req.Header["Sec-WebSocket-Key"] = req.Header["Sec-Websocket-Key"]
delete(req.Header, "Sec-Websocket-Key") delete(req.Header, "Sec-Websocket-Key")

View file

@ -2,6 +2,7 @@ package httputil
import ( import (
"bufio" "bufio"
"bytes"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@ -18,6 +19,8 @@ import (
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
) )
const dialTimeout = time.Second
func TestWebSocketTCPClose(t *testing.T) { func TestWebSocketTCPClose(t *testing.T) {
errChan := make(chan error, 1) errChan := make(chan error, 1)
upgrader := gorillawebsocket.Upgrader{} upgrader := gorillawebsocket.Upgrader{}
@ -419,28 +422,6 @@ func TestForwardsWebsocketTraffic(t *testing.T) {
assert.Equal(t, "ok", resp) assert.Equal(t, "ok", resp)
} }
func createTLSWebsocketServer() *httptest.Server {
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
return srv
}
func TestWebSocketTransferTLSConfig(t *testing.T) { func TestWebSocketTransferTLSConfig(t *testing.T) {
srv := createTLSWebsocketServer() srv := createTLSWebsocketServer()
defer srv.Close() defer srv.Close()
@ -495,7 +476,58 @@ func TestWebSocketTransferTLSConfig(t *testing.T) {
assert.Equal(t, "ok", resp) assert.Equal(t, "ok", resp)
} }
const dialTimeout = time.Second func TestCleanWebSocketHeaders(t *testing.T) {
// Asserts that no headers are sent if the request contain anything.
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
req.Header.Del("User-Agent")
cleanWebSocketHeaders(req)
b := bytes.NewBuffer(nil)
err := req.Header.Write(b)
require.NoError(t, err)
assert.Empty(t, b)
// Asserts that the Sec-WebSocket-* is enforced.
req.Header.Set("Sec-Websocket-Key", "key")
req.Header.Set("Sec-Websocket-Extensions", "extensions")
req.Header.Set("Sec-Websocket-Accept", "accept")
req.Header.Set("Sec-Websocket-Protocol", "protocol")
req.Header.Set("Sec-Websocket-Version", "version")
cleanWebSocketHeaders(req)
want := http.Header{
"Sec-WebSocket-Key": {"key"},
"Sec-WebSocket-Extensions": {"extensions"},
"Sec-WebSocket-Accept": {"accept"},
"Sec-WebSocket-Protocol": {"protocol"},
"Sec-WebSocket-Version": {"version"},
}
assert.Equal(t, want, req.Header)
}
func createTLSWebsocketServer() *httptest.Server {
var upgrader gorillawebsocket.Upgrader
return httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
}
type websocketRequestOpt func(w *websocketRequest) type websocketRequestOpt func(w *websocketRequest)