diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 1c4d6a3..458db3d 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,6 +13,10 @@ Added ----- * Support for Django 4.0 and 4.1 +Fixes +----- +* Fix unicode sandwich issue in cas_server.utils.update_url + Removed ------- * Drop support for Django 1.11 (now deprecated for more than 2 years) diff --git a/cas_server/tests/test_view.py b/cas_server/tests/test_view.py index e216b97..99ef9ce 100644 --- a/cas_server/tests/test_view.py +++ b/cas_server/tests/test_view.py @@ -262,7 +262,7 @@ class LoginTestCase(TestCase, BaseServicePattern, CanLogin): # check that the service pattern registered on the ticket is the on we use for tests self.assertEqual(ticket.service_pattern, self.service_pattern) - def assert_service_ticket(self, client, response): + def assert_service_ticket(self, client, response, service="https://www.example.com"): """check that a ticket is well emited when requested on a allowed service""" # On ticket emission, we should be redirected to the service url, setting the ticket # GET parameter @@ -270,7 +270,7 @@ class LoginTestCase(TestCase, BaseServicePattern, CanLogin): self.assertTrue(response.has_header('Location')) self.assertTrue( response['Location'].startswith( - "https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX + "%s?ticket=%s-" % (service, settings.CAS_SERVICE_TICKET_PREFIX) ) ) # check that the value of the ticket GET parameter match the value of the ticket @@ -337,14 +337,17 @@ class LoginTestCase(TestCase, BaseServicePattern, CanLogin): self.assertFalse(b"Service https://www.example.net not allowed" in response.content) def test_view_login_get_auth_allowed_service(self): - """Request a ticket for an allowed service by an authenticated client containing non ascii char in url""" + """ + Request a ticket for an allowed service by an authenticated client containing + non ascii char in url + """ # get a client that is already authenticated client = get_auth_client() # ask for a ticket for https://www.example.com response = client.get("/login?service=https://www.example.com/é") # as https://www.example.com/é is a valid service a ticket should be created and the # user redirected to the service url - self.assert_service_ticket(client, response) + self.assert_service_ticket(client, response, service="https://www.example.com/%C3%A9") def test_view_login_get_auth_allowed_service_non_ascii(self): """Request a ticket for an allowed service by an authenticated client""" diff --git a/cas_server/utils.py b/cas_server/utils.py index 4ec2333..31d923e 100644 --- a/cas_server/utils.py +++ b/cas_server/utils.py @@ -249,15 +249,25 @@ def update_url(url, params): :return: The URL with an updated querystring :rtype: unicode """ - if not isinstance(url, bytes): - url = url.encode('utf-8') - for key, value in list(params.items()): - if not isinstance(key, bytes): - del params[key] - key = key.encode('utf-8') - if not isinstance(value, bytes): - value = value.encode('utf-8') - params[key] = value + def to_unicode(data): + if isinstance(data, bytes): + return data.decode('utf-8') + else: + return data + + def to_bytes(data): + if not isinstance(data, bytes): + return data.encode('utf-8') + else: + return data + + if six.PY3: + url = to_unicode(url) + params = {to_unicode(key): to_unicode(value) for (key, value) in params.items()} + else: + url = to_bytes(url) + params = {to_bytes(key): to_bytes(value) for (key, value) in params.items()} + url_parts = list(urlparse(url)) query = dict(parse_qsl(url_parts[4], keep_blank_values=True)) query.update(params) @@ -265,10 +275,12 @@ def update_url(url, params): query = list(query.items()) query.sort() url_query = urlencode(query) - if not isinstance(url_query, bytes): # pragma: no cover in python3 urlencode return an unicode - url_query = url_query.encode("utf-8") url_parts[4] = url_query - return urlunparse(url_parts).decode('utf-8') + url = urlunparse(url_parts) + + if isinstance(url, bytes): + url = url.decode('utf-8') + return url def unpack_nested_exception(error):