Tweak the cas client lib to always return unicode

hence, the behaviour is consistent between python2 and python3
This commit is contained in:
Valentin Samir 2016-07-03 13:49:31 +02:00
parent 2f1b3862ff
commit fcd906ca78

View file

@ -21,6 +21,7 @@
# This file is originated from https://github.com/python-cas/python-cas # This file is originated from https://github.com/python-cas/python-cas
# at commit ec1f2d4779625229398547b9234d0e9e874a2c9a # at commit ec1f2d4779625229398547b9234d0e9e874a2c9a
import six
from six.moves.urllib import parse as urllib_parse from six.moves.urllib import parse as urllib_parse
from six.moves.urllib import request as urllib_request from six.moves.urllib import request as urllib_request
from six.moves.urllib.request import Request from six.moves.urllib.request import Request
@ -32,6 +33,15 @@ class CASError(ValueError):
pass pass
class ReturnUnicode(object):
@staticmethod
def unicode(string, charset):
if not isinstance(string, six.text_type):
return string.decode(charset)
else:
return string
class SingleLogoutMixin(object): class SingleLogoutMixin(object):
@classmethod @classmethod
def get_saml_slos(cls, logout_request): def get_saml_slos(cls, logout_request):
@ -124,7 +134,7 @@ class CASClientBase(object):
raise CASError("Bad http code %s" % response.code) raise CASError("Bad http code %s" % response.code)
class CASClientV1(CASClientBase): class CASClientV1(CASClientBase, ReturnUnicode):
"""CAS Client Version 1""" """CAS Client Version 1"""
logout_redirect_param_name = 'url' logout_redirect_param_name = 'url'
@ -140,15 +150,21 @@ class CASClientV1(CASClientBase):
page = urllib_request.urlopen(url) page = urllib_request.urlopen(url)
try: try:
verified = page.readline().strip() verified = page.readline().strip()
if verified == 'yes': if verified == b'yes':
return page.readline().strip(), None, None content_type = page.info().get('Content-type')
if "charset=" in content_type:
charset = content_type.split("charset=")[-1]
else:
charset = "ascii"
user = self.unicode(page.readline().strip(), charset)
return user, None, None
else: else:
return None, None, None return None, None, None
finally: finally:
page.close() page.close()
class CASClientV2(CASClientBase): class CASClientV2(CASClientBase, ReturnUnicode):
"""CAS Client Version 2""" """CAS Client Version 2"""
url_suffix = 'serviceValidate' url_suffix = 'serviceValidate'
@ -161,8 +177,8 @@ class CASClientV2(CASClientBase):
def verify_ticket(self, ticket): def verify_ticket(self, ticket):
"""Verifies CAS 2.0+/3.0+ XML-based authentication ticket and returns extended attributes""" """Verifies CAS 2.0+/3.0+ XML-based authentication ticket and returns extended attributes"""
response = self.get_verification_response(ticket) (response, charset) = self.get_verification_response(ticket)
return self.verify_response(response) return self.verify_response(response, charset)
def get_verification_response(self, ticket): def get_verification_response(self, ticket):
params = [('ticket', ticket), ('service', self.service_url)] params = [('ticket', ticket), ('service', self.service_url)]
@ -172,37 +188,42 @@ class CASClientV2(CASClientBase):
url = base_url + '?' + urllib_parse.urlencode(params) url = base_url + '?' + urllib_parse.urlencode(params)
page = urllib_request.urlopen(url) page = urllib_request.urlopen(url)
try: try:
return page.read() content_type = page.info().get('Content-type')
if "charset=" in content_type:
charset = content_type.split("charset=")[-1]
else:
charset = "ascii"
return (page.read(), charset)
finally: finally:
page.close() page.close()
@classmethod @classmethod
def parse_attributes_xml_element(cls, element): def parse_attributes_xml_element(cls, element, charset):
attributes = dict() attributes = dict()
for attribute in element: for attribute in element:
tag = attribute.tag.split("}").pop() tag = cls.self.unicode(attribute.tag, charset).split(u"}").pop()
if tag in attributes: if tag in attributes:
if isinstance(attributes[tag], list): if isinstance(attributes[tag], list):
attributes[tag].append(attribute.text) attributes[tag].append(cls.unicode(attribute.text, charset))
else: else:
attributes[tag] = [attributes[tag]] attributes[tag] = [attributes[tag]]
attributes[tag].append(attribute.text) attributes[tag].append(cls.unicode(attribute.text, charset))
else: else:
if tag == 'attraStyle': if tag == u'attraStyle':
pass pass
else: else:
attributes[tag] = attribute.text attributes[tag] = cls.unicode(attribute.text, charset)
return attributes return attributes
@classmethod @classmethod
def verify_response(cls, response): def verify_response(cls, response, charset):
user, attributes, pgtiou = cls.parse_response_xml(response) user, attributes, pgtiou = cls.parse_response_xml(response, charset)
if len(attributes) == 0: if len(attributes) == 0:
attributes = None attributes = None
return user, attributes, pgtiou return user, attributes, pgtiou
@classmethod @classmethod
def parse_response_xml(cls, response): def parse_response_xml(cls, response, charset):
try: try:
from xml.etree import ElementTree from xml.etree import ElementTree
except ImportError: except ImportError:
@ -216,11 +237,11 @@ class CASClientV2(CASClientBase):
if tree[0].tag.endswith('authenticationSuccess'): if tree[0].tag.endswith('authenticationSuccess'):
for element in tree[0]: for element in tree[0]:
if element.tag.endswith('user'): if element.tag.endswith('user'):
user = element.text user = cls.unicode(element.text, charset)
elif element.tag.endswith('proxyGrantingTicket'): elif element.tag.endswith('proxyGrantingTicket'):
pgtiou = element.text pgtiou = cls.unicode(element.text, charset)
elif element.tag.endswith('attributes'): elif element.tag.endswith('attributes'):
attributes = cls.parse_attributes_xml_element(element) attributes = cls.parse_attributes_xml_element(element, charset)
return user, attributes, pgtiou return user, attributes, pgtiou
@ -230,23 +251,23 @@ class CASClientV3(CASClientV2, SingleLogoutMixin):
logout_redirect_param_name = 'service' logout_redirect_param_name = 'service'
@classmethod @classmethod
def parse_attributes_xml_element(cls, element): def parse_attributes_xml_element(cls, element, charset):
attributes = dict() attributes = dict()
for attribute in element: for attribute in element:
tag = attribute.tag.split("}").pop() tag = cls.unicode(attribute.tag, charset).split(u"}").pop()
if tag in attributes: if tag in attributes:
if isinstance(attributes[tag], list): if isinstance(attributes[tag], list):
attributes[tag].append(attribute.text) attributes[tag].append(cls.unicode(attribute.text, charset))
else: else:
attributes[tag] = [attributes[tag]] attributes[tag] = [attributes[tag]]
attributes[tag].append(attribute.text) attributes[tag].append(cls.unicode(attribute.text, charset))
else: else:
attributes[tag] = attribute.text attributes[tag] = cls.unicode(attribute.text, charset)
return attributes return attributes
@classmethod @classmethod
def verify_response(cls, response): def verify_response(cls, response, charset):
return cls.parse_response_xml(response) return cls.parse_response_xml(response, charset)
SAML_1_0_NS = 'urn:oasis:names:tc:SAML:1.0:' SAML_1_0_NS = 'urn:oasis:names:tc:SAML:1.0:'
@ -284,6 +305,11 @@ class CASClientWithSAMLV1(CASClientV2, SingleLogoutMixin):
from elementtree import ElementTree from elementtree import ElementTree
page = self.fetch_saml_validation(ticket) page = self.fetch_saml_validation(ticket)
content_type = page.info().get('Content-type')
if "charset=" in content_type:
charset = content_type.split("charset=")[-1]
else:
charset = "ascii"
try: try:
user = None user = None
@ -296,21 +322,25 @@ class CASClientWithSAMLV1(CASClientV2, SingleLogoutMixin):
# User is validated # User is validated
name_identifier = tree.find('.//' + SAML_1_0_ASSERTION_NS + 'NameIdentifier') name_identifier = tree.find('.//' + SAML_1_0_ASSERTION_NS + 'NameIdentifier')
if name_identifier is not None: if name_identifier is not None:
user = name_identifier.text user = self.unicode(name_identifier.text, charset)
attrs = tree.findall('.//' + SAML_1_0_ASSERTION_NS + 'Attribute') attrs = tree.findall('.//' + SAML_1_0_ASSERTION_NS + 'Attribute')
for at in attrs: for at in attrs:
if self.username_attribute in list(at.attrib.values()): if self.username_attribute in list(at.attrib.values()):
user = at.find(SAML_1_0_ASSERTION_NS + 'AttributeValue').text user = self.unicode(
attributes['uid'] = user at.find(SAML_1_0_ASSERTION_NS + 'AttributeValue').text,
charset
)
attributes[u'uid'] = user
values = at.findall(SAML_1_0_ASSERTION_NS + 'AttributeValue') values = at.findall(SAML_1_0_ASSERTION_NS + 'AttributeValue')
key = self.unicode(at.attrib['AttributeName'], charset)
if len(values) > 1: if len(values) > 1:
values_array = [] values_array = []
for v in values: for v in values:
values_array.append(v.text) values_array.append(self.unicode(v.text, charset))
attributes[at.attrib['AttributeName']] = values_array attributes[key] = values_array
else: else:
attributes[at.attrib['AttributeName']] = values[0].text attributes[key] = self.unicode(values[0].text, charset)
return user, attributes, None return user, attributes, None
finally: finally:
page.close() page.close()