Add some docstrings

This commit is contained in:
Valentin Samir 2016-07-03 17:54:11 +02:00
parent 7cc3ba689f
commit 8e5b75e090
7 changed files with 42 additions and 4 deletions

View file

@ -144,6 +144,7 @@ class DjangoAuthUser(AuthUser): # pragma: no cover
class CASFederateAuth(AuthUser): class CASFederateAuth(AuthUser):
"""Authentication class used then CAS_FEDERATE is True"""
user = None user = None
def __init__(self, username): def __init__(self, username):

View file

@ -1,4 +1,4 @@
# *- coding: utf-8 -*- # -*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT # This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for # FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
@ -9,6 +9,7 @@
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
# #
# (c) 2015 Valentin Samir # (c) 2015 Valentin Samir
"""federated mode helper classes"""
from .default_settings import settings from .default_settings import settings
from .cas import CASClient from .cas import CASClient
@ -21,6 +22,7 @@ SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
class CASFederateValidateUser(object): class CASFederateValidateUser(object):
"""Class CAS client used to authenticate the user again a CAS provider"""
username = None username = None
attributs = {} attributs = {}
client = None client = None
@ -38,13 +40,15 @@ class CASFederateValidateUser(object):
) )
def get_login_url(self): def get_login_url(self):
"""return the CAS provider login url"""
return self.client.get_login_url() if self.client is not None else False return self.client.get_login_url() if self.client is not None else False
def get_logout_url(self, redirect_url=None): def get_logout_url(self, redirect_url=None):
"""return the CAS provider logout url"""
return self.client.get_logout_url(redirect_url) if self.client is not None else False return self.client.get_logout_url(redirect_url) if self.client is not None else False
def verify_ticket(self, ticket): def verify_ticket(self, ticket):
"""test `password` agains the user""" """test `ticket` agains the CAS provider, if valid, create the local federated user"""
if self.client is None: # pragma: no cover (should not happen) if self.client is None: # pragma: no cover (should not happen)
return False return False
try: try:
@ -79,6 +83,7 @@ class CASFederateValidateUser(object):
@staticmethod @staticmethod
def register_slo(username, session_key, ticket): def register_slo(username, session_key, ticket):
"""association a ticket with a (username, session) for processing later SLO request"""
FederateSLO.objects.create( FederateSLO.objects.create(
username=username, username=username,
session_key=session_key, session_key=session_key,
@ -86,6 +91,7 @@ class CASFederateValidateUser(object):
) )
def clean_sessions(self, logout_request): def clean_sessions(self, logout_request):
"""process a SLO request"""
try: try:
slos = self.client.get_saml_slos(logout_request) or [] slos = self.client.get_saml_slos(logout_request) or []
except NameError: # pragma: no cover (should not happen) except NameError: # pragma: no cover (should not happen)

View file

@ -29,6 +29,10 @@ class WarnForm(forms.Form):
class FederateSelect(forms.Form): class FederateSelect(forms.Form):
"""
Form used on the login page when CAS_FEDERATE is True
allowing the user to choose a identity provider.
"""
provider = forms.ChoiceField( provider = forms.ChoiceField(
label=_('Identity provider'), label=_('Identity provider'),
# with use a lambda abstraction to delay the access to settings.CAS_FEDERATE_PROVIDERS # with use a lambda abstraction to delay the access to settings.CAS_FEDERATE_PROVIDERS

View file

@ -35,6 +35,7 @@ logger = logging.getLogger(__name__)
class FederatedUser(models.Model): class FederatedUser(models.Model):
"""A federated user as returner by a CAS provider (username and attributes)"""
class Meta: class Meta:
unique_together = ("username", "provider") unique_together = ("username", "provider")
username = models.CharField(max_length=124) username = models.CharField(max_length=124)
@ -48,6 +49,7 @@ class FederatedUser(models.Model):
@classmethod @classmethod
def clean_old_entries(cls): def clean_old_entries(cls):
"""remove old unused federated users"""
federated_users = cls.objects.filter( federated_users = cls.objects.filter(
last_update__lt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT)) last_update__lt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT))
) )
@ -58,6 +60,7 @@ class FederatedUser(models.Model):
class FederateSLO(models.Model): class FederateSLO(models.Model):
"""An association between a CAS provider ticket and a (username, session) for processing SLO"""
class Meta: class Meta:
unique_together = ("username", "session_key") unique_together = ("username", "session_key")
username = models.CharField(max_length=30) username = models.CharField(max_length=30)
@ -66,6 +69,7 @@ class FederateSLO(models.Model):
@classmethod @classmethod
def clean_deleted_sessions(cls): def clean_deleted_sessions(cls):
"""remove old object for which the session do not exists anymore"""
for federate_slo in cls.objects.all(): for federate_slo in cls.objects.all():
if not SessionStore(session_key=federate_slo.session_key).get('authenticated'): if not SessionStore(session_key=federate_slo.session_key).get('authenticated'):
federate_slo.delete() federate_slo.delete()
@ -82,6 +86,7 @@ class User(models.Model):
date = models.DateTimeField(auto_now=True) date = models.DateTimeField(auto_now=True)
def delete(self, *args, **kwargs): def delete(self, *args, **kwargs):
"""remove the User"""
if settings.CAS_FEDERATE: if settings.CAS_FEDERATE:
FederateSLO.objects.filter( FederateSLO.objects.filter(
username=self.username, username=self.username,

View file

@ -29,6 +29,7 @@ from cas_server import utils
def return_unicode(string, charset): def return_unicode(string, charset):
"""make `string` a unicode if `string` is a unicode or bytes encoded with `charset`"""
if not isinstance(string, six.text_type): if not isinstance(string, six.text_type):
return string.decode(charset) return string.decode(charset)
else: else:
@ -36,6 +37,10 @@ def return_unicode(string, charset):
def return_bytes(string, charset): def return_bytes(string, charset):
"""
make `string` a bytes encoded with `charset` if `string` is a unicode
or bytes encoded with `charset`
"""
if isinstance(string, six.text_type): if isinstance(string, six.text_type):
return string.encode(charset) return string.encode(charset)
else: else:
@ -200,8 +205,9 @@ class Http404Handler(HttpParamsHandler):
class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler): class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
"""A dummy CAS that validate for only one (service, ticket) used in federated mode tests"""
def test_params(self): def test_params(self):
"""check that internal and provided (service, ticket) matches"""
if ( if (
self.server.ticket is not None and self.server.ticket is not None and
self.params.get("service").encode("ascii") == self.server.service and self.params.get("service").encode("ascii") == self.server.service and
@ -213,11 +219,13 @@ class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
return False return False
def send_headers(self, code, content_type): def send_headers(self, code, content_type):
"""send http headers"""
self.send_response(code) self.send_response(code)
self.send_header("Content-type", content_type) self.send_header("Content-type", content_type)
self.end_headers() self.end_headers()
def do_GET(self): def do_GET(self):
"""Called on a GET request on the BaseHTTPServer"""
url = urlparse(self.path) url = urlparse(self.path)
self.params = dict(parse_qsl(url.query)) self.params = dict(parse_qsl(url.query))
if url.path == "/validate": if url.path == "/validate":
@ -250,6 +258,7 @@ class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
self.return_404() self.return_404()
def do_POST(self): def do_POST(self):
"""Called on a POST request on the BaseHTTPServer"""
url = urlparse(self.path) url = urlparse(self.path)
self.params = dict(parse_qsl(url.query)) self.params = dict(parse_qsl(url.query))
if url.path == "/samlValidate": if url.path == "/samlValidate":
@ -287,6 +296,7 @@ class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
self.return_404() self.return_404()
def return_404(self): def return_404(self):
"""return a 404 error"""
self.send_headers(404, "text/plain; charset=utf-8") self.send_headers(404, "text/plain; charset=utf-8")
self.wfile.write("not found") self.wfile.write("not found")
@ -317,6 +327,7 @@ class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
def logout_request(ticket): def logout_request(ticket):
"""build a SLO request XML, ready to be send"""
return u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" return u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s"> ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID> <saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>

View file

@ -76,6 +76,7 @@ def reverse_params(url_name, params=None, **kwargs):
def copy_params(get_or_post_params, ignore=None): def copy_params(get_or_post_params, ignore=None):
"""copy from a dictionnary like `get_or_post_params` ignoring keys in the set `ignore`"""
if ignore is None: if ignore is None:
ignore = set() ignore = set()
params = {} params = {}
@ -86,6 +87,7 @@ def copy_params(get_or_post_params, ignore=None):
def set_cookie(response, key, value, max_age): def set_cookie(response, key, value, max_age):
"""Set the cookie `key` on `response` with value `value` valid for `max_age` secondes"""
expires = datetime.strftime( expires = datetime.strftime(
datetime.utcnow() + timedelta(seconds=max_age), datetime.utcnow() + timedelta(seconds=max_age),
"%a, %d-%b-%Y %H:%M:%S GMT" "%a, %d-%b-%Y %H:%M:%S GMT"
@ -101,6 +103,7 @@ def set_cookie(response, key, value, max_age):
def get_current_url(request, ignore_params=None): def get_current_url(request, ignore_params=None):
"""Giving a django request, return the current http url, possibly ignoring some GET params"""
if ignore_params is None: if ignore_params is None:
ignore_params = set() ignore_params = set()
protocol = 'https' if request.is_secure() else "http" protocol = 'https' if request.is_secure() else "http"
@ -194,6 +197,10 @@ def gen_saml_id():
def get_tuple(nuplet, index, default=None): def get_tuple(nuplet, index, default=None):
"""
return the value in index `index` of the tuple `nuplet` if it exists,
else return `default`
"""
if nuplet is None: if nuplet is None:
return default return default
try: try:

View file

@ -192,18 +192,21 @@ class LogoutView(View, LogoutMixin):
class FederateAuth(View): class FederateAuth(View):
"""view to authenticated user agains a backend CAS then CAS_FEDERATE is True"""
@method_decorator(csrf_exempt) @method_decorator(csrf_exempt)
def dispatch(self, request, *args, **kwargs): def dispatch(self, request, *args, **kwargs):
"""dispatch different http request to the methods of the same name"""
return super(FederateAuth, self).dispatch(request, *args, **kwargs) return super(FederateAuth, self).dispatch(request, *args, **kwargs)
@staticmethod @staticmethod
def get_cas_client(request, provider): def get_cas_client(request, provider):
"""return a CAS client object matching provider"""
if provider in settings.CAS_FEDERATE_PROVIDERS: # pragma: no branch (should always be true) if provider in settings.CAS_FEDERATE_PROVIDERS: # pragma: no branch (should always be true)
service_url = utils.get_current_url(request, {"ticket", "provider"}) service_url = utils.get_current_url(request, {"ticket", "provider"})
return CASFederateValidateUser(provider, service_url) return CASFederateValidateUser(provider, service_url)
def post(self, request, provider=None): def post(self, request, provider=None):
"""method called on POST request"""
if not settings.CAS_FEDERATE: if not settings.CAS_FEDERATE:
return redirect("cas_server:login") return redirect("cas_server:login")
# POST with a provider, this is probably an SLO request # POST with a provider, this is probably an SLO request
@ -245,6 +248,7 @@ class FederateAuth(View):
return redirect("cas_server:login") return redirect("cas_server:login")
def get(self, request, provider=None): def get(self, request, provider=None):
"""method called on GET request"""
if not settings.CAS_FEDERATE: if not settings.CAS_FEDERATE:
return redirect("cas_server:login") return redirect("cas_server:login")
if provider not in settings.CAS_FEDERATE_PROVIDERS: if provider not in settings.CAS_FEDERATE_PROVIDERS: