Add WebSocket headers if they are present in the request

Co-authored-by: Romain <rtribotte@users.noreply.github.com>
This commit is contained in:
Kevin Pollet 2025-02-17 20:20:05 +01:00 committed by GitHub
parent 1cfcf0d318
commit 1ccbf743cb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 135 additions and 63 deletions

View file

@ -171,6 +171,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if reqUpType != "" {
outReq.Header.Set("Connection", "Upgrade")
outReq.Header.Set("Upgrade", reqUpType)
if strings.EqualFold(reqUpType, "websocket") {
cleanWebSocketHeaders(&outReq.Header)
}
@ -351,7 +352,7 @@ func isGraphic(s string) bool {
type fasthttpHeader interface {
Peek(key string) []byte
Set(key string, value string)
SetBytesV(key string, value []byte)
SetCanonical(key []byte, value []byte)
DelBytes(key []byte)
Del(key string)
ConnectionUpgrade() bool
@ -382,18 +383,33 @@ func fixPragmaCacheControl(header fasthttpHeader) {
// Sec-WebSocket-Protocol and Sec-WebSocket-Version to be case-sensitive.
// https://tools.ietf.org/html/rfc6455#page-20
func cleanWebSocketHeaders(headers fasthttpHeader) {
headers.SetBytesV("Sec-WebSocket-Key", headers.Peek("Sec-Websocket-Key"))
headers.Del("Sec-Websocket-Key")
secWebsocketKey := headers.Peek("Sec-Websocket-Key")
if len(secWebsocketKey) > 0 {
headers.SetCanonical([]byte("Sec-WebSocket-Key"), secWebsocketKey)
headers.Del("Sec-Websocket-Key")
}
headers.SetBytesV("Sec-WebSocket-Extensions", headers.Peek("Sec-Websocket-Extensions"))
headers.Del("Sec-Websocket-Extensions")
secWebsocketExtensions := headers.Peek("Sec-Websocket-Extensions")
if len(secWebsocketExtensions) > 0 {
headers.SetCanonical([]byte("Sec-WebSocket-Extensions"), secWebsocketExtensions)
headers.Del("Sec-Websocket-Extensions")
}
headers.SetBytesV("Sec-WebSocket-Accept", headers.Peek("Sec-Websocket-Accept"))
headers.Del("Sec-Websocket-Accept")
secWebsocketAccept := headers.Peek("Sec-Websocket-Accept")
if len(secWebsocketAccept) > 0 {
headers.SetCanonical([]byte("Sec-WebSocket-Accept"), secWebsocketAccept)
headers.Del("Sec-Websocket-Accept")
}
headers.SetBytesV("Sec-WebSocket-Protocol", headers.Peek("Sec-Websocket-Protocol"))
headers.Del("Sec-Websocket-Protocol")
secWebsocketProtocol := headers.Peek("Sec-Websocket-Protocol")
if len(secWebsocketProtocol) > 0 {
headers.SetCanonical([]byte("Sec-WebSocket-Protocol"), secWebsocketProtocol)
headers.Del("Sec-Websocket-Protocol")
}
headers.SetBytesV("Sec-WebSocket-Version", headers.Peek("Sec-Websocket-Version"))
headers.DelBytes([]byte("Sec-Websocket-Version"))
secWebsocketVersion := headers.Peek("Sec-Websocket-Version")
if len(secWebsocketVersion) > 0 {
headers.SetCanonical([]byte("Sec-WebSocket-Version"), secWebsocketVersion)
headers.Del("Sec-Websocket-Version")
}
}