Factorize some code
This commit is contained in:
parent
4127af0db1
commit
232aafcace
2 changed files with 118 additions and 91 deletions
|
@ -736,6 +736,9 @@ class Ticket(models.Model):
|
||||||
#: requests.
|
#: requests.
|
||||||
TIMEOUT = settings.CAS_TICKET_TIMEOUT
|
TIMEOUT = settings.CAS_TICKET_TIMEOUT
|
||||||
|
|
||||||
|
class DoesNotExist(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return u"Ticket-%s" % self.pk
|
return u"Ticket-%s" % self.pk
|
||||||
|
|
||||||
|
@ -806,19 +809,108 @@ class Ticket(models.Model):
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_class(ticket):
|
def get_class(ticket, classes=None):
|
||||||
"""
|
"""
|
||||||
Return the ticket class of ``ticket``
|
Return the ticket class of ``ticket``
|
||||||
|
|
||||||
:param unicode ticket: A ticket
|
:param unicode ticket: A ticket
|
||||||
|
:param list classes: Optinal arguement. A list of possible :class:`Ticket` subclasses
|
||||||
:return: The class corresponding to ``ticket`` (:class:`ServiceTicket` or
|
:return: The class corresponding to ``ticket`` (:class:`ServiceTicket` or
|
||||||
:class:`ProxyTicket` or :class:`ProxyGrantingTicket`) if found, ``None`` otherwise.
|
:class:`ProxyTicket` or :class:`ProxyGrantingTicket`) if found among ``classes,
|
||||||
|
``None`` otherwise.
|
||||||
:rtype: :obj:`type` or :obj:`NoneType<types.NoneType>`
|
:rtype: :obj:`type` or :obj:`NoneType<types.NoneType>`
|
||||||
"""
|
"""
|
||||||
for ticket_class in [ServiceTicket, ProxyTicket, ProxyGrantingTicket]:
|
if classes is None: # pragma: no cover (not used)
|
||||||
|
classes = [ServiceTicket, ProxyTicket, ProxyGrantingTicket]
|
||||||
|
for ticket_class in classes:
|
||||||
if ticket.startswith(ticket_class.PREFIX):
|
if ticket.startswith(ticket_class.PREFIX):
|
||||||
return ticket_class
|
return ticket_class
|
||||||
|
|
||||||
|
def username(self):
|
||||||
|
"""
|
||||||
|
The username to send on ticket validation
|
||||||
|
|
||||||
|
:return: The value of the corresponding user attribute if
|
||||||
|
:attr:`service_pattern`.user_field is set, the user username otherwise.
|
||||||
|
"""
|
||||||
|
if self.service_pattern.user_field and self.user.attributs.get(
|
||||||
|
self.service_pattern.user_field
|
||||||
|
):
|
||||||
|
username = self.user.attributs[self.service_pattern.user_field]
|
||||||
|
if isinstance(username, list):
|
||||||
|
# the list is not empty because we wont generate a ticket with a user_field
|
||||||
|
# that evaluate to False
|
||||||
|
username = username[0]
|
||||||
|
else:
|
||||||
|
username = self.user.username
|
||||||
|
return username
|
||||||
|
|
||||||
|
def attributs_flat(self):
|
||||||
|
"""
|
||||||
|
generate attributes list for template rendering
|
||||||
|
|
||||||
|
:return: An list of (attribute name, attribute value) of all user attributes flatened
|
||||||
|
(no nested list)
|
||||||
|
:rtype: :obj:`list` of :obj:`tuple` of :obj:`unicode`
|
||||||
|
"""
|
||||||
|
attributes = []
|
||||||
|
for key, value in self.attributs.items():
|
||||||
|
if isinstance(value, list):
|
||||||
|
for elt in value:
|
||||||
|
attributes.append((key, elt))
|
||||||
|
else:
|
||||||
|
attributes.append((key, value))
|
||||||
|
return attributes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(cls, ticket, renew=False, service=None):
|
||||||
|
"""
|
||||||
|
Search the database for a valid ticket with provided arguments
|
||||||
|
|
||||||
|
:param unicode ticket: A ticket value
|
||||||
|
:param bool renew: Is authentication renewal needed
|
||||||
|
:param unicode service: Optional argument. The ticket service
|
||||||
|
:raises Ticket.DoesNotExist: if no class is found for the ticket prefix
|
||||||
|
:raises cls.DoesNotExist: if ``ticket`` value is not found in th database
|
||||||
|
:return: a :class:`Ticket` instance
|
||||||
|
:rtype: Ticket
|
||||||
|
"""
|
||||||
|
# If the method class is the ticket abstract class, search for the submited ticket
|
||||||
|
# class using its prefix. Assuming ticket is a ProxyTicket or a ServiceTicket
|
||||||
|
if cls == Ticket:
|
||||||
|
ticket_class = cls.get_class(ticket, classes=[ServiceTicket, ProxyTicket])
|
||||||
|
# else use the method class
|
||||||
|
else:
|
||||||
|
ticket_class = cls
|
||||||
|
# If ticket prefix is wrong, raise DoesNotExist
|
||||||
|
if cls != Ticket and not ticket.startswith(cls.PREFIX):
|
||||||
|
raise Ticket.DoesNotExist()
|
||||||
|
if ticket_class:
|
||||||
|
# search for the ticket that is not yet validated and is still valid
|
||||||
|
ticket_queryset = ticket_class.objects.filter(
|
||||||
|
value=ticket,
|
||||||
|
validate=False,
|
||||||
|
creation__gt=(timezone.now() - timedelta(seconds=ticket_class.VALIDITY))
|
||||||
|
)
|
||||||
|
# if service is specified, add it the the queryset
|
||||||
|
if service is not None:
|
||||||
|
ticket_queryset = ticket_queryset.filter(service=service)
|
||||||
|
# only require renew if renew is True, otherwise it do not matter if renew is True
|
||||||
|
# or False.
|
||||||
|
if renew:
|
||||||
|
ticket_queryset = ticket_queryset.filter(renew=True)
|
||||||
|
# fetch the ticket ``MultipleObjectsReturned`` is never raised as the ticket value
|
||||||
|
# is unique across the database
|
||||||
|
ticket = ticket_queryset.get()
|
||||||
|
# For ServiceTicket and Proxyticket, mark it as validated before returning
|
||||||
|
if ticket_class != ProxyGrantingTicket:
|
||||||
|
ticket.validate = True
|
||||||
|
ticket.save()
|
||||||
|
return ticket
|
||||||
|
# If no class found for the ticket, raise DoesNotExist
|
||||||
|
else:
|
||||||
|
raise Ticket.DoesNotExist()
|
||||||
|
|
||||||
|
|
||||||
@python_2_unicode_compatible
|
@python_2_unicode_compatible
|
||||||
class ServiceTicket(Ticket):
|
class ServiceTicket(Ticket):
|
||||||
|
|
|
@ -36,30 +36,13 @@ import cas_server.forms as forms
|
||||||
import cas_server.models as models
|
import cas_server.models as models
|
||||||
|
|
||||||
from .utils import json_response
|
from .utils import json_response
|
||||||
from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket
|
from .models import Ticket, ServiceTicket, ProxyTicket, ProxyGrantingTicket
|
||||||
from .models import ServicePattern, FederatedIendityProvider, FederatedUser
|
from .models import ServicePattern, FederatedIendityProvider, FederatedUser
|
||||||
from .federate import CASFederateValidateUser
|
from .federate import CASFederateValidateUser
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AttributesMixin(object):
|
|
||||||
"""mixin for the attributs methode"""
|
|
||||||
|
|
||||||
# pylint: disable=too-few-public-methods
|
|
||||||
|
|
||||||
def attributes(self):
|
|
||||||
"""regerate attributes list for template rendering"""
|
|
||||||
attributes = []
|
|
||||||
for key, value in self.ticket.attributs.items():
|
|
||||||
if isinstance(value, list):
|
|
||||||
for elt in value:
|
|
||||||
attributes.append((key, elt))
|
|
||||||
else:
|
|
||||||
attributes.append((key, value))
|
|
||||||
return attributes
|
|
||||||
|
|
||||||
|
|
||||||
class LogoutMixin(object):
|
class LogoutMixin(object):
|
||||||
"""destroy CAS session utils"""
|
"""destroy CAS session utils"""
|
||||||
def logout(self, all_session=False):
|
def logout(self, all_session=False):
|
||||||
|
@ -243,7 +226,8 @@ class FederateAuth(View):
|
||||||
:param django.http.HttpRequest request: The current request object
|
:param django.http.HttpRequest request: The current request object
|
||||||
:param cas_server.models.FederatedIendityProvider provider: the user identity provider
|
:param cas_server.models.FederatedIendityProvider provider: the user identity provider
|
||||||
:return: The user CAS client object
|
:return: The user CAS client object
|
||||||
:rtype: :class:`federate.CASFederateValidateUser<cas_server.federate.CASFederateValidateUser>`
|
:rtype: :class:`federate.CASFederateValidateUser
|
||||||
|
<cas_server.federate.CASFederateValidateUser>`
|
||||||
"""
|
"""
|
||||||
# compute the current url, ignoring ticket dans provider GET parameters
|
# compute the current url, ignoring ticket dans provider GET parameters
|
||||||
service_url = utils.get_current_url(request, {"ticket", "provider"})
|
service_url = utils.get_current_url(request, {"ticket", "provider"})
|
||||||
|
@ -961,18 +945,9 @@ class Validate(View):
|
||||||
# service and ticket parameters are mandatory
|
# service and ticket parameters are mandatory
|
||||||
if service and ticket:
|
if service and ticket:
|
||||||
try:
|
try:
|
||||||
ticket_queryset = ServiceTicket.objects.filter(
|
# search for the ticket, associated at service that is not yet validated but is
|
||||||
value=ticket,
|
# still valid
|
||||||
service=service,
|
ticket = ServiceTicket.get(ticket, renew, service)
|
||||||
validate=False,
|
|
||||||
creation__gt=(timezone.now() - timedelta(seconds=ServiceTicket.VALIDITY))
|
|
||||||
)
|
|
||||||
if renew:
|
|
||||||
ticket = ticket_queryset.get(renew=True)
|
|
||||||
else:
|
|
||||||
ticket = ticket_queryset.get()
|
|
||||||
ticket.validate = True
|
|
||||||
ticket.save()
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Validate: Service ticket %s validated, user %s authenticated on service %s" % (
|
"Validate: Service ticket %s validated, user %s authenticated on service %s" % (
|
||||||
ticket.value,
|
ticket.value,
|
||||||
|
@ -980,19 +955,8 @@ class Validate(View):
|
||||||
ticket.service
|
ticket.service
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if (ticket.service_pattern.user_field and
|
|
||||||
ticket.user.attributs.get(ticket.service_pattern.user_field)):
|
|
||||||
username = ticket.user.attributs.get(
|
|
||||||
ticket.service_pattern.user_field
|
|
||||||
)
|
|
||||||
if isinstance(username, list):
|
|
||||||
# the list is not empty because we wont generate a ticket with a user_field
|
|
||||||
# that evaluate to False
|
|
||||||
username = username[0]
|
|
||||||
else:
|
|
||||||
username = ticket.user.username
|
|
||||||
return HttpResponse(
|
return HttpResponse(
|
||||||
u"yes\n%s\n" % username,
|
u"yes\n%s\n" % ticket.username(),
|
||||||
content_type="text/plain; charset=utf-8"
|
content_type="text/plain; charset=utf-8"
|
||||||
)
|
)
|
||||||
except ServiceTicket.DoesNotExist:
|
except ServiceTicket.DoesNotExist:
|
||||||
|
@ -1018,6 +982,7 @@ class ValidateError(Exception):
|
||||||
code = None
|
code = None
|
||||||
#: The error message
|
#: The error message
|
||||||
msg = None
|
msg = None
|
||||||
|
|
||||||
def __init__(self, code, msg=""):
|
def __init__(self, code, msg=""):
|
||||||
self.code = code
|
self.code = code
|
||||||
self.msg = msg
|
self.msg = msg
|
||||||
|
@ -1087,19 +1052,11 @@ class ValidateService(View):
|
||||||
self.ticket, proxies = self.process_ticket()
|
self.ticket, proxies = self.process_ticket()
|
||||||
# prepare template rendering context
|
# prepare template rendering context
|
||||||
params = {
|
params = {
|
||||||
'username': self.ticket.user.username,
|
'username': self.ticket.username(),
|
||||||
'attributes': self.attributes(),
|
'attributes': self.ticket.attributs_flat(),
|
||||||
'proxies': proxies
|
'proxies': proxies
|
||||||
}
|
}
|
||||||
if (self.ticket.service_pattern.user_field and
|
# if pgtUrl is set, require https or localhost
|
||||||
self.ticket.user.attributs.get(self.ticket.service_pattern.user_field)):
|
|
||||||
params['username'] = self.ticket.user.attributs.get(
|
|
||||||
self.ticket.service_pattern.user_field
|
|
||||||
)
|
|
||||||
if isinstance(params['username'], list):
|
|
||||||
# the list is not empty because we wont generate a ticket with a user_field
|
|
||||||
# that evaluate to False
|
|
||||||
params['username'] = params['username'][0]
|
|
||||||
if self.pgt_url and (
|
if self.pgt_url and (
|
||||||
self.pgt_url.startswith("https://") or
|
self.pgt_url.startswith("https://") or
|
||||||
re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url)
|
re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url)
|
||||||
|
@ -1278,11 +1235,7 @@ class Proxy(View):
|
||||||
u'the service %s do not allow proxy ticket' % self.target_service
|
u'the service %s do not allow proxy ticket' % self.target_service
|
||||||
)
|
)
|
||||||
# is the proxy granting ticket valid
|
# is the proxy granting ticket valid
|
||||||
ticket = ProxyGrantingTicket.objects.get(
|
ticket = ProxyGrantingTicket.get(self.pgt)
|
||||||
value=self.pgt,
|
|
||||||
creation__gt=(timezone.now() - timedelta(seconds=ProxyGrantingTicket.VALIDITY)),
|
|
||||||
validate=False
|
|
||||||
)
|
|
||||||
# is the pgt user allowed on the target service
|
# is the pgt user allowed on the target service
|
||||||
pattern.check_user(ticket.user)
|
pattern.check_user(ticket.user)
|
||||||
pticket = ticket.user.get_ticket(
|
pticket = ticket.user.get_ticket(
|
||||||
|
@ -1304,7 +1257,7 @@ class Proxy(View):
|
||||||
{'ticket': pticket.value},
|
{'ticket': pticket.value},
|
||||||
content_type="text/xml; charset=utf-8"
|
content_type="text/xml; charset=utf-8"
|
||||||
)
|
)
|
||||||
except ProxyGrantingTicket.DoesNotExist:
|
except (Ticket.DoesNotExist, ProxyGrantingTicket.DoesNotExist):
|
||||||
raise ValidateError(u'INVALID_TICKET', u'PGT %s not found' % self.pgt)
|
raise ValidateError(u'INVALID_TICKET', u'PGT %s not found' % self.pgt)
|
||||||
except ServicePattern.DoesNotExist:
|
except ServicePattern.DoesNotExist:
|
||||||
raise ValidateError(u'UNAUTHORIZED_SERVICE', self.target_service)
|
raise ValidateError(u'UNAUTHORIZED_SERVICE', self.target_service)
|
||||||
|
@ -1322,6 +1275,7 @@ class SamlValidateError(Exception):
|
||||||
code = None
|
code = None
|
||||||
#: The error message
|
#: The error message
|
||||||
msg = None
|
msg = None
|
||||||
|
|
||||||
def __init__(self, code, msg=""):
|
def __init__(self, code, msg=""):
|
||||||
self.code = code
|
self.code = code
|
||||||
self.msg = msg
|
self.msg = msg
|
||||||
|
@ -1351,7 +1305,7 @@ class SamlValidateError(Exception):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SamlValidate(View, AttributesMixin):
|
class SamlValidate(View):
|
||||||
"""SAML ticket validation"""
|
"""SAML ticket validation"""
|
||||||
request = None
|
request = None
|
||||||
target = None
|
target = None
|
||||||
|
@ -1375,7 +1329,6 @@ class SamlValidate(View, AttributesMixin):
|
||||||
:return: the rendering of ``cas_server/samlValidate.xml`` if no error is raised,
|
:return: the rendering of ``cas_server/samlValidate.xml`` if no error is raised,
|
||||||
else the rendering of ``cas_server/samlValidateError.xml``.
|
else the rendering of ``cas_server/samlValidateError.xml``.
|
||||||
:rtype: django.http.HttpResponse
|
:rtype: django.http.HttpResponse
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.request = request
|
self.request = request
|
||||||
self.target = request.GET.get('TARGET')
|
self.target = request.GET.get('TARGET')
|
||||||
|
@ -1384,24 +1337,14 @@ class SamlValidate(View, AttributesMixin):
|
||||||
self.ticket = self.process_ticket()
|
self.ticket = self.process_ticket()
|
||||||
expire_instant = (self.ticket.creation +
|
expire_instant = (self.ticket.creation +
|
||||||
timedelta(seconds=self.ticket.VALIDITY)).isoformat()
|
timedelta(seconds=self.ticket.VALIDITY)).isoformat()
|
||||||
attributes = self.attributes()
|
|
||||||
params = {
|
params = {
|
||||||
'IssueInstant': timezone.now().isoformat(),
|
'IssueInstant': timezone.now().isoformat(),
|
||||||
'expireInstant': expire_instant,
|
'expireInstant': expire_instant,
|
||||||
'Recipient': self.target,
|
'Recipient': self.target,
|
||||||
'ResponseID': utils.gen_saml_id(),
|
'ResponseID': utils.gen_saml_id(),
|
||||||
'username': self.ticket.user.username,
|
'username': self.ticket.username(),
|
||||||
'attributes': attributes
|
'attributes': self.ticket.attributs_flat()
|
||||||
}
|
}
|
||||||
if (self.ticket.service_pattern.user_field and
|
|
||||||
self.ticket.user.attributs.get(self.ticket.service_pattern.user_field)):
|
|
||||||
params['username'] = self.ticket.user.attributs.get(
|
|
||||||
self.ticket.service_pattern.user_field
|
|
||||||
)
|
|
||||||
if isinstance(params['username'], list):
|
|
||||||
# the list is not empty because we wont generate a ticket with a user_field
|
|
||||||
# that evaluate to False
|
|
||||||
params['username'] = params['username'][0]
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"SamlValidate: ticket %s validated for user %s on service %s." % (
|
"SamlValidate: ticket %s validated for user %s on service %s." % (
|
||||||
self.ticket.value,
|
self.ticket.value,
|
||||||
|
@ -1435,20 +1378,7 @@ class SamlValidate(View, AttributesMixin):
|
||||||
try:
|
try:
|
||||||
auth_req = self.root.getchildren()[1].getchildren()[0]
|
auth_req = self.root.getchildren()[1].getchildren()[0]
|
||||||
ticket = auth_req.getchildren()[0].text
|
ticket = auth_req.getchildren()[0].text
|
||||||
ticket_class = models.Ticket.get_class(ticket)
|
ticket = models.Ticket.get(ticket)
|
||||||
if ticket_class:
|
|
||||||
ticket = ticket_class.objects.get(
|
|
||||||
value=ticket,
|
|
||||||
validate=False,
|
|
||||||
creation__gt=(timezone.now() - timedelta(seconds=ServiceTicket.VALIDITY))
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise SamlValidateError(
|
|
||||||
u'AuthnFailed',
|
|
||||||
u'ticket %s should begin with PT- or ST-' % ticket
|
|
||||||
)
|
|
||||||
ticket.validate = True
|
|
||||||
ticket.save()
|
|
||||||
if ticket.service != self.target:
|
if ticket.service != self.target:
|
||||||
raise SamlValidateError(
|
raise SamlValidateError(
|
||||||
u'AuthnFailed',
|
u'AuthnFailed',
|
||||||
|
@ -1457,5 +1387,10 @@ class SamlValidate(View, AttributesMixin):
|
||||||
return ticket
|
return ticket
|
||||||
except (IndexError, KeyError):
|
except (IndexError, KeyError):
|
||||||
raise SamlValidateError(u'VersionMismatch')
|
raise SamlValidateError(u'VersionMismatch')
|
||||||
|
except Ticket.DoesNotExist:
|
||||||
|
raise SamlValidateError(
|
||||||
|
u'AuthnFailed',
|
||||||
|
u'ticket %s should begin with PT- or ST-' % ticket
|
||||||
|
)
|
||||||
except (ServiceTicket.DoesNotExist, ProxyTicket.DoesNotExist):
|
except (ServiceTicket.DoesNotExist, ProxyTicket.DoesNotExist):
|
||||||
raise SamlValidateError(u'AuthnFailed', u'ticket %s not found' % ticket)
|
raise SamlValidateError(u'AuthnFailed', u'ticket %s not found' % ticket)
|
||||||
|
|
Loading…
Reference in a new issue