perf: improve forwarded header and recovery middlewares

Co-authored-by: Ludovic Fernandez <ldez@users.noreply.github.com>
This commit is contained in:
Julien Salleyron 2021-01-21 10:04:04 +01:00 committed by GitHub
parent c74918321d
commit a90b2a672e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 59 additions and 36 deletions

View file

@ -84,19 +84,28 @@ func (x *XForwarded) isTrustedIP(ip string) bool {
// removeIPv6Zone removes the zone if the given IP is an ipv6 address and it has {zone} information in it,
// like "[fe80::d806:a55d:eb1b:49cc%vEthernet (vmxnet3 Ethernet Adapter - Virtual Switch)]:64692".
func removeIPv6Zone(clientIP string) string {
return strings.Split(clientIP, "%")[0]
if idx := strings.Index(clientIP, "%"); idx != -1 {
return clientIP[:idx]
}
return clientIP
}
// isWebsocketRequest returns whether the specified HTTP request is a websocket handshake request.
func isWebsocketRequest(req *http.Request) bool {
containsHeader := func(name, value string) bool {
items := strings.Split(req.Header.Get(name), ",")
for _, item := range items {
if value == strings.ToLower(strings.TrimSpace(item)) {
h := unsafeHeader(req.Header).Get(name)
for {
pos := strings.Index(h, ",")
if pos == -1 {
return strings.EqualFold(value, strings.TrimSpace(h))
}
if strings.EqualFold(value, strings.TrimSpace(h[:pos])) {
return true
}
h = h[pos:]
}
return false
}
return containsHeader(connection, "upgrade") && containsHeader(upgrade, "websocket")
}
@ -110,7 +119,7 @@ func forwardedPort(req *http.Request) string {
return port
}
if req.Header.Get(xForwardedProto) == "https" || req.Header.Get(xForwardedProto) == "wss" {
if unsafeHeader(req.Header).Get(xForwardedProto) == "https" || unsafeHeader(req.Header).Get(xForwardedProto) == "wss" {
return "443"
}
@ -125,38 +134,38 @@ func (x *XForwarded) rewrite(outreq *http.Request) {
if clientIP, _, err := net.SplitHostPort(outreq.RemoteAddr); err == nil {
clientIP = removeIPv6Zone(clientIP)
if outreq.Header.Get(xRealIP) == "" {
outreq.Header.Set(xRealIP, clientIP)
if unsafeHeader(outreq.Header).Get(xRealIP) == "" {
unsafeHeader(outreq.Header).Set(xRealIP, clientIP)
}
}
xfProto := outreq.Header.Get(xForwardedProto)
xfProto := unsafeHeader(outreq.Header).Get(xForwardedProto)
if xfProto == "" {
if isWebsocketRequest(outreq) {
if outreq.TLS != nil {
outreq.Header.Set(xForwardedProto, "wss")
unsafeHeader(outreq.Header).Set(xForwardedProto, "wss")
} else {
outreq.Header.Set(xForwardedProto, "ws")
unsafeHeader(outreq.Header).Set(xForwardedProto, "ws")
}
} else {
if outreq.TLS != nil {
outreq.Header.Set(xForwardedProto, "https")
unsafeHeader(outreq.Header).Set(xForwardedProto, "https")
} else {
outreq.Header.Set(xForwardedProto, "http")
unsafeHeader(outreq.Header).Set(xForwardedProto, "http")
}
}
}
if xfPort := outreq.Header.Get(xForwardedPort); xfPort == "" {
outreq.Header.Set(xForwardedPort, forwardedPort(outreq))
if xfPort := unsafeHeader(outreq.Header).Get(xForwardedPort); xfPort == "" {
unsafeHeader(outreq.Header).Set(xForwardedPort, forwardedPort(outreq))
}
if xfHost := outreq.Header.Get(xForwardedHost); xfHost == "" && outreq.Host != "" {
outreq.Header.Set(xForwardedHost, outreq.Host)
if xfHost := unsafeHeader(outreq.Header).Get(xForwardedHost); xfHost == "" && outreq.Host != "" {
unsafeHeader(outreq.Header).Set(xForwardedHost, outreq.Host)
}
if x.hostname != "" {
outreq.Header.Set(xForwardedServer, x.hostname)
unsafeHeader(outreq.Header).Set(xForwardedServer, x.hostname)
}
}
@ -164,7 +173,7 @@ func (x *XForwarded) rewrite(outreq *http.Request) {
func (x *XForwarded) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !x.insecure && !x.isTrustedIP(r.RemoteAddr) {
for _, h := range xHeaders {
r.Header.Del(h)
unsafeHeader(r.Header).Del(h)
}
}
@ -172,3 +181,22 @@ func (x *XForwarded) ServeHTTP(w http.ResponseWriter, r *http.Request) {
x.next.ServeHTTP(w, r)
}
// unsafeHeader allows to manage Header values.
// Must be used only when the header name is already a canonical key.
type unsafeHeader map[string][]string
func (h unsafeHeader) Set(key, value string) {
h[key] = []string{value}
}
func (h unsafeHeader) Get(key string) string {
if len(h[key]) == 0 {
return ""
}
return h[key][0]
}
func (h unsafeHeader) Del(key string) {
delete(h, key)
}