Move extract_client_auth to oauth2 utils.
This commit is contained in:
parent
61d88014c9
commit
21a64b262c
2 changed files with 29 additions and 26 deletions
|
@ -1,7 +1,6 @@
|
||||||
from base64 import b64decode, urlsafe_b64encode
|
from base64 import urlsafe_b64encode
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import re
|
|
||||||
from django.contrib.auth import authenticate
|
from django.contrib.auth import authenticate
|
||||||
|
|
||||||
from django.http import JsonResponse
|
from django.http import JsonResponse
|
||||||
|
@ -10,6 +9,7 @@ from oidc_provider.lib.errors import (
|
||||||
TokenError,
|
TokenError,
|
||||||
UserAuthError,
|
UserAuthError,
|
||||||
)
|
)
|
||||||
|
from oidc_provider.lib.utils.oauth2 import extract_client_auth
|
||||||
from oidc_provider.lib.utils.token import (
|
from oidc_provider.lib.utils.token import (
|
||||||
create_id_token,
|
create_id_token,
|
||||||
create_token,
|
create_token,
|
||||||
|
@ -26,6 +26,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TokenEndpoint(object):
|
class TokenEndpoint(object):
|
||||||
|
|
||||||
def __init__(self, request):
|
def __init__(self, request):
|
||||||
self.request = request
|
self.request = request
|
||||||
self.params = {}
|
self.params = {}
|
||||||
|
@ -33,7 +34,7 @@ class TokenEndpoint(object):
|
||||||
self._extract_params()
|
self._extract_params()
|
||||||
|
|
||||||
def _extract_params(self):
|
def _extract_params(self):
|
||||||
client_id, client_secret = self._extract_client_auth()
|
client_id, client_secret = extract_client_auth(self.request)
|
||||||
|
|
||||||
self.params['client_id'] = client_id
|
self.params['client_id'] = client_id
|
||||||
self.params['client_secret'] = client_secret
|
self.params['client_secret'] = client_secret
|
||||||
|
@ -49,29 +50,6 @@ class TokenEndpoint(object):
|
||||||
self.params['username'] = self.request.POST.get('username', '')
|
self.params['username'] = self.request.POST.get('username', '')
|
||||||
self.params['password'] = self.request.POST.get('password', '')
|
self.params['password'] = self.request.POST.get('password', '')
|
||||||
|
|
||||||
def _extract_client_auth(self):
|
|
||||||
"""
|
|
||||||
Get client credentials using HTTP Basic Authentication method.
|
|
||||||
Or try getting parameters via POST.
|
|
||||||
See: http://tools.ietf.org/html/rfc6750#section-2.1
|
|
||||||
|
|
||||||
Return a string.
|
|
||||||
"""
|
|
||||||
auth_header = self.request.META.get('HTTP_AUTHORIZATION', '')
|
|
||||||
|
|
||||||
if re.compile('^Basic\s{1}.+$').match(auth_header):
|
|
||||||
b64_user_pass = auth_header.split()[1]
|
|
||||||
try:
|
|
||||||
user_pass = b64decode(b64_user_pass).decode('utf-8').split(':')
|
|
||||||
client_id, client_secret = tuple(user_pass)
|
|
||||||
except Exception:
|
|
||||||
client_id = client_secret = ''
|
|
||||||
else:
|
|
||||||
client_id = self.request.POST.get('client_id', '')
|
|
||||||
client_secret = self.request.POST.get('client_secret', '')
|
|
||||||
|
|
||||||
return (client_id, client_secret)
|
|
||||||
|
|
||||||
def validate_params(self):
|
def validate_params(self):
|
||||||
try:
|
try:
|
||||||
self.client = Client.objects.get(client_id=self.params['client_id'])
|
self.client = Client.objects.get(client_id=self.params['client_id'])
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from base64 import b64decode
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
@ -28,6 +29,30 @@ def extract_access_token(request):
|
||||||
return access_token
|
return access_token
|
||||||
|
|
||||||
|
|
||||||
|
def extract_client_auth(request):
|
||||||
|
"""
|
||||||
|
Get client credentials using HTTP Basic Authentication method.
|
||||||
|
Or try getting parameters via POST.
|
||||||
|
See: http://tools.ietf.org/html/rfc6750#section-2.1
|
||||||
|
|
||||||
|
Return a tuple `(client_id, client_secret)`.
|
||||||
|
"""
|
||||||
|
auth_header = request.META.get('HTTP_AUTHORIZATION', '')
|
||||||
|
|
||||||
|
if re.compile('^Basic\s{1}.+$').match(auth_header):
|
||||||
|
b64_user_pass = auth_header.split()[1]
|
||||||
|
try:
|
||||||
|
user_pass = b64decode(b64_user_pass).decode('utf-8').split(':')
|
||||||
|
client_id, client_secret = tuple(user_pass)
|
||||||
|
except Exception:
|
||||||
|
client_id = client_secret = ''
|
||||||
|
else:
|
||||||
|
client_id = request.POST.get('client_id', '')
|
||||||
|
client_secret = request.POST.get('client_secret', '')
|
||||||
|
|
||||||
|
return (client_id, client_secret)
|
||||||
|
|
||||||
|
|
||||||
def protected_resource_view(scopes=None):
|
def protected_resource_view(scopes=None):
|
||||||
"""
|
"""
|
||||||
View decorator. The client accesses protected resources by presenting the
|
View decorator. The client accesses protected resources by presenting the
|
||||||
|
|
Loading…
Reference in a new issue