diff --git a/cas_server/cas.py b/cas_server/cas.py index c190d48..fdc22fb 100644 --- a/cas_server/cas.py +++ b/cas_server/cas.py @@ -21,6 +21,7 @@ # This file is originated from https://github.com/python-cas/python-cas # at commit ec1f2d4779625229398547b9234d0e9e874a2c9a +import six from six.moves.urllib import parse as urllib_parse from six.moves.urllib import request as urllib_request from six.moves.urllib.request import Request @@ -32,6 +33,15 @@ class CASError(ValueError): pass +class ReturnUnicode(object): + @staticmethod + def unicode(string, charset): + if not isinstance(string, six.text_type): + return string.decode(charset) + else: + return string + + class SingleLogoutMixin(object): @classmethod def get_saml_slos(cls, logout_request): @@ -124,7 +134,7 @@ class CASClientBase(object): raise CASError("Bad http code %s" % response.code) -class CASClientV1(CASClientBase): +class CASClientV1(CASClientBase, ReturnUnicode): """CAS Client Version 1""" logout_redirect_param_name = 'url' @@ -140,15 +150,21 @@ class CASClientV1(CASClientBase): page = urllib_request.urlopen(url) try: verified = page.readline().strip() - if verified == 'yes': - return page.readline().strip(), None, None + if verified == b'yes': + content_type = page.info().get('Content-type') + if "charset=" in content_type: + charset = content_type.split("charset=")[-1] + else: + charset = "ascii" + user = self.unicode(page.readline().strip(), charset) + return user, None, None else: return None, None, None finally: page.close() -class CASClientV2(CASClientBase): +class CASClientV2(CASClientBase, ReturnUnicode): """CAS Client Version 2""" url_suffix = 'serviceValidate' @@ -161,8 +177,8 @@ class CASClientV2(CASClientBase): def verify_ticket(self, ticket): """Verifies CAS 2.0+/3.0+ XML-based authentication ticket and returns extended attributes""" - response = self.get_verification_response(ticket) - return self.verify_response(response) + (response, charset) = self.get_verification_response(ticket) + return self.verify_response(response, charset) def get_verification_response(self, ticket): params = [('ticket', ticket), ('service', self.service_url)] @@ -172,37 +188,42 @@ class CASClientV2(CASClientBase): url = base_url + '?' + urllib_parse.urlencode(params) page = urllib_request.urlopen(url) try: - return page.read() + content_type = page.info().get('Content-type') + if "charset=" in content_type: + charset = content_type.split("charset=")[-1] + else: + charset = "ascii" + return (page.read(), charset) finally: page.close() @classmethod - def parse_attributes_xml_element(cls, element): + def parse_attributes_xml_element(cls, element, charset): attributes = dict() for attribute in element: - tag = attribute.tag.split("}").pop() + tag = cls.self.unicode(attribute.tag, charset).split(u"}").pop() if tag in attributes: if isinstance(attributes[tag], list): - attributes[tag].append(attribute.text) + attributes[tag].append(cls.unicode(attribute.text, charset)) else: attributes[tag] = [attributes[tag]] - attributes[tag].append(attribute.text) + attributes[tag].append(cls.unicode(attribute.text, charset)) else: - if tag == 'attraStyle': + if tag == u'attraStyle': pass else: - attributes[tag] = attribute.text + attributes[tag] = cls.unicode(attribute.text, charset) return attributes @classmethod - def verify_response(cls, response): - user, attributes, pgtiou = cls.parse_response_xml(response) + def verify_response(cls, response, charset): + user, attributes, pgtiou = cls.parse_response_xml(response, charset) if len(attributes) == 0: attributes = None return user, attributes, pgtiou @classmethod - def parse_response_xml(cls, response): + def parse_response_xml(cls, response, charset): try: from xml.etree import ElementTree except ImportError: @@ -216,11 +237,11 @@ class CASClientV2(CASClientBase): if tree[0].tag.endswith('authenticationSuccess'): for element in tree[0]: if element.tag.endswith('user'): - user = element.text + user = cls.unicode(element.text, charset) elif element.tag.endswith('proxyGrantingTicket'): - pgtiou = element.text + pgtiou = cls.unicode(element.text, charset) elif element.tag.endswith('attributes'): - attributes = cls.parse_attributes_xml_element(element) + attributes = cls.parse_attributes_xml_element(element, charset) return user, attributes, pgtiou @@ -230,23 +251,23 @@ class CASClientV3(CASClientV2, SingleLogoutMixin): logout_redirect_param_name = 'service' @classmethod - def parse_attributes_xml_element(cls, element): + def parse_attributes_xml_element(cls, element, charset): attributes = dict() for attribute in element: - tag = attribute.tag.split("}").pop() + tag = cls.unicode(attribute.tag, charset).split(u"}").pop() if tag in attributes: if isinstance(attributes[tag], list): - attributes[tag].append(attribute.text) + attributes[tag].append(cls.unicode(attribute.text, charset)) else: attributes[tag] = [attributes[tag]] - attributes[tag].append(attribute.text) + attributes[tag].append(cls.unicode(attribute.text, charset)) else: - attributes[tag] = attribute.text + attributes[tag] = cls.unicode(attribute.text, charset) return attributes @classmethod - def verify_response(cls, response): - return cls.parse_response_xml(response) + def verify_response(cls, response, charset): + return cls.parse_response_xml(response, charset) SAML_1_0_NS = 'urn:oasis:names:tc:SAML:1.0:' @@ -284,6 +305,11 @@ class CASClientWithSAMLV1(CASClientV2, SingleLogoutMixin): from elementtree import ElementTree page = self.fetch_saml_validation(ticket) + content_type = page.info().get('Content-type') + if "charset=" in content_type: + charset = content_type.split("charset=")[-1] + else: + charset = "ascii" try: user = None @@ -296,21 +322,25 @@ class CASClientWithSAMLV1(CASClientV2, SingleLogoutMixin): # User is validated name_identifier = tree.find('.//' + SAML_1_0_ASSERTION_NS + 'NameIdentifier') if name_identifier is not None: - user = name_identifier.text + user = self.unicode(name_identifier.text, charset) attrs = tree.findall('.//' + SAML_1_0_ASSERTION_NS + 'Attribute') for at in attrs: if self.username_attribute in list(at.attrib.values()): - user = at.find(SAML_1_0_ASSERTION_NS + 'AttributeValue').text - attributes['uid'] = user + user = self.unicode( + at.find(SAML_1_0_ASSERTION_NS + 'AttributeValue').text, + charset + ) + attributes[u'uid'] = user values = at.findall(SAML_1_0_ASSERTION_NS + 'AttributeValue') + key = self.unicode(at.attrib['AttributeName'], charset) if len(values) > 1: values_array = [] for v in values: - values_array.append(v.text) - attributes[at.attrib['AttributeName']] = values_array + values_array.append(self.unicode(v.text, charset)) + attributes[key] = values_array else: - attributes[at.attrib['AttributeName']] = values[0].text + attributes[key] = self.unicode(values[0].text, charset) return user, attributes, None finally: page.close()