By session logout
This commit is contained in:
parent
41fcc06200
commit
245086f6ef
7 changed files with 111 additions and 22 deletions
|
@ -22,22 +22,24 @@ class UserCredential(forms.Form):
|
|||
username = forms.CharField(label=_('login'))
|
||||
service = forms.CharField(widget=forms.HiddenInput(), required=False)
|
||||
password = forms.CharField(label=_('password'), widget=forms.PasswordInput)
|
||||
lt = forms.CharField(widget=forms.HiddenInput())
|
||||
lt = forms.CharField(widget=forms.HiddenInput(), required=False)
|
||||
method = forms.CharField(widget=forms.HiddenInput(), required=False)
|
||||
warn = forms.BooleanField(label=_('warn'), required=False)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, request, *args, **kwargs):
|
||||
self.request = request
|
||||
super(UserCredential, self).__init__(*args, **kwargs)
|
||||
|
||||
def clean(self):
|
||||
cleaned_data = super(UserCredential, self).clean()
|
||||
auth = utils.import_attr(settings.CAS_AUTH_CLASS)(cleaned_data.get("username"))
|
||||
if auth.test_password(cleaned_data.get("password")):
|
||||
session = utils.get_session(self.request)
|
||||
try:
|
||||
user = models.User.objects.get(username=auth.username)
|
||||
user = models.User.objects.get(username=auth.username, session=session)
|
||||
user.save()
|
||||
except models.User.DoesNotExist:
|
||||
user = models.User.objects.create(username=auth.username)
|
||||
user = models.User.objects.create(username=auth.username, session=session)
|
||||
user.save()
|
||||
else:
|
||||
raise forms.ValidationError(_(u"Bad user"))
|
||||
|
|
|
@ -8,5 +8,6 @@ class Command(BaseCommand):
|
|||
help = _(u"Clean old trickets")
|
||||
|
||||
def handle(self, *args, **options):
|
||||
models.User.clean_old_entries()
|
||||
for ticket_class in [models.ServiceTicket, models.ProxyTicket, models.ProxyGrantingTicket]:
|
||||
ticket_class.clean()
|
||||
ticket_class.clean_old_entries()
|
||||
|
|
31
cas_server/migrations/0019_auto_20150609_1903.py
Normal file
31
cas_server/migrations/0019_auto_20150609_1903.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from django.db import models, migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('sessions', '0001_initial'),
|
||||
('cas_server', '0018_auto_20150608_1621'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='user',
|
||||
name='session',
|
||||
field=models.OneToOneField(related_name='cas_server_user', null=True, blank=True, to='sessions.Session'),
|
||||
preserve_default=True,
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='user',
|
||||
name='username',
|
||||
field=models.CharField(max_length=30),
|
||||
preserve_default=True,
|
||||
),
|
||||
migrations.AlterUniqueTogether(
|
||||
name='user',
|
||||
unique_together=set([('username', 'session')]),
|
||||
),
|
||||
]
|
21
cas_server/migrations/0020_auto_20150609_1917.py
Normal file
21
cas_server/migrations/0020_auto_20150609_1917.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from django.db import models, migrations
|
||||
import django.db.models.deletion
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('cas_server', '0019_auto_20150609_1903'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='user',
|
||||
name='session',
|
||||
field=models.OneToOneField(related_name='cas_server_user', null=True, on_delete=django.db.models.deletion.SET_NULL, blank=True, to='sessions.Session'),
|
||||
preserve_default=True,
|
||||
),
|
||||
]
|
|
@ -17,6 +17,7 @@ from django.db.models import Q
|
|||
from django.contrib import messages
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
from django.utils import timezone
|
||||
from django.contrib.sessions.models import Session
|
||||
from picklefield.fields import PickledObjectField
|
||||
|
||||
import re
|
||||
|
@ -30,18 +31,31 @@ import utils
|
|||
|
||||
class User(models.Model):
|
||||
"""A user logged into the CAS"""
|
||||
username = models.CharField(max_length=30, unique=True)
|
||||
class Meta:
|
||||
unique_together = ("username", "session")
|
||||
session = models.OneToOneField(Session, related_name="cas_server_user", blank=True, null=True, on_delete=models.SET_NULL)
|
||||
username = models.CharField(max_length=30)
|
||||
date = models.DateTimeField(auto_now_add=True, auto_now=True)
|
||||
|
||||
@classmethod
|
||||
def clean_old_entries(cls):
|
||||
users = cls.objects.filter(session=None)
|
||||
for user in users:
|
||||
user.logout()
|
||||
users.delete()
|
||||
|
||||
@property
|
||||
def attributs(self):
|
||||
"""return a fresh dict for the user attributs"""
|
||||
return utils.import_attr(settings.CAS_AUTH_CLASS)(self.username).attributs()
|
||||
|
||||
def __unicode__(self):
|
||||
return self.username
|
||||
if self.session:
|
||||
return u"%s - %s" % (self.username, self.session.session_key)
|
||||
else:
|
||||
return self.username
|
||||
|
||||
def logout(self, request):
|
||||
def logout(self, request=None):
|
||||
"""Sending SLO request to all services the user logged in"""
|
||||
async_list = []
|
||||
session = FuturesSession(executor=ThreadPoolExecutor(max_workers=10))
|
||||
|
@ -59,12 +73,13 @@ class User(models.Model):
|
|||
try:
|
||||
future.result()
|
||||
except Exception as error:
|
||||
error = utils.unpack_nested_exception(error)
|
||||
messages.add_message(
|
||||
request,
|
||||
messages.WARNING,
|
||||
_(u'Error during service logout %s') % error
|
||||
)
|
||||
if request is not None:
|
||||
error = utils.unpack_nested_exception(error)
|
||||
messages.add_message(
|
||||
request,
|
||||
messages.WARNING,
|
||||
_(u'Error during service logout %s') % error
|
||||
)
|
||||
|
||||
def get_ticket(self, ticket_class, service, service_pattern, renew):
|
||||
"""
|
||||
|
@ -309,7 +324,7 @@ class Ticket(models.Model):
|
|||
return u"Ticket(%s, %s)" % (self.user, self.service)
|
||||
|
||||
@classmethod
|
||||
def clean(cls):
|
||||
def clean_old_entries(cls):
|
||||
"""Remove old ticket and send SLO to timed-out services"""
|
||||
# removing old validated ticket and non validated expired tickets
|
||||
cls.objects.filter(
|
||||
|
|
|
@ -15,6 +15,7 @@ from .default_settings import settings
|
|||
from django.utils.importlib import import_module
|
||||
from django.core.urlresolvers import reverse
|
||||
from django.http import HttpResponseRedirect
|
||||
from django.contrib.sessions.models import Session
|
||||
|
||||
import urlparse
|
||||
import urllib
|
||||
|
@ -101,3 +102,8 @@ def gen_pgtiou():
|
|||
def gen_saml_id():
|
||||
"""Generate an saml id"""
|
||||
return _gen_ticket('_')
|
||||
|
||||
def get_session(request):
|
||||
if not request.session.exists(request.session.session_key):
|
||||
request.session.create()
|
||||
return Session.objects.get(session_key=request.session.session_key)
|
||||
|
|
|
@ -69,7 +69,10 @@ class LogoutMixin(object):
|
|||
def logout(self):
|
||||
"""effectively destroy CAS session"""
|
||||
try:
|
||||
user = models.User.objects.get(username=self.request.session.get("username"))
|
||||
user = models.User.objects.get(
|
||||
username=self.request.session.get("username"),
|
||||
session=utils.get_session(self.request)
|
||||
)
|
||||
user.logout(self.request)
|
||||
user.delete()
|
||||
except models.User.DoesNotExist:
|
||||
|
@ -151,7 +154,10 @@ class LoginView(View, LogoutMixin):
|
|||
elif not request.session.get("authenticated") or self.renew:
|
||||
self.init_form(request.POST)
|
||||
if self.form.is_valid():
|
||||
self.user = models.User.objects.get(username=self.form.cleaned_data['username'])
|
||||
self.user = models.User.objects.get(
|
||||
username=self.form.cleaned_data['username'],
|
||||
session=utils.get_session(self.request)
|
||||
)
|
||||
request.session.set_expiry(0)
|
||||
request.session["username"] = self.form.cleaned_data['username']
|
||||
request.session["warn"] = True if self.form.cleaned_data.get("warn") else False
|
||||
|
@ -179,6 +185,7 @@ class LoginView(View, LogoutMixin):
|
|||
|
||||
def init_form(self, values=None):
|
||||
self.form = forms.UserCredential(
|
||||
self.request,
|
||||
values,
|
||||
initial={
|
||||
'service':self.service,
|
||||
|
@ -254,7 +261,10 @@ class LoginView(View, LogoutMixin):
|
|||
def authenticated(self):
|
||||
"""Processing authenticated users"""
|
||||
try:
|
||||
self.user = models.User.objects.get(username=self.request.session.get("username"))
|
||||
self.user = models.User.objects.get(
|
||||
username=self.request.session.get("username"),
|
||||
session=utils.get_session(self.request)
|
||||
)
|
||||
except models.User.DoesNotExist:
|
||||
self.logout()
|
||||
return utils.redirect_params("cas_server:login", params=self.request.GET)
|
||||
|
@ -329,6 +339,7 @@ class Auth(View):
|
|||
if not username or not password or not service:
|
||||
return HttpResponse("no\n", content_type="text/plain")
|
||||
form = forms.UserCredential(
|
||||
request,
|
||||
request.POST,
|
||||
initial={
|
||||
'service':service,
|
||||
|
@ -338,18 +349,20 @@ class Auth(View):
|
|||
)
|
||||
if form.is_valid():
|
||||
try:
|
||||
user = models.User.objects.get(username=form.cleaned_data['username'])
|
||||
user = models.User.objects.get(
|
||||
username=form.cleaned_data['username'],
|
||||
session=utils.get_session(request)
|
||||
)
|
||||
# is the service allowed
|
||||
service_pattern = ServicePattern.validate(service)
|
||||
# is the current user allowed on this service
|
||||
service_pattern.check_user(user)
|
||||
# if the user has asked to be warned before any login to a service
|
||||
if not request.session.get("authenticated"):
|
||||
user.delete()
|
||||
return HttpResponse("yes\n", content_type="text/plain")
|
||||
except (ServicePattern.DoesNotExist, ServicePatternException) as error:
|
||||
print "error: %r" % error
|
||||
return HttpResponse("no\n", content_type="text/plain")
|
||||
else:
|
||||
print "bad password"
|
||||
return HttpResponse("no\n", content_type="text/plain")
|
||||
|
||||
class Validate(View):
|
||||
|
|
Loading…
Reference in a new issue