By session logout
This commit is contained in:
parent
41fcc06200
commit
245086f6ef
7 changed files with 111 additions and 22 deletions
|
@ -22,22 +22,24 @@ class UserCredential(forms.Form):
|
||||||
username = forms.CharField(label=_('login'))
|
username = forms.CharField(label=_('login'))
|
||||||
service = forms.CharField(widget=forms.HiddenInput(), required=False)
|
service = forms.CharField(widget=forms.HiddenInput(), required=False)
|
||||||
password = forms.CharField(label=_('password'), widget=forms.PasswordInput)
|
password = forms.CharField(label=_('password'), widget=forms.PasswordInput)
|
||||||
lt = forms.CharField(widget=forms.HiddenInput())
|
lt = forms.CharField(widget=forms.HiddenInput(), required=False)
|
||||||
method = forms.CharField(widget=forms.HiddenInput(), required=False)
|
method = forms.CharField(widget=forms.HiddenInput(), required=False)
|
||||||
warn = forms.BooleanField(label=_('warn'), required=False)
|
warn = forms.BooleanField(label=_('warn'), required=False)
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, request, *args, **kwargs):
|
||||||
|
self.request = request
|
||||||
super(UserCredential, self).__init__(*args, **kwargs)
|
super(UserCredential, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
def clean(self):
|
def clean(self):
|
||||||
cleaned_data = super(UserCredential, self).clean()
|
cleaned_data = super(UserCredential, self).clean()
|
||||||
auth = utils.import_attr(settings.CAS_AUTH_CLASS)(cleaned_data.get("username"))
|
auth = utils.import_attr(settings.CAS_AUTH_CLASS)(cleaned_data.get("username"))
|
||||||
if auth.test_password(cleaned_data.get("password")):
|
if auth.test_password(cleaned_data.get("password")):
|
||||||
|
session = utils.get_session(self.request)
|
||||||
try:
|
try:
|
||||||
user = models.User.objects.get(username=auth.username)
|
user = models.User.objects.get(username=auth.username, session=session)
|
||||||
user.save()
|
user.save()
|
||||||
except models.User.DoesNotExist:
|
except models.User.DoesNotExist:
|
||||||
user = models.User.objects.create(username=auth.username)
|
user = models.User.objects.create(username=auth.username, session=session)
|
||||||
user.save()
|
user.save()
|
||||||
else:
|
else:
|
||||||
raise forms.ValidationError(_(u"Bad user"))
|
raise forms.ValidationError(_(u"Bad user"))
|
||||||
|
|
|
@ -8,5 +8,6 @@ class Command(BaseCommand):
|
||||||
help = _(u"Clean old trickets")
|
help = _(u"Clean old trickets")
|
||||||
|
|
||||||
def handle(self, *args, **options):
|
def handle(self, *args, **options):
|
||||||
|
models.User.clean_old_entries()
|
||||||
for ticket_class in [models.ServiceTicket, models.ProxyTicket, models.ProxyGrantingTicket]:
|
for ticket_class in [models.ServiceTicket, models.ProxyTicket, models.ProxyGrantingTicket]:
|
||||||
ticket_class.clean()
|
ticket_class.clean_old_entries()
|
||||||
|
|
31
cas_server/migrations/0019_auto_20150609_1903.py
Normal file
31
cas_server/migrations/0019_auto_20150609_1903.py
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
from django.db import models, migrations
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('sessions', '0001_initial'),
|
||||||
|
('cas_server', '0018_auto_20150608_1621'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name='user',
|
||||||
|
name='session',
|
||||||
|
field=models.OneToOneField(related_name='cas_server_user', null=True, blank=True, to='sessions.Session'),
|
||||||
|
preserve_default=True,
|
||||||
|
),
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name='user',
|
||||||
|
name='username',
|
||||||
|
field=models.CharField(max_length=30),
|
||||||
|
preserve_default=True,
|
||||||
|
),
|
||||||
|
migrations.AlterUniqueTogether(
|
||||||
|
name='user',
|
||||||
|
unique_together=set([('username', 'session')]),
|
||||||
|
),
|
||||||
|
]
|
21
cas_server/migrations/0020_auto_20150609_1917.py
Normal file
21
cas_server/migrations/0020_auto_20150609_1917.py
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
from django.db import models, migrations
|
||||||
|
import django.db.models.deletion
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
('cas_server', '0019_auto_20150609_1903'),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name='user',
|
||||||
|
name='session',
|
||||||
|
field=models.OneToOneField(related_name='cas_server_user', null=True, on_delete=django.db.models.deletion.SET_NULL, blank=True, to='sessions.Session'),
|
||||||
|
preserve_default=True,
|
||||||
|
),
|
||||||
|
]
|
|
@ -17,6 +17,7 @@ from django.db.models import Q
|
||||||
from django.contrib import messages
|
from django.contrib import messages
|
||||||
from django.utils.translation import ugettext_lazy as _
|
from django.utils.translation import ugettext_lazy as _
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
from django.contrib.sessions.models import Session
|
||||||
from picklefield.fields import PickledObjectField
|
from picklefield.fields import PickledObjectField
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
@ -30,18 +31,31 @@ import utils
|
||||||
|
|
||||||
class User(models.Model):
|
class User(models.Model):
|
||||||
"""A user logged into the CAS"""
|
"""A user logged into the CAS"""
|
||||||
username = models.CharField(max_length=30, unique=True)
|
class Meta:
|
||||||
|
unique_together = ("username", "session")
|
||||||
|
session = models.OneToOneField(Session, related_name="cas_server_user", blank=True, null=True, on_delete=models.SET_NULL)
|
||||||
|
username = models.CharField(max_length=30)
|
||||||
date = models.DateTimeField(auto_now_add=True, auto_now=True)
|
date = models.DateTimeField(auto_now_add=True, auto_now=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def clean_old_entries(cls):
|
||||||
|
users = cls.objects.filter(session=None)
|
||||||
|
for user in users:
|
||||||
|
user.logout()
|
||||||
|
users.delete()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def attributs(self):
|
def attributs(self):
|
||||||
"""return a fresh dict for the user attributs"""
|
"""return a fresh dict for the user attributs"""
|
||||||
return utils.import_attr(settings.CAS_AUTH_CLASS)(self.username).attributs()
|
return utils.import_attr(settings.CAS_AUTH_CLASS)(self.username).attributs()
|
||||||
|
|
||||||
def __unicode__(self):
|
def __unicode__(self):
|
||||||
|
if self.session:
|
||||||
|
return u"%s - %s" % (self.username, self.session.session_key)
|
||||||
|
else:
|
||||||
return self.username
|
return self.username
|
||||||
|
|
||||||
def logout(self, request):
|
def logout(self, request=None):
|
||||||
"""Sending SLO request to all services the user logged in"""
|
"""Sending SLO request to all services the user logged in"""
|
||||||
async_list = []
|
async_list = []
|
||||||
session = FuturesSession(executor=ThreadPoolExecutor(max_workers=10))
|
session = FuturesSession(executor=ThreadPoolExecutor(max_workers=10))
|
||||||
|
@ -59,6 +73,7 @@ class User(models.Model):
|
||||||
try:
|
try:
|
||||||
future.result()
|
future.result()
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
|
if request is not None:
|
||||||
error = utils.unpack_nested_exception(error)
|
error = utils.unpack_nested_exception(error)
|
||||||
messages.add_message(
|
messages.add_message(
|
||||||
request,
|
request,
|
||||||
|
@ -309,7 +324,7 @@ class Ticket(models.Model):
|
||||||
return u"Ticket(%s, %s)" % (self.user, self.service)
|
return u"Ticket(%s, %s)" % (self.user, self.service)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def clean(cls):
|
def clean_old_entries(cls):
|
||||||
"""Remove old ticket and send SLO to timed-out services"""
|
"""Remove old ticket and send SLO to timed-out services"""
|
||||||
# removing old validated ticket and non validated expired tickets
|
# removing old validated ticket and non validated expired tickets
|
||||||
cls.objects.filter(
|
cls.objects.filter(
|
||||||
|
|
|
@ -15,6 +15,7 @@ from .default_settings import settings
|
||||||
from django.utils.importlib import import_module
|
from django.utils.importlib import import_module
|
||||||
from django.core.urlresolvers import reverse
|
from django.core.urlresolvers import reverse
|
||||||
from django.http import HttpResponseRedirect
|
from django.http import HttpResponseRedirect
|
||||||
|
from django.contrib.sessions.models import Session
|
||||||
|
|
||||||
import urlparse
|
import urlparse
|
||||||
import urllib
|
import urllib
|
||||||
|
@ -101,3 +102,8 @@ def gen_pgtiou():
|
||||||
def gen_saml_id():
|
def gen_saml_id():
|
||||||
"""Generate an saml id"""
|
"""Generate an saml id"""
|
||||||
return _gen_ticket('_')
|
return _gen_ticket('_')
|
||||||
|
|
||||||
|
def get_session(request):
|
||||||
|
if not request.session.exists(request.session.session_key):
|
||||||
|
request.session.create()
|
||||||
|
return Session.objects.get(session_key=request.session.session_key)
|
||||||
|
|
|
@ -69,7 +69,10 @@ class LogoutMixin(object):
|
||||||
def logout(self):
|
def logout(self):
|
||||||
"""effectively destroy CAS session"""
|
"""effectively destroy CAS session"""
|
||||||
try:
|
try:
|
||||||
user = models.User.objects.get(username=self.request.session.get("username"))
|
user = models.User.objects.get(
|
||||||
|
username=self.request.session.get("username"),
|
||||||
|
session=utils.get_session(self.request)
|
||||||
|
)
|
||||||
user.logout(self.request)
|
user.logout(self.request)
|
||||||
user.delete()
|
user.delete()
|
||||||
except models.User.DoesNotExist:
|
except models.User.DoesNotExist:
|
||||||
|
@ -151,7 +154,10 @@ class LoginView(View, LogoutMixin):
|
||||||
elif not request.session.get("authenticated") or self.renew:
|
elif not request.session.get("authenticated") or self.renew:
|
||||||
self.init_form(request.POST)
|
self.init_form(request.POST)
|
||||||
if self.form.is_valid():
|
if self.form.is_valid():
|
||||||
self.user = models.User.objects.get(username=self.form.cleaned_data['username'])
|
self.user = models.User.objects.get(
|
||||||
|
username=self.form.cleaned_data['username'],
|
||||||
|
session=utils.get_session(self.request)
|
||||||
|
)
|
||||||
request.session.set_expiry(0)
|
request.session.set_expiry(0)
|
||||||
request.session["username"] = self.form.cleaned_data['username']
|
request.session["username"] = self.form.cleaned_data['username']
|
||||||
request.session["warn"] = True if self.form.cleaned_data.get("warn") else False
|
request.session["warn"] = True if self.form.cleaned_data.get("warn") else False
|
||||||
|
@ -179,6 +185,7 @@ class LoginView(View, LogoutMixin):
|
||||||
|
|
||||||
def init_form(self, values=None):
|
def init_form(self, values=None):
|
||||||
self.form = forms.UserCredential(
|
self.form = forms.UserCredential(
|
||||||
|
self.request,
|
||||||
values,
|
values,
|
||||||
initial={
|
initial={
|
||||||
'service':self.service,
|
'service':self.service,
|
||||||
|
@ -254,7 +261,10 @@ class LoginView(View, LogoutMixin):
|
||||||
def authenticated(self):
|
def authenticated(self):
|
||||||
"""Processing authenticated users"""
|
"""Processing authenticated users"""
|
||||||
try:
|
try:
|
||||||
self.user = models.User.objects.get(username=self.request.session.get("username"))
|
self.user = models.User.objects.get(
|
||||||
|
username=self.request.session.get("username"),
|
||||||
|
session=utils.get_session(self.request)
|
||||||
|
)
|
||||||
except models.User.DoesNotExist:
|
except models.User.DoesNotExist:
|
||||||
self.logout()
|
self.logout()
|
||||||
return utils.redirect_params("cas_server:login", params=self.request.GET)
|
return utils.redirect_params("cas_server:login", params=self.request.GET)
|
||||||
|
@ -329,6 +339,7 @@ class Auth(View):
|
||||||
if not username or not password or not service:
|
if not username or not password or not service:
|
||||||
return HttpResponse("no\n", content_type="text/plain")
|
return HttpResponse("no\n", content_type="text/plain")
|
||||||
form = forms.UserCredential(
|
form = forms.UserCredential(
|
||||||
|
request,
|
||||||
request.POST,
|
request.POST,
|
||||||
initial={
|
initial={
|
||||||
'service':service,
|
'service':service,
|
||||||
|
@ -338,18 +349,20 @@ class Auth(View):
|
||||||
)
|
)
|
||||||
if form.is_valid():
|
if form.is_valid():
|
||||||
try:
|
try:
|
||||||
user = models.User.objects.get(username=form.cleaned_data['username'])
|
user = models.User.objects.get(
|
||||||
|
username=form.cleaned_data['username'],
|
||||||
|
session=utils.get_session(request)
|
||||||
|
)
|
||||||
# is the service allowed
|
# is the service allowed
|
||||||
service_pattern = ServicePattern.validate(service)
|
service_pattern = ServicePattern.validate(service)
|
||||||
# is the current user allowed on this service
|
# is the current user allowed on this service
|
||||||
service_pattern.check_user(user)
|
service_pattern.check_user(user)
|
||||||
# if the user has asked to be warned before any login to a service
|
if not request.session.get("authenticated"):
|
||||||
|
user.delete()
|
||||||
return HttpResponse("yes\n", content_type="text/plain")
|
return HttpResponse("yes\n", content_type="text/plain")
|
||||||
except (ServicePattern.DoesNotExist, ServicePatternException) as error:
|
except (ServicePattern.DoesNotExist, ServicePatternException) as error:
|
||||||
print "error: %r" % error
|
|
||||||
return HttpResponse("no\n", content_type="text/plain")
|
return HttpResponse("no\n", content_type="text/plain")
|
||||||
else:
|
else:
|
||||||
print "bad password"
|
|
||||||
return HttpResponse("no\n", content_type="text/plain")
|
return HttpResponse("no\n", content_type="text/plain")
|
||||||
|
|
||||||
class Validate(View):
|
class Validate(View):
|
||||||
|
|
Loading…
Reference in a new issue