diff --git a/oidc_provider/tests/cases/test_authorize_endpoint.py b/oidc_provider/tests/cases/test_authorize_endpoint.py index a062856..3bbf74b 100644 --- a/oidc_provider/tests/cases/test_authorize_endpoint.py +++ b/oidc_provider/tests/cases/test_authorize_endpoint.py @@ -468,6 +468,34 @@ class AuthorizationCodeFlowTestCase(TestCase, AuthorizeEndpointMixin): response = self._auth_request('get', data, is_user_authenticated=True) self.assertIn('consent_required', response['Location']) + def test_strip_prompt_login(self): + """ + Test for helper method test_strip_prompt_login. + """ + # Original paths + path0 = 'http://idp.com/?prompt=login' + path1 = 'http://idp.com/?prompt=consent login none' + path2 = ('http://idp.com/?response_type=code&client' + + '_id=112233&prompt=consent login') + path3 = ('http://idp.com/?response_type=code&client' + + '_id=112233&prompt=login none&redirect_uri' + + '=http://localhost:8000') + + self.assertNotIn('prompt', AuthorizeView.strip_prompt_login(path0)) + + self.assertIn('prompt', AuthorizeView.strip_prompt_login(path1)) + self.assertIn('consent', AuthorizeView.strip_prompt_login(path1)) + self.assertIn('none', AuthorizeView.strip_prompt_login(path1)) + self.assertNotIn('login', AuthorizeView.strip_prompt_login(path1)) + + self.assertIn('prompt', AuthorizeView.strip_prompt_login(path2)) + self.assertIn('consent', AuthorizeView.strip_prompt_login(path1)) + self.assertNotIn('login', AuthorizeView.strip_prompt_login(path2)) + + self.assertIn('prompt', AuthorizeView.strip_prompt_login(path3)) + self.assertIn('none', AuthorizeView.strip_prompt_login(path3)) + self.assertNotIn('login', AuthorizeView.strip_prompt_login(path3)) + class AuthorizationImplicitFlowTestCase(TestCase, AuthorizeEndpointMixin): """ diff --git a/oidc_provider/views.py b/oidc_provider/views.py index f84ecba..d0c8d0c 100644 --- a/oidc_provider/views.py +++ b/oidc_provider/views.py @@ -207,8 +207,10 @@ class AuthorizeView(View): """ uri = urlsplit(path) query_params = parse_qs(uri.query) - if 'login' in query_params['prompt']: - query_params['prompt'].remove('login') + prompt_list = query_params.get('prompt', '')[0].split() + if 'login' in prompt_list: + prompt_list.remove('login') + query_params['prompt'] = ' '.join(prompt_list) if not query_params['prompt']: del query_params['prompt'] uri = uri._replace(query=urlencode(query_params, doseq=True))