Add pep8 compliance and checker

This commit is contained in:
Philippe Savoie 2017-08-08 15:41:42 -07:00 committed by Phil Savoie
parent f78e2be3c5
commit 5dcd6a10b0
33 changed files with 365 additions and 231 deletions

View file

@ -12,8 +12,8 @@
# All configuration values have a default; values that are commented out # All configuration values have a default; values that are commented out
# serve to show the default. # serve to show the default.
import sys # import sys
import os # import os
# If extensions (or modules to document with autodoc) are in another directory, # If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the # add these directories to sys.path here. If the directory is relative to the

View file

@ -1,5 +1,6 @@
import os import os
from django.core.wsgi import get_wsgi_application
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'myapp.settings') os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'myapp.settings')
from django.core.wsgi import get_wsgi_application
application = get_wsgi_application() application = get_wsgi_application()

View file

@ -51,7 +51,9 @@ class ClientAdmin(admin.ModelAdmin):
fieldsets = [ fieldsets = [
[_(u''), { [_(u''), {
'fields': ('name', 'client_type', 'response_type','_redirect_uris', 'jwt_alg', 'require_consent', 'reuse_consent'), 'fields': (
'name', 'client_type', 'response_type', '_redirect_uris', 'jwt_alg', 'require_consent',
'reuse_consent'),
}], }],
[_(u'Credentials'), { [_(u'Credentials'), {
'fields': ('client_id', 'client_secret'), 'fields': ('client_id', 'client_secret'),

View file

@ -9,8 +9,8 @@ STANDARD_CLAIMS = {
'name': '', 'given_name': '', 'family_name': '', 'middle_name': '', 'nickname': '', 'name': '', 'given_name': '', 'family_name': '', 'middle_name': '', 'nickname': '',
'preferred_username': '', 'profile': '', 'picture': '', 'website': '', 'gender': '', 'preferred_username': '', 'profile': '', 'picture': '', 'website': '', 'gender': '',
'birthdate': '', 'zoneinfo': '', 'locale': '', 'updated_at': '', 'email': '', 'email_verified': '', 'birthdate': '', 'zoneinfo': '', 'locale': '', 'updated_at': '', 'email': '', 'email_verified': '',
'phone_number': '', 'phone_number_verified': '', 'address': { 'formatted': '', 'phone_number': '', 'phone_number_verified': '', 'address': {
'street_address': '', 'locality': '', 'region': '', 'postal_code': '', 'country': '', }, 'formatted': '', 'street_address': '', 'locality': '', 'region': '', 'postal_code': '', 'country': '', },
} }
@ -72,7 +72,9 @@ class ScopeClaims(object):
return aux_dic return aux_dic
@classmethod @classmethod
def get_scopes_info(cls, scopes=[]): def get_scopes_info(cls, scopes=None):
if scopes is None:
scopes = []
scopes_info = [] scopes_info = []
for name in cls.__dict__: for name in cls.__dict__:
@ -99,6 +101,7 @@ class StandardScopeClaims(ScopeClaims):
_(u'Basic profile'), _(u'Basic profile'),
_(u'Access to your basic information. Includes names, gender, birthdate and other information.'), _(u'Access to your basic information. Includes names, gender, birthdate and other information.'),
) )
def scope_profile(self): def scope_profile(self):
dic = { dic = {
'name': self.userinfo.get('name'), 'name': self.userinfo.get('name'),
@ -123,6 +126,7 @@ class StandardScopeClaims(ScopeClaims):
_(u'Email'), _(u'Email'),
_(u'Access to your email address.'), _(u'Access to your email address.'),
) )
def scope_email(self): def scope_email(self):
dic = { dic = {
'email': self.userinfo.get('email') or getattr(self.user, 'email', None), 'email': self.userinfo.get('email') or getattr(self.user, 'email', None),
@ -135,6 +139,7 @@ class StandardScopeClaims(ScopeClaims):
_(u'Phone number'), _(u'Phone number'),
_(u'Access to your phone number.'), _(u'Access to your phone number.'),
) )
def scope_phone(self): def scope_phone(self):
dic = { dic = {
'phone_number': self.userinfo.get('phone_number'), 'phone_number': self.userinfo.get('phone_number'),
@ -147,6 +152,7 @@ class StandardScopeClaims(ScopeClaims):
_(u'Address information'), _(u'Address information'),
_(u'Access to your address. Includes country, locality, street and other information.'), _(u'Access to your address. Includes country, locality, street and other information.'),
) )
def scope_address(self): def scope_address(self):
dic = { dic = {
'address': { 'address': {

View file

@ -102,8 +102,8 @@ class AuthorizeEndpoint(object):
logger.debug('[Authorize] Invalid response type: %s', self.params['response_type']) logger.debug('[Authorize] Invalid response type: %s', self.params['response_type'])
raise AuthorizeError(self.params['redirect_uri'], 'unsupported_response_type', self.grant_type) raise AuthorizeError(self.params['redirect_uri'], 'unsupported_response_type', self.grant_type)
if not self.is_authentication and \ if (not self.is_authentication and
(self.grant_type == 'hybrid' or self.params['response_type'] in ['id_token', 'id_token token']): (self.grant_type == 'hybrid' or self.params['response_type'] in ['id_token', 'id_token token'])):
logger.debug('[Authorize] Missing openid scope.') logger.debug('[Authorize] Missing openid scope.')
raise AuthorizeError(self.params['redirect_uri'], 'invalid_scope', self.grant_type) raise AuthorizeError(self.params['redirect_uri'], 'invalid_scope', self.grant_type)
@ -165,7 +165,8 @@ class AuthorizeEndpoint(object):
id_token_dic = create_id_token(**kwargs) id_token_dic = create_id_token(**kwargs)
# Check if response_type must include id_token in the response. # Check if response_type must include id_token in the response.
if self.params['response_type'] in ['id_token', 'id_token token', 'code id_token', 'code id_token token']: if self.params['response_type'] in [
'id_token', 'id_token token', 'code id_token', 'code id_token token']:
query_fragment['id_token'] = encode_id_token(id_token_dic, self.client) query_fragment['id_token'] = encode_id_token(id_token_dic, self.client)
else: else:
id_token_dic = {} id_token_dic = {}
@ -211,7 +212,8 @@ class AuthorizeEndpoint(object):
logger.exception('[Authorize] Error when trying to create response uri: %s', error) logger.exception('[Authorize] Error when trying to create response uri: %s', error)
raise AuthorizeError(self.params['redirect_uri'], 'server_error', self.grant_type) raise AuthorizeError(self.params['redirect_uri'], 'server_error', self.grant_type)
uri = uri._replace(query=urlencode(query_params, doseq=True), fragment=uri.fragment + urlencode(query_fragment, doseq=True)) uri = uri._replace(
query=urlencode(query_params, doseq=True), fragment=uri.fragment + urlencode(query_fragment, doseq=True))
return urlunsplit(uri) return urlunsplit(uri)
@ -264,7 +266,8 @@ class AuthorizeEndpoint(object):
""" """
scopes = StandardScopeClaims.get_scopes_info(self.params['scope']) scopes = StandardScopeClaims.get_scopes_info(self.params['scope'])
if settings.get('OIDC_EXTRA_SCOPE_CLAIMS'): if settings.get('OIDC_EXTRA_SCOPE_CLAIMS'):
scopes_extra = settings.get('OIDC_EXTRA_SCOPE_CLAIMS', import_str=True).get_scopes_info(self.params['scope']) scopes_extra = settings.get('OIDC_EXTRA_SCOPE_CLAIMS', import_str=True).get_scopes_info(
self.params['scope'])
for index_extra, scope_extra in enumerate(scopes_extra): for index_extra, scope_extra in enumerate(scopes_extra):
for index, scope in enumerate(scopes[:]): for index, scope in enumerate(scopes[:]):
if scope_extra['scope'] == scope['scope']: if scope_extra['scope'] == scope['scope']:

View file

@ -4,11 +4,6 @@ import logging
import re import re
from django.contrib.auth import authenticate from django.contrib.auth import authenticate
try:
from urllib.parse import unquote
except ImportError:
from urllib import unquote
from django.http import JsonResponse from django.http import JsonResponse
from oidc_provider.lib.errors import ( from oidc_provider.lib.errors import (

View file

@ -31,6 +31,7 @@ class UserAuthError(Exception):
'error_description': self.description, 'error_description': self.description,
} }
class AuthorizeError(Exception): class AuthorizeError(Exception):
_errors = { _errors = {

View file

@ -5,11 +5,6 @@ from django.http import HttpResponse
from oidc_provider import settings from oidc_provider import settings
try:
from urlparse import urlsplit, urlunsplit
except ImportError:
from urllib.parse import urlsplit, urlunsplit
def redirect(uri): def redirect(uri):
""" """
@ -75,7 +70,8 @@ def default_after_userlogin_hook(request, user, client):
return None return None
def default_after_end_session_hook(request, id_token=None, post_logout_redirect_uri=None, state=None, client=None, next_page=None): def default_after_end_session_hook(
request, id_token=None, post_logout_redirect_uri=None, state=None, client=None, next_page=None):
""" """
Default function for setting OIDC_AFTER_END_SESSION_HOOK. Default function for setting OIDC_AFTER_END_SESSION_HOOK.
@ -91,7 +87,8 @@ def default_after_end_session_hook(request, id_token=None, post_logout_redirect_
:param state: state param from url query params :param state: state param from url query params
:type state: str :type state: str
:param client: If id_token has `aud` param and associated Client exists, this is an instance of it - do NOT trust this param :param client: If id_token has `aud` param and associated Client exists,
this is an instance of it - do NOT trust this param
:type client: oidc_provider.models.Client :type client: oidc_provider.models.Client
:param next_page: calculated next_page redirection target :param next_page: calculated next_page redirection target

View file

@ -28,12 +28,15 @@ def extract_access_token(request):
return access_token return access_token
def protected_resource_view(scopes=[]): def protected_resource_view(scopes=None):
""" """
View decorator. The client accesses protected resources by presenting the View decorator. The client accesses protected resources by presenting the
access token to the resource server. access token to the resource server.
https://tools.ietf.org/html/rfc6749#section-7 https://tools.ietf.org/html/rfc6749#section-7
""" """
if scopes is None:
scopes = []
def wrapper(view): def wrapper(view):
def view_wrapper(request, *args, **kwargs): def view_wrapper(request, *args, **kwargs):
access_token = extract_access_token(request) access_token = extract_access_token(request)
@ -52,9 +55,10 @@ def protected_resource_view(scopes=[]):
if not set(scopes).issubset(set(kwargs['token'].scope)): if not set(scopes).issubset(set(kwargs['token'].scope)):
logger.debug('[UserInfo] Missing openid scope.') logger.debug('[UserInfo] Missing openid scope.')
raise BearerTokenError('insufficient_scope') raise BearerTokenError('insufficient_scope')
except (BearerTokenError) as error: except BearerTokenError as error:
response = HttpResponse(status=error.status) response = HttpResponse(status=error.status)
response['WWW-Authenticate'] = 'error="{0}", error_description="{1}"'.format(error.code, error.description) response['WWW-Authenticate'] = 'error="{0}", error_description="{1}"'.format(
error.code, error.description)
return response return response
return view(request, *args, **kwargs) return view(request, *args, **kwargs)

View file

@ -18,12 +18,14 @@ from oidc_provider.models import (
from oidc_provider import settings from oidc_provider import settings
def create_id_token(user, aud, nonce='', at_hash='', request=None, scope=[]): def create_id_token(user, aud, nonce='', at_hash='', request=None, scope=None):
""" """
Creates the id_token dictionary. Creates the id_token dictionary.
See: http://openid.net/specs/openid-connect-core-1_0.html#IDToken See: http://openid.net/specs/openid-connect-core-1_0.html#IDToken
Return a dic. Return a dic.
""" """
if scope is None:
scope = []
sub = settings.get('OIDC_IDTOKEN_SUB_GENERATOR', import_str=True)(user=user) sub = settings.get('OIDC_IDTOKEN_SUB_GENERATOR', import_str=True)(user=user)
expires_in = settings.get('OIDC_IDTOKEN_EXPIRE') expires_in = settings.get('OIDC_IDTOKEN_EXPIRE')
@ -63,6 +65,7 @@ def create_id_token(user, aud, nonce='', at_hash='', request=None, scope=[]):
return dic return dic
def encode_id_token(payload, client): def encode_id_token(payload, client):
""" """
Represent the ID Token as a JSON Web Token (JWT). Represent the ID Token as a JSON Web Token (JWT).
@ -72,6 +75,7 @@ def encode_id_token(payload, client):
_jws = JWS(payload, alg=client.jwt_alg) _jws = JWS(payload, alg=client.jwt_alg)
return _jws.sign_compact(keys) return _jws.sign_compact(keys)
def decode_id_token(token, client): def decode_id_token(token, client):
""" """
Represent the ID Token as a JSON Web Token (JWT). Represent the ID Token as a JSON Web Token (JWT).
@ -80,6 +84,7 @@ def decode_id_token(token, client):
keys = get_client_alg_keys(client) keys = get_client_alg_keys(client)
return JWS().verify_compact(token, keys=keys) return JWS().verify_compact(token, keys=keys)
def client_id_from_id_token(id_token): def client_id_from_id_token(id_token):
""" """
Extracts the client id from a JSON Web Token (JWT). Extracts the client id from a JSON Web Token (JWT).
@ -88,6 +93,7 @@ def client_id_from_id_token(id_token):
payload = JWT().unpack(id_token).payload() payload = JWT().unpack(id_token).payload()
return payload.get('aud', None) return payload.get('aud', None)
def create_token(user, client, scope, id_token_dic=None): def create_token(user, client, scope, id_token_dic=None):
""" """
Create and populate a Token object. Create and populate a Token object.
@ -108,6 +114,7 @@ def create_token(user, client, scope, id_token_dic=None):
return token return token
def create_code(user, client, scope, nonce, is_authentication, def create_code(user, client, scope, nonce, is_authentication,
code_challenge=None, code_challenge_method=None): code_challenge=None, code_challenge_method=None):
""" """
@ -132,6 +139,7 @@ def create_code(user, client, scope, nonce, is_authentication,
return code return code
def get_client_alg_keys(client): def get_client_alg_keys(client):
""" """
Takes a client and returns the set of keys associated with it. Takes a client and returns the set of keys associated with it.

View file

@ -1,9 +1,6 @@
import os
from Cryptodome.PublicKey import RSA from Cryptodome.PublicKey import RSA
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from oidc_provider import settings
from oidc_provider.models import RSAKey from oidc_provider.models import RSAKey

View file

@ -20,7 +20,9 @@ class Migration(migrations.Migration):
('name', models.CharField(default=b'', max_length=100)), ('name', models.CharField(default=b'', max_length=100)),
('client_id', models.CharField(unique=True, max_length=255)), ('client_id', models.CharField(unique=True, max_length=255)),
('client_secret', models.CharField(unique=True, max_length=255)), ('client_secret', models.CharField(unique=True, max_length=255)),
('response_type', models.CharField(max_length=30, choices=[(b'code', b'code (Authorization Code Flow)'), (b'id_token', b'id_token (Implicit Flow)'), (b'id_token token', b'id_token token (Implicit Flow)')])), ('response_type', models.CharField(max_length=30, choices=[
(b'code', b'code (Authorization Code Flow)'), (b'id_token', b'id_token (Implicit Flow)'),
(b'id_token token', b'id_token token (Implicit Flow)')])),
('_redirect_uris', models.TextField(default=b'')), ('_redirect_uris', models.TextField(default=b'')),
], ],
options={ options={

View file

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import unicode_literals from __future__ import unicode_literals
from django.db import models, migrations from django.db import migrations
class Migration(migrations.Migration): class Migration(migrations.Migration):

View file

@ -29,7 +29,8 @@ class Migration(migrations.Migration):
migrations.AddField( migrations.AddField(
model_name='client', model_name='client',
name='date_created', name='date_created',
field=models.DateField(auto_now_add=True, default=datetime.datetime(2016, 1, 11, 18, 44, 32, 192477, tzinfo=utc)), field=models.DateField(
auto_now_add=True, default=datetime.datetime(2016, 1, 11, 18, 44, 32, 192477, tzinfo=utc)),
preserve_default=False, preserve_default=False,
), ),
migrations.AlterField( migrations.AlterField(

View file

@ -15,6 +15,11 @@ class Migration(migrations.Migration):
migrations.AddField( migrations.AddField(
model_name='client', model_name='client',
name='client_type', name='client_type',
field=models.CharField(choices=[(b'confidential', b'Confidential'), (b'public', b'Public')], default=b'confidential', help_text='<b>Confidential</b> clients are capable of maintaining the confidentiality of their credentials. <b>Public</b> clients are incapable.', max_length=30), field=models.CharField(
choices=[(b'confidential', b'Confidential'), (b'public', b'Public')],
default=b'confidential',
help_text='<b>Confidential</b> clients are capable of maintaining the confidentiality of their '
'credentials. <b>Public</b> clients are incapable.',
max_length=30),
), ),
] ]

View file

@ -15,6 +15,10 @@ class Migration(migrations.Migration):
migrations.AddField( migrations.AddField(
model_name='client', model_name='client',
name='jwt_alg', name='jwt_alg',
field=models.CharField(choices=[(b'HS256', b'HS256'), (b'RS256', b'RS256')], default=b'RS256', max_length=10, verbose_name='JWT Algorithm'), field=models.CharField(
choices=[(b'HS256', b'HS256'), (b'RS256', b'RS256')],
default=b'RS256',
max_length=10,
verbose_name='JWT Algorithm'),
), ),
] ]

View file

@ -25,12 +25,21 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name='client', model_name='client',
name='client_type', name='client_type',
field=models.CharField(choices=[('confidential', 'Confidential'), ('public', 'Public')], default='confidential', help_text='<b>Confidential</b> clients are capable of maintaining the confidentiality of their credentials. <b>Public</b> clients are incapable.', max_length=30), field=models.CharField(
choices=[('confidential', 'Confidential'), ('public', 'Public')],
default='confidential',
help_text='<b>Confidential</b> clients are capable of maintaining the confidentiality of their'
' credentials. <b>Public</b> clients are incapable.',
max_length=30),
), ),
migrations.AlterField( migrations.AlterField(
model_name='client', model_name='client',
name='jwt_alg', name='jwt_alg',
field=models.CharField(choices=[('HS256', 'HS256'), ('RS256', 'RS256')], default='RS256', max_length=10, verbose_name='JWT Algorithm'), field=models.CharField(
choices=[('HS256', 'HS256'), ('RS256', 'RS256')],
default='RS256',
max_length=10,
verbose_name='JWT Algorithm'),
), ),
migrations.AlterField( migrations.AlterField(
model_name='client', model_name='client',
@ -40,7 +49,11 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name='client', model_name='client',
name='response_type', name='response_type',
field=models.CharField(choices=[('code', 'code (Authorization Code Flow)'), ('id_token', 'id_token (Implicit Flow)'), ('id_token token', 'id_token token (Implicit Flow)')], max_length=30), field=models.CharField(
choices=[
('code', 'code (Authorization Code Flow)'), ('id_token', 'id_token (Implicit Flow)'),
('id_token token', 'id_token token (Implicit Flow)')],
max_length=30),
), ),
migrations.AlterField( migrations.AlterField(
model_name='code', model_name='code',

View file

@ -19,13 +19,15 @@ class Migration(migrations.Migration):
migrations.AddField( migrations.AddField(
model_name='userconsent', model_name='userconsent',
name='date_given', name='date_given',
field=models.DateTimeField(default=datetime.datetime(2016, 6, 10, 17, 53, 48, 889808, tzinfo=utc), verbose_name='Date Given'), field=models.DateTimeField(
default=datetime.datetime(2016, 6, 10, 17, 53, 48, 889808, tzinfo=utc), verbose_name='Date Given'),
preserve_default=False, preserve_default=False,
), ),
migrations.AlterField( migrations.AlterField(
model_name='client', model_name='client',
name='_redirect_uris', name='_redirect_uris',
field=models.TextField(default=b'', help_text='Enter each URI on a new line.', verbose_name='Redirect URIs'), field=models.TextField(
default=b'', help_text='Enter each URI on a new line.', verbose_name='Redirect URIs'),
), ),
migrations.AlterField( migrations.AlterField(
model_name='client', model_name='client',
@ -40,7 +42,13 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name='client', model_name='client',
name='client_type', name='client_type',
field=models.CharField(choices=[(b'confidential', b'Confidential'), (b'public', b'Public')], default=b'confidential', help_text='<b>Confidential</b> clients are capable of maintaining the confidentiality of their credentials. <b>Public</b> clients are incapable.', max_length=30, verbose_name='Client Type'), field=models.CharField(
choices=[(b'confidential', b'Confidential'), (b'public', b'Public')],
default=b'confidential',
help_text='<b>Confidential</b> clients are capable of maintaining the confidentiality of their '
'credentials. <b>Public</b> clients are incapable.',
max_length=30,
verbose_name='Client Type'),
), ),
migrations.AlterField( migrations.AlterField(
model_name='client', model_name='client',
@ -55,7 +63,12 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name='client', model_name='client',
name='response_type', name='response_type',
field=models.CharField(choices=[(b'code', b'code (Authorization Code Flow)'), (b'id_token', b'id_token (Implicit Flow)'), (b'id_token token', b'id_token token (Implicit Flow)')], max_length=30, verbose_name='Response Type'), field=models.CharField(
choices=[
(b'code', b'code (Authorization Code Flow)'), (b'id_token', b'id_token (Implicit Flow)'),
(b'id_token token', b'id_token token (Implicit Flow)')],
max_length=30,
verbose_name='Response Type'),
), ),
migrations.AlterField( migrations.AlterField(
model_name='code', model_name='code',
@ -65,7 +78,8 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name='code', model_name='code',
name='client', name='client',
field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='oidc_provider.Client', verbose_name='Client'), field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to='oidc_provider.Client', verbose_name='Client'),
), ),
migrations.AlterField( migrations.AlterField(
model_name='code', model_name='code',
@ -100,7 +114,8 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name='code', model_name='code',
name='user', name='user',
field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='User'), field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='User'),
), ),
migrations.AlterField( migrations.AlterField(
model_name='rsakey', model_name='rsakey',
@ -125,7 +140,8 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name='token', model_name='token',
name='client', name='client',
field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='oidc_provider.Client', verbose_name='Client'), field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to='oidc_provider.Client', verbose_name='Client'),
), ),
migrations.AlterField( migrations.AlterField(
model_name='token', model_name='token',
@ -140,7 +156,8 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name='token', model_name='token',
name='user', name='user',
field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='User'), field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='User'),
), ),
migrations.AlterField( migrations.AlterField(
model_name='userconsent', model_name='userconsent',
@ -150,7 +167,8 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name='userconsent', model_name='userconsent',
name='client', name='client',
field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='oidc_provider.Client', verbose_name='Client'), field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to='oidc_provider.Client', verbose_name='Client'),
), ),
migrations.AlterField( migrations.AlterField(
model_name='userconsent', model_name='userconsent',
@ -160,6 +178,7 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name='userconsent', model_name='userconsent',
name='user', name='user',
field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='User'), field=models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL, verbose_name='User'),
), ),
] ]

View file

@ -25,7 +25,13 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name='client', model_name='client',
name='client_type', name='client_type',
field=models.CharField(choices=[('confidential', 'Confidential'), ('public', 'Public')], default='confidential', help_text='<b>Confidential</b> clients are capable of maintaining the confidentiality of their credentials. <b>Public</b> clients are incapable.', max_length=30, verbose_name='Client Type'), field=models.CharField(
choices=[('confidential', 'Confidential'), ('public', 'Public')],
default='confidential',
help_text='<b>Confidential</b> clients are capable of maintaining the confidentiality of their '
'credentials. <b>Public</b> clients are incapable.',
max_length=30,
verbose_name='Client Type'),
), ),
migrations.AlterField( migrations.AlterField(
model_name='client', model_name='client',
@ -35,7 +41,12 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name='client', model_name='client',
name='response_type', name='response_type',
field=models.CharField(choices=[('code', 'code (Authorization Code Flow)'), ('id_token', 'id_token (Implicit Flow)'), ('id_token token', 'id_token token (Implicit Flow)')], max_length=30, verbose_name='Response Type'), field=models.CharField(
choices=[
('code', 'code (Authorization Code Flow)'), ('id_token', 'id_token (Implicit Flow)'),
('id_token token', 'id_token token (Implicit Flow)')],
max_length=30,
verbose_name='Response Type'),
), ),
migrations.AlterField( migrations.AlterField(
model_name='code', model_name='code',

View file

@ -20,12 +20,18 @@ class Migration(migrations.Migration):
migrations.AddField( migrations.AddField(
model_name='client', model_name='client',
name='logo', name='logo',
field=models.FileField(blank=True, default='', upload_to='oidc_provider/clients', verbose_name='Logo Image'), field=models.FileField(
blank=True, default='', upload_to='oidc_provider/clients', verbose_name='Logo Image'),
), ),
migrations.AddField( migrations.AddField(
model_name='client', model_name='client',
name='terms_url', name='terms_url',
field=models.CharField(blank=True, default='', help_text='External reference to the privacy policy of the client.', max_length=255, verbose_name='Terms URL'), field=models.CharField(
blank=True,
default='',
help_text='External reference to the privacy policy of the client.',
max_length=255,
verbose_name='Terms URL'),
), ),
migrations.AddField( migrations.AddField(
model_name='client', model_name='client',
@ -35,11 +41,23 @@ class Migration(migrations.Migration):
migrations.AlterField( migrations.AlterField(
model_name='client', model_name='client',
name='jwt_alg', name='jwt_alg',
field=models.CharField(choices=[('HS256', 'HS256'), ('RS256', 'RS256')], default='RS256', help_text='Algorithm used to encode ID Tokens.', max_length=10, verbose_name='JWT Algorithm'), field=models.CharField(
choices=[('HS256', 'HS256'), ('RS256', 'RS256')],
default='RS256',
help_text='Algorithm used to encode ID Tokens.',
max_length=10,
verbose_name='JWT Algorithm'),
), ),
migrations.AlterField( migrations.AlterField(
model_name='client', model_name='client',
name='response_type', name='response_type',
field=models.CharField(choices=[('code', 'code (Authorization Code Flow)'), ('id_token', 'id_token (Implicit Flow)'), ('id_token token', 'id_token token (Implicit Flow)'), ('code token', 'code token (Hybrid Flow)'), ('code id_token', 'code id_token (Hybrid Flow)'), ('code id_token token', 'code id_token token (Hybrid Flow)')], max_length=30, verbose_name='Response Type'), field=models.CharField(
choices=[
('code', 'code (Authorization Code Flow)'), ('id_token', 'id_token (Implicit Flow)'),
('id_token token', 'id_token token (Implicit Flow)'), ('code token', 'code token (Hybrid Flow)'),
('code id_token', 'code id_token (Hybrid Flow)'),
('code id_token token', 'code id_token token (Hybrid Flow)')],
max_length=30,
verbose_name='Response Type'),
), ),
] ]

View file

@ -15,6 +15,10 @@ class Migration(migrations.Migration):
migrations.AddField( migrations.AddField(
model_name='client', model_name='client',
name='_post_logout_redirect_uris', name='_post_logout_redirect_uris',
field=models.TextField(blank=True, default='', help_text='Enter each URI on a new line.', verbose_name='Post Logout Redirect URIs'), field=models.TextField(
blank=True,
default='',
help_text='Enter each URI on a new line.',
verbose_name='Post Logout Redirect URIs'),
), ),
] ]

View file

@ -15,11 +15,18 @@ class Migration(migrations.Migration):
migrations.AddField( migrations.AddField(
model_name='client', model_name='client',
name='require_consent', name='require_consent',
field=models.BooleanField(default=True, help_text='If disabled, the Server will NEVER ask the user for consent.', verbose_name='Require Consent?'), field=models.BooleanField(
default=True,
help_text='If disabled, the Server will NEVER ask the user for consent.',
verbose_name='Require Consent?'),
), ),
migrations.AddField( migrations.AddField(
model_name='client', model_name='client',
name='reuse_consent', name='reuse_consent',
field=models.BooleanField(default=True, help_text="If enabled, the Server will save the user consent given to a specific client, so that user won't be prompted for the same authorization multiple times.", verbose_name='Reuse Consent?'), field=models.BooleanField(
default=True,
help_text="If enabled, the Server will save the user consent given to a specific client,"
" so that user won't be prompted for the same authorization multiple times.",
verbose_name='Reuse Consent?'),
), ),
] ]

View file

@ -33,36 +33,66 @@ JWT_ALGS = [
class Client(models.Model): class Client(models.Model):
name = models.CharField(max_length=100, default='', verbose_name=_(u'Name')) name = models.CharField(max_length=100, default='', verbose_name=_(u'Name'))
client_type = models.CharField(max_length=30, choices=CLIENT_TYPE_CHOICES, default='confidential', verbose_name=_(u'Client Type'), help_text=_(u'<b>Confidential</b> clients are capable of maintaining the confidentiality of their credentials. <b>Public</b> clients are incapable.')) client_type = models.CharField(
max_length=30,
choices=CLIENT_TYPE_CHOICES,
default='confidential',
verbose_name=_(u'Client Type'),
help_text=_(u'<b>Confidential</b> clients are capable of maintaining the confidentiality of their credentials. '
u'<b>Public</b> clients are incapable.'))
client_id = models.CharField(max_length=255, unique=True, verbose_name=_(u'Client ID')) client_id = models.CharField(max_length=255, unique=True, verbose_name=_(u'Client ID'))
client_secret = models.CharField(max_length=255, blank=True, verbose_name=_(u'Client SECRET')) client_secret = models.CharField(max_length=255, blank=True, verbose_name=_(u'Client SECRET'))
response_type = models.CharField(max_length=30, choices=RESPONSE_TYPE_CHOICES, verbose_name=_(u'Response Type')) response_type = models.CharField(max_length=30, choices=RESPONSE_TYPE_CHOICES, verbose_name=_(u'Response Type'))
jwt_alg = models.CharField(max_length=10, choices=JWT_ALGS, default='RS256', verbose_name=_(u'JWT Algorithm'), help_text=_(u'Algorithm used to encode ID Tokens.')) jwt_alg = models.CharField(
max_length=10,
choices=JWT_ALGS,
default='RS256',
verbose_name=_(u'JWT Algorithm'),
help_text=_(u'Algorithm used to encode ID Tokens.'))
date_created = models.DateField(auto_now_add=True, verbose_name=_(u'Date Created')) date_created = models.DateField(auto_now_add=True, verbose_name=_(u'Date Created'))
website_url = models.CharField(max_length=255, blank=True, default='', verbose_name=_(u'Website URL')) website_url = models.CharField(max_length=255, blank=True, default='', verbose_name=_(u'Website URL'))
terms_url = models.CharField(max_length=255, blank=True, default='', verbose_name=_(u'Terms URL'), help_text=_(u'External reference to the privacy policy of the client.')) terms_url = models.CharField(
max_length=255,
blank=True,
default='',
verbose_name=_(u'Terms URL'),
help_text=_(u'External reference to the privacy policy of the client.'))
contact_email = models.CharField(max_length=255, blank=True, default='', verbose_name=_(u'Contact Email')) contact_email = models.CharField(max_length=255, blank=True, default='', verbose_name=_(u'Contact Email'))
logo = models.FileField(blank=True, default='', upload_to='oidc_provider/clients', verbose_name=_(u'Logo Image')) logo = models.FileField(blank=True, default='', upload_to='oidc_provider/clients', verbose_name=_(u'Logo Image'))
reuse_consent = models.BooleanField(default=True, verbose_name=_('Reuse Consent?'), help_text=_('If enabled, the Server will save the user consent given to a specific client, so that user won\'t be prompted for the same authorization multiple times.')) reuse_consent = models.BooleanField(
require_consent = models.BooleanField(default=True, verbose_name=_('Require Consent?'), help_text=_('If disabled, the Server will NEVER ask the user for consent.')) default=True,
verbose_name=_('Reuse Consent?'),
help_text=_('If enabled, the Server will save the user consent given to a specific client, so that'
' user won\'t be prompted for the same authorization multiple times.'))
require_consent = models.BooleanField(
default=True,
verbose_name=_('Require Consent?'),
help_text=_('If disabled, the Server will NEVER ask the user for consent.'))
_redirect_uris = models.TextField(default='', verbose_name=_(u'Redirect URIs'), help_text=_(u'Enter each URI on a new line.')) _redirect_uris = models.TextField(
def redirect_uris(): default='', verbose_name=_(u'Redirect URIs'), help_text=_(u'Enter each URI on a new line.'))
def fget(self):
@property
def redirect_uris(self):
return self._redirect_uris.splitlines() return self._redirect_uris.splitlines()
def fset(self, value):
self._redirect_uris = '\n'.join(value)
return locals()
redirect_uris = property(**redirect_uris())
_post_logout_redirect_uris = models.TextField(blank=True, default='', verbose_name=_(u'Post Logout Redirect URIs'), help_text=_(u'Enter each URI on a new line.')) @redirect_uris.setter
def post_logout_redirect_uris(): def redirect_uris(self, value):
def fget(self): self._redirect_uris = '\n'.join(value)
_post_logout_redirect_uris = models.TextField(
blank=True,
default='',
verbose_name=_(u'Post Logout Redirect URIs'),
help_text=_(u'Enter each URI on a new line.'))
@property
def post_logout_redirect_uris(self):
return self._post_logout_redirect_uris.splitlines() return self._post_logout_redirect_uris.splitlines()
def fset(self, value):
@post_logout_redirect_uris.setter
def post_logout_redirect_uris(self, value):
self._post_logout_redirect_uris = '\n'.join(value) self._post_logout_redirect_uris = '\n'.join(value)
return locals()
post_logout_redirect_uris = property(**post_logout_redirect_uris())
class Meta: class Meta:
verbose_name = _(u'Client') verbose_name = _(u'Client')
@ -74,8 +104,6 @@ class Client(models.Model):
def __unicode__(self): def __unicode__(self):
return self.__str__() return self.__str__()
@property @property
def default_redirect_uri(self): def default_redirect_uri(self):
return self.redirect_uris[0] if self.redirect_uris else '' return self.redirect_uris[0] if self.redirect_uris else ''
@ -88,16 +116,14 @@ class BaseCodeTokenModel(models.Model):
expires_at = models.DateTimeField(verbose_name=_(u'Expiration Date')) expires_at = models.DateTimeField(verbose_name=_(u'Expiration Date'))
_scope = models.TextField(default='', verbose_name=_(u'Scopes')) _scope = models.TextField(default='', verbose_name=_(u'Scopes'))
def scope(): @property
def fget(self): def scope(self):
return self._scope.split() return self._scope.split()
def fset(self, value): @scope.setter
def scope(self, value):
self._scope = ' '.join(value) self._scope = ' '.join(value)
return locals()
scope = property(**scope())
def has_expired(self): def has_expired(self):
return timezone.now() >= self.expires_at return timezone.now() >= self.expires_at
@ -130,17 +156,14 @@ class Token(BaseCodeTokenModel):
refresh_token = models.CharField(max_length=255, unique=True, verbose_name=_(u'Refresh Token')) refresh_token = models.CharField(max_length=255, unique=True, verbose_name=_(u'Refresh Token'))
_id_token = models.TextField(verbose_name=_(u'ID Token')) _id_token = models.TextField(verbose_name=_(u'ID Token'))
def id_token(): @property
def id_token(self):
def fget(self):
return json.loads(self._id_token) return json.loads(self._id_token)
def fset(self, value): @id_token.setter
def id_token(self, value):
self._id_token = json.dumps(value) self._id_token = json.dumps(value)
return locals()
id_token = property(**id_token())
class Meta: class Meta:
verbose_name = _(u'Token') verbose_name = _(u'Token')
verbose_name_plural = _(u'Tokens') verbose_name_plural = _(u'Tokens')

View file

@ -145,6 +145,7 @@ class DefaultSettings(object):
'error': 'oidc_provider/error.html' 'error': 'oidc_provider/error.html'
} }
default_settings = DefaultSettings() default_settings = DefaultSettings()

View file

@ -34,7 +34,9 @@ from oidc_provider.lib.endpoints.authorize import AuthorizeEndpoint
class AuthorizeEndpointMixin(object): class AuthorizeEndpointMixin(object):
def _auth_request(self, method, data={}, is_user_authenticated=False): def _auth_request(self, method, data=None, is_user_authenticated=False):
if data is None:
data = {}
url = reverse('oidc_provider:authorize') url = reverse('oidc_provider:authorize')
if method.lower() == 'get': if method.lower() == 'get':
@ -67,7 +69,8 @@ class AuthorizationCodeFlowTestCase(TestCase, AuthorizeEndpointMixin):
self.client = create_fake_client(response_type='code') self.client = create_fake_client(response_type='code')
self.client_with_no_consent = create_fake_client(response_type='code', require_consent=False) self.client_with_no_consent = create_fake_client(response_type='code', require_consent=False)
self.client_public = create_fake_client(response_type='code', is_public=True) self.client_public = create_fake_client(response_type='code', is_public=True)
self.client_public_with_no_consent = create_fake_client(response_type='code', is_public=True, require_consent=False) self.client_public_with_no_consent = create_fake_client(
response_type='code', is_public=True, require_consent=False)
self.state = uuid.uuid4().hex self.state = uuid.uuid4().hex
self.nonce = uuid.uuid4().hex self.nonce = uuid.uuid4().hex
@ -163,8 +166,7 @@ class AuthorizationCodeFlowTestCase(TestCase, AuthorizeEndpointMixin):
for key, value in iter(to_check.items()): for key, value in iter(to_check.items()):
is_input_ok = input_html.format(key, value) in response.content.decode('utf-8') is_input_ok = input_html.format(key, value) in response.content.decode('utf-8')
self.assertEqual(is_input_ok, True, self.assertEqual(is_input_ok, True, msg='Hidden input for "' + key + '" fails.')
msg='Hidden input for "' + key + '" fails.')
def test_user_consent_response(self): def test_user_consent_response(self):
""" """
@ -204,8 +206,7 @@ class AuthorizationCodeFlowTestCase(TestCase, AuthorizeEndpointMixin):
is_code_ok = is_code_valid(url=response['Location'], is_code_ok = is_code_valid(url=response['Location'],
user=self.user, user=self.user,
client=self.client) client=self.client)
self.assertEqual(is_code_ok, True, self.assertEqual(is_code_ok, True, msg='Code returned is invalid.')
msg='Code returned is invalid.')
# Check if the state is returned. # Check if the state is returned.
state = (response['Location'].split('state='))[1].split('&')[0] state = (response['Location'].split('state='))[1].split('&')[0]
@ -276,9 +277,10 @@ class AuthorizationCodeFlowTestCase(TestCase, AuthorizeEndpointMixin):
client=self.client) client=self.client)
self.assertTrue(is_code_ok, msg='Code returned is invalid or missing') self.assertTrue(is_code_ok, msg='Code returned is invalid or missing')
self.assertEquals(set(params.keys()), set(['state', 'code']), msg='More than state or code appended as query params') self.assertEquals(set(params.keys()), {'state', 'code'}, msg='More than state or code appended as query params')
self.assertTrue(response['Location'].startswith(self.client.default_redirect_uri), msg='Different redirect_uri returned') self.assertTrue(
response['Location'].startswith(self.client.default_redirect_uri), msg='Different redirect_uri returned')
def test_unknown_redirect_uris_are_rejected(self): def test_unknown_redirect_uris_are_rejected(self):
""" """
@ -372,7 +374,8 @@ class AuthorizationCodeFlowTestCase(TestCase, AuthorizeEndpointMixin):
self.assertNotIn( self.assertNotIn(
quote('prompt=login'), quote('prompt=login'),
response['Location'], response['Location'],
"Found prompt=login, this leads to infinite login loop. See https://github.com/juanifioren/django-oidc-provider/issues/197." "Found prompt=login, this leads to infinite login loop. See "
"https://github.com/juanifioren/django-oidc-provider/issues/197."
) )
response = self._auth_request('get', data, is_user_authenticated=True) response = self._auth_request('get', data, is_user_authenticated=True)
@ -381,7 +384,8 @@ class AuthorizationCodeFlowTestCase(TestCase, AuthorizeEndpointMixin):
self.assertNotIn( self.assertNotIn(
quote('prompt=login'), quote('prompt=login'),
response['Location'], response['Location'],
"Found prompt=login, this leads to infinite login loop. See https://github.com/juanifioren/django-oidc-provider/issues/197." "Found prompt=login, this leads to infinite login loop. See "
"https://github.com/juanifioren/django-oidc-provider/issues/197."
) )
def test_prompt_login_none_parameter(self): def test_prompt_login_none_parameter(self):
@ -447,7 +451,6 @@ class AuthorizationCodeFlowTestCase(TestCase, AuthorizeEndpointMixin):
self.assertIn('consent_required', response['Location']) self.assertIn('consent_required', response['Location'])
class AuthorizationImplicitFlowTestCase(TestCase, AuthorizeEndpointMixin): class AuthorizationImplicitFlowTestCase(TestCase, AuthorizeEndpointMixin):
""" """
Test cases for Authorization Endpoint using Implicit Flow. Test cases for Authorization Endpoint using Implicit Flow.

View file

@ -50,5 +50,6 @@ class EndSessionTestCase(TestCase):
def test_call_post_end_session_hook(self, hook_function): def test_call_post_end_session_hook(self, hook_function):
self.client.get(self.url) self.client.get(self.url)
self.assertTrue(hook_function.called, 'OIDC_AFTER_END_SESSION_HOOK should be called') self.assertTrue(hook_function.called, 'OIDC_AFTER_END_SESSION_HOOK should be called')
self.assertTrue(hook_function.call_count == 1, 'OIDC_AFTER_END_SESSION_HOOK should be called once but was {}'.format(hook_function.call_count)) self.assertTrue(
hook_function.call_count == 1,
'OIDC_AFTER_END_SESSION_HOOK should be called once but was {}'.format(hook_function.call_count))

View file

@ -10,6 +10,7 @@ class StubbedViews:
urlpatterns = [url('^test/', SampleView.as_view())] urlpatterns = [url('^test/', SampleView.as_view())]
MW_CLASSES = ('django.contrib.sessions.middleware.SessionMiddleware', MW_CLASSES = ('django.contrib.sessions.middleware.SessionMiddleware',
'oidc_provider.middleware.SessionManagementMiddleware') 'oidc_provider.middleware.SessionManagementMiddleware')

View file

@ -18,7 +18,7 @@ from django.test import TestCase
from jwkest.jwk import KEYS from jwkest.jwk import KEYS
from jwkest.jws import JWS from jwkest.jws import JWS
from jwkest.jwt import JWT from jwkest.jwt import JWT
from mock import patch, Mock from mock import patch
from oidc_provider.lib.utils.token import create_code from oidc_provider.lib.utils.token import create_code
from oidc_provider.models import Token from oidc_provider.models import Token
@ -101,7 +101,8 @@ class TokenTestCase(TestCase):
""" """
url = reverse('oidc_provider:token') url = reverse('oidc_provider:token')
request = self.factory.post(url, request = self.factory.post(
url,
data=urlencode(post_data), data=urlencode(post_data),
content_type='application/x-www-form-urlencoded', content_type='application/x-www-form-urlencoded',
**extras) **extras)
@ -427,7 +428,6 @@ class TokenTestCase(TestCase):
See http://openid.net/specs/openid-connect-core-1_0.html#AuthRequest and See http://openid.net/specs/openid-connect-core-1_0.html#AuthRequest and
http://openid.net/specs/openid-connect-core-1_0.html#HybridTokenRequest. http://openid.net/specs/openid-connect-core-1_0.html#HybridTokenRequest.
""" """
SIGKEYS = self._get_keys()
code = self._create_code() code = self._create_code()
post_data = self._auth_code_post_data(code=code.code) post_data = self._auth_code_post_data(code=code.code)
@ -465,15 +465,13 @@ class TokenTestCase(TestCase):
for request in requests: for request in requests:
response = TokenView.as_view()(request) response = TokenView.as_view()(request)
self.assertEqual(response.status_code == 405, True, self.assertEqual(response.status_code, 405, msg=request.method + ' request does not return a 405 status.')
msg=request.method + ' request does not return a 405 status.')
request = self.factory.post(url) request = self.factory.post(url)
response = TokenView.as_view()(request) response = TokenView.as_view()(request)
self.assertEqual(response.status_code == 400, True, self.assertEqual(response.status_code, 400, msg=request.method + ' request does not return a 400 status.')
msg=request.method + ' request does not return a 400 status.')
def test_client_authentication(self): def test_client_authentication(self):
""" """
@ -490,8 +488,9 @@ class TokenTestCase(TestCase):
response = self._post_request(post_data) response = self._post_request(post_data)
self.assertEqual('invalid_client' in response.content.decode('utf-8'), self.assertNotIn(
False, 'invalid_client',
response.content.decode('utf-8'),
msg='Client authentication fails using request-body credentials.') msg='Client authentication fails using request-body credentials.')
# Now, test with an invalid client_id. # Now, test with an invalid client_id.
@ -504,8 +503,9 @@ class TokenTestCase(TestCase):
response = self._post_request(invalid_data) response = self._post_request(invalid_data)
self.assertEqual('invalid_client' in response.content.decode('utf-8'), self.assertIn(
True, 'invalid_client',
response.content.decode('utf-8'),
msg='Client authentication success with an invalid "client_id".') msg='Client authentication success with an invalid "client_id".')
# Now, test using HTTP Basic Authentication method. # Now, test using HTTP Basic Authentication method.
@ -521,8 +521,9 @@ class TokenTestCase(TestCase):
response = self._post_request(basicauth_data, self._password_grant_auth_header()) response = self._post_request(basicauth_data, self._password_grant_auth_header())
response.content.decode('utf-8') response.content.decode('utf-8')
self.assertEqual('invalid_client' in response.content.decode('utf-8'), self.assertNotIn(
False, 'invalid_client',
response.content.decode('utf-8'),
msg='Client authentication fails using HTTP Basic Auth.') msg='Client authentication fails using HTTP Basic Auth.')
def test_access_token_contains_nonce(self): def test_access_token_contains_nonce(self):
@ -588,7 +589,7 @@ class TokenTestCase(TestCase):
response = self._post_request(post_data) response = self._post_request(post_data)
response_dic = json.loads(response.content.decode('utf-8')) response_dic = json.loads(response.content.decode('utf-8'))
id_token = JWS().verify_compact(response_dic['id_token'].encode('utf-8'), RSAKEYS) JWS().verify_compact(response_dic['id_token'].encode('utf-8'), RSAKEYS)
@override_settings(OIDC_IDTOKEN_SUB_GENERATOR='oidc_provider.tests.app.utils.fake_sub_generator') @override_settings(OIDC_IDTOKEN_SUB_GENERATOR='oidc_provider.tests.app.utils.fake_sub_generator')
def test_custom_sub_generator(self): def test_custom_sub_generator(self):
@ -732,4 +733,4 @@ class TokenTestCase(TestCase):
response = self._post_request(post_data) response = self._post_request(post_data)
response_dic = json.loads(response.content.decode('utf-8')) json.loads(response.content.decode('utf-8'))

View file

@ -30,10 +30,12 @@ class UserInfoTestCase(TestCase):
self.user = create_fake_user() self.user = create_fake_user()
self.client = create_fake_client(response_type='code') self.client = create_fake_client(response_type='code')
def _create_token(self, extra_scope=[]): def _create_token(self, extra_scope=None):
""" """
Generate a valid token. Generate a valid token.
""" """
if extra_scope is None:
extra_scope = []
scope = ['openid', 'email'] + extra_scope scope = ['openid', 'email'] + extra_scope
id_token_dic = create_id_token( id_token_dic = create_id_token(
@ -60,9 +62,7 @@ class UserInfoTestCase(TestCase):
""" """
url = reverse('oidc_provider:userinfo') url = reverse('oidc_provider:userinfo')
request = self.factory.post(url, request = self.factory.post(url, data={}, content_type='multipart/form-data')
data={},
content_type='multipart/form-data')
request.META['HTTP_AUTHORIZATION'] = 'Bearer ' + access_token request.META['HTTP_AUTHORIZATION'] = 'Bearer ' + access_token
@ -136,17 +136,13 @@ class UserInfoTestCase(TestCase):
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(bool(response.content), True) self.assertEqual(bool(response.content), True)
self.assertEqual('given_name' in response_dic, True, self.assertIn('given_name', response_dic, msg='"given_name" claim should be in response.')
msg='"given_name" claim should be in response.') self.assertNotIn('profile', response_dic, msg='"profile" claim should not be in response.')
self.assertEqual('profile' in response_dic, False,
msg='"profile" claim should not be in response.')
# Now adding `address` scope. # Now adding `address` scope.
token = self._create_token(extra_scope=['profile', 'address']) token = self._create_token(extra_scope=['profile', 'address'])
response = self._post_request(token.access_token) response = self._post_request(token.access_token)
response_dic = json.loads(response.content.decode('utf-8')) response_dic = json.loads(response.content.decode('utf-8'))
self.assertEqual('address' in response_dic, True, self.assertIn('address', response_dic, msg='"address" claim should be in response.')
msg='"address" claim should be in response.') self.assertIn('country', response_dic['address'], msg='"country" claim should be in response.')
self.assertEqual('country' in response_dic['address'], True,
msg='"country" claim should be in response.')

View file

@ -78,7 +78,8 @@ class AuthorizeView(View):
if 'select_account' in authorize.params['prompt']: if 'select_account' in authorize.params['prompt']:
# TODO: see how we can support multiple accounts for the end-user. # TODO: see how we can support multiple accounts for the end-user.
if 'none' in authorize.params['prompt']: if 'none' in authorize.params['prompt']:
raise AuthorizeError(authorize.params['redirect_uri'], 'account_selection_required', authorize.grant_type) raise AuthorizeError(
authorize.params['redirect_uri'], 'account_selection_required', authorize.grant_type)
else: else:
django_user_logout(request) django_user_logout(request)
return redirect_to_login(request.get_full_path(), settings.get('OIDC_LOGIN_URL')) return redirect_to_login(request.get_full_path(), settings.get('OIDC_LOGIN_URL'))
@ -86,7 +87,7 @@ class AuthorizeView(View):
if {'none', 'consent'}.issubset(authorize.params['prompt']): if {'none', 'consent'}.issubset(authorize.params['prompt']):
raise AuthorizeError(authorize.params['redirect_uri'], 'consent_required', authorize.grant_type) raise AuthorizeError(authorize.params['redirect_uri'], 'consent_required', authorize.grant_type)
implicit_flow_resp_types = set(['id_token', 'id_token token']) implicit_flow_resp_types = {'id_token', 'id_token token'}
allow_skipping_consent = ( allow_skipping_consent = (
authorize.client.client_type != 'public' or authorize.client.client_type != 'public' or
authorize.client.response_type in implicit_flow_resp_types) authorize.client.response_type in implicit_flow_resp_types)
@ -156,13 +157,15 @@ class AuthorizeView(View):
authorize.validate_params() authorize.validate_params()
if not request.POST.get('allow'): if not request.POST.get('allow'):
signals.user_decline_consent.send(self.__class__, user=request.user, client=authorize.client, scope=authorize.params['scope']) signals.user_decline_consent.send(
self.__class__, user=request.user, client=authorize.client, scope=authorize.params['scope'])
raise AuthorizeError(authorize.params['redirect_uri'], raise AuthorizeError(authorize.params['redirect_uri'],
'access_denied', 'access_denied',
authorize.grant_type) authorize.grant_type)
signals.user_accept_consent.send(self.__class__, user=request.user, client=authorize.client, scope=authorize.params['scope']) signals.user_accept_consent.send(
self.__class__, user=request.user, client=authorize.client, scope=authorize.params['scope'])
# Save the user consent given to the client. # Save the user consent given to the client.
authorize.set_client_user_consent() authorize.set_client_user_consent()
@ -171,7 +174,7 @@ class AuthorizeView(View):
return redirect(uri) return redirect(uri)
except (AuthorizeError) as error: except AuthorizeError as error:
uri = error.create_uri( uri = error.create_uri(
authorize.params['redirect_uri'], authorize.params['redirect_uri'],
authorize.params['state']) authorize.params['state'])

View file

@ -6,6 +6,7 @@ envlist=
py34-django{17,18,19,110,111}, py34-django{17,18,19,110,111},
py35-django{18,19,110,111}, py35-django{18,19,110,111},
py36-django{18,19,110,111}, py36-django{18,19,110,111},
flake8
[testenv] [testenv]
@ -30,3 +31,9 @@ commands=
commands= commands=
coverage report -m coverage report -m
[testenv:flake8]
basepython=python
deps=flake8
commands =
flake8 --max-line-length=120