From e5efdadde0c55ac800fc2c01a5c22fe24e825855 Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Wed, 29 Jun 2016 17:07:49 +0200 Subject: [PATCH] Full coverage for saml + split tests --- cas_server/tests/mixin.py | 145 +++++++++ cas_server/tests/test_utils.py | 67 ++++ cas_server/tests/test_view.py | 542 ++++++++++++++++----------------- cas_server/tests/utils.py | 86 ++++++ cas_server/views.py | 8 +- 5 files changed, 568 insertions(+), 280 deletions(-) create mode 100644 cas_server/tests/mixin.py create mode 100644 cas_server/tests/test_utils.py create mode 100644 cas_server/tests/utils.py diff --git a/cas_server/tests/mixin.py b/cas_server/tests/mixin.py new file mode 100644 index 0000000..84b7b39 --- /dev/null +++ b/cas_server/tests/mixin.py @@ -0,0 +1,145 @@ +"""Some mixin classes for tests""" +from cas_server.default_settings import settings + +import re +from lxml import etree + +from cas_server import models + + +class BaseServicePattern(object): + """Mixing for setting up service pattern for testing""" + def setup_service_patterns(self, proxy=False): + """setting up service pattern""" + # For general purpose testing + self.service = "https://www.example.com" + self.service_pattern = models.ServicePattern.objects.create( + name="example", + pattern="^https://www\.example\.com(/.*)?$", + proxy=proxy, + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + # For testing the restrict_users attributes + self.service_restrict_user_fail = "https://restrict_user_fail.example.com" + self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create( + name="restrict_user_fail", + pattern="^https://restrict_user_fail\.example\.com(/.*)?$", + restrict_users=True, + proxy=proxy, + ) + self.service_restrict_user_success = "https://restrict_user_success.example.com" + self.service_pattern_restrict_user_success = models.ServicePattern.objects.create( + name="restrict_user_success", + pattern="^https://restrict_user_success\.example\.com(/.*)?$", + restrict_users=True, + proxy=proxy, + ) + models.Username.objects.create( + value=settings.CAS_TEST_USER, + service_pattern=self.service_pattern_restrict_user_success + ) + + # For testing the user attributes filtering conditions + self.service_filter_fail = "https://filter_fail.example.com" + self.service_pattern_filter_fail = models.ServicePattern.objects.create( + name="filter_fail", + pattern="^https://filter_fail\.example\.com(/.*)?$", + proxy=proxy, + ) + models.FilterAttributValue.objects.create( + attribut="right", + pattern="^admin$", + service_pattern=self.service_pattern_filter_fail + ) + self.service_filter_success = "https://filter_success.example.com" + self.service_pattern_filter_success = models.ServicePattern.objects.create( + name="filter_success", + pattern="^https://filter_success\.example\.com(/.*)?$", + proxy=proxy, + ) + models.FilterAttributValue.objects.create( + attribut="email", + pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']), + service_pattern=self.service_pattern_filter_success + ) + + # For testing the user_field attributes + self.service_field_needed_fail = "https://field_needed_fail.example.com" + self.service_pattern_field_needed_fail = models.ServicePattern.objects.create( + name="field_needed_fail", + pattern="^https://field_needed_fail\.example\.com(/.*)?$", + user_field="uid", + proxy=proxy, + ) + self.service_field_needed_success = "https://field_needed_success.example.com" + self.service_pattern_field_needed_success = models.ServicePattern.objects.create( + name="field_needed_success", + pattern="^https://field_needed_success\.example\.com(/.*)?$", + user_field="alias", + proxy=proxy, + ) + self.service_field_needed_success_alt = "https://field_needed_success_alt.example.com" + self.service_pattern_field_needed_success = models.ServicePattern.objects.create( + name="field_needed_success_alt", + pattern="^https://field_needed_success_alt\.example\.com(/.*)?$", + user_field="nom", + proxy=proxy, + ) + + +class XmlContent(object): + """Mixin for test on CAS XML responses""" + def assert_error(self, response, code, text=None): + """Assert a validation error""" + self.assertEqual(response.status_code, 200) + root = etree.fromstring(response.content) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], code) + if text is not None: + self.assertEqual(error[0].text, text) + + def assert_success(self, response, username, original_attributes): + """assert a ticket validation success""" + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + sucess = root.xpath( + "//cas:authenticationSuccess", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertTrue(sucess) + + users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(users), 1) + self.assertEqual(users[0].text, username) + + attributes = root.xpath( + "//cas:attributes", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(attributes), 1) + attrs1 = set() + for attr in attributes[0]: + attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text)) + + attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(attributes), len(attrs1)) + attrs2 = set() + for attr in attributes: + attrs2.add((attr.attrib['name'], attr.attrib['value'])) + original = set() + for key, value in original_attributes.items(): + if isinstance(value, list): + for sub_value in value: + original.add((key, sub_value)) + else: + original.add((key, value)) + self.assertEqual(attrs1, attrs2) + self.assertEqual(attrs1, original) + + return root diff --git a/cas_server/tests/test_utils.py b/cas_server/tests/test_utils.py new file mode 100644 index 0000000..b42c18c --- /dev/null +++ b/cas_server/tests/test_utils.py @@ -0,0 +1,67 @@ +from django.test import TestCase + +import six + +from cas_server import utils + + +class CheckPasswordCase(TestCase): + """Tests for the utils function `utils.check_password`""" + + def setUp(self): + """Generate random bytes string that will be used ass passwords""" + self.password1 = utils.gen_saml_id() + self.password2 = utils.gen_saml_id() + if not isinstance(self.password1, bytes): # pragma: no cover executed only in python3 + self.password1 = self.password1.encode("utf8") + self.password2 = self.password2.encode("utf8") + + def test_setup(self): + """check that generated password are bytes""" + self.assertIsInstance(self.password1, bytes) + self.assertIsInstance(self.password2, bytes) + + def test_plain(self): + """test the plain auth method""" + self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8")) + self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8")) + + def test_crypt(self): + """test the crypt auth method""" + if six.PY3: + hashed_password1 = utils.crypt.crypt( + self.password1.decode("utf8"), + "$6$UVVAQvrMyXMF3FF3" + ).encode("utf8") + else: + hashed_password1 = utils.crypt.crypt(self.password1, "$6$UVVAQvrMyXMF3FF3") + + self.assertTrue(utils.check_password("crypt", self.password1, hashed_password1, "utf8")) + self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8")) + + def test_ldap_ssha(self): + """test the ldap auth method with a {SSHA} scheme""" + salt = b"UVVAQvrMyXMF3FF3" + hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8") + + self.assertIsInstance(hashed_password1, bytes) + self.assertTrue(utils.check_password("ldap", self.password1, hashed_password1, "utf8")) + self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8")) + + def test_hex_md5(self): + """test the hex_md5 auth method""" + hashed_password1 = utils.hashlib.md5(self.password1).hexdigest() + + self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8")) + self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8")) + + def test_hex_sha512(self): + """test the hex_sha512 auth method""" + hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest() + + self.assertTrue( + utils.check_password("hex_sha512", self.password1, hashed_password1, "utf8") + ) + self.assertFalse( + utils.check_password("hex_sha512", self.password2, hashed_password1, "utf8") + ) diff --git a/cas_server/tests/test_view.py b/cas_server/tests/test_view.py index 72185a6..8326d77 100644 --- a/cas_server/tests/test_view.py +++ b/cas_server/tests/test_view.py @@ -1,12 +1,12 @@ -"""Tests module""" +"""Tests module for views""" from cas_server.default_settings import settings import django from django.test import TestCase, Client from django.test.utils import override_settings +from django.utils import timezone + -import re -import six import random import json from lxml import etree @@ -14,203 +14,15 @@ from six.moves import range from cas_server import models from cas_server import utils - - -def copy_form(form): - """Copy form value into a dict""" - params = {} - for field in form: - if field.value(): - params[field.name] = field.value() - else: - params[field.name] = "" - return params - - -def get_login_page_params(client=None): - """Return a client and the POST params for the client to login""" - if client is None: - client = Client() - response = client.get('/login') - params = copy_form(response.context["form"]) - return client, params - - -def get_auth_client(**update): - """return a authenticated client""" - client, params = get_login_page_params() - params["username"] = settings.CAS_TEST_USER - params["password"] = settings.CAS_TEST_PASSWORD - params.update(update) - - client.post('/login', params) - return client - - -def get_user_ticket_request(service): - """Make an auth client to request a ticket for `service`, return the tuple (user, ticket)""" - client = get_auth_client() - response = client.get("/login", {"service": service}) - ticket_value = response['Location'].split('ticket=')[-1] - user = models.User.objects.get( - username=settings.CAS_TEST_USER, - session_key=client.session.session_key - ) - ticket = models.ServiceTicket.objects.get(value=ticket_value) - return (user, ticket) - - -def get_pgt(): - """return a dict contening a service, user and PGT ticket for this service""" - (host, port) = utils.PGTUrlHandler.run()[1:3] - service = "http://%s:%s" % (host, port) - - (user, ticket) = get_user_ticket_request(service) - - client = Client() - client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service}) - params = utils.PGTUrlHandler.PARAMS.copy() - - params["service"] = service - params["user"] = user - - return params - - -class CheckPasswordCase(TestCase): - """Tests for the utils function `utils.check_password`""" - - def setUp(self): - """Generate random bytes string that will be used ass passwords""" - self.password1 = utils.gen_saml_id() - self.password2 = utils.gen_saml_id() - if not isinstance(self.password1, bytes): # pragma: no cover executed only in python3 - self.password1 = self.password1.encode("utf8") - self.password2 = self.password2.encode("utf8") - - def test_setup(self): - """check that generated password are bytes""" - self.assertIsInstance(self.password1, bytes) - self.assertIsInstance(self.password2, bytes) - - def test_plain(self): - """test the plain auth method""" - self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8")) - self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8")) - - def test_crypt(self): - """test the crypt auth method""" - if six.PY3: - hashed_password1 = utils.crypt.crypt( - self.password1.decode("utf8"), - "$6$UVVAQvrMyXMF3FF3" - ).encode("utf8") - else: - hashed_password1 = utils.crypt.crypt(self.password1, "$6$UVVAQvrMyXMF3FF3") - - self.assertTrue(utils.check_password("crypt", self.password1, hashed_password1, "utf8")) - self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8")) - - def test_ldap_ssha(self): - """test the ldap auth method with a {SSHA} scheme""" - salt = b"UVVAQvrMyXMF3FF3" - hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8") - - self.assertIsInstance(hashed_password1, bytes) - self.assertTrue(utils.check_password("ldap", self.password1, hashed_password1, "utf8")) - self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8")) - - def test_hex_md5(self): - """test the hex_md5 auth method""" - hashed_password1 = utils.hashlib.md5(self.password1).hexdigest() - - self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8")) - self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8")) - - def test_hex_sha512(self): - """test the hex_sha512 auth method""" - hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest() - - self.assertTrue( - utils.check_password("hex_sha512", self.password1, hashed_password1, "utf8") - ) - self.assertFalse( - utils.check_password("hex_sha512", self.password2, hashed_password1, "utf8") - ) - - -class BaseServicePattern(object): - """Mixing for setting up service pattern for testing""" - def setup_service_patterns(self, proxy=False): - """setting up service pattern""" - # For general purpose testing - self.service = "https://www.example.com" - self.service_pattern = models.ServicePattern.objects.create( - name="example", - pattern="^https://www\.example\.com(/.*)?$", - proxy=proxy, - ) - models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) - - # For testing the restrict_users attributes - self.service_restrict_user_fail = "https://restrict_user_fail.example.com" - self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create( - name="restrict_user_fail", - pattern="^https://restrict_user_fail\.example\.com(/.*)?$", - restrict_users=True, - proxy=proxy, - ) - self.service_restrict_user_success = "https://restrict_user_success.example.com" - self.service_pattern_restrict_user_success = models.ServicePattern.objects.create( - name="restrict_user_success", - pattern="^https://restrict_user_success\.example\.com(/.*)?$", - restrict_users=True, - proxy=proxy, - ) - models.Username.objects.create( - value=settings.CAS_TEST_USER, - service_pattern=self.service_pattern_restrict_user_success - ) - - # For testing the user attributes filtering conditions - self.service_filter_fail = "https://filter_fail.example.com" - self.service_pattern_filter_fail = models.ServicePattern.objects.create( - name="filter_fail", - pattern="^https://filter_fail\.example\.com(/.*)?$", - proxy=proxy, - ) - models.FilterAttributValue.objects.create( - attribut="right", - pattern="^admin$", - service_pattern=self.service_pattern_filter_fail - ) - self.service_filter_success = "https://filter_success.example.com" - self.service_pattern_filter_success = models.ServicePattern.objects.create( - name="filter_success", - pattern="^https://filter_success\.example\.com(/.*)?$", - proxy=proxy, - ) - models.FilterAttributValue.objects.create( - attribut="email", - pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']), - service_pattern=self.service_pattern_filter_success - ) - - # For testing the user_field attributes - self.service_field_needed_fail = "https://field_needed_fail.example.com" - self.service_pattern_field_needed_fail = models.ServicePattern.objects.create( - name="field_needed_fail", - pattern="^https://field_needed_fail\.example\.com(/.*)?$", - user_field="uid", - proxy=proxy, - ) - self.service_field_needed_success = "https://field_needed_success.example.com" - self.service_pattern_field_needed_success = models.ServicePattern.objects.create( - name="field_needed_success", - pattern="^https://field_needed_success\.example\.com(/.*)?$", - user_field="nom", - proxy=proxy, - ) +from cas_server.tests.utils import ( + copy_form, + get_login_page_params, + get_auth_client, + get_user_ticket_request, + get_pgt, + get_proxy_ticket +) +from cas_server.tests.mixin import BaseServicePattern, XmlContent @override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser') @@ -449,7 +261,7 @@ class LoginTestCase(TestCase, BaseServicePattern): response["Location"].startswith("%s?ticket=" % self.service_field_needed_success) ) - @override_settings(CAS_TEST_ATTRIBUTES={'nom': []}) + @override_settings(CAS_TEST_ATTRIBUTES={'alias': []}) def test_service_user_field_evaluate_to_false(self): """ Test using a user attribute as username: @@ -458,7 +270,7 @@ class LoginTestCase(TestCase, BaseServicePattern): client = get_auth_client() response = client.get("/login", {"service": self.service_field_needed_success}) self.assertEqual(response.status_code, 200) - self.assertTrue(b"The attribut nom is needed to use that service" in response.content) + self.assertTrue(b"The attribut alias is needed to use that service" in response.content) def test_gateway(self): """test gateway parameter""" @@ -743,6 +555,22 @@ class AuthTestCase(TestCase): @override_settings(CAS_AUTH_SHARED_SECRET='test') def test_auth_view_goodpass(self): """successful request are awsered by yes""" + client = get_auth_client() + response = client.post( + '/auth', + { + 'username': settings.CAS_TEST_USER, + 'password': settings.CAS_TEST_PASSWORD, + 'service': self.service, + 'secret': 'test' + } + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'yes\n') + + @override_settings(CAS_AUTH_SHARED_SECRET='test') + def test_auth_view_goodpass_logged(self): + """successful request are awsered by yes, using a logged sessions""" client = Client() response = client.post( '/auth', @@ -853,6 +681,12 @@ class ValidateTestCase(TestCase): pattern="^https://user_field\.example\.com(/.*)?$", user_field="alias" ) + self.service_user_field_alt = "https://user_field_alt.example.com" + self.service_pattern_user_field_alt = models.ServicePattern.objects.create( + name="user field alt", + pattern="^https://user_field_alt\.example\.com(/.*)?$", + user_field="nom" + ) models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) def test_validate_view_ok(self): @@ -893,14 +727,18 @@ class ValidateTestCase(TestCase): test with a good user_field. A bad user_field (that evaluate to False) wont happed cause it is filtered in the login view """ - ticket = get_user_ticket_request(self.service_user_field)[1] - client = Client() - response = client.get( - '/validate', - {'ticket': ticket.value, 'service': self.service_user_field} - ) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.content, b'yes\ndemo1\n') + for (service, username) in [ + (self.service_user_field, b"demo1"), + (self.service_user_field_alt, b"Nymous") + ]: + ticket = get_user_ticket_request(service)[1] + client = Client() + response = client.get( + '/validate', + {'ticket': ticket.value, 'service': service} + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'yes\n' + username + b'\n') def test_validate_missing_parameter(self): """test with a missing GET parameter among [service, ticket]""" @@ -916,63 +754,6 @@ class ValidateTestCase(TestCase): self.assertEqual(response.content, b'no\n') -class XmlContent(object): - """Mixin for test on CAS XML responses""" - def assert_error(self, response, code, text=None): - """Assert a validation error""" - self.assertEqual(response.status_code, 200) - root = etree.fromstring(response.content) - error = root.xpath( - "//cas:authenticationFailure", - namespaces={'cas': "http://www.yale.edu/tp/cas"} - ) - self.assertEqual(len(error), 1) - self.assertEqual(error[0].attrib['code'], code) - if text is not None: - self.assertEqual(error[0].text, text) - - def assert_success(self, response, username, original_attributes): - """assert a ticket validation success""" - self.assertEqual(response.status_code, 200) - - root = etree.fromstring(response.content) - sucess = root.xpath( - "//cas:authenticationSuccess", - namespaces={'cas': "http://www.yale.edu/tp/cas"} - ) - self.assertTrue(sucess) - - users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - self.assertEqual(len(users), 1) - self.assertEqual(users[0].text, username) - - attributes = root.xpath( - "//cas:attributes", - namespaces={'cas': "http://www.yale.edu/tp/cas"} - ) - self.assertEqual(len(attributes), 1) - attrs1 = set() - for attr in attributes[0]: - attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text)) - - attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - self.assertEqual(len(attributes), len(attrs1)) - attrs2 = set() - for attr in attributes: - attrs2.add((attr.attrib['name'], attr.attrib['value'])) - original = set() - for key, value in original_attributes.items(): - if isinstance(value, list): - for sub_value in value: - original.add((key, sub_value)) - else: - original.add((key, value)) - self.assertEqual(attrs1, attrs2) - self.assertEqual(attrs1, original) - - return root - - @override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser') class ValidateServiceTestCase(TestCase, XmlContent): """tests for the serviceValidate view""" @@ -992,6 +773,12 @@ class ValidateServiceTestCase(TestCase, XmlContent): pattern="^https://user_field\.example\.com(/.*)?$", user_field="alias" ) + self.service_user_field_alt = "https://user_field_alt.example.com" + self.service_pattern_user_field_alt = models.ServicePattern.objects.create( + name="user field alt", + pattern="^https://user_field_alt\.example\.com(/.*)?$", + user_field="nom" + ) self.service_one_attribute = "https://one_attribute.example.com" self.service_pattern_one_attribute = models.ServicePattern.objects.create( @@ -1171,17 +958,21 @@ class ValidateServiceTestCase(TestCase, XmlContent): test with a good user_field. A bad user_field (that evaluate to False) wont happed cause it is filtered in the login view """ - ticket = get_user_ticket_request(self.service_user_field)[1] - client = Client() - response = client.get( - '/serviceValidate', - {'ticket': ticket.value, 'service': self.service_user_field} - ) - self.assert_success( - response, - settings.CAS_TEST_ATTRIBUTES["alias"][0], - {} - ) + for (service, username) in [ + (self.service_user_field, settings.CAS_TEST_ATTRIBUTES["alias"][0]), + (self.service_user_field_alt, settings.CAS_TEST_ATTRIBUTES["nom"]) + ]: + ticket = get_user_ticket_request(service)[1] + client = Client() + response = client.get( + '/serviceValidate', + {'ticket': ticket.value, 'service': service} + ) + self.assert_success( + response, + username, + {} + ) def test_validate_missing_parameter(self): """test with a missing GET parameter among [service, ticket]""" @@ -1349,3 +1140,198 @@ class ProxyTestCase(TestCase, BaseServicePattern, XmlContent): "INVALID_REQUEST", 'you must specify and pgt and targetService' ) + + +@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser') +class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent): + """tests for the proxy view""" + def setUp(self): + """preparing test context""" + self.setup_service_patterns(proxy=True) + + self.service_pgt = 'http://127.0.0.1' + self.service_pattern_pgt = models.ServicePattern.objects.create( + name="localhost", + pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$", + proxy=True, + proxy_callback=True + ) + models.ReplaceAttributName.objects.create( + name="*", + service_pattern=self.service_pattern_pgt + ) + + xml_template = """ + + + + + %(ticket)s + + +""" + + def assert_success(self, response, username, original_attributes): + """assert ticket validation success""" + self.assertEqual(response.status_code, 200) + root = etree.fromstring(response.content) + success = root.xpath( + "//samlp:StatusCode", + namespaces={'samlp': "urn:oasis:names:tc:SAML:1.0:protocol"} + ) + self.assertEqual(len(success), 1) + self.assertTrue(success[0].attrib['Value'].endswith(":Success")) + + user = root.xpath( + "//samla:NameIdentifier", + namespaces={'samla': "urn:oasis:names:tc:SAML:1.0:assertion"} + ) + self.assertTrue(user) + self.assertEqual(user[0].text, username) + + attributes = root.xpath( + "//samla:AttributeStatement/samla:Attribute", + namespaces={'samla': "urn:oasis:names:tc:SAML:1.0:assertion"} + ) + attrs = set() + for attr in attributes: + attrs.add((attr.attrib['AttributeName'], attr.getchildren()[0].text)) + original = set() + for key, value in original_attributes.items(): + if isinstance(value, list): + for subval in value: + original.add((key, subval)) + else: + original.add((key, value)) + self.assertEqual(original, attrs) + + def assert_error(self, response, code, msg=None): + """assert ticket validation error""" + self.assertEqual(response.status_code, 200) + root = etree.fromstring(response.content) + error = root.xpath( + "//samlp:StatusCode", + namespaces={'samlp': "urn:oasis:names:tc:SAML:1.0:protocol"} + ) + self.assertEqual(len(error), 1) + self.assertTrue(error[0].attrib['Value'].endswith(":%s" % code)) + if msg is not None: + self.assertEqual(error[0].text, msg) + + def test_saml_ok(self): + """ + test with a valid (ticket, service), with a ST and a PT, + the username and all attributes are transmited""" + tickets = [ + get_user_ticket_request(self.service)[1], + get_proxy_ticket(self.service) + ] + + for ticket in tickets: + client = Client() + response = client.post( + '/samlValidate?TARGET=%s' % self.service, + self.xml_template % { + 'ticket': ticket.value, + 'request_id': utils.gen_saml_id(), + 'issue_instant': timezone.now().isoformat() + }, + content_type="text/xml; encoding='utf-8'" + ) + self.assert_success(response, settings.CAS_TEST_USER, settings.CAS_TEST_ATTRIBUTES) + + def test_saml_ok_user_field(self): + """test with a valid(ticket, service), use a attributes as transmitted username""" + for (service, username) in [ + (self.service_field_needed_success, settings.CAS_TEST_ATTRIBUTES['alias'][0]), + (self.service_field_needed_success_alt, settings.CAS_TEST_ATTRIBUTES['nom']) + ]: + ticket = get_user_ticket_request(service)[1] + + client = Client() + response = client.post( + '/samlValidate?TARGET=%s' % service, + self.xml_template % { + 'ticket': ticket.value, + 'request_id': utils.gen_saml_id(), + 'issue_instant': timezone.now().isoformat() + }, + content_type="text/xml; encoding='utf-8'" + ) + self.assert_success(response, username, {}) + + def test_saml_bad_ticket(self): + """test validation with a bad ST and a bad PT, validation should fail""" + tickets = [utils.gen_st(), utils.gen_pt()] + + for ticket in tickets: + client = Client() + response = client.post( + '/samlValidate?TARGET=%s' % self.service, + self.xml_template % { + 'ticket': ticket, + 'request_id': utils.gen_saml_id(), + 'issue_instant': timezone.now().isoformat() + }, + content_type="text/xml; encoding='utf-8'" + ) + self.assert_error( + response, + "AuthnFailed", + 'ticket %s not found' % ticket + ) + + def test_saml_bad_ticket_prefix(self): + """test validation with a bad ticket prefix. Validation should fail with 'AuthnFailed'""" + bad_ticket = "RANDOM-NOT-BEGINING-WITH-ST-OR-ST" + client = Client() + response = client.post( + '/samlValidate?TARGET=%s' % self.service, + self.xml_template % { + 'ticket': bad_ticket, + 'request_id': utils.gen_saml_id(), + 'issue_instant': timezone.now().isoformat() + }, + content_type="text/xml; encoding='utf-8'" + ) + self.assert_error( + response, + "AuthnFailed", + 'ticket %s should begin with PT- or ST-' % bad_ticket + ) + + def test_saml_bad_target(self): + """test with a valid(ticket, service), but using a bad target""" + bad_target = "https://www.example.org" + ticket = get_user_ticket_request(self.service)[1] + + client = Client() + response = client.post( + '/samlValidate?TARGET=%s' % bad_target, + self.xml_template % { + 'ticket': ticket.value, + 'request_id': utils.gen_saml_id(), + 'issue_instant': timezone.now().isoformat() + }, + content_type="text/xml; encoding='utf-8'" + ) + self.assert_error( + response, + "AuthnFailed", + 'TARGET %s do not match ticket service' % bad_target + ) + + def test_saml_bad_xml(self): + """test validation with a bad xml request, validation should fail""" + client = Client() + response = client.post( + '/samlValidate?TARGET=%s' % self.service, + "", + content_type="text/xml; encoding='utf-8'" + ) + self.assert_error(response, 'VersionMismatch') diff --git a/cas_server/tests/utils.py b/cas_server/tests/utils.py new file mode 100644 index 0000000..db49dc9 --- /dev/null +++ b/cas_server/tests/utils.py @@ -0,0 +1,86 @@ +"""Some utils functions for tests""" +from cas_server.default_settings import settings + +from django.test import Client + +from lxml import etree + +from cas_server import models +from cas_server import utils + + +def copy_form(form): + """Copy form value into a dict""" + params = {} + for field in form: + if field.value(): + params[field.name] = field.value() + else: + params[field.name] = "" + return params + + +def get_login_page_params(client=None): + """Return a client and the POST params for the client to login""" + if client is None: + client = Client() + response = client.get('/login') + params = copy_form(response.context["form"]) + return client, params + + +def get_auth_client(**update): + """return a authenticated client""" + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + params.update(update) + + client.post('/login', params) + return client + + +def get_user_ticket_request(service): + """Make an auth client to request a ticket for `service`, return the tuple (user, ticket)""" + client = get_auth_client() + response = client.get("/login", {"service": service}) + ticket_value = response['Location'].split('ticket=')[-1] + user = models.User.objects.get( + username=settings.CAS_TEST_USER, + session_key=client.session.session_key + ) + ticket = models.ServiceTicket.objects.get(value=ticket_value) + return (user, ticket) + + +def get_pgt(): + """return a dict contening a service, user and PGT ticket for this service""" + (host, port) = utils.PGTUrlHandler.run()[1:3] + service = "http://%s:%s" % (host, port) + + (user, ticket) = get_user_ticket_request(service) + + client = Client() + client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service}) + params = utils.PGTUrlHandler.PARAMS.copy() + + params["service"] = service + params["user"] = user + + return params + + +def get_proxy_ticket(service): + params = get_pgt() + + # get a proxy ticket + client = Client() + response = client.get('/proxy', {'pgt': params['pgtId'], 'targetService': service}) + root = etree.fromstring(response.content) + proxy_ticket = root.xpath( + "//cas:proxyTicket", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + proxy_ticket = proxy_ticket[0].text + ticket = models.ProxyTicket.objects.get(value=proxy_ticket) + return ticket diff --git a/cas_server/views.py b/cas_server/views.py index bbd25d4..7d42768 100644 --- a/cas_server/views.py +++ b/cas_server/views.py @@ -923,11 +923,15 @@ class SamlValidate(View, AttributesMixin): 'username': self.ticket.user.username, 'attributes': attributes } - if self.ticket.service_pattern.user_field and \ - self.ticket.user.attributs.get(self.ticket.service_pattern.user_field): + if (self.ticket.service_pattern.user_field and + self.ticket.user.attributs.get(self.ticket.service_pattern.user_field)): params['username'] = self.ticket.user.attributs.get( self.ticket.service_pattern.user_field ) + if isinstance(params['username'], list): + # the list is not empty because we wont generate a ticket with a user_field + # that evaluate to False + params['username'] = params['username'][0] logger.info( "SamlValidate: ticket %s validated for user %s on service %s." % ( self.ticket.value,