1
0
Fork 0

Move code to pkg

This commit is contained in:
Ludovic Fernandez 2019-03-15 09:42:03 +01:00 committed by Traefiker Bot
parent bd4c822670
commit f1b085fa36
465 changed files with 656 additions and 680 deletions

View file

@ -0,0 +1,65 @@
package auth
import (
"io/ioutil"
"strings"
)
// UserParser Parses a string and return a userName/userHash. An error if the format of the string is incorrect.
type UserParser func(user string) (string, string, error)
const (
defaultRealm = "traefik"
authorizationHeader = "Authorization"
)
func getUsers(fileName string, appendUsers []string, parser UserParser) (map[string]string, error) {
users, err := loadUsers(fileName, appendUsers)
if err != nil {
return nil, err
}
userMap := make(map[string]string)
for _, user := range users {
userName, userHash, err := parser(user)
if err != nil {
return nil, err
}
userMap[userName] = userHash
}
return userMap, nil
}
func loadUsers(fileName string, appendUsers []string) ([]string, error) {
var users []string
var err error
if fileName != "" {
users, err = getLinesFromFile(fileName)
if err != nil {
return nil, err
}
}
return append(users, appendUsers...), nil
}
func getLinesFromFile(filename string) ([]string, error) {
dat, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
// Trim lines and filter out blanks
rawLines := strings.Split(string(dat), "\n")
var filteredLines []string
for _, rawLine := range rawLines {
line := strings.TrimSpace(rawLine)
if line != "" && !strings.HasPrefix(line, "#") {
filteredLines = append(filteredLines, line)
}
}
return filteredLines, nil
}

View file

@ -0,0 +1,102 @@
package auth
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
goauth "github.com/abbot/go-http-auth"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/middlewares/accesslog"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
)
const (
basicTypeName = "BasicAuth"
)
type basicAuth struct {
next http.Handler
auth *goauth.BasicAuth
users map[string]string
headerField string
removeHeader bool
name string
}
// NewBasic creates a basicAuth middleware.
func NewBasic(ctx context.Context, next http.Handler, authConfig config.BasicAuth, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, basicTypeName).Debug("Creating middleware")
users, err := getUsers(authConfig.UsersFile, authConfig.Users, basicUserParser)
if err != nil {
return nil, err
}
ba := &basicAuth{
next: next,
users: users,
headerField: authConfig.HeaderField,
removeHeader: authConfig.RemoveHeader,
name: name,
}
realm := defaultRealm
if len(authConfig.Realm) > 0 {
realm = authConfig.Realm
}
ba.auth = goauth.NewBasicAuthenticator(realm, ba.secretBasic)
return ba, nil
}
func (b *basicAuth) GetTracingInformation() (string, ext.SpanKindEnum) {
return b.name, tracing.SpanKindNoneEnum
}
func (b *basicAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), b.name, basicTypeName)
if username := b.auth.CheckAuth(req); username == "" {
logger.Debug("Authentication failed")
tracing.SetErrorWithEvent(req, "Authentication failed")
b.auth.RequireAuth(rw, req)
} else {
logger.Debug("Authentication succeeded")
req.URL.User = url.User(username)
logData := accesslog.GetLogData(req)
if logData != nil {
logData.Core[accesslog.ClientUsername] = username
}
if b.headerField != "" {
req.Header[b.headerField] = []string{username}
}
if b.removeHeader {
logger.Debug("Removing authorization header")
req.Header.Del(authorizationHeader)
}
b.next.ServeHTTP(rw, req)
}
}
func (b *basicAuth) secretBasic(user, realm string) string {
if secret, ok := b.users[user]; ok {
return secret
}
return ""
}
func basicUserParser(user string) (string, string, error) {
split := strings.Split(user, ":")
if len(split) != 2 {
return "", "", fmt.Errorf("error parsing BasicUser: %v", user)
}
return split[0], split[1], nil
}

View file

@ -0,0 +1,284 @@
package auth
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBasicAuthFail(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.BasicAuth{
Users: []string{"test"},
}
_, err := NewBasic(context.Background(), next, auth, "authName")
require.Error(t, err)
auth2 := config.BasicAuth{
Users: []string{"test:test"},
}
authMiddleware, err := NewBasic(context.Background(), next, auth2, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, res.StatusCode, "they should be equal")
}
func TestBasicAuthSuccess(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.BasicAuth{
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
}
authMiddleware, err := NewBasic(context.Background(), next, auth, "authName")
require.NoError(t, err)
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode, "they should be equal")
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, "traefik\n", string(body), "they should be equal")
}
func TestBasicAuthUserHeader(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "test", r.Header["X-Webauth-User"][0], "auth user should be set")
fmt.Fprintln(w, "traefik")
})
auth := config.BasicAuth{
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
HeaderField: "X-Webauth-User",
}
middleware, err := NewBasic(context.Background(), next, auth, "authName")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, "traefik\n", string(body))
}
func TestBasicAuthHeaderRemoved(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Empty(t, r.Header.Get(authorizationHeader))
fmt.Fprintln(w, "traefik")
})
auth := config.BasicAuth{
RemoveHeader: true,
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
}
middleware, err := NewBasic(context.Background(), next, auth, "authName")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.Equal(t, "traefik\n", string(body))
}
func TestBasicAuthHeaderPresent(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.NotEmpty(t, r.Header.Get(authorizationHeader))
fmt.Fprintln(w, "traefik")
})
auth := config.BasicAuth{
Users: []string{"test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/"},
}
middleware, err := NewBasic(context.Background(), next, auth, "authName")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.Equal(t, "traefik\n", string(body))
}
func TestBasicAuthUsersFromFile(t *testing.T) {
testCases := []struct {
desc string
userFileContent string
expectedUsers map[string]string
givenUsers []string
realm string
}{
{
desc: "Finds the users in the file",
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
givenUsers: []string{},
expectedUsers: map[string]string{"test": "test", "test2": "test2"},
},
{
desc: "Merges given users with users from the file",
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\n",
givenUsers: []string{"test2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0", "test3:$apr1$3rJbDP0q$RfzJiorTk78jQ1EcKqWso0"},
expectedUsers: map[string]string{"test": "test", "test2": "test2", "test3": "test3"},
},
{
desc: "Given users have priority over users in the file",
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
givenUsers: []string{"test2:$apr1$mK.GtItK$ncnLYvNLek0weXdxo68690"},
expectedUsers: map[string]string{"test": "test", "test2": "overridden"},
},
{
desc: "Should authenticate the correct user based on the realm",
userFileContent: "test:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
givenUsers: []string{"test2:$apr1$mK.GtItK$ncnLYvNLek0weXdxo68690"},
expectedUsers: map[string]string{"test": "test", "test2": "overridden"},
realm: "traefik",
},
{
desc: "Should skip comments",
userFileContent: "#Comment\ntest:$apr1$H6uskkkW$IgXLP6ewTrSuBkTrqE8wj/\ntest2:$apr1$d9hr9HBB$4HxwgUir3HP4EsggP/QNo0\n",
givenUsers: []string{},
expectedUsers: map[string]string{"test": "test", "test2": "test2"},
realm: "traefiker",
},
}
for _, test := range testCases {
if test.desc != "Should skip comments" {
continue
}
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
// Creates the temporary configuration file with the users
usersFile, err := ioutil.TempFile("", "auth-users")
require.NoError(t, err)
defer os.Remove(usersFile.Name())
_, err = usersFile.Write([]byte(test.userFileContent))
require.NoError(t, err)
// Creates the configuration for our Authenticator
authenticatorConfiguration := config.BasicAuth{
Users: test.givenUsers,
UsersFile: usersFile.Name(),
Realm: test.realm,
}
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
authenticator, err := NewBasic(context.Background(), next, authenticatorConfiguration, "authName")
require.NoError(t, err)
ts := httptest.NewServer(authenticator)
defer ts.Close()
for userName, userPwd := range test.expectedUsers {
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth(userName, userPwd)
var res *http.Response
res, err = http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode, "Cannot authenticate user "+userName)
var body []byte
body, err = ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
require.Equal(t, "traefik\n", string(body))
}
// Checks that user foo doesn't work
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("foo", "foo")
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
if len(test.realm) > 0 {
require.Equal(t, `Basic realm="`+test.realm+`"`, res.Header.Get("WWW-Authenticate"))
}
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
require.NotContains(t, "traefik", string(body))
})
}
}

View file

@ -0,0 +1,102 @@
package auth
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
goauth "github.com/abbot/go-http-auth"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/middlewares/accesslog"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
)
const (
digestTypeName = "digestAuth"
)
type digestAuth struct {
next http.Handler
auth *goauth.DigestAuth
users map[string]string
headerField string
removeHeader bool
name string
}
// NewDigest creates a digest auth middleware.
func NewDigest(ctx context.Context, next http.Handler, authConfig config.DigestAuth, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, digestTypeName).Debug("Creating middleware")
users, err := getUsers(authConfig.UsersFile, authConfig.Users, digestUserParser)
if err != nil {
return nil, err
}
da := &digestAuth{
next: next,
users: users,
headerField: authConfig.HeaderField,
removeHeader: authConfig.RemoveHeader,
name: name,
}
realm := defaultRealm
if len(authConfig.Realm) > 0 {
realm = authConfig.Realm
}
da.auth = goauth.NewDigestAuthenticator(realm, da.secretDigest)
return da, nil
}
func (d *digestAuth) GetTracingInformation() (string, ext.SpanKindEnum) {
return d.name, tracing.SpanKindNoneEnum
}
func (d *digestAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), d.name, digestTypeName)
if username, _ := d.auth.CheckAuth(req); username == "" {
logger.Debug("Digest authentication failed")
tracing.SetErrorWithEvent(req, "Digest authentication failed")
d.auth.RequireAuth(rw, req)
} else {
logger.Debug("Digest authentication succeeded")
req.URL.User = url.User(username)
logData := accesslog.GetLogData(req)
if logData != nil {
logData.Core[accesslog.ClientUsername] = username
}
if d.headerField != "" {
req.Header[d.headerField] = []string{username}
}
if d.removeHeader {
logger.Debug("Removing the Authorization header")
req.Header.Del(authorizationHeader)
}
d.next.ServeHTTP(rw, req)
}
}
func (d *digestAuth) secretDigest(user, realm string) string {
if secret, ok := d.users[user+":"+realm]; ok {
return secret
}
return ""
}
func digestUserParser(user string) (string, string, error) {
split := strings.Split(user, ":")
if len(split) != 3 {
return "", "", fmt.Errorf("error parsing DigestUser: %v", user)
}
return split[0] + ":" + split[1], split[2], nil
}

View file

@ -0,0 +1,141 @@
package auth
import (
"crypto/md5"
"crypto/rand"
"encoding/hex"
"fmt"
"io"
"net/http"
"strings"
)
const (
algorithm = "algorithm"
authorization = "Authorization"
nonce = "nonce"
opaque = "opaque"
qop = "qop"
realm = "realm"
wwwAuthenticate = "Www-Authenticate"
)
// DigestRequest is a client for digest authentication requests
type digestRequest struct {
client *http.Client
username, password string
nonceCount nonceCount
}
type nonceCount int
func (nc nonceCount) String() string {
return fmt.Sprintf("%08x", int(nc))
}
var wanted = []string{algorithm, nonce, opaque, qop, realm}
// New makes a DigestRequest instance
func newDigestRequest(username, password string, client *http.Client) *digestRequest {
return &digestRequest{
client: client,
username: username,
password: password,
}
}
// Do does requests as http.Do does
func (r *digestRequest) Do(req *http.Request) (*http.Response, error) {
parts, err := r.makeParts(req)
if err != nil {
return nil, err
}
if parts != nil {
req.Header.Set(authorization, r.makeAuthorization(req, parts))
}
return r.client.Do(req)
}
func (r *digestRequest) makeParts(req *http.Request) (map[string]string, error) {
authReq, err := http.NewRequest(req.Method, req.URL.String(), nil)
if err != nil {
return nil, err
}
resp, err := r.client.Do(authReq)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
return nil, nil
}
if len(resp.Header[wwwAuthenticate]) == 0 {
return nil, fmt.Errorf("headers do not have %s", wwwAuthenticate)
}
headers := strings.Split(resp.Header[wwwAuthenticate][0], ",")
parts := make(map[string]string, len(wanted))
for _, r := range headers {
for _, w := range wanted {
if strings.Contains(r, w) {
parts[w] = strings.Split(r, `"`)[1]
}
}
}
if len(parts) != len(wanted) {
return nil, fmt.Errorf("header is invalid: %+v", parts)
}
return parts, nil
}
func getMD5(texts []string) string {
h := md5.New()
_, _ = io.WriteString(h, strings.Join(texts, ":"))
return hex.EncodeToString(h.Sum(nil))
}
func (r *digestRequest) getNonceCount() string {
r.nonceCount++
return r.nonceCount.String()
}
func (r *digestRequest) makeAuthorization(req *http.Request, parts map[string]string) string {
ha1 := getMD5([]string{r.username, parts[realm], r.password})
ha2 := getMD5([]string{req.Method, req.URL.String()})
cnonce := generateRandom(16)
nc := r.getNonceCount()
response := getMD5([]string{
ha1,
parts[nonce],
nc,
cnonce,
parts[qop],
ha2,
})
return fmt.Sprintf(
`Digest username="%s", realm="%s", nonce="%s", uri="%s", algorithm=%s, qop=%s, nc=%s, cnonce="%s", response="%s", opaque="%s"`,
r.username,
parts[realm],
parts[nonce],
req.URL.String(),
parts[algorithm],
parts[qop],
nc,
cnonce,
response,
parts[opaque],
)
}
// GenerateRandom generates random string
func generateRandom(n int) string {
b := make([]byte, 8)
_, _ = io.ReadFull(rand.Reader, b)
return fmt.Sprintf("%x", b)[:n]
}

View file

@ -0,0 +1,156 @@
package auth
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestDigestAuthError(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.DigestAuth{
Users: []string{"test"},
}
_, err := NewDigest(context.Background(), next, auth, "authName")
assert.Error(t, err)
}
func TestDigestAuthFail(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.DigestAuth{
Users: []string{"test:traefik:a2688e031edb4be6a3797f3882655c05"},
}
authMiddleware, err := NewDigest(context.Background(), next, auth, "authName")
require.NoError(t, err)
assert.NotNil(t, authMiddleware, "this should not be nil")
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
client := http.DefaultClient
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
req.SetBasicAuth("test", "test")
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
}
func TestDigestAuthUsersFromFile(t *testing.T) {
testCases := []struct {
desc string
userFileContent string
expectedUsers map[string]string
givenUsers []string
realm string
}{
{
desc: "Finds the users in the file",
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\ntest2:traefik:518845800f9e2bfb1f1f740ec24f074e\n",
givenUsers: []string{},
expectedUsers: map[string]string{"test": "test", "test2": "test2"},
},
{
desc: "Merges given users with users from the file",
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\n",
givenUsers: []string{"test2:traefik:518845800f9e2bfb1f1f740ec24f074e", "test3:traefik:c8e9f57ce58ecb4424407f665a91646c"},
expectedUsers: map[string]string{"test": "test", "test2": "test2", "test3": "test3"},
},
{
desc: "Given users have priority over users in the file",
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\ntest2:traefik:518845800f9e2bfb1f1f740ec24f074e\n",
givenUsers: []string{"test2:traefik:8de60a1c52da68ccf41f0c0ffb7c51a0"},
expectedUsers: map[string]string{"test": "test", "test2": "overridden"},
},
{
desc: "Should authenticate the correct user based on the realm",
userFileContent: "test:traefik:a2688e031edb4be6a3797f3882655c05\ntest:traefiker:a3d334dff2645b914918de78bec50bf4\n",
givenUsers: []string{},
expectedUsers: map[string]string{"test": "test2"},
realm: "traefiker",
},
}
for _, test := range testCases {
test := test
t.Run(test.desc, func(t *testing.T) {
t.Parallel()
// Creates the temporary configuration file with the users
usersFile, err := ioutil.TempFile("", "auth-users")
require.NoError(t, err)
defer os.Remove(usersFile.Name())
_, err = usersFile.Write([]byte(test.userFileContent))
require.NoError(t, err)
// Creates the configuration for our Authenticator
authenticatorConfiguration := config.DigestAuth{
Users: test.givenUsers,
UsersFile: usersFile.Name(),
Realm: test.realm,
}
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
authenticator, err := NewDigest(context.Background(), next, authenticatorConfiguration, "authName")
require.NoError(t, err)
ts := httptest.NewServer(authenticator)
defer ts.Close()
for userName, userPwd := range test.expectedUsers {
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
digestRequest := newDigestRequest(userName, userPwd, http.DefaultClient)
var res *http.Response
res, err = digestRequest.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, res.StatusCode, "Cannot authenticate user "+userName)
var body []byte
body, err = ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
require.Equal(t, "traefik\n", string(body))
}
// Checks that user foo doesn't work
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
digestRequest := newDigestRequest("foo", "foo", http.DefaultClient)
var res *http.Response
res, err = digestRequest.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusUnauthorized, res.StatusCode)
var body []byte
body, err = ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
require.NotContains(t, "traefik", string(body))
})
}
}

View file

@ -0,0 +1,213 @@
package auth
import (
"context"
"crypto/tls"
"fmt"
"io/ioutil"
"net"
"net/http"
"strings"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/middlewares"
"github.com/containous/traefik/pkg/tracing"
"github.com/opentracing/opentracing-go/ext"
"github.com/vulcand/oxy/forward"
"github.com/vulcand/oxy/utils"
)
const (
xForwardedURI = "X-Forwarded-Uri"
xForwardedMethod = "X-Forwarded-Method"
forwardedTypeName = "ForwardedAuthType"
)
type forwardAuth struct {
address string
authResponseHeaders []string
next http.Handler
name string
tlsConfig *tls.Config
trustForwardHeader bool
}
// NewForward creates a forward auth middleware.
func NewForward(ctx context.Context, next http.Handler, config config.ForwardAuth, name string) (http.Handler, error) {
middlewares.GetLogger(ctx, name, forwardedTypeName).Debug("Creating middleware")
fa := &forwardAuth{
address: config.Address,
authResponseHeaders: config.AuthResponseHeaders,
next: next,
name: name,
trustForwardHeader: config.TrustForwardHeader,
}
if config.TLS != nil {
tlsConfig, err := config.TLS.CreateTLSConfig()
if err != nil {
return nil, err
}
fa.tlsConfig = tlsConfig
}
return fa, nil
}
func (fa *forwardAuth) GetTracingInformation() (string, ext.SpanKindEnum) {
return fa.name, ext.SpanKindRPCClientEnum
}
func (fa *forwardAuth) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
logger := middlewares.GetLogger(req.Context(), fa.name, forwardedTypeName)
// Ensure our request client does not follow redirects
httpClient := http.Client{
CheckRedirect: func(r *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
if fa.tlsConfig != nil {
httpClient.Transport = &http.Transport{
TLSClientConfig: fa.tlsConfig,
}
}
forwardReq, err := http.NewRequest(http.MethodGet, fa.address, nil)
tracing.LogRequest(tracing.GetSpan(req), forwardReq)
if err != nil {
logMessage := fmt.Sprintf("Error calling %s. Cause %s", fa.address, err)
logger.Debug(logMessage)
tracing.SetErrorWithEvent(req, logMessage)
rw.WriteHeader(http.StatusInternalServerError)
return
}
writeHeader(req, forwardReq, fa.trustForwardHeader)
tracing.InjectRequestHeaders(forwardReq)
forwardResponse, forwardErr := httpClient.Do(forwardReq)
if forwardErr != nil {
logMessage := fmt.Sprintf("Error calling %s. Cause: %s", fa.address, forwardErr)
logger.Debug(logMessage)
tracing.SetErrorWithEvent(req, logMessage)
rw.WriteHeader(http.StatusInternalServerError)
return
}
body, readError := ioutil.ReadAll(forwardResponse.Body)
if readError != nil {
logMessage := fmt.Sprintf("Error reading body %s. Cause: %s", fa.address, readError)
logger.Debug(logMessage)
tracing.SetErrorWithEvent(req, logMessage)
rw.WriteHeader(http.StatusInternalServerError)
return
}
defer forwardResponse.Body.Close()
// Pass the forward response's body and selected headers if it
// didn't return a response within the range of [200, 300).
if forwardResponse.StatusCode < http.StatusOK || forwardResponse.StatusCode >= http.StatusMultipleChoices {
logger.Debugf("Remote error %s. StatusCode: %d", fa.address, forwardResponse.StatusCode)
utils.CopyHeaders(rw.Header(), forwardResponse.Header)
utils.RemoveHeaders(rw.Header(), forward.HopHeaders...)
// Grab the location header, if any.
redirectURL, err := forwardResponse.Location()
if err != nil {
if err != http.ErrNoLocation {
logMessage := fmt.Sprintf("Error reading response location header %s. Cause: %s", fa.address, err)
logger.Debug(logMessage)
tracing.SetErrorWithEvent(req, logMessage)
rw.WriteHeader(http.StatusInternalServerError)
return
}
} else if redirectURL.String() != "" {
// Set the location in our response if one was sent back.
rw.Header().Set("Location", redirectURL.String())
}
tracing.LogResponseCode(tracing.GetSpan(req), forwardResponse.StatusCode)
rw.WriteHeader(forwardResponse.StatusCode)
if _, err = rw.Write(body); err != nil {
logger.Error(err)
}
return
}
for _, headerName := range fa.authResponseHeaders {
req.Header.Set(headerName, forwardResponse.Header.Get(headerName))
}
req.RequestURI = req.URL.RequestURI()
fa.next.ServeHTTP(rw, req)
}
func writeHeader(req *http.Request, forwardReq *http.Request, trustForwardHeader bool) {
utils.CopyHeaders(forwardReq.Header, req.Header)
utils.RemoveHeaders(forwardReq.Header, forward.HopHeaders...)
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
if trustForwardHeader {
if prior, ok := req.Header[forward.XForwardedFor]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
}
forwardReq.Header.Set(forward.XForwardedFor, clientIP)
}
xMethod := req.Header.Get(xForwardedMethod)
switch {
case xMethod != "" && trustForwardHeader:
forwardReq.Header.Set(xForwardedMethod, xMethod)
case req.Method != "":
forwardReq.Header.Set(xForwardedMethod, req.Method)
default:
forwardReq.Header.Del(xForwardedMethod)
}
xfp := req.Header.Get(forward.XForwardedProto)
switch {
case xfp != "" && trustForwardHeader:
forwardReq.Header.Set(forward.XForwardedProto, xfp)
case req.TLS != nil:
forwardReq.Header.Set(forward.XForwardedProto, "https")
default:
forwardReq.Header.Set(forward.XForwardedProto, "http")
}
if xfp := req.Header.Get(forward.XForwardedPort); xfp != "" && trustForwardHeader {
forwardReq.Header.Set(forward.XForwardedPort, xfp)
}
xfh := req.Header.Get(forward.XForwardedHost)
switch {
case xfh != "" && trustForwardHeader:
forwardReq.Header.Set(forward.XForwardedHost, xfh)
case req.Host != "":
forwardReq.Header.Set(forward.XForwardedHost, req.Host)
default:
forwardReq.Header.Del(forward.XForwardedHost)
}
xfURI := req.Header.Get(xForwardedURI)
switch {
case xfURI != "" && trustForwardHeader:
forwardReq.Header.Set(xForwardedURI, xfURI)
case req.URL.RequestURI() != "":
forwardReq.Header.Set(xForwardedURI, req.URL.RequestURI())
default:
forwardReq.Header.Del(xForwardedURI)
}
}

View file

@ -0,0 +1,393 @@
package auth
import (
"context"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"github.com/containous/traefik/pkg/config"
"github.com/containous/traefik/pkg/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/vulcand/oxy/forward"
)
func TestForwardAuthFail(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Forbidden", http.StatusForbidden)
}))
defer server.Close()
middleware, err := NewForward(context.Background(), next, config.ForwardAuth{
Address: server.URL,
}, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusForbidden, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.Equal(t, "Forbidden\n", string(body))
}
func TestForwardAuthSuccess(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Auth-User", "user@example.com")
w.Header().Set("X-Auth-Secret", "secret")
fmt.Fprintln(w, "Success")
}))
defer server.Close()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "user@example.com", r.Header.Get("X-Auth-User"))
assert.Empty(t, r.Header.Get("X-Auth-Secret"))
fmt.Fprintln(w, "traefik")
})
auth := config.ForwardAuth{
Address: server.URL,
AuthResponseHeaders: []string{"X-Auth-User"},
}
middleware, err := NewForward(context.Background(), next, auth, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(middleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.Equal(t, "traefik\n", string(body))
}
func TestForwardAuthRedirect(t *testing.T) {
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "http://example.com/redirect-test", http.StatusFound)
}))
defer authTs.Close()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.ForwardAuth{
Address: authTs.URL,
}
authMiddleware, err := NewForward(context.Background(), next, auth, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
client := &http.Client{
CheckRedirect: func(r *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := client.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusFound, res.StatusCode)
location, err := res.Location()
require.NoError(t, err)
assert.Equal(t, "http://example.com/redirect-test", location.String())
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.NotEmpty(t, string(body))
}
func TestForwardAuthRemoveHopByHopHeaders(t *testing.T) {
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
headers := w.Header()
for _, header := range forward.HopHeaders {
if header == forward.TransferEncoding {
headers.Add(header, "identity")
} else {
headers.Add(header, "test")
}
}
http.Redirect(w, r, "http://example.com/redirect-test", http.StatusFound)
}))
defer authTs.Close()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.ForwardAuth{
Address: authTs.URL,
}
authMiddleware, err := NewForward(context.Background(), next, auth, "authTest")
assert.NoError(t, err, "there should be no error")
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
client := &http.Client{
CheckRedirect: func(r *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := client.Do(req)
assert.NoError(t, err, "there should be no error")
assert.Equal(t, http.StatusFound, res.StatusCode, "they should be equal")
for _, header := range forward.HopHeaders {
assert.Equal(t, "", res.Header.Get(header), "hop-by-hop header '%s' mustn't be set", header)
}
location, err := res.Location()
assert.NoError(t, err, "there should be no error")
assert.Equal(t, "http://example.com/redirect-test", location.String(), "they should be equal")
body, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err, "there should be no error")
assert.NotEmpty(t, string(body), "there should be something in the body")
}
func TestForwardAuthFailResponseHeaders(t *testing.T) {
authTs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie := &http.Cookie{Name: "example", Value: "testing", Path: "/"}
http.SetCookie(w, cookie)
w.Header().Add("X-Foo", "bar")
http.Error(w, "Forbidden", http.StatusForbidden)
}))
defer authTs.Close()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "traefik")
})
auth := config.ForwardAuth{
Address: authTs.URL,
}
authMiddleware, err := NewForward(context.Background(), next, auth, "authTest")
require.NoError(t, err)
ts := httptest.NewServer(authMiddleware)
defer ts.Close()
req := testhelpers.MustNewRequest(http.MethodGet, ts.URL, nil)
res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusForbidden, res.StatusCode)
require.Len(t, res.Cookies(), 1)
for _, cookie := range res.Cookies() {
assert.Equal(t, "testing", cookie.Value)
}
expectedHeaders := http.Header{
"Content-Length": []string{"10"},
"Content-Type": []string{"text/plain; charset=utf-8"},
"X-Foo": []string{"bar"},
"Set-Cookie": []string{"example=testing; Path=/"},
"X-Content-Type-Options": []string{"nosniff"},
}
assert.Len(t, res.Header, 6)
for key, value := range expectedHeaders {
assert.Equal(t, value, res.Header[key])
}
body, err := ioutil.ReadAll(res.Body)
require.NoError(t, err)
err = res.Body.Close()
require.NoError(t, err)
assert.Equal(t, "Forbidden\n", string(body))
}
func Test_writeHeader(t *testing.T) {
testCases := []struct {
name string
headers map[string]string
trustForwardHeader bool
emptyHost bool
expectedHeaders map[string]string
checkForUnexpectedHeaders bool
}{
{
name: "trust Forward Header",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
trustForwardHeader: true,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
},
{
name: "not trust Forward Header",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
trustForwardHeader: false,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "foo.bar",
},
},
{
name: "trust Forward Header with empty Host",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
trustForwardHeader: true,
emptyHost: true,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
},
{
name: "not trust Forward Header with empty Host",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
},
trustForwardHeader: false,
emptyHost: true,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "",
},
},
{
name: "trust Forward Header with forwarded URI",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
"X-Forwarded-Uri": "/forward?q=1",
},
trustForwardHeader: true,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
"X-Forwarded-Uri": "/forward?q=1",
},
},
{
name: "not trust Forward Header with forward requested URI",
headers: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "fii.bir",
"X-Forwarded-Uri": "/forward?q=1",
},
trustForwardHeader: false,
expectedHeaders: map[string]string{
"Accept": "application/json",
"X-Forwarded-Host": "foo.bar",
"X-Forwarded-Uri": "/path?q=1",
},
}, {
name: "trust Forward Header with forwarded request Method",
headers: map[string]string{
"X-Forwarded-Method": "OPTIONS",
},
trustForwardHeader: true,
expectedHeaders: map[string]string{
"X-Forwarded-Method": "OPTIONS",
},
},
{
name: "not trust Forward Header with forward request Method",
headers: map[string]string{
"X-Forwarded-Method": "OPTIONS",
},
trustForwardHeader: false,
expectedHeaders: map[string]string{
"X-Forwarded-Method": "GET",
},
},
{
name: "remove hop-by-hop headers",
headers: map[string]string{
forward.Connection: "Connection",
forward.KeepAlive: "KeepAlive",
forward.ProxyAuthenticate: "ProxyAuthenticate",
forward.ProxyAuthorization: "ProxyAuthorization",
forward.Te: "Te",
forward.Trailers: "Trailers",
forward.TransferEncoding: "TransferEncoding",
forward.Upgrade: "Upgrade",
"X-CustomHeader": "CustomHeader",
},
trustForwardHeader: false,
expectedHeaders: map[string]string{
"X-CustomHeader": "CustomHeader",
"X-Forwarded-Proto": "http",
"X-Forwarded-Host": "foo.bar",
"X-Forwarded-Uri": "/path?q=1",
"X-Forwarded-Method": "GET",
},
checkForUnexpectedHeaders: true,
},
}
for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {
req := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/path?q=1", nil)
for key, value := range test.headers {
req.Header.Set(key, value)
}
if test.emptyHost {
req.Host = ""
}
forwardReq := testhelpers.MustNewRequest(http.MethodGet, "http://foo.bar/path?q=1", nil)
writeHeader(req, forwardReq, test.trustForwardHeader)
actualHeaders := forwardReq.Header
expectedHeaders := test.expectedHeaders
for key, value := range expectedHeaders {
assert.Equal(t, value, actualHeaders.Get(key))
actualHeaders.Del(key)
}
if test.checkForUnexpectedHeaders {
for key := range actualHeaders {
assert.Fail(t, "Unexpected header found", key)
}
}
})
}
}