Remember up to 100 login ticket insted of 1

This commit is contained in:
Valentin Samir 2015-11-14 01:05:53 +01:00
parent df9dd5364f
commit ee987f6d00
2 changed files with 20 additions and 19 deletions

View file

@ -68,15 +68,13 @@ class User(models.Model):
"""Sending SLO request to all services the user logged in""" """Sending SLO request to all services the user logged in"""
async_list = [] async_list = []
session = FuturesSession(executor=ThreadPoolExecutor(max_workers=10)) session = FuturesSession(executor=ThreadPoolExecutor(max_workers=10))
# first invalidate all PGTs # first invalidate all Tickets
ticket_classes = [ProxyGrantingTicket, ProxyTicket, ServiceTicket] ticket_classes = [ProxyGrantingTicket, ServiceTicket, ProxyTicket]
for ticket_class in ticket_classes: for ticket_class in ticket_classes:
for ticket in ticket_class.objects.filter( queryset = ticket_class.objects.filter(user=self)
user=self, for ticket in queryset:
validate=True if ticket_class != ProxyGrantingTicket else False,
):
ticket.logout(request, session, async_list) ticket.logout(request, session, async_list)
ticket.delete() queryset.delete()
for future in async_list: for future in async_list:
if future: if future:
try: try:
@ -361,7 +359,6 @@ class Ticket(models.Model):
async_list = [] async_list = []
session = FuturesSession(executor=ThreadPoolExecutor(max_workers=10)) session = FuturesSession(executor=ThreadPoolExecutor(max_workers=10))
queryset = cls.objects.filter( queryset = cls.objects.filter(
validate=True if cls != ProxyGrantingTicket else False,
creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT)) creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT))
) )
for ticket in queryset: for ticket in queryset:
@ -376,10 +373,9 @@ class Ticket(models.Model):
def logout(self, request, session, async_list=None): def logout(self, request, session, async_list=None):
"""Send a SLO request to the ticket service""" """Send a SLO request to the ticket service"""
if isinstance(self, ProxyGrantingTicket): # On logout invalidate the Ticket
# On logout invalidate the PGT self.validate = True
self.validate = True self.save()
self.save()
if self.validate and self.single_log_out: if self.validate and self.single_log_out:
try: try:
xml = u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" xml = u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"

View file

@ -76,10 +76,11 @@ class LogoutMixin(object):
session_key=self.request.session.session_key session_key=self.request.session.session_key
) )
self.clean_session_variables() self.clean_session_variables()
self.request.session.flush()
user.logout(self.request) user.logout(self.request)
user.delete() user.delete()
except models.User.DoesNotExist: except models.User.DoesNotExist:
self.clean_session_variables() pass
class LogoutView(View, LogoutMixin): class LogoutView(View, LogoutMixin):
@ -148,15 +149,19 @@ class LoginView(View, LogoutMixin):
def check_lt(self): def check_lt(self):
# save LT for later check # save LT for later check
lt_valid = self.request.session.get('lt') lt_valid = self.request.session.get('lt', [])
lt_send = self.request.POST.get('lt') lt_send = self.request.POST.get('lt')
# generate a new LT (by posting the LT has been consumed) # generate a new LT (by posting the LT has been consumed)
self.request.session['lt'] = utils.gen_lt() self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
if len(self.request.session['lt']) > 100:
self.request.session['lt'] = self.request.session['lt'][-100:]
# check if send LT is valid # check if send LT is valid
if lt_valid is None or lt_valid != lt_send: if lt_valid is None or lt_send not in lt_valid:
return False return False
else: else:
self.request.session['lt'].remove(lt_send)
self.request.session['lt'] = self.request.session['lt']
return True return True
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
@ -194,7 +199,7 @@ class LoginView(View, LogoutMixin):
if not self.check_lt(): if not self.check_lt():
values = self.request.POST.copy() values = self.request.POST.copy()
# if not set a new LT and fail # if not set a new LT and fail
values['lt'] = self.request.session['lt'] values['lt'] = self.request.session['lt'][-1]
self.init_form(values) self.init_form(values)
return self.INVALID_LOGIN_TICKET return self.INVALID_LOGIN_TICKET
elif not self.request.session.get("authenticated") or self.renew: elif not self.request.session.get("authenticated") or self.renew:
@ -227,7 +232,7 @@ class LoginView(View, LogoutMixin):
def process_get(self): def process_get(self):
# generate a new LT if none is present # generate a new LT if none is present
self.request.session['lt'] = self.request.session.get('lt', utils.gen_lt()) self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
if not self.request.session.get("authenticated") or self.renew: if not self.request.session.get("authenticated") or self.renew:
self.init_form() self.init_form()
@ -241,7 +246,7 @@ class LoginView(View, LogoutMixin):
'service': self.service, 'service': self.service,
'method': self.method, 'method': self.method,
'warn': self.request.session.get("warn"), 'warn': self.request.session.get("warn"),
'lt': self.request.session['lt'] 'lt': self.request.session['lt'][-1]
} }
) )