diff --git a/internal/oidc/auth/auth_handler.go b/internal/oidc/auth/auth_handler.go index de3c7f714..a22e50289 100644 --- a/internal/oidc/auth/auth_handler.go +++ b/internal/oidc/auth/auth_handler.go @@ -17,6 +17,7 @@ import ( "golang.org/x/oauth2" "go.pinniped.dev/internal/httputil/httperr" + "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidcclient/nonce" @@ -42,10 +43,6 @@ const ( csrfCookieEncodingName = "csrf" ) -type IDPListGetter interface { - GetIDPList() []provider.UpstreamOIDCIdentityProvider -} - // This is the encoding side of the securecookie.Codec interface. type Encoder interface { Encode(name string, value interface{}) (string, error) @@ -53,7 +50,7 @@ type Encoder interface { func NewHandler( issuer string, - idpListGetter IDPListGetter, + idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provider, generateCSRF func() (csrftoken.CSRFToken, error), generatePKCE func() (pkce.Code, error), @@ -178,7 +175,7 @@ func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) { } } -func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) { +func chooseUpstreamIDP(idpListGetter oidc.IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) { allUpstreamIDPs := idpListGetter.GetIDPList() if len(allUpstreamIDPs) == 0 { return nil, httperr.New( diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index e0fdbacd4..173021d84 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -23,6 +23,7 @@ import ( "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidcclient/nonce" "go.pinniped.dev/internal/oidcclient/pkce" + "go.pinniped.dev/internal/testutil" ) func TestAuthorizationEndpoint(t *testing.T) { @@ -210,7 +211,7 @@ func TestAuthorizationEndpoint(t *testing.T) { csrf = csrfValueOverride } encoded, err := happyStateEncoder.Encode("s", - expectedUpstreamStateParamFormat{ + testutil.ExpectedUpstreamStateParamFormat{ P: encodeQuery(modifiedHappyGetRequestQueryMap(queryOverrides)), N: happyNonce, C: csrf, @@ -270,7 +271,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "happy path using GET without a CSRF cookie", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -288,7 +289,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "happy path using GET with a CSRF cookie", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -306,7 +307,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "happy path using POST", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -326,7 +327,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "happy path when downstream redirect uri matches what is configured for client except for the port number", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -348,7 +349,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "downstream redirect uri does not match what is configured for client", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -365,7 +366,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "downstream client does not exist", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -380,7 +381,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "response type is unsupported", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -396,7 +397,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "downstream scopes do not match what is configured for client", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -412,7 +413,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "missing response type in request", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -428,7 +429,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "missing client id in request", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -443,7 +444,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -459,7 +460,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3 issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -475,7 +476,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3 issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -491,7 +492,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -509,7 +510,7 @@ func TestAuthorizationEndpoint(t *testing.T) { // through that part of the fosite library. name: "prompt param is not allowed to have none and another legal value at the same time", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -525,7 +526,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "OIDC validations are skipped when the openid scope was not requested", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -546,7 +547,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "state does not have enough entropy", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -562,7 +563,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while encoding upstream state param", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -577,7 +578,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while encoding CSRF cookie value for new cookie", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -592,7 +593,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while generating CSRF token", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") }, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -607,7 +608,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while generating nonce", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") }, @@ -622,7 +623,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while generating PKCE", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") }, generateNonce: happyNonceGenerator, @@ -637,7 +638,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while decoding CSRF cookie", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -653,7 +654,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "no upstream providers are configured", issuer: issuer, - idpListGetter: newIDPListGetter(), // empty + idpListGetter: testutil.NewIDPListGetter(), // empty method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusUnprocessableEntity, @@ -663,7 +664,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "too many upstream providers are configured", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider, upstreamOIDCIdentityProvider), // more than one not allowed + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider, upstreamOIDCIdentityProvider), // more than one not allowed method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusUnprocessableEntity, @@ -673,7 +674,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "PUT is a bad method", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), method: http.MethodPut, path: "/some/path", wantStatus: http.StatusMethodNotAllowed, @@ -683,7 +684,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "PATCH is a bad method", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), method: http.MethodPatch, path: "/some/path", wantStatus: http.StatusMethodNotAllowed, @@ -693,7 +694,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "DELETE is a bad method", issuer: issuer, - idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), method: http.MethodDelete, path: "/some/path", wantStatus: http.StatusMethodNotAllowed, @@ -805,18 +806,6 @@ func TestAuthorizationEndpoint(t *testing.T) { }) } -// Declare a separate type from the production code to ensure that the state param's contents was serialized -// in the format that we expect, with the json keys that we expect, etc. This also ensure that the order of -// the serialized fields is the same, which doesn't really matter expect that we can make simpler equality -// assertions about the redirect URL in this test. -type expectedUpstreamStateParamFormat struct { - P string `json:"p"` - N string `json:"n"` - C string `json:"c"` - K string `json:"k"` - V string `json:"v"` -} - type errorReturningEncoder struct { securecookie.Codec } @@ -850,13 +839,13 @@ func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL expectedQueryStateParam := expectedLocationURL.Query().Get("state") require.NotEmpty(t, expectedQueryStateParam) - var expectedDecodedStateParam expectedUpstreamStateParamFormat + var expectedDecodedStateParam testutil.ExpectedUpstreamStateParamFormat err = stateParamDecoder.Decode("s", expectedQueryStateParam, &expectedDecodedStateParam) require.NoError(t, err) actualQueryStateParam := actualLocationURL.Query().Get("state") require.NotEmpty(t, actualQueryStateParam) - var actualDecodedStateParam expectedUpstreamStateParamFormat + var actualDecodedStateParam testutil.ExpectedUpstreamStateParamFormat err = stateParamDecoder.Decode("s", actualQueryStateParam, &actualDecodedStateParam) require.NoError(t, err) @@ -884,9 +873,3 @@ func requireEqualURLs(t *testing.T, actualURL string, expectedURL string, ignore } require.Equal(t, expectedLocationQuery, actualLocationQuery) } - -func newIDPListGetter(upstreamOIDCIdentityProviders ...provider.UpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider { - idpProvider := provider.NewDynamicUpstreamIDPProvider() - idpProvider.SetIDPList(upstreamOIDCIdentityProviders) - return idpProvider -} diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 837c38c65..475008e5e 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -6,16 +6,43 @@ package callback import ( "net/http" + "path" "go.pinniped.dev/internal/httputil/httperr" + "go.pinniped.dev/internal/oidc" + "go.pinniped.dev/internal/oidc/provider" ) -func NewHandler() http.Handler { +func NewHandler( + idpListGetter oidc.IDPListGetter, +) http.Handler { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { if r.Method != http.MethodGet { return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method) } - return nil + if r.FormValue("code") == "" { + return httperr.New(http.StatusBadRequest, "code param not found") + } + + if r.FormValue("state") == "" { + return httperr.New(http.StatusBadRequest, "state param not found") + } + + if findUpstreamIDPConfig(r, idpListGetter) == nil { + return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found") + } + + return httperr.New(http.StatusBadRequest, "state param not valid") }) } + +func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) *provider.UpstreamOIDCIdentityProvider { + _, lastPathComponent := path.Split(r.URL.Path) + for _, p := range idpListGetter.GetIDPList() { + if p.Name == lastPathComponent { + return &p + } + } + return nil +} diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 2f88aa303..05635016f 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -4,18 +4,76 @@ package callback import ( + "fmt" "net/http" "net/http/httptest" + "net/url" "testing" + "github.com/gorilla/securecookie" "github.com/stretchr/testify/require" + + "go.pinniped.dev/internal/oidc/provider" + "go.pinniped.dev/internal/testutil" +) + +const ( + happyUpstreamIDPName = "upstream-idp-name" ) func TestCallbackEndpoint(t *testing.T) { + upstreamAuthURL, err := url.Parse("https://some-upstream-idp:8443/auth") + require.NoError(t, err) + otherUpstreamAuthURL, err := url.Parse("https://some-other-upstream-idp:8443/auth") + require.NoError(t, err) + + upstreamOIDCIdentityProvider := provider.UpstreamOIDCIdentityProvider{ + Name: happyUpstreamIDPName, + ClientID: "some-client-id", + AuthorizationURL: *upstreamAuthURL, + Scopes: []string{"scope1", "scope2"}, + } + + otherUpstreamOIDCIdentityProvider := provider.UpstreamOIDCIdentityProvider{ + Name: "other-upstream-idp-name", + ClientID: "other-some-client-id", + AuthorizationURL: *otherUpstreamAuthURL, + Scopes: []string{"other-scope1", "other-scope2"}, + } + + var stateEncoderHashKey = []byte("fake-hash-secret") + var stateEncoderBlockKey = []byte("0123456789ABCDEF") // block encryption requires 16/24/32 bytes for AES + var cookieEncoderHashKey = []byte("fake-hash-secret2") + var cookieEncoderBlockKey = []byte("0123456789ABCDE2") // block encryption requires 16/24/32 bytes for AES + require.NotEqual(t, stateEncoderHashKey, cookieEncoderHashKey) + require.NotEqual(t, stateEncoderBlockKey, cookieEncoderBlockKey) + + var happyStateEncoder = securecookie.New(stateEncoderHashKey, stateEncoderBlockKey) + happyStateEncoder.SetSerializer(securecookie.JSONEncoder{}) + var happyCookieEncoder = securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey) + happyCookieEncoder.SetSerializer(securecookie.JSONEncoder{}) + + //happyCSRF := "test-csrf" + //happyPKCE := "test-pkce" + //happyNonce := "test-nonce" + // + //happyEncodedState, err := happyStateEncoder.Encode("s", + // testutil.ExpectedUpstreamStateParamFormat{ + // P: "todo query goes here", + // N: happyNonce, + // C: happyCSRF, + // K: happyPKCE, + // V: "1", + // }, + //) + //require.NoError(t, err) + tests := []struct { name string - method string + method string + path string + idpListGetter provider.DynamicUpstreamIDPProvider wantStatus int wantBody string @@ -27,24 +85,67 @@ func TestCallbackEndpoint(t *testing.T) { { name: "PUT method is invalid", method: http.MethodPut, + path: newRequestPath().String(), wantStatus: http.StatusMethodNotAllowed, wantBody: "Method Not Allowed: PUT (try GET)\n", }, - // TODO: POST/PATCH/DELETE is invalid - // TODO: request has body? maybe we don't need to do anything... - // TODO: code does not exist - // TODO: we got called twice with the same state and cookie...is this bad? might be ok if the client's first roundtrip failed - // TODO: we got called twice with the same state and cookie and the UpstreamOIDCProvider CRD has been deleted - // TODO: state does not exist - // TODO: invalid signature on state - // TODO: state is expired (the expiration is encoded in the state itself) - // TODO: state csrf value does not match csrf cookie - // TODO: cookie does not exist - // TODO: invalid signature on cookie - // TODO: state version does not match what we want + { + name: "POST method is invalid", + method: http.MethodPost, + path: newRequestPath().String(), + wantStatus: http.StatusMethodNotAllowed, + wantBody: "Method Not Allowed: POST (try GET)\n", + }, + { + name: "PATCH method is invalid", + method: http.MethodPatch, + path: newRequestPath().String(), + wantStatus: http.StatusMethodNotAllowed, + wantBody: "Method Not Allowed: PATCH (try GET)\n", + }, + { + name: "DELETE method is invalid", + method: http.MethodDelete, + path: newRequestPath().String(), + wantStatus: http.StatusMethodNotAllowed, + wantBody: "Method Not Allowed: DELETE (try GET)\n", + }, + { + name: "code param was not included on request", + method: http.MethodGet, + path: newRequestPath().WithoutCode().String(), + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: code param not found\n", + }, + { + name: "state param was not included on request", + method: http.MethodGet, + path: newRequestPath().WithoutState().String(), + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: state param not found\n", + }, + { + name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason", + method: http.MethodGet, + path: newRequestPath().WithState("this-will-not-decode").String(), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: state param not valid\n", + }, + { + name: "the UpstreamOIDCProvider CRD has been deleted", + method: http.MethodGet, + path: newRequestPath().String(), + idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: upstream provider not found\n", + }, + // TODO: csrf cookie does not exist on request + // TODO: csrf cookie value cannot be decoded (e.g. invalid signture or any other decoding problem) + // TODO: csrf value from inside state param does not match csrf cookie value + // TODO: state's internal version does not match what we want // Upstream exchange - // TODO: we can't figure out what the upstream token endpoint is (do we get this UpstreamOIDCProvider name from the path?) // TODO: network call to upstream token endpoint fails // TODO: the upstream token endpoint returns an error @@ -61,14 +162,15 @@ func TestCallbackEndpoint(t *testing.T) { // TODO: here (e.g., id jwt cannot be verified, nonce is wrong, we didn't get refresh token, we didn't get access token, we didn't get id token, access token expires too quickly) // Downstream redirect + // TODO: we grant the openid scope if it was requested, similar to what we did in auth_handler.go // TODO: cannot generate auth code // TODO: cannot persist downstream state } for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { - subject := NewHandler() - req := httptest.NewRequest(test.method, "/path-is-not-yet-tested", nil /* body not yet tested */) + subject := NewHandler(test.idpListGetter) + req := httptest.NewRequest(test.method, test.path, nil) rsp := httptest.NewRecorder() subject.ServeHTTP(rsp, req) @@ -77,3 +179,55 @@ func TestCallbackEndpoint(t *testing.T) { }) } } + +type requestPath struct { + upstreamIDPName, code, state *string +} + +func newRequestPath() *requestPath { + n := happyUpstreamIDPName + c := "1234" + s := "4321" + return &requestPath{ + upstreamIDPName: &n, + code: &c, + state: &s, + } +} + +func (r *requestPath) WithUpstreamIDPName(name string) *requestPath { + r.upstreamIDPName = &name + return r +} + +func (r *requestPath) WithCode(code string) *requestPath { + r.code = &code + return r +} + +func (r *requestPath) WithoutCode() *requestPath { + r.code = nil + return r +} + +func (r *requestPath) WithState(state string) *requestPath { + r.state = &state + return r +} + +func (r *requestPath) WithoutState() *requestPath { + r.state = nil + return r +} + +func (r *requestPath) String() string { + path := fmt.Sprintf("/downstream-provider-name/callback/%s?", *r.upstreamIDPName) + params := url.Values{} + if r.code != nil { + params.Add("code", *r.code) + } + if r.state != nil { + params.Add("state", *r.state) + } + return path + params.Encode() +} diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index 8d319c1d1..f0f74970f 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -7,6 +7,8 @@ package oidc import ( "github.com/ory/fosite" "github.com/ory/fosite/compose" + + "go.pinniped.dev/internal/oidc/provider" ) const ( @@ -49,3 +51,7 @@ func FositeOauth2Helper(oauthStore interface{}, hmacSecretOfLengthAtLeast32 []by compose.OAuth2PKCEFactory, ) } + +type IDPListGetter interface { + GetIDPList() []provider.UpstreamOIDCIdentityProvider +} diff --git a/internal/oidc/provider/manager/manager.go b/internal/oidc/provider/manager/manager.go index 0687ba22d..f009693d0 100644 --- a/internal/oidc/provider/manager/manager.go +++ b/internal/oidc/provider/manager/manager.go @@ -30,14 +30,14 @@ type Manager struct { providerHandlers map[string]http.Handler // map of all routes for all providers nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request dynamicJWKSProvider jwks.DynamicJWKSProvider // in-memory cache of per-issuer JWKS data - idpListGetter auth.IDPListGetter // in-memory cache of upstream IDPs + idpListGetter oidc.IDPListGetter // in-memory cache of upstream IDPs } // NewManager returns an empty Manager. // nextHandler will be invoked for any requests that could not be handled by this manager's providers. // dynamicJWKSProvider will be used as an in-memory cache for per-issuer JWKS data. // idpListGetter will be used as an in-memory cache of currently configured upstream IDPs. -func NewManager(nextHandler http.Handler, dynamicJWKSProvider jwks.DynamicJWKSProvider, idpListGetter auth.IDPListGetter) *Manager { +func NewManager(nextHandler http.Handler, dynamicJWKSProvider jwks.DynamicJWKSProvider, idpListGetter oidc.IDPListGetter) *Manager { return &Manager{ providerHandlers: make(map[string]http.Handler), nextHandler: nextHandler, diff --git a/internal/testutil/oidc.go b/internal/testutil/oidc.go new file mode 100644 index 000000000..0552ea95f --- /dev/null +++ b/internal/testutil/oidc.go @@ -0,0 +1,26 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package testutil + +import "go.pinniped.dev/internal/oidc/provider" + +// Test helpers for the OIDC package. + +func NewIDPListGetter(upstreamOIDCIdentityProviders ...provider.UpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider { + idpProvider := provider.NewDynamicUpstreamIDPProvider() + idpProvider.SetIDPList(upstreamOIDCIdentityProviders) + return idpProvider +} + +// Declare a separate type from the production code to ensure that the state param's contents was serialized +// in the format that we expect, with the json keys that we expect, etc. This also ensure that the order of +// the serialized fields is the same, which doesn't really matter expect that we can make simpler equality +// assertions about the redirect URL in this test. +type ExpectedUpstreamStateParamFormat struct { + P string `json:"p"` + N string `json:"n"` + C string `json:"c"` + K string `json:"k"` + V string `json:"v"` +}