Fix TCP-TLS/HTTPS routing precedence
Co-authored-by: Mathieu Lonjaret <mathieu.lonjaret@gmail.com>
This commit is contained in:
parent
ede2be1f66
commit
ac4086d0ac
4 changed files with 1031 additions and 36 deletions
|
@ -95,32 +95,21 @@ func NewMuxer() (*Muxer, error) {
|
|||
return &Muxer{parser: parser}, nil
|
||||
}
|
||||
|
||||
// Match returns the handler of the first route matching the connection metadata.
|
||||
func (m Muxer) Match(meta ConnData) tcp.Handler {
|
||||
// Match returns the handler of the first route matching the connection metadata,
|
||||
// and whether the match is exactly from the rule HostSNI(*).
|
||||
func (m Muxer) Match(meta ConnData) (tcp.Handler, bool) {
|
||||
for _, route := range m.routes {
|
||||
if route.matchers.match(meta) {
|
||||
return route.handler
|
||||
return route.handler, route.catchAll
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// AddRoute adds a new route, associated to the given handler, at the given
|
||||
// priority, to the muxer.
|
||||
func (m *Muxer) AddRoute(rule string, priority int, handler tcp.Handler) error {
|
||||
// Special case for when the catchAll fallback is present.
|
||||
// When no user-defined priority is found, the lowest computable priority minus one is used,
|
||||
// in order to make the fallback the last to be evaluated.
|
||||
if priority == 0 && rule == "HostSNI(`*`)" {
|
||||
priority = -1
|
||||
}
|
||||
|
||||
// Default value, which means the user has not set it, so we'll compute it.
|
||||
if priority == 0 {
|
||||
priority = len(rule)
|
||||
}
|
||||
|
||||
parse, err := m.parser.Parse(rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error while parsing rule %s: %w", rule, err)
|
||||
|
@ -131,16 +120,36 @@ func (m *Muxer) AddRoute(rule string, priority int, handler tcp.Handler) error {
|
|||
return fmt.Errorf("error while parsing rule %s", rule)
|
||||
}
|
||||
|
||||
ruleTree := buildTree()
|
||||
|
||||
var matchers matchersTree
|
||||
err = addRule(&matchers, buildTree())
|
||||
err = addRule(&matchers, ruleTree)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var catchAll bool
|
||||
if ruleTree.RuleLeft == nil && ruleTree.RuleRight == nil && len(ruleTree.Value) == 1 {
|
||||
catchAll = ruleTree.Value[0] == "*" && strings.EqualFold(ruleTree.Matcher, "HostSNI")
|
||||
}
|
||||
|
||||
// Special case for when the catchAll fallback is present.
|
||||
// When no user-defined priority is found, the lowest computable priority minus one is used,
|
||||
// in order to make the fallback the last to be evaluated.
|
||||
if priority == 0 && catchAll {
|
||||
priority = -1
|
||||
}
|
||||
|
||||
// Default value, which means the user has not set it, so we'll compute it.
|
||||
if priority == 0 {
|
||||
priority = len(rule)
|
||||
}
|
||||
|
||||
newRoute := &route{
|
||||
handler: handler,
|
||||
priority: priority,
|
||||
matchers: matchers,
|
||||
catchAll: catchAll,
|
||||
priority: priority,
|
||||
}
|
||||
m.routes = append(m.routes, newRoute)
|
||||
|
||||
|
@ -207,9 +216,10 @@ type route struct {
|
|||
matchers matchersTree
|
||||
// handler responsible for handling the route.
|
||||
handler tcp.Handler
|
||||
|
||||
// Used to disambiguate between two (or more) rules that would both match for a
|
||||
// given request.
|
||||
// catchAll indicates whether the route rule has exactly the catchAll value (HostSNI(`*`)).
|
||||
catchAll bool
|
||||
// priority is used to disambiguate between two (or more) rules that would
|
||||
// all match for a given request.
|
||||
// Computed from the matching rule length, if not user-set.
|
||||
priority int
|
||||
}
|
||||
|
|
|
@ -474,7 +474,7 @@ func Test_addTCPRoute(t *testing.T) {
|
|||
connData, err := NewConnData(test.serverName, conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
matchingHandler := router.Match(connData)
|
||||
matchingHandler, _ := router.Match(connData)
|
||||
if test.matchErr {
|
||||
require.Nil(t, matchingHandler)
|
||||
return
|
||||
|
@ -568,6 +568,54 @@ func TestParseHostSNI(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func Test_HostSNICatchAll(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
rule string
|
||||
isCatchAll bool
|
||||
}{
|
||||
{
|
||||
desc: "HostSNI(`foobar`) is not catchAll",
|
||||
rule: "HostSNI(`foobar`)",
|
||||
},
|
||||
{
|
||||
desc: "HostSNI(`*`) is catchAll",
|
||||
rule: "HostSNI(`*`)",
|
||||
isCatchAll: true,
|
||||
},
|
||||
{
|
||||
desc: "HOSTSNI(`*`) is catchAll",
|
||||
rule: "HOSTSNI(`*`)",
|
||||
isCatchAll: true,
|
||||
},
|
||||
{
|
||||
desc: `HostSNI("*") is catchAll`,
|
||||
rule: `HostSNI("*")`,
|
||||
isCatchAll: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range testCases {
|
||||
test := test
|
||||
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
muxer, err := NewMuxer()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = muxer.AddRoute(test.rule, 0, tcp.HandlerFunc(func(conn tcp.WriteCloser) {}))
|
||||
require.NoError(t, err)
|
||||
|
||||
handler, catchAll := muxer.Match(ConnData{
|
||||
serverName: "foobar",
|
||||
})
|
||||
require.NotNil(t, handler)
|
||||
assert.Equal(t, test.isCatchAll, catchAll)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HostSNI(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
|
@ -934,7 +982,7 @@ func Test_Priority(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
handler := muxer.Match(ConnData{
|
||||
handler, _ := muxer.Match(ConnData{
|
||||
serverName: test.serverName,
|
||||
remoteIP: test.remoteIP,
|
||||
})
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue