diff --git a/cas_server/models.py b/cas_server/models.py index 675260c..2314c4f 100644 --- a/cas_server/models.py +++ b/cas_server/models.py @@ -595,6 +595,12 @@ class Ticket(models.Model): ) ) + @staticmethod + def get_class(ticket): + for ticket_class in [ServiceTicket, ProxyTicket, ProxyGrantingTicket]: + if ticket.startswith(ticket_class.PREFIX): + return ticket_class + @python_2_unicode_compatible class ServiceTicket(Ticket): diff --git a/cas_server/tests/test_view.py b/cas_server/tests/test_view.py index 49fa2d2..bfa4c32 100644 --- a/cas_server/tests/test_view.py +++ b/cas_server/tests/test_view.py @@ -990,6 +990,52 @@ class ValidateTestCase(TestCase): self.assertEqual(response.status_code, 200) self.assertEqual(response.content, b'yes\ntest\n') + def test_validate_service_renew(self): + """test with a valid (ticket, service) asking for auth renewal""" + # case 1 client is renewing and service ask for renew + (client1, response) = get_auth_client(renew="True", service=self.service) + self.assertEqual(response.status_code, 302) + ticket_value = response['Location'].split('ticket=')[-1] + # get a bare client + client = Client() + # requesting validation with a good (ticket, service) + response = client.get( + '/validate', + {'ticket': ticket_value, 'service': self.service, 'renew': 'True'} + ) + # the validation should succes with username settings.CAS_TEST_USER and transmit + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'yes\ntest\n') + + # cas2 client is renewing and service do not ask for renew + (client2, response) = get_auth_client(renew="True", service=self.service) + self.assertEqual(response.status_code, 302) + ticket_value = response['Location'].split('ticket=')[-1] + # get a bare client + client = Client() + # requesting validation with a good (ticket, service) + response = client.get( + '/validate', + {'ticket': ticket_value, 'service': self.service} + ) + # the validation should succes with username settings.CAS_TEST_USER and transmit + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'yes\ntest\n') + + # case 3, client is not renewing and service ask for renew (client is authenticated) + response = client2.get("/login", {"service": self.service}) + self.assertEqual(response.status_code, 302) + ticket_value = response['Location'].split('ticket=')[-1] + client = Client() + # requesting validation with a good (ticket, service) + response = client.get( + '/validate', + {'ticket': ticket_value, 'service': self.service, 'renew': 'True'} + ) + # the validation should fail + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'no\n') + def test_validate_view_badservice(self): """test for a valid ticket but bad service""" ticket = get_user_ticket_request(self.service)[1] @@ -1144,6 +1190,55 @@ class ValidateServiceTestCase(TestCase, XmlContent): # the attributes settings.CAS_TEST_ATTRIBUTES self.assert_success(response, settings.CAS_TEST_USER, settings.CAS_TEST_ATTRIBUTES) + def test_validate_service_renew(self): + """test with a valid (ticket, service) asking for auth renewal""" + # case 1 client is renewing and service ask for renew + (client1, response) = get_auth_client(renew="True", service=self.service) + self.assertEqual(response.status_code, 302) + ticket_value = response['Location'].split('ticket=')[-1] + # get a bare client + client = Client() + # requesting validation with a good (ticket, service) + response = client.get( + '/serviceValidate', + {'ticket': ticket_value, 'service': self.service, 'renew': 'True'} + ) + # the validation should succes with username settings.CAS_TEST_USER and transmit + # the attributes settings.CAS_TEST_ATTRIBUTES + self.assert_success(response, settings.CAS_TEST_USER, settings.CAS_TEST_ATTRIBUTES) + + # cas2 client is renewing and service do not ask for renew + (client2, response) = get_auth_client(renew="True", service=self.service) + self.assertEqual(response.status_code, 302) + ticket_value = response['Location'].split('ticket=')[-1] + # get a bare client + client = Client() + # requesting validation with a good (ticket, service) + response = client.get( + '/serviceValidate', + {'ticket': ticket_value, 'service': self.service} + ) + # the validation should succes with username settings.CAS_TEST_USER and transmit + # the attributes settings.CAS_TEST_ATTRIBUTES + self.assert_success(response, settings.CAS_TEST_USER, settings.CAS_TEST_ATTRIBUTES) + + # case 3, client is not renewing and service ask for renew (client is authenticated) + response = client2.get("/login", {"service": self.service}) + self.assertEqual(response.status_code, 302) + ticket_value = response['Location'].split('ticket=')[-1] + client = Client() + # requesting validation with a good (ticket, service) + response = client.get( + '/serviceValidate', + {'ticket': ticket_value, 'service': self.service, 'renew': 'True'} + ) + # the validation should fail + self.assert_error( + response, + "INVALID_TICKET", + 'ticket not found' + ) + def test_validate_service_view_ok_one_attribute(self): """ test with a valid (ticket, service), the username and diff --git a/cas_server/tests/utils.py b/cas_server/tests/utils.py index 67e7c7b..515b653 100644 --- a/cas_server/tests/utils.py +++ b/cas_server/tests/utils.py @@ -74,10 +74,13 @@ def get_auth_client(**update): params["password"] = settings.CAS_TEST_PASSWORD params.update(update) - client.post('/login', params) + response = client.post('/login', params) assert client.session.get("authenticated") - return client + if params.get("service"): + return (client, response) + else: + return client def get_user_ticket_request(service): diff --git a/cas_server/views.py b/cas_server/views.py index c85cc52..9d3fcc2 100644 --- a/cas_server/views.py +++ b/cas_server/views.py @@ -751,13 +751,16 @@ class Validate(View): renew = True if request.GET.get('renew') else False if service and ticket: try: - ticket = ServiceTicket.objects.get( + ticket_queryset = ServiceTicket.objects.filter( value=ticket, service=service, validate=False, - renew=renew, creation__gt=(timezone.now() - timedelta(seconds=ServiceTicket.VALIDITY)) ) + if renew: + ticket = ticket_queryset.get(renew=True) + else: + ticket = ticket_queryset.get() ticket.validate = True ticket.save() logger.info( @@ -893,22 +896,20 @@ class ValidateService(View, AttributesMixin): """fetch the ticket angains the database and check its validity""" try: proxies = [] - if self.ticket.startswith(ServiceTicket.PREFIX): - ticket = ServiceTicket.objects.get( + ticket_class = models.Ticket.get_class(self.ticket) + if ticket_class: + ticket_queryset = ticket_class.objects.filter( value=self.ticket, validate=False, - renew=self.renew, creation__gt=(timezone.now() - timedelta(seconds=ServiceTicket.VALIDITY)) ) - elif self.allow_proxy_ticket and self.ticket.startswith(ProxyTicket.PREFIX): - ticket = ProxyTicket.objects.get( - value=self.ticket, - validate=False, - renew=self.renew, - creation__gt=(timezone.now() - timedelta(seconds=ProxyTicket.VALIDITY)) - ) - for prox in ticket.proxies.all(): - proxies.append(prox.url) + if self.renew: + ticket = ticket_queryset.get(renew=True) + else: + ticket = ticket_queryset.get() + if ticket_class == models.ProxyTicket: + for prox in ticket.proxies.all(): + proxies.append(prox.url) else: raise ValidateError(u'INVALID_TICKET', self.ticket) ticket.validate = True @@ -1140,18 +1141,13 @@ class SamlValidate(View, AttributesMixin): try: auth_req = self.root.getchildren()[1].getchildren()[0] ticket = auth_req.getchildren()[0].text - if ticket.startswith(ServiceTicket.PREFIX): - ticket = ServiceTicket.objects.get( + ticket_class = models.Ticket.get_class(ticket) + if ticket_class: + ticket = ticket_class.objects.get( value=ticket, validate=False, creation__gt=(timezone.now() - timedelta(seconds=ServiceTicket.VALIDITY)) ) - elif ticket.startswith(ProxyTicket.PREFIX): - ticket = ProxyTicket.objects.get( - value=ticket, - validate=False, - creation__gt=(timezone.now() - timedelta(seconds=ProxyTicket.VALIDITY)) - ) else: raise SamlValidateError( u'AuthnFailed',