1
0
Fork 0

Middlewares: add forwardAuth.authResponseHeadersRegex

This commit is contained in:
iamolegga 2020-10-29 17:10:04 +03:00 committed by GitHub
parent b5198e63c4
commit 49cdb67ddc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 116 additions and 24 deletions

View file

@ -6,6 +6,7 @@ import (
"io/ioutil"
"net"
"net/http"
"regexp"
"strings"
"time"
@ -25,13 +26,14 @@ const (
)
type forwardAuth struct {
address string
authResponseHeaders []string
next http.Handler
name string
client http.Client
trustForwardHeader bool
authRequestHeaders []string
address string
authResponseHeaders []string
authResponseHeadersRegex *regexp.Regexp
next http.Handler
name string
client http.Client
trustForwardHeader bool
authRequestHeaders []string
}
// NewForward creates a forward auth middleware.
@ -66,6 +68,14 @@ func NewForward(ctx context.Context, next http.Handler, config dynamic.ForwardAu
fa.client.Transport = tr
}
if config.AuthResponseHeadersRegex != "" {
re, err := regexp.Compile(config.AuthResponseHeadersRegex)
if err != nil {
return nil, fmt.Errorf("error compiling regular expression %s: %w", config.AuthResponseHeadersRegex, err)
}
fa.authResponseHeadersRegex = re
}
return fa, nil
}
@ -156,6 +166,20 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
}
if fa.authResponseHeadersRegex != nil {
for headerKey := range req.Header {
if fa.authResponseHeadersRegex.MatchString(headerKey) {
req.Header.Del(headerKey)
}
}
for headerKey, headerValues := range forwardResponse.Header {
if fa.authResponseHeadersRegex.MatchString(headerKey) {
req.Header[headerKey] = append([]string(nil), headerValues...)
}
}
}
req.RequestURI = req.URL.RequestURI()
fa.next.ServeHTTP(rw, req)
}

View file

@ -57,6 +57,7 @@ func TestForwardAuthSuccess(t *testing.T) {
w.Header().Set("X-Auth-Secret", "secret")
w.Header().Add("X-Auth-Group", "group1")
w.Header().Add("X-Auth-Group", "group2")
w.Header().Add("Foo-Bar", "auth-value")
fmt.Fprintln(w, "Success")
}))
t.Cleanup(server.Close)
@ -65,12 +66,15 @@ func TestForwardAuthSuccess(t *testing.T) {
assert.Equal(t, "user@example.com", r.Header.Get("X-Auth-User"))
assert.Empty(t, r.Header.Get("X-Auth-Secret"))
assert.Equal(t, []string{"group1", "group2"}, r.Header["X-Auth-Group"])
assert.Equal(t, "auth-value", r.Header.Get("Foo-Bar"))
assert.Empty(t, r.Header.Get("Foo-Baz"))
fmt.Fprintln(w, "traefik")
})
auth := dynamic.ForwardAuth{
Address: server.URL,
AuthResponseHeaders: []string{"X-Auth-User", "X-Auth-Group"},
Address: server.URL,
AuthResponseHeaders: []string{"X-Auth-User", "X-Auth-Group"},
AuthResponseHeadersRegex: "^Foo-",
}
middleware, err := NewForward(context.Background(), next, auth, "authTest")
require.NoError(t, err)
@ -80,6 +84,8 @@ func TestForwardAuthSuccess(t *testing.T) {
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.Header.Set("X-Auth-Group", "admin_group")
req.Header.Set("Foo-Bar", "client-value")
req.Header.Set("Foo-Baz", "client-value")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)