diff --git a/cas_server/models.py b/cas_server/models.py index 23b9587..5ca1296 100644 --- a/cas_server/models.py +++ b/cas_server/models.py @@ -736,6 +736,9 @@ class Ticket(models.Model): #: requests. TIMEOUT = settings.CAS_TICKET_TIMEOUT + class DoesNotExist(Exception): + pass + def __str__(self): return u"Ticket-%s" % self.pk @@ -806,19 +809,108 @@ class Ticket(models.Model): ) @staticmethod - def get_class(ticket): + def get_class(ticket, classes=None): """ Return the ticket class of ``ticket`` :param unicode ticket: A ticket + :param list classes: Optinal arguement. A list of possible :class:`Ticket` subclasses :return: The class corresponding to ``ticket`` (:class:`ServiceTicket` or - :class:`ProxyTicket` or :class:`ProxyGrantingTicket`) if found, ``None`` otherwise. + :class:`ProxyTicket` or :class:`ProxyGrantingTicket`) if found among ``classes, + ``None`` otherwise. :rtype: :obj:`type` or :obj:`NoneType` """ - for ticket_class in [ServiceTicket, ProxyTicket, ProxyGrantingTicket]: + if classes is None: # pragma: no cover (not used) + classes = [ServiceTicket, ProxyTicket, ProxyGrantingTicket] + for ticket_class in classes: if ticket.startswith(ticket_class.PREFIX): return ticket_class + def username(self): + """ + The username to send on ticket validation + + :return: The value of the corresponding user attribute if + :attr:`service_pattern`.user_field is set, the user username otherwise. + """ + if self.service_pattern.user_field and self.user.attributs.get( + self.service_pattern.user_field + ): + username = self.user.attributs[self.service_pattern.user_field] + if isinstance(username, list): + # the list is not empty because we wont generate a ticket with a user_field + # that evaluate to False + username = username[0] + else: + username = self.user.username + return username + + def attributs_flat(self): + """ + generate attributes list for template rendering + + :return: An list of (attribute name, attribute value) of all user attributes flatened + (no nested list) + :rtype: :obj:`list` of :obj:`tuple` of :obj:`unicode` + """ + attributes = [] + for key, value in self.attributs.items(): + if isinstance(value, list): + for elt in value: + attributes.append((key, elt)) + else: + attributes.append((key, value)) + return attributes + + @classmethod + def get(cls, ticket, renew=False, service=None): + """ + Search the database for a valid ticket with provided arguments + + :param unicode ticket: A ticket value + :param bool renew: Is authentication renewal needed + :param unicode service: Optional argument. The ticket service + :raises Ticket.DoesNotExist: if no class is found for the ticket prefix + :raises cls.DoesNotExist: if ``ticket`` value is not found in th database + :return: a :class:`Ticket` instance + :rtype: Ticket + """ + # If the method class is the ticket abstract class, search for the submited ticket + # class using its prefix. Assuming ticket is a ProxyTicket or a ServiceTicket + if cls == Ticket: + ticket_class = cls.get_class(ticket, classes=[ServiceTicket, ProxyTicket]) + # else use the method class + else: + ticket_class = cls + # If ticket prefix is wrong, raise DoesNotExist + if cls != Ticket and not ticket.startswith(cls.PREFIX): + raise Ticket.DoesNotExist() + if ticket_class: + # search for the ticket that is not yet validated and is still valid + ticket_queryset = ticket_class.objects.filter( + value=ticket, + validate=False, + creation__gt=(timezone.now() - timedelta(seconds=ticket_class.VALIDITY)) + ) + # if service is specified, add it the the queryset + if service is not None: + ticket_queryset = ticket_queryset.filter(service=service) + # only require renew if renew is True, otherwise it do not matter if renew is True + # or False. + if renew: + ticket_queryset = ticket_queryset.filter(renew=True) + # fetch the ticket ``MultipleObjectsReturned`` is never raised as the ticket value + # is unique across the database + ticket = ticket_queryset.get() + # For ServiceTicket and Proxyticket, mark it as validated before returning + if ticket_class != ProxyGrantingTicket: + ticket.validate = True + ticket.save() + return ticket + # If no class found for the ticket, raise DoesNotExist + else: + raise Ticket.DoesNotExist() + @python_2_unicode_compatible class ServiceTicket(Ticket): diff --git a/cas_server/views.py b/cas_server/views.py index 2bc3b51..ea2bbfd 100644 --- a/cas_server/views.py +++ b/cas_server/views.py @@ -36,30 +36,13 @@ import cas_server.forms as forms import cas_server.models as models from .utils import json_response -from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket +from .models import Ticket, ServiceTicket, ProxyTicket, ProxyGrantingTicket from .models import ServicePattern, FederatedIendityProvider, FederatedUser from .federate import CASFederateValidateUser logger = logging.getLogger(__name__) -class AttributesMixin(object): - """mixin for the attributs methode""" - - # pylint: disable=too-few-public-methods - - def attributes(self): - """regerate attributes list for template rendering""" - attributes = [] - for key, value in self.ticket.attributs.items(): - if isinstance(value, list): - for elt in value: - attributes.append((key, elt)) - else: - attributes.append((key, value)) - return attributes - - class LogoutMixin(object): """destroy CAS session utils""" def logout(self, all_session=False): @@ -243,7 +226,8 @@ class FederateAuth(View): :param django.http.HttpRequest request: The current request object :param cas_server.models.FederatedIendityProvider provider: the user identity provider :return: The user CAS client object - :rtype: :class:`federate.CASFederateValidateUser` + :rtype: :class:`federate.CASFederateValidateUser + ` """ # compute the current url, ignoring ticket dans provider GET parameters service_url = utils.get_current_url(request, {"ticket", "provider"}) @@ -961,18 +945,9 @@ class Validate(View): # service and ticket parameters are mandatory if service and ticket: try: - ticket_queryset = ServiceTicket.objects.filter( - value=ticket, - service=service, - validate=False, - creation__gt=(timezone.now() - timedelta(seconds=ServiceTicket.VALIDITY)) - ) - if renew: - ticket = ticket_queryset.get(renew=True) - else: - ticket = ticket_queryset.get() - ticket.validate = True - ticket.save() + # search for the ticket, associated at service that is not yet validated but is + # still valid + ticket = ServiceTicket.get(ticket, renew, service) logger.info( "Validate: Service ticket %s validated, user %s authenticated on service %s" % ( ticket.value, @@ -980,19 +955,8 @@ class Validate(View): ticket.service ) ) - if (ticket.service_pattern.user_field and - ticket.user.attributs.get(ticket.service_pattern.user_field)): - username = ticket.user.attributs.get( - ticket.service_pattern.user_field - ) - if isinstance(username, list): - # the list is not empty because we wont generate a ticket with a user_field - # that evaluate to False - username = username[0] - else: - username = ticket.user.username return HttpResponse( - u"yes\n%s\n" % username, + u"yes\n%s\n" % ticket.username(), content_type="text/plain; charset=utf-8" ) except ServiceTicket.DoesNotExist: @@ -1018,6 +982,7 @@ class ValidateError(Exception): code = None #: The error message msg = None + def __init__(self, code, msg=""): self.code = code self.msg = msg @@ -1087,19 +1052,11 @@ class ValidateService(View): self.ticket, proxies = self.process_ticket() # prepare template rendering context params = { - 'username': self.ticket.user.username, - 'attributes': self.attributes(), + 'username': self.ticket.username(), + 'attributes': self.ticket.attributs_flat(), 'proxies': proxies } - if (self.ticket.service_pattern.user_field and - self.ticket.user.attributs.get(self.ticket.service_pattern.user_field)): - params['username'] = self.ticket.user.attributs.get( - self.ticket.service_pattern.user_field - ) - if isinstance(params['username'], list): - # the list is not empty because we wont generate a ticket with a user_field - # that evaluate to False - params['username'] = params['username'][0] + # if pgtUrl is set, require https or localhost if self.pgt_url and ( self.pgt_url.startswith("https://") or re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url) @@ -1278,11 +1235,7 @@ class Proxy(View): u'the service %s do not allow proxy ticket' % self.target_service ) # is the proxy granting ticket valid - ticket = ProxyGrantingTicket.objects.get( - value=self.pgt, - creation__gt=(timezone.now() - timedelta(seconds=ProxyGrantingTicket.VALIDITY)), - validate=False - ) + ticket = ProxyGrantingTicket.get(self.pgt) # is the pgt user allowed on the target service pattern.check_user(ticket.user) pticket = ticket.user.get_ticket( @@ -1304,7 +1257,7 @@ class Proxy(View): {'ticket': pticket.value}, content_type="text/xml; charset=utf-8" ) - except ProxyGrantingTicket.DoesNotExist: + except (Ticket.DoesNotExist, ProxyGrantingTicket.DoesNotExist): raise ValidateError(u'INVALID_TICKET', u'PGT %s not found' % self.pgt) except ServicePattern.DoesNotExist: raise ValidateError(u'UNAUTHORIZED_SERVICE', self.target_service) @@ -1322,6 +1275,7 @@ class SamlValidateError(Exception): code = None #: The error message msg = None + def __init__(self, code, msg=""): self.code = code self.msg = msg @@ -1351,7 +1305,7 @@ class SamlValidateError(Exception): ) -class SamlValidate(View, AttributesMixin): +class SamlValidate(View): """SAML ticket validation""" request = None target = None @@ -1375,7 +1329,6 @@ class SamlValidate(View, AttributesMixin): :return: the rendering of ``cas_server/samlValidate.xml`` if no error is raised, else the rendering of ``cas_server/samlValidateError.xml``. :rtype: django.http.HttpResponse - """ self.request = request self.target = request.GET.get('TARGET') @@ -1384,24 +1337,14 @@ class SamlValidate(View, AttributesMixin): self.ticket = self.process_ticket() expire_instant = (self.ticket.creation + timedelta(seconds=self.ticket.VALIDITY)).isoformat() - attributes = self.attributes() params = { 'IssueInstant': timezone.now().isoformat(), 'expireInstant': expire_instant, 'Recipient': self.target, 'ResponseID': utils.gen_saml_id(), - 'username': self.ticket.user.username, - 'attributes': attributes + 'username': self.ticket.username(), + 'attributes': self.ticket.attributs_flat() } - if (self.ticket.service_pattern.user_field and - self.ticket.user.attributs.get(self.ticket.service_pattern.user_field)): - params['username'] = self.ticket.user.attributs.get( - self.ticket.service_pattern.user_field - ) - if isinstance(params['username'], list): - # the list is not empty because we wont generate a ticket with a user_field - # that evaluate to False - params['username'] = params['username'][0] logger.info( "SamlValidate: ticket %s validated for user %s on service %s." % ( self.ticket.value, @@ -1435,20 +1378,7 @@ class SamlValidate(View, AttributesMixin): try: auth_req = self.root.getchildren()[1].getchildren()[0] ticket = auth_req.getchildren()[0].text - ticket_class = models.Ticket.get_class(ticket) - if ticket_class: - ticket = ticket_class.objects.get( - value=ticket, - validate=False, - creation__gt=(timezone.now() - timedelta(seconds=ServiceTicket.VALIDITY)) - ) - else: - raise SamlValidateError( - u'AuthnFailed', - u'ticket %s should begin with PT- or ST-' % ticket - ) - ticket.validate = True - ticket.save() + ticket = models.Ticket.get(ticket) if ticket.service != self.target: raise SamlValidateError( u'AuthnFailed', @@ -1457,5 +1387,10 @@ class SamlValidate(View, AttributesMixin): return ticket except (IndexError, KeyError): raise SamlValidateError(u'VersionMismatch') + except Ticket.DoesNotExist: + raise SamlValidateError( + u'AuthnFailed', + u'ticket %s should begin with PT- or ST-' % ticket + ) except (ServiceTicket.DoesNotExist, ProxyTicket.DoesNotExist): raise SamlValidateError(u'AuthnFailed', u'ticket %s not found' % ticket)