Send request body to authorization server for forward auth

This commit is contained in:
kyosuke 2024-12-12 18:18:05 +09:00 committed by GitHub
parent b1934231ca
commit 26738cbf93
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 411 additions and 38 deletions

View file

@ -1,6 +1,7 @@
package auth
import (
"bytes"
"context"
"errors"
"fmt"
@ -22,13 +23,13 @@ import (
"go.opentelemetry.io/otel/trace"
)
const typeNameForward = "ForwardAuth"
const (
xForwardedURI = "X-Forwarded-Uri"
xForwardedMethod = "X-Forwarded-Method"
)
const typeNameForward = "ForwardAuth"
// hopHeaders Hop-by-hop headers to be removed in the authentication request.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
// Proxy-Authorization header is forwarded to the authentication server (see https://tools.ietf.org/html/rfc7235#section-4.4).
@ -52,6 +53,8 @@ type forwardAuth struct {
authRequestHeaders []string
addAuthCookiesToResponse map[string]struct{}
headerField string
forwardBody bool
maxBodySize int64
}
// NewForward creates a forward auth middleware.
@ -73,6 +76,12 @@ func NewForward(ctx context.Context, next http.Handler, config dynamic.ForwardAu
authRequestHeaders: config.AuthRequestHeaders,
addAuthCookiesToResponse: addAuthCookiesToResponse,
headerField: config.HeaderField,
forwardBody: config.ForwardBody,
maxBodySize: dynamic.ForwardAuthDefaultMaxBodySize,
}
if config.MaxBodySize != nil {
fa.maxBodySize = *config.MaxBodySize
}
// Ensure our request client does not follow redirects
@ -125,13 +134,37 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
forwardReq, err := http.NewRequestWithContext(req.Context(), http.MethodGet, fa.address, nil)
if err != nil {
logger.Debug().Msgf("Error calling %s. Cause %s", fa.address, err)
logger.Debug().Err(err).Msgf("Error calling %s", fa.address)
observability.SetStatusErrorf(req.Context(), "Error calling %s. Cause %s", fa.address, err)
rw.WriteHeader(http.StatusInternalServerError)
return
}
if fa.forwardBody {
bodyBytes, err := fa.readBodyBytes(req)
if errors.Is(err, errBodyTooLarge) {
logger.Debug().Msgf("Request body is too large, maxBodySize: %d", fa.maxBodySize)
observability.SetStatusErrorf(req.Context(), "Request body is too large, maxBodySize: %d", fa.maxBodySize)
rw.WriteHeader(http.StatusUnauthorized)
return
}
if err != nil {
logger.Debug().Err(err).Msg("Error while reading body")
observability.SetStatusErrorf(req.Context(), "Error while reading Body: %s", err)
rw.WriteHeader(http.StatusInternalServerError)
return
}
// bodyBytes is nil when the request has no body.
if bodyBytes != nil {
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
forwardReq.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}
}
writeHeader(req, forwardReq, fa.trustForwardHeader, fa.authRequestHeaders)
var forwardSpan trace.Span
@ -149,7 +182,7 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
forwardResponse, forwardErr := fa.client.Do(forwardReq)
if forwardErr != nil {
logger.Debug().Msgf("Error calling %s. Cause: %s", fa.address, forwardErr)
logger.Debug().Err(forwardErr).Msgf("Error calling %s", fa.address)
observability.SetStatusErrorf(req.Context(), "Error calling %s. Cause: %s", fa.address, forwardErr)
rw.WriteHeader(http.StatusInternalServerError)
@ -159,7 +192,7 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
body, readError := io.ReadAll(forwardResponse.Body)
if readError != nil {
logger.Debug().Msgf("Error reading body %s. Cause: %s", fa.address, readError)
logger.Debug().Err(readError).Msgf("Error reading body %s", fa.address)
observability.SetStatusErrorf(req.Context(), "Error reading body %s. Cause: %s", fa.address, readError)
rw.WriteHeader(http.StatusInternalServerError)
@ -194,7 +227,7 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if err != nil {
if !errors.Is(err, http.ErrNoLocation) {
logger.Debug().Msgf("Error reading response location header %s. Cause: %s", fa.address, err)
logger.Debug().Err(err).Msgf("Error reading response location header %s", fa.address)
observability.SetStatusErrorf(req.Context(), "Error reading response location header %s. Cause: %s", fa.address, err)
rw.WriteHeader(http.StatusInternalServerError)
@ -270,6 +303,27 @@ func (fa *forwardAuth) buildModifier(authCookies []*http.Cookie) func(res *http.
}
}
var errBodyTooLarge = errors.New("request body too large")
func (fa *forwardAuth) readBodyBytes(req *http.Request) ([]byte, error) {
if fa.maxBodySize < 0 {
return io.ReadAll(req.Body)
}
body := make([]byte, fa.maxBodySize+1)
n, err := io.ReadFull(req.Body, body)
if errors.Is(err, io.EOF) {
return nil, nil
}
if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) {
return nil, fmt.Errorf("reading body bytes: %w", err)
}
if errors.Is(err, io.ErrUnexpectedEOF) {
return body[:n], nil
}
return nil, errBodyTooLarge
}
func writeHeader(req, forwardReq *http.Request, trustForwardHeader bool, allowedHeaders []string) {
utils.CopyHeaders(forwardReq.Header, req.Header)

View file

@ -1,6 +1,7 @@
package auth
import (
"bytes"
"context"
"fmt"
"io"
@ -112,6 +113,154 @@ func TestForwardAuthSuccess(t *testing.T) {
assert.Equal(t, "traefik\n", string(body))
}
func TestForwardAuthForwardBody(t *testing.T) {
data := []byte("forwardBodyTest")
var serverCallCount int
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
serverCallCount++
forwardedData, err := io.ReadAll(req.Body)
require.NoError(t, err)
assert.Equal(t, data, forwardedData)
}))
t.Cleanup(server.Close)
var nextCallCount int
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
nextCallCount++
})
maxBodySize := int64(len(data))
auth := dynamic.ForwardAuth{Address: server.URL, ForwardBody: true, MaxBodySize: &maxBodySize}
middleware, err := NewForward(context.Background(), next, auth, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
t.Cleanup(ts.Close)
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, bytes.NewReader(data))
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, 1, serverCallCount)
assert.Equal(t, 1, nextCallCount)
}
func TestForwardAuthForwardBodyEmptyBody(t *testing.T) {
var serverCallCount int
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
serverCallCount++
forwardedData, err := io.ReadAll(req.Body)
require.NoError(t, err)
assert.Empty(t, forwardedData)
}))
t.Cleanup(server.Close)
var nextCallCount int
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
nextCallCount++
})
auth := dynamic.ForwardAuth{Address: server.URL, ForwardBody: true}
middleware, err := NewForward(context.Background(), next, auth, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
t.Cleanup(ts.Close)
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, http.NoBody)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, 1, serverCallCount)
assert.Equal(t, 1, nextCallCount)
}
func TestForwardAuthForwardBodySizeLimit(t *testing.T) {
data := []byte("forwardBodyTest")
var serverCallCount int
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
serverCallCount++
forwardedData, err := io.ReadAll(req.Body)
require.NoError(t, err)
assert.Equal(t, data, forwardedData)
}))
t.Cleanup(server.Close)
var nextCallCount int
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
nextCallCount++
})
maxBodySize := int64(len(data)) - 1
auth := dynamic.ForwardAuth{Address: server.URL, ForwardBody: true, MaxBodySize: &maxBodySize}
middleware, err := NewForward(context.Background(), next, auth, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
t.Cleanup(ts.Close)
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, bytes.NewReader(data))
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
assert.Equal(t, 0, serverCallCount)
assert.Equal(t, 0, nextCallCount)
}
func TestForwardAuthNotForwardBody(t *testing.T) {
data := []byte("forwardBodyTest")
var serverCallCount int
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
serverCallCount++
forwardedData, err := io.ReadAll(req.Body)
require.NoError(t, err)
assert.Empty(t, forwardedData)
}))
t.Cleanup(server.Close)
var nextCallCount int
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
nextCallCount++
})
auth := dynamic.ForwardAuth{Address: server.URL}
middleware, err := NewForward(context.Background(), next, auth, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
t.Cleanup(ts.Close)
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, bytes.NewReader(data))
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, 1, serverCallCount)
assert.Equal(t, 1, nextCallCount)
}
func TestForwardAuthRedirect(t *testing.T) {
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "http://example.com/redirect-test", http.StatusFound)