Now passing along the token to create_id_token function.
This commit is contained in:
parent
8c736b8b08
commit
900cc9e5df
9 changed files with 30 additions and 18 deletions
|
@ -16,11 +16,12 @@ STANDARD_CLAIMS = {
|
||||||
|
|
||||||
class ScopeClaims(object):
|
class ScopeClaims(object):
|
||||||
|
|
||||||
def __init__(self, user, scope):
|
def __init__(self, token):
|
||||||
self.user = user
|
self.user = token.user
|
||||||
claims = copy.deepcopy(STANDARD_CLAIMS)
|
claims = copy.deepcopy(STANDARD_CLAIMS)
|
||||||
self.userinfo = settings.get('OIDC_USERINFO', import_str=True)(claims, self.user)
|
self.userinfo = settings.get('OIDC_USERINFO', import_str=True)(claims, self.user)
|
||||||
self.scopes = scope
|
self.scopes = token.scope
|
||||||
|
self.client = token.client
|
||||||
|
|
||||||
def create_response_dic(self):
|
def create_response_dic(self):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -155,6 +155,7 @@ class AuthorizeEndpoint(object):
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'user': self.request.user,
|
'user': self.request.user,
|
||||||
'aud': self.client.client_id,
|
'aud': self.client.client_id,
|
||||||
|
'token': token,
|
||||||
'nonce': self.params['nonce'],
|
'nonce': self.params['nonce'],
|
||||||
'request': self.request,
|
'request': self.request,
|
||||||
'scope': self.params['scope'],
|
'scope': self.params['scope'],
|
||||||
|
|
|
@ -164,6 +164,7 @@ class TokenEndpoint(object):
|
||||||
id_token_dic = create_id_token(
|
id_token_dic = create_id_token(
|
||||||
user=self.user,
|
user=self.user,
|
||||||
aud=self.client.client_id,
|
aud=self.client.client_id,
|
||||||
|
token=token,
|
||||||
nonce='self.code.nonce',
|
nonce='self.code.nonce',
|
||||||
at_hash=token.at_hash,
|
at_hash=token.at_hash,
|
||||||
request=self.request,
|
request=self.request,
|
||||||
|
@ -193,6 +194,7 @@ class TokenEndpoint(object):
|
||||||
id_token_dic = create_id_token(
|
id_token_dic = create_id_token(
|
||||||
user=self.code.user,
|
user=self.code.user,
|
||||||
aud=self.client.client_id,
|
aud=self.client.client_id,
|
||||||
|
token=token,
|
||||||
nonce=self.code.nonce,
|
nonce=self.code.nonce,
|
||||||
at_hash=token.at_hash,
|
at_hash=token.at_hash,
|
||||||
request=self.request,
|
request=self.request,
|
||||||
|
@ -237,6 +239,7 @@ class TokenEndpoint(object):
|
||||||
id_token_dic = create_id_token(
|
id_token_dic = create_id_token(
|
||||||
user=self.token.user,
|
user=self.token.user,
|
||||||
aud=self.client.client_id,
|
aud=self.client.client_id,
|
||||||
|
token=token,
|
||||||
nonce=None,
|
nonce=None,
|
||||||
at_hash=token.at_hash,
|
at_hash=token.at_hash,
|
||||||
request=self.request,
|
request=self.request,
|
||||||
|
|
|
@ -19,7 +19,7 @@ from oidc_provider.models import (
|
||||||
from oidc_provider import settings
|
from oidc_provider import settings
|
||||||
|
|
||||||
|
|
||||||
def create_id_token(user, aud, nonce='', at_hash='', request=None, scope=None):
|
def create_id_token(user, aud, token, nonce='', at_hash='', request=None, scope=None):
|
||||||
"""
|
"""
|
||||||
Creates the id_token dictionary.
|
Creates the id_token dictionary.
|
||||||
See: http://openid.net/specs/openid-connect-core-1_0.html#IDToken
|
See: http://openid.net/specs/openid-connect-core-1_0.html#IDToken
|
||||||
|
@ -54,10 +54,10 @@ def create_id_token(user, aud, nonce='', at_hash='', request=None, scope=None):
|
||||||
dic['at_hash'] = at_hash
|
dic['at_hash'] = at_hash
|
||||||
|
|
||||||
if settings.get('OIDC_EXTRA_SCOPE_CLAIMS'):
|
if settings.get('OIDC_EXTRA_SCOPE_CLAIMS'):
|
||||||
custom_claims = settings.get('OIDC_EXTRA_SCOPE_CLAIMS', import_str=True)(user, scope)
|
custom_claims = settings.get('OIDC_EXTRA_SCOPE_CLAIMS', import_str=True)(token)
|
||||||
claims = custom_claims.create_response_dic()
|
claims = custom_claims.create_response_dic()
|
||||||
else:
|
else:
|
||||||
claims = StandardScopeClaims(user=user, scope=scope).create_response_dic()
|
claims = StandardScopeClaims(token).create_response_dic()
|
||||||
|
|
||||||
dic.update(claims) # modifies dic, adding all requested claims
|
dic.update(claims) # modifies dic, adding all requested claims
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ class ClaimsTestCase(TestCase):
|
||||||
self.scopes = ['openid', 'address', 'email', 'phone', 'profile']
|
self.scopes = ['openid', 'address', 'email', 'phone', 'profile']
|
||||||
self.client = create_fake_client('code')
|
self.client = create_fake_client('code')
|
||||||
self.token = create_fake_token(self.user, self.scopes, self.client)
|
self.token = create_fake_token(self.user, self.scopes, self.client)
|
||||||
self.scopeClaims = ScopeClaims(self.token.user, self.token.scope)
|
self.scopeClaims = ScopeClaims(self.token)
|
||||||
|
|
||||||
def test_empty_standard_claims(self):
|
def test_empty_standard_claims(self):
|
||||||
for v in [v for k, v in STANDARD_CLAIMS.items() if k != 'address']:
|
for v in [v for k, v in STANDARD_CLAIMS.items() if k != 'address']:
|
||||||
|
|
|
@ -3,6 +3,7 @@ from django.core.urlresolvers import reverse
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from oidc_provider.lib.utils.token import (
|
from oidc_provider.lib.utils.token import (
|
||||||
|
create_token,
|
||||||
create_id_token,
|
create_id_token,
|
||||||
encode_id_token,
|
encode_id_token,
|
||||||
)
|
)
|
||||||
|
@ -41,8 +42,9 @@ class EndSessionTestCase(TestCase):
|
||||||
response, settings.get('OIDC_LOGIN_URL'),
|
response, settings.get('OIDC_LOGIN_URL'),
|
||||||
fetch_redirect_response=False)
|
fetch_redirect_response=False)
|
||||||
|
|
||||||
|
token = create_token(self.user, self.oidc_client, [])
|
||||||
id_token_dic = create_id_token(
|
id_token_dic = create_id_token(
|
||||||
user=self.user, aud=self.oidc_client.client_id)
|
user=self.user, aud=self.oidc_client.client_id, token=token)
|
||||||
id_token = encode_id_token(id_token_dic, self.oidc_client)
|
id_token = encode_id_token(id_token_dic, self.oidc_client)
|
||||||
|
|
||||||
query_params['id_token_hint'] = id_token
|
query_params['id_token_hint'] = id_token
|
||||||
|
@ -56,8 +58,9 @@ class EndSessionTestCase(TestCase):
|
||||||
query_params = {
|
query_params = {
|
||||||
'post_logout_redirect_uri': self.LOGOUT_URL,
|
'post_logout_redirect_uri': self.LOGOUT_URL,
|
||||||
}
|
}
|
||||||
|
token = create_token(self.user, self.oidc_client, [])
|
||||||
id_token_dic = create_id_token(
|
id_token_dic = create_id_token(
|
||||||
user=self.user, aud=self.oidc_client.client_id)
|
user=self.user, aud=self.oidc_client.client_id, token=token)
|
||||||
id_token_dic['aud'] = [id_token_dic['aud']]
|
id_token_dic['aud'] = [id_token_dic['aud']]
|
||||||
id_token = encode_id_token(id_token_dic, self.oidc_client)
|
id_token = encode_id_token(id_token_dic, self.oidc_client)
|
||||||
query_params['id_token_hint'] = id_token
|
query_params['id_token_hint'] = id_token
|
||||||
|
|
|
@ -38,18 +38,20 @@ class UserInfoTestCase(TestCase):
|
||||||
extra_scope = []
|
extra_scope = []
|
||||||
scope = ['openid', 'email'] + extra_scope
|
scope = ['openid', 'email'] + extra_scope
|
||||||
|
|
||||||
|
token = create_token(
|
||||||
|
user=self.user,
|
||||||
|
client=self.client,
|
||||||
|
scope=scope)
|
||||||
|
|
||||||
id_token_dic = create_id_token(
|
id_token_dic = create_id_token(
|
||||||
user=self.user,
|
user=self.user,
|
||||||
aud=self.client.client_id,
|
aud=self.client.client_id,
|
||||||
|
token=token,
|
||||||
nonce=FAKE_NONCE,
|
nonce=FAKE_NONCE,
|
||||||
scope=scope,
|
scope=scope,
|
||||||
)
|
)
|
||||||
|
|
||||||
token = create_token(
|
token.id_token=id_token_dic
|
||||||
user=self.user,
|
|
||||||
client=self.client,
|
|
||||||
id_token_dic=id_token_dic,
|
|
||||||
scope=scope)
|
|
||||||
token.save()
|
token.save()
|
||||||
|
|
||||||
return token
|
return token
|
||||||
|
|
|
@ -8,8 +8,8 @@ from django.utils import timezone
|
||||||
from mock import mock
|
from mock import mock
|
||||||
|
|
||||||
from oidc_provider.lib.utils.common import get_issuer, get_browser_state_or_default
|
from oidc_provider.lib.utils.common import get_issuer, get_browser_state_or_default
|
||||||
from oidc_provider.lib.utils.token import create_id_token
|
from oidc_provider.lib.utils.token import create_token, create_id_token
|
||||||
from oidc_provider.tests.app.utils import create_fake_user
|
from oidc_provider.tests.app.utils import create_fake_user, create_fake_client
|
||||||
|
|
||||||
|
|
||||||
class Request(object):
|
class Request(object):
|
||||||
|
@ -67,7 +67,9 @@ class TokenTest(TestCase):
|
||||||
start_time = int(time.time())
|
start_time = int(time.time())
|
||||||
login_timestamp = start_time - 1234
|
login_timestamp = start_time - 1234
|
||||||
self.user.last_login = timestamp_to_datetime(login_timestamp)
|
self.user.last_login = timestamp_to_datetime(login_timestamp)
|
||||||
id_token_data = create_id_token(self.user, aud='test-aud')
|
client = create_fake_client("code")
|
||||||
|
token = create_token(self.user, client, [])
|
||||||
|
id_token_data = create_id_token(self.user, aud='test-aud', token=token)
|
||||||
iat = id_token_data['iat']
|
iat = id_token_data['iat']
|
||||||
self.assertEqual(type(iat), int)
|
self.assertEqual(type(iat), int)
|
||||||
self.assertGreaterEqual(iat, start_time)
|
self.assertGreaterEqual(iat, start_time)
|
||||||
|
|
|
@ -234,7 +234,7 @@ def userinfo(request, *args, **kwargs):
|
||||||
'sub': token.id_token.get('sub'),
|
'sub': token.id_token.get('sub'),
|
||||||
}
|
}
|
||||||
|
|
||||||
standard_claims = StandardScopeClaims(user=token.user, scope=token.scope)
|
standard_claims = StandardScopeClaims(token)
|
||||||
dic.update(standard_claims.create_response_dic())
|
dic.update(standard_claims.create_response_dic())
|
||||||
|
|
||||||
if settings.get('OIDC_EXTRA_SCOPE_CLAIMS'):
|
if settings.get('OIDC_EXTRA_SCOPE_CLAIMS'):
|
||||||
|
|
Loading…
Reference in a new issue