diff --git a/.coveragerc b/.coveragerc index f11c9de..8f6e752 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,3 +1,11 @@ +[run] +branch = True +source = cas_server +omit = + cas_server/migrations* + cas_server/management/* + cas_server/tests/* + [report] exclude_lines = pragma: no cover @@ -5,3 +13,4 @@ exclude_lines = def __unicode__ raise AssertionError raise NotImplementedError + if six.PY3: diff --git a/.travis.yml b/.travis.yml index b0bdd46..943f5b5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,19 +2,19 @@ language: python python: - "2.7" env: - global: - - PIP_DOWNLOAD_CACHE=$HOME/.pip_cache matrix: + - TOX_ENV=coverage + - TOX_ENV=flake8 - TOX_ENV=py27-django17 - TOX_ENV=py27-django18 - TOX_ENV=py27-django19 - TOX_ENV=py34-django17 - TOX_ENV=py34-django18 - TOX_ENV=py34-django19 - - TOX_ENV=flake8 cache: directories: - - $HOME/.pip-cache/ + - $HOME/.cache/pip/ + - $HOME/build/nitmir/django-cas-server/.tox/ install: - "travis_retry pip install setuptools --upgrade" - "pip install tox" @@ -22,4 +22,3 @@ script: - tox -e $TOX_ENV after_script: - cat .tox/$TOX_ENV/log/*.log - diff --git a/Makefile b/Makefile index 9088fba..d0f9165 100644 --- a/Makefile +++ b/Makefile @@ -1,11 +1,15 @@ -.PHONY: clean build install dist test_venv test_project +.PHONY: build dist VERSION=`python setup.py -V` build: python setup.py build -install: - python setup.py install +install: dist + pip -V + pip install --no-deps --upgrade --force-reinstall --find-links ./dist/django-cas-server-${VERSION}.tar.gz django-cas-server + +uninstall: + pip uninstall django-cas-server || true clean_pyc: find ./ -name '*.pyc' -delete @@ -16,18 +20,23 @@ clean_tox: rm -rf .tox clean_test_venv: rm -rf test_venv -clean: clean_pyc clean_build -clean_all: clean_pyc clean_build clean_tox clean_test_venv +clean_coverage: + rm -rf coverage.xml .coverage htmlcov +clean_tild_backup: + find ./ -name '*~' -delete + +clean: clean_pyc clean_build clean_coverage clean_tild_backup + +clean_all: clean clean_tox clean_test_venv dist: python setup.py sdist -test_venv: - mkdir -p test_venv +test_venv/bin/python: virtualenv test_venv - test_venv/bin/pip install -U --requirement requirements.txt + test_venv/bin/pip install -U --requirement requirements-dev.txt Django -test_venv/cas/manage.py: +test_venv/cas/manage.py: test_venv mkdir -p test_venv/cas test_venv/bin/django-admin startproject cas test_venv/cas ln -s ../../cas_server test_venv/cas/cas_server @@ -38,19 +47,15 @@ test_venv/cas/manage.py: test_venv/bin/python test_venv/cas/manage.py migrate test_venv/bin/python test_venv/cas/manage.py createsuperuser -test_project: test_venv test_venv/cas/manage.py +test_venv: test_venv/bin/python + +test_project: test_venv/cas/manage.py @echo "##############################################################" @echo "A test django project was created in $(realpath test_venv/cas)" -run_test_server: test_project +run_server: test_project test_venv/bin/python test_venv/cas/manage.py runserver -coverage: test_venv - test_venv/bin/pip install coverage - test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests - test_venv/bin/coverage html - test_venv/bin/coverage xml - -coverage_codacy: coverage - test_venv/bin/pip install codacy-coverage - test_venv/bin/python-codacy-coverage -r coverage.xml +run_tests: test_venv + test_venv/bin/py.test --cov=cas_server --cov-report html + rm htmlcov/coverage_html.js # I am really pissed off by those keybord shortcuts diff --git a/README.rst b/README.rst index 29b8057..bb148a3 100644 --- a/README.rst +++ b/README.rst @@ -219,7 +219,8 @@ Test backend settings. Only usefull if you are using the test authentication bac * ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``. * ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``. * ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is - ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}``. + ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net', + 'alias': ['demo1', 'demo2']}``. Authentication backend diff --git a/cas_server/__init__.py b/cas_server/__init__.py index f830740..29f5de6 100644 --- a/cas_server/__init__.py +++ b/cas_server/__init__.py @@ -7,6 +7,6 @@ # along with this program; if not, write to the Free Software Foundation, Inc., 51 # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # -# (c) 2015 Valentin Samir - +# (c) 2015-2016 Valentin Samir +"""A django CAS server application""" default_app_config = 'cas_server.apps.CasAppConfig' diff --git a/cas_server/admin.py b/cas_server/admin.py index a6a9be4..472e1df 100644 --- a/cas_server/admin.py +++ b/cas_server/admin.py @@ -7,7 +7,7 @@ # along with this program; if not, write to the Free Software Foundation, Inc., 51 # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # -# (c) 2015 Valentin Samir +# (c) 2015-2016 Valentin Samir """module for the admin interface of the app""" from django.contrib import admin from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket, User, ServicePattern diff --git a/cas_server/apps.py b/cas_server/apps.py index c34b6eb..ea15273 100644 --- a/cas_server/apps.py +++ b/cas_server/apps.py @@ -1,7 +1,19 @@ +# This program is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for +# more details. +# +# You should have received a copy of the GNU General Public License version 3 +# along with this program; if not, write to the Free Software Foundation, Inc., 51 +# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# (c) 2015-2016 Valentin Samir +"""django config module""" from django.utils.translation import ugettext_lazy as _ from django.apps import AppConfig class CasAppConfig(AppConfig): + """django CAS application config class""" name = 'cas_server' verbose_name = _('Central Authentication Service') diff --git a/cas_server/auth.py b/cas_server/auth.py index f84fb11..2826a85 100644 --- a/cas_server/auth.py +++ b/cas_server/auth.py @@ -8,7 +8,7 @@ # along with this program; if not, write to the Free Software Foundation, Inc., 51 # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # -# (c) 2015 Valentin Samir +# (c) 2015-2016 Valentin Samir """Some authentication classes for the CAS""" from django.conf import settings from django.contrib.auth import get_user_model @@ -21,6 +21,7 @@ except ImportError: class AuthUser(object): + """Authentication base class""" def __init__(self, username): self.username = username diff --git a/cas_server/default_settings.py b/cas_server/default_settings.py index 2824991..1d2174c 100644 --- a/cas_server/default_settings.py +++ b/cas_server/default_settings.py @@ -78,5 +78,12 @@ setting_default('CAS_TEST_USER', 'test') setting_default('CAS_TEST_PASSWORD', 'test') setting_default( 'CAS_TEST_ATTRIBUTES', - {'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'} + { + 'nom': 'Nymous', + 'prenom': 'Ano', + 'email': 'anonymous@example.net', + 'alias': ['demo1', 'demo2'] + } ) + +setting_default('CAS_ENABLE_AJAX_AUTH', False) diff --git a/cas_server/forms.py b/cas_server/forms.py index f970ccd..83cfe8a 100644 --- a/cas_server/forms.py +++ b/cas_server/forms.py @@ -7,7 +7,7 @@ # along with this program; if not, write to the Free Software Foundation, Inc., 51 # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # -# (c) 2015 Valentin Samir +# (c) 2015-2016 Valentin Samir """forms for the app""" from .default_settings import settings @@ -19,6 +19,7 @@ import cas_server.models as models class WarnForm(forms.Form): + """Form used on warn page before emiting a ticket""" service = forms.CharField(widget=forms.HiddenInput(), required=False) renew = forms.BooleanField(widget=forms.HiddenInput(), required=False) gateway = forms.CharField(widget=forms.HiddenInput(), required=False) @@ -35,6 +36,7 @@ class UserCredential(forms.Form): lt = forms.CharField(widget=forms.HiddenInput(), required=False) method = forms.CharField(widget=forms.HiddenInput(), required=False) warn = forms.BooleanField(label=_('warn'), required=False) + renew = forms.BooleanField(widget=forms.HiddenInput(), required=False) def __init__(self, *args, **kwargs): super(UserCredential, self).__init__(*args, **kwargs) @@ -46,6 +48,7 @@ class UserCredential(forms.Form): cleaned_data["username"] = auth.username else: raise forms.ValidationError(_(u"Bad user")) + return cleaned_data class TicketForm(forms.ModelForm): diff --git a/cas_server/management/commands/cas_clean_sessions.py b/cas_server/management/commands/cas_clean_sessions.py index da60c1a..3d32090 100644 --- a/cas_server/management/commands/cas_clean_sessions.py +++ b/cas_server/management/commands/cas_clean_sessions.py @@ -1,3 +1,4 @@ +"""Clean deleted sessions management command""" from django.core.management.base import BaseCommand from django.utils.translation import ugettext_lazy as _ @@ -5,6 +6,7 @@ from ... import models class Command(BaseCommand): + """Clean deleted sessions""" args = '' help = _(u"Clean deleted sessions") diff --git a/cas_server/management/commands/cas_clean_tickets.py b/cas_server/management/commands/cas_clean_tickets.py index d18a7d4..dfbd4ec 100644 --- a/cas_server/management/commands/cas_clean_tickets.py +++ b/cas_server/management/commands/cas_clean_tickets.py @@ -1,3 +1,4 @@ +"""Clean old trickets management command""" from django.core.management.base import BaseCommand from django.utils.translation import ugettext_lazy as _ @@ -5,6 +6,7 @@ from ... import models class Command(BaseCommand): + """Clean old trickets""" args = '' help = _(u"Clean old trickets") diff --git a/cas_server/models.py b/cas_server/models.py index 9cb0ac5..d870a50 100644 --- a/cas_server/models.py +++ b/cas_server/models.py @@ -8,7 +8,7 @@ # along with this program; if not, write to the Free Software Foundation, Inc., 51 # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # -# (c) 2015 Valentin Samir +# (c) 2015-2016 Valentin Samir """models for the app""" from .default_settings import settings @@ -20,7 +20,6 @@ from django.utils import timezone from picklefield.fields import PickledObjectField import re -import os import sys import logging from importlib import import_module @@ -47,6 +46,7 @@ class User(models.Model): @classmethod def clean_old_entries(cls): + """Remove users inactive since more that SESSION_COOKIE_AGE""" users = cls.objects.filter( date__lt=(timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE)) ) @@ -56,6 +56,7 @@ class User(models.Model): @classmethod def clean_deleted_sessions(cls): + """Remove user where the session do not exists anymore""" for user in cls.objects.all(): if not SessionStore(session_key=user.session_key).get('authenticated'): user.logout() @@ -80,10 +81,10 @@ class User(models.Model): for ticket_class in ticket_classes: queryset = ticket_class.objects.filter(user=self) for ticket in queryset: - ticket.logout(request, session, async_list) + ticket.logout(session, async_list) queryset.delete() for future in async_list: - if future: + if future: # pragma: no branch (should always be true) try: future.result() except Exception as error: @@ -111,13 +112,21 @@ class User(models.Model): (a.name, a.replace if a.replace else a.name) for a in service_pattern.attributs.all() ) replacements = dict( - (a.name, (a.pattern, a.replace)) for a in service_pattern.replacements.all() + (a.attribut, (a.pattern, a.replace)) for a in service_pattern.replacements.all() ) service_attributs = {} for (key, value) in self.attributs.items(): if key in attributs or '*' in attributs: if key in replacements: - value = re.sub(replacements[key][0], replacements[key][1], value) + if isinstance(value, list): + for index, subval in enumerate(value): + value[index] = re.sub( + replacements[key][0], + replacements[key][1], + subval + ) + else: + value = re.sub(replacements[key][0], replacements[key][1], value) service_attributs[attributs.get(key, key)] = value ticket = ticket_class.objects.create( user=self, @@ -141,6 +150,7 @@ class User(models.Model): class ServicePatternException(Exception): + """Base exception of exceptions raised in the ServicePattern model""" pass @@ -394,77 +404,57 @@ class Ticket(models.Model): ).delete() # sending SLO to timed-out validated tickets - if cls.TIMEOUT and cls.TIMEOUT > 0: - async_list = [] - session = FuturesSession( - executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS) - ) - queryset = cls.objects.filter( - creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT)) - ) - for ticket in queryset: - ticket.logout(None, session, async_list) - queryset.delete() - for future in async_list: - if future: - try: - future.result() - except Exception as error: - logger.warning("Error durring SLO %s" % error) - sys.stderr.write("%r\n" % error) + async_list = [] + session = FuturesSession( + executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS) + ) + queryset = cls.objects.filter( + creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT)) + ) + for ticket in queryset: + ticket.logout(session, async_list) + queryset.delete() + for future in async_list: + if future: # pragma: no branch (should always be true) + try: + future.result() + except Exception as error: + logger.warning("Error durring SLO %s" % error) + sys.stderr.write("%r\n" % error) - def logout(self, request, session, async_list=None): + def logout(self, session, async_list=None): """Send a SLO request to the ticket service""" # On logout invalidate the Ticket self.validate = True self.save() - if self.validate and self.single_log_out: + if self.validate and self.single_log_out: # pragma: no branch (should always be true) logger.info( "Sending SLO requests to service %s for user %s" % ( self.service, self.user.username ) ) - try: - xml = u""" - - %(ticket)s - """ % \ - { - 'id': os.urandom(20).encode("hex"), - 'datetime': timezone.now().isoformat(), - 'ticket': self.value - } - if self.service_pattern.single_log_out_callback: - url = self.service_pattern.single_log_out_callback - else: - url = self.service - async_list.append( - session.post( - url.encode('utf-8'), - data={'logoutRequest': xml.encode('utf-8')}, - timeout=settings.CAS_SLO_TIMEOUT - ) + xml = u""" + +%(ticket)s +""" % \ + { + 'id': utils.gen_saml_id(), + 'datetime': timezone.now().isoformat(), + 'ticket': self.value + } + if self.service_pattern.single_log_out_callback: + url = self.service_pattern.single_log_out_callback + else: + url = self.service + async_list.append( + session.post( + url.encode('utf-8'), + data={'logoutRequest': xml.encode('utf-8')}, + timeout=settings.CAS_SLO_TIMEOUT ) - except Exception as error: - error = utils.unpack_nested_exception(error) - logger.warning( - "Error durring SLO for user %s on service %s: %s" % ( - self.user.username, - self.service, - error - ) - ) - if request is not None: - messages.add_message( - request, - messages.WARNING, - _(u'Error during service logout %(service)s:\n%(error)s') % - {'service': self.service, 'error': error} - ) - else: - sys.stderr.write("%r\n" % error) + ) class ServiceTicket(Ticket): diff --git a/cas_server/tests.py b/cas_server/tests.py deleted file mode 100644 index 7d355cb..0000000 --- a/cas_server/tests.py +++ /dev/null @@ -1,702 +0,0 @@ -from .default_settings import settings - -from django.test import TestCase -from django.test import Client - -import six -from lxml import etree - -from cas_server import models -from cas_server import utils - - -def get_login_page_params(): - client = Client() - response = client.get('/login') - form = response.context["form"] - params = {} - for field in form: - if field.value(): - params[field.name] = field.value() - else: - params[field.name] = "" - return client, params - - -def get_auth_client(): - client, params = get_login_page_params() - params["username"] = settings.CAS_TEST_USER - params["password"] = settings.CAS_TEST_PASSWORD - - client.post('/login', params) - return client - - -def get_user_ticket_request(service): - client = get_auth_client() - response = client.get("/login", {"service": service}) - ticket_value = response['Location'].split('ticket=')[-1] - user = models.User.objects.get( - username=settings.CAS_TEST_USER, - session_key=client.session.session_key - ) - ticket = models.ServiceTicket.objects.get(value=ticket_value) - return (user, ticket) - - -def get_pgt(): - (host, port) = utils.PGTUrlHandler.run()[1:3] - service = "http://%s:%s" % (host, port) - - (user, ticket) = get_user_ticket_request(service) - - client = Client() - client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service}) - params = utils.PGTUrlHandler.PARAMS.copy() - - params["service"] = service - params["user"] = user - - return params - - -class CheckPasswordCase(TestCase): - """Tests for the utils function `utils.check_password`""" - - def setUp(self): - """Generate random bytes string that will be used ass passwords""" - self.password1 = utils.gen_saml_id() - self.password2 = utils.gen_saml_id() - if not isinstance(self.password1, bytes): - self.password1 = self.password1.encode("utf8") - self.password2 = self.password2.encode("utf8") - - def test_setup(self): - """check that generated password are bytes""" - self.assertIsInstance(self.password1, bytes) - self.assertIsInstance(self.password2, bytes) - - def test_plain(self): - """test the plain auth method""" - self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8")) - self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8")) - - def test_crypt(self): - """test the crypt auth method""" - if six.PY3: - hashed_password1 = utils.crypt.crypt( - self.password1.decode("utf8"), - "$6$UVVAQvrMyXMF3FF3" - ).encode("utf8") - else: - hashed_password1 = utils.crypt.crypt(self.password1, "$6$UVVAQvrMyXMF3FF3") - - self.assertTrue(utils.check_password("crypt", self.password1, hashed_password1, "utf8")) - self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8")) - - def test_ldap_ssha(self): - """test the ldap auth method with a {SSHA} scheme""" - salt = b"UVVAQvrMyXMF3FF3" - hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8") - - self.assertIsInstance(hashed_password1, bytes) - self.assertTrue(utils.check_password("ldap", self.password1, hashed_password1, "utf8")) - self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8")) - - def test_hex_md5(self): - """test the hex_md5 auth method""" - hashed_password1 = utils.hashlib.md5(self.password1).hexdigest() - - self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8")) - self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8")) - - def test_hox_sha512(self): - """test the hex_sha512 auth method""" - hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest() - - self.assertTrue( - utils.check_password("hex_sha512", self.password1, hashed_password1, "utf8") - ) - self.assertFalse( - utils.check_password("hex_sha512", self.password2, hashed_password1, "utf8") - ) - - -class LoginTestCase(TestCase): - - def setUp(self): - settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' - self.service_pattern = models.ServicePattern.objects.create( - name="example", - pattern="^https://www\.example\.com(/.*)?$", - ) - models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) - - def test_login_view_post_goodpass_goodlt(self): - client, params = get_login_page_params() - params["username"] = settings.CAS_TEST_USER - params["password"] = settings.CAS_TEST_PASSWORD - - response = client.post('/login', params) - - self.assertEqual(response.status_code, 200) - self.assertTrue( - ( - b"You have successfully logged into " - b"the Central Authentication Service" - ) in response.content - ) - - self.assertTrue( - models.User.objects.get( - username=settings.CAS_TEST_USER, - session_key=client.session.session_key - ) - ) - - def test_login_view_post_badlt(self): - client, params = get_login_page_params() - params["username"] = settings.CAS_TEST_USER - params["password"] = settings.CAS_TEST_PASSWORD - params["lt"] = 'LT-random' - - response = client.post('/login', params) - - self.assertEqual(response.status_code, 200) - self.assertTrue(b"Invalid login ticket" in response.content) - self.assertFalse( - ( - b"You have successfully logged into " - b"the Central Authentication Service" - ) in response.content - ) - - def test_login_view_post_badpass_good_lt(self): - client, params = get_login_page_params() - params["username"] = settings.CAS_TEST_USER - params["password"] = "test2" - response = client.post('/login', params) - - self.assertEqual(response.status_code, 200) - self.assertTrue( - ( - b"The credentials you provided cannot be " - b"determined to be authentic" - ) in response.content - ) - self.assertFalse( - ( - b"You have successfully logged into " - b"the Central Authentication Service" - ) in response.content - ) - - def test_view_login_get_auth_allowed_service(self): - client = get_auth_client() - response = client.get("/login?service=https://www.example.com") - self.assertEqual(response.status_code, 302) - self.assertTrue(response.has_header('Location')) - self.assertTrue( - response['Location'].startswith( - "https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX - ) - ) - - ticket_value = response['Location'].split('ticket=')[-1] - user = models.User.objects.get( - username=settings.CAS_TEST_USER, - session_key=client.session.session_key - ) - self.assertTrue(user) - ticket = models.ServiceTicket.objects.get(value=ticket_value) - self.assertEqual(ticket.user, user) - self.assertEqual(ticket.attributs, settings.CAS_TEST_ATTRIBUTES) - self.assertEqual(ticket.validate, False) - self.assertEqual(ticket.service_pattern, self.service_pattern) - - def test_view_login_get_auth_denied_service(self): - client = get_auth_client() - response = client.get("/login?service=https://www.example.org") - self.assertEqual(response.status_code, 200) - self.assertTrue(b"Service https://www.example.org non allowed" in response.content) - - -class LogoutTestCase(TestCase): - - def setUp(self): - settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' - - def test_logout_view(self): - client = get_auth_client() - - response = client.get("/login") - self.assertEqual(response.status_code, 200) - self.assertTrue( - ( - b"You have successfully logged into " - b"the Central Authentication Service" - ) in response.content - ) - - response = client.get("/logout") - self.assertEqual(response.status_code, 200) - self.assertTrue( - ( - b"You have successfully logged out from " - b"the Central Authentication Service" - ) in response.content - ) - - response = client.get("/login") - self.assertEqual(response.status_code, 200) - self.assertFalse( - ( - b"You have successfully logged into " - b"the Central Authentication Service" - ) in response.content - ) - - def test_logout_view_url(self): - client = get_auth_client() - - response = client.get('/logout?url=https://www.example.com') - self.assertEqual(response.status_code, 302) - self.assertTrue(response.has_header("Location")) - self.assertEqual(response["Location"], "https://www.example.com") - - response = client.get("/login") - self.assertEqual(response.status_code, 200) - self.assertFalse( - ( - b"You have successfully logged into " - b"the Central Authentication Service" - ) in response.content - ) - - def test_logout_view_service(self): - client = get_auth_client() - - response = client.get('/logout?service=https://www.example.com') - self.assertEqual(response.status_code, 302) - self.assertTrue(response.has_header("Location")) - self.assertEqual(response["Location"], "https://www.example.com") - - response = client.get("/login") - self.assertEqual(response.status_code, 200) - self.assertFalse( - ( - b"You have successfully logged into " - b"the Central Authentication Service" - ) in response.content - ) - - -class AuthTestCase(TestCase): - - def setUp(self): - settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' - self.service = 'https://www.example.com' - models.ServicePattern.objects.create( - name="example", - pattern="^https://www\.example\.com(/.*)?$" - ) - - def test_auth_view_goodpass(self): - settings.CAS_AUTH_SHARED_SECRET = 'test' - client = Client() - response = client.post( - '/auth', - { - 'username': settings.CAS_TEST_USER, - 'password': settings.CAS_TEST_PASSWORD, - 'service': self.service, - 'secret': 'test' - } - ) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.content, b'yes\n') - - def test_auth_view_badpass(self): - settings.CAS_AUTH_SHARED_SECRET = 'test' - client = Client() - response = client.post( - '/auth', - { - 'username': settings.CAS_TEST_USER, - 'password': 'badpass', - 'service': self.service, - 'secret': 'test' - } - ) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.content, b'no\n') - - def test_auth_view_badservice(self): - settings.CAS_AUTH_SHARED_SECRET = 'test' - client = Client() - response = client.post( - '/auth', - { - 'username': settings.CAS_TEST_USER, - 'password': settings.CAS_TEST_PASSWORD, - 'service': 'https://www.example.org', - 'secret': 'test' - } - ) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.content, b'no\n') - - def test_auth_view_badsecret(self): - settings.CAS_AUTH_SHARED_SECRET = 'test' - client = Client() - response = client.post( - '/auth', - { - 'username': settings.CAS_TEST_USER, - 'password': settings.CAS_TEST_PASSWORD, - 'service': self.service, - 'secret': 'badsecret' - } - ) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.content, b'no\n') - - def test_auth_view_badsettings(self): - settings.CAS_AUTH_SHARED_SECRET = None - client = Client() - response = client.post( - '/auth', - { - 'username': settings.CAS_TEST_USER, - 'password': settings.CAS_TEST_PASSWORD, - 'service': self.service, - 'secret': 'test' - } - ) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.content, b"no\nplease set CAS_AUTH_SHARED_SECRET") - - -class ValidateTestCase(TestCase): - - def setUp(self): - settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' - self.service = 'https://www.example.com' - self.service_pattern = models.ServicePattern.objects.create( - name="example", - pattern="^https://www\.example\.com(/.*)?$" - ) - models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) - - def test_validate_view_ok(self): - ticket = get_user_ticket_request(self.service)[1] - - client = Client() - response = client.get('/validate', {'ticket': ticket.value, 'service': self.service}) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.content, b'yes\ntest\n') - - def test_validate_view_badservice(self): - ticket = get_user_ticket_request(self.service)[1] - - client = Client() - response = client.get( - '/validate', - {'ticket': ticket.value, 'service': "https://www.example.org"} - ) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.content, b'no\n') - - def test_validate_view_badticket(self): - get_user_ticket_request(self.service) - - client = Client() - response = client.get( - '/validate', - {'ticket': "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX, 'service': self.service} - ) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.content, b'no\n') - - -class ValidateServiceTestCase(TestCase): - - def setUp(self): - settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' - self.service = 'http://127.0.0.1:45678' - self.service_pattern = models.ServicePattern.objects.create( - name="localhost", - pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$", - proxy_callback=True - ) - models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) - - def test_validate_service_view_ok(self): - ticket = get_user_ticket_request(self.service)[1] - - client = Client() - response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service}) - self.assertEqual(response.status_code, 200) - - root = etree.fromstring(response.content) - sucess = root.xpath( - "//cas:authenticationSuccess", - namespaces={'cas': "http://www.yale.edu/tp/cas"} - ) - self.assertTrue(sucess) - - users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - self.assertEqual(len(users), 1) - self.assertEqual(users[0].text, settings.CAS_TEST_USER) - - attributes = root.xpath( - "//cas:attributes", - namespaces={'cas': "http://www.yale.edu/tp/cas"} - ) - self.assertEqual(len(attributes), 1) - attrs1 = {} - for attr in attributes[0]: - attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text - - attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - self.assertEqual(len(attributes), len(attrs1)) - attrs2 = {} - for attr in attributes: - attrs2[attr.attrib['name']] = attr.attrib['value'] - self.assertEqual(attrs1, attrs2) - self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES) - - def test_validate_service_view_badservice(self): - ticket = get_user_ticket_request(self.service)[1] - - client = Client() - bad_service = "https://www.example.org" - response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': bad_service}) - self.assertEqual(response.status_code, 200) - - root = etree.fromstring(response.content) - error = root.xpath( - "//cas:authenticationFailure", - namespaces={'cas': "http://www.yale.edu/tp/cas"} - ) - self.assertEqual(len(error), 1) - self.assertEqual(error[0].attrib['code'], "INVALID_SERVICE") - self.assertEqual(error[0].text, bad_service) - - def test_validate_service_view_badticket_goodprefix(self): - get_user_ticket_request(self.service) - - client = Client() - bad_ticket = "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX - response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service}) - self.assertEqual(response.status_code, 200) - - root = etree.fromstring(response.content) - error = root.xpath( - "//cas:authenticationFailure", - namespaces={'cas': "http://www.yale.edu/tp/cas"} - ) - self.assertEqual(len(error), 1) - self.assertEqual(error[0].attrib['code'], "INVALID_TICKET") - self.assertEqual(error[0].text, 'ticket not found') - - def test_validate_service_view_badticket_badprefix(self): - get_user_ticket_request(self.service) - - client = Client() - bad_ticket = "RANDOM" - response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service}) - self.assertEqual(response.status_code, 200) - - root = etree.fromstring(response.content) - error = root.xpath( - "//cas:authenticationFailure", - namespaces={'cas': "http://www.yale.edu/tp/cas"} - ) - self.assertEqual(len(error), 1) - self.assertEqual(error[0].attrib['code'], "INVALID_TICKET") - self.assertEqual(error[0].text, bad_ticket) - - def test_validate_service_view_ok_pgturl(self): - (host, port) = utils.PGTUrlHandler.run()[1:3] - service = "http://%s:%s" % (host, port) - - ticket = get_user_ticket_request(service)[1] - - client = Client() - response = client.get( - '/serviceValidate', - {'ticket': ticket.value, 'service': service, 'pgtUrl': service} - ) - pgt_params = utils.PGTUrlHandler.PARAMS.copy() - self.assertEqual(response.status_code, 200) - - root = etree.fromstring(response.content) - pgtiou = root.xpath( - "//cas:proxyGrantingTicket", - namespaces={'cas': "http://www.yale.edu/tp/cas"} - ) - self.assertEqual(len(pgtiou), 1) - self.assertEqual(pgt_params["pgtIou"], pgtiou[0].text) - self.assertTrue("pgtId" in pgt_params) - - def test_validate_service_pgturl_bad_proxy_callback(self): - self.service_pattern.proxy_callback = False - self.service_pattern.save() - ticket = get_user_ticket_request(self.service)[1] - - client = Client() - response = client.get( - '/serviceValidate', - {'ticket': ticket.value, 'service': self.service, 'pgtUrl': self.service} - ) - self.assertEqual(response.status_code, 200) - - root = etree.fromstring(response.content) - error = root.xpath( - "//cas:authenticationFailure", - namespaces={'cas': "http://www.yale.edu/tp/cas"} - ) - self.assertEqual(len(error), 1) - self.assertEqual(error[0].attrib['code'], "INVALID_PROXY_CALLBACK") - self.assertEqual(error[0].text, "callback url not allowed by configuration") - - -class ProxyTestCase(TestCase): - - def setUp(self): - settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' - self.service = 'http://127.0.0.1' - self.service_pattern = models.ServicePattern.objects.create( - name="localhost", - pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$", - proxy=True, - proxy_callback=True - ) - models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) - - def test_validate_proxy_ok(self): - params = get_pgt() - - # get a proxy ticket - client1 = Client() - response = client1.get('/proxy', {'pgt': params['pgtId'], 'targetService': self.service}) - self.assertEqual(response.status_code, 200) - - root = etree.fromstring(response.content) - sucess = root.xpath("//cas:proxySuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - self.assertTrue(sucess) - - proxy_ticket = root.xpath( - "//cas:proxyTicket", - namespaces={'cas': "http://www.yale.edu/tp/cas"} - ) - self.assertEqual(len(proxy_ticket), 1) - proxy_ticket = proxy_ticket[0].text - - # validate the proxy ticket - client2 = Client() - response = client2.get('/proxyValidate', {'ticket': proxy_ticket, 'service': self.service}) - self.assertEqual(response.status_code, 200) - - root = etree.fromstring(response.content) - sucess = root.xpath( - "//cas:authenticationSuccess", - namespaces={'cas': "http://www.yale.edu/tp/cas"} - ) - self.assertTrue(sucess) - - # check that the proxy is send to the end service - proxies = root.xpath("//cas:proxies", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - self.assertEqual(len(proxies), 1) - proxy = proxies[0].xpath("//cas:proxy", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - self.assertEqual(len(proxy), 1) - self.assertEqual(proxy[0].text, params["service"]) - - # same tests than those for serviceValidate - users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - self.assertEqual(len(users), 1) - self.assertEqual(users[0].text, settings.CAS_TEST_USER) - - attributes = root.xpath( - "//cas:attributes", - namespaces={'cas': "http://www.yale.edu/tp/cas"} - ) - self.assertEqual(len(attributes), 1) - attrs1 = {} - for attr in attributes[0]: - attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text - - attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - self.assertEqual(len(attributes), len(attrs1)) - attrs2 = {} - for attr in attributes: - attrs2[attr.attrib['name']] = attr.attrib['value'] - self.assertEqual(attrs1, attrs2) - self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES) - - def test_validate_proxy_bad(self): - params = get_pgt() - - # bad PGT - client1 = Client() - response = client1.get( - '/proxy', - { - 'pgt': "%s-RANDOM" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX, - 'targetService': params['service'] - } - ) - self.assertEqual(response.status_code, 200) - - root = etree.fromstring(response.content) - error = root.xpath( - "//cas:authenticationFailure", - namespaces={'cas': "http://www.yale.edu/tp/cas"} - ) - self.assertEqual(len(error), 1) - self.assertEqual(error[0].attrib['code'], "INVALID_TICKET") - self.assertEqual( - error[0].text, - "PGT %s-RANDOM not found" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX - ) - - # bad targetService - client2 = Client() - response = client2.get( - '/proxy', - {'pgt': params['pgtId'], 'targetService': "https://www.example.org"} - ) - self.assertEqual(response.status_code, 200) - - root = etree.fromstring(response.content) - error = root.xpath( - "//cas:authenticationFailure", - namespaces={'cas': "http://www.yale.edu/tp/cas"} - ) - self.assertEqual(len(error), 1) - self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE") - self.assertEqual(error[0].text, "https://www.example.org") - - # service do not allow proxy ticket - self.service_pattern.proxy = False - self.service_pattern.save() - - client3 = Client() - response = client3.get( - '/proxy', - {'pgt': params['pgtId'], 'targetService': params['service']} - ) - self.assertEqual(response.status_code, 200) - - root = etree.fromstring(response.content) - error = root.xpath( - "//cas:authenticationFailure", - namespaces={'cas': "http://www.yale.edu/tp/cas"} - ) - self.assertEqual(len(error), 1) - self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE") - self.assertEqual( - error[0].text, - 'the service %s do not allow proxy ticket' % params['service'] - ) diff --git a/cas_server/tests/__init__.py b/cas_server/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cas_server/tests/mixin.py b/cas_server/tests/mixin.py new file mode 100644 index 0000000..ddbf2d2 --- /dev/null +++ b/cas_server/tests/mixin.py @@ -0,0 +1,193 @@ +# ⁻*- coding: utf-8 -*- +# This program is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for +# more details. +# +# You should have received a copy of the GNU General Public License version 3 +# along with this program; if not, write to the Free Software Foundation, Inc., 51 +# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# (c) 2016 Valentin Samir +"""Some mixin classes for tests""" +from cas_server.default_settings import settings +from django.utils import timezone + +import re +from lxml import etree +from datetime import timedelta + +from cas_server import models +from cas_server.tests.utils import get_auth_client + + +class BaseServicePattern(object): + """Mixing for setting up service pattern for testing""" + def setup_service_patterns(self, proxy=False): + """setting up service pattern""" + # For general purpose testing + self.service = "https://www.example.com" + self.service_pattern = models.ServicePattern.objects.create( + name="example", + pattern="^https://www\.example\.com(/.*)?$", + proxy=proxy, + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + # For testing the restrict_users attributes + self.service_restrict_user_fail = "https://restrict_user_fail.example.com" + self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create( + name="restrict_user_fail", + pattern="^https://restrict_user_fail\.example\.com(/.*)?$", + restrict_users=True, + proxy=proxy, + ) + self.service_restrict_user_success = "https://restrict_user_success.example.com" + self.service_pattern_restrict_user_success = models.ServicePattern.objects.create( + name="restrict_user_success", + pattern="^https://restrict_user_success\.example\.com(/.*)?$", + restrict_users=True, + proxy=proxy, + ) + models.Username.objects.create( + value=settings.CAS_TEST_USER, + service_pattern=self.service_pattern_restrict_user_success + ) + + # For testing the user attributes filtering conditions + self.service_filter_fail = "https://filter_fail.example.com" + self.service_pattern_filter_fail = models.ServicePattern.objects.create( + name="filter_fail", + pattern="^https://filter_fail\.example\.com(/.*)?$", + proxy=proxy, + ) + models.FilterAttributValue.objects.create( + attribut="right", + pattern="^admin$", + service_pattern=self.service_pattern_filter_fail + ) + self.service_filter_fail_alt = "https://filter_fail_alt.example.com" + self.service_pattern_filter_fail_alt = models.ServicePattern.objects.create( + name="filter_fail_alt", + pattern="^https://filter_fail_alt\.example\.com(/.*)?$", + proxy=proxy, + ) + models.FilterAttributValue.objects.create( + attribut="nom", + pattern="^toto$", + service_pattern=self.service_pattern_filter_fail_alt + ) + self.service_filter_success = "https://filter_success.example.com" + self.service_pattern_filter_success = models.ServicePattern.objects.create( + name="filter_success", + pattern="^https://filter_success\.example\.com(/.*)?$", + proxy=proxy, + ) + models.FilterAttributValue.objects.create( + attribut="email", + pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']), + service_pattern=self.service_pattern_filter_success + ) + + # For testing the user_field attributes + self.service_field_needed_fail = "https://field_needed_fail.example.com" + self.service_pattern_field_needed_fail = models.ServicePattern.objects.create( + name="field_needed_fail", + pattern="^https://field_needed_fail\.example\.com(/.*)?$", + user_field="uid", + proxy=proxy, + ) + self.service_field_needed_success = "https://field_needed_success.example.com" + self.service_pattern_field_needed_success = models.ServicePattern.objects.create( + name="field_needed_success", + pattern="^https://field_needed_success\.example\.com(/.*)?$", + user_field="alias", + proxy=proxy, + ) + self.service_field_needed_success_alt = "https://field_needed_success_alt.example.com" + self.service_pattern_field_needed_success = models.ServicePattern.objects.create( + name="field_needed_success_alt", + pattern="^https://field_needed_success_alt\.example\.com(/.*)?$", + user_field="nom", + proxy=proxy, + ) + + +class XmlContent(object): + """Mixin for test on CAS XML responses""" + def assert_error(self, response, code, text=None): + """Assert a validation error""" + self.assertEqual(response.status_code, 200) + root = etree.fromstring(response.content) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], code) + if text is not None: + self.assertEqual(error[0].text, text) + + def assert_success(self, response, username, original_attributes): + """assert a ticket validation success""" + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + sucess = root.xpath( + "//cas:authenticationSuccess", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertTrue(sucess) + + users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(users), 1) + self.assertEqual(users[0].text, username) + + attributes = root.xpath( + "//cas:attributes", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(attributes), 1) + attrs1 = set() + for attr in attributes[0]: + attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text)) + + attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(attributes), len(attrs1)) + attrs2 = set() + for attr in attributes: + attrs2.add((attr.attrib['name'], attr.attrib['value'])) + original = set() + for key, value in original_attributes.items(): + if isinstance(value, list): + for sub_value in value: + original.add((key, sub_value)) + else: + original.add((key, value)) + self.assertEqual(attrs1, attrs2) + self.assertEqual(attrs1, original) + + return root + + +class UserModels(object): + """Mixin for test on CAS user models""" + @staticmethod + def expire_user(): + """return an expired user""" + client = get_auth_client() + + new_date = timezone.now() - timedelta(seconds=(settings.SESSION_COOKIE_AGE + 600)) + models.User.objects.filter( + username=settings.CAS_TEST_USER, + session_key=client.session.session_key + ).update(date=new_date) + return client + + @staticmethod + def get_user(client): + """return the user associated with an authenticated client""" + return models.User.objects.get( + username=settings.CAS_TEST_USER, + session_key=client.session.session_key + ) diff --git a/settings_tests.py b/cas_server/tests/settings.py similarity index 96% rename from settings_tests.py rename to cas_server/tests/settings.py index 4588c2c..1402c64 100644 --- a/settings_tests.py +++ b/cas_server/tests/settings.py @@ -52,7 +52,7 @@ MIDDLEWARE_CLASSES = [ 'django.middleware.locale.LocaleMiddleware', ] -ROOT_URLCONF = 'cas_server.urls' +ROOT_URLCONF = 'cas_server.tests.urls' # Database # https://docs.djangoproject.com/en/1.9/ref/settings/#databases @@ -60,6 +60,7 @@ ROOT_URLCONF = 'cas_server.urls' DATABASES = { 'default': { 'ENGINE': 'django.db.backends.sqlite3', + 'NAME': ':memory:', } } diff --git a/cas_server/tests/test_models.py b/cas_server/tests/test_models.py new file mode 100644 index 0000000..e75f54f --- /dev/null +++ b/cas_server/tests/test_models.py @@ -0,0 +1,166 @@ +# ⁻*- coding: utf-8 -*- +# This program is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for +# more details. +# +# You should have received a copy of the GNU General Public License version 3 +# along with this program; if not, write to the Free Software Foundation, Inc., 51 +# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# (c) 2016 Valentin Samir +"""Tests module for models""" +from cas_server.default_settings import settings + +from django.test import TestCase +from django.test.utils import override_settings +from django.utils import timezone + +from datetime import timedelta +from importlib import import_module + +from cas_server import models +from cas_server.tests.utils import get_auth_client, HttpParamsHandler +from cas_server.tests.mixin import UserModels, BaseServicePattern + +SessionStore = import_module(settings.SESSION_ENGINE).SessionStore + + +@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser') +class UserTestCase(TestCase, UserModels): + """tests for the user models""" + def setUp(self): + """Prepare the test context""" + self.service = 'http://127.0.0.1:45678' + self.service_pattern = models.ServicePattern.objects.create( + name="localhost", + pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$", + single_log_out=True + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + def test_clean_old_entries(self): + """test clean_old_entries""" + # get an authenticated client + client = self.expire_user() + # assert the user exists before being cleaned + self.assertEqual(len(models.User.objects.all()), 1) + # assert the last activity date is before the expiry date + self.assertTrue( + self.get_user(client).date < ( + timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE) + ) + ) + # delete old inactive users + models.User.clean_old_entries() + # assert the user has being well delete + self.assertEqual(len(models.User.objects.all()), 0) + + def test_clean_deleted_sessions(self): + """test clean_deleted_sessions""" + # get an authenticated client + client1 = get_auth_client() + client2 = get_auth_client() + # generate a ticket to fire SLO during user cleaning (SLO should fail a nothing listen + # on self.service) + ticket = self.get_user(client1).get_ticket( + models.ServiceTicket, + self.service, + self.service_pattern, + renew=False + ) + ticket.validate = True + ticket.save() + # simulated expired session being garbage collected for client1 + session = SessionStore(session_key=client1.session.session_key) + session.flush() + # assert the user exists before being cleaned + self.assertTrue(self.get_user(client1)) + self.assertTrue(self.get_user(client2)) + self.assertEqual(len(models.User.objects.all()), 2) + # session has being remove so the user of client1 is no longer authenticated + self.assertFalse(client1.session.get("authenticated")) + # the user a client2 should still be authenticated + self.assertTrue(client2.session.get("authenticated")) + # the user should be deleted + models.User.clean_deleted_sessions() + # assert the user with expired sessions has being well deleted but the other remain + self.assertEqual(len(models.User.objects.all()), 1) + self.assertFalse(models.ServiceTicket.objects.all()) + self.assertTrue(client2.session.get("authenticated")) + + +@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser') +class TicketTestCase(TestCase, UserModels, BaseServicePattern): + """tests for the tickets models""" + def setUp(self): + """Prepare the test context""" + self.setup_service_patterns() + self.service = 'http://127.0.0.1:45678' + self.service_pattern = models.ServicePattern.objects.create( + name="localhost", + pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$", + single_log_out=True + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + @staticmethod + def get_ticket( + user, + ticket_class, + service, + service_pattern, + renew=False, + validate=False, + validity_expired=False, + timeout_expired=False, + single_log_out=False, + ): + """Return a ticket""" + ticket = user.get_ticket(ticket_class, service, service_pattern, renew) + ticket.validate = validate + ticket.single_log_out = single_log_out + if validity_expired: + ticket.creation = min( + ticket.creation, + (timezone.now() - timedelta(seconds=(ticket_class.VALIDITY + 10))) + ) + if timeout_expired: + ticket.creation = min( + ticket.creation, + (timezone.now() - timedelta(seconds=(ticket_class.TIMEOUT + 10))) + ) + ticket.save() + return ticket + + def test_clean_old_service_ticket(self): + """test tickets clean_old_entries""" + # ge an authenticated client + client = get_auth_client() + # get the user associated to the client + user = self.get_user(client) + # generate a ticket for that client, waiting for validation + self.get_ticket(user, models.ServiceTicket, self.service, self.service_pattern) + # generate another ticket for those validation time has expired + self.get_ticket( + user, models.ServiceTicket, + self.service, self.service_pattern, validity_expired=True + ) + (httpd, host, port) = HttpParamsHandler.run()[0:3] + service = "http://%s:%s" % (host, port) + # generate a ticket with SLO having timeout reach + self.get_ticket( + user, models.ServiceTicket, + service, self.service_pattern, timeout_expired=True, + validate=True, single_log_out=True + ) + # there should be 3 tickets in the db + self.assertEqual(len(models.ServiceTicket.objects.all()), 3) + # we call the clean_old_entries method that should delete validated non SLO ticket and + # expired non validated ticket and send SLO for SLO expired ticket before deleting then + models.ServiceTicket.clean_old_entries() + params = httpd.PARAMS + # we successfully got a SLO request + self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest']) + # only 1 ticket remain in the db + self.assertEqual(len(models.ServiceTicket.objects.all()), 1) diff --git a/cas_server/tests/test_utils.py b/cas_server/tests/test_utils.py new file mode 100644 index 0000000..76fa2cc --- /dev/null +++ b/cas_server/tests/test_utils.py @@ -0,0 +1,191 @@ +# ⁻*- coding: utf-8 -*- +# This program is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for +# more details. +# +# You should have received a copy of the GNU General Public License version 3 +# along with this program; if not, write to the Free Software Foundation, Inc., 51 +# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# (c) 2016 Valentin Samir +"""Tests module for utils""" +from django.test import TestCase + +import six + +from cas_server import utils + + +class CheckPasswordCase(TestCase): + """Tests for the utils function `utils.check_password`""" + + def setUp(self): + """Generate random bytes string that will be used ass passwords""" + self.password1 = utils.gen_saml_id() + self.password2 = utils.gen_saml_id() + if not isinstance(self.password1, bytes): # pragma: no cover executed only in python3 + self.password1 = self.password1.encode("utf8") + self.password2 = self.password2.encode("utf8") + + def test_setup(self): + """check that generated password are bytes""" + self.assertIsInstance(self.password1, bytes) + self.assertIsInstance(self.password2, bytes) + + def test_plain(self): + """test the plain auth method""" + self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8")) + self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8")) + + def test_plain_unicode(self): + """test the plain auth method with unicode input""" + self.assertTrue( + utils.check_password( + "plain", + self.password1.decode("utf8"), + self.password1.decode("utf8"), + "utf8" + ) + ) + self.assertFalse( + utils.check_password( + "plain", + self.password1.decode("utf8"), + self.password2.decode("utf8"), + "utf8" + ) + ) + + def test_crypt(self): + """test the crypt auth method""" + salts = ["$6$UVVAQvrMyXMF3FF3", "aa"] + hashed_password1 = [] + for salt in salts: + if six.PY3: + hashed_password1.append( + utils.crypt.crypt( + self.password1.decode("utf8"), + salt + ).encode("utf8") + ) + else: + hashed_password1.append(utils.crypt.crypt(self.password1, salt)) + + for hp1 in hashed_password1: + self.assertTrue(utils.check_password("crypt", self.password1, hp1, "utf8")) + self.assertFalse(utils.check_password("crypt", self.password2, hp1, "utf8")) + + with self.assertRaises(ValueError): + utils.check_password("crypt", self.password1, b"$truc$s$dsdsd", "utf8") + + def test_ldap_password_valid(self): + """test the ldap auth method with all the schemes""" + salt = b"UVVAQvrMyXMF3FF3" + schemes_salt = [b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}"] + schemes_nosalt = [b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"] + hashed_password1 = [] + for scheme in schemes_salt: + hashed_password1.append( + utils.LdapHashUserPassword.hash(scheme, self.password1, salt, charset="utf8") + ) + for scheme in schemes_nosalt: + hashed_password1.append( + utils.LdapHashUserPassword.hash(scheme, self.password1, charset="utf8") + ) + hashed_password1.append( + utils.LdapHashUserPassword.hash( + b"{CRYPT}", + self.password1, + b"$6$UVVAQvrMyXMF3FF3", + charset="utf8" + ) + ) + for hp1 in hashed_password1: + self.assertIsInstance(hp1, bytes) + self.assertTrue(utils.check_password("ldap", self.password1, hp1, "utf8")) + self.assertFalse(utils.check_password("ldap", self.password2, hp1, "utf8")) + + def test_ldap_password_fail(self): + """test the ldap auth method with malformed hash or bad schemes""" + salt = b"UVVAQvrMyXMF3FF3" + schemes_salt = [b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}"] + schemes_nosalt = [b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"] + + # first try to hash with bad parameters + with self.assertRaises(utils.LdapHashUserPassword.BadScheme): + utils.LdapHashUserPassword.hash(b"TOTO", self.password1) + for scheme in schemes_nosalt: + with self.assertRaises(utils.LdapHashUserPassword.BadScheme): + utils.LdapHashUserPassword.hash(scheme, self.password1, salt) + for scheme in schemes_salt: + with self.assertRaises(utils.LdapHashUserPassword.BadScheme): + utils.LdapHashUserPassword.hash(scheme, self.password1) + with self.assertRaises(utils.LdapHashUserPassword.BadSalt): + utils.LdapHashUserPassword.hash(b'{CRYPT}', self.password1, b"$truc$toto") + + # then try to check hash with bad hashes + with self.assertRaises(utils.LdapHashUserPassword.BadHash): + utils.check_password("ldap", self.password1, b"TOTOssdsdsd", "utf8") + for scheme in schemes_salt: + with self.assertRaises(utils.LdapHashUserPassword.BadHash): + utils.check_password("ldap", self.password1, scheme + b"dG90b3E8ZHNkcw==", "utf8") + + def test_hex(self): + """test all the hex_HASH method: the hashed password is a simple hash of the password""" + hashes = ["md5", "sha1", "sha224", "sha256", "sha384", "sha512"] + hashed_password1 = [] + for hash in hashes: + hashed_password1.append( + ("hex_%s" % hash, getattr(utils.hashlib, hash)(self.password1).hexdigest()) + ) + for (method, hp1) in hashed_password1: + self.assertTrue(utils.check_password(method, self.password1, hp1, "utf8")) + self.assertFalse(utils.check_password(method, self.password2, hp1, "utf8")) + + def test_bad_method(self): + """try to check password with a bad method, should raise a ValueError""" + with self.assertRaises(ValueError): + utils.check_password("test", self.password1, b"$truc$s$dsdsd", "utf8") + + +class UtilsTestCase(TestCase): + """tests for some little utils functions""" + def test_import_attr(self): + """ + test the import_attr function. Feeded with a dotted path string, it should + import the dotted module and return that last componend of the dotted path + (function, class or variable) + """ + with self.assertRaises(ImportError): + utils.import_attr('toto.titi.tutu') + with self.assertRaises(AttributeError): + utils.import_attr('cas_server.utils.toto') + with self.assertRaises(ValueError): + utils.import_attr('toto') + self.assertEqual( + utils.import_attr('cas_server.default_app_config'), + 'cas_server.apps.CasAppConfig' + ) + self.assertEqual(utils.import_attr(utils), utils) + + def test_update_url(self): + """ + test the update_url function. Given an url with possible GET parameter and a dict + the function build a url with GET parameters updated by the dictionnary + """ + url1 = utils.update_url(u"https://www.example.com?toto=1", {u"tata": u"2"}) + url2 = utils.update_url(b"https://www.example.com?toto=1", {b"tata": b"2"}) + self.assertEqual(url1, u"https://www.example.com?tata=2&toto=1") + self.assertEqual(url2, u"https://www.example.com?tata=2&toto=1") + + url3 = utils.update_url(u"https://www.example.com?toto=1", {u"toto": u"2"}) + self.assertEqual(url3, u"https://www.example.com?toto=2") + + def test_crypt_salt_is_valid(self): + """test the function crypt_salt_is_valid who test if a crypt salt is valid""" + self.assertFalse(utils.crypt_salt_is_valid("")) # len 0 + self.assertFalse(utils.crypt_salt_is_valid("a")) # len 1 + self.assertFalse(utils.crypt_salt_is_valid("$$")) # start with $ followed by $ + self.assertFalse(utils.crypt_salt_is_valid("$toto")) # start with $ but no secondary $ + self.assertFalse(utils.crypt_salt_is_valid("$toto$toto")) # algorithm toto not known diff --git a/cas_server/tests/test_view.py b/cas_server/tests/test_view.py new file mode 100644 index 0000000..95720c4 --- /dev/null +++ b/cas_server/tests/test_view.py @@ -0,0 +1,1813 @@ +# ⁻*- coding: utf-8 -*- +# This program is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for +# more details. +# +# You should have received a copy of the GNU General Public License version 3 +# along with this program; if not, write to the Free Software Foundation, Inc., 51 +# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# (c) 2016 Valentin Samir +"""Tests module for views""" +from cas_server.default_settings import settings + +import django +from django.test import TestCase, Client +from django.test.utils import override_settings +from django.utils import timezone + + +import random +import json +from lxml import etree +from six.moves import range + +from cas_server import models +from cas_server import utils +from cas_server.tests.utils import ( + copy_form, + get_login_page_params, + get_auth_client, + get_user_ticket_request, + get_pgt, + get_proxy_ticket, + get_validated_ticket, + HttpParamsHandler, + Http404Handler +) +from cas_server.tests.mixin import BaseServicePattern, XmlContent + + +@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser') +class LoginTestCase(TestCase, BaseServicePattern): + """Tests for the login view""" + def setUp(self): + """Prepare the test context:""" + # we prepare a bunch a service url and service patterns for tests + self.setup_service_patterns() + + def assert_logged(self, client, response, warn=False, code=200): + """Assertions testing that client is well authenticated""" + self.assertEqual(response.status_code, code) + # this message is displayed to the user upon successful authentication + self.assertTrue( + ( + b"You have successfully logged into " + b"the Central Authentication Service" + ) in response.content + ) + # these session variables a set if usccessfully authenticated + self.assertTrue(client.session["username"] == settings.CAS_TEST_USER) + self.assertTrue(client.session["warn"] is warn) + self.assertTrue(client.session["authenticated"] is True) + + # on successfull authentication, a corresponding user object is created + self.assertTrue( + models.User.objects.get( + username=settings.CAS_TEST_USER, + session_key=client.session.session_key + ) + ) + + def assert_login_failed(self, client, response, code=200): + """Assertions testing a failed login attempt""" + self.assertEqual(response.status_code, code) + # this message is displayed to the user upon successful authentication, so it should not + # appear + self.assertFalse( + ( + b"You have successfully logged into " + b"the Central Authentication Service" + ) in response.content + ) + + # if authentication has failed, these session variables should not be set + self.assertTrue(client.session.get("username") is None) + self.assertTrue(client.session.get("warn") is None) + self.assertTrue(client.session.get("authenticated") is None) + + def test_login_view_post_goodpass_goodlt(self): + """Test a successul login""" + # we get a client who fetch a frist time the login page and the login form default + # parameters + client, params = get_login_page_params() + # we set username/password in the form + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + # the LoginTicket in the form should match a valid LT in the user session + self.assertTrue(params['lt'] in client.session['lt']) + + # we post a login attempt + response = client.post('/login', params) + # as username/password/lt are all valid, the login should succed + self.assert_logged(client, response) + # The LoginTicket is conssumed and should no longer be valid + self.assertTrue(params['lt'] not in client.session['lt']) + + def test_login_view_post_goodpass_goodlt_warn(self): + """Test a successul login requesting to be warned before creating services tickets""" + # get a client and initial login params + client, params = get_login_page_params() + # set valids usernames/passswords + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + # this time, we check the warn checkbox + params["warn"] = "on" + + # postings login request + response = client.post('/login', params) + # as username/password/lt are all valid, the login should succed and warn be enabled + self.assert_logged(client, response, warn=True) + + def test_lt_max(self): + """Check we only keep the last 100 Login Ticket for a user""" + # get a client and initial login params + client, params = get_login_page_params() + # get a first LT that should be valid + current_lt = params["lt"] + # we keep the last 100 generated LT by user, so after having generated `i_in_test` we + # test if `current_lt` is still valid + i_in_test = random.randint(0, 99) + # after `i_not_in_test` `current_lt` should be valid not more + i_not_in_test = random.randint(101, 150) + # start generating 150 LT + for i in range(150): + if i == i_in_test: + # before more than 100 LT generated, the first TL should be valid + self.assertTrue(current_lt in client.session['lt']) + if i == i_not_in_test: + # after more than 100 LT generated, the first LT should be valid no more + self.assertTrue(current_lt not in client.session['lt']) + # assert that we do not keep more that 100 valid LT + self.assertTrue(len(client.session['lt']) <= 100) + # generate a new LT by getting the login page + client, params = get_login_page_params(client) + # in the end, we still have less that 100 valid LT + self.assertTrue(len(client.session['lt']) <= 100) + + def test_login_view_post_badlt(self): + """Login attempt with a bad LoginTicket, login should fail""" + # get a client and initial login params + client, params = get_login_page_params() + # set valid username/password + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + # set a bad LT + params["lt"] = 'LT-random' + + # posting the login request + response = client.post('/login', params) + + # as the LT is not valid, login should fail + self.assert_login_failed(client, response) + # the reason why login has failed is displayed to the user + self.assertTrue(b"Invalid login ticket" in response.content) + + def test_login_view_post_badpass_good_lt(self): + """Login attempt with a bad password""" + # get a client and initial login params + client, params = get_login_page_params() + # set valid username but invalid password + params["username"] = settings.CAS_TEST_USER + params["password"] = "test2" + # posting the login request + response = client.post('/login', params) + + # as the password is wrong, login should fail + self.assert_login_failed(client, response) + # the reason why login has failed is displayed to the user + self.assertTrue( + ( + b"The credentials you provided cannot be " + b"determined to be authentic" + ) in response.content + ) + + def assert_ticket_attributes(self, client, ticket_value): + """check the ticket attributes in the db""" + # Get get current session user in the db + user = models.User.objects.get( + username=settings.CAS_TEST_USER, + session_key=client.session.session_key + ) + # we should find exactly one user + self.assertTrue(user) + # get the ticker object corresponting to `ticket_value` + ticket = models.ServiceTicket.objects.get(value=ticket_value) + # chek that the ticket is well attributed to the user + self.assertEqual(ticket.user, user) + # check that the user attributes match the attributes registered on the ticket + self.assertEqual(ticket.attributs, settings.CAS_TEST_ATTRIBUTES) + # check that the ticket has not being validated yet + self.assertEqual(ticket.validate, False) + # check that the service pattern registered on the ticket is the on we use for tests + self.assertEqual(ticket.service_pattern, self.service_pattern) + + def assert_service_ticket(self, client, response): + """check that a ticket is well emited when requested on a allowed service""" + # On ticket emission, we should be redirected to the service url, setting the ticket + # GET parameter + self.assertEqual(response.status_code, 302) + self.assertTrue(response.has_header('Location')) + self.assertTrue( + response['Location'].startswith( + "https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX + ) + ) + # check that the value of the ticket GET parameter match the value of the ticket + # created in the db + ticket_value = response['Location'].split('ticket=')[-1] + self.assert_ticket_attributes(client, ticket_value) + + def test_view_login_get_allowed_service(self): + """Request a ticket for an allowed service by an unauthenticated client""" + # get a bare new http client + client = Client() + # we are not authenticated and are asking for a ticket for https://www.example.com + # which is a valid service matched by self.service_pattern + response = client.get("/login?service=https://www.example.com") + # the login page should be displayed + self.assertEqual(response.status_code, 200) + # we warn the user why it need to authenticated + self.assertTrue( + ( + b"Authentication required by service " + b"example (https://www.example.com)" + ) in response.content + ) + + def test_view_login_get_denied_service(self): + """Request a ticket for an denied service by an unauthenticated client""" + # get a bare new http client + client = Client() + # we are not authenticated and are asking for a ticket for https://www.example.net + # which is NOT a valid service + response = client.get("/login?service=https://www.example.net") + self.assertEqual(response.status_code, 200) + # we warn the user that https://www.example.net is not an allowed service url + self.assertTrue(b"Service https://www.example.net non allowed" in response.content) + + def test_view_login_get_auth_allowed_service(self): + """Request a ticket for an allowed service by an authenticated client""" + # get a client that is already authenticated + client = get_auth_client() + # ask for a ticket for https://www.example.com + response = client.get("/login?service=https://www.example.com") + # as https://www.example.com is a valid service a ticket should be created and the + # user redirected to the service url + self.assert_service_ticket(client, response) + + def test_view_login_get_auth_allowed_service_warn(self): + """Request a ticket for an allowed service by an authenticated client""" + # get a client that is already authenticated and has ask to be warned befor we + # generated a ticket + client = get_auth_client(warn="on") + # ask for a ticket for https://www.example.com + response = client.get("/login?service=https://www.example.com") + # we display a warning to the user, asking him to validate the ticket creation (insted + # a generating and redirecting directly to the service url) + self.assertEqual(response.status_code, 200) + self.assertTrue( + ( + b"Authentication has been required by service " + b"example (https://www.example.com)" + ) in response.content + ) + # get the displayed form parameters + params = copy_form(response.context["form"]) + # we post, confirming we want a ticket + response = client.post("/login", params) + # as https://www.example.com is a valid service a ticket should be created and the + # user redirected to the service url + self.assert_service_ticket(client, response) + + def test_view_login_get_auth_denied_service(self): + """Request a ticket for a not allowed service by an authenticated client""" + # get a client that is already authenticated + client = get_auth_client() + # we are authenticated and are asking for a ticket for https://www.example.org + # which is NOT a valid service + response = client.get("/login?service=https://www.example.org") + self.assertEqual(response.status_code, 200) + # we warn the user that https://www.example.net is not an allowed service url + # NO ticket are created + self.assertTrue(b"Service https://www.example.org non allowed" in response.content) + + def test_user_logged_not_in_db(self): + """If the user is logged but has been delete from the database, it should be logged out""" + # get a client that is already authenticated + client = get_auth_client() + # delete the user in the db + models.User.objects.get( + username=settings.CAS_TEST_USER, + session_key=client.session.session_key + ).delete() + # fetch the login page + response = client.get("/login") + + # The user should be logged out + self.assert_login_failed(client, response, code=302) + # and redirected to the login page. We branch depending on the version a django as + # the test client behaviour changed after django 1.9 + if django.VERSION < (1, 9): # pragma: no cover coverage is computed with dango 1.9 + self.assertEqual(response["Location"], "http://testserver/login") + else: + self.assertEqual(response["Location"], "/login?") + + def test_service_restrict_user(self): + """Testing the restric user capability from a service""" + # get a client that is already authenticated + client = get_auth_client() + + # trying to get a ticket from a service url matched by a service pattern having a + # restriction on the usernames allowed to get tickets. the test user username is not one + # of this username. + response = client.get("/login", {'service': self.service_restrict_user_fail}) + self.assertEqual(response.status_code, 200) + # the ticket is not created and a warning is displayed to the user + self.assertTrue(b"Username non allowed" in response.content) + + # same but with the tes user username being one of the allowed usernames + response = client.get("/login", {'service': self.service_restrict_user_success}) + # the ticket is created and we are redirected to the service url + self.assertEqual(response.status_code, 302) + self.assertTrue( + response["Location"].startswith("%s?ticket=" % self.service_restrict_user_success) + ) + + def test_service_filter(self): + """Test the filtering on user attributes""" + # get a client that is already authenticated + client = get_auth_client() + + # trying to get a ticket from a service url matched by a service pattern having + # a restriction on the user attributes. The test user if ailing these restrictions + # We try first with a single value attribut (aka a string) and then with + # a multi values attributs (aka a list of strings) + for service in [self.service_filter_fail, self.service_filter_fail_alt]: + response = client.get("/login", {'service': service}) + # the ticket is not created and a warning is displayed to the user + self.assertEqual(response.status_code, 200) + self.assertTrue(b"User charateristics non allowed" in response.content) + + # same but with rectriction that a valid upon the test user attributes + response = client.get("/login", {'service': self.service_filter_success}) + # the ticket us created and the user redirected to the service url + self.assertEqual(response.status_code, 302) + self.assertTrue(response["Location"].startswith("%s?ticket=" % self.service_filter_success)) + + def test_service_user_field(self): + """Test using a user attribute as username: case on if the attribute exists or not""" + # get a client that is already authenticated + client = get_auth_client() + + # trying to get a ticket from a service url matched by a service pattern that use + # a particular attribute has username. The test user do NOT have this attribute + response = client.get("/login", {'service': self.service_field_needed_fail}) + # the ticket is not created and a warning is displayed to the user + self.assertEqual(response.status_code, 200) + self.assertTrue(b"The attribut uid is needed to use that service" in response.content) + + # same but with a attribute that the test user has + response = client.get("/login", {'service': self.service_field_needed_success}) + # the ticket us created and the user redirected to the service url + self.assertEqual(response.status_code, 302) + self.assertTrue( + response["Location"].startswith("%s?ticket=" % self.service_field_needed_success) + ) + + @override_settings(CAS_TEST_ATTRIBUTES={'alias': []}) + def test_service_user_field_evaluate_to_false(self): + """ + Test using a user attribute as username: + case the attribute exists but evaluate to False + """ + # get a client that is already authenticated + client = get_auth_client() + # trying to get a ticket from a service url matched by a service pattern that use + # a particular attribute has username. The test user have this attribute, but it is + # evaluated to False (eg an empty string "" or an empty list []) + response = client.get("/login", {"service": self.service_field_needed_success}) + # the ticket is not created and a warning is displayed to the user + self.assertEqual(response.status_code, 200) + self.assertTrue(b"The attribut alias is needed to use that service" in response.content) + + def test_gateway(self): + """test gateway parameter""" + + # First with an authenticated client that fail to get a ticket for a service + service = "https://restrict_user_fail.example.com" + # get a client that is already authenticated + client = get_auth_client() + # the authenticated client fail to get a ticket for some reason + response = client.get("/login", {'service': service, 'gateway': 'on'}) + # as gateway is set, he is redirected to the service url without any ticket + self.assertEqual(response.status_code, 302) + self.assertEqual(response["Location"], service) + + # second for an user not yet authenticated on a valid service + client = Client() + # the client fail to get a ticket since he is not yep authenticated + response = client.get('/login', {'service': service, 'gateway': 'on'}) + # as gateway is set, he is redirected to the service url without any ticket + self.assertEqual(response.status_code, 302) + self.assertEqual(response["Location"], service) + + def test_renew(self): + """test the authentication renewal request from a service""" + # use the default test service + service = "https://www.example.com" + # get a client that is already authenticated + client = get_auth_client() + # ask for a ticket for the service but aks for authentication renewal + response = client.get("/login", {'service': service, 'renew': 'on'}) + # we are ask to reauthenticate and tell the user why + self.assertEqual(response.status_code, 200) + self.assertTrue( + ( + b"Authentication renewal required by " + b"service example (https://www.example.com)" + ) in response.content + ) + # get the form default parameter + params = copy_form(response.context["form"]) + # set valid username/password + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + # the renew parameter from the form should be True + self.assertEqual(params["renew"], True) + # post the authentication request + response = client.post("/login", params) + # the request succed, a ticket is created and we are redirected to the service url + self.assertEqual(response.status_code, 302) + ticket_value = response['Location'].split('ticket=')[-1] + ticket = models.ServiceTicket.objects.get(value=ticket_value) + # the created ticket is marked has being gottent after a renew. Futher testing about + # renewing authentication is done in the validate and serviceValidate views tests + self.assertEqual(ticket.renew, True) + + @override_settings(CAS_ENABLE_AJAX_AUTH=True) + def test_ajax_login_required(self): + """ + test ajax, login required. + The ajax methods allow the log a user in using javascript. + For doing so, every 302 redirection a replaced by a 200 returning a json with the + url to redirect to. + By default, ajax login is disabled. + If CAS_ENABLE_AJAX_AUTH is True, ajax login is enable and only page on the same domain + as the CAS can do ajax request. To allow pages on other domains, you need to use CORS. + You can use the django app corsheaders for that. Be carefull to only allow domains + you completly trust as any javascript on these domaine will be able to authenticate + as the user. + """ + # get a bare client + client = Client() + # fetch the login page setting up the custom header HTTP_X_AJAX to tell we wish to de + # ajax requests + response = client.get("/login", HTTP_X_AJAX='on') + # we get a json as response telling us the user need to be authenticated + self.assertEqual(response.status_code, 200) + data = json.loads(response.content.decode("utf8")) + self.assertEqual(data["status"], "error") + self.assertEqual(data["detail"], "login required") + self.assertEqual(data["url"], "/login?") + + @override_settings(CAS_ENABLE_AJAX_AUTH=True) + def test_ajax_logged_user_deleted(self): + """test ajax user logged deleted: login required""" + # get a client that is already authenticated + client = get_auth_client() + # delete the user in the db + user = models.User.objects.get( + username=settings.CAS_TEST_USER, + session_key=client.session.session_key + ) + user.delete() + # fetch the login page with ajax on + response = client.get("/login", HTTP_X_AJAX='on') + # we get a json telling us that the user need to authenticate + self.assertEqual(response.status_code, 200) + data = json.loads(response.content.decode("utf8")) + self.assertEqual(data["status"], "error") + self.assertEqual(data["detail"], "login required") + self.assertEqual(data["url"], "/login?") + + @override_settings(CAS_ENABLE_AJAX_AUTH=True) + def test_ajax_logged(self): + """test ajax user is successfully logged""" + # get a client that is already authenticated + client = get_auth_client() + # fetch the login page with ajax on + response = client.get("/login", HTTP_X_AJAX='on') + # we get a json telling us that the user is well authenticated + self.assertEqual(response.status_code, 200) + data = json.loads(response.content.decode("utf8")) + self.assertEqual(data["status"], "success") + self.assertEqual(data["detail"], "logged") + + @override_settings(CAS_ENABLE_AJAX_AUTH=True) + def test_ajax_get_ticket_success(self): + """test ajax retrieve a ticket for an allowed service""" + # using the default test service + service = "https://www.example.com" + # get a client that is already authenticated + client = get_auth_client() + # fetch the login page with ajax on + response = client.get("/login", {'service': service}, HTTP_X_AJAX='on') + # we get a json telling us that the ticket has being created + # and we get the url to fetch to authenticate the user to the service + # contening the ticket has GET parameter + self.assertEqual(response.status_code, 200) + data = json.loads(response.content.decode("utf8")) + self.assertEqual(data["status"], "success") + self.assertEqual(data["detail"], "auth") + self.assertTrue(data["url"].startswith('%s?ticket=' % service)) + + def test_ajax_get_ticket_success_alt(self): + """ + test ajax retrieve a ticket for an allowed service. + Same as above but with CAS_ENABLE_AJAX_AUTH=False + """ + # using the default test service + service = "https://www.example.com" + # get a client that is already authenticated + client = get_auth_client() + # fetch the login page with ajax on + response = client.get("/login", {'service': service}, HTTP_X_AJAX='on') + # as CAS_ENABLE_AJAX_AUTH is False the ajax request is ignored and word normally: + # 302 redirect to the service url with ticket as GET parameter. javascript + # cannot retieve the ticket info and try follow the redirect to an other domain and fail + # silently + self.assertEqual(response.status_code, 302) + + @override_settings(CAS_ENABLE_AJAX_AUTH=True) + def test_ajax_get_ticket_fail(self): + """test ajax retrieve a ticket for a denied service""" + # using a denied service url + service = "https://www.example.org" + # get a client that is already authenticated + client = get_auth_client() + # fetch the login page with ajax on + response = client.get("/login", {'service': service}, HTTP_X_AJAX='on') + # we get a json telling us that the service is not allowed + self.assertEqual(response.status_code, 200) + data = json.loads(response.content.decode("utf8")) + self.assertEqual(data["status"], "error") + self.assertEqual(data["detail"], "auth") + self.assertEqual(data["messages"][0]["level"], "error") + self.assertEqual( + data["messages"][0]["message"], + "Service https://www.example.org non allowed." + ) + + @override_settings(CAS_ENABLE_AJAX_AUTH=True) + def test_ajax_get_ticket_warn(self): + """test get a ticket but user asked to be warned""" + # using the default test service + service = "https://www.example.com" + # get a client that is already authenticated wth warn on + client = get_auth_client(warn="on") + # fetch the login page with ajax on + response = client.get("/login", {'service': service}, HTTP_X_AJAX='on') + # we get a json telling us that we cannot get a ticket transparently and that the + # user has asked to be warned + self.assertEqual(response.status_code, 200) + data = json.loads(response.content.decode("utf8")) + self.assertEqual(data["status"], "error") + self.assertEqual(data["detail"], "confirmation needed") + + +@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser') +class LogoutTestCase(TestCase): + """test fot the logout view""" + def setUp(self): + """Prepare the test context""" + # for testing SingleLogOut we need to use a service on localhost were we lanch + # a simple one request http server + self.service = 'http://127.0.0.1:45678' + self.service_pattern = models.ServicePattern.objects.create( + name="localhost", + pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$", + single_log_out=True + ) + # return all user attributes + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + def test_logout(self): + """logout is idempotent""" + # get a bare client + client = Client() + + # call logout + client.get("/logout") + + # we are still not logged + self.assertFalse(client.session.get("username")) + self.assertFalse(client.session.get("authenticated")) + + def test_logout_view(self): + """test simple logout, logout only an user from one and only one sessions""" + # get two authenticated client with the same test user (but two different sessions) + client = get_auth_client() + client2 = get_auth_client() + + # fetch login, the first client is well authenticated + response = client.get("/login") + self.assertEqual(response.status_code, 200) + self.assertTrue( + ( + b"You have successfully logged into " + b"the Central Authentication Service" + ) in response.content + ) + # and session variable are well + self.assertTrue(client.session["username"] == settings.CAS_TEST_USER) + self.assertTrue(client.session["authenticated"] is True) + + # call logout with the first client + response = client.get("/logout") + # the client is logged out + self.assertEqual(response.status_code, 200) + self.assertTrue( + ( + b"You have successfully logged out from " + b"the Central Authentication Service" + ) in response.content + ) + # and session variable a well cleaned + self.assertFalse(client.session.get("username")) + self.assertFalse(client.session.get("authenticated")) + # client2 is still logged + self.assertTrue(client2.session["username"] == settings.CAS_TEST_USER) + self.assertTrue(client2.session["authenticated"] is True) + + response = client.get("/login") + # fetch login, the second client is well authenticated + self.assertEqual(response.status_code, 200) + self.assertFalse( + ( + b"You have successfully logged into " + b"the Central Authentication Service" + ) in response.content + ) + + def test_logout_from_all_session(self): + """test logout from all my session""" + # get two authenticated client with the same test user (but two different sessions) + client = get_auth_client() + client2 = get_auth_client() + + # call logout with the first client and ask to be logged out from all of this user sessions + client.get("/logout?all=1") + + # both client are logged out + self.assertFalse(client.session.get("username")) + self.assertFalse(client.session.get("authenticated")) + self.assertFalse(client2.session.get("username")) + self.assertFalse(client2.session.get("authenticated")) + + def assert_redirect_to_service(self, client, response): + """assert logout redirect to parameter""" + # assert a redirection with a service + self.assertEqual(response.status_code, 302) + self.assertTrue(response.has_header("Location")) + self.assertEqual(response["Location"], "https://www.example.com") + + response = client.get("/login") + self.assertEqual(response.status_code, 200) + # assert we are not longer logged in + self.assertFalse( + ( + b"You have successfully logged into " + b"the Central Authentication Service" + ) in response.content + ) + + def test_logout_view_url(self): + """test logout redirect to url parameter""" + # get a client that is authenticated + client = get_auth_client() + + # logout with an url paramer + response = client.get('/logout?url=https://www.example.com') + # we are redirected to the addresse of the url parameter + self.assert_redirect_to_service(client, response) + + def test_logout_view_service(self): + """test logout redirect to service parameter""" + # get a client that is authenticated + client = get_auth_client() + + # logout with a service parameter + response = client.get('/logout?service=https://www.example.com') + # we are redirected to the addresse of the service parameter + self.assert_redirect_to_service(client, response) + + def test_logout_slo(self): + """test logout from a service with SLO support""" + parameters = [] + + # test normal SLO + # setup a simple one request http server + (httpd, host, port) = HttpParamsHandler.run()[0:3] + # build a service url depending on which port the http server has binded + service = "http://%s:%s" % (host, port) + # get a ticket requested by client and being validated by the service + (client, ticket) = get_validated_ticket(service)[:2] + # the client logout triggering the send of the SLO requests + client.get('/logout') + # we store the POST parameters send for this ticket for furthur analisys + parameters.append((httpd.PARAMS, ticket)) + + # text SLO with a single_log_out_callback + # setup a simple one request http server + (httpd, host, port) = HttpParamsHandler.run()[0:3] + # set the default test service pattern to use the http server port for SLO requests. + # in fact, this single_log_out_callback parametter is usefull to implement SLO + # for non http service like imap or ftp + self.service_pattern.single_log_out_callback = "http://%s:%s" % (host, port) + self.service_pattern.save() + # get a ticket requested by client and being validated by the service + (client, ticket) = get_validated_ticket(self.service)[:2] + # the client logout triggering the send of the SLO requests + client.get('/logout') + # we store the POST parameters send for this ticket for furthur analisys + parameters.append((httpd.PARAMS, ticket)) + + # for earch POST parameters and corresponding ticket + for (params, ticket) in parameters: + # there is a POST parameter 'logoutRequest' + self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest']) + + # it is a valid xml + root = etree.fromstring(params[b'logoutRequest'][0]) + # contening a tag + self.assertTrue( + root.xpath( + "//samlp:LogoutRequest", + namespaces={"samlp": "urn:oasis:names:tc:SAML:2.0:protocol"} + ) + ) + # with a tag enclosing the value of the ticket + session_index = root.xpath( + "//samlp:SessionIndex", + namespaces={"samlp": "urn:oasis:names:tc:SAML:2.0:protocol"} + ) + self.assertEqual(len(session_index), 1) + self.assertEqual(session_index[0].text, ticket.value) + + # SLO error are displayed on logout page + (client, ticket) = get_validated_ticket(self.service)[:2] + # the client logout triggering the send of the SLO requests but + # not http server are listening + response = client.get('/logout') + self.assertTrue(b"Error during service logout" in response.content) + + @override_settings(CAS_ENABLE_AJAX_AUTH=True) + def test_ajax_logout(self): + """ + test ajax logout. These methode are here, but I do not really see an use case for + javascript logout + """ + # get a client that is authenticated + client = get_auth_client() + + # fetch the logout page with ajax on + response = client.get('/logout', HTTP_X_AJAX='on') + # we get a json telling us the user is well logged out + self.assertEqual(response.status_code, 200) + data = json.loads(response.content.decode("utf8")) + self.assertEqual(data["status"], "success") + self.assertEqual(data["detail"], "logout") + self.assertEqual(data['session_nb'], 1) + + @override_settings(CAS_ENABLE_AJAX_AUTH=True) + def test_ajax_logout_all_session(self): + """test ajax logout from a random number a sessions""" + # fire a random int in [2, 10[ + nb_client = random.randint(2, 10) + # get this much of logged clients all for the test user + clients = [get_auth_client() for i in range(nb_client)] + # fetch the logout page with ajax on, requesting to logout from all sessions + response = clients[0].get('/logout?all=1', HTTP_X_AJAX='on') + # we get a json telling us the user is well logged out and the number of session + # the user has being logged out + self.assertEqual(response.status_code, 200) + data = json.loads(response.content.decode("utf8")) + self.assertEqual(data["status"], "success") + self.assertEqual(data["detail"], "logout") + self.assertEqual(data['session_nb'], nb_client) + + @override_settings(CAS_REDIRECT_TO_LOGIN_AFTER_LOGOUT=True) + def test_redirect_after_logout(self): + """Test redirect to login after logout parameter""" + # get a client that is authenticated + client = get_auth_client() + + # fetch the logout page + response = client.get('/logout') + # as CAS_REDIRECT_TO_LOGIN_AFTER_LOGOUT is True, we are redirected to the login page + self.assertEqual(response.status_code, 302) + if django.VERSION < (1, 9): # pragma: no cover coverage is computed with dango 1.9 + self.assertEqual(response["Location"], "http://testserver/login") + else: + self.assertEqual(response["Location"], "/login") + self.assertFalse(client.session.get("username")) + self.assertFalse(client.session.get("authenticated")) + + @override_settings(CAS_REDIRECT_TO_LOGIN_AFTER_LOGOUT=True) + def test_redirect_after_logout_to_service(self): + """test prevalence of redirect url/service parameter over redirect to login after logout""" + # get a client that is authenticated + client = get_auth_client() + + # fetch the logout page with an url parameter + response = client.get('/logout?url=https://www.example.com') + # we are redirected to the url parameter and not to the login page + self.assert_redirect_to_service(client, response) + + # fetch the logout page with an service parameter + response = client.get('/logout?service=https://www.example.com') + # we are redirected to the service parameter and not to the login page + self.assert_redirect_to_service(client, response) + + @override_settings(CAS_REDIRECT_TO_LOGIN_AFTER_LOGOUT=True, CAS_ENABLE_AJAX_AUTH=True) + def test_ajax_redirect_after_logout(self): + """Test ajax redirect to login after logout parameter""" + # get a client that is authenticated + client = get_auth_client() + + # fetch the logout page with ajax on + response = client.get('/logout', HTTP_X_AJAX='on') + # we get a json telling us the user is well logged out. And url key is added to aks for + # redirection to the login page + self.assertEqual(response.status_code, 200) + data = json.loads(response.content.decode("utf8")) + self.assertEqual(data["status"], "success") + self.assertEqual(data["detail"], "logout") + self.assertEqual(data['session_nb'], 1) + self.assertEqual(data['url'], '/login') + + +@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser') +class AuthTestCase(TestCase): + """ + Test for the auth view, used for external services + to validate (user, pass, service) tuples. + """ + def setUp(self): + """preparing test context""" + # setting up a default test service url and pattern + self.service = 'https://www.example.com' + models.ServicePattern.objects.create( + name="example", + pattern="^https://www\.example\.com(/.*)?$" + ) + + @override_settings(CAS_AUTH_SHARED_SECRET='test') + def test_auth_view_goodpass(self): + """successful request are awsered by yes""" + # get a bare client + client = Client() + # post the the auth view a valid (username, password, service) and the shared secret + # to test the user again the service, a user is created in the database for the + # current session and is then deleted as the user is not authenticated + response = client.post( + '/auth', + { + 'username': settings.CAS_TEST_USER, + 'password': settings.CAS_TEST_PASSWORD, + 'service': self.service, + 'secret': 'test' + } + ) + # as (username, password, service) and the hared secret are valid, we get yes as a response + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'yes\n') + + @override_settings(CAS_AUTH_SHARED_SECRET='test') + def test_auth_view_goodpass_logged(self): + """successful request are awsered by yes, using a logged sessions""" + # same as above + client = get_auth_client() + # to test the user again the service, a user is fetch in the database for the + # current session and is NOT deleted as the user is currently logged. + # Deleting the user from the database would cause the user to be logged out as + # showed in the login tests + response = client.post( + '/auth', + { + 'username': settings.CAS_TEST_USER, + 'password': settings.CAS_TEST_PASSWORD, + 'service': self.service, + 'secret': 'test' + } + ) + # as (username, password, service) and the hared secret are valid, we get yes as a response + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'yes\n') + + @override_settings(CAS_AUTH_SHARED_SECRET='test') + def test_auth_view_badpass(self): + """ bag user password => no""" + client = Client() + response = client.post( + '/auth', + { + 'username': settings.CAS_TEST_USER, + 'password': 'badpass', + 'service': self.service, + 'secret': 'test' + } + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'no\n') + + @override_settings(CAS_AUTH_SHARED_SECRET='test') + def test_auth_view_badservice(self): + """bad service => no""" + client = Client() + response = client.post( + '/auth', + { + 'username': settings.CAS_TEST_USER, + 'password': settings.CAS_TEST_PASSWORD, + 'service': 'https://www.example.org', + 'secret': 'test' + } + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'no\n') + + @override_settings(CAS_AUTH_SHARED_SECRET='test') + def test_auth_view_badsecret(self): + """bad api key => no""" + client = Client() + response = client.post( + '/auth', + { + 'username': settings.CAS_TEST_USER, + 'password': settings.CAS_TEST_PASSWORD, + 'service': self.service, + 'secret': 'badsecret' + } + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'no\n') + + def test_auth_view_badsettings(self): + """api not set => error""" + client = Client() + response = client.post( + '/auth', + { + 'username': settings.CAS_TEST_USER, + 'password': settings.CAS_TEST_PASSWORD, + 'service': self.service, + 'secret': 'test' + } + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b"no\nplease set CAS_AUTH_SHARED_SECRET") + + @override_settings(CAS_AUTH_SHARED_SECRET='test') + def test_auth_view_missing_parameter(self): + """missing parameter in request => no""" + client = Client() + params = { + 'username': settings.CAS_TEST_USER, + 'password': settings.CAS_TEST_PASSWORD, + 'service': self.service, + 'secret': 'test' + } + for key in ['username', 'password', 'service']: + send_params = params.copy() + del send_params[key] + response = client.post('/auth', send_params) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'no\n') + + +@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser') +class ValidateTestCase(TestCase): + """tests for the validate view""" + def setUp(self): + """preparing test context""" + # setting up a default test service url and pattern + self.service = 'https://www.example.com' + self.service_pattern = models.ServicePattern.objects.create( + name="example", + pattern="^https://www\.example\.com(/.*)?$" + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + # setting up a test service and pattern using a multi valued user attribut as username + # the first value of the list should be used as username + self.service_user_field = "https://user_field.example.com" + self.service_pattern_user_field = models.ServicePattern.objects.create( + name="user field", + pattern="^https://user_field\.example\.com(/.*)?$", + user_field="alias" + ) + # setting up a test service and pattern using a single valued user attribut as username + self.service_user_field_alt = "https://user_field_alt.example.com" + self.service_pattern_user_field_alt = models.ServicePattern.objects.create( + name="user field alt", + pattern="^https://user_field_alt\.example\.com(/.*)?$", + user_field="nom" + ) + + def test_validate_view_ok(self): + """test for a valid (ticket, service)""" + # get a ticket waiting to be validated for self.service + ticket = get_user_ticket_request(self.service)[1] + + # get a bare client + client = Client() + # calling the validate view with this ticket value and service + response = client.get('/validate', {'ticket': ticket.value, 'service': self.service}) + # get yes as a response and the test user username + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'yes\ntest\n') + + def test_validate_view_badservice(self): + """test for a valid ticket but bad service""" + ticket = get_user_ticket_request(self.service)[1] + + client = Client() + # calling the validate view with this ticket value and another service + response = client.get( + '/validate', + {'ticket': ticket.value, 'service': "https://www.example.org"} + ) + # the ticket service and validation service do not match, validation should fail + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'no\n') + + def test_validate_view_badticket(self): + """test for a bad ticket but valid service""" + get_user_ticket_request(self.service) + + client = Client() + # calling the validate view with another ticket value and this service + response = client.get( + '/validate', + {'ticket': "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX, 'service': self.service} + ) + # as the ticket is bad, validation should fail + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'no\n') + + def test_validate_user_field_ok(self): + """ + test with a good user_field. A bad user_field (that evaluate to False) + wont happed cause it is filtered in the login view + """ + for (service, username) in [ + (self.service_user_field, b"demo1"), + (self.service_user_field_alt, b"Nymous") + ]: + ticket = get_user_ticket_request(service)[1] + client = Client() + response = client.get( + '/validate', + {'ticket': ticket.value, 'service': service} + ) + self.assertEqual(response.status_code, 200) + # the user attribute is well used as username + self.assertEqual(response.content, b'yes\n' + username + b'\n') + + def test_validate_missing_parameter(self): + """test with a missing GET parameter among [service, ticket]""" + ticket = get_user_ticket_request(self.service)[1] + + client = Client() + params = {'ticket': ticket.value, 'service': self.service} + for key in ['ticket', 'service']: + send_params = params.copy() + del send_params[key] + response = client.get('/validate', send_params) + # if the GET request is missing the ticket or + # service GET parameter, validation should fail + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'no\n') + + +@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser') +class ValidateServiceTestCase(TestCase, XmlContent): + """tests for the serviceValidate view""" + def setUp(self): + """preparing test context""" + # for testing SingleLogOut and Proxy GrantingTicket transmission + # we need to use a service on localhost were we launch + # a simple one request http server + self.service = 'http://127.0.0.1:45678' + self.service_pattern = models.ServicePattern.objects.create( + name="localhost", + pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$", + # allow to request PGT by the service + proxy_callback=True + ) + # tell the service pattern to transmit all the user attributes (* is a joker) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + # test service pattern using the attribute alias as username + self.service_user_field = "https://user_field.example.com" + self.service_pattern_user_field = models.ServicePattern.objects.create( + name="user field", + pattern="^https://user_field\.example\.com(/.*)?$", + user_field="alias" + ) + # test service pattern using the attribute nom as username + self.service_user_field_alt = "https://user_field_alt.example.com" + self.service_pattern_user_field_alt = models.ServicePattern.objects.create( + name="user field alt", + pattern="^https://user_field_alt\.example\.com(/.*)?$", + user_field="nom" + ) + + # test service pattern only transmiting one single attributes + self.service_one_attribute = "https://one_attribute.example.com" + self.service_pattern_one_attribute = models.ServicePattern.objects.create( + name="one_attribute", + pattern="^https://one_attribute\.example\.com(/.*)?$" + ) + models.ReplaceAttributName.objects.create( + name="nom", + service_pattern=self.service_pattern_one_attribute + ) + + # test service pattern testing attribute name and value replacement + self.service_replace_attribute_list = "https://replace_attribute_list.example.com" + self.service_pattern_replace_attribute_list = models.ServicePattern.objects.create( + name="replace_attribute_list", + pattern="^https://replace_attribute_list\.example\.com(/.*)?$", + ) + models.ReplaceAttributValue.objects.create( + attribut="alias", + pattern="^demo", + replace="truc", + service_pattern=self.service_pattern_replace_attribute_list + ) + models.ReplaceAttributName.objects.create( + name="alias", + replace="ALIAS", + service_pattern=self.service_pattern_replace_attribute_list + ) + self.service_replace_attribute = "https://replace_attribute.example.com" + self.service_pattern_replace_attribute = models.ServicePattern.objects.create( + name="replace_attribute", + pattern="^https://replace_attribute\.example\.com(/.*)?$", + ) + models.ReplaceAttributValue.objects.create( + attribut="nom", + pattern="N", + replace="P", + service_pattern=self.service_pattern_replace_attribute + ) + models.ReplaceAttributName.objects.create( + name="nom", + replace="NOM", + service_pattern=self.service_pattern_replace_attribute + ) + + def test_validate_service_view_ok(self): + """test with a valid (ticket, service), the username and all attributes are transmited""" + # get a ticket from an authenticated user waiting for validation + ticket = get_user_ticket_request(self.service)[1] + + # get a bare client + client = Client() + # requesting validation with a good (ticket, service) + response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service}) + # the validation should succes with username settings.CAS_TEST_USER and transmit + # the attributes settings.CAS_TEST_ATTRIBUTES + self.assert_success(response, settings.CAS_TEST_USER, settings.CAS_TEST_ATTRIBUTES) + + def test_validate_service_view_ok_one_attribute(self): + """ + test with a valid (ticket, service), the username and + the 'nom' only attribute are transmited + """ + # get a ticket for a service that transmit only one attribute + ticket = get_user_ticket_request(self.service_one_attribute)[1] + + client = Client() + response = client.get( + '/serviceValidate', + {'ticket': ticket.value, 'service': self.service_one_attribute} + ) + # the validation should succed, returning settings.CAS_TEST_USER as username and a single + # attribute 'nom' + self.assert_success( + response, + settings.CAS_TEST_USER, + {'nom': settings.CAS_TEST_ATTRIBUTES['nom']} + ) + + def test_validate_replace_attributes(self): + """test with a valid (ticket, service), attributes name and value replacement""" + # get a ticket for a service pattern replacing attributes names + # nom -> NOM and value nom -> s/^N/P/ for a single valued attribute + ticket = get_user_ticket_request(self.service_replace_attribute)[1] + client = Client() + response = client.get( + '/serviceValidate', + {'ticket': ticket.value, 'service': self.service_replace_attribute} + ) + self.assert_success( + response, + settings.CAS_TEST_USER, + {'NOM': 'Pymous'} + ) + + # get a ticket for a service pattern replacing attributes names + # alias -> ALIAS and value alias -> s/demo/truc/ for a multi valued attribute + ticket = get_user_ticket_request(self.service_replace_attribute_list)[1] + client = Client() + response = client.get( + '/serviceValidate', + {'ticket': ticket.value, 'service': self.service_replace_attribute_list} + ) + self.assert_success( + response, + settings.CAS_TEST_USER, + {'ALIAS': ['truc1', 'truc2']} + ) + + def test_validate_service_view_badservice(self): + """test with a valid ticket but a bad service, the validatin should fail""" + # get a ticket for service A + ticket = get_user_ticket_request(self.service)[1] + + client = Client() + bad_service = "https://www.example.org" + # try to validate it for service B + response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': bad_service}) + # the validation should fail with error code "INVALID_SERVICE" + self.assert_error( + response, + "INVALID_SERVICE", + bad_service + ) + + def test_validate_service_view_badticket_goodprefix(self): + """ + test with a good service but a bad ticket begining with ST-, + the validation should fail with the error (INVALID_TICKET, ticket not found) + """ + get_user_ticket_request(self.service) + + client = Client() + bad_ticket = "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX + response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service}) + self.assert_error( + response, + "INVALID_TICKET", + 'ticket not found' + ) + + def test_validate_service_view_badticket_badprefix(self): + """ + test with a good service bud a bad ticket not begining with ST-, + the validation should fail with the error (INVALID_TICKET, `the ticket`) + """ + get_user_ticket_request(self.service) + + client = Client() + bad_ticket = "RANDOM" + response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service}) + self.assert_error( + response, + "INVALID_TICKET", + bad_ticket + ) + + def test_validate_service_view_ok_pgturl(self): + """test the retrieval of a ProxyGrantingTicket""" + # start a simple on request http server + (httpd, host, port) = HttpParamsHandler.run()[0:3] + # construct the service from it + service = "http://%s:%s" % (host, port) + + # get a ticket to be validated + ticket = get_user_ticket_request(service)[1] + + client = Client() + # request a PGT ticket then validating the ticket by setting the pgtUrl parameter + response = client.get( + '/serviceValidate', + {'ticket': ticket.value, 'service': service, 'pgtUrl': service} + ) + # We should have recieved the PGT via a GET request parameter on the simple http server + pgt_params = httpd.PARAMS + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + # the validation response should return a id to match again the request transmitting the PGT + pgtiou = root.xpath( + "//cas:proxyGrantingTicket", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(pgtiou), 1) + # the matching id for making corresponde one PGT to a validatin response should match + self.assertEqual(pgt_params["pgtIou"], pgtiou[0].text) + # the PGT is present in the receive GET requests parameters + self.assertTrue("pgtId" in pgt_params) + + def test_validate_service_pgturl_sslerror(self): + """test the retrieval of a ProxyGrantingTicket with a SSL error on the pgtUrl""" + (host, port) = HttpParamsHandler.run()[1:3] + # is fact the service listen on http and not https raisin a SSL Protocol Error + # but other SSL/TLS error should behave the same + service = "https://%s:%s" % (host, port) + + ticket = get_user_ticket_request(service)[1] + + client = Client() + response = client.get( + '/serviceValidate', + {'ticket': ticket.value, 'service': service, 'pgtUrl': service} + ) + # The pgtUrl is validated: it must be localhost or have valid x509 certificat and + # certificat validation should succed. Moreother, pgtUrl should match a service pattern + # with proxy_callback set to True + self.assert_error( + response, + "INVALID_PROXY_CALLBACK", + ) + + def test_validate_service_pgturl_404(self): + """ + test the retrieval on a ProxyGrantingTicket then to pgtUrl return a http error. + PGT creation should be aborted but the ticket still be valid + """ + (host, port) = Http404Handler.run()[1:3] + service = "http://%s:%s" % (host, port) + + ticket = get_user_ticket_request(service)[1] + client = Client() + response = client.get( + '/serviceValidate', + {'ticket': ticket.value, 'service': service, 'pgtUrl': service} + ) + # The ticket is successfully validated + root = self.assert_success(response, settings.CAS_TEST_USER, settings.CAS_TEST_ATTRIBUTES) + # but no PGT is transmitted + pgtiou = root.xpath( + "//cas:proxyGrantingTicket", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertFalse(pgtiou) + + def test_validate_service_pgturl_bad_proxy_callback(self): + """test the retrieval of a ProxyGrantingTicket, not allowed pgtUrl should be denied""" + self.service_pattern.proxy_callback = False + self.service_pattern.save() + ticket = get_user_ticket_request(self.service)[1] + client = Client() + response = client.get( + '/serviceValidate', + {'ticket': ticket.value, 'service': self.service, 'pgtUrl': self.service} + ) + self.assert_error( + response, + "INVALID_PROXY_CALLBACK", + "callback url not allowed by configuration" + ) + + self.service_pattern.proxy_callback = True + + ticket = get_user_ticket_request(self.service)[1] + client = Client() + response = client.get( + '/serviceValidate', + {'ticket': ticket.value, 'service': self.service, 'pgtUrl': "https://www.example.org"} + ) + self.assert_error( + response, + "INVALID_PROXY_CALLBACK", + "callback url not allowed by configuration" + ) + + def test_validate_user_field_ok(self): + """ + test with a good user_field. A bad user_field (that evaluate to False) + wont happed cause it is filtered in the login view + """ + for (service, username) in [ + (self.service_user_field, settings.CAS_TEST_ATTRIBUTES["alias"][0]), + (self.service_user_field_alt, settings.CAS_TEST_ATTRIBUTES["nom"]) + ]: + # requesting a ticket for a service url matched by a service pattern using a user + # attribute as username + ticket = get_user_ticket_request(service)[1] + client = Client() + response = client.get( + '/serviceValidate', + {'ticket': ticket.value, 'service': service} + ) + # The validate shoudl be successful with specified username and no attributes transmited + self.assert_success( + response, + username, + {} + ) + + def test_validate_missing_parameter(self): + """test with a missing GET parameter among [service, ticket]""" + ticket = get_user_ticket_request(self.service)[1] + + client = Client() + params = {'ticket': ticket.value, 'service': self.service} + for key in ['ticket', 'service']: + send_params = params.copy() + del send_params[key] + response = client.get('/serviceValidate', send_params) + # a validation request with a missing GET parameter should fail + self.assert_error( + response, + "INVALID_REQUEST", + "you must specify a service and a ticket" + ) + + +@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser') +class ProxyTestCase(TestCase, BaseServicePattern, XmlContent): + """tests for the proxy view""" + def setUp(self): + """preparing test context""" + # we prepare a bunch a service url and service patterns for tests + self.setup_service_patterns(proxy=True) + + # set the default service pattern to localhost to be able to retrieve PGT + self.service = 'http://127.0.0.1' + self.service_pattern = models.ServicePattern.objects.create( + name="localhost", + pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$", + proxy=True, + proxy_callback=True + ) + # transmit all attributes + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + def test_validate_proxy_ok(self): + """ + Get a PGT, get a proxy ticket, validate it. Validation should succeed and + show the proxy service URL. + """ + # we directrly get a ProxyGrantingTicket + params = get_pgt() + + # We try get a proxy ticket with our PGT + client1 = Client() + # for what we send a GET request to /proxy with ge PGT and the target service for which + # we want a ProxyTicket to. + response = client1.get( + '/proxy', + {'pgt': params['pgtId'], 'targetService': "https://www.example.com"} + ) + self.assertEqual(response.status_code, 200) + + # we should sucessfully reteive a PT + root = etree.fromstring(response.content) + sucess = root.xpath("//cas:proxySuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertTrue(sucess) + + proxy_ticket = root.xpath( + "//cas:proxyTicket", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(proxy_ticket), 1) + proxy_ticket = proxy_ticket[0].text + + # validate the proxy ticket with the service for which is was emitted + client2 = Client() + response = client2.get( + '/proxyValidate', + {'ticket': proxy_ticket, 'service': "https://www.example.com"} + ) + # validation should succeed and return settings.CAS_TEST_USER as username + # and settings.CAS_TEST_ATTRIBUTES as attributes + root = self.assert_success( + response, + settings.CAS_TEST_USER, + settings.CAS_TEST_ATTRIBUTES + ) + + # in the PT validation response, it should have the service url of the PGY + proxies = root.xpath("//cas:proxies", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(proxies), 1) + proxy = proxies[0].xpath("//cas:proxy", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(proxy), 1) + self.assertEqual(proxy[0].text, params["service"]) + + def test_validate_proxy_bad_pgt(self): + """Try to get a ProxyTicket with a bad PGT. The PT generation should fail""" + # we directrly get a ProxyGrantingTicket + params = get_pgt() + client = Client() + response = client.get( + '/proxy', + { + 'pgt': "%s-RANDOM" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX, + 'targetService': params['service'] + } + ) + self.assert_error( + response, + "INVALID_TICKET", + "PGT %s-RANDOM not found" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX + ) + + def test_validate_proxy_bad_service(self): + """ + Try to get a ProxyTicket for a denied service and + a service that do not allow PT. The PT generation should fail. + """ + # we directrly get a ProxyGrantingTicket + params = get_pgt() + + # try to get a PT for a denied service + client1 = Client() + response = client1.get( + '/proxy', + {'pgt': params['pgtId'], 'targetService': "https://www.example.org"} + ) + self.assert_error( + response, + "UNAUTHORIZED_SERVICE", + "https://www.example.org" + ) + + # try to get a PT for a service that do not allow PT + self.service_pattern.proxy = False + self.service_pattern.save() + + client2 = Client() + response = client2.get( + '/proxy', + {'pgt': params['pgtId'], 'targetService': params['service']} + ) + + self.assert_error( + response, + "UNAUTHORIZED_SERVICE", + 'the service %s do not allow proxy ticket' % params['service'] + ) + + self.service_pattern.proxy = True + self.service_pattern.save() + + def test_proxy_unauthorized_user(self): + """ + Try to get a PT for services that do not allow the current user: + * first with a service that restrict allowed username + * second with a service requiring somes conditions on the user attributes + * third with a service using a particular user attribute as username + All this tests should fail + """ + # we directrly get a ProxyGrantingTicket + params = get_pgt() + + for service in [ + # do ot allow the test username + self.service_restrict_user_fail, + # require the 'nom' attribute to be 'toto' + self.service_filter_fail, + # want to use the non-exitant 'uid' attribute as username + self.service_field_needed_fail + ]: + client = Client() + response = client.get( + '/proxy', + {'pgt': params['pgtId'], 'targetService': service} + ) + # PT generation should fail + self.assert_error( + response, + "UNAUTHORIZED_USER", + 'User %s not allowed on %s' % (settings.CAS_TEST_USER, service) + ) + + def test_proxy_missing_parameter(self): + """Try to get a PGT with some missing GET parameters. The PT should not be emited""" + params = get_pgt() + base_params = {'pgt': params['pgtId'], 'targetService': "https://www.example.org"} + for key in ["pgt", 'targetService']: + send_params = base_params.copy() + del send_params[key] + client = Client() + response = client.get("/proxy", send_params) + self.assert_error( + response, + "INVALID_REQUEST", + 'you must specify and pgt and targetService' + ) + + +@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser') +class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent): + """tests for the proxy view""" + def setUp(self): + """preparing test context""" + # we prepare a bunch a service url and service patterns for tests + self.setup_service_patterns(proxy=True) + + # special service pattern for retrieving a PGT + self.service_pgt = 'http://127.0.0.1' + self.service_pattern_pgt = models.ServicePattern.objects.create( + name="localhost", + pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$", + proxy=True, + proxy_callback=True + ) + models.ReplaceAttributName.objects.create( + name="*", + service_pattern=self.service_pattern_pgt + ) + + # template for the XML POST need to be send to validate a ticket using SAML 1.1 + xml_template = """ + + + + + %(ticket)s + + +""" + + def assert_success(self, response, username, original_attributes): + """assert ticket validation success""" + self.assertEqual(response.status_code, 200) + # on validation success, the response should have a StatusCode set to Success + root = etree.fromstring(response.content) + success = root.xpath( + "//samlp:StatusCode", + namespaces={'samlp': "urn:oasis:names:tc:SAML:1.0:protocol"} + ) + self.assertEqual(len(success), 1) + self.assertTrue(success[0].attrib['Value'].endswith(":Success")) + + # the user username should be return whithin tags + user = root.xpath( + "//samla:NameIdentifier", + namespaces={'samla': "urn:oasis:names:tc:SAML:1.0:assertion"} + ) + self.assertTrue(user) + self.assertEqual(user[0].text, username) + + # the returned attributes should match original_attributes + attributes = root.xpath( + "//samla:AttributeStatement/samla:Attribute", + namespaces={'samla': "urn:oasis:names:tc:SAML:1.0:assertion"} + ) + attrs = set() + for attr in attributes: + attrs.add((attr.attrib['AttributeName'], attr.getchildren()[0].text)) + original = set() + for key, value in original_attributes.items(): + if isinstance(value, list): + for subval in value: + original.add((key, subval)) + else: + original.add((key, value)) + self.assertEqual(original, attrs) + + def assert_error(self, response, code, msg=None): + """assert ticket validation error""" + self.assertEqual(response.status_code, 200) + # on error the status code value should be the one provider in `code` + root = etree.fromstring(response.content) + error = root.xpath( + "//samlp:StatusCode", + namespaces={'samlp': "urn:oasis:names:tc:SAML:1.0:protocol"} + ) + self.assertEqual(len(error), 1) + self.assertTrue(error[0].attrib['Value'].endswith(":%s" % code)) + # it may have an error message + if msg is not None: + self.assertEqual(error[0].text, msg) + + def test_saml_ok(self): + """ + test with a valid (ticket, service), with a ST and a PT, + the username and all attributes are transmited""" + tickets = [ + # return a ServiceTicket (standard ticket) waiting for validation + get_user_ticket_request(self.service)[1], + # return a PT waiting for validation + get_proxy_ticket(self.service) + ] + + for ticket in tickets: + client = Client() + # we send the POST validation requests + response = client.post( + '/samlValidate?TARGET=%s' % self.service, + self.xml_template % { + 'ticket': ticket.value, + 'request_id': utils.gen_saml_id(), + 'issue_instant': timezone.now().isoformat() + }, + content_type="text/xml; encoding='utf-8'" + ) + # and it should succeed + self.assert_success(response, settings.CAS_TEST_USER, settings.CAS_TEST_ATTRIBUTES) + + def test_saml_ok_user_field(self): + """test with a valid(ticket, service), use a attributes as transmitted username""" + for (service, username) in [ + (self.service_field_needed_success, settings.CAS_TEST_ATTRIBUTES['alias'][0]), + (self.service_field_needed_success_alt, settings.CAS_TEST_ATTRIBUTES['nom']) + ]: + ticket = get_user_ticket_request(service)[1] + + client = Client() + response = client.post( + '/samlValidate?TARGET=%s' % service, + self.xml_template % { + 'ticket': ticket.value, + 'request_id': utils.gen_saml_id(), + 'issue_instant': timezone.now().isoformat() + }, + content_type="text/xml; encoding='utf-8'" + ) + self.assert_success(response, username, {}) + + def test_saml_bad_ticket(self): + """test validation with a bad ST and a bad PT, validation should fail""" + tickets = [utils.gen_st(), utils.gen_pt()] + + for ticket in tickets: + client = Client() + response = client.post( + '/samlValidate?TARGET=%s' % self.service, + self.xml_template % { + 'ticket': ticket, + 'request_id': utils.gen_saml_id(), + 'issue_instant': timezone.now().isoformat() + }, + content_type="text/xml; encoding='utf-8'" + ) + self.assert_error( + response, + "AuthnFailed", + 'ticket %s not found' % ticket + ) + + def test_saml_bad_ticket_prefix(self): + """test validation with a bad ticket prefix. Validation should fail with 'AuthnFailed'""" + bad_ticket = "RANDOM-NOT-BEGINING-WITH-ST-OR-ST" + client = Client() + response = client.post( + '/samlValidate?TARGET=%s' % self.service, + self.xml_template % { + 'ticket': bad_ticket, + 'request_id': utils.gen_saml_id(), + 'issue_instant': timezone.now().isoformat() + }, + content_type="text/xml; encoding='utf-8'" + ) + self.assert_error( + response, + "AuthnFailed", + 'ticket %s should begin with PT- or ST-' % bad_ticket + ) + + def test_saml_bad_target(self): + """test with a valid ticket, but using a bad target, validation should fail""" + bad_target = "https://www.example.org" + ticket = get_user_ticket_request(self.service)[1] + + client = Client() + response = client.post( + '/samlValidate?TARGET=%s' % bad_target, + self.xml_template % { + 'ticket': ticket.value, + 'request_id': utils.gen_saml_id(), + 'issue_instant': timezone.now().isoformat() + }, + content_type="text/xml; encoding='utf-8'" + ) + self.assert_error( + response, + "AuthnFailed", + 'TARGET %s do not match ticket service' % bad_target + ) + + def test_saml_bad_xml(self): + """test validation with a bad xml request, validation should fail""" + client = Client() + response = client.post( + '/samlValidate?TARGET=%s' % self.service, + "", + content_type="text/xml; encoding='utf-8'" + ) + self.assert_error(response, 'VersionMismatch') diff --git a/cas_server/tests/urls.py b/cas_server/tests/urls.py new file mode 100644 index 0000000..a9ed25c --- /dev/null +++ b/cas_server/tests/urls.py @@ -0,0 +1,22 @@ +"""cas URL Configuration + +The `urlpatterns` list routes URLs to views. For more information please see: + https://docs.djangoproject.com/en/1.9/topics/http/urls/ +Examples: +Function views + 1. Add an import: from my_app import views + 2. Add a URL to urlpatterns: url(r'^$', views.home, name='home') +Class-based views + 1. Add an import: from other_app.views import Home + 2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home') +Including another URLconf + 1. Import the include() function: from django.conf.urls import url, include, include + 2. Add a URL to urlpatterns: url(r'^blog/', include('blog.urls')) +""" +from django.conf.urls import url, include +from django.contrib import admin + +urlpatterns = [ + url(r'^admin/', admin.site.urls), + url(r'^', include('cas_server.urls', namespace='cas_server')), +] diff --git a/cas_server/tests/utils.py b/cas_server/tests/utils.py new file mode 100644 index 0000000..bd692e9 --- /dev/null +++ b/cas_server/tests/utils.py @@ -0,0 +1,180 @@ +# ⁻*- coding: utf-8 -*- +# This program is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for +# more details. +# +# You should have received a copy of the GNU General Public License version 3 +# along with this program; if not, write to the Free Software Foundation, Inc., 51 +# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# (c) 2016 Valentin Samir +"""Some utils functions for tests""" +from cas_server.default_settings import settings + +from django.test import Client + +import cgi +from threading import Thread +from lxml import etree +from six.moves import BaseHTTPServer +from six.moves.urllib.parse import urlparse, parse_qsl + +from cas_server import models + + +def copy_form(form): + """Copy form value into a dict""" + params = {} + for field in form: + if field.value(): + params[field.name] = field.value() + else: + params[field.name] = "" + return params + + +def get_login_page_params(client=None): + """Return a client and the POST params for the client to login""" + if client is None: + client = Client() + response = client.get('/login') + params = copy_form(response.context["form"]) + return client, params + + +def get_auth_client(**update): + """return a authenticated client""" + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + params.update(update) + + client.post('/login', params) + assert client.session.get("authenticated") + + return client + + +def get_user_ticket_request(service): + """Make an auth client to request a ticket for `service`, return the tuple (user, ticket)""" + client = get_auth_client() + response = client.get("/login", {"service": service}) + ticket_value = response['Location'].split('ticket=')[-1] + user = models.User.objects.get( + username=settings.CAS_TEST_USER, + session_key=client.session.session_key + ) + ticket = models.ServiceTicket.objects.get(value=ticket_value) + return (user, ticket, client) + + +def get_validated_ticket(service): + """Return a tick that has being already validated. Used to test SLO""" + (ticket, auth_client) = get_user_ticket_request(service)[1:3] + + client = Client() + response = client.get('/validate', {'ticket': ticket.value, 'service': service}) + assert (response.status_code == 200) + assert (response.content == b'yes\ntest\n') + + ticket = models.ServiceTicket.objects.get(value=ticket.value) + return (auth_client, ticket) + + +def get_pgt(): + """return a dict contening a service, user and PGT ticket for this service""" + (httpd, host, port) = HttpParamsHandler.run()[0:3] + service = "http://%s:%s" % (host, port) + + (user, ticket) = get_user_ticket_request(service)[:2] + + client = Client() + client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service}) + params = httpd.PARAMS + + params["service"] = service + params["user"] = user + + return params + + +def get_proxy_ticket(service): + """Return a ProxyTicket waiting for validation""" + params = get_pgt() + + # get a proxy ticket + client = Client() + response = client.get('/proxy', {'pgt': params['pgtId'], 'targetService': service}) + root = etree.fromstring(response.content) + proxy_ticket = root.xpath( + "//cas:proxyTicket", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + proxy_ticket = proxy_ticket[0].text + ticket = models.ProxyTicket.objects.get(value=proxy_ticket) + return ticket + + +class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler): + """ + A simple http server that return 200 on GET or POST + and store GET or POST parameters. Used in unit tests + """ + + def do_GET(self): + """Called on a GET request on the BaseHTTPServer""" + self.send_response(200) + self.send_header(b"Content-type", "text/plain") + self.end_headers() + self.wfile.write(b"ok") + url = urlparse(self.path) + params = dict(parse_qsl(url.query)) + self.server.PARAMS = params + + def do_POST(self): + """Called on a POST request on the BaseHTTPServer""" + ctype, pdict = cgi.parse_header(self.headers.get('content-type')) + if ctype == 'multipart/form-data': + postvars = cgi.parse_multipart(self.rfile, pdict) + elif ctype == 'application/x-www-form-urlencoded': + length = int(self.headers.get('content-length')) + postvars = cgi.parse_qs(self.rfile.read(length), keep_blank_values=1) + else: + postvars = {} + self.server.PARAMS = postvars + + def log_message(self, *args): + """silent any log message""" + return + + @classmethod + def run(cls): + """Run a BaseHTTPServer using this class as handler""" + server_class = BaseHTTPServer.HTTPServer + httpd = server_class(("127.0.0.1", 0), cls) + (host, port) = httpd.socket.getsockname() + + def lauch(): + """routine to lauch in a background thread""" + httpd.handle_request() + httpd.server_close() + + httpd_thread = Thread(target=lauch) + httpd_thread.daemon = True + httpd_thread.start() + return (httpd, host, port) + + +class Http404Handler(HttpParamsHandler): + """A simple http server that always return 404 not found. Used in unit tests""" + def do_GET(self): + """Called on a GET request on the BaseHTTPServer""" + self.send_response(404) + self.send_header(b"Content-type", "text/plain") + self.end_headers() + self.wfile.write(b"error 404 not found") + + def do_POST(self): + """Called on a POST request on the BaseHTTPServer""" + return self.do_GET() diff --git a/cas_server/urls.py b/cas_server/urls.py index 982ef9d..8b7f762 100644 --- a/cas_server/urls.py +++ b/cas_server/urls.py @@ -8,7 +8,7 @@ # along with this program; if not, write to the Free Software Foundation, Inc., 51 # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # -# (c) 2015 Valentin Samir +# (c) 2015-2016 Valentin Samir """urls for the app""" from django.conf.urls import patterns, url from django.views.generic import RedirectView diff --git a/cas_server/utils.py b/cas_server/utils.py index c8b345b..ee7b5e5 100644 --- a/cas_server/utils.py +++ b/cas_server/utils.py @@ -1,4 +1,4 @@ -# ⁻*- coding: utf-8 -*- +# -*- coding: utf-8 -*- # This program is distributed in the hope that it will be useful, but WITHOUT # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS # FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for @@ -8,7 +8,7 @@ # along with this program; if not, write to the Free Software Foundation, Inc., 51 # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # -# (c) 2015 Valentin Samir +# (c) 2015-2016 Valentin Samir """Some util function for the app""" from .default_settings import settings @@ -23,18 +23,19 @@ import hashlib import crypt import base64 import six -from threading import Thread + from importlib import import_module -from six.moves import BaseHTTPServer from six.moves.urllib.parse import urlparse, urlunparse, parse_qsl, urlencode def context(params): + """Function that add somes variable to the context before template rendering""" params["settings"] = settings return params def json_response(request, data): + """Wrapper dumping `data` to a json and sending it to the user with an HttpResponse""" data["messages"] = [] for msg in messages.get_messages(request): data["messages"].append({'message': msg.message, 'level': msg.level_tag}) @@ -64,6 +65,7 @@ def redirect_params(url_name, params=None): def reverse_params(url_name, params=None, **kwargs): + """compule the reverse url or `url_name` and add GET parameters from `params` to it""" url = reverse(url_name, **kwargs) params = urlencode(params if params else {}) return url + "?%s" % params @@ -83,10 +85,13 @@ def update_url(url, params): url_parts = list(urlparse(url)) query = dict(parse_qsl(url_parts[4])) query.update(params) - url_parts[4] = urlencode(query) - for i, url_part in enumerate(url_parts): - if not isinstance(url_part, bytes): - url_parts[i] = url_part.encode('utf-8') + # make the params order deterministic + query = list(query.items()) + query.sort() + url_query = urlencode(query) + if not isinstance(url_query, bytes): # pragma: no cover in python3 urlencode return an unicode + url_query = url_query.encode("utf-8") + url_parts[4] = url_query return urlunparse(url_parts).decode('utf-8') @@ -147,35 +152,25 @@ def gen_saml_id(): return _gen_ticket('_') -class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler): - PARAMS = {} - - def do_GET(self): - self.send_response(200) - self.send_header(b"Content-type", "text/plain") - self.end_headers() - self.wfile.write(b"ok") - url = urlparse(self.path) - params = dict(parse_qsl(url.query)) - PGTUrlHandler.PARAMS.update(params) - - def log_message(self, *args): - return - - @staticmethod - def run(): - server_class = BaseHTTPServer.HTTPServer - httpd = server_class(("127.0.0.1", 0), PGTUrlHandler) - (host, port) = httpd.socket.getsockname() - - def lauch(): - httpd.handle_request() - httpd.server_close() - - httpd_thread = Thread(target=lauch) - httpd_thread.daemon = True - httpd_thread.start() - return (httpd_thread, host, port) +def crypt_salt_is_valid(salt): + """Return True is salt is valid has a crypt salt, False otherwise""" + if len(salt) < 2: + return False + else: + if salt[0] == '$': + if salt[1] == '$': + return False + else: + if '$' not in salt[1:]: + return False + else: + hashed = crypt.crypt("", salt) + if not hashed or '$' not in hashed[1:]: + return False + else: + return True + else: + return True class LdapHashUserPassword(object): @@ -268,7 +263,7 @@ class LdapHashUserPassword(object): if salt is None or salt == b"": salt = b"" cls._test_scheme_nosalt(scheme) - elif salt is not None: + else: cls._test_scheme_salt(scheme) try: return scheme + base64.b64encode( @@ -278,9 +273,9 @@ class LdapHashUserPassword(object): if six.PY3: password = password.decode(charset) salt = salt.decode(charset) - hashed_password = crypt.crypt(password, salt) - if hashed_password is None: + if not crypt_salt_is_valid(salt): raise cls.BadSalt("System crypt implementation do not support the salt %r" % salt) + hashed_password = crypt.crypt(password, salt) if six.PY3: hashed_password = hashed_password.encode(charset) return scheme + hashed_password @@ -302,7 +297,7 @@ class LdapHashUserPassword(object): if scheme in cls.schemes_nosalt: return b"" elif scheme == b'{CRYPT}': - return b'$'.join(hashed_passord.split(b'$', 3)[:-1]) + return b'$'.join(hashed_passord.split(b'$', 3)[:-1])[len(scheme):] else: hashed_passord = base64.b64decode(hashed_passord[len(scheme):]) if len(hashed_passord) < cls._schemes_to_len[scheme]: @@ -324,7 +319,7 @@ def check_password(method, password, hashed_password, charset): elif method == "crypt": if hashed_password.startswith(b'$'): salt = b'$'.join(hashed_password.split(b'$', 3)[:-1]) - elif hashed_password.startswith(b'_'): + elif hashed_password.startswith(b'_'): # pragma: no cover old BSD format not supported salt = hashed_password[:9] else: salt = hashed_password[:2] @@ -332,9 +327,9 @@ def check_password(method, password, hashed_password, charset): password = password.decode(charset) salt = salt.decode(charset) hashed_password = hashed_password.decode(charset) - crypted_password = crypt.crypt(password, salt) - if crypted_password is None: + if not crypt_salt_is_valid(salt): raise ValueError("System crypt implementation do not support the salt %r" % salt) + crypted_password = crypt.crypt(password, salt) return crypted_password == hashed_password elif method == "ldap": scheme = LdapHashUserPassword.get_scheme(hashed_password) diff --git a/cas_server/views.py b/cas_server/views.py index 2b33a6c..94ee0f0 100644 --- a/cas_server/views.py +++ b/cas_server/views.py @@ -8,7 +8,7 @@ # along with this program; if not, write to the Free Software Foundation, Inc., 51 # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # -# (c) 2015 Valentin Samir +# (c) 2015-2016 Valentin Samir """views for the app""" from .default_settings import settings @@ -105,10 +105,11 @@ class LogoutView(View, LogoutMixin): service = None def init_get(self, request): + """Initialize GET received parameters""" self.request = request self.service = request.GET.get('service') self.url = request.GET.get('url') - self.ajax = 'HTTP_X_AJAX' in request.META + self.ajax = settings.CAS_ENABLE_AJAX_AUTH and 'HTTP_X_AJAX' in request.META def get(self, request, *args, **kwargs): """methode called on GET request on this view""" @@ -196,24 +197,30 @@ class LoginView(View, LogoutMixin): USER_NOT_AUTHENTICATED = 6 def init_post(self, request): + """Initialize POST received parameters""" self.request = request self.service = request.POST.get('service') self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False") self.gateway = request.POST.get('gateway') self.method = request.POST.get('method') - self.ajax = 'HTTP_X_AJAX' in request.META + self.ajax = settings.CAS_ENABLE_AJAX_AUTH and 'HTTP_X_AJAX' in request.META if request.POST.get('warned') and request.POST['warned'] != "False": self.warned = True + self.warn = request.POST.get('warn') - def check_lt(self): - # save LT for later check - lt_valid = self.request.session.get('lt', []) - lt_send = self.request.POST.get('lt') - # generate a new LT (by posting the LT has been consumed) + def gen_lt(self): + """Generate a new LoginTicket and add it to the list of valid LT for the user""" self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()] if len(self.request.session['lt']) > 100: self.request.session['lt'] = self.request.session['lt'][-100:] + def check_lt(self): + """Check is the POSTed LoginTicket is valid, if yes invalide it""" + # save LT for later check + lt_valid = self.request.session.get('lt', []) + lt_send = self.request.POST.get('lt') + # generate a new LT (by posting the LT has been consumed) + self.gen_lt() # check if send LT is valid if lt_valid is None or lt_send not in lt_valid: return False @@ -238,7 +245,7 @@ class LoginView(View, LogoutMixin): username=self.request.session['username'], session_key=self.request.session.session_key ) - self.user.save() + self.user.save() # pragma: no cover (should not happend) except models.User.DoesNotExist: self.user = models.User.objects.create( username=self.request.session['username'], @@ -250,10 +257,15 @@ class LoginView(View, LogoutMixin): elif ret == self.USER_ALREADY_LOGGED: pass else: - raise EnvironmentError("invalid output for LoginView.process_post") + raise EnvironmentError("invalid output for LoginView.process_post") # pragma: no cover return self.common() def process_post(self): + """ + Analyse the POST request: + * check that the LoginTicket is valid + * check that the user sumited credentials are valid + """ if not self.check_lt(): values = self.request.POST.copy() # if not set a new LT and fail @@ -280,12 +292,14 @@ class LoginView(View, LogoutMixin): return self.USER_ALREADY_LOGGED def init_get(self, request): + """Initialize GET received parameters""" self.request = request self.service = request.GET.get('service') self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False") self.gateway = request.GET.get('gateway') self.method = request.GET.get('method') - self.ajax = 'HTTP_X_AJAX' in request.META + self.ajax = settings.CAS_ENABLE_AJAX_AUTH and 'HTTP_X_AJAX' in request.META + self.warn = request.GET.get('warn') def get(self, request, *args, **kwargs): """methode called on GET request on this view""" @@ -294,22 +308,24 @@ class LoginView(View, LogoutMixin): return self.common() def process_get(self): - # generate a new LT if none is present - self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()] - + """Analyse the GET request""" + # generate a new LT + self.gen_lt() if not self.request.session.get("authenticated") or self.renew: self.init_form() return self.USER_NOT_AUTHENTICATED return self.USER_AUTHENTICATED def init_form(self, values=None): + """Initialization of the good form depending of POST and GET parameters""" self.form = forms.UserCredential( values, initial={ 'service': self.service, 'method': self.method, 'warn': self.request.session.get("warn"), - 'lt': self.request.session['lt'][-1] + 'lt': self.request.session['lt'][-1], + 'renew': self.renew } ) @@ -351,7 +367,7 @@ class LoginView(View, LogoutMixin): redirect_url = self.user.get_service_url( self.service, service_pattern, - renew=self.renew + renew=self.renewed ) if not self.ajax: return HttpResponseRedirect(redirect_url) @@ -580,12 +596,9 @@ class Validate(View): ticket.service_pattern.user_field ) if isinstance(username, list): - try: - username = username[0] - except IndexError: - username = None - if not username: - username = "" + # the list is not empty because we wont generate a ticket with a user_field + # that evaluate to False + username = username[0] else: username = ticket.user.username return HttpResponse("yes\n%s\n" % username, content_type="text/plain") @@ -661,6 +674,10 @@ class ValidateService(View, AttributesMixin): params['username'] = self.ticket.user.attributs.get( self.ticket.service_pattern.user_field ) + if isinstance(params['username'], list): + # the list is not empty because we wont generate a ticket with a user_field + # that evaluate to False + params['username'] = params['username'][0] if self.pgt_url and ( self.pgt_url.startswith("https://") or re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url) @@ -762,9 +779,12 @@ class ValidateService(View, AttributesMixin): params, content_type="text/xml; charset=utf-8" ) - except requests.exceptions.SSLError as error: + except requests.exceptions.RequestException as error: error = utils.unpack_nested_exception(error) - raise ValidateError('INVALID_PROXY_CALLBACK', str(error)) + raise ValidateError( + 'INVALID_PROXY_CALLBACK', + "%s: %s" % (type(error), str(error)) + ) else: raise ValidateError( 'INVALID_PROXY_CALLBACK', @@ -844,7 +864,7 @@ class Proxy(View): except (models.BadUsername, models.BadFilter, models.UserFieldNotDefined): raise ValidateError( 'UNAUTHORIZED_USER', - '%s not allowed on %s' % (ticket.user, self.target_service) + 'User %s not allowed on %s' % (ticket.user.username, self.target_service) ) @@ -903,11 +923,15 @@ class SamlValidate(View, AttributesMixin): 'username': self.ticket.user.username, 'attributes': attributes } - if self.ticket.service_pattern.user_field and \ - self.ticket.user.attributs.get(self.ticket.service_pattern.user_field): + if (self.ticket.service_pattern.user_field and + self.ticket.user.attributs.get(self.ticket.service_pattern.user_field)): params['username'] = self.ticket.user.attributs.get( self.ticket.service_pattern.user_field ) + if isinstance(params['username'], list): + # the list is not empty because we wont generate a ticket with a user_field + # that evaluate to False + params['username'] = params['username'][0] logger.info( "SamlValidate: ticket %s validated for user %s on service %s." % ( self.ticket.value, diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..74acafe --- /dev/null +++ b/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +testpaths = cas_server/tests/ +DJANGO_SETTINGS_MODULE = cas_server.tests.settings +norecursedirs = .* build dist docs +python_paths = . diff --git a/requirements-dev.txt b/requirements-dev.txt index e6ef993..3cf4247 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,10 +1,12 @@ -tox==1.8.1 -pytest==2.6.4 -pytest-django==2.7.0 -pytest-pythonpath==0.3 +setuptools>=5.5 +tox>=1.8.1 +pytest>=2.6.4 +pytest-django>=2.8.0 +pytest-pythonpath>=0.3 +pytest-cov>=2.2.1 requests>=2.4 -django-picklefield>=0.3.1 requests_futures>=0.9.5 +django-picklefield>=0.3.1 django-bootstrap3>=5.4 lxml>=3.4 six>=1 diff --git a/run_tests b/run_tests deleted file mode 100755 index 4ea21ee..0000000 --- a/run_tests +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/env python -import os, sys -import django -from django.conf import settings - -import settings_tests - -settings.configure(**settings_tests.__dict__) -django.setup() - -try: - # Django <= 1.8 - from django.test.simple import DjangoTestSuiteRunner - test_runner = DjangoTestSuiteRunner(verbosity=1) -except ImportError: - # Django >= 1.8 - from django.test.runner import DiscoverRunner - test_runner = DiscoverRunner(verbosity=1) - -failures = test_runner.run_tests(['cas_server']) -if failures: - sys.exit(failures) diff --git a/setup.py b/setup.py index a9826c7..0548d66 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,8 @@ setup( version='0.4.4', packages=[ 'cas_server', 'cas_server.migrations', - 'cas_server.management', 'cas_server.management.commands' + 'cas_server.management', 'cas_server.management.commands', + 'cas_server.tests' ], include_package_data=True, license='GPLv3', diff --git a/tox.ini b/tox.ini index 0b65c56..ad0fec7 100644 --- a/tox.ini +++ b/tox.ini @@ -1,12 +1,12 @@ [tox] envlist= + flake8, py27-django17, py27-django18, py27-django19, py34-django17, py34-django18, py34-django19, - flake8, [flake8] max-line-length=100 @@ -17,7 +17,7 @@ deps = -r{toxinidir}/requirements-dev.txt [testenv] -commands=python run_tests {posargs:tests} +commands=py.test {posargs:cas_server/tests/} [testenv:py27-django17] basepython=python2.7 @@ -60,3 +60,13 @@ basepython=python deps=flake8 commands=flake8 {toxinidir}/cas_server +[testenv:coverage] +basepython=python +passenv=CODACY_PROJECT_TOKEN +deps= + -r{toxinidir}/requirements-dev.txt + codacy-coverage +commands= + py.test --cov=cas_server --cov-report xml + python-codacy-coverage -r {toxinidir}/coverage.xml +