Merge remote-tracking branch 'brosner/fix-id-token-at-hash' into test-implicit-flow

This commit is contained in:
Graham Ullrich 2016-08-05 15:49:18 -06:00
commit 86fbfdba60
4 changed files with 63 additions and 37 deletions

View file

@ -121,35 +121,41 @@ class AuthorizeEndpoint(object):
query_params['state'] = self.params.state if self.params.state else '' query_params['state'] = self.params.state if self.params.state else ''
elif self.grant_type == 'implicit': elif self.grant_type == 'implicit':
# We don't need id_token if it's an OAuth2 request.
if self.is_authentication:
id_token_dic = create_id_token(
user=self.request.user,
aud=self.client.client_id,
nonce=self.params.nonce,
request=self.request)
query_fragment['id_token'] = encode_id_token(id_token_dic, self.client)
else:
id_token_dic = {}
token = create_token( token = create_token(
user=self.request.user, user=self.request.user,
client=self.client, client=self.client,
id_token_dic=id_token_dic,
scope=self.params.scope) scope=self.params.scope)
# Store the token.
token.save()
query_fragment['token_type'] = 'bearer'
# TODO: Create setting 'OIDC_TOKEN_EXPIRE'.
query_fragment['expires_in'] = 60 * 10
# Check if response_type is an OpenID request with value 'id_token token' # Check if response_type is an OpenID request with value 'id_token token'
# or it's an OAuth2 Implicit Flow request. # or it's an OAuth2 Implicit Flow request.
if self.params.response_type in ['id_token token', 'token']: if self.params.response_type in ['id_token token', 'token']:
query_fragment['access_token'] = token.access_token query_fragment['access_token'] = token.access_token
# We don't need id_token if it's an OAuth2 request.
if self.is_authentication:
kwargs = {
"user": self.request.user,
"aud": self.client.client_id,
"nonce": self.params.nonce,
"request": self.request
}
# Include at_hash when access_token is being returned.
if 'access_token' in query_fragment:
kwargs['at_hash'] = token.at_hash
id_token_dic = create_id_token(**kwargs)
query_fragment['id_token'] = encode_id_token(id_token_dic, self.client)
token.id_token = id_token_dic
else:
id_token_dic = {}
# Store the token.
token.id_token = id_token_dic
token.save()
query_fragment['token_type'] = 'bearer'
# TODO: Create setting 'OIDC_TOKEN_EXPIRE'.
query_fragment['expires_in'] = 60 * 10
query_fragment['state'] = self.params.state if self.params.state else '' query_fragment['state'] = self.params.state if self.params.state else ''
except Exception as error: except Exception as error:

View file

@ -131,23 +131,24 @@ class TokenEndpoint(object):
return self.create_refresh_response_dic() return self.create_refresh_response_dic()
def create_code_response_dic(self): def create_code_response_dic(self):
token = create_token(
user=self.code.user,
client=self.code.client,
scope=self.code.scope)
if self.code.is_authentication: if self.code.is_authentication:
id_token_dic = create_id_token( id_token_dic = create_id_token(
user=self.code.user, user=self.code.user,
aud=self.client.client_id, aud=self.client.client_id,
nonce=self.code.nonce, nonce=self.code.nonce,
at_hash=token.at_hash,
request=self.request, request=self.request,
) )
else: else:
id_token_dic = {} id_token_dic = {}
token = create_token(
user=self.code.user,
client=self.code.client,
id_token_dic=id_token_dic,
scope=self.code.scope)
# Store the token. # Store the token.
token.id_token = id_token_dic
token.save() token.save()
# We don't need to store the code anymore. # We don't need to store the code anymore.
@ -164,24 +165,25 @@ class TokenEndpoint(object):
return dic return dic
def create_refresh_response_dic(self): def create_refresh_response_dic(self):
token = create_token(
user=self.token.user,
client=self.token.client,
scope=self.token.scope)
# If the Token has an id_token it's an Authentication request. # If the Token has an id_token it's an Authentication request.
if self.token.id_token: if self.token.id_token:
id_token_dic = create_id_token( id_token_dic = create_id_token(
user=self.token.user, user=self.token.user,
aud=self.client.client_id, aud=self.client.client_id,
nonce=None, nonce=None,
at_hash=token.at_hash,
request=self.request, request=self.request,
) )
else: else:
id_token_dic = {} id_token_dic = {}
token = create_token(
user=self.token.user,
client=self.token.client,
id_token_dic=id_token_dic,
scope=self.token.scope)
# Store the token. # Store the token.
token.id_token = id_token_dic
token.save() token.save()
# Forget the old token. # Forget the old token.

View file

@ -13,7 +13,7 @@ from oidc_provider.models import *
from oidc_provider import settings from oidc_provider import settings
def create_id_token(user, aud, nonce, request=None): def create_id_token(user, aud, nonce, at_hash=None, request=None):
""" """
Receives a user object and aud (audience). Receives a user object and aud (audience).
Then creates the id_token dictionary. Then creates the id_token dictionary.
@ -44,6 +44,9 @@ def create_id_token(user, aud, nonce, request=None):
if nonce: if nonce:
dic['nonce'] = str(nonce) dic['nonce'] = str(nonce)
if at_hash:
dic['at_hash'] = at_hash
processing_hook = settings.get('OIDC_IDTOKEN_PROCESSING_HOOK') processing_hook = settings.get('OIDC_IDTOKEN_PROCESSING_HOOK')
if isinstance(processing_hook, (list, tuple)): if isinstance(processing_hook, (list, tuple)):
@ -73,13 +76,13 @@ def encode_id_token(payload, client):
keys = [SYMKey(key=client.client_secret, alg=alg)] keys = [SYMKey(key=client.client_secret, alg=alg)]
else: else:
raise Exception('Unsupported key algorithm.') raise Exception('Unsupported key algorithm.')
_jws = JWS(payload, alg=alg) _jws = JWS(payload, alg=alg)
return _jws.sign_compact(keys) return _jws.sign_compact(keys)
def create_token(user, client, id_token_dic, scope): def create_token(user, client, scope, id_token_dic=None):
""" """
Create and populate a Token object. Create and populate a Token object.
@ -90,7 +93,8 @@ def create_token(user, client, id_token_dic, scope):
token.client = client token.client = client
token.access_token = uuid.uuid4().hex token.access_token = uuid.uuid4().hex
token.id_token = id_token_dic if id_token_dic is not None:
token.id_token = id_token_dic
token.refresh_token = uuid.uuid4().hex token.refresh_token = uuid.uuid4().hex
token.expires_at = timezone.now() + timedelta( token.expires_at = timezone.now() + timedelta(
@ -112,7 +116,7 @@ def create_code(user, client, scope, nonce, is_authentication,
code.client = client code.client = client
code.code = uuid.uuid4().hex code.code = uuid.uuid4().hex
if code_challenge and code_challenge_method: if code_challenge and code_challenge_method:
code.code_challenge = code_challenge code.code_challenge = code_challenge
code.code_challenge_method = code_challenge_method code.code_challenge_method = code_challenge_method

View file

@ -1,5 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from hashlib import md5 import base64
import binascii
from hashlib import md5, sha256
import json import json
from django.db import models from django.db import models
@ -117,6 +119,18 @@ class Token(BaseCodeTokenModel):
verbose_name = _(u'Token') verbose_name = _(u'Token')
verbose_name_plural = _(u'Tokens') verbose_name_plural = _(u'Tokens')
@property
def at_hash(self):
# @@@ d-o-p only supports 256 bits (change this if that changes)
hashed_access_token = sha256(
self.access_token.encode('ascii')
).hexdigest().encode('ascii')
return base64.urlsafe_b64encode(
binascii.unhexlify(
hashed_access_token[:len(hashed_access_token) // 2]
)
).rstrip(b'=').decode('ascii')
class UserConsent(BaseCodeTokenModel): class UserConsent(BaseCodeTokenModel):