Add support for refresh_token to token endpoint

This commit is contained in:
Maarten van Schaik 2015-09-30 14:55:48 +02:00
parent 03d2770f5e
commit 8d672cc1ba
4 changed files with 186 additions and 26 deletions

View file

@ -35,6 +35,8 @@ class TokenEndpoint(object):
self.params.grant_type = self.request.POST.get('grant_type', '')
self.params.code = self.request.POST.get('code', '')
self.params.state = self.request.POST.get('state', '')
self.params.scope = self.request.POST.get('scope', '')
self.params.refresh_token = self.request.POST.get('refresh_token', '')
def _extract_client_auth(self):
"""
@ -60,23 +62,29 @@ class TokenEndpoint(object):
return (client_id, client_secret)
def validate_params(self):
if not (self.params.grant_type == 'authorization_code'):
logger.error('[Token] Invalid grant type: %s', self.params.grant_type)
raise TokenError('unsupported_grant_type')
try:
self.client = Client.objects.get(client_id=self.params.client_id)
if not (self.client.client_secret == self.params.client_secret):
logger.error('[Token] Invalid client secret: client %s do not have secret %s',
self.client.client_id, self.client.client_secret)
raise TokenError('invalid_client')
except Client.DoesNotExist:
logger.error('[Token] Client does not exist: %s', self.params.client_id)
raise TokenError('invalid_client')
if not (self.client.client_secret == self.params.client_secret):
logger.error('[Token] Invalid client secret: client %s do not have secret %s',
self.client.client_id, self.client.client_secret)
raise TokenError('invalid_client')
if self.params.grant_type == 'authorization_code':
if not (self.params.redirect_uri in self.client.redirect_uris):
logger.error('[Token] Invalid redirect uri: %s', self.params.redirect_uri)
raise TokenError('invalid_client')
self.code = Code.objects.get(code=self.params.code)
try:
self.code = Code.objects.get(code=self.params.code)
except Code.DoesNotExist:
logger.error('[Token] Code does not exist: %s', self.params.code)
raise TokenError('invalid_grant')
if not (self.code.client == self.client) \
or self.code.has_expired():
@ -84,15 +92,33 @@ class TokenEndpoint(object):
self.params.redirect_uri)
raise TokenError('invalid_grant')
except Client.DoesNotExist:
logger.error('[Token] Client does not exist: %s', self.params.client_id)
raise TokenError('invalid_client')
elif self.params.grant_type == 'refresh_token':
if not self.params.refresh_token:
logger.error('[Token] Missing refresh token')
raise TokenError('invalid_grant')
except Code.DoesNotExist:
logger.error('[Token] Code does not exist: %s', self.params.code)
raise TokenError('invalid_grant')
try:
self.token = Token.objects.get(refresh_token=self.params.refresh_token,
client=self.client)
except Token.DoesNotExist:
logger.error('[Token] Refresh token does not exist: %s', self.params.refresh_token)
raise TokenError('invalid_grant')
else:
logger.error('[Token] Invalid grant type: %s', self.params.grant_type)
raise TokenError('unsupported_grant_type')
def create_response_dic(self):
if self.params.grant_type == 'authorization_code':
return self.create_code_response_dic()
elif self.params.grant_type == 'refresh_token':
return self.create_refresh_response_dic()
else:
# Should have already been catched by validate_params
raise RuntimeError('Invalid grant type')
def create_code_response_dic(self):
id_token_dic = create_id_token(
user=self.code.user,
aud=self.client.client_id,
@ -113,6 +139,37 @@ class TokenEndpoint(object):
dic = {
'access_token': token.access_token,
'refresh_token': token.refresh_token,
'token_type': 'bearer',
'expires_in': settings.get('OIDC_TOKEN_EXPIRE'),
'id_token': encode_id_token(id_token_dic),
}
return dic
def create_refresh_response_dic(self):
id_token_dic = create_id_token(
user=self.token.user,
aud=self.client.client_id,
nonce=None,
)
token = create_token(
user=self.token.user,
client=self.token.client,
id_token_dic=id_token_dic,
scope=self.token.scope)
# Store the token.
token.save()
# We don't need to store the code anymore.
self.token.refresh_token = None
self.token.save()
dic = {
'access_token': token.access_token,
'refresh_token': token.refresh_token,
'token_type': 'bearer',
'expires_in': settings.get('OIDC_TOKEN_EXPIRE'),
'id_token': encode_id_token(id_token_dic),

View file

@ -7,7 +7,7 @@ except ImportError:
import uuid
from django.core.urlresolvers import reverse
from django.test import RequestFactory
from django.test import RequestFactory, override_settings
from django.test import TestCase
from jwkest.jwk import KEYS
from jwkest.jws import JWS
@ -16,6 +16,7 @@ from jwkest.jwt import JWT
from oidc_provider.lib.utils.token import *
from oidc_provider.tests.app.utils import *
from oidc_provider.views import *
from mock import patch
class TokenTestCase(TestCase):
@ -30,7 +31,7 @@ class TokenTestCase(TestCase):
self.user = create_fake_user()
self.client = create_fake_client(response_type='code')
def _post_data(self, code):
def _auth_code_post_data(self, code):
"""
All the data that will be POSTed to the Token Endpoint.
"""
@ -45,6 +46,19 @@ class TokenTestCase(TestCase):
return post_data
def _refresh_token_post_data(self, refresh_token):
"""
All the data that will be POSTed to the Token Endpoint.
"""
post_data = {
'client_id': self.client.client_id,
'client_secret': self.client.client_secret,
'grant_type': 'refresh_token',
'refresh_token': refresh_token,
}
return post_data
def _post_request(self, post_data, extras={}):
"""
Makes a request to the token endpoint by sending the
@ -75,6 +89,98 @@ class TokenTestCase(TestCase):
return code
def _get_keys(self):
"""
Get public key from discovery.
"""
request = self.factory.get(reverse('oidc_provider:jwks'))
response = JwksView.as_view()(request)
jwks_dic = json.loads(response.content.decode('utf-8'))
SIGKEYS = KEYS()
SIGKEYS.load_dict(jwks_dic)
return SIGKEYS
@override_settings(OIDC_TOKEN_EXPIRE=720)
def test_authorization_code(self):
"""
We MUST validate the signature of the ID Token according to JWS
using the algorithm specified in the alg Header Parameter of
the JOSE Header.
"""
SIGKEYS = self._get_keys()
code = self._create_code()
post_data = self._auth_code_post_data(code=code.code)
response = self._post_request(post_data)
response_dic = json.loads(response.content.decode('utf-8'))
id_token = JWS().verify_compact(response_dic['id_token'].encode('utf-8'), SIGKEYS)
token = Token.objects.get(user=self.user)
self.assertEqual(response_dic['access_token'], token.access_token)
self.assertEqual(response_dic['refresh_token'], token.refresh_token)
self.assertEqual(response_dic['token_type'], 'bearer')
self.assertEqual(response_dic['expires_in'], 720)
self.assertEqual(id_token['sub'], str(self.user.id))
self.assertEqual(id_token['aud'], self.client.client_id)
def test_refresh_token(self):
"""
A request to the Token Endpoint can also use a Refresh Token
by using the grant_type value refresh_token, as described in
Section 6 of OAuth 2.0 [RFC6749].
"""
SIGKEYS = self._get_keys()
# Retrieve refresh token
code = self._create_code()
post_data = self._auth_code_post_data(code=code.code)
real_now = timezone.now
with patch('oidc_provider.lib.utils.token.timezone.now') as now:
now.return_value = real_now()
response = self._post_request(post_data)
response_dic1 = json.loads(response.content.decode('utf-8'))
id_token1 = JWS().verify_compact(response_dic1['id_token'].encode('utf-8'), SIGKEYS)
# Use refresh token to obtain new token
post_data = self._refresh_token_post_data(response_dic1['refresh_token'])
with patch('oidc_provider.lib.utils.token.timezone.now') as now:
now.return_value = real_now() + timedelta(minutes=10)
response = self._post_request(post_data)
response_dic2 = json.loads(response.content.decode('utf-8'))
id_token2 = JWS().verify_compact(response_dic2['id_token'].encode('utf-8'), SIGKEYS)
self.assertNotEqual(response_dic1['id_token'], response_dic2['id_token'])
self.assertNotEqual(response_dic1['access_token'], response_dic2['access_token'])
self.assertNotEqual(response_dic1['refresh_token'], response_dic2['refresh_token'])
# http://openid.net/specs/openid-connect-core-1_0.html#rfc.section.12.2
self.assertEqual(id_token1['iss'], id_token2['iss'])
self.assertEqual(id_token1['sub'], id_token2['sub'])
self.assertNotEqual(id_token1['iat'], id_token2['iat'])
self.assertEqual(id_token1['aud'], id_token2['aud'])
self.assertEqual(id_token1['auth_time'], id_token2['auth_time'])
self.assertEqual(id_token1.get('azp'), id_token2.get('azp'))
# Refresh token can't be reused
post_data = self._refresh_token_post_data(response_dic1['refresh_token'])
response = self._post_request(post_data)
self.assertIn('invalid_grant', response.content.decode('utf-8'))
# Empty refresh token is invalid
post_data = self._refresh_token_post_data('')
response = self._post_request(post_data)
self.assertIn('invalid_grant', response.content.decode('utf-8'))
# No refresh token is invalid
post_data = self._refresh_token_post_data('')
del post_data['refresh_token']
response = self._post_request(post_data)
self.assertIn('invalid_grant', response.content.decode('utf-8'))
def test_request_methods(self):
"""
Client sends an HTTP POST request to the Token Endpoint. Other request
@ -112,7 +218,7 @@ class TokenTestCase(TestCase):
code = self._create_code()
# Test a valid request to the token endpoint.
post_data = self._post_data(code=code.code)
post_data = self._auth_code_post_data(code=code.code)
response = self._post_request(post_data)
@ -168,7 +274,7 @@ class TokenTestCase(TestCase):
"""
code = self._create_code()
post_data = self._post_data(code=code.code)
post_data = self._auth_code_post_data(code=code.code)
response = self._post_request(post_data)
@ -194,17 +300,12 @@ class TokenTestCase(TestCase):
using the algorithm specified in the alg Header Parameter of
the JOSE Header.
"""
# Get public key from discovery.
request = self.factory.get(reverse('oidc_provider:jwks'))
response = JwksView.as_view()(request)
jwks_dic = json.loads(response.content.decode('utf-8'))
SIGKEYS = KEYS()
SIGKEYS.load_dict(jwks_dic)
SIGKEYS = self._get_keys()
RSAKEYS = [ k for k in SIGKEYS if k.kty == 'RSA' ]
code = self._create_code()
post_data = self._post_data(code=code.code)
post_data = self._auth_code_post_data(code=code.code)
response = self._post_request(post_data)
response_dic = json.loads(response.content.decode('utf-8'))

View file

@ -38,6 +38,7 @@ setup(
],
tests_require=[
'pyjwkest==1.0.1',
'mock==1.3.0',
],
install_requires=[

View file

@ -9,6 +9,7 @@ deps =
django17: django==1.7
django18: django==1.8
coverage
mock
commands =
pip install -e .