Add some docstrings
This commit is contained in:
parent
7cc3ba689f
commit
8e5b75e090
7 changed files with 42 additions and 4 deletions
|
@ -144,6 +144,7 @@ class DjangoAuthUser(AuthUser): # pragma: no cover
|
|||
|
||||
|
||||
class CASFederateAuth(AuthUser):
|
||||
"""Authentication class used then CAS_FEDERATE is True"""
|
||||
user = None
|
||||
|
||||
def __init__(self, username):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# ⁻*- coding: utf-8 -*-
|
||||
# -*- coding: utf-8 -*-
|
||||
# This program is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
|
||||
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
|
||||
|
@ -9,6 +9,7 @@
|
|||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||
#
|
||||
# (c) 2015 Valentin Samir
|
||||
"""federated mode helper classes"""
|
||||
from .default_settings import settings
|
||||
|
||||
from .cas import CASClient
|
||||
|
@ -21,6 +22,7 @@ SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
|
|||
|
||||
|
||||
class CASFederateValidateUser(object):
|
||||
"""Class CAS client used to authenticate the user again a CAS provider"""
|
||||
username = None
|
||||
attributs = {}
|
||||
client = None
|
||||
|
@ -38,13 +40,15 @@ class CASFederateValidateUser(object):
|
|||
)
|
||||
|
||||
def get_login_url(self):
|
||||
"""return the CAS provider login url"""
|
||||
return self.client.get_login_url() if self.client is not None else False
|
||||
|
||||
def get_logout_url(self, redirect_url=None):
|
||||
"""return the CAS provider logout url"""
|
||||
return self.client.get_logout_url(redirect_url) if self.client is not None else False
|
||||
|
||||
def verify_ticket(self, ticket):
|
||||
"""test `password` agains the user"""
|
||||
"""test `ticket` agains the CAS provider, if valid, create the local federated user"""
|
||||
if self.client is None: # pragma: no cover (should not happen)
|
||||
return False
|
||||
try:
|
||||
|
@ -79,6 +83,7 @@ class CASFederateValidateUser(object):
|
|||
|
||||
@staticmethod
|
||||
def register_slo(username, session_key, ticket):
|
||||
"""association a ticket with a (username, session) for processing later SLO request"""
|
||||
FederateSLO.objects.create(
|
||||
username=username,
|
||||
session_key=session_key,
|
||||
|
@ -86,6 +91,7 @@ class CASFederateValidateUser(object):
|
|||
)
|
||||
|
||||
def clean_sessions(self, logout_request):
|
||||
"""process a SLO request"""
|
||||
try:
|
||||
slos = self.client.get_saml_slos(logout_request) or []
|
||||
except NameError: # pragma: no cover (should not happen)
|
||||
|
|
|
@ -29,6 +29,10 @@ class WarnForm(forms.Form):
|
|||
|
||||
|
||||
class FederateSelect(forms.Form):
|
||||
"""
|
||||
Form used on the login page when CAS_FEDERATE is True
|
||||
allowing the user to choose a identity provider.
|
||||
"""
|
||||
provider = forms.ChoiceField(
|
||||
label=_('Identity provider'),
|
||||
# with use a lambda abstraction to delay the access to settings.CAS_FEDERATE_PROVIDERS
|
||||
|
|
|
@ -35,6 +35,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class FederatedUser(models.Model):
|
||||
"""A federated user as returner by a CAS provider (username and attributes)"""
|
||||
class Meta:
|
||||
unique_together = ("username", "provider")
|
||||
username = models.CharField(max_length=124)
|
||||
|
@ -48,6 +49,7 @@ class FederatedUser(models.Model):
|
|||
|
||||
@classmethod
|
||||
def clean_old_entries(cls):
|
||||
"""remove old unused federated users"""
|
||||
federated_users = cls.objects.filter(
|
||||
last_update__lt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT))
|
||||
)
|
||||
|
@ -58,6 +60,7 @@ class FederatedUser(models.Model):
|
|||
|
||||
|
||||
class FederateSLO(models.Model):
|
||||
"""An association between a CAS provider ticket and a (username, session) for processing SLO"""
|
||||
class Meta:
|
||||
unique_together = ("username", "session_key")
|
||||
username = models.CharField(max_length=30)
|
||||
|
@ -66,6 +69,7 @@ class FederateSLO(models.Model):
|
|||
|
||||
@classmethod
|
||||
def clean_deleted_sessions(cls):
|
||||
"""remove old object for which the session do not exists anymore"""
|
||||
for federate_slo in cls.objects.all():
|
||||
if not SessionStore(session_key=federate_slo.session_key).get('authenticated'):
|
||||
federate_slo.delete()
|
||||
|
@ -82,6 +86,7 @@ class User(models.Model):
|
|||
date = models.DateTimeField(auto_now=True)
|
||||
|
||||
def delete(self, *args, **kwargs):
|
||||
"""remove the User"""
|
||||
if settings.CAS_FEDERATE:
|
||||
FederateSLO.objects.filter(
|
||||
username=self.username,
|
||||
|
|
|
@ -29,6 +29,7 @@ from cas_server import utils
|
|||
|
||||
|
||||
def return_unicode(string, charset):
|
||||
"""make `string` a unicode if `string` is a unicode or bytes encoded with `charset`"""
|
||||
if not isinstance(string, six.text_type):
|
||||
return string.decode(charset)
|
||||
else:
|
||||
|
@ -36,6 +37,10 @@ def return_unicode(string, charset):
|
|||
|
||||
|
||||
def return_bytes(string, charset):
|
||||
"""
|
||||
make `string` a bytes encoded with `charset` if `string` is a unicode
|
||||
or bytes encoded with `charset`
|
||||
"""
|
||||
if isinstance(string, six.text_type):
|
||||
return string.encode(charset)
|
||||
else:
|
||||
|
@ -200,8 +205,9 @@ class Http404Handler(HttpParamsHandler):
|
|||
|
||||
|
||||
class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
|
||||
|
||||
"""A dummy CAS that validate for only one (service, ticket) used in federated mode tests"""
|
||||
def test_params(self):
|
||||
"""check that internal and provided (service, ticket) matches"""
|
||||
if (
|
||||
self.server.ticket is not None and
|
||||
self.params.get("service").encode("ascii") == self.server.service and
|
||||
|
@ -213,11 +219,13 @@ class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
|
|||
return False
|
||||
|
||||
def send_headers(self, code, content_type):
|
||||
"""send http headers"""
|
||||
self.send_response(code)
|
||||
self.send_header("Content-type", content_type)
|
||||
self.end_headers()
|
||||
|
||||
def do_GET(self):
|
||||
"""Called on a GET request on the BaseHTTPServer"""
|
||||
url = urlparse(self.path)
|
||||
self.params = dict(parse_qsl(url.query))
|
||||
if url.path == "/validate":
|
||||
|
@ -250,6 +258,7 @@ class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
|
|||
self.return_404()
|
||||
|
||||
def do_POST(self):
|
||||
"""Called on a POST request on the BaseHTTPServer"""
|
||||
url = urlparse(self.path)
|
||||
self.params = dict(parse_qsl(url.query))
|
||||
if url.path == "/samlValidate":
|
||||
|
@ -287,6 +296,7 @@ class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
|
|||
self.return_404()
|
||||
|
||||
def return_404(self):
|
||||
"""return a 404 error"""
|
||||
self.send_headers(404, "text/plain; charset=utf-8")
|
||||
self.wfile.write("not found")
|
||||
|
||||
|
@ -317,6 +327,7 @@ class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
|
|||
|
||||
|
||||
def logout_request(ticket):
|
||||
"""build a SLO request XML, ready to be send"""
|
||||
return u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
|
||||
ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
|
||||
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
|
||||
|
|
|
@ -76,6 +76,7 @@ def reverse_params(url_name, params=None, **kwargs):
|
|||
|
||||
|
||||
def copy_params(get_or_post_params, ignore=None):
|
||||
"""copy from a dictionnary like `get_or_post_params` ignoring keys in the set `ignore`"""
|
||||
if ignore is None:
|
||||
ignore = set()
|
||||
params = {}
|
||||
|
@ -86,6 +87,7 @@ def copy_params(get_or_post_params, ignore=None):
|
|||
|
||||
|
||||
def set_cookie(response, key, value, max_age):
|
||||
"""Set the cookie `key` on `response` with value `value` valid for `max_age` secondes"""
|
||||
expires = datetime.strftime(
|
||||
datetime.utcnow() + timedelta(seconds=max_age),
|
||||
"%a, %d-%b-%Y %H:%M:%S GMT"
|
||||
|
@ -101,6 +103,7 @@ def set_cookie(response, key, value, max_age):
|
|||
|
||||
|
||||
def get_current_url(request, ignore_params=None):
|
||||
"""Giving a django request, return the current http url, possibly ignoring some GET params"""
|
||||
if ignore_params is None:
|
||||
ignore_params = set()
|
||||
protocol = 'https' if request.is_secure() else "http"
|
||||
|
@ -194,6 +197,10 @@ def gen_saml_id():
|
|||
|
||||
|
||||
def get_tuple(nuplet, index, default=None):
|
||||
"""
|
||||
return the value in index `index` of the tuple `nuplet` if it exists,
|
||||
else return `default`
|
||||
"""
|
||||
if nuplet is None:
|
||||
return default
|
||||
try:
|
||||
|
|
|
@ -192,18 +192,21 @@ class LogoutView(View, LogoutMixin):
|
|||
|
||||
|
||||
class FederateAuth(View):
|
||||
|
||||
"""view to authenticated user agains a backend CAS then CAS_FEDERATE is True"""
|
||||
@method_decorator(csrf_exempt)
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
"""dispatch different http request to the methods of the same name"""
|
||||
return super(FederateAuth, self).dispatch(request, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def get_cas_client(request, provider):
|
||||
"""return a CAS client object matching provider"""
|
||||
if provider in settings.CAS_FEDERATE_PROVIDERS: # pragma: no branch (should always be true)
|
||||
service_url = utils.get_current_url(request, {"ticket", "provider"})
|
||||
return CASFederateValidateUser(provider, service_url)
|
||||
|
||||
def post(self, request, provider=None):
|
||||
"""method called on POST request"""
|
||||
if not settings.CAS_FEDERATE:
|
||||
return redirect("cas_server:login")
|
||||
# POST with a provider, this is probably an SLO request
|
||||
|
@ -245,6 +248,7 @@ class FederateAuth(View):
|
|||
return redirect("cas_server:login")
|
||||
|
||||
def get(self, request, provider=None):
|
||||
"""method called on GET request"""
|
||||
if not settings.CAS_FEDERATE:
|
||||
return redirect("cas_server:login")
|
||||
if provider not in settings.CAS_FEDERATE_PROVIDERS:
|
||||
|
|
Loading…
Reference in a new issue