1
0
Fork 0

Send request body to authorization server for forward auth

This commit is contained in:
kyosuke 2024-12-12 18:18:05 +09:00 committed by GitHub
parent b1934231ca
commit 26738cbf93
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 411 additions and 38 deletions

View file

@ -21,6 +21,11 @@ const (
// DefaultFlushInterval is the default value for the ResponseForwarding flush interval.
DefaultFlushInterval = ptypes.Duration(100 * time.Millisecond)
// MirroringDefaultMirrorBody is the Mirroring.MirrorBody option default value.
MirroringDefaultMirrorBody = true
// MirroringDefaultMaxBodySize is the Mirroring.MaxBodySize option default value.
MirroringDefaultMaxBodySize int64 = -1
)
// +k8s:deepcopy-gen=true
@ -100,9 +105,9 @@ type Mirroring struct {
// SetDefaults Default values for a WRRService.
func (m *Mirroring) SetDefaults() {
defaultMirrorBody := true
defaultMirrorBody := MirroringDefaultMirrorBody
m.MirrorBody = &defaultMirrorBody
var defaultMaxBodySize int64 = -1
defaultMaxBodySize := MirroringDefaultMaxBodySize
m.MaxBodySize = &defaultMaxBodySize
}

View file

@ -9,6 +9,9 @@ import (
"github.com/traefik/traefik/v3/pkg/ip"
)
// ForwardAuthDefaultMaxBodySize is the ForwardAuth.MaxBodySize option default value.
const ForwardAuthDefaultMaxBodySize int64 = -1
// +k8s:deepcopy-gen=true
// Middleware holds the Middleware configuration.
@ -251,6 +254,15 @@ type ForwardAuth struct {
// HeaderField defines a header field to store the authenticated user.
// More info: https://doc.traefik.io/traefik/v3.0/middlewares/http/forwardauth/#headerfield
HeaderField string `json:"headerField,omitempty" toml:"headerField,omitempty" yaml:"headerField,omitempty" export:"true"`
// ForwardBody defines whether to send the request body to the authentication server.
ForwardBody bool `json:"forwardBody,omitempty" toml:"forwardBody,omitempty" yaml:"forwardBody,omitempty" export:"true"`
// MaxBodySize defines the maximum body size in bytes allowed to be forwarded to the authentication server.
MaxBodySize *int64 `json:"maxBodySize,omitempty" toml:"maxBodySize,omitempty" yaml:"maxBodySize,omitempty" export:"true"`
}
func (f *ForwardAuth) SetDefaults() {
defaultMaxBodySize := ForwardAuthDefaultMaxBodySize
f.MaxBodySize = &defaultMaxBodySize
}
// +k8s:deepcopy-gen=true

View file

@ -370,6 +370,11 @@ func (in *ForwardAuth) DeepCopyInto(out *ForwardAuth) {
*out = make([]string, len(*in))
copy(*out, *in)
}
if in.MaxBodySize != nil {
in, out := &in.MaxBodySize, &out.MaxBodySize
*out = new(int64)
**out = **in
}
return
}

View file

@ -51,6 +51,8 @@ func TestDecodeConfiguration(t *testing.T) {
"traefik.http.middlewares.Middleware7.forwardauth.tls.insecureskipverify": "true",
"traefik.http.middlewares.Middleware7.forwardauth.tls.key": "foobar",
"traefik.http.middlewares.Middleware7.forwardauth.trustforwardheader": "true",
"traefik.http.middlewares.Middleware7.forwardauth.forwardbody": "true",
"traefik.http.middlewares.Middleware7.forwardauth.maxbodysize": "42",
"traefik.http.middlewares.Middleware8.headers.accesscontrolallowcredentials": "true",
"traefik.http.middlewares.Middleware8.headers.allowedhosts": "foobar, fiibar",
"traefik.http.middlewares.Middleware8.headers.accesscontrolallowheaders": "X-foobar, X-fiibar",
@ -572,6 +574,8 @@ func TestDecodeConfiguration(t *testing.T) {
"foobar",
"fiibar",
},
ForwardBody: true,
MaxBodySize: pointer(int64(42)),
},
},
"Middleware8": {
@ -1114,6 +1118,8 @@ func TestEncodeConfiguration(t *testing.T) {
"foobar",
"fiibar",
},
ForwardBody: true,
MaxBodySize: pointer(int64(42)),
},
},
"Middleware8": {
@ -1315,6 +1321,8 @@ func TestEncodeConfiguration(t *testing.T) {
"traefik.HTTP.Middlewares.Middleware7.ForwardAuth.Address": "foobar",
"traefik.HTTP.Middlewares.Middleware7.ForwardAuth.AuthResponseHeaders": "foobar, fiibar",
"traefik.HTTP.Middlewares.Middleware7.ForwardAuth.AuthRequestHeaders": "foobar, fiibar",
"traefik.HTTP.Middlewares.Middleware7.ForwardAuth.ForwardBody": "true",
"traefik.HTTP.Middlewares.Middleware7.ForwardAuth.MaxBodySize": "42",
"traefik.HTTP.Middlewares.Middleware7.ForwardAuth.TLS.CA": "foobar",
"traefik.HTTP.Middlewares.Middleware7.ForwardAuth.TLS.CAOptional": "true",
"traefik.HTTP.Middlewares.Middleware7.ForwardAuth.TLS.Cert": "foobar",

View file

@ -1,6 +1,7 @@
package auth
import (
"bytes"
"context"
"errors"
"fmt"
@ -22,13 +23,13 @@ import (
"go.opentelemetry.io/otel/trace"
)
const typeNameForward = "ForwardAuth"
const (
xForwardedURI = "X-Forwarded-Uri"
xForwardedMethod = "X-Forwarded-Method"
)
const typeNameForward = "ForwardAuth"
// hopHeaders Hop-by-hop headers to be removed in the authentication request.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
// Proxy-Authorization header is forwarded to the authentication server (see https://tools.ietf.org/html/rfc7235#section-4.4).
@ -52,6 +53,8 @@ type forwardAuth struct {
authRequestHeaders []string
addAuthCookiesToResponse map[string]struct{}
headerField string
forwardBody bool
maxBodySize int64
}
// NewForward creates a forward auth middleware.
@ -73,6 +76,12 @@ func NewForward(ctx context.Context, next http.Handler, config dynamic.ForwardAu
authRequestHeaders: config.AuthRequestHeaders,
addAuthCookiesToResponse: addAuthCookiesToResponse,
headerField: config.HeaderField,
forwardBody: config.ForwardBody,
maxBodySize: dynamic.ForwardAuthDefaultMaxBodySize,
}
if config.MaxBodySize != nil {
fa.maxBodySize = *config.MaxBodySize
}
// Ensure our request client does not follow redirects
@ -125,13 +134,37 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
forwardReq, err := http.NewRequestWithContext(req.Context(), http.MethodGet, fa.address, nil)
if err != nil {
logger.Debug().Msgf("Error calling %s. Cause %s", fa.address, err)
logger.Debug().Err(err).Msgf("Error calling %s", fa.address)
observability.SetStatusErrorf(req.Context(), "Error calling %s. Cause %s", fa.address, err)
rw.WriteHeader(http.StatusInternalServerError)
return
}
if fa.forwardBody {
bodyBytes, err := fa.readBodyBytes(req)
if errors.Is(err, errBodyTooLarge) {
logger.Debug().Msgf("Request body is too large, maxBodySize: %d", fa.maxBodySize)
observability.SetStatusErrorf(req.Context(), "Request body is too large, maxBodySize: %d", fa.maxBodySize)
rw.WriteHeader(http.StatusUnauthorized)
return
}
if err != nil {
logger.Debug().Err(err).Msg("Error while reading body")
observability.SetStatusErrorf(req.Context(), "Error while reading Body: %s", err)
rw.WriteHeader(http.StatusInternalServerError)
return
}
// bodyBytes is nil when the request has no body.
if bodyBytes != nil {
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
forwardReq.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
}
writeHeader(req, forwardReq, fa.trustForwardHeader, fa.authRequestHeaders)
var forwardSpan trace.Span
@ -149,7 +182,7 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
forwardResponse, forwardErr := fa.client.Do(forwardReq)
if forwardErr != nil {
logger.Debug().Msgf("Error calling %s. Cause: %s", fa.address, forwardErr)
logger.Debug().Err(forwardErr).Msgf("Error calling %s", fa.address)
observability.SetStatusErrorf(req.Context(), "Error calling %s. Cause: %s", fa.address, forwardErr)
rw.WriteHeader(http.StatusInternalServerError)
@ -159,7 +192,7 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
body, readError := io.ReadAll(forwardResponse.Body)
if readError != nil {
logger.Debug().Msgf("Error reading body %s. Cause: %s", fa.address, readError)
logger.Debug().Err(readError).Msgf("Error reading body %s", fa.address)
observability.SetStatusErrorf(req.Context(), "Error reading body %s. Cause: %s", fa.address, readError)
rw.WriteHeader(http.StatusInternalServerError)
@ -194,7 +227,7 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if err != nil {
if !errors.Is(err, http.ErrNoLocation) {
logger.Debug().Msgf("Error reading response location header %s. Cause: %s", fa.address, err)
logger.Debug().Err(err).Msgf("Error reading response location header %s", fa.address)
observability.SetStatusErrorf(req.Context(), "Error reading response location header %s. Cause: %s", fa.address, err)
rw.WriteHeader(http.StatusInternalServerError)
@ -270,6 +303,27 @@ func (fa *forwardAuth) buildModifier(authCookies []*http.Cookie) func(res *http.
}
}
var errBodyTooLarge = errors.New("request body too large")
func (fa *forwardAuth) readBodyBytes(req *http.Request) ([]byte, error) {
if fa.maxBodySize < 0 {
return io.ReadAll(req.Body)
}
body := make([]byte, fa.maxBodySize+1)
n, err := io.ReadFull(req.Body, body)
if errors.Is(err, io.EOF) {
return nil, nil
}
if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) {
return nil, fmt.Errorf("reading body bytes: %w", err)
}
if errors.Is(err, io.ErrUnexpectedEOF) {
return body[:n], nil
}
return nil, errBodyTooLarge
}
func writeHeader(req, forwardReq *http.Request, trustForwardHeader bool, allowedHeaders []string) {
utils.CopyHeaders(forwardReq.Header, req.Header)

View file

@ -1,6 +1,7 @@
package auth
import (
"bytes"
"context"
"fmt"
"io"
@ -112,6 +113,154 @@ func TestForwardAuthSuccess(t *testing.T) {
assert.Equal(t, "traefik\n", string(body))
}
func TestForwardAuthForwardBody(t *testing.T) {
data := []byte("forwardBodyTest")
var serverCallCount int
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
serverCallCount++
forwardedData, err := io.ReadAll(req.Body)
require.NoError(t, err)
assert.Equal(t, data, forwardedData)
}))
t.Cleanup(server.Close)
var nextCallCount int
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
nextCallCount++
})
maxBodySize := int64(len(data))
auth := dynamic.ForwardAuth{Address: server.URL, ForwardBody: true, MaxBodySize: &maxBodySize}
middleware, err := NewForward(context.Background(), next, auth, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
t.Cleanup(ts.Close)
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, bytes.NewReader(data))
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, 1, serverCallCount)
assert.Equal(t, 1, nextCallCount)
}
func TestForwardAuthForwardBodyEmptyBody(t *testing.T) {
var serverCallCount int
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
serverCallCount++
forwardedData, err := io.ReadAll(req.Body)
require.NoError(t, err)
assert.Empty(t, forwardedData)
}))
t.Cleanup(server.Close)
var nextCallCount int
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
nextCallCount++
})
auth := dynamic.ForwardAuth{Address: server.URL, ForwardBody: true}
middleware, err := NewForward(context.Background(), next, auth, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
t.Cleanup(ts.Close)
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, http.NoBody)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, 1, serverCallCount)
assert.Equal(t, 1, nextCallCount)
}
func TestForwardAuthForwardBodySizeLimit(t *testing.T) {
data := []byte("forwardBodyTest")
var serverCallCount int
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
serverCallCount++
forwardedData, err := io.ReadAll(req.Body)
require.NoError(t, err)
assert.Equal(t, data, forwardedData)
}))
t.Cleanup(server.Close)
var nextCallCount int
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
nextCallCount++
})
maxBodySize := int64(len(data)) - 1
auth := dynamic.ForwardAuth{Address: server.URL, ForwardBody: true, MaxBodySize: &maxBodySize}
middleware, err := NewForward(context.Background(), next, auth, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
t.Cleanup(ts.Close)
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, bytes.NewReader(data))
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
assert.Equal(t, 0, serverCallCount)
assert.Equal(t, 0, nextCallCount)
}
func TestForwardAuthNotForwardBody(t *testing.T) {
data := []byte("forwardBodyTest")
var serverCallCount int
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
serverCallCount++
forwardedData, err := io.ReadAll(req.Body)
require.NoError(t, err)
assert.Empty(t, forwardedData)
}))
t.Cleanup(server.Close)
var nextCallCount int
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
nextCallCount++
})
auth := dynamic.ForwardAuth{Address: server.URL}
middleware, err := NewForward(context.Background(), next, auth, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
t.Cleanup(ts.Close)
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, bytes.NewReader(data))
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, 1, serverCallCount)
assert.Equal(t, 1, nextCallCount)
}
func TestForwardAuthRedirect(t *testing.T) {
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "http://example.com/redirect-test", http.StatusFound)

View file

@ -789,34 +789,38 @@ func createForwardAuthMiddleware(k8sClient Client, namespace string, auth *traef
AuthResponseHeadersRegex: auth.AuthResponseHeadersRegex,
AuthRequestHeaders: auth.AuthRequestHeaders,
AddAuthCookiesToResponse: auth.AddAuthCookiesToResponse,
ForwardBody: auth.ForwardBody,
}
forwardAuth.SetDefaults()
if auth.MaxBodySize != nil {
forwardAuth.MaxBodySize = auth.MaxBodySize
}
if auth.TLS == nil {
return forwardAuth, nil
}
forwardAuth.TLS = &dynamic.ClientTLS{
InsecureSkipVerify: auth.TLS.InsecureSkipVerify,
}
if len(auth.TLS.CASecret) > 0 {
caSecret, err := loadCASecret(namespace, auth.TLS.CASecret, k8sClient)
if err != nil {
return nil, fmt.Errorf("failed to load auth ca secret: %w", err)
if auth.TLS != nil {
forwardAuth.TLS = &dynamic.ClientTLS{
InsecureSkipVerify: auth.TLS.InsecureSkipVerify,
}
forwardAuth.TLS.CA = caSecret
}
if len(auth.TLS.CertSecret) > 0 {
authSecretCert, authSecretKey, err := loadAuthTLSSecret(namespace, auth.TLS.CertSecret, k8sClient)
if err != nil {
return nil, fmt.Errorf("failed to load auth secret: %w", err)
if len(auth.TLS.CASecret) > 0 {
caSecret, err := loadCASecret(namespace, auth.TLS.CASecret, k8sClient)
if err != nil {
return nil, fmt.Errorf("failed to load auth ca secret: %w", err)
}
forwardAuth.TLS.CA = caSecret
}
forwardAuth.TLS.Cert = authSecretCert
forwardAuth.TLS.Key = authSecretKey
}
forwardAuth.TLS.CAOptional = auth.TLS.CAOptional
if len(auth.TLS.CertSecret) > 0 {
authSecretCert, authSecretKey, err := loadAuthTLSSecret(namespace, auth.TLS.CertSecret, k8sClient)
if err != nil {
return nil, fmt.Errorf("failed to load auth secret: %w", err)
}
forwardAuth.TLS.Cert = authSecretCert
forwardAuth.TLS.Key = authSecretKey
}
forwardAuth.TLS.CAOptional = auth.TLS.CAOptional
}
return forwardAuth, nil
}

View file

@ -3915,7 +3915,8 @@ func TestLoadIngressRoutes(t *testing.T) {
},
"default-forwardauth": {
ForwardAuth: &dynamic.ForwardAuth{
Address: "test.com",
Address: "test.com",
MaxBodySize: pointer(int64(-1)),
TLS: &dynamic.ClientTLS{
CA: "-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE-----",
Cert: "-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE-----",

View file

@ -161,6 +161,10 @@ type ForwardAuth struct {
TLS *ClientTLS `json:"tls,omitempty"`
// AddAuthCookiesToResponse defines the list of cookies to copy from the authentication server response to the response.
AddAuthCookiesToResponse []string `json:"addAuthCookiesToResponse,omitempty"`
// ForwardBody defines whether to send the request body to the authentication server.
ForwardBody bool `json:"forwardBody,omitempty"`
// MaxBodySize defines the maximum body size in bytes allowed to be forwarded to the authentication server.
MaxBodySize *int64 `json:"maxBodySize,omitempty"`
}
// ClientTLS holds the client TLS configuration.

View file

@ -266,6 +266,11 @@ func (in *ForwardAuth) DeepCopyInto(out *ForwardAuth) {
*out = make([]string, len(*in))
copy(*out, *in)
}
if in.MaxBodySize != nil {
in, out := &in.MaxBodySize, &out.MaxBodySize
*out = new(int64)
**out = **in
}
return
}

View file

@ -88,6 +88,8 @@ func Test_buildConfiguration(t *testing.T) {
"traefik/http/middlewares/Middleware08/forwardAuth/tls/cert": "foobar",
"traefik/http/middlewares/Middleware08/forwardAuth/address": "foobar",
"traefik/http/middlewares/Middleware08/forwardAuth/trustForwardHeader": "true",
"traefik/http/middlewares/Middleware08/forwardAuth/forwardBody": "true",
"traefik/http/middlewares/Middleware08/forwardAuth/maxBodySize": "42",
"traefik/http/middlewares/Middleware15/redirectScheme/scheme": "foobar",
"traefik/http/middlewares/Middleware15/redirectScheme/port": "foobar",
"traefik/http/middlewares/Middleware15/redirectScheme/permanent": "true",
@ -440,6 +442,8 @@ func Test_buildConfiguration(t *testing.T) {
"foobar",
"foobar",
},
ForwardBody: true,
MaxBodySize: pointer(int64(42)),
},
},
"Middleware06": {

View file

@ -35,11 +35,6 @@ import (
"google.golang.org/grpc/status"
)
const (
defaultMirrorBody = true
defaultMaxBodySize int64 = -1
)
// ProxyBuilder builds reverse proxy handlers.
type ProxyBuilder interface {
Build(cfgName string, targetURL *url.URL, shouldObserve, passHostHeader, preservePath bool, flushInterval time.Duration) (http.Handler, error)
@ -221,12 +216,12 @@ func (m *Manager) getMirrorServiceHandler(ctx context.Context, config *dynamic.M
return nil, err
}
mirrorBody := defaultMirrorBody
mirrorBody := dynamic.MirroringDefaultMirrorBody
if config.MirrorBody != nil {
mirrorBody = *config.MirrorBody
}
maxBodySize := defaultMaxBodySize
maxBodySize := dynamic.MirroringDefaultMaxBodySize
if config.MaxBodySize != nil {
maxBodySize = *config.MaxBodySize
}