Make SITE_URL optional.

This commit is contained in:
Si Feng 2016-05-25 14:58:58 -07:00
parent 497f2f3a68
commit be5656bcf4
7 changed files with 75 additions and 27 deletions

View file

@ -1,4 +1,3 @@
from datetime import timedelta
import logging import logging
try: try:
from urllib import urlencode from urllib import urlencode
@ -126,7 +125,8 @@ class AuthorizeEndpoint(object):
id_token_dic = create_id_token( id_token_dic = create_id_token(
user=self.request.user, user=self.request.user,
aud=self.client.client_id, aud=self.client.client_id,
nonce=self.params.nonce) nonce=self.params.nonce,
request=self.request)
query_fragment['id_token'] = encode_id_token(id_token_dic, self.client) query_fragment['id_token'] = encode_id_token(id_token_dic, self.client)
else: else:
id_token_dic = {} id_token_dic = {}

View file

@ -1,4 +1,4 @@
from base64 import b64decode, urlsafe_b64decode, urlsafe_b64encode from base64 import b64decode, urlsafe_b64encode
import hashlib import hashlib
import logging import logging
import re import re
@ -7,9 +7,7 @@ try:
except ImportError: except ImportError:
from urllib import unquote from urllib import unquote
from Crypto.Cipher import AES
from django.http import JsonResponse from django.http import JsonResponse
from django.conf import settings as django_settings
from oidc_provider.lib.errors import * from oidc_provider.lib.errors import *
from oidc_provider.lib.utils.params import * from oidc_provider.lib.utils.params import *
@ -138,6 +136,7 @@ class TokenEndpoint(object):
user=self.code.user, user=self.code.user,
aud=self.client.client_id, aud=self.client.client_id,
nonce=self.code.nonce, nonce=self.code.nonce,
request=self.request,
) )
else: else:
id_token_dic = {} id_token_dic = {}
@ -171,6 +170,7 @@ class TokenEndpoint(object):
user=self.token.user, user=self.token.user,
aud=self.client.client_id, aud=self.client.client_id,
nonce=None, nonce=None,
request=self.request,
) )
else: else:
id_token_dic = {} id_token_dic = {}

View file

@ -13,12 +13,31 @@ def redirect(uri):
return response return response
def get_issuer(): def get_site_url(site_url=None, request=None):
"""
Construct the site url.
Orders to decide site url:
1. valid `site_url` parameter
2. valid `SITE_URL` in settings
3. construct from `request` object
"""
site_url = site_url or settings.get('SITE_URL')
if site_url:
return site_url
elif request:
return '{}://{}'.format(request.scheme, request.get_host())
else:
raise Exception('Either pass `site_url`, '
'or set `SITE_URL` in settings, '
'or pass `request` object.')
def get_issuer(site_url=None, request=None):
""" """
Construct the issuer full url. Basically is the site url with some path Construct the issuer full url. Basically is the site url with some path
appended. appended.
""" """
site_url = settings.get('SITE_URL') site_url = get_site_url(site_url=site_url, request=request)
path = reverse('oidc_provider:provider_info') \ path = reverse('oidc_provider:provider_info') \
.split('/.well-known/openid-configuration')[0] .split('/.well-known/openid-configuration')[0]
issuer = site_url + path issuer = site_url + path

View file

@ -1,11 +1,9 @@
from base64 import urlsafe_b64decode, urlsafe_b64encode
from datetime import timedelta from datetime import timedelta
import time import time
import uuid import uuid
from Crypto.PublicKey.RSA import importKey from Crypto.PublicKey.RSA import importKey
from django.utils import timezone from django.utils import timezone
from hashlib import md5
from jwkest.jwk import RSAKey as jwk_RSAKey from jwkest.jwk import RSAKey as jwk_RSAKey
from jwkest.jwk import SYMKey from jwkest.jwk import SYMKey
from jwkest.jws import JWS from jwkest.jws import JWS
@ -15,7 +13,7 @@ from oidc_provider.models import *
from oidc_provider import settings from oidc_provider import settings
def create_id_token(user, aud, nonce): def create_id_token(user, aud, nonce, request=None):
""" """
Receives a user object and aud (audience). Receives a user object and aud (audience).
Then creates the id_token dictionary. Then creates the id_token dictionary.
@ -35,7 +33,7 @@ def create_id_token(user, aud, nonce):
auth_time = int(time.mktime(user_auth_time.timetuple())) auth_time = int(time.mktime(user_auth_time.timetuple()))
dic = { dic = {
'iss': get_issuer(), 'iss': get_issuer(request=request),
'sub': sub, 'sub': sub,
'aud': str(aud), 'aud': str(aud),
'exp': exp_time, 'exp': exp_time,

View file

@ -4,6 +4,9 @@ from django.conf import settings
class DefaultSettings(object): class DefaultSettings(object):
required_attrs = (
'LOGIN_URL',
)
@property @property
def LOGIN_URL(self): def LOGIN_URL(self):
@ -15,7 +18,7 @@ class DefaultSettings(object):
@property @property
def SITE_URL(self): def SITE_URL(self):
""" """
REQUIRED. The OP server url. OPTIONAL. The OP server url.
""" """
return None return None
@ -131,7 +134,7 @@ def get(name, import_str=False):
value = getattr(default_settings, name) value = getattr(default_settings, name)
value = getattr(settings, name) value = getattr(settings, name)
except AttributeError: except AttributeError:
if value is None: if value is None and value in default_settings.required_attrs:
raise Exception('You must set ' + name + ' in your settings.') raise Exception('You must set ' + name + ' in your settings.')
value = import_from_str(value) if import_str else value value = import_from_str(value) if import_str else value

View file

@ -1,13 +1,44 @@
from django.conf import settings
from django.test import TestCase from django.test import TestCase
from oidc_provider.lib.utils.common import get_issuer from oidc_provider.lib.utils.common import get_issuer
class Request(object):
"""
Mock request object.
"""
scheme = 'http'
def get_host(self):
return 'host-from-request:8888'
class CommonTest(TestCase): class CommonTest(TestCase):
""" """
Test cases for common utils. Test cases for common utils.
""" """
def test_get_issuer(self): def test_get_issuer(self):
issuer = get_issuer() request = Request()
self.assertEqual(issuer, settings.SITE_URL + '/openid')
# from default settings
self.assertEqual(get_issuer(),
'http://localhost:8000/openid')
# from custom settings
with self.settings(SITE_URL='http://otherhost:8000'):
self.assertEqual(get_issuer(),
'http://otherhost:8000/openid')
# `SITE_URL` not set, from `request`
with self.settings(SITE_URL=''):
self.assertEqual(get_issuer(request=request),
'http://host-from-request:8888/openid')
# use settings first if both are provided
self.assertEqual(get_issuer(request=request),
'http://localhost:8000/openid')
# `site_url` can even be overridden manually
self.assertEqual(get_issuer(site_url='http://127.0.0.1:9000',
request=request),
'http://127.0.0.1:9000/openid')

View file

@ -1,5 +1,3 @@
import logging
from Crypto.PublicKey import RSA from Crypto.PublicKey import RSA
from django.contrib.auth.views import redirect_to_login, logout from django.contrib.auth.views import redirect_to_login, logout
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
@ -14,7 +12,7 @@ from oidc_provider.lib.claims import StandardScopeClaims
from oidc_provider.lib.endpoints.authorize import * from oidc_provider.lib.endpoints.authorize import *
from oidc_provider.lib.endpoints.token import * from oidc_provider.lib.endpoints.token import *
from oidc_provider.lib.errors import * from oidc_provider.lib.errors import *
from oidc_provider.lib.utils.common import redirect, get_issuer from oidc_provider.lib.utils.common import redirect, get_site_url, get_issuer
from oidc_provider.lib.utils.oauth2 import protected_resource_view from oidc_provider.lib.utils.oauth2 import protected_resource_view
from oidc_provider.models import RESPONSE_TYPE_CHOICES, RSAKey from oidc_provider.models import RESPONSE_TYPE_CHOICES, RSAKey
from oidc_provider import settings from oidc_provider import settings
@ -178,19 +176,18 @@ class ProviderInfoView(View):
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
dic = dict() dic = dict()
dic['issuer'] = get_issuer() site_url = get_site_url(request=request)
dic['issuer'] = get_issuer(site_url=site_url, request=request)
SITE_URL = settings.get('SITE_URL') dic['authorization_endpoint'] = site_url + reverse('oidc_provider:authorize')
dic['token_endpoint'] = site_url + reverse('oidc_provider:token')
dic['authorization_endpoint'] = SITE_URL + reverse('oidc_provider:authorize') dic['userinfo_endpoint'] = site_url + reverse('oidc_provider:userinfo')
dic['token_endpoint'] = SITE_URL + reverse('oidc_provider:token') dic['end_session_endpoint'] = site_url + reverse('oidc_provider:logout')
dic['userinfo_endpoint'] = SITE_URL + reverse('oidc_provider:userinfo')
dic['end_session_endpoint'] = SITE_URL + reverse('oidc_provider:logout')
types_supported = [x[0] for x in RESPONSE_TYPE_CHOICES] types_supported = [x[0] for x in RESPONSE_TYPE_CHOICES]
dic['response_types_supported'] = types_supported dic['response_types_supported'] = types_supported
dic['jwks_uri'] = SITE_URL + reverse('oidc_provider:jwks') dic['jwks_uri'] = site_url + reverse('oidc_provider:jwks')
dic['id_token_signing_alg_values_supported'] = ['HS256', 'RS256'] dic['id_token_signing_alg_values_supported'] = ['HS256', 'RS256']