Code factorisation in models.py
This commit is contained in:
parent
ee003b6b65
commit
d46428520f
4 changed files with 102 additions and 111 deletions
|
@ -28,13 +28,38 @@ from datetime import timedelta
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from requests_futures.sessions import FuturesSession
|
from requests_futures.sessions import FuturesSession
|
||||||
|
|
||||||
import cas_server.utils as utils
|
from cas_server import utils
|
||||||
from . import VERSION
|
from . import VERSION
|
||||||
|
|
||||||
#: logger facility
|
#: logger facility
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class JsonAttributes(models.Model):
|
||||||
|
"""
|
||||||
|
Bases: :class:`django.db.models.Model`
|
||||||
|
|
||||||
|
A base class for models storing attributes as a json
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
abstract = True
|
||||||
|
|
||||||
|
#: The attributes json encoded
|
||||||
|
_attributs = models.TextField(default=None, null=True, blank=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def attributs(self):
|
||||||
|
"""The attributes"""
|
||||||
|
if self._attributs is not None:
|
||||||
|
return utils.json.loads(self._attributs)
|
||||||
|
|
||||||
|
@attributs.setter
|
||||||
|
def attributs(self, value):
|
||||||
|
"""attributs property setter"""
|
||||||
|
self._attributs = utils.json_encode(value)
|
||||||
|
|
||||||
|
|
||||||
@python_2_unicode_compatible
|
@python_2_unicode_compatible
|
||||||
class FederatedIendityProvider(models.Model):
|
class FederatedIendityProvider(models.Model):
|
||||||
"""
|
"""
|
||||||
|
@ -130,9 +155,9 @@ class FederatedIendityProvider(models.Model):
|
||||||
|
|
||||||
|
|
||||||
@python_2_unicode_compatible
|
@python_2_unicode_compatible
|
||||||
class FederatedUser(models.Model):
|
class FederatedUser(JsonAttributes):
|
||||||
"""
|
"""
|
||||||
Bases: :class:`django.db.models.Model`
|
Bases: :class:`JsonAttributes`
|
||||||
|
|
||||||
A federated user as returner by a CAS provider (username and attributes)
|
A federated user as returner by a CAS provider (username and attributes)
|
||||||
"""
|
"""
|
||||||
|
@ -142,8 +167,6 @@ class FederatedUser(models.Model):
|
||||||
username = models.CharField(max_length=124)
|
username = models.CharField(max_length=124)
|
||||||
#: A foreign key to :class:`FederatedIendityProvider`
|
#: A foreign key to :class:`FederatedIendityProvider`
|
||||||
provider = models.ForeignKey(FederatedIendityProvider, on_delete=models.CASCADE)
|
provider = models.ForeignKey(FederatedIendityProvider, on_delete=models.CASCADE)
|
||||||
#: The user attributes json encoded
|
|
||||||
_attributs = models.TextField(default=None, null=True, blank=True)
|
|
||||||
#: The last ticket used to authenticate :attr:`username` against :attr:`provider`
|
#: The last ticket used to authenticate :attr:`username` against :attr:`provider`
|
||||||
ticket = models.CharField(max_length=255)
|
ticket = models.CharField(max_length=255)
|
||||||
#: Last update timespampt. Usually, the last time :attr:`ticket` has been set.
|
#: Last update timespampt. Usually, the last time :attr:`ticket` has been set.
|
||||||
|
@ -152,17 +175,6 @@ class FederatedUser(models.Model):
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.federated_username
|
return self.federated_username
|
||||||
|
|
||||||
@property
|
|
||||||
def attributs(self):
|
|
||||||
"""The user attributes returned by the CAS backend on successful ticket validation"""
|
|
||||||
if self._attributs is not None:
|
|
||||||
return utils.json.loads(self._attributs)
|
|
||||||
|
|
||||||
@attributs.setter
|
|
||||||
def attributs(self, value):
|
|
||||||
"""attributs property setter"""
|
|
||||||
self._attributs = utils.json_encode(value)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def federated_username(self):
|
def federated_username(self):
|
||||||
"""The federated username with a suffix for the current :class:`FederatedUser`."""
|
"""The federated username with a suffix for the current :class:`FederatedUser`."""
|
||||||
|
@ -290,35 +302,23 @@ class User(models.Model):
|
||||||
:param request: The current django HttpRequest to display possible failure to the user.
|
:param request: The current django HttpRequest to display possible failure to the user.
|
||||||
:type request: :class:`django.http.HttpRequest` or :obj:`NoneType<types.NoneType>`
|
:type request: :class:`django.http.HttpRequest` or :obj:`NoneType<types.NoneType>`
|
||||||
"""
|
"""
|
||||||
async_list = []
|
|
||||||
session = FuturesSession(
|
|
||||||
executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
|
|
||||||
)
|
|
||||||
# first invalidate all Tickets
|
|
||||||
ticket_classes = [ProxyGrantingTicket, ServiceTicket, ProxyTicket]
|
ticket_classes = [ProxyGrantingTicket, ServiceTicket, ProxyTicket]
|
||||||
for ticket_class in ticket_classes:
|
for error in Ticket.send_slos(
|
||||||
queryset = ticket_class.objects.filter(user=self)
|
[ticket_class.objects.filter(user=self) for ticket_class in ticket_classes]
|
||||||
for ticket in queryset:
|
):
|
||||||
ticket.logout(session, async_list)
|
logger.warning(
|
||||||
queryset.delete()
|
"Error during SLO for user %s: %s" % (
|
||||||
for future in async_list:
|
self.username,
|
||||||
if future: # pragma: no branch (should always be true)
|
error
|
||||||
try:
|
)
|
||||||
future.result()
|
)
|
||||||
except Exception as error:
|
if request is not None:
|
||||||
logger.warning(
|
error = utils.unpack_nested_exception(error)
|
||||||
"Error during SLO for user %s: %s" % (
|
messages.add_message(
|
||||||
self.username,
|
request,
|
||||||
error
|
messages.WARNING,
|
||||||
)
|
_(u'Error during service logout %s') % error
|
||||||
)
|
)
|
||||||
if request is not None:
|
|
||||||
error = utils.unpack_nested_exception(error)
|
|
||||||
messages.add_message(
|
|
||||||
request,
|
|
||||||
messages.WARNING,
|
|
||||||
_(u'Error during service logout %s') % error
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_ticket(self, ticket_class, service, service_pattern, renew):
|
def get_ticket(self, ticket_class, service, service_pattern, renew):
|
||||||
"""
|
"""
|
||||||
|
@ -544,20 +544,13 @@ class ServicePattern(models.Model):
|
||||||
if re.match(filtre.pattern, str(value)):
|
if re.match(filtre.pattern, str(value)):
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
bad_filter = (filtre.pattern, filtre.attribut, user.attributs.get(filtre.attribut))
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"User constraint failed for %s, service %s: %s do not match %s %s." % (
|
"User constraint failed for %s, service %s: %s do not match %s %s." % (
|
||||||
user.username,
|
(user.username, self.name) + bad_filter
|
||||||
self.name,
|
|
||||||
filtre.pattern,
|
|
||||||
filtre.attribut,
|
|
||||||
user.attributs.get(filtre.attribut)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
raise BadFilter('%s do not match %s %s' % (
|
raise BadFilter('%s do not match %s %s' % bad_filter)
|
||||||
filtre.pattern,
|
|
||||||
filtre.attribut,
|
|
||||||
user.attributs.get(filtre.attribut)
|
|
||||||
))
|
|
||||||
if self.user_field and not user.attributs.get(self.user_field):
|
if self.user_field and not user.attributs.get(self.user_field):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Cannot use %s a loggin for user %s on service %s because it is absent" % (
|
"Cannot use %s a loggin for user %s on service %s because it is absent" % (
|
||||||
|
@ -715,9 +708,9 @@ class ReplaceAttributValue(models.Model):
|
||||||
|
|
||||||
|
|
||||||
@python_2_unicode_compatible
|
@python_2_unicode_compatible
|
||||||
class Ticket(models.Model):
|
class Ticket(JsonAttributes):
|
||||||
"""
|
"""
|
||||||
Bases: :class:`django.db.models.Model`
|
Bases: :class:`JsonAttributes`
|
||||||
|
|
||||||
Generic class for a Ticket
|
Generic class for a Ticket
|
||||||
"""
|
"""
|
||||||
|
@ -725,8 +718,6 @@ class Ticket(models.Model):
|
||||||
abstract = True
|
abstract = True
|
||||||
#: ForeignKey to a :class:`User`.
|
#: ForeignKey to a :class:`User`.
|
||||||
user = models.ForeignKey(User, related_name="%(class)s")
|
user = models.ForeignKey(User, related_name="%(class)s")
|
||||||
#: The user attributes to transmit to the service json encoded
|
|
||||||
_attributs = models.TextField(default=None, null=True, blank=True)
|
|
||||||
#: A boolean. ``True`` if the ticket has been validated
|
#: A boolean. ``True`` if the ticket has been validated
|
||||||
validate = models.BooleanField(default=False)
|
validate = models.BooleanField(default=False)
|
||||||
#: The service url for the ticket
|
#: The service url for the ticket
|
||||||
|
@ -749,17 +740,6 @@ class Ticket(models.Model):
|
||||||
#: requests.
|
#: requests.
|
||||||
TIMEOUT = settings.CAS_TICKET_TIMEOUT
|
TIMEOUT = settings.CAS_TICKET_TIMEOUT
|
||||||
|
|
||||||
@property
|
|
||||||
def attributs(self):
|
|
||||||
"""The user attributes to be transmited to the service on successful validation"""
|
|
||||||
if self._attributs is not None:
|
|
||||||
return utils.json.loads(self._attributs)
|
|
||||||
|
|
||||||
@attributs.setter
|
|
||||||
def attributs(self, value):
|
|
||||||
"""attributs property setter"""
|
|
||||||
self._attributs = utils.json_encode(value)
|
|
||||||
|
|
||||||
class DoesNotExist(Exception):
|
class DoesNotExist(Exception):
|
||||||
"""raised in :meth:`Ticket.get` then ticket prefix and ticket classes mismatch"""
|
"""raised in :meth:`Ticket.get` then ticket prefix and ticket classes mismatch"""
|
||||||
pass
|
pass
|
||||||
|
@ -767,6 +747,33 @@ class Ticket(models.Model):
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return u"Ticket-%s" % self.pk
|
return u"Ticket-%s" % self.pk
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def send_slos(queryset_list):
|
||||||
|
"""
|
||||||
|
Send SLO requests to each ticket of each queryset of ``queryset_list``
|
||||||
|
|
||||||
|
:param list queryset_list: A list a :class:`Ticket` queryset
|
||||||
|
:return: A list of possibly encoutered :class:`Exception`
|
||||||
|
:rtype: list
|
||||||
|
"""
|
||||||
|
# sending SLO to timed-out validated tickets
|
||||||
|
async_list = []
|
||||||
|
session = FuturesSession(
|
||||||
|
executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
|
||||||
|
)
|
||||||
|
errors = []
|
||||||
|
for queryset in queryset_list:
|
||||||
|
for ticket in queryset:
|
||||||
|
ticket.logout(session, async_list)
|
||||||
|
queryset.delete()
|
||||||
|
for future in async_list:
|
||||||
|
if future: # pragma: no branch (should always be true)
|
||||||
|
try:
|
||||||
|
future.result()
|
||||||
|
except Exception as error:
|
||||||
|
errors.append(error)
|
||||||
|
return errors
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def clean_old_entries(cls):
|
def clean_old_entries(cls):
|
||||||
"""Remove old ticket and send SLO to timed-out services"""
|
"""Remove old ticket and send SLO to timed-out services"""
|
||||||
|
@ -779,25 +786,12 @@ class Ticket(models.Model):
|
||||||
Q(creation__lt=(timezone.now() - timedelta(seconds=cls.VALIDITY)))
|
Q(creation__lt=(timezone.now() - timedelta(seconds=cls.VALIDITY)))
|
||||||
)
|
)
|
||||||
).delete()
|
).delete()
|
||||||
|
|
||||||
# sending SLO to timed-out validated tickets
|
|
||||||
async_list = []
|
|
||||||
session = FuturesSession(
|
|
||||||
executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
|
|
||||||
)
|
|
||||||
queryset = cls.objects.filter(
|
queryset = cls.objects.filter(
|
||||||
creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT))
|
creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT))
|
||||||
)
|
)
|
||||||
for ticket in queryset:
|
for error in cls.send_slos([queryset]):
|
||||||
ticket.logout(session, async_list)
|
logger.warning("Error durring SLO %s" % error)
|
||||||
queryset.delete()
|
sys.stderr.write("%r\n" % error)
|
||||||
for future in async_list:
|
|
||||||
if future: # pragma: no branch (should always be true)
|
|
||||||
try:
|
|
||||||
future.result()
|
|
||||||
except Exception as error:
|
|
||||||
logger.warning("Error durring SLO %s" % error)
|
|
||||||
sys.stderr.write("%r\n" % error)
|
|
||||||
|
|
||||||
def logout(self, session, async_list=None):
|
def logout(self, session, async_list=None):
|
||||||
"""Send a SLO request to the ticket service"""
|
"""Send a SLO request to the ticket service"""
|
||||||
|
@ -811,16 +805,7 @@ class Ticket(models.Model):
|
||||||
self.user.username
|
self.user.username
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
xml = u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
|
xml = utils.logout_request(self.value)
|
||||||
ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
|
|
||||||
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
|
|
||||||
<samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
|
|
||||||
</samlp:LogoutRequest>""" % \
|
|
||||||
{
|
|
||||||
'id': utils.gen_saml_id(),
|
|
||||||
'datetime': timezone.now().isoformat(),
|
|
||||||
'ticket': self.value
|
|
||||||
}
|
|
||||||
if self.service_pattern.single_log_out_callback:
|
if self.service_pattern.single_log_out_callback:
|
||||||
url = self.service_pattern.single_log_out_callback
|
url = self.service_pattern.single_log_out_callback
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -261,7 +261,7 @@ class FederateAuthLoginLogoutTestCase(
|
||||||
# SLO for an unkown ticket should do nothing
|
# SLO for an unkown ticket should do nothing
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/federate/%s" % provider.suffix,
|
"/federate/%s" % provider.suffix,
|
||||||
{'logoutRequest': tests_utils.logout_request(utils.gen_st())}
|
{'logoutRequest': utils.logout_request(utils.gen_st())}
|
||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
self.assertEqual(response.content, b"ok")
|
self.assertEqual(response.content, b"ok")
|
||||||
|
@ -288,7 +288,7 @@ class FederateAuthLoginLogoutTestCase(
|
||||||
# 3 or 'CAS_2_SAML_1_0'
|
# 3 or 'CAS_2_SAML_1_0'
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/federate/%s" % provider.suffix,
|
"/federate/%s" % provider.suffix,
|
||||||
{'logoutRequest': tests_utils.logout_request(ticket)}
|
{'logoutRequest': utils.logout_request(ticket)}
|
||||||
)
|
)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
self.assertEqual(response.content, b"ok")
|
self.assertEqual(response.content, b"ok")
|
||||||
|
|
|
@ -340,17 +340,3 @@ class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
|
||||||
httpd_thread.daemon = True
|
httpd_thread.daemon = True
|
||||||
httpd_thread.start()
|
httpd_thread.start()
|
||||||
return (httpd, host, port)
|
return (httpd, host, port)
|
||||||
|
|
||||||
|
|
||||||
def logout_request(ticket):
|
|
||||||
"""build a SLO request XML, ready to be send"""
|
|
||||||
return u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
|
|
||||||
ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
|
|
||||||
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
|
|
||||||
<samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
|
|
||||||
</samlp:LogoutRequest>""" % \
|
|
||||||
{
|
|
||||||
'id': utils.gen_saml_id(),
|
|
||||||
'datetime': timezone.now().isoformat(),
|
|
||||||
'ticket': ticket
|
|
||||||
}
|
|
||||||
|
|
|
@ -17,6 +17,7 @@ from django.http import HttpResponseRedirect, HttpResponse
|
||||||
from django.contrib import messages
|
from django.contrib import messages
|
||||||
from django.contrib.messages import constants as DEFAULT_MESSAGE_LEVELS
|
from django.contrib.messages import constants as DEFAULT_MESSAGE_LEVELS
|
||||||
from django.core.serializers.json import DjangoJSONEncoder
|
from django.core.serializers.json import DjangoJSONEncoder
|
||||||
|
from django.utils import timezone
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
|
@ -680,3 +681,22 @@ def dictfetchall(cursor):
|
||||||
dict(zip(columns, row))
|
dict(zip(columns, row))
|
||||||
for row in cursor.fetchall()
|
for row in cursor.fetchall()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def logout_request(ticket):
|
||||||
|
"""
|
||||||
|
Forge a SLO logout request
|
||||||
|
|
||||||
|
:param unicode ticket: A ticket value
|
||||||
|
:return: A SLO XML body request
|
||||||
|
:rtype: unicode
|
||||||
|
"""
|
||||||
|
return u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
|
||||||
|
ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
|
||||||
|
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
|
||||||
|
<samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
|
||||||
|
</samlp:LogoutRequest>""" % {
|
||||||
|
'id': gen_saml_id(),
|
||||||
|
'datetime': timezone.now().isoformat(),
|
||||||
|
'ticket': ticket
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue