diff --git a/openid_provider/lib/endpoints/authorize.py b/openid_provider/lib/endpoints/authorize.py index 03c9a68..f12ac97 100644 --- a/openid_provider/lib/endpoints/authorize.py +++ b/openid_provider/lib/endpoints/authorize.py @@ -41,7 +41,7 @@ class AuthorizeEndpoint(object): self.params.client_id = self.query_dict.get('client_id', '') self.params.redirect_uri = self.query_dict.get('redirect_uri', '') self.params.response_type = self.query_dict.get('response_type', '') - self.params.scope = self.query_dict.get('scope', '') + self.params.scope = self.query_dict.get('scope', '').split() self.params.state = self.query_dict.get('state', '') def _extract_implicit_params(self): @@ -57,8 +57,11 @@ class AuthorizeEndpoint(object): if not self.params.redirect_uri: raise RedirectUriError() - if not ('openid' in self.params.scope.split()): - raise AuthorizeError(self.params.redirect_uri, 'invalid_scope', self.grant_type) + if not ('openid' in self.params.scope): + raise AuthorizeError( + self.params.redirect_uri, + 'invalid_scope', + self.grant_type) try: self.client = Client.objects.get(client_id=self.params.client_id) @@ -66,8 +69,13 @@ class AuthorizeEndpoint(object): if not (self.params.redirect_uri in self.client.redirect_uris): raise RedirectUriError() - if not (self.grant_type) or not (self.params.response_type == self.client.response_type): - raise AuthorizeError(self.params.redirect_uri, 'unsupported_response_type', self.grant_type) + if not (self.grant_type) or \ + not (self.params.response_type == self.client.response_type): + + raise AuthorizeError( + self.params.redirect_uri, + 'unsupported_response_type', + self.grant_type) except Client.DoesNotExist: raise ClientIdError() @@ -75,7 +83,10 @@ class AuthorizeEndpoint(object): def create_response_uri(self, allow): if not allow: - raise AuthorizeError(self.params.redirect_uri, 'access_denied', self.grant_type) + raise AuthorizeError( + self.params.redirect_uri, + 'access_denied', + self.grant_type) try: self.validate_params() @@ -110,7 +121,7 @@ class AuthorizeEndpoint(object): id_token = encode_id_token(id_token_dic, self.client.client_secret) - # TODO: Check if response_type is 'id_token token' and + # TODO: Check if response_type is 'id_token token' then # add access_token to the fragment. uri = self.params.redirect_uri + \ '#token_type={0}&id_token={1}&expires_in={2}'.format( @@ -118,9 +129,13 @@ class AuthorizeEndpoint(object): id_token, 60*10) except: - raise AuthorizeError(self.params.redirect_uri, 'server_error', self.grant_type) + raise AuthorizeError( + self.params.redirect_uri, + 'server_error', + self.grant_type) # Add state if present. - uri = uri + ('&state={0}'.format(self.params.state) if self.params.state else '') + uri = uri + ('&state={0}'.format(self.params.state) + if self.params.state else '') return uri \ No newline at end of file diff --git a/openid_provider/lib/utils/decorators.py b/openid_provider/lib/utils/decorators.py new file mode 100644 index 0000000..fb91ecb --- /dev/null +++ b/openid_provider/lib/utils/decorators.py @@ -0,0 +1,18 @@ +from django.contrib.auth import REDIRECT_FIELD_NAME +from django.contrib.auth.decorators import user_passes_test + + +def staff_required(function=None, redirect_field_name=REDIRECT_FIELD_NAME, login_url=None): + """ + Decorator for views that checks that the user is logged in and is staff, + redirecting to the log-in page if necessary. + """ + actual_decorator = user_passes_test( + lambda u: u.is_authenticated() and u.is_staff, + login_url=login_url, + redirect_field_name=redirect_field_name + ) + if function: + return actual_decorator(function) + + return actual_decorator \ No newline at end of file diff --git a/openid_provider/lib/utils/params.py b/openid_provider/lib/utils/params.py index 927d507..94f8661 100644 --- a/openid_provider/lib/utils/params.py +++ b/openid_provider/lib/utils/params.py @@ -1,2 +1,4 @@ + + class Params(object): pass \ No newline at end of file diff --git a/openid_provider/migrations/0001_initial.py b/openid_provider/migrations/0001_initial.py new file mode 100644 index 0000000..67d9e56 --- /dev/null +++ b/openid_provider/migrations/0001_initial.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals + +from django.db import models, migrations +from django.conf import settings + + +class Migration(migrations.Migration): + + dependencies = [ + ('auth', '0001_initial'), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name='Client', + fields=[ + ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), + ('name', models.CharField(default=b'', max_length=100)), + ('client_id', models.CharField(unique=True, max_length=255)), + ('client_secret', models.CharField(unique=True, max_length=255)), + ('client_type', models.CharField(max_length=20, choices=[(b'confidential', b'Confidential'), (b'public', b'Public')])), + ('response_type', models.CharField(max_length=30, choices=[(b'code', b'code (Authorization Code Flow)'), (b'id_token', b'id_token (Implicit Flow)'), (b'id_token token', b'id_token token (Implicit Flow)')])), + ('_scope', models.TextField(default=b'')), + ('_redirect_uris', models.TextField(default=b'')), + ], + options={ + }, + bases=(models.Model,), + ), + migrations.CreateModel( + name='Code', + fields=[ + ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), + ('code', models.CharField(unique=True, max_length=255)), + ('expires_at', models.DateTimeField()), + ('_scope', models.TextField(default=b'')), + ('client', models.ForeignKey(to='openid_provider.Client')), + ], + options={ + }, + bases=(models.Model,), + ), + migrations.CreateModel( + name='Token', + fields=[ + ('id', models.AutoField(verbose_name='ID', serialize=False, auto_created=True, primary_key=True)), + ('access_token', models.CharField(unique=True, max_length=255)), + ('expires_at', models.DateTimeField()), + ('_scope', models.TextField(default=b'')), + ('_id_token', models.TextField()), + ('client', models.ForeignKey(to='openid_provider.Client')), + ], + options={ + }, + bases=(models.Model,), + ), + migrations.CreateModel( + name='UserInfo', + fields=[ + ('user', models.OneToOneField(primary_key=True, serialize=False, to=settings.AUTH_USER_MODEL)), + ('given_name', models.CharField(default=b'', max_length=255)), + ('family_name', models.CharField(default=b'', max_length=255)), + ('middle_name', models.CharField(default=b'', max_length=255)), + ('nickname', models.CharField(default=b'', max_length=255)), + ('preferred_username', models.CharField(default=b'', max_length=255)), + ('profile', models.URLField(default=b'')), + ('picture', models.URLField(default=b'')), + ('website', models.URLField(default=b'')), + ('email_verified', models.BooleanField(default=False)), + ('gender', models.CharField(default=b'', max_length=100)), + ('birthdate', models.DateField()), + ('zoneinfo', models.CharField(default=b'', max_length=100)), + ('locale', models.CharField(default=b'', max_length=100)), + ('phone_number', models.CharField(default=b'', max_length=255)), + ('phone_number_verified', models.BooleanField(default=False)), + ('address_formatted', models.CharField(default=b'', max_length=255)), + ('address_street_address', models.CharField(default=b'', max_length=255)), + ('address_locality', models.CharField(default=b'', max_length=255)), + ('address_region', models.CharField(default=b'', max_length=255)), + ('address_postal_code', models.CharField(default=b'', max_length=255)), + ('address_country', models.CharField(default=b'', max_length=255)), + ('updated_at', models.DateTimeField()), + ], + options={ + }, + bases=(models.Model,), + ), + migrations.AddField( + model_name='token', + name='user', + field=models.ForeignKey(to=settings.AUTH_USER_MODEL), + preserve_default=True, + ), + migrations.AddField( + model_name='code', + name='user', + field=models.ForeignKey(to=settings.AUTH_USER_MODEL), + preserve_default=True, + ), + migrations.AddField( + model_name='client', + name='user', + field=models.ForeignKey(to=settings.AUTH_USER_MODEL), + preserve_default=True, + ), + ] diff --git a/openid_provider/models.py b/openid_provider/models.py index a45d7ea..c4769b2 100644 --- a/openid_provider/models.py +++ b/openid_provider/models.py @@ -54,7 +54,15 @@ class Code(models.Model): client = models.ForeignKey(Client) code = models.CharField(max_length=255, unique=True) expires_at = models.DateTimeField() - scope = models.TextField() # TODO: add getter and setter for this. + + _scope = models.TextField(default='') + def scope(): + def fget(self): + return self._scope.split() + def fset(self, value): + self._scope = ' '.join(value) + return locals() + scope = property(**scope()) def has_expired(self): return timezone.now() >= self.expires_at @@ -64,9 +72,16 @@ class Token(models.Model): user = models.ForeignKey(User) client = models.ForeignKey(Client) access_token = models.CharField(max_length=255, unique=True) - refresh_token = models.CharField(max_length=255, unique=True) expires_at = models.DateTimeField() - scope = models.TextField() # TODO: add getter and setter for this. + + _scope = models.TextField(default='') + def scope(): + def fget(self): + return self._scope.split() + def fset(self, value): + self._scope = ' '.join(value) + return locals() + scope = property(**scope()) _id_token = models.TextField() def id_token(): diff --git a/openid_provider/templates/openid_provider/authorize.html b/openid_provider/templates/openid_provider/authorize.html index 115dc90..a6f3363 100644 --- a/openid_provider/templates/openid_provider/authorize.html +++ b/openid_provider/templates/openid_provider/authorize.html @@ -11,18 +11,18 @@
Client {{ client.name }} would like to access this information of you ...
+Client {{ client.name }} would like to access this information of you ...