diff --git a/oidc_provider/lib/endpoints/authorize.py b/oidc_provider/lib/endpoints/authorize.py index 56972a4..3bb6409 100644 --- a/oidc_provider/lib/endpoints/authorize.py +++ b/oidc_provider/lib/endpoints/authorize.py @@ -56,6 +56,10 @@ class AuthorizeEndpoint(object): self.params.state = query_dict.get('state', '') self.params.nonce = query_dict.get('nonce', '') + # PKCE parameters. + self.params.code_challenge = query_dict.get('code_challenge') + self.params.code_challenge_method = query_dict.get('code_challenge_method') + def validate_params(self): try: self.client = Client.objects.get(client_id=self.params.client_id) @@ -85,7 +89,11 @@ class AuthorizeEndpoint(object): if not (clean_redirect_uri in self.client.redirect_uris): logger.debug('[Authorize] Invalid redirect uri: %s', self.params.redirect_uri) raise RedirectUriError() - + + # PKCE validation of the transformation method. + if self.params.code_challenge and self.params.code_challenge_method: + if not (self.params.code_challenge_method in ['plain', 'S256']): + raise AuthorizeError(self.params.redirect_uri, 'invalid_request', self.grant_type) def create_response_uri(self): uri = urlsplit(self.params.redirect_uri) @@ -99,7 +107,9 @@ class AuthorizeEndpoint(object): client=self.client, scope=self.params.scope, nonce=self.params.nonce, - is_authentication=self.is_authentication) + is_authentication=self.is_authentication, + code_challenge=self.params.code_challenge, + code_challenge_method=self.params.code_challenge_method) code.save() diff --git a/oidc_provider/lib/endpoints/token.py b/oidc_provider/lib/endpoints/token.py index a981eee..4acb005 100644 --- a/oidc_provider/lib/endpoints/token.py +++ b/oidc_provider/lib/endpoints/token.py @@ -1,4 +1,5 @@ -from base64 import b64decode +from base64 import b64decode, urlsafe_b64encode +import hashlib import logging import re try: @@ -6,6 +7,7 @@ try: except ImportError: from urllib import unquote +from Crypto.Cipher import AES from django.http import JsonResponse from oidc_provider.lib.errors import * @@ -30,14 +32,16 @@ class TokenEndpoint(object): self.params.client_id = client_id self.params.client_secret = client_secret - self.params.redirect_uri = unquote( - self.request.POST.get('redirect_uri', '')) + self.params.redirect_uri = unquote(self.request.POST.get('redirect_uri', '')) self.params.grant_type = self.request.POST.get('grant_type', '') self.params.code = self.request.POST.get('code', '') self.params.state = self.request.POST.get('state', '') self.params.scope = self.request.POST.get('scope', '') self.params.refresh_token = self.request.POST.get('refresh_token', '') + # PKCE parameters. + self.params.code_verifier = self.request.POST.get('code_verifier') + def _extract_client_auth(self): """ Get client credentials using HTTP Basic Authentication method. @@ -90,6 +94,20 @@ class TokenEndpoint(object): self.params.redirect_uri) raise TokenError('invalid_grant') + # Validate PKCE parameters. + if self.params.code_verifier: + obj = AES.new(settings.SECRET_KEY, AES.MODE_CBC) + code_challenge, code_challenge_method = tuple(obj.decrypt(self.code.code.decode('hex')).split(':')) + + if code_challenge_method == 'S256': + new_code_challenge = urlsafe_b64encode(hashlib.sha256(self.params.code_verifier.encode('ascii')).digest()).replace('=', '') + else: + new_code_challenge = self.params.code_verifier + + # TODO: We should explain the error. + if not (new_code_challenge == code_challenge): + raise TokenError('invalid_grant') + elif self.params.grant_type == 'refresh_token': if not self.params.refresh_token: logger.debug('[Token] Missing refresh token') diff --git a/oidc_provider/lib/utils/token.py b/oidc_provider/lib/utils/token.py index e512326..308364f 100644 --- a/oidc_provider/lib/utils/token.py +++ b/oidc_provider/lib/utils/token.py @@ -2,6 +2,7 @@ from datetime import timedelta import time import uuid +from Crypto.Cipher import AES from Crypto.PublicKey.RSA import importKey from django.utils import timezone from hashlib import md5 @@ -95,7 +96,8 @@ def create_token(user, client, id_token_dic, scope): return token -def create_code(user, client, scope, nonce, is_authentication): +def create_code(user, client, scope, nonce, is_authentication, + code_challenge=None, code_challenge_method=None): """ Create and populate a Code object. @@ -104,7 +106,18 @@ def create_code(user, client, scope, nonce, is_authentication): code = Code() code.user = user code.client = client - code.code = uuid.uuid4().hex + + if not code_challenge: + code.code = uuid.uuid4().hex + else: + obj = AES.new(settings.SECRET_KEY, AES.MODE_CBC) + + # Default is 'plain' method. + code_challenge_method = 'plain' if not code_challenge_method else code_challenge_method + + ciphertext = obj.encrypt(code_challenge + ':' + code_challenge_method) + code.code = ciphertext.encode('hex') + code.expires_at = timezone.now() + timedelta( seconds=settings.get('OIDC_CODE_EXPIRE')) code.scope = scope