diff --git a/.coveragerc b/.coveragerc
index f11c9de..8f6e752 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -1,3 +1,11 @@
+[run]
+branch = True
+source = cas_server
+omit =
+ cas_server/migrations*
+ cas_server/management/*
+ cas_server/tests/*
+
[report]
exclude_lines =
pragma: no cover
@@ -5,3 +13,4 @@ exclude_lines =
def __unicode__
raise AssertionError
raise NotImplementedError
+ if six.PY3:
diff --git a/.travis.yml b/.travis.yml
index b0bdd46..943f5b5 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -2,19 +2,19 @@ language: python
python:
- "2.7"
env:
- global:
- - PIP_DOWNLOAD_CACHE=$HOME/.pip_cache
matrix:
+ - TOX_ENV=coverage
+ - TOX_ENV=flake8
- TOX_ENV=py27-django17
- TOX_ENV=py27-django18
- TOX_ENV=py27-django19
- TOX_ENV=py34-django17
- TOX_ENV=py34-django18
- TOX_ENV=py34-django19
- - TOX_ENV=flake8
cache:
directories:
- - $HOME/.pip-cache/
+ - $HOME/.cache/pip/
+ - $HOME/build/nitmir/django-cas-server/.tox/
install:
- "travis_retry pip install setuptools --upgrade"
- "pip install tox"
@@ -22,4 +22,3 @@ script:
- tox -e $TOX_ENV
after_script:
- cat .tox/$TOX_ENV/log/*.log
-
diff --git a/Makefile b/Makefile
index 9088fba..d0f9165 100644
--- a/Makefile
+++ b/Makefile
@@ -1,11 +1,15 @@
-.PHONY: clean build install dist test_venv test_project
+.PHONY: build dist
VERSION=`python setup.py -V`
build:
python setup.py build
-install:
- python setup.py install
+install: dist
+ pip -V
+ pip install --no-deps --upgrade --force-reinstall --find-links ./dist/django-cas-server-${VERSION}.tar.gz django-cas-server
+
+uninstall:
+ pip uninstall django-cas-server || true
clean_pyc:
find ./ -name '*.pyc' -delete
@@ -16,18 +20,23 @@ clean_tox:
rm -rf .tox
clean_test_venv:
rm -rf test_venv
-clean: clean_pyc clean_build
-clean_all: clean_pyc clean_build clean_tox clean_test_venv
+clean_coverage:
+ rm -rf coverage.xml .coverage htmlcov
+clean_tild_backup:
+ find ./ -name '*~' -delete
+
+clean: clean_pyc clean_build clean_coverage clean_tild_backup
+
+clean_all: clean clean_tox clean_test_venv
dist:
python setup.py sdist
-test_venv:
- mkdir -p test_venv
+test_venv/bin/python:
virtualenv test_venv
- test_venv/bin/pip install -U --requirement requirements.txt
+ test_venv/bin/pip install -U --requirement requirements-dev.txt Django
-test_venv/cas/manage.py:
+test_venv/cas/manage.py: test_venv
mkdir -p test_venv/cas
test_venv/bin/django-admin startproject cas test_venv/cas
ln -s ../../cas_server test_venv/cas/cas_server
@@ -38,19 +47,15 @@ test_venv/cas/manage.py:
test_venv/bin/python test_venv/cas/manage.py migrate
test_venv/bin/python test_venv/cas/manage.py createsuperuser
-test_project: test_venv test_venv/cas/manage.py
+test_venv: test_venv/bin/python
+
+test_project: test_venv/cas/manage.py
@echo "##############################################################"
@echo "A test django project was created in $(realpath test_venv/cas)"
-run_test_server: test_project
+run_server: test_project
test_venv/bin/python test_venv/cas/manage.py runserver
-coverage: test_venv
- test_venv/bin/pip install coverage
- test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests
- test_venv/bin/coverage html
- test_venv/bin/coverage xml
-
-coverage_codacy: coverage
- test_venv/bin/pip install codacy-coverage
- test_venv/bin/python-codacy-coverage -r coverage.xml
+run_tests: test_venv
+ test_venv/bin/py.test --cov=cas_server --cov-report html
+ rm htmlcov/coverage_html.js # I am really pissed off by those keybord shortcuts
diff --git a/README.rst b/README.rst
index 29b8057..bb148a3 100644
--- a/README.rst
+++ b/README.rst
@@ -219,7 +219,8 @@ Test backend settings. Only usefull if you are using the test authentication bac
* ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``.
* ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``.
* ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is
- ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}``.
+ ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net',
+ 'alias': ['demo1', 'demo2']}``.
Authentication backend
diff --git a/cas_server/__init__.py b/cas_server/__init__.py
index f830740..29f5de6 100644
--- a/cas_server/__init__.py
+++ b/cas_server/__init__.py
@@ -7,6 +7,6 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
-# (c) 2015 Valentin Samir
-
+# (c) 2015-2016 Valentin Samir
+"""A django CAS server application"""
default_app_config = 'cas_server.apps.CasAppConfig'
diff --git a/cas_server/admin.py b/cas_server/admin.py
index a6a9be4..472e1df 100644
--- a/cas_server/admin.py
+++ b/cas_server/admin.py
@@ -7,7 +7,7 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
-# (c) 2015 Valentin Samir
+# (c) 2015-2016 Valentin Samir
"""module for the admin interface of the app"""
from django.contrib import admin
from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket, User, ServicePattern
diff --git a/cas_server/apps.py b/cas_server/apps.py
index c34b6eb..ea15273 100644
--- a/cas_server/apps.py
+++ b/cas_server/apps.py
@@ -1,7 +1,19 @@
+# This program is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
+# more details.
+#
+# You should have received a copy of the GNU General Public License version 3
+# along with this program; if not, write to the Free Software Foundation, Inc., 51
+# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+#
+# (c) 2015-2016 Valentin Samir
+"""django config module"""
from django.utils.translation import ugettext_lazy as _
from django.apps import AppConfig
class CasAppConfig(AppConfig):
+ """django CAS application config class"""
name = 'cas_server'
verbose_name = _('Central Authentication Service')
diff --git a/cas_server/auth.py b/cas_server/auth.py
index f84fb11..2826a85 100644
--- a/cas_server/auth.py
+++ b/cas_server/auth.py
@@ -8,7 +8,7 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
-# (c) 2015 Valentin Samir
+# (c) 2015-2016 Valentin Samir
"""Some authentication classes for the CAS"""
from django.conf import settings
from django.contrib.auth import get_user_model
@@ -21,6 +21,7 @@ except ImportError:
class AuthUser(object):
+ """Authentication base class"""
def __init__(self, username):
self.username = username
diff --git a/cas_server/default_settings.py b/cas_server/default_settings.py
index 2824991..1d2174c 100644
--- a/cas_server/default_settings.py
+++ b/cas_server/default_settings.py
@@ -78,5 +78,12 @@ setting_default('CAS_TEST_USER', 'test')
setting_default('CAS_TEST_PASSWORD', 'test')
setting_default(
'CAS_TEST_ATTRIBUTES',
- {'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}
+ {
+ 'nom': 'Nymous',
+ 'prenom': 'Ano',
+ 'email': 'anonymous@example.net',
+ 'alias': ['demo1', 'demo2']
+ }
)
+
+setting_default('CAS_ENABLE_AJAX_AUTH', False)
diff --git a/cas_server/forms.py b/cas_server/forms.py
index f970ccd..83cfe8a 100644
--- a/cas_server/forms.py
+++ b/cas_server/forms.py
@@ -7,7 +7,7 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
-# (c) 2015 Valentin Samir
+# (c) 2015-2016 Valentin Samir
"""forms for the app"""
from .default_settings import settings
@@ -19,6 +19,7 @@ import cas_server.models as models
class WarnForm(forms.Form):
+ """Form used on warn page before emiting a ticket"""
service = forms.CharField(widget=forms.HiddenInput(), required=False)
renew = forms.BooleanField(widget=forms.HiddenInput(), required=False)
gateway = forms.CharField(widget=forms.HiddenInput(), required=False)
@@ -35,6 +36,7 @@ class UserCredential(forms.Form):
lt = forms.CharField(widget=forms.HiddenInput(), required=False)
method = forms.CharField(widget=forms.HiddenInput(), required=False)
warn = forms.BooleanField(label=_('warn'), required=False)
+ renew = forms.BooleanField(widget=forms.HiddenInput(), required=False)
def __init__(self, *args, **kwargs):
super(UserCredential, self).__init__(*args, **kwargs)
@@ -46,6 +48,7 @@ class UserCredential(forms.Form):
cleaned_data["username"] = auth.username
else:
raise forms.ValidationError(_(u"Bad user"))
+ return cleaned_data
class TicketForm(forms.ModelForm):
diff --git a/cas_server/management/commands/cas_clean_sessions.py b/cas_server/management/commands/cas_clean_sessions.py
index da60c1a..3d32090 100644
--- a/cas_server/management/commands/cas_clean_sessions.py
+++ b/cas_server/management/commands/cas_clean_sessions.py
@@ -1,3 +1,4 @@
+"""Clean deleted sessions management command"""
from django.core.management.base import BaseCommand
from django.utils.translation import ugettext_lazy as _
@@ -5,6 +6,7 @@ from ... import models
class Command(BaseCommand):
+ """Clean deleted sessions"""
args = ''
help = _(u"Clean deleted sessions")
diff --git a/cas_server/management/commands/cas_clean_tickets.py b/cas_server/management/commands/cas_clean_tickets.py
index d18a7d4..dfbd4ec 100644
--- a/cas_server/management/commands/cas_clean_tickets.py
+++ b/cas_server/management/commands/cas_clean_tickets.py
@@ -1,3 +1,4 @@
+"""Clean old trickets management command"""
from django.core.management.base import BaseCommand
from django.utils.translation import ugettext_lazy as _
@@ -5,6 +6,7 @@ from ... import models
class Command(BaseCommand):
+ """Clean old trickets"""
args = ''
help = _(u"Clean old trickets")
diff --git a/cas_server/models.py b/cas_server/models.py
index 9cb0ac5..d870a50 100644
--- a/cas_server/models.py
+++ b/cas_server/models.py
@@ -8,7 +8,7 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
-# (c) 2015 Valentin Samir
+# (c) 2015-2016 Valentin Samir
"""models for the app"""
from .default_settings import settings
@@ -20,7 +20,6 @@ from django.utils import timezone
from picklefield.fields import PickledObjectField
import re
-import os
import sys
import logging
from importlib import import_module
@@ -47,6 +46,7 @@ class User(models.Model):
@classmethod
def clean_old_entries(cls):
+ """Remove users inactive since more that SESSION_COOKIE_AGE"""
users = cls.objects.filter(
date__lt=(timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE))
)
@@ -56,6 +56,7 @@ class User(models.Model):
@classmethod
def clean_deleted_sessions(cls):
+ """Remove user where the session do not exists anymore"""
for user in cls.objects.all():
if not SessionStore(session_key=user.session_key).get('authenticated'):
user.logout()
@@ -80,10 +81,10 @@ class User(models.Model):
for ticket_class in ticket_classes:
queryset = ticket_class.objects.filter(user=self)
for ticket in queryset:
- ticket.logout(request, session, async_list)
+ ticket.logout(session, async_list)
queryset.delete()
for future in async_list:
- if future:
+ if future: # pragma: no branch (should always be true)
try:
future.result()
except Exception as error:
@@ -111,13 +112,21 @@ class User(models.Model):
(a.name, a.replace if a.replace else a.name) for a in service_pattern.attributs.all()
)
replacements = dict(
- (a.name, (a.pattern, a.replace)) for a in service_pattern.replacements.all()
+ (a.attribut, (a.pattern, a.replace)) for a in service_pattern.replacements.all()
)
service_attributs = {}
for (key, value) in self.attributs.items():
if key in attributs or '*' in attributs:
if key in replacements:
- value = re.sub(replacements[key][0], replacements[key][1], value)
+ if isinstance(value, list):
+ for index, subval in enumerate(value):
+ value[index] = re.sub(
+ replacements[key][0],
+ replacements[key][1],
+ subval
+ )
+ else:
+ value = re.sub(replacements[key][0], replacements[key][1], value)
service_attributs[attributs.get(key, key)] = value
ticket = ticket_class.objects.create(
user=self,
@@ -141,6 +150,7 @@ class User(models.Model):
class ServicePatternException(Exception):
+ """Base exception of exceptions raised in the ServicePattern model"""
pass
@@ -394,77 +404,57 @@ class Ticket(models.Model):
).delete()
# sending SLO to timed-out validated tickets
- if cls.TIMEOUT and cls.TIMEOUT > 0:
- async_list = []
- session = FuturesSession(
- executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
- )
- queryset = cls.objects.filter(
- creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT))
- )
- for ticket in queryset:
- ticket.logout(None, session, async_list)
- queryset.delete()
- for future in async_list:
- if future:
- try:
- future.result()
- except Exception as error:
- logger.warning("Error durring SLO %s" % error)
- sys.stderr.write("%r\n" % error)
+ async_list = []
+ session = FuturesSession(
+ executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
+ )
+ queryset = cls.objects.filter(
+ creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT))
+ )
+ for ticket in queryset:
+ ticket.logout(session, async_list)
+ queryset.delete()
+ for future in async_list:
+ if future: # pragma: no branch (should always be true)
+ try:
+ future.result()
+ except Exception as error:
+ logger.warning("Error durring SLO %s" % error)
+ sys.stderr.write("%r\n" % error)
- def logout(self, request, session, async_list=None):
+ def logout(self, session, async_list=None):
"""Send a SLO request to the ticket service"""
# On logout invalidate the Ticket
self.validate = True
self.save()
- if self.validate and self.single_log_out:
+ if self.validate and self.single_log_out: # pragma: no branch (should always be true)
logger.info(
"Sending SLO requests to service %s for user %s" % (
self.service,
self.user.username
)
)
- try:
- xml = u"""
-
- %(ticket)s
- """ % \
- {
- 'id': os.urandom(20).encode("hex"),
- 'datetime': timezone.now().isoformat(),
- 'ticket': self.value
- }
- if self.service_pattern.single_log_out_callback:
- url = self.service_pattern.single_log_out_callback
- else:
- url = self.service
- async_list.append(
- session.post(
- url.encode('utf-8'),
- data={'logoutRequest': xml.encode('utf-8')},
- timeout=settings.CAS_SLO_TIMEOUT
- )
+ xml = u"""
+
+%(ticket)s
+""" % \
+ {
+ 'id': utils.gen_saml_id(),
+ 'datetime': timezone.now().isoformat(),
+ 'ticket': self.value
+ }
+ if self.service_pattern.single_log_out_callback:
+ url = self.service_pattern.single_log_out_callback
+ else:
+ url = self.service
+ async_list.append(
+ session.post(
+ url.encode('utf-8'),
+ data={'logoutRequest': xml.encode('utf-8')},
+ timeout=settings.CAS_SLO_TIMEOUT
)
- except Exception as error:
- error = utils.unpack_nested_exception(error)
- logger.warning(
- "Error durring SLO for user %s on service %s: %s" % (
- self.user.username,
- self.service,
- error
- )
- )
- if request is not None:
- messages.add_message(
- request,
- messages.WARNING,
- _(u'Error during service logout %(service)s:\n%(error)s') %
- {'service': self.service, 'error': error}
- )
- else:
- sys.stderr.write("%r\n" % error)
+ )
class ServiceTicket(Ticket):
diff --git a/cas_server/tests.py b/cas_server/tests.py
deleted file mode 100644
index 7d355cb..0000000
--- a/cas_server/tests.py
+++ /dev/null
@@ -1,702 +0,0 @@
-from .default_settings import settings
-
-from django.test import TestCase
-from django.test import Client
-
-import six
-from lxml import etree
-
-from cas_server import models
-from cas_server import utils
-
-
-def get_login_page_params():
- client = Client()
- response = client.get('/login')
- form = response.context["form"]
- params = {}
- for field in form:
- if field.value():
- params[field.name] = field.value()
- else:
- params[field.name] = ""
- return client, params
-
-
-def get_auth_client():
- client, params = get_login_page_params()
- params["username"] = settings.CAS_TEST_USER
- params["password"] = settings.CAS_TEST_PASSWORD
-
- client.post('/login', params)
- return client
-
-
-def get_user_ticket_request(service):
- client = get_auth_client()
- response = client.get("/login", {"service": service})
- ticket_value = response['Location'].split('ticket=')[-1]
- user = models.User.objects.get(
- username=settings.CAS_TEST_USER,
- session_key=client.session.session_key
- )
- ticket = models.ServiceTicket.objects.get(value=ticket_value)
- return (user, ticket)
-
-
-def get_pgt():
- (host, port) = utils.PGTUrlHandler.run()[1:3]
- service = "http://%s:%s" % (host, port)
-
- (user, ticket) = get_user_ticket_request(service)
-
- client = Client()
- client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
- params = utils.PGTUrlHandler.PARAMS.copy()
-
- params["service"] = service
- params["user"] = user
-
- return params
-
-
-class CheckPasswordCase(TestCase):
- """Tests for the utils function `utils.check_password`"""
-
- def setUp(self):
- """Generate random bytes string that will be used ass passwords"""
- self.password1 = utils.gen_saml_id()
- self.password2 = utils.gen_saml_id()
- if not isinstance(self.password1, bytes):
- self.password1 = self.password1.encode("utf8")
- self.password2 = self.password2.encode("utf8")
-
- def test_setup(self):
- """check that generated password are bytes"""
- self.assertIsInstance(self.password1, bytes)
- self.assertIsInstance(self.password2, bytes)
-
- def test_plain(self):
- """test the plain auth method"""
- self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
- self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
-
- def test_crypt(self):
- """test the crypt auth method"""
- if six.PY3:
- hashed_password1 = utils.crypt.crypt(
- self.password1.decode("utf8"),
- "$6$UVVAQvrMyXMF3FF3"
- ).encode("utf8")
- else:
- hashed_password1 = utils.crypt.crypt(self.password1, "$6$UVVAQvrMyXMF3FF3")
-
- self.assertTrue(utils.check_password("crypt", self.password1, hashed_password1, "utf8"))
- self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8"))
-
- def test_ldap_ssha(self):
- """test the ldap auth method with a {SSHA} scheme"""
- salt = b"UVVAQvrMyXMF3FF3"
- hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8")
-
- self.assertIsInstance(hashed_password1, bytes)
- self.assertTrue(utils.check_password("ldap", self.password1, hashed_password1, "utf8"))
- self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8"))
-
- def test_hex_md5(self):
- """test the hex_md5 auth method"""
- hashed_password1 = utils.hashlib.md5(self.password1).hexdigest()
-
- self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8"))
- self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8"))
-
- def test_hox_sha512(self):
- """test the hex_sha512 auth method"""
- hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest()
-
- self.assertTrue(
- utils.check_password("hex_sha512", self.password1, hashed_password1, "utf8")
- )
- self.assertFalse(
- utils.check_password("hex_sha512", self.password2, hashed_password1, "utf8")
- )
-
-
-class LoginTestCase(TestCase):
-
- def setUp(self):
- settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
- self.service_pattern = models.ServicePattern.objects.create(
- name="example",
- pattern="^https://www\.example\.com(/.*)?$",
- )
- models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
-
- def test_login_view_post_goodpass_goodlt(self):
- client, params = get_login_page_params()
- params["username"] = settings.CAS_TEST_USER
- params["password"] = settings.CAS_TEST_PASSWORD
-
- response = client.post('/login', params)
-
- self.assertEqual(response.status_code, 200)
- self.assertTrue(
- (
- b"You have successfully logged into "
- b"the Central Authentication Service"
- ) in response.content
- )
-
- self.assertTrue(
- models.User.objects.get(
- username=settings.CAS_TEST_USER,
- session_key=client.session.session_key
- )
- )
-
- def test_login_view_post_badlt(self):
- client, params = get_login_page_params()
- params["username"] = settings.CAS_TEST_USER
- params["password"] = settings.CAS_TEST_PASSWORD
- params["lt"] = 'LT-random'
-
- response = client.post('/login', params)
-
- self.assertEqual(response.status_code, 200)
- self.assertTrue(b"Invalid login ticket" in response.content)
- self.assertFalse(
- (
- b"You have successfully logged into "
- b"the Central Authentication Service"
- ) in response.content
- )
-
- def test_login_view_post_badpass_good_lt(self):
- client, params = get_login_page_params()
- params["username"] = settings.CAS_TEST_USER
- params["password"] = "test2"
- response = client.post('/login', params)
-
- self.assertEqual(response.status_code, 200)
- self.assertTrue(
- (
- b"The credentials you provided cannot be "
- b"determined to be authentic"
- ) in response.content
- )
- self.assertFalse(
- (
- b"You have successfully logged into "
- b"the Central Authentication Service"
- ) in response.content
- )
-
- def test_view_login_get_auth_allowed_service(self):
- client = get_auth_client()
- response = client.get("/login?service=https://www.example.com")
- self.assertEqual(response.status_code, 302)
- self.assertTrue(response.has_header('Location'))
- self.assertTrue(
- response['Location'].startswith(
- "https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX
- )
- )
-
- ticket_value = response['Location'].split('ticket=')[-1]
- user = models.User.objects.get(
- username=settings.CAS_TEST_USER,
- session_key=client.session.session_key
- )
- self.assertTrue(user)
- ticket = models.ServiceTicket.objects.get(value=ticket_value)
- self.assertEqual(ticket.user, user)
- self.assertEqual(ticket.attributs, settings.CAS_TEST_ATTRIBUTES)
- self.assertEqual(ticket.validate, False)
- self.assertEqual(ticket.service_pattern, self.service_pattern)
-
- def test_view_login_get_auth_denied_service(self):
- client = get_auth_client()
- response = client.get("/login?service=https://www.example.org")
- self.assertEqual(response.status_code, 200)
- self.assertTrue(b"Service https://www.example.org non allowed" in response.content)
-
-
-class LogoutTestCase(TestCase):
-
- def setUp(self):
- settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
-
- def test_logout_view(self):
- client = get_auth_client()
-
- response = client.get("/login")
- self.assertEqual(response.status_code, 200)
- self.assertTrue(
- (
- b"You have successfully logged into "
- b"the Central Authentication Service"
- ) in response.content
- )
-
- response = client.get("/logout")
- self.assertEqual(response.status_code, 200)
- self.assertTrue(
- (
- b"You have successfully logged out from "
- b"the Central Authentication Service"
- ) in response.content
- )
-
- response = client.get("/login")
- self.assertEqual(response.status_code, 200)
- self.assertFalse(
- (
- b"You have successfully logged into "
- b"the Central Authentication Service"
- ) in response.content
- )
-
- def test_logout_view_url(self):
- client = get_auth_client()
-
- response = client.get('/logout?url=https://www.example.com')
- self.assertEqual(response.status_code, 302)
- self.assertTrue(response.has_header("Location"))
- self.assertEqual(response["Location"], "https://www.example.com")
-
- response = client.get("/login")
- self.assertEqual(response.status_code, 200)
- self.assertFalse(
- (
- b"You have successfully logged into "
- b"the Central Authentication Service"
- ) in response.content
- )
-
- def test_logout_view_service(self):
- client = get_auth_client()
-
- response = client.get('/logout?service=https://www.example.com')
- self.assertEqual(response.status_code, 302)
- self.assertTrue(response.has_header("Location"))
- self.assertEqual(response["Location"], "https://www.example.com")
-
- response = client.get("/login")
- self.assertEqual(response.status_code, 200)
- self.assertFalse(
- (
- b"You have successfully logged into "
- b"the Central Authentication Service"
- ) in response.content
- )
-
-
-class AuthTestCase(TestCase):
-
- def setUp(self):
- settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
- self.service = 'https://www.example.com'
- models.ServicePattern.objects.create(
- name="example",
- pattern="^https://www\.example\.com(/.*)?$"
- )
-
- def test_auth_view_goodpass(self):
- settings.CAS_AUTH_SHARED_SECRET = 'test'
- client = Client()
- response = client.post(
- '/auth',
- {
- 'username': settings.CAS_TEST_USER,
- 'password': settings.CAS_TEST_PASSWORD,
- 'service': self.service,
- 'secret': 'test'
- }
- )
- self.assertEqual(response.status_code, 200)
- self.assertEqual(response.content, b'yes\n')
-
- def test_auth_view_badpass(self):
- settings.CAS_AUTH_SHARED_SECRET = 'test'
- client = Client()
- response = client.post(
- '/auth',
- {
- 'username': settings.CAS_TEST_USER,
- 'password': 'badpass',
- 'service': self.service,
- 'secret': 'test'
- }
- )
- self.assertEqual(response.status_code, 200)
- self.assertEqual(response.content, b'no\n')
-
- def test_auth_view_badservice(self):
- settings.CAS_AUTH_SHARED_SECRET = 'test'
- client = Client()
- response = client.post(
- '/auth',
- {
- 'username': settings.CAS_TEST_USER,
- 'password': settings.CAS_TEST_PASSWORD,
- 'service': 'https://www.example.org',
- 'secret': 'test'
- }
- )
- self.assertEqual(response.status_code, 200)
- self.assertEqual(response.content, b'no\n')
-
- def test_auth_view_badsecret(self):
- settings.CAS_AUTH_SHARED_SECRET = 'test'
- client = Client()
- response = client.post(
- '/auth',
- {
- 'username': settings.CAS_TEST_USER,
- 'password': settings.CAS_TEST_PASSWORD,
- 'service': self.service,
- 'secret': 'badsecret'
- }
- )
- self.assertEqual(response.status_code, 200)
- self.assertEqual(response.content, b'no\n')
-
- def test_auth_view_badsettings(self):
- settings.CAS_AUTH_SHARED_SECRET = None
- client = Client()
- response = client.post(
- '/auth',
- {
- 'username': settings.CAS_TEST_USER,
- 'password': settings.CAS_TEST_PASSWORD,
- 'service': self.service,
- 'secret': 'test'
- }
- )
- self.assertEqual(response.status_code, 200)
- self.assertEqual(response.content, b"no\nplease set CAS_AUTH_SHARED_SECRET")
-
-
-class ValidateTestCase(TestCase):
-
- def setUp(self):
- settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
- self.service = 'https://www.example.com'
- self.service_pattern = models.ServicePattern.objects.create(
- name="example",
- pattern="^https://www\.example\.com(/.*)?$"
- )
- models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
-
- def test_validate_view_ok(self):
- ticket = get_user_ticket_request(self.service)[1]
-
- client = Client()
- response = client.get('/validate', {'ticket': ticket.value, 'service': self.service})
- self.assertEqual(response.status_code, 200)
- self.assertEqual(response.content, b'yes\ntest\n')
-
- def test_validate_view_badservice(self):
- ticket = get_user_ticket_request(self.service)[1]
-
- client = Client()
- response = client.get(
- '/validate',
- {'ticket': ticket.value, 'service': "https://www.example.org"}
- )
- self.assertEqual(response.status_code, 200)
- self.assertEqual(response.content, b'no\n')
-
- def test_validate_view_badticket(self):
- get_user_ticket_request(self.service)
-
- client = Client()
- response = client.get(
- '/validate',
- {'ticket': "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX, 'service': self.service}
- )
- self.assertEqual(response.status_code, 200)
- self.assertEqual(response.content, b'no\n')
-
-
-class ValidateServiceTestCase(TestCase):
-
- def setUp(self):
- settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
- self.service = 'http://127.0.0.1:45678'
- self.service_pattern = models.ServicePattern.objects.create(
- name="localhost",
- pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
- proxy_callback=True
- )
- models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
-
- def test_validate_service_view_ok(self):
- ticket = get_user_ticket_request(self.service)[1]
-
- client = Client()
- response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service})
- self.assertEqual(response.status_code, 200)
-
- root = etree.fromstring(response.content)
- sucess = root.xpath(
- "//cas:authenticationSuccess",
- namespaces={'cas': "http://www.yale.edu/tp/cas"}
- )
- self.assertTrue(sucess)
-
- users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
- self.assertEqual(len(users), 1)
- self.assertEqual(users[0].text, settings.CAS_TEST_USER)
-
- attributes = root.xpath(
- "//cas:attributes",
- namespaces={'cas': "http://www.yale.edu/tp/cas"}
- )
- self.assertEqual(len(attributes), 1)
- attrs1 = {}
- for attr in attributes[0]:
- attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text
-
- attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
- self.assertEqual(len(attributes), len(attrs1))
- attrs2 = {}
- for attr in attributes:
- attrs2[attr.attrib['name']] = attr.attrib['value']
- self.assertEqual(attrs1, attrs2)
- self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES)
-
- def test_validate_service_view_badservice(self):
- ticket = get_user_ticket_request(self.service)[1]
-
- client = Client()
- bad_service = "https://www.example.org"
- response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': bad_service})
- self.assertEqual(response.status_code, 200)
-
- root = etree.fromstring(response.content)
- error = root.xpath(
- "//cas:authenticationFailure",
- namespaces={'cas': "http://www.yale.edu/tp/cas"}
- )
- self.assertEqual(len(error), 1)
- self.assertEqual(error[0].attrib['code'], "INVALID_SERVICE")
- self.assertEqual(error[0].text, bad_service)
-
- def test_validate_service_view_badticket_goodprefix(self):
- get_user_ticket_request(self.service)
-
- client = Client()
- bad_ticket = "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX
- response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
- self.assertEqual(response.status_code, 200)
-
- root = etree.fromstring(response.content)
- error = root.xpath(
- "//cas:authenticationFailure",
- namespaces={'cas': "http://www.yale.edu/tp/cas"}
- )
- self.assertEqual(len(error), 1)
- self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
- self.assertEqual(error[0].text, 'ticket not found')
-
- def test_validate_service_view_badticket_badprefix(self):
- get_user_ticket_request(self.service)
-
- client = Client()
- bad_ticket = "RANDOM"
- response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
- self.assertEqual(response.status_code, 200)
-
- root = etree.fromstring(response.content)
- error = root.xpath(
- "//cas:authenticationFailure",
- namespaces={'cas': "http://www.yale.edu/tp/cas"}
- )
- self.assertEqual(len(error), 1)
- self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
- self.assertEqual(error[0].text, bad_ticket)
-
- def test_validate_service_view_ok_pgturl(self):
- (host, port) = utils.PGTUrlHandler.run()[1:3]
- service = "http://%s:%s" % (host, port)
-
- ticket = get_user_ticket_request(service)[1]
-
- client = Client()
- response = client.get(
- '/serviceValidate',
- {'ticket': ticket.value, 'service': service, 'pgtUrl': service}
- )
- pgt_params = utils.PGTUrlHandler.PARAMS.copy()
- self.assertEqual(response.status_code, 200)
-
- root = etree.fromstring(response.content)
- pgtiou = root.xpath(
- "//cas:proxyGrantingTicket",
- namespaces={'cas': "http://www.yale.edu/tp/cas"}
- )
- self.assertEqual(len(pgtiou), 1)
- self.assertEqual(pgt_params["pgtIou"], pgtiou[0].text)
- self.assertTrue("pgtId" in pgt_params)
-
- def test_validate_service_pgturl_bad_proxy_callback(self):
- self.service_pattern.proxy_callback = False
- self.service_pattern.save()
- ticket = get_user_ticket_request(self.service)[1]
-
- client = Client()
- response = client.get(
- '/serviceValidate',
- {'ticket': ticket.value, 'service': self.service, 'pgtUrl': self.service}
- )
- self.assertEqual(response.status_code, 200)
-
- root = etree.fromstring(response.content)
- error = root.xpath(
- "//cas:authenticationFailure",
- namespaces={'cas': "http://www.yale.edu/tp/cas"}
- )
- self.assertEqual(len(error), 1)
- self.assertEqual(error[0].attrib['code'], "INVALID_PROXY_CALLBACK")
- self.assertEqual(error[0].text, "callback url not allowed by configuration")
-
-
-class ProxyTestCase(TestCase):
-
- def setUp(self):
- settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
- self.service = 'http://127.0.0.1'
- self.service_pattern = models.ServicePattern.objects.create(
- name="localhost",
- pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
- proxy=True,
- proxy_callback=True
- )
- models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
-
- def test_validate_proxy_ok(self):
- params = get_pgt()
-
- # get a proxy ticket
- client1 = Client()
- response = client1.get('/proxy', {'pgt': params['pgtId'], 'targetService': self.service})
- self.assertEqual(response.status_code, 200)
-
- root = etree.fromstring(response.content)
- sucess = root.xpath("//cas:proxySuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"})
- self.assertTrue(sucess)
-
- proxy_ticket = root.xpath(
- "//cas:proxyTicket",
- namespaces={'cas': "http://www.yale.edu/tp/cas"}
- )
- self.assertEqual(len(proxy_ticket), 1)
- proxy_ticket = proxy_ticket[0].text
-
- # validate the proxy ticket
- client2 = Client()
- response = client2.get('/proxyValidate', {'ticket': proxy_ticket, 'service': self.service})
- self.assertEqual(response.status_code, 200)
-
- root = etree.fromstring(response.content)
- sucess = root.xpath(
- "//cas:authenticationSuccess",
- namespaces={'cas': "http://www.yale.edu/tp/cas"}
- )
- self.assertTrue(sucess)
-
- # check that the proxy is send to the end service
- proxies = root.xpath("//cas:proxies", namespaces={'cas': "http://www.yale.edu/tp/cas"})
- self.assertEqual(len(proxies), 1)
- proxy = proxies[0].xpath("//cas:proxy", namespaces={'cas': "http://www.yale.edu/tp/cas"})
- self.assertEqual(len(proxy), 1)
- self.assertEqual(proxy[0].text, params["service"])
-
- # same tests than those for serviceValidate
- users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
- self.assertEqual(len(users), 1)
- self.assertEqual(users[0].text, settings.CAS_TEST_USER)
-
- attributes = root.xpath(
- "//cas:attributes",
- namespaces={'cas': "http://www.yale.edu/tp/cas"}
- )
- self.assertEqual(len(attributes), 1)
- attrs1 = {}
- for attr in attributes[0]:
- attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text
-
- attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
- self.assertEqual(len(attributes), len(attrs1))
- attrs2 = {}
- for attr in attributes:
- attrs2[attr.attrib['name']] = attr.attrib['value']
- self.assertEqual(attrs1, attrs2)
- self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES)
-
- def test_validate_proxy_bad(self):
- params = get_pgt()
-
- # bad PGT
- client1 = Client()
- response = client1.get(
- '/proxy',
- {
- 'pgt': "%s-RANDOM" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX,
- 'targetService': params['service']
- }
- )
- self.assertEqual(response.status_code, 200)
-
- root = etree.fromstring(response.content)
- error = root.xpath(
- "//cas:authenticationFailure",
- namespaces={'cas': "http://www.yale.edu/tp/cas"}
- )
- self.assertEqual(len(error), 1)
- self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
- self.assertEqual(
- error[0].text,
- "PGT %s-RANDOM not found" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX
- )
-
- # bad targetService
- client2 = Client()
- response = client2.get(
- '/proxy',
- {'pgt': params['pgtId'], 'targetService': "https://www.example.org"}
- )
- self.assertEqual(response.status_code, 200)
-
- root = etree.fromstring(response.content)
- error = root.xpath(
- "//cas:authenticationFailure",
- namespaces={'cas': "http://www.yale.edu/tp/cas"}
- )
- self.assertEqual(len(error), 1)
- self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE")
- self.assertEqual(error[0].text, "https://www.example.org")
-
- # service do not allow proxy ticket
- self.service_pattern.proxy = False
- self.service_pattern.save()
-
- client3 = Client()
- response = client3.get(
- '/proxy',
- {'pgt': params['pgtId'], 'targetService': params['service']}
- )
- self.assertEqual(response.status_code, 200)
-
- root = etree.fromstring(response.content)
- error = root.xpath(
- "//cas:authenticationFailure",
- namespaces={'cas': "http://www.yale.edu/tp/cas"}
- )
- self.assertEqual(len(error), 1)
- self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE")
- self.assertEqual(
- error[0].text,
- 'the service %s do not allow proxy ticket' % params['service']
- )
diff --git a/cas_server/tests/__init__.py b/cas_server/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/cas_server/tests/mixin.py b/cas_server/tests/mixin.py
new file mode 100644
index 0000000..ddbf2d2
--- /dev/null
+++ b/cas_server/tests/mixin.py
@@ -0,0 +1,193 @@
+# ⁻*- coding: utf-8 -*-
+# This program is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
+# more details.
+#
+# You should have received a copy of the GNU General Public License version 3
+# along with this program; if not, write to the Free Software Foundation, Inc., 51
+# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+#
+# (c) 2016 Valentin Samir
+"""Some mixin classes for tests"""
+from cas_server.default_settings import settings
+from django.utils import timezone
+
+import re
+from lxml import etree
+from datetime import timedelta
+
+from cas_server import models
+from cas_server.tests.utils import get_auth_client
+
+
+class BaseServicePattern(object):
+ """Mixing for setting up service pattern for testing"""
+ def setup_service_patterns(self, proxy=False):
+ """setting up service pattern"""
+ # For general purpose testing
+ self.service = "https://www.example.com"
+ self.service_pattern = models.ServicePattern.objects.create(
+ name="example",
+ pattern="^https://www\.example\.com(/.*)?$",
+ proxy=proxy,
+ )
+ models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
+
+ # For testing the restrict_users attributes
+ self.service_restrict_user_fail = "https://restrict_user_fail.example.com"
+ self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create(
+ name="restrict_user_fail",
+ pattern="^https://restrict_user_fail\.example\.com(/.*)?$",
+ restrict_users=True,
+ proxy=proxy,
+ )
+ self.service_restrict_user_success = "https://restrict_user_success.example.com"
+ self.service_pattern_restrict_user_success = models.ServicePattern.objects.create(
+ name="restrict_user_success",
+ pattern="^https://restrict_user_success\.example\.com(/.*)?$",
+ restrict_users=True,
+ proxy=proxy,
+ )
+ models.Username.objects.create(
+ value=settings.CAS_TEST_USER,
+ service_pattern=self.service_pattern_restrict_user_success
+ )
+
+ # For testing the user attributes filtering conditions
+ self.service_filter_fail = "https://filter_fail.example.com"
+ self.service_pattern_filter_fail = models.ServicePattern.objects.create(
+ name="filter_fail",
+ pattern="^https://filter_fail\.example\.com(/.*)?$",
+ proxy=proxy,
+ )
+ models.FilterAttributValue.objects.create(
+ attribut="right",
+ pattern="^admin$",
+ service_pattern=self.service_pattern_filter_fail
+ )
+ self.service_filter_fail_alt = "https://filter_fail_alt.example.com"
+ self.service_pattern_filter_fail_alt = models.ServicePattern.objects.create(
+ name="filter_fail_alt",
+ pattern="^https://filter_fail_alt\.example\.com(/.*)?$",
+ proxy=proxy,
+ )
+ models.FilterAttributValue.objects.create(
+ attribut="nom",
+ pattern="^toto$",
+ service_pattern=self.service_pattern_filter_fail_alt
+ )
+ self.service_filter_success = "https://filter_success.example.com"
+ self.service_pattern_filter_success = models.ServicePattern.objects.create(
+ name="filter_success",
+ pattern="^https://filter_success\.example\.com(/.*)?$",
+ proxy=proxy,
+ )
+ models.FilterAttributValue.objects.create(
+ attribut="email",
+ pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']),
+ service_pattern=self.service_pattern_filter_success
+ )
+
+ # For testing the user_field attributes
+ self.service_field_needed_fail = "https://field_needed_fail.example.com"
+ self.service_pattern_field_needed_fail = models.ServicePattern.objects.create(
+ name="field_needed_fail",
+ pattern="^https://field_needed_fail\.example\.com(/.*)?$",
+ user_field="uid",
+ proxy=proxy,
+ )
+ self.service_field_needed_success = "https://field_needed_success.example.com"
+ self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
+ name="field_needed_success",
+ pattern="^https://field_needed_success\.example\.com(/.*)?$",
+ user_field="alias",
+ proxy=proxy,
+ )
+ self.service_field_needed_success_alt = "https://field_needed_success_alt.example.com"
+ self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
+ name="field_needed_success_alt",
+ pattern="^https://field_needed_success_alt\.example\.com(/.*)?$",
+ user_field="nom",
+ proxy=proxy,
+ )
+
+
+class XmlContent(object):
+ """Mixin for test on CAS XML responses"""
+ def assert_error(self, response, code, text=None):
+ """Assert a validation error"""
+ self.assertEqual(response.status_code, 200)
+ root = etree.fromstring(response.content)
+ error = root.xpath(
+ "//cas:authenticationFailure",
+ namespaces={'cas': "http://www.yale.edu/tp/cas"}
+ )
+ self.assertEqual(len(error), 1)
+ self.assertEqual(error[0].attrib['code'], code)
+ if text is not None:
+ self.assertEqual(error[0].text, text)
+
+ def assert_success(self, response, username, original_attributes):
+ """assert a ticket validation success"""
+ self.assertEqual(response.status_code, 200)
+
+ root = etree.fromstring(response.content)
+ sucess = root.xpath(
+ "//cas:authenticationSuccess",
+ namespaces={'cas': "http://www.yale.edu/tp/cas"}
+ )
+ self.assertTrue(sucess)
+
+ users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
+ self.assertEqual(len(users), 1)
+ self.assertEqual(users[0].text, username)
+
+ attributes = root.xpath(
+ "//cas:attributes",
+ namespaces={'cas': "http://www.yale.edu/tp/cas"}
+ )
+ self.assertEqual(len(attributes), 1)
+ attrs1 = set()
+ for attr in attributes[0]:
+ attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
+
+ attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
+ self.assertEqual(len(attributes), len(attrs1))
+ attrs2 = set()
+ for attr in attributes:
+ attrs2.add((attr.attrib['name'], attr.attrib['value']))
+ original = set()
+ for key, value in original_attributes.items():
+ if isinstance(value, list):
+ for sub_value in value:
+ original.add((key, sub_value))
+ else:
+ original.add((key, value))
+ self.assertEqual(attrs1, attrs2)
+ self.assertEqual(attrs1, original)
+
+ return root
+
+
+class UserModels(object):
+ """Mixin for test on CAS user models"""
+ @staticmethod
+ def expire_user():
+ """return an expired user"""
+ client = get_auth_client()
+
+ new_date = timezone.now() - timedelta(seconds=(settings.SESSION_COOKIE_AGE + 600))
+ models.User.objects.filter(
+ username=settings.CAS_TEST_USER,
+ session_key=client.session.session_key
+ ).update(date=new_date)
+ return client
+
+ @staticmethod
+ def get_user(client):
+ """return the user associated with an authenticated client"""
+ return models.User.objects.get(
+ username=settings.CAS_TEST_USER,
+ session_key=client.session.session_key
+ )
diff --git a/settings_tests.py b/cas_server/tests/settings.py
similarity index 96%
rename from settings_tests.py
rename to cas_server/tests/settings.py
index 4588c2c..1402c64 100644
--- a/settings_tests.py
+++ b/cas_server/tests/settings.py
@@ -52,7 +52,7 @@ MIDDLEWARE_CLASSES = [
'django.middleware.locale.LocaleMiddleware',
]
-ROOT_URLCONF = 'cas_server.urls'
+ROOT_URLCONF = 'cas_server.tests.urls'
# Database
# https://docs.djangoproject.com/en/1.9/ref/settings/#databases
@@ -60,6 +60,7 @@ ROOT_URLCONF = 'cas_server.urls'
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.sqlite3',
+ 'NAME': ':memory:',
}
}
diff --git a/cas_server/tests/test_models.py b/cas_server/tests/test_models.py
new file mode 100644
index 0000000..e75f54f
--- /dev/null
+++ b/cas_server/tests/test_models.py
@@ -0,0 +1,166 @@
+# ⁻*- coding: utf-8 -*-
+# This program is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
+# more details.
+#
+# You should have received a copy of the GNU General Public License version 3
+# along with this program; if not, write to the Free Software Foundation, Inc., 51
+# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+#
+# (c) 2016 Valentin Samir
+"""Tests module for models"""
+from cas_server.default_settings import settings
+
+from django.test import TestCase
+from django.test.utils import override_settings
+from django.utils import timezone
+
+from datetime import timedelta
+from importlib import import_module
+
+from cas_server import models
+from cas_server.tests.utils import get_auth_client, HttpParamsHandler
+from cas_server.tests.mixin import UserModels, BaseServicePattern
+
+SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
+
+
+@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
+class UserTestCase(TestCase, UserModels):
+ """tests for the user models"""
+ def setUp(self):
+ """Prepare the test context"""
+ self.service = 'http://127.0.0.1:45678'
+ self.service_pattern = models.ServicePattern.objects.create(
+ name="localhost",
+ pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
+ single_log_out=True
+ )
+ models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
+
+ def test_clean_old_entries(self):
+ """test clean_old_entries"""
+ # get an authenticated client
+ client = self.expire_user()
+ # assert the user exists before being cleaned
+ self.assertEqual(len(models.User.objects.all()), 1)
+ # assert the last activity date is before the expiry date
+ self.assertTrue(
+ self.get_user(client).date < (
+ timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE)
+ )
+ )
+ # delete old inactive users
+ models.User.clean_old_entries()
+ # assert the user has being well delete
+ self.assertEqual(len(models.User.objects.all()), 0)
+
+ def test_clean_deleted_sessions(self):
+ """test clean_deleted_sessions"""
+ # get an authenticated client
+ client1 = get_auth_client()
+ client2 = get_auth_client()
+ # generate a ticket to fire SLO during user cleaning (SLO should fail a nothing listen
+ # on self.service)
+ ticket = self.get_user(client1).get_ticket(
+ models.ServiceTicket,
+ self.service,
+ self.service_pattern,
+ renew=False
+ )
+ ticket.validate = True
+ ticket.save()
+ # simulated expired session being garbage collected for client1
+ session = SessionStore(session_key=client1.session.session_key)
+ session.flush()
+ # assert the user exists before being cleaned
+ self.assertTrue(self.get_user(client1))
+ self.assertTrue(self.get_user(client2))
+ self.assertEqual(len(models.User.objects.all()), 2)
+ # session has being remove so the user of client1 is no longer authenticated
+ self.assertFalse(client1.session.get("authenticated"))
+ # the user a client2 should still be authenticated
+ self.assertTrue(client2.session.get("authenticated"))
+ # the user should be deleted
+ models.User.clean_deleted_sessions()
+ # assert the user with expired sessions has being well deleted but the other remain
+ self.assertEqual(len(models.User.objects.all()), 1)
+ self.assertFalse(models.ServiceTicket.objects.all())
+ self.assertTrue(client2.session.get("authenticated"))
+
+
+@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
+class TicketTestCase(TestCase, UserModels, BaseServicePattern):
+ """tests for the tickets models"""
+ def setUp(self):
+ """Prepare the test context"""
+ self.setup_service_patterns()
+ self.service = 'http://127.0.0.1:45678'
+ self.service_pattern = models.ServicePattern.objects.create(
+ name="localhost",
+ pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
+ single_log_out=True
+ )
+ models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
+
+ @staticmethod
+ def get_ticket(
+ user,
+ ticket_class,
+ service,
+ service_pattern,
+ renew=False,
+ validate=False,
+ validity_expired=False,
+ timeout_expired=False,
+ single_log_out=False,
+ ):
+ """Return a ticket"""
+ ticket = user.get_ticket(ticket_class, service, service_pattern, renew)
+ ticket.validate = validate
+ ticket.single_log_out = single_log_out
+ if validity_expired:
+ ticket.creation = min(
+ ticket.creation,
+ (timezone.now() - timedelta(seconds=(ticket_class.VALIDITY + 10)))
+ )
+ if timeout_expired:
+ ticket.creation = min(
+ ticket.creation,
+ (timezone.now() - timedelta(seconds=(ticket_class.TIMEOUT + 10)))
+ )
+ ticket.save()
+ return ticket
+
+ def test_clean_old_service_ticket(self):
+ """test tickets clean_old_entries"""
+ # ge an authenticated client
+ client = get_auth_client()
+ # get the user associated to the client
+ user = self.get_user(client)
+ # generate a ticket for that client, waiting for validation
+ self.get_ticket(user, models.ServiceTicket, self.service, self.service_pattern)
+ # generate another ticket for those validation time has expired
+ self.get_ticket(
+ user, models.ServiceTicket,
+ self.service, self.service_pattern, validity_expired=True
+ )
+ (httpd, host, port) = HttpParamsHandler.run()[0:3]
+ service = "http://%s:%s" % (host, port)
+ # generate a ticket with SLO having timeout reach
+ self.get_ticket(
+ user, models.ServiceTicket,
+ service, self.service_pattern, timeout_expired=True,
+ validate=True, single_log_out=True
+ )
+ # there should be 3 tickets in the db
+ self.assertEqual(len(models.ServiceTicket.objects.all()), 3)
+ # we call the clean_old_entries method that should delete validated non SLO ticket and
+ # expired non validated ticket and send SLO for SLO expired ticket before deleting then
+ models.ServiceTicket.clean_old_entries()
+ params = httpd.PARAMS
+ # we successfully got a SLO request
+ self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest'])
+ # only 1 ticket remain in the db
+ self.assertEqual(len(models.ServiceTicket.objects.all()), 1)
diff --git a/cas_server/tests/test_utils.py b/cas_server/tests/test_utils.py
new file mode 100644
index 0000000..76fa2cc
--- /dev/null
+++ b/cas_server/tests/test_utils.py
@@ -0,0 +1,191 @@
+# ⁻*- coding: utf-8 -*-
+# This program is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
+# more details.
+#
+# You should have received a copy of the GNU General Public License version 3
+# along with this program; if not, write to the Free Software Foundation, Inc., 51
+# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+#
+# (c) 2016 Valentin Samir
+"""Tests module for utils"""
+from django.test import TestCase
+
+import six
+
+from cas_server import utils
+
+
+class CheckPasswordCase(TestCase):
+ """Tests for the utils function `utils.check_password`"""
+
+ def setUp(self):
+ """Generate random bytes string that will be used ass passwords"""
+ self.password1 = utils.gen_saml_id()
+ self.password2 = utils.gen_saml_id()
+ if not isinstance(self.password1, bytes): # pragma: no cover executed only in python3
+ self.password1 = self.password1.encode("utf8")
+ self.password2 = self.password2.encode("utf8")
+
+ def test_setup(self):
+ """check that generated password are bytes"""
+ self.assertIsInstance(self.password1, bytes)
+ self.assertIsInstance(self.password2, bytes)
+
+ def test_plain(self):
+ """test the plain auth method"""
+ self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
+ self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
+
+ def test_plain_unicode(self):
+ """test the plain auth method with unicode input"""
+ self.assertTrue(
+ utils.check_password(
+ "plain",
+ self.password1.decode("utf8"),
+ self.password1.decode("utf8"),
+ "utf8"
+ )
+ )
+ self.assertFalse(
+ utils.check_password(
+ "plain",
+ self.password1.decode("utf8"),
+ self.password2.decode("utf8"),
+ "utf8"
+ )
+ )
+
+ def test_crypt(self):
+ """test the crypt auth method"""
+ salts = ["$6$UVVAQvrMyXMF3FF3", "aa"]
+ hashed_password1 = []
+ for salt in salts:
+ if six.PY3:
+ hashed_password1.append(
+ utils.crypt.crypt(
+ self.password1.decode("utf8"),
+ salt
+ ).encode("utf8")
+ )
+ else:
+ hashed_password1.append(utils.crypt.crypt(self.password1, salt))
+
+ for hp1 in hashed_password1:
+ self.assertTrue(utils.check_password("crypt", self.password1, hp1, "utf8"))
+ self.assertFalse(utils.check_password("crypt", self.password2, hp1, "utf8"))
+
+ with self.assertRaises(ValueError):
+ utils.check_password("crypt", self.password1, b"$truc$s$dsdsd", "utf8")
+
+ def test_ldap_password_valid(self):
+ """test the ldap auth method with all the schemes"""
+ salt = b"UVVAQvrMyXMF3FF3"
+ schemes_salt = [b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}"]
+ schemes_nosalt = [b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"]
+ hashed_password1 = []
+ for scheme in schemes_salt:
+ hashed_password1.append(
+ utils.LdapHashUserPassword.hash(scheme, self.password1, salt, charset="utf8")
+ )
+ for scheme in schemes_nosalt:
+ hashed_password1.append(
+ utils.LdapHashUserPassword.hash(scheme, self.password1, charset="utf8")
+ )
+ hashed_password1.append(
+ utils.LdapHashUserPassword.hash(
+ b"{CRYPT}",
+ self.password1,
+ b"$6$UVVAQvrMyXMF3FF3",
+ charset="utf8"
+ )
+ )
+ for hp1 in hashed_password1:
+ self.assertIsInstance(hp1, bytes)
+ self.assertTrue(utils.check_password("ldap", self.password1, hp1, "utf8"))
+ self.assertFalse(utils.check_password("ldap", self.password2, hp1, "utf8"))
+
+ def test_ldap_password_fail(self):
+ """test the ldap auth method with malformed hash or bad schemes"""
+ salt = b"UVVAQvrMyXMF3FF3"
+ schemes_salt = [b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}"]
+ schemes_nosalt = [b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"]
+
+ # first try to hash with bad parameters
+ with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
+ utils.LdapHashUserPassword.hash(b"TOTO", self.password1)
+ for scheme in schemes_nosalt:
+ with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
+ utils.LdapHashUserPassword.hash(scheme, self.password1, salt)
+ for scheme in schemes_salt:
+ with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
+ utils.LdapHashUserPassword.hash(scheme, self.password1)
+ with self.assertRaises(utils.LdapHashUserPassword.BadSalt):
+ utils.LdapHashUserPassword.hash(b'{CRYPT}', self.password1, b"$truc$toto")
+
+ # then try to check hash with bad hashes
+ with self.assertRaises(utils.LdapHashUserPassword.BadHash):
+ utils.check_password("ldap", self.password1, b"TOTOssdsdsd", "utf8")
+ for scheme in schemes_salt:
+ with self.assertRaises(utils.LdapHashUserPassword.BadHash):
+ utils.check_password("ldap", self.password1, scheme + b"dG90b3E8ZHNkcw==", "utf8")
+
+ def test_hex(self):
+ """test all the hex_HASH method: the hashed password is a simple hash of the password"""
+ hashes = ["md5", "sha1", "sha224", "sha256", "sha384", "sha512"]
+ hashed_password1 = []
+ for hash in hashes:
+ hashed_password1.append(
+ ("hex_%s" % hash, getattr(utils.hashlib, hash)(self.password1).hexdigest())
+ )
+ for (method, hp1) in hashed_password1:
+ self.assertTrue(utils.check_password(method, self.password1, hp1, "utf8"))
+ self.assertFalse(utils.check_password(method, self.password2, hp1, "utf8"))
+
+ def test_bad_method(self):
+ """try to check password with a bad method, should raise a ValueError"""
+ with self.assertRaises(ValueError):
+ utils.check_password("test", self.password1, b"$truc$s$dsdsd", "utf8")
+
+
+class UtilsTestCase(TestCase):
+ """tests for some little utils functions"""
+ def test_import_attr(self):
+ """
+ test the import_attr function. Feeded with a dotted path string, it should
+ import the dotted module and return that last componend of the dotted path
+ (function, class or variable)
+ """
+ with self.assertRaises(ImportError):
+ utils.import_attr('toto.titi.tutu')
+ with self.assertRaises(AttributeError):
+ utils.import_attr('cas_server.utils.toto')
+ with self.assertRaises(ValueError):
+ utils.import_attr('toto')
+ self.assertEqual(
+ utils.import_attr('cas_server.default_app_config'),
+ 'cas_server.apps.CasAppConfig'
+ )
+ self.assertEqual(utils.import_attr(utils), utils)
+
+ def test_update_url(self):
+ """
+ test the update_url function. Given an url with possible GET parameter and a dict
+ the function build a url with GET parameters updated by the dictionnary
+ """
+ url1 = utils.update_url(u"https://www.example.com?toto=1", {u"tata": u"2"})
+ url2 = utils.update_url(b"https://www.example.com?toto=1", {b"tata": b"2"})
+ self.assertEqual(url1, u"https://www.example.com?tata=2&toto=1")
+ self.assertEqual(url2, u"https://www.example.com?tata=2&toto=1")
+
+ url3 = utils.update_url(u"https://www.example.com?toto=1", {u"toto": u"2"})
+ self.assertEqual(url3, u"https://www.example.com?toto=2")
+
+ def test_crypt_salt_is_valid(self):
+ """test the function crypt_salt_is_valid who test if a crypt salt is valid"""
+ self.assertFalse(utils.crypt_salt_is_valid("")) # len 0
+ self.assertFalse(utils.crypt_salt_is_valid("a")) # len 1
+ self.assertFalse(utils.crypt_salt_is_valid("$$")) # start with $ followed by $
+ self.assertFalse(utils.crypt_salt_is_valid("$toto")) # start with $ but no secondary $
+ self.assertFalse(utils.crypt_salt_is_valid("$toto$toto")) # algorithm toto not known
diff --git a/cas_server/tests/test_view.py b/cas_server/tests/test_view.py
new file mode 100644
index 0000000..95720c4
--- /dev/null
+++ b/cas_server/tests/test_view.py
@@ -0,0 +1,1813 @@
+# ⁻*- coding: utf-8 -*-
+# This program is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
+# more details.
+#
+# You should have received a copy of the GNU General Public License version 3
+# along with this program; if not, write to the Free Software Foundation, Inc., 51
+# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+#
+# (c) 2016 Valentin Samir
+"""Tests module for views"""
+from cas_server.default_settings import settings
+
+import django
+from django.test import TestCase, Client
+from django.test.utils import override_settings
+from django.utils import timezone
+
+
+import random
+import json
+from lxml import etree
+from six.moves import range
+
+from cas_server import models
+from cas_server import utils
+from cas_server.tests.utils import (
+ copy_form,
+ get_login_page_params,
+ get_auth_client,
+ get_user_ticket_request,
+ get_pgt,
+ get_proxy_ticket,
+ get_validated_ticket,
+ HttpParamsHandler,
+ Http404Handler
+)
+from cas_server.tests.mixin import BaseServicePattern, XmlContent
+
+
+@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
+class LoginTestCase(TestCase, BaseServicePattern):
+ """Tests for the login view"""
+ def setUp(self):
+ """Prepare the test context:"""
+ # we prepare a bunch a service url and service patterns for tests
+ self.setup_service_patterns()
+
+ def assert_logged(self, client, response, warn=False, code=200):
+ """Assertions testing that client is well authenticated"""
+ self.assertEqual(response.status_code, code)
+ # this message is displayed to the user upon successful authentication
+ self.assertTrue(
+ (
+ b"You have successfully logged into "
+ b"the Central Authentication Service"
+ ) in response.content
+ )
+ # these session variables a set if usccessfully authenticated
+ self.assertTrue(client.session["username"] == settings.CAS_TEST_USER)
+ self.assertTrue(client.session["warn"] is warn)
+ self.assertTrue(client.session["authenticated"] is True)
+
+ # on successfull authentication, a corresponding user object is created
+ self.assertTrue(
+ models.User.objects.get(
+ username=settings.CAS_TEST_USER,
+ session_key=client.session.session_key
+ )
+ )
+
+ def assert_login_failed(self, client, response, code=200):
+ """Assertions testing a failed login attempt"""
+ self.assertEqual(response.status_code, code)
+ # this message is displayed to the user upon successful authentication, so it should not
+ # appear
+ self.assertFalse(
+ (
+ b"You have successfully logged into "
+ b"the Central Authentication Service"
+ ) in response.content
+ )
+
+ # if authentication has failed, these session variables should not be set
+ self.assertTrue(client.session.get("username") is None)
+ self.assertTrue(client.session.get("warn") is None)
+ self.assertTrue(client.session.get("authenticated") is None)
+
+ def test_login_view_post_goodpass_goodlt(self):
+ """Test a successul login"""
+ # we get a client who fetch a frist time the login page and the login form default
+ # parameters
+ client, params = get_login_page_params()
+ # we set username/password in the form
+ params["username"] = settings.CAS_TEST_USER
+ params["password"] = settings.CAS_TEST_PASSWORD
+ # the LoginTicket in the form should match a valid LT in the user session
+ self.assertTrue(params['lt'] in client.session['lt'])
+
+ # we post a login attempt
+ response = client.post('/login', params)
+ # as username/password/lt are all valid, the login should succed
+ self.assert_logged(client, response)
+ # The LoginTicket is conssumed and should no longer be valid
+ self.assertTrue(params['lt'] not in client.session['lt'])
+
+ def test_login_view_post_goodpass_goodlt_warn(self):
+ """Test a successul login requesting to be warned before creating services tickets"""
+ # get a client and initial login params
+ client, params = get_login_page_params()
+ # set valids usernames/passswords
+ params["username"] = settings.CAS_TEST_USER
+ params["password"] = settings.CAS_TEST_PASSWORD
+ # this time, we check the warn checkbox
+ params["warn"] = "on"
+
+ # postings login request
+ response = client.post('/login', params)
+ # as username/password/lt are all valid, the login should succed and warn be enabled
+ self.assert_logged(client, response, warn=True)
+
+ def test_lt_max(self):
+ """Check we only keep the last 100 Login Ticket for a user"""
+ # get a client and initial login params
+ client, params = get_login_page_params()
+ # get a first LT that should be valid
+ current_lt = params["lt"]
+ # we keep the last 100 generated LT by user, so after having generated `i_in_test` we
+ # test if `current_lt` is still valid
+ i_in_test = random.randint(0, 99)
+ # after `i_not_in_test` `current_lt` should be valid not more
+ i_not_in_test = random.randint(101, 150)
+ # start generating 150 LT
+ for i in range(150):
+ if i == i_in_test:
+ # before more than 100 LT generated, the first TL should be valid
+ self.assertTrue(current_lt in client.session['lt'])
+ if i == i_not_in_test:
+ # after more than 100 LT generated, the first LT should be valid no more
+ self.assertTrue(current_lt not in client.session['lt'])
+ # assert that we do not keep more that 100 valid LT
+ self.assertTrue(len(client.session['lt']) <= 100)
+ # generate a new LT by getting the login page
+ client, params = get_login_page_params(client)
+ # in the end, we still have less that 100 valid LT
+ self.assertTrue(len(client.session['lt']) <= 100)
+
+ def test_login_view_post_badlt(self):
+ """Login attempt with a bad LoginTicket, login should fail"""
+ # get a client and initial login params
+ client, params = get_login_page_params()
+ # set valid username/password
+ params["username"] = settings.CAS_TEST_USER
+ params["password"] = settings.CAS_TEST_PASSWORD
+ # set a bad LT
+ params["lt"] = 'LT-random'
+
+ # posting the login request
+ response = client.post('/login', params)
+
+ # as the LT is not valid, login should fail
+ self.assert_login_failed(client, response)
+ # the reason why login has failed is displayed to the user
+ self.assertTrue(b"Invalid login ticket" in response.content)
+
+ def test_login_view_post_badpass_good_lt(self):
+ """Login attempt with a bad password"""
+ # get a client and initial login params
+ client, params = get_login_page_params()
+ # set valid username but invalid password
+ params["username"] = settings.CAS_TEST_USER
+ params["password"] = "test2"
+ # posting the login request
+ response = client.post('/login', params)
+
+ # as the password is wrong, login should fail
+ self.assert_login_failed(client, response)
+ # the reason why login has failed is displayed to the user
+ self.assertTrue(
+ (
+ b"The credentials you provided cannot be "
+ b"determined to be authentic"
+ ) in response.content
+ )
+
+ def assert_ticket_attributes(self, client, ticket_value):
+ """check the ticket attributes in the db"""
+ # Get get current session user in the db
+ user = models.User.objects.get(
+ username=settings.CAS_TEST_USER,
+ session_key=client.session.session_key
+ )
+ # we should find exactly one user
+ self.assertTrue(user)
+ # get the ticker object corresponting to `ticket_value`
+ ticket = models.ServiceTicket.objects.get(value=ticket_value)
+ # chek that the ticket is well attributed to the user
+ self.assertEqual(ticket.user, user)
+ # check that the user attributes match the attributes registered on the ticket
+ self.assertEqual(ticket.attributs, settings.CAS_TEST_ATTRIBUTES)
+ # check that the ticket has not being validated yet
+ self.assertEqual(ticket.validate, False)
+ # check that the service pattern registered on the ticket is the on we use for tests
+ self.assertEqual(ticket.service_pattern, self.service_pattern)
+
+ def assert_service_ticket(self, client, response):
+ """check that a ticket is well emited when requested on a allowed service"""
+ # On ticket emission, we should be redirected to the service url, setting the ticket
+ # GET parameter
+ self.assertEqual(response.status_code, 302)
+ self.assertTrue(response.has_header('Location'))
+ self.assertTrue(
+ response['Location'].startswith(
+ "https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX
+ )
+ )
+ # check that the value of the ticket GET parameter match the value of the ticket
+ # created in the db
+ ticket_value = response['Location'].split('ticket=')[-1]
+ self.assert_ticket_attributes(client, ticket_value)
+
+ def test_view_login_get_allowed_service(self):
+ """Request a ticket for an allowed service by an unauthenticated client"""
+ # get a bare new http client
+ client = Client()
+ # we are not authenticated and are asking for a ticket for https://www.example.com
+ # which is a valid service matched by self.service_pattern
+ response = client.get("/login?service=https://www.example.com")
+ # the login page should be displayed
+ self.assertEqual(response.status_code, 200)
+ # we warn the user why it need to authenticated
+ self.assertTrue(
+ (
+ b"Authentication required by service "
+ b"example (https://www.example.com)"
+ ) in response.content
+ )
+
+ def test_view_login_get_denied_service(self):
+ """Request a ticket for an denied service by an unauthenticated client"""
+ # get a bare new http client
+ client = Client()
+ # we are not authenticated and are asking for a ticket for https://www.example.net
+ # which is NOT a valid service
+ response = client.get("/login?service=https://www.example.net")
+ self.assertEqual(response.status_code, 200)
+ # we warn the user that https://www.example.net is not an allowed service url
+ self.assertTrue(b"Service https://www.example.net non allowed" in response.content)
+
+ def test_view_login_get_auth_allowed_service(self):
+ """Request a ticket for an allowed service by an authenticated client"""
+ # get a client that is already authenticated
+ client = get_auth_client()
+ # ask for a ticket for https://www.example.com
+ response = client.get("/login?service=https://www.example.com")
+ # as https://www.example.com is a valid service a ticket should be created and the
+ # user redirected to the service url
+ self.assert_service_ticket(client, response)
+
+ def test_view_login_get_auth_allowed_service_warn(self):
+ """Request a ticket for an allowed service by an authenticated client"""
+ # get a client that is already authenticated and has ask to be warned befor we
+ # generated a ticket
+ client = get_auth_client(warn="on")
+ # ask for a ticket for https://www.example.com
+ response = client.get("/login?service=https://www.example.com")
+ # we display a warning to the user, asking him to validate the ticket creation (insted
+ # a generating and redirecting directly to the service url)
+ self.assertEqual(response.status_code, 200)
+ self.assertTrue(
+ (
+ b"Authentication has been required by service "
+ b"example (https://www.example.com)"
+ ) in response.content
+ )
+ # get the displayed form parameters
+ params = copy_form(response.context["form"])
+ # we post, confirming we want a ticket
+ response = client.post("/login", params)
+ # as https://www.example.com is a valid service a ticket should be created and the
+ # user redirected to the service url
+ self.assert_service_ticket(client, response)
+
+ def test_view_login_get_auth_denied_service(self):
+ """Request a ticket for a not allowed service by an authenticated client"""
+ # get a client that is already authenticated
+ client = get_auth_client()
+ # we are authenticated and are asking for a ticket for https://www.example.org
+ # which is NOT a valid service
+ response = client.get("/login?service=https://www.example.org")
+ self.assertEqual(response.status_code, 200)
+ # we warn the user that https://www.example.net is not an allowed service url
+ # NO ticket are created
+ self.assertTrue(b"Service https://www.example.org non allowed" in response.content)
+
+ def test_user_logged_not_in_db(self):
+ """If the user is logged but has been delete from the database, it should be logged out"""
+ # get a client that is already authenticated
+ client = get_auth_client()
+ # delete the user in the db
+ models.User.objects.get(
+ username=settings.CAS_TEST_USER,
+ session_key=client.session.session_key
+ ).delete()
+ # fetch the login page
+ response = client.get("/login")
+
+ # The user should be logged out
+ self.assert_login_failed(client, response, code=302)
+ # and redirected to the login page. We branch depending on the version a django as
+ # the test client behaviour changed after django 1.9
+ if django.VERSION < (1, 9): # pragma: no cover coverage is computed with dango 1.9
+ self.assertEqual(response["Location"], "http://testserver/login")
+ else:
+ self.assertEqual(response["Location"], "/login?")
+
+ def test_service_restrict_user(self):
+ """Testing the restric user capability from a service"""
+ # get a client that is already authenticated
+ client = get_auth_client()
+
+ # trying to get a ticket from a service url matched by a service pattern having a
+ # restriction on the usernames allowed to get tickets. the test user username is not one
+ # of this username.
+ response = client.get("/login", {'service': self.service_restrict_user_fail})
+ self.assertEqual(response.status_code, 200)
+ # the ticket is not created and a warning is displayed to the user
+ self.assertTrue(b"Username non allowed" in response.content)
+
+ # same but with the tes user username being one of the allowed usernames
+ response = client.get("/login", {'service': self.service_restrict_user_success})
+ # the ticket is created and we are redirected to the service url
+ self.assertEqual(response.status_code, 302)
+ self.assertTrue(
+ response["Location"].startswith("%s?ticket=" % self.service_restrict_user_success)
+ )
+
+ def test_service_filter(self):
+ """Test the filtering on user attributes"""
+ # get a client that is already authenticated
+ client = get_auth_client()
+
+ # trying to get a ticket from a service url matched by a service pattern having
+ # a restriction on the user attributes. The test user if ailing these restrictions
+ # We try first with a single value attribut (aka a string) and then with
+ # a multi values attributs (aka a list of strings)
+ for service in [self.service_filter_fail, self.service_filter_fail_alt]:
+ response = client.get("/login", {'service': service})
+ # the ticket is not created and a warning is displayed to the user
+ self.assertEqual(response.status_code, 200)
+ self.assertTrue(b"User charateristics non allowed" in response.content)
+
+ # same but with rectriction that a valid upon the test user attributes
+ response = client.get("/login", {'service': self.service_filter_success})
+ # the ticket us created and the user redirected to the service url
+ self.assertEqual(response.status_code, 302)
+ self.assertTrue(response["Location"].startswith("%s?ticket=" % self.service_filter_success))
+
+ def test_service_user_field(self):
+ """Test using a user attribute as username: case on if the attribute exists or not"""
+ # get a client that is already authenticated
+ client = get_auth_client()
+
+ # trying to get a ticket from a service url matched by a service pattern that use
+ # a particular attribute has username. The test user do NOT have this attribute
+ response = client.get("/login", {'service': self.service_field_needed_fail})
+ # the ticket is not created and a warning is displayed to the user
+ self.assertEqual(response.status_code, 200)
+ self.assertTrue(b"The attribut uid is needed to use that service" in response.content)
+
+ # same but with a attribute that the test user has
+ response = client.get("/login", {'service': self.service_field_needed_success})
+ # the ticket us created and the user redirected to the service url
+ self.assertEqual(response.status_code, 302)
+ self.assertTrue(
+ response["Location"].startswith("%s?ticket=" % self.service_field_needed_success)
+ )
+
+ @override_settings(CAS_TEST_ATTRIBUTES={'alias': []})
+ def test_service_user_field_evaluate_to_false(self):
+ """
+ Test using a user attribute as username:
+ case the attribute exists but evaluate to False
+ """
+ # get a client that is already authenticated
+ client = get_auth_client()
+ # trying to get a ticket from a service url matched by a service pattern that use
+ # a particular attribute has username. The test user have this attribute, but it is
+ # evaluated to False (eg an empty string "" or an empty list [])
+ response = client.get("/login", {"service": self.service_field_needed_success})
+ # the ticket is not created and a warning is displayed to the user
+ self.assertEqual(response.status_code, 200)
+ self.assertTrue(b"The attribut alias is needed to use that service" in response.content)
+
+ def test_gateway(self):
+ """test gateway parameter"""
+
+ # First with an authenticated client that fail to get a ticket for a service
+ service = "https://restrict_user_fail.example.com"
+ # get a client that is already authenticated
+ client = get_auth_client()
+ # the authenticated client fail to get a ticket for some reason
+ response = client.get("/login", {'service': service, 'gateway': 'on'})
+ # as gateway is set, he is redirected to the service url without any ticket
+ self.assertEqual(response.status_code, 302)
+ self.assertEqual(response["Location"], service)
+
+ # second for an user not yet authenticated on a valid service
+ client = Client()
+ # the client fail to get a ticket since he is not yep authenticated
+ response = client.get('/login', {'service': service, 'gateway': 'on'})
+ # as gateway is set, he is redirected to the service url without any ticket
+ self.assertEqual(response.status_code, 302)
+ self.assertEqual(response["Location"], service)
+
+ def test_renew(self):
+ """test the authentication renewal request from a service"""
+ # use the default test service
+ service = "https://www.example.com"
+ # get a client that is already authenticated
+ client = get_auth_client()
+ # ask for a ticket for the service but aks for authentication renewal
+ response = client.get("/login", {'service': service, 'renew': 'on'})
+ # we are ask to reauthenticate and tell the user why
+ self.assertEqual(response.status_code, 200)
+ self.assertTrue(
+ (
+ b"Authentication renewal required by "
+ b"service example (https://www.example.com)"
+ ) in response.content
+ )
+ # get the form default parameter
+ params = copy_form(response.context["form"])
+ # set valid username/password
+ params["username"] = settings.CAS_TEST_USER
+ params["password"] = settings.CAS_TEST_PASSWORD
+ # the renew parameter from the form should be True
+ self.assertEqual(params["renew"], True)
+ # post the authentication request
+ response = client.post("/login", params)
+ # the request succed, a ticket is created and we are redirected to the service url
+ self.assertEqual(response.status_code, 302)
+ ticket_value = response['Location'].split('ticket=')[-1]
+ ticket = models.ServiceTicket.objects.get(value=ticket_value)
+ # the created ticket is marked has being gottent after a renew. Futher testing about
+ # renewing authentication is done in the validate and serviceValidate views tests
+ self.assertEqual(ticket.renew, True)
+
+ @override_settings(CAS_ENABLE_AJAX_AUTH=True)
+ def test_ajax_login_required(self):
+ """
+ test ajax, login required.
+ The ajax methods allow the log a user in using javascript.
+ For doing so, every 302 redirection a replaced by a 200 returning a json with the
+ url to redirect to.
+ By default, ajax login is disabled.
+ If CAS_ENABLE_AJAX_AUTH is True, ajax login is enable and only page on the same domain
+ as the CAS can do ajax request. To allow pages on other domains, you need to use CORS.
+ You can use the django app corsheaders for that. Be carefull to only allow domains
+ you completly trust as any javascript on these domaine will be able to authenticate
+ as the user.
+ """
+ # get a bare client
+ client = Client()
+ # fetch the login page setting up the custom header HTTP_X_AJAX to tell we wish to de
+ # ajax requests
+ response = client.get("/login", HTTP_X_AJAX='on')
+ # we get a json as response telling us the user need to be authenticated
+ self.assertEqual(response.status_code, 200)
+ data = json.loads(response.content.decode("utf8"))
+ self.assertEqual(data["status"], "error")
+ self.assertEqual(data["detail"], "login required")
+ self.assertEqual(data["url"], "/login?")
+
+ @override_settings(CAS_ENABLE_AJAX_AUTH=True)
+ def test_ajax_logged_user_deleted(self):
+ """test ajax user logged deleted: login required"""
+ # get a client that is already authenticated
+ client = get_auth_client()
+ # delete the user in the db
+ user = models.User.objects.get(
+ username=settings.CAS_TEST_USER,
+ session_key=client.session.session_key
+ )
+ user.delete()
+ # fetch the login page with ajax on
+ response = client.get("/login", HTTP_X_AJAX='on')
+ # we get a json telling us that the user need to authenticate
+ self.assertEqual(response.status_code, 200)
+ data = json.loads(response.content.decode("utf8"))
+ self.assertEqual(data["status"], "error")
+ self.assertEqual(data["detail"], "login required")
+ self.assertEqual(data["url"], "/login?")
+
+ @override_settings(CAS_ENABLE_AJAX_AUTH=True)
+ def test_ajax_logged(self):
+ """test ajax user is successfully logged"""
+ # get a client that is already authenticated
+ client = get_auth_client()
+ # fetch the login page with ajax on
+ response = client.get("/login", HTTP_X_AJAX='on')
+ # we get a json telling us that the user is well authenticated
+ self.assertEqual(response.status_code, 200)
+ data = json.loads(response.content.decode("utf8"))
+ self.assertEqual(data["status"], "success")
+ self.assertEqual(data["detail"], "logged")
+
+ @override_settings(CAS_ENABLE_AJAX_AUTH=True)
+ def test_ajax_get_ticket_success(self):
+ """test ajax retrieve a ticket for an allowed service"""
+ # using the default test service
+ service = "https://www.example.com"
+ # get a client that is already authenticated
+ client = get_auth_client()
+ # fetch the login page with ajax on
+ response = client.get("/login", {'service': service}, HTTP_X_AJAX='on')
+ # we get a json telling us that the ticket has being created
+ # and we get the url to fetch to authenticate the user to the service
+ # contening the ticket has GET parameter
+ self.assertEqual(response.status_code, 200)
+ data = json.loads(response.content.decode("utf8"))
+ self.assertEqual(data["status"], "success")
+ self.assertEqual(data["detail"], "auth")
+ self.assertTrue(data["url"].startswith('%s?ticket=' % service))
+
+ def test_ajax_get_ticket_success_alt(self):
+ """
+ test ajax retrieve a ticket for an allowed service.
+ Same as above but with CAS_ENABLE_AJAX_AUTH=False
+ """
+ # using the default test service
+ service = "https://www.example.com"
+ # get a client that is already authenticated
+ client = get_auth_client()
+ # fetch the login page with ajax on
+ response = client.get("/login", {'service': service}, HTTP_X_AJAX='on')
+ # as CAS_ENABLE_AJAX_AUTH is False the ajax request is ignored and word normally:
+ # 302 redirect to the service url with ticket as GET parameter. javascript
+ # cannot retieve the ticket info and try follow the redirect to an other domain and fail
+ # silently
+ self.assertEqual(response.status_code, 302)
+
+ @override_settings(CAS_ENABLE_AJAX_AUTH=True)
+ def test_ajax_get_ticket_fail(self):
+ """test ajax retrieve a ticket for a denied service"""
+ # using a denied service url
+ service = "https://www.example.org"
+ # get a client that is already authenticated
+ client = get_auth_client()
+ # fetch the login page with ajax on
+ response = client.get("/login", {'service': service}, HTTP_X_AJAX='on')
+ # we get a json telling us that the service is not allowed
+ self.assertEqual(response.status_code, 200)
+ data = json.loads(response.content.decode("utf8"))
+ self.assertEqual(data["status"], "error")
+ self.assertEqual(data["detail"], "auth")
+ self.assertEqual(data["messages"][0]["level"], "error")
+ self.assertEqual(
+ data["messages"][0]["message"],
+ "Service https://www.example.org non allowed."
+ )
+
+ @override_settings(CAS_ENABLE_AJAX_AUTH=True)
+ def test_ajax_get_ticket_warn(self):
+ """test get a ticket but user asked to be warned"""
+ # using the default test service
+ service = "https://www.example.com"
+ # get a client that is already authenticated wth warn on
+ client = get_auth_client(warn="on")
+ # fetch the login page with ajax on
+ response = client.get("/login", {'service': service}, HTTP_X_AJAX='on')
+ # we get a json telling us that we cannot get a ticket transparently and that the
+ # user has asked to be warned
+ self.assertEqual(response.status_code, 200)
+ data = json.loads(response.content.decode("utf8"))
+ self.assertEqual(data["status"], "error")
+ self.assertEqual(data["detail"], "confirmation needed")
+
+
+@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
+class LogoutTestCase(TestCase):
+ """test fot the logout view"""
+ def setUp(self):
+ """Prepare the test context"""
+ # for testing SingleLogOut we need to use a service on localhost were we lanch
+ # a simple one request http server
+ self.service = 'http://127.0.0.1:45678'
+ self.service_pattern = models.ServicePattern.objects.create(
+ name="localhost",
+ pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
+ single_log_out=True
+ )
+ # return all user attributes
+ models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
+
+ def test_logout(self):
+ """logout is idempotent"""
+ # get a bare client
+ client = Client()
+
+ # call logout
+ client.get("/logout")
+
+ # we are still not logged
+ self.assertFalse(client.session.get("username"))
+ self.assertFalse(client.session.get("authenticated"))
+
+ def test_logout_view(self):
+ """test simple logout, logout only an user from one and only one sessions"""
+ # get two authenticated client with the same test user (but two different sessions)
+ client = get_auth_client()
+ client2 = get_auth_client()
+
+ # fetch login, the first client is well authenticated
+ response = client.get("/login")
+ self.assertEqual(response.status_code, 200)
+ self.assertTrue(
+ (
+ b"You have successfully logged into "
+ b"the Central Authentication Service"
+ ) in response.content
+ )
+ # and session variable are well
+ self.assertTrue(client.session["username"] == settings.CAS_TEST_USER)
+ self.assertTrue(client.session["authenticated"] is True)
+
+ # call logout with the first client
+ response = client.get("/logout")
+ # the client is logged out
+ self.assertEqual(response.status_code, 200)
+ self.assertTrue(
+ (
+ b"You have successfully logged out from "
+ b"the Central Authentication Service"
+ ) in response.content
+ )
+ # and session variable a well cleaned
+ self.assertFalse(client.session.get("username"))
+ self.assertFalse(client.session.get("authenticated"))
+ # client2 is still logged
+ self.assertTrue(client2.session["username"] == settings.CAS_TEST_USER)
+ self.assertTrue(client2.session["authenticated"] is True)
+
+ response = client.get("/login")
+ # fetch login, the second client is well authenticated
+ self.assertEqual(response.status_code, 200)
+ self.assertFalse(
+ (
+ b"You have successfully logged into "
+ b"the Central Authentication Service"
+ ) in response.content
+ )
+
+ def test_logout_from_all_session(self):
+ """test logout from all my session"""
+ # get two authenticated client with the same test user (but two different sessions)
+ client = get_auth_client()
+ client2 = get_auth_client()
+
+ # call logout with the first client and ask to be logged out from all of this user sessions
+ client.get("/logout?all=1")
+
+ # both client are logged out
+ self.assertFalse(client.session.get("username"))
+ self.assertFalse(client.session.get("authenticated"))
+ self.assertFalse(client2.session.get("username"))
+ self.assertFalse(client2.session.get("authenticated"))
+
+ def assert_redirect_to_service(self, client, response):
+ """assert logout redirect to parameter"""
+ # assert a redirection with a service
+ self.assertEqual(response.status_code, 302)
+ self.assertTrue(response.has_header("Location"))
+ self.assertEqual(response["Location"], "https://www.example.com")
+
+ response = client.get("/login")
+ self.assertEqual(response.status_code, 200)
+ # assert we are not longer logged in
+ self.assertFalse(
+ (
+ b"You have successfully logged into "
+ b"the Central Authentication Service"
+ ) in response.content
+ )
+
+ def test_logout_view_url(self):
+ """test logout redirect to url parameter"""
+ # get a client that is authenticated
+ client = get_auth_client()
+
+ # logout with an url paramer
+ response = client.get('/logout?url=https://www.example.com')
+ # we are redirected to the addresse of the url parameter
+ self.assert_redirect_to_service(client, response)
+
+ def test_logout_view_service(self):
+ """test logout redirect to service parameter"""
+ # get a client that is authenticated
+ client = get_auth_client()
+
+ # logout with a service parameter
+ response = client.get('/logout?service=https://www.example.com')
+ # we are redirected to the addresse of the service parameter
+ self.assert_redirect_to_service(client, response)
+
+ def test_logout_slo(self):
+ """test logout from a service with SLO support"""
+ parameters = []
+
+ # test normal SLO
+ # setup a simple one request http server
+ (httpd, host, port) = HttpParamsHandler.run()[0:3]
+ # build a service url depending on which port the http server has binded
+ service = "http://%s:%s" % (host, port)
+ # get a ticket requested by client and being validated by the service
+ (client, ticket) = get_validated_ticket(service)[:2]
+ # the client logout triggering the send of the SLO requests
+ client.get('/logout')
+ # we store the POST parameters send for this ticket for furthur analisys
+ parameters.append((httpd.PARAMS, ticket))
+
+ # text SLO with a single_log_out_callback
+ # setup a simple one request http server
+ (httpd, host, port) = HttpParamsHandler.run()[0:3]
+ # set the default test service pattern to use the http server port for SLO requests.
+ # in fact, this single_log_out_callback parametter is usefull to implement SLO
+ # for non http service like imap or ftp
+ self.service_pattern.single_log_out_callback = "http://%s:%s" % (host, port)
+ self.service_pattern.save()
+ # get a ticket requested by client and being validated by the service
+ (client, ticket) = get_validated_ticket(self.service)[:2]
+ # the client logout triggering the send of the SLO requests
+ client.get('/logout')
+ # we store the POST parameters send for this ticket for furthur analisys
+ parameters.append((httpd.PARAMS, ticket))
+
+ # for earch POST parameters and corresponding ticket
+ for (params, ticket) in parameters:
+ # there is a POST parameter 'logoutRequest'
+ self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest'])
+
+ # it is a valid xml
+ root = etree.fromstring(params[b'logoutRequest'][0])
+ # contening a tag
+ self.assertTrue(
+ root.xpath(
+ "//samlp:LogoutRequest",
+ namespaces={"samlp": "urn:oasis:names:tc:SAML:2.0:protocol"}
+ )
+ )
+ # with a tag enclosing the value of the ticket
+ session_index = root.xpath(
+ "//samlp:SessionIndex",
+ namespaces={"samlp": "urn:oasis:names:tc:SAML:2.0:protocol"}
+ )
+ self.assertEqual(len(session_index), 1)
+ self.assertEqual(session_index[0].text, ticket.value)
+
+ # SLO error are displayed on logout page
+ (client, ticket) = get_validated_ticket(self.service)[:2]
+ # the client logout triggering the send of the SLO requests but
+ # not http server are listening
+ response = client.get('/logout')
+ self.assertTrue(b"Error during service logout" in response.content)
+
+ @override_settings(CAS_ENABLE_AJAX_AUTH=True)
+ def test_ajax_logout(self):
+ """
+ test ajax logout. These methode are here, but I do not really see an use case for
+ javascript logout
+ """
+ # get a client that is authenticated
+ client = get_auth_client()
+
+ # fetch the logout page with ajax on
+ response = client.get('/logout', HTTP_X_AJAX='on')
+ # we get a json telling us the user is well logged out
+ self.assertEqual(response.status_code, 200)
+ data = json.loads(response.content.decode("utf8"))
+ self.assertEqual(data["status"], "success")
+ self.assertEqual(data["detail"], "logout")
+ self.assertEqual(data['session_nb'], 1)
+
+ @override_settings(CAS_ENABLE_AJAX_AUTH=True)
+ def test_ajax_logout_all_session(self):
+ """test ajax logout from a random number a sessions"""
+ # fire a random int in [2, 10[
+ nb_client = random.randint(2, 10)
+ # get this much of logged clients all for the test user
+ clients = [get_auth_client() for i in range(nb_client)]
+ # fetch the logout page with ajax on, requesting to logout from all sessions
+ response = clients[0].get('/logout?all=1', HTTP_X_AJAX='on')
+ # we get a json telling us the user is well logged out and the number of session
+ # the user has being logged out
+ self.assertEqual(response.status_code, 200)
+ data = json.loads(response.content.decode("utf8"))
+ self.assertEqual(data["status"], "success")
+ self.assertEqual(data["detail"], "logout")
+ self.assertEqual(data['session_nb'], nb_client)
+
+ @override_settings(CAS_REDIRECT_TO_LOGIN_AFTER_LOGOUT=True)
+ def test_redirect_after_logout(self):
+ """Test redirect to login after logout parameter"""
+ # get a client that is authenticated
+ client = get_auth_client()
+
+ # fetch the logout page
+ response = client.get('/logout')
+ # as CAS_REDIRECT_TO_LOGIN_AFTER_LOGOUT is True, we are redirected to the login page
+ self.assertEqual(response.status_code, 302)
+ if django.VERSION < (1, 9): # pragma: no cover coverage is computed with dango 1.9
+ self.assertEqual(response["Location"], "http://testserver/login")
+ else:
+ self.assertEqual(response["Location"], "/login")
+ self.assertFalse(client.session.get("username"))
+ self.assertFalse(client.session.get("authenticated"))
+
+ @override_settings(CAS_REDIRECT_TO_LOGIN_AFTER_LOGOUT=True)
+ def test_redirect_after_logout_to_service(self):
+ """test prevalence of redirect url/service parameter over redirect to login after logout"""
+ # get a client that is authenticated
+ client = get_auth_client()
+
+ # fetch the logout page with an url parameter
+ response = client.get('/logout?url=https://www.example.com')
+ # we are redirected to the url parameter and not to the login page
+ self.assert_redirect_to_service(client, response)
+
+ # fetch the logout page with an service parameter
+ response = client.get('/logout?service=https://www.example.com')
+ # we are redirected to the service parameter and not to the login page
+ self.assert_redirect_to_service(client, response)
+
+ @override_settings(CAS_REDIRECT_TO_LOGIN_AFTER_LOGOUT=True, CAS_ENABLE_AJAX_AUTH=True)
+ def test_ajax_redirect_after_logout(self):
+ """Test ajax redirect to login after logout parameter"""
+ # get a client that is authenticated
+ client = get_auth_client()
+
+ # fetch the logout page with ajax on
+ response = client.get('/logout', HTTP_X_AJAX='on')
+ # we get a json telling us the user is well logged out. And url key is added to aks for
+ # redirection to the login page
+ self.assertEqual(response.status_code, 200)
+ data = json.loads(response.content.decode("utf8"))
+ self.assertEqual(data["status"], "success")
+ self.assertEqual(data["detail"], "logout")
+ self.assertEqual(data['session_nb'], 1)
+ self.assertEqual(data['url'], '/login')
+
+
+@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
+class AuthTestCase(TestCase):
+ """
+ Test for the auth view, used for external services
+ to validate (user, pass, service) tuples.
+ """
+ def setUp(self):
+ """preparing test context"""
+ # setting up a default test service url and pattern
+ self.service = 'https://www.example.com'
+ models.ServicePattern.objects.create(
+ name="example",
+ pattern="^https://www\.example\.com(/.*)?$"
+ )
+
+ @override_settings(CAS_AUTH_SHARED_SECRET='test')
+ def test_auth_view_goodpass(self):
+ """successful request are awsered by yes"""
+ # get a bare client
+ client = Client()
+ # post the the auth view a valid (username, password, service) and the shared secret
+ # to test the user again the service, a user is created in the database for the
+ # current session and is then deleted as the user is not authenticated
+ response = client.post(
+ '/auth',
+ {
+ 'username': settings.CAS_TEST_USER,
+ 'password': settings.CAS_TEST_PASSWORD,
+ 'service': self.service,
+ 'secret': 'test'
+ }
+ )
+ # as (username, password, service) and the hared secret are valid, we get yes as a response
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, b'yes\n')
+
+ @override_settings(CAS_AUTH_SHARED_SECRET='test')
+ def test_auth_view_goodpass_logged(self):
+ """successful request are awsered by yes, using a logged sessions"""
+ # same as above
+ client = get_auth_client()
+ # to test the user again the service, a user is fetch in the database for the
+ # current session and is NOT deleted as the user is currently logged.
+ # Deleting the user from the database would cause the user to be logged out as
+ # showed in the login tests
+ response = client.post(
+ '/auth',
+ {
+ 'username': settings.CAS_TEST_USER,
+ 'password': settings.CAS_TEST_PASSWORD,
+ 'service': self.service,
+ 'secret': 'test'
+ }
+ )
+ # as (username, password, service) and the hared secret are valid, we get yes as a response
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, b'yes\n')
+
+ @override_settings(CAS_AUTH_SHARED_SECRET='test')
+ def test_auth_view_badpass(self):
+ """ bag user password => no"""
+ client = Client()
+ response = client.post(
+ '/auth',
+ {
+ 'username': settings.CAS_TEST_USER,
+ 'password': 'badpass',
+ 'service': self.service,
+ 'secret': 'test'
+ }
+ )
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, b'no\n')
+
+ @override_settings(CAS_AUTH_SHARED_SECRET='test')
+ def test_auth_view_badservice(self):
+ """bad service => no"""
+ client = Client()
+ response = client.post(
+ '/auth',
+ {
+ 'username': settings.CAS_TEST_USER,
+ 'password': settings.CAS_TEST_PASSWORD,
+ 'service': 'https://www.example.org',
+ 'secret': 'test'
+ }
+ )
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, b'no\n')
+
+ @override_settings(CAS_AUTH_SHARED_SECRET='test')
+ def test_auth_view_badsecret(self):
+ """bad api key => no"""
+ client = Client()
+ response = client.post(
+ '/auth',
+ {
+ 'username': settings.CAS_TEST_USER,
+ 'password': settings.CAS_TEST_PASSWORD,
+ 'service': self.service,
+ 'secret': 'badsecret'
+ }
+ )
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, b'no\n')
+
+ def test_auth_view_badsettings(self):
+ """api not set => error"""
+ client = Client()
+ response = client.post(
+ '/auth',
+ {
+ 'username': settings.CAS_TEST_USER,
+ 'password': settings.CAS_TEST_PASSWORD,
+ 'service': self.service,
+ 'secret': 'test'
+ }
+ )
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, b"no\nplease set CAS_AUTH_SHARED_SECRET")
+
+ @override_settings(CAS_AUTH_SHARED_SECRET='test')
+ def test_auth_view_missing_parameter(self):
+ """missing parameter in request => no"""
+ client = Client()
+ params = {
+ 'username': settings.CAS_TEST_USER,
+ 'password': settings.CAS_TEST_PASSWORD,
+ 'service': self.service,
+ 'secret': 'test'
+ }
+ for key in ['username', 'password', 'service']:
+ send_params = params.copy()
+ del send_params[key]
+ response = client.post('/auth', send_params)
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, b'no\n')
+
+
+@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
+class ValidateTestCase(TestCase):
+ """tests for the validate view"""
+ def setUp(self):
+ """preparing test context"""
+ # setting up a default test service url and pattern
+ self.service = 'https://www.example.com'
+ self.service_pattern = models.ServicePattern.objects.create(
+ name="example",
+ pattern="^https://www\.example\.com(/.*)?$"
+ )
+ models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
+ # setting up a test service and pattern using a multi valued user attribut as username
+ # the first value of the list should be used as username
+ self.service_user_field = "https://user_field.example.com"
+ self.service_pattern_user_field = models.ServicePattern.objects.create(
+ name="user field",
+ pattern="^https://user_field\.example\.com(/.*)?$",
+ user_field="alias"
+ )
+ # setting up a test service and pattern using a single valued user attribut as username
+ self.service_user_field_alt = "https://user_field_alt.example.com"
+ self.service_pattern_user_field_alt = models.ServicePattern.objects.create(
+ name="user field alt",
+ pattern="^https://user_field_alt\.example\.com(/.*)?$",
+ user_field="nom"
+ )
+
+ def test_validate_view_ok(self):
+ """test for a valid (ticket, service)"""
+ # get a ticket waiting to be validated for self.service
+ ticket = get_user_ticket_request(self.service)[1]
+
+ # get a bare client
+ client = Client()
+ # calling the validate view with this ticket value and service
+ response = client.get('/validate', {'ticket': ticket.value, 'service': self.service})
+ # get yes as a response and the test user username
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, b'yes\ntest\n')
+
+ def test_validate_view_badservice(self):
+ """test for a valid ticket but bad service"""
+ ticket = get_user_ticket_request(self.service)[1]
+
+ client = Client()
+ # calling the validate view with this ticket value and another service
+ response = client.get(
+ '/validate',
+ {'ticket': ticket.value, 'service': "https://www.example.org"}
+ )
+ # the ticket service and validation service do not match, validation should fail
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, b'no\n')
+
+ def test_validate_view_badticket(self):
+ """test for a bad ticket but valid service"""
+ get_user_ticket_request(self.service)
+
+ client = Client()
+ # calling the validate view with another ticket value and this service
+ response = client.get(
+ '/validate',
+ {'ticket': "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX, 'service': self.service}
+ )
+ # as the ticket is bad, validation should fail
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, b'no\n')
+
+ def test_validate_user_field_ok(self):
+ """
+ test with a good user_field. A bad user_field (that evaluate to False)
+ wont happed cause it is filtered in the login view
+ """
+ for (service, username) in [
+ (self.service_user_field, b"demo1"),
+ (self.service_user_field_alt, b"Nymous")
+ ]:
+ ticket = get_user_ticket_request(service)[1]
+ client = Client()
+ response = client.get(
+ '/validate',
+ {'ticket': ticket.value, 'service': service}
+ )
+ self.assertEqual(response.status_code, 200)
+ # the user attribute is well used as username
+ self.assertEqual(response.content, b'yes\n' + username + b'\n')
+
+ def test_validate_missing_parameter(self):
+ """test with a missing GET parameter among [service, ticket]"""
+ ticket = get_user_ticket_request(self.service)[1]
+
+ client = Client()
+ params = {'ticket': ticket.value, 'service': self.service}
+ for key in ['ticket', 'service']:
+ send_params = params.copy()
+ del send_params[key]
+ response = client.get('/validate', send_params)
+ # if the GET request is missing the ticket or
+ # service GET parameter, validation should fail
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, b'no\n')
+
+
+@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
+class ValidateServiceTestCase(TestCase, XmlContent):
+ """tests for the serviceValidate view"""
+ def setUp(self):
+ """preparing test context"""
+ # for testing SingleLogOut and Proxy GrantingTicket transmission
+ # we need to use a service on localhost were we launch
+ # a simple one request http server
+ self.service = 'http://127.0.0.1:45678'
+ self.service_pattern = models.ServicePattern.objects.create(
+ name="localhost",
+ pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
+ # allow to request PGT by the service
+ proxy_callback=True
+ )
+ # tell the service pattern to transmit all the user attributes (* is a joker)
+ models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
+
+ # test service pattern using the attribute alias as username
+ self.service_user_field = "https://user_field.example.com"
+ self.service_pattern_user_field = models.ServicePattern.objects.create(
+ name="user field",
+ pattern="^https://user_field\.example\.com(/.*)?$",
+ user_field="alias"
+ )
+ # test service pattern using the attribute nom as username
+ self.service_user_field_alt = "https://user_field_alt.example.com"
+ self.service_pattern_user_field_alt = models.ServicePattern.objects.create(
+ name="user field alt",
+ pattern="^https://user_field_alt\.example\.com(/.*)?$",
+ user_field="nom"
+ )
+
+ # test service pattern only transmiting one single attributes
+ self.service_one_attribute = "https://one_attribute.example.com"
+ self.service_pattern_one_attribute = models.ServicePattern.objects.create(
+ name="one_attribute",
+ pattern="^https://one_attribute\.example\.com(/.*)?$"
+ )
+ models.ReplaceAttributName.objects.create(
+ name="nom",
+ service_pattern=self.service_pattern_one_attribute
+ )
+
+ # test service pattern testing attribute name and value replacement
+ self.service_replace_attribute_list = "https://replace_attribute_list.example.com"
+ self.service_pattern_replace_attribute_list = models.ServicePattern.objects.create(
+ name="replace_attribute_list",
+ pattern="^https://replace_attribute_list\.example\.com(/.*)?$",
+ )
+ models.ReplaceAttributValue.objects.create(
+ attribut="alias",
+ pattern="^demo",
+ replace="truc",
+ service_pattern=self.service_pattern_replace_attribute_list
+ )
+ models.ReplaceAttributName.objects.create(
+ name="alias",
+ replace="ALIAS",
+ service_pattern=self.service_pattern_replace_attribute_list
+ )
+ self.service_replace_attribute = "https://replace_attribute.example.com"
+ self.service_pattern_replace_attribute = models.ServicePattern.objects.create(
+ name="replace_attribute",
+ pattern="^https://replace_attribute\.example\.com(/.*)?$",
+ )
+ models.ReplaceAttributValue.objects.create(
+ attribut="nom",
+ pattern="N",
+ replace="P",
+ service_pattern=self.service_pattern_replace_attribute
+ )
+ models.ReplaceAttributName.objects.create(
+ name="nom",
+ replace="NOM",
+ service_pattern=self.service_pattern_replace_attribute
+ )
+
+ def test_validate_service_view_ok(self):
+ """test with a valid (ticket, service), the username and all attributes are transmited"""
+ # get a ticket from an authenticated user waiting for validation
+ ticket = get_user_ticket_request(self.service)[1]
+
+ # get a bare client
+ client = Client()
+ # requesting validation with a good (ticket, service)
+ response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service})
+ # the validation should succes with username settings.CAS_TEST_USER and transmit
+ # the attributes settings.CAS_TEST_ATTRIBUTES
+ self.assert_success(response, settings.CAS_TEST_USER, settings.CAS_TEST_ATTRIBUTES)
+
+ def test_validate_service_view_ok_one_attribute(self):
+ """
+ test with a valid (ticket, service), the username and
+ the 'nom' only attribute are transmited
+ """
+ # get a ticket for a service that transmit only one attribute
+ ticket = get_user_ticket_request(self.service_one_attribute)[1]
+
+ client = Client()
+ response = client.get(
+ '/serviceValidate',
+ {'ticket': ticket.value, 'service': self.service_one_attribute}
+ )
+ # the validation should succed, returning settings.CAS_TEST_USER as username and a single
+ # attribute 'nom'
+ self.assert_success(
+ response,
+ settings.CAS_TEST_USER,
+ {'nom': settings.CAS_TEST_ATTRIBUTES['nom']}
+ )
+
+ def test_validate_replace_attributes(self):
+ """test with a valid (ticket, service), attributes name and value replacement"""
+ # get a ticket for a service pattern replacing attributes names
+ # nom -> NOM and value nom -> s/^N/P/ for a single valued attribute
+ ticket = get_user_ticket_request(self.service_replace_attribute)[1]
+ client = Client()
+ response = client.get(
+ '/serviceValidate',
+ {'ticket': ticket.value, 'service': self.service_replace_attribute}
+ )
+ self.assert_success(
+ response,
+ settings.CAS_TEST_USER,
+ {'NOM': 'Pymous'}
+ )
+
+ # get a ticket for a service pattern replacing attributes names
+ # alias -> ALIAS and value alias -> s/demo/truc/ for a multi valued attribute
+ ticket = get_user_ticket_request(self.service_replace_attribute_list)[1]
+ client = Client()
+ response = client.get(
+ '/serviceValidate',
+ {'ticket': ticket.value, 'service': self.service_replace_attribute_list}
+ )
+ self.assert_success(
+ response,
+ settings.CAS_TEST_USER,
+ {'ALIAS': ['truc1', 'truc2']}
+ )
+
+ def test_validate_service_view_badservice(self):
+ """test with a valid ticket but a bad service, the validatin should fail"""
+ # get a ticket for service A
+ ticket = get_user_ticket_request(self.service)[1]
+
+ client = Client()
+ bad_service = "https://www.example.org"
+ # try to validate it for service B
+ response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': bad_service})
+ # the validation should fail with error code "INVALID_SERVICE"
+ self.assert_error(
+ response,
+ "INVALID_SERVICE",
+ bad_service
+ )
+
+ def test_validate_service_view_badticket_goodprefix(self):
+ """
+ test with a good service but a bad ticket begining with ST-,
+ the validation should fail with the error (INVALID_TICKET, ticket not found)
+ """
+ get_user_ticket_request(self.service)
+
+ client = Client()
+ bad_ticket = "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX
+ response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
+ self.assert_error(
+ response,
+ "INVALID_TICKET",
+ 'ticket not found'
+ )
+
+ def test_validate_service_view_badticket_badprefix(self):
+ """
+ test with a good service bud a bad ticket not begining with ST-,
+ the validation should fail with the error (INVALID_TICKET, `the ticket`)
+ """
+ get_user_ticket_request(self.service)
+
+ client = Client()
+ bad_ticket = "RANDOM"
+ response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
+ self.assert_error(
+ response,
+ "INVALID_TICKET",
+ bad_ticket
+ )
+
+ def test_validate_service_view_ok_pgturl(self):
+ """test the retrieval of a ProxyGrantingTicket"""
+ # start a simple on request http server
+ (httpd, host, port) = HttpParamsHandler.run()[0:3]
+ # construct the service from it
+ service = "http://%s:%s" % (host, port)
+
+ # get a ticket to be validated
+ ticket = get_user_ticket_request(service)[1]
+
+ client = Client()
+ # request a PGT ticket then validating the ticket by setting the pgtUrl parameter
+ response = client.get(
+ '/serviceValidate',
+ {'ticket': ticket.value, 'service': service, 'pgtUrl': service}
+ )
+ # We should have recieved the PGT via a GET request parameter on the simple http server
+ pgt_params = httpd.PARAMS
+ self.assertEqual(response.status_code, 200)
+
+ root = etree.fromstring(response.content)
+ # the validation response should return a id to match again the request transmitting the PGT
+ pgtiou = root.xpath(
+ "//cas:proxyGrantingTicket",
+ namespaces={'cas': "http://www.yale.edu/tp/cas"}
+ )
+ self.assertEqual(len(pgtiou), 1)
+ # the matching id for making corresponde one PGT to a validatin response should match
+ self.assertEqual(pgt_params["pgtIou"], pgtiou[0].text)
+ # the PGT is present in the receive GET requests parameters
+ self.assertTrue("pgtId" in pgt_params)
+
+ def test_validate_service_pgturl_sslerror(self):
+ """test the retrieval of a ProxyGrantingTicket with a SSL error on the pgtUrl"""
+ (host, port) = HttpParamsHandler.run()[1:3]
+ # is fact the service listen on http and not https raisin a SSL Protocol Error
+ # but other SSL/TLS error should behave the same
+ service = "https://%s:%s" % (host, port)
+
+ ticket = get_user_ticket_request(service)[1]
+
+ client = Client()
+ response = client.get(
+ '/serviceValidate',
+ {'ticket': ticket.value, 'service': service, 'pgtUrl': service}
+ )
+ # The pgtUrl is validated: it must be localhost or have valid x509 certificat and
+ # certificat validation should succed. Moreother, pgtUrl should match a service pattern
+ # with proxy_callback set to True
+ self.assert_error(
+ response,
+ "INVALID_PROXY_CALLBACK",
+ )
+
+ def test_validate_service_pgturl_404(self):
+ """
+ test the retrieval on a ProxyGrantingTicket then to pgtUrl return a http error.
+ PGT creation should be aborted but the ticket still be valid
+ """
+ (host, port) = Http404Handler.run()[1:3]
+ service = "http://%s:%s" % (host, port)
+
+ ticket = get_user_ticket_request(service)[1]
+ client = Client()
+ response = client.get(
+ '/serviceValidate',
+ {'ticket': ticket.value, 'service': service, 'pgtUrl': service}
+ )
+ # The ticket is successfully validated
+ root = self.assert_success(response, settings.CAS_TEST_USER, settings.CAS_TEST_ATTRIBUTES)
+ # but no PGT is transmitted
+ pgtiou = root.xpath(
+ "//cas:proxyGrantingTicket",
+ namespaces={'cas': "http://www.yale.edu/tp/cas"}
+ )
+ self.assertFalse(pgtiou)
+
+ def test_validate_service_pgturl_bad_proxy_callback(self):
+ """test the retrieval of a ProxyGrantingTicket, not allowed pgtUrl should be denied"""
+ self.service_pattern.proxy_callback = False
+ self.service_pattern.save()
+ ticket = get_user_ticket_request(self.service)[1]
+ client = Client()
+ response = client.get(
+ '/serviceValidate',
+ {'ticket': ticket.value, 'service': self.service, 'pgtUrl': self.service}
+ )
+ self.assert_error(
+ response,
+ "INVALID_PROXY_CALLBACK",
+ "callback url not allowed by configuration"
+ )
+
+ self.service_pattern.proxy_callback = True
+
+ ticket = get_user_ticket_request(self.service)[1]
+ client = Client()
+ response = client.get(
+ '/serviceValidate',
+ {'ticket': ticket.value, 'service': self.service, 'pgtUrl': "https://www.example.org"}
+ )
+ self.assert_error(
+ response,
+ "INVALID_PROXY_CALLBACK",
+ "callback url not allowed by configuration"
+ )
+
+ def test_validate_user_field_ok(self):
+ """
+ test with a good user_field. A bad user_field (that evaluate to False)
+ wont happed cause it is filtered in the login view
+ """
+ for (service, username) in [
+ (self.service_user_field, settings.CAS_TEST_ATTRIBUTES["alias"][0]),
+ (self.service_user_field_alt, settings.CAS_TEST_ATTRIBUTES["nom"])
+ ]:
+ # requesting a ticket for a service url matched by a service pattern using a user
+ # attribute as username
+ ticket = get_user_ticket_request(service)[1]
+ client = Client()
+ response = client.get(
+ '/serviceValidate',
+ {'ticket': ticket.value, 'service': service}
+ )
+ # The validate shoudl be successful with specified username and no attributes transmited
+ self.assert_success(
+ response,
+ username,
+ {}
+ )
+
+ def test_validate_missing_parameter(self):
+ """test with a missing GET parameter among [service, ticket]"""
+ ticket = get_user_ticket_request(self.service)[1]
+
+ client = Client()
+ params = {'ticket': ticket.value, 'service': self.service}
+ for key in ['ticket', 'service']:
+ send_params = params.copy()
+ del send_params[key]
+ response = client.get('/serviceValidate', send_params)
+ # a validation request with a missing GET parameter should fail
+ self.assert_error(
+ response,
+ "INVALID_REQUEST",
+ "you must specify a service and a ticket"
+ )
+
+
+@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
+class ProxyTestCase(TestCase, BaseServicePattern, XmlContent):
+ """tests for the proxy view"""
+ def setUp(self):
+ """preparing test context"""
+ # we prepare a bunch a service url and service patterns for tests
+ self.setup_service_patterns(proxy=True)
+
+ # set the default service pattern to localhost to be able to retrieve PGT
+ self.service = 'http://127.0.0.1'
+ self.service_pattern = models.ServicePattern.objects.create(
+ name="localhost",
+ pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
+ proxy=True,
+ proxy_callback=True
+ )
+ # transmit all attributes
+ models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
+
+ def test_validate_proxy_ok(self):
+ """
+ Get a PGT, get a proxy ticket, validate it. Validation should succeed and
+ show the proxy service URL.
+ """
+ # we directrly get a ProxyGrantingTicket
+ params = get_pgt()
+
+ # We try get a proxy ticket with our PGT
+ client1 = Client()
+ # for what we send a GET request to /proxy with ge PGT and the target service for which
+ # we want a ProxyTicket to.
+ response = client1.get(
+ '/proxy',
+ {'pgt': params['pgtId'], 'targetService': "https://www.example.com"}
+ )
+ self.assertEqual(response.status_code, 200)
+
+ # we should sucessfully reteive a PT
+ root = etree.fromstring(response.content)
+ sucess = root.xpath("//cas:proxySuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"})
+ self.assertTrue(sucess)
+
+ proxy_ticket = root.xpath(
+ "//cas:proxyTicket",
+ namespaces={'cas': "http://www.yale.edu/tp/cas"}
+ )
+ self.assertEqual(len(proxy_ticket), 1)
+ proxy_ticket = proxy_ticket[0].text
+
+ # validate the proxy ticket with the service for which is was emitted
+ client2 = Client()
+ response = client2.get(
+ '/proxyValidate',
+ {'ticket': proxy_ticket, 'service': "https://www.example.com"}
+ )
+ # validation should succeed and return settings.CAS_TEST_USER as username
+ # and settings.CAS_TEST_ATTRIBUTES as attributes
+ root = self.assert_success(
+ response,
+ settings.CAS_TEST_USER,
+ settings.CAS_TEST_ATTRIBUTES
+ )
+
+ # in the PT validation response, it should have the service url of the PGY
+ proxies = root.xpath("//cas:proxies", namespaces={'cas': "http://www.yale.edu/tp/cas"})
+ self.assertEqual(len(proxies), 1)
+ proxy = proxies[0].xpath("//cas:proxy", namespaces={'cas': "http://www.yale.edu/tp/cas"})
+ self.assertEqual(len(proxy), 1)
+ self.assertEqual(proxy[0].text, params["service"])
+
+ def test_validate_proxy_bad_pgt(self):
+ """Try to get a ProxyTicket with a bad PGT. The PT generation should fail"""
+ # we directrly get a ProxyGrantingTicket
+ params = get_pgt()
+ client = Client()
+ response = client.get(
+ '/proxy',
+ {
+ 'pgt': "%s-RANDOM" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX,
+ 'targetService': params['service']
+ }
+ )
+ self.assert_error(
+ response,
+ "INVALID_TICKET",
+ "PGT %s-RANDOM not found" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX
+ )
+
+ def test_validate_proxy_bad_service(self):
+ """
+ Try to get a ProxyTicket for a denied service and
+ a service that do not allow PT. The PT generation should fail.
+ """
+ # we directrly get a ProxyGrantingTicket
+ params = get_pgt()
+
+ # try to get a PT for a denied service
+ client1 = Client()
+ response = client1.get(
+ '/proxy',
+ {'pgt': params['pgtId'], 'targetService': "https://www.example.org"}
+ )
+ self.assert_error(
+ response,
+ "UNAUTHORIZED_SERVICE",
+ "https://www.example.org"
+ )
+
+ # try to get a PT for a service that do not allow PT
+ self.service_pattern.proxy = False
+ self.service_pattern.save()
+
+ client2 = Client()
+ response = client2.get(
+ '/proxy',
+ {'pgt': params['pgtId'], 'targetService': params['service']}
+ )
+
+ self.assert_error(
+ response,
+ "UNAUTHORIZED_SERVICE",
+ 'the service %s do not allow proxy ticket' % params['service']
+ )
+
+ self.service_pattern.proxy = True
+ self.service_pattern.save()
+
+ def test_proxy_unauthorized_user(self):
+ """
+ Try to get a PT for services that do not allow the current user:
+ * first with a service that restrict allowed username
+ * second with a service requiring somes conditions on the user attributes
+ * third with a service using a particular user attribute as username
+ All this tests should fail
+ """
+ # we directrly get a ProxyGrantingTicket
+ params = get_pgt()
+
+ for service in [
+ # do ot allow the test username
+ self.service_restrict_user_fail,
+ # require the 'nom' attribute to be 'toto'
+ self.service_filter_fail,
+ # want to use the non-exitant 'uid' attribute as username
+ self.service_field_needed_fail
+ ]:
+ client = Client()
+ response = client.get(
+ '/proxy',
+ {'pgt': params['pgtId'], 'targetService': service}
+ )
+ # PT generation should fail
+ self.assert_error(
+ response,
+ "UNAUTHORIZED_USER",
+ 'User %s not allowed on %s' % (settings.CAS_TEST_USER, service)
+ )
+
+ def test_proxy_missing_parameter(self):
+ """Try to get a PGT with some missing GET parameters. The PT should not be emited"""
+ params = get_pgt()
+ base_params = {'pgt': params['pgtId'], 'targetService': "https://www.example.org"}
+ for key in ["pgt", 'targetService']:
+ send_params = base_params.copy()
+ del send_params[key]
+ client = Client()
+ response = client.get("/proxy", send_params)
+ self.assert_error(
+ response,
+ "INVALID_REQUEST",
+ 'you must specify and pgt and targetService'
+ )
+
+
+@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
+class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent):
+ """tests for the proxy view"""
+ def setUp(self):
+ """preparing test context"""
+ # we prepare a bunch a service url and service patterns for tests
+ self.setup_service_patterns(proxy=True)
+
+ # special service pattern for retrieving a PGT
+ self.service_pgt = 'http://127.0.0.1'
+ self.service_pattern_pgt = models.ServicePattern.objects.create(
+ name="localhost",
+ pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
+ proxy=True,
+ proxy_callback=True
+ )
+ models.ReplaceAttributName.objects.create(
+ name="*",
+ service_pattern=self.service_pattern_pgt
+ )
+
+ # template for the XML POST need to be send to validate a ticket using SAML 1.1
+ xml_template = """
+
+
+
+
+ %(ticket)s
+
+
+"""
+
+ def assert_success(self, response, username, original_attributes):
+ """assert ticket validation success"""
+ self.assertEqual(response.status_code, 200)
+ # on validation success, the response should have a StatusCode set to Success
+ root = etree.fromstring(response.content)
+ success = root.xpath(
+ "//samlp:StatusCode",
+ namespaces={'samlp': "urn:oasis:names:tc:SAML:1.0:protocol"}
+ )
+ self.assertEqual(len(success), 1)
+ self.assertTrue(success[0].attrib['Value'].endswith(":Success"))
+
+ # the user username should be return whithin tags
+ user = root.xpath(
+ "//samla:NameIdentifier",
+ namespaces={'samla': "urn:oasis:names:tc:SAML:1.0:assertion"}
+ )
+ self.assertTrue(user)
+ self.assertEqual(user[0].text, username)
+
+ # the returned attributes should match original_attributes
+ attributes = root.xpath(
+ "//samla:AttributeStatement/samla:Attribute",
+ namespaces={'samla': "urn:oasis:names:tc:SAML:1.0:assertion"}
+ )
+ attrs = set()
+ for attr in attributes:
+ attrs.add((attr.attrib['AttributeName'], attr.getchildren()[0].text))
+ original = set()
+ for key, value in original_attributes.items():
+ if isinstance(value, list):
+ for subval in value:
+ original.add((key, subval))
+ else:
+ original.add((key, value))
+ self.assertEqual(original, attrs)
+
+ def assert_error(self, response, code, msg=None):
+ """assert ticket validation error"""
+ self.assertEqual(response.status_code, 200)
+ # on error the status code value should be the one provider in `code`
+ root = etree.fromstring(response.content)
+ error = root.xpath(
+ "//samlp:StatusCode",
+ namespaces={'samlp': "urn:oasis:names:tc:SAML:1.0:protocol"}
+ )
+ self.assertEqual(len(error), 1)
+ self.assertTrue(error[0].attrib['Value'].endswith(":%s" % code))
+ # it may have an error message
+ if msg is not None:
+ self.assertEqual(error[0].text, msg)
+
+ def test_saml_ok(self):
+ """
+ test with a valid (ticket, service), with a ST and a PT,
+ the username and all attributes are transmited"""
+ tickets = [
+ # return a ServiceTicket (standard ticket) waiting for validation
+ get_user_ticket_request(self.service)[1],
+ # return a PT waiting for validation
+ get_proxy_ticket(self.service)
+ ]
+
+ for ticket in tickets:
+ client = Client()
+ # we send the POST validation requests
+ response = client.post(
+ '/samlValidate?TARGET=%s' % self.service,
+ self.xml_template % {
+ 'ticket': ticket.value,
+ 'request_id': utils.gen_saml_id(),
+ 'issue_instant': timezone.now().isoformat()
+ },
+ content_type="text/xml; encoding='utf-8'"
+ )
+ # and it should succeed
+ self.assert_success(response, settings.CAS_TEST_USER, settings.CAS_TEST_ATTRIBUTES)
+
+ def test_saml_ok_user_field(self):
+ """test with a valid(ticket, service), use a attributes as transmitted username"""
+ for (service, username) in [
+ (self.service_field_needed_success, settings.CAS_TEST_ATTRIBUTES['alias'][0]),
+ (self.service_field_needed_success_alt, settings.CAS_TEST_ATTRIBUTES['nom'])
+ ]:
+ ticket = get_user_ticket_request(service)[1]
+
+ client = Client()
+ response = client.post(
+ '/samlValidate?TARGET=%s' % service,
+ self.xml_template % {
+ 'ticket': ticket.value,
+ 'request_id': utils.gen_saml_id(),
+ 'issue_instant': timezone.now().isoformat()
+ },
+ content_type="text/xml; encoding='utf-8'"
+ )
+ self.assert_success(response, username, {})
+
+ def test_saml_bad_ticket(self):
+ """test validation with a bad ST and a bad PT, validation should fail"""
+ tickets = [utils.gen_st(), utils.gen_pt()]
+
+ for ticket in tickets:
+ client = Client()
+ response = client.post(
+ '/samlValidate?TARGET=%s' % self.service,
+ self.xml_template % {
+ 'ticket': ticket,
+ 'request_id': utils.gen_saml_id(),
+ 'issue_instant': timezone.now().isoformat()
+ },
+ content_type="text/xml; encoding='utf-8'"
+ )
+ self.assert_error(
+ response,
+ "AuthnFailed",
+ 'ticket %s not found' % ticket
+ )
+
+ def test_saml_bad_ticket_prefix(self):
+ """test validation with a bad ticket prefix. Validation should fail with 'AuthnFailed'"""
+ bad_ticket = "RANDOM-NOT-BEGINING-WITH-ST-OR-ST"
+ client = Client()
+ response = client.post(
+ '/samlValidate?TARGET=%s' % self.service,
+ self.xml_template % {
+ 'ticket': bad_ticket,
+ 'request_id': utils.gen_saml_id(),
+ 'issue_instant': timezone.now().isoformat()
+ },
+ content_type="text/xml; encoding='utf-8'"
+ )
+ self.assert_error(
+ response,
+ "AuthnFailed",
+ 'ticket %s should begin with PT- or ST-' % bad_ticket
+ )
+
+ def test_saml_bad_target(self):
+ """test with a valid ticket, but using a bad target, validation should fail"""
+ bad_target = "https://www.example.org"
+ ticket = get_user_ticket_request(self.service)[1]
+
+ client = Client()
+ response = client.post(
+ '/samlValidate?TARGET=%s' % bad_target,
+ self.xml_template % {
+ 'ticket': ticket.value,
+ 'request_id': utils.gen_saml_id(),
+ 'issue_instant': timezone.now().isoformat()
+ },
+ content_type="text/xml; encoding='utf-8'"
+ )
+ self.assert_error(
+ response,
+ "AuthnFailed",
+ 'TARGET %s do not match ticket service' % bad_target
+ )
+
+ def test_saml_bad_xml(self):
+ """test validation with a bad xml request, validation should fail"""
+ client = Client()
+ response = client.post(
+ '/samlValidate?TARGET=%s' % self.service,
+ "",
+ content_type="text/xml; encoding='utf-8'"
+ )
+ self.assert_error(response, 'VersionMismatch')
diff --git a/cas_server/tests/urls.py b/cas_server/tests/urls.py
new file mode 100644
index 0000000..a9ed25c
--- /dev/null
+++ b/cas_server/tests/urls.py
@@ -0,0 +1,22 @@
+"""cas URL Configuration
+
+The `urlpatterns` list routes URLs to views. For more information please see:
+ https://docs.djangoproject.com/en/1.9/topics/http/urls/
+Examples:
+Function views
+ 1. Add an import: from my_app import views
+ 2. Add a URL to urlpatterns: url(r'^$', views.home, name='home')
+Class-based views
+ 1. Add an import: from other_app.views import Home
+ 2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home')
+Including another URLconf
+ 1. Import the include() function: from django.conf.urls import url, include, include
+ 2. Add a URL to urlpatterns: url(r'^blog/', include('blog.urls'))
+"""
+from django.conf.urls import url, include
+from django.contrib import admin
+
+urlpatterns = [
+ url(r'^admin/', admin.site.urls),
+ url(r'^', include('cas_server.urls', namespace='cas_server')),
+]
diff --git a/cas_server/tests/utils.py b/cas_server/tests/utils.py
new file mode 100644
index 0000000..bd692e9
--- /dev/null
+++ b/cas_server/tests/utils.py
@@ -0,0 +1,180 @@
+# ⁻*- coding: utf-8 -*-
+# This program is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
+# more details.
+#
+# You should have received a copy of the GNU General Public License version 3
+# along with this program; if not, write to the Free Software Foundation, Inc., 51
+# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+#
+# (c) 2016 Valentin Samir
+"""Some utils functions for tests"""
+from cas_server.default_settings import settings
+
+from django.test import Client
+
+import cgi
+from threading import Thread
+from lxml import etree
+from six.moves import BaseHTTPServer
+from six.moves.urllib.parse import urlparse, parse_qsl
+
+from cas_server import models
+
+
+def copy_form(form):
+ """Copy form value into a dict"""
+ params = {}
+ for field in form:
+ if field.value():
+ params[field.name] = field.value()
+ else:
+ params[field.name] = ""
+ return params
+
+
+def get_login_page_params(client=None):
+ """Return a client and the POST params for the client to login"""
+ if client is None:
+ client = Client()
+ response = client.get('/login')
+ params = copy_form(response.context["form"])
+ return client, params
+
+
+def get_auth_client(**update):
+ """return a authenticated client"""
+ client, params = get_login_page_params()
+ params["username"] = settings.CAS_TEST_USER
+ params["password"] = settings.CAS_TEST_PASSWORD
+ params.update(update)
+
+ client.post('/login', params)
+ assert client.session.get("authenticated")
+
+ return client
+
+
+def get_user_ticket_request(service):
+ """Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
+ client = get_auth_client()
+ response = client.get("/login", {"service": service})
+ ticket_value = response['Location'].split('ticket=')[-1]
+ user = models.User.objects.get(
+ username=settings.CAS_TEST_USER,
+ session_key=client.session.session_key
+ )
+ ticket = models.ServiceTicket.objects.get(value=ticket_value)
+ return (user, ticket, client)
+
+
+def get_validated_ticket(service):
+ """Return a tick that has being already validated. Used to test SLO"""
+ (ticket, auth_client) = get_user_ticket_request(service)[1:3]
+
+ client = Client()
+ response = client.get('/validate', {'ticket': ticket.value, 'service': service})
+ assert (response.status_code == 200)
+ assert (response.content == b'yes\ntest\n')
+
+ ticket = models.ServiceTicket.objects.get(value=ticket.value)
+ return (auth_client, ticket)
+
+
+def get_pgt():
+ """return a dict contening a service, user and PGT ticket for this service"""
+ (httpd, host, port) = HttpParamsHandler.run()[0:3]
+ service = "http://%s:%s" % (host, port)
+
+ (user, ticket) = get_user_ticket_request(service)[:2]
+
+ client = Client()
+ client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
+ params = httpd.PARAMS
+
+ params["service"] = service
+ params["user"] = user
+
+ return params
+
+
+def get_proxy_ticket(service):
+ """Return a ProxyTicket waiting for validation"""
+ params = get_pgt()
+
+ # get a proxy ticket
+ client = Client()
+ response = client.get('/proxy', {'pgt': params['pgtId'], 'targetService': service})
+ root = etree.fromstring(response.content)
+ proxy_ticket = root.xpath(
+ "//cas:proxyTicket",
+ namespaces={'cas': "http://www.yale.edu/tp/cas"}
+ )
+ proxy_ticket = proxy_ticket[0].text
+ ticket = models.ProxyTicket.objects.get(value=proxy_ticket)
+ return ticket
+
+
+class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler):
+ """
+ A simple http server that return 200 on GET or POST
+ and store GET or POST parameters. Used in unit tests
+ """
+
+ def do_GET(self):
+ """Called on a GET request on the BaseHTTPServer"""
+ self.send_response(200)
+ self.send_header(b"Content-type", "text/plain")
+ self.end_headers()
+ self.wfile.write(b"ok")
+ url = urlparse(self.path)
+ params = dict(parse_qsl(url.query))
+ self.server.PARAMS = params
+
+ def do_POST(self):
+ """Called on a POST request on the BaseHTTPServer"""
+ ctype, pdict = cgi.parse_header(self.headers.get('content-type'))
+ if ctype == 'multipart/form-data':
+ postvars = cgi.parse_multipart(self.rfile, pdict)
+ elif ctype == 'application/x-www-form-urlencoded':
+ length = int(self.headers.get('content-length'))
+ postvars = cgi.parse_qs(self.rfile.read(length), keep_blank_values=1)
+ else:
+ postvars = {}
+ self.server.PARAMS = postvars
+
+ def log_message(self, *args):
+ """silent any log message"""
+ return
+
+ @classmethod
+ def run(cls):
+ """Run a BaseHTTPServer using this class as handler"""
+ server_class = BaseHTTPServer.HTTPServer
+ httpd = server_class(("127.0.0.1", 0), cls)
+ (host, port) = httpd.socket.getsockname()
+
+ def lauch():
+ """routine to lauch in a background thread"""
+ httpd.handle_request()
+ httpd.server_close()
+
+ httpd_thread = Thread(target=lauch)
+ httpd_thread.daemon = True
+ httpd_thread.start()
+ return (httpd, host, port)
+
+
+class Http404Handler(HttpParamsHandler):
+ """A simple http server that always return 404 not found. Used in unit tests"""
+ def do_GET(self):
+ """Called on a GET request on the BaseHTTPServer"""
+ self.send_response(404)
+ self.send_header(b"Content-type", "text/plain")
+ self.end_headers()
+ self.wfile.write(b"error 404 not found")
+
+ def do_POST(self):
+ """Called on a POST request on the BaseHTTPServer"""
+ return self.do_GET()
diff --git a/cas_server/urls.py b/cas_server/urls.py
index 982ef9d..8b7f762 100644
--- a/cas_server/urls.py
+++ b/cas_server/urls.py
@@ -8,7 +8,7 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
-# (c) 2015 Valentin Samir
+# (c) 2015-2016 Valentin Samir
"""urls for the app"""
from django.conf.urls import patterns, url
from django.views.generic import RedirectView
diff --git a/cas_server/utils.py b/cas_server/utils.py
index c8b345b..ee7b5e5 100644
--- a/cas_server/utils.py
+++ b/cas_server/utils.py
@@ -1,4 +1,4 @@
-# ⁻*- coding: utf-8 -*-
+# -*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
@@ -8,7 +8,7 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
-# (c) 2015 Valentin Samir
+# (c) 2015-2016 Valentin Samir
"""Some util function for the app"""
from .default_settings import settings
@@ -23,18 +23,19 @@ import hashlib
import crypt
import base64
import six
-from threading import Thread
+
from importlib import import_module
-from six.moves import BaseHTTPServer
from six.moves.urllib.parse import urlparse, urlunparse, parse_qsl, urlencode
def context(params):
+ """Function that add somes variable to the context before template rendering"""
params["settings"] = settings
return params
def json_response(request, data):
+ """Wrapper dumping `data` to a json and sending it to the user with an HttpResponse"""
data["messages"] = []
for msg in messages.get_messages(request):
data["messages"].append({'message': msg.message, 'level': msg.level_tag})
@@ -64,6 +65,7 @@ def redirect_params(url_name, params=None):
def reverse_params(url_name, params=None, **kwargs):
+ """compule the reverse url or `url_name` and add GET parameters from `params` to it"""
url = reverse(url_name, **kwargs)
params = urlencode(params if params else {})
return url + "?%s" % params
@@ -83,10 +85,13 @@ def update_url(url, params):
url_parts = list(urlparse(url))
query = dict(parse_qsl(url_parts[4]))
query.update(params)
- url_parts[4] = urlencode(query)
- for i, url_part in enumerate(url_parts):
- if not isinstance(url_part, bytes):
- url_parts[i] = url_part.encode('utf-8')
+ # make the params order deterministic
+ query = list(query.items())
+ query.sort()
+ url_query = urlencode(query)
+ if not isinstance(url_query, bytes): # pragma: no cover in python3 urlencode return an unicode
+ url_query = url_query.encode("utf-8")
+ url_parts[4] = url_query
return urlunparse(url_parts).decode('utf-8')
@@ -147,35 +152,25 @@ def gen_saml_id():
return _gen_ticket('_')
-class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler):
- PARAMS = {}
-
- def do_GET(self):
- self.send_response(200)
- self.send_header(b"Content-type", "text/plain")
- self.end_headers()
- self.wfile.write(b"ok")
- url = urlparse(self.path)
- params = dict(parse_qsl(url.query))
- PGTUrlHandler.PARAMS.update(params)
-
- def log_message(self, *args):
- return
-
- @staticmethod
- def run():
- server_class = BaseHTTPServer.HTTPServer
- httpd = server_class(("127.0.0.1", 0), PGTUrlHandler)
- (host, port) = httpd.socket.getsockname()
-
- def lauch():
- httpd.handle_request()
- httpd.server_close()
-
- httpd_thread = Thread(target=lauch)
- httpd_thread.daemon = True
- httpd_thread.start()
- return (httpd_thread, host, port)
+def crypt_salt_is_valid(salt):
+ """Return True is salt is valid has a crypt salt, False otherwise"""
+ if len(salt) < 2:
+ return False
+ else:
+ if salt[0] == '$':
+ if salt[1] == '$':
+ return False
+ else:
+ if '$' not in salt[1:]:
+ return False
+ else:
+ hashed = crypt.crypt("", salt)
+ if not hashed or '$' not in hashed[1:]:
+ return False
+ else:
+ return True
+ else:
+ return True
class LdapHashUserPassword(object):
@@ -268,7 +263,7 @@ class LdapHashUserPassword(object):
if salt is None or salt == b"":
salt = b""
cls._test_scheme_nosalt(scheme)
- elif salt is not None:
+ else:
cls._test_scheme_salt(scheme)
try:
return scheme + base64.b64encode(
@@ -278,9 +273,9 @@ class LdapHashUserPassword(object):
if six.PY3:
password = password.decode(charset)
salt = salt.decode(charset)
- hashed_password = crypt.crypt(password, salt)
- if hashed_password is None:
+ if not crypt_salt_is_valid(salt):
raise cls.BadSalt("System crypt implementation do not support the salt %r" % salt)
+ hashed_password = crypt.crypt(password, salt)
if six.PY3:
hashed_password = hashed_password.encode(charset)
return scheme + hashed_password
@@ -302,7 +297,7 @@ class LdapHashUserPassword(object):
if scheme in cls.schemes_nosalt:
return b""
elif scheme == b'{CRYPT}':
- return b'$'.join(hashed_passord.split(b'$', 3)[:-1])
+ return b'$'.join(hashed_passord.split(b'$', 3)[:-1])[len(scheme):]
else:
hashed_passord = base64.b64decode(hashed_passord[len(scheme):])
if len(hashed_passord) < cls._schemes_to_len[scheme]:
@@ -324,7 +319,7 @@ def check_password(method, password, hashed_password, charset):
elif method == "crypt":
if hashed_password.startswith(b'$'):
salt = b'$'.join(hashed_password.split(b'$', 3)[:-1])
- elif hashed_password.startswith(b'_'):
+ elif hashed_password.startswith(b'_'): # pragma: no cover old BSD format not supported
salt = hashed_password[:9]
else:
salt = hashed_password[:2]
@@ -332,9 +327,9 @@ def check_password(method, password, hashed_password, charset):
password = password.decode(charset)
salt = salt.decode(charset)
hashed_password = hashed_password.decode(charset)
- crypted_password = crypt.crypt(password, salt)
- if crypted_password is None:
+ if not crypt_salt_is_valid(salt):
raise ValueError("System crypt implementation do not support the salt %r" % salt)
+ crypted_password = crypt.crypt(password, salt)
return crypted_password == hashed_password
elif method == "ldap":
scheme = LdapHashUserPassword.get_scheme(hashed_password)
diff --git a/cas_server/views.py b/cas_server/views.py
index 2b33a6c..94ee0f0 100644
--- a/cas_server/views.py
+++ b/cas_server/views.py
@@ -8,7 +8,7 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
-# (c) 2015 Valentin Samir
+# (c) 2015-2016 Valentin Samir
"""views for the app"""
from .default_settings import settings
@@ -105,10 +105,11 @@ class LogoutView(View, LogoutMixin):
service = None
def init_get(self, request):
+ """Initialize GET received parameters"""
self.request = request
self.service = request.GET.get('service')
self.url = request.GET.get('url')
- self.ajax = 'HTTP_X_AJAX' in request.META
+ self.ajax = settings.CAS_ENABLE_AJAX_AUTH and 'HTTP_X_AJAX' in request.META
def get(self, request, *args, **kwargs):
"""methode called on GET request on this view"""
@@ -196,24 +197,30 @@ class LoginView(View, LogoutMixin):
USER_NOT_AUTHENTICATED = 6
def init_post(self, request):
+ """Initialize POST received parameters"""
self.request = request
self.service = request.POST.get('service')
self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False")
self.gateway = request.POST.get('gateway')
self.method = request.POST.get('method')
- self.ajax = 'HTTP_X_AJAX' in request.META
+ self.ajax = settings.CAS_ENABLE_AJAX_AUTH and 'HTTP_X_AJAX' in request.META
if request.POST.get('warned') and request.POST['warned'] != "False":
self.warned = True
+ self.warn = request.POST.get('warn')
- def check_lt(self):
- # save LT for later check
- lt_valid = self.request.session.get('lt', [])
- lt_send = self.request.POST.get('lt')
- # generate a new LT (by posting the LT has been consumed)
+ def gen_lt(self):
+ """Generate a new LoginTicket and add it to the list of valid LT for the user"""
self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
if len(self.request.session['lt']) > 100:
self.request.session['lt'] = self.request.session['lt'][-100:]
+ def check_lt(self):
+ """Check is the POSTed LoginTicket is valid, if yes invalide it"""
+ # save LT for later check
+ lt_valid = self.request.session.get('lt', [])
+ lt_send = self.request.POST.get('lt')
+ # generate a new LT (by posting the LT has been consumed)
+ self.gen_lt()
# check if send LT is valid
if lt_valid is None or lt_send not in lt_valid:
return False
@@ -238,7 +245,7 @@ class LoginView(View, LogoutMixin):
username=self.request.session['username'],
session_key=self.request.session.session_key
)
- self.user.save()
+ self.user.save() # pragma: no cover (should not happend)
except models.User.DoesNotExist:
self.user = models.User.objects.create(
username=self.request.session['username'],
@@ -250,10 +257,15 @@ class LoginView(View, LogoutMixin):
elif ret == self.USER_ALREADY_LOGGED:
pass
else:
- raise EnvironmentError("invalid output for LoginView.process_post")
+ raise EnvironmentError("invalid output for LoginView.process_post") # pragma: no cover
return self.common()
def process_post(self):
+ """
+ Analyse the POST request:
+ * check that the LoginTicket is valid
+ * check that the user sumited credentials are valid
+ """
if not self.check_lt():
values = self.request.POST.copy()
# if not set a new LT and fail
@@ -280,12 +292,14 @@ class LoginView(View, LogoutMixin):
return self.USER_ALREADY_LOGGED
def init_get(self, request):
+ """Initialize GET received parameters"""
self.request = request
self.service = request.GET.get('service')
self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False")
self.gateway = request.GET.get('gateway')
self.method = request.GET.get('method')
- self.ajax = 'HTTP_X_AJAX' in request.META
+ self.ajax = settings.CAS_ENABLE_AJAX_AUTH and 'HTTP_X_AJAX' in request.META
+ self.warn = request.GET.get('warn')
def get(self, request, *args, **kwargs):
"""methode called on GET request on this view"""
@@ -294,22 +308,24 @@ class LoginView(View, LogoutMixin):
return self.common()
def process_get(self):
- # generate a new LT if none is present
- self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
-
+ """Analyse the GET request"""
+ # generate a new LT
+ self.gen_lt()
if not self.request.session.get("authenticated") or self.renew:
self.init_form()
return self.USER_NOT_AUTHENTICATED
return self.USER_AUTHENTICATED
def init_form(self, values=None):
+ """Initialization of the good form depending of POST and GET parameters"""
self.form = forms.UserCredential(
values,
initial={
'service': self.service,
'method': self.method,
'warn': self.request.session.get("warn"),
- 'lt': self.request.session['lt'][-1]
+ 'lt': self.request.session['lt'][-1],
+ 'renew': self.renew
}
)
@@ -351,7 +367,7 @@ class LoginView(View, LogoutMixin):
redirect_url = self.user.get_service_url(
self.service,
service_pattern,
- renew=self.renew
+ renew=self.renewed
)
if not self.ajax:
return HttpResponseRedirect(redirect_url)
@@ -580,12 +596,9 @@ class Validate(View):
ticket.service_pattern.user_field
)
if isinstance(username, list):
- try:
- username = username[0]
- except IndexError:
- username = None
- if not username:
- username = ""
+ # the list is not empty because we wont generate a ticket with a user_field
+ # that evaluate to False
+ username = username[0]
else:
username = ticket.user.username
return HttpResponse("yes\n%s\n" % username, content_type="text/plain")
@@ -661,6 +674,10 @@ class ValidateService(View, AttributesMixin):
params['username'] = self.ticket.user.attributs.get(
self.ticket.service_pattern.user_field
)
+ if isinstance(params['username'], list):
+ # the list is not empty because we wont generate a ticket with a user_field
+ # that evaluate to False
+ params['username'] = params['username'][0]
if self.pgt_url and (
self.pgt_url.startswith("https://") or
re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url)
@@ -762,9 +779,12 @@ class ValidateService(View, AttributesMixin):
params,
content_type="text/xml; charset=utf-8"
)
- except requests.exceptions.SSLError as error:
+ except requests.exceptions.RequestException as error:
error = utils.unpack_nested_exception(error)
- raise ValidateError('INVALID_PROXY_CALLBACK', str(error))
+ raise ValidateError(
+ 'INVALID_PROXY_CALLBACK',
+ "%s: %s" % (type(error), str(error))
+ )
else:
raise ValidateError(
'INVALID_PROXY_CALLBACK',
@@ -844,7 +864,7 @@ class Proxy(View):
except (models.BadUsername, models.BadFilter, models.UserFieldNotDefined):
raise ValidateError(
'UNAUTHORIZED_USER',
- '%s not allowed on %s' % (ticket.user, self.target_service)
+ 'User %s not allowed on %s' % (ticket.user.username, self.target_service)
)
@@ -903,11 +923,15 @@ class SamlValidate(View, AttributesMixin):
'username': self.ticket.user.username,
'attributes': attributes
}
- if self.ticket.service_pattern.user_field and \
- self.ticket.user.attributs.get(self.ticket.service_pattern.user_field):
+ if (self.ticket.service_pattern.user_field and
+ self.ticket.user.attributs.get(self.ticket.service_pattern.user_field)):
params['username'] = self.ticket.user.attributs.get(
self.ticket.service_pattern.user_field
)
+ if isinstance(params['username'], list):
+ # the list is not empty because we wont generate a ticket with a user_field
+ # that evaluate to False
+ params['username'] = params['username'][0]
logger.info(
"SamlValidate: ticket %s validated for user %s on service %s." % (
self.ticket.value,
diff --git a/pytest.ini b/pytest.ini
new file mode 100644
index 0000000..74acafe
--- /dev/null
+++ b/pytest.ini
@@ -0,0 +1,5 @@
+[pytest]
+testpaths = cas_server/tests/
+DJANGO_SETTINGS_MODULE = cas_server.tests.settings
+norecursedirs = .* build dist docs
+python_paths = .
diff --git a/requirements-dev.txt b/requirements-dev.txt
index e6ef993..3cf4247 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -1,10 +1,12 @@
-tox==1.8.1
-pytest==2.6.4
-pytest-django==2.7.0
-pytest-pythonpath==0.3
+setuptools>=5.5
+tox>=1.8.1
+pytest>=2.6.4
+pytest-django>=2.8.0
+pytest-pythonpath>=0.3
+pytest-cov>=2.2.1
requests>=2.4
-django-picklefield>=0.3.1
requests_futures>=0.9.5
+django-picklefield>=0.3.1
django-bootstrap3>=5.4
lxml>=3.4
six>=1
diff --git a/run_tests b/run_tests
deleted file mode 100755
index 4ea21ee..0000000
--- a/run_tests
+++ /dev/null
@@ -1,22 +0,0 @@
-#!/usr/bin/env python
-import os, sys
-import django
-from django.conf import settings
-
-import settings_tests
-
-settings.configure(**settings_tests.__dict__)
-django.setup()
-
-try:
- # Django <= 1.8
- from django.test.simple import DjangoTestSuiteRunner
- test_runner = DjangoTestSuiteRunner(verbosity=1)
-except ImportError:
- # Django >= 1.8
- from django.test.runner import DiscoverRunner
- test_runner = DiscoverRunner(verbosity=1)
-
-failures = test_runner.run_tests(['cas_server'])
-if failures:
- sys.exit(failures)
diff --git a/setup.py b/setup.py
index a9826c7..0548d66 100644
--- a/setup.py
+++ b/setup.py
@@ -34,7 +34,8 @@ setup(
version='0.4.4',
packages=[
'cas_server', 'cas_server.migrations',
- 'cas_server.management', 'cas_server.management.commands'
+ 'cas_server.management', 'cas_server.management.commands',
+ 'cas_server.tests'
],
include_package_data=True,
license='GPLv3',
diff --git a/tox.ini b/tox.ini
index 0b65c56..ad0fec7 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,12 +1,12 @@
[tox]
envlist=
+ flake8,
py27-django17,
py27-django18,
py27-django19,
py34-django17,
py34-django18,
py34-django19,
- flake8,
[flake8]
max-line-length=100
@@ -17,7 +17,7 @@ deps =
-r{toxinidir}/requirements-dev.txt
[testenv]
-commands=python run_tests {posargs:tests}
+commands=py.test {posargs:cas_server/tests/}
[testenv:py27-django17]
basepython=python2.7
@@ -60,3 +60,13 @@ basepython=python
deps=flake8
commands=flake8 {toxinidir}/cas_server
+[testenv:coverage]
+basepython=python
+passenv=CODACY_PROJECT_TOKEN
+deps=
+ -r{toxinidir}/requirements-dev.txt
+ codacy-coverage
+commands=
+ py.test --cov=cas_server --cov-report xml
+ python-codacy-coverage -r {toxinidir}/coverage.xml
+