Merge v1.4.1 into master

This commit is contained in:
Fernandez Ludovic 2017-10-25 11:15:50 +02:00
commit a0c72cdf00
24 changed files with 531 additions and 113 deletions

View file

@ -57,11 +57,6 @@ calling mux.Vars():
vars := mux.Vars(request)
category := vars["category"]
Note that if any capturing groups are present, mux will panic() during parsing. To prevent
this, convert any capturing groups to non-capturing, e.g. change "/{sort:(asc|desc)}" to
"/{sort:(?:asc|desc)}". This is a change from prior versions which behaved unpredictably
when capturing groups were present.
And this is all you need to know about the basic usage. More advanced options
are explained below.

View file

@ -11,7 +11,10 @@ import (
"path"
"regexp"
"sort"
"strings"
)
var (
ErrMethodMismatch = errors.New("method is not allowed")
)
// NewRouter returns a new router instance.
@ -40,6 +43,10 @@ func NewRouter() *Router {
type Router struct {
// Configurable Handler to be used when no route matches.
NotFoundHandler http.Handler
// Configurable Handler to be used when the request method does not match the route.
MethodNotAllowedHandler http.Handler
// Parent route, if this is a subrouter.
parent parentRoute
// Routes to be matched, in order.
@ -66,6 +73,11 @@ func (r *Router) Match(req *http.Request, match *RouteMatch) bool {
}
}
if match.MatchErr == ErrMethodMismatch && r.MethodNotAllowedHandler != nil {
match.Handler = r.MethodNotAllowedHandler
return true
}
// Closest match for a router (includes sub-routers)
if r.NotFoundHandler != nil {
match.Handler = r.NotFoundHandler
@ -82,7 +94,7 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if !r.skipClean {
path := req.URL.Path
if r.useEncodedPath {
path = getPath(req)
path = req.URL.EscapedPath()
}
// Clean path to canonical form and redirect.
if p := cleanPath(path); p != path {
@ -106,9 +118,15 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
req = setVars(req, match.Vars)
req = setCurrentRoute(req, match.Route)
}
if handler == nil && match.MatchErr == ErrMethodMismatch {
handler = methodNotAllowedHandler()
}
if handler == nil {
handler = http.NotFoundHandler()
}
if !r.KeepContext {
defer contextClear(req)
}
@ -356,6 +374,11 @@ type RouteMatch struct {
Route *Route
Handler http.Handler
Vars map[string]string
// MatchErr is set to appropriate matching error
// It is set to ErrMethodMismatch if there is a mismatch in
// the request method and route method
MatchErr error
}
type contextKey int
@ -397,28 +420,6 @@ func setCurrentRoute(r *http.Request, val interface{}) *http.Request {
// Helpers
// ----------------------------------------------------------------------------
// getPath returns the escaped path if possible; doing what URL.EscapedPath()
// which was added in go1.5 does
func getPath(req *http.Request) string {
if req.RequestURI != "" {
// Extract the path from RequestURI (which is escaped unlike URL.Path)
// as detailed here as detailed in https://golang.org/pkg/net/url/#URL
// for < 1.5 server side workaround
// http://localhost/path/here?v=1 -> /path/here
path := req.RequestURI
path = strings.TrimPrefix(path, req.URL.Scheme+`://`)
path = strings.TrimPrefix(path, req.URL.Host)
if i := strings.LastIndex(path, "?"); i > -1 {
path = path[:i]
}
if i := strings.LastIndex(path, "#"); i > -1 {
path = path[:i]
}
return path
}
return req.URL.Path
}
// cleanPath returns the canonical path for p, eliminating . and .. elements.
// Borrowed from the net/http package.
func cleanPath(p string) string {
@ -557,3 +558,12 @@ func matchMapWithRegex(toCheck map[string]*regexp.Regexp, toMatch map[string][]s
}
return true
}
// methodNotAllowed replies to the request with an HTTP status code 405.
func methodNotAllowed(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusMethodNotAllowed)
}
// methodNotAllowedHandler returns a simple request handler
// that replies to each request with a status code 405.
func methodNotAllowedHandler() http.Handler { return http.HandlerFunc(methodNotAllowed) }

View file

@ -109,13 +109,6 @@ func newRouteRegexp(tpl string, matchHost, matchPrefix, matchQuery, strictSlash,
if errCompile != nil {
return nil, errCompile
}
// Check for capturing groups which used to work in older versions
if reg.NumSubexp() != len(idxs)/2 {
panic(fmt.Sprintf("route %s contains capture groups in its regexp. ", template) +
"Only non-capturing groups are accepted: e.g. (?:pattern) instead of (pattern)")
}
// Done!
return &routeRegexp{
template: template,
@ -141,7 +134,7 @@ type routeRegexp struct {
matchQuery bool
// The strictSlash value defined on the route, but disabled if PathPrefix was used.
strictSlash bool
// Determines whether to use encoded path from getPath function or unencoded
// Determines whether to use encoded req.URL.EnscapedPath() or unencoded
// req.URL.Path for path matching
useEncodedPath bool
// Expanded regexp.
@ -162,7 +155,7 @@ func (r *routeRegexp) Match(req *http.Request, match *RouteMatch) bool {
}
path := req.URL.Path
if r.useEncodedPath {
path = getPath(req)
path = req.URL.EscapedPath()
}
return r.regexp.MatchString(path)
}
@ -272,7 +265,7 @@ func (v *routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route)
}
path := req.URL.Path
if r.useEncodedPath {
path = getPath(req)
path = req.URL.EscapedPath()
}
// Store path variables.
if v.path != nil {
@ -320,7 +313,14 @@ func getHost(r *http.Request) string {
}
func extractVars(input string, matches []int, names []string, output map[string]string) {
for i, name := range names {
output[name] = input[matches[2*i+2]:matches[2*i+3]]
matchesCount := 0
prevEnd := -1
for i := 2; i < len(matches) && matchesCount < len(names); i += 2 {
if prevEnd < matches[i+1] {
value := input[matches[i]:matches[i+1]]
output[names[matchesCount]] = value
prevEnd = matches[i+1]
matchesCount++
}
}
}

View file

@ -54,12 +54,27 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool {
if r.buildOnly || r.err != nil {
return false
}
var matchErr error
// Match everything.
for _, m := range r.matchers {
if matched := m.Match(req, match); !matched {
if _, ok := m.(methodMatcher); ok {
matchErr = ErrMethodMismatch
continue
}
matchErr = nil
return false
}
}
if matchErr != nil {
match.MatchErr = matchErr
return false
}
match.MatchErr = nil
// Yay, we have a match. Let's collect some info about it.
if match.Route == nil {
match.Route = r
@ -70,6 +85,7 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool {
if match.Vars == nil {
match.Vars = make(map[string]string)
}
// Set variables.
if r.regexp != nil {
r.regexp.setMatch(req, match, r)
@ -607,6 +623,44 @@ func (r *Route) GetPathRegexp() (string, error) {
return r.regexp.path.regexp.String(), nil
}
// GetQueriesRegexp returns the expanded regular expressions used to match the
// route queries.
// This is useful for building simple REST API documentation and for instrumentation
// against third-party services.
// An empty list will be returned if the route does not have queries.
func (r *Route) GetQueriesRegexp() ([]string, error) {
if r.err != nil {
return nil, r.err
}
if r.regexp == nil || r.regexp.queries == nil {
return nil, errors.New("mux: route doesn't have queries")
}
var queries []string
for _, query := range r.regexp.queries {
queries = append(queries, query.regexp.String())
}
return queries, nil
}
// GetQueriesTemplates returns the templates used to build the
// query matching.
// This is useful for building simple REST API documentation and for instrumentation
// against third-party services.
// An empty list will be returned if the route does not define queries.
func (r *Route) GetQueriesTemplates() ([]string, error) {
if r.err != nil {
return nil, r.err
}
if r.regexp == nil || r.regexp.queries == nil {
return nil, errors.New("mux: route doesn't have queries")
}
var queries []string
for _, query := range r.regexp.queries {
queries = append(queries, query.template)
}
return queries, nil
}
// GetMethods returns the methods the route matches against
// This is useful for building simple REST API documentation and for instrumentation
// against third-party services.

View file

@ -190,7 +190,9 @@ func (f *httpForwarder) serveHTTP(w http.ResponseWriter, req *http.Request, ctx
stream = contentType == "text/event-stream"
}
}
written, err := io.Copy(newResponseFlusher(w, stream), response.Body)
flush := stream || req.ProtoMajor == 2
written, err := io.Copy(newResponseFlusher(w, flush), response.Body)
if err != nil {
ctx.log.Errorf("Error copying upstream response body: %v", err)
ctx.errHandler.ServeHTTP(w, req, err)
@ -264,7 +266,8 @@ func (f *websocketForwarder) serveHTTP(w http.ResponseWriter, req *http.Request,
dialer := websocket.DefaultDialer
if outReq.URL.Scheme == "wss" && f.TLSClientConfig != nil {
dialer.TLSClientConfig = f.TLSClientConfig
dialer.TLSClientConfig = f.TLSClientConfig.Clone()
dialer.TLSClientConfig.NextProtos = []string{"http/1.1"}
}
targetConn, resp, err := dialer.Dial(outReq.URL.String(), outReq.Header)
if err != nil {

View file

@ -6,6 +6,7 @@ const (
XForwardedHost = "X-Forwarded-Host"
XForwardedPort = "X-Forwarded-Port"
XForwardedServer = "X-Forwarded-Server"
XRealIp = "X-Real-Ip"
Connection = "Connection"
KeepAlive = "Keep-Alive"
ProxyAuthenticate = "Proxy-Authenticate"
@ -50,3 +51,12 @@ var WebsocketUpgradeHeaders = []string{
Connection,
SecWebsocketAccept,
}
var XHeaders = []string{
XForwardedProto,
XForwardedFor,
XForwardedHost,
XForwardedPort,
XForwardedServer,
XRealIp,
}

View file

@ -15,30 +15,36 @@ type HeaderRewriter struct {
}
func (rw *HeaderRewriter) Rewrite(req *http.Request) {
if !rw.TrustForwardHeader {
utils.RemoveHeaders(req.Header, XHeaders...)
}
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if rw.TrustForwardHeader {
if prior, ok := req.Header[XForwardedFor]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
if prior, ok := req.Header[XForwardedFor]; ok {
req.Header.Set(XForwardedFor, strings.Join(prior, ", ")+", "+clientIP)
} else {
req.Header.Set(XForwardedFor, clientIP)
}
if req.Header.Get(XRealIp) == "" {
req.Header.Set(XRealIp, clientIP)
}
req.Header.Set(XForwardedFor, clientIP)
}
if xfp := req.Header.Get(XForwardedProto); xfp != "" && rw.TrustForwardHeader {
req.Header.Set(XForwardedProto, xfp)
} else if req.TLS != nil {
req.Header.Set(XForwardedProto, "https")
} else {
req.Header.Set(XForwardedProto, "http")
xfProto := req.Header.Get(XForwardedProto)
if xfProto == "" {
if req.TLS != nil {
req.Header.Set(XForwardedProto, "https")
} else {
req.Header.Set(XForwardedProto, "http")
}
}
if xfp := req.Header.Get(XForwardedPort); xfp != "" && rw.TrustForwardHeader {
req.Header.Set(XForwardedPort, xfp)
if xfp := req.Header.Get(XForwardedPort); xfp == "" {
req.Header.Set(XForwardedPort, forwardedPort(req))
}
if xfh := req.Header.Get(XForwardedHost); xfh != "" && rw.TrustForwardHeader {
req.Header.Set(XForwardedHost, xfh)
} else if req.Host != "" {
if xfHost := req.Header.Get(XForwardedHost); xfHost == "" && req.Host != "" {
req.Header.Set(XForwardedHost, req.Host)
}
@ -50,3 +56,19 @@ func (rw *HeaderRewriter) Rewrite(req *http.Request) {
// connection, regardless of what the client sent to us.
utils.RemoveHeaders(req.Header, HopHeaders...)
}
func forwardedPort(req *http.Request) string {
if req == nil {
return ""
}
if _, port, err := net.SplitHostPort(req.Host); err == nil && port != "" {
return port
}
if req.TLS != nil {
return "443"
}
return "80"
}