Add WebSocket headers if they are present in the request

Co-authored-by: Romain <rtribotte@users.noreply.github.com>
This commit is contained in:
Kevin Pollet 2025-02-17 20:20:05 +01:00 committed by GitHub
parent 1cfcf0d318
commit 1ccbf743cb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 135 additions and 63 deletions

View file

@ -18,9 +18,12 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/traefik/traefik/v3/pkg/testhelpers"
"github.com/valyala/fasthttp"
"golang.org/x/net/websocket"
)
const dialTimeout = time.Second
func TestWebSocketUpgradeCase(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
challengeKey := r.Header.Get("Sec-Websocket-Key")
@ -49,6 +52,31 @@ func TestWebSocketUpgradeCase(t *testing.T) {
conn.Close()
}
func TestCleanWebSocketHeaders(t *testing.T) {
// Asserts that no headers are sent if the request contain anything.
req := fasthttp.AcquireRequest()
defer fasthttp.ReleaseRequest(req)
cleanWebSocketHeaders(&req.Header)
want := "GET / HTTP/1.1\r\n\r\n"
assert.Equal(t, want, req.Header.String())
// Asserts that the Sec-WebSocket-* is enforced.
req.Reset()
req.Header.Set("Sec-Websocket-Key", "key")
req.Header.Set("Sec-Websocket-Extensions", "extensions")
req.Header.Set("Sec-Websocket-Accept", "accept")
req.Header.Set("Sec-Websocket-Protocol", "protocol")
req.Header.Set("Sec-Websocket-Version", "version")
cleanWebSocketHeaders(&req.Header)
want = "GET / HTTP/1.1\r\nSec-WebSocket-Key: key\r\nSec-WebSocket-Extensions: extensions\r\nSec-WebSocket-Accept: accept\r\nSec-WebSocket-Protocol: protocol\r\nSec-WebSocket-Version: version\r\n\r\n"
assert.Equal(t, want, req.Header.String())
}
func TestWebSocketTCPClose(t *testing.T) {
errChan := make(chan error, 1)
upgrader := gorillawebsocket.Upgrader{}
@ -535,29 +563,6 @@ func TestForwardsWebsocketTraffic(t *testing.T) {
assert.Equal(t, "ok", resp)
}
func createTLSWebsocketServer() *httptest.Server {
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
return srv
}
func TestWebSocketTransferTLSConfig(t *testing.T) {
srv := createTLSWebsocketServer()
defer srv.Close()
@ -592,7 +597,28 @@ func TestWebSocketTransferTLSConfig(t *testing.T) {
assert.Equal(t, "ok", resp)
}
const dialTimeout = time.Second
func createTLSWebsocketServer() *httptest.Server {
upgrader := gorillawebsocket.Upgrader{}
srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer conn.Close()
for {
mt, message, err := conn.ReadMessage()
if err != nil {
break
}
err = conn.WriteMessage(mt, message)
if err != nil {
break
}
}
}))
return srv
}
type websocketRequestOpt func(w *websocketRequest)