Tests comments and move http server handlers from cas_server.utils to cas_server.tests.utils

This commit is contained in:
Valentin Samir 2016-06-30 23:13:53 +02:00
parent 3ada10b3c5
commit c7c5151acf
4 changed files with 133 additions and 86 deletions

View file

@ -9,8 +9,7 @@ from datetime import timedelta
from importlib import import_module
from cas_server import models
from cas_server import utils
from cas_server.tests.utils import get_auth_client
from cas_server.tests.utils import get_auth_client, HttpParamsHandler
from cas_server.tests.mixin import UserModels, BaseServicePattern
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
@ -125,22 +124,32 @@ class TicketTestCase(TestCase, UserModels, BaseServicePattern):
def test_clean_old_service_ticket(self):
"""test tickets clean_old_entries"""
# ge an authenticated client
client = get_auth_client()
# get the user associated to the client
user = self.get_user(client)
# generate a ticket for that client, waiting for validation
self.get_ticket(user, models.ServiceTicket, self.service, self.service_pattern)
# generate another ticket for those validation time has expired
self.get_ticket(
user, models.ServiceTicket,
self.service, self.service_pattern, validity_expired=True
)
(httpd, host, port) = utils.HttpParamsHandler.run()[0:3]
(httpd, host, port) = HttpParamsHandler.run()[0:3]
service = "http://%s:%s" % (host, port)
# generate a ticket with SLO having timeout reach
self.get_ticket(
user, models.ServiceTicket,
service, self.service_pattern, timeout_expired=True,
validate=True, single_log_out=True
)
# there should be 3 tickets in the db
self.assertEqual(len(models.ServiceTicket.objects.all()), 3)
# we call the clean_old_entries method that should delete validated non SLO ticket and
# expired non validated ticket and send SLO for SLO expired ticket before deleting then
models.ServiceTicket.clean_old_entries()
params = httpd.PARAMS
# we successfully got a SLO request
self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest'])
# only 1 ticket remain in the db
self.assertEqual(len(models.ServiceTicket.objects.all()), 1)

View file

@ -21,7 +21,9 @@ from cas_server.tests.utils import (
get_user_ticket_request,
get_pgt,
get_proxy_ticket,
get_validated_ticket
get_validated_ticket,
HttpParamsHandler,
Http404Handler
)
from cas_server.tests.mixin import BaseServicePattern, XmlContent
@ -697,7 +699,7 @@ class LogoutTestCase(TestCase):
# test normal SLO
# setup a simple one request http server
(httpd, host, port) = utils.HttpParamsHandler.run()[0:3]
(httpd, host, port) = HttpParamsHandler.run()[0:3]
# build a service url depending on which port the http server has binded
service = "http://%s:%s" % (host, port)
# get a ticket requested by client and being validated by the service
@ -709,7 +711,7 @@ class LogoutTestCase(TestCase):
# text SLO with a single_log_out_callback
# setup a simple one request http server
(httpd, host, port) = utils.HttpParamsHandler.run()[0:3]
(httpd, host, port) = HttpParamsHandler.run()[0:3]
# set the default test service pattern to use the http server port for SLO requests.
# in fact, this single_log_out_callback parametter is usefull to implement SLO
# for non http service like imap or ftp
@ -1273,7 +1275,7 @@ class ValidateServiceTestCase(TestCase, XmlContent):
def test_validate_service_view_ok_pgturl(self):
"""test the retrieval of a ProxyGrantingTicket"""
# start a simple on request http server
(httpd, host, port) = utils.HttpParamsHandler.run()[0:3]
(httpd, host, port) = HttpParamsHandler.run()[0:3]
# construct the service from it
service = "http://%s:%s" % (host, port)
@ -1304,7 +1306,7 @@ class ValidateServiceTestCase(TestCase, XmlContent):
def test_validate_service_pgturl_sslerror(self):
"""test the retrieval of a ProxyGrantingTicket with a SSL error on the pgtUrl"""
(host, port) = utils.HttpParamsHandler.run()[1:3]
(host, port) = HttpParamsHandler.run()[1:3]
# is fact the service listen on http and not https raisin a SSL Protocol Error
# but other SSL/TLS error should behave the same
service = "https://%s:%s" % (host, port)
@ -1329,7 +1331,7 @@ class ValidateServiceTestCase(TestCase, XmlContent):
test the retrieval on a ProxyGrantingTicket then to pgtUrl return a http error.
PGT creation should be aborted but the ticket still be valid
"""
(host, port) = utils.Http404Handler.run()[1:3]
(host, port) = Http404Handler.run()[1:3]
service = "http://%s:%s" % (host, port)
ticket = get_user_ticket_request(service)[1]
@ -1424,8 +1426,10 @@ class ProxyTestCase(TestCase, BaseServicePattern, XmlContent):
"""tests for the proxy view"""
def setUp(self):
"""preparing test context"""
# we prepare a bunch a service url and service patterns for tests
self.setup_service_patterns(proxy=True)
# set the default service pattern to localhost to be able to retrieve PGT
self.service = 'http://127.0.0.1'
self.service_pattern = models.ServicePattern.objects.create(
name="localhost",
@ -1433,6 +1437,7 @@ class ProxyTestCase(TestCase, BaseServicePattern, XmlContent):
proxy=True,
proxy_callback=True
)
# transmit all attributes
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_validate_proxy_ok(self):
@ -1440,13 +1445,20 @@ class ProxyTestCase(TestCase, BaseServicePattern, XmlContent):
Get a PGT, get a proxy ticket, validate it. Validation should succeed and
show the proxy service URL.
"""
# we directrly get a ProxyGrantingTicket
params = get_pgt()
# get a proxy ticket
# We try get a proxy ticket with our PGT
client1 = Client()
response = client1.get('/proxy', {'pgt': params['pgtId'], 'targetService': self.service})
# for what we send a GET request to /proxy with ge PGT and the target service for which
# we want a ProxyTicket to.
response = client1.get(
'/proxy',
{'pgt': params['pgtId'], 'targetService': "https://www.example.com"}
)
self.assertEqual(response.status_code, 200)
# we should sucessfully reteive a PT
root = etree.fromstring(response.content)
sucess = root.xpath("//cas:proxySuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertTrue(sucess)
@ -1458,16 +1470,21 @@ class ProxyTestCase(TestCase, BaseServicePattern, XmlContent):
self.assertEqual(len(proxy_ticket), 1)
proxy_ticket = proxy_ticket[0].text
# validate the proxy ticket
# validate the proxy ticket with the service for which is was emitted
client2 = Client()
response = client2.get('/proxyValidate', {'ticket': proxy_ticket, 'service': self.service})
response = client2.get(
'/proxyValidate',
{'ticket': proxy_ticket, 'service': "https://www.example.com"}
)
# validation should succeed and return settings.CAS_TEST_USER as username
# and settings.CAS_TEST_ATTRIBUTES as attributes
root = self.assert_success(
response,
settings.CAS_TEST_USER,
settings.CAS_TEST_ATTRIBUTES
)
# check that the proxy is send to the end service
# in the PT validation response, it should have the service url of the PGY
proxies = root.xpath("//cas:proxies", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(proxies), 1)
proxy = proxies[0].xpath("//cas:proxy", namespaces={'cas': "http://www.yale.edu/tp/cas"})
@ -1476,6 +1493,7 @@ class ProxyTestCase(TestCase, BaseServicePattern, XmlContent):
def test_validate_proxy_bad_pgt(self):
"""Try to get a ProxyTicket with a bad PGT. The PT generation should fail"""
# we directrly get a ProxyGrantingTicket
params = get_pgt()
client = Client()
response = client.get(
@ -1496,8 +1514,10 @@ class ProxyTestCase(TestCase, BaseServicePattern, XmlContent):
Try to get a ProxyTicket for a denied service and
a service that do not allow PT. The PT generation should fail.
"""
# we directrly get a ProxyGrantingTicket
params = get_pgt()
# try to get a PT for a denied service
client1 = Client()
response = client1.get(
'/proxy',
@ -1509,7 +1529,7 @@ class ProxyTestCase(TestCase, BaseServicePattern, XmlContent):
"https://www.example.org"
)
# service do not allow proxy ticket
# try to get a PT for a service that do not allow PT
self.service_pattern.proxy = False
self.service_pattern.save()
@ -1531,16 +1551,20 @@ class ProxyTestCase(TestCase, BaseServicePattern, XmlContent):
def test_proxy_unauthorized_user(self):
"""
Try to get a PT for services that do not allow the current user:
* first with a service that restrict allower username
* first with a service that restrict allowed username
* second with a service requiring somes conditions on the user attributes
* third with a service using a particular user attribute as username
All this tests should fail
"""
# we directrly get a ProxyGrantingTicket
params = get_pgt()
for service in [
# do ot allow the test username
self.service_restrict_user_fail,
# require the 'nom' attribute to be 'toto'
self.service_filter_fail,
# want to use the non-exitant 'uid' attribute as username
self.service_field_needed_fail
]:
client = Client()
@ -1548,6 +1572,7 @@ class ProxyTestCase(TestCase, BaseServicePattern, XmlContent):
'/proxy',
{'pgt': params['pgtId'], 'targetService': service}
)
# PT generation should fail
self.assert_error(
response,
"UNAUTHORIZED_USER",
@ -1575,8 +1600,10 @@ class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent):
"""tests for the proxy view"""
def setUp(self):
"""preparing test context"""
# we prepare a bunch a service url and service patterns for tests
self.setup_service_patterns(proxy=True)
# special service pattern for retrieving a PGT
self.service_pgt = 'http://127.0.0.1'
self.service_pattern_pgt = models.ServicePattern.objects.create(
name="localhost",
@ -1589,6 +1616,7 @@ class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent):
service_pattern=self.service_pattern_pgt
)
# template for the XML POST need to be send to validate a ticket using SAML 1.1
xml_template = """
<SOAP-ENV:Envelope xmlns:SOAP-ENV="http://schemas.xmlsoap.org/soap/envelope/">
<SOAP-ENV:Header/>
@ -1607,6 +1635,7 @@ class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent):
def assert_success(self, response, username, original_attributes):
"""assert ticket validation success"""
self.assertEqual(response.status_code, 200)
# on validation success, the response should have a StatusCode set to Success
root = etree.fromstring(response.content)
success = root.xpath(
"//samlp:StatusCode",
@ -1615,6 +1644,7 @@ class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent):
self.assertEqual(len(success), 1)
self.assertTrue(success[0].attrib['Value'].endswith(":Success"))
# the user username should be return whithin <NameIdentifier> tags
user = root.xpath(
"//samla:NameIdentifier",
namespaces={'samla': "urn:oasis:names:tc:SAML:1.0:assertion"}
@ -1622,6 +1652,7 @@ class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent):
self.assertTrue(user)
self.assertEqual(user[0].text, username)
# the returned attributes should match original_attributes
attributes = root.xpath(
"//samla:AttributeStatement/samla:Attribute",
namespaces={'samla': "urn:oasis:names:tc:SAML:1.0:assertion"}
@ -1641,6 +1672,7 @@ class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent):
def assert_error(self, response, code, msg=None):
"""assert ticket validation error"""
self.assertEqual(response.status_code, 200)
# on error the status code value should be the one provider in `code`
root = etree.fromstring(response.content)
error = root.xpath(
"//samlp:StatusCode",
@ -1648,6 +1680,7 @@ class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent):
)
self.assertEqual(len(error), 1)
self.assertTrue(error[0].attrib['Value'].endswith(":%s" % code))
# it may have an error message
if msg is not None:
self.assertEqual(error[0].text, msg)
@ -1656,12 +1689,15 @@ class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent):
test with a valid (ticket, service), with a ST and a PT,
the username and all attributes are transmited"""
tickets = [
# return a ServiceTicket (standard ticket) waiting for validation
get_user_ticket_request(self.service)[1],
# return a PT waiting for validation
get_proxy_ticket(self.service)
]
for ticket in tickets:
client = Client()
# we send the POST validation requests
response = client.post(
'/samlValidate?TARGET=%s' % self.service,
self.xml_template % {
@ -1671,6 +1707,7 @@ class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent):
},
content_type="text/xml; encoding='utf-8'"
)
# and it should succeed
self.assert_success(response, settings.CAS_TEST_USER, settings.CAS_TEST_ATTRIBUTES)
def test_saml_ok_user_field(self):
@ -1734,7 +1771,7 @@ class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent):
)
def test_saml_bad_target(self):
"""test with a valid(ticket, service), but using a bad target"""
"""test with a valid ticket, but using a bad target, validation should fail"""
bad_target = "https://www.example.org"
ticket = get_user_ticket_request(self.service)[1]

View file

@ -3,10 +3,13 @@ from cas_server.default_settings import settings
from django.test import Client
import cgi
from threading import Thread
from lxml import etree
from six.moves import BaseHTTPServer
from six.moves.urllib.parse import urlparse, parse_qsl
from cas_server import models
from cas_server import utils
def copy_form(form):
@ -70,7 +73,7 @@ def get_validated_ticket(service):
def get_pgt():
"""return a dict contening a service, user and PGT ticket for this service"""
(httpd, host, port) = utils.HttpParamsHandler.run()[0:3]
(httpd, host, port) = HttpParamsHandler.run()[0:3]
service = "http://%s:%s" % (host, port)
(user, ticket) = get_user_ticket_request(service)[:2]
@ -100,3 +103,67 @@ def get_proxy_ticket(service):
proxy_ticket = proxy_ticket[0].text
ticket = models.ProxyTicket.objects.get(value=proxy_ticket)
return ticket
class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler):
"""
A simple http server that return 200 on GET or POST
and store GET or POST parameters. Used in unit tests
"""
def do_GET(self):
"""Called on a GET request on the BaseHTTPServer"""
self.send_response(200)
self.send_header(b"Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"ok")
url = urlparse(self.path)
params = dict(parse_qsl(url.query))
self.server.PARAMS = params
def do_POST(self):
"""Called on a POST request on the BaseHTTPServer"""
ctype, pdict = cgi.parse_header(self.headers.get('content-type'))
if ctype == 'multipart/form-data':
postvars = cgi.parse_multipart(self.rfile, pdict)
elif ctype == 'application/x-www-form-urlencoded':
length = int(self.headers.get('content-length'))
postvars = cgi.parse_qs(self.rfile.read(length), keep_blank_values=1)
else:
postvars = {}
self.server.PARAMS = postvars
def log_message(self, *args):
"""silent any log message"""
return
@classmethod
def run(cls):
"""Run a BaseHTTPServer using this class as handler"""
server_class = BaseHTTPServer.HTTPServer
httpd = server_class(("127.0.0.1", 0), cls)
(host, port) = httpd.socket.getsockname()
def lauch():
"""routine to lauch in a background thread"""
httpd.handle_request()
httpd.server_close()
httpd_thread = Thread(target=lauch)
httpd_thread.daemon = True
httpd_thread.start()
return (httpd, host, port)
class Http404Handler(HttpParamsHandler):
"""A simple http server that always return 404 not found. Used in unit tests"""
def do_GET(self):
"""Called on a GET request on the BaseHTTPServer"""
self.send_response(404)
self.send_header(b"Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"error 404 not found")
def do_POST(self):
"""Called on a POST request on the BaseHTTPServer"""
return self.do_GET()

View file

@ -23,10 +23,8 @@ import hashlib
import crypt
import base64
import six
import cgi
from threading import Thread
from importlib import import_module
from six.moves import BaseHTTPServer
from six.moves.urllib.parse import urlparse, urlunparse, parse_qsl, urlencode
@ -151,70 +149,6 @@ def gen_saml_id():
return _gen_ticket('_')
class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler):
"""
A simple http server that return 200 on GET or POST
and store GET or POST parameters. Used in unit tests
"""
def do_GET(self):
"""Called on a GET request on the BaseHTTPServer"""
self.send_response(200)
self.send_header(b"Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"ok")
url = urlparse(self.path)
params = dict(parse_qsl(url.query))
self.server.PARAMS = params
def do_POST(self):
"""Called on a POST request on the BaseHTTPServer"""
ctype, pdict = cgi.parse_header(self.headers.get('content-type'))
if ctype == 'multipart/form-data':
postvars = cgi.parse_multipart(self.rfile, pdict)
elif ctype == 'application/x-www-form-urlencoded':
length = int(self.headers.get('content-length'))
postvars = cgi.parse_qs(self.rfile.read(length), keep_blank_values=1)
else:
postvars = {}
self.server.PARAMS = postvars
def log_message(self, *args):
"""silent any log message"""
return
@classmethod
def run(cls):
"""Run a BaseHTTPServer using this class as handler"""
server_class = BaseHTTPServer.HTTPServer
httpd = server_class(("127.0.0.1", 0), cls)
(host, port) = httpd.socket.getsockname()
def lauch():
"""routine to lauch in a background thread"""
httpd.handle_request()
httpd.server_close()
httpd_thread = Thread(target=lauch)
httpd_thread.daemon = True
httpd_thread.start()
return (httpd, host, port)
class Http404Handler(HttpParamsHandler):
"""A simple http server that always return 404 not found. Used in unit tests"""
def do_GET(self):
"""Called on a GET request on the BaseHTTPServer"""
self.send_response(404)
self.send_header(b"Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"error 404 not found")
def do_POST(self):
"""Called on a POST request on the BaseHTTPServer"""
return self.do_GET()
class LdapHashUserPassword(object):
"""Please see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html"""