diff --git a/cas_server/tests/test_federate.py b/cas_server/tests/test_federate.py index 1418b15..42bef71 100644 --- a/cas_server/tests/test_federate.py +++ b/cas_server/tests/test_federate.py @@ -88,19 +88,23 @@ class FederateAuthLoginLogoutTestCase( response = client.post('/federate', params) # we are redirected to the provider CAS client url self.assertEqual(response.status_code, 302) - self.assertEqual(response["Location"], '%s/federate/%s' % ( + self.assertEqual(response["Location"], '%s/federate/%s%s' % ( 'http://testserver' if django.VERSION < (1, 9) else "", - provider.suffix + provider.suffix, + "?remember=on" if remember else "" )) # let's follow the redirect - response = client.get('/federate/%s' % provider.suffix) + response = client.get( + '/federate/%s%s' % (provider.suffix, "?remember=on" if remember else "") + ) # we are redirected to the provider CAS for authentication self.assertEqual(response.status_code, 302) self.assertEqual( response["Location"], - "%s/login?service=http%%3A%%2F%%2Ftestserver%%2Ffederate%%2F%s" % ( + "%s/login?service=http%%3A%%2F%%2Ftestserver%%2Ffederate%%2F%s%s" % ( provider.server_url, - provider.suffix + provider.suffix, + "%3Fremember%3Don" if remember else "" ) ) # let's generate a ticket @@ -108,7 +112,10 @@ class FederateAuthLoginLogoutTestCase( # we lauch a dummy CAS server that only validate once for the service # http://testserver/federate/example.com with `ticket` tests_utils.DummyCAS.run( - ("http://testserver/federate/%s" % provider.suffix).encode("ascii"), + ("http://testserver/federate/%s%s" % ( + provider.suffix, + "?remember=on" if remember else "" + )).encode("ascii"), ticket.encode("ascii"), settings.CAS_TEST_USER.encode("utf8"), [], @@ -116,7 +123,13 @@ class FederateAuthLoginLogoutTestCase( ) # we normally provide a good ticket and should be redirected to /login as the ticket # get successfully validated again the dummy CAS - response = client.get('/federate/%s' % provider.suffix, {'ticket': ticket}) + response = client.get( + '/federate/%s' % provider.suffix, + {'ticket': ticket, 'remember': 'on' if remember else ''} + ) + if remember: + self.assertIn("_remember_provider", client.cookies) + self.assertEqual(client.cookies["_remember_provider"].value, provider.suffix) self.assertEqual(response.status_code, 302) self.assertEqual(response["Location"], "%s/login" % ( 'http://testserver' if django.VERSION < (1, 9) else "" diff --git a/cas_server/views.py b/cas_server/views.py index b6a8e5f..2a74c4f 100644 --- a/cas_server/views.py +++ b/cas_server/views.py @@ -218,8 +218,7 @@ class FederateAuth(View): """ return super(FederateAuth, self).dispatch(request, *args, **kwargs) - @staticmethod - def get_cas_client(request, provider): + def get_cas_client(self, request, provider): """ return a CAS client object matching provider @@ -231,6 +230,7 @@ class FederateAuth(View): """ # compute the current url, ignoring ticket dans provider GET parameters service_url = utils.get_current_url(request, {"ticket", "provider"}) + self.service_url = service_url return CASFederateValidateUser(provider, service_url) def post(self, request, provider=None): @@ -264,7 +264,7 @@ class FederateAuth(View): if form.is_valid(): params = utils.copy_params( request.POST, - ignore={"provider", "csrfmiddlewaretoken", "ticket", "lt", "remember"} + ignore={"provider", "csrfmiddlewaretoken", "ticket", "lt"} ) if params.get("renew") == "False": del params["renew"] @@ -273,17 +273,7 @@ class FederateAuth(View): kwargs=dict(provider=form.cleaned_data["provider"].suffix), params=params ) - response = HttpResponseRedirect(url) - # If the user has checked "remember my identity provider" store it in a cookie - if form.cleaned_data["remember"]: - max_age = settings.CAS_FEDERATE_REMEMBER_TIMEOUT - utils.set_cookie( - response, - "_remember_provider", - form.cleaned_data["provider"].suffix, - max_age - ) - return response + return HttpResponseRedirect(url) else: return redirect("cas_server:login") @@ -323,7 +313,7 @@ class FederateAuth(View): auth.provider.server_url ) ) - params = utils.copy_params(request.GET, ignore={"ticket"}) + params = utils.copy_params(request.GET, ignore={"ticket", "remember"}) request.session["federate_username"] = auth.federated_username request.session["federate_ticket"] = ticket auth.register_slo( @@ -334,13 +324,28 @@ class FederateAuth(View): # redirect to the the login page for the user to become authenticated # thanks to the `federate_username` and `federate_ticket` session parameters url = utils.reverse_params("cas_server:login", params) - return HttpResponseRedirect(url) + response = HttpResponseRedirect(url) + # If the user has checked "remember my identity provider" store it in a + # cookie + if request.GET.get("remember"): + max_age = settings.CAS_FEDERATE_REMEMBER_TIMEOUT + utils.set_cookie( + response, + "_remember_provider", + provider.suffix, + max_age + ) + return response # else redirect to the identity provider CAS login page else: logger.info( - "Got a invalid ticket for %s from %s. Retrying to authenticate" % ( - auth.username, - auth.provider.server_url + ( + "Got a invalid ticket %s from %s for service %s. " + "Retrying to authenticate" + ) % ( + ticket, + auth.provider.server_url, + self.service_url ) ) return HttpResponseRedirect(auth.get_login_url())