diff --git a/oidc_provider/lib/endpoints/authorize.py b/oidc_provider/lib/endpoints/authorize.py index d2d1951..b3ea536 100644 --- a/oidc_provider/lib/endpoints/authorize.py +++ b/oidc_provider/lib/endpoints/authorize.py @@ -121,35 +121,41 @@ class AuthorizeEndpoint(object): query_params['state'] = self.params.state if self.params.state else '' elif self.grant_type == 'implicit': - # We don't need id_token if it's an OAuth2 request. - if self.is_authentication: - id_token_dic = create_id_token( - user=self.request.user, - aud=self.client.client_id, - nonce=self.params.nonce, - request=self.request) - query_fragment['id_token'] = encode_id_token(id_token_dic, self.client) - else: - id_token_dic = {} - token = create_token( user=self.request.user, client=self.client, - id_token_dic=id_token_dic, scope=self.params.scope) - # Store the token. - token.save() - - query_fragment['token_type'] = 'bearer' - # TODO: Create setting 'OIDC_TOKEN_EXPIRE'. - query_fragment['expires_in'] = 60 * 10 - # Check if response_type is an OpenID request with value 'id_token token' # or it's an OAuth2 Implicit Flow request. if self.params.response_type in ['id_token token', 'token']: query_fragment['access_token'] = token.access_token + # We don't need id_token if it's an OAuth2 request. + if self.is_authentication: + kwargs = { + "user": self.request.user, + "aud": self.client.client_id, + "nonce": self.params.nonce, + "request": self.request + } + # Include at_hash when access_token is being returned. + if 'access_token' in query_fragment: + kwargs['at_hash'] = token.at_hash + id_token_dic = create_id_token(**kwargs) + query_fragment['id_token'] = encode_id_token(id_token_dic, self.client) + token.id_token = id_token_dic + else: + id_token_dic = {} + + # Store the token. + token.id_token = id_token_dic + token.save() + + query_fragment['token_type'] = 'bearer' + # TODO: Create setting 'OIDC_TOKEN_EXPIRE'. + query_fragment['expires_in'] = 60 * 10 + query_fragment['state'] = self.params.state if self.params.state else '' except Exception as error: diff --git a/oidc_provider/lib/endpoints/token.py b/oidc_provider/lib/endpoints/token.py index 200cb2c..6aa3b83 100644 --- a/oidc_provider/lib/endpoints/token.py +++ b/oidc_provider/lib/endpoints/token.py @@ -131,23 +131,24 @@ class TokenEndpoint(object): return self.create_refresh_response_dic() def create_code_response_dic(self): + token = create_token( + user=self.code.user, + client=self.code.client, + scope=self.code.scope) + if self.code.is_authentication: id_token_dic = create_id_token( user=self.code.user, aud=self.client.client_id, nonce=self.code.nonce, + at_hash=token.at_hash, request=self.request, ) else: id_token_dic = {} - token = create_token( - user=self.code.user, - client=self.code.client, - id_token_dic=id_token_dic, - scope=self.code.scope) - # Store the token. + token.id_token = id_token_dic token.save() # We don't need to store the code anymore. @@ -164,24 +165,25 @@ class TokenEndpoint(object): return dic def create_refresh_response_dic(self): + token = create_token( + user=self.token.user, + client=self.token.client, + scope=self.token.scope) + # If the Token has an id_token it's an Authentication request. if self.token.id_token: id_token_dic = create_id_token( user=self.token.user, aud=self.client.client_id, nonce=None, + at_hash=token.at_hash, request=self.request, ) else: id_token_dic = {} - token = create_token( - user=self.token.user, - client=self.token.client, - id_token_dic=id_token_dic, - scope=self.token.scope) - # Store the token. + token.id_token = id_token_dic token.save() # Forget the old token. diff --git a/oidc_provider/lib/utils/token.py b/oidc_provider/lib/utils/token.py index fc0880d..83291ec 100644 --- a/oidc_provider/lib/utils/token.py +++ b/oidc_provider/lib/utils/token.py @@ -13,7 +13,7 @@ from oidc_provider.models import * from oidc_provider import settings -def create_id_token(user, aud, nonce, request=None): +def create_id_token(user, aud, nonce, at_hash=None, request=None): """ Receives a user object and aud (audience). Then creates the id_token dictionary. @@ -44,6 +44,9 @@ def create_id_token(user, aud, nonce, request=None): if nonce: dic['nonce'] = str(nonce) + if at_hash: + dic['at_hash'] = at_hash + processing_hook = settings.get('OIDC_IDTOKEN_PROCESSING_HOOK') if isinstance(processing_hook, (list, tuple)): @@ -73,13 +76,13 @@ def encode_id_token(payload, client): keys = [SYMKey(key=client.client_secret, alg=alg)] else: raise Exception('Unsupported key algorithm.') - + _jws = JWS(payload, alg=alg) return _jws.sign_compact(keys) -def create_token(user, client, id_token_dic, scope): +def create_token(user, client, scope, id_token_dic=None): """ Create and populate a Token object. @@ -90,7 +93,8 @@ def create_token(user, client, id_token_dic, scope): token.client = client token.access_token = uuid.uuid4().hex - token.id_token = id_token_dic + if id_token_dic is not None: + token.id_token = id_token_dic token.refresh_token = uuid.uuid4().hex token.expires_at = timezone.now() + timedelta( @@ -112,7 +116,7 @@ def create_code(user, client, scope, nonce, is_authentication, code.client = client code.code = uuid.uuid4().hex - + if code_challenge and code_challenge_method: code.code_challenge = code_challenge code.code_challenge_method = code_challenge_method diff --git a/oidc_provider/models.py b/oidc_provider/models.py index 09b36a2..284e858 100644 --- a/oidc_provider/models.py +++ b/oidc_provider/models.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- -from hashlib import md5 +import base64 +import binascii +from hashlib import md5, sha256 import json from django.db import models @@ -117,6 +119,18 @@ class Token(BaseCodeTokenModel): verbose_name = _(u'Token') verbose_name_plural = _(u'Tokens') + @property + def at_hash(self): + # @@@ d-o-p only supports 256 bits (change this if that changes) + hashed_access_token = sha256( + self.access_token.encode('ascii') + ).hexdigest().encode('ascii') + return base64.urlsafe_b64encode( + binascii.unhexlify( + hashed_access_token[:len(hashed_access_token) // 2] + ) + ).rstrip(b'=').decode('ascii') + class UserConsent(BaseCodeTokenModel):