Filter ForwardAuth request headers

This commit is contained in:
Nikita Konev 2020-10-07 17:36:04 +03:00 committed by GitHub
parent f2e53a3569
commit 326be29568
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 171 additions and 3 deletions

View file

@ -31,6 +31,7 @@ type forwardAuth struct {
name string
client http.Client
trustForwardHeader bool
authRequestHeaders []string
}
// NewForward creates a forward auth middleware.
@ -43,6 +44,7 @@ func NewForward(ctx context.Context, next http.Handler, config dynamic.ForwardAu
next: next,
name: name,
trustForwardHeader: config.TrustForwardHeader,
authRequestHeaders: config.AuthRequestHeaders,
}
// Ensure our request client does not follow redirects
@ -89,7 +91,7 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// forwardReq.
tracing.InjectRequestHeaders(req)
writeHeader(req, forwardReq, fa.trustForwardHeader)
writeHeader(req, forwardReq, fa.trustForwardHeader, fa.authRequestHeaders)
forwardResponse, forwardErr := fa.client.Do(forwardReq)
if forwardErr != nil {
@ -158,10 +160,12 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
fa.next.ServeHTTP(rw, req)
}
func writeHeader(req, forwardReq *http.Request, trustForwardHeader bool) {
func writeHeader(req, forwardReq *http.Request, trustForwardHeader bool, allowedHeaders []string) {
utils.CopyHeaders(forwardReq.Header, req.Header)
utils.RemoveHeaders(forwardReq.Header, forward.HopHeaders...)
forwardReq.Header = filterForwardRequestHeaders(forwardReq.Header, allowedHeaders)
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if trustForwardHeader {
if prior, ok := req.Header[forward.XForwardedFor]; ok {
@ -215,3 +219,19 @@ func writeHeader(req, forwardReq *http.Request, trustForwardHeader bool) {
forwardReq.Header.Del(xForwardedURI)
}
}
func filterForwardRequestHeaders(forwardRequestHeaders http.Header, allowedHeaders []string) http.Header {
if len(allowedHeaders) == 0 {
return forwardRequestHeaders
}
filteredHeaders := http.Header{}
for _, headerName := range allowedHeaders {
values := forwardRequestHeaders.Values(headerName)
if len(values) > 0 {
filteredHeaders[http.CanonicalHeaderKey(headerName)] = append([]string(nil), values...)
}
}
return filteredHeaders
}