diff --git a/middlewares/errorpages/error_pages.go b/middlewares/errorpages/error_pages.go index a2e9d426c..d73aa1eb0 100644 --- a/middlewares/errorpages/error_pages.go +++ b/middlewares/errorpages/error_pages.go @@ -3,6 +3,7 @@ package errorpages import ( "bufio" "bytes" + "errors" "fmt" "net" "net/http" @@ -13,7 +14,6 @@ import ( "github.com/containous/traefik/log" "github.com/containous/traefik/middlewares" "github.com/containous/traefik/types" - "github.com/pkg/errors" "github.com/vulcand/oxy/forward" "github.com/vulcand/oxy/utils" ) diff --git a/middlewares/ip_whitelister.go b/middlewares/ip_whitelister.go index 60ff8a30e..a352aad12 100644 --- a/middlewares/ip_whitelister.go +++ b/middlewares/ip_whitelister.go @@ -38,21 +38,15 @@ func NewIPWhiteLister(whiteList []string, useXForwardedFor bool) (*IPWhiteLister } func (wl *IPWhiteLister) handle(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) { - allowed, ip, err := wl.whiteLister.IsAuthorized(r) + err := wl.whiteLister.IsAuthorized(r) if err != nil { - tracing.SetErrorAndDebugLog(r, "request %+v matched none of the white list - rejecting", r) + tracing.SetErrorAndDebugLog(r, "request %+v - rejecting: %v", r, err) reject(w) return } - if allowed { - tracing.SetErrorAndDebugLog(r, "request %+v matched white list %s - passing", r, wl.whiteLister) - next.ServeHTTP(w, r) - return - } - - tracing.SetErrorAndDebugLog(r, "source-IP %s matched none of the white list - rejecting", ip) - reject(w) + tracing.SetErrorAndDebugLog(r, "request %+v matched white list %s - passing", r, wl.whiteLister) + next.ServeHTTP(w, r) } func (wl *IPWhiteLister) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) { @@ -63,5 +57,8 @@ func reject(w http.ResponseWriter) { statusCode := http.StatusForbidden w.WriteHeader(statusCode) - w.Write([]byte(http.StatusText(statusCode))) + _, err := w.Write([]byte(http.StatusText(statusCode))) + if err != nil { + log.Error(err) + } } diff --git a/middlewares/ip_whitelister_test.go b/middlewares/ip_whitelister_test.go index 33b0c56e5..07ee800ba 100644 --- a/middlewares/ip_whitelister_test.go +++ b/middlewares/ip_whitelister_test.go @@ -88,6 +88,13 @@ func TestIPWhiteLister_ServeHTTP(t *testing.T) { xForwardedFor: []string{"30.30.30.30", "40.40.40.40"}, expected: 200, }, + { + desc: "authorized with only one X-Forwarded-For", + whiteList: []string{"30.30.30.30"}, + useXForwardedFor: true, + xForwardedFor: []string{"30.30.30.30"}, + expected: 200, + }, { desc: "non authorized with X-Forwarded-For", whiteList: []string{"30.30.30.30"}, diff --git a/server/header_rewriter.go b/server/header_rewriter.go index e4531bc64..193220385 100644 --- a/server/header_rewriter.go +++ b/server/header_rewriter.go @@ -11,20 +11,20 @@ import ( // NewHeaderRewriter Create a header rewriter func NewHeaderRewriter(trustedIPs []string, insecure bool) (forward.ReqRewriter, error) { - IPs, err := whitelist.NewIP(trustedIPs, insecure, true) + ips, err := whitelist.NewIP(trustedIPs, insecure, true) if err != nil { return nil, err } - h, err := os.Hostname() + hostname, err := os.Hostname() if err != nil { - h = "localhost" + hostname = "localhost" } return &headerRewriter{ - secureRewriter: &forward.HeaderRewriter{TrustForwardHeader: true, Hostname: h}, - insecureRewriter: &forward.HeaderRewriter{TrustForwardHeader: false, Hostname: h}, - ips: IPs, + secureRewriter: &forward.HeaderRewriter{TrustForwardHeader: false, Hostname: hostname}, + insecureRewriter: &forward.HeaderRewriter{TrustForwardHeader: true, Hostname: hostname}, + ips: ips, insecure: insecure, }, nil } @@ -37,16 +37,17 @@ type headerRewriter struct { } func (h *headerRewriter) Rewrite(req *http.Request) { - authorized, _, err := h.ips.IsAuthorized(req) + if h.insecure { + h.insecureRewriter.Rewrite(req) + return + } + + err := h.ips.IsAuthorized(req) if err != nil { log.Error(err) h.secureRewriter.Rewrite(req) return } - if h.insecure || authorized { - h.secureRewriter.Rewrite(req) - } else { - h.insecureRewriter.Rewrite(req) - } + h.insecureRewriter.Rewrite(req) } diff --git a/server/header_rewriter_test.go b/server/header_rewriter_test.go new file mode 100644 index 000000000..7e5df3bbf --- /dev/null +++ b/server/header_rewriter_test.go @@ -0,0 +1,104 @@ +package server + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestHeaderRewriter_Rewrite(t *testing.T) { + testCases := []struct { + desc string + remoteAddr string + trustedIPs []string + insecure bool + expected map[string]string + }{ + { + desc: "Secure & authorized", + remoteAddr: "10.10.10.10:80", + trustedIPs: []string{"10.10.10.10"}, + insecure: false, + expected: map[string]string{ + "X-Foo": "bar", + "X-Forwarded-For": "30.30.30.30", + }, + }, + { + desc: "Secure & unauthorized", + remoteAddr: "50.50.50.50:80", + trustedIPs: []string{"10.10.10.10"}, + insecure: false, + expected: map[string]string{ + "X-Foo": "bar", + "X-Forwarded-For": "", + }, + }, + { + desc: "Secure & authorized error", + remoteAddr: "10.10.10.10", + trustedIPs: []string{"10.10.10.10"}, + insecure: false, + expected: map[string]string{ + "X-Foo": "bar", + "X-Forwarded-For": "", + }, + }, + { + desc: "insecure & authorized", + remoteAddr: "10.10.10.10:80", + trustedIPs: []string{"10.10.10.10"}, + insecure: true, + expected: map[string]string{ + "X-Foo": "bar", + "X-Forwarded-For": "30.30.30.30", + }, + }, + { + desc: "insecure & unauthorized", + remoteAddr: "50.50.50.50:80", + trustedIPs: []string{"10.10.10.10"}, + insecure: true, + expected: map[string]string{ + "X-Foo": "bar", + "X-Forwarded-For": "30.30.30.30", + }, + }, + { + desc: "insecure & authorized error", + remoteAddr: "10.10.10.10", + trustedIPs: []string{"10.10.10.10"}, + insecure: true, + expected: map[string]string{ + "X-Foo": "bar", + "X-Forwarded-For": "30.30.30.30", + }, + }, + } + + for _, test := range testCases { + test := test + t.Run(test.desc, func(t *testing.T) { + t.Parallel() + + rewriter, err := NewHeaderRewriter(test.trustedIPs, test.insecure) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodGet, "http://20.20.20.20/foo", nil) + require.NoError(t, err) + req.RemoteAddr = test.remoteAddr + + req.Header.Set("X-Foo", "bar") + req.Header.Set("X-Forwarded-For", "30.30.30.30") + + rewriter.Rewrite(req) + + for key, value := range test.expected { + assert.Equal(t, value, req.Header.Get(key)) + } + }) + } +} diff --git a/server/server.go b/server/server.go index 8d95ed938..201368be9 100644 --- a/server/server.go +++ b/server/server.go @@ -808,7 +808,8 @@ func (s *Server) prepareServer(entryPointName string, entryPoint *configuration. if !ok { return false, fmt.Errorf("type error %v", addr) } - return IPs.ContainsIP(ip.IP) + + return IPs.ContainsIP(ip.IP), nil }, } } diff --git a/whitelist/ip.go b/whitelist/ip.go index 3b1c4a0b7..33a988eba 100644 --- a/whitelist/ip.go +++ b/whitelist/ip.go @@ -1,11 +1,11 @@ package whitelist import ( + "errors" "fmt" "net" "net/http" - - "github.com/pkg/errors" + "strings" ) const ( @@ -50,64 +50,78 @@ func NewIP(whiteList []string, insecure bool, useXForwardedFor bool) (*IP, error } // IsAuthorized checks if provided request is authorized by the white list -func (ip *IP) IsAuthorized(req *http.Request) (bool, net.IP, error) { +func (ip *IP) IsAuthorized(req *http.Request) error { if ip.insecure { - return true, nil, nil + return nil } + var invalidMatches []string + if ip.useXForwardedFor { xFFs := req.Header[XForwardedFor] - if len(xFFs) > 1 { + if len(xFFs) > 0 { for _, xFF := range xFFs { - ok, i, err := ip.contains(parseHost(xFF)) + ok, err := ip.contains(parseHost(xFF)) if err != nil { - return false, nil, err + return err } if ok { - return ok, i, nil + return nil } + + invalidMatches = append(invalidMatches, xFF) } } } host, _, err := net.SplitHostPort(req.RemoteAddr) if err != nil { - return false, nil, err + return err } - return ip.contains(host) + + ok, err := ip.contains(host) + if err != nil { + return err + } + + if !ok { + invalidMatches = append(invalidMatches, req.RemoteAddr) + return fmt.Errorf("%q matched none of the white list", strings.Join(invalidMatches, ", ")) + } + + return nil } // contains checks if provided address is in the white list -func (ip *IP) contains(addr string) (bool, net.IP, error) { +func (ip *IP) contains(addr string) (bool, error) { ipAddr, err := parseIP(addr) if err != nil { - return false, nil, fmt.Errorf("unable to parse address: %s: %s", addr, err) + return false, fmt.Errorf("unable to parse address: %s: %s", addr, err) } - contains, err := ip.ContainsIP(ipAddr) - return contains, ipAddr, err + return ip.ContainsIP(ipAddr), nil } // ContainsIP checks if provided address is in the white list -func (ip *IP) ContainsIP(addr net.IP) (bool, error) { +func (ip *IP) ContainsIP(addr net.IP) bool { if ip.insecure { - return true, nil + return true } for _, whiteListIP := range ip.whiteListsIPs { if whiteListIP.Equal(addr) { - return true, nil + return true } } for _, whiteListNet := range ip.whiteListsNet { if whiteListNet.Contains(addr) { - return true, nil + return true } } - return false, nil + return false } func parseIP(addr string) (net.IP, error) { diff --git a/whitelist/ip_test.go b/whitelist/ip_test.go index f4dc0f022..f29c09335 100644 --- a/whitelist/ip_test.go +++ b/whitelist/ip_test.go @@ -17,7 +17,7 @@ func TestIsAuthorized(t *testing.T) { allowXForwardedFor bool remoteAddr string xForwardedForValues []string - expected bool + authorized bool }{ { desc: "allow UseXForwardedFor, remoteAddr not in range, UseXForwardedFor in range", @@ -25,7 +25,7 @@ func TestIsAuthorized(t *testing.T) { allowXForwardedFor: true, remoteAddr: "10.2.3.1:123", xForwardedForValues: []string{"1.2.3.1", "10.2.3.1"}, - expected: true, + authorized: true, }, { desc: "allow UseXForwardedFor, remoteAddr in range, UseXForwardedFor in range", @@ -33,7 +33,7 @@ func TestIsAuthorized(t *testing.T) { allowXForwardedFor: true, remoteAddr: "1.2.3.1:123", xForwardedForValues: []string{"1.2.3.1", "10.2.3.1"}, - expected: true, + authorized: true, }, { desc: "allow UseXForwardedFor, remoteAddr in range, UseXForwardedFor not in range", @@ -41,7 +41,7 @@ func TestIsAuthorized(t *testing.T) { allowXForwardedFor: true, remoteAddr: "1.2.3.1:123", xForwardedForValues: []string{"10.2.3.1", "10.2.3.1"}, - expected: true, + authorized: true, }, { desc: "allow UseXForwardedFor, remoteAddr not in range, UseXForwardedFor not in range", @@ -49,7 +49,7 @@ func TestIsAuthorized(t *testing.T) { allowXForwardedFor: true, remoteAddr: "10.2.3.1:123", xForwardedForValues: []string{"10.2.3.1", "10.2.3.1"}, - expected: false, + authorized: false, }, { desc: "don't allow UseXForwardedFor, remoteAddr not in range, UseXForwardedFor in range", @@ -57,7 +57,7 @@ func TestIsAuthorized(t *testing.T) { allowXForwardedFor: false, remoteAddr: "10.2.3.1:123", xForwardedForValues: []string{"1.2.3.1", "10.2.3.1"}, - expected: false, + authorized: false, }, { desc: "don't allow UseXForwardedFor, remoteAddr in range, UseXForwardedFor in range", @@ -65,7 +65,7 @@ func TestIsAuthorized(t *testing.T) { allowXForwardedFor: false, remoteAddr: "1.2.3.1:123", xForwardedForValues: []string{"1.2.3.1", "10.2.3.1"}, - expected: true, + authorized: true, }, { desc: "don't allow UseXForwardedFor, remoteAddr in range, UseXForwardedFor not in range", @@ -73,7 +73,7 @@ func TestIsAuthorized(t *testing.T) { allowXForwardedFor: false, remoteAddr: "1.2.3.1:123", xForwardedForValues: []string{"10.2.3.1", "10.2.3.1"}, - expected: true, + authorized: true, }, { desc: "don't allow UseXForwardedFor, remoteAddr not in range, UseXForwardedFor not in range", @@ -81,7 +81,7 @@ func TestIsAuthorized(t *testing.T) { allowXForwardedFor: false, remoteAddr: "10.2.3.1:123", xForwardedForValues: []string{"10.2.3.1", "10.2.3.1"}, - expected: false, + authorized: false, }, } @@ -95,11 +95,12 @@ func TestIsAuthorized(t *testing.T) { whiteLister, err := NewIP(test.whiteList, false, test.allowXForwardedFor) require.NoError(t, err) - authorized, ips, err := whiteLister.IsAuthorized(req) - require.NoError(t, err) - assert.NotNil(t, ips) - - assert.Equal(t, test.expected, authorized) + err = whiteLister.IsAuthorized(req) + if test.authorized { + require.NoError(t, err) + } else { + require.Error(t, err) + } }) } } @@ -349,16 +350,14 @@ func TestContainsIsAllowed(t *testing.T) { require.NotNil(t, whiteLister) for _, testIP := range test.passIPs { - allowed, ip, err := whiteLister.contains(testIP) + allowed, err := whiteLister.contains(testIP) require.NoError(t, err) - require.NotNil(t, ip, err) assert.Truef(t, allowed, "%s should have passed.", testIP) } for _, testIP := range test.rejectIPs { - allowed, ip, err := whiteLister.contains(testIP) + allowed, err := whiteLister.contains(testIP) require.NoError(t, err) - require.NotNil(t, ip, err) assert.Falsef(t, allowed, "%s should not have passed.", testIP) } }) @@ -405,7 +404,7 @@ func TestContainsInsecure(t *testing.T) { t.Run(test.desc, func(t *testing.T) { t.Parallel() - ok, _, err := test.whiteLister.contains(test.ip) + ok, err := test.whiteLister.contains(test.ip) require.NoError(t, err) assert.Equal(t, test.expected, ok) @@ -426,9 +425,8 @@ func TestContainsBrokenIPs(t *testing.T) { require.NoError(t, err) for _, testIP := range brokenIPs { - _, ip, err := whiteLister.contains(testIP) + _, err := whiteLister.contains(testIP) assert.Error(t, err) - require.Nil(t, ip, err) } }