diff --git a/middlewares/stripPrefix.go b/middlewares/stripPrefix.go index 585e60305..291dda40d 100644 --- a/middlewares/stripPrefix.go +++ b/middlewares/stripPrefix.go @@ -17,17 +17,29 @@ type StripPrefix struct { func (s *StripPrefix) ServeHTTP(w http.ResponseWriter, r *http.Request) { for _, prefix := range s.Prefixes { - if p := strings.TrimPrefix(r.URL.Path, strings.TrimSpace(prefix)); len(p) < len(r.URL.Path) { - r.URL.Path = p - r.Header[forwardedPrefixHeader] = []string{prefix} - r.RequestURI = r.URL.RequestURI() - s.Handler.ServeHTTP(w, r) + origPrefix := strings.TrimSpace(prefix) + if origPrefix == r.URL.Path { + r.URL.Path = "/" + s.serveRequest(w, r, origPrefix) + return + } + + prefix = strings.TrimSuffix(origPrefix, "/") + "/" + if p := strings.TrimPrefix(r.URL.Path, prefix); len(p) < len(r.URL.Path) { + r.URL.Path = "/" + strings.TrimPrefix(p, "/") + s.serveRequest(w, r, origPrefix) return } } http.NotFound(w, r) } +func (s *StripPrefix) serveRequest(w http.ResponseWriter, r *http.Request, prefix string) { + r.Header[forwardedPrefixHeader] = []string{prefix} + r.RequestURI = r.URL.RequestURI() + s.Handler.ServeHTTP(w, r) +} + // SetHandler sets handler func (s *StripPrefix) SetHandler(Handler http.Handler) { s.Handler = Handler diff --git a/middlewares/stripPrefix_test.go b/middlewares/stripPrefix_test.go new file mode 100644 index 000000000..d7fe348f5 --- /dev/null +++ b/middlewares/stripPrefix_test.go @@ -0,0 +1,103 @@ +package middlewares + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStripPrefix(t *testing.T) { + tests := []struct { + desc string + prefixes []string + path string + expectedStatusCode int + expectedPath string + }{ + { + desc: "no prefixes configured", + prefixes: []string{}, + path: "/noprefixes", + expectedStatusCode: http.StatusNotFound, + }, + { + desc: "wildcard (.*) requests", + prefixes: []string{"/"}, + path: "/", + expectedStatusCode: http.StatusOK, + expectedPath: "/", + }, + { + desc: "prefix and path matching", + prefixes: []string{"/stat"}, + path: "/stat", + expectedStatusCode: http.StatusOK, + expectedPath: "/", + }, + { + desc: "path prefix on exactly matching path", + prefixes: []string{"/stat/"}, + path: "/stat/", + expectedStatusCode: http.StatusOK, + expectedPath: "/", + }, + { + desc: "path prefix on matching longer path", + prefixes: []string{"/stat/"}, + path: "/stat/us", + expectedStatusCode: http.StatusOK, + expectedPath: "/us", + }, + { + desc: "path prefix on mismatching path", + prefixes: []string{"/stat/"}, + path: "/status", + expectedStatusCode: http.StatusNotFound, + }, + { + desc: "general prefix on matching path", + prefixes: []string{"/stat"}, + path: "/stat/", + expectedStatusCode: http.StatusOK, + expectedPath: "/", + }, + { + desc: "earlier prefix matching", + prefixes: []string{"/stat", "/stat/us"}, + path: "/stat/us", + expectedStatusCode: http.StatusOK, + expectedPath: "/us", + }, + { + desc: "later prefix matching", + prefixes: []string{"/mismatch", "/stat"}, + path: "/stat", + expectedStatusCode: http.StatusOK, + expectedPath: "/", + }, + } + + for _, test := range tests { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + var gotPath string + server := httptest.NewServer(&StripPrefix{ + Prefixes: test.prefixes, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + }), + }) + defer server.Close() + + resp, err := http.Get(server.URL + test.path) + require.NoError(t, err, "Failed to send GET request") + assert.Equal(t, test.expectedStatusCode, resp.StatusCode, "Unexpected status code") + + assert.Equal(t, test.expectedPath, gotPath, "Unexpected path") + }) + } +}