Merge branch 'master' into federate

This commit is contained in:
Valentin Samir 2016-07-01 16:42:31 +02:00
commit 63f5b2cabf
30 changed files with 2788 additions and 1178 deletions

View file

@ -1,3 +1,11 @@
[run]
branch = True
source = cas_server
omit =
cas_server/migrations*
cas_server/management/*
cas_server/tests/*
[report] [report]
exclude_lines = exclude_lines =
pragma: no cover pragma: no cover

View file

@ -2,19 +2,20 @@ language: python
python: python:
- "2.7" - "2.7"
env: env:
global:
- PIP_DOWNLOAD_CACHE=$HOME/.pip_cache
matrix: matrix:
- TOX_ENV=coverage
- TOX_ENV=flake8
- TOX_ENV=check_rst
- TOX_ENV=py27-django17 - TOX_ENV=py27-django17
- TOX_ENV=py27-django18 - TOX_ENV=py27-django18
- TOX_ENV=py27-django19 - TOX_ENV=py27-django19
- TOX_ENV=py34-django17 - TOX_ENV=py34-django17
- TOX_ENV=py34-django18 - TOX_ENV=py34-django18
- TOX_ENV=py34-django19 - TOX_ENV=py34-django19
- TOX_ENV=flake8
cache: cache:
directories: directories:
- $HOME/.pip-cache/ - $HOME/.cache/pip/
- $HOME/build/nitmir/django-cas-server/.tox/
install: install:
- "travis_retry pip install setuptools --upgrade" - "travis_retry pip install setuptools --upgrade"
- "pip install tox" - "pip install tox"
@ -22,4 +23,3 @@ script:
- tox -e $TOX_ENV - tox -e $TOX_ENV
after_script: after_script:
- cat .tox/$TOX_ENV/log/*.log - cat .tox/$TOX_ENV/log/*.log

View file

@ -1,11 +1,15 @@
.PHONY: clean build install dist test_venv test_project .PHONY: build dist
VERSION=`python setup.py -V` VERSION=`python setup.py -V`
build: build:
python setup.py build python setup.py build
install: install: dist
python setup.py install pip -V
pip install --no-deps --upgrade --force-reinstall --find-links ./dist/django-cas-server-${VERSION}.tar.gz django-cas-server
uninstall:
pip uninstall django-cas-server || true
clean_pyc: clean_pyc:
find ./ -name '*.pyc' -delete find ./ -name '*.pyc' -delete
@ -16,18 +20,23 @@ clean_tox:
rm -rf .tox rm -rf .tox
clean_test_venv: clean_test_venv:
rm -rf test_venv rm -rf test_venv
clean: clean_pyc clean_build clean_coverage:
clean_all: clean_pyc clean_build clean_tox clean_test_venv rm -rf coverage.xml .coverage htmlcov
clean_tild_backup:
find ./ -name '*~' -delete
clean: clean_pyc clean_build clean_coverage clean_tild_backup
clean_all: clean clean_tox clean_test_venv
dist: dist:
python setup.py sdist python setup.py sdist
test_venv: test_venv/bin/python:
mkdir -p test_venv
virtualenv test_venv virtualenv test_venv
test_venv/bin/pip install -U --requirement requirements.txt test_venv/bin/pip install -U --requirement requirements-dev.txt Django
test_venv/cas/manage.py: test_venv/cas/manage.py: test_venv
mkdir -p test_venv/cas mkdir -p test_venv/cas
test_venv/bin/django-admin startproject cas test_venv/cas test_venv/bin/django-admin startproject cas test_venv/cas
ln -s ../../cas_server test_venv/cas/cas_server ln -s ../../cas_server test_venv/cas/cas_server
@ -38,20 +47,16 @@ test_venv/cas/manage.py:
test_venv/bin/python test_venv/cas/manage.py migrate test_venv/bin/python test_venv/cas/manage.py migrate
test_venv/bin/python test_venv/cas/manage.py createsuperuser test_venv/bin/python test_venv/cas/manage.py createsuperuser
test_project: test_venv test_venv/cas/manage.py test_venv: test_venv/bin/python
test_project: test_venv/cas/manage.py
@echo "##############################################################" @echo "##############################################################"
@echo "A test django project was created in $(realpath test_venv/cas)" @echo "A test django project was created in $(realpath test_venv/cas)"
run_test_server: test_project run_server: test_project
test_venv/bin/python test_venv/cas/manage.py runserver test_venv/bin/python test_venv/cas/manage.py runserver
coverage: test_venv run_tests: test_venv
test_venv/bin/pip install coverage python setup.py check --restructuredtext --stric
test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests test_venv/bin/py.test --cov=cas_server --cov-report html
test_venv/bin/coverage html
rm htmlcov/coverage_html.js # I am really pissed off by those keybord shortcuts rm htmlcov/coverage_html.js # I am really pissed off by those keybord shortcuts
coverage_codacy: coverage
test_venv/bin/coverage xml
test_venv/bin/pip install codacy-coverage
test_venv/bin/python-codacy-coverage -r coverage.xml

View file

@ -22,11 +22,11 @@ CAS Server is a Django application implementing the `CAS Protocol 3.0 Specificat
By defaut, the authentication process use django internal users but you can easily By defaut, the authentication process use django internal users but you can easily
use any sources (see auth classes in the auth.py file) use any sources (see auth classes in the auth.py file)
The defaut login/logout template use `django-bootstrap3 <https://github.com/dyve/django-bootstrap3>`_ The defaut login/logout template use `django-bootstrap3 <https://github.com/dyve/django-bootstrap3>`__
but you can use your own templates using settings variables. but you can use your own templates using settings variables.
Note that for Django 1.7 compatibility, you need a version of Note that for Django 1.7 compatibility, you need a version of
`django-bootstrap3 <https://github.com/dyve/django-bootstrap3>`_ < 7.0.0 `django-bootstrap3 <https://github.com/dyve/django-bootstrap3>`__ < 7.0.0
like the 6.2.2 version. like the 6.2.2 version.
Features Features
@ -43,7 +43,7 @@ Features
Quick start Quick start
----------- -----------
0. If you want to make a virtualenv for ``django-cas-server``, you will need the following 1. If you want to make a virtualenv for ``django-cas-server``, you will need the following
dependencies on a bare debian like system:: dependencies on a bare debian like system::
virtualenv build-essential python-dev libxml2-dev libxslt1-dev zlib1g-dev virtualenv build-essential python-dev libxml2-dev libxslt1-dev zlib1g-dev
@ -53,7 +53,7 @@ Quick start
If you intend to run the tox tests you will also need ``python3.4-dev`` depending of the current If you intend to run the tox tests you will also need ``python3.4-dev`` depending of the current
version of python3 on your system. version of python3 on your system.
1. Add "cas_server" to your INSTALLED_APPS setting like this:: 2. Add "cas_server" to your INSTALLED_APPS setting like this::
INSTALLED_APPS = ( INSTALLED_APPS = (
'django.contrib.admin', 'django.contrib.admin',
@ -71,7 +71,7 @@ Quick start
... ...
) )
2. Include the cas_server URLconf in your project urls.py like this:: 3. Include the cas_server URLconf in your project urls.py like this::
urlpatterns = [ urlpatterns = [
url(r'^admin/', admin.site.urls), url(r'^admin/', admin.site.urls),
@ -79,10 +79,10 @@ Quick start
url(r'^cas/', include('cas_server.urls', namespace="cas_server")), url(r'^cas/', include('cas_server.urls', namespace="cas_server")),
] ]
3. Run `python manage.py migrate` to create the cas_server models. 4. Run `python manage.py migrate` to create the cas_server models.
4. You should add some management commands to a crontab: ``clearsessions``, 5. You should add some management commands to a crontab: ``clearsessions``,
``cas_clean_tickets`` and ``cas_clean_sessions``. ``cas_clean_tickets`` and ``cas_clean_sessions``.
* ``clearsessions``: please see `Clearing the session store <https://docs.djangoproject.com/en/stable/topics/http/sessions/#clearing-the-session-store>`_. * ``clearsessions``: please see `Clearing the session store <https://docs.djangoproject.com/en/stable/topics/http/sessions/#clearing-the-session-store>`_.
@ -102,11 +102,11 @@ Quick start
*/5 * * * * cas-user /path/to/project/manage.py cas_clean_tickets */5 * * * * cas-user /path/to/project/manage.py cas_clean_tickets
5 0 * * * cas-user /path/to/project/manage.py cas_clean_sessions 5 0 * * * cas-user /path/to/project/manage.py cas_clean_sessions
5. Start the development server and visit http://127.0.0.1:8000/admin/ 6. Start the development server and visit http://127.0.0.1:8000/admin/
to add a first service allowed to authenticate user agains the CAS to add a first service allowed to authenticate user agains the CAS
(you'll need the Admin app enabled). (you'll need the Admin app enabled).
6. Visit http://127.0.0.1:8000/cas/ to login with your django users. 7. Visit http://127.0.0.1:8000/cas/ to login with your django users.

View file

@ -7,6 +7,6 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51 # along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
# #
# (c) 2015 Valentin Samir # (c) 2015-2016 Valentin Samir
"""A django CAS server application"""
default_app_config = 'cas_server.apps.CasAppConfig' default_app_config = 'cas_server.apps.CasAppConfig'

View file

@ -7,7 +7,7 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51 # along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
# #
# (c) 2015 Valentin Samir # (c) 2015-2016 Valentin Samir
"""module for the admin interface of the app""" """module for the admin interface of the app"""
from django.contrib import admin from django.contrib import admin
from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket, User, ServicePattern from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket, User, ServicePattern

View file

@ -1,7 +1,19 @@
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
# more details.
#
# You should have received a copy of the GNU General Public License version 3
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# (c) 2015-2016 Valentin Samir
"""django config module"""
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.apps import AppConfig from django.apps import AppConfig
class CasAppConfig(AppConfig): class CasAppConfig(AppConfig):
"""django CAS application config class"""
name = 'cas_server' name = 'cas_server'
verbose_name = _('Central Authentication Service') verbose_name = _('Central Authentication Service')

View file

@ -8,7 +8,7 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51 # along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
# #
# (c) 2015 Valentin Samir # (c) 2015-2016 Valentin Samir
"""Some authentication classes for the CAS""" """Some authentication classes for the CAS"""
from django.conf import settings from django.conf import settings
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
@ -26,6 +26,7 @@ from .models import FederatedUser
class AuthUser(object): class AuthUser(object):
"""Authentication base class"""
def __init__(self, username): def __init__(self, username):
self.username = username self.username = username

View file

@ -90,6 +90,8 @@ setting_default(
} }
) )
setting_default('CAS_ENABLE_AJAX_AUTH', False)
setting_default('CAS_FEDERATE', False) setting_default('CAS_FEDERATE', False)
# A dict of "provider suffix" -> (provider CAS server url, CAS version, verbose name) # A dict of "provider suffix" -> (provider CAS server url, CAS version, verbose name)
setting_default('CAS_FEDERATE_PROVIDERS', {}) setting_default('CAS_FEDERATE_PROVIDERS', {})

View file

@ -7,7 +7,7 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51 # along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
# #
# (c) 2015 Valentin Samir # (c) 2015-2016 Valentin Samir
"""forms for the app""" """forms for the app"""
from .default_settings import settings from .default_settings import settings
@ -19,6 +19,7 @@ import cas_server.models as models
class WarnForm(forms.Form): class WarnForm(forms.Form):
"""Form used on warn page before emiting a ticket"""
service = forms.CharField(widget=forms.HiddenInput(), required=False) service = forms.CharField(widget=forms.HiddenInput(), required=False)
renew = forms.BooleanField(widget=forms.HiddenInput(), required=False) renew = forms.BooleanField(widget=forms.HiddenInput(), required=False)
gateway = forms.CharField(widget=forms.HiddenInput(), required=False) gateway = forms.CharField(widget=forms.HiddenInput(), required=False)

View file

@ -1,3 +1,4 @@
"""Clean deleted sessions management command"""
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
@ -5,6 +6,7 @@ from ... import models
class Command(BaseCommand): class Command(BaseCommand):
"""Clean deleted sessions"""
args = '' args = ''
help = _(u"Clean deleted sessions") help = _(u"Clean deleted sessions")

View file

@ -1,3 +1,4 @@
"""Clean old trickets management command"""
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
@ -5,6 +6,7 @@ from ... import models
class Command(BaseCommand): class Command(BaseCommand):
"""Clean old trickets"""
args = '' args = ''
help = _(u"Clean old trickets") help = _(u"Clean old trickets")

View file

@ -8,7 +8,7 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51 # along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
# #
# (c) 2015 Valentin Samir # (c) 2015-2016 Valentin Samir
"""models for the app""" """models for the app"""
from .default_settings import settings from .default_settings import settings
@ -20,7 +20,6 @@ from django.utils import timezone
from picklefield.fields import PickledObjectField from picklefield.fields import PickledObjectField
import re import re
import os
import sys import sys
import logging import logging
from importlib import import_module from importlib import import_module
@ -79,6 +78,7 @@ class User(models.Model):
@classmethod @classmethod
def clean_old_entries(cls): def clean_old_entries(cls):
"""Remove users inactive since more that SESSION_COOKIE_AGE"""
users = cls.objects.filter( users = cls.objects.filter(
date__lt=(timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE)) date__lt=(timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE))
) )
@ -88,6 +88,7 @@ class User(models.Model):
@classmethod @classmethod
def clean_deleted_sessions(cls): def clean_deleted_sessions(cls):
"""Remove user where the session do not exists anymore"""
for user in cls.objects.all(): for user in cls.objects.all():
if not SessionStore(session_key=user.session_key).get('authenticated'): if not SessionStore(session_key=user.session_key).get('authenticated'):
user.logout() user.logout()
@ -112,10 +113,10 @@ class User(models.Model):
for ticket_class in ticket_classes: for ticket_class in ticket_classes:
queryset = ticket_class.objects.filter(user=self) queryset = ticket_class.objects.filter(user=self)
for ticket in queryset: for ticket in queryset:
ticket.logout(request, session, async_list) ticket.logout(session, async_list)
queryset.delete() queryset.delete()
for future in async_list: for future in async_list:
if future: if future: # pragma: no branch (should always be true)
try: try:
future.result() future.result()
except Exception as error: except Exception as error:
@ -143,12 +144,20 @@ class User(models.Model):
(a.name, a.replace if a.replace else a.name) for a in service_pattern.attributs.all() (a.name, a.replace if a.replace else a.name) for a in service_pattern.attributs.all()
) )
replacements = dict( replacements = dict(
(a.name, (a.pattern, a.replace)) for a in service_pattern.replacements.all() (a.attribut, (a.pattern, a.replace)) for a in service_pattern.replacements.all()
) )
service_attributs = {} service_attributs = {}
for (key, value) in self.attributs.items(): for (key, value) in self.attributs.items():
if key in attributs or '*' in attributs: if key in attributs or '*' in attributs:
if key in replacements: if key in replacements:
if isinstance(value, list):
for index, subval in enumerate(value):
value[index] = re.sub(
replacements[key][0],
replacements[key][1],
subval
)
else:
value = re.sub(replacements[key][0], replacements[key][1], value) value = re.sub(replacements[key][0], replacements[key][1], value)
service_attributs[attributs.get(key, key)] = value service_attributs[attributs.get(key, key)] = value
ticket = ticket_class.objects.create( ticket = ticket_class.objects.create(
@ -173,6 +182,7 @@ class User(models.Model):
class ServicePatternException(Exception): class ServicePatternException(Exception):
"""Base exception of exceptions raised in the ServicePattern model"""
pass pass
@ -426,7 +436,6 @@ class Ticket(models.Model):
).delete() ).delete()
# sending SLO to timed-out validated tickets # sending SLO to timed-out validated tickets
if cls.TIMEOUT and cls.TIMEOUT > 0:
async_list = [] async_list = []
session = FuturesSession( session = FuturesSession(
executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS) executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
@ -435,36 +444,35 @@ class Ticket(models.Model):
creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT)) creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT))
) )
for ticket in queryset: for ticket in queryset:
ticket.logout(None, session, async_list) ticket.logout(session, async_list)
queryset.delete() queryset.delete()
for future in async_list: for future in async_list:
if future: if future: # pragma: no branch (should always be true)
try: try:
future.result() future.result()
except Exception as error: except Exception as error:
logger.warning("Error durring SLO %s" % error) logger.warning("Error durring SLO %s" % error)
sys.stderr.write("%r\n" % error) sys.stderr.write("%r\n" % error)
def logout(self, request, session, async_list=None): def logout(self, session, async_list=None):
"""Send a SLO request to the ticket service""" """Send a SLO request to the ticket service"""
# On logout invalidate the Ticket # On logout invalidate the Ticket
self.validate = True self.validate = True
self.save() self.save()
if self.validate and self.single_log_out: if self.validate and self.single_log_out: # pragma: no branch (should always be true)
logger.info( logger.info(
"Sending SLO requests to service %s for user %s" % ( "Sending SLO requests to service %s for user %s" % (
self.service, self.service,
self.user.username self.user.username
) )
) )
try:
xml = u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" xml = u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s"> ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID> <saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
<samlp:SessionIndex>%(ticket)s</samlp:SessionIndex> <samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
</samlp:LogoutRequest>""" % \ </samlp:LogoutRequest>""" % \
{ {
'id': os.urandom(20).encode("hex"), 'id': utils.gen_saml_id(),
'datetime': timezone.now().isoformat(), 'datetime': timezone.now().isoformat(),
'ticket': self.value 'ticket': self.value
} }
@ -479,24 +487,6 @@ class Ticket(models.Model):
timeout=settings.CAS_SLO_TIMEOUT timeout=settings.CAS_SLO_TIMEOUT
) )
) )
except Exception as error:
error = utils.unpack_nested_exception(error)
logger.warning(
"Error durring SLO for user %s on service %s: %s" % (
self.user.username,
self.service,
error
)
)
if request is not None:
messages.add_message(
request,
messages.WARNING,
_(u'Error during service logout %(service)s:\n%(error)s') %
{'service': self.service, 'error': error}
)
else:
sys.stderr.write("%r\n" % error)
class ServiceTicket(Ticket): class ServiceTicket(Ticket):

View file

@ -1,964 +0,0 @@
from .default_settings import settings
import django
from django.test import TestCase
from django.test import Client
import re
import six
import random
from lxml import etree
from six.moves import range
from cas_server import models
from cas_server import utils
def copy_form(form):
"""Copy form value into a dict"""
params = {}
for field in form:
if field.value():
params[field.name] = field.value()
else:
params[field.name] = ""
return params
def get_login_page_params(client=None):
"""Return a client and the POST params for the client to login"""
if client is None:
client = Client()
response = client.get('/login')
params = copy_form(response.context["form"])
return client, params
def get_auth_client(**update):
"""return a authenticated client"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
params.update(update)
client.post('/login', params)
return client
def get_user_ticket_request(service):
"""Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
client = get_auth_client()
response = client.get("/login", {"service": service})
ticket_value = response['Location'].split('ticket=')[-1]
user = models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
ticket = models.ServiceTicket.objects.get(value=ticket_value)
return (user, ticket)
def get_pgt():
"""return a dict contening a service, user and PGT ticket for this service"""
(host, port) = utils.PGTUrlHandler.run()[1:3]
service = "http://%s:%s" % (host, port)
(user, ticket) = get_user_ticket_request(service)
client = Client()
client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
params = utils.PGTUrlHandler.PARAMS.copy()
params["service"] = service
params["user"] = user
return params
class CheckPasswordCase(TestCase):
"""Tests for the utils function `utils.check_password`"""
def setUp(self):
"""Generate random bytes string that will be used ass passwords"""
self.password1 = utils.gen_saml_id()
self.password2 = utils.gen_saml_id()
if not isinstance(self.password1, bytes):
self.password1 = self.password1.encode("utf8")
self.password2 = self.password2.encode("utf8")
def test_setup(self):
"""check that generated password are bytes"""
self.assertIsInstance(self.password1, bytes)
self.assertIsInstance(self.password2, bytes)
def test_plain(self):
"""test the plain auth method"""
self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
def test_crypt(self):
"""test the crypt auth method"""
if six.PY3:
hashed_password1 = utils.crypt.crypt(
self.password1.decode("utf8"),
"$6$UVVAQvrMyXMF3FF3"
).encode("utf8")
else:
hashed_password1 = utils.crypt.crypt(self.password1, "$6$UVVAQvrMyXMF3FF3")
self.assertTrue(utils.check_password("crypt", self.password1, hashed_password1, "utf8"))
self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8"))
def test_ldap_ssha(self):
"""test the ldap auth method with a {SSHA} scheme"""
salt = b"UVVAQvrMyXMF3FF3"
hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8")
self.assertIsInstance(hashed_password1, bytes)
self.assertTrue(utils.check_password("ldap", self.password1, hashed_password1, "utf8"))
self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8"))
def test_hex_md5(self):
"""test the hex_md5 auth method"""
hashed_password1 = utils.hashlib.md5(self.password1).hexdigest()
self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8"))
self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8"))
def test_hex_sha512(self):
"""test the hex_sha512 auth method"""
hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest()
self.assertTrue(
utils.check_password("hex_sha512", self.password1, hashed_password1, "utf8")
)
self.assertFalse(
utils.check_password("hex_sha512", self.password2, hashed_password1, "utf8")
)
class LoginTestCase(TestCase):
"""Tests for the login view"""
def setUp(self):
"""
Prepare the test context:
* set the auth class to 'cas_server.auth.TestAuthUser'
* create a service pattern for https://www.example.com/**
* Set the service pattern to return all user attributes
"""
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
# For general purpose testing
self.service_pattern = models.ServicePattern.objects.create(
name="example",
pattern="^https://www\.example\.com(/.*)?$",
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
# For testing the restrict_users attributes
self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create(
name="restrict_user_fail",
pattern="^https://restrict_user_fail\.example\.com(/.*)?$",
restrict_users=True,
)
self.service_pattern_restrict_user_success = models.ServicePattern.objects.create(
name="restrict_user_success",
pattern="^https://restrict_user_success\.example\.com(/.*)?$",
restrict_users=True,
)
models.Username.objects.create(
value=settings.CAS_TEST_USER,
service_pattern=self.service_pattern_restrict_user_success
)
# For testing the user attributes filtering conditions
self.service_pattern_filter_fail = models.ServicePattern.objects.create(
name="filter_fail",
pattern="^https://filter_fail\.example\.com(/.*)?$",
)
models.FilterAttributValue.objects.create(
attribut="right",
pattern="^admin$",
service_pattern=self.service_pattern_filter_fail
)
self.service_pattern_filter_success = models.ServicePattern.objects.create(
name="filter_success",
pattern="^https://filter_success\.example\.com(/.*)?$",
)
models.FilterAttributValue.objects.create(
attribut="email",
pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']),
service_pattern=self.service_pattern_filter_success
)
# For testing the user_field attributes
self.service_pattern_field_needed_fail = models.ServicePattern.objects.create(
name="field_needed_fail",
pattern="^https://field_needed_fail\.example\.com(/.*)?$",
user_field="uid"
)
self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
name="field_needed_success",
pattern="^https://field_needed_success\.example\.com(/.*)?$",
user_field="nom"
)
def assert_logged(self, client, response, warn=False, code=200):
"""Assertions testing that client is well authenticated"""
self.assertEqual(response.status_code, code)
self.assertTrue(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
self.assertTrue(client.session["username"] == settings.CAS_TEST_USER)
self.assertTrue(client.session["warn"] is warn)
self.assertTrue(client.session["authenticated"] is True)
self.assertTrue(
models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
)
def assert_login_failed(self, client, response, code=200):
"""Assertions testing a failed login attempt"""
self.assertEqual(response.status_code, code)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
self.assertTrue(client.session.get("username") is None)
self.assertTrue(client.session.get("warn") is None)
self.assertTrue(client.session.get("authenticated") is None)
def test_login_view_post_goodpass_goodlt(self):
"""Test a successul login"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
self.assertTrue(params['lt'] in client.session['lt'])
response = client.post('/login', params)
self.assert_logged(client, response)
# LoginTicket conssumed
self.assertTrue(params['lt'] not in client.session['lt'])
def test_login_view_post_goodpass_goodlt_warn(self):
"""Test a successul login requesting to be warned before creating services tickets"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
params["warn"] = "on"
response = client.post('/login', params)
self.assert_logged(client, response, warn=True)
def test_lt_max(self):
"""Check we only keep the last 100 Login Ticket for a user"""
client, params = get_login_page_params()
current_lt = params["lt"]
i_in_test = random.randint(0, 100)
i_not_in_test = random.randint(100, 150)
for i in range(150):
if i == i_in_test:
self.assertTrue(current_lt in client.session['lt'])
if i == i_not_in_test:
self.assertTrue(current_lt not in client.session['lt'])
self.assertTrue(len(client.session['lt']) <= 100)
client, params = get_login_page_params(client)
self.assertTrue(len(client.session['lt']) <= 100)
def test_login_view_post_badlt(self):
"""Login attempt with a bad LoginTicket"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
params["lt"] = 'LT-random'
response = client.post('/login', params)
self.assert_login_failed(client, response)
self.assertTrue(b"Invalid login ticket" in response.content)
def test_login_view_post_badpass_good_lt(self):
"""Login attempt with a bad password"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = "test2"
response = client.post('/login', params)
self.assert_login_failed(client, response)
self.assertTrue(
(
b"The credentials you provided cannot be "
b"determined to be authentic"
) in response.content
)
def assert_ticket_attributes(self, client, ticket_value):
"""check the ticket attributes in the db"""
user = models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
self.assertTrue(user)
ticket = models.ServiceTicket.objects.get(value=ticket_value)
self.assertEqual(ticket.user, user)
self.assertEqual(ticket.attributs, settings.CAS_TEST_ATTRIBUTES)
self.assertEqual(ticket.validate, False)
self.assertEqual(ticket.service_pattern, self.service_pattern)
def assert_service_ticket(self, client, response):
"""check that a ticket is well emited when requested on a allowed service"""
self.assertEqual(response.status_code, 302)
self.assertTrue(response.has_header('Location'))
self.assertTrue(
response['Location'].startswith(
"https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX
)
)
ticket_value = response['Location'].split('ticket=')[-1]
self.assert_ticket_attributes(client, ticket_value)
def test_view_login_get_allowed_service(self):
"""Request a ticket for an allowed service by an unauthenticated client"""
client = Client()
response = client.get("/login?service=https://www.example.com")
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"Authentication required by service "
b"example (https://www.example.com)"
) in response.content
)
def test_view_login_get_denied_service(self):
"""Request a ticket for an denied service by an unauthenticated client"""
client = Client()
response = client.get("/login?service=https://www.example.net")
self.assertEqual(response.status_code, 200)
self.assertTrue(b"Service https://www.example.net non allowed" in response.content)
def test_view_login_get_auth_allowed_service(self):
"""Request a ticket for an allowed service by an authenticated client"""
# client is already authenticated
client = get_auth_client()
response = client.get("/login?service=https://www.example.com")
self.assert_service_ticket(client, response)
def test_view_login_get_auth_allowed_service_warn(self):
"""Request a ticket for an allowed service by an authenticated client"""
# client is already authenticated
client = get_auth_client(warn="on")
response = client.get("/login?service=https://www.example.com")
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"Authentication has been required by service "
b"example (https://www.example.com)"
) in response.content
)
params = copy_form(response.context["form"])
response = client.post("/login", params)
self.assert_service_ticket(client, response)
def test_view_login_get_auth_denied_service(self):
"""Request a ticket for a not allowed service by an authenticated client"""
client = get_auth_client()
response = client.get("/login?service=https://www.example.org")
self.assertEqual(response.status_code, 200)
self.assertTrue(b"Service https://www.example.org non allowed" in response.content)
def test_user_logged_not_in_db(self):
"""If the user is logged but has been delete from the database, it should be logged out"""
client = get_auth_client()
models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
).delete()
response = client.get("/login")
self.assert_login_failed(client, response, code=302)
if django.VERSION < (1, 9):
self.assertEqual(response["Location"], "http://testserver/login")
else:
self.assertEqual(response["Location"], "/login?")
def test_service_restrict_user(self):
"""Testing the restric user capability fro a service"""
service = "https://restrict_user_fail.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 200)
self.assertTrue(b"Username non allowed" in response.content)
service = "https://restrict_user_success.example.com"
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 302)
self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
def test_service_filter(self):
"""Test the filtering on user attributes"""
service = "https://filter_fail.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 200)
self.assertTrue(b"User charateristics non allowed" in response.content)
service = "https://filter_success.example.com"
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 302)
self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
def test_service_user_field(self):
"""Test using a user attribute as username: case on if the attribute exists or not"""
service = "https://field_needed_fail.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 200)
self.assertTrue(b"The attribut uid is needed to use that service" in response.content)
service = "https://field_needed_success.example.com"
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 302)
self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
def test_gateway(self):
"""test gateway parameter"""
# First with an authenticated client that fail to get a ticket for a service
service = "https://restrict_user_fail.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service, 'gateway': 'on'})
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], service)
# second for an user not yet authenticated on a valid service
client = Client()
response = client.get('/login', {'service': service, 'gateway': 'on'})
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], service)
def test_renew(self):
service = "https://www.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service, 'renew': 'on'})
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"Authentication renewal required by "
b"service example (https://www.example.com)"
) in response.content
)
params = copy_form(response.context["form"])
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
self.assertEqual(params["renew"], True)
response = client.post("/login", params)
self.assertEqual(response.status_code, 302)
ticket_value = response['Location'].split('ticket=')[-1]
ticket = models.ServiceTicket.objects.get(value=ticket_value)
self.assertEqual(ticket.renew, True)
class LogoutTestCase(TestCase):
def setUp(self):
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
def test_logout_view(self):
client = get_auth_client()
response = client.get("/login")
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
response = client.get("/logout")
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"You have successfully logged out from "
b"the Central Authentication Service"
) in response.content
)
response = client.get("/login")
self.assertEqual(response.status_code, 200)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
def test_logout_view_url(self):
client = get_auth_client()
response = client.get('/logout?url=https://www.example.com')
self.assertEqual(response.status_code, 302)
self.assertTrue(response.has_header("Location"))
self.assertEqual(response["Location"], "https://www.example.com")
response = client.get("/login")
self.assertEqual(response.status_code, 200)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
def test_logout_view_service(self):
client = get_auth_client()
response = client.get('/logout?service=https://www.example.com')
self.assertEqual(response.status_code, 302)
self.assertTrue(response.has_header("Location"))
self.assertEqual(response["Location"], "https://www.example.com")
response = client.get("/login")
self.assertEqual(response.status_code, 200)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
class AuthTestCase(TestCase):
def setUp(self):
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
self.service = 'https://www.example.com'
models.ServicePattern.objects.create(
name="example",
pattern="^https://www\.example\.com(/.*)?$"
)
def test_auth_view_goodpass(self):
settings.CAS_AUTH_SHARED_SECRET = 'test'
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': settings.CAS_TEST_PASSWORD,
'service': self.service,
'secret': 'test'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'yes\n')
def test_auth_view_badpass(self):
settings.CAS_AUTH_SHARED_SECRET = 'test'
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': 'badpass',
'service': self.service,
'secret': 'test'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
def test_auth_view_badservice(self):
settings.CAS_AUTH_SHARED_SECRET = 'test'
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': settings.CAS_TEST_PASSWORD,
'service': 'https://www.example.org',
'secret': 'test'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
def test_auth_view_badsecret(self):
settings.CAS_AUTH_SHARED_SECRET = 'test'
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': settings.CAS_TEST_PASSWORD,
'service': self.service,
'secret': 'badsecret'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
def test_auth_view_badsettings(self):
settings.CAS_AUTH_SHARED_SECRET = None
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': settings.CAS_TEST_PASSWORD,
'service': self.service,
'secret': 'test'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"no\nplease set CAS_AUTH_SHARED_SECRET")
class ValidateTestCase(TestCase):
def setUp(self):
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
self.service = 'https://www.example.com'
self.service_pattern = models.ServicePattern.objects.create(
name="example",
pattern="^https://www\.example\.com(/.*)?$"
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_validate_view_ok(self):
ticket = get_user_ticket_request(self.service)[1]
client = Client()
response = client.get('/validate', {'ticket': ticket.value, 'service': self.service})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'yes\ntest\n')
def test_validate_view_badservice(self):
ticket = get_user_ticket_request(self.service)[1]
client = Client()
response = client.get(
'/validate',
{'ticket': ticket.value, 'service': "https://www.example.org"}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
def test_validate_view_badticket(self):
get_user_ticket_request(self.service)
client = Client()
response = client.get(
'/validate',
{'ticket': "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX, 'service': self.service}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
class ValidateServiceTestCase(TestCase):
def setUp(self):
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
self.service = 'http://127.0.0.1:45678'
self.service_pattern = models.ServicePattern.objects.create(
name="localhost",
pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
proxy_callback=True
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_validate_service_view_ok(self):
ticket = get_user_ticket_request(self.service)[1]
client = Client()
response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
sucess = root.xpath(
"//cas:authenticationSuccess",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertTrue(sucess)
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(users), 1)
self.assertEqual(users[0].text, settings.CAS_TEST_USER)
attributes = root.xpath(
"//cas:attributes",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(attributes), 1)
attrs1 = set()
for attr in attributes[0]:
attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(attributes), len(attrs1))
attrs2 = set()
for attr in attributes:
attrs2.add((attr.attrib['name'], attr.attrib['value']))
original = set()
for key, value in settings.CAS_TEST_ATTRIBUTES.items():
if isinstance(value, list):
for sub_value in value:
original.add((key, sub_value))
else:
original.add((key, value))
self.assertEqual(attrs1, attrs2)
self.assertEqual(attrs1, original)
def test_validate_service_view_badservice(self):
ticket = get_user_ticket_request(self.service)[1]
client = Client()
bad_service = "https://www.example.org"
response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': bad_service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_SERVICE")
self.assertEqual(error[0].text, bad_service)
def test_validate_service_view_badticket_goodprefix(self):
get_user_ticket_request(self.service)
client = Client()
bad_ticket = "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX
response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
self.assertEqual(error[0].text, 'ticket not found')
def test_validate_service_view_badticket_badprefix(self):
get_user_ticket_request(self.service)
client = Client()
bad_ticket = "RANDOM"
response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
self.assertEqual(error[0].text, bad_ticket)
def test_validate_service_view_ok_pgturl(self):
(host, port) = utils.PGTUrlHandler.run()[1:3]
service = "http://%s:%s" % (host, port)
ticket = get_user_ticket_request(service)[1]
client = Client()
response = client.get(
'/serviceValidate',
{'ticket': ticket.value, 'service': service, 'pgtUrl': service}
)
pgt_params = utils.PGTUrlHandler.PARAMS.copy()
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
pgtiou = root.xpath(
"//cas:proxyGrantingTicket",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(pgtiou), 1)
self.assertEqual(pgt_params["pgtIou"], pgtiou[0].text)
self.assertTrue("pgtId" in pgt_params)
def test_validate_service_pgturl_bad_proxy_callback(self):
self.service_pattern.proxy_callback = False
self.service_pattern.save()
ticket = get_user_ticket_request(self.service)[1]
client = Client()
response = client.get(
'/serviceValidate',
{'ticket': ticket.value, 'service': self.service, 'pgtUrl': self.service}
)
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_PROXY_CALLBACK")
self.assertEqual(error[0].text, "callback url not allowed by configuration")
class ProxyTestCase(TestCase):
def setUp(self):
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
self.service = 'http://127.0.0.1'
self.service_pattern = models.ServicePattern.objects.create(
name="localhost",
pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
proxy=True,
proxy_callback=True
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_validate_proxy_ok(self):
params = get_pgt()
# get a proxy ticket
client1 = Client()
response = client1.get('/proxy', {'pgt': params['pgtId'], 'targetService': self.service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
sucess = root.xpath("//cas:proxySuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertTrue(sucess)
proxy_ticket = root.xpath(
"//cas:proxyTicket",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(proxy_ticket), 1)
proxy_ticket = proxy_ticket[0].text
# validate the proxy ticket
client2 = Client()
response = client2.get('/proxyValidate', {'ticket': proxy_ticket, 'service': self.service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
sucess = root.xpath(
"//cas:authenticationSuccess",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertTrue(sucess)
# check that the proxy is send to the end service
proxies = root.xpath("//cas:proxies", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(proxies), 1)
proxy = proxies[0].xpath("//cas:proxy", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(proxy), 1)
self.assertEqual(proxy[0].text, params["service"])
# same tests than those for serviceValidate
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(users), 1)
self.assertEqual(users[0].text, settings.CAS_TEST_USER)
attributes = root.xpath(
"//cas:attributes",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(attributes), 1)
attrs1 = set()
for attr in attributes[0]:
attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(attributes), len(attrs1))
attrs2 = set()
for attr in attributes:
attrs2.add((attr.attrib['name'], attr.attrib['value']))
original = set()
for key, value in settings.CAS_TEST_ATTRIBUTES.items():
if isinstance(value, list):
for sub_value in value:
original.add((key, sub_value))
else:
original.add((key, value))
self.assertEqual(attrs1, attrs2)
self.assertEqual(attrs1, original)
def test_validate_proxy_bad(self):
params = get_pgt()
# bad PGT
client1 = Client()
response = client1.get(
'/proxy',
{
'pgt': "%s-RANDOM" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX,
'targetService': params['service']
}
)
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
self.assertEqual(
error[0].text,
"PGT %s-RANDOM not found" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX
)
# bad targetService
client2 = Client()
response = client2.get(
'/proxy',
{'pgt': params['pgtId'], 'targetService': "https://www.example.org"}
)
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE")
self.assertEqual(error[0].text, "https://www.example.org")
# service do not allow proxy ticket
self.service_pattern.proxy = False
self.service_pattern.save()
client3 = Client()
response = client3.get(
'/proxy',
{'pgt': params['pgtId'], 'targetService': params['service']}
)
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE")
self.assertEqual(
error[0].text,
'the service %s do not allow proxy ticket' % params['service']
)

View file

193
cas_server/tests/mixin.py Normal file
View file

@ -0,0 +1,193 @@
# ⁻*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
# more details.
#
# You should have received a copy of the GNU General Public License version 3
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# (c) 2016 Valentin Samir
"""Some mixin classes for tests"""
from cas_server.default_settings import settings
from django.utils import timezone
import re
from lxml import etree
from datetime import timedelta
from cas_server import models
from cas_server.tests.utils import get_auth_client
class BaseServicePattern(object):
"""Mixing for setting up service pattern for testing"""
def setup_service_patterns(self, proxy=False):
"""setting up service pattern"""
# For general purpose testing
self.service = "https://www.example.com"
self.service_pattern = models.ServicePattern.objects.create(
name="example",
pattern="^https://www\.example\.com(/.*)?$",
proxy=proxy,
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
# For testing the restrict_users attributes
self.service_restrict_user_fail = "https://restrict_user_fail.example.com"
self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create(
name="restrict_user_fail",
pattern="^https://restrict_user_fail\.example\.com(/.*)?$",
restrict_users=True,
proxy=proxy,
)
self.service_restrict_user_success = "https://restrict_user_success.example.com"
self.service_pattern_restrict_user_success = models.ServicePattern.objects.create(
name="restrict_user_success",
pattern="^https://restrict_user_success\.example\.com(/.*)?$",
restrict_users=True,
proxy=proxy,
)
models.Username.objects.create(
value=settings.CAS_TEST_USER,
service_pattern=self.service_pattern_restrict_user_success
)
# For testing the user attributes filtering conditions
self.service_filter_fail = "https://filter_fail.example.com"
self.service_pattern_filter_fail = models.ServicePattern.objects.create(
name="filter_fail",
pattern="^https://filter_fail\.example\.com(/.*)?$",
proxy=proxy,
)
models.FilterAttributValue.objects.create(
attribut="right",
pattern="^admin$",
service_pattern=self.service_pattern_filter_fail
)
self.service_filter_fail_alt = "https://filter_fail_alt.example.com"
self.service_pattern_filter_fail_alt = models.ServicePattern.objects.create(
name="filter_fail_alt",
pattern="^https://filter_fail_alt\.example\.com(/.*)?$",
proxy=proxy,
)
models.FilterAttributValue.objects.create(
attribut="nom",
pattern="^toto$",
service_pattern=self.service_pattern_filter_fail_alt
)
self.service_filter_success = "https://filter_success.example.com"
self.service_pattern_filter_success = models.ServicePattern.objects.create(
name="filter_success",
pattern="^https://filter_success\.example\.com(/.*)?$",
proxy=proxy,
)
models.FilterAttributValue.objects.create(
attribut="email",
pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']),
service_pattern=self.service_pattern_filter_success
)
# For testing the user_field attributes
self.service_field_needed_fail = "https://field_needed_fail.example.com"
self.service_pattern_field_needed_fail = models.ServicePattern.objects.create(
name="field_needed_fail",
pattern="^https://field_needed_fail\.example\.com(/.*)?$",
user_field="uid",
proxy=proxy,
)
self.service_field_needed_success = "https://field_needed_success.example.com"
self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
name="field_needed_success",
pattern="^https://field_needed_success\.example\.com(/.*)?$",
user_field="alias",
proxy=proxy,
)
self.service_field_needed_success_alt = "https://field_needed_success_alt.example.com"
self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
name="field_needed_success_alt",
pattern="^https://field_needed_success_alt\.example\.com(/.*)?$",
user_field="nom",
proxy=proxy,
)
class XmlContent(object):
"""Mixin for test on CAS XML responses"""
def assert_error(self, response, code, text=None):
"""Assert a validation error"""
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], code)
if text is not None:
self.assertEqual(error[0].text, text)
def assert_success(self, response, username, original_attributes):
"""assert a ticket validation success"""
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
sucess = root.xpath(
"//cas:authenticationSuccess",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertTrue(sucess)
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(users), 1)
self.assertEqual(users[0].text, username)
attributes = root.xpath(
"//cas:attributes",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(attributes), 1)
attrs1 = set()
for attr in attributes[0]:
attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(attributes), len(attrs1))
attrs2 = set()
for attr in attributes:
attrs2.add((attr.attrib['name'], attr.attrib['value']))
original = set()
for key, value in original_attributes.items():
if isinstance(value, list):
for sub_value in value:
original.add((key, sub_value))
else:
original.add((key, value))
self.assertEqual(attrs1, attrs2)
self.assertEqual(attrs1, original)
return root
class UserModels(object):
"""Mixin for test on CAS user models"""
@staticmethod
def expire_user():
"""return an expired user"""
client = get_auth_client()
new_date = timezone.now() - timedelta(seconds=(settings.SESSION_COOKIE_AGE + 600))
models.User.objects.filter(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
).update(date=new_date)
return client
@staticmethod
def get_user(client):
"""return the user associated with an authenticated client"""
return models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)

View file

@ -52,7 +52,7 @@ MIDDLEWARE_CLASSES = [
'django.middleware.locale.LocaleMiddleware', 'django.middleware.locale.LocaleMiddleware',
] ]
ROOT_URLCONF = 'urls_tests' ROOT_URLCONF = 'cas_server.tests.urls'
# Database # Database
# https://docs.djangoproject.com/en/1.9/ref/settings/#databases # https://docs.djangoproject.com/en/1.9/ref/settings/#databases
@ -60,6 +60,7 @@ ROOT_URLCONF = 'urls_tests'
DATABASES = { DATABASES = {
'default': { 'default': {
'ENGINE': 'django.db.backends.sqlite3', 'ENGINE': 'django.db.backends.sqlite3',
'NAME': ':memory:',
} }
} }

View file

@ -0,0 +1,166 @@
# ⁻*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
# more details.
#
# You should have received a copy of the GNU General Public License version 3
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# (c) 2016 Valentin Samir
"""Tests module for models"""
from cas_server.default_settings import settings
from django.test import TestCase
from django.test.utils import override_settings
from django.utils import timezone
from datetime import timedelta
from importlib import import_module
from cas_server import models
from cas_server.tests.utils import get_auth_client, HttpParamsHandler
from cas_server.tests.mixin import UserModels, BaseServicePattern
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class UserTestCase(TestCase, UserModels):
"""tests for the user models"""
def setUp(self):
"""Prepare the test context"""
self.service = 'http://127.0.0.1:45678'
self.service_pattern = models.ServicePattern.objects.create(
name="localhost",
pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
single_log_out=True
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_clean_old_entries(self):
"""test clean_old_entries"""
# get an authenticated client
client = self.expire_user()
# assert the user exists before being cleaned
self.assertEqual(len(models.User.objects.all()), 1)
# assert the last activity date is before the expiry date
self.assertTrue(
self.get_user(client).date < (
timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE)
)
)
# delete old inactive users
models.User.clean_old_entries()
# assert the user has being well delete
self.assertEqual(len(models.User.objects.all()), 0)
def test_clean_deleted_sessions(self):
"""test clean_deleted_sessions"""
# get an authenticated client
client1 = get_auth_client()
client2 = get_auth_client()
# generate a ticket to fire SLO during user cleaning (SLO should fail a nothing listen
# on self.service)
ticket = self.get_user(client1).get_ticket(
models.ServiceTicket,
self.service,
self.service_pattern,
renew=False
)
ticket.validate = True
ticket.save()
# simulated expired session being garbage collected for client1
session = SessionStore(session_key=client1.session.session_key)
session.flush()
# assert the user exists before being cleaned
self.assertTrue(self.get_user(client1))
self.assertTrue(self.get_user(client2))
self.assertEqual(len(models.User.objects.all()), 2)
# session has being remove so the user of client1 is no longer authenticated
self.assertFalse(client1.session.get("authenticated"))
# the user a client2 should still be authenticated
self.assertTrue(client2.session.get("authenticated"))
# the user should be deleted
models.User.clean_deleted_sessions()
# assert the user with expired sessions has being well deleted but the other remain
self.assertEqual(len(models.User.objects.all()), 1)
self.assertFalse(models.ServiceTicket.objects.all())
self.assertTrue(client2.session.get("authenticated"))
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class TicketTestCase(TestCase, UserModels, BaseServicePattern):
"""tests for the tickets models"""
def setUp(self):
"""Prepare the test context"""
self.setup_service_patterns()
self.service = 'http://127.0.0.1:45678'
self.service_pattern = models.ServicePattern.objects.create(
name="localhost",
pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
single_log_out=True
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
@staticmethod
def get_ticket(
user,
ticket_class,
service,
service_pattern,
renew=False,
validate=False,
validity_expired=False,
timeout_expired=False,
single_log_out=False,
):
"""Return a ticket"""
ticket = user.get_ticket(ticket_class, service, service_pattern, renew)
ticket.validate = validate
ticket.single_log_out = single_log_out
if validity_expired:
ticket.creation = min(
ticket.creation,
(timezone.now() - timedelta(seconds=(ticket_class.VALIDITY + 10)))
)
if timeout_expired:
ticket.creation = min(
ticket.creation,
(timezone.now() - timedelta(seconds=(ticket_class.TIMEOUT + 10)))
)
ticket.save()
return ticket
def test_clean_old_service_ticket(self):
"""test tickets clean_old_entries"""
# ge an authenticated client
client = get_auth_client()
# get the user associated to the client
user = self.get_user(client)
# generate a ticket for that client, waiting for validation
self.get_ticket(user, models.ServiceTicket, self.service, self.service_pattern)
# generate another ticket for those validation time has expired
self.get_ticket(
user, models.ServiceTicket,
self.service, self.service_pattern, validity_expired=True
)
(httpd, host, port) = HttpParamsHandler.run()[0:3]
service = "http://%s:%s" % (host, port)
# generate a ticket with SLO having timeout reach
self.get_ticket(
user, models.ServiceTicket,
service, self.service_pattern, timeout_expired=True,
validate=True, single_log_out=True
)
# there should be 3 tickets in the db
self.assertEqual(len(models.ServiceTicket.objects.all()), 3)
# we call the clean_old_entries method that should delete validated non SLO ticket and
# expired non validated ticket and send SLO for SLO expired ticket before deleting then
models.ServiceTicket.clean_old_entries()
params = httpd.PARAMS
# we successfully got a SLO request
self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest'])
# only 1 ticket remain in the db
self.assertEqual(len(models.ServiceTicket.objects.all()), 1)

View file

@ -0,0 +1,191 @@
# ⁻*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
# more details.
#
# You should have received a copy of the GNU General Public License version 3
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# (c) 2016 Valentin Samir
"""Tests module for utils"""
from django.test import TestCase
import six
from cas_server import utils
class CheckPasswordCase(TestCase):
"""Tests for the utils function `utils.check_password`"""
def setUp(self):
"""Generate random bytes string that will be used ass passwords"""
self.password1 = utils.gen_saml_id()
self.password2 = utils.gen_saml_id()
if not isinstance(self.password1, bytes): # pragma: no cover executed only in python3
self.password1 = self.password1.encode("utf8")
self.password2 = self.password2.encode("utf8")
def test_setup(self):
"""check that generated password are bytes"""
self.assertIsInstance(self.password1, bytes)
self.assertIsInstance(self.password2, bytes)
def test_plain(self):
"""test the plain auth method"""
self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
def test_plain_unicode(self):
"""test the plain auth method with unicode input"""
self.assertTrue(
utils.check_password(
"plain",
self.password1.decode("utf8"),
self.password1.decode("utf8"),
"utf8"
)
)
self.assertFalse(
utils.check_password(
"plain",
self.password1.decode("utf8"),
self.password2.decode("utf8"),
"utf8"
)
)
def test_crypt(self):
"""test the crypt auth method"""
salts = ["$6$UVVAQvrMyXMF3FF3", "aa"]
hashed_password1 = []
for salt in salts:
if six.PY3:
hashed_password1.append(
utils.crypt.crypt(
self.password1.decode("utf8"),
salt
).encode("utf8")
)
else:
hashed_password1.append(utils.crypt.crypt(self.password1, salt))
for hp1 in hashed_password1:
self.assertTrue(utils.check_password("crypt", self.password1, hp1, "utf8"))
self.assertFalse(utils.check_password("crypt", self.password2, hp1, "utf8"))
with self.assertRaises(ValueError):
utils.check_password("crypt", self.password1, b"$truc$s$dsdsd", "utf8")
def test_ldap_password_valid(self):
"""test the ldap auth method with all the schemes"""
salt = b"UVVAQvrMyXMF3FF3"
schemes_salt = [b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}"]
schemes_nosalt = [b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"]
hashed_password1 = []
for scheme in schemes_salt:
hashed_password1.append(
utils.LdapHashUserPassword.hash(scheme, self.password1, salt, charset="utf8")
)
for scheme in schemes_nosalt:
hashed_password1.append(
utils.LdapHashUserPassword.hash(scheme, self.password1, charset="utf8")
)
hashed_password1.append(
utils.LdapHashUserPassword.hash(
b"{CRYPT}",
self.password1,
b"$6$UVVAQvrMyXMF3FF3",
charset="utf8"
)
)
for hp1 in hashed_password1:
self.assertIsInstance(hp1, bytes)
self.assertTrue(utils.check_password("ldap", self.password1, hp1, "utf8"))
self.assertFalse(utils.check_password("ldap", self.password2, hp1, "utf8"))
def test_ldap_password_fail(self):
"""test the ldap auth method with malformed hash or bad schemes"""
salt = b"UVVAQvrMyXMF3FF3"
schemes_salt = [b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}"]
schemes_nosalt = [b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"]
# first try to hash with bad parameters
with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
utils.LdapHashUserPassword.hash(b"TOTO", self.password1)
for scheme in schemes_nosalt:
with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
utils.LdapHashUserPassword.hash(scheme, self.password1, salt)
for scheme in schemes_salt:
with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
utils.LdapHashUserPassword.hash(scheme, self.password1)
with self.assertRaises(utils.LdapHashUserPassword.BadSalt):
utils.LdapHashUserPassword.hash(b'{CRYPT}', self.password1, b"$truc$toto")
# then try to check hash with bad hashes
with self.assertRaises(utils.LdapHashUserPassword.BadHash):
utils.check_password("ldap", self.password1, b"TOTOssdsdsd", "utf8")
for scheme in schemes_salt:
with self.assertRaises(utils.LdapHashUserPassword.BadHash):
utils.check_password("ldap", self.password1, scheme + b"dG90b3E8ZHNkcw==", "utf8")
def test_hex(self):
"""test all the hex_HASH method: the hashed password is a simple hash of the password"""
hashes = ["md5", "sha1", "sha224", "sha256", "sha384", "sha512"]
hashed_password1 = []
for hash in hashes:
hashed_password1.append(
("hex_%s" % hash, getattr(utils.hashlib, hash)(self.password1).hexdigest())
)
for (method, hp1) in hashed_password1:
self.assertTrue(utils.check_password(method, self.password1, hp1, "utf8"))
self.assertFalse(utils.check_password(method, self.password2, hp1, "utf8"))
def test_bad_method(self):
"""try to check password with a bad method, should raise a ValueError"""
with self.assertRaises(ValueError):
utils.check_password("test", self.password1, b"$truc$s$dsdsd", "utf8")
class UtilsTestCase(TestCase):
"""tests for some little utils functions"""
def test_import_attr(self):
"""
test the import_attr function. Feeded with a dotted path string, it should
import the dotted module and return that last componend of the dotted path
(function, class or variable)
"""
with self.assertRaises(ImportError):
utils.import_attr('toto.titi.tutu')
with self.assertRaises(AttributeError):
utils.import_attr('cas_server.utils.toto')
with self.assertRaises(ValueError):
utils.import_attr('toto')
self.assertEqual(
utils.import_attr('cas_server.default_app_config'),
'cas_server.apps.CasAppConfig'
)
self.assertEqual(utils.import_attr(utils), utils)
def test_update_url(self):
"""
test the update_url function. Given an url with possible GET parameter and a dict
the function build a url with GET parameters updated by the dictionnary
"""
url1 = utils.update_url(u"https://www.example.com?toto=1", {u"tata": u"2"})
url2 = utils.update_url(b"https://www.example.com?toto=1", {b"tata": b"2"})
self.assertEqual(url1, u"https://www.example.com?tata=2&toto=1")
self.assertEqual(url2, u"https://www.example.com?tata=2&toto=1")
url3 = utils.update_url(u"https://www.example.com?toto=1", {u"toto": u"2"})
self.assertEqual(url3, u"https://www.example.com?toto=2")
def test_crypt_salt_is_valid(self):
"""test the function crypt_salt_is_valid who test if a crypt salt is valid"""
self.assertFalse(utils.crypt_salt_is_valid("")) # len 0
self.assertFalse(utils.crypt_salt_is_valid("a")) # len 1
self.assertFalse(utils.crypt_salt_is_valid("$$")) # start with $ followed by $
self.assertFalse(utils.crypt_salt_is_valid("$toto")) # start with $ but no secondary $
self.assertFalse(utils.crypt_salt_is_valid("$toto$toto")) # algorithm toto not known

File diff suppressed because it is too large Load diff

180
cas_server/tests/utils.py Normal file
View file

@ -0,0 +1,180 @@
# ⁻*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
# more details.
#
# You should have received a copy of the GNU General Public License version 3
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# (c) 2016 Valentin Samir
"""Some utils functions for tests"""
from cas_server.default_settings import settings
from django.test import Client
import cgi
from threading import Thread
from lxml import etree
from six.moves import BaseHTTPServer
from six.moves.urllib.parse import urlparse, parse_qsl
from cas_server import models
def copy_form(form):
"""Copy form value into a dict"""
params = {}
for field in form:
if field.value():
params[field.name] = field.value()
else:
params[field.name] = ""
return params
def get_login_page_params(client=None):
"""Return a client and the POST params for the client to login"""
if client is None:
client = Client()
response = client.get('/login')
params = copy_form(response.context["form"])
return client, params
def get_auth_client(**update):
"""return a authenticated client"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
params.update(update)
client.post('/login', params)
assert client.session.get("authenticated")
return client
def get_user_ticket_request(service):
"""Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
client = get_auth_client()
response = client.get("/login", {"service": service})
ticket_value = response['Location'].split('ticket=')[-1]
user = models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
ticket = models.ServiceTicket.objects.get(value=ticket_value)
return (user, ticket, client)
def get_validated_ticket(service):
"""Return a tick that has being already validated. Used to test SLO"""
(ticket, auth_client) = get_user_ticket_request(service)[1:3]
client = Client()
response = client.get('/validate', {'ticket': ticket.value, 'service': service})
assert (response.status_code == 200)
assert (response.content == b'yes\ntest\n')
ticket = models.ServiceTicket.objects.get(value=ticket.value)
return (auth_client, ticket)
def get_pgt():
"""return a dict contening a service, user and PGT ticket for this service"""
(httpd, host, port) = HttpParamsHandler.run()[0:3]
service = "http://%s:%s" % (host, port)
(user, ticket) = get_user_ticket_request(service)[:2]
client = Client()
client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
params = httpd.PARAMS
params["service"] = service
params["user"] = user
return params
def get_proxy_ticket(service):
"""Return a ProxyTicket waiting for validation"""
params = get_pgt()
# get a proxy ticket
client = Client()
response = client.get('/proxy', {'pgt': params['pgtId'], 'targetService': service})
root = etree.fromstring(response.content)
proxy_ticket = root.xpath(
"//cas:proxyTicket",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
proxy_ticket = proxy_ticket[0].text
ticket = models.ProxyTicket.objects.get(value=proxy_ticket)
return ticket
class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler):
"""
A simple http server that return 200 on GET or POST
and store GET or POST parameters. Used in unit tests
"""
def do_GET(self):
"""Called on a GET request on the BaseHTTPServer"""
self.send_response(200)
self.send_header(b"Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"ok")
url = urlparse(self.path)
params = dict(parse_qsl(url.query))
self.server.PARAMS = params
def do_POST(self):
"""Called on a POST request on the BaseHTTPServer"""
ctype, pdict = cgi.parse_header(self.headers.get('content-type'))
if ctype == 'multipart/form-data':
postvars = cgi.parse_multipart(self.rfile, pdict)
elif ctype == 'application/x-www-form-urlencoded':
length = int(self.headers.get('content-length'))
postvars = cgi.parse_qs(self.rfile.read(length), keep_blank_values=1)
else:
postvars = {}
self.server.PARAMS = postvars
def log_message(self, *args):
"""silent any log message"""
return
@classmethod
def run(cls):
"""Run a BaseHTTPServer using this class as handler"""
server_class = BaseHTTPServer.HTTPServer
httpd = server_class(("127.0.0.1", 0), cls)
(host, port) = httpd.socket.getsockname()
def lauch():
"""routine to lauch in a background thread"""
httpd.handle_request()
httpd.server_close()
httpd_thread = Thread(target=lauch)
httpd_thread.daemon = True
httpd_thread.start()
return (httpd, host, port)
class Http404Handler(HttpParamsHandler):
"""A simple http server that always return 404 not found. Used in unit tests"""
def do_GET(self):
"""Called on a GET request on the BaseHTTPServer"""
self.send_response(404)
self.send_header(b"Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"error 404 not found")
def do_POST(self):
"""Called on a POST request on the BaseHTTPServer"""
return self.do_GET()

View file

@ -8,7 +8,7 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51 # along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
# #
# (c) 2015 Valentin Samir # (c) 2015-2016 Valentin Samir
"""urls for the app""" """urls for the app"""
from django.conf.urls import patterns, url from django.conf.urls import patterns, url
from django.views.generic import RedirectView from django.views.generic import RedirectView

View file

@ -1,4 +1,4 @@
# *- coding: utf-8 -*- # -*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT # This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for # FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
@ -8,7 +8,7 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51 # along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
# #
# (c) 2015 Valentin Samir # (c) 2015-2016 Valentin Samir
"""Some util function for the app""" """Some util function for the app"""
from .default_settings import settings from .default_settings import settings
@ -23,19 +23,20 @@ import hashlib
import crypt import crypt
import base64 import base64
import six import six
from threading import Thread
from importlib import import_module from importlib import import_module
from datetime import datetime, timedelta from datetime import datetime, timedelta
from six.moves import BaseHTTPServer
from six.moves.urllib.parse import urlparse, urlunparse, parse_qsl, urlencode from six.moves.urllib.parse import urlparse, urlunparse, parse_qsl, urlencode
def context(params): def context(params):
"""Function that add somes variable to the context before template rendering"""
params["settings"] = settings params["settings"] = settings
return params return params
def json_response(request, data): def json_response(request, data):
"""Wrapper dumping `data` to a json and sending it to the user with an HttpResponse"""
data["messages"] = [] data["messages"] = []
for msg in messages.get_messages(request): for msg in messages.get_messages(request):
data["messages"].append({'message': msg.message, 'level': msg.level_tag}) data["messages"].append({'message': msg.message, 'level': msg.level_tag})
@ -65,6 +66,7 @@ def redirect_params(url_name, params=None):
def reverse_params(url_name, params=None, **kwargs): def reverse_params(url_name, params=None, **kwargs):
"""compule the reverse url or `url_name` and add GET parameters from `params` to it"""
url = reverse(url_name, **kwargs) url = reverse(url_name, **kwargs)
params = urlencode(params if params else {}) params = urlencode(params if params else {})
if params: if params:
@ -124,10 +126,13 @@ def update_url(url, params):
url_parts = list(urlparse(url)) url_parts = list(urlparse(url))
query = dict(parse_qsl(url_parts[4])) query = dict(parse_qsl(url_parts[4]))
query.update(params) query.update(params)
url_parts[4] = urlencode(query) # make the params order deterministic
for i, url_part in enumerate(url_parts): query = list(query.items())
if not isinstance(url_part, bytes): query.sort()
url_parts[i] = url_part.encode('utf-8') url_query = urlencode(query)
if not isinstance(url_query, bytes): # pragma: no cover in python3 urlencode return an unicode
url_query = url_query.encode("utf-8")
url_parts[4] = url_query
return urlunparse(url_parts).decode('utf-8') return urlunparse(url_parts).decode('utf-8')
@ -197,35 +202,25 @@ def get_tuple(nuplet, index, default=None):
return default return default
class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler): def crypt_salt_is_valid(salt):
PARAMS = {} """Return True is salt is valid has a crypt salt, False otherwise"""
if len(salt) < 2:
def do_GET(self): return False
self.send_response(200) else:
self.send_header(b"Content-type", "text/plain") if salt[0] == '$':
self.end_headers() if salt[1] == '$':
self.wfile.write(b"ok") return False
url = urlparse(self.path) else:
params = dict(parse_qsl(url.query)) if '$' not in salt[1:]:
PGTUrlHandler.PARAMS.update(params) return False
else:
def log_message(self, *args): hashed = crypt.crypt("", salt)
return if not hashed or '$' not in hashed[1:]:
return False
@staticmethod else:
def run(): return True
server_class = BaseHTTPServer.HTTPServer else:
httpd = server_class(("127.0.0.1", 0), PGTUrlHandler) return True
(host, port) = httpd.socket.getsockname()
def lauch():
httpd.handle_request()
httpd.server_close()
httpd_thread = Thread(target=lauch)
httpd_thread.daemon = True
httpd_thread.start()
return (httpd_thread, host, port)
class LdapHashUserPassword(object): class LdapHashUserPassword(object):
@ -318,7 +313,7 @@ class LdapHashUserPassword(object):
if salt is None or salt == b"": if salt is None or salt == b"":
salt = b"" salt = b""
cls._test_scheme_nosalt(scheme) cls._test_scheme_nosalt(scheme)
elif salt is not None: else:
cls._test_scheme_salt(scheme) cls._test_scheme_salt(scheme)
try: try:
return scheme + base64.b64encode( return scheme + base64.b64encode(
@ -328,9 +323,9 @@ class LdapHashUserPassword(object):
if six.PY3: if six.PY3:
password = password.decode(charset) password = password.decode(charset)
salt = salt.decode(charset) salt = salt.decode(charset)
hashed_password = crypt.crypt(password, salt) if not crypt_salt_is_valid(salt):
if hashed_password is None:
raise cls.BadSalt("System crypt implementation do not support the salt %r" % salt) raise cls.BadSalt("System crypt implementation do not support the salt %r" % salt)
hashed_password = crypt.crypt(password, salt)
if six.PY3: if six.PY3:
hashed_password = hashed_password.encode(charset) hashed_password = hashed_password.encode(charset)
return scheme + hashed_password return scheme + hashed_password
@ -352,7 +347,7 @@ class LdapHashUserPassword(object):
if scheme in cls.schemes_nosalt: if scheme in cls.schemes_nosalt:
return b"" return b""
elif scheme == b'{CRYPT}': elif scheme == b'{CRYPT}':
return b'$'.join(hashed_passord.split(b'$', 3)[:-1]) return b'$'.join(hashed_passord.split(b'$', 3)[:-1])[len(scheme):]
else: else:
hashed_passord = base64.b64decode(hashed_passord[len(scheme):]) hashed_passord = base64.b64decode(hashed_passord[len(scheme):])
if len(hashed_passord) < cls._schemes_to_len[scheme]: if len(hashed_passord) < cls._schemes_to_len[scheme]:
@ -374,7 +369,7 @@ def check_password(method, password, hashed_password, charset):
elif method == "crypt": elif method == "crypt":
if hashed_password.startswith(b'$'): if hashed_password.startswith(b'$'):
salt = b'$'.join(hashed_password.split(b'$', 3)[:-1]) salt = b'$'.join(hashed_password.split(b'$', 3)[:-1])
elif hashed_password.startswith(b'_'): elif hashed_password.startswith(b'_'): # pragma: no cover old BSD format not supported
salt = hashed_password[:9] salt = hashed_password[:9]
else: else:
salt = hashed_password[:2] salt = hashed_password[:2]
@ -382,9 +377,9 @@ def check_password(method, password, hashed_password, charset):
password = password.decode(charset) password = password.decode(charset)
salt = salt.decode(charset) salt = salt.decode(charset)
hashed_password = hashed_password.decode(charset) hashed_password = hashed_password.decode(charset)
crypted_password = crypt.crypt(password, salt) if not crypt_salt_is_valid(salt):
if crypted_password is None:
raise ValueError("System crypt implementation do not support the salt %r" % salt) raise ValueError("System crypt implementation do not support the salt %r" % salt)
crypted_password = crypt.crypt(password, salt)
return crypted_password == hashed_password return crypted_password == hashed_password
elif method == "ldap": elif method == "ldap":
scheme = LdapHashUserPassword.get_scheme(hashed_password) scheme = LdapHashUserPassword.get_scheme(hashed_password)

View file

@ -8,7 +8,7 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51 # along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
# #
# (c) 2015 Valentin Samir # (c) 2015-2016 Valentin Samir
"""views for the app""" """views for the app"""
from .default_settings import settings from .default_settings import settings
@ -115,7 +115,7 @@ class LogoutView(View, LogoutMixin):
self.request = request self.request = request
self.service = request.GET.get('service') self.service = request.GET.get('service')
self.url = request.GET.get('url') self.url = request.GET.get('url')
self.ajax = 'HTTP_X_AJAX' in request.META self.ajax = settings.CAS_ENABLE_AJAX_AUTH and 'HTTP_X_AJAX' in request.META
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
"""methode called on GET request on this view""" """methode called on GET request on this view"""
@ -299,7 +299,7 @@ class LoginView(View, LogoutMixin):
self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False") self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False")
self.gateway = request.POST.get('gateway') self.gateway = request.POST.get('gateway')
self.method = request.POST.get('method') self.method = request.POST.get('method')
self.ajax = 'HTTP_X_AJAX' in request.META self.ajax = settings.CAS_ENABLE_AJAX_AUTH and 'HTTP_X_AJAX' in request.META
if request.POST.get('warned') and request.POST['warned'] != "False": if request.POST.get('warned') and request.POST['warned'] != "False":
self.warned = True self.warned = True
self.warn = request.POST.get('warn') self.warn = request.POST.get('warn')
@ -401,7 +401,7 @@ class LoginView(View, LogoutMixin):
self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False") self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False")
self.gateway = request.GET.get('gateway') self.gateway = request.GET.get('gateway')
self.method = request.GET.get('method') self.method = request.GET.get('method')
self.ajax = 'HTTP_X_AJAX' in request.META self.ajax = settings.CAS_ENABLE_AJAX_AUTH and 'HTTP_X_AJAX' in request.META
self.warn = request.GET.get('warn') self.warn = request.GET.get('warn')
if settings.CAS_FEDERATE: if settings.CAS_FEDERATE:
self.username = request.session.get("federate_username") self.username = request.session.get("federate_username")
@ -753,12 +753,9 @@ class Validate(View):
ticket.service_pattern.user_field ticket.service_pattern.user_field
) )
if isinstance(username, list): if isinstance(username, list):
try: # the list is not empty because we wont generate a ticket with a user_field
# that evaluate to False
username = username[0] username = username[0]
except IndexError:
username = None
if not username:
username = ""
else: else:
username = ticket.user.username username = ticket.user.username
return HttpResponse("yes\n%s\n" % username, content_type="text/plain") return HttpResponse("yes\n%s\n" % username, content_type="text/plain")
@ -834,6 +831,10 @@ class ValidateService(View, AttributesMixin):
params['username'] = self.ticket.user.attributs.get( params['username'] = self.ticket.user.attributs.get(
self.ticket.service_pattern.user_field self.ticket.service_pattern.user_field
) )
if isinstance(params['username'], list):
# the list is not empty because we wont generate a ticket with a user_field
# that evaluate to False
params['username'] = params['username'][0]
if self.pgt_url and ( if self.pgt_url and (
self.pgt_url.startswith("https://") or self.pgt_url.startswith("https://") or
re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url) re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url)
@ -935,9 +936,12 @@ class ValidateService(View, AttributesMixin):
params, params,
content_type="text/xml; charset=utf-8" content_type="text/xml; charset=utf-8"
) )
except requests.exceptions.SSLError as error: except requests.exceptions.RequestException as error:
error = utils.unpack_nested_exception(error) error = utils.unpack_nested_exception(error)
raise ValidateError('INVALID_PROXY_CALLBACK', str(error)) raise ValidateError(
'INVALID_PROXY_CALLBACK',
"%s: %s" % (type(error), str(error))
)
else: else:
raise ValidateError( raise ValidateError(
'INVALID_PROXY_CALLBACK', 'INVALID_PROXY_CALLBACK',
@ -1017,7 +1021,7 @@ class Proxy(View):
except (models.BadUsername, models.BadFilter, models.UserFieldNotDefined): except (models.BadUsername, models.BadFilter, models.UserFieldNotDefined):
raise ValidateError( raise ValidateError(
'UNAUTHORIZED_USER', 'UNAUTHORIZED_USER',
'%s not allowed on %s' % (ticket.user, self.target_service) 'User %s not allowed on %s' % (ticket.user.username, self.target_service)
) )
@ -1076,11 +1080,15 @@ class SamlValidate(View, AttributesMixin):
'username': self.ticket.user.username, 'username': self.ticket.user.username,
'attributes': attributes 'attributes': attributes
} }
if self.ticket.service_pattern.user_field and \ if (self.ticket.service_pattern.user_field and
self.ticket.user.attributs.get(self.ticket.service_pattern.user_field): self.ticket.user.attributs.get(self.ticket.service_pattern.user_field)):
params['username'] = self.ticket.user.attributs.get( params['username'] = self.ticket.user.attributs.get(
self.ticket.service_pattern.user_field self.ticket.service_pattern.user_field
) )
if isinstance(params['username'], list):
# the list is not empty because we wont generate a ticket with a user_field
# that evaluate to False
params['username'] = params['username'][0]
logger.info( logger.info(
"SamlValidate: ticket %s validated for user %s on service %s." % ( "SamlValidate: ticket %s validated for user %s on service %s." % (
self.ticket.value, self.ticket.value,

5
pytest.ini Normal file
View file

@ -0,0 +1,5 @@
[pytest]
testpaths = cas_server/tests/
DJANGO_SETTINGS_MODULE = cas_server.tests.settings
norecursedirs = .* build dist docs
python_paths = .

View file

@ -1,10 +1,12 @@
tox==1.8.1 setuptools>=5.5
pytest==2.6.4 tox>=1.8.1
pytest-django==2.7.0 pytest>=2.6.4
pytest-pythonpath==0.3 pytest-django>=2.8.0
pytest-pythonpath>=0.3
pytest-cov>=2.2.1
requests>=2.4 requests>=2.4
django-picklefield>=0.3.1
requests_futures>=0.9.5 requests_futures>=0.9.5
django-picklefield>=0.3.1
django-bootstrap3>=5.4 django-bootstrap3>=5.4
lxml>=3.4 lxml>=3.4
six>=1 six>=1

View file

@ -1,22 +0,0 @@
#!/usr/bin/env python
import os, sys
import django
from django.conf import settings
import settings_tests
settings.configure(**settings_tests.__dict__)
django.setup()
try:
# Django <= 1.8
from django.test.simple import DjangoTestSuiteRunner
test_runner = DjangoTestSuiteRunner(verbosity=1)
except ImportError:
# Django >= 1.8
from django.test.runner import DiscoverRunner
test_runner = DiscoverRunner(verbosity=1)
failures = test_runner.run_tests(['cas_server'])
if failures:
sys.exit(failures)

View file

@ -34,7 +34,8 @@ setup(
version='0.5.0', version='0.5.0',
packages=[ packages=[
'cas_server', 'cas_server.migrations', 'cas_server', 'cas_server.migrations',
'cas_server.management', 'cas_server.management.commands' 'cas_server.management', 'cas_server.management.commands',
'cas_server.tests'
], ],
include_package_data=True, include_package_data=True,
license='GPLv3', license='GPLv3',

22
tox.ini
View file

@ -1,12 +1,13 @@
[tox] [tox]
envlist= envlist=
flake8,
check_rst,
py27-django17, py27-django17,
py27-django18, py27-django18,
py27-django19, py27-django19,
py34-django17, py34-django17,
py34-django18, py34-django18,
py34-django19, py34-django19,
flake8,
[flake8] [flake8]
max-line-length=100 max-line-length=100
@ -17,7 +18,7 @@ deps =
-r{toxinidir}/requirements-dev.txt -r{toxinidir}/requirements-dev.txt
[testenv] [testenv]
commands=python run_tests {posargs:tests} commands=py.test {posargs:cas_server/tests/}
[testenv:py27-django17] [testenv:py27-django17]
basepython=python2.7 basepython=python2.7
@ -60,3 +61,20 @@ basepython=python
deps=flake8 deps=flake8
commands=flake8 {toxinidir}/cas_server commands=flake8 {toxinidir}/cas_server
[testenv:check_rst]
basepython=python
deps=
docutils
Pygments
commands=rst2html.py --strict {toxinidir}/README.rst /dev/null
[testenv:coverage]
basepython=python
passenv=CODACY_PROJECT_TOKEN
deps=
-r{toxinidir}/requirements-dev.txt
codacy-coverage
commands=
py.test --cov=cas_server --cov-report xml
python-codacy-coverage -r {toxinidir}/coverage.xml