diff --git a/src/Lifecycle.ts b/src/Lifecycle.ts index ce7d7b5e2a..cbc1f19915 100644 --- a/src/Lifecycle.ts +++ b/src/Lifecycle.ts @@ -289,7 +289,7 @@ export async function attemptDelegatedAuthLogin( */ async function attemptOidcNativeLogin(queryParams: QueryDict): Promise { try { - const { accessToken, refreshToken, homeserverUrl, identityServerUrl, idTokenClaims, clientId, issuer } = + const { accessToken, refreshToken, homeserverUrl, identityServerUrl, idToken, clientId, issuer } = await completeOidcLogin(queryParams); const { @@ -311,7 +311,7 @@ async function attemptOidcNativeLogin(queryParams: QueryDict): Promise logger.debug("Logged in via OIDC native flow"); await onSuccessfulDelegatedAuthLogin(credentials); // this needs to happen after success handler which clears storages - persistOidcAuthenticatedSettings(clientId, issuer, idTokenClaims); + persistOidcAuthenticatedSettings(clientId, issuer, idToken); return true; } catch (error) { logger.error("Failed to login via OIDC", error); diff --git a/src/stores/oidc/OidcClientStore.ts b/src/stores/oidc/OidcClientStore.ts index 57fe1adcd1..04328dfc94 100644 --- a/src/stores/oidc/OidcClientStore.ts +++ b/src/stores/oidc/OidcClientStore.ts @@ -18,7 +18,11 @@ import { MatrixClient, discoverAndValidateOIDCIssuerWellKnown } from "matrix-js- import { logger } from "matrix-js-sdk/src/logger"; import { OidcClient } from "oidc-client-ts"; -import { getStoredOidcTokenIssuer, getStoredOidcClientId } from "../../utils/oidc/persistOidcSettings"; +import { + getStoredOidcTokenIssuer, + getStoredOidcClientId, + getStoredOidcIdToken, +} from "../../utils/oidc/persistOidcSettings"; import PlatformPeg from "../../PlatformPeg"; /** @@ -58,7 +62,7 @@ export class OidcClientStore { const { accountManagementEndpoint, metadata } = await discoverAndValidateOIDCIssuerWellKnown( authIssuer.issuer, ); - this._accountManagementEndpoint = accountManagementEndpoint ?? metadata.issuer; + this.setAccountManagementEndpoint(accountManagementEndpoint, metadata.issuer); } catch (e) { console.log("Auth issuer not found", e); } @@ -72,6 +76,16 @@ export class OidcClientStore { return !!this.authenticatedIssuer; } + private setAccountManagementEndpoint(endpoint: string | undefined, issuer: string): void { + // if no account endpoint is configured default to the issuer + const url = new URL(endpoint ?? issuer); + const idToken = getStoredOidcIdToken(); + if (idToken) { + url.searchParams.set("id_token_hint", idToken); + } + this._accountManagementEndpoint = url.toString(); + } + public get accountManagementEndpoint(): string | undefined { return this._accountManagementEndpoint; } @@ -150,8 +164,7 @@ export class OidcClientStore { const { accountManagementEndpoint, metadata, signingKeys } = await discoverAndValidateOIDCIssuerWellKnown( this.authenticatedIssuer, ); - // if no account endpoint is configured default to the issuer - this._accountManagementEndpoint = accountManagementEndpoint ?? metadata.issuer; + this.setAccountManagementEndpoint(accountManagementEndpoint, metadata.issuer); this.oidcClient = new OidcClient({ ...metadata, authority: metadata.issuer, diff --git a/src/utils/oidc/authorize.ts b/src/utils/oidc/authorize.ts index 8bbdd9894a..3cb4147680 100644 --- a/src/utils/oidc/authorize.ts +++ b/src/utils/oidc/authorize.ts @@ -86,6 +86,8 @@ type CompleteOidcLoginResponse = { accessToken: string; // refreshToken gained from OIDC token issuer, when falsy token cannot be refreshed refreshToken?: string; + // idToken gained from OIDC token issuer + idToken: string; // this client's id as registered with the OIDC issuer clientId: string; // issuer used during authentication @@ -109,6 +111,7 @@ export const completeOidcLogin = async (queryParams: QueryDict): Promise { +export const persistOidcAuthenticatedSettings = (clientId: string, issuer: string, idToken: string): void => { localStorage.setItem(clientIdStorageKey, clientId); localStorage.setItem(tokenIssuerStorageKey, issuer); - localStorage.setItem(idTokenClaimsStorageKey, JSON.stringify(idTokenClaims)); + localStorage.setItem(idTokenStorageKey, idToken); }; /** @@ -59,13 +62,26 @@ export const getStoredOidcClientId = (): string => { }; /** - * Retrieve stored id token claims from local storage - * @returns idtokenclaims or undefined + * Retrieve stored id token claims from stored id token or local storage + * @returns idTokenClaims or undefined */ export const getStoredOidcIdTokenClaims = (): IdTokenClaims | undefined => { + const idToken = getStoredOidcIdToken(); + if (idToken) { + return decodeIdToken(idToken); + } + const idTokenClaims = localStorage.getItem(idTokenClaimsStorageKey); if (!idTokenClaims) { return; } return JSON.parse(idTokenClaims) as IdTokenClaims; }; + +/** + * Retrieve stored id token from local storage + * @returns idToken or undefined + */ +export const getStoredOidcIdToken = (): string | undefined => { + return localStorage.getItem(idTokenStorageKey) ?? undefined; +}; diff --git a/test/Lifecycle-test.ts b/test/Lifecycle-test.ts index 271cae8b79..4a6122f470 100644 --- a/test/Lifecycle-test.ts +++ b/test/Lifecycle-test.ts @@ -657,13 +657,8 @@ describe("Lifecycle", () => { const issuer = "https://auth.com/"; const delegatedAuthConfig = makeDelegatedAuthConfig(issuer); - const idTokenClaims = { - aud: "123", - iss: issuer, - sub: "123", - exp: 123, - iat: 456, - }; + const idToken = + "eyJhbGciOiJSUzI1NiIsImtpZCI6Imh4ZEhXb0Y5bW4ifQ.eyJzdWIiOiIwMUhQUDJGU0JZREU5UDlFTU04REQ3V1pIUiIsImlzcyI6Imh0dHBzOi8vYXV0aC1vaWRjLmxhYi5lbGVtZW50LmRldi8iLCJpYXQiOjE3MTUwNzE5ODUsImF1dGhfdGltZSI6MTcwNzk5MDMxMiwiY19oYXNoIjoidGt5R1RhUjU5aTk3YXoyTU4yMGdidyIsImV4cCI6MTcxNTA3NTU4NSwibm9uY2UiOiJxaXhwM0hFMmVaIiwiYXVkIjoiMDFIWDk0Mlg3QTg3REgxRUs2UDRaNjI4WEciLCJhdF9oYXNoIjoiNFlFUjdPRlVKTmRTeEVHV2hJUDlnZyJ9.HxODneXvSTfWB5Vc4cf7b8GiN2gdwUuTiyVqZuupWske2HkZiJZUt5Lsxg9BW3gz28POkE0Ln17snlkmy02B_AD3DQxKOOxQCzIIARHdfFvZxgGWsMdFcVQZDW7rtXcqgj-SpVaUQ_8acsgxSrz_DF2o0O4tto0PT6wVUiw8KlBmgWTscWPeAWe-39T-8EiQ8Wi16h6oSPcz2NzOQ7eOM_S9fDkOorgcBkRGLl1nrahrPSdWJSGAeruk5mX4YxN714YThFDyEA2t9YmKpjaiSQ2tT-Xkd7tgsZqeirNs2ni9mIiFX3bRX6t2AhUNzA7MaX9ZyizKGa6go3BESO_oDg"; beforeAll(() => { fetchMock.get( @@ -682,7 +677,7 @@ describe("Lifecycle", () => { beforeEach(() => { initSessionStorageMock(); // set values in session storage as they would be after a successful oidc authentication - persistOidcAuthenticatedSettings(clientId, issuer, idTokenClaims); + persistOidcAuthenticatedSettings(clientId, issuer, idToken); }); it("should not try to create a token refresher without a refresh token", async () => { @@ -712,7 +707,7 @@ describe("Lifecycle", () => { clientId, // @ts-ignore set undefined issuer undefined, - idTokenClaims, + idToken, ); await setLoggedIn({ ...credentials, @@ -744,7 +739,7 @@ describe("Lifecycle", () => { it("should create a client when creating token refresher fails", async () => { // set invalid value in session storage for a malformed oidc authentication - persistOidcAuthenticatedSettings(null as any, issuer, idTokenClaims); + persistOidcAuthenticatedSettings(null as any, issuer, idToken); // succeeded expect( diff --git a/test/components/structures/MatrixChat-test.tsx b/test/components/structures/MatrixChat-test.tsx index d112cebe81..00a44c44bd 100644 --- a/test/components/structures/MatrixChat-test.tsx +++ b/test/components/structures/MatrixChat-test.tsx @@ -284,6 +284,7 @@ describe("", () => { const tokenResponse: BearerTokenResponse = { access_token: accessToken, refresh_token: "def456", + id_token: "ghi789", scope: "test", token_type: "Bearer", expires_at: 12345, diff --git a/test/utils/oidc/authorize-test.ts b/test/utils/oidc/authorize-test.ts index a323fc95a1..4ee13b55ed 100644 --- a/test/utils/oidc/authorize-test.ts +++ b/test/utils/oidc/authorize-test.ts @@ -115,6 +115,7 @@ describe("OIDC authorization", () => { const tokenResponse: BearerTokenResponse = { access_token: "abc123", refresh_token: "def456", + id_token: "ghi789", scope: "test", token_type: "Bearer", expires_at: 12345, @@ -163,6 +164,7 @@ describe("OIDC authorization", () => { identityServerUrl, issuer, clientId, + idToken: "ghi789", idTokenClaims: result.idTokenClaims, }); }); diff --git a/test/utils/oidc/persistOidcSettings-test.ts b/test/utils/oidc/persistOidcSettings-test.ts index 3585c1576e..2904f38c69 100644 --- a/test/utils/oidc/persistOidcSettings-test.ts +++ b/test/utils/oidc/persistOidcSettings-test.ts @@ -15,14 +15,19 @@ limitations under the License. */ import { IdTokenClaims } from "oidc-client-ts"; +import { decodeIdToken } from "matrix-js-sdk/src/matrix"; +import { mocked } from "jest-mock"; import { getStoredOidcClientId, + getStoredOidcIdToken, getStoredOidcIdTokenClaims, getStoredOidcTokenIssuer, persistOidcAuthenticatedSettings, } from "../../../src/utils/oidc/persistOidcSettings"; +jest.mock("matrix-js-sdk/src/matrix"); + describe("persist OIDC settings", () => { jest.spyOn(Storage.prototype, "getItem"); jest.spyOn(Storage.prototype, "setItem"); @@ -33,6 +38,7 @@ describe("persist OIDC settings", () => { const clientId = "test-client-id"; const issuer = "https://auth.org/"; + const idToken = "test-id-token"; const idTokenClaims: IdTokenClaims = { // audience is this client aud: "123", @@ -44,45 +50,65 @@ describe("persist OIDC settings", () => { }; describe("persistOidcAuthenticatedSettings", () => { - it("should set clientId and issuer in session storage", () => { - persistOidcAuthenticatedSettings(clientId, issuer, idTokenClaims); + it("should set clientId and issuer in localStorage", () => { + persistOidcAuthenticatedSettings(clientId, issuer, idToken); expect(localStorage.setItem).toHaveBeenCalledWith("mx_oidc_client_id", clientId); expect(localStorage.setItem).toHaveBeenCalledWith("mx_oidc_token_issuer", issuer); - expect(localStorage.setItem).toHaveBeenCalledWith("mx_oidc_id_token_claims", JSON.stringify(idTokenClaims)); + expect(localStorage.setItem).toHaveBeenCalledWith("mx_oidc_id_token", idToken); }); }); describe("getStoredOidcTokenIssuer()", () => { - it("should return issuer from session storage", () => { + it("should return issuer from localStorage", () => { localStorage.setItem("mx_oidc_token_issuer", issuer); expect(getStoredOidcTokenIssuer()).toEqual(issuer); expect(localStorage.getItem).toHaveBeenCalledWith("mx_oidc_token_issuer"); }); - it("should return undefined when no issuer in session storage", () => { + it("should return undefined when no issuer in localStorage", () => { expect(getStoredOidcTokenIssuer()).toBeUndefined(); }); }); describe("getStoredOidcClientId()", () => { - it("should return clientId from session storage", () => { + it("should return clientId from localStorage", () => { localStorage.setItem("mx_oidc_client_id", clientId); expect(getStoredOidcClientId()).toEqual(clientId); expect(localStorage.getItem).toHaveBeenCalledWith("mx_oidc_client_id"); }); - it("should throw when no clientId in session storage", () => { + it("should throw when no clientId in localStorage", () => { expect(() => getStoredOidcClientId()).toThrow("Oidc client id not found in storage"); }); }); + describe("getStoredOidcIdToken()", () => { + it("should return token from localStorage", () => { + localStorage.setItem("mx_oidc_id_token", idToken); + expect(getStoredOidcIdToken()).toEqual(idToken); + expect(localStorage.getItem).toHaveBeenCalledWith("mx_oidc_id_token"); + }); + + it("should return undefined when no token in localStorage", () => { + expect(getStoredOidcIdToken()).toBeUndefined(); + }); + }); + describe("getStoredOidcIdTokenClaims()", () => { - it("should return issuer from session storage", () => { + it("should return claims from localStorage", () => { localStorage.setItem("mx_oidc_id_token_claims", JSON.stringify(idTokenClaims)); expect(getStoredOidcIdTokenClaims()).toEqual(idTokenClaims); expect(localStorage.getItem).toHaveBeenCalledWith("mx_oidc_id_token_claims"); }); - it("should return undefined when no issuer in session storage", () => { + it("should return claims extracted from id_token in localStorage", () => { + localStorage.setItem("mx_oidc_id_token", idToken); + mocked(decodeIdToken).mockReturnValue(idTokenClaims); + expect(getStoredOidcIdTokenClaims()).toEqual(idTokenClaims); + expect(decodeIdToken).toHaveBeenCalledWith(idToken); + expect(localStorage.getItem).toHaveBeenCalledWith("mx_oidc_id_token_claims"); + }); + + it("should return undefined when no claims in localStorage", () => { expect(getStoredOidcIdTokenClaims()).toBeUndefined(); }); });