From 7cb5b4d54e6611dae14e8edd08cc40560919ca92 Mon Sep 17 00:00:00 2001 From: Wojciech Bartosiak Date: Tue, 1 Mar 2016 17:54:57 +0000 Subject: [PATCH] str or list or tuple for OIDC_ID_TOKEN_PROCESSING_HOOK --- docs/sections/settings.rst | 6 +- oidc_provider/lib/utils/token.py | 8 +- oidc_provider/tests/app/utils.py | 9 +++ oidc_provider/tests/test_token_endpoint.py | 92 ++++++++++++++++++++++ 4 files changed, 113 insertions(+), 2 deletions(-) diff --git a/docs/sections/settings.rst b/docs/sections/settings.rst index f3e4329..23f295b 100644 --- a/docs/sections/settings.rst +++ b/docs/sections/settings.rst @@ -95,9 +95,13 @@ Expressed in seconds. Default is ``60*10``. OIDC_IDTOKEN_PROCESSING_HOOK ============================ -OPTIONAL. ``str``. A string with the location of your function hook. +OPTIONAL. ``str`` or ``(list, tuple)``. + +A string with the location of your function hook or ``list`` or ``tuple`` with hook functions. Here you can add extra dictionary values specific for your app into id_token. +The ``list`` or ``tuple`` is useful when You want to set multiple hooks, i.e. one for permissions and second for some special field. + The function receives a ``id_token`` dictionary and ``user`` instance and returns it with additional fields. diff --git a/oidc_provider/lib/utils/token.py b/oidc_provider/lib/utils/token.py index 4708c28..e512326 100644 --- a/oidc_provider/lib/utils/token.py +++ b/oidc_provider/lib/utils/token.py @@ -44,7 +44,13 @@ def create_id_token(user, aud, nonce): if nonce: dic['nonce'] = str(nonce) - dic = settings.get('OIDC_IDTOKEN_PROCESSING_HOOK', import_str=True)(dic, user=user) + processing_hook = settings.get('OIDC_IDTOKEN_PROCESSING_HOOK') + + if isinstance(processing_hook, (list, tuple)): + for hook in processing_hook: + dic = settings.import_from_str(hook)(dic, user=user) + else: + dic = settings.import_from_str(processing_hook)(dic, user=user) return dic diff --git a/oidc_provider/tests/app/utils.py b/oidc_provider/tests/app/utils.py index 9b76233..bd3989d 100644 --- a/oidc_provider/tests/app/utils.py +++ b/oidc_provider/tests/app/utils.py @@ -115,3 +115,12 @@ def fake_idtoken_processing_hook(id_token, user): id_token['test_idtoken_processing_hook'] = FAKE_RANDOM_STRING id_token['test_idtoken_processing_hook_user_email'] = user.email return id_token + + +def fake_idtoken_processing_hook2(id_token, user): + """ + Fake function for inserting some keys into token. Testing OIDC_IDTOKEN_PROCESSING_HOOK - tuple or list as param + """ + id_token['test_idtoken_processing_hook2'] = FAKE_RANDOM_STRING + id_token['test_idtoken_processing_hook_user_email2'] = user.email + return id_token diff --git a/oidc_provider/tests/test_token_endpoint.py b/oidc_provider/tests/test_token_endpoint.py index bb9e772..b17408d 100644 --- a/oidc_provider/tests/test_token_endpoint.py +++ b/oidc_provider/tests/test_token_endpoint.py @@ -353,3 +353,95 @@ class TokenTestCase(TestCase): self.assertEqual(id_token.get('test_idtoken_processing_hook'), FAKE_RANDOM_STRING) self.assertEqual(id_token.get('test_idtoken_processing_hook_user_email'), self.user.email) + + @override_settings( + OIDC_IDTOKEN_PROCESSING_HOOK=( + 'oidc_provider.tests.app.utils.fake_idtoken_processing_hook', + ) + ) + def test_additional_idtoken_processing_hook_one_element_in_tuple(self): + """ + Test custom function for setting OIDC_IDTOKEN_PROCESSING_HOOK. + """ + code = self._create_code() + + post_data = self._auth_code_post_data(code=code.code) + + response = self._post_request(post_data) + + response_dic = json.loads(response.content.decode('utf-8')) + id_token = JWT().unpack(response_dic['id_token'].encode('utf-8')).payload() + + self.assertEqual(id_token.get('test_idtoken_processing_hook'), FAKE_RANDOM_STRING) + self.assertEqual(id_token.get('test_idtoken_processing_hook_user_email'), self.user.email) + + @override_settings( + OIDC_IDTOKEN_PROCESSING_HOOK=[ + 'oidc_provider.tests.app.utils.fake_idtoken_processing_hook', + ] + ) + def test_additional_idtoken_processing_hook_one_element_in_list(self): + """ + Test custom function for setting OIDC_IDTOKEN_PROCESSING_HOOK. + """ + code = self._create_code() + + post_data = self._auth_code_post_data(code=code.code) + + response = self._post_request(post_data) + + response_dic = json.loads(response.content.decode('utf-8')) + id_token = JWT().unpack(response_dic['id_token'].encode('utf-8')).payload() + + self.assertEqual(id_token.get('test_idtoken_processing_hook'), FAKE_RANDOM_STRING) + self.assertEqual(id_token.get('test_idtoken_processing_hook_user_email'), self.user.email) + + @override_settings( + OIDC_IDTOKEN_PROCESSING_HOOK=[ + 'oidc_provider.tests.app.utils.fake_idtoken_processing_hook', + 'oidc_provider.tests.app.utils.fake_idtoken_processing_hook2', + ] + ) + def test_additional_idtoken_processing_hook_two_elements_in_list(self): + """ + Test custom function for setting OIDC_IDTOKEN_PROCESSING_HOOK. + """ + code = self._create_code() + + post_data = self._auth_code_post_data(code=code.code) + + response = self._post_request(post_data) + + response_dic = json.loads(response.content.decode('utf-8')) + id_token = JWT().unpack(response_dic['id_token'].encode('utf-8')).payload() + + self.assertEqual(id_token.get('test_idtoken_processing_hook'), FAKE_RANDOM_STRING) + self.assertEqual(id_token.get('test_idtoken_processing_hook_user_email'), self.user.email) + + self.assertEqual(id_token.get('test_idtoken_processing_hook2'), FAKE_RANDOM_STRING) + self.assertEqual(id_token.get('test_idtoken_processing_hook_user_email2'), self.user.email) + + @override_settings( + OIDC_IDTOKEN_PROCESSING_HOOK=( + 'oidc_provider.tests.app.utils.fake_idtoken_processing_hook', + 'oidc_provider.tests.app.utils.fake_idtoken_processing_hook2', + ) + ) + def test_additional_idtoken_processing_hook_two_elements_in_tuple(self): + """ + Test custom function for setting OIDC_IDTOKEN_PROCESSING_HOOK. + """ + code = self._create_code() + + post_data = self._auth_code_post_data(code=code.code) + + response = self._post_request(post_data) + + response_dic = json.loads(response.content.decode('utf-8')) + id_token = JWT().unpack(response_dic['id_token'].encode('utf-8')).payload() + + self.assertEqual(id_token.get('test_idtoken_processing_hook'), FAKE_RANDOM_STRING) + self.assertEqual(id_token.get('test_idtoken_processing_hook_user_email'), self.user.email) + + self.assertEqual(id_token.get('test_idtoken_processing_hook2'), FAKE_RANDOM_STRING) + self.assertEqual(id_token.get('test_idtoken_processing_hook_user_email2'), self.user.email)