Add unit tests for when CAS_FEDERATE is True
Also fix some unicode related bugs
This commit is contained in:
parent
fcd906ca78
commit
90daf3d2a0
13 changed files with 749 additions and 144 deletions
|
@ -171,7 +171,7 @@ class CASFederateAuth(AuthUser):
|
||||||
|
|
||||||
def attributs(self):
|
def attributs(self):
|
||||||
"""return a dict of user attributes"""
|
"""return a dict of user attributes"""
|
||||||
if not self.user:
|
if not self.user: # pragma: no cover (should not happen)
|
||||||
return {}
|
return {}
|
||||||
else:
|
else:
|
||||||
return self.user.attributs
|
return self.user.attributs
|
||||||
|
|
|
@ -14,7 +14,6 @@ from django.conf import settings
|
||||||
from django.contrib.staticfiles.templatetags.staticfiles import static
|
from django.contrib.staticfiles.templatetags.staticfiles import static
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import six
|
|
||||||
|
|
||||||
|
|
||||||
def setting_default(name, default_value):
|
def setting_default(name, default_value):
|
||||||
|
@ -112,13 +111,10 @@ except AttributeError:
|
||||||
key = settings.CAS_FEDERATE_PROVIDERS[key][2].lower()
|
key = settings.CAS_FEDERATE_PROVIDERS[key][2].lower()
|
||||||
else:
|
else:
|
||||||
key = key.lower()
|
key = key.lower()
|
||||||
if isinstance(key, six.string_types) or isinstance(key, six.text_type):
|
return tuple(
|
||||||
return tuple(
|
int(num) if num else alpha
|
||||||
int(num) if num else alpha
|
for num, alpha in __cas_federate_providers_list_sort.tokenize(key)
|
||||||
for num, alpha in __cas_federate_providers_list_sort.tokenize(key)
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
return key
|
|
||||||
__cas_federate_providers_list_sort.tokenize = re.compile(r'(\d+)|(\D+)').findall
|
__cas_federate_providers_list_sort.tokenize = re.compile(r'(\d+)|(\D+)').findall
|
||||||
__CAS_FEDERATE_PROVIDERS_LIST.sort(key=__cas_federate_providers_list_sort)
|
__CAS_FEDERATE_PROVIDERS_LIST.sort(key=__cas_federate_providers_list_sort)
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ from .cas import CASClient
|
||||||
from .models import FederatedUser, FederateSLO, User
|
from .models import FederatedUser, FederateSLO, User
|
||||||
|
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
from six.moves import urllib
|
||||||
|
|
||||||
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
|
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
|
||||||
|
|
||||||
|
@ -27,7 +28,7 @@ class CASFederateValidateUser(object):
|
||||||
def __init__(self, provider, service_url):
|
def __init__(self, provider, service_url):
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
|
|
||||||
if provider in settings.CAS_FEDERATE_PROVIDERS:
|
if provider in settings.CAS_FEDERATE_PROVIDERS: # pragma: no branch (should always be True)
|
||||||
(server_url, version) = settings.CAS_FEDERATE_PROVIDERS[provider][:2]
|
(server_url, version) = settings.CAS_FEDERATE_PROVIDERS[provider][:2]
|
||||||
self.client = CASClient(
|
self.client = CASClient(
|
||||||
service_url=service_url,
|
service_url=service_url,
|
||||||
|
@ -44,9 +45,12 @@ class CASFederateValidateUser(object):
|
||||||
|
|
||||||
def verify_ticket(self, ticket):
|
def verify_ticket(self, ticket):
|
||||||
"""test `password` agains the user"""
|
"""test `password` agains the user"""
|
||||||
if self.client is None:
|
if self.client is None: # pragma: no cover (should not happen)
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
username, attributs = self.client.verify_ticket(ticket)[:2]
|
||||||
|
except urllib.error.URLError:
|
||||||
return False
|
return False
|
||||||
username, attributs = self.client.verify_ticket(ticket)[:2]
|
|
||||||
if username is not None:
|
if username is not None:
|
||||||
if attributs is None:
|
if attributs is None:
|
||||||
attributs = {}
|
attributs = {}
|
||||||
|
@ -83,23 +87,20 @@ class CASFederateValidateUser(object):
|
||||||
|
|
||||||
def clean_sessions(self, logout_request):
|
def clean_sessions(self, logout_request):
|
||||||
try:
|
try:
|
||||||
slos = self.client.get_saml_slos(logout_request)
|
slos = self.client.get_saml_slos(logout_request) or []
|
||||||
except NameError:
|
except NameError: # pragma: no cover (should not happen)
|
||||||
slos = []
|
slos = []
|
||||||
for slo in slos:
|
for slo in slos:
|
||||||
try:
|
for federate_slo in FederateSLO.objects.filter(ticket=slo.text):
|
||||||
for federate_slo in FederateSLO.objects.filter(ticket=slo.text):
|
session = SessionStore(session_key=federate_slo.session_key)
|
||||||
session = SessionStore(session_key=federate_slo.session_key)
|
session.flush()
|
||||||
session.flush()
|
try:
|
||||||
try:
|
user = User.objects.get(
|
||||||
user = User.objects.get(
|
username=federate_slo.username,
|
||||||
username=federate_slo.username,
|
session_key=federate_slo.session_key
|
||||||
session_key=federate_slo.session_key
|
)
|
||||||
)
|
user.logout()
|
||||||
user.logout()
|
user.delete()
|
||||||
user.delete()
|
except User.DoesNotExist: # pragma: no cover (should not happen)
|
||||||
except User.DoesNotExist:
|
pass
|
||||||
pass
|
federate_slo.delete()
|
||||||
federate_slo.delete()
|
|
||||||
except FederateSLO.DoesNotExist:
|
|
||||||
pass
|
|
||||||
|
|
|
@ -31,6 +31,8 @@ class WarnForm(forms.Form):
|
||||||
class FederateSelect(forms.Form):
|
class FederateSelect(forms.Form):
|
||||||
provider = forms.ChoiceField(
|
provider = forms.ChoiceField(
|
||||||
label=_('Identity provider'),
|
label=_('Identity provider'),
|
||||||
|
# with use a lambda abstraction to delay the access to settings.CAS_FEDERATE_PROVIDERS
|
||||||
|
# this is usefull to use the override_settings decorator in tests
|
||||||
choices=[
|
choices=[
|
||||||
(
|
(
|
||||||
p,
|
p,
|
||||||
|
@ -88,8 +90,12 @@ class FederateUserCredential(UserCredential):
|
||||||
user = models.FederatedUser.objects.get(username=username, provider=provider)
|
user = models.FederatedUser.objects.get(username=username, provider=provider)
|
||||||
user.ticket = ""
|
user.ticket = ""
|
||||||
user.save()
|
user.save()
|
||||||
except models.FederatedUser.DoesNotExist:
|
# should not happed as is the FederatedUser do not exists, super should
|
||||||
raise
|
# raise before a ValidationError("bad user")
|
||||||
|
except models.FederatedUser.DoesNotExist: # pragma: no cover (should not happend)
|
||||||
|
raise forms.ValidationError(
|
||||||
|
_(u"User not found in the temporary database, please try to reconnect")
|
||||||
|
)
|
||||||
return cleaned_data
|
return cleaned_data
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,7 @@
|
||||||
from django.core.management.base import BaseCommand
|
from django.core.management.base import BaseCommand
|
||||||
from django.utils.translation import ugettext_lazy as _
|
from django.utils.translation import ugettext_lazy as _
|
||||||
from django.utils import timezone
|
|
||||||
|
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
from ... import models
|
from ... import models
|
||||||
from ...default_settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
class Command(BaseCommand):
|
class Command(BaseCommand):
|
||||||
|
@ -13,11 +9,5 @@ class Command(BaseCommand):
|
||||||
help = _(u"Clean old federated users")
|
help = _(u"Clean old federated users")
|
||||||
|
|
||||||
def handle(self, *args, **options):
|
def handle(self, *args, **options):
|
||||||
federated_users = models.FederatedUser.objects.filter(
|
models.FederatedUser.clean_old_entries()
|
||||||
last_update__lt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT))
|
|
||||||
)
|
|
||||||
known_users = {user.username for user in models.User.objects.all()}
|
|
||||||
for user in federated_users:
|
|
||||||
if not ('%s@%s' % (user.username, user.provider)) in known_users:
|
|
||||||
user.delete()
|
|
||||||
models.FederateSLO.clean_deleted_sessions()
|
models.FederateSLO.clean_deleted_sessions()
|
||||||
|
|
|
@ -46,6 +46,16 @@ class FederatedUser(models.Model):
|
||||||
def __unicode__(self):
|
def __unicode__(self):
|
||||||
return u"%s@%s" % (self.username, self.provider)
|
return u"%s@%s" % (self.username, self.provider)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def clean_old_entries(cls):
|
||||||
|
federated_users = cls.objects.filter(
|
||||||
|
last_update__lt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT))
|
||||||
|
)
|
||||||
|
known_users = {user.username for user in User.objects.all()}
|
||||||
|
for user in federated_users:
|
||||||
|
if not ('%s@%s' % (user.username, user.provider)) in known_users:
|
||||||
|
user.delete()
|
||||||
|
|
||||||
|
|
||||||
class FederateSLO(models.Model):
|
class FederateSLO(models.Model):
|
||||||
class Meta:
|
class Meta:
|
||||||
|
@ -54,11 +64,6 @@ class FederateSLO(models.Model):
|
||||||
session_key = models.CharField(max_length=40, blank=True, null=True)
|
session_key = models.CharField(max_length=40, blank=True, null=True)
|
||||||
ticket = models.CharField(max_length=255)
|
ticket = models.CharField(max_length=255)
|
||||||
|
|
||||||
@property
|
|
||||||
def provider(self):
|
|
||||||
component = self.username.split("@")
|
|
||||||
return component[-1]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def clean_deleted_sessions(cls):
|
def clean_deleted_sessions(cls):
|
||||||
for federate_slo in cls.objects.all():
|
for federate_slo in cls.objects.all():
|
||||||
|
@ -76,6 +81,14 @@ class User(models.Model):
|
||||||
username = models.CharField(max_length=30)
|
username = models.CharField(max_length=30)
|
||||||
date = models.DateTimeField(auto_now=True)
|
date = models.DateTimeField(auto_now=True)
|
||||||
|
|
||||||
|
def delete(self, *args, **kwargs):
|
||||||
|
if settings.CAS_FEDERATE:
|
||||||
|
FederateSLO.objects.filter(
|
||||||
|
username=self.username,
|
||||||
|
session_key=self.session_key
|
||||||
|
).delete()
|
||||||
|
super(User, self).delete(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def clean_old_entries(cls):
|
def clean_old_entries(cls):
|
||||||
"""Remove users inactive since more that SESSION_COOKIE_AGE"""
|
"""Remove users inactive since more that SESSION_COOKIE_AGE"""
|
||||||
|
|
|
@ -191,3 +191,50 @@ class UserModels(object):
|
||||||
username=settings.CAS_TEST_USER,
|
username=settings.CAS_TEST_USER,
|
||||||
session_key=client.session.session_key
|
session_key=client.session.session_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CanLogin(object):
|
||||||
|
"""Assertion about login"""
|
||||||
|
def assert_logged(
|
||||||
|
self, client, response, warn=False,
|
||||||
|
code=200, username=settings.CAS_TEST_USER
|
||||||
|
):
|
||||||
|
"""Assertions testing that client is well authenticated"""
|
||||||
|
self.assertEqual(response.status_code, code)
|
||||||
|
# this message is displayed to the user upon successful authentication
|
||||||
|
self.assertIn(
|
||||||
|
(
|
||||||
|
b"You have successfully logged into "
|
||||||
|
b"the Central Authentication Service"
|
||||||
|
),
|
||||||
|
response.content
|
||||||
|
)
|
||||||
|
# these session variables a set if usccessfully authenticated
|
||||||
|
self.assertEqual(client.session["username"], username)
|
||||||
|
self.assertIs(client.session["warn"], warn)
|
||||||
|
self.assertIs(client.session["authenticated"], True)
|
||||||
|
|
||||||
|
# on successfull authentication, a corresponding user object is created
|
||||||
|
self.assertTrue(
|
||||||
|
models.User.objects.get(
|
||||||
|
username=username,
|
||||||
|
session_key=client.session.session_key
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def assert_login_failed(self, client, response, code=200):
|
||||||
|
"""Assertions testing a failed login attempt"""
|
||||||
|
self.assertEqual(response.status_code, code)
|
||||||
|
# this message is displayed to the user upon successful authentication, so it should not
|
||||||
|
# appear
|
||||||
|
self.assertFalse(
|
||||||
|
(
|
||||||
|
b"You have successfully logged into "
|
||||||
|
b"the Central Authentication Service"
|
||||||
|
) in response.content
|
||||||
|
)
|
||||||
|
|
||||||
|
# if authentication has failed, these session variables should not be set
|
||||||
|
self.assertTrue(client.session.get("username") is None)
|
||||||
|
self.assertTrue(client.session.get("warn") is None)
|
||||||
|
self.assertTrue(client.session.get("authenticated") is None)
|
||||||
|
|
344
cas_server/tests/test_federate.py
Normal file
344
cas_server/tests/test_federate.py
Normal file
|
@ -0,0 +1,344 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# This program is distributed in the hope that it will be useful, but WITHOUT
|
||||||
|
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
||||||
|
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
|
||||||
|
# more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU General Public License version 3
|
||||||
|
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
||||||
|
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||||
|
#
|
||||||
|
# (c) 2016 Valentin Samir
|
||||||
|
"""tests for the CAS federate mode"""
|
||||||
|
from cas_server import default_settings
|
||||||
|
from cas_server.default_settings import settings
|
||||||
|
|
||||||
|
import django
|
||||||
|
from django.test import TestCase, Client
|
||||||
|
from django.test.utils import override_settings
|
||||||
|
|
||||||
|
from six.moves import reload_module
|
||||||
|
|
||||||
|
from cas_server import utils, forms
|
||||||
|
from cas_server.tests.mixin import BaseServicePattern, CanLogin
|
||||||
|
from cas_server.tests import utils as tests_utils
|
||||||
|
|
||||||
|
PROVIDERS = {
|
||||||
|
"example.com": ("http://127.0.0.1:8080", 1, "Example dot com"),
|
||||||
|
"example.org": ("http://127.0.0.1:8081", 2, "Example dot org"),
|
||||||
|
"example.net": ("http://127.0.0.1:8082", 3, "Example dot net"),
|
||||||
|
"example.test": ("http://127.0.0.1:8083", 'CAS_2_SAML_1_0'),
|
||||||
|
}
|
||||||
|
|
||||||
|
PROVIDERS_LIST = list(PROVIDERS.keys())
|
||||||
|
PROVIDERS_LIST.sort()
|
||||||
|
|
||||||
|
|
||||||
|
def getaddrinfo_mock(name, port, *args, **kwargs):
|
||||||
|
return [(2, 1, 6, '', ('127.0.0.1', 80))]
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(
|
||||||
|
CAS_FEDERATE=True,
|
||||||
|
CAS_FEDERATE_PROVIDERS=PROVIDERS,
|
||||||
|
CAS_FEDERATE_PROVIDERS_LIST=PROVIDERS_LIST,
|
||||||
|
CAS_AUTH_CLASS="cas_server.auth.CASFederateAuth",
|
||||||
|
# test with a non ascii username
|
||||||
|
CAS_TEST_USER=u"dédé"
|
||||||
|
)
|
||||||
|
class FederateAuthLoginLogoutTestCase(TestCase, BaseServicePattern, CanLogin):
|
||||||
|
"""tests for the views login logout and federate then the federated mode is enabled"""
|
||||||
|
def setUp(self):
|
||||||
|
"""Prepare the test context"""
|
||||||
|
self.setup_service_patterns()
|
||||||
|
reload_module(forms)
|
||||||
|
|
||||||
|
def test_default_settings(self):
|
||||||
|
"""default settings should populated some default variable then CAS_FEDERATE is True"""
|
||||||
|
provider_list = settings.CAS_FEDERATE_PROVIDERS_LIST
|
||||||
|
del settings.CAS_FEDERATE_PROVIDERS_LIST
|
||||||
|
del settings.CAS_AUTH_CLASS
|
||||||
|
reload_module(default_settings)
|
||||||
|
self.assertEqual(settings.CAS_FEDERATE_PROVIDERS_LIST, provider_list)
|
||||||
|
self.assertEqual(settings.CAS_AUTH_CLASS, "cas_server.auth.CASFederateAuth")
|
||||||
|
|
||||||
|
def test_login_get_provider(self):
|
||||||
|
"""some assertion about the login page in federated mode"""
|
||||||
|
client = Client()
|
||||||
|
response = client.get("/login")
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
for key, value in settings.CAS_FEDERATE_PROVIDERS.items():
|
||||||
|
self.assertTrue('<option value="%s">%s</option>' % (
|
||||||
|
key,
|
||||||
|
utils.get_tuple(value, 2, key)
|
||||||
|
) in response.content.decode("utf-8"))
|
||||||
|
self.assertEqual(response.context['post_url'], '/federate')
|
||||||
|
|
||||||
|
def test_login_post_provider(self, remember=False):
|
||||||
|
"""test a successful login wrokflow"""
|
||||||
|
tickets = []
|
||||||
|
# choose the example.com provider
|
||||||
|
for (provider, cas_port) in [
|
||||||
|
("example.com", 8080), ("example.org", 8081),
|
||||||
|
("example.net", 8082), ("example.test", 8083)
|
||||||
|
]:
|
||||||
|
# get a bare client
|
||||||
|
client = Client()
|
||||||
|
# fetch the login page
|
||||||
|
response = client.get("/login")
|
||||||
|
# in federated mode, we shoudl POST do /federate on the login page
|
||||||
|
self.assertEqual(response.context['post_url'], '/federate')
|
||||||
|
# get current form parameter
|
||||||
|
params = tests_utils.copy_form(response.context["form"])
|
||||||
|
params['provider'] = provider
|
||||||
|
if remember:
|
||||||
|
params['remember'] = 'on'
|
||||||
|
# post the choosed provider
|
||||||
|
response = client.post('/federate', params)
|
||||||
|
# we are redirected to the provider CAS client url
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
if remember:
|
||||||
|
self.assertEqual(response["Location"], '%s/federate/%s?remember=on' % (
|
||||||
|
'http://testserver' if django.VERSION < (1, 9) else "",
|
||||||
|
provider
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
self.assertEqual(response["Location"], '%s/federate/%s' % (
|
||||||
|
'http://testserver' if django.VERSION < (1, 9) else "",
|
||||||
|
provider
|
||||||
|
))
|
||||||
|
# let's follow the redirect
|
||||||
|
response = client.get('/federate/%s' % provider)
|
||||||
|
# we are redirected to the provider CAS for authentication
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
self.assertEqual(
|
||||||
|
response["Location"],
|
||||||
|
"%s/login?service=http%%3A%%2F%%2Ftestserver%%2Ffederate%%2F%s" % (
|
||||||
|
settings.CAS_FEDERATE_PROVIDERS[provider][0],
|
||||||
|
provider
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# let's generate a ticket
|
||||||
|
ticket = utils.gen_st()
|
||||||
|
# we lauch a dummy CAS server that only validate once for the service
|
||||||
|
# http://testserver/federate/example.com with `ticket`
|
||||||
|
tests_utils.DummyCAS.run(
|
||||||
|
("http://testserver/federate/%s" % provider).encode("ascii"),
|
||||||
|
ticket.encode("ascii"),
|
||||||
|
settings.CAS_TEST_USER.encode("utf8"),
|
||||||
|
[],
|
||||||
|
cas_port
|
||||||
|
)
|
||||||
|
# we normally provide a good ticket and should be redirected to /login as the ticket
|
||||||
|
# get successfully validated again the dummy CAS
|
||||||
|
response = client.get('/federate/%s' % provider, {'ticket': ticket})
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
self.assertEqual(response["Location"], "%s/login" % (
|
||||||
|
'http://testserver' if django.VERSION < (1, 9) else ""
|
||||||
|
))
|
||||||
|
# follow the redirect
|
||||||
|
response = client.get("/login")
|
||||||
|
# we should get a page with a from with all widget hidden that auto POST to /login using
|
||||||
|
# javascript. If javascript is disabled, a "connect" button is showed
|
||||||
|
self.assertTrue(response.context['auto_submit'])
|
||||||
|
self.assertEqual(response.context['post_url'], '/login')
|
||||||
|
params = tests_utils.copy_form(response.context["form"])
|
||||||
|
# POST ge prefiled from parameters
|
||||||
|
response = client.post("/login", params)
|
||||||
|
# the user should now being authenticated using username test@`provider`
|
||||||
|
self.assert_logged(
|
||||||
|
client, response, username='%s@%s' % (settings.CAS_TEST_USER, provider)
|
||||||
|
)
|
||||||
|
tickets.append((provider, ticket, client))
|
||||||
|
|
||||||
|
# try to get a ticket
|
||||||
|
response = client.get("/login", {'service': self.service})
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
self.assertTrue(response["Location"].startswith("%s?ticket=" % self.service))
|
||||||
|
return tickets
|
||||||
|
|
||||||
|
def test_login_twice(self):
|
||||||
|
"""Test that user id db is used for the second login (cf coverage)"""
|
||||||
|
self.test_login_post_provider()
|
||||||
|
self.test_login_post_provider()
|
||||||
|
|
||||||
|
@override_settings(CAS_FEDERATE=False)
|
||||||
|
def test_auth_federate_false(self):
|
||||||
|
"""federated view should redirect to /login then CAS_FEDERATE is False"""
|
||||||
|
provider = "example.com"
|
||||||
|
client = Client()
|
||||||
|
response = client.get("/federate/%s" % provider)
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
self.assertEqual(response["Location"], "%s/login" % (
|
||||||
|
'http://testserver' if django.VERSION < (1, 9) else ""
|
||||||
|
))
|
||||||
|
response = client.post("%s/federate/%s" % (
|
||||||
|
'http://testserver' if django.VERSION < (1, 9) else "",
|
||||||
|
provider
|
||||||
|
))
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
self.assertEqual(response["Location"], "%s/login" % (
|
||||||
|
'http://testserver' if django.VERSION < (1, 9) else ""
|
||||||
|
))
|
||||||
|
|
||||||
|
def test_auth_federate_errors(self):
|
||||||
|
"""
|
||||||
|
The federated view should redirect to /login if the provider is unknown or not provided,
|
||||||
|
try to fetch a new ticket if the provided ticket validation fail
|
||||||
|
(network error or bad ticket)
|
||||||
|
"""
|
||||||
|
return
|
||||||
|
good_provider = "example.com"
|
||||||
|
bad_provider = "exemple.fr"
|
||||||
|
client = Client()
|
||||||
|
response = client.get("/federate/%s" % bad_provider)
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
self.assertEqual(response["Location"], "%s/login" % (
|
||||||
|
'http://testserver' if django.VERSION < (1, 9) else ""
|
||||||
|
))
|
||||||
|
|
||||||
|
# test CAS not avaible
|
||||||
|
response = client.get("/federate/%s" % good_provider, {'ticket': utils.gen_st()})
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
self.assertEqual(
|
||||||
|
response["Location"],
|
||||||
|
"%s/login?service=http%%3A%%2F%%2Ftestserver%%2Ffederate%%2F%s" % (
|
||||||
|
settings.CAS_FEDERATE_PROVIDERS[good_provider][0],
|
||||||
|
good_provider
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# test CAS avaible but bad ticket
|
||||||
|
tests_utils.DummyCAS.run(
|
||||||
|
("http://testserver/federate/%s" % good_provider).encode("ascii"),
|
||||||
|
utils.gen_st().encode("ascii"),
|
||||||
|
settings.CAS_TEST_USER.encode("utf-8"),
|
||||||
|
[],
|
||||||
|
8080
|
||||||
|
)
|
||||||
|
response = client.get("/federate/%s" % good_provider, {'ticket': utils.gen_st()})
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
self.assertEqual(
|
||||||
|
response["Location"],
|
||||||
|
"%s/login?service=http%%3A%%2F%%2Ftestserver%%2Ffederate%%2F%s" % (
|
||||||
|
settings.CAS_FEDERATE_PROVIDERS[good_provider][0],
|
||||||
|
good_provider
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.post("/federate")
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
self.assertEqual(response["Location"], "%s/login" % (
|
||||||
|
'http://testserver' if django.VERSION < (1, 9) else ""
|
||||||
|
))
|
||||||
|
|
||||||
|
def test_auth_federate_slo(self):
|
||||||
|
"""test that SLO receive from backend CAS log out the users"""
|
||||||
|
# get tickets and connected clients
|
||||||
|
tickets = self.test_login_post_provider()
|
||||||
|
for (provider, ticket, client) in tickets:
|
||||||
|
# SLO for an unkown ticket should do nothing
|
||||||
|
response = client.post(
|
||||||
|
"/federate/%s" % provider,
|
||||||
|
{'logoutRequest': tests_utils.logout_request(utils.gen_st())}
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
self.assertEqual(response.content, b"ok")
|
||||||
|
# Bad SLO format should do nothing
|
||||||
|
response = client.post(
|
||||||
|
"/federate/%s" % provider,
|
||||||
|
{'logoutRequest': ""}
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
self.assertEqual(response.content, b"ok")
|
||||||
|
# Bad SLO format should do nothing
|
||||||
|
response = client.post(
|
||||||
|
"/federate/%s" % provider,
|
||||||
|
{'logoutRequest': "<root></root>"}
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
self.assertEqual(response.content, b"ok")
|
||||||
|
response = client.get("/login")
|
||||||
|
self.assert_logged(
|
||||||
|
client, response, username='%s@%s' % (settings.CAS_TEST_USER, provider)
|
||||||
|
)
|
||||||
|
|
||||||
|
# SLO for a previously logged ticket should log out the user if CAS version is
|
||||||
|
# 3 or 'CAS_2_SAML_1_0'
|
||||||
|
response = client.post(
|
||||||
|
"/federate/%s" % provider,
|
||||||
|
{'logoutRequest': tests_utils.logout_request(ticket)}
|
||||||
|
)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
self.assertEqual(response.content, b"ok")
|
||||||
|
|
||||||
|
response = client.get("/login")
|
||||||
|
if settings.CAS_FEDERATE_PROVIDERS[provider][1] in {3, 'CAS_2_SAML_1_0'}: # support SLO
|
||||||
|
self.assert_login_failed(client, response)
|
||||||
|
else:
|
||||||
|
self.assert_logged(
|
||||||
|
client, response, username='%s@%s' % (settings.CAS_TEST_USER, provider)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_federate_logout(self):
|
||||||
|
"""
|
||||||
|
test the logout function: the user should be log out
|
||||||
|
and redirected to his CAS logout page
|
||||||
|
"""
|
||||||
|
# get tickets and connected clients
|
||||||
|
tickets = self.test_login_post_provider()
|
||||||
|
for (provider, _, client) in tickets:
|
||||||
|
response = client.get("/logout")
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
self.assertEqual(
|
||||||
|
response["Location"],
|
||||||
|
"%s/logout" % settings.CAS_FEDERATE_PROVIDERS[provider][0]
|
||||||
|
)
|
||||||
|
response = client.get("/login")
|
||||||
|
self.assert_login_failed(client, response)
|
||||||
|
|
||||||
|
def test_remember_provider(self):
|
||||||
|
"""
|
||||||
|
If the user check remember, next login should not offer the chose of the backend CAS
|
||||||
|
and use the one store in the cookie
|
||||||
|
"""
|
||||||
|
tickets = self.test_login_post_provider(remember=True)
|
||||||
|
for (provider, _, client) in tickets:
|
||||||
|
client.get("/logout")
|
||||||
|
response = client.get("/login")
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
self.assertEqual(response["Location"], "%s/federate/%s" % (
|
||||||
|
'http://testserver' if django.VERSION < (1, 9) else "",
|
||||||
|
provider
|
||||||
|
))
|
||||||
|
|
||||||
|
def test_login_bad_ticket(self):
|
||||||
|
"""
|
||||||
|
Try login with a bad ticket:
|
||||||
|
login should fail and the main login page should be displayed to the user
|
||||||
|
"""
|
||||||
|
provider = "example.com"
|
||||||
|
# get a bare client
|
||||||
|
client = Client()
|
||||||
|
session = client.session
|
||||||
|
session["federate_username"] = '%s@%s' % (settings.CAS_TEST_USER, provider)
|
||||||
|
session["federate_ticket"] = utils.gen_st()
|
||||||
|
try:
|
||||||
|
session.save()
|
||||||
|
response = client.get("/login")
|
||||||
|
# we should get a page with a from with all widget hidden that auto POST to /login using
|
||||||
|
# javascript. If javascript is disabled, a "connect" button is showed
|
||||||
|
self.assertTrue(response.context['auto_submit'])
|
||||||
|
self.assertEqual(response.context['post_url'], '/login')
|
||||||
|
params = tests_utils.copy_form(response.context["form"])
|
||||||
|
# POST, as (username, ticket) are not valid, we should get the federate login page
|
||||||
|
response = client.post("/login", params)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
for key, value in settings.CAS_FEDERATE_PROVIDERS.items():
|
||||||
|
self.assertTrue('<option value="%s">%s</option>' % (
|
||||||
|
key,
|
||||||
|
utils.get_tuple(value, 2, key)
|
||||||
|
) in response.content.decode("utf-8"))
|
||||||
|
self.assertEqual(response.context['post_url'], '/federate')
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
|
@ -12,20 +12,81 @@
|
||||||
"""Tests module for models"""
|
"""Tests module for models"""
|
||||||
from cas_server.default_settings import settings
|
from cas_server.default_settings import settings
|
||||||
|
|
||||||
from django.test import TestCase
|
from django.test import TestCase, Client
|
||||||
from django.test.utils import override_settings
|
from django.test.utils import override_settings
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
from cas_server import models
|
from cas_server import models, utils
|
||||||
from cas_server.tests.utils import get_auth_client, HttpParamsHandler
|
from cas_server.tests.utils import get_auth_client, HttpParamsHandler
|
||||||
from cas_server.tests.mixin import UserModels, BaseServicePattern
|
from cas_server.tests.mixin import UserModels, BaseServicePattern
|
||||||
|
|
||||||
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
|
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
|
||||||
|
|
||||||
|
|
||||||
|
class FederatedUserTestCase(TestCase, UserModels):
|
||||||
|
"""test for the federated user model"""
|
||||||
|
def test_clean_old_entries(self):
|
||||||
|
"""tests for clean_old_entries that should delete federated user no longer used"""
|
||||||
|
client = Client()
|
||||||
|
client.get("/login")
|
||||||
|
models.FederatedUser.objects.create(
|
||||||
|
username="test1", provider="example.com", attributs={}, ticket=""
|
||||||
|
)
|
||||||
|
models.FederatedUser.objects.create(
|
||||||
|
username="test2", provider="example.com", attributs={}, ticket=""
|
||||||
|
)
|
||||||
|
models.FederatedUser.objects.all().update(
|
||||||
|
last_update=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT + 10))
|
||||||
|
)
|
||||||
|
models.FederatedUser.objects.create(
|
||||||
|
username="test3", provider="example.com", attributs={}, ticket=""
|
||||||
|
)
|
||||||
|
models.User.objects.create(
|
||||||
|
username="test1@example.com", session_key=client.session.session_key
|
||||||
|
)
|
||||||
|
models.FederatedUser.clean_old_entries()
|
||||||
|
self.assertEqual(len(models.FederatedUser.objects.all()), 2)
|
||||||
|
with self.assertRaises(models.FederatedUser.DoesNotExist):
|
||||||
|
models.FederatedUser.objects.get(username="test2")
|
||||||
|
|
||||||
|
|
||||||
|
class FederateSLOTestCase(TestCase, UserModels):
|
||||||
|
"""test for the federated SLO model"""
|
||||||
|
def test_clean_deleted_sessions(self):
|
||||||
|
"""
|
||||||
|
tests for clean_deleted_sessions that should delete object for which matching session
|
||||||
|
do not exists anymore
|
||||||
|
"""
|
||||||
|
client1 = Client()
|
||||||
|
client2 = Client()
|
||||||
|
client1.get("/login")
|
||||||
|
client2.get("/login")
|
||||||
|
session = client2.session
|
||||||
|
session['authenticated'] = True
|
||||||
|
try:
|
||||||
|
session.save()
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
models.FederateSLO.objects.create(
|
||||||
|
username="test1@example.com",
|
||||||
|
session_key=client1.session.session_key,
|
||||||
|
ticket=utils.gen_st()
|
||||||
|
)
|
||||||
|
models.FederateSLO.objects.create(
|
||||||
|
username="test2@example.com",
|
||||||
|
session_key=client2.session.session_key,
|
||||||
|
ticket=utils.gen_st()
|
||||||
|
)
|
||||||
|
self.assertEqual(len(models.FederateSLO.objects.all()), 2)
|
||||||
|
models.FederateSLO.clean_deleted_sessions()
|
||||||
|
self.assertEqual(len(models.FederateSLO.objects.all()), 1)
|
||||||
|
with self.assertRaises(models.FederateSLO.DoesNotExist):
|
||||||
|
models.FederateSLO.objects.get(username="test1@example.com")
|
||||||
|
|
||||||
|
|
||||||
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
|
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
|
||||||
class UserTestCase(TestCase, UserModels):
|
class UserTestCase(TestCase, UserModels):
|
||||||
"""tests for the user models"""
|
"""tests for the user models"""
|
||||||
|
|
|
@ -10,7 +10,7 @@
|
||||||
#
|
#
|
||||||
# (c) 2016 Valentin Samir
|
# (c) 2016 Valentin Samir
|
||||||
"""Tests module for utils"""
|
"""Tests module for utils"""
|
||||||
from django.test import TestCase
|
from django.test import TestCase, RequestFactory
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
@ -189,3 +189,22 @@ class UtilsTestCase(TestCase):
|
||||||
self.assertFalse(utils.crypt_salt_is_valid("$$")) # start with $ followed by $
|
self.assertFalse(utils.crypt_salt_is_valid("$$")) # start with $ followed by $
|
||||||
self.assertFalse(utils.crypt_salt_is_valid("$toto")) # start with $ but no secondary $
|
self.assertFalse(utils.crypt_salt_is_valid("$toto")) # start with $ but no secondary $
|
||||||
self.assertFalse(utils.crypt_salt_is_valid("$toto$toto")) # algorithm toto not known
|
self.assertFalse(utils.crypt_salt_is_valid("$toto$toto")) # algorithm toto not known
|
||||||
|
|
||||||
|
def test_get_current_url(self):
|
||||||
|
"""test the function get_current_url"""
|
||||||
|
factory = RequestFactory()
|
||||||
|
request = factory.get('/truc/muche?test=1')
|
||||||
|
self.assertEqual(utils.get_current_url(request), 'http://testserver/truc/muche?test=1')
|
||||||
|
self.assertEqual(
|
||||||
|
utils.get_current_url(request, ignore_params={'test'}),
|
||||||
|
'http://testserver/truc/muche'
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_tuple(self):
|
||||||
|
"""test the function get_tuple"""
|
||||||
|
test_tuple = (1, 2, 3)
|
||||||
|
for index, value in enumerate(test_tuple):
|
||||||
|
self.assertEqual(utils.get_tuple(test_tuple, index), value)
|
||||||
|
self.assertEqual(utils.get_tuple(test_tuple, 3), None)
|
||||||
|
self.assertEqual(utils.get_tuple(test_tuple, 3, 'toto'), 'toto')
|
||||||
|
self.assertEqual(utils.get_tuple(None, 3), None)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# ⁻*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# This program is distributed in the hope that it will be useful, but WITHOUT
|
# This program is distributed in the hope that it will be useful, but WITHOUT
|
||||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
||||||
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
|
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
|
||||||
|
@ -36,57 +36,17 @@ from cas_server.tests.utils import (
|
||||||
HttpParamsHandler,
|
HttpParamsHandler,
|
||||||
Http404Handler
|
Http404Handler
|
||||||
)
|
)
|
||||||
from cas_server.tests.mixin import BaseServicePattern, XmlContent
|
from cas_server.tests.mixin import BaseServicePattern, XmlContent, CanLogin
|
||||||
|
|
||||||
|
|
||||||
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
|
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
|
||||||
class LoginTestCase(TestCase, BaseServicePattern):
|
class LoginTestCase(TestCase, BaseServicePattern, CanLogin):
|
||||||
"""Tests for the login view"""
|
"""Tests for the login view"""
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
"""Prepare the test context:"""
|
"""Prepare the test context:"""
|
||||||
# we prepare a bunch a service url and service patterns for tests
|
# we prepare a bunch a service url and service patterns for tests
|
||||||
self.setup_service_patterns()
|
self.setup_service_patterns()
|
||||||
|
|
||||||
def assert_logged(self, client, response, warn=False, code=200):
|
|
||||||
"""Assertions testing that client is well authenticated"""
|
|
||||||
self.assertEqual(response.status_code, code)
|
|
||||||
# this message is displayed to the user upon successful authentication
|
|
||||||
self.assertTrue(
|
|
||||||
(
|
|
||||||
b"You have successfully logged into "
|
|
||||||
b"the Central Authentication Service"
|
|
||||||
) in response.content
|
|
||||||
)
|
|
||||||
# these session variables a set if usccessfully authenticated
|
|
||||||
self.assertTrue(client.session["username"] == settings.CAS_TEST_USER)
|
|
||||||
self.assertTrue(client.session["warn"] is warn)
|
|
||||||
self.assertTrue(client.session["authenticated"] is True)
|
|
||||||
|
|
||||||
# on successfull authentication, a corresponding user object is created
|
|
||||||
self.assertTrue(
|
|
||||||
models.User.objects.get(
|
|
||||||
username=settings.CAS_TEST_USER,
|
|
||||||
session_key=client.session.session_key
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def assert_login_failed(self, client, response, code=200):
|
|
||||||
"""Assertions testing a failed login attempt"""
|
|
||||||
self.assertEqual(response.status_code, code)
|
|
||||||
# this message is displayed to the user upon successful authentication, so it should not
|
|
||||||
# appear
|
|
||||||
self.assertFalse(
|
|
||||||
(
|
|
||||||
b"You have successfully logged into "
|
|
||||||
b"the Central Authentication Service"
|
|
||||||
) in response.content
|
|
||||||
)
|
|
||||||
|
|
||||||
# if authentication has failed, these session variables should not be set
|
|
||||||
self.assertTrue(client.session.get("username") is None)
|
|
||||||
self.assertTrue(client.session.get("warn") is None)
|
|
||||||
self.assertTrue(client.session.get("authenticated") is None)
|
|
||||||
|
|
||||||
def test_login_view_post_goodpass_goodlt(self):
|
def test_login_view_post_goodpass_goodlt(self):
|
||||||
"""Test a successul login"""
|
"""Test a successul login"""
|
||||||
# we get a client who fetch a frist time the login page and the login form default
|
# we get a client who fetch a frist time the login page and the login form default
|
||||||
|
|
|
@ -13,14 +13,33 @@
|
||||||
from cas_server.default_settings import settings
|
from cas_server.default_settings import settings
|
||||||
|
|
||||||
from django.test import Client
|
from django.test import Client
|
||||||
|
from django.template import loader, Context
|
||||||
|
from django.utils import timezone
|
||||||
|
|
||||||
import cgi
|
import cgi
|
||||||
|
import six
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from lxml import etree
|
from lxml import etree
|
||||||
from six.moves import BaseHTTPServer
|
from six.moves import BaseHTTPServer
|
||||||
from six.moves.urllib.parse import urlparse, parse_qsl
|
from six.moves.urllib.parse import urlparse, parse_qsl
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
from cas_server import models
|
from cas_server import models
|
||||||
|
from cas_server import utils
|
||||||
|
|
||||||
|
|
||||||
|
def return_unicode(string, charset):
|
||||||
|
if not isinstance(string, six.text_type):
|
||||||
|
return string.decode(charset)
|
||||||
|
else:
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def return_bytes(string, charset):
|
||||||
|
if isinstance(string, six.text_type):
|
||||||
|
return string.encode(charset)
|
||||||
|
else:
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
def copy_form(form):
|
def copy_form(form):
|
||||||
|
@ -149,10 +168,10 @@ class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler):
|
||||||
return
|
return
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def run(cls):
|
def run(cls, port=0):
|
||||||
"""Run a BaseHTTPServer using this class as handler"""
|
"""Run a BaseHTTPServer using this class as handler"""
|
||||||
server_class = BaseHTTPServer.HTTPServer
|
server_class = BaseHTTPServer.HTTPServer
|
||||||
httpd = server_class(("127.0.0.1", 0), cls)
|
httpd = server_class(("127.0.0.1", port), cls)
|
||||||
(host, port) = httpd.socket.getsockname()
|
(host, port) = httpd.socket.getsockname()
|
||||||
|
|
||||||
def lauch():
|
def lauch():
|
||||||
|
@ -178,3 +197,143 @@ class Http404Handler(HttpParamsHandler):
|
||||||
def do_POST(self):
|
def do_POST(self):
|
||||||
"""Called on a POST request on the BaseHTTPServer"""
|
"""Called on a POST request on the BaseHTTPServer"""
|
||||||
return self.do_GET()
|
return self.do_GET()
|
||||||
|
|
||||||
|
|
||||||
|
class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
|
||||||
|
|
||||||
|
def test_params(self):
|
||||||
|
if (
|
||||||
|
self.server.ticket is not None and
|
||||||
|
self.params.get("service").encode("ascii") == self.server.service and
|
||||||
|
self.params.get("ticket").encode("ascii") == self.server.ticket
|
||||||
|
):
|
||||||
|
self.server.ticket = None
|
||||||
|
print("good")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("bad (%r, %r) != (%r, %r)" % (
|
||||||
|
self.params.get("service").encode("ascii"),
|
||||||
|
self.params.get("ticket").encode("ascii"),
|
||||||
|
self.server.service,
|
||||||
|
self.server.ticket
|
||||||
|
))
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def send_headers(self, code, content_type):
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header("Content-type", content_type)
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
def do_GET(self):
|
||||||
|
url = urlparse(self.path)
|
||||||
|
self.params = dict(parse_qsl(url.query))
|
||||||
|
if url.path == "/validate":
|
||||||
|
self.send_headers(200, "text/plain; charset=utf-8")
|
||||||
|
if self.test_params():
|
||||||
|
self.wfile.write(b"yes\n" + self.server.username + b"\n")
|
||||||
|
self.server.ticket = None
|
||||||
|
else:
|
||||||
|
self.wfile.write(b"no\n")
|
||||||
|
elif url.path in {
|
||||||
|
'/serviceValidate', '/serviceValidate',
|
||||||
|
'/p3/serviceValidate', '/p3/proxyValidate'
|
||||||
|
}:
|
||||||
|
self.send_headers(200, "text/xml; charset=utf-8")
|
||||||
|
if self.test_params():
|
||||||
|
t = loader.get_template('cas_server/serviceValidate.xml')
|
||||||
|
c = Context({
|
||||||
|
'username': self.server.username,
|
||||||
|
'attributes': self.server.attributes
|
||||||
|
})
|
||||||
|
self.wfile.write(return_bytes(t.render(c), "utf8"))
|
||||||
|
else:
|
||||||
|
t = loader.get_template('cas_server/serviceValidateError.xml')
|
||||||
|
c = Context({
|
||||||
|
'code': 'BAD_SERVICE_TICKET',
|
||||||
|
'msg': 'Valids are (%r, %r)' % (self.server.service, self.server.ticket)
|
||||||
|
})
|
||||||
|
self.wfile.write(return_bytes(t.render(c), "utf8"))
|
||||||
|
else:
|
||||||
|
self.return_404()
|
||||||
|
|
||||||
|
def do_POST(self):
|
||||||
|
url = urlparse(self.path)
|
||||||
|
self.params = dict(parse_qsl(url.query))
|
||||||
|
if url.path == "/samlValidate":
|
||||||
|
self.send_headers(200, "text/xml; charset=utf-8")
|
||||||
|
length = int(self.headers.get('content-length'))
|
||||||
|
root = etree.fromstring(self.rfile.read(length))
|
||||||
|
auth_req = root.getchildren()[1].getchildren()[0]
|
||||||
|
ticket = auth_req.getchildren()[0].text.encode("ascii")
|
||||||
|
if (
|
||||||
|
self.server.ticket is not None and
|
||||||
|
self.params.get("TARGET").encode("ascii") == self.server.service and
|
||||||
|
ticket == self.server.ticket
|
||||||
|
):
|
||||||
|
self.server.ticket = None
|
||||||
|
t = loader.get_template('cas_server/samlValidate.xml')
|
||||||
|
c = Context({
|
||||||
|
'IssueInstant': timezone.now().isoformat(),
|
||||||
|
'expireInstant': (timezone.now() + timedelta(seconds=60)).isoformat(),
|
||||||
|
'Recipient': self.server.service,
|
||||||
|
'ResponseID': utils.gen_saml_id(),
|
||||||
|
'username': self.server.username,
|
||||||
|
'attributes': self.server.attributes,
|
||||||
|
})
|
||||||
|
self.wfile.write(return_bytes(t.render(c), "utf8"))
|
||||||
|
else:
|
||||||
|
t = loader.get_template('cas_server/samlValidateError.xml')
|
||||||
|
c = Context({
|
||||||
|
'IssueInstant': timezone.now().isoformat(),
|
||||||
|
'ResponseID': utils.gen_saml_id(),
|
||||||
|
'code': 'BAD_SERVICE_TICKET',
|
||||||
|
'msg': 'Valids are (%r, %r)' % (self.server.service, self.server.ticket)
|
||||||
|
})
|
||||||
|
self.wfile.write(return_bytes(t.render(c), "utf8"))
|
||||||
|
else:
|
||||||
|
self.return_404()
|
||||||
|
|
||||||
|
def return_404(self):
|
||||||
|
self.send_response(404)
|
||||||
|
self.send_header(b"Content-type", "text/plain")
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write("not found")
|
||||||
|
|
||||||
|
def log_message(self, *args):
|
||||||
|
"""silent any log message"""
|
||||||
|
return
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def run(cls, service, ticket, username, attributes, port=0):
|
||||||
|
"""Run a BaseHTTPServer using this class as handler"""
|
||||||
|
server_class = BaseHTTPServer.HTTPServer
|
||||||
|
httpd = server_class(("127.0.0.1", port), cls)
|
||||||
|
httpd.service = service
|
||||||
|
httpd.ticket = ticket
|
||||||
|
httpd.username = username
|
||||||
|
httpd.attributes = attributes
|
||||||
|
(host, port) = httpd.socket.getsockname()
|
||||||
|
|
||||||
|
def lauch():
|
||||||
|
"""routine to lauch in a background thread"""
|
||||||
|
httpd.handle_request()
|
||||||
|
httpd.server_close()
|
||||||
|
|
||||||
|
httpd_thread = Thread(target=lauch)
|
||||||
|
httpd_thread.daemon = True
|
||||||
|
httpd_thread.start()
|
||||||
|
return (httpd, host, port)
|
||||||
|
|
||||||
|
|
||||||
|
def logout_request(ticket):
|
||||||
|
return u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
|
||||||
|
ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
|
||||||
|
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
|
||||||
|
<samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
|
||||||
|
</samlp:LogoutRequest>""" % \
|
||||||
|
{
|
||||||
|
'id': utils.gen_saml_id(),
|
||||||
|
'datetime': timezone.now().isoformat(),
|
||||||
|
'ticket': ticket
|
||||||
|
}
|
||||||
|
|
|
@ -123,15 +123,18 @@ class LogoutView(View, LogoutMixin):
|
||||||
self.init_get(request)
|
self.init_get(request)
|
||||||
# if CAS federation mode is enable, bakup the provider before flushing the sessions
|
# if CAS federation mode is enable, bakup the provider before flushing the sessions
|
||||||
if settings.CAS_FEDERATE:
|
if settings.CAS_FEDERATE:
|
||||||
component = self.request.session.get("username").split('@')
|
if "username" in self.request.session:
|
||||||
provider = component[-1]
|
component = self.request.session["username"].split('@')
|
||||||
auth = CASFederateValidateUser(provider, service_url="")
|
provider = component[-1]
|
||||||
|
auth = CASFederateValidateUser(provider, service_url="")
|
||||||
|
else:
|
||||||
|
auth = None
|
||||||
session_nb = self.logout(self.request.GET.get("all"))
|
session_nb = self.logout(self.request.GET.get("all"))
|
||||||
# if CAS federation mode is enable, redirect to user CAS logout page
|
# if CAS federation mode is enable, redirect to user CAS logout page
|
||||||
if settings.CAS_FEDERATE:
|
if settings.CAS_FEDERATE:
|
||||||
params = utils.copy_params(request.GET)
|
if auth is not None:
|
||||||
url = utils.update_url(auth.get_logout_url(), params)
|
params = utils.copy_params(request.GET)
|
||||||
if url:
|
url = utils.update_url(auth.get_logout_url(), params)
|
||||||
return HttpResponseRedirect(url)
|
return HttpResponseRedirect(url)
|
||||||
# if service is set, redirect to service after logout
|
# if service is set, redirect to service after logout
|
||||||
if self.service:
|
if self.service:
|
||||||
|
@ -195,7 +198,7 @@ class FederateAuth(View):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_cas_client(request, provider):
|
def get_cas_client(request, provider):
|
||||||
if provider in settings.CAS_FEDERATE_PROVIDERS:
|
if provider in settings.CAS_FEDERATE_PROVIDERS: # pragma: no branch (should always be true)
|
||||||
service_url = utils.get_current_url(request, {"ticket", "provider"})
|
service_url = utils.get_current_url(request, {"ticket", "provider"})
|
||||||
return CASFederateValidateUser(provider, service_url)
|
return CASFederateValidateUser(provider, service_url)
|
||||||
|
|
||||||
|
@ -207,14 +210,14 @@ class FederateAuth(View):
|
||||||
auth = self.get_cas_client(request, provider)
|
auth = self.get_cas_client(request, provider)
|
||||||
try:
|
try:
|
||||||
auth.clean_sessions(request.POST['logoutRequest'])
|
auth.clean_sessions(request.POST['logoutRequest'])
|
||||||
except KeyError:
|
except (KeyError, AttributeError):
|
||||||
pass
|
pass
|
||||||
return HttpResponse("ok")
|
return HttpResponse("ok")
|
||||||
# else, a User is trying to log in using an identity provider
|
# else, a User is trying to log in using an identity provider
|
||||||
else:
|
else:
|
||||||
# Manually checking for csrf to protect the code below
|
# Manually checking for csrf to protect the code below
|
||||||
reason = CsrfViewMiddleware().process_view(request, None, (), {})
|
reason = CsrfViewMiddleware().process_view(request, None, (), {})
|
||||||
if reason is not None:
|
if reason is not None: # pragma: no cover (csrf checks are disabled during tests)
|
||||||
return reason # Failed the test, stop here.
|
return reason # Failed the test, stop here.
|
||||||
form = forms.FederateSelect(request.POST)
|
form = forms.FederateSelect(request.POST)
|
||||||
if form.is_valid():
|
if form.is_valid():
|
||||||
|
@ -252,7 +255,7 @@ class FederateAuth(View):
|
||||||
ticket = request.GET['ticket']
|
ticket = request.GET['ticket']
|
||||||
if auth.verify_ticket(ticket):
|
if auth.verify_ticket(ticket):
|
||||||
params = utils.copy_params(request.GET, ignore={"ticket"})
|
params = utils.copy_params(request.GET, ignore={"ticket"})
|
||||||
username = "%s@%s" % (auth.username, auth.provider)
|
username = u"%s@%s" % (auth.username, auth.provider)
|
||||||
request.session["federate_username"] = username
|
request.session["federate_username"] = username
|
||||||
request.session["federate_ticket"] = ticket
|
request.session["federate_ticket"] = ticket
|
||||||
auth.register_slo(username, request.session.session_key, ticket)
|
auth.register_slo(username, request.session.session_key, ticket)
|
||||||
|
@ -281,9 +284,9 @@ class LoginView(View, LogoutMixin):
|
||||||
renewed = False
|
renewed = False
|
||||||
warned = False
|
warned = False
|
||||||
|
|
||||||
if settings.CAS_FEDERATE:
|
# used if CAS_FEDERATE is True
|
||||||
username = None
|
username = None
|
||||||
ticket = None
|
ticket = None
|
||||||
|
|
||||||
INVALID_LOGIN_TICKET = 1
|
INVALID_LOGIN_TICKET = 1
|
||||||
USER_LOGIN_OK = 2
|
USER_LOGIN_OK = 2
|
||||||
|
@ -354,7 +357,7 @@ class LoginView(View, LogoutMixin):
|
||||||
elif ret == self.USER_LOGIN_FAILURE: # bad user login
|
elif ret == self.USER_LOGIN_FAILURE: # bad user login
|
||||||
if settings.CAS_FEDERATE:
|
if settings.CAS_FEDERATE:
|
||||||
self.ticket = None
|
self.ticket = None
|
||||||
self.usernalme = None
|
self.username = None
|
||||||
self.init_form()
|
self.init_form()
|
||||||
self.logout()
|
self.logout()
|
||||||
elif ret == self.USER_ALREADY_LOGGED:
|
elif ret == self.USER_ALREADY_LOGGED:
|
||||||
|
@ -682,11 +685,14 @@ class Auth(View):
|
||||||
secret = request.POST.get('secret')
|
secret = request.POST.get('secret')
|
||||||
|
|
||||||
if not settings.CAS_AUTH_SHARED_SECRET:
|
if not settings.CAS_AUTH_SHARED_SECRET:
|
||||||
return HttpResponse("no\nplease set CAS_AUTH_SHARED_SECRET", content_type="text/plain")
|
return HttpResponse(
|
||||||
|
"no\nplease set CAS_AUTH_SHARED_SECRET",
|
||||||
|
content_type="text/plain; charset=utf-8"
|
||||||
|
)
|
||||||
if secret != settings.CAS_AUTH_SHARED_SECRET:
|
if secret != settings.CAS_AUTH_SHARED_SECRET:
|
||||||
return HttpResponse("no\n", content_type="text/plain")
|
return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
|
||||||
if not username or not password or not service:
|
if not username or not password or not service:
|
||||||
return HttpResponse("no\n", content_type="text/plain")
|
return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
|
||||||
form = forms.UserCredential(
|
form = forms.UserCredential(
|
||||||
request.POST,
|
request.POST,
|
||||||
initial={
|
initial={
|
||||||
|
@ -714,11 +720,11 @@ class Auth(View):
|
||||||
service_pattern.check_user(user)
|
service_pattern.check_user(user)
|
||||||
if not request.session.get("authenticated"):
|
if not request.session.get("authenticated"):
|
||||||
user.delete()
|
user.delete()
|
||||||
return HttpResponse("yes\n", content_type="text/plain")
|
return HttpResponse(u"yes\n", content_type="text/plain; charset=utf-8")
|
||||||
except (ServicePattern.DoesNotExist, models.ServicePatternException):
|
except (ServicePattern.DoesNotExist, models.ServicePatternException):
|
||||||
return HttpResponse("no\n", content_type="text/plain")
|
return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
|
||||||
else:
|
else:
|
||||||
return HttpResponse("no\n", content_type="text/plain")
|
return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
|
||||||
|
|
||||||
|
|
||||||
class Validate(View):
|
class Validate(View):
|
||||||
|
@ -758,7 +764,10 @@ class Validate(View):
|
||||||
username = username[0]
|
username = username[0]
|
||||||
else:
|
else:
|
||||||
username = ticket.user.username
|
username = ticket.user.username
|
||||||
return HttpResponse("yes\n%s\n" % username, content_type="text/plain")
|
return HttpResponse(
|
||||||
|
u"yes\n%s\n" % username,
|
||||||
|
content_type="text/plain; charset=utf-8"
|
||||||
|
)
|
||||||
except ServiceTicket.DoesNotExist:
|
except ServiceTicket.DoesNotExist:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
(
|
(
|
||||||
|
@ -769,10 +778,10 @@ class Validate(View):
|
||||||
service
|
service
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return HttpResponse("no\n", content_type="text/plain")
|
return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
|
||||||
else:
|
else:
|
||||||
logger.warning("Validate: service or ticket missing")
|
logger.warning("Validate: service or ticket missing")
|
||||||
return HttpResponse("no\n", content_type="text/plain")
|
return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
|
||||||
|
|
||||||
|
|
||||||
class ValidateError(Exception):
|
class ValidateError(Exception):
|
||||||
|
@ -815,8 +824,8 @@ class ValidateService(View, AttributesMixin):
|
||||||
if not self.service or not self.ticket:
|
if not self.service or not self.ticket:
|
||||||
logger.warning("ValidateService: missing ticket or service")
|
logger.warning("ValidateService: missing ticket or service")
|
||||||
return ValidateError(
|
return ValidateError(
|
||||||
'INVALID_REQUEST',
|
u'INVALID_REQUEST',
|
||||||
"you must specify a service and a ticket"
|
u"you must specify a service and a ticket"
|
||||||
).render(request)
|
).render(request)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
@ -886,14 +895,14 @@ class ValidateService(View, AttributesMixin):
|
||||||
for prox in ticket.proxies.all():
|
for prox in ticket.proxies.all():
|
||||||
proxies.append(prox.url)
|
proxies.append(prox.url)
|
||||||
else:
|
else:
|
||||||
raise ValidateError('INVALID_TICKET', self.ticket)
|
raise ValidateError(u'INVALID_TICKET', self.ticket)
|
||||||
ticket.validate = True
|
ticket.validate = True
|
||||||
ticket.save()
|
ticket.save()
|
||||||
if ticket.service != self.service:
|
if ticket.service != self.service:
|
||||||
raise ValidateError('INVALID_SERVICE', self.service)
|
raise ValidateError(u'INVALID_SERVICE', self.service)
|
||||||
return ticket, proxies
|
return ticket, proxies
|
||||||
except (ServiceTicket.DoesNotExist, ProxyTicket.DoesNotExist):
|
except (ServiceTicket.DoesNotExist, ProxyTicket.DoesNotExist):
|
||||||
raise ValidateError('INVALID_TICKET', 'ticket not found')
|
raise ValidateError(u'INVALID_TICKET', 'ticket not found')
|
||||||
|
|
||||||
def process_pgturl(self, params):
|
def process_pgturl(self, params):
|
||||||
"""Handle PGT request"""
|
"""Handle PGT request"""
|
||||||
|
@ -939,18 +948,18 @@ class ValidateService(View, AttributesMixin):
|
||||||
except requests.exceptions.RequestException as error:
|
except requests.exceptions.RequestException as error:
|
||||||
error = utils.unpack_nested_exception(error)
|
error = utils.unpack_nested_exception(error)
|
||||||
raise ValidateError(
|
raise ValidateError(
|
||||||
'INVALID_PROXY_CALLBACK',
|
u'INVALID_PROXY_CALLBACK',
|
||||||
"%s: %s" % (type(error), str(error))
|
u"%s: %s" % (type(error), str(error))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValidateError(
|
raise ValidateError(
|
||||||
'INVALID_PROXY_CALLBACK',
|
u'INVALID_PROXY_CALLBACK',
|
||||||
"callback url not allowed by configuration"
|
u"callback url not allowed by configuration"
|
||||||
)
|
)
|
||||||
except ServicePattern.DoesNotExist:
|
except ServicePattern.DoesNotExist:
|
||||||
raise ValidateError(
|
raise ValidateError(
|
||||||
'INVALID_PROXY_CALLBACK',
|
u'INVALID_PROXY_CALLBACK',
|
||||||
'callback url not allowed by configuration'
|
u'callback url not allowed by configuration'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -971,8 +980,8 @@ class Proxy(View):
|
||||||
return self.process_proxy()
|
return self.process_proxy()
|
||||||
else:
|
else:
|
||||||
raise ValidateError(
|
raise ValidateError(
|
||||||
'INVALID_REQUEST',
|
u'INVALID_REQUEST',
|
||||||
"you must specify and pgt and targetService"
|
u"you must specify and pgt and targetService"
|
||||||
)
|
)
|
||||||
except ValidateError as error:
|
except ValidateError as error:
|
||||||
logger.warning("Proxy: validation error: %s %s" % (error.code, error.msg))
|
logger.warning("Proxy: validation error: %s %s" % (error.code, error.msg))
|
||||||
|
@ -985,8 +994,8 @@ class Proxy(View):
|
||||||
pattern = ServicePattern.validate(self.target_service)
|
pattern = ServicePattern.validate(self.target_service)
|
||||||
if not pattern.proxy:
|
if not pattern.proxy:
|
||||||
raise ValidateError(
|
raise ValidateError(
|
||||||
'UNAUTHORIZED_SERVICE',
|
u'UNAUTHORIZED_SERVICE',
|
||||||
'the service %s do not allow proxy ticket' % self.target_service
|
u'the service %s do not allow proxy ticket' % self.target_service
|
||||||
)
|
)
|
||||||
# is the proxy granting ticket valid
|
# is the proxy granting ticket valid
|
||||||
ticket = ProxyGrantingTicket.objects.get(
|
ticket = ProxyGrantingTicket.objects.get(
|
||||||
|
@ -1015,13 +1024,13 @@ class Proxy(View):
|
||||||
content_type="text/xml; charset=utf-8"
|
content_type="text/xml; charset=utf-8"
|
||||||
)
|
)
|
||||||
except ProxyGrantingTicket.DoesNotExist:
|
except ProxyGrantingTicket.DoesNotExist:
|
||||||
raise ValidateError('INVALID_TICKET', 'PGT %s not found' % self.pgt)
|
raise ValidateError(u'INVALID_TICKET', u'PGT %s not found' % self.pgt)
|
||||||
except ServicePattern.DoesNotExist:
|
except ServicePattern.DoesNotExist:
|
||||||
raise ValidateError('UNAUTHORIZED_SERVICE', self.target_service)
|
raise ValidateError(u'UNAUTHORIZED_SERVICE', self.target_service)
|
||||||
except (models.BadUsername, models.BadFilter, models.UserFieldNotDefined):
|
except (models.BadUsername, models.BadFilter, models.UserFieldNotDefined):
|
||||||
raise ValidateError(
|
raise ValidateError(
|
||||||
'UNAUTHORIZED_USER',
|
u'UNAUTHORIZED_USER',
|
||||||
'User %s not allowed on %s' % (ticket.user.username, self.target_service)
|
u'User %s not allowed on %s' % (ticket.user.username, self.target_service)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1129,18 +1138,18 @@ class SamlValidate(View, AttributesMixin):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise SamlValidateError(
|
raise SamlValidateError(
|
||||||
'AuthnFailed',
|
u'AuthnFailed',
|
||||||
'ticket %s should begin with PT- or ST-' % ticket
|
u'ticket %s should begin with PT- or ST-' % ticket
|
||||||
)
|
)
|
||||||
ticket.validate = True
|
ticket.validate = True
|
||||||
ticket.save()
|
ticket.save()
|
||||||
if ticket.service != self.target:
|
if ticket.service != self.target:
|
||||||
raise SamlValidateError(
|
raise SamlValidateError(
|
||||||
'AuthnFailed',
|
u'AuthnFailed',
|
||||||
'TARGET %s do not match ticket service' % self.target
|
u'TARGET %s do not match ticket service' % self.target
|
||||||
)
|
)
|
||||||
return ticket
|
return ticket
|
||||||
except (IndexError, KeyError):
|
except (IndexError, KeyError):
|
||||||
raise SamlValidateError('VersionMismatch')
|
raise SamlValidateError(u'VersionMismatch')
|
||||||
except (ServiceTicket.DoesNotExist, ProxyTicket.DoesNotExist):
|
except (ServiceTicket.DoesNotExist, ProxyTicket.DoesNotExist):
|
||||||
raise SamlValidateError('AuthnFailed', 'ticket %s not found' % ticket)
|
raise SamlValidateError(u'AuthnFailed', u'ticket %s not found' % ticket)
|
||||||
|
|
Loading…
Reference in a new issue