diff --git a/.coveragerc b/.coveragerc index abcfd58..aa192af 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,7 +1,9 @@ [run] branch = True source = cas_server -omit = cas_server/migrations* +omit = + cas_server/migrations* + cas_server/management/* [report] exclude_lines = diff --git a/cas_server/models.py b/cas_server/models.py index 83433a5..0694866 100644 --- a/cas_server/models.py +++ b/cas_server/models.py @@ -84,7 +84,7 @@ class User(models.Model): ticket.logout(session, async_list) queryset.delete() for future in async_list: - if future: + if future: # pragma: no branch (should always be true) try: future.result() except Exception as error: @@ -112,13 +112,21 @@ class User(models.Model): (a.name, a.replace if a.replace else a.name) for a in service_pattern.attributs.all() ) replacements = dict( - (a.name, (a.pattern, a.replace)) for a in service_pattern.replacements.all() + (a.attribut, (a.pattern, a.replace)) for a in service_pattern.replacements.all() ) service_attributs = {} for (key, value) in self.attributs.items(): if key in attributs or '*' in attributs: if key in replacements: - value = re.sub(replacements[key][0], replacements[key][1], value) + if isinstance(value, list): + for index, subval in enumerate(value): + value[index] = re.sub( + replacements[key][0], + replacements[key][1], + subval + ) + else: + value = re.sub(replacements[key][0], replacements[key][1], value) service_attributs[attributs.get(key, key)] = value ticket = ticket_class.objects.create( user=self, @@ -396,31 +404,30 @@ class Ticket(models.Model): ).delete() # sending SLO to timed-out validated tickets - if cls.TIMEOUT and cls.TIMEOUT > 0: - async_list = [] - session = FuturesSession( - executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS) - ) - queryset = cls.objects.filter( - creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT)) - ) - for ticket in queryset: - ticket.logout(None, session, async_list) - queryset.delete() - for future in async_list: - if future: - try: - future.result() - except Exception as error: - logger.warning("Error durring SLO %s" % error) - sys.stderr.write("%r\n" % error) + async_list = [] + session = FuturesSession( + executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS) + ) + queryset = cls.objects.filter( + creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT)) + ) + for ticket in queryset: + ticket.logout(session, async_list) + queryset.delete() + for future in async_list: + if future: # pragma: no branch (should always be true) + try: + future.result() + except Exception as error: + logger.warning("Error durring SLO %s" % error) + sys.stderr.write("%r\n" % error) def logout(self, session, async_list=None): """Send a SLO request to the ticket service""" # On logout invalidate the Ticket self.validate = True self.save() - if self.validate and self.single_log_out: + if self.validate and self.single_log_out: # pragma: no branch (should always be true) logger.info( "Sending SLO requests to service %s for user %s" % ( self.service, diff --git a/cas_server/tests/mixin.py b/cas_server/tests/mixin.py index 84b7b39..246a1c6 100644 --- a/cas_server/tests/mixin.py +++ b/cas_server/tests/mixin.py @@ -1,10 +1,13 @@ """Some mixin classes for tests""" from cas_server.default_settings import settings +from django.utils import timezone import re from lxml import etree +from datetime import timedelta from cas_server import models +from cas_server.tests.utils import get_auth_client class BaseServicePattern(object): @@ -52,6 +55,17 @@ class BaseServicePattern(object): pattern="^admin$", service_pattern=self.service_pattern_filter_fail ) + self.service_filter_fail_alt = "https://filter_fail_alt.example.com" + self.service_pattern_filter_fail_alt = models.ServicePattern.objects.create( + name="filter_fail_alt", + pattern="^https://filter_fail_alt\.example\.com(/.*)?$", + proxy=proxy, + ) + models.FilterAttributValue.objects.create( + attribut="nom", + pattern="^toto$", + service_pattern=self.service_pattern_filter_fail_alt + ) self.service_filter_success = "https://filter_success.example.com" self.service_pattern_filter_success = models.ServicePattern.objects.create( name="filter_success", @@ -143,3 +157,24 @@ class XmlContent(object): self.assertEqual(attrs1, original) return root + + +class UserModels(object): + """Mixin for test on CAS user models""" + def expire_user(self): + """return an expired user""" + client = get_auth_client() + + new_date = timezone.now() - timedelta(seconds=(settings.SESSION_COOKIE_AGE + 600)) + models.User.objects.filter( + username=settings.CAS_TEST_USER, + session_key=client.session.session_key + ).update(date=new_date) + return client + + def get_user(self, client): + """return the user associated with an authenticated client""" + return models.User.objects.get( + username=settings.CAS_TEST_USER, + session_key=client.session.session_key + ) diff --git a/cas_server/tests/test_models.py b/cas_server/tests/test_models.py new file mode 100644 index 0000000..d2f999e --- /dev/null +++ b/cas_server/tests/test_models.py @@ -0,0 +1,146 @@ +"""Tests module for models""" +from cas_server.default_settings import settings + +from django.test import TestCase +from django.test.utils import override_settings +from django.utils import timezone + +from datetime import timedelta +from importlib import import_module + +from cas_server import models +from cas_server import utils +from cas_server.tests.utils import get_auth_client +from cas_server.tests.mixin import UserModels, BaseServicePattern + +SessionStore = import_module(settings.SESSION_ENGINE).SessionStore + + +@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser') +class UserTestCase(TestCase, UserModels): + """tests for the user models""" + def setUp(self): + """Prepare the test context""" + self.service = 'http://127.0.0.1:45678' + self.service_pattern = models.ServicePattern.objects.create( + name="localhost", + pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$", + single_log_out=True + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + def test_clean_old_entries(self): + """test clean_old_entries""" + # get an authenticated client + client = self.expire_user() + # assert the user exists before being cleaned + self.assertEqual(len(models.User.objects.all()), 1) + # assert the last activity date is before the expiry date + self.assertTrue( + self.get_user(client).date < ( + timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE) + ) + ) + # delete old inactive users + models.User.clean_old_entries() + # assert the user has being well delete + self.assertEqual(len(models.User.objects.all()), 0) + + def test_clean_deleted_sessions(self): + """test clean_deleted_sessions""" + # get an authenticated client + client1 = get_auth_client() + client2 = get_auth_client() + # generate a ticket to fire SLO during user cleaning (SLO should fail a nothing listen + # on self.service) + ticket = self.get_user(client1).get_ticket( + models.ServiceTicket, + self.service, + self.service_pattern, + renew=False + ) + ticket.validate = True + ticket.save() + # simulated expired session being garbage collected for client1 + session = SessionStore(session_key=client1.session.session_key) + session.flush() + # assert the user exists before being cleaned + self.assertTrue(self.get_user(client1)) + self.assertTrue(self.get_user(client2)) + self.assertEqual(len(models.User.objects.all()), 2) + # session has being remove so the user of client1 is no longer authenticated + self.assertFalse(client1.session.get("authenticated")) + # the user a client2 should still be authenticated + self.assertTrue(client2.session.get("authenticated")) + # the user should be deleted + models.User.clean_deleted_sessions() + # assert the user with expired sessions has being well deleted but the other remain + self.assertEqual(len(models.User.objects.all()), 1) + self.assertFalse(models.ServiceTicket.objects.all()) + self.assertTrue(client2.session.get("authenticated")) + + +@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser') +class TicketTestCase(TestCase, UserModels, BaseServicePattern): + """tests for the tickets models""" + def setUp(self): + """Prepare the test context""" + self.setup_service_patterns() + self.service = 'http://127.0.0.1:45678' + self.service_pattern = models.ServicePattern.objects.create( + name="localhost", + pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$", + single_log_out=True + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + def get_ticket( + self, + user, + ticket_class, + service, + service_pattern, + renew=False, + validate=False, + validity_expired=False, + timeout_expired=False, + single_log_out=False, + ): + """Return a ticket""" + ticket = user.get_ticket(ticket_class, service, service_pattern, renew) + ticket.validate = validate + ticket.single_log_out = single_log_out + if validity_expired: + ticket.creation = min( + ticket.creation, + (timezone.now() - timedelta(seconds=(ticket_class.VALIDITY + 10))) + ) + if timeout_expired: + ticket.creation = min( + ticket.creation, + (timezone.now() - timedelta(seconds=(ticket_class.TIMEOUT + 10))) + ) + ticket.save() + return ticket + + def test_clean_old_service_ticket(self): + """test tickets clean_old_entries""" + client = get_auth_client() + user = self.get_user(client) + self.get_ticket(user, models.ServiceTicket, self.service, self.service_pattern) + self.get_ticket( + user, models.ServiceTicket, + self.service, self.service_pattern, validity_expired=True + ) + (httpd, host, port) = utils.HttpParamsHandler.run()[0:3] + service = "http://%s:%s" % (host, port) + self.get_ticket( + user, models.ServiceTicket, + service, self.service_pattern, timeout_expired=True, + validate=True, single_log_out=True + ) + self.assertEqual(len(models.ServiceTicket.objects.all()), 3) + models.ServiceTicket.clean_old_entries() + params = httpd.PARAMS + self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest']) + self.assertEqual(len(models.ServiceTicket.objects.all()), 1) diff --git a/cas_server/tests/test_view.py b/cas_server/tests/test_view.py index 7abaeb4..32b79c9 100644 --- a/cas_server/tests/test_view.py +++ b/cas_server/tests/test_view.py @@ -240,9 +240,10 @@ class LoginTestCase(TestCase, BaseServicePattern): """Test the filtering on user attributes""" client = get_auth_client() - response = client.get("/login", {'service': self.service_filter_fail}) - self.assertEqual(response.status_code, 200) - self.assertTrue(b"User charateristics non allowed" in response.content) + for service in [self.service_filter_fail, self.service_filter_fail_alt]: + response = client.get("/login", {'service': service}) + self.assertEqual(response.status_code, 200) + self.assertTrue(b"User charateristics non allowed" in response.content) response = client.get("/login", {'service': self.service_filter_success}) self.assertEqual(response.status_code, 302) @@ -388,6 +389,7 @@ class LoginTestCase(TestCase, BaseServicePattern): class LogoutTestCase(TestCase): """test fot the logout view""" def setUp(self): + """Prepare the test context""" self.service = 'http://127.0.0.1:45678' self.service_pattern = models.ServicePattern.objects.create( name="localhost", @@ -489,29 +491,39 @@ class LogoutTestCase(TestCase): def test_logout_slo(self): """test logout from a service with SLO support""" + parameters = [] + + # test normal SLO (httpd, host, port) = utils.HttpParamsHandler.run()[0:3] service = "http://%s:%s" % (host, port) - (client, ticket) = get_validated_ticket(service)[:2] - client.get('/logout') + parameters.append((httpd.PARAMS, ticket)) - params = httpd.PARAMS - self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest']) + # text SLO with a single_log_out_callback + (httpd, host, port) = utils.HttpParamsHandler.run()[0:3] + self.service_pattern.single_log_out_callback = "http://%s:%s" % (host, port) + self.service_pattern.save() + (client, ticket) = get_validated_ticket(self.service)[:2] + client.get('/logout') + parameters.append((httpd.PARAMS, ticket)) - root = etree.fromstring(params[b'logoutRequest'][0]) - self.assertTrue( - root.xpath( - "//samlp:LogoutRequest", + for (params, ticket) in parameters: + self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest']) + + root = etree.fromstring(params[b'logoutRequest'][0]) + self.assertTrue( + root.xpath( + "//samlp:LogoutRequest", + namespaces={"samlp": "urn:oasis:names:tc:SAML:2.0:protocol"} + ) + ) + session_index = root.xpath( + "//samlp:SessionIndex", namespaces={"samlp": "urn:oasis:names:tc:SAML:2.0:protocol"} ) - ) - session_index = root.xpath( - "//samlp:SessionIndex", - namespaces={"samlp": "urn:oasis:names:tc:SAML:2.0:protocol"} - ) - self.assertEqual(len(session_index), 1) - self.assertEqual(session_index[0].text, ticket.value) + self.assertEqual(len(session_index), 1) + self.assertEqual(session_index[0].text, ticket.value) # SLO error are displayed on logout page (client, ticket) = get_validated_ticket(self.service)[:2] @@ -831,6 +843,39 @@ class ValidateServiceTestCase(TestCase, XmlContent): service_pattern=self.service_pattern_one_attribute ) + self.service_replace_attribute_list = "https://replace_attribute_list.example.com" + self.service_pattern_replace_attribute_list = models.ServicePattern.objects.create( + name="replace_attribute_list", + pattern="^https://replace_attribute_list\.example\.com(/.*)?$", + ) + models.ReplaceAttributValue.objects.create( + attribut="alias", + pattern="^demo", + replace="truc", + service_pattern=self.service_pattern_replace_attribute_list + ) + models.ReplaceAttributName.objects.create( + name="alias", + replace="ALIAS", + service_pattern=self.service_pattern_replace_attribute_list + ) + self.service_replace_attribute = "https://replace_attribute.example.com" + self.service_pattern_replace_attribute = models.ServicePattern.objects.create( + name="replace_attribute", + pattern="^https://replace_attribute\.example\.com(/.*)?$", + ) + models.ReplaceAttributValue.objects.create( + attribut="nom", + pattern="N", + replace="P", + service_pattern=self.service_pattern_replace_attribute + ) + models.ReplaceAttributName.objects.create( + name="nom", + replace="NOM", + service_pattern=self.service_pattern_replace_attribute + ) + def test_validate_service_view_ok(self): """test with a valid (ticket, service), the username and all attributes are transmited""" ticket = get_user_ticket_request(self.service)[1] @@ -857,6 +902,32 @@ class ValidateServiceTestCase(TestCase, XmlContent): {'nom': settings.CAS_TEST_ATTRIBUTES['nom']} ) + def test_validate_replace_attributes(self): + """test with a valid (ticket, service), attributes name and value replacement""" + ticket = get_user_ticket_request(self.service_replace_attribute)[1] + client = Client() + response = client.get( + '/serviceValidate', + {'ticket': ticket.value, 'service': self.service_replace_attribute} + ) + self.assert_success( + response, + settings.CAS_TEST_USER, + {'NOM': 'Pymous'} + ) + + ticket = get_user_ticket_request(self.service_replace_attribute_list)[1] + client = Client() + response = client.get( + '/serviceValidate', + {'ticket': ticket.value, 'service': self.service_replace_attribute_list} + ) + self.assert_success( + response, + settings.CAS_TEST_USER, + {'ALIAS': ['truc1', 'truc2']} + ) + def test_validate_service_view_badservice(self): """test with a valid ticket but a bad service, the validatin should fail""" ticket = get_user_ticket_request(self.service)[1] diff --git a/cas_server/tests/utils.py b/cas_server/tests/utils.py index 162269d..ef06c93 100644 --- a/cas_server/tests/utils.py +++ b/cas_server/tests/utils.py @@ -37,6 +37,8 @@ def get_auth_client(**update): params.update(update) client.post('/login', params) + assert client.session.get("authenticated") + return client