Full coverage for view validateService

This commit is contained in:
Valentin Samir 2016-06-28 18:57:56 +02:00
parent 6d610d5aa6
commit 44acd005ee
3 changed files with 207 additions and 9 deletions

View file

@ -915,18 +915,37 @@ class ValidateTestCase(TestCase):
class ValidateServiceTestCase(TestCase): class ValidateServiceTestCase(TestCase):
"""tests for the serviceValidate view"""
def setUp(self): def setUp(self):
"""preparing test context"""
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
self.service = 'http://127.0.0.1:45678' self.service = 'http://127.0.0.1:45678'
self.service_pattern = models.ServicePattern.objects.create( self.service_pattern = models.ServicePattern.objects.create(
name="localhost", name="localhost",
pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$", pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
proxy_callback=True proxy_callback=True
) )
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
self.service_user_field = "https://user_field.example.com"
self.service_pattern_user_field = models.ServicePattern.objects.create(
name="user field",
pattern="^https://user_field\.example\.com(/.*)?$",
user_field="alias"
)
self.service_one_attribute = "https://one_attribute.example.com"
self.service_pattern_one_attribute = models.ServicePattern.objects.create(
name="one_attribute",
pattern="^https://one_attribute\.example\.com(/.*)?$"
)
models.ReplaceAttributName.objects.create(
name="nom",
service_pattern=self.service_pattern_one_attribute
)
def test_validate_service_view_ok(self): def test_validate_service_view_ok(self):
"""test with a valid (ticket, service), the username and all attributes are transmited"""
ticket = get_user_ticket_request(self.service)[1] ticket = get_user_ticket_request(self.service)[1]
client = Client() client = Client()
@ -968,7 +987,51 @@ class ValidateServiceTestCase(TestCase):
self.assertEqual(attrs1, attrs2) self.assertEqual(attrs1, attrs2)
self.assertEqual(attrs1, original) self.assertEqual(attrs1, original)
def test_validate_service_view_ok_one_attribute(self):
"""
test with a valid (ticket, service), the username and
the 'nom' only attribute are transmited
"""
ticket = get_user_ticket_request(self.service_one_attribute)[1]
client = Client()
response = client.get(
'/serviceValidate',
{'ticket': ticket.value, 'service': self.service_one_attribute}
)
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
sucess = root.xpath(
"//cas:authenticationSuccess",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertTrue(sucess)
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(users), 1)
self.assertEqual(users[0].text, settings.CAS_TEST_USER)
attributes = root.xpath(
"//cas:attributes",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(attributes), 1)
attrs1 = set()
for attr in attributes[0]:
attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(attributes), len(attrs1))
attrs2 = set()
for attr in attributes:
attrs2.add((attr.attrib['name'], attr.attrib['value']))
original = set([('nom', settings.CAS_TEST_ATTRIBUTES['nom'])])
self.assertEqual(attrs1, attrs2)
self.assertEqual(attrs1, original)
def test_validate_service_view_badservice(self): def test_validate_service_view_badservice(self):
"""test with a valid ticket but a bad service, the validatin should fail"""
ticket = get_user_ticket_request(self.service)[1] ticket = get_user_ticket_request(self.service)[1]
client = Client() client = Client()
@ -986,6 +1049,10 @@ class ValidateServiceTestCase(TestCase):
self.assertEqual(error[0].text, bad_service) self.assertEqual(error[0].text, bad_service)
def test_validate_service_view_badticket_goodprefix(self): def test_validate_service_view_badticket_goodprefix(self):
"""
test with a good service bud a bad ticket begining with ST-,
the validation should fail with the error (INVALID_TICKET, ticket not found)
"""
get_user_ticket_request(self.service) get_user_ticket_request(self.service)
client = Client() client = Client()
@ -1003,6 +1070,10 @@ class ValidateServiceTestCase(TestCase):
self.assertEqual(error[0].text, 'ticket not found') self.assertEqual(error[0].text, 'ticket not found')
def test_validate_service_view_badticket_badprefix(self): def test_validate_service_view_badticket_badprefix(self):
"""
test with a good service bud a bad ticket not begining with ST-,
the validation should fail with the error (INVALID_TICKET, `the ticket`)
"""
get_user_ticket_request(self.service) get_user_ticket_request(self.service)
client = Client() client = Client()
@ -1020,6 +1091,7 @@ class ValidateServiceTestCase(TestCase):
self.assertEqual(error[0].text, bad_ticket) self.assertEqual(error[0].text, bad_ticket)
def test_validate_service_view_ok_pgturl(self): def test_validate_service_view_ok_pgturl(self):
"""test the retrieval of a ProxyGrantingTicket"""
(host, port) = utils.PGTUrlHandler.run()[1:3] (host, port) = utils.PGTUrlHandler.run()[1:3]
service = "http://%s:%s" % (host, port) service = "http://%s:%s" % (host, port)
@ -1042,11 +1114,60 @@ class ValidateServiceTestCase(TestCase):
self.assertEqual(pgt_params["pgtIou"], pgtiou[0].text) self.assertEqual(pgt_params["pgtIou"], pgtiou[0].text)
self.assertTrue("pgtId" in pgt_params) self.assertTrue("pgtId" in pgt_params)
def test_validate_service_pgturl_sslerror(self):
"""test the retrieval of a ProxyGrantingTicket with a SSL error on the pgtUrl"""
(host, port) = utils.PGTUrlHandler.run()[1:3]
service = "https://%s:%s" % (host, port)
ticket = get_user_ticket_request(service)[1]
client = Client()
response = client.get(
'/serviceValidate',
{'ticket': ticket.value, 'service': service, 'pgtUrl': service}
)
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_PROXY_CALLBACK")
def test_validate_service_pgturl_404(self):
"""
test the retrieval on a ProxyGrantingTicket then to pgtUrl return a http error.
PGT creation should be aborted but the ticket still be valid
"""
(host, port) = utils.PGTUrlHandler404.run()[1:3]
service = "http://%s:%s" % (host, port)
ticket = get_user_ticket_request(service)[1]
client = Client()
response = client.get(
'/serviceValidate',
{'ticket': ticket.value, 'service': service, 'pgtUrl': service}
)
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
sucess = root.xpath(
"//cas:authenticationSuccess",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertTrue(sucess)
pgtiou = root.xpath(
"//cas:proxyGrantingTicket",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertFalse(pgtiou)
def test_validate_service_pgturl_bad_proxy_callback(self): def test_validate_service_pgturl_bad_proxy_callback(self):
"""test the retrieval of a ProxyGrantingTicket, not allowed pgtUrl should be denied"""
self.service_pattern.proxy_callback = False self.service_pattern.proxy_callback = False
self.service_pattern.save() self.service_pattern.save()
ticket = get_user_ticket_request(self.service)[1] ticket = get_user_ticket_request(self.service)[1]
client = Client() client = Client()
response = client.get( response = client.get(
'/serviceValidate', '/serviceValidate',
@ -1063,6 +1184,66 @@ class ValidateServiceTestCase(TestCase):
self.assertEqual(error[0].attrib['code'], "INVALID_PROXY_CALLBACK") self.assertEqual(error[0].attrib['code'], "INVALID_PROXY_CALLBACK")
self.assertEqual(error[0].text, "callback url not allowed by configuration") self.assertEqual(error[0].text, "callback url not allowed by configuration")
self.service_pattern.proxy_callback = True
ticket = get_user_ticket_request(self.service)[1]
client = Client()
response = client.get(
'/serviceValidate',
{'ticket': ticket.value, 'service': self.service, 'pgtUrl': "https://www.example.org"}
)
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_PROXY_CALLBACK")
self.assertEqual(error[0].text, "callback url not allowed by configuration")
def test_validate_user_field_ok(self):
"""
test with a good user_field. A bad user_field (that evaluate to False)
wont happed cause it is filtered in the login view
"""
ticket = get_user_ticket_request(self.service_user_field)[1]
client = Client()
response = client.get(
'/serviceValidate',
{'ticket': ticket.value, 'service': self.service_user_field}
)
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
sucess = root.xpath(
"//cas:authenticationSuccess",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertTrue(sucess)
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(users), 1)
self.assertEqual(users[0].text, settings.CAS_TEST_ATTRIBUTES["alias"][0])
def test_validate_missing_parameter(self):
"""test with a missing GET parameter among [service, ticket]"""
ticket = get_user_ticket_request(self.service)[1]
client = Client()
params = {'ticket': ticket.value, 'service': self.service}
for key in ['ticket', 'service']:
send_params = params.copy()
del send_params[key]
response = client.get('/serviceValidate', send_params)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_REQUEST")
self.assertEqual(error[0].text, "you must specify a service and a ticket")
class ProxyTestCase(TestCase): class ProxyTestCase(TestCase):

View file

@ -148,6 +148,7 @@ def gen_saml_id():
class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler): class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler):
"""A simple http server that return 200 on GET and store GET parameters. Used in unit tests"""
PARAMS = {} PARAMS = {}
def do_GET(self): def do_GET(self):
@ -162,10 +163,10 @@ class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler):
def log_message(self, *args): def log_message(self, *args):
return return
@staticmethod @classmethod
def run(): def run(cls):
server_class = BaseHTTPServer.HTTPServer server_class = BaseHTTPServer.HTTPServer
httpd = server_class(("127.0.0.1", 0), PGTUrlHandler) httpd = server_class(("127.0.0.1", 0), cls)
(host, port) = httpd.socket.getsockname() (host, port) = httpd.socket.getsockname()
def lauch(): def lauch():
@ -178,6 +179,15 @@ class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler):
return (httpd_thread, host, port) return (httpd_thread, host, port)
class PGTUrlHandler404(PGTUrlHandler):
"""A simple http server that always return 404 not found. Used in unit tests"""
def do_GET(self):
self.send_response(404)
self.send_header(b"Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"error 404 not found")
class LdapHashUserPassword(object): class LdapHashUserPassword(object):
"""Please see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html""" """Please see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html"""

View file

@ -596,7 +596,7 @@ class Validate(View):
ticket.service_pattern.user_field ticket.service_pattern.user_field
) )
if isinstance(username, list): if isinstance(username, list):
# the list not empty because we wont generate a ticket with a user_field # the list is not empty because we wont generate a ticket with a user_field
# that evaluate to False # that evaluate to False
username = username[0] username = username[0]
else: else:
@ -674,6 +674,10 @@ class ValidateService(View, AttributesMixin):
params['username'] = self.ticket.user.attributs.get( params['username'] = self.ticket.user.attributs.get(
self.ticket.service_pattern.user_field self.ticket.service_pattern.user_field
) )
if isinstance(params['username'], list):
# the list is not empty because we wont generate a ticket with a user_field
# that evaluate to False
params['username'] = params['username'][0]
if self.pgt_url and ( if self.pgt_url and (
self.pgt_url.startswith("https://") or self.pgt_url.startswith("https://") or
re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url) re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url)
@ -775,9 +779,12 @@ class ValidateService(View, AttributesMixin):
params, params,
content_type="text/xml; charset=utf-8" content_type="text/xml; charset=utf-8"
) )
except requests.exceptions.SSLError as error: except requests.exceptions.RequestException as error:
error = utils.unpack_nested_exception(error) error = utils.unpack_nested_exception(error)
raise ValidateError('INVALID_PROXY_CALLBACK', str(error)) raise ValidateError(
'INVALID_PROXY_CALLBACK',
"%s: %s" % (type(error), str(error))
)
else: else:
raise ValidateError( raise ValidateError(
'INVALID_PROXY_CALLBACK', 'INVALID_PROXY_CALLBACK',