Add support for refresh_token to token endpoint
This commit is contained in:
parent
03d2770f5e
commit
8d672cc1ba
4 changed files with 186 additions and 26 deletions
|
@ -35,6 +35,8 @@ class TokenEndpoint(object):
|
||||||
self.params.grant_type = self.request.POST.get('grant_type', '')
|
self.params.grant_type = self.request.POST.get('grant_type', '')
|
||||||
self.params.code = self.request.POST.get('code', '')
|
self.params.code = self.request.POST.get('code', '')
|
||||||
self.params.state = self.request.POST.get('state', '')
|
self.params.state = self.request.POST.get('state', '')
|
||||||
|
self.params.scope = self.request.POST.get('scope', '')
|
||||||
|
self.params.refresh_token = self.request.POST.get('refresh_token', '')
|
||||||
|
|
||||||
def _extract_client_auth(self):
|
def _extract_client_auth(self):
|
||||||
"""
|
"""
|
||||||
|
@ -60,39 +62,63 @@ class TokenEndpoint(object):
|
||||||
return (client_id, client_secret)
|
return (client_id, client_secret)
|
||||||
|
|
||||||
def validate_params(self):
|
def validate_params(self):
|
||||||
if not (self.params.grant_type == 'authorization_code'):
|
|
||||||
logger.error('[Token] Invalid grant type: %s', self.params.grant_type)
|
|
||||||
raise TokenError('unsupported_grant_type')
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.client = Client.objects.get(client_id=self.params.client_id)
|
self.client = Client.objects.get(client_id=self.params.client_id)
|
||||||
|
|
||||||
|
except Client.DoesNotExist:
|
||||||
|
logger.error('[Token] Client does not exist: %s', self.params.client_id)
|
||||||
|
raise TokenError('invalid_client')
|
||||||
|
|
||||||
if not (self.client.client_secret == self.params.client_secret):
|
if not (self.client.client_secret == self.params.client_secret):
|
||||||
logger.error('[Token] Invalid client secret: client %s do not have secret %s',
|
logger.error('[Token] Invalid client secret: client %s do not have secret %s',
|
||||||
self.client.client_id, self.client.client_secret)
|
self.client.client_id, self.client.client_secret)
|
||||||
raise TokenError('invalid_client')
|
raise TokenError('invalid_client')
|
||||||
|
|
||||||
|
if self.params.grant_type == 'authorization_code':
|
||||||
if not (self.params.redirect_uri in self.client.redirect_uris):
|
if not (self.params.redirect_uri in self.client.redirect_uris):
|
||||||
logger.error('[Token] Invalid redirect uri: %s', self.params.redirect_uri)
|
logger.error('[Token] Invalid redirect uri: %s', self.params.redirect_uri)
|
||||||
raise TokenError('invalid_client')
|
raise TokenError('invalid_client')
|
||||||
|
|
||||||
|
try:
|
||||||
self.code = Code.objects.get(code=self.params.code)
|
self.code = Code.objects.get(code=self.params.code)
|
||||||
|
|
||||||
|
except Code.DoesNotExist:
|
||||||
|
logger.error('[Token] Code does not exist: %s', self.params.code)
|
||||||
|
raise TokenError('invalid_grant')
|
||||||
|
|
||||||
if not (self.code.client == self.client) \
|
if not (self.code.client == self.client) \
|
||||||
or self.code.has_expired():
|
or self.code.has_expired():
|
||||||
logger.error('[Token] Invalid code: invalid client or code has expired',
|
logger.error('[Token] Invalid code: invalid client or code has expired',
|
||||||
self.params.redirect_uri)
|
self.params.redirect_uri)
|
||||||
raise TokenError('invalid_grant')
|
raise TokenError('invalid_grant')
|
||||||
|
|
||||||
except Client.DoesNotExist:
|
elif self.params.grant_type == 'refresh_token':
|
||||||
logger.error('[Token] Client does not exist: %s', self.params.client_id)
|
if not self.params.refresh_token:
|
||||||
raise TokenError('invalid_client')
|
logger.error('[Token] Missing refresh token')
|
||||||
|
|
||||||
except Code.DoesNotExist:
|
|
||||||
logger.error('[Token] Code does not exist: %s', self.params.code)
|
|
||||||
raise TokenError('invalid_grant')
|
raise TokenError('invalid_grant')
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.token = Token.objects.get(refresh_token=self.params.refresh_token,
|
||||||
|
client=self.client)
|
||||||
|
|
||||||
|
except Token.DoesNotExist:
|
||||||
|
logger.error('[Token] Refresh token does not exist: %s', self.params.refresh_token)
|
||||||
|
raise TokenError('invalid_grant')
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.error('[Token] Invalid grant type: %s', self.params.grant_type)
|
||||||
|
raise TokenError('unsupported_grant_type')
|
||||||
|
|
||||||
def create_response_dic(self):
|
def create_response_dic(self):
|
||||||
|
if self.params.grant_type == 'authorization_code':
|
||||||
|
return self.create_code_response_dic()
|
||||||
|
elif self.params.grant_type == 'refresh_token':
|
||||||
|
return self.create_refresh_response_dic()
|
||||||
|
else:
|
||||||
|
# Should have already been catched by validate_params
|
||||||
|
raise RuntimeError('Invalid grant type')
|
||||||
|
|
||||||
|
def create_code_response_dic(self):
|
||||||
id_token_dic = create_id_token(
|
id_token_dic = create_id_token(
|
||||||
user=self.code.user,
|
user=self.code.user,
|
||||||
aud=self.client.client_id,
|
aud=self.client.client_id,
|
||||||
|
@ -113,6 +139,37 @@ class TokenEndpoint(object):
|
||||||
|
|
||||||
dic = {
|
dic = {
|
||||||
'access_token': token.access_token,
|
'access_token': token.access_token,
|
||||||
|
'refresh_token': token.refresh_token,
|
||||||
|
'token_type': 'bearer',
|
||||||
|
'expires_in': settings.get('OIDC_TOKEN_EXPIRE'),
|
||||||
|
'id_token': encode_id_token(id_token_dic),
|
||||||
|
}
|
||||||
|
|
||||||
|
return dic
|
||||||
|
|
||||||
|
def create_refresh_response_dic(self):
|
||||||
|
id_token_dic = create_id_token(
|
||||||
|
user=self.token.user,
|
||||||
|
aud=self.client.client_id,
|
||||||
|
nonce=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
token = create_token(
|
||||||
|
user=self.token.user,
|
||||||
|
client=self.token.client,
|
||||||
|
id_token_dic=id_token_dic,
|
||||||
|
scope=self.token.scope)
|
||||||
|
|
||||||
|
# Store the token.
|
||||||
|
token.save()
|
||||||
|
|
||||||
|
# We don't need to store the code anymore.
|
||||||
|
self.token.refresh_token = None
|
||||||
|
self.token.save()
|
||||||
|
|
||||||
|
dic = {
|
||||||
|
'access_token': token.access_token,
|
||||||
|
'refresh_token': token.refresh_token,
|
||||||
'token_type': 'bearer',
|
'token_type': 'bearer',
|
||||||
'expires_in': settings.get('OIDC_TOKEN_EXPIRE'),
|
'expires_in': settings.get('OIDC_TOKEN_EXPIRE'),
|
||||||
'id_token': encode_id_token(id_token_dic),
|
'id_token': encode_id_token(id_token_dic),
|
||||||
|
|
|
@ -7,7 +7,7 @@ except ImportError:
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from django.core.urlresolvers import reverse
|
from django.core.urlresolvers import reverse
|
||||||
from django.test import RequestFactory
|
from django.test import RequestFactory, override_settings
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from jwkest.jwk import KEYS
|
from jwkest.jwk import KEYS
|
||||||
from jwkest.jws import JWS
|
from jwkest.jws import JWS
|
||||||
|
@ -16,6 +16,7 @@ from jwkest.jwt import JWT
|
||||||
from oidc_provider.lib.utils.token import *
|
from oidc_provider.lib.utils.token import *
|
||||||
from oidc_provider.tests.app.utils import *
|
from oidc_provider.tests.app.utils import *
|
||||||
from oidc_provider.views import *
|
from oidc_provider.views import *
|
||||||
|
from mock import patch
|
||||||
|
|
||||||
|
|
||||||
class TokenTestCase(TestCase):
|
class TokenTestCase(TestCase):
|
||||||
|
@ -30,7 +31,7 @@ class TokenTestCase(TestCase):
|
||||||
self.user = create_fake_user()
|
self.user = create_fake_user()
|
||||||
self.client = create_fake_client(response_type='code')
|
self.client = create_fake_client(response_type='code')
|
||||||
|
|
||||||
def _post_data(self, code):
|
def _auth_code_post_data(self, code):
|
||||||
"""
|
"""
|
||||||
All the data that will be POSTed to the Token Endpoint.
|
All the data that will be POSTed to the Token Endpoint.
|
||||||
"""
|
"""
|
||||||
|
@ -45,6 +46,19 @@ class TokenTestCase(TestCase):
|
||||||
|
|
||||||
return post_data
|
return post_data
|
||||||
|
|
||||||
|
def _refresh_token_post_data(self, refresh_token):
|
||||||
|
"""
|
||||||
|
All the data that will be POSTed to the Token Endpoint.
|
||||||
|
"""
|
||||||
|
post_data = {
|
||||||
|
'client_id': self.client.client_id,
|
||||||
|
'client_secret': self.client.client_secret,
|
||||||
|
'grant_type': 'refresh_token',
|
||||||
|
'refresh_token': refresh_token,
|
||||||
|
}
|
||||||
|
|
||||||
|
return post_data
|
||||||
|
|
||||||
def _post_request(self, post_data, extras={}):
|
def _post_request(self, post_data, extras={}):
|
||||||
"""
|
"""
|
||||||
Makes a request to the token endpoint by sending the
|
Makes a request to the token endpoint by sending the
|
||||||
|
@ -75,6 +89,98 @@ class TokenTestCase(TestCase):
|
||||||
|
|
||||||
return code
|
return code
|
||||||
|
|
||||||
|
def _get_keys(self):
|
||||||
|
"""
|
||||||
|
Get public key from discovery.
|
||||||
|
"""
|
||||||
|
request = self.factory.get(reverse('oidc_provider:jwks'))
|
||||||
|
response = JwksView.as_view()(request)
|
||||||
|
jwks_dic = json.loads(response.content.decode('utf-8'))
|
||||||
|
SIGKEYS = KEYS()
|
||||||
|
SIGKEYS.load_dict(jwks_dic)
|
||||||
|
return SIGKEYS
|
||||||
|
|
||||||
|
@override_settings(OIDC_TOKEN_EXPIRE=720)
|
||||||
|
def test_authorization_code(self):
|
||||||
|
"""
|
||||||
|
We MUST validate the signature of the ID Token according to JWS
|
||||||
|
using the algorithm specified in the alg Header Parameter of
|
||||||
|
the JOSE Header.
|
||||||
|
"""
|
||||||
|
SIGKEYS = self._get_keys()
|
||||||
|
code = self._create_code()
|
||||||
|
|
||||||
|
post_data = self._auth_code_post_data(code=code.code)
|
||||||
|
|
||||||
|
response = self._post_request(post_data)
|
||||||
|
response_dic = json.loads(response.content.decode('utf-8'))
|
||||||
|
|
||||||
|
id_token = JWS().verify_compact(response_dic['id_token'].encode('utf-8'), SIGKEYS)
|
||||||
|
|
||||||
|
token = Token.objects.get(user=self.user)
|
||||||
|
self.assertEqual(response_dic['access_token'], token.access_token)
|
||||||
|
self.assertEqual(response_dic['refresh_token'], token.refresh_token)
|
||||||
|
self.assertEqual(response_dic['token_type'], 'bearer')
|
||||||
|
self.assertEqual(response_dic['expires_in'], 720)
|
||||||
|
self.assertEqual(id_token['sub'], str(self.user.id))
|
||||||
|
self.assertEqual(id_token['aud'], self.client.client_id)
|
||||||
|
|
||||||
|
def test_refresh_token(self):
|
||||||
|
"""
|
||||||
|
A request to the Token Endpoint can also use a Refresh Token
|
||||||
|
by using the grant_type value refresh_token, as described in
|
||||||
|
Section 6 of OAuth 2.0 [RFC6749].
|
||||||
|
"""
|
||||||
|
SIGKEYS = self._get_keys()
|
||||||
|
|
||||||
|
# Retrieve refresh token
|
||||||
|
code = self._create_code()
|
||||||
|
post_data = self._auth_code_post_data(code=code.code)
|
||||||
|
real_now = timezone.now
|
||||||
|
with patch('oidc_provider.lib.utils.token.timezone.now') as now:
|
||||||
|
now.return_value = real_now()
|
||||||
|
response = self._post_request(post_data)
|
||||||
|
|
||||||
|
response_dic1 = json.loads(response.content.decode('utf-8'))
|
||||||
|
id_token1 = JWS().verify_compact(response_dic1['id_token'].encode('utf-8'), SIGKEYS)
|
||||||
|
|
||||||
|
# Use refresh token to obtain new token
|
||||||
|
post_data = self._refresh_token_post_data(response_dic1['refresh_token'])
|
||||||
|
with patch('oidc_provider.lib.utils.token.timezone.now') as now:
|
||||||
|
now.return_value = real_now() + timedelta(minutes=10)
|
||||||
|
response = self._post_request(post_data)
|
||||||
|
|
||||||
|
response_dic2 = json.loads(response.content.decode('utf-8'))
|
||||||
|
id_token2 = JWS().verify_compact(response_dic2['id_token'].encode('utf-8'), SIGKEYS)
|
||||||
|
|
||||||
|
self.assertNotEqual(response_dic1['id_token'], response_dic2['id_token'])
|
||||||
|
self.assertNotEqual(response_dic1['access_token'], response_dic2['access_token'])
|
||||||
|
self.assertNotEqual(response_dic1['refresh_token'], response_dic2['refresh_token'])
|
||||||
|
|
||||||
|
# http://openid.net/specs/openid-connect-core-1_0.html#rfc.section.12.2
|
||||||
|
self.assertEqual(id_token1['iss'], id_token2['iss'])
|
||||||
|
self.assertEqual(id_token1['sub'], id_token2['sub'])
|
||||||
|
self.assertNotEqual(id_token1['iat'], id_token2['iat'])
|
||||||
|
self.assertEqual(id_token1['aud'], id_token2['aud'])
|
||||||
|
self.assertEqual(id_token1['auth_time'], id_token2['auth_time'])
|
||||||
|
self.assertEqual(id_token1.get('azp'), id_token2.get('azp'))
|
||||||
|
|
||||||
|
# Refresh token can't be reused
|
||||||
|
post_data = self._refresh_token_post_data(response_dic1['refresh_token'])
|
||||||
|
response = self._post_request(post_data)
|
||||||
|
self.assertIn('invalid_grant', response.content.decode('utf-8'))
|
||||||
|
|
||||||
|
# Empty refresh token is invalid
|
||||||
|
post_data = self._refresh_token_post_data('')
|
||||||
|
response = self._post_request(post_data)
|
||||||
|
self.assertIn('invalid_grant', response.content.decode('utf-8'))
|
||||||
|
|
||||||
|
# No refresh token is invalid
|
||||||
|
post_data = self._refresh_token_post_data('')
|
||||||
|
del post_data['refresh_token']
|
||||||
|
response = self._post_request(post_data)
|
||||||
|
self.assertIn('invalid_grant', response.content.decode('utf-8'))
|
||||||
|
|
||||||
def test_request_methods(self):
|
def test_request_methods(self):
|
||||||
"""
|
"""
|
||||||
Client sends an HTTP POST request to the Token Endpoint. Other request
|
Client sends an HTTP POST request to the Token Endpoint. Other request
|
||||||
|
@ -112,7 +218,7 @@ class TokenTestCase(TestCase):
|
||||||
code = self._create_code()
|
code = self._create_code()
|
||||||
|
|
||||||
# Test a valid request to the token endpoint.
|
# Test a valid request to the token endpoint.
|
||||||
post_data = self._post_data(code=code.code)
|
post_data = self._auth_code_post_data(code=code.code)
|
||||||
|
|
||||||
response = self._post_request(post_data)
|
response = self._post_request(post_data)
|
||||||
|
|
||||||
|
@ -168,7 +274,7 @@ class TokenTestCase(TestCase):
|
||||||
"""
|
"""
|
||||||
code = self._create_code()
|
code = self._create_code()
|
||||||
|
|
||||||
post_data = self._post_data(code=code.code)
|
post_data = self._auth_code_post_data(code=code.code)
|
||||||
|
|
||||||
response = self._post_request(post_data)
|
response = self._post_request(post_data)
|
||||||
|
|
||||||
|
@ -194,17 +300,12 @@ class TokenTestCase(TestCase):
|
||||||
using the algorithm specified in the alg Header Parameter of
|
using the algorithm specified in the alg Header Parameter of
|
||||||
the JOSE Header.
|
the JOSE Header.
|
||||||
"""
|
"""
|
||||||
# Get public key from discovery.
|
SIGKEYS = self._get_keys()
|
||||||
request = self.factory.get(reverse('oidc_provider:jwks'))
|
|
||||||
response = JwksView.as_view()(request)
|
|
||||||
jwks_dic = json.loads(response.content.decode('utf-8'))
|
|
||||||
SIGKEYS = KEYS()
|
|
||||||
SIGKEYS.load_dict(jwks_dic)
|
|
||||||
RSAKEYS = [ k for k in SIGKEYS if k.kty == 'RSA' ]
|
RSAKEYS = [ k for k in SIGKEYS if k.kty == 'RSA' ]
|
||||||
|
|
||||||
code = self._create_code()
|
code = self._create_code()
|
||||||
|
|
||||||
post_data = self._post_data(code=code.code)
|
post_data = self._auth_code_post_data(code=code.code)
|
||||||
|
|
||||||
response = self._post_request(post_data)
|
response = self._post_request(post_data)
|
||||||
response_dic = json.loads(response.content.decode('utf-8'))
|
response_dic = json.loads(response.content.decode('utf-8'))
|
||||||
|
|
1
setup.py
1
setup.py
|
@ -38,6 +38,7 @@ setup(
|
||||||
],
|
],
|
||||||
tests_require=[
|
tests_require=[
|
||||||
'pyjwkest==1.0.1',
|
'pyjwkest==1.0.1',
|
||||||
|
'mock==1.3.0',
|
||||||
],
|
],
|
||||||
|
|
||||||
install_requires=[
|
install_requires=[
|
||||||
|
|
1
tox.ini
1
tox.ini
|
@ -9,6 +9,7 @@ deps =
|
||||||
django17: django==1.7
|
django17: django==1.7
|
||||||
django18: django==1.8
|
django18: django==1.8
|
||||||
coverage
|
coverage
|
||||||
|
mock
|
||||||
|
|
||||||
commands =
|
commands =
|
||||||
pip install -e .
|
pip install -e .
|
||||||
|
|
Loading…
Reference in a new issue