Code factorisation in models.py
This commit is contained in:
parent
ee003b6b65
commit
d46428520f
4 changed files with 102 additions and 111 deletions
|
@ -28,13 +28,38 @@ from datetime import timedelta
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
from requests_futures.sessions import FuturesSession
|
||||
|
||||
import cas_server.utils as utils
|
||||
from cas_server import utils
|
||||
from . import VERSION
|
||||
|
||||
#: logger facility
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JsonAttributes(models.Model):
|
||||
"""
|
||||
Bases: :class:`django.db.models.Model`
|
||||
|
||||
A base class for models storing attributes as a json
|
||||
"""
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
#: The attributes json encoded
|
||||
_attributs = models.TextField(default=None, null=True, blank=True)
|
||||
|
||||
@property
|
||||
def attributs(self):
|
||||
"""The attributes"""
|
||||
if self._attributs is not None:
|
||||
return utils.json.loads(self._attributs)
|
||||
|
||||
@attributs.setter
|
||||
def attributs(self, value):
|
||||
"""attributs property setter"""
|
||||
self._attributs = utils.json_encode(value)
|
||||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class FederatedIendityProvider(models.Model):
|
||||
"""
|
||||
|
@ -130,9 +155,9 @@ class FederatedIendityProvider(models.Model):
|
|||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class FederatedUser(models.Model):
|
||||
class FederatedUser(JsonAttributes):
|
||||
"""
|
||||
Bases: :class:`django.db.models.Model`
|
||||
Bases: :class:`JsonAttributes`
|
||||
|
||||
A federated user as returner by a CAS provider (username and attributes)
|
||||
"""
|
||||
|
@ -142,8 +167,6 @@ class FederatedUser(models.Model):
|
|||
username = models.CharField(max_length=124)
|
||||
#: A foreign key to :class:`FederatedIendityProvider`
|
||||
provider = models.ForeignKey(FederatedIendityProvider, on_delete=models.CASCADE)
|
||||
#: The user attributes json encoded
|
||||
_attributs = models.TextField(default=None, null=True, blank=True)
|
||||
#: The last ticket used to authenticate :attr:`username` against :attr:`provider`
|
||||
ticket = models.CharField(max_length=255)
|
||||
#: Last update timespampt. Usually, the last time :attr:`ticket` has been set.
|
||||
|
@ -152,17 +175,6 @@ class FederatedUser(models.Model):
|
|||
def __str__(self):
|
||||
return self.federated_username
|
||||
|
||||
@property
|
||||
def attributs(self):
|
||||
"""The user attributes returned by the CAS backend on successful ticket validation"""
|
||||
if self._attributs is not None:
|
||||
return utils.json.loads(self._attributs)
|
||||
|
||||
@attributs.setter
|
||||
def attributs(self, value):
|
||||
"""attributs property setter"""
|
||||
self._attributs = utils.json_encode(value)
|
||||
|
||||
@property
|
||||
def federated_username(self):
|
||||
"""The federated username with a suffix for the current :class:`FederatedUser`."""
|
||||
|
@ -290,22 +302,10 @@ class User(models.Model):
|
|||
:param request: The current django HttpRequest to display possible failure to the user.
|
||||
:type request: :class:`django.http.HttpRequest` or :obj:`NoneType<types.NoneType>`
|
||||
"""
|
||||
async_list = []
|
||||
session = FuturesSession(
|
||||
executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
|
||||
)
|
||||
# first invalidate all Tickets
|
||||
ticket_classes = [ProxyGrantingTicket, ServiceTicket, ProxyTicket]
|
||||
for ticket_class in ticket_classes:
|
||||
queryset = ticket_class.objects.filter(user=self)
|
||||
for ticket in queryset:
|
||||
ticket.logout(session, async_list)
|
||||
queryset.delete()
|
||||
for future in async_list:
|
||||
if future: # pragma: no branch (should always be true)
|
||||
try:
|
||||
future.result()
|
||||
except Exception as error:
|
||||
for error in Ticket.send_slos(
|
||||
[ticket_class.objects.filter(user=self) for ticket_class in ticket_classes]
|
||||
):
|
||||
logger.warning(
|
||||
"Error during SLO for user %s: %s" % (
|
||||
self.username,
|
||||
|
@ -544,20 +544,13 @@ class ServicePattern(models.Model):
|
|||
if re.match(filtre.pattern, str(value)):
|
||||
break
|
||||
else:
|
||||
bad_filter = (filtre.pattern, filtre.attribut, user.attributs.get(filtre.attribut))
|
||||
logger.warning(
|
||||
"User constraint failed for %s, service %s: %s do not match %s %s." % (
|
||||
user.username,
|
||||
self.name,
|
||||
filtre.pattern,
|
||||
filtre.attribut,
|
||||
user.attributs.get(filtre.attribut)
|
||||
(user.username, self.name) + bad_filter
|
||||
)
|
||||
)
|
||||
raise BadFilter('%s do not match %s %s' % (
|
||||
filtre.pattern,
|
||||
filtre.attribut,
|
||||
user.attributs.get(filtre.attribut)
|
||||
))
|
||||
raise BadFilter('%s do not match %s %s' % bad_filter)
|
||||
if self.user_field and not user.attributs.get(self.user_field):
|
||||
logger.warning(
|
||||
"Cannot use %s a loggin for user %s on service %s because it is absent" % (
|
||||
|
@ -715,9 +708,9 @@ class ReplaceAttributValue(models.Model):
|
|||
|
||||
|
||||
@python_2_unicode_compatible
|
||||
class Ticket(models.Model):
|
||||
class Ticket(JsonAttributes):
|
||||
"""
|
||||
Bases: :class:`django.db.models.Model`
|
||||
Bases: :class:`JsonAttributes`
|
||||
|
||||
Generic class for a Ticket
|
||||
"""
|
||||
|
@ -725,8 +718,6 @@ class Ticket(models.Model):
|
|||
abstract = True
|
||||
#: ForeignKey to a :class:`User`.
|
||||
user = models.ForeignKey(User, related_name="%(class)s")
|
||||
#: The user attributes to transmit to the service json encoded
|
||||
_attributs = models.TextField(default=None, null=True, blank=True)
|
||||
#: A boolean. ``True`` if the ticket has been validated
|
||||
validate = models.BooleanField(default=False)
|
||||
#: The service url for the ticket
|
||||
|
@ -749,17 +740,6 @@ class Ticket(models.Model):
|
|||
#: requests.
|
||||
TIMEOUT = settings.CAS_TICKET_TIMEOUT
|
||||
|
||||
@property
|
||||
def attributs(self):
|
||||
"""The user attributes to be transmited to the service on successful validation"""
|
||||
if self._attributs is not None:
|
||||
return utils.json.loads(self._attributs)
|
||||
|
||||
@attributs.setter
|
||||
def attributs(self, value):
|
||||
"""attributs property setter"""
|
||||
self._attributs = utils.json_encode(value)
|
||||
|
||||
class DoesNotExist(Exception):
|
||||
"""raised in :meth:`Ticket.get` then ticket prefix and ticket classes mismatch"""
|
||||
pass
|
||||
|
@ -767,6 +747,33 @@ class Ticket(models.Model):
|
|||
def __str__(self):
|
||||
return u"Ticket-%s" % self.pk
|
||||
|
||||
@staticmethod
|
||||
def send_slos(queryset_list):
|
||||
"""
|
||||
Send SLO requests to each ticket of each queryset of ``queryset_list``
|
||||
|
||||
:param list queryset_list: A list a :class:`Ticket` queryset
|
||||
:return: A list of possibly encoutered :class:`Exception`
|
||||
:rtype: list
|
||||
"""
|
||||
# sending SLO to timed-out validated tickets
|
||||
async_list = []
|
||||
session = FuturesSession(
|
||||
executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
|
||||
)
|
||||
errors = []
|
||||
for queryset in queryset_list:
|
||||
for ticket in queryset:
|
||||
ticket.logout(session, async_list)
|
||||
queryset.delete()
|
||||
for future in async_list:
|
||||
if future: # pragma: no branch (should always be true)
|
||||
try:
|
||||
future.result()
|
||||
except Exception as error:
|
||||
errors.append(error)
|
||||
return errors
|
||||
|
||||
@classmethod
|
||||
def clean_old_entries(cls):
|
||||
"""Remove old ticket and send SLO to timed-out services"""
|
||||
|
@ -779,23 +786,10 @@ class Ticket(models.Model):
|
|||
Q(creation__lt=(timezone.now() - timedelta(seconds=cls.VALIDITY)))
|
||||
)
|
||||
).delete()
|
||||
|
||||
# sending SLO to timed-out validated tickets
|
||||
async_list = []
|
||||
session = FuturesSession(
|
||||
executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
|
||||
)
|
||||
queryset = cls.objects.filter(
|
||||
creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT))
|
||||
)
|
||||
for ticket in queryset:
|
||||
ticket.logout(session, async_list)
|
||||
queryset.delete()
|
||||
for future in async_list:
|
||||
if future: # pragma: no branch (should always be true)
|
||||
try:
|
||||
future.result()
|
||||
except Exception as error:
|
||||
for error in cls.send_slos([queryset]):
|
||||
logger.warning("Error durring SLO %s" % error)
|
||||
sys.stderr.write("%r\n" % error)
|
||||
|
||||
|
@ -811,16 +805,7 @@ class Ticket(models.Model):
|
|||
self.user.username
|
||||
)
|
||||
)
|
||||
xml = u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
|
||||
ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
|
||||
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
|
||||
<samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
|
||||
</samlp:LogoutRequest>""" % \
|
||||
{
|
||||
'id': utils.gen_saml_id(),
|
||||
'datetime': timezone.now().isoformat(),
|
||||
'ticket': self.value
|
||||
}
|
||||
xml = utils.logout_request(self.value)
|
||||
if self.service_pattern.single_log_out_callback:
|
||||
url = self.service_pattern.single_log_out_callback
|
||||
else:
|
||||
|
|
|
@ -261,7 +261,7 @@ class FederateAuthLoginLogoutTestCase(
|
|||
# SLO for an unkown ticket should do nothing
|
||||
response = client.post(
|
||||
"/federate/%s" % provider.suffix,
|
||||
{'logoutRequest': tests_utils.logout_request(utils.gen_st())}
|
||||
{'logoutRequest': utils.logout_request(utils.gen_st())}
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.content, b"ok")
|
||||
|
@ -288,7 +288,7 @@ class FederateAuthLoginLogoutTestCase(
|
|||
# 3 or 'CAS_2_SAML_1_0'
|
||||
response = client.post(
|
||||
"/federate/%s" % provider.suffix,
|
||||
{'logoutRequest': tests_utils.logout_request(ticket)}
|
||||
{'logoutRequest': utils.logout_request(ticket)}
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.content, b"ok")
|
||||
|
|
|
@ -340,17 +340,3 @@ class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
|
|||
httpd_thread.daemon = True
|
||||
httpd_thread.start()
|
||||
return (httpd, host, port)
|
||||
|
||||
|
||||
def logout_request(ticket):
|
||||
"""build a SLO request XML, ready to be send"""
|
||||
return u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
|
||||
ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
|
||||
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
|
||||
<samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
|
||||
</samlp:LogoutRequest>""" % \
|
||||
{
|
||||
'id': utils.gen_saml_id(),
|
||||
'datetime': timezone.now().isoformat(),
|
||||
'ticket': ticket
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ from django.http import HttpResponseRedirect, HttpResponse
|
|||
from django.contrib import messages
|
||||
from django.contrib.messages import constants as DEFAULT_MESSAGE_LEVELS
|
||||
from django.core.serializers.json import DjangoJSONEncoder
|
||||
from django.utils import timezone
|
||||
|
||||
import random
|
||||
import string
|
||||
|
@ -680,3 +681,22 @@ def dictfetchall(cursor):
|
|||
dict(zip(columns, row))
|
||||
for row in cursor.fetchall()
|
||||
]
|
||||
|
||||
|
||||
def logout_request(ticket):
|
||||
"""
|
||||
Forge a SLO logout request
|
||||
|
||||
:param unicode ticket: A ticket value
|
||||
:return: A SLO XML body request
|
||||
:rtype: unicode
|
||||
"""
|
||||
return u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
|
||||
ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
|
||||
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
|
||||
<samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
|
||||
</samlp:LogoutRequest>""" % {
|
||||
'id': gen_saml_id(),
|
||||
'datetime': timezone.now().isoformat(),
|
||||
'ticket': ticket
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue