diff --git a/docs/sections/settings.rst b/docs/sections/settings.rst index f6f3131..fbed9fc 100644 --- a/docs/sections/settings.rst +++ b/docs/sections/settings.rst @@ -81,12 +81,28 @@ Here you can add extra dictionary values specific for your app into id_token. The ``list`` or ``tuple`` is useful when you want to set multiple hooks, i.e. one for permissions and second for some special field. -The function receives a ``id_token`` dictionary and ``user`` instance -and returns it with additional fields. +The hook function receives following arguments: + + * ``id_token``: the ID token dictionary which contains at least the + basic claims (``iss``, ``sub``, ``aud``, ``exp``, ``iat``, + ``auth_time``), but may also contain other claims. If several + processing hooks are configured, then the claims of the previous hook + are also present in the passed dictionary. + * ``user``: User object of the authenticating user, + * ``scope``: the authorized scopes as list of strings or None, + * ``token``: the Token object created for the authentication request, and + * ``request``: Django request object of the authentication request. + +The hook function should return the modified ID token as dictionary. + +.. note:: + It is a good idea to add ``**kwargs`` to the hook function argument + list so that the hook function will work even if new arguments are + added to the hook function call signature. Default is:: - def default_idtoken_processing_hook(id_token, user): + def default_idtoken_processing_hook(id_token, user, scope, token, request, **kwargs): return id_token diff --git a/oidc_provider/lib/utils/common.py b/oidc_provider/lib/utils/common.py index b667b38..0ecc95f 100644 --- a/oidc_provider/lib/utils/common.py +++ b/oidc_provider/lib/utils/common.py @@ -107,9 +107,10 @@ def default_after_end_session_hook( return None -def default_idtoken_processing_hook(id_token, user): +def default_idtoken_processing_hook( + id_token, user, scope, token, request, **kwargs): """ - Hook to perform some additional actions ti `id_token` dictionary just before serialization. + Hook for modifying `id_token` just before serialization. :param id_token: dictionary contains values that going to be serialized into `id_token` :type id_token: dict @@ -117,8 +118,17 @@ def default_idtoken_processing_hook(id_token, user): :param user: user for whom id_token is generated :type user: User + :param scope: scope for the token + :type scope: list[str]|None + + :param token: the Token object created for the authentication request + :type token: oidc_provider.models.Token + + :param request: the request initiating this ID token processing + :type request: django.http.HttpRequest + :return: custom modified dictionary of values for `id_token` - :rtype dict + :rtype: dict """ return id_token @@ -144,10 +154,12 @@ def get_browser_state_or_default(request): def run_processing_hook(subject, hook_settings_name, **kwargs): - processing_hook = settings.get(hook_settings_name) - if isinstance(processing_hook, (list, tuple)): - for hook in processing_hook: - subject = settings.import_from_str(hook)(subject, **kwargs) - else: - subject = settings.import_from_str(processing_hook)(subject, **kwargs) + processing_hooks = settings.get(hook_settings_name) + if not isinstance(processing_hooks, (list, tuple)): + processing_hooks = [processing_hooks] + + for hook_string in processing_hooks: + hook = settings.import_from_str(hook_string) + subject = hook(subject, **kwargs) + return subject diff --git a/oidc_provider/lib/utils/token.py b/oidc_provider/lib/utils/token.py index a413bc8..089ce45 100644 --- a/oidc_provider/lib/utils/token.py +++ b/oidc_provider/lib/utils/token.py @@ -62,7 +62,9 @@ def create_id_token(token, user, aud, nonce='', at_hash='', request=None, scope= claims = StandardScopeClaims(token).create_response_dic() dic.update(claims) - dic = run_processing_hook(dic, 'OIDC_IDTOKEN_PROCESSING_HOOK', user=user) + dic = run_processing_hook( + dic, 'OIDC_IDTOKEN_PROCESSING_HOOK', + user=user, scope=scope, token=token, request=request) return dic diff --git a/oidc_provider/tests/app/utils.py b/oidc_provider/tests/app/utils.py index 457757d..63ddc8d 100644 --- a/oidc_provider/tests/app/utils.py +++ b/oidc_provider/tests/app/utils.py @@ -113,7 +113,7 @@ def fake_sub_generator(user): return user.email -def fake_idtoken_processing_hook(id_token, user): +def fake_idtoken_processing_hook(id_token, user, **kwargs): """ Fake function for inserting some keys into token. Testing OIDC_IDTOKEN_PROCESSING_HOOK. """ @@ -122,7 +122,7 @@ def fake_idtoken_processing_hook(id_token, user): return id_token -def fake_idtoken_processing_hook2(id_token, user): +def fake_idtoken_processing_hook2(id_token, user, **kwargs): """ Fake function for inserting some keys into token. Testing OIDC_IDTOKEN_PROCESSING_HOOK - tuple or list as param @@ -132,6 +132,25 @@ def fake_idtoken_processing_hook2(id_token, user): return id_token +def fake_idtoken_processing_hook3(id_token, user, scope=None, **kwargs): + """ + Fake function for checking scope is passed to processing hook. + """ + id_token['scope_passed_to_processing_hook'] = scope + return id_token + + +def fake_idtoken_processing_hook4(id_token, user, **kwargs): + """ + Fake function for checking kwargs passed to processing hook. + """ + id_token['kwargs_passed_to_processing_hook'] = { + key: repr(value) + for (key, value) in kwargs.items() + } + return id_token + + def fake_introspection_processing_hook(response_dict, client, id_token): response_dict['test_introspection_processing_hook'] = FAKE_RANDOM_STRING return response_dict diff --git a/oidc_provider/tests/cases/test_token_endpoint.py b/oidc_provider/tests/cases/test_token_endpoint.py index 0646046..fcedd35 100644 --- a/oidc_provider/tests/cases/test_token_endpoint.py +++ b/oidc_provider/tests/cases/test_token_endpoint.py @@ -728,6 +728,47 @@ class TokenTestCase(TestCase): self.assertEqual(id_token.get('test_idtoken_processing_hook2'), FAKE_RANDOM_STRING) self.assertEqual(id_token.get('test_idtoken_processing_hook_user_email2'), self.user.email) + @override_settings( + OIDC_IDTOKEN_PROCESSING_HOOK=( + 'oidc_provider.tests.app.utils.fake_idtoken_processing_hook3')) + def test_additional_idtoken_processing_hook_scope_param(self): + """ + Test scope parameter is passed to OIDC_IDTOKEN_PROCESSING_HOOK. + """ + id_token = self._request_id_token_with_scope( + ['openid', 'email', 'profile', 'dummy']) + self.assertEqual( + id_token.get('scope_passed_to_processing_hook'), + ['openid', 'email', 'profile', 'dummy']) + + @override_settings( + OIDC_IDTOKEN_PROCESSING_HOOK=( + 'oidc_provider.tests.app.utils.fake_idtoken_processing_hook4')) + def test_additional_idtoken_processing_hook_kwargs(self): + """ + Test correct kwargs are passed to OIDC_IDTOKEN_PROCESSING_HOOK. + """ + id_token = self._request_id_token_with_scope(['openid', 'profile']) + kwargs_passed = id_token.get('kwargs_passed_to_processing_hook') + assert kwargs_passed + self.assertEqual(kwargs_passed.get('scope'), + repr([u'openid', u'profile'])) + self.assertEqual(kwargs_passed.get('token'), + '') + self.assertEqual(kwargs_passed.get('request'), + "") + + def _request_id_token_with_scope(self, scope): + code = self._create_code(scope) + + post_data = self._auth_code_post_data(code=code.code) + + response = self._post_request(post_data) + + response_dic = json.loads(response.content.decode('utf-8')) + id_token = JWT().unpack(response_dic['id_token'].encode('utf-8')).payload() + return id_token + def test_pkce_parameters(self): """ Test Proof Key for Code Exchange by OAuth Public Clients.