Resync oxy with original repository

This commit is contained in:
SALLEYRON Julien 2017-11-22 18:20:03 +01:00 committed by Traefiker
parent da5e4a13bf
commit bee8ebb00b
31 changed files with 650 additions and 808 deletions

View file

@ -7,13 +7,15 @@ import (
"crypto/tls"
"io"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"os"
"reflect"
"strconv"
"strings"
"time"
log "github.com/Sirupsen/logrus"
"github.com/gorilla/websocket"
"github.com/vulcand/oxy/utils"
)
@ -29,15 +31,7 @@ type optSetter func(f *Forwarder) error
// be delegated
func PassHostHeader(b bool) optSetter {
return func(f *Forwarder) error {
f.passHost = b
return nil
}
}
// StreamResponse forces streaming body (flushes response directly to client)
func StreamResponse(b bool) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.streamResponse = b
f.httpForwarder.passHost = b
return nil
}
}
@ -46,7 +40,7 @@ func StreamResponse(b bool) optSetter {
// Forwarder will use http.DefaultTransport as a default round tripper
func RoundTripper(r http.RoundTripper) optSetter {
return func(f *Forwarder) error {
f.roundTripper = r
f.httpForwarder.roundTripper = r
return nil
}
}
@ -59,10 +53,11 @@ func Rewriter(r ReqRewriter) optSetter {
}
}
// WebsocketRewriter defines a request rewriter for the websocket forwarder
func WebsocketRewriter(r ReqRewriter) optSetter {
// PassHostHeader specifies if a client's Host header field should
// be delegated
func WebsocketTLSClientConfig(tcc *tls.Config) optSetter {
return func(f *Forwarder) error {
f.websocketForwarder.rewriter = r
f.httpForwarder.tlsClientConfig = tcc
return nil
}
}
@ -75,27 +70,74 @@ func ErrorHandler(h utils.ErrorHandler) optSetter {
}
}
// Logger specifies the logger to use.
// Forwarder will default to oxyutils.NullLogger if no logger has been specified
func Logger(l utils.Logger) optSetter {
// Stream specifies if HTTP responses should be streamed.
func Stream(stream bool) optSetter {
return func(f *Forwarder) error {
f.stream = stream
return nil
}
}
// Logger defines the logger the forwarder will use.
//
// It defaults to logrus.StandardLogger(), the global logger used by logrus.
func Logger(l *log.Logger) optSetter {
return func(f *Forwarder) error {
f.log = l
return nil
}
}
func StateListener(stateListener UrlForwardingStateListener) optSetter {
return func(f *Forwarder) error {
f.stateListener = stateListener
return nil
}
}
func ResponseModifier(responseModifier func(*http.Response) error) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.modifyResponse = responseModifier
return nil
}
}
func StreamingFlushInterval(flushInterval time.Duration) optSetter {
return func(f *Forwarder) error {
f.httpForwarder.flushInterval = flushInterval
return nil
}
}
type ErrorHandlingRoundTripper struct {
http.RoundTripper
errorHandler utils.ErrorHandler
}
func (rt ErrorHandlingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
res, err := rt.RoundTripper.RoundTrip(req)
if err != nil {
// We use the recorder from httptest because there isn't another `public` implementation of a recorder.
recorder := httptest.NewRecorder()
rt.errorHandler.ServeHTTP(recorder, req, err)
res = recorder.Result()
err = nil
}
return res, err
}
// Forwarder wraps two traffic forwarding implementations: HTTP and websockets.
// It decides based on the specified request which implementation to use
type Forwarder struct {
*httpForwarder
*websocketForwarder
*handlerContext
stateListener UrlForwardingStateListener
stream bool
}
// handlerContext defines a handler context for error reporting and logging
type handlerContext struct {
errHandler utils.ErrorHandler
log utils.Logger
}
// httpForwarder is a handler that can reverse proxy
@ -104,32 +146,40 @@ type httpForwarder struct {
roundTripper http.RoundTripper
rewriter ReqRewriter
passHost bool
streamResponse bool
flushInterval time.Duration
modifyResponse func(*http.Response) error
tlsClientConfig *tls.Config
log *log.Logger
}
// websocketForwarder is a handler that can reverse proxy
// websocket traffic
type websocketForwarder struct {
rewriter ReqRewriter
TLSClientConfig *tls.Config
}
const (
defaultFlushInterval = time.Duration(100) * time.Millisecond
StateConnected = iota
StateDisconnected
)
type UrlForwardingStateListener func(*url.URL, int)
// New creates an instance of Forwarder based on the provided list of configuration options
func New(setters ...optSetter) (*Forwarder, error) {
f := &Forwarder{
httpForwarder: &httpForwarder{},
websocketForwarder: &websocketForwarder{},
handlerContext: &handlerContext{},
httpForwarder: &httpForwarder{log: log.StandardLogger()},
handlerContext: &handlerContext{},
}
for _, s := range setters {
if err := s(f); err != nil {
return nil, err
}
}
if f.httpForwarder.roundTripper == nil {
f.httpForwarder.roundTripper = http.DefaultTransport
if !f.stream {
f.flushInterval = 0
} else if f.flushInterval == 0 {
f.flushInterval = defaultFlushInterval
}
f.websocketForwarder.TLSClientConfig = f.httpForwarder.roundTripper.(*http.Transport).TLSClientConfig
if f.httpForwarder.rewriter == nil {
h, err := os.Hostname()
if err != nil {
@ -137,136 +187,104 @@ func New(setters ...optSetter) (*Forwarder, error) {
}
f.httpForwarder.rewriter = &HeaderRewriter{TrustForwardHeader: true, Hostname: h}
}
if f.log == nil {
f.log = utils.NullLogger
if f.httpForwarder.roundTripper == nil {
f.httpForwarder.roundTripper = http.DefaultTransport
}
if f.errHandler == nil {
f.errHandler = utils.DefaultHandler
}
if f.tlsClientConfig == nil {
f.tlsClientConfig = f.httpForwarder.roundTripper.(*http.Transport).TLSClientConfig
}
f.httpForwarder.roundTripper = ErrorHandlingRoundTripper{
RoundTripper: f.httpForwarder.roundTripper,
errorHandler: f.errHandler,
}
return f, nil
}
// ServeHTTP decides which forwarder to use based on the specified
// request and delegates to the proper implementation
func (f *Forwarder) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if isWebsocketRequest(req) {
f.websocketForwarder.serveHTTP(w, req, f.handlerContext)
if f.log.Level >= log.DebugLevel {
logEntry := f.log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/forward: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/forward: competed ServeHttp on request")
}
if f.stateListener != nil {
f.stateListener(req.URL, StateConnected)
defer f.stateListener(req.URL, StateDisconnected)
}
if IsWebsocketRequest(req) {
f.httpForwarder.serveWebSocket(w, req, f.handlerContext)
} else {
f.httpForwarder.serveHTTP(w, req, f.handlerContext)
}
}
// serveHTTP forwards HTTP traffic using the configured transport
func (f *httpForwarder) serveHTTP(w http.ResponseWriter, req *http.Request, ctx *handlerContext) {
start := time.Now().UTC()
response, err := f.roundTripper.RoundTrip(f.copyRequest(req, req.URL))
if err != nil {
ctx.log.Errorf("Error forwarding to %v, err: %v", req.URL, err)
ctx.errHandler.ServeHTTP(w, req, err)
return
}
utils.CopyHeaders(w.Header(), response.Header)
// Remove hop-by-hop headers.
utils.RemoveHeaders(w.Header(), HopHeaders...)
announcedTrailerKeyCount := len(response.Trailer)
if announcedTrailerKeyCount > 0 {
trailerKeys := make([]string, 0, announcedTrailerKeyCount)
for k := range response.Trailer {
trailerKeys = append(trailerKeys, k)
}
w.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
}
w.WriteHeader(response.StatusCode)
stream := f.streamResponse
if !stream {
contentType, err := utils.GetHeaderMediaType(response.Header, ContentType)
func (f *httpForwarder) getUrlFromRequest(req *http.Request) *url.URL {
// If the Request was created by Go via a real HTTP request, RequestURI will
// contain the original query string. If the Request was created in code, RequestURI
// will be empty, and we will use the URL object instead
u := req.URL
if req.RequestURI != "" {
parsedURL, err := url.ParseRequestURI(req.RequestURI)
if err == nil {
stream = contentType == "text/event-stream"
u = parsedURL
} else {
f.log.Warnf("vulcand/oxy/forward: error when parsing RequestURI: %s", err)
}
}
flush := stream || req.ProtoMajor == 2
written, err := io.Copy(newResponseFlusher(w, flush), response.Body)
if err != nil {
ctx.log.Errorf("Error copying upstream response body: %v", err)
ctx.errHandler.ServeHTTP(w, req, err)
return
}
defer response.Body.Close()
forceSetTrailers := len(response.Trailer) != announcedTrailerKeyCount
shallowCopyTrailers(w.Header(), response.Trailer, forceSetTrailers)
if written != 0 {
w.Header().Set(ContentLength, strconv.FormatInt(written, 10))
}
if req.TLS != nil {
ctx.log.Infof("Round trip: %v, code: %v, duration: %v tls:version: %x, tls:resume:%t, tls:csuite:%x, tls:server:%v",
req.URL, response.StatusCode, time.Now().UTC().Sub(start),
req.TLS.Version,
req.TLS.DidResume,
req.TLS.CipherSuite,
req.TLS.ServerName)
} else {
ctx.log.Infof("Round trip: %v, code: %v, duration: %v",
req.URL, response.StatusCode, time.Now().UTC().Sub(start))
}
return u
}
// copyRequest makes a copy of the specified request to be sent using the configured
// transport
func (f *httpForwarder) copyRequest(req *http.Request, u *url.URL) *http.Request {
outReq := new(http.Request)
*outReq = *req // includes shallow copies of maps, but we handle this below
// Modify the request to handle the target URL
func (f *httpForwarder) modifyRequest(outReq *http.Request, target *url.URL) {
outReq.URL = utils.CopyURL(outReq.URL)
outReq.URL.Scheme = target.Scheme
outReq.URL.Host = target.Host
u := f.getUrlFromRequest(outReq)
outReq.URL.Path = u.Path
outReq.URL.RawPath = u.RawPath
outReq.URL.RawQuery = u.RawQuery
outReq.RequestURI = "" // Outgoing request should not have RequestURI
outReq.URL = utils.CopyURL(req.URL)
outReq.URL.Scheme = u.Scheme
outReq.URL.Host = u.Host
outReq.URL.Opaque = req.RequestURI
// raw query is already included in RequestURI, so ignore it to avoid dupes
outReq.URL.RawQuery = ""
// Do not pass client Host header unless optsetter PassHostHeader is set.
if !f.passHost {
outReq.Host = u.Host
outReq.Host = target.Host
}
outReq.Proto = "HTTP/1.1"
outReq.ProtoMajor = 1
outReq.ProtoMinor = 1
// Overwrite close flag so we can keep persistent connection for the backend servers
outReq.Close = false
outReq.Header = make(http.Header)
utils.CopyHeaders(outReq.Header, req.Header)
if f.rewriter != nil {
f.rewriter.Rewrite(outReq)
}
if req.ContentLength == 0 {
// https://github.com/golang/go/issues/16036: nil Body for http.Transport retries
outReq.Body = nil
}
return outReq
}
// serveHTTP forwards websocket traffic
func (f *websocketForwarder) serveHTTP(w http.ResponseWriter, req *http.Request, ctx *handlerContext) {
outReq := f.copyRequest(req, req.URL)
func (f *httpForwarder) serveWebSocket(w http.ResponseWriter, req *http.Request, ctx *handlerContext) {
if f.log.Level >= log.DebugLevel {
logEntry := f.log.WithField("Request", utils.DumpHttpRequest(req))
logEntry.Debug("vulcand/oxy/forward/websocket: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/forward/websocket: competed ServeHttp on request")
}
outReq := f.copyWebSocketRequest(req)
dialer := websocket.DefaultDialer
if outReq.URL.Scheme == "wss" && f.TLSClientConfig != nil {
dialer.TLSClientConfig = f.TLSClientConfig.Clone()
if outReq.URL.Scheme == "wss" && f.tlsClientConfig != nil {
dialer.TLSClientConfig = f.tlsClientConfig.Clone()
// WebSocket is only in http/1.1
dialer.TLSClientConfig.NextProtos = []string{"http/1.1"}
}
targetConn, resp, err := dialer.Dial(outReq.URL.String(), outReq.Header)
@ -274,33 +292,33 @@ func (f *websocketForwarder) serveHTTP(w http.ResponseWriter, req *http.Request,
if resp == nil {
ctx.errHandler.ServeHTTP(w, req, err)
} else {
ctx.log.Errorf("Error dialing %q: %v with resp: %d %s", outReq.Host, err, resp.StatusCode, resp.Status)
log.Errorf("vulcand/oxy/forward/websocket: Error dialing %q: %v with resp: %d %s", outReq.Host, err, resp.StatusCode, resp.Status)
hijacker, ok := w.(http.Hijacker)
if !ok {
ctx.log.Errorf("%s can not be hijack", reflect.TypeOf(w))
log.Errorf("vulcand/oxy/forward/websocket: %s can not be hijack", reflect.TypeOf(w))
ctx.errHandler.ServeHTTP(w, req, err)
return
}
conn, _, err := hijacker.Hijack()
if err != nil {
ctx.log.Errorf("Failed to hijack responseWriter")
ctx.errHandler.ServeHTTP(w, req, err)
conn, _, errHijack := hijacker.Hijack()
if errHijack != nil {
log.Errorf("vulcand/oxy/forward/websocket: Failed to hijack responseWriter")
ctx.errHandler.ServeHTTP(w, req, errHijack)
return
}
defer conn.Close()
err = resp.Write(conn)
if err != nil {
ctx.log.Errorf("Failed to forward response")
ctx.errHandler.ServeHTTP(w, req, err)
errWrite := resp.Write(conn)
if errWrite != nil {
log.Errorf("vulcand/oxy/forward/websocket: Failed to forward response")
ctx.errHandler.ServeHTTP(w, req, errWrite)
return
}
}
return
}
//Only the targetConn choose to CheckOrigin or not
// Only the targetConn choose to CheckOrigin or not
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool {
return true
}}
@ -308,62 +326,64 @@ func (f *websocketForwarder) serveHTTP(w http.ResponseWriter, req *http.Request,
utils.RemoveHeaders(resp.Header, WebsocketUpgradeHeaders...)
underlyingConn, err := upgrader.Upgrade(w, req, resp.Header)
if err != nil {
ctx.log.Errorf("Error while upgrading connection : %v", err)
log.Errorf("vulcand/oxy/forward/websocket: Error while upgrading connection : %v", err)
return
}
defer underlyingConn.Close()
defer targetConn.Close()
errc := make(chan error, 2)
replicate := func(dst io.Writer, src io.Reader) {
_, err := io.Copy(dst, src)
errc <- err
replicate := func(dst io.Writer, src io.Reader, dstName string, srcName string) {
_, errCopy := io.Copy(dst, src)
if errCopy != nil {
f.log.Errorf("vulcand/oxy/forward/websocket: Error when copying from %s to %s using io.Copy: %v", srcName, dstName, errCopy)
} else {
f.log.Infof("vulcand/oxy/forward/websocket: Copying from %s to %s using io.Copy completed without error.", srcName, dstName)
}
errc <- errCopy
}
go replicate(targetConn.UnderlyingConn(), underlyingConn.UnderlyingConn())
go replicate(targetConn.UnderlyingConn(), underlyingConn.UnderlyingConn(), "backend", "client")
// Try to read the first message
t, msg, err := targetConn.ReadMessage()
msgType, msg, err := targetConn.ReadMessage()
if err != nil {
ctx.log.Errorf("Couldn't read first message : %v", err)
log.Errorf("vulcand/oxy/forward/websocket: Couldn't read first message : %v", err)
} else {
underlyingConn.WriteMessage(t, msg)
underlyingConn.WriteMessage(msgType, msg)
}
go replicate(underlyingConn.UnderlyingConn(), targetConn.UnderlyingConn())
go replicate(underlyingConn.UnderlyingConn(), targetConn.UnderlyingConn(), "client", "backend")
<-errc
}
// copyRequest makes a copy of the specified request.
func (f *websocketForwarder) copyRequest(req *http.Request, u *url.URL) (outReq *http.Request) {
// copyWebsocketRequest makes a copy of the specified request.
func (f *httpForwarder) copyWebSocketRequest(req *http.Request) (outReq *http.Request) {
outReq = new(http.Request)
*outReq = *req // includes shallow copies of maps, but we handle this below
outReq.URL = utils.CopyURL(req.URL)
outReq.URL.Scheme = u.Scheme
outReq.URL.Scheme = req.URL.Scheme
//sometimes backends might be registered as HTTP/HTTPS servers so translate URLs to websocket URLs.
switch u.Scheme {
// sometimes backends might be registered as HTTP/HTTPS servers so translate URLs to websocket URLs.
switch req.URL.Scheme {
case "https":
outReq.URL.Scheme = "wss"
case "http":
outReq.URL.Scheme = "ws"
}
if requestURI, err := url.ParseRequestURI(outReq.RequestURI); err == nil {
if requestURI.RawPath != "" {
outReq.URL.Path = requestURI.RawPath
} else {
outReq.URL.Path = requestURI.Path
}
outReq.URL.RawQuery = requestURI.RawQuery
}
u := f.getUrlFromRequest(outReq)
outReq.URL.Host = u.Host
outReq.URL.Path = u.Path
outReq.URL.RawPath = u.RawPath
outReq.URL.RawQuery = u.RawQuery
outReq.RequestURI = "" // Outgoing request should not have RequestURI
outReq.URL.Host = req.URL.Host
outReq.Header = make(http.Header)
//gorilla websocket use this header to set the request.Host tested in checkSameOrigin
// gorilla websocket use this header to set the request.Host tested in checkSameOrigin
outReq.Header.Set("Host", outReq.Host)
utils.CopyHeaders(outReq.Header, req.Header)
utils.RemoveHeaders(outReq.Header, WebsocketDialHeaders...)
@ -374,9 +394,48 @@ func (f *websocketForwarder) copyRequest(req *http.Request, u *url.URL) (outReq
return outReq
}
// serveHTTP forwards HTTP traffic using the configured transport
func (f *httpForwarder) serveHTTP(w http.ResponseWriter, inReq *http.Request, ctx *handlerContext) {
if f.log.Level >= log.DebugLevel {
logEntry := f.log.WithField("Request", utils.DumpHttpRequest(inReq))
logEntry.Debug("vulcand/oxy/forward/http: begin ServeHttp on request")
defer logEntry.Debug("vulcand/oxy/forward/http: completed ServeHttp on request")
}
pw := &utils.ProxyWriter{
W: w,
}
start := time.Now().UTC()
outReq := new(http.Request)
*outReq = *inReq // includes shallow copies of maps, but we handle this in Director
revproxy := httputil.ReverseProxy{
Director: func(req *http.Request) {
f.modifyRequest(req, inReq.URL)
},
Transport: f.roundTripper,
FlushInterval: f.flushInterval,
ModifyResponse: f.modifyResponse,
}
revproxy.ServeHTTP(pw, outReq)
if inReq.TLS != nil {
f.log.Infof("vulcand/oxy/forward/http: Round trip: %v, code: %v, Length: %v, duration: %v tls:version: %x, tls:resume:%t, tls:csuite:%x, tls:server:%v",
inReq.URL, pw.Code, pw.Length, time.Now().UTC().Sub(start),
inReq.TLS.Version,
inReq.TLS.DidResume,
inReq.TLS.CipherSuite,
inReq.TLS.ServerName)
} else {
f.log.Infof("vulcand/oxy/forward/http: Round trip: %v, code: %v, Length: %v, duration: %v",
inReq.URL, pw.Code, pw.Length, time.Now().UTC().Sub(start))
}
}
// isWebsocketRequest determines if the specified HTTP request is a
// websocket handshake request
func isWebsocketRequest(req *http.Request) bool {
func IsWebsocketRequest(req *http.Request) bool {
containsHeader := func(name, value string) bool {
items := strings.Split(req.Header.Get(name), ",")
for _, item := range items {
@ -388,12 +447,3 @@ func isWebsocketRequest(req *http.Request) bool {
}
return containsHeader(Connection, "upgrade") && containsHeader(Upgrade, "websocket")
}
func shallowCopyTrailers(dstHeader, srcTrailer http.Header, forceSetTrailers bool) {
for k, vv := range srcTrailer {
if forceSetTrailers {
k = http.TrailerPrefix + k
}
dstHeader[k] = vv
}
}