Enhancement: AuthorizeView's static method strip-prompt-login was moved to a new file oidc_provider/lib/utils/authorize.py in order to be more consistent with the implementation of other Views
This commit is contained in:
parent
035e7a3674
commit
eb2f272a0b
21
oidc_provider/lib/utils/authorize.py
Normal file
21
oidc_provider/lib/utils/authorize.py
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
try:
|
||||||
|
from urllib import urlencode
|
||||||
|
from urlparse import urlsplit, parse_qs, urlunsplit
|
||||||
|
except ImportError:
|
||||||
|
from urllib.parse import urlsplit, parse_qs, urlunsplit, urlencode
|
||||||
|
|
||||||
|
|
||||||
|
def strip_prompt_login(path):
|
||||||
|
"""
|
||||||
|
Strips 'login' from the 'prompt' query parameter.
|
||||||
|
"""
|
||||||
|
uri = urlsplit(path)
|
||||||
|
query_params = parse_qs(uri.query)
|
||||||
|
prompt_list = query_params.get('prompt', '')[0].split()
|
||||||
|
if 'login' in prompt_list:
|
||||||
|
prompt_list.remove('login')
|
||||||
|
query_params['prompt'] = ' '.join(prompt_list)
|
||||||
|
if not query_params['prompt']:
|
||||||
|
del query_params['prompt']
|
||||||
|
uri = uri._replace(query=urlencode(query_params, doseq=True))
|
||||||
|
return urlunsplit(uri)
|
|
@ -31,6 +31,7 @@ from oidc_provider.tests.app.utils import (
|
||||||
FAKE_CODE_CHALLENGE,
|
FAKE_CODE_CHALLENGE,
|
||||||
is_code_valid,
|
is_code_valid,
|
||||||
)
|
)
|
||||||
|
from oidc_provider.lib.utils.authorize import strip_prompt_login
|
||||||
from oidc_provider.views import AuthorizeView
|
from oidc_provider.views import AuthorizeView
|
||||||
from oidc_provider.lib.endpoints.authorize import AuthorizeEndpoint
|
from oidc_provider.lib.endpoints.authorize import AuthorizeEndpoint
|
||||||
|
|
||||||
|
@ -481,20 +482,20 @@ class AuthorizationCodeFlowTestCase(TestCase, AuthorizeEndpointMixin):
|
||||||
'_id=112233&prompt=login none&redirect_uri' +
|
'_id=112233&prompt=login none&redirect_uri' +
|
||||||
'=http://localhost:8000')
|
'=http://localhost:8000')
|
||||||
|
|
||||||
self.assertNotIn('prompt', AuthorizeView.strip_prompt_login(path0))
|
self.assertNotIn('prompt', strip_prompt_login(path0))
|
||||||
|
|
||||||
self.assertIn('prompt', AuthorizeView.strip_prompt_login(path1))
|
self.assertIn('prompt', strip_prompt_login(path1))
|
||||||
self.assertIn('consent', AuthorizeView.strip_prompt_login(path1))
|
self.assertIn('consent', strip_prompt_login(path1))
|
||||||
self.assertIn('none', AuthorizeView.strip_prompt_login(path1))
|
self.assertIn('none', strip_prompt_login(path1))
|
||||||
self.assertNotIn('login', AuthorizeView.strip_prompt_login(path1))
|
self.assertNotIn('login', strip_prompt_login(path1))
|
||||||
|
|
||||||
self.assertIn('prompt', AuthorizeView.strip_prompt_login(path2))
|
self.assertIn('prompt', strip_prompt_login(path2))
|
||||||
self.assertIn('consent', AuthorizeView.strip_prompt_login(path1))
|
self.assertIn('consent', strip_prompt_login(path1))
|
||||||
self.assertNotIn('login', AuthorizeView.strip_prompt_login(path2))
|
self.assertNotIn('login', strip_prompt_login(path2))
|
||||||
|
|
||||||
self.assertIn('prompt', AuthorizeView.strip_prompt_login(path3))
|
self.assertIn('prompt', strip_prompt_login(path3))
|
||||||
self.assertIn('none', AuthorizeView.strip_prompt_login(path3))
|
self.assertIn('none', strip_prompt_login(path3))
|
||||||
self.assertNotIn('login', AuthorizeView.strip_prompt_login(path3))
|
self.assertNotIn('login', strip_prompt_login(path3))
|
||||||
|
|
||||||
|
|
||||||
class AuthorizationImplicitFlowTestCase(TestCase, AuthorizeEndpointMixin):
|
class AuthorizationImplicitFlowTestCase(TestCase, AuthorizeEndpointMixin):
|
||||||
|
|
|
@ -39,6 +39,7 @@ from oidc_provider.lib.errors import (
|
||||||
TokenError,
|
TokenError,
|
||||||
UserAuthError,
|
UserAuthError,
|
||||||
TokenIntrospectionError)
|
TokenIntrospectionError)
|
||||||
|
from oidc_provider.lib.utils.authorize import strip_prompt_login
|
||||||
from oidc_provider.lib.utils.common import (
|
from oidc_provider.lib.utils.common import (
|
||||||
redirect,
|
redirect,
|
||||||
get_site_url,
|
get_site_url,
|
||||||
|
@ -84,7 +85,7 @@ class AuthorizeView(View):
|
||||||
authorize.grant_type)
|
authorize.grant_type)
|
||||||
else:
|
else:
|
||||||
django_user_logout(request)
|
django_user_logout(request)
|
||||||
next_page = self.strip_prompt_login(request.get_full_path())
|
next_page = strip_prompt_login(request.get_full_path())
|
||||||
return redirect_to_login(next_page, settings.get('OIDC_LOGIN_URL'))
|
return redirect_to_login(next_page, settings.get('OIDC_LOGIN_URL'))
|
||||||
|
|
||||||
if 'select_account' in authorize.params['prompt']:
|
if 'select_account' in authorize.params['prompt']:
|
||||||
|
@ -147,7 +148,7 @@ class AuthorizeView(View):
|
||||||
raise AuthorizeError(
|
raise AuthorizeError(
|
||||||
authorize.params['redirect_uri'], 'login_required', authorize.grant_type)
|
authorize.params['redirect_uri'], 'login_required', authorize.grant_type)
|
||||||
if 'login' in authorize.params['prompt']:
|
if 'login' in authorize.params['prompt']:
|
||||||
next_page = self.strip_prompt_login(request.get_full_path())
|
next_page = strip_prompt_login(request.get_full_path())
|
||||||
return redirect_to_login(next_page, settings.get('OIDC_LOGIN_URL'))
|
return redirect_to_login(next_page, settings.get('OIDC_LOGIN_URL'))
|
||||||
|
|
||||||
return redirect_to_login(request.get_full_path(), settings.get('OIDC_LOGIN_URL'))
|
return redirect_to_login(request.get_full_path(), settings.get('OIDC_LOGIN_URL'))
|
||||||
|
@ -200,22 +201,6 @@ class AuthorizeView(View):
|
||||||
|
|
||||||
return redirect(uri)
|
return redirect(uri)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def strip_prompt_login(path):
|
|
||||||
"""
|
|
||||||
Strips 'login' from the 'prompt' query parameter.
|
|
||||||
"""
|
|
||||||
uri = urlsplit(path)
|
|
||||||
query_params = parse_qs(uri.query)
|
|
||||||
prompt_list = query_params.get('prompt', '')[0].split()
|
|
||||||
if 'login' in prompt_list:
|
|
||||||
prompt_list.remove('login')
|
|
||||||
query_params['prompt'] = ' '.join(prompt_list)
|
|
||||||
if not query_params['prompt']:
|
|
||||||
del query_params['prompt']
|
|
||||||
uri = uri._replace(query=urlencode(query_params, doseq=True))
|
|
||||||
return urlunsplit(uri)
|
|
||||||
|
|
||||||
|
|
||||||
class TokenView(View):
|
class TokenView(View):
|
||||||
def post(self, request, *args, **kwargs):
|
def post(self, request, *args, **kwargs):
|
||||||
|
|
Loading…
Reference in a new issue