From 8d672cc1ba945c72bc402197e605a50ceff6ce90 Mon Sep 17 00:00:00 2001 From: Maarten van Schaik Date: Wed, 30 Sep 2015 14:55:48 +0200 Subject: [PATCH] Add support for refresh_token to token endpoint --- oidc_provider/lib/endpoints/token.py | 87 ++++++++++++--- oidc_provider/tests/test_token_endpoint.py | 123 +++++++++++++++++++-- setup.py | 1 + tox.ini | 1 + 4 files changed, 186 insertions(+), 26 deletions(-) diff --git a/oidc_provider/lib/endpoints/token.py b/oidc_provider/lib/endpoints/token.py index 3ba49f1..b98d40a 100644 --- a/oidc_provider/lib/endpoints/token.py +++ b/oidc_provider/lib/endpoints/token.py @@ -35,6 +35,8 @@ class TokenEndpoint(object): self.params.grant_type = self.request.POST.get('grant_type', '') self.params.code = self.request.POST.get('code', '') self.params.state = self.request.POST.get('state', '') + self.params.scope = self.request.POST.get('scope', '') + self.params.refresh_token = self.request.POST.get('refresh_token', '') def _extract_client_auth(self): """ @@ -60,23 +62,29 @@ class TokenEndpoint(object): return (client_id, client_secret) def validate_params(self): - if not (self.params.grant_type == 'authorization_code'): - logger.error('[Token] Invalid grant type: %s', self.params.grant_type) - raise TokenError('unsupported_grant_type') - try: self.client = Client.objects.get(client_id=self.params.client_id) - if not (self.client.client_secret == self.params.client_secret): - logger.error('[Token] Invalid client secret: client %s do not have secret %s', - self.client.client_id, self.client.client_secret) - raise TokenError('invalid_client') + except Client.DoesNotExist: + logger.error('[Token] Client does not exist: %s', self.params.client_id) + raise TokenError('invalid_client') + if not (self.client.client_secret == self.params.client_secret): + logger.error('[Token] Invalid client secret: client %s do not have secret %s', + self.client.client_id, self.client.client_secret) + raise TokenError('invalid_client') + + if self.params.grant_type == 'authorization_code': if not (self.params.redirect_uri in self.client.redirect_uris): logger.error('[Token] Invalid redirect uri: %s', self.params.redirect_uri) raise TokenError('invalid_client') - self.code = Code.objects.get(code=self.params.code) + try: + self.code = Code.objects.get(code=self.params.code) + + except Code.DoesNotExist: + logger.error('[Token] Code does not exist: %s', self.params.code) + raise TokenError('invalid_grant') if not (self.code.client == self.client) \ or self.code.has_expired(): @@ -84,15 +92,33 @@ class TokenEndpoint(object): self.params.redirect_uri) raise TokenError('invalid_grant') - except Client.DoesNotExist: - logger.error('[Token] Client does not exist: %s', self.params.client_id) - raise TokenError('invalid_client') + elif self.params.grant_type == 'refresh_token': + if not self.params.refresh_token: + logger.error('[Token] Missing refresh token') + raise TokenError('invalid_grant') - except Code.DoesNotExist: - logger.error('[Token] Code does not exist: %s', self.params.code) - raise TokenError('invalid_grant') + try: + self.token = Token.objects.get(refresh_token=self.params.refresh_token, + client=self.client) + + except Token.DoesNotExist: + logger.error('[Token] Refresh token does not exist: %s', self.params.refresh_token) + raise TokenError('invalid_grant') + + else: + logger.error('[Token] Invalid grant type: %s', self.params.grant_type) + raise TokenError('unsupported_grant_type') def create_response_dic(self): + if self.params.grant_type == 'authorization_code': + return self.create_code_response_dic() + elif self.params.grant_type == 'refresh_token': + return self.create_refresh_response_dic() + else: + # Should have already been catched by validate_params + raise RuntimeError('Invalid grant type') + + def create_code_response_dic(self): id_token_dic = create_id_token( user=self.code.user, aud=self.client.client_id, @@ -113,6 +139,37 @@ class TokenEndpoint(object): dic = { 'access_token': token.access_token, + 'refresh_token': token.refresh_token, + 'token_type': 'bearer', + 'expires_in': settings.get('OIDC_TOKEN_EXPIRE'), + 'id_token': encode_id_token(id_token_dic), + } + + return dic + + def create_refresh_response_dic(self): + id_token_dic = create_id_token( + user=self.token.user, + aud=self.client.client_id, + nonce=None, + ) + + token = create_token( + user=self.token.user, + client=self.token.client, + id_token_dic=id_token_dic, + scope=self.token.scope) + + # Store the token. + token.save() + + # We don't need to store the code anymore. + self.token.refresh_token = None + self.token.save() + + dic = { + 'access_token': token.access_token, + 'refresh_token': token.refresh_token, 'token_type': 'bearer', 'expires_in': settings.get('OIDC_TOKEN_EXPIRE'), 'id_token': encode_id_token(id_token_dic), diff --git a/oidc_provider/tests/test_token_endpoint.py b/oidc_provider/tests/test_token_endpoint.py index 1aa2e1e..6a373b4 100644 --- a/oidc_provider/tests/test_token_endpoint.py +++ b/oidc_provider/tests/test_token_endpoint.py @@ -7,7 +7,7 @@ except ImportError: import uuid from django.core.urlresolvers import reverse -from django.test import RequestFactory +from django.test import RequestFactory, override_settings from django.test import TestCase from jwkest.jwk import KEYS from jwkest.jws import JWS @@ -16,6 +16,7 @@ from jwkest.jwt import JWT from oidc_provider.lib.utils.token import * from oidc_provider.tests.app.utils import * from oidc_provider.views import * +from mock import patch class TokenTestCase(TestCase): @@ -30,7 +31,7 @@ class TokenTestCase(TestCase): self.user = create_fake_user() self.client = create_fake_client(response_type='code') - def _post_data(self, code): + def _auth_code_post_data(self, code): """ All the data that will be POSTed to the Token Endpoint. """ @@ -45,6 +46,19 @@ class TokenTestCase(TestCase): return post_data + def _refresh_token_post_data(self, refresh_token): + """ + All the data that will be POSTed to the Token Endpoint. + """ + post_data = { + 'client_id': self.client.client_id, + 'client_secret': self.client.client_secret, + 'grant_type': 'refresh_token', + 'refresh_token': refresh_token, + } + + return post_data + def _post_request(self, post_data, extras={}): """ Makes a request to the token endpoint by sending the @@ -75,6 +89,98 @@ class TokenTestCase(TestCase): return code + def _get_keys(self): + """ + Get public key from discovery. + """ + request = self.factory.get(reverse('oidc_provider:jwks')) + response = JwksView.as_view()(request) + jwks_dic = json.loads(response.content.decode('utf-8')) + SIGKEYS = KEYS() + SIGKEYS.load_dict(jwks_dic) + return SIGKEYS + + @override_settings(OIDC_TOKEN_EXPIRE=720) + def test_authorization_code(self): + """ + We MUST validate the signature of the ID Token according to JWS + using the algorithm specified in the alg Header Parameter of + the JOSE Header. + """ + SIGKEYS = self._get_keys() + code = self._create_code() + + post_data = self._auth_code_post_data(code=code.code) + + response = self._post_request(post_data) + response_dic = json.loads(response.content.decode('utf-8')) + + id_token = JWS().verify_compact(response_dic['id_token'].encode('utf-8'), SIGKEYS) + + token = Token.objects.get(user=self.user) + self.assertEqual(response_dic['access_token'], token.access_token) + self.assertEqual(response_dic['refresh_token'], token.refresh_token) + self.assertEqual(response_dic['token_type'], 'bearer') + self.assertEqual(response_dic['expires_in'], 720) + self.assertEqual(id_token['sub'], str(self.user.id)) + self.assertEqual(id_token['aud'], self.client.client_id) + + def test_refresh_token(self): + """ + A request to the Token Endpoint can also use a Refresh Token + by using the grant_type value refresh_token, as described in + Section 6 of OAuth 2.0 [RFC6749]. + """ + SIGKEYS = self._get_keys() + + # Retrieve refresh token + code = self._create_code() + post_data = self._auth_code_post_data(code=code.code) + real_now = timezone.now + with patch('oidc_provider.lib.utils.token.timezone.now') as now: + now.return_value = real_now() + response = self._post_request(post_data) + + response_dic1 = json.loads(response.content.decode('utf-8')) + id_token1 = JWS().verify_compact(response_dic1['id_token'].encode('utf-8'), SIGKEYS) + + # Use refresh token to obtain new token + post_data = self._refresh_token_post_data(response_dic1['refresh_token']) + with patch('oidc_provider.lib.utils.token.timezone.now') as now: + now.return_value = real_now() + timedelta(minutes=10) + response = self._post_request(post_data) + + response_dic2 = json.loads(response.content.decode('utf-8')) + id_token2 = JWS().verify_compact(response_dic2['id_token'].encode('utf-8'), SIGKEYS) + + self.assertNotEqual(response_dic1['id_token'], response_dic2['id_token']) + self.assertNotEqual(response_dic1['access_token'], response_dic2['access_token']) + self.assertNotEqual(response_dic1['refresh_token'], response_dic2['refresh_token']) + + # http://openid.net/specs/openid-connect-core-1_0.html#rfc.section.12.2 + self.assertEqual(id_token1['iss'], id_token2['iss']) + self.assertEqual(id_token1['sub'], id_token2['sub']) + self.assertNotEqual(id_token1['iat'], id_token2['iat']) + self.assertEqual(id_token1['aud'], id_token2['aud']) + self.assertEqual(id_token1['auth_time'], id_token2['auth_time']) + self.assertEqual(id_token1.get('azp'), id_token2.get('azp')) + + # Refresh token can't be reused + post_data = self._refresh_token_post_data(response_dic1['refresh_token']) + response = self._post_request(post_data) + self.assertIn('invalid_grant', response.content.decode('utf-8')) + + # Empty refresh token is invalid + post_data = self._refresh_token_post_data('') + response = self._post_request(post_data) + self.assertIn('invalid_grant', response.content.decode('utf-8')) + + # No refresh token is invalid + post_data = self._refresh_token_post_data('') + del post_data['refresh_token'] + response = self._post_request(post_data) + self.assertIn('invalid_grant', response.content.decode('utf-8')) + def test_request_methods(self): """ Client sends an HTTP POST request to the Token Endpoint. Other request @@ -112,7 +218,7 @@ class TokenTestCase(TestCase): code = self._create_code() # Test a valid request to the token endpoint. - post_data = self._post_data(code=code.code) + post_data = self._auth_code_post_data(code=code.code) response = self._post_request(post_data) @@ -168,7 +274,7 @@ class TokenTestCase(TestCase): """ code = self._create_code() - post_data = self._post_data(code=code.code) + post_data = self._auth_code_post_data(code=code.code) response = self._post_request(post_data) @@ -194,17 +300,12 @@ class TokenTestCase(TestCase): using the algorithm specified in the alg Header Parameter of the JOSE Header. """ - # Get public key from discovery. - request = self.factory.get(reverse('oidc_provider:jwks')) - response = JwksView.as_view()(request) - jwks_dic = json.loads(response.content.decode('utf-8')) - SIGKEYS = KEYS() - SIGKEYS.load_dict(jwks_dic) + SIGKEYS = self._get_keys() RSAKEYS = [ k for k in SIGKEYS if k.kty == 'RSA' ] code = self._create_code() - post_data = self._post_data(code=code.code) + post_data = self._auth_code_post_data(code=code.code) response = self._post_request(post_data) response_dic = json.loads(response.content.decode('utf-8')) diff --git a/setup.py b/setup.py index 1023408..fe06c08 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ setup( ], tests_require=[ 'pyjwkest==1.0.1', + 'mock==1.3.0', ], install_requires=[ diff --git a/tox.ini b/tox.ini index 2221008..857acb9 100644 --- a/tox.ini +++ b/tox.ini @@ -9,6 +9,7 @@ deps = django17: django==1.7 django18: django==1.8 coverage + mock commands = pip install -e .