Filter ForwardAuth request headers
This commit is contained in:
parent
f2e53a3569
commit
326be29568
16 changed files with 171 additions and 3 deletions
|
@ -31,6 +31,7 @@ type forwardAuth struct {
|
|||
name string
|
||||
client http.Client
|
||||
trustForwardHeader bool
|
||||
authRequestHeaders []string
|
||||
}
|
||||
|
||||
// NewForward creates a forward auth middleware.
|
||||
|
@ -43,6 +44,7 @@ func NewForward(ctx context.Context, next http.Handler, config dynamic.ForwardAu
|
|||
next: next,
|
||||
name: name,
|
||||
trustForwardHeader: config.TrustForwardHeader,
|
||||
authRequestHeaders: config.AuthRequestHeaders,
|
||||
}
|
||||
|
||||
// Ensure our request client does not follow redirects
|
||||
|
@ -89,7 +91,7 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||
// forwardReq.
|
||||
tracing.InjectRequestHeaders(req)
|
||||
|
||||
writeHeader(req, forwardReq, fa.trustForwardHeader)
|
||||
writeHeader(req, forwardReq, fa.trustForwardHeader, fa.authRequestHeaders)
|
||||
|
||||
forwardResponse, forwardErr := fa.client.Do(forwardReq)
|
||||
if forwardErr != nil {
|
||||
|
@ -158,10 +160,12 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||
fa.next.ServeHTTP(rw, req)
|
||||
}
|
||||
|
||||
func writeHeader(req, forwardReq *http.Request, trustForwardHeader bool) {
|
||||
func writeHeader(req, forwardReq *http.Request, trustForwardHeader bool, allowedHeaders []string) {
|
||||
utils.CopyHeaders(forwardReq.Header, req.Header)
|
||||
utils.RemoveHeaders(forwardReq.Header, forward.HopHeaders...)
|
||||
|
||||
forwardReq.Header = filterForwardRequestHeaders(forwardReq.Header, allowedHeaders)
|
||||
|
||||
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
|
||||
if trustForwardHeader {
|
||||
if prior, ok := req.Header[forward.XForwardedFor]; ok {
|
||||
|
@ -215,3 +219,19 @@ func writeHeader(req, forwardReq *http.Request, trustForwardHeader bool) {
|
|||
forwardReq.Header.Del(xForwardedURI)
|
||||
}
|
||||
}
|
||||
|
||||
func filterForwardRequestHeaders(forwardRequestHeaders http.Header, allowedHeaders []string) http.Header {
|
||||
if len(allowedHeaders) == 0 {
|
||||
return forwardRequestHeaders
|
||||
}
|
||||
|
||||
filteredHeaders := http.Header{}
|
||||
for _, headerName := range allowedHeaders {
|
||||
values := forwardRequestHeaders.Values(headerName)
|
||||
if len(values) > 0 {
|
||||
filteredHeaders[http.CanonicalHeaderKey(headerName)] = append([]string(nil), values...)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredHeaders
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue