Remove Params() object from endpoints classes.

This commit is contained in:
Ignacio Fiorentino 2016-09-09 14:49:41 -03:00
parent 8a63c83514
commit c14d2f055a
4 changed files with 94 additions and 114 deletions

View file

@ -14,7 +14,6 @@ from oidc_provider.lib.errors import (
ClientIdError, ClientIdError,
RedirectUriError, RedirectUriError,
) )
from oidc_provider.lib.utils.params import Params
from oidc_provider.lib.utils.token import ( from oidc_provider.lib.utils.token import (
create_code, create_code,
create_id_token, create_id_token,
@ -35,22 +34,22 @@ class AuthorizeEndpoint(object):
def __init__(self, request): def __init__(self, request):
self.request = request self.request = request
self.params = Params() self.params = {}
self._extract_params() self._extract_params()
# Determine which flow to use. # Determine which flow to use.
if self.params.response_type in ['code']: if self.params['response_type'] in ['code']:
self.grant_type = 'authorization_code' self.grant_type = 'authorization_code'
elif self.params.response_type in ['id_token', 'id_token token', 'token']: elif self.params['response_type'] in ['id_token', 'id_token token', 'token']:
self.grant_type = 'implicit' self.grant_type = 'implicit'
elif self.params.response_type in ['code token', 'code id_token', 'code id_token token']: elif self.params['response_type'] in ['code token', 'code id_token', 'code id_token token']:
self.grant_type = 'hybrid' self.grant_type = 'hybrid'
else: else:
self.grant_type = None self.grant_type = None
# Determine if it's an OpenID Authentication request (or OAuth2). # Determine if it's an OpenID Authentication request (or OAuth2).
self.is_authentication = 'openid' in self.params.scope self.is_authentication = 'openid' in self.params['scope']
def _extract_params(self): def _extract_params(self):
""" """
@ -64,58 +63,54 @@ class AuthorizeEndpoint(object):
query_dict = (self.request.POST if self.request.method == 'POST' query_dict = (self.request.POST if self.request.method == 'POST'
else self.request.GET) else self.request.GET)
self.params.client_id = query_dict.get('client_id', '') self.params['client_id'] = query_dict.get('client_id', '')
self.params.redirect_uri = query_dict.get('redirect_uri', '') self.params['redirect_uri'] = query_dict.get('redirect_uri', '')
self.params.response_type = query_dict.get('response_type', '') self.params['response_type'] = query_dict.get('response_type', '')
self.params.scope = query_dict.get('scope', '').split() self.params['scope'] = query_dict.get('scope', '').split()
self.params.state = query_dict.get('state', '') self.params['state'] = query_dict.get('state', '')
self.params['nonce'] = query_dict.get('nonce', '')
self.params.nonce = query_dict.get('nonce', '') self.params['prompt'] = query_dict.get('prompt', '')
self.params.prompt = query_dict.get('prompt', '') self.params['code_challenge'] = query_dict.get('code_challenge', '')
self.params.code_challenge = query_dict.get('code_challenge', '') self.params['code_challenge_method'] = query_dict.get('code_challenge_method', '')
self.params.code_challenge_method = query_dict.get('code_challenge_method', '')
def validate_params(self): def validate_params(self):
# Client validation. # Client validation.
try: try:
self.client = Client.objects.get(client_id=self.params.client_id) self.client = Client.objects.get(client_id=self.params['client_id'])
except Client.DoesNotExist: except Client.DoesNotExist:
logger.debug('[Authorize] Invalid client identifier: %s', self.params.client_id) logger.debug('[Authorize] Invalid client identifier: %s', self.params['client_id'])
raise ClientIdError() raise ClientIdError()
# Redirect URI validation. # Redirect URI validation.
if self.is_authentication and not self.params.redirect_uri: if self.is_authentication and not self.params['redirect_uri']:
logger.debug('[Authorize] Missing redirect uri.') logger.debug('[Authorize] Missing redirect uri.')
raise RedirectUriError() raise RedirectUriError()
clean_redirect_uri = urlsplit(self.params.redirect_uri) clean_redirect_uri = urlsplit(self.params['redirect_uri'])
clean_redirect_uri = urlunsplit(clean_redirect_uri._replace(query='')) clean_redirect_uri = urlunsplit(clean_redirect_uri._replace(query=''))
if not (clean_redirect_uri in self.client.redirect_uris): if not (clean_redirect_uri in self.client.redirect_uris):
logger.debug('[Authorize] Invalid redirect uri: %s', self.params.redirect_uri) logger.debug('[Authorize] Invalid redirect uri: %s', self.params['redirect_uri'])
raise RedirectUriError() raise RedirectUriError()
# Grant type validation. # Grant type validation.
if not self.grant_type: if not self.grant_type:
logger.debug('[Authorize] Invalid response type: %s', self.params.response_type) logger.debug('[Authorize] Invalid response type: %s', self.params['response_type'])
raise AuthorizeError(self.params.redirect_uri, 'unsupported_response_type', raise AuthorizeError(self.params['redirect_uri'], 'unsupported_response_type', self.grant_type)
self.grant_type)
# Nonce parameter validation. # Nonce parameter validation.
if self.is_authentication and self.grant_type == 'implicit' and not self.params.nonce: if self.is_authentication and self.grant_type == 'implicit' and not self.params['nonce']:
raise AuthorizeError(self.params.redirect_uri, 'invalid_request', raise AuthorizeError(self.params['redirect_uri'], 'invalid_request', self.grant_type)
self.grant_type)
# Response type parameter validation. # Response type parameter validation.
if self.is_authentication and self.params.response_type != self.client.response_type: if self.is_authentication and self.params['response_type'] != self.client.response_type:
raise AuthorizeError(self.params.redirect_uri, 'invalid_request', raise AuthorizeError(self.params['redirect_uri'], 'invalid_request', self.grant_type)
self.grant_type)
# PKCE validation of the transformation method. # PKCE validation of the transformation method.
if self.params.code_challenge: if self.params['code_challenge']:
if not (self.params.code_challenge_method in ['plain', 'S256']): if not (self.params['code_challenge_method'] in ['plain', 'S256']):
raise AuthorizeError(self.params.redirect_uri, 'invalid_request', self.grant_type) raise AuthorizeError(self.params['redirect_uri'], 'invalid_request', self.grant_type)
def create_response_uri(self): def create_response_uri(self):
uri = urlsplit(self.params.redirect_uri) uri = urlsplit(self.params['redirect_uri'])
query_params = parse_qs(uri.query) query_params = parse_qs(uri.query)
query_fragment = parse_qs(uri.fragment) query_fragment = parse_qs(uri.fragment)
@ -124,24 +119,24 @@ class AuthorizeEndpoint(object):
code = create_code( code = create_code(
user=self.request.user, user=self.request.user,
client=self.client, client=self.client,
scope=self.params.scope, scope=self.params['scope'],
nonce=self.params.nonce, nonce=self.params['nonce'],
is_authentication=self.is_authentication, is_authentication=self.is_authentication,
code_challenge=self.params.code_challenge, code_challenge=self.params['code_challenge'],
code_challenge_method=self.params.code_challenge_method) code_challenge_method=self.params['code_challenge_method'])
code.save() code.save()
if self.grant_type == 'authorization_code': if self.grant_type == 'authorization_code':
query_params['code'] = code.code query_params['code'] = code.code
query_params['state'] = self.params.state if self.params.state else '' query_params['state'] = self.params['state'] if self.params['state'] else ''
elif self.grant_type in ['implicit', 'hybrid']: elif self.grant_type in ['implicit', 'hybrid']:
token = create_token( token = create_token(
user=self.request.user, user=self.request.user,
client=self.client, client=self.client,
scope=self.params.scope) scope=self.params['scope'])
# Check if response_type must include access_token in the response. # Check if response_type must include access_token in the response.
if self.params.response_type in ['id_token token', 'token', 'code token', 'code id_token token']: if self.params['response_type'] in ['id_token token', 'token', 'code token', 'code id_token token']:
query_fragment['access_token'] = token.access_token query_fragment['access_token'] = token.access_token
# We don't need id_token if it's an OAuth2 request. # We don't need id_token if it's an OAuth2 request.
@ -149,9 +144,9 @@ class AuthorizeEndpoint(object):
kwargs = { kwargs = {
'user': self.request.user, 'user': self.request.user,
'aud': self.client.client_id, 'aud': self.client.client_id,
'nonce': self.params.nonce, 'nonce': self.params['nonce'],
'request': self.request, 'request': self.request,
'scope': self.params.scope, 'scope': self.params['scope'],
} }
# Include at_hash when access_token is being returned. # Include at_hash when access_token is being returned.
if 'access_token' in query_fragment: if 'access_token' in query_fragment:
@ -159,7 +154,7 @@ class AuthorizeEndpoint(object):
id_token_dic = create_id_token(**kwargs) id_token_dic = create_id_token(**kwargs)
# Check if response_type must include id_token in the response. # Check if response_type must include id_token in the response.
if self.params.response_type in ['id_token', 'id_token token', 'code id_token', 'code id_token token']: if self.params['response_type'] in ['id_token', 'id_token token', 'code id_token', 'code id_token token']:
query_fragment['id_token'] = encode_id_token(id_token_dic, self.client) query_fragment['id_token'] = encode_id_token(id_token_dic, self.client)
else: else:
id_token_dic = {} id_token_dic = {}
@ -176,14 +171,11 @@ class AuthorizeEndpoint(object):
query_fragment['expires_in'] = settings.get('OIDC_TOKEN_EXPIRE') query_fragment['expires_in'] = settings.get('OIDC_TOKEN_EXPIRE')
query_fragment['state'] = self.params.state if self.params.state else '' query_fragment['state'] = self.params['state'] if self.params['state'] else ''
except Exception as error: except Exception as error:
logger.debug('[Authorize] Error when trying to create response uri: %s', error) logger.debug('[Authorize] Error when trying to create response uri: %s', error)
raise AuthorizeError( raise AuthorizeError(self.params['redirect_uri'], 'server_error', self.grant_type)
self.params.redirect_uri,
'server_error',
self.grant_type)
uri = uri._replace(query=urlencode(query_params, doseq=True)) uri = uri._replace(query=urlencode(query_params, doseq=True))
uri = uri._replace(fragment=urlencode(query_fragment, doseq=True)) uri = uri._replace(fragment=urlencode(query_fragment, doseq=True))
@ -208,7 +200,7 @@ class AuthorizeEndpoint(object):
'date_given': date_given, 'date_given': date_given,
} }
) )
uc.scope = self.params.scope uc.scope = self.params['scope']
# Rewrite expires_at and date_given if object already exists. # Rewrite expires_at and date_given if object already exists.
if not created: if not created:
@ -225,10 +217,8 @@ class AuthorizeEndpoint(object):
""" """
value = False value = False
try: try:
uc = UserConsent.objects.get(user=self.request.user, uc = UserConsent.objects.get(user=self.request.user, client=self.client)
client=self.client) if (set(self.params['scope']).issubset(uc.scope)) and not (uc.has_expired()):
if (set(self.params.scope).issubset(uc.scope)) and \
not (uc.has_expired()):
value = True value = True
except UserConsent.DoesNotExist: except UserConsent.DoesNotExist:
pass pass
@ -239,9 +229,9 @@ class AuthorizeEndpoint(object):
""" """
Return a list with the description of all the scopes requested. Return a list with the description of all the scopes requested.
""" """
scopes = StandardScopeClaims.get_scopes_info(self.params.scope) scopes = StandardScopeClaims.get_scopes_info(self.params['scope'])
if settings.get('OIDC_EXTRA_SCOPE_CLAIMS'): if settings.get('OIDC_EXTRA_SCOPE_CLAIMS'):
scopes_extra = settings.get('OIDC_EXTRA_SCOPE_CLAIMS', import_str=True).get_scopes_info(self.params.scope) scopes_extra = settings.get('OIDC_EXTRA_SCOPE_CLAIMS', import_str=True).get_scopes_info(self.params['scope'])
for index_extra, scope_extra in enumerate(scopes_extra): for index_extra, scope_extra in enumerate(scopes_extra):
for index, scope in enumerate(scopes[:]): for index, scope in enumerate(scopes[:]):
if scope_extra['scope'] == scope['scope']: if scope_extra['scope'] == scope['scope']:

View file

@ -12,7 +12,6 @@ from django.http import JsonResponse
from oidc_provider.lib.errors import ( from oidc_provider.lib.errors import (
TokenError, TokenError,
) )
from oidc_provider.lib.utils.params import Params
from oidc_provider.lib.utils.token import ( from oidc_provider.lib.utils.token import (
create_id_token, create_id_token,
create_token, create_token,
@ -33,23 +32,22 @@ class TokenEndpoint(object):
def __init__(self, request): def __init__(self, request):
self.request = request self.request = request
self.params = Params() self.params = {}
self._extract_params() self._extract_params()
def _extract_params(self): def _extract_params(self):
client_id, client_secret = self._extract_client_auth() client_id, client_secret = self._extract_client_auth()
self.params.client_id = client_id self.params['client_id'] = client_id
self.params.client_secret = client_secret self.params['client_secret'] = client_secret
self.params.redirect_uri = unquote(self.request.POST.get('redirect_uri', '')) self.params['redirect_uri'] = unquote(self.request.POST.get('redirect_uri', ''))
self.params.grant_type = self.request.POST.get('grant_type', '') self.params['grant_type'] = self.request.POST.get('grant_type', '')
self.params.code = self.request.POST.get('code', '') self.params['code'] = self.request.POST.get('code', '')
self.params.state = self.request.POST.get('state', '') self.params['state'] = self.request.POST.get('state', '')
self.params.scope = self.request.POST.get('scope', '') self.params['scope'] = self.request.POST.get('scope', '')
self.params.refresh_token = self.request.POST.get('refresh_token', '') self.params['refresh_token'] = self.request.POST.get('refresh_token', '')
# PKCE parameter.
# PKCE parameters. self.params['code_verifier'] = self.request.POST.get('code_verifier')
self.params.code_verifier = self.request.POST.get('code_verifier')
def _extract_client_auth(self): def _extract_client_auth(self):
""" """
@ -76,68 +74,68 @@ class TokenEndpoint(object):
def validate_params(self): def validate_params(self):
try: try:
self.client = Client.objects.get(client_id=self.params.client_id) self.client = Client.objects.get(client_id=self.params['client_id'])
except Client.DoesNotExist: except Client.DoesNotExist:
logger.debug('[Token] Client does not exist: %s', self.params.client_id) logger.debug('[Token] Client does not exist: %s', self.params['client_id'])
raise TokenError('invalid_client') raise TokenError('invalid_client')
if self.client.client_type == 'confidential': if self.client.client_type == 'confidential':
if not (self.client.client_secret == self.params.client_secret): if not (self.client.client_secret == self.params['client_secret']):
logger.debug('[Token] Invalid client secret: client %s do not have secret %s', logger.debug('[Token] Invalid client secret: client %s do not have secret %s',
self.client.client_id, self.client.client_secret) self.client.client_id, self.client.client_secret)
raise TokenError('invalid_client') raise TokenError('invalid_client')
if self.params.grant_type == 'authorization_code': if self.params['grant_type'] == 'authorization_code':
if not (self.params.redirect_uri in self.client.redirect_uris): if not (self.params['redirect_uri'] in self.client.redirect_uris):
logger.debug('[Token] Invalid redirect uri: %s', self.params.redirect_uri) logger.debug('[Token] Invalid redirect uri: %s', self.params['redirect_uri'])
raise TokenError('invalid_client') raise TokenError('invalid_client')
try: try:
self.code = Code.objects.get(code=self.params.code) self.code = Code.objects.get(code=self.params['code'])
except Code.DoesNotExist: except Code.DoesNotExist:
logger.debug('[Token] Code does not exist: %s', self.params.code) logger.debug('[Token] Code does not exist: %s', self.params['code'])
raise TokenError('invalid_grant') raise TokenError('invalid_grant')
if not (self.code.client == self.client) \ if not (self.code.client == self.client) \
or self.code.has_expired(): or self.code.has_expired():
logger.debug('[Token] Invalid code: invalid client or code has expired', logger.debug('[Token] Invalid code: invalid client or code has expired',
self.params.redirect_uri) self.params['redirect_uri'])
raise TokenError('invalid_grant') raise TokenError('invalid_grant')
# Validate PKCE parameters. # Validate PKCE parameters.
if self.params.code_verifier: if self.params['code_verifier']:
if self.code.code_challenge_method == 'S256': if self.code.code_challenge_method == 'S256':
new_code_challenge = urlsafe_b64encode( new_code_challenge = urlsafe_b64encode(
hashlib.sha256(self.params.code_verifier.encode('ascii')).digest() hashlib.sha256(self.params['code_verifier'].encode('ascii')).digest()
).decode('utf-8').replace('=', '') ).decode('utf-8').replace('=', '')
else: else:
new_code_challenge = self.params.code_verifier new_code_challenge = self.params['code_verifier']
# TODO: We should explain the error. # TODO: We should explain the error.
if not (new_code_challenge == self.code.code_challenge): if not (new_code_challenge == self.code.code_challenge):
raise TokenError('invalid_grant') raise TokenError('invalid_grant')
elif self.params.grant_type == 'refresh_token': elif self.params['grant_type'] == 'refresh_token':
if not self.params.refresh_token: if not self.params['refresh_token']:
logger.debug('[Token] Missing refresh token') logger.debug('[Token] Missing refresh token')
raise TokenError('invalid_grant') raise TokenError('invalid_grant')
try: try:
self.token = Token.objects.get(refresh_token=self.params.refresh_token, self.token = Token.objects.get(refresh_token=self.params['refresh_token'],
client=self.client) client=self.client)
except Token.DoesNotExist: except Token.DoesNotExist:
logger.debug('[Token] Refresh token does not exist: %s', self.params.refresh_token) logger.debug('[Token] Refresh token does not exist: %s', self.params['refresh_token'])
raise TokenError('invalid_grant') raise TokenError('invalid_grant')
else: else:
logger.debug('[Token] Invalid grant type: %s', self.params.grant_type) logger.debug('[Token] Invalid grant type: %s', self.params['grant_type'])
raise TokenError('unsupported_grant_type') raise TokenError('unsupported_grant_type')
def create_response_dic(self): def create_response_dic(self):
if self.params.grant_type == 'authorization_code': if self.params['grant_type'] == 'authorization_code':
return self.create_code_response_dic() return self.create_code_response_dic()
elif self.params.grant_type == 'refresh_token': elif self.params['grant_type'] == 'refresh_token':
return self.create_refresh_response_dic() return self.create_refresh_response_dic()
def create_code_response_dic(self): def create_code_response_dic(self):
@ -153,7 +151,7 @@ class TokenEndpoint(object):
nonce=self.code.nonce, nonce=self.code.nonce,
at_hash=token.at_hash, at_hash=token.at_hash,
request=self.request, request=self.request,
scope=self.params.scope, scope=self.params['scope'],
) )
else: else:
id_token_dic = {} id_token_dic = {}
@ -189,7 +187,7 @@ class TokenEndpoint(object):
nonce=None, nonce=None,
at_hash=token.at_hash, at_hash=token.at_hash,
request=self.request, request=self.request,
scope=self.params.scope, scope=self.params['scope'],
) )
else: else:
id_token_dic = {} id_token_dic = {}

View file

@ -1,7 +0,0 @@
class Params(object):
"""
The purpose of this class is for accesing params via dot notation.
"""
pass

View file

@ -46,24 +46,24 @@ class AuthorizeView(View):
return hook_resp return hook_resp
if settings.get('OIDC_SKIP_CONSENT_ALWAYS') and not (authorize.client.client_type == 'public') \ if settings.get('OIDC_SKIP_CONSENT_ALWAYS') and not (authorize.client.client_type == 'public') \
and not (authorize.params.prompt == 'consent'): and not (authorize.params['prompt'] == 'consent'):
return redirect(authorize.create_response_uri()) return redirect(authorize.create_response_uri())
if settings.get('OIDC_SKIP_CONSENT_ENABLE'): if settings.get('OIDC_SKIP_CONSENT_ENABLE'):
# Check if user previously give consent. # Check if user previously give consent.
if authorize.client_has_user_consent() and not (authorize.client.client_type == 'public') \ if authorize.client_has_user_consent() and not (authorize.client.client_type == 'public') \
and not (authorize.params.prompt == 'consent'): and not (authorize.params['prompt'] == 'consent'):
return redirect(authorize.create_response_uri()) return redirect(authorize.create_response_uri())
if authorize.params.prompt == 'none': if authorize.params['prompt'] == 'none':
raise AuthorizeError(authorize.params.redirect_uri, 'interaction_required', authorize.grant_type) raise AuthorizeError(authorize.params['redirect_uri'], 'interaction_required', authorize.grant_type)
if authorize.params.prompt == 'login': if authorize.params['prompt'] == 'login':
return redirect_to_login(request.get_full_path()) return redirect_to_login(request.get_full_path())
if authorize.params.prompt == 'select_account': if authorize.params['prompt'] == 'select_account':
# TODO: see how we can support multiple accounts for the end-user. # TODO: see how we can support multiple accounts for the end-user.
raise AuthorizeError(authorize.params.redirect_uri, 'account_selection_required', authorize.grant_type) raise AuthorizeError(authorize.params['redirect_uri'], 'account_selection_required', authorize.grant_type)
# Generate hidden inputs for the form. # Generate hidden inputs for the form.
context = { context = {
@ -73,8 +73,8 @@ class AuthorizeView(View):
# Remove `openid` from scope list # Remove `openid` from scope list
# since we don't need to print it. # since we don't need to print it.
if 'openid' in authorize.params.scope: if 'openid' in authorize.params['scope']:
authorize.params.scope.remove('openid') authorize.params['scope'].remove('openid')
context = { context = {
'client': authorize.client, 'client': authorize.client,
@ -85,8 +85,8 @@ class AuthorizeView(View):
return render(request, 'oidc_provider/authorize.html', context) return render(request, 'oidc_provider/authorize.html', context)
else: else:
if authorize.params.prompt == 'none': if authorize.params['prompt'] == 'none':
raise AuthorizeError(authorize.params.redirect_uri, 'login_required', authorize.grant_type) raise AuthorizeError(authorize.params['redirect_uri'], 'login_required', authorize.grant_type)
return redirect_to_login(request.get_full_path()) return redirect_to_login(request.get_full_path())
@ -100,8 +100,8 @@ class AuthorizeView(View):
except (AuthorizeError) as error: except (AuthorizeError) as error:
uri = error.create_uri( uri = error.create_uri(
authorize.params.redirect_uri, authorize.params['redirect_uri'],
authorize.params.state) authorize.params['state'])
return redirect(uri) return redirect(uri)
@ -112,7 +112,7 @@ class AuthorizeView(View):
authorize.validate_params() authorize.validate_params()
if not request.POST.get('allow'): if not request.POST.get('allow'):
raise AuthorizeError(authorize.params.redirect_uri, raise AuthorizeError(authorize.params['redirect_uri'],
'access_denied', 'access_denied',
authorize.grant_type) authorize.grant_type)
@ -125,8 +125,8 @@ class AuthorizeView(View):
except (AuthorizeError) as error: except (AuthorizeError) as error:
uri = error.create_uri( uri = error.create_uri(
authorize.params.redirect_uri, authorize.params['redirect_uri'],
authorize.params.state) authorize.params['state'])
return redirect(uri) return redirect(uri)
@ -134,7 +134,6 @@ class AuthorizeView(View):
class TokenView(View): class TokenView(View):
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
token = TokenEndpoint(request) token = TokenEndpoint(request)
try: try: