diff --git a/cas_server/tests/test_utils.py b/cas_server/tests/test_utils.py index 4818ff3..3cbe5ab 100644 --- a/cas_server/tests/test_utils.py +++ b/cas_server/tests/test_utils.py @@ -27,42 +27,146 @@ class CheckPasswordCase(TestCase): self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8")) self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8")) - def test_crypt(self): - """test the crypt auth method""" - if six.PY3: - hashed_password1 = utils.crypt.crypt( - self.password1.decode("utf8"), - "$6$UVVAQvrMyXMF3FF3" - ).encode("utf8") - else: - hashed_password1 = utils.crypt.crypt(self.password1, "$6$UVVAQvrMyXMF3FF3") - - self.assertTrue(utils.check_password("crypt", self.password1, hashed_password1, "utf8")) - self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8")) - - def test_ldap_ssha(self): - """test the ldap auth method with a {SSHA} scheme""" - salt = b"UVVAQvrMyXMF3FF3" - hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8") - - self.assertIsInstance(hashed_password1, bytes) - self.assertTrue(utils.check_password("ldap", self.password1, hashed_password1, "utf8")) - self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8")) - - def test_hex_md5(self): - """test the hex_md5 auth method""" - hashed_password1 = utils.hashlib.md5(self.password1).hexdigest() - - self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8")) - self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8")) - - def test_hex_sha512(self): - """test the hex_sha512 auth method""" - hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest() - + def test_plain_unicode(self): + """test the plain auth method with unicode input""" self.assertTrue( - utils.check_password("hex_sha512", self.password1, hashed_password1, "utf8") + utils.check_password( + "plain", + self.password1.decode("utf8"), + self.password1.decode("utf8"), + "utf8" + ) ) self.assertFalse( - utils.check_password("hex_sha512", self.password2, hashed_password1, "utf8") + utils.check_password( + "plain", + self.password1.decode("utf8"), + self.password2.decode("utf8"), + "utf8" + ) ) + + def test_crypt(self): + """test the crypt auth method""" + salts = ["$6$UVVAQvrMyXMF3FF3", "aa"] + hashed_password1 = [] + for salt in salts: + if six.PY3: + hashed_password1.append( + utils.crypt.crypt( + self.password1.decode("utf8"), + salt + ).encode("utf8") + ) + else: + hashed_password1.append(utils.crypt.crypt(self.password1, salt)) + + for hp1 in hashed_password1: + self.assertTrue(utils.check_password("crypt", self.password1, hp1, "utf8")) + self.assertFalse(utils.check_password("crypt", self.password2, hp1, "utf8")) + + with self.assertRaises(ValueError): + utils.check_password("crypt", self.password1, b"$truc$s$dsdsd", "utf8") + + def test_ldap_password_valid(self): + """test the ldap auth method with all the schemes""" + salt = b"UVVAQvrMyXMF3FF3" + schemes_salt = [b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}"] + schemes_nosalt = [b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"] + hashed_password1 = [] + for scheme in schemes_salt: + hashed_password1.append( + utils.LdapHashUserPassword.hash(scheme, self.password1, salt, charset="utf8") + ) + for scheme in schemes_nosalt: + hashed_password1.append( + utils.LdapHashUserPassword.hash(scheme, self.password1, charset="utf8") + ) + hashed_password1.append( + utils.LdapHashUserPassword.hash( + b"{CRYPT}", + self.password1, + "$6$UVVAQvrMyXMF3FF3", + charset="utf8" + ) + ) + for hp1 in hashed_password1: + self.assertIsInstance(hp1, bytes) + self.assertTrue(utils.check_password("ldap", self.password1, hp1, "utf8")) + self.assertFalse(utils.check_password("ldap", self.password2, hp1, "utf8")) + + def test_ldap_password_fail(self): + """test the ldap auth method with malformed hash or bad schemes""" + salt = b"UVVAQvrMyXMF3FF3" + schemes_salt = [b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}"] + schemes_nosalt = [b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"] + + # first try to hash with bad parameters + with self.assertRaises(utils.LdapHashUserPassword.BadScheme): + utils.LdapHashUserPassword.hash(b"TOTO", self.password1) + for scheme in schemes_nosalt: + with self.assertRaises(utils.LdapHashUserPassword.BadScheme): + utils.LdapHashUserPassword.hash(scheme, self.password1, salt) + for scheme in schemes_salt: + with self.assertRaises(utils.LdapHashUserPassword.BadScheme): + utils.LdapHashUserPassword.hash(scheme, self.password1) + with self.assertRaises(utils.LdapHashUserPassword.BadSalt): + utils.LdapHashUserPassword.hash(b'{CRYPT}', self.password1, "$truc$toto") + + # then try to check hash with bad hashes + with self.assertRaises(utils.LdapHashUserPassword.BadHash): + utils.check_password("ldap", self.password1, b"TOTOssdsdsd", "utf8") + for scheme in schemes_salt: + with self.assertRaises(utils.LdapHashUserPassword.BadHash): + utils.check_password("ldap", self.password1, scheme + b"dG90b3E8ZHNkcw==", "utf8") + + def test_hex(self): + """test all the hex_HASH method: the hashed password is a simple hash of the password""" + hashes = ["md5", "sha1", "sha224", "sha256", "sha384", "sha512"] + hashed_password1 = [] + for hash in hashes: + hashed_password1.append( + ("hex_%s" % hash, getattr(utils.hashlib, hash)(self.password1).hexdigest()) + ) + for (method, hp1) in hashed_password1: + self.assertTrue(utils.check_password(method, self.password1, hp1, "utf8")) + self.assertFalse(utils.check_password(method, self.password2, hp1, "utf8")) + + def test_bad_method(self): + """try to check password with a bad method, should raise a ValueError""" + with self.assertRaises(ValueError): + utils.check_password("test", self.password1, b"$truc$s$dsdsd", "utf8") + + +class UtilsTestCase(TestCase): + """tests for some little utils functions""" + def test_import_attr(self): + """ + test the import_attr function. Feeded with a dotted path string, it should + import the dotted module and return that last componend of the dotted path + (function, class or variable) + """ + with self.assertRaises(ImportError): + utils.import_attr('toto.titi.tutu') + with self.assertRaises(AttributeError): + utils.import_attr('cas_server.utils.toto') + with self.assertRaises(ValueError): + utils.import_attr('toto') + self.assertEqual( + utils.import_attr('cas_server.default_app_config'), + 'cas_server.apps.CasAppConfig' + ) + self.assertEqual(utils.import_attr(utils), utils) + + def test_update_url(self): + """ + test the update_url function. Given an url with possible GET parameter and a dict + the function build a url with GET parameters updated by the dictionnary + """ + url1 = utils.update_url(u"https://www.example.com?toto=1", {u"tata": u"2"}) + url2 = utils.update_url(b"https://www.example.com?toto=1", {b"tata": b"2"}) + self.assertEqual(url1, u"https://www.example.com?tata=2&toto=1") + self.assertEqual(url2, u"https://www.example.com?tata=2&toto=1") + + url3 = utils.update_url(u"https://www.example.com?toto=1", {u"toto": u"2"}) + self.assertEqual(url3, u"https://www.example.com?toto=2") diff --git a/cas_server/utils.py b/cas_server/utils.py index 3be2bad..a257c98 100644 --- a/cas_server/utils.py +++ b/cas_server/utils.py @@ -86,9 +86,6 @@ def update_url(url, params): query = dict(parse_qsl(url_parts[4])) query.update(params) url_parts[4] = urlencode(query) - for i, url_part in enumerate(url_parts): - if not isinstance(url_part, bytes): - url_parts[i] = url_part.encode('utf-8') return urlunparse(url_parts).decode('utf-8') @@ -239,7 +236,7 @@ class LdapHashUserPassword(object): if salt is None or salt == b"": salt = b"" cls._test_scheme_nosalt(scheme) - elif salt is not None: + else: cls._test_scheme_salt(scheme) try: return scheme + base64.b64encode( @@ -273,7 +270,7 @@ class LdapHashUserPassword(object): if scheme in cls.schemes_nosalt: return b"" elif scheme == b'{CRYPT}': - return b'$'.join(hashed_passord.split(b'$', 3)[:-1]) + return b'$'.join(hashed_passord.split(b'$', 3)[:-1])[len(scheme):] else: hashed_passord = base64.b64decode(hashed_passord[len(scheme):]) if len(hashed_passord) < cls._schemes_to_len[scheme]: @@ -295,7 +292,7 @@ def check_password(method, password, hashed_password, charset): elif method == "crypt": if hashed_password.startswith(b'$'): salt = b'$'.join(hashed_password.split(b'$', 3)[:-1]) - elif hashed_password.startswith(b'_'): + elif hashed_password.startswith(b'_'): # pragma: no cover old BSD format not supported salt = hashed_password[:9] else: salt = hashed_password[:2]