From 16fb7b502139bcda31a846d73757d9268869f084 Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Tue, 28 Jun 2016 15:24:50 +0200 Subject: [PATCH] Fix renew request from service --- cas_server/forms.py | 2 ++ cas_server/tests.py | 21 +++++++++++++++++++++ cas_server/views.py | 7 +++++-- 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/cas_server/forms.py b/cas_server/forms.py index f970ccd..28893cf 100644 --- a/cas_server/forms.py +++ b/cas_server/forms.py @@ -35,6 +35,7 @@ class UserCredential(forms.Form): lt = forms.CharField(widget=forms.HiddenInput(), required=False) method = forms.CharField(widget=forms.HiddenInput(), required=False) warn = forms.BooleanField(label=_('warn'), required=False) + renew = forms.BooleanField(widget=forms.HiddenInput(), required=False) def __init__(self, *args, **kwargs): super(UserCredential, self).__init__(*args, **kwargs) @@ -46,6 +47,7 @@ class UserCredential(forms.Form): cleaned_data["username"] = auth.username else: raise forms.ValidationError(_(u"Bad user")) + return cleaned_data class TicketForm(forms.ModelForm): diff --git a/cas_server/tests.py b/cas_server/tests.py index 916a6d4..d774fb2 100644 --- a/cas_server/tests.py +++ b/cas_server/tests.py @@ -447,6 +447,27 @@ class LoginTestCase(TestCase): self.assertEqual(response.status_code, 302) self.assertEqual(response["Location"], service) + def test_renew(self): + service = "https://www.example.com" + client = get_auth_client() + response = client.get("/login", {'service': service, 'renew': 'on'}) + self.assertEqual(response.status_code, 200) + self.assertTrue( + ( + b"Authentication renewal required by " + b"service example (https://www.example.com)" + ) in response.content + ) + params = copy_form(response.context["form"]) + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + self.assertEqual(params["renew"], True) + response = client.post("/login", params) + self.assertEqual(response.status_code, 302) + ticket_value = response['Location'].split('ticket=')[-1] + ticket = models.ServiceTicket.objects.get(value=ticket_value) + self.assertEqual(ticket.renew, True) + class LogoutTestCase(TestCase): diff --git a/cas_server/views.py b/cas_server/views.py index a48dd7e..c9dd05b 100644 --- a/cas_server/views.py +++ b/cas_server/views.py @@ -206,6 +206,7 @@ class LoginView(View, LogoutMixin): self.ajax = 'HTTP_X_AJAX' in request.META if request.POST.get('warned') and request.POST['warned'] != "False": self.warned = True + self.warn = request.POST.get('warn') def gen_lt(self): """Generate a new LoginTicket and add it to the list of valid LT for the user""" @@ -298,6 +299,7 @@ class LoginView(View, LogoutMixin): self.gateway = request.GET.get('gateway') self.method = request.GET.get('method') self.ajax = 'HTTP_X_AJAX' in request.META + self.warn = request.GET.get('warn') def get(self, request, *args, **kwargs): """methode called on GET request on this view""" @@ -322,7 +324,8 @@ class LoginView(View, LogoutMixin): 'service': self.service, 'method': self.method, 'warn': self.request.session.get("warn"), - 'lt': self.request.session['lt'][-1] + 'lt': self.request.session['lt'][-1], + 'renew': self.renew } ) @@ -364,7 +367,7 @@ class LoginView(View, LogoutMixin): redirect_url = self.user.get_service_url( self.service, service_pattern, - renew=self.renew + renew=self.renewed ) if not self.ajax: return HttpResponseRedirect(redirect_url)