From 871baaac87e14db92cc8e3b8a3f4c935202c13d0 Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Thu, 28 May 2015 15:08:57 +0200 Subject: [PATCH] some code refactoring and better error handling on ticket validation --- cas_server/migrations/0001_initial.py | 6 +- cas_server/models.py | 42 +---- .../cas_server/samlValidateError.xml | 3 +- .../cas_server/serviceValidateError.xml | 4 +- cas_server/utils.py | 49 +++++ cas_server/views.py | 178 ++++++++---------- 6 files changed, 141 insertions(+), 141 deletions(-) diff --git a/cas_server/migrations/0001_initial.py b/cas_server/migrations/0001_initial.py index dcde945..6a1294a 100644 --- a/cas_server/migrations/0001_initial.py +++ b/cas_server/migrations/0001_initial.py @@ -32,7 +32,7 @@ class Migration(migrations.Migration): ('service', models.TextField()), ('creation', models.DateTimeField(auto_now_add=True)), ('renew', models.BooleanField(default=False)), - ('value', models.CharField(default=cas_server.models._gen_pgt, unique=True, max_length=255)), + ('value', models.CharField(default=cas_server.utils.gen_pgt, unique=True, max_length=255)), ], options={ 'abstract': False, @@ -48,7 +48,7 @@ class Migration(migrations.Migration): ('service', models.TextField()), ('creation', models.DateTimeField(auto_now_add=True)), ('renew', models.BooleanField(default=False)), - ('value', models.CharField(default=cas_server.models._gen_pt, unique=True, max_length=255)), + ('value', models.CharField(default=cas_server.utils.gen_pt, unique=True, max_length=255)), ], options={ 'abstract': False, @@ -81,7 +81,7 @@ class Migration(migrations.Migration): ('service', models.TextField()), ('creation', models.DateTimeField(auto_now_add=True)), ('renew', models.BooleanField(default=False)), - ('value', models.CharField(default=cas_server.models._gen_st, unique=True, max_length=255)), + ('value', models.CharField(default=cas_server.utils.gen_st, unique=True, max_length=255)), ], options={ 'abstract': False, diff --git a/cas_server/models.py b/cas_server/models.py index 1b47d3c..7ae4ab7 100644 --- a/cas_server/models.py +++ b/cas_server/models.py @@ -10,9 +10,6 @@ # # (c) 2015 Valentin Samir """models for the app""" -from . import default_settings - -from django.conf import settings from django.db import models from django.contrib import messages from picklefield.fields import PickledObjectField @@ -21,41 +18,12 @@ from django.utils import timezone import re import os -import random -import string from concurrent.futures import ThreadPoolExecutor from requests_futures.sessions import FuturesSession from . import utils -def _gen_ticket(prefix): - """Generate a ticket with prefix `prefix`""" - return '%s-%s' % ( - prefix, - ''.join( - random.choice( - string.ascii_letters + string.digits - ) for _ in range(settings.CAS_ST_LEN) - ) - ) - -def _gen_st(): - """Generate a Service Ticket""" - return _gen_ticket('ST') - -def _gen_pt(): - """Generate a Proxy Ticket""" - return _gen_ticket('PT') - -def _gen_pgt(): - """Generate a Proxy Granting Ticket""" - return _gen_ticket('PGT') - -def gen_pgtiou(): - """Generate a Proxy Granting Ticket IOU""" - return _gen_ticket('PGTIOU') - class User(models.Model): """A user logged into the CAS""" username = models.CharField(max_length=30, unique=True) @@ -83,10 +51,11 @@ class User(models.Model): try: future.result() except Exception as error: + error = utils.unpack_nested_exception(error) messages.add_message( request, messages.WARNING, - _(u'Error during service logout %r') % error + _(u'Error during service logout %s') % error ) def get_ticket(self, ticket_class, service, service_pattern, renew): @@ -333,6 +302,7 @@ class Ticket(models.Model): headers=headers ) except Exception as error: + error = utils.unpack_nested_exception(error) messages.add_message( request, messages.WARNING, @@ -342,17 +312,17 @@ class Ticket(models.Model): class ServiceTicket(Ticket): """A Service Ticket""" - value = models.CharField(max_length=255, default=_gen_st, unique=True) + value = models.CharField(max_length=255, default=utils.gen_st, unique=True) def __unicode__(self): return u"ServiceTicket(%s, %s, %s)" % (self.user, self.value, self.service) class ProxyTicket(Ticket): """A Proxy Ticket""" - value = models.CharField(max_length=255, default=_gen_pt, unique=True) + value = models.CharField(max_length=255, default=utils.gen_pt, unique=True) def __unicode__(self): return u"ProxyTicket(%s, %s, %s)" % (self.user, self.value, self.service) class ProxyGrantingTicket(Ticket): """A Proxy Granting Ticket""" - value = models.CharField(max_length=255, default=_gen_pgt, unique=True) + value = models.CharField(max_length=255, default=utils.gen_pgt, unique=True) def __unicode__(self): return u"ProxyGrantingTicket(%s, %s, %s)" % (self.user, self.value, self.service) diff --git a/cas_server/templates/cas_server/samlValidateError.xml b/cas_server/templates/cas_server/samlValidateError.xml index 062337b..b1b4226 100644 --- a/cas_server/templates/cas_server/samlValidateError.xml +++ b/cas_server/templates/cas_server/samlValidateError.xml @@ -7,8 +7,7 @@ MajorVersion="1" MinorVersion="1" Recipient="{{Recipient}}" ResponseID="{{ResponseID}}"> - - + {{msg}} diff --git a/cas_server/templates/cas_server/serviceValidateError.xml b/cas_server/templates/cas_server/serviceValidateError.xml index 56e03c9..cab8d9b 100644 --- a/cas_server/templates/cas_server/serviceValidateError.xml +++ b/cas_server/templates/cas_server/serviceValidateError.xml @@ -1,5 +1,3 @@ - - {{msg}} - + {{msg}} diff --git a/cas_server/utils.py b/cas_server/utils.py index ca1223d..3e62a0f 100644 --- a/cas_server/utils.py +++ b/cas_server/utils.py @@ -9,8 +9,14 @@ # # (c) 2015 Valentin Samir """Some util function for the app""" +from . import default_settings + +from django.conf import settings + import urlparse import urllib +import random +import string def update_url(url, params): """update params in the `url` query string""" @@ -19,3 +25,46 @@ def update_url(url, params): query.update(params) url_parts[4] = urllib.urlencode(query) return urlparse.urlunparse(url_parts) + +def unpack_nested_exception(error): + """If exception are stacked, return the first one""" + i = 0 + while True: + if error.args[i:]: + if isinstance(error.args[i], Exception): + error = error.args[i] + i = 0 + else: + i += 1 + else: + break + return error + + +def _gen_ticket(prefix): + """Generate a ticket with prefix `prefix`""" + return '%s-%s' % ( + prefix, + ''.join( + random.choice( + string.ascii_letters + string.digits + ) for _ in range(settings.CAS_ST_LEN) + ) + ) + +def gen_st(): + """Generate a Service Ticket""" + return _gen_ticket('ST') + +def gen_pt(): + """Generate a Proxy Ticket""" + return _gen_ticket('PT') + +def gen_pgt(): + """Generate a Proxy Granting Ticket""" + return _gen_ticket('PGT') + +def gen_pgtiou(): + """Generate a Proxy Granting Ticket IOU""" + return _gen_ticket('PGTIOU') + diff --git a/cas_server/views.py b/cas_server/views.py index ef0570c..e00da96 100644 --- a/cas_server/views.py +++ b/cas_server/views.py @@ -229,6 +229,15 @@ def validate(request): return HttpResponse("no\n", content_type="text/plain") +def _validate_error(request, code, msg=""): + """render the serviceValidateError.xml template using `code` and `msg`""" + return render( + request, + "cas_server/serviceValidateError.xml", + {'code':code, 'msg':msg}, + content_type="text/xml; charset=utf-8" + ) + def ps_validate(request, ticket_type=None): """factorization for serviceValidate and proxyValidate""" if ticket_type is None: @@ -238,22 +247,20 @@ def ps_validate(request, ticket_type=None): pgt_url = request.GET.get('pgtUrl') renew = True if request.GET.get('renew') else False if service and ticket: - for typ in ticket_type: - if ticket.startswith(typ): + for elt in ticket_type: + if ticket.startswith(elt): break else: - return render( + return _validate_error( request, - "cas_server/serviceValidateError.xml", - {'code':'INVALID_TICKET'}, - content_type="text/xml; charset=utf-8" + 'INVALID_TICKET', + 'tickets should begin with %s' % ' or '.join(ticket_type) ) try: proxies = [] if ticket.startswith("ST"): ticket = models.ServiceTicket.objects.get( value=ticket, - service=service, validate=False, renew=renew, creation__gt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_VALIDITY)) @@ -261,7 +268,6 @@ def ps_validate(request, ticket_type=None): elif ticket.startswith("PT"): ticket = models.ProxyTicket.objects.get( value=ticket, - service=service, validate=False, renew=renew, creation__gt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_VALIDITY)) @@ -270,6 +276,8 @@ def ps_validate(request, ticket_type=None): proxies.append(prox.url) ticket.validate = True ticket.save() + if ticket.service != service: + return _validate_error(request, 'INVALID_SERVICE') attributes = [] for key, value in ticket.attributs.items(): if isinstance(value, list): @@ -284,7 +292,7 @@ def ps_validate(request, ticket_type=None): if pgt_url and pgt_url.startswith("https://"): pattern = models.ServicePattern.validate(pgt_url) if pattern.proxy: - proxyid = models.gen_pgtiou() + proxyid = utils.gen_pgtiou() pticket = models.ProxyGrantingTicket.objects.create( user=ticket.user, service=pgt_url, @@ -304,19 +312,14 @@ def ps_validate(request, ticket_type=None): params, content_type="text/xml; charset=utf-8" ) - except requests.exceptions.SSLError: - return render( - request, - "cas_server/serviceValidateError.xml", - {'code':'INVALID_PROXY_CALLBACK'}, - content_type="text/xml; charset=utf-8" - ) + except requests.exceptions.SSLError as error: + error = utils.unpack_nested_exception(error) + return _validate_error(request, 'INVALID_PROXY_CALLBACK', str(error)) else: - return render( + return _validate_error( request, - "cas_server/serviceValidateError.xml", - {'code':'INVALID_PROXY_CALLBACK'}, - content_type="text/xml; charset=utf-8" + 'INVALID_PROXY_CALLBACK', + "callback url not allowed by configuration" ) else: return render( @@ -326,25 +329,18 @@ def ps_validate(request, ticket_type=None): content_type="text/xml; charset=utf-8" ) except (models.ServiceTicket.DoesNotExist, models.ProxyTicket.DoesNotExist): - return render( - request, - "cas_server/serviceValidateError.xml", - {'code':'INVALID_TICKET'}, - content_type="text/xml; charset=utf-8" - ) + return _validate_error(request, 'INVALID_TICKET', 'ticket not found') except models.ServicePattern.DoesNotExist: - return render( + return _validate_error( request, - "cas_server/serviceValidateError.xml", - {'code':'INVALID_TICKET'}, - content_type="text/xml; charset=utf-8" + 'INVALID_PROXY_CALLBACK', + 'callback url not allowed by configuration' ) else: - return render( + return _validate_error( request, - "cas_server/serviceValidateError.xml", - {'code':'INVALID_REQUEST'}, - content_type="text/xml; charset=utf-8" + 'INVALID_REQUEST', + "you must specify a service and a ticket" ) def service_validate(request): @@ -378,46 +374,20 @@ def proxy(request): content_type="text/xml; charset=utf-8" ) except models.ProxyGrantingTicket.DoesNotExist: - return render( - request, - "cas_server/serviceValidateError.xml", - {'code':'INVALID_TICKET'}, - content_type="text/xml; charset=utf-8" - ) + return _validate_error(request, 'INVALID_TICKET', 'PGT not found') except models.ServicePattern.DoesNotExist: - return render( + return _validate_error(request, 'UNAUTHORIZED_SERVICE') + except (models.BadUsername, models.BadFilter, models.UserFieldNotDefined): + return _validate_error( request, - "cas_server/serviceValidateError.xml", - {'code':'INVALID_TICKET'}, - content_type="text/xml; charset=utf-8" - ) - except models.BadUsername: - return render( - request, - "cas_server/serviceValidateError.xml", - {'code':'INVALID_TICKET'}, - content_type="text/xml; charset=utf-8" - ) - except models.BadFilter: - return render( - request, - "cas_server/serviceValidateError.xml", - {'code':'INVALID_TICKET'}, - content_type="text/xml; charset=utf-8" - ) - except models.UserFieldNotDefined: - return render( - request, - "cas_server/serviceValidateError.xml", - {'code':'INVALID_TICKET'}, - content_type="text/xml; charset=utf-8" + 'UNAUTHORIZED_USER', + '%s not allowed on %s' % (ticket.user, target_service) ) else: - return render( + return _validate_error( request, - "cas_server/serviceValidateError.xml", - {'code':'INVALID_REQUEST'}, - content_type="text/xml; charset=utf-8" + 'INVALID_REQUEST', + "you must specify and pgt and targetService" ) def p3_service_validate(request): @@ -428,6 +398,15 @@ def p3_proxy_validate(request): """service/proxy ticket validation CAS 3.0""" return proxy_validate(request) +def _saml_validate_error(request, code, msg=""): + """render the samlValidateError.xml templace using `code` and `msg`""" + return render( + request, + "cas_server/samlValidateError.xml", + {'code':code, 'msg':msg}, + content_type="text/xml; charset=utf-8" + ) + @csrf_exempt def saml_validate(request): """checks the validity of a Service Ticket by a SAML 1.1 request""" @@ -439,14 +418,32 @@ def saml_validate(request): issue_instant = auth_req.attrib['IssueInstant'] request_id = auth_req.attrib['RequestID'] ticket = auth_req.getchildren()[0].text - ticket = models.ServiceTicket.objects.get( - value=ticket, - service=target, - validate=False, - creation__gt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_VALIDITY)) - ) + if ticket.startswith("ST"): + ticket = models.ServiceTicket.objects.get( + value=ticket, + validate=False, + creation__gt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_VALIDITY)) + ) + elif ticket.startswith("PT"): + ticket = models.ProxyTicket.objects.get( + value=ticket, + validate=False, + creation__gt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_VALIDITY)) + ) + else: + return _saml_validate_error( + request, + 'AuthnFailed', + 'ticket should begin with PT- or ST-' + ) ticket.validate = True ticket.save() + if ticket.service != target: + return _saml_validate_error( + request, + 'AuthnFailed', + 'TARGET do not match ticket service' + ) expire_instant = (ticket.creation + \ timedelta(seconds=settings.CAS_TICKET_VALIDITY)).isoformat() attributes = [] @@ -473,26 +470,13 @@ def saml_validate(request): params, content_type="text/xml; charset=utf-8" ) - except IndexError: - return render( - request, - "cas_server/samlValidateError.xml", - {'code':'VersionMismatch'}, - content_type="text/xml; charset=utf-8" - ) - except KeyError: - return render( - request, - "cas_server/samlValidateError.xml", - {'code':'VersionMismatch'}, - content_type="text/xml; charset=utf-8" - ) - except models.ServiceTicket.DoesNotExist: - return render( - request, - "cas_server/samlValidateError.xml", - {'code':'AuthnFailed'}, - content_type="text/xml; charset=utf-8" - ) + except (IndexError, KeyError): + return _saml_validate_error(request, 'VersionMismatch') + except (models.ServiceTicket.DoesNotExist, models.ProxyTicket.DoesNotExist): + return _saml_validate_error(request, 'AuthnFailed', 'ticket not found') else: - return redirect("login") + return _saml_validate_error( + request, + 'VersionMismatch', + 'request should be send using POST' + )