1
0
Fork 0

Detect and drop broken conns in the fastproxy pool

Co-authored-by: Romain <rtribotte@users.noreply.github.com>
This commit is contained in:
Kevin Pollet 2024-10-25 14:26:04 +02:00 committed by GitHub
parent b22e081c7c
commit e3ed52ba7c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 426 additions and 280 deletions

View file

@ -4,18 +4,14 @@ import (
"bufio"
"bytes"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptrace"
"net/http/httputil"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/rs/zerolog/log"
proxyhttputil "github.com/traefik/traefik/v3/pkg/proxy/httputil"
@ -57,15 +53,6 @@ func (p *pool[T]) Put(x T) {
p.pool.Put(x)
}
type buffConn struct {
*bufio.Reader
net.Conn
}
func (b buffConn) Read(p []byte) (int, error) {
return b.Reader.Read(p)
}
type writeDetector struct {
net.Conn
@ -112,21 +99,17 @@ type ReverseProxy struct {
connPool *connPool
bufferPool pool[[]byte]
readerPool pool[*bufio.Reader]
writerPool pool[*bufio.Writer]
limitReaderPool pool[*io.LimitedReader]
writerPool pool[*bufio.Writer]
proxyAuth string
targetURL *url.URL
passHostHeader bool
preservePath bool
responseHeaderTimeout time.Duration
targetURL *url.URL
passHostHeader bool
preservePath bool
}
// NewReverseProxy creates a new ReverseProxy.
func NewReverseProxy(targetURL, proxyURL *url.URL, debug, passHostHeader, preservePath bool, responseHeaderTimeout time.Duration, connPool *connPool) (*ReverseProxy, error) {
func NewReverseProxy(targetURL, proxyURL *url.URL, debug, passHostHeader, preservePath bool, connPool *connPool) (*ReverseProxy, error) {
var proxyAuth string
if proxyURL != nil && proxyURL.User != nil && targetURL.Scheme == "http" {
username := proxyURL.User.Username()
@ -135,13 +118,12 @@ func NewReverseProxy(targetURL, proxyURL *url.URL, debug, passHostHeader, preser
}
return &ReverseProxy{
debug: debug,
passHostHeader: passHostHeader,
preservePath: preservePath,
targetURL: targetURL,
proxyAuth: proxyAuth,
connPool: connPool,
responseHeaderTimeout: responseHeaderTimeout,
debug: debug,
passHostHeader: passHostHeader,
preservePath: preservePath,
targetURL: targetURL,
proxyAuth: proxyAuth,
connPool: connPool,
}, nil
}
@ -273,8 +255,15 @@ func (p *ReverseProxy) roundTrip(rw http.ResponseWriter, req *http.Request, outR
return fmt.Errorf("acquire connection: %w", err)
}
// Before writing the request,
// we mark the conn as expecting to handle a response.
co.expectedResponse.Store(true)
wd := &writeDetector{Conn: co}
// TODO: do not wait to write the full request before reading the response (to handle "100 Continue").
// TODO: this is currently impossible with fasthttp to write the request partially (headers only).
// Currently, writing the request fully is a mandatory step before handling the response.
err = p.writeRequest(wd, outReq)
if wd.written && trace != nil && trace.WroteRequest != nil {
// WroteRequest hook is used by the tracing middleware to detect if the request has been written.
@ -293,169 +282,17 @@ func (p *ReverseProxy) roundTrip(rw http.ResponseWriter, req *http.Request, outR
}
}
br := p.readerPool.Get()
if br == nil {
br = bufio.NewReaderSize(co, bufioSize)
}
defer p.readerPool.Put(br)
br.Reset(co)
res := fasthttp.AcquireResponse()
defer fasthttp.ReleaseResponse(res)
res.Header.SetNoDefaultContentType(true)
for {
var timer *time.Timer
errTimeout := atomic.Pointer[timeoutError]{}
if p.responseHeaderTimeout > 0 {
timer = time.AfterFunc(p.responseHeaderTimeout, func() {
errTimeout.Store(&timeoutError{errors.New("timeout awaiting response headers")})
co.Close()
})
}
res.Header.SetNoDefaultContentType(true)
if err := res.Header.Read(br); err != nil {
if p.responseHeaderTimeout > 0 {
if errT := errTimeout.Load(); errT != nil {
return errT
}
}
co.Close()
return err
}
if timer != nil {
timer.Stop()
}
fixPragmaCacheControl(&res.Header)
resCode := res.StatusCode()
is1xx := 100 <= resCode && resCode <= 199
// treat 101 as a terminal status, see issue 26161
is1xxNonTerminal := is1xx && resCode != http.StatusSwitchingProtocols
if is1xxNonTerminal {
removeConnectionHeaders(&res.Header)
h := rw.Header()
for _, header := range hopHeaders {
res.Header.Del(header)
}
res.Header.VisitAll(func(key, value []byte) {
rw.Header().Add(string(key), string(value))
})
rw.WriteHeader(res.StatusCode())
// Clear headers, it's not automatically done by ResponseWriter.WriteHeader() for 1xx responses
for k := range h {
delete(h, k)
}
res.Reset()
res.Header.Reset()
res.Header.SetNoDefaultContentType(true)
continue
}
break
// Sending the responseWriter unlocks the connection readLoop, to handle the response.
co.RWCh <- rwWithUpgrade{
RW: rw,
Upgrade: upgradeResponseHandler(req.Context(), reqUpType),
}
announcedTrailers := res.Header.Peek("Trailer")
// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
if res.StatusCode() == http.StatusSwitchingProtocols {
// As the connection has been hijacked, it cannot be added back to the pool.
handleUpgradeResponse(rw, req, reqUpType, res, buffConn{Conn: co, Reader: br})
return nil
}
removeConnectionHeaders(&res.Header)
for _, header := range hopHeaders {
res.Header.Del(header)
}
if len(announcedTrailers) > 0 {
res.Header.Add("Trailer", string(announcedTrailers))
}
res.Header.VisitAll(func(key, value []byte) {
rw.Header().Add(string(key), string(value))
})
rw.WriteHeader(res.StatusCode())
// Chunked response, Content-Length is set to -1 by FastProxy when "Transfer-Encoding: chunked" header is received.
if res.Header.ContentLength() == -1 {
cbr := httputil.NewChunkedReader(br)
b := p.bufferPool.Get()
if b == nil {
b = make([]byte, bufferSize)
}
defer p.bufferPool.Put(b)
if _, err := io.CopyBuffer(&writeFlusher{rw}, cbr, b); err != nil {
co.Close()
return err
}
res.Header.Reset()
res.Header.SetNoDefaultContentType(true)
if err := res.Header.ReadTrailer(br); err != nil {
co.Close()
return err
}
if res.Header.Len() > 0 {
var announcedTrailersKey []string
if len(announcedTrailers) > 0 {
announcedTrailersKey = strings.Split(string(announcedTrailers), ",")
}
res.Header.VisitAll(func(key, value []byte) {
for _, s := range announcedTrailersKey {
if strings.EqualFold(s, strings.TrimSpace(string(key))) {
rw.Header().Add(string(key), string(value))
return
}
}
rw.Header().Add(http.TrailerPrefix+string(key), string(value))
})
}
p.connPool.ReleaseConn(co)
return nil
}
brl := p.limitReaderPool.Get()
if brl == nil {
brl = &io.LimitedReader{}
}
defer p.limitReaderPool.Put(brl)
brl.R = br
brl.N = int64(res.Header.ContentLength())
b := p.bufferPool.Get()
if b == nil {
b = make([]byte, bufferSize)
}
defer p.bufferPool.Put(b)
if _, err := io.CopyBuffer(rw, brl, b); err != nil {
co.Close()
if err := <-co.ErrCh; err != nil {
return err
}
p.connPool.ReleaseConn(co)
return nil
}