From 90daf3d2a0d1639906d0da0ea13ddefddcc967da Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Sun, 3 Jul 2016 13:51:00 +0200 Subject: [PATCH] Add unit tests for when CAS_FEDERATE is True Also fix some unicode related bugs --- cas_server/auth.py | 2 +- cas_server/default_settings.py | 12 +- cas_server/federate.py | 43 +-- cas_server/forms.py | 10 +- .../management/commands/cas_clean_federate.py | 12 +- cas_server/models.py | 23 +- cas_server/tests/mixin.py | 47 +++ cas_server/tests/test_federate.py | 344 ++++++++++++++++++ cas_server/tests/test_models.py | 65 +++- cas_server/tests/test_utils.py | 21 +- cas_server/tests/test_view.py | 46 +-- cas_server/tests/utils.py | 163 ++++++++- cas_server/views.py | 105 +++--- 13 files changed, 749 insertions(+), 144 deletions(-) create mode 100644 cas_server/tests/test_federate.py diff --git a/cas_server/auth.py b/cas_server/auth.py index afcb722..d666ec5 100644 --- a/cas_server/auth.py +++ b/cas_server/auth.py @@ -171,7 +171,7 @@ class CASFederateAuth(AuthUser): def attributs(self): """return a dict of user attributes""" - if not self.user: + if not self.user: # pragma: no cover (should not happen) return {} else: return self.user.attributs diff --git a/cas_server/default_settings.py b/cas_server/default_settings.py index 07b420e..be3f064 100644 --- a/cas_server/default_settings.py +++ b/cas_server/default_settings.py @@ -14,7 +14,6 @@ from django.conf import settings from django.contrib.staticfiles.templatetags.staticfiles import static import re -import six def setting_default(name, default_value): @@ -112,13 +111,10 @@ except AttributeError: key = settings.CAS_FEDERATE_PROVIDERS[key][2].lower() else: key = key.lower() - if isinstance(key, six.string_types) or isinstance(key, six.text_type): - return tuple( - int(num) if num else alpha - for num, alpha in __cas_federate_providers_list_sort.tokenize(key) - ) - else: - return key + return tuple( + int(num) if num else alpha + for num, alpha in __cas_federate_providers_list_sort.tokenize(key) + ) __cas_federate_providers_list_sort.tokenize = re.compile(r'(\d+)|(\D+)').findall __CAS_FEDERATE_PROVIDERS_LIST.sort(key=__cas_federate_providers_list_sort) diff --git a/cas_server/federate.py b/cas_server/federate.py index 453a778..2f6489a 100644 --- a/cas_server/federate.py +++ b/cas_server/federate.py @@ -15,6 +15,7 @@ from .cas import CASClient from .models import FederatedUser, FederateSLO, User from importlib import import_module +from six.moves import urllib SessionStore = import_module(settings.SESSION_ENGINE).SessionStore @@ -27,7 +28,7 @@ class CASFederateValidateUser(object): def __init__(self, provider, service_url): self.provider = provider - if provider in settings.CAS_FEDERATE_PROVIDERS: + if provider in settings.CAS_FEDERATE_PROVIDERS: # pragma: no branch (should always be True) (server_url, version) = settings.CAS_FEDERATE_PROVIDERS[provider][:2] self.client = CASClient( service_url=service_url, @@ -44,9 +45,12 @@ class CASFederateValidateUser(object): def verify_ticket(self, ticket): """test `password` agains the user""" - if self.client is None: + if self.client is None: # pragma: no cover (should not happen) + return False + try: + username, attributs = self.client.verify_ticket(ticket)[:2] + except urllib.error.URLError: return False - username, attributs = self.client.verify_ticket(ticket)[:2] if username is not None: if attributs is None: attributs = {} @@ -83,23 +87,20 @@ class CASFederateValidateUser(object): def clean_sessions(self, logout_request): try: - slos = self.client.get_saml_slos(logout_request) - except NameError: + slos = self.client.get_saml_slos(logout_request) or [] + except NameError: # pragma: no cover (should not happen) slos = [] for slo in slos: - try: - for federate_slo in FederateSLO.objects.filter(ticket=slo.text): - session = SessionStore(session_key=federate_slo.session_key) - session.flush() - try: - user = User.objects.get( - username=federate_slo.username, - session_key=federate_slo.session_key - ) - user.logout() - user.delete() - except User.DoesNotExist: - pass - federate_slo.delete() - except FederateSLO.DoesNotExist: - pass + for federate_slo in FederateSLO.objects.filter(ticket=slo.text): + session = SessionStore(session_key=federate_slo.session_key) + session.flush() + try: + user = User.objects.get( + username=federate_slo.username, + session_key=federate_slo.session_key + ) + user.logout() + user.delete() + except User.DoesNotExist: # pragma: no cover (should not happen) + pass + federate_slo.delete() diff --git a/cas_server/forms.py b/cas_server/forms.py index b5cf4d0..dc0e866 100644 --- a/cas_server/forms.py +++ b/cas_server/forms.py @@ -31,6 +31,8 @@ class WarnForm(forms.Form): class FederateSelect(forms.Form): provider = forms.ChoiceField( label=_('Identity provider'), + # with use a lambda abstraction to delay the access to settings.CAS_FEDERATE_PROVIDERS + # this is usefull to use the override_settings decorator in tests choices=[ ( p, @@ -88,8 +90,12 @@ class FederateUserCredential(UserCredential): user = models.FederatedUser.objects.get(username=username, provider=provider) user.ticket = "" user.save() - except models.FederatedUser.DoesNotExist: - raise + # should not happed as is the FederatedUser do not exists, super should + # raise before a ValidationError("bad user") + except models.FederatedUser.DoesNotExist: # pragma: no cover (should not happend) + raise forms.ValidationError( + _(u"User not found in the temporary database, please try to reconnect") + ) return cleaned_data diff --git a/cas_server/management/commands/cas_clean_federate.py b/cas_server/management/commands/cas_clean_federate.py index 04e0608..8d91935 100644 --- a/cas_server/management/commands/cas_clean_federate.py +++ b/cas_server/management/commands/cas_clean_federate.py @@ -1,11 +1,7 @@ from django.core.management.base import BaseCommand from django.utils.translation import ugettext_lazy as _ -from django.utils import timezone - -from datetime import timedelta from ... import models -from ...default_settings import settings class Command(BaseCommand): @@ -13,11 +9,5 @@ class Command(BaseCommand): help = _(u"Clean old federated users") def handle(self, *args, **options): - federated_users = models.FederatedUser.objects.filter( - last_update__lt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT)) - ) - known_users = {user.username for user in models.User.objects.all()} - for user in federated_users: - if not ('%s@%s' % (user.username, user.provider)) in known_users: - user.delete() + models.FederatedUser.clean_old_entries() models.FederateSLO.clean_deleted_sessions() diff --git a/cas_server/models.py b/cas_server/models.py index aea270b..3d1f17f 100644 --- a/cas_server/models.py +++ b/cas_server/models.py @@ -46,6 +46,16 @@ class FederatedUser(models.Model): def __unicode__(self): return u"%s@%s" % (self.username, self.provider) + @classmethod + def clean_old_entries(cls): + federated_users = cls.objects.filter( + last_update__lt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT)) + ) + known_users = {user.username for user in User.objects.all()} + for user in federated_users: + if not ('%s@%s' % (user.username, user.provider)) in known_users: + user.delete() + class FederateSLO(models.Model): class Meta: @@ -54,11 +64,6 @@ class FederateSLO(models.Model): session_key = models.CharField(max_length=40, blank=True, null=True) ticket = models.CharField(max_length=255) - @property - def provider(self): - component = self.username.split("@") - return component[-1] - @classmethod def clean_deleted_sessions(cls): for federate_slo in cls.objects.all(): @@ -76,6 +81,14 @@ class User(models.Model): username = models.CharField(max_length=30) date = models.DateTimeField(auto_now=True) + def delete(self, *args, **kwargs): + if settings.CAS_FEDERATE: + FederateSLO.objects.filter( + username=self.username, + session_key=self.session_key + ).delete() + super(User, self).delete(*args, **kwargs) + @classmethod def clean_old_entries(cls): """Remove users inactive since more that SESSION_COOKIE_AGE""" diff --git a/cas_server/tests/mixin.py b/cas_server/tests/mixin.py index ddbf2d2..4612fd2 100644 --- a/cas_server/tests/mixin.py +++ b/cas_server/tests/mixin.py @@ -191,3 +191,50 @@ class UserModels(object): username=settings.CAS_TEST_USER, session_key=client.session.session_key ) + + +class CanLogin(object): + """Assertion about login""" + def assert_logged( + self, client, response, warn=False, + code=200, username=settings.CAS_TEST_USER + ): + """Assertions testing that client is well authenticated""" + self.assertEqual(response.status_code, code) + # this message is displayed to the user upon successful authentication + self.assertIn( + ( + b"You have successfully logged into " + b"the Central Authentication Service" + ), + response.content + ) + # these session variables a set if usccessfully authenticated + self.assertEqual(client.session["username"], username) + self.assertIs(client.session["warn"], warn) + self.assertIs(client.session["authenticated"], True) + + # on successfull authentication, a corresponding user object is created + self.assertTrue( + models.User.objects.get( + username=username, + session_key=client.session.session_key + ) + ) + + def assert_login_failed(self, client, response, code=200): + """Assertions testing a failed login attempt""" + self.assertEqual(response.status_code, code) + # this message is displayed to the user upon successful authentication, so it should not + # appear + self.assertFalse( + ( + b"You have successfully logged into " + b"the Central Authentication Service" + ) in response.content + ) + + # if authentication has failed, these session variables should not be set + self.assertTrue(client.session.get("username") is None) + self.assertTrue(client.session.get("warn") is None) + self.assertTrue(client.session.get("authenticated") is None) diff --git a/cas_server/tests/test_federate.py b/cas_server/tests/test_federate.py new file mode 100644 index 0000000..b4e76b2 --- /dev/null +++ b/cas_server/tests/test_federate.py @@ -0,0 +1,344 @@ +# -*- coding: utf-8 -*- +# This program is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for +# more details. +# +# You should have received a copy of the GNU General Public License version 3 +# along with this program; if not, write to the Free Software Foundation, Inc., 51 +# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# (c) 2016 Valentin Samir +"""tests for the CAS federate mode""" +from cas_server import default_settings +from cas_server.default_settings import settings + +import django +from django.test import TestCase, Client +from django.test.utils import override_settings + +from six.moves import reload_module + +from cas_server import utils, forms +from cas_server.tests.mixin import BaseServicePattern, CanLogin +from cas_server.tests import utils as tests_utils + +PROVIDERS = { + "example.com": ("http://127.0.0.1:8080", 1, "Example dot com"), + "example.org": ("http://127.0.0.1:8081", 2, "Example dot org"), + "example.net": ("http://127.0.0.1:8082", 3, "Example dot net"), + "example.test": ("http://127.0.0.1:8083", 'CAS_2_SAML_1_0'), +} + +PROVIDERS_LIST = list(PROVIDERS.keys()) +PROVIDERS_LIST.sort() + + +def getaddrinfo_mock(name, port, *args, **kwargs): + return [(2, 1, 6, '', ('127.0.0.1', 80))] + + +@override_settings( + CAS_FEDERATE=True, + CAS_FEDERATE_PROVIDERS=PROVIDERS, + CAS_FEDERATE_PROVIDERS_LIST=PROVIDERS_LIST, + CAS_AUTH_CLASS="cas_server.auth.CASFederateAuth", + # test with a non ascii username + CAS_TEST_USER=u"dédé" +) +class FederateAuthLoginLogoutTestCase(TestCase, BaseServicePattern, CanLogin): + """tests for the views login logout and federate then the federated mode is enabled""" + def setUp(self): + """Prepare the test context""" + self.setup_service_patterns() + reload_module(forms) + + def test_default_settings(self): + """default settings should populated some default variable then CAS_FEDERATE is True""" + provider_list = settings.CAS_FEDERATE_PROVIDERS_LIST + del settings.CAS_FEDERATE_PROVIDERS_LIST + del settings.CAS_AUTH_CLASS + reload_module(default_settings) + self.assertEqual(settings.CAS_FEDERATE_PROVIDERS_LIST, provider_list) + self.assertEqual(settings.CAS_AUTH_CLASS, "cas_server.auth.CASFederateAuth") + + def test_login_get_provider(self): + """some assertion about the login page in federated mode""" + client = Client() + response = client.get("/login") + self.assertEqual(response.status_code, 200) + for key, value in settings.CAS_FEDERATE_PROVIDERS.items(): + self.assertTrue('' % ( + key, + utils.get_tuple(value, 2, key) + ) in response.content.decode("utf-8")) + self.assertEqual(response.context['post_url'], '/federate') + + def test_login_post_provider(self, remember=False): + """test a successful login wrokflow""" + tickets = [] + # choose the example.com provider + for (provider, cas_port) in [ + ("example.com", 8080), ("example.org", 8081), + ("example.net", 8082), ("example.test", 8083) + ]: + # get a bare client + client = Client() + # fetch the login page + response = client.get("/login") + # in federated mode, we shoudl POST do /federate on the login page + self.assertEqual(response.context['post_url'], '/federate') + # get current form parameter + params = tests_utils.copy_form(response.context["form"]) + params['provider'] = provider + if remember: + params['remember'] = 'on' + # post the choosed provider + response = client.post('/federate', params) + # we are redirected to the provider CAS client url + self.assertEqual(response.status_code, 302) + if remember: + self.assertEqual(response["Location"], '%s/federate/%s?remember=on' % ( + 'http://testserver' if django.VERSION < (1, 9) else "", + provider + )) + else: + self.assertEqual(response["Location"], '%s/federate/%s' % ( + 'http://testserver' if django.VERSION < (1, 9) else "", + provider + )) + # let's follow the redirect + response = client.get('/federate/%s' % provider) + # we are redirected to the provider CAS for authentication + self.assertEqual(response.status_code, 302) + self.assertEqual( + response["Location"], + "%s/login?service=http%%3A%%2F%%2Ftestserver%%2Ffederate%%2F%s" % ( + settings.CAS_FEDERATE_PROVIDERS[provider][0], + provider + ) + ) + # let's generate a ticket + ticket = utils.gen_st() + # we lauch a dummy CAS server that only validate once for the service + # http://testserver/federate/example.com with `ticket` + tests_utils.DummyCAS.run( + ("http://testserver/federate/%s" % provider).encode("ascii"), + ticket.encode("ascii"), + settings.CAS_TEST_USER.encode("utf8"), + [], + cas_port + ) + # we normally provide a good ticket and should be redirected to /login as the ticket + # get successfully validated again the dummy CAS + response = client.get('/federate/%s' % provider, {'ticket': ticket}) + self.assertEqual(response.status_code, 302) + self.assertEqual(response["Location"], "%s/login" % ( + 'http://testserver' if django.VERSION < (1, 9) else "" + )) + # follow the redirect + response = client.get("/login") + # we should get a page with a from with all widget hidden that auto POST to /login using + # javascript. If javascript is disabled, a "connect" button is showed + self.assertTrue(response.context['auto_submit']) + self.assertEqual(response.context['post_url'], '/login') + params = tests_utils.copy_form(response.context["form"]) + # POST ge prefiled from parameters + response = client.post("/login", params) + # the user should now being authenticated using username test@`provider` + self.assert_logged( + client, response, username='%s@%s' % (settings.CAS_TEST_USER, provider) + ) + tickets.append((provider, ticket, client)) + + # try to get a ticket + response = client.get("/login", {'service': self.service}) + self.assertEqual(response.status_code, 302) + self.assertTrue(response["Location"].startswith("%s?ticket=" % self.service)) + return tickets + + def test_login_twice(self): + """Test that user id db is used for the second login (cf coverage)""" + self.test_login_post_provider() + self.test_login_post_provider() + + @override_settings(CAS_FEDERATE=False) + def test_auth_federate_false(self): + """federated view should redirect to /login then CAS_FEDERATE is False""" + provider = "example.com" + client = Client() + response = client.get("/federate/%s" % provider) + self.assertEqual(response.status_code, 302) + self.assertEqual(response["Location"], "%s/login" % ( + 'http://testserver' if django.VERSION < (1, 9) else "" + )) + response = client.post("%s/federate/%s" % ( + 'http://testserver' if django.VERSION < (1, 9) else "", + provider + )) + self.assertEqual(response.status_code, 302) + self.assertEqual(response["Location"], "%s/login" % ( + 'http://testserver' if django.VERSION < (1, 9) else "" + )) + + def test_auth_federate_errors(self): + """ + The federated view should redirect to /login if the provider is unknown or not provided, + try to fetch a new ticket if the provided ticket validation fail + (network error or bad ticket) + """ + return + good_provider = "example.com" + bad_provider = "exemple.fr" + client = Client() + response = client.get("/federate/%s" % bad_provider) + self.assertEqual(response.status_code, 302) + self.assertEqual(response["Location"], "%s/login" % ( + 'http://testserver' if django.VERSION < (1, 9) else "" + )) + + # test CAS not avaible + response = client.get("/federate/%s" % good_provider, {'ticket': utils.gen_st()}) + self.assertEqual(response.status_code, 302) + self.assertEqual( + response["Location"], + "%s/login?service=http%%3A%%2F%%2Ftestserver%%2Ffederate%%2F%s" % ( + settings.CAS_FEDERATE_PROVIDERS[good_provider][0], + good_provider + ) + ) + + # test CAS avaible but bad ticket + tests_utils.DummyCAS.run( + ("http://testserver/federate/%s" % good_provider).encode("ascii"), + utils.gen_st().encode("ascii"), + settings.CAS_TEST_USER.encode("utf-8"), + [], + 8080 + ) + response = client.get("/federate/%s" % good_provider, {'ticket': utils.gen_st()}) + self.assertEqual(response.status_code, 302) + self.assertEqual( + response["Location"], + "%s/login?service=http%%3A%%2F%%2Ftestserver%%2Ffederate%%2F%s" % ( + settings.CAS_FEDERATE_PROVIDERS[good_provider][0], + good_provider + ) + ) + + response = client.post("/federate") + self.assertEqual(response.status_code, 302) + self.assertEqual(response["Location"], "%s/login" % ( + 'http://testserver' if django.VERSION < (1, 9) else "" + )) + + def test_auth_federate_slo(self): + """test that SLO receive from backend CAS log out the users""" + # get tickets and connected clients + tickets = self.test_login_post_provider() + for (provider, ticket, client) in tickets: + # SLO for an unkown ticket should do nothing + response = client.post( + "/federate/%s" % provider, + {'logoutRequest': tests_utils.logout_request(utils.gen_st())} + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b"ok") + # Bad SLO format should do nothing + response = client.post( + "/federate/%s" % provider, + {'logoutRequest': ""} + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b"ok") + # Bad SLO format should do nothing + response = client.post( + "/federate/%s" % provider, + {'logoutRequest': ""} + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b"ok") + response = client.get("/login") + self.assert_logged( + client, response, username='%s@%s' % (settings.CAS_TEST_USER, provider) + ) + + # SLO for a previously logged ticket should log out the user if CAS version is + # 3 or 'CAS_2_SAML_1_0' + response = client.post( + "/federate/%s" % provider, + {'logoutRequest': tests_utils.logout_request(ticket)} + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b"ok") + + response = client.get("/login") + if settings.CAS_FEDERATE_PROVIDERS[provider][1] in {3, 'CAS_2_SAML_1_0'}: # support SLO + self.assert_login_failed(client, response) + else: + self.assert_logged( + client, response, username='%s@%s' % (settings.CAS_TEST_USER, provider) + ) + + def test_federate_logout(self): + """ + test the logout function: the user should be log out + and redirected to his CAS logout page + """ + # get tickets and connected clients + tickets = self.test_login_post_provider() + for (provider, _, client) in tickets: + response = client.get("/logout") + self.assertEqual(response.status_code, 302) + self.assertEqual( + response["Location"], + "%s/logout" % settings.CAS_FEDERATE_PROVIDERS[provider][0] + ) + response = client.get("/login") + self.assert_login_failed(client, response) + + def test_remember_provider(self): + """ + If the user check remember, next login should not offer the chose of the backend CAS + and use the one store in the cookie + """ + tickets = self.test_login_post_provider(remember=True) + for (provider, _, client) in tickets: + client.get("/logout") + response = client.get("/login") + self.assertEqual(response.status_code, 302) + self.assertEqual(response["Location"], "%s/federate/%s" % ( + 'http://testserver' if django.VERSION < (1, 9) else "", + provider + )) + + def test_login_bad_ticket(self): + """ + Try login with a bad ticket: + login should fail and the main login page should be displayed to the user + """ + provider = "example.com" + # get a bare client + client = Client() + session = client.session + session["federate_username"] = '%s@%s' % (settings.CAS_TEST_USER, provider) + session["federate_ticket"] = utils.gen_st() + try: + session.save() + response = client.get("/login") + # we should get a page with a from with all widget hidden that auto POST to /login using + # javascript. If javascript is disabled, a "connect" button is showed + self.assertTrue(response.context['auto_submit']) + self.assertEqual(response.context['post_url'], '/login') + params = tests_utils.copy_form(response.context["form"]) + # POST, as (username, ticket) are not valid, we should get the federate login page + response = client.post("/login", params) + self.assertEqual(response.status_code, 200) + for key, value in settings.CAS_FEDERATE_PROVIDERS.items(): + self.assertTrue('' % ( + key, + utils.get_tuple(value, 2, key) + ) in response.content.decode("utf-8")) + self.assertEqual(response.context['post_url'], '/federate') + except AttributeError: + pass diff --git a/cas_server/tests/test_models.py b/cas_server/tests/test_models.py index e75f54f..cdaece8 100644 --- a/cas_server/tests/test_models.py +++ b/cas_server/tests/test_models.py @@ -12,20 +12,81 @@ """Tests module for models""" from cas_server.default_settings import settings -from django.test import TestCase +from django.test import TestCase, Client from django.test.utils import override_settings from django.utils import timezone from datetime import timedelta from importlib import import_module -from cas_server import models +from cas_server import models, utils from cas_server.tests.utils import get_auth_client, HttpParamsHandler from cas_server.tests.mixin import UserModels, BaseServicePattern SessionStore = import_module(settings.SESSION_ENGINE).SessionStore +class FederatedUserTestCase(TestCase, UserModels): + """test for the federated user model""" + def test_clean_old_entries(self): + """tests for clean_old_entries that should delete federated user no longer used""" + client = Client() + client.get("/login") + models.FederatedUser.objects.create( + username="test1", provider="example.com", attributs={}, ticket="" + ) + models.FederatedUser.objects.create( + username="test2", provider="example.com", attributs={}, ticket="" + ) + models.FederatedUser.objects.all().update( + last_update=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT + 10)) + ) + models.FederatedUser.objects.create( + username="test3", provider="example.com", attributs={}, ticket="" + ) + models.User.objects.create( + username="test1@example.com", session_key=client.session.session_key + ) + models.FederatedUser.clean_old_entries() + self.assertEqual(len(models.FederatedUser.objects.all()), 2) + with self.assertRaises(models.FederatedUser.DoesNotExist): + models.FederatedUser.objects.get(username="test2") + + +class FederateSLOTestCase(TestCase, UserModels): + """test for the federated SLO model""" + def test_clean_deleted_sessions(self): + """ + tests for clean_deleted_sessions that should delete object for which matching session + do not exists anymore + """ + client1 = Client() + client2 = Client() + client1.get("/login") + client2.get("/login") + session = client2.session + session['authenticated'] = True + try: + session.save() + except AttributeError: + pass + models.FederateSLO.objects.create( + username="test1@example.com", + session_key=client1.session.session_key, + ticket=utils.gen_st() + ) + models.FederateSLO.objects.create( + username="test2@example.com", + session_key=client2.session.session_key, + ticket=utils.gen_st() + ) + self.assertEqual(len(models.FederateSLO.objects.all()), 2) + models.FederateSLO.clean_deleted_sessions() + self.assertEqual(len(models.FederateSLO.objects.all()), 1) + with self.assertRaises(models.FederateSLO.DoesNotExist): + models.FederateSLO.objects.get(username="test1@example.com") + + @override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser') class UserTestCase(TestCase, UserModels): """tests for the user models""" diff --git a/cas_server/tests/test_utils.py b/cas_server/tests/test_utils.py index 76fa2cc..411848a 100644 --- a/cas_server/tests/test_utils.py +++ b/cas_server/tests/test_utils.py @@ -10,7 +10,7 @@ # # (c) 2016 Valentin Samir """Tests module for utils""" -from django.test import TestCase +from django.test import TestCase, RequestFactory import six @@ -189,3 +189,22 @@ class UtilsTestCase(TestCase): self.assertFalse(utils.crypt_salt_is_valid("$$")) # start with $ followed by $ self.assertFalse(utils.crypt_salt_is_valid("$toto")) # start with $ but no secondary $ self.assertFalse(utils.crypt_salt_is_valid("$toto$toto")) # algorithm toto not known + + def test_get_current_url(self): + """test the function get_current_url""" + factory = RequestFactory() + request = factory.get('/truc/muche?test=1') + self.assertEqual(utils.get_current_url(request), 'http://testserver/truc/muche?test=1') + self.assertEqual( + utils.get_current_url(request, ignore_params={'test'}), + 'http://testserver/truc/muche' + ) + + def test_get_tuple(self): + """test the function get_tuple""" + test_tuple = (1, 2, 3) + for index, value in enumerate(test_tuple): + self.assertEqual(utils.get_tuple(test_tuple, index), value) + self.assertEqual(utils.get_tuple(test_tuple, 3), None) + self.assertEqual(utils.get_tuple(test_tuple, 3, 'toto'), 'toto') + self.assertEqual(utils.get_tuple(None, 3), None) diff --git a/cas_server/tests/test_view.py b/cas_server/tests/test_view.py index 0acd52f..49fa2d2 100644 --- a/cas_server/tests/test_view.py +++ b/cas_server/tests/test_view.py @@ -1,4 +1,4 @@ -# ⁻*- coding: utf-8 -*- +# -*- coding: utf-8 -*- # This program is distributed in the hope that it will be useful, but WITHOUT # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS # FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for @@ -36,57 +36,17 @@ from cas_server.tests.utils import ( HttpParamsHandler, Http404Handler ) -from cas_server.tests.mixin import BaseServicePattern, XmlContent +from cas_server.tests.mixin import BaseServicePattern, XmlContent, CanLogin @override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser') -class LoginTestCase(TestCase, BaseServicePattern): +class LoginTestCase(TestCase, BaseServicePattern, CanLogin): """Tests for the login view""" def setUp(self): """Prepare the test context:""" # we prepare a bunch a service url and service patterns for tests self.setup_service_patterns() - def assert_logged(self, client, response, warn=False, code=200): - """Assertions testing that client is well authenticated""" - self.assertEqual(response.status_code, code) - # this message is displayed to the user upon successful authentication - self.assertTrue( - ( - b"You have successfully logged into " - b"the Central Authentication Service" - ) in response.content - ) - # these session variables a set if usccessfully authenticated - self.assertTrue(client.session["username"] == settings.CAS_TEST_USER) - self.assertTrue(client.session["warn"] is warn) - self.assertTrue(client.session["authenticated"] is True) - - # on successfull authentication, a corresponding user object is created - self.assertTrue( - models.User.objects.get( - username=settings.CAS_TEST_USER, - session_key=client.session.session_key - ) - ) - - def assert_login_failed(self, client, response, code=200): - """Assertions testing a failed login attempt""" - self.assertEqual(response.status_code, code) - # this message is displayed to the user upon successful authentication, so it should not - # appear - self.assertFalse( - ( - b"You have successfully logged into " - b"the Central Authentication Service" - ) in response.content - ) - - # if authentication has failed, these session variables should not be set - self.assertTrue(client.session.get("username") is None) - self.assertTrue(client.session.get("warn") is None) - self.assertTrue(client.session.get("authenticated") is None) - def test_login_view_post_goodpass_goodlt(self): """Test a successul login""" # we get a client who fetch a frist time the login page and the login form default diff --git a/cas_server/tests/utils.py b/cas_server/tests/utils.py index bd692e9..b8419c6 100644 --- a/cas_server/tests/utils.py +++ b/cas_server/tests/utils.py @@ -13,14 +13,33 @@ from cas_server.default_settings import settings from django.test import Client +from django.template import loader, Context +from django.utils import timezone import cgi +import six from threading import Thread from lxml import etree from six.moves import BaseHTTPServer from six.moves.urllib.parse import urlparse, parse_qsl +from datetime import timedelta from cas_server import models +from cas_server import utils + + +def return_unicode(string, charset): + if not isinstance(string, six.text_type): + return string.decode(charset) + else: + return string + + +def return_bytes(string, charset): + if isinstance(string, six.text_type): + return string.encode(charset) + else: + return string def copy_form(form): @@ -149,10 +168,10 @@ class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler): return @classmethod - def run(cls): + def run(cls, port=0): """Run a BaseHTTPServer using this class as handler""" server_class = BaseHTTPServer.HTTPServer - httpd = server_class(("127.0.0.1", 0), cls) + httpd = server_class(("127.0.0.1", port), cls) (host, port) = httpd.socket.getsockname() def lauch(): @@ -178,3 +197,143 @@ class Http404Handler(HttpParamsHandler): def do_POST(self): """Called on a POST request on the BaseHTTPServer""" return self.do_GET() + + +class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler): + + def test_params(self): + if ( + self.server.ticket is not None and + self.params.get("service").encode("ascii") == self.server.service and + self.params.get("ticket").encode("ascii") == self.server.ticket + ): + self.server.ticket = None + print("good") + return True + else: + print("bad (%r, %r) != (%r, %r)" % ( + self.params.get("service").encode("ascii"), + self.params.get("ticket").encode("ascii"), + self.server.service, + self.server.ticket + )) + + return False + + def send_headers(self, code, content_type): + self.send_response(200) + self.send_header("Content-type", content_type) + self.end_headers() + + def do_GET(self): + url = urlparse(self.path) + self.params = dict(parse_qsl(url.query)) + if url.path == "/validate": + self.send_headers(200, "text/plain; charset=utf-8") + if self.test_params(): + self.wfile.write(b"yes\n" + self.server.username + b"\n") + self.server.ticket = None + else: + self.wfile.write(b"no\n") + elif url.path in { + '/serviceValidate', '/serviceValidate', + '/p3/serviceValidate', '/p3/proxyValidate' + }: + self.send_headers(200, "text/xml; charset=utf-8") + if self.test_params(): + t = loader.get_template('cas_server/serviceValidate.xml') + c = Context({ + 'username': self.server.username, + 'attributes': self.server.attributes + }) + self.wfile.write(return_bytes(t.render(c), "utf8")) + else: + t = loader.get_template('cas_server/serviceValidateError.xml') + c = Context({ + 'code': 'BAD_SERVICE_TICKET', + 'msg': 'Valids are (%r, %r)' % (self.server.service, self.server.ticket) + }) + self.wfile.write(return_bytes(t.render(c), "utf8")) + else: + self.return_404() + + def do_POST(self): + url = urlparse(self.path) + self.params = dict(parse_qsl(url.query)) + if url.path == "/samlValidate": + self.send_headers(200, "text/xml; charset=utf-8") + length = int(self.headers.get('content-length')) + root = etree.fromstring(self.rfile.read(length)) + auth_req = root.getchildren()[1].getchildren()[0] + ticket = auth_req.getchildren()[0].text.encode("ascii") + if ( + self.server.ticket is not None and + self.params.get("TARGET").encode("ascii") == self.server.service and + ticket == self.server.ticket + ): + self.server.ticket = None + t = loader.get_template('cas_server/samlValidate.xml') + c = Context({ + 'IssueInstant': timezone.now().isoformat(), + 'expireInstant': (timezone.now() + timedelta(seconds=60)).isoformat(), + 'Recipient': self.server.service, + 'ResponseID': utils.gen_saml_id(), + 'username': self.server.username, + 'attributes': self.server.attributes, + }) + self.wfile.write(return_bytes(t.render(c), "utf8")) + else: + t = loader.get_template('cas_server/samlValidateError.xml') + c = Context({ + 'IssueInstant': timezone.now().isoformat(), + 'ResponseID': utils.gen_saml_id(), + 'code': 'BAD_SERVICE_TICKET', + 'msg': 'Valids are (%r, %r)' % (self.server.service, self.server.ticket) + }) + self.wfile.write(return_bytes(t.render(c), "utf8")) + else: + self.return_404() + + def return_404(self): + self.send_response(404) + self.send_header(b"Content-type", "text/plain") + self.end_headers() + self.wfile.write("not found") + + def log_message(self, *args): + """silent any log message""" + return + + @classmethod + def run(cls, service, ticket, username, attributes, port=0): + """Run a BaseHTTPServer using this class as handler""" + server_class = BaseHTTPServer.HTTPServer + httpd = server_class(("127.0.0.1", port), cls) + httpd.service = service + httpd.ticket = ticket + httpd.username = username + httpd.attributes = attributes + (host, port) = httpd.socket.getsockname() + + def lauch(): + """routine to lauch in a background thread""" + httpd.handle_request() + httpd.server_close() + + httpd_thread = Thread(target=lauch) + httpd_thread.daemon = True + httpd_thread.start() + return (httpd, host, port) + + +def logout_request(ticket): + return u""" + +%(ticket)s +""" % \ + { + 'id': utils.gen_saml_id(), + 'datetime': timezone.now().isoformat(), + 'ticket': ticket + } diff --git a/cas_server/views.py b/cas_server/views.py index 9543c6f..05ce47d 100644 --- a/cas_server/views.py +++ b/cas_server/views.py @@ -123,15 +123,18 @@ class LogoutView(View, LogoutMixin): self.init_get(request) # if CAS federation mode is enable, bakup the provider before flushing the sessions if settings.CAS_FEDERATE: - component = self.request.session.get("username").split('@') - provider = component[-1] - auth = CASFederateValidateUser(provider, service_url="") + if "username" in self.request.session: + component = self.request.session["username"].split('@') + provider = component[-1] + auth = CASFederateValidateUser(provider, service_url="") + else: + auth = None session_nb = self.logout(self.request.GET.get("all")) # if CAS federation mode is enable, redirect to user CAS logout page if settings.CAS_FEDERATE: - params = utils.copy_params(request.GET) - url = utils.update_url(auth.get_logout_url(), params) - if url: + if auth is not None: + params = utils.copy_params(request.GET) + url = utils.update_url(auth.get_logout_url(), params) return HttpResponseRedirect(url) # if service is set, redirect to service after logout if self.service: @@ -195,7 +198,7 @@ class FederateAuth(View): @staticmethod def get_cas_client(request, provider): - if provider in settings.CAS_FEDERATE_PROVIDERS: + if provider in settings.CAS_FEDERATE_PROVIDERS: # pragma: no branch (should always be true) service_url = utils.get_current_url(request, {"ticket", "provider"}) return CASFederateValidateUser(provider, service_url) @@ -207,14 +210,14 @@ class FederateAuth(View): auth = self.get_cas_client(request, provider) try: auth.clean_sessions(request.POST['logoutRequest']) - except KeyError: + except (KeyError, AttributeError): pass return HttpResponse("ok") # else, a User is trying to log in using an identity provider else: # Manually checking for csrf to protect the code below reason = CsrfViewMiddleware().process_view(request, None, (), {}) - if reason is not None: + if reason is not None: # pragma: no cover (csrf checks are disabled during tests) return reason # Failed the test, stop here. form = forms.FederateSelect(request.POST) if form.is_valid(): @@ -252,7 +255,7 @@ class FederateAuth(View): ticket = request.GET['ticket'] if auth.verify_ticket(ticket): params = utils.copy_params(request.GET, ignore={"ticket"}) - username = "%s@%s" % (auth.username, auth.provider) + username = u"%s@%s" % (auth.username, auth.provider) request.session["federate_username"] = username request.session["federate_ticket"] = ticket auth.register_slo(username, request.session.session_key, ticket) @@ -281,9 +284,9 @@ class LoginView(View, LogoutMixin): renewed = False warned = False - if settings.CAS_FEDERATE: - username = None - ticket = None + # used if CAS_FEDERATE is True + username = None + ticket = None INVALID_LOGIN_TICKET = 1 USER_LOGIN_OK = 2 @@ -354,7 +357,7 @@ class LoginView(View, LogoutMixin): elif ret == self.USER_LOGIN_FAILURE: # bad user login if settings.CAS_FEDERATE: self.ticket = None - self.usernalme = None + self.username = None self.init_form() self.logout() elif ret == self.USER_ALREADY_LOGGED: @@ -682,11 +685,14 @@ class Auth(View): secret = request.POST.get('secret') if not settings.CAS_AUTH_SHARED_SECRET: - return HttpResponse("no\nplease set CAS_AUTH_SHARED_SECRET", content_type="text/plain") + return HttpResponse( + "no\nplease set CAS_AUTH_SHARED_SECRET", + content_type="text/plain; charset=utf-8" + ) if secret != settings.CAS_AUTH_SHARED_SECRET: - return HttpResponse("no\n", content_type="text/plain") + return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8") if not username or not password or not service: - return HttpResponse("no\n", content_type="text/plain") + return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8") form = forms.UserCredential( request.POST, initial={ @@ -714,11 +720,11 @@ class Auth(View): service_pattern.check_user(user) if not request.session.get("authenticated"): user.delete() - return HttpResponse("yes\n", content_type="text/plain") + return HttpResponse(u"yes\n", content_type="text/plain; charset=utf-8") except (ServicePattern.DoesNotExist, models.ServicePatternException): - return HttpResponse("no\n", content_type="text/plain") + return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8") else: - return HttpResponse("no\n", content_type="text/plain") + return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8") class Validate(View): @@ -758,7 +764,10 @@ class Validate(View): username = username[0] else: username = ticket.user.username - return HttpResponse("yes\n%s\n" % username, content_type="text/plain") + return HttpResponse( + u"yes\n%s\n" % username, + content_type="text/plain; charset=utf-8" + ) except ServiceTicket.DoesNotExist: logger.warning( ( @@ -769,10 +778,10 @@ class Validate(View): service ) ) - return HttpResponse("no\n", content_type="text/plain") + return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8") else: logger.warning("Validate: service or ticket missing") - return HttpResponse("no\n", content_type="text/plain") + return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8") class ValidateError(Exception): @@ -815,8 +824,8 @@ class ValidateService(View, AttributesMixin): if not self.service or not self.ticket: logger.warning("ValidateService: missing ticket or service") return ValidateError( - 'INVALID_REQUEST', - "you must specify a service and a ticket" + u'INVALID_REQUEST', + u"you must specify a service and a ticket" ).render(request) else: try: @@ -886,14 +895,14 @@ class ValidateService(View, AttributesMixin): for prox in ticket.proxies.all(): proxies.append(prox.url) else: - raise ValidateError('INVALID_TICKET', self.ticket) + raise ValidateError(u'INVALID_TICKET', self.ticket) ticket.validate = True ticket.save() if ticket.service != self.service: - raise ValidateError('INVALID_SERVICE', self.service) + raise ValidateError(u'INVALID_SERVICE', self.service) return ticket, proxies except (ServiceTicket.DoesNotExist, ProxyTicket.DoesNotExist): - raise ValidateError('INVALID_TICKET', 'ticket not found') + raise ValidateError(u'INVALID_TICKET', 'ticket not found') def process_pgturl(self, params): """Handle PGT request""" @@ -939,18 +948,18 @@ class ValidateService(View, AttributesMixin): except requests.exceptions.RequestException as error: error = utils.unpack_nested_exception(error) raise ValidateError( - 'INVALID_PROXY_CALLBACK', - "%s: %s" % (type(error), str(error)) + u'INVALID_PROXY_CALLBACK', + u"%s: %s" % (type(error), str(error)) ) else: raise ValidateError( - 'INVALID_PROXY_CALLBACK', - "callback url not allowed by configuration" + u'INVALID_PROXY_CALLBACK', + u"callback url not allowed by configuration" ) except ServicePattern.DoesNotExist: raise ValidateError( - 'INVALID_PROXY_CALLBACK', - 'callback url not allowed by configuration' + u'INVALID_PROXY_CALLBACK', + u'callback url not allowed by configuration' ) @@ -971,8 +980,8 @@ class Proxy(View): return self.process_proxy() else: raise ValidateError( - 'INVALID_REQUEST', - "you must specify and pgt and targetService" + u'INVALID_REQUEST', + u"you must specify and pgt and targetService" ) except ValidateError as error: logger.warning("Proxy: validation error: %s %s" % (error.code, error.msg)) @@ -985,8 +994,8 @@ class Proxy(View): pattern = ServicePattern.validate(self.target_service) if not pattern.proxy: raise ValidateError( - 'UNAUTHORIZED_SERVICE', - 'the service %s do not allow proxy ticket' % self.target_service + u'UNAUTHORIZED_SERVICE', + u'the service %s do not allow proxy ticket' % self.target_service ) # is the proxy granting ticket valid ticket = ProxyGrantingTicket.objects.get( @@ -1015,13 +1024,13 @@ class Proxy(View): content_type="text/xml; charset=utf-8" ) except ProxyGrantingTicket.DoesNotExist: - raise ValidateError('INVALID_TICKET', 'PGT %s not found' % self.pgt) + raise ValidateError(u'INVALID_TICKET', u'PGT %s not found' % self.pgt) except ServicePattern.DoesNotExist: - raise ValidateError('UNAUTHORIZED_SERVICE', self.target_service) + raise ValidateError(u'UNAUTHORIZED_SERVICE', self.target_service) except (models.BadUsername, models.BadFilter, models.UserFieldNotDefined): raise ValidateError( - 'UNAUTHORIZED_USER', - 'User %s not allowed on %s' % (ticket.user.username, self.target_service) + u'UNAUTHORIZED_USER', + u'User %s not allowed on %s' % (ticket.user.username, self.target_service) ) @@ -1129,18 +1138,18 @@ class SamlValidate(View, AttributesMixin): ) else: raise SamlValidateError( - 'AuthnFailed', - 'ticket %s should begin with PT- or ST-' % ticket + u'AuthnFailed', + u'ticket %s should begin with PT- or ST-' % ticket ) ticket.validate = True ticket.save() if ticket.service != self.target: raise SamlValidateError( - 'AuthnFailed', - 'TARGET %s do not match ticket service' % self.target + u'AuthnFailed', + u'TARGET %s do not match ticket service' % self.target ) return ticket except (IndexError, KeyError): - raise SamlValidateError('VersionMismatch') + raise SamlValidateError(u'VersionMismatch') except (ServiceTicket.DoesNotExist, ProxyTicket.DoesNotExist): - raise SamlValidateError('AuthnFailed', 'ticket %s not found' % ticket) + raise SamlValidateError(u'AuthnFailed', u'ticket %s not found' % ticket)