diff --git a/oidc_provider/lib/endpoints/token.py b/oidc_provider/lib/endpoints/token.py index cf1e10c..327ed81 100644 --- a/oidc_provider/lib/endpoints/token.py +++ b/oidc_provider/lib/endpoints/token.py @@ -2,7 +2,7 @@ from base64 import b64decode, urlsafe_b64encode import hashlib import logging import re - +from django.contrib.auth import authenticate from oidc_provider.lib.utils.common import cleanup_url_from_query_string try: @@ -34,6 +34,7 @@ class TokenEndpoint(object): def __init__(self, request): self.request = request self.params = {} + self.user = None self._extract_params() def _extract_params(self): @@ -122,23 +123,15 @@ class TokenEndpoint(object): raise TokenError('invalid_grant') elif self.params['grant_type'] == 'password': - from django.contrib.auth import authenticate - user = authenticate(username=self.params['username'], password=self.params['password']) + user = authenticate( + username=self.params['username'], + password=self.params['password'] + ) + if not user: raise TokenError('Invalid user credentials') - self.token = create_token(user, self.client, self.params['scope'].split(' ')) - - self.token.id_token = create_id_token( - user=user, - aud=self.client.client_id, - nonce='self.code.nonce', - at_hash=self.token.at_hash, - request=self.request, - scope=self.params['scope'], - ) - - self.token.save() + self.user = user elif self.params['grant_type'] == 'refresh_token': if not self.params['refresh_token']: @@ -163,7 +156,30 @@ class TokenEndpoint(object): elif self.params['grant_type'] == 'refresh_token': return self.create_refresh_response_dic() elif self.params['grant_type'] == 'password': - return {'access_token': self.token.access_token} + return self.create_access_token_response_dic() + + def create_access_token_response_dic(self): + token = create_token( + self.user, + self.client, + self.params['scope'].split(' ')) + + token.id_token = create_id_token( + user=self.user, + aud=self.client.client_id, + nonce='self.code.nonce', + at_hash=token.at_hash, + request=self.request, + scope=self.params['scope'], + ) + + token.save() + return { + 'access_token': token.access_token, + 'refresh_token': token.refresh_token, + 'expires_in': settings.get('OIDC_TOKEN_EXPIRE'), + 'token_type': 'bearer' + } def create_code_response_dic(self): token = create_token( diff --git a/oidc_provider/tests/test_token_endpoint.py b/oidc_provider/tests/test_token_endpoint.py index 50f556e..1da2c07 100644 --- a/oidc_provider/tests/test_token_endpoint.py +++ b/oidc_provider/tests/test_token_endpoint.py @@ -19,7 +19,7 @@ from django.utils import timezone from jwkest.jwk import KEYS from jwkest.jws import JWS from jwkest.jwt import JWT -from mock import patch +from mock import patch, Mock from oidc_provider.lib.utils.token import create_code from oidc_provider.models import Token @@ -207,6 +207,29 @@ class TokenTestCase(TestCase): self.assertEqual(400, response.status_code) + @patch('oidc_provider.lib.utils.token.uuid') + @override_settings(OIDC_TOKEN_EXPIRE=120) + def test_password_grant_full_response(self, mock_uuid): + test_hex = 'fake_token' + mock_uuid4 = Mock(spec=uuid.uuid4) + mock_uuid4.hex = test_hex + mock_uuid.uuid4.return_value = mock_uuid4 + + response = self._post_request( + post_data=self._password_grant_post_data(), + extras=self._auth_header() + ) + + response_dict = json.loads(response.content.decode('utf-8')) + expected_response_dic = { + "access_token": 'fake_token', + "refresh_token": 'fake_token', + "expires_in": 120, + "token_type": "bearer", + } + + self.assertDictEqual(expected_response_dic, response_dict) + @override_settings(OIDC_TOKEN_EXPIRE=720) def test_authorization_code(self): """