Fix renew request from service

This commit is contained in:
Valentin Samir 2016-06-28 15:24:50 +02:00
parent fc57288c30
commit 16fb7b5021
3 changed files with 28 additions and 2 deletions

View file

@ -35,6 +35,7 @@ class UserCredential(forms.Form):
lt = forms.CharField(widget=forms.HiddenInput(), required=False) lt = forms.CharField(widget=forms.HiddenInput(), required=False)
method = forms.CharField(widget=forms.HiddenInput(), required=False) method = forms.CharField(widget=forms.HiddenInput(), required=False)
warn = forms.BooleanField(label=_('warn'), required=False) warn = forms.BooleanField(label=_('warn'), required=False)
renew = forms.BooleanField(widget=forms.HiddenInput(), required=False)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(UserCredential, self).__init__(*args, **kwargs) super(UserCredential, self).__init__(*args, **kwargs)
@ -46,6 +47,7 @@ class UserCredential(forms.Form):
cleaned_data["username"] = auth.username cleaned_data["username"] = auth.username
else: else:
raise forms.ValidationError(_(u"Bad user")) raise forms.ValidationError(_(u"Bad user"))
return cleaned_data
class TicketForm(forms.ModelForm): class TicketForm(forms.ModelForm):

View file

@ -447,6 +447,27 @@ class LoginTestCase(TestCase):
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], service) self.assertEqual(response["Location"], service)
def test_renew(self):
service = "https://www.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service, 'renew': 'on'})
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"Authentication renewal required by "
b"service example (https://www.example.com)"
) in response.content
)
params = copy_form(response.context["form"])
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
self.assertEqual(params["renew"], True)
response = client.post("/login", params)
self.assertEqual(response.status_code, 302)
ticket_value = response['Location'].split('ticket=')[-1]
ticket = models.ServiceTicket.objects.get(value=ticket_value)
self.assertEqual(ticket.renew, True)
class LogoutTestCase(TestCase): class LogoutTestCase(TestCase):

View file

@ -206,6 +206,7 @@ class LoginView(View, LogoutMixin):
self.ajax = 'HTTP_X_AJAX' in request.META self.ajax = 'HTTP_X_AJAX' in request.META
if request.POST.get('warned') and request.POST['warned'] != "False": if request.POST.get('warned') and request.POST['warned'] != "False":
self.warned = True self.warned = True
self.warn = request.POST.get('warn')
def gen_lt(self): def gen_lt(self):
"""Generate a new LoginTicket and add it to the list of valid LT for the user""" """Generate a new LoginTicket and add it to the list of valid LT for the user"""
@ -298,6 +299,7 @@ class LoginView(View, LogoutMixin):
self.gateway = request.GET.get('gateway') self.gateway = request.GET.get('gateway')
self.method = request.GET.get('method') self.method = request.GET.get('method')
self.ajax = 'HTTP_X_AJAX' in request.META self.ajax = 'HTTP_X_AJAX' in request.META
self.warn = request.GET.get('warn')
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
"""methode called on GET request on this view""" """methode called on GET request on this view"""
@ -322,7 +324,8 @@ class LoginView(View, LogoutMixin):
'service': self.service, 'service': self.service,
'method': self.method, 'method': self.method,
'warn': self.request.session.get("warn"), 'warn': self.request.session.get("warn"),
'lt': self.request.session['lt'][-1] 'lt': self.request.session['lt'][-1],
'renew': self.renew
} }
) )
@ -364,7 +367,7 @@ class LoginView(View, LogoutMixin):
redirect_url = self.user.get_service_url( redirect_url = self.user.get_service_url(
self.service, self.service,
service_pattern, service_pattern,
renew=self.renew renew=self.renewed
) )
if not self.ajax: if not self.ajax:
return HttpResponseRedirect(redirect_url) return HttpResponseRedirect(redirect_url)