django-cas-server/cas_server/utils.py

398 lines
13 KiB
Python
Raw Normal View History

2015-06-03 15:42:25 +00:00
# ⁻*- coding: utf-8 -*-
2015-05-27 20:10:06 +00:00
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
# more details.
#
# You should have received a copy of the GNU General Public License version 3
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# (c) 2015 Valentin Samir
2015-05-27 19:56:39 +00:00
"""Some util function for the app"""
from .default_settings import settings
2015-05-29 14:11:10 +00:00
from django.core.urlresolvers import reverse
from django.http import HttpResponseRedirect, HttpResponse
from django.contrib import messages
import random
import string
import json
import hashlib
import crypt
import base64
import six
2016-06-24 19:07:19 +00:00
from threading import Thread
2015-12-12 12:51:59 +00:00
from importlib import import_module
2016-06-17 17:28:49 +00:00
from datetime import datetime, timedelta
2016-06-24 21:37:24 +00:00
from six.moves import BaseHTTPServer
from six.moves.urllib.parse import urlparse, urlunparse, parse_qsl, urlencode
2015-06-21 16:56:16 +00:00
def context(params):
params["settings"] = settings
return params
2016-06-26 14:02:25 +00:00
def json_response(request, data):
data["messages"] = []
for msg in messages.get_messages(request):
data["messages"].append({'message': msg.message, 'level': msg.level_tag})
return HttpResponse(json.dumps(data), content_type="application/json")
def import_attr(path):
"""transform a python module.attr path to the attr"""
if not isinstance(path, str):
2016-05-11 11:06:41 +00:00
return path
if "." not in path:
ValueError("%r should be of the form `module.attr` and we just got `attr`" % path)
module, attr = path.rsplit('.', 1)
try:
return getattr(import_module(module), attr)
except ImportError:
raise ImportError("Module %r not found" % module)
except AttributeError:
raise AttributeError("Module %r has not attribut %r" % (module, attr))
2015-06-12 16:10:52 +00:00
2015-05-29 14:11:10 +00:00
def redirect_params(url_name, params=None):
"""Redirect to `url_name` with `params` as querystring"""
url = reverse(url_name)
2015-06-21 16:56:16 +00:00
params = urlencode(params if params else {})
2015-05-29 14:11:10 +00:00
return HttpResponseRedirect(url + "?%s" % params)
2015-06-12 16:10:52 +00:00
def reverse_params(url_name, params=None, **kwargs):
url = reverse(url_name, **kwargs)
params = urlencode(params if params else {})
2016-06-17 17:28:49 +00:00
if params:
return url + "?%s" % params
else:
return url
def copy_params(get_or_post_params, ignore=set()):
params = {}
for key in get_or_post_params:
if key not in ignore and get_or_post_params[key]:
params[key] = get_or_post_params[key]
return params
def set_cookie(response, key, value, max_age):
expires = datetime.strftime(
datetime.utcnow() + timedelta(seconds=max_age),
"%a, %d-%b-%Y %H:%M:%S GMT"
)
response.set_cookie(
key,
value,
max_age=max_age,
expires=expires,
domain=settings.SESSION_COOKIE_DOMAIN,
secure=settings.SESSION_COOKIE_SECURE or None
)
def get_current_url(request, ignore_params=set()):
protocol = 'https' if request.is_secure() else "http"
service_url = "%s://%s%s" % (protocol, request.get_host(), request.path)
if request.GET:
params = copy_params(request.GET, ignore_params)
if params:
service_url += "?%s" % urlencode(params)
return service_url
2015-05-16 21:43:46 +00:00
def update_url(url, params):
2015-05-27 19:56:39 +00:00
"""update params in the `url` query string"""
2015-06-21 16:56:16 +00:00
if not isinstance(url, bytes):
2015-06-12 16:10:52 +00:00
url = url.encode('utf-8')
2015-06-21 16:56:16 +00:00
for key, value in list(params.items()):
if not isinstance(key, bytes):
2015-06-03 15:42:25 +00:00
del params[key]
key = key.encode('utf-8')
2015-06-21 16:56:16 +00:00
if not isinstance(value, bytes):
2015-06-03 15:42:25 +00:00
value = value.encode('utf-8')
params[key] = value
2015-06-21 16:56:16 +00:00
url_parts = list(urlparse(url))
query = dict(parse_qsl(url_parts[4]))
2015-05-16 21:43:46 +00:00
query.update(params)
2015-06-21 16:56:16 +00:00
url_parts[4] = urlencode(query)
2016-06-26 09:16:41 +00:00
for i, url_part in enumerate(url_parts):
if not isinstance(url_part, bytes):
url_parts[i] = url_part.encode('utf-8')
2015-06-21 16:56:16 +00:00
return urlunparse(url_parts).decode('utf-8')
2015-06-12 16:10:52 +00:00
def unpack_nested_exception(error):
"""If exception are stacked, return the first one"""
i = 0
while True:
if error.args[i:]:
if isinstance(error.args[i], Exception):
error = error.args[i]
i = 0
else:
i += 1
else:
break
return error
def _gen_ticket(prefix, lg=settings.CAS_TICKET_LEN):
"""Generate a ticket with prefix `prefix`"""
return '%s-%s' % (
prefix,
''.join(
random.choice(
string.ascii_letters + string.digits
) for _ in range(lg - len(prefix) - 1)
)
)
2015-06-12 16:10:52 +00:00
def gen_lt():
"""Generate a Service Ticket"""
return _gen_ticket(settings.CAS_LOGIN_TICKET_PREFIX, settings.CAS_LT_LEN)
2015-06-12 16:10:52 +00:00
def gen_st():
"""Generate a Service Ticket"""
return _gen_ticket(settings.CAS_SERVICE_TICKET_PREFIX, settings.CAS_ST_LEN)
2015-06-12 16:10:52 +00:00
def gen_pt():
"""Generate a Proxy Ticket"""
return _gen_ticket(settings.CAS_PROXY_TICKET_PREFIX, settings.CAS_PT_LEN)
2015-06-12 16:10:52 +00:00
def gen_pgt():
"""Generate a Proxy Granting Ticket"""
return _gen_ticket(settings.CAS_PROXY_GRANTING_TICKET_PREFIX, settings.CAS_PGT_LEN)
2015-06-12 16:10:52 +00:00
def gen_pgtiou():
"""Generate a Proxy Granting Ticket IOU"""
return _gen_ticket(settings.CAS_PROXY_GRANTING_TICKET_IOU_PREFIX, settings.CAS_PGTIOU_LEN)
def gen_saml_id():
"""Generate an saml id"""
return _gen_ticket('_')
def get_tuple(tuple, index, default=None):
2016-06-23 15:18:53 +00:00
if tuple is None:
return default
try:
return tuple[index]
except IndexError:
return default
2016-06-27 22:34:31 +00:00
2016-06-24 19:07:19 +00:00
class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler):
2016-06-24 19:23:33 +00:00
PARAMS = {}
2016-06-26 14:02:25 +00:00
def do_GET(self):
self.send_response(200)
self.send_header(b"Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"ok")
url = urlparse(self.path)
2016-06-24 19:07:19 +00:00
params = dict(parse_qsl(url.query))
PGTUrlHandler.PARAMS.update(params)
2016-06-24 19:23:33 +00:00
2016-06-26 14:02:25 +00:00
def log_message(self, *args):
2016-06-24 19:07:19 +00:00
return
@staticmethod
def run():
server_class = BaseHTTPServer.HTTPServer
httpd = server_class(("127.0.0.1", 0), PGTUrlHandler)
2016-06-24 19:23:33 +00:00
(host, port) = httpd.socket.getsockname()
2016-06-24 19:07:19 +00:00
def lauch():
httpd.handle_request()
httpd.server_close()
2016-06-24 19:23:33 +00:00
2016-06-24 19:07:19 +00:00
httpd_thread = Thread(target=lauch)
httpd_thread.daemon = True
httpd_thread.start()
return (httpd_thread, host, port)
class LdapHashUserPassword(object):
"""Please see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html"""
schemes_salt = {b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}", b"{CRYPT}"}
schemes_nosalt = {b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"}
_schemes_to_hash = {
b"{SMD5}": hashlib.md5,
b"{MD5}": hashlib.md5,
b"{SSHA}": hashlib.sha1,
b"{SHA}": hashlib.sha1,
b"{SSHA256}": hashlib.sha256,
b"{SHA256}": hashlib.sha256,
b"{SSHA384}": hashlib.sha384,
b"{SHA384}": hashlib.sha384,
b"{SSHA512}": hashlib.sha512,
b"{SHA512}": hashlib.sha512
}
_schemes_to_len = {
b"{SMD5}": 16,
b"{SSHA}": 20,
b"{SSHA256}": 32,
b"{SSHA384}": 48,
b"{SSHA512}": 64,
}
class BadScheme(ValueError):
2016-06-26 20:07:38 +00:00
"""Error raised then the hash scheme is not in schemes_salt + schemes_nosalt"""
pass
class BadHash(ValueError):
2016-06-26 20:07:38 +00:00
"""Error raised then the hash is too short"""
pass
class BadSalt(ValueError):
2016-06-26 20:07:38 +00:00
"""Error raised then with the scheme {CRYPT} the salt is invalid"""
pass
@classmethod
def _raise_bad_scheme(cls, scheme, valid, msg):
2016-06-26 20:07:38 +00:00
"""
Raise BadScheme error for `scheme`, possible valid scheme are
in `valid`, the error message is `msg`
"""
valid_schemes = [s.decode() for s in valid]
valid_schemes.sort()
raise cls.BadScheme(msg % (scheme, u", ".join(valid_schemes)))
@classmethod
def _test_scheme(cls, scheme):
2016-06-26 20:07:38 +00:00
"""Test if a scheme is valide or raise BadScheme"""
if scheme not in cls.schemes_salt and scheme not in cls.schemes_nosalt:
cls._raise_bad_scheme(
scheme,
cls.schemes_salt | cls.schemes_nosalt,
"The scheme %r is not valid. Valide schemes are %s."
)
@classmethod
def _test_scheme_salt(cls, scheme):
2016-06-26 20:07:38 +00:00
"""Test if the scheme need a salt or raise BadScheme"""
if scheme not in cls.schemes_salt:
cls._raise_bad_scheme(
scheme,
cls.schemes_salt,
"The scheme %r is only valid without a salt. Valide schemes with salt are %s."
)
@classmethod
def _test_scheme_nosalt(cls, scheme):
2016-06-26 20:07:38 +00:00
"""Test if the scheme need no salt or raise BadScheme"""
if scheme not in cls.schemes_nosalt:
cls._raise_bad_scheme(
scheme,
cls.schemes_nosalt,
"The scheme %r is only valid with a salt. Valide schemes without salt are %s."
)
@classmethod
def hash(cls, scheme, password, salt=None, charset="utf8"):
2016-06-26 20:07:38 +00:00
"""
Hash `password` with `scheme` using `salt`.
This three variable beeing encoded in `charset`.
"""
scheme = scheme.upper()
cls._test_scheme(scheme)
if salt is None or salt == b"":
salt = b""
cls._test_scheme_nosalt(scheme)
elif salt is not None:
cls._test_scheme_salt(scheme)
try:
return scheme + base64.b64encode(
cls._schemes_to_hash[scheme](password + salt).digest() + salt
)
except KeyError:
if six.PY3:
password = password.decode(charset)
salt = salt.decode(charset)
hashed_password = crypt.crypt(password, salt)
if hashed_password is None:
raise cls.BadSalt("System crypt implementation do not support the salt %r" % salt)
if six.PY3:
hashed_password = hashed_password.encode(charset)
return scheme + hashed_password
@classmethod
def get_scheme(cls, hashed_passord):
2016-06-26 20:07:38 +00:00
"""Return the scheme of `hashed_passord` or raise BadHash"""
if not hashed_passord[0] == b'{'[0] or b'}' not in hashed_passord:
raise cls.BadHash("%r should start with the scheme enclosed with { }" % hashed_passord)
scheme = hashed_passord.split(b'}', 1)[0]
scheme = scheme.upper() + b"}"
return scheme
@classmethod
def get_salt(cls, hashed_passord):
2016-06-26 20:07:38 +00:00
"""Return the salt of `hashed_passord` possibly empty"""
scheme = cls.get_scheme(hashed_passord)
cls._test_scheme(scheme)
if scheme in cls.schemes_nosalt:
return b""
elif scheme == b'{CRYPT}':
return b'$'.join(hashed_passord.split(b'$', 3)[:-1])
else:
hashed_passord = base64.b64decode(hashed_passord[len(scheme):])
if len(hashed_passord) < cls._schemes_to_len[scheme]:
raise cls.BadHash("Hash too short for the scheme %s" % scheme)
return hashed_passord[cls._schemes_to_len[scheme]:]
def check_password(method, password, hashed_password, charset):
2016-06-26 20:07:38 +00:00
"""
Check that `password` match `hashed_password` using `method`,
assuming the encoding is `charset`.
"""
if not isinstance(password, six.binary_type):
password = password.encode(charset)
if not isinstance(hashed_password, six.binary_type):
hashed_password = hashed_password.encode(charset)
if method == "plain":
return password == hashed_password
elif method == "crypt":
if hashed_password.startswith(b'$'):
salt = b'$'.join(hashed_password.split(b'$', 3)[:-1])
elif hashed_password.startswith(b'_'):
salt = hashed_password[:9]
else:
salt = hashed_password[:2]
if six.PY3:
password = password.decode(charset)
salt = salt.decode(charset)
hashed_password = hashed_password.decode(charset)
crypted_password = crypt.crypt(password, salt)
if crypted_password is None:
raise ValueError("System crypt implementation do not support the salt %r" % salt)
return crypted_password == hashed_password
elif method == "ldap":
scheme = LdapHashUserPassword.get_scheme(hashed_password)
salt = LdapHashUserPassword.get_salt(hashed_password)
return LdapHashUserPassword.hash(scheme, password, salt, charset=charset) == hashed_password
elif (
method.startswith("hex_") and
method[4:] in {"md5", "sha1", "sha224", "sha256", "sha384", "sha512"}
):
return getattr(
hashlib,
method[4:]
)(password).hexdigest().encode("ascii") == hashed_password.lower()
else:
raise ValueError("Unknown password method check %r" % method)