IPStrategy for selecting IP in whitelist
This commit is contained in:
parent
1ec4e03738
commit
00728e711c
65 changed files with 2444 additions and 1837 deletions
99
ip/checker.go
Normal file
99
ip/checker.go
Normal file
|
@ -0,0 +1,99 @@
|
|||
package ip
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Checker allows to check that addresses are in a trusted IPs
|
||||
type Checker struct {
|
||||
authorizedIPs []*net.IP
|
||||
authorizedIPsNet []*net.IPNet
|
||||
}
|
||||
|
||||
// NewChecker builds a new Checker given a list of CIDR-Strings to trusted IPs
|
||||
func NewChecker(trustedIPs []string) (*Checker, error) {
|
||||
if len(trustedIPs) == 0 {
|
||||
return nil, errors.New("no trusted IPs provided")
|
||||
}
|
||||
|
||||
checker := &Checker{}
|
||||
|
||||
for _, ipMask := range trustedIPs {
|
||||
if ipAddr := net.ParseIP(ipMask); ipAddr != nil {
|
||||
checker.authorizedIPs = append(checker.authorizedIPs, &ipAddr)
|
||||
} else {
|
||||
_, ipAddr, err := net.ParseCIDR(ipMask)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing CIDR trusted IPs %s: %v", ipAddr, err)
|
||||
}
|
||||
checker.authorizedIPsNet = append(checker.authorizedIPsNet, ipAddr)
|
||||
}
|
||||
}
|
||||
|
||||
return checker, nil
|
||||
}
|
||||
|
||||
// IsAuthorized checks if provided request is authorized by the trusted IPs
|
||||
func (ip *Checker) IsAuthorized(addr string) error {
|
||||
var invalidMatches []string
|
||||
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
host = addr
|
||||
}
|
||||
|
||||
ok, err := ip.Contains(host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !ok {
|
||||
invalidMatches = append(invalidMatches, addr)
|
||||
return fmt.Errorf("%q matched none of the trusted IPs", strings.Join(invalidMatches, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Contains checks if provided address is in the trusted IPs
|
||||
func (ip *Checker) Contains(addr string) (bool, error) {
|
||||
if len(addr) <= 0 {
|
||||
return false, errors.New("empty IP address")
|
||||
}
|
||||
|
||||
ipAddr, err := parseIP(addr)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("unable to parse address: %s: %s", addr, err)
|
||||
}
|
||||
|
||||
return ip.ContainsIP(ipAddr), nil
|
||||
}
|
||||
|
||||
// ContainsIP checks if provided address is in the trusted IPs
|
||||
func (ip *Checker) ContainsIP(addr net.IP) bool {
|
||||
for _, authorizedIP := range ip.authorizedIPs {
|
||||
if authorizedIP.Equal(addr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for _, authorizedNet := range ip.authorizedIPsNet {
|
||||
if authorizedNet.Contains(addr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func parseIP(addr string) (net.IP, error) {
|
||||
userIP := net.ParseIP(addr)
|
||||
if userIP == nil {
|
||||
return nil, fmt.Errorf("can't parse IP from address %s", addr)
|
||||
}
|
||||
|
||||
return userIP, nil
|
||||
}
|
326
ip/checker_test.go
Normal file
326
ip/checker_test.go
Normal file
|
@ -0,0 +1,326 @@
|
|||
package ip
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsAuthorized(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
whiteList []string
|
||||
remoteAddr string
|
||||
authorized bool
|
||||
}{
|
||||
{
|
||||
desc: "remoteAddr not in range",
|
||||
whiteList: []string{"1.2.3.4/24"},
|
||||
remoteAddr: "10.2.3.1:123",
|
||||
authorized: false,
|
||||
},
|
||||
{
|
||||
desc: "remoteAddr in range",
|
||||
whiteList: []string{"1.2.3.4/24"},
|
||||
remoteAddr: "1.2.3.1:123",
|
||||
authorized: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ipChecker, err := NewChecker(test.whiteList)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = ipChecker.IsAuthorized(test.remoteAddr)
|
||||
if test.authorized {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
trustedIPs []string
|
||||
expectedAuthorizedIPs []*net.IPNet
|
||||
errMessage string
|
||||
}{
|
||||
{
|
||||
desc: "nil trusted IPs",
|
||||
trustedIPs: nil,
|
||||
expectedAuthorizedIPs: nil,
|
||||
errMessage: "no trusted IPs provided",
|
||||
}, {
|
||||
desc: "empty trusted IPs",
|
||||
trustedIPs: []string{},
|
||||
expectedAuthorizedIPs: nil,
|
||||
errMessage: "no trusted IPs provided",
|
||||
}, {
|
||||
desc: "trusted IPs containing empty string",
|
||||
trustedIPs: []string{
|
||||
"1.2.3.4/24",
|
||||
"",
|
||||
"fe80::/16",
|
||||
},
|
||||
expectedAuthorizedIPs: nil,
|
||||
errMessage: "parsing CIDR trusted IPs <nil>: invalid CIDR address: ",
|
||||
}, {
|
||||
desc: "trusted IPs containing only an empty string",
|
||||
trustedIPs: []string{
|
||||
"",
|
||||
},
|
||||
expectedAuthorizedIPs: nil,
|
||||
errMessage: "parsing CIDR trusted IPs <nil>: invalid CIDR address: ",
|
||||
}, {
|
||||
desc: "trusted IPs containing an invalid string",
|
||||
trustedIPs: []string{
|
||||
"foo",
|
||||
},
|
||||
expectedAuthorizedIPs: nil,
|
||||
errMessage: "parsing CIDR trusted IPs <nil>: invalid CIDR address: foo",
|
||||
}, {
|
||||
desc: "IPv4 & IPv6 trusted IPs",
|
||||
trustedIPs: []string{
|
||||
"1.2.3.4/24",
|
||||
"fe80::/16",
|
||||
},
|
||||
expectedAuthorizedIPs: []*net.IPNet{
|
||||
{IP: net.IPv4(1, 2, 3, 0).To4(), Mask: net.IPv4Mask(255, 255, 255, 0)},
|
||||
{IP: net.ParseIP("fe80::"), Mask: net.IPMask(net.ParseIP("ffff::"))},
|
||||
},
|
||||
errMessage: "",
|
||||
}, {
|
||||
desc: "IPv4 only",
|
||||
trustedIPs: []string{
|
||||
"127.0.0.1/8",
|
||||
},
|
||||
expectedAuthorizedIPs: []*net.IPNet{
|
||||
{IP: net.IPv4(127, 0, 0, 0).To4(), Mask: net.IPv4Mask(255, 0, 0, 0)},
|
||||
},
|
||||
errMessage: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ipChecker, err := NewChecker(test.trustedIPs)
|
||||
if test.errMessage != "" {
|
||||
require.EqualError(t, err, test.errMessage)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
for index, actual := range ipChecker.authorizedIPsNet {
|
||||
expected := test.expectedAuthorizedIPs[index]
|
||||
assert.Equal(t, expected.IP, actual.IP)
|
||||
assert.Equal(t, expected.Mask.String(), actual.Mask.String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsIsAllowed(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
trustedIPs []string
|
||||
passIPs []string
|
||||
rejectIPs []string
|
||||
}{
|
||||
{
|
||||
desc: "IPv4",
|
||||
trustedIPs: []string{"1.2.3.4/24"},
|
||||
passIPs: []string{
|
||||
"1.2.3.1",
|
||||
"1.2.3.32",
|
||||
"1.2.3.156",
|
||||
"1.2.3.255",
|
||||
},
|
||||
rejectIPs: []string{
|
||||
"1.2.16.1",
|
||||
"1.2.32.1",
|
||||
"127.0.0.1",
|
||||
"8.8.8.8",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "IPv4 single IP",
|
||||
trustedIPs: []string{"8.8.8.8"},
|
||||
passIPs: []string{"8.8.8.8"},
|
||||
rejectIPs: []string{
|
||||
"8.8.8.7",
|
||||
"8.8.8.9",
|
||||
"8.8.8.0",
|
||||
"8.8.8.255",
|
||||
"4.4.4.4",
|
||||
"127.0.0.1",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "IPv4 Net single IP",
|
||||
trustedIPs: []string{"8.8.8.8/32"},
|
||||
passIPs: []string{"8.8.8.8"},
|
||||
rejectIPs: []string{
|
||||
"8.8.8.7",
|
||||
"8.8.8.9",
|
||||
"8.8.8.0",
|
||||
"8.8.8.255",
|
||||
"4.4.4.4",
|
||||
"127.0.0.1",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple IPv4",
|
||||
trustedIPs: []string{"1.2.3.4/24", "8.8.8.8/8"},
|
||||
passIPs: []string{
|
||||
"1.2.3.1",
|
||||
"1.2.3.32",
|
||||
"1.2.3.156",
|
||||
"1.2.3.255",
|
||||
"8.8.4.4",
|
||||
"8.0.0.1",
|
||||
"8.32.42.128",
|
||||
"8.255.255.255",
|
||||
},
|
||||
rejectIPs: []string{
|
||||
"1.2.16.1",
|
||||
"1.2.32.1",
|
||||
"127.0.0.1",
|
||||
"4.4.4.4",
|
||||
"4.8.8.8",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "IPv6",
|
||||
trustedIPs: []string{"2a03:4000:6:d080::/64"},
|
||||
passIPs: []string{
|
||||
"2a03:4000:6:d080::",
|
||||
"2a03:4000:6:d080::1",
|
||||
"2a03:4000:6:d080:dead:beef:ffff:ffff",
|
||||
"2a03:4000:6:d080::42",
|
||||
},
|
||||
rejectIPs: []string{
|
||||
"2a03:4000:7:d080::",
|
||||
"2a03:4000:7:d080::1",
|
||||
"fe80::",
|
||||
"4242::1",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "IPv6 single IP",
|
||||
trustedIPs: []string{"2a03:4000:6:d080::42/128"},
|
||||
passIPs: []string{"2a03:4000:6:d080::42"},
|
||||
rejectIPs: []string{
|
||||
"2a03:4000:6:d080::1",
|
||||
"2a03:4000:6:d080:dead:beef:ffff:ffff",
|
||||
"2a03:4000:6:d080::43",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple IPv6",
|
||||
trustedIPs: []string{"2a03:4000:6:d080::/64", "fe80::/16"},
|
||||
passIPs: []string{
|
||||
"2a03:4000:6:d080::",
|
||||
"2a03:4000:6:d080::1",
|
||||
"2a03:4000:6:d080:dead:beef:ffff:ffff",
|
||||
"2a03:4000:6:d080::42",
|
||||
"fe80::1",
|
||||
"fe80:aa00:00bb:4232:ff00:eeee:00ff:1111",
|
||||
"fe80::fe80",
|
||||
},
|
||||
rejectIPs: []string{
|
||||
"2a03:4000:7:d080::",
|
||||
"2a03:4000:7:d080::1",
|
||||
"4242::1",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple IPv6 & IPv4",
|
||||
trustedIPs: []string{"2a03:4000:6:d080::/64", "fe80::/16", "1.2.3.4/24", "8.8.8.8/8"},
|
||||
passIPs: []string{
|
||||
"2a03:4000:6:d080::",
|
||||
"2a03:4000:6:d080::1",
|
||||
"2a03:4000:6:d080:dead:beef:ffff:ffff",
|
||||
"2a03:4000:6:d080::42",
|
||||
"fe80::1",
|
||||
"fe80:aa00:00bb:4232:ff00:eeee:00ff:1111",
|
||||
"fe80::fe80",
|
||||
"1.2.3.1",
|
||||
"1.2.3.32",
|
||||
"1.2.3.156",
|
||||
"1.2.3.255",
|
||||
"8.8.4.4",
|
||||
"8.0.0.1",
|
||||
"8.32.42.128",
|
||||
"8.255.255.255",
|
||||
},
|
||||
rejectIPs: []string{
|
||||
"2a03:4000:7:d080::",
|
||||
"2a03:4000:7:d080::1",
|
||||
"4242::1",
|
||||
"1.2.16.1",
|
||||
"1.2.32.1",
|
||||
"127.0.0.1",
|
||||
"4.4.4.4",
|
||||
"4.8.8.8",
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "broken IP-addresses",
|
||||
trustedIPs: []string{"127.0.0.1/32"},
|
||||
passIPs: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ipChecker, err := NewChecker(test.trustedIPs)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, ipChecker)
|
||||
|
||||
for _, testIP := range test.passIPs {
|
||||
allowed, err := ipChecker.Contains(testIP)
|
||||
require.NoError(t, err)
|
||||
assert.Truef(t, allowed, "%s should have passed.", testIP)
|
||||
}
|
||||
|
||||
for _, testIP := range test.rejectIPs {
|
||||
allowed, err := ipChecker.Contains(testIP)
|
||||
require.NoError(t, err)
|
||||
assert.Falsef(t, allowed, "%s should not have passed.", testIP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestContainsBrokenIPs(t *testing.T) {
|
||||
brokenIPs := []string{
|
||||
"foo",
|
||||
"10.0.0.350",
|
||||
"fe:::80",
|
||||
"",
|
||||
"\\&$§&/(",
|
||||
}
|
||||
|
||||
ipChecker, err := NewChecker([]string{"1.2.3.4/24"})
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, testIP := range brokenIPs {
|
||||
_, err := ipChecker.Contains(testIP)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
62
ip/strategy.go
Normal file
62
ip/strategy.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package ip
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
xForwardedFor = "X-Forwarded-For"
|
||||
)
|
||||
|
||||
// Strategy a strategy for IP selection
|
||||
type Strategy interface {
|
||||
GetIP(req *http.Request) string
|
||||
}
|
||||
|
||||
// RemoteAddrStrategy a strategy that always return the remote address
|
||||
type RemoteAddrStrategy struct{}
|
||||
|
||||
// GetIP return the selected IP
|
||||
func (s *RemoteAddrStrategy) GetIP(req *http.Request) string {
|
||||
return req.RemoteAddr
|
||||
}
|
||||
|
||||
// DepthStrategy a strategy based on the depth inside the X-Forwarded-For from right to left
|
||||
type DepthStrategy struct {
|
||||
Depth int
|
||||
}
|
||||
|
||||
// GetIP return the selected IP
|
||||
func (s *DepthStrategy) GetIP(req *http.Request) string {
|
||||
xff := req.Header.Get(xForwardedFor)
|
||||
xffs := strings.Split(xff, ",")
|
||||
|
||||
if len(xffs) < s.Depth {
|
||||
return ""
|
||||
}
|
||||
return xffs[len(xffs)-s.Depth]
|
||||
}
|
||||
|
||||
// CheckerStrategy a strategy based on an IP Checker
|
||||
// allows to check that addresses are in a trusted IPs
|
||||
type CheckerStrategy struct {
|
||||
Checker *Checker
|
||||
}
|
||||
|
||||
// GetIP return the selected IP
|
||||
func (s *CheckerStrategy) GetIP(req *http.Request) string {
|
||||
if s.Checker == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
xff := req.Header.Get(xForwardedFor)
|
||||
xffs := strings.Split(xff, ",")
|
||||
|
||||
for i := len(xffs) - 1; i >= 0; i-- {
|
||||
if contain, _ := s.Checker.Contains(xffs[i]); !contain {
|
||||
return xffs[i]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
125
ip/strategy_test.go
Normal file
125
ip/strategy_test.go
Normal file
|
@ -0,0 +1,125 @@
|
|||
package ip
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRemoteAddrStrategy_GetIP(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "Use RemoteAddr",
|
||||
expected: "192.0.2.1:1234",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
strategy := RemoteAddrStrategy{}
|
||||
req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
||||
actual := strategy.GetIP(req)
|
||||
assert.Equal(t, test.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDepthStrategy_GetIP(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
depth int
|
||||
xForwardedFor string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "Use depth",
|
||||
depth: 3,
|
||||
xForwardedFor: "10.0.0.4,10.0.0.3,10.0.0.2,10.0.0.1",
|
||||
expected: "10.0.0.3",
|
||||
},
|
||||
{
|
||||
desc: "Use non existing depth in XForwardedFor",
|
||||
depth: 2,
|
||||
xForwardedFor: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
desc: "Use depth that match the first IP in XForwardedFor",
|
||||
depth: 2,
|
||||
xForwardedFor: "10.0.0.2,10.0.0.1",
|
||||
expected: "10.0.0.2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
strategy := DepthStrategy{Depth: test.depth}
|
||||
req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
||||
req.Header.Set(xForwardedFor, test.xForwardedFor)
|
||||
actual := strategy.GetIP(req)
|
||||
assert.Equal(t, test.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExcludedIPsStrategy_GetIP(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
excludedIPs []string
|
||||
xForwardedFor string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
desc: "Use excluded all IPs",
|
||||
excludedIPs: []string{"10.0.0.4", "10.0.0.3", "10.0.0.2", "10.0.0.1"},
|
||||
xForwardedFor: "10.0.0.4,10.0.0.3,10.0.0.2,10.0.0.1",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
desc: "Use excluded IPs",
|
||||
excludedIPs: []string{"10.0.0.2", "10.0.0.1"},
|
||||
xForwardedFor: "10.0.0.4,10.0.0.3,10.0.0.2,10.0.0.1",
|
||||
expected: "10.0.0.3",
|
||||
},
|
||||
{
|
||||
desc: "Use excluded IPs CIDR",
|
||||
excludedIPs: []string{"10.0.0.1/24"},
|
||||
xForwardedFor: "127.0.0.1,10.0.0.4,10.0.0.3,10.0.0.2,10.0.0.1",
|
||||
expected: "127.0.0.1",
|
||||
},
|
||||
{
|
||||
desc: "Use excluded all IPs CIDR",
|
||||
excludedIPs: []string{"10.0.0.1/24"},
|
||||
xForwardedFor: "10.0.0.4,10.0.0.3,10.0.0.2,10.0.0.1",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
checker, err := NewChecker(test.excludedIPs)
|
||||
require.NoError(t, err)
|
||||
|
||||
strategy := CheckerStrategy{Checker: checker}
|
||||
req := httptest.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
|
||||
req.Header.Set(xForwardedFor, test.xForwardedFor)
|
||||
actual := strategy.GetIP(req)
|
||||
assert.Equal(t, test.expected, actual)
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue