diff --git a/oidc_provider/lib/endpoints/authorize.py b/oidc_provider/lib/endpoints/authorize.py index 09365e7..43ff43f 100644 --- a/oidc_provider/lib/endpoints/authorize.py +++ b/oidc_provider/lib/endpoints/authorize.py @@ -1,4 +1,3 @@ -from datetime import timedelta import logging try: from urllib import urlencode @@ -126,7 +125,8 @@ class AuthorizeEndpoint(object): id_token_dic = create_id_token( user=self.request.user, aud=self.client.client_id, - nonce=self.params.nonce) + nonce=self.params.nonce, + request=self.request) query_fragment['id_token'] = encode_id_token(id_token_dic, self.client) else: id_token_dic = {} diff --git a/oidc_provider/lib/endpoints/token.py b/oidc_provider/lib/endpoints/token.py index b5652ba..200cb2c 100644 --- a/oidc_provider/lib/endpoints/token.py +++ b/oidc_provider/lib/endpoints/token.py @@ -1,4 +1,4 @@ -from base64 import b64decode, urlsafe_b64decode, urlsafe_b64encode +from base64 import b64decode, urlsafe_b64encode import hashlib import logging import re @@ -7,9 +7,7 @@ try: except ImportError: from urllib import unquote -from Crypto.Cipher import AES from django.http import JsonResponse -from django.conf import settings as django_settings from oidc_provider.lib.errors import * from oidc_provider.lib.utils.params import * @@ -138,6 +136,7 @@ class TokenEndpoint(object): user=self.code.user, aud=self.client.client_id, nonce=self.code.nonce, + request=self.request, ) else: id_token_dic = {} @@ -171,6 +170,7 @@ class TokenEndpoint(object): user=self.token.user, aud=self.client.client_id, nonce=None, + request=self.request, ) else: id_token_dic = {} diff --git a/oidc_provider/lib/utils/common.py b/oidc_provider/lib/utils/common.py index 9d37f4f..78305c6 100644 --- a/oidc_provider/lib/utils/common.py +++ b/oidc_provider/lib/utils/common.py @@ -13,12 +13,31 @@ def redirect(uri): return response -def get_issuer(): +def get_site_url(site_url=None, request=None): + """ + Construct the site url. + + Orders to decide site url: + 1. valid `site_url` parameter + 2. valid `SITE_URL` in settings + 3. construct from `request` object + """ + site_url = site_url or settings.get('SITE_URL') + if site_url: + return site_url + elif request: + return '{}://{}'.format(request.scheme, request.get_host()) + else: + raise Exception('Either pass `site_url`, ' + 'or set `SITE_URL` in settings, ' + 'or pass `request` object.') + +def get_issuer(site_url=None, request=None): """ Construct the issuer full url. Basically is the site url with some path appended. """ - site_url = settings.get('SITE_URL') + site_url = get_site_url(site_url=site_url, request=request) path = reverse('oidc_provider:provider_info') \ .split('/.well-known/openid-configuration')[0] issuer = site_url + path diff --git a/oidc_provider/lib/utils/token.py b/oidc_provider/lib/utils/token.py index 10f4a18..fc0880d 100644 --- a/oidc_provider/lib/utils/token.py +++ b/oidc_provider/lib/utils/token.py @@ -1,11 +1,9 @@ -from base64 import urlsafe_b64decode, urlsafe_b64encode from datetime import timedelta import time import uuid from Crypto.PublicKey.RSA import importKey from django.utils import timezone -from hashlib import md5 from jwkest.jwk import RSAKey as jwk_RSAKey from jwkest.jwk import SYMKey from jwkest.jws import JWS @@ -15,7 +13,7 @@ from oidc_provider.models import * from oidc_provider import settings -def create_id_token(user, aud, nonce): +def create_id_token(user, aud, nonce, request=None): """ Receives a user object and aud (audience). Then creates the id_token dictionary. @@ -35,7 +33,7 @@ def create_id_token(user, aud, nonce): auth_time = int(time.mktime(user_auth_time.timetuple())) dic = { - 'iss': get_issuer(), + 'iss': get_issuer(request=request), 'sub': sub, 'aud': str(aud), 'exp': exp_time, diff --git a/oidc_provider/settings.py b/oidc_provider/settings.py index b7890d1..49530f4 100644 --- a/oidc_provider/settings.py +++ b/oidc_provider/settings.py @@ -4,6 +4,9 @@ from django.conf import settings class DefaultSettings(object): + required_attrs = ( + 'LOGIN_URL', + ) @property def LOGIN_URL(self): @@ -15,7 +18,7 @@ class DefaultSettings(object): @property def SITE_URL(self): """ - REQUIRED. The OP server url. + OPTIONAL. The OP server url. """ return None @@ -131,7 +134,7 @@ def get(name, import_str=False): value = getattr(default_settings, name) value = getattr(settings, name) except AttributeError: - if value is None: + if value is None and value in default_settings.required_attrs: raise Exception('You must set ' + name + ' in your settings.') value = import_from_str(value) if import_str else value diff --git a/oidc_provider/tests/test_utils.py b/oidc_provider/tests/test_utils.py index 0357014..32bdf8d 100644 --- a/oidc_provider/tests/test_utils.py +++ b/oidc_provider/tests/test_utils.py @@ -1,13 +1,44 @@ -from django.conf import settings from django.test import TestCase from oidc_provider.lib.utils.common import get_issuer +class Request(object): + """ + Mock request object. + """ + scheme = 'http' + + def get_host(self): + return 'host-from-request:8888' + + class CommonTest(TestCase): """ Test cases for common utils. """ def test_get_issuer(self): - issuer = get_issuer() - self.assertEqual(issuer, settings.SITE_URL + '/openid') + request = Request() + + # from default settings + self.assertEqual(get_issuer(), + 'http://localhost:8000/openid') + + # from custom settings + with self.settings(SITE_URL='http://otherhost:8000'): + self.assertEqual(get_issuer(), + 'http://otherhost:8000/openid') + + # `SITE_URL` not set, from `request` + with self.settings(SITE_URL=''): + self.assertEqual(get_issuer(request=request), + 'http://host-from-request:8888/openid') + + # use settings first if both are provided + self.assertEqual(get_issuer(request=request), + 'http://localhost:8000/openid') + + # `site_url` can even be overridden manually + self.assertEqual(get_issuer(site_url='http://127.0.0.1:9000', + request=request), + 'http://127.0.0.1:9000/openid') diff --git a/oidc_provider/views.py b/oidc_provider/views.py index a60312c..1b1bba4 100644 --- a/oidc_provider/views.py +++ b/oidc_provider/views.py @@ -1,5 +1,3 @@ -import logging - from Crypto.PublicKey import RSA from django.contrib.auth.views import redirect_to_login, logout from django.core.urlresolvers import reverse @@ -14,7 +12,7 @@ from oidc_provider.lib.claims import StandardScopeClaims from oidc_provider.lib.endpoints.authorize import * from oidc_provider.lib.endpoints.token import * from oidc_provider.lib.errors import * -from oidc_provider.lib.utils.common import redirect, get_issuer +from oidc_provider.lib.utils.common import redirect, get_site_url, get_issuer from oidc_provider.lib.utils.oauth2 import protected_resource_view from oidc_provider.models import RESPONSE_TYPE_CHOICES, RSAKey from oidc_provider import settings @@ -178,19 +176,18 @@ class ProviderInfoView(View): def get(self, request, *args, **kwargs): dic = dict() - dic['issuer'] = get_issuer() + site_url = get_site_url(request=request) + dic['issuer'] = get_issuer(site_url=site_url, request=request) - SITE_URL = settings.get('SITE_URL') - - dic['authorization_endpoint'] = SITE_URL + reverse('oidc_provider:authorize') - dic['token_endpoint'] = SITE_URL + reverse('oidc_provider:token') - dic['userinfo_endpoint'] = SITE_URL + reverse('oidc_provider:userinfo') - dic['end_session_endpoint'] = SITE_URL + reverse('oidc_provider:logout') + dic['authorization_endpoint'] = site_url + reverse('oidc_provider:authorize') + dic['token_endpoint'] = site_url + reverse('oidc_provider:token') + dic['userinfo_endpoint'] = site_url + reverse('oidc_provider:userinfo') + dic['end_session_endpoint'] = site_url + reverse('oidc_provider:logout') types_supported = [x[0] for x in RESPONSE_TYPE_CHOICES] dic['response_types_supported'] = types_supported - dic['jwks_uri'] = SITE_URL + reverse('oidc_provider:jwks') + dic['jwks_uri'] = site_url + reverse('oidc_provider:jwks') dic['id_token_signing_alg_values_supported'] = ['HS256', 'RS256']