From ee987f6d00e13e86b7f5db7782d06293d4b7eb2e Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Sat, 14 Nov 2015 01:05:53 +0100 Subject: [PATCH] Remember up to 100 login ticket insted of 1 --- cas_server/models.py | 20 ++++++++------------ cas_server/views.py | 19 ++++++++++++------- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/cas_server/models.py b/cas_server/models.py index 6543cdc..c735537 100644 --- a/cas_server/models.py +++ b/cas_server/models.py @@ -68,15 +68,13 @@ class User(models.Model): """Sending SLO request to all services the user logged in""" async_list = [] session = FuturesSession(executor=ThreadPoolExecutor(max_workers=10)) - # first invalidate all PGTs - ticket_classes = [ProxyGrantingTicket, ProxyTicket, ServiceTicket] + # first invalidate all Tickets + ticket_classes = [ProxyGrantingTicket, ServiceTicket, ProxyTicket] for ticket_class in ticket_classes: - for ticket in ticket_class.objects.filter( - user=self, - validate=True if ticket_class != ProxyGrantingTicket else False, - ): + queryset = ticket_class.objects.filter(user=self) + for ticket in queryset: ticket.logout(request, session, async_list) - ticket.delete() + queryset.delete() for future in async_list: if future: try: @@ -361,7 +359,6 @@ class Ticket(models.Model): async_list = [] session = FuturesSession(executor=ThreadPoolExecutor(max_workers=10)) queryset = cls.objects.filter( - validate=True if cls != ProxyGrantingTicket else False, creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT)) ) for ticket in queryset: @@ -376,10 +373,9 @@ class Ticket(models.Model): def logout(self, request, session, async_list=None): """Send a SLO request to the ticket service""" - if isinstance(self, ProxyGrantingTicket): - # On logout invalidate the PGT - self.validate = True - self.save() + # On logout invalidate the Ticket + self.validate = True + self.save() if self.validate and self.single_log_out: try: xml = u""" 100: + self.request.session['lt'] = self.request.session['lt'][-100:] # check if send LT is valid - if lt_valid is None or lt_valid != lt_send: + if lt_valid is None or lt_send not in lt_valid: return False else: + self.request.session['lt'].remove(lt_send) + self.request.session['lt'] = self.request.session['lt'] return True def post(self, request, *args, **kwargs): @@ -194,7 +199,7 @@ class LoginView(View, LogoutMixin): if not self.check_lt(): values = self.request.POST.copy() # if not set a new LT and fail - values['lt'] = self.request.session['lt'] + values['lt'] = self.request.session['lt'][-1] self.init_form(values) return self.INVALID_LOGIN_TICKET elif not self.request.session.get("authenticated") or self.renew: @@ -227,7 +232,7 @@ class LoginView(View, LogoutMixin): def process_get(self): # generate a new LT if none is present - self.request.session['lt'] = self.request.session.get('lt', utils.gen_lt()) + self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()] if not self.request.session.get("authenticated") or self.renew: self.init_form() @@ -241,7 +246,7 @@ class LoginView(View, LogoutMixin): 'service': self.service, 'method': self.method, 'warn': self.request.session.get("warn"), - 'lt': self.request.session['lt'] + 'lt': self.request.session['lt'][-1] } )