More unit tests (essentially for the login view) and some docstrings
This commit is contained in:
parent
7db3157864
commit
bab79c4de5
8 changed files with 343 additions and 63 deletions
|
@ -5,3 +5,4 @@ exclude_lines =
|
||||||
def __unicode__
|
def __unicode__
|
||||||
raise AssertionError
|
raise AssertionError
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
if six.PY3:
|
||||||
|
|
3
Makefile
3
Makefile
|
@ -49,8 +49,9 @@ coverage: test_venv
|
||||||
test_venv/bin/pip install coverage
|
test_venv/bin/pip install coverage
|
||||||
test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests
|
test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests
|
||||||
test_venv/bin/coverage html
|
test_venv/bin/coverage html
|
||||||
test_venv/bin/coverage xml
|
rm htmlcov/coverage_html.js # I am really pissed off by those keybord shortcuts
|
||||||
|
|
||||||
coverage_codacy: coverage
|
coverage_codacy: coverage
|
||||||
|
test_venv/bin/coverage xml
|
||||||
test_venv/bin/pip install codacy-coverage
|
test_venv/bin/pip install codacy-coverage
|
||||||
test_venv/bin/python-codacy-coverage -r coverage.xml
|
test_venv/bin/python-codacy-coverage -r coverage.xml
|
||||||
|
|
|
@ -219,7 +219,8 @@ Test backend settings. Only usefull if you are using the test authentication bac
|
||||||
* ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``.
|
* ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``.
|
||||||
* ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``.
|
* ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``.
|
||||||
* ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is
|
* ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is
|
||||||
``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}``.
|
``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net',
|
||||||
|
'alias': ['demo1', 'demo2']}``.
|
||||||
|
|
||||||
|
|
||||||
Authentication backend
|
Authentication backend
|
||||||
|
|
|
@ -78,5 +78,10 @@ setting_default('CAS_TEST_USER', 'test')
|
||||||
setting_default('CAS_TEST_PASSWORD', 'test')
|
setting_default('CAS_TEST_PASSWORD', 'test')
|
||||||
setting_default(
|
setting_default(
|
||||||
'CAS_TEST_ATTRIBUTES',
|
'CAS_TEST_ATTRIBUTES',
|
||||||
{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}
|
{
|
||||||
|
'nom': 'Nymous',
|
||||||
|
'prenom': 'Ano',
|
||||||
|
'email': 'anonymous@example.net',
|
||||||
|
'alias': ['demo1', 'demo2']
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,36 +3,49 @@ from .default_settings import settings
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.test import Client
|
from django.test import Client
|
||||||
|
|
||||||
|
import re
|
||||||
import six
|
import six
|
||||||
|
import random
|
||||||
from lxml import etree
|
from lxml import etree
|
||||||
|
from six.moves import range
|
||||||
|
|
||||||
from cas_server import models
|
from cas_server import models
|
||||||
from cas_server import utils
|
from cas_server import utils
|
||||||
|
|
||||||
|
|
||||||
def get_login_page_params():
|
def copy_form(form):
|
||||||
client = Client()
|
"""Copy form value into a dict"""
|
||||||
response = client.get('/login')
|
|
||||||
form = response.context["form"]
|
|
||||||
params = {}
|
params = {}
|
||||||
for field in form:
|
for field in form:
|
||||||
if field.value():
|
if field.value():
|
||||||
params[field.name] = field.value()
|
params[field.name] = field.value()
|
||||||
else:
|
else:
|
||||||
params[field.name] = ""
|
params[field.name] = ""
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def get_login_page_params(client=None):
|
||||||
|
"""Return a client and the POST params for the client to login"""
|
||||||
|
if client is None:
|
||||||
|
client = Client()
|
||||||
|
response = client.get('/login')
|
||||||
|
params = copy_form(response.context["form"])
|
||||||
return client, params
|
return client, params
|
||||||
|
|
||||||
|
|
||||||
def get_auth_client():
|
def get_auth_client(**update):
|
||||||
|
"""return a authenticated client"""
|
||||||
client, params = get_login_page_params()
|
client, params = get_login_page_params()
|
||||||
params["username"] = settings.CAS_TEST_USER
|
params["username"] = settings.CAS_TEST_USER
|
||||||
params["password"] = settings.CAS_TEST_PASSWORD
|
params["password"] = settings.CAS_TEST_PASSWORD
|
||||||
|
params.update(update)
|
||||||
|
|
||||||
client.post('/login', params)
|
client.post('/login', params)
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
def get_user_ticket_request(service):
|
def get_user_ticket_request(service):
|
||||||
|
"""Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
|
||||||
client = get_auth_client()
|
client = get_auth_client()
|
||||||
response = client.get("/login", {"service": service})
|
response = client.get("/login", {"service": service})
|
||||||
ticket_value = response['Location'].split('ticket=')[-1]
|
ticket_value = response['Location'].split('ticket=')[-1]
|
||||||
|
@ -45,6 +58,7 @@ def get_user_ticket_request(service):
|
||||||
|
|
||||||
|
|
||||||
def get_pgt():
|
def get_pgt():
|
||||||
|
"""return a dict contening a service, user and PGT ticket for this service"""
|
||||||
(host, port) = utils.PGTUrlHandler.run()[1:3]
|
(host, port) = utils.PGTUrlHandler.run()[1:3]
|
||||||
service = "http://%s:%s" % (host, port)
|
service = "http://%s:%s" % (host, port)
|
||||||
|
|
||||||
|
@ -110,7 +124,7 @@ class CheckPasswordCase(TestCase):
|
||||||
self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8"))
|
self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8"))
|
||||||
self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8"))
|
self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8"))
|
||||||
|
|
||||||
def test_hox_sha512(self):
|
def test_hex_sha512(self):
|
||||||
"""test the hex_sha512 auth method"""
|
"""test the hex_sha512 auth method"""
|
||||||
hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest()
|
hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest()
|
||||||
|
|
||||||
|
@ -123,29 +137,83 @@ class CheckPasswordCase(TestCase):
|
||||||
|
|
||||||
|
|
||||||
class LoginTestCase(TestCase):
|
class LoginTestCase(TestCase):
|
||||||
|
"""Tests for the login view"""
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
"""
|
||||||
|
Prepare the test context:
|
||||||
|
* set the auth class to 'cas_server.auth.TestAuthUser'
|
||||||
|
* create a service pattern for https://www.example.com/**
|
||||||
|
* Set the service pattern to return all user attributes
|
||||||
|
"""
|
||||||
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
|
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
|
||||||
|
|
||||||
|
# For general purpose testing
|
||||||
self.service_pattern = models.ServicePattern.objects.create(
|
self.service_pattern = models.ServicePattern.objects.create(
|
||||||
name="example",
|
name="example",
|
||||||
pattern="^https://www\.example\.com(/.*)?$",
|
pattern="^https://www\.example\.com(/.*)?$",
|
||||||
)
|
)
|
||||||
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
|
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
|
||||||
|
|
||||||
def test_login_view_post_goodpass_goodlt(self):
|
# For testing the restrict_users attributes
|
||||||
client, params = get_login_page_params()
|
self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create(
|
||||||
params["username"] = settings.CAS_TEST_USER
|
name="restrict_user_fail",
|
||||||
params["password"] = settings.CAS_TEST_PASSWORD
|
pattern="^https://restrict_user_fail\.example\.com(/.*)?$",
|
||||||
|
restrict_users=True,
|
||||||
|
)
|
||||||
|
self.service_pattern_restrict_user_success = models.ServicePattern.objects.create(
|
||||||
|
name="restrict_user_success",
|
||||||
|
pattern="^https://restrict_user_success\.example\.com(/.*)?$",
|
||||||
|
restrict_users=True,
|
||||||
|
)
|
||||||
|
models.Username.objects.create(
|
||||||
|
value=settings.CAS_TEST_USER,
|
||||||
|
service_pattern=self.service_pattern_restrict_user_success
|
||||||
|
)
|
||||||
|
|
||||||
response = client.post('/login', params)
|
# For testing the user attributes filtering conditions
|
||||||
|
self.service_pattern_filter_fail = models.ServicePattern.objects.create(
|
||||||
|
name="filter_fail",
|
||||||
|
pattern="^https://filter_fail\.example\.com(/.*)?$",
|
||||||
|
)
|
||||||
|
models.FilterAttributValue.objects.create(
|
||||||
|
attribut="right",
|
||||||
|
pattern="^admin$",
|
||||||
|
service_pattern=self.service_pattern_filter_fail
|
||||||
|
)
|
||||||
|
self.service_pattern_filter_success = models.ServicePattern.objects.create(
|
||||||
|
name="filter_success",
|
||||||
|
pattern="^https://filter_success\.example\.com(/.*)?$",
|
||||||
|
)
|
||||||
|
models.FilterAttributValue.objects.create(
|
||||||
|
attribut="email",
|
||||||
|
pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']),
|
||||||
|
service_pattern=self.service_pattern_filter_success
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, 200)
|
# For testing the user_field attributes
|
||||||
|
self.service_pattern_field_needed_fail = models.ServicePattern.objects.create(
|
||||||
|
name="field_needed_fail",
|
||||||
|
pattern="^https://field_needed_fail\.example\.com(/.*)?$",
|
||||||
|
user_field="uid"
|
||||||
|
)
|
||||||
|
self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
|
||||||
|
name="field_needed_success",
|
||||||
|
pattern="^https://field_needed_success\.example\.com(/.*)?$",
|
||||||
|
user_field="nom"
|
||||||
|
)
|
||||||
|
|
||||||
|
def assert_logged(self, client, response, warn=False, code=200):
|
||||||
|
"""Assertions testing that client is well authenticated"""
|
||||||
|
self.assertEqual(response.status_code, code)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
(
|
(
|
||||||
b"You have successfully logged into "
|
b"You have successfully logged into "
|
||||||
b"the Central Authentication Service"
|
b"the Central Authentication Service"
|
||||||
) in response.content
|
) in response.content
|
||||||
)
|
)
|
||||||
|
self.assertTrue(client.session["username"] == settings.CAS_TEST_USER)
|
||||||
|
self.assertTrue(client.session["warn"] is warn)
|
||||||
|
self.assertTrue(client.session["authenticated"] is True)
|
||||||
|
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
models.User.objects.get(
|
models.User.objects.get(
|
||||||
|
@ -154,7 +222,59 @@ class LoginTestCase(TestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def assert_login_failed(self, client, response, code=200):
|
||||||
|
"""Assertions testing a failed login attempt"""
|
||||||
|
self.assertEqual(response.status_code, code)
|
||||||
|
self.assertFalse(
|
||||||
|
(
|
||||||
|
b"You have successfully logged into "
|
||||||
|
b"the Central Authentication Service"
|
||||||
|
) in response.content
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(client.session.get("username") is None)
|
||||||
|
self.assertTrue(client.session.get("warn") is None)
|
||||||
|
self.assertTrue(client.session.get("authenticated") is None)
|
||||||
|
|
||||||
|
def test_login_view_post_goodpass_goodlt(self):
|
||||||
|
"""Test a successul login"""
|
||||||
|
client, params = get_login_page_params()
|
||||||
|
params["username"] = settings.CAS_TEST_USER
|
||||||
|
params["password"] = settings.CAS_TEST_PASSWORD
|
||||||
|
self.assertTrue(params['lt'] in client.session['lt'])
|
||||||
|
|
||||||
|
response = client.post('/login', params)
|
||||||
|
self.assert_logged(client, response)
|
||||||
|
# LoginTicket conssumed
|
||||||
|
self.assertTrue(params['lt'] not in client.session['lt'])
|
||||||
|
|
||||||
|
def test_login_view_post_goodpass_goodlt_warn(self):
|
||||||
|
"""Test a successul login requesting to be warned before creating services tickets"""
|
||||||
|
client, params = get_login_page_params()
|
||||||
|
params["username"] = settings.CAS_TEST_USER
|
||||||
|
params["password"] = settings.CAS_TEST_PASSWORD
|
||||||
|
params["warn"] = "on"
|
||||||
|
|
||||||
|
response = client.post('/login', params)
|
||||||
|
self.assert_logged(client, response, warn=True)
|
||||||
|
|
||||||
|
def test_lt_max(self):
|
||||||
|
"""Check we only keep the last 100 Login Ticket for a user"""
|
||||||
|
client, params = get_login_page_params()
|
||||||
|
current_lt = params["lt"]
|
||||||
|
i_in_test = random.randint(0, 100)
|
||||||
|
i_not_in_test = random.randint(100, 150)
|
||||||
|
for i in range(150):
|
||||||
|
if i == i_in_test:
|
||||||
|
self.assertTrue(current_lt in client.session['lt'])
|
||||||
|
if i == i_not_in_test:
|
||||||
|
self.assertTrue(current_lt not in client.session['lt'])
|
||||||
|
self.assertTrue(len(client.session['lt']) <= 100)
|
||||||
|
client, params = get_login_page_params(client)
|
||||||
|
self.assertTrue(len(client.session['lt']) <= 100)
|
||||||
|
|
||||||
def test_login_view_post_badlt(self):
|
def test_login_view_post_badlt(self):
|
||||||
|
"""Login attempt with a bad LoginTicket"""
|
||||||
client, params = get_login_page_params()
|
client, params = get_login_page_params()
|
||||||
params["username"] = settings.CAS_TEST_USER
|
params["username"] = settings.CAS_TEST_USER
|
||||||
params["password"] = settings.CAS_TEST_PASSWORD
|
params["password"] = settings.CAS_TEST_PASSWORD
|
||||||
|
@ -162,47 +282,26 @@ class LoginTestCase(TestCase):
|
||||||
|
|
||||||
response = client.post('/login', params)
|
response = client.post('/login', params)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assert_login_failed(client, response)
|
||||||
self.assertTrue(b"Invalid login ticket" in response.content)
|
self.assertTrue(b"Invalid login ticket" in response.content)
|
||||||
self.assertFalse(
|
|
||||||
(
|
|
||||||
b"You have successfully logged into "
|
|
||||||
b"the Central Authentication Service"
|
|
||||||
) in response.content
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_login_view_post_badpass_good_lt(self):
|
def test_login_view_post_badpass_good_lt(self):
|
||||||
|
"""Login attempt with a bad password"""
|
||||||
client, params = get_login_page_params()
|
client, params = get_login_page_params()
|
||||||
params["username"] = settings.CAS_TEST_USER
|
params["username"] = settings.CAS_TEST_USER
|
||||||
params["password"] = "test2"
|
params["password"] = "test2"
|
||||||
response = client.post('/login', params)
|
response = client.post('/login', params)
|
||||||
|
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assert_login_failed(client, response)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
(
|
(
|
||||||
b"The credentials you provided cannot be "
|
b"The credentials you provided cannot be "
|
||||||
b"determined to be authentic"
|
b"determined to be authentic"
|
||||||
) in response.content
|
) in response.content
|
||||||
)
|
)
|
||||||
self.assertFalse(
|
|
||||||
(
|
|
||||||
b"You have successfully logged into "
|
|
||||||
b"the Central Authentication Service"
|
|
||||||
) in response.content
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_view_login_get_auth_allowed_service(self):
|
def assert_ticket_attributes(self, client, ticket_value):
|
||||||
client = get_auth_client()
|
"""check the ticket attributes in the db"""
|
||||||
response = client.get("/login?service=https://www.example.com")
|
|
||||||
self.assertEqual(response.status_code, 302)
|
|
||||||
self.assertTrue(response.has_header('Location'))
|
|
||||||
self.assertTrue(
|
|
||||||
response['Location'].startswith(
|
|
||||||
"https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
ticket_value = response['Location'].split('ticket=')[-1]
|
|
||||||
user = models.User.objects.get(
|
user = models.User.objects.get(
|
||||||
username=settings.CAS_TEST_USER,
|
username=settings.CAS_TEST_USER,
|
||||||
session_key=client.session.session_key
|
session_key=client.session.session_key
|
||||||
|
@ -214,12 +313,136 @@ class LoginTestCase(TestCase):
|
||||||
self.assertEqual(ticket.validate, False)
|
self.assertEqual(ticket.validate, False)
|
||||||
self.assertEqual(ticket.service_pattern, self.service_pattern)
|
self.assertEqual(ticket.service_pattern, self.service_pattern)
|
||||||
|
|
||||||
|
def assert_service_ticket(self, client, response):
|
||||||
|
"""check that a ticket is well emited when requested on a allowed service"""
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
self.assertTrue(response.has_header('Location'))
|
||||||
|
self.assertTrue(
|
||||||
|
response['Location'].startswith(
|
||||||
|
"https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
ticket_value = response['Location'].split('ticket=')[-1]
|
||||||
|
self.assert_ticket_attributes(client, ticket_value)
|
||||||
|
|
||||||
|
def test_view_login_get_allowed_service(self):
|
||||||
|
"""Request a ticket for an allowed service by an unauthenticated client"""
|
||||||
|
client = Client()
|
||||||
|
response = client.get("/login?service=https://www.example.com")
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
self.assertTrue(
|
||||||
|
(
|
||||||
|
"Authentication required by service "
|
||||||
|
"example (https://www.example.com)"
|
||||||
|
) in response.content
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_view_login_get_denied_service(self):
|
||||||
|
"""Request a ticket for an denied service by an unauthenticated client"""
|
||||||
|
client = Client()
|
||||||
|
response = client.get("/login?service=https://www.example.net")
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
self.assertTrue("Service https://www.example.net non allowed" in response.content)
|
||||||
|
|
||||||
|
def test_view_login_get_auth_allowed_service(self):
|
||||||
|
"""Request a ticket for an allowed service by an authenticated client"""
|
||||||
|
# client is already authenticated
|
||||||
|
client = get_auth_client()
|
||||||
|
response = client.get("/login?service=https://www.example.com")
|
||||||
|
self.assert_service_ticket(client, response)
|
||||||
|
|
||||||
|
def test_view_login_get_auth_allowed_service_warn(self):
|
||||||
|
"""Request a ticket for an allowed service by an authenticated client"""
|
||||||
|
# client is already authenticated
|
||||||
|
client = get_auth_client(warn="on")
|
||||||
|
response = client.get("/login?service=https://www.example.com")
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
self.assertTrue(
|
||||||
|
(
|
||||||
|
"Authentication has been required by service "
|
||||||
|
"example (https://www.example.com)"
|
||||||
|
) in response.content
|
||||||
|
)
|
||||||
|
|
||||||
|
params = copy_form(response.context["form"])
|
||||||
|
response = client.post("/login", params)
|
||||||
|
self.assert_service_ticket(client, response)
|
||||||
|
|
||||||
def test_view_login_get_auth_denied_service(self):
|
def test_view_login_get_auth_denied_service(self):
|
||||||
|
"""Request a ticket for a not allowed service by an authenticated client"""
|
||||||
client = get_auth_client()
|
client = get_auth_client()
|
||||||
response = client.get("/login?service=https://www.example.org")
|
response = client.get("/login?service=https://www.example.org")
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
self.assertTrue(b"Service https://www.example.org non allowed" in response.content)
|
self.assertTrue(b"Service https://www.example.org non allowed" in response.content)
|
||||||
|
|
||||||
|
def test_user_logged_not_in_db(self):
|
||||||
|
"""If the user is logged but has been delete from the database, it should be logged out"""
|
||||||
|
client = get_auth_client()
|
||||||
|
models.User.objects.get(
|
||||||
|
username=settings.CAS_TEST_USER,
|
||||||
|
session_key=client.session.session_key
|
||||||
|
).delete()
|
||||||
|
response = client.get("/login")
|
||||||
|
|
||||||
|
self.assert_login_failed(client, response, code=302)
|
||||||
|
self.assertEqual(response["Location"], "/login?")
|
||||||
|
|
||||||
|
def test_service_restrict_user(self):
|
||||||
|
"""Testing the restric user capability fro a service"""
|
||||||
|
service = "https://restrict_user_fail.example.com"
|
||||||
|
client = get_auth_client()
|
||||||
|
response = client.get("/login", {'service': service})
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
self.assertTrue("Username non allowed" in response.content)
|
||||||
|
|
||||||
|
service = "https://restrict_user_success.example.com"
|
||||||
|
response = client.get("/login", {'service': service})
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
|
||||||
|
|
||||||
|
def test_service_filter(self):
|
||||||
|
"""Test the filtering on user attributes"""
|
||||||
|
service = "https://filter_fail.example.com"
|
||||||
|
client = get_auth_client()
|
||||||
|
response = client.get("/login", {'service': service})
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
self.assertTrue("User charateristics non allowed" in response.content)
|
||||||
|
|
||||||
|
service = "https://filter_success.example.com"
|
||||||
|
response = client.get("/login", {'service': service})
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
|
||||||
|
|
||||||
|
def test_service_user_field(self):
|
||||||
|
"""Test using a user attribute as username: case on if the attribute exists or not"""
|
||||||
|
service = "https://field_needed_fail.example.com"
|
||||||
|
client = get_auth_client()
|
||||||
|
response = client.get("/login", {'service': service})
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
self.assertTrue("The attribut uid is needed to use that service" in response.content)
|
||||||
|
|
||||||
|
service = "https://field_needed_success.example.com"
|
||||||
|
response = client.get("/login", {'service': service})
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
|
||||||
|
|
||||||
|
def test_gateway(self):
|
||||||
|
"""test gateway parameter"""
|
||||||
|
|
||||||
|
# First with an authenticated client that fail to get a ticket for a service
|
||||||
|
service = "https://restrict_user_fail.example.com"
|
||||||
|
client = get_auth_client()
|
||||||
|
response = client.get("/login", {'service': service, 'gateway': 'on'})
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
self.assertEqual(response["Location"], service)
|
||||||
|
|
||||||
|
# second for an user not yet authenticated on a valid service
|
||||||
|
client = Client()
|
||||||
|
response = client.get('/login', {'service': service, 'gateway': 'on'})
|
||||||
|
self.assertEqual(response.status_code, 302)
|
||||||
|
self.assertEqual(response["Location"], service)
|
||||||
|
|
||||||
|
|
||||||
class LogoutTestCase(TestCase):
|
class LogoutTestCase(TestCase):
|
||||||
|
|
||||||
|
@ -454,17 +677,24 @@ class ValidateServiceTestCase(TestCase):
|
||||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||||
)
|
)
|
||||||
self.assertEqual(len(attributes), 1)
|
self.assertEqual(len(attributes), 1)
|
||||||
attrs1 = {}
|
attrs1 = set()
|
||||||
for attr in attributes[0]:
|
for attr in attributes[0]:
|
||||||
attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text
|
attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
|
||||||
|
|
||||||
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||||
self.assertEqual(len(attributes), len(attrs1))
|
self.assertEqual(len(attributes), len(attrs1))
|
||||||
attrs2 = {}
|
attrs2 = set()
|
||||||
for attr in attributes:
|
for attr in attributes:
|
||||||
attrs2[attr.attrib['name']] = attr.attrib['value']
|
attrs2.add((attr.attrib['name'], attr.attrib['value']))
|
||||||
|
original = set()
|
||||||
|
for key, value in settings.CAS_TEST_ATTRIBUTES.items():
|
||||||
|
if isinstance(value, list):
|
||||||
|
for v in value:
|
||||||
|
original.add((key, v))
|
||||||
|
else:
|
||||||
|
original.add((key, value))
|
||||||
self.assertEqual(attrs1, attrs2)
|
self.assertEqual(attrs1, attrs2)
|
||||||
self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES)
|
self.assertEqual(attrs1, original)
|
||||||
|
|
||||||
def test_validate_service_view_badservice(self):
|
def test_validate_service_view_badservice(self):
|
||||||
ticket = get_user_ticket_request(self.service)[1]
|
ticket = get_user_ticket_request(self.service)[1]
|
||||||
|
@ -623,17 +853,24 @@ class ProxyTestCase(TestCase):
|
||||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||||
)
|
)
|
||||||
self.assertEqual(len(attributes), 1)
|
self.assertEqual(len(attributes), 1)
|
||||||
attrs1 = {}
|
attrs1 = set()
|
||||||
for attr in attributes[0]:
|
for attr in attributes[0]:
|
||||||
attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text
|
attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
|
||||||
|
|
||||||
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||||
self.assertEqual(len(attributes), len(attrs1))
|
self.assertEqual(len(attributes), len(attrs1))
|
||||||
attrs2 = {}
|
attrs2 = set()
|
||||||
for attr in attributes:
|
for attr in attributes:
|
||||||
attrs2[attr.attrib['name']] = attr.attrib['value']
|
attrs2.add((attr.attrib['name'], attr.attrib['value']))
|
||||||
|
original = set()
|
||||||
|
for key, value in settings.CAS_TEST_ATTRIBUTES.items():
|
||||||
|
if isinstance(value, list):
|
||||||
|
for v in value:
|
||||||
|
original.add((key, v))
|
||||||
|
else:
|
||||||
|
original.add((key, value))
|
||||||
self.assertEqual(attrs1, attrs2)
|
self.assertEqual(attrs1, attrs2)
|
||||||
self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES)
|
self.assertEqual(attrs1, original)
|
||||||
|
|
||||||
def test_validate_proxy_bad(self):
|
def test_validate_proxy_bad(self):
|
||||||
params = get_pgt()
|
params = get_pgt()
|
||||||
|
|
|
@ -105,6 +105,7 @@ class LogoutView(View, LogoutMixin):
|
||||||
service = None
|
service = None
|
||||||
|
|
||||||
def init_get(self, request):
|
def init_get(self, request):
|
||||||
|
"""Initialize GET received parameters"""
|
||||||
self.request = request
|
self.request = request
|
||||||
self.service = request.GET.get('service')
|
self.service = request.GET.get('service')
|
||||||
self.url = request.GET.get('url')
|
self.url = request.GET.get('url')
|
||||||
|
@ -196,6 +197,7 @@ class LoginView(View, LogoutMixin):
|
||||||
USER_NOT_AUTHENTICATED = 6
|
USER_NOT_AUTHENTICATED = 6
|
||||||
|
|
||||||
def init_post(self, request):
|
def init_post(self, request):
|
||||||
|
"""Initialize POST received parameters"""
|
||||||
self.request = request
|
self.request = request
|
||||||
self.service = request.POST.get('service')
|
self.service = request.POST.get('service')
|
||||||
self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False")
|
self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False")
|
||||||
|
@ -205,15 +207,19 @@ class LoginView(View, LogoutMixin):
|
||||||
if request.POST.get('warned') and request.POST['warned'] != "False":
|
if request.POST.get('warned') and request.POST['warned'] != "False":
|
||||||
self.warned = True
|
self.warned = True
|
||||||
|
|
||||||
def check_lt(self):
|
def gen_lt(self):
|
||||||
# save LT for later check
|
"""Generate a new LoginTicket and add it to the list of valid LT for the user"""
|
||||||
lt_valid = self.request.session.get('lt', [])
|
|
||||||
lt_send = self.request.POST.get('lt')
|
|
||||||
# generate a new LT (by posting the LT has been consumed)
|
|
||||||
self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
|
self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
|
||||||
if len(self.request.session['lt']) > 100:
|
if len(self.request.session['lt']) > 100:
|
||||||
self.request.session['lt'] = self.request.session['lt'][-100:]
|
self.request.session['lt'] = self.request.session['lt'][-100:]
|
||||||
|
|
||||||
|
def check_lt(self):
|
||||||
|
"""Check is the POSTed LoginTicket is valid, if yes invalide it"""
|
||||||
|
# save LT for later check
|
||||||
|
lt_valid = self.request.session.get('lt', [])
|
||||||
|
lt_send = self.request.POST.get('lt')
|
||||||
|
# generate a new LT (by posting the LT has been consumed)
|
||||||
|
self.gen_lt()
|
||||||
# check if send LT is valid
|
# check if send LT is valid
|
||||||
if lt_valid is None or lt_send not in lt_valid:
|
if lt_valid is None or lt_send not in lt_valid:
|
||||||
return False
|
return False
|
||||||
|
@ -238,7 +244,7 @@ class LoginView(View, LogoutMixin):
|
||||||
username=self.request.session['username'],
|
username=self.request.session['username'],
|
||||||
session_key=self.request.session.session_key
|
session_key=self.request.session.session_key
|
||||||
)
|
)
|
||||||
self.user.save()
|
self.user.save() # pragma: no cover (should not happend)
|
||||||
except models.User.DoesNotExist:
|
except models.User.DoesNotExist:
|
||||||
self.user = models.User.objects.create(
|
self.user = models.User.objects.create(
|
||||||
username=self.request.session['username'],
|
username=self.request.session['username'],
|
||||||
|
@ -250,10 +256,15 @@ class LoginView(View, LogoutMixin):
|
||||||
elif ret == self.USER_ALREADY_LOGGED:
|
elif ret == self.USER_ALREADY_LOGGED:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise EnvironmentError("invalid output for LoginView.process_post")
|
raise EnvironmentError("invalid output for LoginView.process_post") # pragma: no cover
|
||||||
return self.common()
|
return self.common()
|
||||||
|
|
||||||
def process_post(self):
|
def process_post(self):
|
||||||
|
"""
|
||||||
|
Analyse the POST request:
|
||||||
|
* check that the LoginTicket is valid
|
||||||
|
* check that the user sumited credentials are valid
|
||||||
|
"""
|
||||||
if not self.check_lt():
|
if not self.check_lt():
|
||||||
values = self.request.POST.copy()
|
values = self.request.POST.copy()
|
||||||
# if not set a new LT and fail
|
# if not set a new LT and fail
|
||||||
|
@ -280,6 +291,7 @@ class LoginView(View, LogoutMixin):
|
||||||
return self.USER_ALREADY_LOGGED
|
return self.USER_ALREADY_LOGGED
|
||||||
|
|
||||||
def init_get(self, request):
|
def init_get(self, request):
|
||||||
|
"""Initialize GET received parameters"""
|
||||||
self.request = request
|
self.request = request
|
||||||
self.service = request.GET.get('service')
|
self.service = request.GET.get('service')
|
||||||
self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False")
|
self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False")
|
||||||
|
@ -294,15 +306,16 @@ class LoginView(View, LogoutMixin):
|
||||||
return self.common()
|
return self.common()
|
||||||
|
|
||||||
def process_get(self):
|
def process_get(self):
|
||||||
# generate a new LT if none is present
|
"""Analyse the GET request"""
|
||||||
self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
|
# generate a new LT
|
||||||
|
self.gen_lt()
|
||||||
if not self.request.session.get("authenticated") or self.renew:
|
if not self.request.session.get("authenticated") or self.renew:
|
||||||
self.init_form()
|
self.init_form()
|
||||||
return self.USER_NOT_AUTHENTICATED
|
return self.USER_NOT_AUTHENTICATED
|
||||||
return self.USER_AUTHENTICATED
|
return self.USER_AUTHENTICATED
|
||||||
|
|
||||||
def init_form(self, values=None):
|
def init_form(self, values=None):
|
||||||
|
"""Initialization of the good form depending of POST and GET parameters"""
|
||||||
self.form = forms.UserCredential(
|
self.form = forms.UserCredential(
|
||||||
values,
|
values,
|
||||||
initial={
|
initial={
|
||||||
|
|
|
@ -52,7 +52,7 @@ MIDDLEWARE_CLASSES = [
|
||||||
'django.middleware.locale.LocaleMiddleware',
|
'django.middleware.locale.LocaleMiddleware',
|
||||||
]
|
]
|
||||||
|
|
||||||
ROOT_URLCONF = 'cas_server.urls'
|
ROOT_URLCONF = 'urls_tests'
|
||||||
|
|
||||||
# Database
|
# Database
|
||||||
# https://docs.djangoproject.com/en/1.9/ref/settings/#databases
|
# https://docs.djangoproject.com/en/1.9/ref/settings/#databases
|
||||||
|
|
22
urls_tests.py
Normal file
22
urls_tests.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
"""cas URL Configuration
|
||||||
|
|
||||||
|
The `urlpatterns` list routes URLs to views. For more information please see:
|
||||||
|
https://docs.djangoproject.com/en/1.9/topics/http/urls/
|
||||||
|
Examples:
|
||||||
|
Function views
|
||||||
|
1. Add an import: from my_app import views
|
||||||
|
2. Add a URL to urlpatterns: url(r'^$', views.home, name='home')
|
||||||
|
Class-based views
|
||||||
|
1. Add an import: from other_app.views import Home
|
||||||
|
2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home')
|
||||||
|
Including another URLconf
|
||||||
|
1. Import the include() function: from django.conf.urls import url, include, include
|
||||||
|
2. Add a URL to urlpatterns: url(r'^blog/', include('blog.urls'))
|
||||||
|
"""
|
||||||
|
from django.conf.urls import url, include
|
||||||
|
from django.contrib import admin
|
||||||
|
|
||||||
|
urlpatterns = [
|
||||||
|
url(r'^admin/', admin.site.urls),
|
||||||
|
url(r'^', include('cas_server.urls', namespace='cas_server')),
|
||||||
|
]
|
Loading…
Reference in a new issue