Middlewares: add forwardAuth.authResponseHeadersRegex
This commit is contained in:
parent
b5198e63c4
commit
49cdb67ddc
12 changed files with 116 additions and 24 deletions
|
@ -6,6 +6,7 @@ import (
|
|||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
@ -25,13 +26,14 @@ const (
|
|||
)
|
||||
|
||||
type forwardAuth struct {
|
||||
address string
|
||||
authResponseHeaders []string
|
||||
next http.Handler
|
||||
name string
|
||||
client http.Client
|
||||
trustForwardHeader bool
|
||||
authRequestHeaders []string
|
||||
address string
|
||||
authResponseHeaders []string
|
||||
authResponseHeadersRegex *regexp.Regexp
|
||||
next http.Handler
|
||||
name string
|
||||
client http.Client
|
||||
trustForwardHeader bool
|
||||
authRequestHeaders []string
|
||||
}
|
||||
|
||||
// NewForward creates a forward auth middleware.
|
||||
|
@ -66,6 +68,14 @@ func NewForward(ctx context.Context, next http.Handler, config dynamic.ForwardAu
|
|||
fa.client.Transport = tr
|
||||
}
|
||||
|
||||
if config.AuthResponseHeadersRegex != "" {
|
||||
re, err := regexp.Compile(config.AuthResponseHeadersRegex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error compiling regular expression %s: %w", config.AuthResponseHeadersRegex, err)
|
||||
}
|
||||
fa.authResponseHeadersRegex = re
|
||||
}
|
||||
|
||||
return fa, nil
|
||||
}
|
||||
|
||||
|
@ -156,6 +166,20 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
if fa.authResponseHeadersRegex != nil {
|
||||
for headerKey := range req.Header {
|
||||
if fa.authResponseHeadersRegex.MatchString(headerKey) {
|
||||
req.Header.Del(headerKey)
|
||||
}
|
||||
}
|
||||
|
||||
for headerKey, headerValues := range forwardResponse.Header {
|
||||
if fa.authResponseHeadersRegex.MatchString(headerKey) {
|
||||
req.Header[headerKey] = append([]string(nil), headerValues...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
req.RequestURI = req.URL.RequestURI()
|
||||
fa.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
|
|
@ -57,6 +57,7 @@ func TestForwardAuthSuccess(t *testing.T) {
|
|||
w.Header().Set("X-Auth-Secret", "secret")
|
||||
w.Header().Add("X-Auth-Group", "group1")
|
||||
w.Header().Add("X-Auth-Group", "group2")
|
||||
w.Header().Add("Foo-Bar", "auth-value")
|
||||
fmt.Fprintln(w, "Success")
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
@ -65,12 +66,15 @@ func TestForwardAuthSuccess(t *testing.T) {
|
|||
assert.Equal(t, "user@example.com", r.Header.Get("X-Auth-User"))
|
||||
assert.Empty(t, r.Header.Get("X-Auth-Secret"))
|
||||
assert.Equal(t, []string{"group1", "group2"}, r.Header["X-Auth-Group"])
|
||||
assert.Equal(t, "auth-value", r.Header.Get("Foo-Bar"))
|
||||
assert.Empty(t, r.Header.Get("Foo-Baz"))
|
||||
fmt.Fprintln(w, "traefik")
|
||||
})
|
||||
|
||||
auth := dynamic.ForwardAuth{
|
||||
Address: server.URL,
|
||||
AuthResponseHeaders: []string{"X-Auth-User", "X-Auth-Group"},
|
||||
Address: server.URL,
|
||||
AuthResponseHeaders: []string{"X-Auth-User", "X-Auth-Group"},
|
||||
AuthResponseHeadersRegex: "^Foo-",
|
||||
}
|
||||
middleware, err := NewForward(context.Background(), next, auth, "authTest")
|
||||
require.NoError(t, err)
|
||||
|
@ -80,6 +84,8 @@ func TestForwardAuthSuccess(t *testing.T) {
|
|||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
|
||||
req.Header.Set("X-Auth-Group", "admin_group")
|
||||
req.Header.Set("Foo-Bar", "client-value")
|
||||
req.Header.Set("Foo-Baz", "client-value")
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue