Ability to use "X-Forwarded-For" as a source of IP for white list.
This commit is contained in:
parent
4802484729
commit
d2766b1b4f
50 changed files with 1496 additions and 599 deletions
|
@ -3,36 +3,45 @@ package whitelist
|
|||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
// XForwardedFor Header name
|
||||
XForwardedFor = "X-Forwarded-For"
|
||||
)
|
||||
|
||||
// IP allows to check that addresses are in a white list
|
||||
type IP struct {
|
||||
whiteListsIPs []*net.IP
|
||||
whiteListsNet []*net.IPNet
|
||||
insecure bool
|
||||
whiteListsIPs []*net.IP
|
||||
whiteListsNet []*net.IPNet
|
||||
insecure bool
|
||||
useXForwardedFor bool
|
||||
}
|
||||
|
||||
// NewIP builds a new IP given a list of CIDR-Strings to whitelist
|
||||
func NewIP(whitelistStrings []string, insecure bool) (*IP, error) {
|
||||
if len(whitelistStrings) == 0 && !insecure {
|
||||
// NewIP builds a new IP given a list of CIDR-Strings to white list
|
||||
func NewIP(whiteList []string, insecure bool, useXForwardedFor bool) (*IP, error) {
|
||||
if len(whiteList) == 0 && !insecure {
|
||||
return nil, errors.New("no white list provided")
|
||||
}
|
||||
|
||||
ip := IP{insecure: insecure}
|
||||
ip := IP{
|
||||
insecure: insecure,
|
||||
useXForwardedFor: useXForwardedFor,
|
||||
}
|
||||
|
||||
if !insecure {
|
||||
for _, whitelistString := range whitelistStrings {
|
||||
ipAddr := net.ParseIP(whitelistString)
|
||||
if ipAddr != nil {
|
||||
for _, ipMask := range whiteList {
|
||||
if ipAddr := net.ParseIP(ipMask); ipAddr != nil {
|
||||
ip.whiteListsIPs = append(ip.whiteListsIPs, &ipAddr)
|
||||
} else {
|
||||
_, whitelist, err := net.ParseCIDR(whitelistString)
|
||||
_, ipAddr, err := net.ParseCIDR(ipMask)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing CIDR whitelist %s: %v", whitelist, err)
|
||||
return nil, fmt.Errorf("parsing CIDR white list %s: %v", ipAddr, err)
|
||||
}
|
||||
ip.whiteListsNet = append(ip.whiteListsNet, whitelist)
|
||||
ip.whiteListsNet = append(ip.whiteListsNet, ipAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -40,13 +49,38 @@ func NewIP(whitelistStrings []string, insecure bool) (*IP, error) {
|
|||
return &ip, nil
|
||||
}
|
||||
|
||||
// Contains checks if provided address is in the white list
|
||||
func (ip *IP) Contains(addr string) (bool, net.IP, error) {
|
||||
// IsAuthorized checks if provided request is authorized by the white list
|
||||
func (ip *IP) IsAuthorized(req *http.Request) (bool, net.IP, error) {
|
||||
if ip.insecure {
|
||||
return true, nil, nil
|
||||
}
|
||||
|
||||
ipAddr, err := ipFromRemoteAddr(addr)
|
||||
if ip.useXForwardedFor {
|
||||
xFFs := req.Header[XForwardedFor]
|
||||
if len(xFFs) > 1 {
|
||||
for _, xFF := range xFFs {
|
||||
ok, i, err := ip.contains(parseHost(xFF))
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
if ok {
|
||||
return ok, i, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
host, _, err := net.SplitHostPort(req.RemoteAddr)
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
return ip.contains(host)
|
||||
}
|
||||
|
||||
// contains checks if provided address is in the white list
|
||||
func (ip *IP) contains(addr string) (bool, net.IP, error) {
|
||||
ipAddr, err := parseIP(addr)
|
||||
if err != nil {
|
||||
return false, nil, fmt.Errorf("unable to parse address: %s: %s", addr, err)
|
||||
}
|
||||
|
@ -76,7 +110,7 @@ func (ip *IP) ContainsIP(addr net.IP) (bool, error) {
|
|||
return false, nil
|
||||
}
|
||||
|
||||
func ipFromRemoteAddr(addr string) (net.IP, error) {
|
||||
func parseIP(addr string) (net.IP, error) {
|
||||
userIP := net.ParseIP(addr)
|
||||
if userIP == nil {
|
||||
return nil, fmt.Errorf("can't parse IP from address %s", addr)
|
||||
|
@ -84,3 +118,11 @@ func ipFromRemoteAddr(addr string) (net.IP, error) {
|
|||
|
||||
return userIP, nil
|
||||
}
|
||||
|
||||
func parseHost(addr string) string {
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return addr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
|
|
@ -2,55 +2,151 @@ package whitelist
|
|||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsAuthorized(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
whiteList []string
|
||||
allowXForwardedFor bool
|
||||
remoteAddr string
|
||||
xForwardedForValues []string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
desc: "allow UseXForwardedFor, remoteAddr not in range, UseXForwardedFor in range",
|
||||
whiteList: []string{"1.2.3.4/24"},
|
||||
allowXForwardedFor: true,
|
||||
remoteAddr: "10.2.3.1:123",
|
||||
xForwardedForValues: []string{"1.2.3.1", "10.2.3.1"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
desc: "allow UseXForwardedFor, remoteAddr in range, UseXForwardedFor in range",
|
||||
whiteList: []string{"1.2.3.4/24"},
|
||||
allowXForwardedFor: true,
|
||||
remoteAddr: "1.2.3.1:123",
|
||||
xForwardedForValues: []string{"1.2.3.1", "10.2.3.1"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
desc: "allow UseXForwardedFor, remoteAddr in range, UseXForwardedFor not in range",
|
||||
whiteList: []string{"1.2.3.4/24"},
|
||||
allowXForwardedFor: true,
|
||||
remoteAddr: "1.2.3.1:123",
|
||||
xForwardedForValues: []string{"10.2.3.1", "10.2.3.1"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
desc: "allow UseXForwardedFor, remoteAddr not in range, UseXForwardedFor not in range",
|
||||
whiteList: []string{"1.2.3.4/24"},
|
||||
allowXForwardedFor: true,
|
||||
remoteAddr: "10.2.3.1:123",
|
||||
xForwardedForValues: []string{"10.2.3.1", "10.2.3.1"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
desc: "don't allow UseXForwardedFor, remoteAddr not in range, UseXForwardedFor in range",
|
||||
whiteList: []string{"1.2.3.4/24"},
|
||||
allowXForwardedFor: false,
|
||||
remoteAddr: "10.2.3.1:123",
|
||||
xForwardedForValues: []string{"1.2.3.1", "10.2.3.1"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
desc: "don't allow UseXForwardedFor, remoteAddr in range, UseXForwardedFor in range",
|
||||
whiteList: []string{"1.2.3.4/24"},
|
||||
allowXForwardedFor: false,
|
||||
remoteAddr: "1.2.3.1:123",
|
||||
xForwardedForValues: []string{"1.2.3.1", "10.2.3.1"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
desc: "don't allow UseXForwardedFor, remoteAddr in range, UseXForwardedFor not in range",
|
||||
whiteList: []string{"1.2.3.4/24"},
|
||||
allowXForwardedFor: false,
|
||||
remoteAddr: "1.2.3.1:123",
|
||||
xForwardedForValues: []string{"10.2.3.1", "10.2.3.1"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
desc: "don't allow UseXForwardedFor, remoteAddr not in range, UseXForwardedFor not in range",
|
||||
whiteList: []string{"1.2.3.4/24"},
|
||||
allowXForwardedFor: false,
|
||||
remoteAddr: "10.2.3.1:123",
|
||||
xForwardedForValues: []string{"10.2.3.1", "10.2.3.1"},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := NewRequest(test.remoteAddr, test.xForwardedForValues)
|
||||
|
||||
whiteLister, err := NewIP(test.whiteList, false, test.allowXForwardedFor)
|
||||
require.NoError(t, err)
|
||||
|
||||
authorized, ips, err := whiteLister.IsAuthorized(req)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, ips)
|
||||
|
||||
assert.Equal(t, test.expected, authorized)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
whitelistStrings []string
|
||||
whiteList []string
|
||||
expectedWhitelists []*net.IPNet
|
||||
errMessage string
|
||||
}{
|
||||
{
|
||||
desc: "nil whitelist",
|
||||
whitelistStrings: nil,
|
||||
whiteList: nil,
|
||||
expectedWhitelists: nil,
|
||||
errMessage: "no white list provided",
|
||||
}, {
|
||||
desc: "empty whitelist",
|
||||
whitelistStrings: []string{},
|
||||
whiteList: []string{},
|
||||
expectedWhitelists: nil,
|
||||
errMessage: "no white list provided",
|
||||
}, {
|
||||
desc: "whitelist containing empty string",
|
||||
whitelistStrings: []string{
|
||||
whiteList: []string{
|
||||
"1.2.3.4/24",
|
||||
"",
|
||||
"fe80::/16",
|
||||
},
|
||||
expectedWhitelists: nil,
|
||||
errMessage: "parsing CIDR whitelist <nil>: invalid CIDR address: ",
|
||||
errMessage: "parsing CIDR white list <nil>: invalid CIDR address: ",
|
||||
}, {
|
||||
desc: "whitelist containing only an empty string",
|
||||
whitelistStrings: []string{
|
||||
whiteList: []string{
|
||||
"",
|
||||
},
|
||||
expectedWhitelists: nil,
|
||||
errMessage: "parsing CIDR whitelist <nil>: invalid CIDR address: ",
|
||||
errMessage: "parsing CIDR white list <nil>: invalid CIDR address: ",
|
||||
}, {
|
||||
desc: "whitelist containing an invalid string",
|
||||
whitelistStrings: []string{
|
||||
whiteList: []string{
|
||||
"foo",
|
||||
},
|
||||
expectedWhitelists: nil,
|
||||
errMessage: "parsing CIDR whitelist <nil>: invalid CIDR address: foo",
|
||||
errMessage: "parsing CIDR white list <nil>: invalid CIDR address: foo",
|
||||
}, {
|
||||
desc: "IPv4 & IPv6 whitelist",
|
||||
whitelistStrings: []string{
|
||||
whiteList: []string{
|
||||
"1.2.3.4/24",
|
||||
"fe80::/16",
|
||||
},
|
||||
|
@ -61,7 +157,7 @@ func TestNew(t *testing.T) {
|
|||
errMessage: "",
|
||||
}, {
|
||||
desc: "IPv4 only",
|
||||
whitelistStrings: []string{
|
||||
whiteList: []string{
|
||||
"127.0.0.1/8",
|
||||
},
|
||||
expectedWhitelists: []*net.IPNet{
|
||||
|
@ -75,12 +171,12 @@ func TestNew(t *testing.T) {
|
|||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
whitelister, err := NewIP(test.whitelistStrings, false)
|
||||
whiteLister, err := NewIP(test.whiteList, false, false)
|
||||
if test.errMessage != "" {
|
||||
require.EqualError(t, err, test.errMessage)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
for index, actual := range whitelister.whiteListsNet {
|
||||
for index, actual := range whiteLister.whiteListsNet {
|
||||
expected := test.expectedWhitelists[index]
|
||||
assert.Equal(t, expected.IP, actual.IP)
|
||||
assert.Equal(t, expected.Mask.String(), actual.Mask.String())
|
||||
|
@ -98,10 +194,8 @@ func TestContainsIsAllowed(t *testing.T) {
|
|||
rejectIPs []string
|
||||
}{
|
||||
{
|
||||
desc: "IPv4",
|
||||
whitelistStrings: []string{
|
||||
"1.2.3.4/24",
|
||||
},
|
||||
desc: "IPv4",
|
||||
whitelistStrings: []string{"1.2.3.4/24"},
|
||||
passIPs: []string{
|
||||
"1.2.3.1",
|
||||
"1.2.3.32",
|
||||
|
@ -116,13 +210,9 @@ func TestContainsIsAllowed(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
desc: "IPv4 single IP",
|
||||
whitelistStrings: []string{
|
||||
"8.8.8.8",
|
||||
},
|
||||
passIPs: []string{
|
||||
"8.8.8.8",
|
||||
},
|
||||
desc: "IPv4 single IP",
|
||||
whitelistStrings: []string{"8.8.8.8"},
|
||||
passIPs: []string{"8.8.8.8"},
|
||||
rejectIPs: []string{
|
||||
"8.8.8.7",
|
||||
"8.8.8.9",
|
||||
|
@ -133,13 +223,9 @@ func TestContainsIsAllowed(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
desc: "IPv4 Net single IP",
|
||||
whitelistStrings: []string{
|
||||
"8.8.8.8/32",
|
||||
},
|
||||
passIPs: []string{
|
||||
"8.8.8.8",
|
||||
},
|
||||
desc: "IPv4 Net single IP",
|
||||
whitelistStrings: []string{"8.8.8.8/32"},
|
||||
passIPs: []string{"8.8.8.8"},
|
||||
rejectIPs: []string{
|
||||
"8.8.8.7",
|
||||
"8.8.8.9",
|
||||
|
@ -150,11 +236,8 @@ func TestContainsIsAllowed(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple IPv4",
|
||||
whitelistStrings: []string{
|
||||
"1.2.3.4/24",
|
||||
"8.8.8.8/8",
|
||||
},
|
||||
desc: "multiple IPv4",
|
||||
whitelistStrings: []string{"1.2.3.4/24", "8.8.8.8/8"},
|
||||
passIPs: []string{
|
||||
"1.2.3.1",
|
||||
"1.2.3.32",
|
||||
|
@ -174,10 +257,8 @@ func TestContainsIsAllowed(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
desc: "IPv6",
|
||||
whitelistStrings: []string{
|
||||
"2a03:4000:6:d080::/64",
|
||||
},
|
||||
desc: "IPv6",
|
||||
whitelistStrings: []string{"2a03:4000:6:d080::/64"},
|
||||
passIPs: []string{
|
||||
"2a03:4000:6:d080::",
|
||||
"2a03:4000:6:d080::1",
|
||||
|
@ -192,13 +273,9 @@ func TestContainsIsAllowed(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
desc: "IPv6 single IP",
|
||||
whitelistStrings: []string{
|
||||
"2a03:4000:6:d080::42/128",
|
||||
},
|
||||
passIPs: []string{
|
||||
"2a03:4000:6:d080::42",
|
||||
},
|
||||
desc: "IPv6 single IP",
|
||||
whitelistStrings: []string{"2a03:4000:6:d080::42/128"},
|
||||
passIPs: []string{"2a03:4000:6:d080::42"},
|
||||
rejectIPs: []string{
|
||||
"2a03:4000:6:d080::1",
|
||||
"2a03:4000:6:d080:dead:beef:ffff:ffff",
|
||||
|
@ -206,11 +283,8 @@ func TestContainsIsAllowed(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple IPv6",
|
||||
whitelistStrings: []string{
|
||||
"2a03:4000:6:d080::/64",
|
||||
"fe80::/16",
|
||||
},
|
||||
desc: "multiple IPv6",
|
||||
whitelistStrings: []string{"2a03:4000:6:d080::/64", "fe80::/16"},
|
||||
passIPs: []string{
|
||||
"2a03:4000:6:d080::",
|
||||
"2a03:4000:6:d080::1",
|
||||
|
@ -227,13 +301,8 @@ func TestContainsIsAllowed(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple IPv6 & IPv4",
|
||||
whitelistStrings: []string{
|
||||
"2a03:4000:6:d080::/64",
|
||||
"fe80::/16",
|
||||
"1.2.3.4/24",
|
||||
"8.8.8.8/8",
|
||||
},
|
||||
desc: "multiple IPv6 & IPv4",
|
||||
whitelistStrings: []string{"2a03:4000:6:d080::/64", "fe80::/16", "1.2.3.4/24", "8.8.8.8/8"},
|
||||
passIPs: []string{
|
||||
"2a03:4000:6:d080::",
|
||||
"2a03:4000:6:d080::1",
|
||||
|
@ -263,11 +332,9 @@ func TestContainsIsAllowed(t *testing.T) {
|
|||
},
|
||||
},
|
||||
{
|
||||
desc: "broken IP-addresses",
|
||||
whitelistStrings: []string{
|
||||
"127.0.0.1/32",
|
||||
},
|
||||
passIPs: nil,
|
||||
desc: "broken IP-addresses",
|
||||
whitelistStrings: []string{"127.0.0.1/32"},
|
||||
passIPs: nil,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -276,23 +343,23 @@ func TestContainsIsAllowed(t *testing.T) {
|
|||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
whiteLister, err := NewIP(test.whitelistStrings, false)
|
||||
whiteLister, err := NewIP(test.whitelistStrings, false, false)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, whiteLister)
|
||||
|
||||
for _, testIP := range test.passIPs {
|
||||
allowed, ip, err := whiteLister.Contains(testIP)
|
||||
allowed, ip, err := whiteLister.contains(testIP)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, ip, err)
|
||||
assert.True(t, allowed, testIP+" should have passed "+test.desc)
|
||||
assert.Truef(t, allowed, "%s should have passed.", testIP)
|
||||
}
|
||||
|
||||
for _, testIP := range test.rejectIPs {
|
||||
allowed, ip, err := whiteLister.Contains(testIP)
|
||||
allowed, ip, err := whiteLister.contains(testIP)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, ip, err)
|
||||
assert.False(t, allowed, testIP+" should not have passed "+test.desc)
|
||||
assert.Falsef(t, allowed, "%s should not have passed.", testIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -300,7 +367,7 @@ func TestContainsIsAllowed(t *testing.T) {
|
|||
|
||||
func TestContainsInsecure(t *testing.T) {
|
||||
mustNewIP := func(whitelistStrings []string, insecure bool) *IP {
|
||||
ip, err := NewIP(whitelistStrings, insecure)
|
||||
ip, err := NewIP(whitelistStrings, insecure, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -338,7 +405,7 @@ func TestContainsInsecure(t *testing.T) {
|
|||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ok, _, err := test.whiteLister.Contains(test.ip)
|
||||
ok, _, err := test.whiteLister.contains(test.ip)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, test.expected, ok)
|
||||
|
@ -355,13 +422,25 @@ func TestContainsBrokenIPs(t *testing.T) {
|
|||
"\\&$§&/(",
|
||||
}
|
||||
|
||||
whiteLister, err := NewIP([]string{"1.2.3.4/24"}, false)
|
||||
whiteLister, err := NewIP([]string{"1.2.3.4/24"}, false, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, testIP := range brokenIPs {
|
||||
_, ip, err := whiteLister.Contains(testIP)
|
||||
_, ip, err := whiteLister.contains(testIP)
|
||||
assert.Error(t, err)
|
||||
require.Nil(t, ip, err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func NewRequest(remoteAddr string, xForwardedFor []string) *http.Request {
|
||||
req := httptest.NewRequest(http.MethodGet, "http://example.com/foo", nil)
|
||||
if len(remoteAddr) > 0 {
|
||||
req.RemoteAddr = remoteAddr
|
||||
}
|
||||
if len(xForwardedFor) > 0 {
|
||||
for _, xff := range xForwardedFor {
|
||||
req.Header.Add(XForwardedFor, xff)
|
||||
}
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue