commit
eb64412612
30 changed files with 2812 additions and 902 deletions
|
@ -1,3 +1,11 @@
|
||||||
|
[run]
|
||||||
|
branch = True
|
||||||
|
source = cas_server
|
||||||
|
omit =
|
||||||
|
cas_server/migrations*
|
||||||
|
cas_server/management/*
|
||||||
|
cas_server/tests/*
|
||||||
|
|
||||||
[report]
|
[report]
|
||||||
exclude_lines =
|
exclude_lines =
|
||||||
pragma: no cover
|
pragma: no cover
|
||||||
|
@ -5,3 +13,4 @@ exclude_lines =
|
||||||
def __unicode__
|
def __unicode__
|
||||||
raise AssertionError
|
raise AssertionError
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
if six.PY3:
|
||||||
|
|
|
@ -2,19 +2,19 @@ language: python
|
||||||
python:
|
python:
|
||||||
- "2.7"
|
- "2.7"
|
||||||
env:
|
env:
|
||||||
global:
|
|
||||||
- PIP_DOWNLOAD_CACHE=$HOME/.pip_cache
|
|
||||||
matrix:
|
matrix:
|
||||||
|
- TOX_ENV=coverage
|
||||||
|
- TOX_ENV=flake8
|
||||||
- TOX_ENV=py27-django17
|
- TOX_ENV=py27-django17
|
||||||
- TOX_ENV=py27-django18
|
- TOX_ENV=py27-django18
|
||||||
- TOX_ENV=py27-django19
|
- TOX_ENV=py27-django19
|
||||||
- TOX_ENV=py34-django17
|
- TOX_ENV=py34-django17
|
||||||
- TOX_ENV=py34-django18
|
- TOX_ENV=py34-django18
|
||||||
- TOX_ENV=py34-django19
|
- TOX_ENV=py34-django19
|
||||||
- TOX_ENV=flake8
|
|
||||||
cache:
|
cache:
|
||||||
directories:
|
directories:
|
||||||
- $HOME/.pip-cache/
|
- $HOME/.cache/pip/
|
||||||
|
- $HOME/build/nitmir/django-cas-server/.tox/
|
||||||
install:
|
install:
|
||||||
- "travis_retry pip install setuptools --upgrade"
|
- "travis_retry pip install setuptools --upgrade"
|
||||||
- "pip install tox"
|
- "pip install tox"
|
||||||
|
@ -22,4 +22,3 @@ script:
|
||||||
- tox -e $TOX_ENV
|
- tox -e $TOX_ENV
|
||||||
after_script:
|
after_script:
|
||||||
- cat .tox/$TOX_ENV/log/*.log
|
- cat .tox/$TOX_ENV/log/*.log
|
||||||
|
|
||||||
|
|
45
Makefile
45
Makefile
|
@ -1,11 +1,15 @@
|
||||||
.PHONY: clean build install dist test_venv test_project
|
.PHONY: build dist
|
||||||
VERSION=`python setup.py -V`
|
VERSION=`python setup.py -V`
|
||||||
|
|
||||||
build:
|
build:
|
||||||
python setup.py build
|
python setup.py build
|
||||||
|
|
||||||
install:
|
install: dist
|
||||||
python setup.py install
|
pip -V
|
||||||
|
pip install --no-deps --upgrade --force-reinstall --find-links ./dist/django-cas-server-${VERSION}.tar.gz django-cas-server
|
||||||
|
|
||||||
|
uninstall:
|
||||||
|
pip uninstall django-cas-server || true
|
||||||
|
|
||||||
clean_pyc:
|
clean_pyc:
|
||||||
find ./ -name '*.pyc' -delete
|
find ./ -name '*.pyc' -delete
|
||||||
|
@ -16,18 +20,23 @@ clean_tox:
|
||||||
rm -rf .tox
|
rm -rf .tox
|
||||||
clean_test_venv:
|
clean_test_venv:
|
||||||
rm -rf test_venv
|
rm -rf test_venv
|
||||||
clean: clean_pyc clean_build
|
clean_coverage:
|
||||||
clean_all: clean_pyc clean_build clean_tox clean_test_venv
|
rm -rf coverage.xml .coverage htmlcov
|
||||||
|
clean_tild_backup:
|
||||||
|
find ./ -name '*~' -delete
|
||||||
|
|
||||||
|
clean: clean_pyc clean_build clean_coverage clean_tild_backup
|
||||||
|
|
||||||
|
clean_all: clean clean_tox clean_test_venv
|
||||||
|
|
||||||
dist:
|
dist:
|
||||||
python setup.py sdist
|
python setup.py sdist
|
||||||
|
|
||||||
test_venv:
|
test_venv/bin/python:
|
||||||
mkdir -p test_venv
|
|
||||||
virtualenv test_venv
|
virtualenv test_venv
|
||||||
test_venv/bin/pip install -U --requirement requirements.txt
|
test_venv/bin/pip install -U --requirement requirements-dev.txt Django
|
||||||
|
|
||||||
test_venv/cas/manage.py:
|
test_venv/cas/manage.py: test_venv
|
||||||
mkdir -p test_venv/cas
|
mkdir -p test_venv/cas
|
||||||
test_venv/bin/django-admin startproject cas test_venv/cas
|
test_venv/bin/django-admin startproject cas test_venv/cas
|
||||||
ln -s ../../cas_server test_venv/cas/cas_server
|
ln -s ../../cas_server test_venv/cas/cas_server
|
||||||
|
@ -38,19 +47,15 @@ test_venv/cas/manage.py:
|
||||||
test_venv/bin/python test_venv/cas/manage.py migrate
|
test_venv/bin/python test_venv/cas/manage.py migrate
|
||||||
test_venv/bin/python test_venv/cas/manage.py createsuperuser
|
test_venv/bin/python test_venv/cas/manage.py createsuperuser
|
||||||
|
|
||||||
test_project: test_venv test_venv/cas/manage.py
|
test_venv: test_venv/bin/python
|
||||||
|
|
||||||
|
test_project: test_venv/cas/manage.py
|
||||||
@echo "##############################################################"
|
@echo "##############################################################"
|
||||||
@echo "A test django project was created in $(realpath test_venv/cas)"
|
@echo "A test django project was created in $(realpath test_venv/cas)"
|
||||||
|
|
||||||
run_test_server: test_project
|
run_server: test_project
|
||||||
test_venv/bin/python test_venv/cas/manage.py runserver
|
test_venv/bin/python test_venv/cas/manage.py runserver
|
||||||
|
|
||||||
coverage: test_venv
|
run_tests: test_venv
|
||||||
test_venv/bin/pip install coverage
|
test_venv/bin/py.test --cov=cas_server --cov-report html
|
||||||
test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests
|
rm htmlcov/coverage_html.js # I am really pissed off by those keybord shortcuts
|
||||||
test_venv/bin/coverage html
|
|
||||||
test_venv/bin/coverage xml
|
|
||||||
|
|
||||||
coverage_codacy: coverage
|
|
||||||
test_venv/bin/pip install codacy-coverage
|
|
||||||
test_venv/bin/python-codacy-coverage -r coverage.xml
|
|
||||||
|
|
|
@ -219,7 +219,8 @@ Test backend settings. Only usefull if you are using the test authentication bac
|
||||||
* ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``.
|
* ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``.
|
||||||
* ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``.
|
* ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``.
|
||||||
* ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is
|
* ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is
|
||||||
``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}``.
|
``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net',
|
||||||
|
'alias': ['demo1', 'demo2']}``.
|
||||||
|
|
||||||
|
|
||||||
Authentication backend
|
Authentication backend
|
||||||
|
|
|
@ -7,6 +7,6 @@
|
||||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
||||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||||
#
|
#
|
||||||
# (c) 2015 Valentin Samir
|
# (c) 2015-2016 Valentin Samir
|
||||||
|
"""A django CAS server application"""
|
||||||
default_app_config = 'cas_server.apps.CasAppConfig'
|
default_app_config = 'cas_server.apps.CasAppConfig'
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
||||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||||
#
|
#
|
||||||
# (c) 2015 Valentin Samir
|
# (c) 2015-2016 Valentin Samir
|
||||||
"""module for the admin interface of the app"""
|
"""module for the admin interface of the app"""
|
||||||
from django.contrib import admin
|
from django.contrib import admin
|
||||||
from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket, User, ServicePattern
|
from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket, User, ServicePattern
|
||||||
|
|
|
@ -1,7 +1,19 @@
|
||||||
|
# This program is distributed in the hope that it will be useful, but WITHOUT
|
||||||
|
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
||||||
|
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
|
||||||
|
# more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU General Public License version 3
|
||||||
|
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
||||||
|
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||||
|
#
|
||||||
|
# (c) 2015-2016 Valentin Samir
|
||||||
|
"""django config module"""
|
||||||
from django.utils.translation import ugettext_lazy as _
|
from django.utils.translation import ugettext_lazy as _
|
||||||
from django.apps import AppConfig
|
from django.apps import AppConfig
|
||||||
|
|
||||||
|
|
||||||
class CasAppConfig(AppConfig):
|
class CasAppConfig(AppConfig):
|
||||||
|
"""django CAS application config class"""
|
||||||
name = 'cas_server'
|
name = 'cas_server'
|
||||||
verbose_name = _('Central Authentication Service')
|
verbose_name = _('Central Authentication Service')
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
||||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||||
#
|
#
|
||||||
# (c) 2015 Valentin Samir
|
# (c) 2015-2016 Valentin Samir
|
||||||
"""Some authentication classes for the CAS"""
|
"""Some authentication classes for the CAS"""
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.contrib.auth import get_user_model
|
from django.contrib.auth import get_user_model
|
||||||
|
@ -21,6 +21,7 @@ except ImportError:
|
||||||
|
|
||||||
|
|
||||||
class AuthUser(object):
|
class AuthUser(object):
|
||||||
|
"""Authentication base class"""
|
||||||
def __init__(self, username):
|
def __init__(self, username):
|
||||||
self.username = username
|
self.username = username
|
||||||
|
|
||||||
|
|
|
@ -78,5 +78,12 @@ setting_default('CAS_TEST_USER', 'test')
|
||||||
setting_default('CAS_TEST_PASSWORD', 'test')
|
setting_default('CAS_TEST_PASSWORD', 'test')
|
||||||
setting_default(
|
setting_default(
|
||||||
'CAS_TEST_ATTRIBUTES',
|
'CAS_TEST_ATTRIBUTES',
|
||||||
{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}
|
{
|
||||||
|
'nom': 'Nymous',
|
||||||
|
'prenom': 'Ano',
|
||||||
|
'email': 'anonymous@example.net',
|
||||||
|
'alias': ['demo1', 'demo2']
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
setting_default('CAS_ENABLE_AJAX_AUTH', False)
|
||||||
|
|
|
@ -7,7 +7,7 @@
|
||||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
||||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||||
#
|
#
|
||||||
# (c) 2015 Valentin Samir
|
# (c) 2015-2016 Valentin Samir
|
||||||
"""forms for the app"""
|
"""forms for the app"""
|
||||||
from .default_settings import settings
|
from .default_settings import settings
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ import cas_server.models as models
|
||||||
|
|
||||||
|
|
||||||
class WarnForm(forms.Form):
|
class WarnForm(forms.Form):
|
||||||
|
"""Form used on warn page before emiting a ticket"""
|
||||||
service = forms.CharField(widget=forms.HiddenInput(), required=False)
|
service = forms.CharField(widget=forms.HiddenInput(), required=False)
|
||||||
renew = forms.BooleanField(widget=forms.HiddenInput(), required=False)
|
renew = forms.BooleanField(widget=forms.HiddenInput(), required=False)
|
||||||
gateway = forms.CharField(widget=forms.HiddenInput(), required=False)
|
gateway = forms.CharField(widget=forms.HiddenInput(), required=False)
|
||||||
|
@ -35,6 +36,7 @@ class UserCredential(forms.Form):
|
||||||
lt = forms.CharField(widget=forms.HiddenInput(), required=False)
|
lt = forms.CharField(widget=forms.HiddenInput(), required=False)
|
||||||
method = forms.CharField(widget=forms.HiddenInput(), required=False)
|
method = forms.CharField(widget=forms.HiddenInput(), required=False)
|
||||||
warn = forms.BooleanField(label=_('warn'), required=False)
|
warn = forms.BooleanField(label=_('warn'), required=False)
|
||||||
|
renew = forms.BooleanField(widget=forms.HiddenInput(), required=False)
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(UserCredential, self).__init__(*args, **kwargs)
|
super(UserCredential, self).__init__(*args, **kwargs)
|
||||||
|
@ -46,6 +48,7 @@ class UserCredential(forms.Form):
|
||||||
cleaned_data["username"] = auth.username
|
cleaned_data["username"] = auth.username
|
||||||
else:
|
else:
|
||||||
raise forms.ValidationError(_(u"Bad user"))
|
raise forms.ValidationError(_(u"Bad user"))
|
||||||
|
return cleaned_data
|
||||||
|
|
||||||
|
|
||||||
class TicketForm(forms.ModelForm):
|
class TicketForm(forms.ModelForm):
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
"""Clean deleted sessions management command"""
|
||||||
from django.core.management.base import BaseCommand
|
from django.core.management.base import BaseCommand
|
||||||
from django.utils.translation import ugettext_lazy as _
|
from django.utils.translation import ugettext_lazy as _
|
||||||
|
|
||||||
|
@ -5,6 +6,7 @@ from ... import models
|
||||||
|
|
||||||
|
|
||||||
class Command(BaseCommand):
|
class Command(BaseCommand):
|
||||||
|
"""Clean deleted sessions"""
|
||||||
args = ''
|
args = ''
|
||||||
help = _(u"Clean deleted sessions")
|
help = _(u"Clean deleted sessions")
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
"""Clean old trickets management command"""
|
||||||
from django.core.management.base import BaseCommand
|
from django.core.management.base import BaseCommand
|
||||||
from django.utils.translation import ugettext_lazy as _
|
from django.utils.translation import ugettext_lazy as _
|
||||||
|
|
||||||
|
@ -5,6 +6,7 @@ from ... import models
|
||||||
|
|
||||||
|
|
||||||
class Command(BaseCommand):
|
class Command(BaseCommand):
|
||||||
|
"""Clean old trickets"""
|
||||||
args = ''
|
args = ''
|
||||||
help = _(u"Clean old trickets")
|
help = _(u"Clean old trickets")
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
||||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||||
#
|
#
|
||||||
# (c) 2015 Valentin Samir
|
# (c) 2015-2016 Valentin Samir
|
||||||
"""models for the app"""
|
"""models for the app"""
|
||||||
from .default_settings import settings
|
from .default_settings import settings
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ from django.utils import timezone
|
||||||
from picklefield.fields import PickledObjectField
|
from picklefield.fields import PickledObjectField
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
import logging
|
import logging
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
@ -47,6 +46,7 @@ class User(models.Model):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def clean_old_entries(cls):
|
def clean_old_entries(cls):
|
||||||
|
"""Remove users inactive since more that SESSION_COOKIE_AGE"""
|
||||||
users = cls.objects.filter(
|
users = cls.objects.filter(
|
||||||
date__lt=(timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE))
|
date__lt=(timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE))
|
||||||
)
|
)
|
||||||
|
@ -56,6 +56,7 @@ class User(models.Model):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def clean_deleted_sessions(cls):
|
def clean_deleted_sessions(cls):
|
||||||
|
"""Remove user where the session do not exists anymore"""
|
||||||
for user in cls.objects.all():
|
for user in cls.objects.all():
|
||||||
if not SessionStore(session_key=user.session_key).get('authenticated'):
|
if not SessionStore(session_key=user.session_key).get('authenticated'):
|
||||||
user.logout()
|
user.logout()
|
||||||
|
@ -80,10 +81,10 @@ class User(models.Model):
|
||||||
for ticket_class in ticket_classes:
|
for ticket_class in ticket_classes:
|
||||||
queryset = ticket_class.objects.filter(user=self)
|
queryset = ticket_class.objects.filter(user=self)
|
||||||
for ticket in queryset:
|
for ticket in queryset:
|
||||||
ticket.logout(request, session, async_list)
|
ticket.logout(session, async_list)
|
||||||
queryset.delete()
|
queryset.delete()
|
||||||
for future in async_list:
|
for future in async_list:
|
||||||
if future:
|
if future: # pragma: no branch (should always be true)
|
||||||
try:
|
try:
|
||||||
future.result()
|
future.result()
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
|
@ -111,13 +112,21 @@ class User(models.Model):
|
||||||
(a.name, a.replace if a.replace else a.name) for a in service_pattern.attributs.all()
|
(a.name, a.replace if a.replace else a.name) for a in service_pattern.attributs.all()
|
||||||
)
|
)
|
||||||
replacements = dict(
|
replacements = dict(
|
||||||
(a.name, (a.pattern, a.replace)) for a in service_pattern.replacements.all()
|
(a.attribut, (a.pattern, a.replace)) for a in service_pattern.replacements.all()
|
||||||
)
|
)
|
||||||
service_attributs = {}
|
service_attributs = {}
|
||||||
for (key, value) in self.attributs.items():
|
for (key, value) in self.attributs.items():
|
||||||
if key in attributs or '*' in attributs:
|
if key in attributs or '*' in attributs:
|
||||||
if key in replacements:
|
if key in replacements:
|
||||||
value = re.sub(replacements[key][0], replacements[key][1], value)
|
if isinstance(value, list):
|
||||||
|
for index, subval in enumerate(value):
|
||||||
|
value[index] = re.sub(
|
||||||
|
replacements[key][0],
|
||||||
|
replacements[key][1],
|
||||||
|
subval
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
value = re.sub(replacements[key][0], replacements[key][1], value)
|
||||||
service_attributs[attributs.get(key, key)] = value
|
service_attributs[attributs.get(key, key)] = value
|
||||||
ticket = ticket_class.objects.create(
|
ticket = ticket_class.objects.create(
|
||||||
user=self,
|
user=self,
|
||||||
|
@ -141,6 +150,7 @@ class User(models.Model):
|
||||||
|
|
||||||
|
|
||||||
class ServicePatternException(Exception):
|
class ServicePatternException(Exception):
|
||||||
|
"""Base exception of exceptions raised in the ServicePattern model"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -394,77 +404,57 @@ class Ticket(models.Model):
|
||||||
).delete()
|
).delete()
|
||||||
|
|
||||||
# sending SLO to timed-out validated tickets
|
# sending SLO to timed-out validated tickets
|
||||||
if cls.TIMEOUT and cls.TIMEOUT > 0:
|
async_list = []
|
||||||
async_list = []
|
session = FuturesSession(
|
||||||
session = FuturesSession(
|
executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
|
||||||
executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
|
)
|
||||||
)
|
queryset = cls.objects.filter(
|
||||||
queryset = cls.objects.filter(
|
creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT))
|
||||||
creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT))
|
)
|
||||||
)
|
for ticket in queryset:
|
||||||
for ticket in queryset:
|
ticket.logout(session, async_list)
|
||||||
ticket.logout(None, session, async_list)
|
queryset.delete()
|
||||||
queryset.delete()
|
for future in async_list:
|
||||||
for future in async_list:
|
if future: # pragma: no branch (should always be true)
|
||||||
if future:
|
try:
|
||||||
try:
|
future.result()
|
||||||
future.result()
|
except Exception as error:
|
||||||
except Exception as error:
|
logger.warning("Error durring SLO %s" % error)
|
||||||
logger.warning("Error durring SLO %s" % error)
|
sys.stderr.write("%r\n" % error)
|
||||||
sys.stderr.write("%r\n" % error)
|
|
||||||
|
|
||||||
def logout(self, request, session, async_list=None):
|
def logout(self, session, async_list=None):
|
||||||
"""Send a SLO request to the ticket service"""
|
"""Send a SLO request to the ticket service"""
|
||||||
# On logout invalidate the Ticket
|
# On logout invalidate the Ticket
|
||||||
self.validate = True
|
self.validate = True
|
||||||
self.save()
|
self.save()
|
||||||
if self.validate and self.single_log_out:
|
if self.validate and self.single_log_out: # pragma: no branch (should always be true)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Sending SLO requests to service %s for user %s" % (
|
"Sending SLO requests to service %s for user %s" % (
|
||||||
self.service,
|
self.service,
|
||||||
self.user.username
|
self.user.username
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
try:
|
xml = u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
|
||||||
xml = u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
|
ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
|
||||||
ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
|
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
|
||||||
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
|
<samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
|
||||||
<samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
|
</samlp:LogoutRequest>""" % \
|
||||||
</samlp:LogoutRequest>""" % \
|
{
|
||||||
{
|
'id': utils.gen_saml_id(),
|
||||||
'id': os.urandom(20).encode("hex"),
|
'datetime': timezone.now().isoformat(),
|
||||||
'datetime': timezone.now().isoformat(),
|
'ticket': self.value
|
||||||
'ticket': self.value
|
}
|
||||||
}
|
if self.service_pattern.single_log_out_callback:
|
||||||
if self.service_pattern.single_log_out_callback:
|
url = self.service_pattern.single_log_out_callback
|
||||||
url = self.service_pattern.single_log_out_callback
|
else:
|
||||||
else:
|
url = self.service
|
||||||
url = self.service
|
async_list.append(
|
||||||
async_list.append(
|
session.post(
|
||||||
session.post(
|
url.encode('utf-8'),
|
||||||
url.encode('utf-8'),
|
data={'logoutRequest': xml.encode('utf-8')},
|
||||||
data={'logoutRequest': xml.encode('utf-8')},
|
timeout=settings.CAS_SLO_TIMEOUT
|
||||||
timeout=settings.CAS_SLO_TIMEOUT
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
except Exception as error:
|
)
|
||||||
error = utils.unpack_nested_exception(error)
|
|
||||||
logger.warning(
|
|
||||||
"Error durring SLO for user %s on service %s: %s" % (
|
|
||||||
self.user.username,
|
|
||||||
self.service,
|
|
||||||
error
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if request is not None:
|
|
||||||
messages.add_message(
|
|
||||||
request,
|
|
||||||
messages.WARNING,
|
|
||||||
_(u'Error during service logout %(service)s:\n%(error)s') %
|
|
||||||
{'service': self.service, 'error': error}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sys.stderr.write("%r\n" % error)
|
|
||||||
|
|
||||||
|
|
||||||
class ServiceTicket(Ticket):
|
class ServiceTicket(Ticket):
|
||||||
|
|
|
@ -1,702 +0,0 @@
|
||||||
from .default_settings import settings
|
|
||||||
|
|
||||||
from django.test import TestCase
|
|
||||||
from django.test import Client
|
|
||||||
|
|
||||||
import six
|
|
||||||
from lxml import etree
|
|
||||||
|
|
||||||
from cas_server import models
|
|
||||||
from cas_server import utils
|
|
||||||
|
|
||||||
|
|
||||||
def get_login_page_params():
|
|
||||||
client = Client()
|
|
||||||
response = client.get('/login')
|
|
||||||
form = response.context["form"]
|
|
||||||
params = {}
|
|
||||||
for field in form:
|
|
||||||
if field.value():
|
|
||||||
params[field.name] = field.value()
|
|
||||||
else:
|
|
||||||
params[field.name] = ""
|
|
||||||
return client, params
|
|
||||||
|
|
||||||
|
|
||||||
def get_auth_client():
|
|
||||||
client, params = get_login_page_params()
|
|
||||||
params["username"] = settings.CAS_TEST_USER
|
|
||||||
params["password"] = settings.CAS_TEST_PASSWORD
|
|
||||||
|
|
||||||
client.post('/login', params)
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
def get_user_ticket_request(service):
|
|
||||||
client = get_auth_client()
|
|
||||||
response = client.get("/login", {"service": service})
|
|
||||||
ticket_value = response['Location'].split('ticket=')[-1]
|
|
||||||
user = models.User.objects.get(
|
|
||||||
username=settings.CAS_TEST_USER,
|
|
||||||
session_key=client.session.session_key
|
|
||||||
)
|
|
||||||
ticket = models.ServiceTicket.objects.get(value=ticket_value)
|
|
||||||
return (user, ticket)
|
|
||||||
|
|
||||||
|
|
||||||
def get_pgt():
|
|
||||||
(host, port) = utils.PGTUrlHandler.run()[1:3]
|
|
||||||
service = "http://%s:%s" % (host, port)
|
|
||||||
|
|
||||||
(user, ticket) = get_user_ticket_request(service)
|
|
||||||
|
|
||||||
client = Client()
|
|
||||||
client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
|
|
||||||
params = utils.PGTUrlHandler.PARAMS.copy()
|
|
||||||
|
|
||||||
params["service"] = service
|
|
||||||
params["user"] = user
|
|
||||||
|
|
||||||
return params
|
|
||||||
|
|
||||||
|
|
||||||
class CheckPasswordCase(TestCase):
|
|
||||||
"""Tests for the utils function `utils.check_password`"""
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
"""Generate random bytes string that will be used ass passwords"""
|
|
||||||
self.password1 = utils.gen_saml_id()
|
|
||||||
self.password2 = utils.gen_saml_id()
|
|
||||||
if not isinstance(self.password1, bytes):
|
|
||||||
self.password1 = self.password1.encode("utf8")
|
|
||||||
self.password2 = self.password2.encode("utf8")
|
|
||||||
|
|
||||||
def test_setup(self):
|
|
||||||
"""check that generated password are bytes"""
|
|
||||||
self.assertIsInstance(self.password1, bytes)
|
|
||||||
self.assertIsInstance(self.password2, bytes)
|
|
||||||
|
|
||||||
def test_plain(self):
|
|
||||||
"""test the plain auth method"""
|
|
||||||
self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
|
|
||||||
self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
|
|
||||||
|
|
||||||
def test_crypt(self):
|
|
||||||
"""test the crypt auth method"""
|
|
||||||
if six.PY3:
|
|
||||||
hashed_password1 = utils.crypt.crypt(
|
|
||||||
self.password1.decode("utf8"),
|
|
||||||
"$6$UVVAQvrMyXMF3FF3"
|
|
||||||
).encode("utf8")
|
|
||||||
else:
|
|
||||||
hashed_password1 = utils.crypt.crypt(self.password1, "$6$UVVAQvrMyXMF3FF3")
|
|
||||||
|
|
||||||
self.assertTrue(utils.check_password("crypt", self.password1, hashed_password1, "utf8"))
|
|
||||||
self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8"))
|
|
||||||
|
|
||||||
def test_ldap_ssha(self):
|
|
||||||
"""test the ldap auth method with a {SSHA} scheme"""
|
|
||||||
salt = b"UVVAQvrMyXMF3FF3"
|
|
||||||
hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8")
|
|
||||||
|
|
||||||
self.assertIsInstance(hashed_password1, bytes)
|
|
||||||
self.assertTrue(utils.check_password("ldap", self.password1, hashed_password1, "utf8"))
|
|
||||||
self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8"))
|
|
||||||
|
|
||||||
def test_hex_md5(self):
|
|
||||||
"""test the hex_md5 auth method"""
|
|
||||||
hashed_password1 = utils.hashlib.md5(self.password1).hexdigest()
|
|
||||||
|
|
||||||
self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8"))
|
|
||||||
self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8"))
|
|
||||||
|
|
||||||
def test_hox_sha512(self):
|
|
||||||
"""test the hex_sha512 auth method"""
|
|
||||||
hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest()
|
|
||||||
|
|
||||||
self.assertTrue(
|
|
||||||
utils.check_password("hex_sha512", self.password1, hashed_password1, "utf8")
|
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
utils.check_password("hex_sha512", self.password2, hashed_password1, "utf8")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LoginTestCase(TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
|
|
||||||
self.service_pattern = models.ServicePattern.objects.create(
|
|
||||||
name="example",
|
|
||||||
pattern="^https://www\.example\.com(/.*)?$",
|
|
||||||
)
|
|
||||||
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
|
|
||||||
|
|
||||||
def test_login_view_post_goodpass_goodlt(self):
|
|
||||||
client, params = get_login_page_params()
|
|
||||||
params["username"] = settings.CAS_TEST_USER
|
|
||||||
params["password"] = settings.CAS_TEST_PASSWORD
|
|
||||||
|
|
||||||
response = client.post('/login', params)
|
|
||||||
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
self.assertTrue(
|
|
||||||
(
|
|
||||||
b"You have successfully logged into "
|
|
||||||
b"the Central Authentication Service"
|
|
||||||
) in response.content
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(
|
|
||||||
models.User.objects.get(
|
|
||||||
username=settings.CAS_TEST_USER,
|
|
||||||
session_key=client.session.session_key
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_login_view_post_badlt(self):
|
|
||||||
client, params = get_login_page_params()
|
|
||||||
params["username"] = settings.CAS_TEST_USER
|
|
||||||
params["password"] = settings.CAS_TEST_PASSWORD
|
|
||||||
params["lt"] = 'LT-random'
|
|
||||||
|
|
||||||
response = client.post('/login', params)
|
|
||||||
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
self.assertTrue(b"Invalid login ticket" in response.content)
|
|
||||||
self.assertFalse(
|
|
||||||
(
|
|
||||||
b"You have successfully logged into "
|
|
||||||
b"the Central Authentication Service"
|
|
||||||
) in response.content
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_login_view_post_badpass_good_lt(self):
|
|
||||||
client, params = get_login_page_params()
|
|
||||||
params["username"] = settings.CAS_TEST_USER
|
|
||||||
params["password"] = "test2"
|
|
||||||
response = client.post('/login', params)
|
|
||||||
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
self.assertTrue(
|
|
||||||
(
|
|
||||||
b"The credentials you provided cannot be "
|
|
||||||
b"determined to be authentic"
|
|
||||||
) in response.content
|
|
||||||
)
|
|
||||||
self.assertFalse(
|
|
||||||
(
|
|
||||||
b"You have successfully logged into "
|
|
||||||
b"the Central Authentication Service"
|
|
||||||
) in response.content
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_view_login_get_auth_allowed_service(self):
|
|
||||||
client = get_auth_client()
|
|
||||||
response = client.get("/login?service=https://www.example.com")
|
|
||||||
self.assertEqual(response.status_code, 302)
|
|
||||||
self.assertTrue(response.has_header('Location'))
|
|
||||||
self.assertTrue(
|
|
||||||
response['Location'].startswith(
|
|
||||||
"https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
ticket_value = response['Location'].split('ticket=')[-1]
|
|
||||||
user = models.User.objects.get(
|
|
||||||
username=settings.CAS_TEST_USER,
|
|
||||||
session_key=client.session.session_key
|
|
||||||
)
|
|
||||||
self.assertTrue(user)
|
|
||||||
ticket = models.ServiceTicket.objects.get(value=ticket_value)
|
|
||||||
self.assertEqual(ticket.user, user)
|
|
||||||
self.assertEqual(ticket.attributs, settings.CAS_TEST_ATTRIBUTES)
|
|
||||||
self.assertEqual(ticket.validate, False)
|
|
||||||
self.assertEqual(ticket.service_pattern, self.service_pattern)
|
|
||||||
|
|
||||||
def test_view_login_get_auth_denied_service(self):
|
|
||||||
client = get_auth_client()
|
|
||||||
response = client.get("/login?service=https://www.example.org")
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
self.assertTrue(b"Service https://www.example.org non allowed" in response.content)
|
|
||||||
|
|
||||||
|
|
||||||
class LogoutTestCase(TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
|
|
||||||
|
|
||||||
def test_logout_view(self):
|
|
||||||
client = get_auth_client()
|
|
||||||
|
|
||||||
response = client.get("/login")
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
self.assertTrue(
|
|
||||||
(
|
|
||||||
b"You have successfully logged into "
|
|
||||||
b"the Central Authentication Service"
|
|
||||||
) in response.content
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.get("/logout")
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
self.assertTrue(
|
|
||||||
(
|
|
||||||
b"You have successfully logged out from "
|
|
||||||
b"the Central Authentication Service"
|
|
||||||
) in response.content
|
|
||||||
)
|
|
||||||
|
|
||||||
response = client.get("/login")
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
self.assertFalse(
|
|
||||||
(
|
|
||||||
b"You have successfully logged into "
|
|
||||||
b"the Central Authentication Service"
|
|
||||||
) in response.content
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_logout_view_url(self):
|
|
||||||
client = get_auth_client()
|
|
||||||
|
|
||||||
response = client.get('/logout?url=https://www.example.com')
|
|
||||||
self.assertEqual(response.status_code, 302)
|
|
||||||
self.assertTrue(response.has_header("Location"))
|
|
||||||
self.assertEqual(response["Location"], "https://www.example.com")
|
|
||||||
|
|
||||||
response = client.get("/login")
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
self.assertFalse(
|
|
||||||
(
|
|
||||||
b"You have successfully logged into "
|
|
||||||
b"the Central Authentication Service"
|
|
||||||
) in response.content
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_logout_view_service(self):
|
|
||||||
client = get_auth_client()
|
|
||||||
|
|
||||||
response = client.get('/logout?service=https://www.example.com')
|
|
||||||
self.assertEqual(response.status_code, 302)
|
|
||||||
self.assertTrue(response.has_header("Location"))
|
|
||||||
self.assertEqual(response["Location"], "https://www.example.com")
|
|
||||||
|
|
||||||
response = client.get("/login")
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
self.assertFalse(
|
|
||||||
(
|
|
||||||
b"You have successfully logged into "
|
|
||||||
b"the Central Authentication Service"
|
|
||||||
) in response.content
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AuthTestCase(TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
|
|
||||||
self.service = 'https://www.example.com'
|
|
||||||
models.ServicePattern.objects.create(
|
|
||||||
name="example",
|
|
||||||
pattern="^https://www\.example\.com(/.*)?$"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_auth_view_goodpass(self):
|
|
||||||
settings.CAS_AUTH_SHARED_SECRET = 'test'
|
|
||||||
client = Client()
|
|
||||||
response = client.post(
|
|
||||||
'/auth',
|
|
||||||
{
|
|
||||||
'username': settings.CAS_TEST_USER,
|
|
||||||
'password': settings.CAS_TEST_PASSWORD,
|
|
||||||
'service': self.service,
|
|
||||||
'secret': 'test'
|
|
||||||
}
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
self.assertEqual(response.content, b'yes\n')
|
|
||||||
|
|
||||||
def test_auth_view_badpass(self):
|
|
||||||
settings.CAS_AUTH_SHARED_SECRET = 'test'
|
|
||||||
client = Client()
|
|
||||||
response = client.post(
|
|
||||||
'/auth',
|
|
||||||
{
|
|
||||||
'username': settings.CAS_TEST_USER,
|
|
||||||
'password': 'badpass',
|
|
||||||
'service': self.service,
|
|
||||||
'secret': 'test'
|
|
||||||
}
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
self.assertEqual(response.content, b'no\n')
|
|
||||||
|
|
||||||
def test_auth_view_badservice(self):
|
|
||||||
settings.CAS_AUTH_SHARED_SECRET = 'test'
|
|
||||||
client = Client()
|
|
||||||
response = client.post(
|
|
||||||
'/auth',
|
|
||||||
{
|
|
||||||
'username': settings.CAS_TEST_USER,
|
|
||||||
'password': settings.CAS_TEST_PASSWORD,
|
|
||||||
'service': 'https://www.example.org',
|
|
||||||
'secret': 'test'
|
|
||||||
}
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
self.assertEqual(response.content, b'no\n')
|
|
||||||
|
|
||||||
def test_auth_view_badsecret(self):
|
|
||||||
settings.CAS_AUTH_SHARED_SECRET = 'test'
|
|
||||||
client = Client()
|
|
||||||
response = client.post(
|
|
||||||
'/auth',
|
|
||||||
{
|
|
||||||
'username': settings.CAS_TEST_USER,
|
|
||||||
'password': settings.CAS_TEST_PASSWORD,
|
|
||||||
'service': self.service,
|
|
||||||
'secret': 'badsecret'
|
|
||||||
}
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
self.assertEqual(response.content, b'no\n')
|
|
||||||
|
|
||||||
def test_auth_view_badsettings(self):
|
|
||||||
settings.CAS_AUTH_SHARED_SECRET = None
|
|
||||||
client = Client()
|
|
||||||
response = client.post(
|
|
||||||
'/auth',
|
|
||||||
{
|
|
||||||
'username': settings.CAS_TEST_USER,
|
|
||||||
'password': settings.CAS_TEST_PASSWORD,
|
|
||||||
'service': self.service,
|
|
||||||
'secret': 'test'
|
|
||||||
}
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
self.assertEqual(response.content, b"no\nplease set CAS_AUTH_SHARED_SECRET")
|
|
||||||
|
|
||||||
|
|
||||||
class ValidateTestCase(TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
|
|
||||||
self.service = 'https://www.example.com'
|
|
||||||
self.service_pattern = models.ServicePattern.objects.create(
|
|
||||||
name="example",
|
|
||||||
pattern="^https://www\.example\.com(/.*)?$"
|
|
||||||
)
|
|
||||||
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
|
|
||||||
|
|
||||||
def test_validate_view_ok(self):
|
|
||||||
ticket = get_user_ticket_request(self.service)[1]
|
|
||||||
|
|
||||||
client = Client()
|
|
||||||
response = client.get('/validate', {'ticket': ticket.value, 'service': self.service})
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
self.assertEqual(response.content, b'yes\ntest\n')
|
|
||||||
|
|
||||||
def test_validate_view_badservice(self):
|
|
||||||
ticket = get_user_ticket_request(self.service)[1]
|
|
||||||
|
|
||||||
client = Client()
|
|
||||||
response = client.get(
|
|
||||||
'/validate',
|
|
||||||
{'ticket': ticket.value, 'service': "https://www.example.org"}
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
self.assertEqual(response.content, b'no\n')
|
|
||||||
|
|
||||||
def test_validate_view_badticket(self):
|
|
||||||
get_user_ticket_request(self.service)
|
|
||||||
|
|
||||||
client = Client()
|
|
||||||
response = client.get(
|
|
||||||
'/validate',
|
|
||||||
{'ticket': "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX, 'service': self.service}
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
self.assertEqual(response.content, b'no\n')
|
|
||||||
|
|
||||||
|
|
||||||
class ValidateServiceTestCase(TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
|
|
||||||
self.service = 'http://127.0.0.1:45678'
|
|
||||||
self.service_pattern = models.ServicePattern.objects.create(
|
|
||||||
name="localhost",
|
|
||||||
pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
|
|
||||||
proxy_callback=True
|
|
||||||
)
|
|
||||||
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
|
|
||||||
|
|
||||||
def test_validate_service_view_ok(self):
|
|
||||||
ticket = get_user_ticket_request(self.service)[1]
|
|
||||||
|
|
||||||
client = Client()
|
|
||||||
response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service})
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
|
|
||||||
root = etree.fromstring(response.content)
|
|
||||||
sucess = root.xpath(
|
|
||||||
"//cas:authenticationSuccess",
|
|
||||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
||||||
)
|
|
||||||
self.assertTrue(sucess)
|
|
||||||
|
|
||||||
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
|
||||||
self.assertEqual(len(users), 1)
|
|
||||||
self.assertEqual(users[0].text, settings.CAS_TEST_USER)
|
|
||||||
|
|
||||||
attributes = root.xpath(
|
|
||||||
"//cas:attributes",
|
|
||||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
||||||
)
|
|
||||||
self.assertEqual(len(attributes), 1)
|
|
||||||
attrs1 = {}
|
|
||||||
for attr in attributes[0]:
|
|
||||||
attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text
|
|
||||||
|
|
||||||
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
|
||||||
self.assertEqual(len(attributes), len(attrs1))
|
|
||||||
attrs2 = {}
|
|
||||||
for attr in attributes:
|
|
||||||
attrs2[attr.attrib['name']] = attr.attrib['value']
|
|
||||||
self.assertEqual(attrs1, attrs2)
|
|
||||||
self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES)
|
|
||||||
|
|
||||||
def test_validate_service_view_badservice(self):
|
|
||||||
ticket = get_user_ticket_request(self.service)[1]
|
|
||||||
|
|
||||||
client = Client()
|
|
||||||
bad_service = "https://www.example.org"
|
|
||||||
response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': bad_service})
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
|
|
||||||
root = etree.fromstring(response.content)
|
|
||||||
error = root.xpath(
|
|
||||||
"//cas:authenticationFailure",
|
|
||||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
||||||
)
|
|
||||||
self.assertEqual(len(error), 1)
|
|
||||||
self.assertEqual(error[0].attrib['code'], "INVALID_SERVICE")
|
|
||||||
self.assertEqual(error[0].text, bad_service)
|
|
||||||
|
|
||||||
def test_validate_service_view_badticket_goodprefix(self):
|
|
||||||
get_user_ticket_request(self.service)
|
|
||||||
|
|
||||||
client = Client()
|
|
||||||
bad_ticket = "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX
|
|
||||||
response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
|
|
||||||
root = etree.fromstring(response.content)
|
|
||||||
error = root.xpath(
|
|
||||||
"//cas:authenticationFailure",
|
|
||||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
||||||
)
|
|
||||||
self.assertEqual(len(error), 1)
|
|
||||||
self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
|
|
||||||
self.assertEqual(error[0].text, 'ticket not found')
|
|
||||||
|
|
||||||
def test_validate_service_view_badticket_badprefix(self):
|
|
||||||
get_user_ticket_request(self.service)
|
|
||||||
|
|
||||||
client = Client()
|
|
||||||
bad_ticket = "RANDOM"
|
|
||||||
response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
|
|
||||||
root = etree.fromstring(response.content)
|
|
||||||
error = root.xpath(
|
|
||||||
"//cas:authenticationFailure",
|
|
||||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
||||||
)
|
|
||||||
self.assertEqual(len(error), 1)
|
|
||||||
self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
|
|
||||||
self.assertEqual(error[0].text, bad_ticket)
|
|
||||||
|
|
||||||
def test_validate_service_view_ok_pgturl(self):
|
|
||||||
(host, port) = utils.PGTUrlHandler.run()[1:3]
|
|
||||||
service = "http://%s:%s" % (host, port)
|
|
||||||
|
|
||||||
ticket = get_user_ticket_request(service)[1]
|
|
||||||
|
|
||||||
client = Client()
|
|
||||||
response = client.get(
|
|
||||||
'/serviceValidate',
|
|
||||||
{'ticket': ticket.value, 'service': service, 'pgtUrl': service}
|
|
||||||
)
|
|
||||||
pgt_params = utils.PGTUrlHandler.PARAMS.copy()
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
|
|
||||||
root = etree.fromstring(response.content)
|
|
||||||
pgtiou = root.xpath(
|
|
||||||
"//cas:proxyGrantingTicket",
|
|
||||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
||||||
)
|
|
||||||
self.assertEqual(len(pgtiou), 1)
|
|
||||||
self.assertEqual(pgt_params["pgtIou"], pgtiou[0].text)
|
|
||||||
self.assertTrue("pgtId" in pgt_params)
|
|
||||||
|
|
||||||
def test_validate_service_pgturl_bad_proxy_callback(self):
|
|
||||||
self.service_pattern.proxy_callback = False
|
|
||||||
self.service_pattern.save()
|
|
||||||
ticket = get_user_ticket_request(self.service)[1]
|
|
||||||
|
|
||||||
client = Client()
|
|
||||||
response = client.get(
|
|
||||||
'/serviceValidate',
|
|
||||||
{'ticket': ticket.value, 'service': self.service, 'pgtUrl': self.service}
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
|
|
||||||
root = etree.fromstring(response.content)
|
|
||||||
error = root.xpath(
|
|
||||||
"//cas:authenticationFailure",
|
|
||||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
||||||
)
|
|
||||||
self.assertEqual(len(error), 1)
|
|
||||||
self.assertEqual(error[0].attrib['code'], "INVALID_PROXY_CALLBACK")
|
|
||||||
self.assertEqual(error[0].text, "callback url not allowed by configuration")
|
|
||||||
|
|
||||||
|
|
||||||
class ProxyTestCase(TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
|
|
||||||
self.service = 'http://127.0.0.1'
|
|
||||||
self.service_pattern = models.ServicePattern.objects.create(
|
|
||||||
name="localhost",
|
|
||||||
pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
|
|
||||||
proxy=True,
|
|
||||||
proxy_callback=True
|
|
||||||
)
|
|
||||||
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
|
|
||||||
|
|
||||||
def test_validate_proxy_ok(self):
|
|
||||||
params = get_pgt()
|
|
||||||
|
|
||||||
# get a proxy ticket
|
|
||||||
client1 = Client()
|
|
||||||
response = client1.get('/proxy', {'pgt': params['pgtId'], 'targetService': self.service})
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
|
|
||||||
root = etree.fromstring(response.content)
|
|
||||||
sucess = root.xpath("//cas:proxySuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
|
||||||
self.assertTrue(sucess)
|
|
||||||
|
|
||||||
proxy_ticket = root.xpath(
|
|
||||||
"//cas:proxyTicket",
|
|
||||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
||||||
)
|
|
||||||
self.assertEqual(len(proxy_ticket), 1)
|
|
||||||
proxy_ticket = proxy_ticket[0].text
|
|
||||||
|
|
||||||
# validate the proxy ticket
|
|
||||||
client2 = Client()
|
|
||||||
response = client2.get('/proxyValidate', {'ticket': proxy_ticket, 'service': self.service})
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
|
|
||||||
root = etree.fromstring(response.content)
|
|
||||||
sucess = root.xpath(
|
|
||||||
"//cas:authenticationSuccess",
|
|
||||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
||||||
)
|
|
||||||
self.assertTrue(sucess)
|
|
||||||
|
|
||||||
# check that the proxy is send to the end service
|
|
||||||
proxies = root.xpath("//cas:proxies", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
|
||||||
self.assertEqual(len(proxies), 1)
|
|
||||||
proxy = proxies[0].xpath("//cas:proxy", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
|
||||||
self.assertEqual(len(proxy), 1)
|
|
||||||
self.assertEqual(proxy[0].text, params["service"])
|
|
||||||
|
|
||||||
# same tests than those for serviceValidate
|
|
||||||
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
|
||||||
self.assertEqual(len(users), 1)
|
|
||||||
self.assertEqual(users[0].text, settings.CAS_TEST_USER)
|
|
||||||
|
|
||||||
attributes = root.xpath(
|
|
||||||
"//cas:attributes",
|
|
||||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
||||||
)
|
|
||||||
self.assertEqual(len(attributes), 1)
|
|
||||||
attrs1 = {}
|
|
||||||
for attr in attributes[0]:
|
|
||||||
attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text
|
|
||||||
|
|
||||||
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
|
||||||
self.assertEqual(len(attributes), len(attrs1))
|
|
||||||
attrs2 = {}
|
|
||||||
for attr in attributes:
|
|
||||||
attrs2[attr.attrib['name']] = attr.attrib['value']
|
|
||||||
self.assertEqual(attrs1, attrs2)
|
|
||||||
self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES)
|
|
||||||
|
|
||||||
def test_validate_proxy_bad(self):
|
|
||||||
params = get_pgt()
|
|
||||||
|
|
||||||
# bad PGT
|
|
||||||
client1 = Client()
|
|
||||||
response = client1.get(
|
|
||||||
'/proxy',
|
|
||||||
{
|
|
||||||
'pgt': "%s-RANDOM" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX,
|
|
||||||
'targetService': params['service']
|
|
||||||
}
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
|
|
||||||
root = etree.fromstring(response.content)
|
|
||||||
error = root.xpath(
|
|
||||||
"//cas:authenticationFailure",
|
|
||||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
||||||
)
|
|
||||||
self.assertEqual(len(error), 1)
|
|
||||||
self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
|
|
||||||
self.assertEqual(
|
|
||||||
error[0].text,
|
|
||||||
"PGT %s-RANDOM not found" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX
|
|
||||||
)
|
|
||||||
|
|
||||||
# bad targetService
|
|
||||||
client2 = Client()
|
|
||||||
response = client2.get(
|
|
||||||
'/proxy',
|
|
||||||
{'pgt': params['pgtId'], 'targetService': "https://www.example.org"}
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
|
|
||||||
root = etree.fromstring(response.content)
|
|
||||||
error = root.xpath(
|
|
||||||
"//cas:authenticationFailure",
|
|
||||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
||||||
)
|
|
||||||
self.assertEqual(len(error), 1)
|
|
||||||
self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE")
|
|
||||||
self.assertEqual(error[0].text, "https://www.example.org")
|
|
||||||
|
|
||||||
# service do not allow proxy ticket
|
|
||||||
self.service_pattern.proxy = False
|
|
||||||
self.service_pattern.save()
|
|
||||||
|
|
||||||
client3 = Client()
|
|
||||||
response = client3.get(
|
|
||||||
'/proxy',
|
|
||||||
{'pgt': params['pgtId'], 'targetService': params['service']}
|
|
||||||
)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
|
||||||
|
|
||||||
root = etree.fromstring(response.content)
|
|
||||||
error = root.xpath(
|
|
||||||
"//cas:authenticationFailure",
|
|
||||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
||||||
)
|
|
||||||
self.assertEqual(len(error), 1)
|
|
||||||
self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE")
|
|
||||||
self.assertEqual(
|
|
||||||
error[0].text,
|
|
||||||
'the service %s do not allow proxy ticket' % params['service']
|
|
||||||
)
|
|
0
cas_server/tests/__init__.py
Normal file
0
cas_server/tests/__init__.py
Normal file
193
cas_server/tests/mixin.py
Normal file
193
cas_server/tests/mixin.py
Normal file
|
@ -0,0 +1,193 @@
|
||||||
|
# ⁻*- coding: utf-8 -*-
|
||||||
|
# This program is distributed in the hope that it will be useful, but WITHOUT
|
||||||
|
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
||||||
|
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
|
||||||
|
# more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU General Public License version 3
|
||||||
|
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
||||||
|
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||||
|
#
|
||||||
|
# (c) 2016 Valentin Samir
|
||||||
|
"""Some mixin classes for tests"""
|
||||||
|
from cas_server.default_settings import settings
|
||||||
|
from django.utils import timezone
|
||||||
|
|
||||||
|
import re
|
||||||
|
from lxml import etree
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from cas_server import models
|
||||||
|
from cas_server.tests.utils import get_auth_client
|
||||||
|
|
||||||
|
|
||||||
|
class BaseServicePattern(object):
|
||||||
|
"""Mixing for setting up service pattern for testing"""
|
||||||
|
def setup_service_patterns(self, proxy=False):
|
||||||
|
"""setting up service pattern"""
|
||||||
|
# For general purpose testing
|
||||||
|
self.service = "https://www.example.com"
|
||||||
|
self.service_pattern = models.ServicePattern.objects.create(
|
||||||
|
name="example",
|
||||||
|
pattern="^https://www\.example\.com(/.*)?$",
|
||||||
|
proxy=proxy,
|
||||||
|
)
|
||||||
|
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
|
||||||
|
|
||||||
|
# For testing the restrict_users attributes
|
||||||
|
self.service_restrict_user_fail = "https://restrict_user_fail.example.com"
|
||||||
|
self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create(
|
||||||
|
name="restrict_user_fail",
|
||||||
|
pattern="^https://restrict_user_fail\.example\.com(/.*)?$",
|
||||||
|
restrict_users=True,
|
||||||
|
proxy=proxy,
|
||||||
|
)
|
||||||
|
self.service_restrict_user_success = "https://restrict_user_success.example.com"
|
||||||
|
self.service_pattern_restrict_user_success = models.ServicePattern.objects.create(
|
||||||
|
name="restrict_user_success",
|
||||||
|
pattern="^https://restrict_user_success\.example\.com(/.*)?$",
|
||||||
|
restrict_users=True,
|
||||||
|
proxy=proxy,
|
||||||
|
)
|
||||||
|
models.Username.objects.create(
|
||||||
|
value=settings.CAS_TEST_USER,
|
||||||
|
service_pattern=self.service_pattern_restrict_user_success
|
||||||
|
)
|
||||||
|
|
||||||
|
# For testing the user attributes filtering conditions
|
||||||
|
self.service_filter_fail = "https://filter_fail.example.com"
|
||||||
|
self.service_pattern_filter_fail = models.ServicePattern.objects.create(
|
||||||
|
name="filter_fail",
|
||||||
|
pattern="^https://filter_fail\.example\.com(/.*)?$",
|
||||||
|
proxy=proxy,
|
||||||
|
)
|
||||||
|
models.FilterAttributValue.objects.create(
|
||||||
|
attribut="right",
|
||||||
|
pattern="^admin$",
|
||||||
|
service_pattern=self.service_pattern_filter_fail
|
||||||
|
)
|
||||||
|
self.service_filter_fail_alt = "https://filter_fail_alt.example.com"
|
||||||
|
self.service_pattern_filter_fail_alt = models.ServicePattern.objects.create(
|
||||||
|
name="filter_fail_alt",
|
||||||
|
pattern="^https://filter_fail_alt\.example\.com(/.*)?$",
|
||||||
|
proxy=proxy,
|
||||||
|
)
|
||||||
|
models.FilterAttributValue.objects.create(
|
||||||
|
attribut="nom",
|
||||||
|
pattern="^toto$",
|
||||||
|
service_pattern=self.service_pattern_filter_fail_alt
|
||||||
|
)
|
||||||
|
self.service_filter_success = "https://filter_success.example.com"
|
||||||
|
self.service_pattern_filter_success = models.ServicePattern.objects.create(
|
||||||
|
name="filter_success",
|
||||||
|
pattern="^https://filter_success\.example\.com(/.*)?$",
|
||||||
|
proxy=proxy,
|
||||||
|
)
|
||||||
|
models.FilterAttributValue.objects.create(
|
||||||
|
attribut="email",
|
||||||
|
pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']),
|
||||||
|
service_pattern=self.service_pattern_filter_success
|
||||||
|
)
|
||||||
|
|
||||||
|
# For testing the user_field attributes
|
||||||
|
self.service_field_needed_fail = "https://field_needed_fail.example.com"
|
||||||
|
self.service_pattern_field_needed_fail = models.ServicePattern.objects.create(
|
||||||
|
name="field_needed_fail",
|
||||||
|
pattern="^https://field_needed_fail\.example\.com(/.*)?$",
|
||||||
|
user_field="uid",
|
||||||
|
proxy=proxy,
|
||||||
|
)
|
||||||
|
self.service_field_needed_success = "https://field_needed_success.example.com"
|
||||||
|
self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
|
||||||
|
name="field_needed_success",
|
||||||
|
pattern="^https://field_needed_success\.example\.com(/.*)?$",
|
||||||
|
user_field="alias",
|
||||||
|
proxy=proxy,
|
||||||
|
)
|
||||||
|
self.service_field_needed_success_alt = "https://field_needed_success_alt.example.com"
|
||||||
|
self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
|
||||||
|
name="field_needed_success_alt",
|
||||||
|
pattern="^https://field_needed_success_alt\.example\.com(/.*)?$",
|
||||||
|
user_field="nom",
|
||||||
|
proxy=proxy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class XmlContent(object):
|
||||||
|
"""Mixin for test on CAS XML responses"""
|
||||||
|
def assert_error(self, response, code, text=None):
|
||||||
|
"""Assert a validation error"""
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
root = etree.fromstring(response.content)
|
||||||
|
error = root.xpath(
|
||||||
|
"//cas:authenticationFailure",
|
||||||
|
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||||
|
)
|
||||||
|
self.assertEqual(len(error), 1)
|
||||||
|
self.assertEqual(error[0].attrib['code'], code)
|
||||||
|
if text is not None:
|
||||||
|
self.assertEqual(error[0].text, text)
|
||||||
|
|
||||||
|
def assert_success(self, response, username, original_attributes):
|
||||||
|
"""assert a ticket validation success"""
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
root = etree.fromstring(response.content)
|
||||||
|
sucess = root.xpath(
|
||||||
|
"//cas:authenticationSuccess",
|
||||||
|
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||||
|
)
|
||||||
|
self.assertTrue(sucess)
|
||||||
|
|
||||||
|
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||||
|
self.assertEqual(len(users), 1)
|
||||||
|
self.assertEqual(users[0].text, username)
|
||||||
|
|
||||||
|
attributes = root.xpath(
|
||||||
|
"//cas:attributes",
|
||||||
|
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||||
|
)
|
||||||
|
self.assertEqual(len(attributes), 1)
|
||||||
|
attrs1 = set()
|
||||||
|
for attr in attributes[0]:
|
||||||
|
attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
|
||||||
|
|
||||||
|
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||||
|
self.assertEqual(len(attributes), len(attrs1))
|
||||||
|
attrs2 = set()
|
||||||
|
for attr in attributes:
|
||||||
|
attrs2.add((attr.attrib['name'], attr.attrib['value']))
|
||||||
|
original = set()
|
||||||
|
for key, value in original_attributes.items():
|
||||||
|
if isinstance(value, list):
|
||||||
|
for sub_value in value:
|
||||||
|
original.add((key, sub_value))
|
||||||
|
else:
|
||||||
|
original.add((key, value))
|
||||||
|
self.assertEqual(attrs1, attrs2)
|
||||||
|
self.assertEqual(attrs1, original)
|
||||||
|
|
||||||
|
return root
|
||||||
|
|
||||||
|
|
||||||
|
class UserModels(object):
|
||||||
|
"""Mixin for test on CAS user models"""
|
||||||
|
@staticmethod
|
||||||
|
def expire_user():
|
||||||
|
"""return an expired user"""
|
||||||
|
client = get_auth_client()
|
||||||
|
|
||||||
|
new_date = timezone.now() - timedelta(seconds=(settings.SESSION_COOKIE_AGE + 600))
|
||||||
|
models.User.objects.filter(
|
||||||
|
username=settings.CAS_TEST_USER,
|
||||||
|
session_key=client.session.session_key
|
||||||
|
).update(date=new_date)
|
||||||
|
return client
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_user(client):
|
||||||
|
"""return the user associated with an authenticated client"""
|
||||||
|
return models.User.objects.get(
|
||||||
|
username=settings.CAS_TEST_USER,
|
||||||
|
session_key=client.session.session_key
|
||||||
|
)
|
|
@ -52,7 +52,7 @@ MIDDLEWARE_CLASSES = [
|
||||||
'django.middleware.locale.LocaleMiddleware',
|
'django.middleware.locale.LocaleMiddleware',
|
||||||
]
|
]
|
||||||
|
|
||||||
ROOT_URLCONF = 'cas_server.urls'
|
ROOT_URLCONF = 'cas_server.tests.urls'
|
||||||
|
|
||||||
# Database
|
# Database
|
||||||
# https://docs.djangoproject.com/en/1.9/ref/settings/#databases
|
# https://docs.djangoproject.com/en/1.9/ref/settings/#databases
|
||||||
|
@ -60,6 +60,7 @@ ROOT_URLCONF = 'cas_server.urls'
|
||||||
DATABASES = {
|
DATABASES = {
|
||||||
'default': {
|
'default': {
|
||||||
'ENGINE': 'django.db.backends.sqlite3',
|
'ENGINE': 'django.db.backends.sqlite3',
|
||||||
|
'NAME': ':memory:',
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
166
cas_server/tests/test_models.py
Normal file
166
cas_server/tests/test_models.py
Normal file
|
@ -0,0 +1,166 @@
|
||||||
|
# ⁻*- coding: utf-8 -*-
|
||||||
|
# This program is distributed in the hope that it will be useful, but WITHOUT
|
||||||
|
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
||||||
|
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
|
||||||
|
# more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU General Public License version 3
|
||||||
|
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
||||||
|
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||||
|
#
|
||||||
|
# (c) 2016 Valentin Samir
|
||||||
|
"""Tests module for models"""
|
||||||
|
from cas_server.default_settings import settings
|
||||||
|
|
||||||
|
from django.test import TestCase
|
||||||
|
from django.test.utils import override_settings
|
||||||
|
from django.utils import timezone
|
||||||
|
|
||||||
|
from datetime import timedelta
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
|
from cas_server import models
|
||||||
|
from cas_server.tests.utils import get_auth_client, HttpParamsHandler
|
||||||
|
from cas_server.tests.mixin import UserModels, BaseServicePattern
|
||||||
|
|
||||||
|
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
|
||||||
|
class UserTestCase(TestCase, UserModels):
|
||||||
|
"""tests for the user models"""
|
||||||
|
def setUp(self):
|
||||||
|
"""Prepare the test context"""
|
||||||
|
self.service = 'http://127.0.0.1:45678'
|
||||||
|
self.service_pattern = models.ServicePattern.objects.create(
|
||||||
|
name="localhost",
|
||||||
|
pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
|
||||||
|
single_log_out=True
|
||||||
|
)
|
||||||
|
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
|
||||||
|
|
||||||
|
def test_clean_old_entries(self):
|
||||||
|
"""test clean_old_entries"""
|
||||||
|
# get an authenticated client
|
||||||
|
client = self.expire_user()
|
||||||
|
# assert the user exists before being cleaned
|
||||||
|
self.assertEqual(len(models.User.objects.all()), 1)
|
||||||
|
# assert the last activity date is before the expiry date
|
||||||
|
self.assertTrue(
|
||||||
|
self.get_user(client).date < (
|
||||||
|
timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# delete old inactive users
|
||||||
|
models.User.clean_old_entries()
|
||||||
|
# assert the user has being well delete
|
||||||
|
self.assertEqual(len(models.User.objects.all()), 0)
|
||||||
|
|
||||||
|
def test_clean_deleted_sessions(self):
|
||||||
|
"""test clean_deleted_sessions"""
|
||||||
|
# get an authenticated client
|
||||||
|
client1 = get_auth_client()
|
||||||
|
client2 = get_auth_client()
|
||||||
|
# generate a ticket to fire SLO during user cleaning (SLO should fail a nothing listen
|
||||||
|
# on self.service)
|
||||||
|
ticket = self.get_user(client1).get_ticket(
|
||||||
|
models.ServiceTicket,
|
||||||
|
self.service,
|
||||||
|
self.service_pattern,
|
||||||
|
renew=False
|
||||||
|
)
|
||||||
|
ticket.validate = True
|
||||||
|
ticket.save()
|
||||||
|
# simulated expired session being garbage collected for client1
|
||||||
|
session = SessionStore(session_key=client1.session.session_key)
|
||||||
|
session.flush()
|
||||||
|
# assert the user exists before being cleaned
|
||||||
|
self.assertTrue(self.get_user(client1))
|
||||||
|
self.assertTrue(self.get_user(client2))
|
||||||
|
self.assertEqual(len(models.User.objects.all()), 2)
|
||||||
|
# session has being remove so the user of client1 is no longer authenticated
|
||||||
|
self.assertFalse(client1.session.get("authenticated"))
|
||||||
|
# the user a client2 should still be authenticated
|
||||||
|
self.assertTrue(client2.session.get("authenticated"))
|
||||||
|
# the user should be deleted
|
||||||
|
models.User.clean_deleted_sessions()
|
||||||
|
# assert the user with expired sessions has being well deleted but the other remain
|
||||||
|
self.assertEqual(len(models.User.objects.all()), 1)
|
||||||
|
self.assertFalse(models.ServiceTicket.objects.all())
|
||||||
|
self.assertTrue(client2.session.get("authenticated"))
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
|
||||||
|
class TicketTestCase(TestCase, UserModels, BaseServicePattern):
|
||||||
|
"""tests for the tickets models"""
|
||||||
|
def setUp(self):
|
||||||
|
"""Prepare the test context"""
|
||||||
|
self.setup_service_patterns()
|
||||||
|
self.service = 'http://127.0.0.1:45678'
|
||||||
|
self.service_pattern = models.ServicePattern.objects.create(
|
||||||
|
name="localhost",
|
||||||
|
pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
|
||||||
|
single_log_out=True
|
||||||
|
)
|
||||||
|
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_ticket(
|
||||||
|
user,
|
||||||
|
ticket_class,
|
||||||
|
service,
|
||||||
|
service_pattern,
|
||||||
|
renew=False,
|
||||||
|
validate=False,
|
||||||
|
validity_expired=False,
|
||||||
|
timeout_expired=False,
|
||||||
|
single_log_out=False,
|
||||||
|
):
|
||||||
|
"""Return a ticket"""
|
||||||
|
ticket = user.get_ticket(ticket_class, service, service_pattern, renew)
|
||||||
|
ticket.validate = validate
|
||||||
|
ticket.single_log_out = single_log_out
|
||||||
|
if validity_expired:
|
||||||
|
ticket.creation = min(
|
||||||
|
ticket.creation,
|
||||||
|
(timezone.now() - timedelta(seconds=(ticket_class.VALIDITY + 10)))
|
||||||
|
)
|
||||||
|
if timeout_expired:
|
||||||
|
ticket.creation = min(
|
||||||
|
ticket.creation,
|
||||||
|
(timezone.now() - timedelta(seconds=(ticket_class.TIMEOUT + 10)))
|
||||||
|
)
|
||||||
|
ticket.save()
|
||||||
|
return ticket
|
||||||
|
|
||||||
|
def test_clean_old_service_ticket(self):
|
||||||
|
"""test tickets clean_old_entries"""
|
||||||
|
# ge an authenticated client
|
||||||
|
client = get_auth_client()
|
||||||
|
# get the user associated to the client
|
||||||
|
user = self.get_user(client)
|
||||||
|
# generate a ticket for that client, waiting for validation
|
||||||
|
self.get_ticket(user, models.ServiceTicket, self.service, self.service_pattern)
|
||||||
|
# generate another ticket for those validation time has expired
|
||||||
|
self.get_ticket(
|
||||||
|
user, models.ServiceTicket,
|
||||||
|
self.service, self.service_pattern, validity_expired=True
|
||||||
|
)
|
||||||
|
(httpd, host, port) = HttpParamsHandler.run()[0:3]
|
||||||
|
service = "http://%s:%s" % (host, port)
|
||||||
|
# generate a ticket with SLO having timeout reach
|
||||||
|
self.get_ticket(
|
||||||
|
user, models.ServiceTicket,
|
||||||
|
service, self.service_pattern, timeout_expired=True,
|
||||||
|
validate=True, single_log_out=True
|
||||||
|
)
|
||||||
|
# there should be 3 tickets in the db
|
||||||
|
self.assertEqual(len(models.ServiceTicket.objects.all()), 3)
|
||||||
|
# we call the clean_old_entries method that should delete validated non SLO ticket and
|
||||||
|
# expired non validated ticket and send SLO for SLO expired ticket before deleting then
|
||||||
|
models.ServiceTicket.clean_old_entries()
|
||||||
|
params = httpd.PARAMS
|
||||||
|
# we successfully got a SLO request
|
||||||
|
self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest'])
|
||||||
|
# only 1 ticket remain in the db
|
||||||
|
self.assertEqual(len(models.ServiceTicket.objects.all()), 1)
|
191
cas_server/tests/test_utils.py
Normal file
191
cas_server/tests/test_utils.py
Normal file
|
@ -0,0 +1,191 @@
|
||||||
|
# ⁻*- coding: utf-8 -*-
|
||||||
|
# This program is distributed in the hope that it will be useful, but WITHOUT
|
||||||
|
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
||||||
|
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
|
||||||
|
# more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU General Public License version 3
|
||||||
|
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
||||||
|
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||||
|
#
|
||||||
|
# (c) 2016 Valentin Samir
|
||||||
|
"""Tests module for utils"""
|
||||||
|
from django.test import TestCase
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
|
from cas_server import utils
|
||||||
|
|
||||||
|
|
||||||
|
class CheckPasswordCase(TestCase):
|
||||||
|
"""Tests for the utils function `utils.check_password`"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Generate random bytes string that will be used ass passwords"""
|
||||||
|
self.password1 = utils.gen_saml_id()
|
||||||
|
self.password2 = utils.gen_saml_id()
|
||||||
|
if not isinstance(self.password1, bytes): # pragma: no cover executed only in python3
|
||||||
|
self.password1 = self.password1.encode("utf8")
|
||||||
|
self.password2 = self.password2.encode("utf8")
|
||||||
|
|
||||||
|
def test_setup(self):
|
||||||
|
"""check that generated password are bytes"""
|
||||||
|
self.assertIsInstance(self.password1, bytes)
|
||||||
|
self.assertIsInstance(self.password2, bytes)
|
||||||
|
|
||||||
|
def test_plain(self):
|
||||||
|
"""test the plain auth method"""
|
||||||
|
self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
|
||||||
|
self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
|
||||||
|
|
||||||
|
def test_plain_unicode(self):
|
||||||
|
"""test the plain auth method with unicode input"""
|
||||||
|
self.assertTrue(
|
||||||
|
utils.check_password(
|
||||||
|
"plain",
|
||||||
|
self.password1.decode("utf8"),
|
||||||
|
self.password1.decode("utf8"),
|
||||||
|
"utf8"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertFalse(
|
||||||
|
utils.check_password(
|
||||||
|
"plain",
|
||||||
|
self.password1.decode("utf8"),
|
||||||
|
self.password2.decode("utf8"),
|
||||||
|
"utf8"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_crypt(self):
|
||||||
|
"""test the crypt auth method"""
|
||||||
|
salts = ["$6$UVVAQvrMyXMF3FF3", "aa"]
|
||||||
|
hashed_password1 = []
|
||||||
|
for salt in salts:
|
||||||
|
if six.PY3:
|
||||||
|
hashed_password1.append(
|
||||||
|
utils.crypt.crypt(
|
||||||
|
self.password1.decode("utf8"),
|
||||||
|
salt
|
||||||
|
).encode("utf8")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hashed_password1.append(utils.crypt.crypt(self.password1, salt))
|
||||||
|
|
||||||
|
for hp1 in hashed_password1:
|
||||||
|
self.assertTrue(utils.check_password("crypt", self.password1, hp1, "utf8"))
|
||||||
|
self.assertFalse(utils.check_password("crypt", self.password2, hp1, "utf8"))
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
utils.check_password("crypt", self.password1, b"$truc$s$dsdsd", "utf8")
|
||||||
|
|
||||||
|
def test_ldap_password_valid(self):
|
||||||
|
"""test the ldap auth method with all the schemes"""
|
||||||
|
salt = b"UVVAQvrMyXMF3FF3"
|
||||||
|
schemes_salt = [b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}"]
|
||||||
|
schemes_nosalt = [b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"]
|
||||||
|
hashed_password1 = []
|
||||||
|
for scheme in schemes_salt:
|
||||||
|
hashed_password1.append(
|
||||||
|
utils.LdapHashUserPassword.hash(scheme, self.password1, salt, charset="utf8")
|
||||||
|
)
|
||||||
|
for scheme in schemes_nosalt:
|
||||||
|
hashed_password1.append(
|
||||||
|
utils.LdapHashUserPassword.hash(scheme, self.password1, charset="utf8")
|
||||||
|
)
|
||||||
|
hashed_password1.append(
|
||||||
|
utils.LdapHashUserPassword.hash(
|
||||||
|
b"{CRYPT}",
|
||||||
|
self.password1,
|
||||||
|
b"$6$UVVAQvrMyXMF3FF3",
|
||||||
|
charset="utf8"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for hp1 in hashed_password1:
|
||||||
|
self.assertIsInstance(hp1, bytes)
|
||||||
|
self.assertTrue(utils.check_password("ldap", self.password1, hp1, "utf8"))
|
||||||
|
self.assertFalse(utils.check_password("ldap", self.password2, hp1, "utf8"))
|
||||||
|
|
||||||
|
def test_ldap_password_fail(self):
|
||||||
|
"""test the ldap auth method with malformed hash or bad schemes"""
|
||||||
|
salt = b"UVVAQvrMyXMF3FF3"
|
||||||
|
schemes_salt = [b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}"]
|
||||||
|
schemes_nosalt = [b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"]
|
||||||
|
|
||||||
|
# first try to hash with bad parameters
|
||||||
|
with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
|
||||||
|
utils.LdapHashUserPassword.hash(b"TOTO", self.password1)
|
||||||
|
for scheme in schemes_nosalt:
|
||||||
|
with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
|
||||||
|
utils.LdapHashUserPassword.hash(scheme, self.password1, salt)
|
||||||
|
for scheme in schemes_salt:
|
||||||
|
with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
|
||||||
|
utils.LdapHashUserPassword.hash(scheme, self.password1)
|
||||||
|
with self.assertRaises(utils.LdapHashUserPassword.BadSalt):
|
||||||
|
utils.LdapHashUserPassword.hash(b'{CRYPT}', self.password1, b"$truc$toto")
|
||||||
|
|
||||||
|
# then try to check hash with bad hashes
|
||||||
|
with self.assertRaises(utils.LdapHashUserPassword.BadHash):
|
||||||
|
utils.check_password("ldap", self.password1, b"TOTOssdsdsd", "utf8")
|
||||||
|
for scheme in schemes_salt:
|
||||||
|
with self.assertRaises(utils.LdapHashUserPassword.BadHash):
|
||||||
|
utils.check_password("ldap", self.password1, scheme + b"dG90b3E8ZHNkcw==", "utf8")
|
||||||
|
|
||||||
|
def test_hex(self):
|
||||||
|
"""test all the hex_HASH method: the hashed password is a simple hash of the password"""
|
||||||
|
hashes = ["md5", "sha1", "sha224", "sha256", "sha384", "sha512"]
|
||||||
|
hashed_password1 = []
|
||||||
|
for hash in hashes:
|
||||||
|
hashed_password1.append(
|
||||||
|
("hex_%s" % hash, getattr(utils.hashlib, hash)(self.password1).hexdigest())
|
||||||
|
)
|
||||||
|
for (method, hp1) in hashed_password1:
|
||||||
|
self.assertTrue(utils.check_password(method, self.password1, hp1, "utf8"))
|
||||||
|
self.assertFalse(utils.check_password(method, self.password2, hp1, "utf8"))
|
||||||
|
|
||||||
|
def test_bad_method(self):
|
||||||
|
"""try to check password with a bad method, should raise a ValueError"""
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
utils.check_password("test", self.password1, b"$truc$s$dsdsd", "utf8")
|
||||||
|
|
||||||
|
|
||||||
|
class UtilsTestCase(TestCase):
|
||||||
|
"""tests for some little utils functions"""
|
||||||
|
def test_import_attr(self):
|
||||||
|
"""
|
||||||
|
test the import_attr function. Feeded with a dotted path string, it should
|
||||||
|
import the dotted module and return that last componend of the dotted path
|
||||||
|
(function, class or variable)
|
||||||
|
"""
|
||||||
|
with self.assertRaises(ImportError):
|
||||||
|
utils.import_attr('toto.titi.tutu')
|
||||||
|
with self.assertRaises(AttributeError):
|
||||||
|
utils.import_attr('cas_server.utils.toto')
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
utils.import_attr('toto')
|
||||||
|
self.assertEqual(
|
||||||
|
utils.import_attr('cas_server.default_app_config'),
|
||||||
|
'cas_server.apps.CasAppConfig'
|
||||||
|
)
|
||||||
|
self.assertEqual(utils.import_attr(utils), utils)
|
||||||
|
|
||||||
|
def test_update_url(self):
|
||||||
|
"""
|
||||||
|
test the update_url function. Given an url with possible GET parameter and a dict
|
||||||
|
the function build a url with GET parameters updated by the dictionnary
|
||||||
|
"""
|
||||||
|
url1 = utils.update_url(u"https://www.example.com?toto=1", {u"tata": u"2"})
|
||||||
|
url2 = utils.update_url(b"https://www.example.com?toto=1", {b"tata": b"2"})
|
||||||
|
self.assertEqual(url1, u"https://www.example.com?tata=2&toto=1")
|
||||||
|
self.assertEqual(url2, u"https://www.example.com?tata=2&toto=1")
|
||||||
|
|
||||||
|
url3 = utils.update_url(u"https://www.example.com?toto=1", {u"toto": u"2"})
|
||||||
|
self.assertEqual(url3, u"https://www.example.com?toto=2")
|
||||||
|
|
||||||
|
def test_crypt_salt_is_valid(self):
|
||||||
|
"""test the function crypt_salt_is_valid who test if a crypt salt is valid"""
|
||||||
|
self.assertFalse(utils.crypt_salt_is_valid("")) # len 0
|
||||||
|
self.assertFalse(utils.crypt_salt_is_valid("a")) # len 1
|
||||||
|
self.assertFalse(utils.crypt_salt_is_valid("$$")) # start with $ followed by $
|
||||||
|
self.assertFalse(utils.crypt_salt_is_valid("$toto")) # start with $ but no secondary $
|
||||||
|
self.assertFalse(utils.crypt_salt_is_valid("$toto$toto")) # algorithm toto not known
|
1813
cas_server/tests/test_view.py
Normal file
1813
cas_server/tests/test_view.py
Normal file
File diff suppressed because it is too large
Load diff
22
cas_server/tests/urls.py
Normal file
22
cas_server/tests/urls.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
"""cas URL Configuration
|
||||||
|
|
||||||
|
The `urlpatterns` list routes URLs to views. For more information please see:
|
||||||
|
https://docs.djangoproject.com/en/1.9/topics/http/urls/
|
||||||
|
Examples:
|
||||||
|
Function views
|
||||||
|
1. Add an import: from my_app import views
|
||||||
|
2. Add a URL to urlpatterns: url(r'^$', views.home, name='home')
|
||||||
|
Class-based views
|
||||||
|
1. Add an import: from other_app.views import Home
|
||||||
|
2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home')
|
||||||
|
Including another URLconf
|
||||||
|
1. Import the include() function: from django.conf.urls import url, include, include
|
||||||
|
2. Add a URL to urlpatterns: url(r'^blog/', include('blog.urls'))
|
||||||
|
"""
|
||||||
|
from django.conf.urls import url, include
|
||||||
|
from django.contrib import admin
|
||||||
|
|
||||||
|
urlpatterns = [
|
||||||
|
url(r'^admin/', admin.site.urls),
|
||||||
|
url(r'^', include('cas_server.urls', namespace='cas_server')),
|
||||||
|
]
|
180
cas_server/tests/utils.py
Normal file
180
cas_server/tests/utils.py
Normal file
|
@ -0,0 +1,180 @@
|
||||||
|
# ⁻*- coding: utf-8 -*-
|
||||||
|
# This program is distributed in the hope that it will be useful, but WITHOUT
|
||||||
|
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
||||||
|
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
|
||||||
|
# more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU General Public License version 3
|
||||||
|
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
||||||
|
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||||
|
#
|
||||||
|
# (c) 2016 Valentin Samir
|
||||||
|
"""Some utils functions for tests"""
|
||||||
|
from cas_server.default_settings import settings
|
||||||
|
|
||||||
|
from django.test import Client
|
||||||
|
|
||||||
|
import cgi
|
||||||
|
from threading import Thread
|
||||||
|
from lxml import etree
|
||||||
|
from six.moves import BaseHTTPServer
|
||||||
|
from six.moves.urllib.parse import urlparse, parse_qsl
|
||||||
|
|
||||||
|
from cas_server import models
|
||||||
|
|
||||||
|
|
||||||
|
def copy_form(form):
|
||||||
|
"""Copy form value into a dict"""
|
||||||
|
params = {}
|
||||||
|
for field in form:
|
||||||
|
if field.value():
|
||||||
|
params[field.name] = field.value()
|
||||||
|
else:
|
||||||
|
params[field.name] = ""
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def get_login_page_params(client=None):
|
||||||
|
"""Return a client and the POST params for the client to login"""
|
||||||
|
if client is None:
|
||||||
|
client = Client()
|
||||||
|
response = client.get('/login')
|
||||||
|
params = copy_form(response.context["form"])
|
||||||
|
return client, params
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_client(**update):
|
||||||
|
"""return a authenticated client"""
|
||||||
|
client, params = get_login_page_params()
|
||||||
|
params["username"] = settings.CAS_TEST_USER
|
||||||
|
params["password"] = settings.CAS_TEST_PASSWORD
|
||||||
|
params.update(update)
|
||||||
|
|
||||||
|
client.post('/login', params)
|
||||||
|
assert client.session.get("authenticated")
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_ticket_request(service):
|
||||||
|
"""Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
|
||||||
|
client = get_auth_client()
|
||||||
|
response = client.get("/login", {"service": service})
|
||||||
|
ticket_value = response['Location'].split('ticket=')[-1]
|
||||||
|
user = models.User.objects.get(
|
||||||
|
username=settings.CAS_TEST_USER,
|
||||||
|
session_key=client.session.session_key
|
||||||
|
)
|
||||||
|
ticket = models.ServiceTicket.objects.get(value=ticket_value)
|
||||||
|
return (user, ticket, client)
|
||||||
|
|
||||||
|
|
||||||
|
def get_validated_ticket(service):
|
||||||
|
"""Return a tick that has being already validated. Used to test SLO"""
|
||||||
|
(ticket, auth_client) = get_user_ticket_request(service)[1:3]
|
||||||
|
|
||||||
|
client = Client()
|
||||||
|
response = client.get('/validate', {'ticket': ticket.value, 'service': service})
|
||||||
|
assert (response.status_code == 200)
|
||||||
|
assert (response.content == b'yes\ntest\n')
|
||||||
|
|
||||||
|
ticket = models.ServiceTicket.objects.get(value=ticket.value)
|
||||||
|
return (auth_client, ticket)
|
||||||
|
|
||||||
|
|
||||||
|
def get_pgt():
|
||||||
|
"""return a dict contening a service, user and PGT ticket for this service"""
|
||||||
|
(httpd, host, port) = HttpParamsHandler.run()[0:3]
|
||||||
|
service = "http://%s:%s" % (host, port)
|
||||||
|
|
||||||
|
(user, ticket) = get_user_ticket_request(service)[:2]
|
||||||
|
|
||||||
|
client = Client()
|
||||||
|
client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
|
||||||
|
params = httpd.PARAMS
|
||||||
|
|
||||||
|
params["service"] = service
|
||||||
|
params["user"] = user
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def get_proxy_ticket(service):
|
||||||
|
"""Return a ProxyTicket waiting for validation"""
|
||||||
|
params = get_pgt()
|
||||||
|
|
||||||
|
# get a proxy ticket
|
||||||
|
client = Client()
|
||||||
|
response = client.get('/proxy', {'pgt': params['pgtId'], 'targetService': service})
|
||||||
|
root = etree.fromstring(response.content)
|
||||||
|
proxy_ticket = root.xpath(
|
||||||
|
"//cas:proxyTicket",
|
||||||
|
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||||
|
)
|
||||||
|
proxy_ticket = proxy_ticket[0].text
|
||||||
|
ticket = models.ProxyTicket.objects.get(value=proxy_ticket)
|
||||||
|
return ticket
|
||||||
|
|
||||||
|
|
||||||
|
class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler):
|
||||||
|
"""
|
||||||
|
A simple http server that return 200 on GET or POST
|
||||||
|
and store GET or POST parameters. Used in unit tests
|
||||||
|
"""
|
||||||
|
|
||||||
|
def do_GET(self):
|
||||||
|
"""Called on a GET request on the BaseHTTPServer"""
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header(b"Content-type", "text/plain")
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(b"ok")
|
||||||
|
url = urlparse(self.path)
|
||||||
|
params = dict(parse_qsl(url.query))
|
||||||
|
self.server.PARAMS = params
|
||||||
|
|
||||||
|
def do_POST(self):
|
||||||
|
"""Called on a POST request on the BaseHTTPServer"""
|
||||||
|
ctype, pdict = cgi.parse_header(self.headers.get('content-type'))
|
||||||
|
if ctype == 'multipart/form-data':
|
||||||
|
postvars = cgi.parse_multipart(self.rfile, pdict)
|
||||||
|
elif ctype == 'application/x-www-form-urlencoded':
|
||||||
|
length = int(self.headers.get('content-length'))
|
||||||
|
postvars = cgi.parse_qs(self.rfile.read(length), keep_blank_values=1)
|
||||||
|
else:
|
||||||
|
postvars = {}
|
||||||
|
self.server.PARAMS = postvars
|
||||||
|
|
||||||
|
def log_message(self, *args):
|
||||||
|
"""silent any log message"""
|
||||||
|
return
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def run(cls):
|
||||||
|
"""Run a BaseHTTPServer using this class as handler"""
|
||||||
|
server_class = BaseHTTPServer.HTTPServer
|
||||||
|
httpd = server_class(("127.0.0.1", 0), cls)
|
||||||
|
(host, port) = httpd.socket.getsockname()
|
||||||
|
|
||||||
|
def lauch():
|
||||||
|
"""routine to lauch in a background thread"""
|
||||||
|
httpd.handle_request()
|
||||||
|
httpd.server_close()
|
||||||
|
|
||||||
|
httpd_thread = Thread(target=lauch)
|
||||||
|
httpd_thread.daemon = True
|
||||||
|
httpd_thread.start()
|
||||||
|
return (httpd, host, port)
|
||||||
|
|
||||||
|
|
||||||
|
class Http404Handler(HttpParamsHandler):
|
||||||
|
"""A simple http server that always return 404 not found. Used in unit tests"""
|
||||||
|
def do_GET(self):
|
||||||
|
"""Called on a GET request on the BaseHTTPServer"""
|
||||||
|
self.send_response(404)
|
||||||
|
self.send_header(b"Content-type", "text/plain")
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(b"error 404 not found")
|
||||||
|
|
||||||
|
def do_POST(self):
|
||||||
|
"""Called on a POST request on the BaseHTTPServer"""
|
||||||
|
return self.do_GET()
|
|
@ -8,7 +8,7 @@
|
||||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
||||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||||
#
|
#
|
||||||
# (c) 2015 Valentin Samir
|
# (c) 2015-2016 Valentin Samir
|
||||||
"""urls for the app"""
|
"""urls for the app"""
|
||||||
from django.conf.urls import patterns, url
|
from django.conf.urls import patterns, url
|
||||||
from django.views.generic import RedirectView
|
from django.views.generic import RedirectView
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# ⁻*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# This program is distributed in the hope that it will be useful, but WITHOUT
|
# This program is distributed in the hope that it will be useful, but WITHOUT
|
||||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
||||||
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
|
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
|
||||||
|
@ -8,7 +8,7 @@
|
||||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
||||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||||
#
|
#
|
||||||
# (c) 2015 Valentin Samir
|
# (c) 2015-2016 Valentin Samir
|
||||||
"""Some util function for the app"""
|
"""Some util function for the app"""
|
||||||
from .default_settings import settings
|
from .default_settings import settings
|
||||||
|
|
||||||
|
@ -23,18 +23,19 @@ import hashlib
|
||||||
import crypt
|
import crypt
|
||||||
import base64
|
import base64
|
||||||
import six
|
import six
|
||||||
from threading import Thread
|
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from six.moves import BaseHTTPServer
|
|
||||||
from six.moves.urllib.parse import urlparse, urlunparse, parse_qsl, urlencode
|
from six.moves.urllib.parse import urlparse, urlunparse, parse_qsl, urlencode
|
||||||
|
|
||||||
|
|
||||||
def context(params):
|
def context(params):
|
||||||
|
"""Function that add somes variable to the context before template rendering"""
|
||||||
params["settings"] = settings
|
params["settings"] = settings
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def json_response(request, data):
|
def json_response(request, data):
|
||||||
|
"""Wrapper dumping `data` to a json and sending it to the user with an HttpResponse"""
|
||||||
data["messages"] = []
|
data["messages"] = []
|
||||||
for msg in messages.get_messages(request):
|
for msg in messages.get_messages(request):
|
||||||
data["messages"].append({'message': msg.message, 'level': msg.level_tag})
|
data["messages"].append({'message': msg.message, 'level': msg.level_tag})
|
||||||
|
@ -64,6 +65,7 @@ def redirect_params(url_name, params=None):
|
||||||
|
|
||||||
|
|
||||||
def reverse_params(url_name, params=None, **kwargs):
|
def reverse_params(url_name, params=None, **kwargs):
|
||||||
|
"""compule the reverse url or `url_name` and add GET parameters from `params` to it"""
|
||||||
url = reverse(url_name, **kwargs)
|
url = reverse(url_name, **kwargs)
|
||||||
params = urlencode(params if params else {})
|
params = urlencode(params if params else {})
|
||||||
return url + "?%s" % params
|
return url + "?%s" % params
|
||||||
|
@ -83,10 +85,13 @@ def update_url(url, params):
|
||||||
url_parts = list(urlparse(url))
|
url_parts = list(urlparse(url))
|
||||||
query = dict(parse_qsl(url_parts[4]))
|
query = dict(parse_qsl(url_parts[4]))
|
||||||
query.update(params)
|
query.update(params)
|
||||||
url_parts[4] = urlencode(query)
|
# make the params order deterministic
|
||||||
for i, url_part in enumerate(url_parts):
|
query = list(query.items())
|
||||||
if not isinstance(url_part, bytes):
|
query.sort()
|
||||||
url_parts[i] = url_part.encode('utf-8')
|
url_query = urlencode(query)
|
||||||
|
if not isinstance(url_query, bytes): # pragma: no cover in python3 urlencode return an unicode
|
||||||
|
url_query = url_query.encode("utf-8")
|
||||||
|
url_parts[4] = url_query
|
||||||
return urlunparse(url_parts).decode('utf-8')
|
return urlunparse(url_parts).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
@ -147,35 +152,25 @@ def gen_saml_id():
|
||||||
return _gen_ticket('_')
|
return _gen_ticket('_')
|
||||||
|
|
||||||
|
|
||||||
class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler):
|
def crypt_salt_is_valid(salt):
|
||||||
PARAMS = {}
|
"""Return True is salt is valid has a crypt salt, False otherwise"""
|
||||||
|
if len(salt) < 2:
|
||||||
def do_GET(self):
|
return False
|
||||||
self.send_response(200)
|
else:
|
||||||
self.send_header(b"Content-type", "text/plain")
|
if salt[0] == '$':
|
||||||
self.end_headers()
|
if salt[1] == '$':
|
||||||
self.wfile.write(b"ok")
|
return False
|
||||||
url = urlparse(self.path)
|
else:
|
||||||
params = dict(parse_qsl(url.query))
|
if '$' not in salt[1:]:
|
||||||
PGTUrlHandler.PARAMS.update(params)
|
return False
|
||||||
|
else:
|
||||||
def log_message(self, *args):
|
hashed = crypt.crypt("", salt)
|
||||||
return
|
if not hashed or '$' not in hashed[1:]:
|
||||||
|
return False
|
||||||
@staticmethod
|
else:
|
||||||
def run():
|
return True
|
||||||
server_class = BaseHTTPServer.HTTPServer
|
else:
|
||||||
httpd = server_class(("127.0.0.1", 0), PGTUrlHandler)
|
return True
|
||||||
(host, port) = httpd.socket.getsockname()
|
|
||||||
|
|
||||||
def lauch():
|
|
||||||
httpd.handle_request()
|
|
||||||
httpd.server_close()
|
|
||||||
|
|
||||||
httpd_thread = Thread(target=lauch)
|
|
||||||
httpd_thread.daemon = True
|
|
||||||
httpd_thread.start()
|
|
||||||
return (httpd_thread, host, port)
|
|
||||||
|
|
||||||
|
|
||||||
class LdapHashUserPassword(object):
|
class LdapHashUserPassword(object):
|
||||||
|
@ -268,7 +263,7 @@ class LdapHashUserPassword(object):
|
||||||
if salt is None or salt == b"":
|
if salt is None or salt == b"":
|
||||||
salt = b""
|
salt = b""
|
||||||
cls._test_scheme_nosalt(scheme)
|
cls._test_scheme_nosalt(scheme)
|
||||||
elif salt is not None:
|
else:
|
||||||
cls._test_scheme_salt(scheme)
|
cls._test_scheme_salt(scheme)
|
||||||
try:
|
try:
|
||||||
return scheme + base64.b64encode(
|
return scheme + base64.b64encode(
|
||||||
|
@ -278,9 +273,9 @@ class LdapHashUserPassword(object):
|
||||||
if six.PY3:
|
if six.PY3:
|
||||||
password = password.decode(charset)
|
password = password.decode(charset)
|
||||||
salt = salt.decode(charset)
|
salt = salt.decode(charset)
|
||||||
hashed_password = crypt.crypt(password, salt)
|
if not crypt_salt_is_valid(salt):
|
||||||
if hashed_password is None:
|
|
||||||
raise cls.BadSalt("System crypt implementation do not support the salt %r" % salt)
|
raise cls.BadSalt("System crypt implementation do not support the salt %r" % salt)
|
||||||
|
hashed_password = crypt.crypt(password, salt)
|
||||||
if six.PY3:
|
if six.PY3:
|
||||||
hashed_password = hashed_password.encode(charset)
|
hashed_password = hashed_password.encode(charset)
|
||||||
return scheme + hashed_password
|
return scheme + hashed_password
|
||||||
|
@ -302,7 +297,7 @@ class LdapHashUserPassword(object):
|
||||||
if scheme in cls.schemes_nosalt:
|
if scheme in cls.schemes_nosalt:
|
||||||
return b""
|
return b""
|
||||||
elif scheme == b'{CRYPT}':
|
elif scheme == b'{CRYPT}':
|
||||||
return b'$'.join(hashed_passord.split(b'$', 3)[:-1])
|
return b'$'.join(hashed_passord.split(b'$', 3)[:-1])[len(scheme):]
|
||||||
else:
|
else:
|
||||||
hashed_passord = base64.b64decode(hashed_passord[len(scheme):])
|
hashed_passord = base64.b64decode(hashed_passord[len(scheme):])
|
||||||
if len(hashed_passord) < cls._schemes_to_len[scheme]:
|
if len(hashed_passord) < cls._schemes_to_len[scheme]:
|
||||||
|
@ -324,7 +319,7 @@ def check_password(method, password, hashed_password, charset):
|
||||||
elif method == "crypt":
|
elif method == "crypt":
|
||||||
if hashed_password.startswith(b'$'):
|
if hashed_password.startswith(b'$'):
|
||||||
salt = b'$'.join(hashed_password.split(b'$', 3)[:-1])
|
salt = b'$'.join(hashed_password.split(b'$', 3)[:-1])
|
||||||
elif hashed_password.startswith(b'_'):
|
elif hashed_password.startswith(b'_'): # pragma: no cover old BSD format not supported
|
||||||
salt = hashed_password[:9]
|
salt = hashed_password[:9]
|
||||||
else:
|
else:
|
||||||
salt = hashed_password[:2]
|
salt = hashed_password[:2]
|
||||||
|
@ -332,9 +327,9 @@ def check_password(method, password, hashed_password, charset):
|
||||||
password = password.decode(charset)
|
password = password.decode(charset)
|
||||||
salt = salt.decode(charset)
|
salt = salt.decode(charset)
|
||||||
hashed_password = hashed_password.decode(charset)
|
hashed_password = hashed_password.decode(charset)
|
||||||
crypted_password = crypt.crypt(password, salt)
|
if not crypt_salt_is_valid(salt):
|
||||||
if crypted_password is None:
|
|
||||||
raise ValueError("System crypt implementation do not support the salt %r" % salt)
|
raise ValueError("System crypt implementation do not support the salt %r" % salt)
|
||||||
|
crypted_password = crypt.crypt(password, salt)
|
||||||
return crypted_password == hashed_password
|
return crypted_password == hashed_password
|
||||||
elif method == "ldap":
|
elif method == "ldap":
|
||||||
scheme = LdapHashUserPassword.get_scheme(hashed_password)
|
scheme = LdapHashUserPassword.get_scheme(hashed_password)
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
# along with this program; if not, write to the Free Software Foundation, Inc., 51
|
||||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||||
#
|
#
|
||||||
# (c) 2015 Valentin Samir
|
# (c) 2015-2016 Valentin Samir
|
||||||
"""views for the app"""
|
"""views for the app"""
|
||||||
from .default_settings import settings
|
from .default_settings import settings
|
||||||
|
|
||||||
|
@ -105,10 +105,11 @@ class LogoutView(View, LogoutMixin):
|
||||||
service = None
|
service = None
|
||||||
|
|
||||||
def init_get(self, request):
|
def init_get(self, request):
|
||||||
|
"""Initialize GET received parameters"""
|
||||||
self.request = request
|
self.request = request
|
||||||
self.service = request.GET.get('service')
|
self.service = request.GET.get('service')
|
||||||
self.url = request.GET.get('url')
|
self.url = request.GET.get('url')
|
||||||
self.ajax = 'HTTP_X_AJAX' in request.META
|
self.ajax = settings.CAS_ENABLE_AJAX_AUTH and 'HTTP_X_AJAX' in request.META
|
||||||
|
|
||||||
def get(self, request, *args, **kwargs):
|
def get(self, request, *args, **kwargs):
|
||||||
"""methode called on GET request on this view"""
|
"""methode called on GET request on this view"""
|
||||||
|
@ -196,24 +197,30 @@ class LoginView(View, LogoutMixin):
|
||||||
USER_NOT_AUTHENTICATED = 6
|
USER_NOT_AUTHENTICATED = 6
|
||||||
|
|
||||||
def init_post(self, request):
|
def init_post(self, request):
|
||||||
|
"""Initialize POST received parameters"""
|
||||||
self.request = request
|
self.request = request
|
||||||
self.service = request.POST.get('service')
|
self.service = request.POST.get('service')
|
||||||
self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False")
|
self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False")
|
||||||
self.gateway = request.POST.get('gateway')
|
self.gateway = request.POST.get('gateway')
|
||||||
self.method = request.POST.get('method')
|
self.method = request.POST.get('method')
|
||||||
self.ajax = 'HTTP_X_AJAX' in request.META
|
self.ajax = settings.CAS_ENABLE_AJAX_AUTH and 'HTTP_X_AJAX' in request.META
|
||||||
if request.POST.get('warned') and request.POST['warned'] != "False":
|
if request.POST.get('warned') and request.POST['warned'] != "False":
|
||||||
self.warned = True
|
self.warned = True
|
||||||
|
self.warn = request.POST.get('warn')
|
||||||
|
|
||||||
def check_lt(self):
|
def gen_lt(self):
|
||||||
# save LT for later check
|
"""Generate a new LoginTicket and add it to the list of valid LT for the user"""
|
||||||
lt_valid = self.request.session.get('lt', [])
|
|
||||||
lt_send = self.request.POST.get('lt')
|
|
||||||
# generate a new LT (by posting the LT has been consumed)
|
|
||||||
self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
|
self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
|
||||||
if len(self.request.session['lt']) > 100:
|
if len(self.request.session['lt']) > 100:
|
||||||
self.request.session['lt'] = self.request.session['lt'][-100:]
|
self.request.session['lt'] = self.request.session['lt'][-100:]
|
||||||
|
|
||||||
|
def check_lt(self):
|
||||||
|
"""Check is the POSTed LoginTicket is valid, if yes invalide it"""
|
||||||
|
# save LT for later check
|
||||||
|
lt_valid = self.request.session.get('lt', [])
|
||||||
|
lt_send = self.request.POST.get('lt')
|
||||||
|
# generate a new LT (by posting the LT has been consumed)
|
||||||
|
self.gen_lt()
|
||||||
# check if send LT is valid
|
# check if send LT is valid
|
||||||
if lt_valid is None or lt_send not in lt_valid:
|
if lt_valid is None or lt_send not in lt_valid:
|
||||||
return False
|
return False
|
||||||
|
@ -238,7 +245,7 @@ class LoginView(View, LogoutMixin):
|
||||||
username=self.request.session['username'],
|
username=self.request.session['username'],
|
||||||
session_key=self.request.session.session_key
|
session_key=self.request.session.session_key
|
||||||
)
|
)
|
||||||
self.user.save()
|
self.user.save() # pragma: no cover (should not happend)
|
||||||
except models.User.DoesNotExist:
|
except models.User.DoesNotExist:
|
||||||
self.user = models.User.objects.create(
|
self.user = models.User.objects.create(
|
||||||
username=self.request.session['username'],
|
username=self.request.session['username'],
|
||||||
|
@ -250,10 +257,15 @@ class LoginView(View, LogoutMixin):
|
||||||
elif ret == self.USER_ALREADY_LOGGED:
|
elif ret == self.USER_ALREADY_LOGGED:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise EnvironmentError("invalid output for LoginView.process_post")
|
raise EnvironmentError("invalid output for LoginView.process_post") # pragma: no cover
|
||||||
return self.common()
|
return self.common()
|
||||||
|
|
||||||
def process_post(self):
|
def process_post(self):
|
||||||
|
"""
|
||||||
|
Analyse the POST request:
|
||||||
|
* check that the LoginTicket is valid
|
||||||
|
* check that the user sumited credentials are valid
|
||||||
|
"""
|
||||||
if not self.check_lt():
|
if not self.check_lt():
|
||||||
values = self.request.POST.copy()
|
values = self.request.POST.copy()
|
||||||
# if not set a new LT and fail
|
# if not set a new LT and fail
|
||||||
|
@ -280,12 +292,14 @@ class LoginView(View, LogoutMixin):
|
||||||
return self.USER_ALREADY_LOGGED
|
return self.USER_ALREADY_LOGGED
|
||||||
|
|
||||||
def init_get(self, request):
|
def init_get(self, request):
|
||||||
|
"""Initialize GET received parameters"""
|
||||||
self.request = request
|
self.request = request
|
||||||
self.service = request.GET.get('service')
|
self.service = request.GET.get('service')
|
||||||
self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False")
|
self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False")
|
||||||
self.gateway = request.GET.get('gateway')
|
self.gateway = request.GET.get('gateway')
|
||||||
self.method = request.GET.get('method')
|
self.method = request.GET.get('method')
|
||||||
self.ajax = 'HTTP_X_AJAX' in request.META
|
self.ajax = settings.CAS_ENABLE_AJAX_AUTH and 'HTTP_X_AJAX' in request.META
|
||||||
|
self.warn = request.GET.get('warn')
|
||||||
|
|
||||||
def get(self, request, *args, **kwargs):
|
def get(self, request, *args, **kwargs):
|
||||||
"""methode called on GET request on this view"""
|
"""methode called on GET request on this view"""
|
||||||
|
@ -294,22 +308,24 @@ class LoginView(View, LogoutMixin):
|
||||||
return self.common()
|
return self.common()
|
||||||
|
|
||||||
def process_get(self):
|
def process_get(self):
|
||||||
# generate a new LT if none is present
|
"""Analyse the GET request"""
|
||||||
self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
|
# generate a new LT
|
||||||
|
self.gen_lt()
|
||||||
if not self.request.session.get("authenticated") or self.renew:
|
if not self.request.session.get("authenticated") or self.renew:
|
||||||
self.init_form()
|
self.init_form()
|
||||||
return self.USER_NOT_AUTHENTICATED
|
return self.USER_NOT_AUTHENTICATED
|
||||||
return self.USER_AUTHENTICATED
|
return self.USER_AUTHENTICATED
|
||||||
|
|
||||||
def init_form(self, values=None):
|
def init_form(self, values=None):
|
||||||
|
"""Initialization of the good form depending of POST and GET parameters"""
|
||||||
self.form = forms.UserCredential(
|
self.form = forms.UserCredential(
|
||||||
values,
|
values,
|
||||||
initial={
|
initial={
|
||||||
'service': self.service,
|
'service': self.service,
|
||||||
'method': self.method,
|
'method': self.method,
|
||||||
'warn': self.request.session.get("warn"),
|
'warn': self.request.session.get("warn"),
|
||||||
'lt': self.request.session['lt'][-1]
|
'lt': self.request.session['lt'][-1],
|
||||||
|
'renew': self.renew
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -351,7 +367,7 @@ class LoginView(View, LogoutMixin):
|
||||||
redirect_url = self.user.get_service_url(
|
redirect_url = self.user.get_service_url(
|
||||||
self.service,
|
self.service,
|
||||||
service_pattern,
|
service_pattern,
|
||||||
renew=self.renew
|
renew=self.renewed
|
||||||
)
|
)
|
||||||
if not self.ajax:
|
if not self.ajax:
|
||||||
return HttpResponseRedirect(redirect_url)
|
return HttpResponseRedirect(redirect_url)
|
||||||
|
@ -580,12 +596,9 @@ class Validate(View):
|
||||||
ticket.service_pattern.user_field
|
ticket.service_pattern.user_field
|
||||||
)
|
)
|
||||||
if isinstance(username, list):
|
if isinstance(username, list):
|
||||||
try:
|
# the list is not empty because we wont generate a ticket with a user_field
|
||||||
username = username[0]
|
# that evaluate to False
|
||||||
except IndexError:
|
username = username[0]
|
||||||
username = None
|
|
||||||
if not username:
|
|
||||||
username = ""
|
|
||||||
else:
|
else:
|
||||||
username = ticket.user.username
|
username = ticket.user.username
|
||||||
return HttpResponse("yes\n%s\n" % username, content_type="text/plain")
|
return HttpResponse("yes\n%s\n" % username, content_type="text/plain")
|
||||||
|
@ -661,6 +674,10 @@ class ValidateService(View, AttributesMixin):
|
||||||
params['username'] = self.ticket.user.attributs.get(
|
params['username'] = self.ticket.user.attributs.get(
|
||||||
self.ticket.service_pattern.user_field
|
self.ticket.service_pattern.user_field
|
||||||
)
|
)
|
||||||
|
if isinstance(params['username'], list):
|
||||||
|
# the list is not empty because we wont generate a ticket with a user_field
|
||||||
|
# that evaluate to False
|
||||||
|
params['username'] = params['username'][0]
|
||||||
if self.pgt_url and (
|
if self.pgt_url and (
|
||||||
self.pgt_url.startswith("https://") or
|
self.pgt_url.startswith("https://") or
|
||||||
re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url)
|
re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url)
|
||||||
|
@ -762,9 +779,12 @@ class ValidateService(View, AttributesMixin):
|
||||||
params,
|
params,
|
||||||
content_type="text/xml; charset=utf-8"
|
content_type="text/xml; charset=utf-8"
|
||||||
)
|
)
|
||||||
except requests.exceptions.SSLError as error:
|
except requests.exceptions.RequestException as error:
|
||||||
error = utils.unpack_nested_exception(error)
|
error = utils.unpack_nested_exception(error)
|
||||||
raise ValidateError('INVALID_PROXY_CALLBACK', str(error))
|
raise ValidateError(
|
||||||
|
'INVALID_PROXY_CALLBACK',
|
||||||
|
"%s: %s" % (type(error), str(error))
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValidateError(
|
raise ValidateError(
|
||||||
'INVALID_PROXY_CALLBACK',
|
'INVALID_PROXY_CALLBACK',
|
||||||
|
@ -844,7 +864,7 @@ class Proxy(View):
|
||||||
except (models.BadUsername, models.BadFilter, models.UserFieldNotDefined):
|
except (models.BadUsername, models.BadFilter, models.UserFieldNotDefined):
|
||||||
raise ValidateError(
|
raise ValidateError(
|
||||||
'UNAUTHORIZED_USER',
|
'UNAUTHORIZED_USER',
|
||||||
'%s not allowed on %s' % (ticket.user, self.target_service)
|
'User %s not allowed on %s' % (ticket.user.username, self.target_service)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -903,11 +923,15 @@ class SamlValidate(View, AttributesMixin):
|
||||||
'username': self.ticket.user.username,
|
'username': self.ticket.user.username,
|
||||||
'attributes': attributes
|
'attributes': attributes
|
||||||
}
|
}
|
||||||
if self.ticket.service_pattern.user_field and \
|
if (self.ticket.service_pattern.user_field and
|
||||||
self.ticket.user.attributs.get(self.ticket.service_pattern.user_field):
|
self.ticket.user.attributs.get(self.ticket.service_pattern.user_field)):
|
||||||
params['username'] = self.ticket.user.attributs.get(
|
params['username'] = self.ticket.user.attributs.get(
|
||||||
self.ticket.service_pattern.user_field
|
self.ticket.service_pattern.user_field
|
||||||
)
|
)
|
||||||
|
if isinstance(params['username'], list):
|
||||||
|
# the list is not empty because we wont generate a ticket with a user_field
|
||||||
|
# that evaluate to False
|
||||||
|
params['username'] = params['username'][0]
|
||||||
logger.info(
|
logger.info(
|
||||||
"SamlValidate: ticket %s validated for user %s on service %s." % (
|
"SamlValidate: ticket %s validated for user %s on service %s." % (
|
||||||
self.ticket.value,
|
self.ticket.value,
|
||||||
|
|
5
pytest.ini
Normal file
5
pytest.ini
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
[pytest]
|
||||||
|
testpaths = cas_server/tests/
|
||||||
|
DJANGO_SETTINGS_MODULE = cas_server.tests.settings
|
||||||
|
norecursedirs = .* build dist docs
|
||||||
|
python_paths = .
|
|
@ -1,10 +1,12 @@
|
||||||
tox==1.8.1
|
setuptools>=5.5
|
||||||
pytest==2.6.4
|
tox>=1.8.1
|
||||||
pytest-django==2.7.0
|
pytest>=2.6.4
|
||||||
pytest-pythonpath==0.3
|
pytest-django>=2.8.0
|
||||||
|
pytest-pythonpath>=0.3
|
||||||
|
pytest-cov>=2.2.1
|
||||||
requests>=2.4
|
requests>=2.4
|
||||||
django-picklefield>=0.3.1
|
|
||||||
requests_futures>=0.9.5
|
requests_futures>=0.9.5
|
||||||
|
django-picklefield>=0.3.1
|
||||||
django-bootstrap3>=5.4
|
django-bootstrap3>=5.4
|
||||||
lxml>=3.4
|
lxml>=3.4
|
||||||
six>=1
|
six>=1
|
||||||
|
|
22
run_tests
22
run_tests
|
@ -1,22 +0,0 @@
|
||||||
#!/usr/bin/env python
|
|
||||||
import os, sys
|
|
||||||
import django
|
|
||||||
from django.conf import settings
|
|
||||||
|
|
||||||
import settings_tests
|
|
||||||
|
|
||||||
settings.configure(**settings_tests.__dict__)
|
|
||||||
django.setup()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Django <= 1.8
|
|
||||||
from django.test.simple import DjangoTestSuiteRunner
|
|
||||||
test_runner = DjangoTestSuiteRunner(verbosity=1)
|
|
||||||
except ImportError:
|
|
||||||
# Django >= 1.8
|
|
||||||
from django.test.runner import DiscoverRunner
|
|
||||||
test_runner = DiscoverRunner(verbosity=1)
|
|
||||||
|
|
||||||
failures = test_runner.run_tests(['cas_server'])
|
|
||||||
if failures:
|
|
||||||
sys.exit(failures)
|
|
3
setup.py
3
setup.py
|
@ -34,7 +34,8 @@ setup(
|
||||||
version='0.4.4',
|
version='0.4.4',
|
||||||
packages=[
|
packages=[
|
||||||
'cas_server', 'cas_server.migrations',
|
'cas_server', 'cas_server.migrations',
|
||||||
'cas_server.management', 'cas_server.management.commands'
|
'cas_server.management', 'cas_server.management.commands',
|
||||||
|
'cas_server.tests'
|
||||||
],
|
],
|
||||||
include_package_data=True,
|
include_package_data=True,
|
||||||
license='GPLv3',
|
license='GPLv3',
|
||||||
|
|
14
tox.ini
14
tox.ini
|
@ -1,12 +1,12 @@
|
||||||
[tox]
|
[tox]
|
||||||
envlist=
|
envlist=
|
||||||
|
flake8,
|
||||||
py27-django17,
|
py27-django17,
|
||||||
py27-django18,
|
py27-django18,
|
||||||
py27-django19,
|
py27-django19,
|
||||||
py34-django17,
|
py34-django17,
|
||||||
py34-django18,
|
py34-django18,
|
||||||
py34-django19,
|
py34-django19,
|
||||||
flake8,
|
|
||||||
|
|
||||||
[flake8]
|
[flake8]
|
||||||
max-line-length=100
|
max-line-length=100
|
||||||
|
@ -17,7 +17,7 @@ deps =
|
||||||
-r{toxinidir}/requirements-dev.txt
|
-r{toxinidir}/requirements-dev.txt
|
||||||
|
|
||||||
[testenv]
|
[testenv]
|
||||||
commands=python run_tests {posargs:tests}
|
commands=py.test {posargs:cas_server/tests/}
|
||||||
|
|
||||||
[testenv:py27-django17]
|
[testenv:py27-django17]
|
||||||
basepython=python2.7
|
basepython=python2.7
|
||||||
|
@ -60,3 +60,13 @@ basepython=python
|
||||||
deps=flake8
|
deps=flake8
|
||||||
commands=flake8 {toxinidir}/cas_server
|
commands=flake8 {toxinidir}/cas_server
|
||||||
|
|
||||||
|
[testenv:coverage]
|
||||||
|
basepython=python
|
||||||
|
passenv=CODACY_PROJECT_TOKEN
|
||||||
|
deps=
|
||||||
|
-r{toxinidir}/requirements-dev.txt
|
||||||
|
codacy-coverage
|
||||||
|
commands=
|
||||||
|
py.test --cov=cas_server --cov-report xml
|
||||||
|
python-codacy-coverage -r {toxinidir}/coverage.xml
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue