Files
pinniped/internal/oidc/callback/callback_handler.go
Ryan Richard 83101eefce callback_handler.go: start to test upstream token corner cases
Also refactor to get rid of duplicate test structs.

Also also don't default groups ID token claim because there is no standard one.

Also also also add some logging that will hopefully help us in debugging in the
future.

Signed-off-by: Andrew Keesler <akeesler@vmware.com>
2020-11-19 14:19:01 -05:00

260 lines
7.8 KiB
Go

// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
// Package callback provides a handler for the OIDC callback endpoint.
package callback
import (
"net/http"
"net/url"
"path"
"time"
"github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
"github.com/ory/fosite/token/jwt"
"go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/plog"
)
const (
// defaultUpstreamUsernameClaim is what we will use to extract the username from an upstream OIDC
// ID token if the upstream OIDC IDP did not tell us to use another claim.
defaultUpstreamUsernameClaim = "sub"
// downstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token
// information.
downstreamGroupsClaim = "groups"
)
func NewHandler(
downstreamIssuer string,
idpListGetter oidc.IDPListGetter,
oauthHelper fosite.OAuth2Provider,
stateDecoder, cookieDecoder oidc.Decoder,
) http.Handler {
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
state, err := validateRequest(r, stateDecoder, cookieDecoder)
if err != nil {
return err
}
upstreamIDPConfig := findUpstreamIDPConfig(r, idpListGetter)
if upstreamIDPConfig == nil {
plog.Warning("upstream provider not found")
return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found")
}
downstreamAuthParams, err := url.ParseQuery(state.AuthParams)
if err != nil {
plog.Error("error reading state downstream auth params", err)
return httperr.New(http.StatusBadRequest, "error reading state downstream auth params")
}
// Recreate enough of the original authorize request so we can pass it to NewAuthorizeRequest().
reconstitutedAuthRequest := &http.Request{Form: downstreamAuthParams}
authorizeRequester, err := oauthHelper.NewAuthorizeRequest(r.Context(), reconstitutedAuthRequest)
if err != nil {
plog.Error("error using state downstream auth params", err)
return httperr.New(http.StatusBadRequest, "error using state downstream auth params")
}
// Grant the openid scope only if it was requested.
grantOpenIDScopeIfRequested(authorizeRequester)
_, idTokenClaims, err := upstreamIDPConfig.ExchangeAuthcodeAndValidateTokens(
r.Context(),
authcode(r),
state.PKCECode,
state.Nonce,
)
if err != nil {
plog.WarningErr("error exchanging and validating upstream tokens", err, "upstreamName", upstreamIDPConfig.GetName())
return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens")
}
username, err := getUsernameFromUpstreamIDToken(upstreamIDPConfig, idTokenClaims)
if err != nil {
return err
}
groups := getGroupsFromUpstreamIDToken(upstreamIDPConfig, idTokenClaims)
openIDSession := makeDownstreamSession(downstreamIssuer, downstreamAuthParams.Get("client_id"), username, groups)
authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, openIDSession)
if err != nil {
panic(err) // TODO
}
oauthHelper.WriteAuthorizeResponse(w, authorizeRequester, authorizeResponder)
return nil
})
}
func authcode(r *http.Request) string {
return r.FormValue("code")
}
func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) (*oidc.UpstreamStateParamData, error) {
if r.Method != http.MethodGet {
return nil, httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method)
}
csrfValue, err := readCSRFCookie(r, cookieDecoder)
if err != nil {
plog.InfoErr("error reading CSRF cookie", err)
return nil, err
}
if authcode(r) == "" {
plog.Info("code param not found")
return nil, httperr.New(http.StatusBadRequest, "code param not found")
}
if r.FormValue("state") == "" {
plog.Info("state param not found")
return nil, httperr.New(http.StatusBadRequest, "state param not found")
}
state, err := readState(r, stateDecoder)
if err != nil {
plog.InfoErr("error reading state", err)
return nil, err
}
if state.CSRFToken != csrfValue {
plog.InfoErr("CSRF value does not match", err)
return nil, httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err)
}
return state, nil
}
func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) provider.UpstreamOIDCIdentityProviderI {
_, lastPathComponent := path.Split(r.URL.Path)
for _, p := range idpListGetter.GetIDPList() {
if p.GetName() == lastPathComponent {
return p
}
}
return nil
}
func readCSRFCookie(r *http.Request, cookieDecoder oidc.Decoder) (csrftoken.CSRFToken, error) {
receivedCSRFCookie, err := r.Cookie(oidc.CSRFCookieName)
if err != nil {
// Error means that the cookie was not found
return "", httperr.Wrap(http.StatusForbidden, "CSRF cookie is missing", err)
}
var csrfFromCookie csrftoken.CSRFToken
err = cookieDecoder.Decode(oidc.CSRFCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie)
if err != nil {
return "", httperr.Wrap(http.StatusForbidden, "error reading CSRF cookie", err)
}
return csrfFromCookie, nil
}
func readState(r *http.Request, stateDecoder oidc.Decoder) (*oidc.UpstreamStateParamData, error) {
var state oidc.UpstreamStateParamData
if err := stateDecoder.Decode(
oidc.UpstreamStateParamEncodingName,
r.FormValue("state"),
&state,
); err != nil {
return nil, httperr.New(http.StatusBadRequest, "error reading state")
}
if state.FormatVersion != oidc.UpstreamStateParamFormatVersion {
return nil, httperr.New(http.StatusUnprocessableEntity, "state format version is invalid")
}
return &state, nil
}
func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) {
for _, scope := range authorizeRequester.GetRequestedScopes() {
if scope == "openid" {
authorizeRequester.GrantScope(scope)
}
}
}
func getUsernameFromUpstreamIDToken(
upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI,
idTokenClaims map[string]interface{},
) (string, error) {
usernameClaim := upstreamIDPConfig.GetUsernameClaim()
if usernameClaim == "" {
// TODO: if we use the default "sub" claim, maybe we should create the username with the issuer
// since the spec says the "sub" claim is only unique per issuer.
usernameClaim = defaultUpstreamUsernameClaim
}
usernameAsInterface, ok := idTokenClaims[usernameClaim]
if !ok {
plog.Warning(
"no username claim in upstream ID token",
"upstreamName", upstreamIDPConfig.GetName(),
"configuredUsernameClaim", upstreamIDPConfig.GetUsernameClaim(),
"usernameClaim", usernameClaim,
)
return "", httperr.New(http.StatusUnprocessableEntity, "no username claim in upstream ID token")
}
username, ok := usernameAsInterface.(string)
if !ok {
panic("todo bbb") // TODO
}
return username, nil
}
func getGroupsFromUpstreamIDToken(
upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI,
idTokenClaims map[string]interface{},
) []string {
groupsClaim := upstreamIDPConfig.GetGroupsClaim()
if groupsClaim == "" {
return nil
}
groupsAsInterface, ok := idTokenClaims[groupsClaim]
if !ok {
panic("todo ccc") // TODO
}
groups, ok := groupsAsInterface.([]string)
if !ok {
panic("todo ddd") // TODO
}
return groups
}
func makeDownstreamSession(issuer, clientID, username string, groups []string) *openid.DefaultSession {
now := time.Now()
openIDSession := &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{
Issuer: issuer,
Subject: username,
Audience: []string{clientID},
ExpiresAt: now.Add(time.Minute * 30), // TODO use the right value here
IssuedAt: now, // TODO test this
RequestedAt: now, // TODO test this
AuthTime: now, // TODO test this
},
}
if groups != nil {
openIDSession.Claims.Extra = map[string]interface{}{
downstreamGroupsClaim: groups,
}
}
return openIDSession
}