diff --git a/cas_server/migrations/0001_initial.py b/cas_server/migrations/0001_initial.py
index dcde945..6a1294a 100644
--- a/cas_server/migrations/0001_initial.py
+++ b/cas_server/migrations/0001_initial.py
@@ -32,7 +32,7 @@ class Migration(migrations.Migration):
('service', models.TextField()),
('creation', models.DateTimeField(auto_now_add=True)),
('renew', models.BooleanField(default=False)),
- ('value', models.CharField(default=cas_server.models._gen_pgt, unique=True, max_length=255)),
+ ('value', models.CharField(default=cas_server.utils.gen_pgt, unique=True, max_length=255)),
],
options={
'abstract': False,
@@ -48,7 +48,7 @@ class Migration(migrations.Migration):
('service', models.TextField()),
('creation', models.DateTimeField(auto_now_add=True)),
('renew', models.BooleanField(default=False)),
- ('value', models.CharField(default=cas_server.models._gen_pt, unique=True, max_length=255)),
+ ('value', models.CharField(default=cas_server.utils.gen_pt, unique=True, max_length=255)),
],
options={
'abstract': False,
@@ -81,7 +81,7 @@ class Migration(migrations.Migration):
('service', models.TextField()),
('creation', models.DateTimeField(auto_now_add=True)),
('renew', models.BooleanField(default=False)),
- ('value', models.CharField(default=cas_server.models._gen_st, unique=True, max_length=255)),
+ ('value', models.CharField(default=cas_server.utils.gen_st, unique=True, max_length=255)),
],
options={
'abstract': False,
diff --git a/cas_server/models.py b/cas_server/models.py
index 1b47d3c..7ae4ab7 100644
--- a/cas_server/models.py
+++ b/cas_server/models.py
@@ -10,9 +10,6 @@
#
# (c) 2015 Valentin Samir
"""models for the app"""
-from . import default_settings
-
-from django.conf import settings
from django.db import models
from django.contrib import messages
from picklefield.fields import PickledObjectField
@@ -21,41 +18,12 @@ from django.utils import timezone
import re
import os
-import random
-import string
from concurrent.futures import ThreadPoolExecutor
from requests_futures.sessions import FuturesSession
from . import utils
-def _gen_ticket(prefix):
- """Generate a ticket with prefix `prefix`"""
- return '%s-%s' % (
- prefix,
- ''.join(
- random.choice(
- string.ascii_letters + string.digits
- ) for _ in range(settings.CAS_ST_LEN)
- )
- )
-
-def _gen_st():
- """Generate a Service Ticket"""
- return _gen_ticket('ST')
-
-def _gen_pt():
- """Generate a Proxy Ticket"""
- return _gen_ticket('PT')
-
-def _gen_pgt():
- """Generate a Proxy Granting Ticket"""
- return _gen_ticket('PGT')
-
-def gen_pgtiou():
- """Generate a Proxy Granting Ticket IOU"""
- return _gen_ticket('PGTIOU')
-
class User(models.Model):
"""A user logged into the CAS"""
username = models.CharField(max_length=30, unique=True)
@@ -83,10 +51,11 @@ class User(models.Model):
try:
future.result()
except Exception as error:
+ error = utils.unpack_nested_exception(error)
messages.add_message(
request,
messages.WARNING,
- _(u'Error during service logout %r') % error
+ _(u'Error during service logout %s') % error
)
def get_ticket(self, ticket_class, service, service_pattern, renew):
@@ -333,6 +302,7 @@ class Ticket(models.Model):
headers=headers
)
except Exception as error:
+ error = utils.unpack_nested_exception(error)
messages.add_message(
request,
messages.WARNING,
@@ -342,17 +312,17 @@ class Ticket(models.Model):
class ServiceTicket(Ticket):
"""A Service Ticket"""
- value = models.CharField(max_length=255, default=_gen_st, unique=True)
+ value = models.CharField(max_length=255, default=utils.gen_st, unique=True)
def __unicode__(self):
return u"ServiceTicket(%s, %s, %s)" % (self.user, self.value, self.service)
class ProxyTicket(Ticket):
"""A Proxy Ticket"""
- value = models.CharField(max_length=255, default=_gen_pt, unique=True)
+ value = models.CharField(max_length=255, default=utils.gen_pt, unique=True)
def __unicode__(self):
return u"ProxyTicket(%s, %s, %s)" % (self.user, self.value, self.service)
class ProxyGrantingTicket(Ticket):
"""A Proxy Granting Ticket"""
- value = models.CharField(max_length=255, default=_gen_pgt, unique=True)
+ value = models.CharField(max_length=255, default=utils.gen_pgt, unique=True)
def __unicode__(self):
return u"ProxyGrantingTicket(%s, %s, %s)" % (self.user, self.value, self.service)
diff --git a/cas_server/templates/cas_server/samlValidateError.xml b/cas_server/templates/cas_server/samlValidateError.xml
index 062337b..b1b4226 100644
--- a/cas_server/templates/cas_server/samlValidateError.xml
+++ b/cas_server/templates/cas_server/samlValidateError.xml
@@ -7,8 +7,7 @@
MajorVersion="1" MinorVersion="1" Recipient="{{Recipient}}"
ResponseID="{{ResponseID}}">
-
-
+ {{msg}}
diff --git a/cas_server/templates/cas_server/serviceValidateError.xml b/cas_server/templates/cas_server/serviceValidateError.xml
index 56e03c9..cab8d9b 100644
--- a/cas_server/templates/cas_server/serviceValidateError.xml
+++ b/cas_server/templates/cas_server/serviceValidateError.xml
@@ -1,5 +1,3 @@
-
- {{msg}}
-
+ {{msg}}
diff --git a/cas_server/utils.py b/cas_server/utils.py
index ca1223d..3e62a0f 100644
--- a/cas_server/utils.py
+++ b/cas_server/utils.py
@@ -9,8 +9,14 @@
#
# (c) 2015 Valentin Samir
"""Some util function for the app"""
+from . import default_settings
+
+from django.conf import settings
+
import urlparse
import urllib
+import random
+import string
def update_url(url, params):
"""update params in the `url` query string"""
@@ -19,3 +25,46 @@ def update_url(url, params):
query.update(params)
url_parts[4] = urllib.urlencode(query)
return urlparse.urlunparse(url_parts)
+
+def unpack_nested_exception(error):
+ """If exception are stacked, return the first one"""
+ i = 0
+ while True:
+ if error.args[i:]:
+ if isinstance(error.args[i], Exception):
+ error = error.args[i]
+ i = 0
+ else:
+ i += 1
+ else:
+ break
+ return error
+
+
+def _gen_ticket(prefix):
+ """Generate a ticket with prefix `prefix`"""
+ return '%s-%s' % (
+ prefix,
+ ''.join(
+ random.choice(
+ string.ascii_letters + string.digits
+ ) for _ in range(settings.CAS_ST_LEN)
+ )
+ )
+
+def gen_st():
+ """Generate a Service Ticket"""
+ return _gen_ticket('ST')
+
+def gen_pt():
+ """Generate a Proxy Ticket"""
+ return _gen_ticket('PT')
+
+def gen_pgt():
+ """Generate a Proxy Granting Ticket"""
+ return _gen_ticket('PGT')
+
+def gen_pgtiou():
+ """Generate a Proxy Granting Ticket IOU"""
+ return _gen_ticket('PGTIOU')
+
diff --git a/cas_server/views.py b/cas_server/views.py
index ef0570c..e00da96 100644
--- a/cas_server/views.py
+++ b/cas_server/views.py
@@ -229,6 +229,15 @@ def validate(request):
return HttpResponse("no\n", content_type="text/plain")
+def _validate_error(request, code, msg=""):
+ """render the serviceValidateError.xml template using `code` and `msg`"""
+ return render(
+ request,
+ "cas_server/serviceValidateError.xml",
+ {'code':code, 'msg':msg},
+ content_type="text/xml; charset=utf-8"
+ )
+
def ps_validate(request, ticket_type=None):
"""factorization for serviceValidate and proxyValidate"""
if ticket_type is None:
@@ -238,22 +247,20 @@ def ps_validate(request, ticket_type=None):
pgt_url = request.GET.get('pgtUrl')
renew = True if request.GET.get('renew') else False
if service and ticket:
- for typ in ticket_type:
- if ticket.startswith(typ):
+ for elt in ticket_type:
+ if ticket.startswith(elt):
break
else:
- return render(
+ return _validate_error(
request,
- "cas_server/serviceValidateError.xml",
- {'code':'INVALID_TICKET'},
- content_type="text/xml; charset=utf-8"
+ 'INVALID_TICKET',
+ 'tickets should begin with %s' % ' or '.join(ticket_type)
)
try:
proxies = []
if ticket.startswith("ST"):
ticket = models.ServiceTicket.objects.get(
value=ticket,
- service=service,
validate=False,
renew=renew,
creation__gt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_VALIDITY))
@@ -261,7 +268,6 @@ def ps_validate(request, ticket_type=None):
elif ticket.startswith("PT"):
ticket = models.ProxyTicket.objects.get(
value=ticket,
- service=service,
validate=False,
renew=renew,
creation__gt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_VALIDITY))
@@ -270,6 +276,8 @@ def ps_validate(request, ticket_type=None):
proxies.append(prox.url)
ticket.validate = True
ticket.save()
+ if ticket.service != service:
+ return _validate_error(request, 'INVALID_SERVICE')
attributes = []
for key, value in ticket.attributs.items():
if isinstance(value, list):
@@ -284,7 +292,7 @@ def ps_validate(request, ticket_type=None):
if pgt_url and pgt_url.startswith("https://"):
pattern = models.ServicePattern.validate(pgt_url)
if pattern.proxy:
- proxyid = models.gen_pgtiou()
+ proxyid = utils.gen_pgtiou()
pticket = models.ProxyGrantingTicket.objects.create(
user=ticket.user,
service=pgt_url,
@@ -304,19 +312,14 @@ def ps_validate(request, ticket_type=None):
params,
content_type="text/xml; charset=utf-8"
)
- except requests.exceptions.SSLError:
- return render(
- request,
- "cas_server/serviceValidateError.xml",
- {'code':'INVALID_PROXY_CALLBACK'},
- content_type="text/xml; charset=utf-8"
- )
+ except requests.exceptions.SSLError as error:
+ error = utils.unpack_nested_exception(error)
+ return _validate_error(request, 'INVALID_PROXY_CALLBACK', str(error))
else:
- return render(
+ return _validate_error(
request,
- "cas_server/serviceValidateError.xml",
- {'code':'INVALID_PROXY_CALLBACK'},
- content_type="text/xml; charset=utf-8"
+ 'INVALID_PROXY_CALLBACK',
+ "callback url not allowed by configuration"
)
else:
return render(
@@ -326,25 +329,18 @@ def ps_validate(request, ticket_type=None):
content_type="text/xml; charset=utf-8"
)
except (models.ServiceTicket.DoesNotExist, models.ProxyTicket.DoesNotExist):
- return render(
- request,
- "cas_server/serviceValidateError.xml",
- {'code':'INVALID_TICKET'},
- content_type="text/xml; charset=utf-8"
- )
+ return _validate_error(request, 'INVALID_TICKET', 'ticket not found')
except models.ServicePattern.DoesNotExist:
- return render(
+ return _validate_error(
request,
- "cas_server/serviceValidateError.xml",
- {'code':'INVALID_TICKET'},
- content_type="text/xml; charset=utf-8"
+ 'INVALID_PROXY_CALLBACK',
+ 'callback url not allowed by configuration'
)
else:
- return render(
+ return _validate_error(
request,
- "cas_server/serviceValidateError.xml",
- {'code':'INVALID_REQUEST'},
- content_type="text/xml; charset=utf-8"
+ 'INVALID_REQUEST',
+ "you must specify a service and a ticket"
)
def service_validate(request):
@@ -378,46 +374,20 @@ def proxy(request):
content_type="text/xml; charset=utf-8"
)
except models.ProxyGrantingTicket.DoesNotExist:
- return render(
- request,
- "cas_server/serviceValidateError.xml",
- {'code':'INVALID_TICKET'},
- content_type="text/xml; charset=utf-8"
- )
+ return _validate_error(request, 'INVALID_TICKET', 'PGT not found')
except models.ServicePattern.DoesNotExist:
- return render(
+ return _validate_error(request, 'UNAUTHORIZED_SERVICE')
+ except (models.BadUsername, models.BadFilter, models.UserFieldNotDefined):
+ return _validate_error(
request,
- "cas_server/serviceValidateError.xml",
- {'code':'INVALID_TICKET'},
- content_type="text/xml; charset=utf-8"
- )
- except models.BadUsername:
- return render(
- request,
- "cas_server/serviceValidateError.xml",
- {'code':'INVALID_TICKET'},
- content_type="text/xml; charset=utf-8"
- )
- except models.BadFilter:
- return render(
- request,
- "cas_server/serviceValidateError.xml",
- {'code':'INVALID_TICKET'},
- content_type="text/xml; charset=utf-8"
- )
- except models.UserFieldNotDefined:
- return render(
- request,
- "cas_server/serviceValidateError.xml",
- {'code':'INVALID_TICKET'},
- content_type="text/xml; charset=utf-8"
+ 'UNAUTHORIZED_USER',
+ '%s not allowed on %s' % (ticket.user, target_service)
)
else:
- return render(
+ return _validate_error(
request,
- "cas_server/serviceValidateError.xml",
- {'code':'INVALID_REQUEST'},
- content_type="text/xml; charset=utf-8"
+ 'INVALID_REQUEST',
+ "you must specify and pgt and targetService"
)
def p3_service_validate(request):
@@ -428,6 +398,15 @@ def p3_proxy_validate(request):
"""service/proxy ticket validation CAS 3.0"""
return proxy_validate(request)
+def _saml_validate_error(request, code, msg=""):
+ """render the samlValidateError.xml templace using `code` and `msg`"""
+ return render(
+ request,
+ "cas_server/samlValidateError.xml",
+ {'code':code, 'msg':msg},
+ content_type="text/xml; charset=utf-8"
+ )
+
@csrf_exempt
def saml_validate(request):
"""checks the validity of a Service Ticket by a SAML 1.1 request"""
@@ -439,14 +418,32 @@ def saml_validate(request):
issue_instant = auth_req.attrib['IssueInstant']
request_id = auth_req.attrib['RequestID']
ticket = auth_req.getchildren()[0].text
- ticket = models.ServiceTicket.objects.get(
- value=ticket,
- service=target,
- validate=False,
- creation__gt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_VALIDITY))
- )
+ if ticket.startswith("ST"):
+ ticket = models.ServiceTicket.objects.get(
+ value=ticket,
+ validate=False,
+ creation__gt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_VALIDITY))
+ )
+ elif ticket.startswith("PT"):
+ ticket = models.ProxyTicket.objects.get(
+ value=ticket,
+ validate=False,
+ creation__gt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_VALIDITY))
+ )
+ else:
+ return _saml_validate_error(
+ request,
+ 'AuthnFailed',
+ 'ticket should begin with PT- or ST-'
+ )
ticket.validate = True
ticket.save()
+ if ticket.service != target:
+ return _saml_validate_error(
+ request,
+ 'AuthnFailed',
+ 'TARGET do not match ticket service'
+ )
expire_instant = (ticket.creation + \
timedelta(seconds=settings.CAS_TICKET_VALIDITY)).isoformat()
attributes = []
@@ -473,26 +470,13 @@ def saml_validate(request):
params,
content_type="text/xml; charset=utf-8"
)
- except IndexError:
- return render(
- request,
- "cas_server/samlValidateError.xml",
- {'code':'VersionMismatch'},
- content_type="text/xml; charset=utf-8"
- )
- except KeyError:
- return render(
- request,
- "cas_server/samlValidateError.xml",
- {'code':'VersionMismatch'},
- content_type="text/xml; charset=utf-8"
- )
- except models.ServiceTicket.DoesNotExist:
- return render(
- request,
- "cas_server/samlValidateError.xml",
- {'code':'AuthnFailed'},
- content_type="text/xml; charset=utf-8"
- )
+ except (IndexError, KeyError):
+ return _saml_validate_error(request, 'VersionMismatch')
+ except (models.ServiceTicket.DoesNotExist, models.ProxyTicket.DoesNotExist):
+ return _saml_validate_error(request, 'AuthnFailed', 'ticket not found')
else:
- return redirect("login")
+ return _saml_validate_error(
+ request,
+ 'VersionMismatch',
+ 'request should be send using POST'
+ )