Code factorisation in models.py

This commit is contained in:
Valentin Samir 2016-08-05 17:56:34 +02:00
parent ee003b6b65
commit d46428520f
4 changed files with 102 additions and 111 deletions

View file

@ -28,13 +28,38 @@ from datetime import timedelta
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from requests_futures.sessions import FuturesSession from requests_futures.sessions import FuturesSession
import cas_server.utils as utils from cas_server import utils
from . import VERSION from . import VERSION
#: logger facility #: logger facility
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class JsonAttributes(models.Model):
"""
Bases: :class:`django.db.models.Model`
A base class for models storing attributes as a json
"""
class Meta:
abstract = True
#: The attributes json encoded
_attributs = models.TextField(default=None, null=True, blank=True)
@property
def attributs(self):
"""The attributes"""
if self._attributs is not None:
return utils.json.loads(self._attributs)
@attributs.setter
def attributs(self, value):
"""attributs property setter"""
self._attributs = utils.json_encode(value)
@python_2_unicode_compatible @python_2_unicode_compatible
class FederatedIendityProvider(models.Model): class FederatedIendityProvider(models.Model):
""" """
@ -130,9 +155,9 @@ class FederatedIendityProvider(models.Model):
@python_2_unicode_compatible @python_2_unicode_compatible
class FederatedUser(models.Model): class FederatedUser(JsonAttributes):
""" """
Bases: :class:`django.db.models.Model` Bases: :class:`JsonAttributes`
A federated user as returner by a CAS provider (username and attributes) A federated user as returner by a CAS provider (username and attributes)
""" """
@ -142,8 +167,6 @@ class FederatedUser(models.Model):
username = models.CharField(max_length=124) username = models.CharField(max_length=124)
#: A foreign key to :class:`FederatedIendityProvider` #: A foreign key to :class:`FederatedIendityProvider`
provider = models.ForeignKey(FederatedIendityProvider, on_delete=models.CASCADE) provider = models.ForeignKey(FederatedIendityProvider, on_delete=models.CASCADE)
#: The user attributes json encoded
_attributs = models.TextField(default=None, null=True, blank=True)
#: The last ticket used to authenticate :attr:`username` against :attr:`provider` #: The last ticket used to authenticate :attr:`username` against :attr:`provider`
ticket = models.CharField(max_length=255) ticket = models.CharField(max_length=255)
#: Last update timespampt. Usually, the last time :attr:`ticket` has been set. #: Last update timespampt. Usually, the last time :attr:`ticket` has been set.
@ -152,17 +175,6 @@ class FederatedUser(models.Model):
def __str__(self): def __str__(self):
return self.federated_username return self.federated_username
@property
def attributs(self):
"""The user attributes returned by the CAS backend on successful ticket validation"""
if self._attributs is not None:
return utils.json.loads(self._attributs)
@attributs.setter
def attributs(self, value):
"""attributs property setter"""
self._attributs = utils.json_encode(value)
@property @property
def federated_username(self): def federated_username(self):
"""The federated username with a suffix for the current :class:`FederatedUser`.""" """The federated username with a suffix for the current :class:`FederatedUser`."""
@ -290,35 +302,23 @@ class User(models.Model):
:param request: The current django HttpRequest to display possible failure to the user. :param request: The current django HttpRequest to display possible failure to the user.
:type request: :class:`django.http.HttpRequest` or :obj:`NoneType<types.NoneType>` :type request: :class:`django.http.HttpRequest` or :obj:`NoneType<types.NoneType>`
""" """
async_list = []
session = FuturesSession(
executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
)
# first invalidate all Tickets
ticket_classes = [ProxyGrantingTicket, ServiceTicket, ProxyTicket] ticket_classes = [ProxyGrantingTicket, ServiceTicket, ProxyTicket]
for ticket_class in ticket_classes: for error in Ticket.send_slos(
queryset = ticket_class.objects.filter(user=self) [ticket_class.objects.filter(user=self) for ticket_class in ticket_classes]
for ticket in queryset: ):
ticket.logout(session, async_list) logger.warning(
queryset.delete() "Error during SLO for user %s: %s" % (
for future in async_list: self.username,
if future: # pragma: no branch (should always be true) error
try: )
future.result() )
except Exception as error: if request is not None:
logger.warning( error = utils.unpack_nested_exception(error)
"Error during SLO for user %s: %s" % ( messages.add_message(
self.username, request,
error messages.WARNING,
) _(u'Error during service logout %s') % error
) )
if request is not None:
error = utils.unpack_nested_exception(error)
messages.add_message(
request,
messages.WARNING,
_(u'Error during service logout %s') % error
)
def get_ticket(self, ticket_class, service, service_pattern, renew): def get_ticket(self, ticket_class, service, service_pattern, renew):
""" """
@ -544,20 +544,13 @@ class ServicePattern(models.Model):
if re.match(filtre.pattern, str(value)): if re.match(filtre.pattern, str(value)):
break break
else: else:
bad_filter = (filtre.pattern, filtre.attribut, user.attributs.get(filtre.attribut))
logger.warning( logger.warning(
"User constraint failed for %s, service %s: %s do not match %s %s." % ( "User constraint failed for %s, service %s: %s do not match %s %s." % (
user.username, (user.username, self.name) + bad_filter
self.name,
filtre.pattern,
filtre.attribut,
user.attributs.get(filtre.attribut)
) )
) )
raise BadFilter('%s do not match %s %s' % ( raise BadFilter('%s do not match %s %s' % bad_filter)
filtre.pattern,
filtre.attribut,
user.attributs.get(filtre.attribut)
))
if self.user_field and not user.attributs.get(self.user_field): if self.user_field and not user.attributs.get(self.user_field):
logger.warning( logger.warning(
"Cannot use %s a loggin for user %s on service %s because it is absent" % ( "Cannot use %s a loggin for user %s on service %s because it is absent" % (
@ -715,9 +708,9 @@ class ReplaceAttributValue(models.Model):
@python_2_unicode_compatible @python_2_unicode_compatible
class Ticket(models.Model): class Ticket(JsonAttributes):
""" """
Bases: :class:`django.db.models.Model` Bases: :class:`JsonAttributes`
Generic class for a Ticket Generic class for a Ticket
""" """
@ -725,8 +718,6 @@ class Ticket(models.Model):
abstract = True abstract = True
#: ForeignKey to a :class:`User`. #: ForeignKey to a :class:`User`.
user = models.ForeignKey(User, related_name="%(class)s") user = models.ForeignKey(User, related_name="%(class)s")
#: The user attributes to transmit to the service json encoded
_attributs = models.TextField(default=None, null=True, blank=True)
#: A boolean. ``True`` if the ticket has been validated #: A boolean. ``True`` if the ticket has been validated
validate = models.BooleanField(default=False) validate = models.BooleanField(default=False)
#: The service url for the ticket #: The service url for the ticket
@ -749,17 +740,6 @@ class Ticket(models.Model):
#: requests. #: requests.
TIMEOUT = settings.CAS_TICKET_TIMEOUT TIMEOUT = settings.CAS_TICKET_TIMEOUT
@property
def attributs(self):
"""The user attributes to be transmited to the service on successful validation"""
if self._attributs is not None:
return utils.json.loads(self._attributs)
@attributs.setter
def attributs(self, value):
"""attributs property setter"""
self._attributs = utils.json_encode(value)
class DoesNotExist(Exception): class DoesNotExist(Exception):
"""raised in :meth:`Ticket.get` then ticket prefix and ticket classes mismatch""" """raised in :meth:`Ticket.get` then ticket prefix and ticket classes mismatch"""
pass pass
@ -767,6 +747,33 @@ class Ticket(models.Model):
def __str__(self): def __str__(self):
return u"Ticket-%s" % self.pk return u"Ticket-%s" % self.pk
@staticmethod
def send_slos(queryset_list):
"""
Send SLO requests to each ticket of each queryset of ``queryset_list``
:param list queryset_list: A list a :class:`Ticket` queryset
:return: A list of possibly encoutered :class:`Exception`
:rtype: list
"""
# sending SLO to timed-out validated tickets
async_list = []
session = FuturesSession(
executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
)
errors = []
for queryset in queryset_list:
for ticket in queryset:
ticket.logout(session, async_list)
queryset.delete()
for future in async_list:
if future: # pragma: no branch (should always be true)
try:
future.result()
except Exception as error:
errors.append(error)
return errors
@classmethod @classmethod
def clean_old_entries(cls): def clean_old_entries(cls):
"""Remove old ticket and send SLO to timed-out services""" """Remove old ticket and send SLO to timed-out services"""
@ -779,25 +786,12 @@ class Ticket(models.Model):
Q(creation__lt=(timezone.now() - timedelta(seconds=cls.VALIDITY))) Q(creation__lt=(timezone.now() - timedelta(seconds=cls.VALIDITY)))
) )
).delete() ).delete()
# sending SLO to timed-out validated tickets
async_list = []
session = FuturesSession(
executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
)
queryset = cls.objects.filter( queryset = cls.objects.filter(
creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT)) creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT))
) )
for ticket in queryset: for error in cls.send_slos([queryset]):
ticket.logout(session, async_list) logger.warning("Error durring SLO %s" % error)
queryset.delete() sys.stderr.write("%r\n" % error)
for future in async_list:
if future: # pragma: no branch (should always be true)
try:
future.result()
except Exception as error:
logger.warning("Error durring SLO %s" % error)
sys.stderr.write("%r\n" % error)
def logout(self, session, async_list=None): def logout(self, session, async_list=None):
"""Send a SLO request to the ticket service""" """Send a SLO request to the ticket service"""
@ -811,16 +805,7 @@ class Ticket(models.Model):
self.user.username self.user.username
) )
) )
xml = u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" xml = utils.logout_request(self.value)
ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
<samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
</samlp:LogoutRequest>""" % \
{
'id': utils.gen_saml_id(),
'datetime': timezone.now().isoformat(),
'ticket': self.value
}
if self.service_pattern.single_log_out_callback: if self.service_pattern.single_log_out_callback:
url = self.service_pattern.single_log_out_callback url = self.service_pattern.single_log_out_callback
else: else:

View file

@ -261,7 +261,7 @@ class FederateAuthLoginLogoutTestCase(
# SLO for an unkown ticket should do nothing # SLO for an unkown ticket should do nothing
response = client.post( response = client.post(
"/federate/%s" % provider.suffix, "/federate/%s" % provider.suffix,
{'logoutRequest': tests_utils.logout_request(utils.gen_st())} {'logoutRequest': utils.logout_request(utils.gen_st())}
) )
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"ok") self.assertEqual(response.content, b"ok")
@ -288,7 +288,7 @@ class FederateAuthLoginLogoutTestCase(
# 3 or 'CAS_2_SAML_1_0' # 3 or 'CAS_2_SAML_1_0'
response = client.post( response = client.post(
"/federate/%s" % provider.suffix, "/federate/%s" % provider.suffix,
{'logoutRequest': tests_utils.logout_request(ticket)} {'logoutRequest': utils.logout_request(ticket)}
) )
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"ok") self.assertEqual(response.content, b"ok")

View file

@ -340,17 +340,3 @@ class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
httpd_thread.daemon = True httpd_thread.daemon = True
httpd_thread.start() httpd_thread.start()
return (httpd, host, port) return (httpd, host, port)
def logout_request(ticket):
"""build a SLO request XML, ready to be send"""
return u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
<samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
</samlp:LogoutRequest>""" % \
{
'id': utils.gen_saml_id(),
'datetime': timezone.now().isoformat(),
'ticket': ticket
}

View file

@ -17,6 +17,7 @@ from django.http import HttpResponseRedirect, HttpResponse
from django.contrib import messages from django.contrib import messages
from django.contrib.messages import constants as DEFAULT_MESSAGE_LEVELS from django.contrib.messages import constants as DEFAULT_MESSAGE_LEVELS
from django.core.serializers.json import DjangoJSONEncoder from django.core.serializers.json import DjangoJSONEncoder
from django.utils import timezone
import random import random
import string import string
@ -680,3 +681,22 @@ def dictfetchall(cursor):
dict(zip(columns, row)) dict(zip(columns, row))
for row in cursor.fetchall() for row in cursor.fetchall()
] ]
def logout_request(ticket):
"""
Forge a SLO logout request
:param unicode ticket: A ticket value
:return: A SLO XML body request
:rtype: unicode
"""
return u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
<samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
</samlp:LogoutRequest>""" % {
'id': gen_saml_id(),
'datetime': timezone.now().isoformat(),
'ticket': ticket
}