diff --git a/cas_server/forms.py b/cas_server/forms.py index 1036cf5..e5d2176 100644 --- a/cas_server/forms.py +++ b/cas_server/forms.py @@ -41,6 +41,7 @@ class FederateSelect(forms.Form): method = forms.CharField(widget=forms.HiddenInput(), required=False) remember = forms.BooleanField(label=_('Remember the identity provider'), required=False) warn = forms.BooleanField(label=_('warn'), required=False) + renew = forms.BooleanField(widget=forms.HiddenInput(), required=False) class UserCredential(forms.Form): @@ -51,6 +52,7 @@ class UserCredential(forms.Form): lt = forms.CharField(widget=forms.HiddenInput(), required=False) method = forms.CharField(widget=forms.HiddenInput(), required=False) warn = forms.BooleanField(label=_('warn'), required=False) + renew = forms.BooleanField(widget=forms.HiddenInput(), required=False) def __init__(self, *args, **kwargs): super(UserCredential, self).__init__(*args, **kwargs) @@ -74,6 +76,7 @@ class FederateUserCredential(UserCredential): lt = forms.CharField(widget=forms.HiddenInput(), required=False) method = forms.CharField(widget=forms.HiddenInput(), required=False) warn = forms.BooleanField(widget=forms.HiddenInput(), required=False) + renew = forms.BooleanField(widget=forms.HiddenInput(), required=False) def clean(self): cleaned_data = super(FederateUserCredential, self).clean() diff --git a/cas_server/tests.py b/cas_server/tests.py index 916a6d4..d774fb2 100644 --- a/cas_server/tests.py +++ b/cas_server/tests.py @@ -447,6 +447,27 @@ class LoginTestCase(TestCase): self.assertEqual(response.status_code, 302) self.assertEqual(response["Location"], service) + def test_renew(self): + service = "https://www.example.com" + client = get_auth_client() + response = client.get("/login", {'service': service, 'renew': 'on'}) + self.assertEqual(response.status_code, 200) + self.assertTrue( + ( + b"Authentication renewal required by " + b"service example (https://www.example.com)" + ) in response.content + ) + params = copy_form(response.context["form"]) + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + self.assertEqual(params["renew"], True) + response = client.post("/login", params) + self.assertEqual(response.status_code, 302) + ticket_value = response['Location'].split('ticket=')[-1] + ticket = models.ServiceTicket.objects.get(value=ticket_value) + self.assertEqual(ticket.renew, True) + class LogoutTestCase(TestCase): diff --git a/cas_server/views.py b/cas_server/views.py index f1434b0..e7d5419 100644 --- a/cas_server/views.py +++ b/cas_server/views.py @@ -432,7 +432,8 @@ class LoginView(View, LogoutMixin): 'service': self.service, 'method': self.method, 'warn': self.warn or self.request.session.get("warn"), - 'lt': self.request.session['lt'][-1] + 'lt': self.request.session['lt'][-1], + 'renew': self.renew } if settings.CAS_FEDERATE: if self.username and self.ticket: @@ -489,7 +490,7 @@ class LoginView(View, LogoutMixin): redirect_url = self.user.get_service_url( self.service, service_pattern, - renew=self.renew + renew=self.renewed ) if not self.ajax: return HttpResponseRedirect(redirect_url)