some code refactoring and better error handling on ticket validation

This commit is contained in:
Valentin Samir 2015-05-28 15:08:57 +02:00
parent 7e2917e977
commit 871baaac87
6 changed files with 141 additions and 141 deletions

View file

@ -32,7 +32,7 @@ class Migration(migrations.Migration):
('service', models.TextField()), ('service', models.TextField()),
('creation', models.DateTimeField(auto_now_add=True)), ('creation', models.DateTimeField(auto_now_add=True)),
('renew', models.BooleanField(default=False)), ('renew', models.BooleanField(default=False)),
('value', models.CharField(default=cas_server.models._gen_pgt, unique=True, max_length=255)), ('value', models.CharField(default=cas_server.utils.gen_pgt, unique=True, max_length=255)),
], ],
options={ options={
'abstract': False, 'abstract': False,
@ -48,7 +48,7 @@ class Migration(migrations.Migration):
('service', models.TextField()), ('service', models.TextField()),
('creation', models.DateTimeField(auto_now_add=True)), ('creation', models.DateTimeField(auto_now_add=True)),
('renew', models.BooleanField(default=False)), ('renew', models.BooleanField(default=False)),
('value', models.CharField(default=cas_server.models._gen_pt, unique=True, max_length=255)), ('value', models.CharField(default=cas_server.utils.gen_pt, unique=True, max_length=255)),
], ],
options={ options={
'abstract': False, 'abstract': False,
@ -81,7 +81,7 @@ class Migration(migrations.Migration):
('service', models.TextField()), ('service', models.TextField()),
('creation', models.DateTimeField(auto_now_add=True)), ('creation', models.DateTimeField(auto_now_add=True)),
('renew', models.BooleanField(default=False)), ('renew', models.BooleanField(default=False)),
('value', models.CharField(default=cas_server.models._gen_st, unique=True, max_length=255)), ('value', models.CharField(default=cas_server.utils.gen_st, unique=True, max_length=255)),
], ],
options={ options={
'abstract': False, 'abstract': False,

View file

@ -10,9 +10,6 @@
# #
# (c) 2015 Valentin Samir # (c) 2015 Valentin Samir
"""models for the app""" """models for the app"""
from . import default_settings
from django.conf import settings
from django.db import models from django.db import models
from django.contrib import messages from django.contrib import messages
from picklefield.fields import PickledObjectField from picklefield.fields import PickledObjectField
@ -21,41 +18,12 @@ from django.utils import timezone
import re import re
import os import os
import random
import string
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from requests_futures.sessions import FuturesSession from requests_futures.sessions import FuturesSession
from . import utils from . import utils
def _gen_ticket(prefix):
"""Generate a ticket with prefix `prefix`"""
return '%s-%s' % (
prefix,
''.join(
random.choice(
string.ascii_letters + string.digits
) for _ in range(settings.CAS_ST_LEN)
)
)
def _gen_st():
"""Generate a Service Ticket"""
return _gen_ticket('ST')
def _gen_pt():
"""Generate a Proxy Ticket"""
return _gen_ticket('PT')
def _gen_pgt():
"""Generate a Proxy Granting Ticket"""
return _gen_ticket('PGT')
def gen_pgtiou():
"""Generate a Proxy Granting Ticket IOU"""
return _gen_ticket('PGTIOU')
class User(models.Model): class User(models.Model):
"""A user logged into the CAS""" """A user logged into the CAS"""
username = models.CharField(max_length=30, unique=True) username = models.CharField(max_length=30, unique=True)
@ -83,10 +51,11 @@ class User(models.Model):
try: try:
future.result() future.result()
except Exception as error: except Exception as error:
error = utils.unpack_nested_exception(error)
messages.add_message( messages.add_message(
request, request,
messages.WARNING, messages.WARNING,
_(u'Error during service logout %r') % error _(u'Error during service logout %s') % error
) )
def get_ticket(self, ticket_class, service, service_pattern, renew): def get_ticket(self, ticket_class, service, service_pattern, renew):
@ -333,6 +302,7 @@ class Ticket(models.Model):
headers=headers headers=headers
) )
except Exception as error: except Exception as error:
error = utils.unpack_nested_exception(error)
messages.add_message( messages.add_message(
request, request,
messages.WARNING, messages.WARNING,
@ -342,17 +312,17 @@ class Ticket(models.Model):
class ServiceTicket(Ticket): class ServiceTicket(Ticket):
"""A Service Ticket""" """A Service Ticket"""
value = models.CharField(max_length=255, default=_gen_st, unique=True) value = models.CharField(max_length=255, default=utils.gen_st, unique=True)
def __unicode__(self): def __unicode__(self):
return u"ServiceTicket(%s, %s, %s)" % (self.user, self.value, self.service) return u"ServiceTicket(%s, %s, %s)" % (self.user, self.value, self.service)
class ProxyTicket(Ticket): class ProxyTicket(Ticket):
"""A Proxy Ticket""" """A Proxy Ticket"""
value = models.CharField(max_length=255, default=_gen_pt, unique=True) value = models.CharField(max_length=255, default=utils.gen_pt, unique=True)
def __unicode__(self): def __unicode__(self):
return u"ProxyTicket(%s, %s, %s)" % (self.user, self.value, self.service) return u"ProxyTicket(%s, %s, %s)" % (self.user, self.value, self.service)
class ProxyGrantingTicket(Ticket): class ProxyGrantingTicket(Ticket):
"""A Proxy Granting Ticket""" """A Proxy Granting Ticket"""
value = models.CharField(max_length=255, default=_gen_pgt, unique=True) value = models.CharField(max_length=255, default=utils.gen_pgt, unique=True)
def __unicode__(self): def __unicode__(self):
return u"ProxyGrantingTicket(%s, %s, %s)" % (self.user, self.value, self.service) return u"ProxyGrantingTicket(%s, %s, %s)" % (self.user, self.value, self.service)

View file

@ -7,8 +7,7 @@
MajorVersion="1" MinorVersion="1" Recipient="{{Recipient}}" MajorVersion="1" MinorVersion="1" Recipient="{{Recipient}}"
ResponseID="{{ResponseID}}"> ResponseID="{{ResponseID}}">
<Status> <Status>
<StatusCode Value="samlp:{{code}}"> <StatusCode Value="samlp:{{code}}">{{msg}}</StatusCode>
</StatusCode>
</Status> </Status>
</Response> </Response>
</SOAP-ENV:Body> </SOAP-ENV:Body>

View file

@ -1,5 +1,3 @@
<cas:serviceResponse xmlns:cas="http://www.yale.edu/tp/cas"> <cas:serviceResponse xmlns:cas="http://www.yale.edu/tp/cas">
<cas:authenticationFailure code="{{code}}"> <cas:authenticationFailure code="{{code}}">{{msg}}</cas:authenticationFailure>
{{msg}}
</cas:authenticationFailure>
</cas:serviceResponse> </cas:serviceResponse>

View file

@ -9,8 +9,14 @@
# #
# (c) 2015 Valentin Samir # (c) 2015 Valentin Samir
"""Some util function for the app""" """Some util function for the app"""
from . import default_settings
from django.conf import settings
import urlparse import urlparse
import urllib import urllib
import random
import string
def update_url(url, params): def update_url(url, params):
"""update params in the `url` query string""" """update params in the `url` query string"""
@ -19,3 +25,46 @@ def update_url(url, params):
query.update(params) query.update(params)
url_parts[4] = urllib.urlencode(query) url_parts[4] = urllib.urlencode(query)
return urlparse.urlunparse(url_parts) return urlparse.urlunparse(url_parts)
def unpack_nested_exception(error):
"""If exception are stacked, return the first one"""
i = 0
while True:
if error.args[i:]:
if isinstance(error.args[i], Exception):
error = error.args[i]
i = 0
else:
i += 1
else:
break
return error
def _gen_ticket(prefix):
"""Generate a ticket with prefix `prefix`"""
return '%s-%s' % (
prefix,
''.join(
random.choice(
string.ascii_letters + string.digits
) for _ in range(settings.CAS_ST_LEN)
)
)
def gen_st():
"""Generate a Service Ticket"""
return _gen_ticket('ST')
def gen_pt():
"""Generate a Proxy Ticket"""
return _gen_ticket('PT')
def gen_pgt():
"""Generate a Proxy Granting Ticket"""
return _gen_ticket('PGT')
def gen_pgtiou():
"""Generate a Proxy Granting Ticket IOU"""
return _gen_ticket('PGTIOU')

View file

@ -229,6 +229,15 @@ def validate(request):
return HttpResponse("no\n", content_type="text/plain") return HttpResponse("no\n", content_type="text/plain")
def _validate_error(request, code, msg=""):
"""render the serviceValidateError.xml template using `code` and `msg`"""
return render(
request,
"cas_server/serviceValidateError.xml",
{'code':code, 'msg':msg},
content_type="text/xml; charset=utf-8"
)
def ps_validate(request, ticket_type=None): def ps_validate(request, ticket_type=None):
"""factorization for serviceValidate and proxyValidate""" """factorization for serviceValidate and proxyValidate"""
if ticket_type is None: if ticket_type is None:
@ -238,22 +247,20 @@ def ps_validate(request, ticket_type=None):
pgt_url = request.GET.get('pgtUrl') pgt_url = request.GET.get('pgtUrl')
renew = True if request.GET.get('renew') else False renew = True if request.GET.get('renew') else False
if service and ticket: if service and ticket:
for typ in ticket_type: for elt in ticket_type:
if ticket.startswith(typ): if ticket.startswith(elt):
break break
else: else:
return render( return _validate_error(
request, request,
"cas_server/serviceValidateError.xml", 'INVALID_TICKET',
{'code':'INVALID_TICKET'}, 'tickets should begin with %s' % ' or '.join(ticket_type)
content_type="text/xml; charset=utf-8"
) )
try: try:
proxies = [] proxies = []
if ticket.startswith("ST"): if ticket.startswith("ST"):
ticket = models.ServiceTicket.objects.get( ticket = models.ServiceTicket.objects.get(
value=ticket, value=ticket,
service=service,
validate=False, validate=False,
renew=renew, renew=renew,
creation__gt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_VALIDITY)) creation__gt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_VALIDITY))
@ -261,7 +268,6 @@ def ps_validate(request, ticket_type=None):
elif ticket.startswith("PT"): elif ticket.startswith("PT"):
ticket = models.ProxyTicket.objects.get( ticket = models.ProxyTicket.objects.get(
value=ticket, value=ticket,
service=service,
validate=False, validate=False,
renew=renew, renew=renew,
creation__gt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_VALIDITY)) creation__gt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_VALIDITY))
@ -270,6 +276,8 @@ def ps_validate(request, ticket_type=None):
proxies.append(prox.url) proxies.append(prox.url)
ticket.validate = True ticket.validate = True
ticket.save() ticket.save()
if ticket.service != service:
return _validate_error(request, 'INVALID_SERVICE')
attributes = [] attributes = []
for key, value in ticket.attributs.items(): for key, value in ticket.attributs.items():
if isinstance(value, list): if isinstance(value, list):
@ -284,7 +292,7 @@ def ps_validate(request, ticket_type=None):
if pgt_url and pgt_url.startswith("https://"): if pgt_url and pgt_url.startswith("https://"):
pattern = models.ServicePattern.validate(pgt_url) pattern = models.ServicePattern.validate(pgt_url)
if pattern.proxy: if pattern.proxy:
proxyid = models.gen_pgtiou() proxyid = utils.gen_pgtiou()
pticket = models.ProxyGrantingTicket.objects.create( pticket = models.ProxyGrantingTicket.objects.create(
user=ticket.user, user=ticket.user,
service=pgt_url, service=pgt_url,
@ -304,19 +312,14 @@ def ps_validate(request, ticket_type=None):
params, params,
content_type="text/xml; charset=utf-8" content_type="text/xml; charset=utf-8"
) )
except requests.exceptions.SSLError: except requests.exceptions.SSLError as error:
return render( error = utils.unpack_nested_exception(error)
request, return _validate_error(request, 'INVALID_PROXY_CALLBACK', str(error))
"cas_server/serviceValidateError.xml",
{'code':'INVALID_PROXY_CALLBACK'},
content_type="text/xml; charset=utf-8"
)
else: else:
return render( return _validate_error(
request, request,
"cas_server/serviceValidateError.xml", 'INVALID_PROXY_CALLBACK',
{'code':'INVALID_PROXY_CALLBACK'}, "callback url not allowed by configuration"
content_type="text/xml; charset=utf-8"
) )
else: else:
return render( return render(
@ -326,25 +329,18 @@ def ps_validate(request, ticket_type=None):
content_type="text/xml; charset=utf-8" content_type="text/xml; charset=utf-8"
) )
except (models.ServiceTicket.DoesNotExist, models.ProxyTicket.DoesNotExist): except (models.ServiceTicket.DoesNotExist, models.ProxyTicket.DoesNotExist):
return render( return _validate_error(request, 'INVALID_TICKET', 'ticket not found')
request,
"cas_server/serviceValidateError.xml",
{'code':'INVALID_TICKET'},
content_type="text/xml; charset=utf-8"
)
except models.ServicePattern.DoesNotExist: except models.ServicePattern.DoesNotExist:
return render( return _validate_error(
request, request,
"cas_server/serviceValidateError.xml", 'INVALID_PROXY_CALLBACK',
{'code':'INVALID_TICKET'}, 'callback url not allowed by configuration'
content_type="text/xml; charset=utf-8"
) )
else: else:
return render( return _validate_error(
request, request,
"cas_server/serviceValidateError.xml", 'INVALID_REQUEST',
{'code':'INVALID_REQUEST'}, "you must specify a service and a ticket"
content_type="text/xml; charset=utf-8"
) )
def service_validate(request): def service_validate(request):
@ -378,46 +374,20 @@ def proxy(request):
content_type="text/xml; charset=utf-8" content_type="text/xml; charset=utf-8"
) )
except models.ProxyGrantingTicket.DoesNotExist: except models.ProxyGrantingTicket.DoesNotExist:
return render( return _validate_error(request, 'INVALID_TICKET', 'PGT not found')
request,
"cas_server/serviceValidateError.xml",
{'code':'INVALID_TICKET'},
content_type="text/xml; charset=utf-8"
)
except models.ServicePattern.DoesNotExist: except models.ServicePattern.DoesNotExist:
return render( return _validate_error(request, 'UNAUTHORIZED_SERVICE')
except (models.BadUsername, models.BadFilter, models.UserFieldNotDefined):
return _validate_error(
request, request,
"cas_server/serviceValidateError.xml", 'UNAUTHORIZED_USER',
{'code':'INVALID_TICKET'}, '%s not allowed on %s' % (ticket.user, target_service)
content_type="text/xml; charset=utf-8"
)
except models.BadUsername:
return render(
request,
"cas_server/serviceValidateError.xml",
{'code':'INVALID_TICKET'},
content_type="text/xml; charset=utf-8"
)
except models.BadFilter:
return render(
request,
"cas_server/serviceValidateError.xml",
{'code':'INVALID_TICKET'},
content_type="text/xml; charset=utf-8"
)
except models.UserFieldNotDefined:
return render(
request,
"cas_server/serviceValidateError.xml",
{'code':'INVALID_TICKET'},
content_type="text/xml; charset=utf-8"
) )
else: else:
return render( return _validate_error(
request, request,
"cas_server/serviceValidateError.xml", 'INVALID_REQUEST',
{'code':'INVALID_REQUEST'}, "you must specify and pgt and targetService"
content_type="text/xml; charset=utf-8"
) )
def p3_service_validate(request): def p3_service_validate(request):
@ -428,6 +398,15 @@ def p3_proxy_validate(request):
"""service/proxy ticket validation CAS 3.0""" """service/proxy ticket validation CAS 3.0"""
return proxy_validate(request) return proxy_validate(request)
def _saml_validate_error(request, code, msg=""):
"""render the samlValidateError.xml templace using `code` and `msg`"""
return render(
request,
"cas_server/samlValidateError.xml",
{'code':code, 'msg':msg},
content_type="text/xml; charset=utf-8"
)
@csrf_exempt @csrf_exempt
def saml_validate(request): def saml_validate(request):
"""checks the validity of a Service Ticket by a SAML 1.1 request""" """checks the validity of a Service Ticket by a SAML 1.1 request"""
@ -439,14 +418,32 @@ def saml_validate(request):
issue_instant = auth_req.attrib['IssueInstant'] issue_instant = auth_req.attrib['IssueInstant']
request_id = auth_req.attrib['RequestID'] request_id = auth_req.attrib['RequestID']
ticket = auth_req.getchildren()[0].text ticket = auth_req.getchildren()[0].text
if ticket.startswith("ST"):
ticket = models.ServiceTicket.objects.get( ticket = models.ServiceTicket.objects.get(
value=ticket, value=ticket,
service=target,
validate=False, validate=False,
creation__gt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_VALIDITY)) creation__gt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_VALIDITY))
) )
elif ticket.startswith("PT"):
ticket = models.ProxyTicket.objects.get(
value=ticket,
validate=False,
creation__gt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_VALIDITY))
)
else:
return _saml_validate_error(
request,
'AuthnFailed',
'ticket should begin with PT- or ST-'
)
ticket.validate = True ticket.validate = True
ticket.save() ticket.save()
if ticket.service != target:
return _saml_validate_error(
request,
'AuthnFailed',
'TARGET do not match ticket service'
)
expire_instant = (ticket.creation + \ expire_instant = (ticket.creation + \
timedelta(seconds=settings.CAS_TICKET_VALIDITY)).isoformat() timedelta(seconds=settings.CAS_TICKET_VALIDITY)).isoformat()
attributes = [] attributes = []
@ -473,26 +470,13 @@ def saml_validate(request):
params, params,
content_type="text/xml; charset=utf-8" content_type="text/xml; charset=utf-8"
) )
except IndexError: except (IndexError, KeyError):
return render( return _saml_validate_error(request, 'VersionMismatch')
request, except (models.ServiceTicket.DoesNotExist, models.ProxyTicket.DoesNotExist):
"cas_server/samlValidateError.xml", return _saml_validate_error(request, 'AuthnFailed', 'ticket not found')
{'code':'VersionMismatch'},
content_type="text/xml; charset=utf-8"
)
except KeyError:
return render(
request,
"cas_server/samlValidateError.xml",
{'code':'VersionMismatch'},
content_type="text/xml; charset=utf-8"
)
except models.ServiceTicket.DoesNotExist:
return render(
request,
"cas_server/samlValidateError.xml",
{'code':'AuthnFailed'},
content_type="text/xml; charset=utf-8"
)
else: else:
return redirect("login") return _saml_validate_error(
request,
'VersionMismatch',
'request should be send using POST'
)