Send request body to authorization server for forward auth
This commit is contained in:
parent
b1934231ca
commit
26738cbf93
20 changed files with 411 additions and 38 deletions
|
@ -1,6 +1,7 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -22,13 +23,13 @@ import (
|
|||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
const typeNameForward = "ForwardAuth"
|
||||
|
||||
const (
|
||||
xForwardedURI = "X-Forwarded-Uri"
|
||||
xForwardedMethod = "X-Forwarded-Method"
|
||||
)
|
||||
|
||||
const typeNameForward = "ForwardAuth"
|
||||
|
||||
// hopHeaders Hop-by-hop headers to be removed in the authentication request.
|
||||
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
|
||||
// Proxy-Authorization header is forwarded to the authentication server (see https://tools.ietf.org/html/rfc7235#section-4.4).
|
||||
|
@ -52,6 +53,8 @@ type forwardAuth struct {
|
|||
authRequestHeaders []string
|
||||
addAuthCookiesToResponse map[string]struct{}
|
||||
headerField string
|
||||
forwardBody bool
|
||||
maxBodySize int64
|
||||
}
|
||||
|
||||
// NewForward creates a forward auth middleware.
|
||||
|
@ -73,6 +76,12 @@ func NewForward(ctx context.Context, next http.Handler, config dynamic.ForwardAu
|
|||
authRequestHeaders: config.AuthRequestHeaders,
|
||||
addAuthCookiesToResponse: addAuthCookiesToResponse,
|
||||
headerField: config.HeaderField,
|
||||
forwardBody: config.ForwardBody,
|
||||
maxBodySize: dynamic.ForwardAuthDefaultMaxBodySize,
|
||||
}
|
||||
|
||||
if config.MaxBodySize != nil {
|
||||
fa.maxBodySize = *config.MaxBodySize
|
||||
}
|
||||
|
||||
// Ensure our request client does not follow redirects
|
||||
|
@ -125,13 +134,37 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||
|
||||
forwardReq, err := http.NewRequestWithContext(req.Context(), http.MethodGet, fa.address, nil)
|
||||
if err != nil {
|
||||
logger.Debug().Msgf("Error calling %s. Cause %s", fa.address, err)
|
||||
logger.Debug().Err(err).Msgf("Error calling %s", fa.address)
|
||||
observability.SetStatusErrorf(req.Context(), "Error calling %s. Cause %s", fa.address, err)
|
||||
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if fa.forwardBody {
|
||||
bodyBytes, err := fa.readBodyBytes(req)
|
||||
if errors.Is(err, errBodyTooLarge) {
|
||||
logger.Debug().Msgf("Request body is too large, maxBodySize: %d", fa.maxBodySize)
|
||||
|
||||
observability.SetStatusErrorf(req.Context(), "Request body is too large, maxBodySize: %d", fa.maxBodySize)
|
||||
rw.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
logger.Debug().Err(err).Msg("Error while reading body")
|
||||
|
||||
observability.SetStatusErrorf(req.Context(), "Error while reading Body: %s", err)
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// bodyBytes is nil when the request has no body.
|
||||
if bodyBytes != nil {
|
||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
forwardReq.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
}
|
||||
}
|
||||
|
||||
writeHeader(req, forwardReq, fa.trustForwardHeader, fa.authRequestHeaders)
|
||||
|
||||
var forwardSpan trace.Span
|
||||
|
@ -149,7 +182,7 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||
|
||||
forwardResponse, forwardErr := fa.client.Do(forwardReq)
|
||||
if forwardErr != nil {
|
||||
logger.Debug().Msgf("Error calling %s. Cause: %s", fa.address, forwardErr)
|
||||
logger.Debug().Err(forwardErr).Msgf("Error calling %s", fa.address)
|
||||
observability.SetStatusErrorf(req.Context(), "Error calling %s. Cause: %s", fa.address, forwardErr)
|
||||
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
|
@ -159,7 +192,7 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||
|
||||
body, readError := io.ReadAll(forwardResponse.Body)
|
||||
if readError != nil {
|
||||
logger.Debug().Msgf("Error reading body %s. Cause: %s", fa.address, readError)
|
||||
logger.Debug().Err(readError).Msgf("Error reading body %s", fa.address)
|
||||
observability.SetStatusErrorf(req.Context(), "Error reading body %s. Cause: %s", fa.address, readError)
|
||||
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
|
@ -194,7 +227,7 @@ func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||
|
||||
if err != nil {
|
||||
if !errors.Is(err, http.ErrNoLocation) {
|
||||
logger.Debug().Msgf("Error reading response location header %s. Cause: %s", fa.address, err)
|
||||
logger.Debug().Err(err).Msgf("Error reading response location header %s", fa.address)
|
||||
observability.SetStatusErrorf(req.Context(), "Error reading response location header %s. Cause: %s", fa.address, err)
|
||||
|
||||
rw.WriteHeader(http.StatusInternalServerError)
|
||||
|
@ -270,6 +303,27 @@ func (fa *forwardAuth) buildModifier(authCookies []*http.Cookie) func(res *http.
|
|||
}
|
||||
}
|
||||
|
||||
var errBodyTooLarge = errors.New("request body too large")
|
||||
|
||||
func (fa *forwardAuth) readBodyBytes(req *http.Request) ([]byte, error) {
|
||||
if fa.maxBodySize < 0 {
|
||||
return io.ReadAll(req.Body)
|
||||
}
|
||||
|
||||
body := make([]byte, fa.maxBodySize+1)
|
||||
n, err := io.ReadFull(req.Body, body)
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return nil, fmt.Errorf("reading body bytes: %w", err)
|
||||
}
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return body[:n], nil
|
||||
}
|
||||
return nil, errBodyTooLarge
|
||||
}
|
||||
|
||||
func writeHeader(req, forwardReq *http.Request, trustForwardHeader bool, allowedHeaders []string) {
|
||||
utils.CopyHeaders(forwardReq.Header, req.Header)
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -112,6 +113,154 @@ func TestForwardAuthSuccess(t *testing.T) {
|
|||
assert.Equal(t, "traefik\n", string(body))
|
||||
}
|
||||
|
||||
func TestForwardAuthForwardBody(t *testing.T) {
|
||||
data := []byte("forwardBodyTest")
|
||||
|
||||
var serverCallCount int
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
serverCallCount++
|
||||
|
||||
forwardedData, err := io.ReadAll(req.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, data, forwardedData)
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
var nextCallCount int
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
nextCallCount++
|
||||
})
|
||||
|
||||
maxBodySize := int64(len(data))
|
||||
auth := dynamic.ForwardAuth{Address: server.URL, ForwardBody: true, MaxBodySize: &maxBodySize}
|
||||
|
||||
middleware, err := NewForward(context.Background(), next, auth, "authTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(middleware)
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, bytes.NewReader(data))
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Equal(t, 1, serverCallCount)
|
||||
assert.Equal(t, 1, nextCallCount)
|
||||
}
|
||||
|
||||
func TestForwardAuthForwardBodyEmptyBody(t *testing.T) {
|
||||
var serverCallCount int
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
serverCallCount++
|
||||
|
||||
forwardedData, err := io.ReadAll(req.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Empty(t, forwardedData)
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
var nextCallCount int
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
nextCallCount++
|
||||
})
|
||||
|
||||
auth := dynamic.ForwardAuth{Address: server.URL, ForwardBody: true}
|
||||
|
||||
middleware, err := NewForward(context.Background(), next, auth, "authTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(middleware)
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, http.NoBody)
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Equal(t, 1, serverCallCount)
|
||||
assert.Equal(t, 1, nextCallCount)
|
||||
}
|
||||
|
||||
func TestForwardAuthForwardBodySizeLimit(t *testing.T) {
|
||||
data := []byte("forwardBodyTest")
|
||||
|
||||
var serverCallCount int
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
serverCallCount++
|
||||
|
||||
forwardedData, err := io.ReadAll(req.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, data, forwardedData)
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
var nextCallCount int
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
nextCallCount++
|
||||
})
|
||||
|
||||
maxBodySize := int64(len(data)) - 1
|
||||
auth := dynamic.ForwardAuth{Address: server.URL, ForwardBody: true, MaxBodySize: &maxBodySize}
|
||||
|
||||
middleware, err := NewForward(context.Background(), next, auth, "authTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(middleware)
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, bytes.NewReader(data))
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
assert.Equal(t, 0, serverCallCount)
|
||||
assert.Equal(t, 0, nextCallCount)
|
||||
}
|
||||
|
||||
func TestForwardAuthNotForwardBody(t *testing.T) {
|
||||
data := []byte("forwardBodyTest")
|
||||
|
||||
var serverCallCount int
|
||||
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
serverCallCount++
|
||||
|
||||
forwardedData, err := io.ReadAll(req.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Empty(t, forwardedData)
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
var nextCallCount int
|
||||
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
nextCallCount++
|
||||
})
|
||||
|
||||
auth := dynamic.ForwardAuth{Address: server.URL}
|
||||
|
||||
middleware, err := NewForward(context.Background(), next, auth, "authTest")
|
||||
require.NoError(t, err)
|
||||
|
||||
ts := httptest.NewServer(middleware)
|
||||
t.Cleanup(ts.Close)
|
||||
|
||||
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, bytes.NewReader(data))
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Equal(t, 1, serverCallCount)
|
||||
assert.Equal(t, 1, nextCallCount)
|
||||
}
|
||||
|
||||
func TestForwardAuthRedirect(t *testing.T) {
|
||||
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "http://example.com/redirect-test", http.StatusFound)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue