1
0
Fork 0

Add forwardAuth.addAuthCookiesToResponse

This commit is contained in:
Thomas Gunsch 2024-01-15 16:14:05 +01:00 committed by GitHub
parent 980dac4572
commit 81ce45271d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 155 additions and 19 deletions

View file

@ -223,6 +223,8 @@ type ForwardAuth struct {
// AuthRequestHeaders defines the list of the headers to copy from the request to the authentication server.
// If not set or empty then all request headers are passed.
AuthRequestHeaders []string `json:"authRequestHeaders,omitempty" toml:"authRequestHeaders,omitempty" yaml:"authRequestHeaders,omitempty" export:"true"`
// AddAuthCookiesToResponse defines the list of cookies to copy from the authentication server response to the response.
AddAuthCookiesToResponse []string `json:"addAuthCookiesToResponse,omitempty" toml:"addAuthCookiesToResponse,omitempty" yaml:"addAuthCookiesToResponse,omitempty" export:"true"`
}
// +k8s:deepcopy-gen=true

View file

@ -324,6 +324,11 @@ func (in *ForwardAuth) DeepCopyInto(out *ForwardAuth) {
*out = make([]string, len(*in))
copy(*out, *in)
}
if in.AddAuthCookiesToResponse != nil {
in, out := &in.AddAuthCookiesToResponse, &out.AddAuthCookiesToResponse
*out = make([]string, len(*in))
copy(*out, *in)
}
return
}

View file

@ -48,19 +48,26 @@ type forwardAuth struct {
client http.Client
trustForwardHeader bool
authRequestHeaders []string
addAuthCookiesToResponse map[string]struct{}
}
// NewForward creates a forward auth middleware.
func NewForward(ctx context.Context, next http.Handler, config dynamic.ForwardAuth, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, typeNameForward).Debug().Msg("Creating middleware")
addAuthCookiesToResponse := make(map[string]struct{})
for _, cookieName := range config.AddAuthCookiesToResponse {
addAuthCookiesToResponse[cookieName] = struct{}{}
}
fa := &forwardAuth{
address: config.Address,
authResponseHeaders: config.AuthResponseHeaders,
next: next,
name: name,
trustForwardHeader: config.TrustForwardHeader,
authRequestHeaders: config.AuthRequestHeaders,
address: config.Address,
authResponseHeaders: config.AuthResponseHeaders,
next: next,
name: name,
trustForwardHeader: config.TrustForwardHeader,
authRequestHeaders: config.AuthRequestHeaders,
addAuthCookiesToResponse: addAuthCookiesToResponse,
}
// Ensure our request client does not follow redirects
@ -211,7 +218,35 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
tracing.LogResponseCode(forwardSpan, forwardResponse.StatusCode, trace.SpanKindClient)
req.RequestURI = req.URL.RequestURI()
fa.next.ServeHTTP(rw, req)
authCookies := forwardResponse.Cookies()
if len(authCookies) == 0 {
fa.next.ServeHTTP(rw, req)
return
}
fa.next.ServeHTTP(middlewares.NewResponseModifier(rw, req, fa.buildModifier(authCookies)), req)
}
func (fa *forwardAuth) buildModifier(authCookies []*http.Cookie) func(res *http.Response) error {
return func(res *http.Response) error {
cookies := res.Cookies()
res.Header.Del("Set-Cookie")
for _, cookie := range cookies {
if _, found := fa.addAuthCookiesToResponse[cookie.Name]; !found {
res.Header.Add("Set-Cookie", cookie.String())
}
}
for _, cookie := range authCookies {
if _, found := fa.addAuthCookiesToResponse[cookie.Name]; found {
res.Header.Add("Set-Cookie", cookie.String())
}
}
return nil
}
}
func writeHeader(req, forwardReq *http.Request, trustForwardHeader bool, allowedHeaders []string) {

View file

@ -66,6 +66,8 @@ func TestForwardAuthSuccess(t *testing.T) {
w.Header().Add("X-Auth-Group", "group1")
w.Header().Add("X-Auth-Group", "group2")
w.Header().Add("Foo-Bar", "auth-value")
w.Header().Add("Set-Cookie", "authCookie=Auth")
w.Header().Add("Set-Cookie", "authCookieNotAdded=Auth")
fmt.Fprintln(w, "Success")
}))
t.Cleanup(server.Close)
@ -76,6 +78,9 @@ func TestForwardAuthSuccess(t *testing.T) {
assert.Equal(t, []string{"group1", "group2"}, r.Header["X-Auth-Group"])
assert.Equal(t, "auth-value", r.Header.Get("Foo-Bar"))
assert.Empty(t, r.Header.Get("Foo-Baz"))
w.Header().Add("Set-Cookie", "authCookie=Backend")
w.Header().Add("Set-Cookie", "backendCookie=Backend")
w.Header().Add("Other-Header", "BackendHeaderValue")
fmt.Fprintln(w, "traefik")
})
@ -83,6 +88,7 @@ func TestForwardAuthSuccess(t *testing.T) {
Address: server.URL,
AuthResponseHeaders: []string{"X-Auth-User", "X-Auth-Group"},
AuthResponseHeadersRegex: "^Foo-",
AddAuthCookiesToResponse: []string{"authCookie"},
}
middleware, err := NewForward(context.Background(), next, auth, "authTest")
require.NoError(t, err)
@ -97,6 +103,8 @@ func TestForwardAuthSuccess(t *testing.T) {
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, []string{"backendCookie=Backend", "authCookie=Auth"}, res.Header["Set-Cookie"])
assert.Equal(t, []string{"BackendHeaderValue"}, res.Header["Other-Header"])
body, err := io.ReadAll(res.Body)
require.NoError(t, err)

View file

@ -8,6 +8,7 @@ import (
"strings"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/middlewares"
"github.com/vulcand/oxy/v2/forward"
)
@ -58,7 +59,7 @@ func (s *Header) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// If there is a next, call it.
if s.next != nil {
s.next.ServeHTTP(newResponseModifier(rw, req, s.PostRequestModifyResponseHeaders), req)
s.next.ServeHTTP(middlewares.NewResponseModifier(rw, req, s.PostRequestModifyResponseHeaders), req)
}
}

View file

@ -4,6 +4,7 @@ import (
"net/http"
"github.com/traefik/traefik/v3/pkg/config/dynamic"
"github.com/traefik/traefik/v3/pkg/middlewares"
"github.com/unrolled/secure"
)
@ -45,6 +46,6 @@ func newSecure(next http.Handler, cfg dynamic.Headers, contextKey string) *secur
func (s secureHeader) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
s.secure.HandlerFuncWithNextForRequestOnly(rw, req, func(writer http.ResponseWriter, request *http.Request) {
s.next.ServeHTTP(newResponseModifier(writer, request, s.secure.ModifyResponseHeaders), request)
s.next.ServeHTTP(middlewares.NewResponseModifier(writer, request, s.secure.ModifyResponseHeaders), request)
})
}

View file

@ -1,4 +1,4 @@
package headers
package middlewares
import (
"bufio"
@ -9,7 +9,8 @@ import (
"github.com/rs/zerolog/log"
)
type responseModifier struct {
// ResponseModifier is a ResponseWriter to modify the response headers before sending them.
type ResponseModifier struct {
req *http.Request
rw http.ResponseWriter
@ -21,9 +22,10 @@ type responseModifier struct {
modifierErr error // returned by modifier call
}
// modifier can be nil.
func newResponseModifier(w http.ResponseWriter, r *http.Request, modifier func(*http.Response) error) http.ResponseWriter {
return &responseModifier{
// NewResponseModifier returns a new ResponseModifier instance.
// The given modifier can be nil.
func NewResponseModifier(w http.ResponseWriter, r *http.Request, modifier func(*http.Response) error) http.ResponseWriter {
return &ResponseModifier{
req: r,
rw: w,
modifier: modifier,
@ -33,7 +35,7 @@ func newResponseModifier(w http.ResponseWriter, r *http.Request, modifier func(*
// WriteHeader is, in the specific case of 1xx status codes, a direct call to the wrapped ResponseWriter, without marking headers as sent,
// allowing so further calls.
func (r *responseModifier) WriteHeader(code int) {
func (r *ResponseModifier) WriteHeader(code int) {
if r.headersSent {
return
}
@ -73,11 +75,11 @@ func (r *responseModifier) WriteHeader(code int) {
r.rw.WriteHeader(code)
}
func (r *responseModifier) Header() http.Header {
func (r *ResponseModifier) Header() http.Header {
return r.rw.Header()
}
func (r *responseModifier) Write(b []byte) (int, error) {
func (r *ResponseModifier) Write(b []byte) (int, error) {
r.WriteHeader(r.code)
if r.modifierErr != nil {
return 0, r.modifierErr
@ -87,7 +89,7 @@ func (r *responseModifier) Write(b []byte) (int, error) {
}
// Hijack hijacks the connection.
func (r *responseModifier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
func (r *ResponseModifier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if h, ok := r.rw.(http.Hijacker); ok {
return h.Hijack()
}
@ -96,7 +98,7 @@ func (r *responseModifier) Hijack() (net.Conn, *bufio.ReadWriter, error) {
}
// Flush sends any buffered data to the client.
func (r *responseModifier) Flush() {
func (r *ResponseModifier) Flush() {
if flusher, ok := r.rw.(http.Flusher); ok {
flusher.Flush()
}

View file

@ -728,6 +728,7 @@ func createForwardAuthMiddleware(k8sClient Client, namespace string, auth *traef
AuthResponseHeaders: auth.AuthResponseHeaders,
AuthResponseHeadersRegex: auth.AuthResponseHeadersRegex,
AuthRequestHeaders: auth.AuthRequestHeaders,
AddAuthCookiesToResponse: auth.AddAuthCookiesToResponse,
}
if auth.TLS == nil {

View file

@ -157,6 +157,8 @@ type ForwardAuth struct {
AuthRequestHeaders []string `json:"authRequestHeaders,omitempty"`
// TLS defines the configuration used to secure the connection to the authentication server.
TLS *ClientTLS `json:"tls,omitempty"`
// AddAuthCookiesToResponse defines the list of cookies to copy from the authentication server response to the response.
AddAuthCookiesToResponse []string `json:"addAuthCookiesToResponse,omitempty"`
}
// ClientTLS holds the client TLS configuration.

View file

@ -215,6 +215,11 @@ func (in *ForwardAuth) DeepCopyInto(out *ForwardAuth) {
*out = new(ClientTLS)
**out = **in
}
if in.AddAuthCookiesToResponse != nil {
in, out := &in.AddAuthCookiesToResponse, &out.AddAuthCookiesToResponse
*out = make([]string, len(*in))
copy(*out, *in)
}
return
}