Merge from ByteInternet:refresh-tokens
This commit is contained in:
commit
9a685a7afa
|
@ -35,6 +35,8 @@ class TokenEndpoint(object):
|
|||
self.params.grant_type = self.request.POST.get('grant_type', '')
|
||||
self.params.code = self.request.POST.get('code', '')
|
||||
self.params.state = self.request.POST.get('state', '')
|
||||
self.params.scope = self.request.POST.get('scope', '')
|
||||
self.params.refresh_token = self.request.POST.get('refresh_token', '')
|
||||
|
||||
def _extract_client_auth(self):
|
||||
"""
|
||||
|
@ -60,23 +62,29 @@ class TokenEndpoint(object):
|
|||
return (client_id, client_secret)
|
||||
|
||||
def validate_params(self):
|
||||
if not (self.params.grant_type == 'authorization_code'):
|
||||
logger.error('[Token] Invalid grant type: %s', self.params.grant_type)
|
||||
raise TokenError('unsupported_grant_type')
|
||||
|
||||
try:
|
||||
self.client = Client.objects.get(client_id=self.params.client_id)
|
||||
|
||||
if not (self.client.client_secret == self.params.client_secret):
|
||||
logger.error('[Token] Invalid client secret: client %s do not have secret %s',
|
||||
self.client.client_id, self.client.client_secret)
|
||||
raise TokenError('invalid_client')
|
||||
except Client.DoesNotExist:
|
||||
logger.error('[Token] Client does not exist: %s', self.params.client_id)
|
||||
raise TokenError('invalid_client')
|
||||
|
||||
if not (self.client.client_secret == self.params.client_secret):
|
||||
logger.error('[Token] Invalid client secret: client %s do not have secret %s',
|
||||
self.client.client_id, self.client.client_secret)
|
||||
raise TokenError('invalid_client')
|
||||
|
||||
if self.params.grant_type == 'authorization_code':
|
||||
if not (self.params.redirect_uri in self.client.redirect_uris):
|
||||
logger.error('[Token] Invalid redirect uri: %s', self.params.redirect_uri)
|
||||
raise TokenError('invalid_client')
|
||||
|
||||
self.code = Code.objects.get(code=self.params.code)
|
||||
try:
|
||||
self.code = Code.objects.get(code=self.params.code)
|
||||
|
||||
except Code.DoesNotExist:
|
||||
logger.error('[Token] Code does not exist: %s', self.params.code)
|
||||
raise TokenError('invalid_grant')
|
||||
|
||||
if not (self.code.client == self.client) \
|
||||
or self.code.has_expired():
|
||||
|
@ -84,15 +92,33 @@ class TokenEndpoint(object):
|
|||
self.params.redirect_uri)
|
||||
raise TokenError('invalid_grant')
|
||||
|
||||
except Client.DoesNotExist:
|
||||
logger.error('[Token] Client does not exist: %s', self.params.client_id)
|
||||
raise TokenError('invalid_client')
|
||||
elif self.params.grant_type == 'refresh_token':
|
||||
if not self.params.refresh_token:
|
||||
logger.error('[Token] Missing refresh token')
|
||||
raise TokenError('invalid_grant')
|
||||
|
||||
except Code.DoesNotExist:
|
||||
logger.error('[Token] Code does not exist: %s', self.params.code)
|
||||
raise TokenError('invalid_grant')
|
||||
try:
|
||||
self.token = Token.objects.get(refresh_token=self.params.refresh_token,
|
||||
client=self.client)
|
||||
|
||||
except Token.DoesNotExist:
|
||||
logger.error('[Token] Refresh token does not exist: %s', self.params.refresh_token)
|
||||
raise TokenError('invalid_grant')
|
||||
|
||||
else:
|
||||
logger.error('[Token] Invalid grant type: %s', self.params.grant_type)
|
||||
raise TokenError('unsupported_grant_type')
|
||||
|
||||
def create_response_dic(self):
|
||||
if self.params.grant_type == 'authorization_code':
|
||||
return self.create_code_response_dic()
|
||||
elif self.params.grant_type == 'refresh_token':
|
||||
return self.create_refresh_response_dic()
|
||||
else:
|
||||
# Should have already been catched by validate_params
|
||||
raise RuntimeError('Invalid grant type')
|
||||
|
||||
def create_code_response_dic(self):
|
||||
id_token_dic = create_id_token(
|
||||
user=self.code.user,
|
||||
aud=self.client.client_id,
|
||||
|
@ -113,6 +139,36 @@ class TokenEndpoint(object):
|
|||
|
||||
dic = {
|
||||
'access_token': token.access_token,
|
||||
'refresh_token': token.refresh_token,
|
||||
'token_type': 'bearer',
|
||||
'expires_in': settings.get('OIDC_TOKEN_EXPIRE'),
|
||||
'id_token': encode_id_token(id_token_dic),
|
||||
}
|
||||
|
||||
return dic
|
||||
|
||||
def create_refresh_response_dic(self):
|
||||
id_token_dic = create_id_token(
|
||||
user=self.token.user,
|
||||
aud=self.client.client_id,
|
||||
nonce=None,
|
||||
)
|
||||
|
||||
token = create_token(
|
||||
user=self.token.user,
|
||||
client=self.token.client,
|
||||
id_token_dic=id_token_dic,
|
||||
scope=self.token.scope)
|
||||
|
||||
# Store the token.
|
||||
token.save()
|
||||
|
||||
# Forget the old token.
|
||||
self.token.delete()
|
||||
|
||||
dic = {
|
||||
'access_token': token.access_token,
|
||||
'refresh_token': token.refresh_token,
|
||||
'token_type': 'bearer',
|
||||
'expires_in': settings.get('OIDC_TOKEN_EXPIRE'),
|
||||
'id_token': encode_id_token(id_token_dic),
|
||||
|
|
20
oidc_provider/migrations/0005_token_refresh_token.py
Normal file
20
oidc_provider/migrations/0005_token_refresh_token.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from django.db import models, migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('oidc_provider', '0004_remove_userinfo'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='token',
|
||||
name='refresh_token',
|
||||
field=models.CharField(max_length=255, unique=True, null=True),
|
||||
preserve_default=True,
|
||||
),
|
||||
]
|
|
@ -77,6 +77,7 @@ class Code(BaseCodeTokenModel):
|
|||
class Token(BaseCodeTokenModel):
|
||||
|
||||
access_token = models.CharField(max_length=255, unique=True)
|
||||
refresh_token = models.CharField(max_length=255, unique=True, null=True)
|
||||
_id_token = models.TextField()
|
||||
def id_token():
|
||||
def fget(self):
|
||||
|
|
|
@ -7,7 +7,7 @@ except ImportError:
|
|||
import uuid
|
||||
|
||||
from django.core.urlresolvers import reverse
|
||||
from django.test import RequestFactory
|
||||
from django.test import RequestFactory, override_settings
|
||||
from django.test import TestCase
|
||||
from jwkest.jwk import KEYS
|
||||
from jwkest.jws import JWS
|
||||
|
@ -16,6 +16,7 @@ from jwkest.jwt import JWT
|
|||
from oidc_provider.lib.utils.token import *
|
||||
from oidc_provider.tests.app.utils import *
|
||||
from oidc_provider.views import *
|
||||
from mock import patch
|
||||
|
||||
|
||||
class TokenTestCase(TestCase):
|
||||
|
@ -30,7 +31,7 @@ class TokenTestCase(TestCase):
|
|||
self.user = create_fake_user()
|
||||
self.client = create_fake_client(response_type='code')
|
||||
|
||||
def _post_data(self, code):
|
||||
def _auth_code_post_data(self, code):
|
||||
"""
|
||||
All the data that will be POSTed to the Token Endpoint.
|
||||
"""
|
||||
|
@ -45,6 +46,19 @@ class TokenTestCase(TestCase):
|
|||
|
||||
return post_data
|
||||
|
||||
def _refresh_token_post_data(self, refresh_token):
|
||||
"""
|
||||
All the data that will be POSTed to the Token Endpoint.
|
||||
"""
|
||||
post_data = {
|
||||
'client_id': self.client.client_id,
|
||||
'client_secret': self.client.client_secret,
|
||||
'grant_type': 'refresh_token',
|
||||
'refresh_token': refresh_token,
|
||||
}
|
||||
|
||||
return post_data
|
||||
|
||||
def _post_request(self, post_data, extras={}):
|
||||
"""
|
||||
Makes a request to the token endpoint by sending the
|
||||
|
@ -75,6 +89,109 @@ class TokenTestCase(TestCase):
|
|||
|
||||
return code
|
||||
|
||||
def _get_keys(self):
|
||||
"""
|
||||
Get public key from discovery.
|
||||
"""
|
||||
request = self.factory.get(reverse('oidc_provider:jwks'))
|
||||
response = JwksView.as_view()(request)
|
||||
jwks_dic = json.loads(response.content.decode('utf-8'))
|
||||
SIGKEYS = KEYS()
|
||||
SIGKEYS.load_dict(jwks_dic)
|
||||
return SIGKEYS
|
||||
|
||||
def _get_userinfo(self, access_token):
|
||||
url = reverse('oidc_provider:userinfo')
|
||||
request = self.factory.get(url)
|
||||
request.META['HTTP_AUTHORIZATION'] = 'Bearer ' + access_token
|
||||
|
||||
return userinfo(request)
|
||||
|
||||
@override_settings(OIDC_TOKEN_EXPIRE=720)
|
||||
def test_authorization_code(self):
|
||||
"""
|
||||
We MUST validate the signature of the ID Token according to JWS
|
||||
using the algorithm specified in the alg Header Parameter of
|
||||
the JOSE Header.
|
||||
"""
|
||||
SIGKEYS = self._get_keys()
|
||||
code = self._create_code()
|
||||
|
||||
post_data = self._auth_code_post_data(code=code.code)
|
||||
|
||||
response = self._post_request(post_data)
|
||||
response_dic = json.loads(response.content.decode('utf-8'))
|
||||
|
||||
id_token = JWS().verify_compact(response_dic['id_token'].encode('utf-8'), SIGKEYS)
|
||||
|
||||
token = Token.objects.get(user=self.user)
|
||||
self.assertEqual(response_dic['access_token'], token.access_token)
|
||||
self.assertEqual(response_dic['refresh_token'], token.refresh_token)
|
||||
self.assertEqual(response_dic['token_type'], 'bearer')
|
||||
self.assertEqual(response_dic['expires_in'], 720)
|
||||
self.assertEqual(id_token['sub'], str(self.user.id))
|
||||
self.assertEqual(id_token['aud'], self.client.client_id)
|
||||
|
||||
def test_refresh_token(self):
|
||||
"""
|
||||
A request to the Token Endpoint can also use a Refresh Token
|
||||
by using the grant_type value refresh_token, as described in
|
||||
Section 6 of OAuth 2.0 [RFC6749].
|
||||
"""
|
||||
SIGKEYS = self._get_keys()
|
||||
|
||||
# Retrieve refresh token
|
||||
code = self._create_code()
|
||||
post_data = self._auth_code_post_data(code=code.code)
|
||||
real_now = timezone.now
|
||||
with patch('oidc_provider.lib.utils.token.timezone.now') as now:
|
||||
now.return_value = real_now()
|
||||
response = self._post_request(post_data)
|
||||
|
||||
response_dic1 = json.loads(response.content.decode('utf-8'))
|
||||
id_token1 = JWS().verify_compact(response_dic1['id_token'].encode('utf-8'), SIGKEYS)
|
||||
|
||||
# Use refresh token to obtain new token
|
||||
post_data = self._refresh_token_post_data(response_dic1['refresh_token'])
|
||||
with patch('oidc_provider.lib.utils.token.timezone.now') as now:
|
||||
now.return_value = real_now() + timedelta(minutes=10)
|
||||
response = self._post_request(post_data)
|
||||
|
||||
response_dic2 = json.loads(response.content.decode('utf-8'))
|
||||
id_token2 = JWS().verify_compact(response_dic2['id_token'].encode('utf-8'), SIGKEYS)
|
||||
|
||||
self.assertNotEqual(response_dic1['id_token'], response_dic2['id_token'])
|
||||
self.assertNotEqual(response_dic1['access_token'], response_dic2['access_token'])
|
||||
self.assertNotEqual(response_dic1['refresh_token'], response_dic2['refresh_token'])
|
||||
|
||||
# http://openid.net/specs/openid-connect-core-1_0.html#rfc.section.12.2
|
||||
self.assertEqual(id_token1['iss'], id_token2['iss'])
|
||||
self.assertEqual(id_token1['sub'], id_token2['sub'])
|
||||
self.assertNotEqual(id_token1['iat'], id_token2['iat'])
|
||||
self.assertEqual(id_token1['aud'], id_token2['aud'])
|
||||
self.assertEqual(id_token1['auth_time'], id_token2['auth_time'])
|
||||
self.assertEqual(id_token1.get('azp'), id_token2.get('azp'))
|
||||
|
||||
# Refresh token can't be reused
|
||||
post_data = self._refresh_token_post_data(response_dic1['refresh_token'])
|
||||
response = self._post_request(post_data)
|
||||
self.assertIn('invalid_grant', response.content.decode('utf-8'))
|
||||
|
||||
# Old access token is invalidated
|
||||
self.assertEqual(self._get_userinfo(response_dic1['access_token']).status_code, 401)
|
||||
self.assertEqual(self._get_userinfo(response_dic2['access_token']).status_code, 200)
|
||||
|
||||
# Empty refresh token is invalid
|
||||
post_data = self._refresh_token_post_data('')
|
||||
response = self._post_request(post_data)
|
||||
self.assertIn('invalid_grant', response.content.decode('utf-8'))
|
||||
|
||||
# No refresh token is invalid
|
||||
post_data = self._refresh_token_post_data('')
|
||||
del post_data['refresh_token']
|
||||
response = self._post_request(post_data)
|
||||
self.assertIn('invalid_grant', response.content.decode('utf-8'))
|
||||
|
||||
def test_request_methods(self):
|
||||
"""
|
||||
Client sends an HTTP POST request to the Token Endpoint. Other request
|
||||
|
@ -112,7 +229,7 @@ class TokenTestCase(TestCase):
|
|||
code = self._create_code()
|
||||
|
||||
# Test a valid request to the token endpoint.
|
||||
post_data = self._post_data(code=code.code)
|
||||
post_data = self._auth_code_post_data(code=code.code)
|
||||
|
||||
response = self._post_request(post_data)
|
||||
|
||||
|
@ -168,7 +285,7 @@ class TokenTestCase(TestCase):
|
|||
"""
|
||||
code = self._create_code()
|
||||
|
||||
post_data = self._post_data(code=code.code)
|
||||
post_data = self._auth_code_post_data(code=code.code)
|
||||
|
||||
response = self._post_request(post_data)
|
||||
|
||||
|
@ -194,17 +311,12 @@ class TokenTestCase(TestCase):
|
|||
using the algorithm specified in the alg Header Parameter of
|
||||
the JOSE Header.
|
||||
"""
|
||||
# Get public key from discovery.
|
||||
request = self.factory.get(reverse('oidc_provider:jwks'))
|
||||
response = JwksView.as_view()(request)
|
||||
jwks_dic = json.loads(response.content.decode('utf-8'))
|
||||
SIGKEYS = KEYS()
|
||||
SIGKEYS.load_dict(jwks_dic)
|
||||
SIGKEYS = self._get_keys()
|
||||
RSAKEYS = [ k for k in SIGKEYS if k.kty == 'RSA' ]
|
||||
|
||||
code = self._create_code()
|
||||
|
||||
post_data = self._post_data(code=code.code)
|
||||
post_data = self._auth_code_post_data(code=code.code)
|
||||
|
||||
response = self._post_request(post_data)
|
||||
response_dic = json.loads(response.content.decode('utf-8'))
|
||||
|
|
1
setup.py
1
setup.py
|
@ -38,6 +38,7 @@ setup(
|
|||
],
|
||||
tests_require=[
|
||||
'pyjwkest>=1.0.3,<1.1',
|
||||
'mock==1.3.0',
|
||||
],
|
||||
|
||||
install_requires=[
|
||||
|
|
Loading…
Reference in a new issue