Make SITE_URL
optional.
This commit is contained in:
parent
497f2f3a68
commit
be5656bcf4
7 changed files with 75 additions and 27 deletions
|
@ -1,4 +1,3 @@
|
||||||
from datetime import timedelta
|
|
||||||
import logging
|
import logging
|
||||||
try:
|
try:
|
||||||
from urllib import urlencode
|
from urllib import urlencode
|
||||||
|
@ -126,7 +125,8 @@ class AuthorizeEndpoint(object):
|
||||||
id_token_dic = create_id_token(
|
id_token_dic = create_id_token(
|
||||||
user=self.request.user,
|
user=self.request.user,
|
||||||
aud=self.client.client_id,
|
aud=self.client.client_id,
|
||||||
nonce=self.params.nonce)
|
nonce=self.params.nonce,
|
||||||
|
request=self.request)
|
||||||
query_fragment['id_token'] = encode_id_token(id_token_dic, self.client)
|
query_fragment['id_token'] = encode_id_token(id_token_dic, self.client)
|
||||||
else:
|
else:
|
||||||
id_token_dic = {}
|
id_token_dic = {}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from base64 import b64decode, urlsafe_b64decode, urlsafe_b64encode
|
from base64 import b64decode, urlsafe_b64encode
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
@ -7,9 +7,7 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from urllib import unquote
|
from urllib import unquote
|
||||||
|
|
||||||
from Crypto.Cipher import AES
|
|
||||||
from django.http import JsonResponse
|
from django.http import JsonResponse
|
||||||
from django.conf import settings as django_settings
|
|
||||||
|
|
||||||
from oidc_provider.lib.errors import *
|
from oidc_provider.lib.errors import *
|
||||||
from oidc_provider.lib.utils.params import *
|
from oidc_provider.lib.utils.params import *
|
||||||
|
@ -138,6 +136,7 @@ class TokenEndpoint(object):
|
||||||
user=self.code.user,
|
user=self.code.user,
|
||||||
aud=self.client.client_id,
|
aud=self.client.client_id,
|
||||||
nonce=self.code.nonce,
|
nonce=self.code.nonce,
|
||||||
|
request=self.request,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
id_token_dic = {}
|
id_token_dic = {}
|
||||||
|
@ -171,6 +170,7 @@ class TokenEndpoint(object):
|
||||||
user=self.token.user,
|
user=self.token.user,
|
||||||
aud=self.client.client_id,
|
aud=self.client.client_id,
|
||||||
nonce=None,
|
nonce=None,
|
||||||
|
request=self.request,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
id_token_dic = {}
|
id_token_dic = {}
|
||||||
|
|
|
@ -13,12 +13,31 @@ def redirect(uri):
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
def get_issuer():
|
def get_site_url(site_url=None, request=None):
|
||||||
|
"""
|
||||||
|
Construct the site url.
|
||||||
|
|
||||||
|
Orders to decide site url:
|
||||||
|
1. valid `site_url` parameter
|
||||||
|
2. valid `SITE_URL` in settings
|
||||||
|
3. construct from `request` object
|
||||||
|
"""
|
||||||
|
site_url = site_url or settings.get('SITE_URL')
|
||||||
|
if site_url:
|
||||||
|
return site_url
|
||||||
|
elif request:
|
||||||
|
return '{}://{}'.format(request.scheme, request.get_host())
|
||||||
|
else:
|
||||||
|
raise Exception('Either pass `site_url`, '
|
||||||
|
'or set `SITE_URL` in settings, '
|
||||||
|
'or pass `request` object.')
|
||||||
|
|
||||||
|
def get_issuer(site_url=None, request=None):
|
||||||
"""
|
"""
|
||||||
Construct the issuer full url. Basically is the site url with some path
|
Construct the issuer full url. Basically is the site url with some path
|
||||||
appended.
|
appended.
|
||||||
"""
|
"""
|
||||||
site_url = settings.get('SITE_URL')
|
site_url = get_site_url(site_url=site_url, request=request)
|
||||||
path = reverse('oidc_provider:provider_info') \
|
path = reverse('oidc_provider:provider_info') \
|
||||||
.split('/.well-known/openid-configuration')[0]
|
.split('/.well-known/openid-configuration')[0]
|
||||||
issuer = site_url + path
|
issuer = site_url + path
|
||||||
|
|
|
@ -1,11 +1,9 @@
|
||||||
from base64 import urlsafe_b64decode, urlsafe_b64encode
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from Crypto.PublicKey.RSA import importKey
|
from Crypto.PublicKey.RSA import importKey
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
from hashlib import md5
|
|
||||||
from jwkest.jwk import RSAKey as jwk_RSAKey
|
from jwkest.jwk import RSAKey as jwk_RSAKey
|
||||||
from jwkest.jwk import SYMKey
|
from jwkest.jwk import SYMKey
|
||||||
from jwkest.jws import JWS
|
from jwkest.jws import JWS
|
||||||
|
@ -15,7 +13,7 @@ from oidc_provider.models import *
|
||||||
from oidc_provider import settings
|
from oidc_provider import settings
|
||||||
|
|
||||||
|
|
||||||
def create_id_token(user, aud, nonce):
|
def create_id_token(user, aud, nonce, request=None):
|
||||||
"""
|
"""
|
||||||
Receives a user object and aud (audience).
|
Receives a user object and aud (audience).
|
||||||
Then creates the id_token dictionary.
|
Then creates the id_token dictionary.
|
||||||
|
@ -35,7 +33,7 @@ def create_id_token(user, aud, nonce):
|
||||||
auth_time = int(time.mktime(user_auth_time.timetuple()))
|
auth_time = int(time.mktime(user_auth_time.timetuple()))
|
||||||
|
|
||||||
dic = {
|
dic = {
|
||||||
'iss': get_issuer(),
|
'iss': get_issuer(request=request),
|
||||||
'sub': sub,
|
'sub': sub,
|
||||||
'aud': str(aud),
|
'aud': str(aud),
|
||||||
'exp': exp_time,
|
'exp': exp_time,
|
||||||
|
|
|
@ -4,6 +4,9 @@ from django.conf import settings
|
||||||
|
|
||||||
|
|
||||||
class DefaultSettings(object):
|
class DefaultSettings(object):
|
||||||
|
required_attrs = (
|
||||||
|
'LOGIN_URL',
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def LOGIN_URL(self):
|
def LOGIN_URL(self):
|
||||||
|
@ -15,7 +18,7 @@ class DefaultSettings(object):
|
||||||
@property
|
@property
|
||||||
def SITE_URL(self):
|
def SITE_URL(self):
|
||||||
"""
|
"""
|
||||||
REQUIRED. The OP server url.
|
OPTIONAL. The OP server url.
|
||||||
"""
|
"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -131,7 +134,7 @@ def get(name, import_str=False):
|
||||||
value = getattr(default_settings, name)
|
value = getattr(default_settings, name)
|
||||||
value = getattr(settings, name)
|
value = getattr(settings, name)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
if value is None:
|
if value is None and value in default_settings.required_attrs:
|
||||||
raise Exception('You must set ' + name + ' in your settings.')
|
raise Exception('You must set ' + name + ' in your settings.')
|
||||||
|
|
||||||
value = import_from_str(value) if import_str else value
|
value = import_from_str(value) if import_str else value
|
||||||
|
|
|
@ -1,13 +1,44 @@
|
||||||
from django.conf import settings
|
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from oidc_provider.lib.utils.common import get_issuer
|
from oidc_provider.lib.utils.common import get_issuer
|
||||||
|
|
||||||
|
|
||||||
|
class Request(object):
|
||||||
|
"""
|
||||||
|
Mock request object.
|
||||||
|
"""
|
||||||
|
scheme = 'http'
|
||||||
|
|
||||||
|
def get_host(self):
|
||||||
|
return 'host-from-request:8888'
|
||||||
|
|
||||||
|
|
||||||
class CommonTest(TestCase):
|
class CommonTest(TestCase):
|
||||||
"""
|
"""
|
||||||
Test cases for common utils.
|
Test cases for common utils.
|
||||||
"""
|
"""
|
||||||
def test_get_issuer(self):
|
def test_get_issuer(self):
|
||||||
issuer = get_issuer()
|
request = Request()
|
||||||
self.assertEqual(issuer, settings.SITE_URL + '/openid')
|
|
||||||
|
# from default settings
|
||||||
|
self.assertEqual(get_issuer(),
|
||||||
|
'http://localhost:8000/openid')
|
||||||
|
|
||||||
|
# from custom settings
|
||||||
|
with self.settings(SITE_URL='http://otherhost:8000'):
|
||||||
|
self.assertEqual(get_issuer(),
|
||||||
|
'http://otherhost:8000/openid')
|
||||||
|
|
||||||
|
# `SITE_URL` not set, from `request`
|
||||||
|
with self.settings(SITE_URL=''):
|
||||||
|
self.assertEqual(get_issuer(request=request),
|
||||||
|
'http://host-from-request:8888/openid')
|
||||||
|
|
||||||
|
# use settings first if both are provided
|
||||||
|
self.assertEqual(get_issuer(request=request),
|
||||||
|
'http://localhost:8000/openid')
|
||||||
|
|
||||||
|
# `site_url` can even be overridden manually
|
||||||
|
self.assertEqual(get_issuer(site_url='http://127.0.0.1:9000',
|
||||||
|
request=request),
|
||||||
|
'http://127.0.0.1:9000/openid')
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
import logging
|
|
||||||
|
|
||||||
from Crypto.PublicKey import RSA
|
from Crypto.PublicKey import RSA
|
||||||
from django.contrib.auth.views import redirect_to_login, logout
|
from django.contrib.auth.views import redirect_to_login, logout
|
||||||
from django.core.urlresolvers import reverse
|
from django.core.urlresolvers import reverse
|
||||||
|
@ -14,7 +12,7 @@ from oidc_provider.lib.claims import StandardScopeClaims
|
||||||
from oidc_provider.lib.endpoints.authorize import *
|
from oidc_provider.lib.endpoints.authorize import *
|
||||||
from oidc_provider.lib.endpoints.token import *
|
from oidc_provider.lib.endpoints.token import *
|
||||||
from oidc_provider.lib.errors import *
|
from oidc_provider.lib.errors import *
|
||||||
from oidc_provider.lib.utils.common import redirect, get_issuer
|
from oidc_provider.lib.utils.common import redirect, get_site_url, get_issuer
|
||||||
from oidc_provider.lib.utils.oauth2 import protected_resource_view
|
from oidc_provider.lib.utils.oauth2 import protected_resource_view
|
||||||
from oidc_provider.models import RESPONSE_TYPE_CHOICES, RSAKey
|
from oidc_provider.models import RESPONSE_TYPE_CHOICES, RSAKey
|
||||||
from oidc_provider import settings
|
from oidc_provider import settings
|
||||||
|
@ -178,19 +176,18 @@ class ProviderInfoView(View):
|
||||||
def get(self, request, *args, **kwargs):
|
def get(self, request, *args, **kwargs):
|
||||||
dic = dict()
|
dic = dict()
|
||||||
|
|
||||||
dic['issuer'] = get_issuer()
|
site_url = get_site_url(request=request)
|
||||||
|
dic['issuer'] = get_issuer(site_url=site_url, request=request)
|
||||||
|
|
||||||
SITE_URL = settings.get('SITE_URL')
|
dic['authorization_endpoint'] = site_url + reverse('oidc_provider:authorize')
|
||||||
|
dic['token_endpoint'] = site_url + reverse('oidc_provider:token')
|
||||||
dic['authorization_endpoint'] = SITE_URL + reverse('oidc_provider:authorize')
|
dic['userinfo_endpoint'] = site_url + reverse('oidc_provider:userinfo')
|
||||||
dic['token_endpoint'] = SITE_URL + reverse('oidc_provider:token')
|
dic['end_session_endpoint'] = site_url + reverse('oidc_provider:logout')
|
||||||
dic['userinfo_endpoint'] = SITE_URL + reverse('oidc_provider:userinfo')
|
|
||||||
dic['end_session_endpoint'] = SITE_URL + reverse('oidc_provider:logout')
|
|
||||||
|
|
||||||
types_supported = [x[0] for x in RESPONSE_TYPE_CHOICES]
|
types_supported = [x[0] for x in RESPONSE_TYPE_CHOICES]
|
||||||
dic['response_types_supported'] = types_supported
|
dic['response_types_supported'] = types_supported
|
||||||
|
|
||||||
dic['jwks_uri'] = SITE_URL + reverse('oidc_provider:jwks')
|
dic['jwks_uri'] = site_url + reverse('oidc_provider:jwks')
|
||||||
|
|
||||||
dic['id_token_signing_alg_values_supported'] = ['HS256', 'RS256']
|
dic['id_token_signing_alg_values_supported'] = ['HS256', 'RS256']
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue