Tweak the cas client lib to always return unicode
hence, the behaviour is consistent between python2 and python3
This commit is contained in:
parent
2f1b3862ff
commit
fcd906ca78
1 changed files with 62 additions and 32 deletions
|
@ -21,6 +21,7 @@
|
||||||
# This file is originated from https://github.com/python-cas/python-cas
|
# This file is originated from https://github.com/python-cas/python-cas
|
||||||
# at commit ec1f2d4779625229398547b9234d0e9e874a2c9a
|
# at commit ec1f2d4779625229398547b9234d0e9e874a2c9a
|
||||||
|
|
||||||
|
import six
|
||||||
from six.moves.urllib import parse as urllib_parse
|
from six.moves.urllib import parse as urllib_parse
|
||||||
from six.moves.urllib import request as urllib_request
|
from six.moves.urllib import request as urllib_request
|
||||||
from six.moves.urllib.request import Request
|
from six.moves.urllib.request import Request
|
||||||
|
@ -32,6 +33,15 @@ class CASError(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ReturnUnicode(object):
|
||||||
|
@staticmethod
|
||||||
|
def unicode(string, charset):
|
||||||
|
if not isinstance(string, six.text_type):
|
||||||
|
return string.decode(charset)
|
||||||
|
else:
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
class SingleLogoutMixin(object):
|
class SingleLogoutMixin(object):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_saml_slos(cls, logout_request):
|
def get_saml_slos(cls, logout_request):
|
||||||
|
@ -124,7 +134,7 @@ class CASClientBase(object):
|
||||||
raise CASError("Bad http code %s" % response.code)
|
raise CASError("Bad http code %s" % response.code)
|
||||||
|
|
||||||
|
|
||||||
class CASClientV1(CASClientBase):
|
class CASClientV1(CASClientBase, ReturnUnicode):
|
||||||
"""CAS Client Version 1"""
|
"""CAS Client Version 1"""
|
||||||
|
|
||||||
logout_redirect_param_name = 'url'
|
logout_redirect_param_name = 'url'
|
||||||
|
@ -140,15 +150,21 @@ class CASClientV1(CASClientBase):
|
||||||
page = urllib_request.urlopen(url)
|
page = urllib_request.urlopen(url)
|
||||||
try:
|
try:
|
||||||
verified = page.readline().strip()
|
verified = page.readline().strip()
|
||||||
if verified == 'yes':
|
if verified == b'yes':
|
||||||
return page.readline().strip(), None, None
|
content_type = page.info().get('Content-type')
|
||||||
|
if "charset=" in content_type:
|
||||||
|
charset = content_type.split("charset=")[-1]
|
||||||
|
else:
|
||||||
|
charset = "ascii"
|
||||||
|
user = self.unicode(page.readline().strip(), charset)
|
||||||
|
return user, None, None
|
||||||
else:
|
else:
|
||||||
return None, None, None
|
return None, None, None
|
||||||
finally:
|
finally:
|
||||||
page.close()
|
page.close()
|
||||||
|
|
||||||
|
|
||||||
class CASClientV2(CASClientBase):
|
class CASClientV2(CASClientBase, ReturnUnicode):
|
||||||
"""CAS Client Version 2"""
|
"""CAS Client Version 2"""
|
||||||
|
|
||||||
url_suffix = 'serviceValidate'
|
url_suffix = 'serviceValidate'
|
||||||
|
@ -161,8 +177,8 @@ class CASClientV2(CASClientBase):
|
||||||
|
|
||||||
def verify_ticket(self, ticket):
|
def verify_ticket(self, ticket):
|
||||||
"""Verifies CAS 2.0+/3.0+ XML-based authentication ticket and returns extended attributes"""
|
"""Verifies CAS 2.0+/3.0+ XML-based authentication ticket and returns extended attributes"""
|
||||||
response = self.get_verification_response(ticket)
|
(response, charset) = self.get_verification_response(ticket)
|
||||||
return self.verify_response(response)
|
return self.verify_response(response, charset)
|
||||||
|
|
||||||
def get_verification_response(self, ticket):
|
def get_verification_response(self, ticket):
|
||||||
params = [('ticket', ticket), ('service', self.service_url)]
|
params = [('ticket', ticket), ('service', self.service_url)]
|
||||||
|
@ -172,37 +188,42 @@ class CASClientV2(CASClientBase):
|
||||||
url = base_url + '?' + urllib_parse.urlencode(params)
|
url = base_url + '?' + urllib_parse.urlencode(params)
|
||||||
page = urllib_request.urlopen(url)
|
page = urllib_request.urlopen(url)
|
||||||
try:
|
try:
|
||||||
return page.read()
|
content_type = page.info().get('Content-type')
|
||||||
|
if "charset=" in content_type:
|
||||||
|
charset = content_type.split("charset=")[-1]
|
||||||
|
else:
|
||||||
|
charset = "ascii"
|
||||||
|
return (page.read(), charset)
|
||||||
finally:
|
finally:
|
||||||
page.close()
|
page.close()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_attributes_xml_element(cls, element):
|
def parse_attributes_xml_element(cls, element, charset):
|
||||||
attributes = dict()
|
attributes = dict()
|
||||||
for attribute in element:
|
for attribute in element:
|
||||||
tag = attribute.tag.split("}").pop()
|
tag = cls.self.unicode(attribute.tag, charset).split(u"}").pop()
|
||||||
if tag in attributes:
|
if tag in attributes:
|
||||||
if isinstance(attributes[tag], list):
|
if isinstance(attributes[tag], list):
|
||||||
attributes[tag].append(attribute.text)
|
attributes[tag].append(cls.unicode(attribute.text, charset))
|
||||||
else:
|
else:
|
||||||
attributes[tag] = [attributes[tag]]
|
attributes[tag] = [attributes[tag]]
|
||||||
attributes[tag].append(attribute.text)
|
attributes[tag].append(cls.unicode(attribute.text, charset))
|
||||||
else:
|
else:
|
||||||
if tag == 'attraStyle':
|
if tag == u'attraStyle':
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
attributes[tag] = attribute.text
|
attributes[tag] = cls.unicode(attribute.text, charset)
|
||||||
return attributes
|
return attributes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def verify_response(cls, response):
|
def verify_response(cls, response, charset):
|
||||||
user, attributes, pgtiou = cls.parse_response_xml(response)
|
user, attributes, pgtiou = cls.parse_response_xml(response, charset)
|
||||||
if len(attributes) == 0:
|
if len(attributes) == 0:
|
||||||
attributes = None
|
attributes = None
|
||||||
return user, attributes, pgtiou
|
return user, attributes, pgtiou
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_response_xml(cls, response):
|
def parse_response_xml(cls, response, charset):
|
||||||
try:
|
try:
|
||||||
from xml.etree import ElementTree
|
from xml.etree import ElementTree
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -216,11 +237,11 @@ class CASClientV2(CASClientBase):
|
||||||
if tree[0].tag.endswith('authenticationSuccess'):
|
if tree[0].tag.endswith('authenticationSuccess'):
|
||||||
for element in tree[0]:
|
for element in tree[0]:
|
||||||
if element.tag.endswith('user'):
|
if element.tag.endswith('user'):
|
||||||
user = element.text
|
user = cls.unicode(element.text, charset)
|
||||||
elif element.tag.endswith('proxyGrantingTicket'):
|
elif element.tag.endswith('proxyGrantingTicket'):
|
||||||
pgtiou = element.text
|
pgtiou = cls.unicode(element.text, charset)
|
||||||
elif element.tag.endswith('attributes'):
|
elif element.tag.endswith('attributes'):
|
||||||
attributes = cls.parse_attributes_xml_element(element)
|
attributes = cls.parse_attributes_xml_element(element, charset)
|
||||||
return user, attributes, pgtiou
|
return user, attributes, pgtiou
|
||||||
|
|
||||||
|
|
||||||
|
@ -230,23 +251,23 @@ class CASClientV3(CASClientV2, SingleLogoutMixin):
|
||||||
logout_redirect_param_name = 'service'
|
logout_redirect_param_name = 'service'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_attributes_xml_element(cls, element):
|
def parse_attributes_xml_element(cls, element, charset):
|
||||||
attributes = dict()
|
attributes = dict()
|
||||||
for attribute in element:
|
for attribute in element:
|
||||||
tag = attribute.tag.split("}").pop()
|
tag = cls.unicode(attribute.tag, charset).split(u"}").pop()
|
||||||
if tag in attributes:
|
if tag in attributes:
|
||||||
if isinstance(attributes[tag], list):
|
if isinstance(attributes[tag], list):
|
||||||
attributes[tag].append(attribute.text)
|
attributes[tag].append(cls.unicode(attribute.text, charset))
|
||||||
else:
|
else:
|
||||||
attributes[tag] = [attributes[tag]]
|
attributes[tag] = [attributes[tag]]
|
||||||
attributes[tag].append(attribute.text)
|
attributes[tag].append(cls.unicode(attribute.text, charset))
|
||||||
else:
|
else:
|
||||||
attributes[tag] = attribute.text
|
attributes[tag] = cls.unicode(attribute.text, charset)
|
||||||
return attributes
|
return attributes
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def verify_response(cls, response):
|
def verify_response(cls, response, charset):
|
||||||
return cls.parse_response_xml(response)
|
return cls.parse_response_xml(response, charset)
|
||||||
|
|
||||||
|
|
||||||
SAML_1_0_NS = 'urn:oasis:names:tc:SAML:1.0:'
|
SAML_1_0_NS = 'urn:oasis:names:tc:SAML:1.0:'
|
||||||
|
@ -284,6 +305,11 @@ class CASClientWithSAMLV1(CASClientV2, SingleLogoutMixin):
|
||||||
from elementtree import ElementTree
|
from elementtree import ElementTree
|
||||||
|
|
||||||
page = self.fetch_saml_validation(ticket)
|
page = self.fetch_saml_validation(ticket)
|
||||||
|
content_type = page.info().get('Content-type')
|
||||||
|
if "charset=" in content_type:
|
||||||
|
charset = content_type.split("charset=")[-1]
|
||||||
|
else:
|
||||||
|
charset = "ascii"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user = None
|
user = None
|
||||||
|
@ -296,21 +322,25 @@ class CASClientWithSAMLV1(CASClientV2, SingleLogoutMixin):
|
||||||
# User is validated
|
# User is validated
|
||||||
name_identifier = tree.find('.//' + SAML_1_0_ASSERTION_NS + 'NameIdentifier')
|
name_identifier = tree.find('.//' + SAML_1_0_ASSERTION_NS + 'NameIdentifier')
|
||||||
if name_identifier is not None:
|
if name_identifier is not None:
|
||||||
user = name_identifier.text
|
user = self.unicode(name_identifier.text, charset)
|
||||||
attrs = tree.findall('.//' + SAML_1_0_ASSERTION_NS + 'Attribute')
|
attrs = tree.findall('.//' + SAML_1_0_ASSERTION_NS + 'Attribute')
|
||||||
for at in attrs:
|
for at in attrs:
|
||||||
if self.username_attribute in list(at.attrib.values()):
|
if self.username_attribute in list(at.attrib.values()):
|
||||||
user = at.find(SAML_1_0_ASSERTION_NS + 'AttributeValue').text
|
user = self.unicode(
|
||||||
attributes['uid'] = user
|
at.find(SAML_1_0_ASSERTION_NS + 'AttributeValue').text,
|
||||||
|
charset
|
||||||
|
)
|
||||||
|
attributes[u'uid'] = user
|
||||||
|
|
||||||
values = at.findall(SAML_1_0_ASSERTION_NS + 'AttributeValue')
|
values = at.findall(SAML_1_0_ASSERTION_NS + 'AttributeValue')
|
||||||
|
key = self.unicode(at.attrib['AttributeName'], charset)
|
||||||
if len(values) > 1:
|
if len(values) > 1:
|
||||||
values_array = []
|
values_array = []
|
||||||
for v in values:
|
for v in values:
|
||||||
values_array.append(v.text)
|
values_array.append(self.unicode(v.text, charset))
|
||||||
attributes[at.attrib['AttributeName']] = values_array
|
attributes[key] = values_array
|
||||||
else:
|
else:
|
||||||
attributes[at.attrib['AttributeName']] = values[0].text
|
attributes[key] = self.unicode(values[0].text, charset)
|
||||||
return user, attributes, None
|
return user, attributes, None
|
||||||
finally:
|
finally:
|
||||||
page.close()
|
page.close()
|
||||||
|
|
Loading…
Reference in a new issue