First intent to implement PKCE.
This commit is contained in:
parent
2c4ab6695e
commit
6e8af74f76
3 changed files with 48 additions and 7 deletions
|
@ -56,6 +56,10 @@ class AuthorizeEndpoint(object):
|
||||||
self.params.state = query_dict.get('state', '')
|
self.params.state = query_dict.get('state', '')
|
||||||
self.params.nonce = query_dict.get('nonce', '')
|
self.params.nonce = query_dict.get('nonce', '')
|
||||||
|
|
||||||
|
# PKCE parameters.
|
||||||
|
self.params.code_challenge = query_dict.get('code_challenge')
|
||||||
|
self.params.code_challenge_method = query_dict.get('code_challenge_method')
|
||||||
|
|
||||||
def validate_params(self):
|
def validate_params(self):
|
||||||
try:
|
try:
|
||||||
self.client = Client.objects.get(client_id=self.params.client_id)
|
self.client = Client.objects.get(client_id=self.params.client_id)
|
||||||
|
@ -85,7 +89,11 @@ class AuthorizeEndpoint(object):
|
||||||
if not (clean_redirect_uri in self.client.redirect_uris):
|
if not (clean_redirect_uri in self.client.redirect_uris):
|
||||||
logger.debug('[Authorize] Invalid redirect uri: %s', self.params.redirect_uri)
|
logger.debug('[Authorize] Invalid redirect uri: %s', self.params.redirect_uri)
|
||||||
raise RedirectUriError()
|
raise RedirectUriError()
|
||||||
|
|
||||||
|
# PKCE validation of the transformation method.
|
||||||
|
if self.params.code_challenge and self.params.code_challenge_method:
|
||||||
|
if not (self.params.code_challenge_method in ['plain', 'S256']):
|
||||||
|
raise AuthorizeError(self.params.redirect_uri, 'invalid_request', self.grant_type)
|
||||||
|
|
||||||
def create_response_uri(self):
|
def create_response_uri(self):
|
||||||
uri = urlsplit(self.params.redirect_uri)
|
uri = urlsplit(self.params.redirect_uri)
|
||||||
|
@ -99,7 +107,9 @@ class AuthorizeEndpoint(object):
|
||||||
client=self.client,
|
client=self.client,
|
||||||
scope=self.params.scope,
|
scope=self.params.scope,
|
||||||
nonce=self.params.nonce,
|
nonce=self.params.nonce,
|
||||||
is_authentication=self.is_authentication)
|
is_authentication=self.is_authentication,
|
||||||
|
code_challenge=self.params.code_challenge,
|
||||||
|
code_challenge_method=self.params.code_challenge_method)
|
||||||
|
|
||||||
code.save()
|
code.save()
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from base64 import b64decode
|
from base64 import b64decode, urlsafe_b64encode
|
||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
try:
|
try:
|
||||||
|
@ -6,6 +7,7 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from urllib import unquote
|
from urllib import unquote
|
||||||
|
|
||||||
|
from Crypto.Cipher import AES
|
||||||
from django.http import JsonResponse
|
from django.http import JsonResponse
|
||||||
|
|
||||||
from oidc_provider.lib.errors import *
|
from oidc_provider.lib.errors import *
|
||||||
|
@ -30,14 +32,16 @@ class TokenEndpoint(object):
|
||||||
|
|
||||||
self.params.client_id = client_id
|
self.params.client_id = client_id
|
||||||
self.params.client_secret = client_secret
|
self.params.client_secret = client_secret
|
||||||
self.params.redirect_uri = unquote(
|
self.params.redirect_uri = unquote(self.request.POST.get('redirect_uri', ''))
|
||||||
self.request.POST.get('redirect_uri', ''))
|
|
||||||
self.params.grant_type = self.request.POST.get('grant_type', '')
|
self.params.grant_type = self.request.POST.get('grant_type', '')
|
||||||
self.params.code = self.request.POST.get('code', '')
|
self.params.code = self.request.POST.get('code', '')
|
||||||
self.params.state = self.request.POST.get('state', '')
|
self.params.state = self.request.POST.get('state', '')
|
||||||
self.params.scope = self.request.POST.get('scope', '')
|
self.params.scope = self.request.POST.get('scope', '')
|
||||||
self.params.refresh_token = self.request.POST.get('refresh_token', '')
|
self.params.refresh_token = self.request.POST.get('refresh_token', '')
|
||||||
|
|
||||||
|
# PKCE parameters.
|
||||||
|
self.params.code_verifier = self.request.POST.get('code_verifier')
|
||||||
|
|
||||||
def _extract_client_auth(self):
|
def _extract_client_auth(self):
|
||||||
"""
|
"""
|
||||||
Get client credentials using HTTP Basic Authentication method.
|
Get client credentials using HTTP Basic Authentication method.
|
||||||
|
@ -90,6 +94,20 @@ class TokenEndpoint(object):
|
||||||
self.params.redirect_uri)
|
self.params.redirect_uri)
|
||||||
raise TokenError('invalid_grant')
|
raise TokenError('invalid_grant')
|
||||||
|
|
||||||
|
# Validate PKCE parameters.
|
||||||
|
if self.params.code_verifier:
|
||||||
|
obj = AES.new(settings.SECRET_KEY, AES.MODE_CBC)
|
||||||
|
code_challenge, code_challenge_method = tuple(obj.decrypt(self.code.code.decode('hex')).split(':'))
|
||||||
|
|
||||||
|
if code_challenge_method == 'S256':
|
||||||
|
new_code_challenge = urlsafe_b64encode(hashlib.sha256(self.params.code_verifier.encode('ascii')).digest()).replace('=', '')
|
||||||
|
else:
|
||||||
|
new_code_challenge = self.params.code_verifier
|
||||||
|
|
||||||
|
# TODO: We should explain the error.
|
||||||
|
if not (new_code_challenge == code_challenge):
|
||||||
|
raise TokenError('invalid_grant')
|
||||||
|
|
||||||
elif self.params.grant_type == 'refresh_token':
|
elif self.params.grant_type == 'refresh_token':
|
||||||
if not self.params.refresh_token:
|
if not self.params.refresh_token:
|
||||||
logger.debug('[Token] Missing refresh token')
|
logger.debug('[Token] Missing refresh token')
|
||||||
|
|
|
@ -2,6 +2,7 @@ from datetime import timedelta
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
from Crypto.Cipher import AES
|
||||||
from Crypto.PublicKey.RSA import importKey
|
from Crypto.PublicKey.RSA import importKey
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
|
@ -95,7 +96,8 @@ def create_token(user, client, id_token_dic, scope):
|
||||||
return token
|
return token
|
||||||
|
|
||||||
|
|
||||||
def create_code(user, client, scope, nonce, is_authentication):
|
def create_code(user, client, scope, nonce, is_authentication,
|
||||||
|
code_challenge=None, code_challenge_method=None):
|
||||||
"""
|
"""
|
||||||
Create and populate a Code object.
|
Create and populate a Code object.
|
||||||
|
|
||||||
|
@ -104,7 +106,18 @@ def create_code(user, client, scope, nonce, is_authentication):
|
||||||
code = Code()
|
code = Code()
|
||||||
code.user = user
|
code.user = user
|
||||||
code.client = client
|
code.client = client
|
||||||
code.code = uuid.uuid4().hex
|
|
||||||
|
if not code_challenge:
|
||||||
|
code.code = uuid.uuid4().hex
|
||||||
|
else:
|
||||||
|
obj = AES.new(settings.SECRET_KEY, AES.MODE_CBC)
|
||||||
|
|
||||||
|
# Default is 'plain' method.
|
||||||
|
code_challenge_method = 'plain' if not code_challenge_method else code_challenge_method
|
||||||
|
|
||||||
|
ciphertext = obj.encrypt(code_challenge + ':' + code_challenge_method)
|
||||||
|
code.code = ciphertext.encode('hex')
|
||||||
|
|
||||||
code.expires_at = timezone.now() + timedelta(
|
code.expires_at = timezone.now() + timedelta(
|
||||||
seconds=settings.get('OIDC_CODE_EXPIRE'))
|
seconds=settings.get('OIDC_CODE_EXPIRE'))
|
||||||
code.scope = scope
|
code.scope = scope
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue