diff --git a/oidc_provider/lib/endpoints/token.py b/oidc_provider/lib/endpoints/token.py index 2c62f33..3ba49f1 100644 --- a/oidc_provider/lib/endpoints/token.py +++ b/oidc_provider/lib/endpoints/token.py @@ -1,4 +1,6 @@ +from base64 import b64decode import logging +import re try: from urllib.parse import unquote except ImportError: @@ -24,15 +26,38 @@ class TokenEndpoint(object): self._extract_params() def _extract_params(self): - query_dict = self.request.POST + client_id, client_secret = self._extract_client_auth() - self.params.client_id = query_dict.get('client_id', '') - self.params.client_secret = query_dict.get('client_secret', '') + self.params.client_id = client_id + self.params.client_secret = client_secret self.params.redirect_uri = unquote( - query_dict.get('redirect_uri', '')) - self.params.grant_type = query_dict.get('grant_type', '') - self.params.code = query_dict.get('code', '') - self.params.state = query_dict.get('state', '') + self.request.POST.get('redirect_uri', '')) + self.params.grant_type = self.request.POST.get('grant_type', '') + self.params.code = self.request.POST.get('code', '') + self.params.state = self.request.POST.get('state', '') + + def _extract_client_auth(self): + """ + Get client credentials using HTTP Basic Authentication method. + Or try getting parameters via POST. + See: http://tools.ietf.org/html/rfc6750#section-2.1 + + Return a string. + """ + auth_header = self.request.META.get('HTTP_AUTHORIZATION', '') + + if re.compile('^Basic\s{1}.+$').match(auth_header): + b64_user_pass = auth_header.split()[1] + try: + user_pass = b64decode(b64_user_pass).decode('utf-8').split(':') + client_id, client_secret = tuple(user_pass) + except: + client_id = client_secret = '' + else: + client_id = self.request.POST.get('client_id', '') + client_secret = self.request.POST.get('client_secret', '') + + return (client_id, client_secret) def validate_params(self): if not (self.params.grant_type == 'authorization_code'): diff --git a/oidc_provider/lib/endpoints/userinfo.py b/oidc_provider/lib/endpoints/userinfo.py index 6f3f1b6..eb91ca0 100644 --- a/oidc_provider/lib/endpoints/userinfo.py +++ b/oidc_provider/lib/endpoints/userinfo.py @@ -1,5 +1,5 @@ -import re import logging +import re from django.http import HttpResponse from django.http import JsonResponse diff --git a/oidc_provider/tests/test_token_endpoint.py b/oidc_provider/tests/test_token_endpoint.py index 0f4805e..1aa2e1e 100644 --- a/oidc_provider/tests/test_token_endpoint.py +++ b/oidc_provider/tests/test_token_endpoint.py @@ -1,3 +1,4 @@ +from base64 import b64encode import json try: from urllib.parse import urlencode @@ -44,7 +45,7 @@ class TokenTestCase(TestCase): return post_data - def _post_request(self, post_data): + def _post_request(self, post_data, extras={}): """ Makes a request to the token endpoint by sending the `post_data` parameters using the 'application/x-www-form-urlencoded' @@ -54,7 +55,8 @@ class TokenTestCase(TestCase): request = self.factory.post(url, data=urlencode(post_data), - content_type='application/x-www-form-urlencoded') + content_type='application/x-www-form-urlencoded', + **extras) response = TokenView.as_view()(request) @@ -113,12 +115,10 @@ class TokenTestCase(TestCase): post_data = self._post_data(code=code.code) response = self._post_request(post_data) - response_dic = json.loads(response.content.decode('utf-8')) - self.assertEqual('access_token' in response_dic, True, - msg='"access_token" key is missing in response.') - self.assertEqual('error' in response_dic, False, - msg='"error" key should not exists in response.') + self.assertEqual('invalid_client' in response.content.decode('utf-8'), + False, + msg='Client authentication fails using request-body credentials.') # Now, test with an invalid client_id. invalid_data = post_data.copy() @@ -129,12 +129,32 @@ class TokenTestCase(TestCase): invalid_data['code'] = code.code response = self._post_request(invalid_data) - response_dic = json.loads(response.content.decode('utf-8')) - self.assertEqual('error' in response_dic, True, - msg='"error" key should exists in response.') - self.assertEqual(response_dic.get('error') == 'invalid_client', True, - msg='"error" key value should be "invalid_client".') + self.assertEqual('invalid_client' in response.content.decode('utf-8'), + True, + msg='Client authentication success with an invalid "client_id".') + + # Now, test using HTTP Basic Authentication method. + basicauth_data = post_data.copy() + + # Create another grant code. + code = self._create_code() + basicauth_data['code'] = code.code + + del basicauth_data['client_id'] + del basicauth_data['client_secret'] + + # Generate HTTP Basic Auth header with id and secret. + user_pass = self.client.client_id + ':' + self.client.client_secret + auth_header = b'Basic ' + b64encode(user_pass.encode('utf-8')) + response = self._post_request(basicauth_data, { + 'HTTP_AUTHORIZATION': auth_header.decode('utf-8'), + }) + response.content.decode('utf-8') + + self.assertEqual('invalid_client' in response.content.decode('utf-8'), + False, + msg='Client authentication fails using HTTP Basic Auth.') def test_access_token_contains_nonce(self): """