1
0
Fork 0

Fix HTTPRoute Redirect Filter with port and scheme

Co-authored-by: Kevin Pollet <pollet.kevin@gmail.com>
This commit is contained in:
Romain 2024-06-06 10:56:03 +02:00 committed by GitHub
parent 7eac92f49c
commit 28d40e7f3c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 431 additions and 93 deletions

View file

@ -0,0 +1,56 @@
package headermodifier
import (
"context"
"net/http"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/middlewares"
"go.opentelemetry.io/otel/trace"
)
const typeName = "RequestHeaderModifier"
// requestHeaderModifier is a middleware used to modify the headers of an HTTP request.
type requestHeaderModifier struct {
next http.Handler
name string
set map[string]string
add map[string]string
remove []string
}
// NewRequestHeaderModifier creates a new request header modifier middleware.
func NewRequestHeaderModifier(ctx context.Context, next http.Handler, config dynamic.RequestHeaderModifier, name string) (http.Handler, error) {
logger := middlewares.GetLogger(ctx, name, typeName)
logger.Debug().Msg("Creating middleware")
return &requestHeaderModifier{
next: next,
name: name,
set: config.Set,
add: config.Add,
remove: config.Remove,
}, nil
}
func (r *requestHeaderModifier) GetTracingInformation() (string, string, trace.SpanKind) {
return r.name, typeName, trace.SpanKindUnspecified
}
func (r *requestHeaderModifier) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
for headerName, headerValue := range r.set {
req.Header.Set(headerName, headerValue)
}
for headerName, headerValue := range r.add {
req.Header.Add(headerName, headerValue)
}
for _, headerName := range r.remove {
req.Header.Del(headerName)
}
r.next.ServeHTTP(rw, req)
}

View file

@ -0,0 +1,121 @@
package headermodifier
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/testhelpers"
)
func TestRequestHeaderModifier(t *testing.T) {
testCases := []struct {
desc string
config dynamic.RequestHeaderModifier
requestHeaders http.Header
expectedHeaders http.Header
}{
{
desc: "no config",
config: dynamic.RequestHeaderModifier{},
expectedHeaders: map[string][]string{},
},
{
desc: "set header",
config: dynamic.RequestHeaderModifier{
Set: map[string]string{"Foo": "Bar"},
},
expectedHeaders: map[string][]string{"Foo": {"Bar"}},
},
{
desc: "set header with existing headers",
config: dynamic.RequestHeaderModifier{
Set: map[string]string{"Foo": "Bar"},
},
requestHeaders: map[string][]string{"Foo": {"Baz"}, "Bar": {"Foo"}},
expectedHeaders: map[string][]string{"Foo": {"Bar"}, "Bar": {"Foo"}},
},
{
desc: "set multiple headers with existing headers",
config: dynamic.RequestHeaderModifier{
Set: map[string]string{"Foo": "Bar", "Bar": "Foo"},
},
requestHeaders: map[string][]string{"Foo": {"Baz"}, "Bar": {"Foobar"}},
expectedHeaders: map[string][]string{"Foo": {"Bar"}, "Bar": {"Foo"}},
},
{
desc: "add header",
config: dynamic.RequestHeaderModifier{
Add: map[string]string{"Foo": "Bar"},
},
expectedHeaders: map[string][]string{"Foo": {"Bar"}},
},
{
desc: "add header with existing headers",
config: dynamic.RequestHeaderModifier{
Add: map[string]string{"Foo": "Bar"},
},
requestHeaders: map[string][]string{"Foo": {"Baz"}, "Bar": {"Foo"}},
expectedHeaders: map[string][]string{"Foo": {"Baz", "Bar"}, "Bar": {"Foo"}},
},
{
desc: "add multiple headers with existing headers",
config: dynamic.RequestHeaderModifier{
Add: map[string]string{"Foo": "Bar", "Bar": "Foo"},
},
requestHeaders: map[string][]string{"Foo": {"Baz"}, "Bar": {"Foobar"}},
expectedHeaders: map[string][]string{"Foo": {"Baz", "Bar"}, "Bar": {"Foobar", "Foo"}},
},
{
desc: "remove header",
config: dynamic.RequestHeaderModifier{
Remove: []string{"Foo"},
},
expectedHeaders: map[string][]string{},
},
{
desc: "remove header with existing headers",
config: dynamic.RequestHeaderModifier{
Remove: []string{"Foo"},
},
requestHeaders: map[string][]string{"Foo": {"Baz"}, "Bar": {"Foo"}},
expectedHeaders: map[string][]string{"Bar": {"Foo"}},
},
{
desc: "remove multiple headers with existing headers",
config: dynamic.RequestHeaderModifier{
Remove: []string{"Foo", "Bar"},
},
requestHeaders: map[string][]string{"Foo": {"Bar"}, "Bar": {"Foo"}, "Baz": {"Bar"}},
expectedHeaders: map[string][]string{"Baz": {"Bar"}},
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
var gotHeaders http.Header
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders = r.Header
})
handler, err := NewRequestHeaderModifier(context.Background(), next, test.config, "foo-request-header-modifier")
require.NoError(t, err)
req := testhelpers.MustNewRequest(http.MethodGet, "http://localhost", nil)
if test.requestHeaders != nil {
req.Header = test.requestHeaders
}
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, test.expectedHeaders, gotHeaders)
})
}
}

View file

@ -0,0 +1,134 @@
package redirect
import (
"context"
"net/http"
"net/url"
"regexp"
"strings"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/middlewares"
"github.com/vulcand/oxy/v2/utils"
"go.opentelemetry.io/otel/trace"
)
const (
schemeHTTP = "http"
schemeHTTPS = "https"
typeName = "RequestRedirect"
)
var uriRegexp = regexp.MustCompile(`^(https?):\/\/(\[[\w:.]+\]|[\w\._-]+)?(:\d+)?(.*)$`)
// NewRequestRedirect creates a redirect middleware.
func NewRequestRedirect(ctx context.Context, next http.Handler, conf dynamic.RequestRedirect, name string) (http.Handler, error) {
logger := middlewares.GetLogger(ctx, name, typeName)
logger.Debug().Msg("Creating middleware")
logger.Debug().Msgf("Setting up redirection from %s to %s", conf.Regex, conf.Replacement)
re, err := regexp.Compile(conf.Regex)
if err != nil {
return nil, err
}
return &redirect{
regex: re,
replacement: conf.Replacement,
permanent: conf.Permanent,
errHandler: utils.DefaultHandler,
next: next,
name: name,
rawURL: rawURL,
}, nil
}
type redirect struct {
next http.Handler
regex *regexp.Regexp
replacement string
permanent bool
errHandler utils.ErrorHandler
name string
rawURL func(*http.Request) string
}
func rawURL(req *http.Request) string {
scheme := schemeHTTP
host := req.Host
port := ""
uri := req.RequestURI
if match := uriRegexp.FindStringSubmatch(req.RequestURI); len(match) > 0 {
scheme = match[1]
if len(match[2]) > 0 {
host = match[2]
}
if len(match[3]) > 0 {
port = match[3]
}
uri = match[4]
}
if req.TLS != nil {
scheme = schemeHTTPS
}
return strings.Join([]string{scheme, "://", host, port, uri}, "")
}
func (r *redirect) GetTracingInformation() (string, string, trace.SpanKind) {
return r.name, typeName, trace.SpanKindInternal
}
func (r *redirect) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
oldURL := r.rawURL(req)
// If the Regexp doesn't match, skip to the next handler.
if !r.regex.MatchString(oldURL) {
r.next.ServeHTTP(rw, req)
return
}
// Apply a rewrite regexp to the URL.
newURL := r.regex.ReplaceAllString(oldURL, r.replacement)
// Parse the rewritten URL and replace request URL with it.
parsedURL, err := url.Parse(newURL)
if err != nil {
r.errHandler.ServeHTTP(rw, req, err)
return
}
handler := &moveHandler{location: parsedURL, permanent: r.permanent}
handler.ServeHTTP(rw, req)
}
type moveHandler struct {
location *url.URL
permanent bool
}
func (m *moveHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
rw.Header().Set("Location", m.location.String())
status := http.StatusFound
if req.Method != http.MethodGet {
status = http.StatusTemporaryRedirect
}
if m.permanent {
status = http.StatusMovedPermanently
if req.Method != http.MethodGet {
status = http.StatusPermanentRedirect
}
}
rw.WriteHeader(status)
_, err := rw.Write([]byte(http.StatusText(status)))
if err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}

View file

@ -0,0 +1,203 @@
package redirect
import (
"context"
"crypto/tls"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
)
func TestRequestRedirectHandler(t *testing.T) {
testCases := []struct {
desc string
config dynamic.RequestRedirect
method string
url string
headers map[string]string
secured bool
expectedURL string
expectedStatus int
errorExpected bool
}{
{
desc: "simple redirection",
config: dynamic.RequestRedirect{
Regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
Replacement: "https://${1}bar$2:443$4",
},
url: "http://foo.com:80",
expectedURL: "https://foobar.com:443",
expectedStatus: http.StatusFound,
},
{
desc: "URL doesn't match regex",
config: dynamic.RequestRedirect{
Regex: `^(?:http?:\/\/)(foo)(\.com)(:\d+)(.*)$`,
Replacement: "https://${1}bar$2:443$4",
},
url: "http://bar.com:80",
expectedStatus: http.StatusOK,
},
{
desc: "invalid rewritten URL",
config: dynamic.RequestRedirect{
Regex: `^(.*)$`,
Replacement: "http://192.168.0.%31/",
},
url: "http://foo.com:80",
expectedStatus: http.StatusBadGateway,
},
{
desc: "invalid regex",
config: dynamic.RequestRedirect{
Regex: `^(.*`,
Replacement: "$1",
},
url: "http://foo.com:80",
errorExpected: true,
},
{
desc: "HTTP to HTTPS permanent",
config: dynamic.RequestRedirect{
Regex: `^http://`,
Replacement: "https://$1",
Permanent: true,
},
url: "http://foo",
expectedURL: "https://foo",
expectedStatus: http.StatusMovedPermanently,
},
{
desc: "HTTPS to HTTP permanent",
config: dynamic.RequestRedirect{
Regex: `https://foo`,
Replacement: "http://foo",
Permanent: true,
},
secured: true,
url: "https://foo",
expectedURL: "http://foo",
expectedStatus: http.StatusMovedPermanently,
},
{
desc: "HTTP to HTTPS",
config: dynamic.RequestRedirect{
Regex: `http://foo:80`,
Replacement: "https://foo:443",
},
url: "http://foo:80",
expectedURL: "https://foo:443",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP to HTTPS, with X-Forwarded-Proto",
config: dynamic.RequestRedirect{
Regex: `http://foo:80`,
Replacement: "https://foo:443",
},
url: "http://foo:80",
headers: map[string]string{
"X-Forwarded-Proto": "https",
},
expectedURL: "https://foo:443",
expectedStatus: http.StatusFound,
},
{
desc: "HTTPS to HTTP",
config: dynamic.RequestRedirect{
Regex: `https://foo:443`,
Replacement: "http://foo:80",
},
secured: true,
url: "https://foo:443",
expectedURL: "http://foo:80",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP to HTTP",
config: dynamic.RequestRedirect{
Regex: `http://foo:80`,
Replacement: "http://foo:88",
},
url: "http://foo:80",
expectedURL: "http://foo:88",
expectedStatus: http.StatusFound,
},
{
desc: "HTTP to HTTP POST",
config: dynamic.RequestRedirect{
Regex: `^http://`,
Replacement: "https://$1",
},
url: "http://foo",
method: http.MethodPost,
expectedURL: "https://foo",
expectedStatus: http.StatusTemporaryRedirect,
},
{
desc: "HTTP to HTTP POST permanent",
config: dynamic.RequestRedirect{
Regex: `^http://`,
Replacement: "https://$1",
Permanent: true,
},
url: "http://foo",
method: http.MethodPost,
expectedURL: "https://foo",
expectedStatus: http.StatusPermanentRedirect,
},
}
for _, test := range testCases {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
handler, err := NewRequestRedirect(context.Background(), next, test.config, "traefikTest")
if test.errorExpected {
require.Error(t, err)
require.Nil(t, handler)
} else {
require.NoError(t, err)
require.NotNil(t, handler)
recorder := httptest.NewRecorder()
method := http.MethodGet
if test.method != "" {
method = test.method
}
req := httptest.NewRequest(method, test.url, nil)
if test.secured {
req.TLS = &tls.ConnectionState{}
}
for k, v := range test.headers {
req.Header.Set(k, v)
}
req.Header.Set("X-Foo", "bar")
handler.ServeHTTP(recorder, req)
assert.Equal(t, test.expectedStatus, recorder.Code)
switch test.expectedStatus {
case http.StatusMovedPermanently, http.StatusFound, http.StatusTemporaryRedirect, http.StatusPermanentRedirect:
location, err := recorder.Result().Location()
require.NoError(t, err)
assert.Equal(t, test.expectedURL, location.String())
default:
location, err := recorder.Result().Location()
require.Errorf(t, err, "Location %v", location)
}
}
})
}
}