Factorize some code

This commit is contained in:
Valentin Samir 2016-07-23 18:43:27 +02:00
parent 4127af0db1
commit 232aafcace
2 changed files with 118 additions and 91 deletions

View file

@ -736,6 +736,9 @@ class Ticket(models.Model):
#: requests. #: requests.
TIMEOUT = settings.CAS_TICKET_TIMEOUT TIMEOUT = settings.CAS_TICKET_TIMEOUT
class DoesNotExist(Exception):
pass
def __str__(self): def __str__(self):
return u"Ticket-%s" % self.pk return u"Ticket-%s" % self.pk
@ -806,19 +809,108 @@ class Ticket(models.Model):
) )
@staticmethod @staticmethod
def get_class(ticket): def get_class(ticket, classes=None):
""" """
Return the ticket class of ``ticket`` Return the ticket class of ``ticket``
:param unicode ticket: A ticket :param unicode ticket: A ticket
:param list classes: Optinal arguement. A list of possible :class:`Ticket` subclasses
:return: The class corresponding to ``ticket`` (:class:`ServiceTicket` or :return: The class corresponding to ``ticket`` (:class:`ServiceTicket` or
:class:`ProxyTicket` or :class:`ProxyGrantingTicket`) if found, ``None`` otherwise. :class:`ProxyTicket` or :class:`ProxyGrantingTicket`) if found among ``classes,
``None`` otherwise.
:rtype: :obj:`type` or :obj:`NoneType<types.NoneType>` :rtype: :obj:`type` or :obj:`NoneType<types.NoneType>`
""" """
for ticket_class in [ServiceTicket, ProxyTicket, ProxyGrantingTicket]: if classes is None: # pragma: no cover (not used)
classes = [ServiceTicket, ProxyTicket, ProxyGrantingTicket]
for ticket_class in classes:
if ticket.startswith(ticket_class.PREFIX): if ticket.startswith(ticket_class.PREFIX):
return ticket_class return ticket_class
def username(self):
"""
The username to send on ticket validation
:return: The value of the corresponding user attribute if
:attr:`service_pattern`.user_field is set, the user username otherwise.
"""
if self.service_pattern.user_field and self.user.attributs.get(
self.service_pattern.user_field
):
username = self.user.attributs[self.service_pattern.user_field]
if isinstance(username, list):
# the list is not empty because we wont generate a ticket with a user_field
# that evaluate to False
username = username[0]
else:
username = self.user.username
return username
def attributs_flat(self):
"""
generate attributes list for template rendering
:return: An list of (attribute name, attribute value) of all user attributes flatened
(no nested list)
:rtype: :obj:`list` of :obj:`tuple` of :obj:`unicode`
"""
attributes = []
for key, value in self.attributs.items():
if isinstance(value, list):
for elt in value:
attributes.append((key, elt))
else:
attributes.append((key, value))
return attributes
@classmethod
def get(cls, ticket, renew=False, service=None):
"""
Search the database for a valid ticket with provided arguments
:param unicode ticket: A ticket value
:param bool renew: Is authentication renewal needed
:param unicode service: Optional argument. The ticket service
:raises Ticket.DoesNotExist: if no class is found for the ticket prefix
:raises cls.DoesNotExist: if ``ticket`` value is not found in th database
:return: a :class:`Ticket` instance
:rtype: Ticket
"""
# If the method class is the ticket abstract class, search for the submited ticket
# class using its prefix. Assuming ticket is a ProxyTicket or a ServiceTicket
if cls == Ticket:
ticket_class = cls.get_class(ticket, classes=[ServiceTicket, ProxyTicket])
# else use the method class
else:
ticket_class = cls
# If ticket prefix is wrong, raise DoesNotExist
if cls != Ticket and not ticket.startswith(cls.PREFIX):
raise Ticket.DoesNotExist()
if ticket_class:
# search for the ticket that is not yet validated and is still valid
ticket_queryset = ticket_class.objects.filter(
value=ticket,
validate=False,
creation__gt=(timezone.now() - timedelta(seconds=ticket_class.VALIDITY))
)
# if service is specified, add it the the queryset
if service is not None:
ticket_queryset = ticket_queryset.filter(service=service)
# only require renew if renew is True, otherwise it do not matter if renew is True
# or False.
if renew:
ticket_queryset = ticket_queryset.filter(renew=True)
# fetch the ticket ``MultipleObjectsReturned`` is never raised as the ticket value
# is unique across the database
ticket = ticket_queryset.get()
# For ServiceTicket and Proxyticket, mark it as validated before returning
if ticket_class != ProxyGrantingTicket:
ticket.validate = True
ticket.save()
return ticket
# If no class found for the ticket, raise DoesNotExist
else:
raise Ticket.DoesNotExist()
@python_2_unicode_compatible @python_2_unicode_compatible
class ServiceTicket(Ticket): class ServiceTicket(Ticket):

View file

@ -36,30 +36,13 @@ import cas_server.forms as forms
import cas_server.models as models import cas_server.models as models
from .utils import json_response from .utils import json_response
from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket from .models import Ticket, ServiceTicket, ProxyTicket, ProxyGrantingTicket
from .models import ServicePattern, FederatedIendityProvider, FederatedUser from .models import ServicePattern, FederatedIendityProvider, FederatedUser
from .federate import CASFederateValidateUser from .federate import CASFederateValidateUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AttributesMixin(object):
"""mixin for the attributs methode"""
# pylint: disable=too-few-public-methods
def attributes(self):
"""regerate attributes list for template rendering"""
attributes = []
for key, value in self.ticket.attributs.items():
if isinstance(value, list):
for elt in value:
attributes.append((key, elt))
else:
attributes.append((key, value))
return attributes
class LogoutMixin(object): class LogoutMixin(object):
"""destroy CAS session utils""" """destroy CAS session utils"""
def logout(self, all_session=False): def logout(self, all_session=False):
@ -243,7 +226,8 @@ class FederateAuth(View):
:param django.http.HttpRequest request: The current request object :param django.http.HttpRequest request: The current request object
:param cas_server.models.FederatedIendityProvider provider: the user identity provider :param cas_server.models.FederatedIendityProvider provider: the user identity provider
:return: The user CAS client object :return: The user CAS client object
:rtype: :class:`federate.CASFederateValidateUser<cas_server.federate.CASFederateValidateUser>` :rtype: :class:`federate.CASFederateValidateUser
<cas_server.federate.CASFederateValidateUser>`
""" """
# compute the current url, ignoring ticket dans provider GET parameters # compute the current url, ignoring ticket dans provider GET parameters
service_url = utils.get_current_url(request, {"ticket", "provider"}) service_url = utils.get_current_url(request, {"ticket", "provider"})
@ -961,18 +945,9 @@ class Validate(View):
# service and ticket parameters are mandatory # service and ticket parameters are mandatory
if service and ticket: if service and ticket:
try: try:
ticket_queryset = ServiceTicket.objects.filter( # search for the ticket, associated at service that is not yet validated but is
value=ticket, # still valid
service=service, ticket = ServiceTicket.get(ticket, renew, service)
validate=False,
creation__gt=(timezone.now() - timedelta(seconds=ServiceTicket.VALIDITY))
)
if renew:
ticket = ticket_queryset.get(renew=True)
else:
ticket = ticket_queryset.get()
ticket.validate = True
ticket.save()
logger.info( logger.info(
"Validate: Service ticket %s validated, user %s authenticated on service %s" % ( "Validate: Service ticket %s validated, user %s authenticated on service %s" % (
ticket.value, ticket.value,
@ -980,19 +955,8 @@ class Validate(View):
ticket.service ticket.service
) )
) )
if (ticket.service_pattern.user_field and
ticket.user.attributs.get(ticket.service_pattern.user_field)):
username = ticket.user.attributs.get(
ticket.service_pattern.user_field
)
if isinstance(username, list):
# the list is not empty because we wont generate a ticket with a user_field
# that evaluate to False
username = username[0]
else:
username = ticket.user.username
return HttpResponse( return HttpResponse(
u"yes\n%s\n" % username, u"yes\n%s\n" % ticket.username(),
content_type="text/plain; charset=utf-8" content_type="text/plain; charset=utf-8"
) )
except ServiceTicket.DoesNotExist: except ServiceTicket.DoesNotExist:
@ -1018,6 +982,7 @@ class ValidateError(Exception):
code = None code = None
#: The error message #: The error message
msg = None msg = None
def __init__(self, code, msg=""): def __init__(self, code, msg=""):
self.code = code self.code = code
self.msg = msg self.msg = msg
@ -1087,19 +1052,11 @@ class ValidateService(View):
self.ticket, proxies = self.process_ticket() self.ticket, proxies = self.process_ticket()
# prepare template rendering context # prepare template rendering context
params = { params = {
'username': self.ticket.user.username, 'username': self.ticket.username(),
'attributes': self.attributes(), 'attributes': self.ticket.attributs_flat(),
'proxies': proxies 'proxies': proxies
} }
if (self.ticket.service_pattern.user_field and # if pgtUrl is set, require https or localhost
self.ticket.user.attributs.get(self.ticket.service_pattern.user_field)):
params['username'] = self.ticket.user.attributs.get(
self.ticket.service_pattern.user_field
)
if isinstance(params['username'], list):
# the list is not empty because we wont generate a ticket with a user_field
# that evaluate to False
params['username'] = params['username'][0]
if self.pgt_url and ( if self.pgt_url and (
self.pgt_url.startswith("https://") or self.pgt_url.startswith("https://") or
re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url) re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url)
@ -1278,11 +1235,7 @@ class Proxy(View):
u'the service %s do not allow proxy ticket' % self.target_service u'the service %s do not allow proxy ticket' % self.target_service
) )
# is the proxy granting ticket valid # is the proxy granting ticket valid
ticket = ProxyGrantingTicket.objects.get( ticket = ProxyGrantingTicket.get(self.pgt)
value=self.pgt,
creation__gt=(timezone.now() - timedelta(seconds=ProxyGrantingTicket.VALIDITY)),
validate=False
)
# is the pgt user allowed on the target service # is the pgt user allowed on the target service
pattern.check_user(ticket.user) pattern.check_user(ticket.user)
pticket = ticket.user.get_ticket( pticket = ticket.user.get_ticket(
@ -1304,7 +1257,7 @@ class Proxy(View):
{'ticket': pticket.value}, {'ticket': pticket.value},
content_type="text/xml; charset=utf-8" content_type="text/xml; charset=utf-8"
) )
except ProxyGrantingTicket.DoesNotExist: except (Ticket.DoesNotExist, ProxyGrantingTicket.DoesNotExist):
raise ValidateError(u'INVALID_TICKET', u'PGT %s not found' % self.pgt) raise ValidateError(u'INVALID_TICKET', u'PGT %s not found' % self.pgt)
except ServicePattern.DoesNotExist: except ServicePattern.DoesNotExist:
raise ValidateError(u'UNAUTHORIZED_SERVICE', self.target_service) raise ValidateError(u'UNAUTHORIZED_SERVICE', self.target_service)
@ -1322,6 +1275,7 @@ class SamlValidateError(Exception):
code = None code = None
#: The error message #: The error message
msg = None msg = None
def __init__(self, code, msg=""): def __init__(self, code, msg=""):
self.code = code self.code = code
self.msg = msg self.msg = msg
@ -1351,7 +1305,7 @@ class SamlValidateError(Exception):
) )
class SamlValidate(View, AttributesMixin): class SamlValidate(View):
"""SAML ticket validation""" """SAML ticket validation"""
request = None request = None
target = None target = None
@ -1375,7 +1329,6 @@ class SamlValidate(View, AttributesMixin):
:return: the rendering of ``cas_server/samlValidate.xml`` if no error is raised, :return: the rendering of ``cas_server/samlValidate.xml`` if no error is raised,
else the rendering of ``cas_server/samlValidateError.xml``. else the rendering of ``cas_server/samlValidateError.xml``.
:rtype: django.http.HttpResponse :rtype: django.http.HttpResponse
""" """
self.request = request self.request = request
self.target = request.GET.get('TARGET') self.target = request.GET.get('TARGET')
@ -1384,24 +1337,14 @@ class SamlValidate(View, AttributesMixin):
self.ticket = self.process_ticket() self.ticket = self.process_ticket()
expire_instant = (self.ticket.creation + expire_instant = (self.ticket.creation +
timedelta(seconds=self.ticket.VALIDITY)).isoformat() timedelta(seconds=self.ticket.VALIDITY)).isoformat()
attributes = self.attributes()
params = { params = {
'IssueInstant': timezone.now().isoformat(), 'IssueInstant': timezone.now().isoformat(),
'expireInstant': expire_instant, 'expireInstant': expire_instant,
'Recipient': self.target, 'Recipient': self.target,
'ResponseID': utils.gen_saml_id(), 'ResponseID': utils.gen_saml_id(),
'username': self.ticket.user.username, 'username': self.ticket.username(),
'attributes': attributes 'attributes': self.ticket.attributs_flat()
} }
if (self.ticket.service_pattern.user_field and
self.ticket.user.attributs.get(self.ticket.service_pattern.user_field)):
params['username'] = self.ticket.user.attributs.get(
self.ticket.service_pattern.user_field
)
if isinstance(params['username'], list):
# the list is not empty because we wont generate a ticket with a user_field
# that evaluate to False
params['username'] = params['username'][0]
logger.info( logger.info(
"SamlValidate: ticket %s validated for user %s on service %s." % ( "SamlValidate: ticket %s validated for user %s on service %s." % (
self.ticket.value, self.ticket.value,
@ -1435,20 +1378,7 @@ class SamlValidate(View, AttributesMixin):
try: try:
auth_req = self.root.getchildren()[1].getchildren()[0] auth_req = self.root.getchildren()[1].getchildren()[0]
ticket = auth_req.getchildren()[0].text ticket = auth_req.getchildren()[0].text
ticket_class = models.Ticket.get_class(ticket) ticket = models.Ticket.get(ticket)
if ticket_class:
ticket = ticket_class.objects.get(
value=ticket,
validate=False,
creation__gt=(timezone.now() - timedelta(seconds=ServiceTicket.VALIDITY))
)
else:
raise SamlValidateError(
u'AuthnFailed',
u'ticket %s should begin with PT- or ST-' % ticket
)
ticket.validate = True
ticket.save()
if ticket.service != self.target: if ticket.service != self.target:
raise SamlValidateError( raise SamlValidateError(
u'AuthnFailed', u'AuthnFailed',
@ -1457,5 +1387,10 @@ class SamlValidate(View, AttributesMixin):
return ticket return ticket
except (IndexError, KeyError): except (IndexError, KeyError):
raise SamlValidateError(u'VersionMismatch') raise SamlValidateError(u'VersionMismatch')
except Ticket.DoesNotExist:
raise SamlValidateError(
u'AuthnFailed',
u'ticket %s should begin with PT- or ST-' % ticket
)
except (ServiceTicket.DoesNotExist, ProxyTicket.DoesNotExist): except (ServiceTicket.DoesNotExist, ProxyTicket.DoesNotExist):
raise SamlValidateError(u'AuthnFailed', u'ticket %s not found' % ticket) raise SamlValidateError(u'AuthnFailed', u'ticket %s not found' % ticket)