diff --git a/cas_server/tests/mixin.py b/cas_server/tests/mixin.py
new file mode 100644
index 0000000..84b7b39
--- /dev/null
+++ b/cas_server/tests/mixin.py
@@ -0,0 +1,145 @@
+"""Some mixin classes for tests"""
+from cas_server.default_settings import settings
+
+import re
+from lxml import etree
+
+from cas_server import models
+
+
+class BaseServicePattern(object):
+ """Mixing for setting up service pattern for testing"""
+ def setup_service_patterns(self, proxy=False):
+ """setting up service pattern"""
+ # For general purpose testing
+ self.service = "https://www.example.com"
+ self.service_pattern = models.ServicePattern.objects.create(
+ name="example",
+ pattern="^https://www\.example\.com(/.*)?$",
+ proxy=proxy,
+ )
+ models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
+
+ # For testing the restrict_users attributes
+ self.service_restrict_user_fail = "https://restrict_user_fail.example.com"
+ self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create(
+ name="restrict_user_fail",
+ pattern="^https://restrict_user_fail\.example\.com(/.*)?$",
+ restrict_users=True,
+ proxy=proxy,
+ )
+ self.service_restrict_user_success = "https://restrict_user_success.example.com"
+ self.service_pattern_restrict_user_success = models.ServicePattern.objects.create(
+ name="restrict_user_success",
+ pattern="^https://restrict_user_success\.example\.com(/.*)?$",
+ restrict_users=True,
+ proxy=proxy,
+ )
+ models.Username.objects.create(
+ value=settings.CAS_TEST_USER,
+ service_pattern=self.service_pattern_restrict_user_success
+ )
+
+ # For testing the user attributes filtering conditions
+ self.service_filter_fail = "https://filter_fail.example.com"
+ self.service_pattern_filter_fail = models.ServicePattern.objects.create(
+ name="filter_fail",
+ pattern="^https://filter_fail\.example\.com(/.*)?$",
+ proxy=proxy,
+ )
+ models.FilterAttributValue.objects.create(
+ attribut="right",
+ pattern="^admin$",
+ service_pattern=self.service_pattern_filter_fail
+ )
+ self.service_filter_success = "https://filter_success.example.com"
+ self.service_pattern_filter_success = models.ServicePattern.objects.create(
+ name="filter_success",
+ pattern="^https://filter_success\.example\.com(/.*)?$",
+ proxy=proxy,
+ )
+ models.FilterAttributValue.objects.create(
+ attribut="email",
+ pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']),
+ service_pattern=self.service_pattern_filter_success
+ )
+
+ # For testing the user_field attributes
+ self.service_field_needed_fail = "https://field_needed_fail.example.com"
+ self.service_pattern_field_needed_fail = models.ServicePattern.objects.create(
+ name="field_needed_fail",
+ pattern="^https://field_needed_fail\.example\.com(/.*)?$",
+ user_field="uid",
+ proxy=proxy,
+ )
+ self.service_field_needed_success = "https://field_needed_success.example.com"
+ self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
+ name="field_needed_success",
+ pattern="^https://field_needed_success\.example\.com(/.*)?$",
+ user_field="alias",
+ proxy=proxy,
+ )
+ self.service_field_needed_success_alt = "https://field_needed_success_alt.example.com"
+ self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
+ name="field_needed_success_alt",
+ pattern="^https://field_needed_success_alt\.example\.com(/.*)?$",
+ user_field="nom",
+ proxy=proxy,
+ )
+
+
+class XmlContent(object):
+ """Mixin for test on CAS XML responses"""
+ def assert_error(self, response, code, text=None):
+ """Assert a validation error"""
+ self.assertEqual(response.status_code, 200)
+ root = etree.fromstring(response.content)
+ error = root.xpath(
+ "//cas:authenticationFailure",
+ namespaces={'cas': "http://www.yale.edu/tp/cas"}
+ )
+ self.assertEqual(len(error), 1)
+ self.assertEqual(error[0].attrib['code'], code)
+ if text is not None:
+ self.assertEqual(error[0].text, text)
+
+ def assert_success(self, response, username, original_attributes):
+ """assert a ticket validation success"""
+ self.assertEqual(response.status_code, 200)
+
+ root = etree.fromstring(response.content)
+ sucess = root.xpath(
+ "//cas:authenticationSuccess",
+ namespaces={'cas': "http://www.yale.edu/tp/cas"}
+ )
+ self.assertTrue(sucess)
+
+ users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
+ self.assertEqual(len(users), 1)
+ self.assertEqual(users[0].text, username)
+
+ attributes = root.xpath(
+ "//cas:attributes",
+ namespaces={'cas': "http://www.yale.edu/tp/cas"}
+ )
+ self.assertEqual(len(attributes), 1)
+ attrs1 = set()
+ for attr in attributes[0]:
+ attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
+
+ attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
+ self.assertEqual(len(attributes), len(attrs1))
+ attrs2 = set()
+ for attr in attributes:
+ attrs2.add((attr.attrib['name'], attr.attrib['value']))
+ original = set()
+ for key, value in original_attributes.items():
+ if isinstance(value, list):
+ for sub_value in value:
+ original.add((key, sub_value))
+ else:
+ original.add((key, value))
+ self.assertEqual(attrs1, attrs2)
+ self.assertEqual(attrs1, original)
+
+ return root
diff --git a/cas_server/tests/test_utils.py b/cas_server/tests/test_utils.py
new file mode 100644
index 0000000..b42c18c
--- /dev/null
+++ b/cas_server/tests/test_utils.py
@@ -0,0 +1,67 @@
+from django.test import TestCase
+
+import six
+
+from cas_server import utils
+
+
+class CheckPasswordCase(TestCase):
+ """Tests for the utils function `utils.check_password`"""
+
+ def setUp(self):
+ """Generate random bytes string that will be used ass passwords"""
+ self.password1 = utils.gen_saml_id()
+ self.password2 = utils.gen_saml_id()
+ if not isinstance(self.password1, bytes): # pragma: no cover executed only in python3
+ self.password1 = self.password1.encode("utf8")
+ self.password2 = self.password2.encode("utf8")
+
+ def test_setup(self):
+ """check that generated password are bytes"""
+ self.assertIsInstance(self.password1, bytes)
+ self.assertIsInstance(self.password2, bytes)
+
+ def test_plain(self):
+ """test the plain auth method"""
+ self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
+ self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
+
+ def test_crypt(self):
+ """test the crypt auth method"""
+ if six.PY3:
+ hashed_password1 = utils.crypt.crypt(
+ self.password1.decode("utf8"),
+ "$6$UVVAQvrMyXMF3FF3"
+ ).encode("utf8")
+ else:
+ hashed_password1 = utils.crypt.crypt(self.password1, "$6$UVVAQvrMyXMF3FF3")
+
+ self.assertTrue(utils.check_password("crypt", self.password1, hashed_password1, "utf8"))
+ self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8"))
+
+ def test_ldap_ssha(self):
+ """test the ldap auth method with a {SSHA} scheme"""
+ salt = b"UVVAQvrMyXMF3FF3"
+ hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8")
+
+ self.assertIsInstance(hashed_password1, bytes)
+ self.assertTrue(utils.check_password("ldap", self.password1, hashed_password1, "utf8"))
+ self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8"))
+
+ def test_hex_md5(self):
+ """test the hex_md5 auth method"""
+ hashed_password1 = utils.hashlib.md5(self.password1).hexdigest()
+
+ self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8"))
+ self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8"))
+
+ def test_hex_sha512(self):
+ """test the hex_sha512 auth method"""
+ hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest()
+
+ self.assertTrue(
+ utils.check_password("hex_sha512", self.password1, hashed_password1, "utf8")
+ )
+ self.assertFalse(
+ utils.check_password("hex_sha512", self.password2, hashed_password1, "utf8")
+ )
diff --git a/cas_server/tests/test_view.py b/cas_server/tests/test_view.py
index 72185a6..8326d77 100644
--- a/cas_server/tests/test_view.py
+++ b/cas_server/tests/test_view.py
@@ -1,12 +1,12 @@
-"""Tests module"""
+"""Tests module for views"""
from cas_server.default_settings import settings
import django
from django.test import TestCase, Client
from django.test.utils import override_settings
+from django.utils import timezone
+
-import re
-import six
import random
import json
from lxml import etree
@@ -14,203 +14,15 @@ from six.moves import range
from cas_server import models
from cas_server import utils
-
-
-def copy_form(form):
- """Copy form value into a dict"""
- params = {}
- for field in form:
- if field.value():
- params[field.name] = field.value()
- else:
- params[field.name] = ""
- return params
-
-
-def get_login_page_params(client=None):
- """Return a client and the POST params for the client to login"""
- if client is None:
- client = Client()
- response = client.get('/login')
- params = copy_form(response.context["form"])
- return client, params
-
-
-def get_auth_client(**update):
- """return a authenticated client"""
- client, params = get_login_page_params()
- params["username"] = settings.CAS_TEST_USER
- params["password"] = settings.CAS_TEST_PASSWORD
- params.update(update)
-
- client.post('/login', params)
- return client
-
-
-def get_user_ticket_request(service):
- """Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
- client = get_auth_client()
- response = client.get("/login", {"service": service})
- ticket_value = response['Location'].split('ticket=')[-1]
- user = models.User.objects.get(
- username=settings.CAS_TEST_USER,
- session_key=client.session.session_key
- )
- ticket = models.ServiceTicket.objects.get(value=ticket_value)
- return (user, ticket)
-
-
-def get_pgt():
- """return a dict contening a service, user and PGT ticket for this service"""
- (host, port) = utils.PGTUrlHandler.run()[1:3]
- service = "http://%s:%s" % (host, port)
-
- (user, ticket) = get_user_ticket_request(service)
-
- client = Client()
- client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
- params = utils.PGTUrlHandler.PARAMS.copy()
-
- params["service"] = service
- params["user"] = user
-
- return params
-
-
-class CheckPasswordCase(TestCase):
- """Tests for the utils function `utils.check_password`"""
-
- def setUp(self):
- """Generate random bytes string that will be used ass passwords"""
- self.password1 = utils.gen_saml_id()
- self.password2 = utils.gen_saml_id()
- if not isinstance(self.password1, bytes): # pragma: no cover executed only in python3
- self.password1 = self.password1.encode("utf8")
- self.password2 = self.password2.encode("utf8")
-
- def test_setup(self):
- """check that generated password are bytes"""
- self.assertIsInstance(self.password1, bytes)
- self.assertIsInstance(self.password2, bytes)
-
- def test_plain(self):
- """test the plain auth method"""
- self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
- self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
-
- def test_crypt(self):
- """test the crypt auth method"""
- if six.PY3:
- hashed_password1 = utils.crypt.crypt(
- self.password1.decode("utf8"),
- "$6$UVVAQvrMyXMF3FF3"
- ).encode("utf8")
- else:
- hashed_password1 = utils.crypt.crypt(self.password1, "$6$UVVAQvrMyXMF3FF3")
-
- self.assertTrue(utils.check_password("crypt", self.password1, hashed_password1, "utf8"))
- self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8"))
-
- def test_ldap_ssha(self):
- """test the ldap auth method with a {SSHA} scheme"""
- salt = b"UVVAQvrMyXMF3FF3"
- hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8")
-
- self.assertIsInstance(hashed_password1, bytes)
- self.assertTrue(utils.check_password("ldap", self.password1, hashed_password1, "utf8"))
- self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8"))
-
- def test_hex_md5(self):
- """test the hex_md5 auth method"""
- hashed_password1 = utils.hashlib.md5(self.password1).hexdigest()
-
- self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8"))
- self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8"))
-
- def test_hex_sha512(self):
- """test the hex_sha512 auth method"""
- hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest()
-
- self.assertTrue(
- utils.check_password("hex_sha512", self.password1, hashed_password1, "utf8")
- )
- self.assertFalse(
- utils.check_password("hex_sha512", self.password2, hashed_password1, "utf8")
- )
-
-
-class BaseServicePattern(object):
- """Mixing for setting up service pattern for testing"""
- def setup_service_patterns(self, proxy=False):
- """setting up service pattern"""
- # For general purpose testing
- self.service = "https://www.example.com"
- self.service_pattern = models.ServicePattern.objects.create(
- name="example",
- pattern="^https://www\.example\.com(/.*)?$",
- proxy=proxy,
- )
- models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
-
- # For testing the restrict_users attributes
- self.service_restrict_user_fail = "https://restrict_user_fail.example.com"
- self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create(
- name="restrict_user_fail",
- pattern="^https://restrict_user_fail\.example\.com(/.*)?$",
- restrict_users=True,
- proxy=proxy,
- )
- self.service_restrict_user_success = "https://restrict_user_success.example.com"
- self.service_pattern_restrict_user_success = models.ServicePattern.objects.create(
- name="restrict_user_success",
- pattern="^https://restrict_user_success\.example\.com(/.*)?$",
- restrict_users=True,
- proxy=proxy,
- )
- models.Username.objects.create(
- value=settings.CAS_TEST_USER,
- service_pattern=self.service_pattern_restrict_user_success
- )
-
- # For testing the user attributes filtering conditions
- self.service_filter_fail = "https://filter_fail.example.com"
- self.service_pattern_filter_fail = models.ServicePattern.objects.create(
- name="filter_fail",
- pattern="^https://filter_fail\.example\.com(/.*)?$",
- proxy=proxy,
- )
- models.FilterAttributValue.objects.create(
- attribut="right",
- pattern="^admin$",
- service_pattern=self.service_pattern_filter_fail
- )
- self.service_filter_success = "https://filter_success.example.com"
- self.service_pattern_filter_success = models.ServicePattern.objects.create(
- name="filter_success",
- pattern="^https://filter_success\.example\.com(/.*)?$",
- proxy=proxy,
- )
- models.FilterAttributValue.objects.create(
- attribut="email",
- pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']),
- service_pattern=self.service_pattern_filter_success
- )
-
- # For testing the user_field attributes
- self.service_field_needed_fail = "https://field_needed_fail.example.com"
- self.service_pattern_field_needed_fail = models.ServicePattern.objects.create(
- name="field_needed_fail",
- pattern="^https://field_needed_fail\.example\.com(/.*)?$",
- user_field="uid",
- proxy=proxy,
- )
- self.service_field_needed_success = "https://field_needed_success.example.com"
- self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
- name="field_needed_success",
- pattern="^https://field_needed_success\.example\.com(/.*)?$",
- user_field="nom",
- proxy=proxy,
- )
+from cas_server.tests.utils import (
+ copy_form,
+ get_login_page_params,
+ get_auth_client,
+ get_user_ticket_request,
+ get_pgt,
+ get_proxy_ticket
+)
+from cas_server.tests.mixin import BaseServicePattern, XmlContent
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
@@ -449,7 +261,7 @@ class LoginTestCase(TestCase, BaseServicePattern):
response["Location"].startswith("%s?ticket=" % self.service_field_needed_success)
)
- @override_settings(CAS_TEST_ATTRIBUTES={'nom': []})
+ @override_settings(CAS_TEST_ATTRIBUTES={'alias': []})
def test_service_user_field_evaluate_to_false(self):
"""
Test using a user attribute as username:
@@ -458,7 +270,7 @@ class LoginTestCase(TestCase, BaseServicePattern):
client = get_auth_client()
response = client.get("/login", {"service": self.service_field_needed_success})
self.assertEqual(response.status_code, 200)
- self.assertTrue(b"The attribut nom is needed to use that service" in response.content)
+ self.assertTrue(b"The attribut alias is needed to use that service" in response.content)
def test_gateway(self):
"""test gateway parameter"""
@@ -743,6 +555,22 @@ class AuthTestCase(TestCase):
@override_settings(CAS_AUTH_SHARED_SECRET='test')
def test_auth_view_goodpass(self):
"""successful request are awsered by yes"""
+ client = get_auth_client()
+ response = client.post(
+ '/auth',
+ {
+ 'username': settings.CAS_TEST_USER,
+ 'password': settings.CAS_TEST_PASSWORD,
+ 'service': self.service,
+ 'secret': 'test'
+ }
+ )
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, b'yes\n')
+
+ @override_settings(CAS_AUTH_SHARED_SECRET='test')
+ def test_auth_view_goodpass_logged(self):
+ """successful request are awsered by yes, using a logged sessions"""
client = Client()
response = client.post(
'/auth',
@@ -853,6 +681,12 @@ class ValidateTestCase(TestCase):
pattern="^https://user_field\.example\.com(/.*)?$",
user_field="alias"
)
+ self.service_user_field_alt = "https://user_field_alt.example.com"
+ self.service_pattern_user_field_alt = models.ServicePattern.objects.create(
+ name="user field alt",
+ pattern="^https://user_field_alt\.example\.com(/.*)?$",
+ user_field="nom"
+ )
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_validate_view_ok(self):
@@ -893,14 +727,18 @@ class ValidateTestCase(TestCase):
test with a good user_field. A bad user_field (that evaluate to False)
wont happed cause it is filtered in the login view
"""
- ticket = get_user_ticket_request(self.service_user_field)[1]
- client = Client()
- response = client.get(
- '/validate',
- {'ticket': ticket.value, 'service': self.service_user_field}
- )
- self.assertEqual(response.status_code, 200)
- self.assertEqual(response.content, b'yes\ndemo1\n')
+ for (service, username) in [
+ (self.service_user_field, b"demo1"),
+ (self.service_user_field_alt, b"Nymous")
+ ]:
+ ticket = get_user_ticket_request(service)[1]
+ client = Client()
+ response = client.get(
+ '/validate',
+ {'ticket': ticket.value, 'service': service}
+ )
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, b'yes\n' + username + b'\n')
def test_validate_missing_parameter(self):
"""test with a missing GET parameter among [service, ticket]"""
@@ -916,63 +754,6 @@ class ValidateTestCase(TestCase):
self.assertEqual(response.content, b'no\n')
-class XmlContent(object):
- """Mixin for test on CAS XML responses"""
- def assert_error(self, response, code, text=None):
- """Assert a validation error"""
- self.assertEqual(response.status_code, 200)
- root = etree.fromstring(response.content)
- error = root.xpath(
- "//cas:authenticationFailure",
- namespaces={'cas': "http://www.yale.edu/tp/cas"}
- )
- self.assertEqual(len(error), 1)
- self.assertEqual(error[0].attrib['code'], code)
- if text is not None:
- self.assertEqual(error[0].text, text)
-
- def assert_success(self, response, username, original_attributes):
- """assert a ticket validation success"""
- self.assertEqual(response.status_code, 200)
-
- root = etree.fromstring(response.content)
- sucess = root.xpath(
- "//cas:authenticationSuccess",
- namespaces={'cas': "http://www.yale.edu/tp/cas"}
- )
- self.assertTrue(sucess)
-
- users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
- self.assertEqual(len(users), 1)
- self.assertEqual(users[0].text, username)
-
- attributes = root.xpath(
- "//cas:attributes",
- namespaces={'cas': "http://www.yale.edu/tp/cas"}
- )
- self.assertEqual(len(attributes), 1)
- attrs1 = set()
- for attr in attributes[0]:
- attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
-
- attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
- self.assertEqual(len(attributes), len(attrs1))
- attrs2 = set()
- for attr in attributes:
- attrs2.add((attr.attrib['name'], attr.attrib['value']))
- original = set()
- for key, value in original_attributes.items():
- if isinstance(value, list):
- for sub_value in value:
- original.add((key, sub_value))
- else:
- original.add((key, value))
- self.assertEqual(attrs1, attrs2)
- self.assertEqual(attrs1, original)
-
- return root
-
-
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class ValidateServiceTestCase(TestCase, XmlContent):
"""tests for the serviceValidate view"""
@@ -992,6 +773,12 @@ class ValidateServiceTestCase(TestCase, XmlContent):
pattern="^https://user_field\.example\.com(/.*)?$",
user_field="alias"
)
+ self.service_user_field_alt = "https://user_field_alt.example.com"
+ self.service_pattern_user_field_alt = models.ServicePattern.objects.create(
+ name="user field alt",
+ pattern="^https://user_field_alt\.example\.com(/.*)?$",
+ user_field="nom"
+ )
self.service_one_attribute = "https://one_attribute.example.com"
self.service_pattern_one_attribute = models.ServicePattern.objects.create(
@@ -1171,17 +958,21 @@ class ValidateServiceTestCase(TestCase, XmlContent):
test with a good user_field. A bad user_field (that evaluate to False)
wont happed cause it is filtered in the login view
"""
- ticket = get_user_ticket_request(self.service_user_field)[1]
- client = Client()
- response = client.get(
- '/serviceValidate',
- {'ticket': ticket.value, 'service': self.service_user_field}
- )
- self.assert_success(
- response,
- settings.CAS_TEST_ATTRIBUTES["alias"][0],
- {}
- )
+ for (service, username) in [
+ (self.service_user_field, settings.CAS_TEST_ATTRIBUTES["alias"][0]),
+ (self.service_user_field_alt, settings.CAS_TEST_ATTRIBUTES["nom"])
+ ]:
+ ticket = get_user_ticket_request(service)[1]
+ client = Client()
+ response = client.get(
+ '/serviceValidate',
+ {'ticket': ticket.value, 'service': service}
+ )
+ self.assert_success(
+ response,
+ username,
+ {}
+ )
def test_validate_missing_parameter(self):
"""test with a missing GET parameter among [service, ticket]"""
@@ -1349,3 +1140,198 @@ class ProxyTestCase(TestCase, BaseServicePattern, XmlContent):
"INVALID_REQUEST",
'you must specify and pgt and targetService'
)
+
+
+@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
+class SamlValidateTestCase(TestCase, BaseServicePattern, XmlContent):
+ """tests for the proxy view"""
+ def setUp(self):
+ """preparing test context"""
+ self.setup_service_patterns(proxy=True)
+
+ self.service_pgt = 'http://127.0.0.1'
+ self.service_pattern_pgt = models.ServicePattern.objects.create(
+ name="localhost",
+ pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
+ proxy=True,
+ proxy_callback=True
+ )
+ models.ReplaceAttributName.objects.create(
+ name="*",
+ service_pattern=self.service_pattern_pgt
+ )
+
+ xml_template = """
+
+
+
+
+ %(ticket)s
+
+
+"""
+
+ def assert_success(self, response, username, original_attributes):
+ """assert ticket validation success"""
+ self.assertEqual(response.status_code, 200)
+ root = etree.fromstring(response.content)
+ success = root.xpath(
+ "//samlp:StatusCode",
+ namespaces={'samlp': "urn:oasis:names:tc:SAML:1.0:protocol"}
+ )
+ self.assertEqual(len(success), 1)
+ self.assertTrue(success[0].attrib['Value'].endswith(":Success"))
+
+ user = root.xpath(
+ "//samla:NameIdentifier",
+ namespaces={'samla': "urn:oasis:names:tc:SAML:1.0:assertion"}
+ )
+ self.assertTrue(user)
+ self.assertEqual(user[0].text, username)
+
+ attributes = root.xpath(
+ "//samla:AttributeStatement/samla:Attribute",
+ namespaces={'samla': "urn:oasis:names:tc:SAML:1.0:assertion"}
+ )
+ attrs = set()
+ for attr in attributes:
+ attrs.add((attr.attrib['AttributeName'], attr.getchildren()[0].text))
+ original = set()
+ for key, value in original_attributes.items():
+ if isinstance(value, list):
+ for subval in value:
+ original.add((key, subval))
+ else:
+ original.add((key, value))
+ self.assertEqual(original, attrs)
+
+ def assert_error(self, response, code, msg=None):
+ """assert ticket validation error"""
+ self.assertEqual(response.status_code, 200)
+ root = etree.fromstring(response.content)
+ error = root.xpath(
+ "//samlp:StatusCode",
+ namespaces={'samlp': "urn:oasis:names:tc:SAML:1.0:protocol"}
+ )
+ self.assertEqual(len(error), 1)
+ self.assertTrue(error[0].attrib['Value'].endswith(":%s" % code))
+ if msg is not None:
+ self.assertEqual(error[0].text, msg)
+
+ def test_saml_ok(self):
+ """
+ test with a valid (ticket, service), with a ST and a PT,
+ the username and all attributes are transmited"""
+ tickets = [
+ get_user_ticket_request(self.service)[1],
+ get_proxy_ticket(self.service)
+ ]
+
+ for ticket in tickets:
+ client = Client()
+ response = client.post(
+ '/samlValidate?TARGET=%s' % self.service,
+ self.xml_template % {
+ 'ticket': ticket.value,
+ 'request_id': utils.gen_saml_id(),
+ 'issue_instant': timezone.now().isoformat()
+ },
+ content_type="text/xml; encoding='utf-8'"
+ )
+ self.assert_success(response, settings.CAS_TEST_USER, settings.CAS_TEST_ATTRIBUTES)
+
+ def test_saml_ok_user_field(self):
+ """test with a valid(ticket, service), use a attributes as transmitted username"""
+ for (service, username) in [
+ (self.service_field_needed_success, settings.CAS_TEST_ATTRIBUTES['alias'][0]),
+ (self.service_field_needed_success_alt, settings.CAS_TEST_ATTRIBUTES['nom'])
+ ]:
+ ticket = get_user_ticket_request(service)[1]
+
+ client = Client()
+ response = client.post(
+ '/samlValidate?TARGET=%s' % service,
+ self.xml_template % {
+ 'ticket': ticket.value,
+ 'request_id': utils.gen_saml_id(),
+ 'issue_instant': timezone.now().isoformat()
+ },
+ content_type="text/xml; encoding='utf-8'"
+ )
+ self.assert_success(response, username, {})
+
+ def test_saml_bad_ticket(self):
+ """test validation with a bad ST and a bad PT, validation should fail"""
+ tickets = [utils.gen_st(), utils.gen_pt()]
+
+ for ticket in tickets:
+ client = Client()
+ response = client.post(
+ '/samlValidate?TARGET=%s' % self.service,
+ self.xml_template % {
+ 'ticket': ticket,
+ 'request_id': utils.gen_saml_id(),
+ 'issue_instant': timezone.now().isoformat()
+ },
+ content_type="text/xml; encoding='utf-8'"
+ )
+ self.assert_error(
+ response,
+ "AuthnFailed",
+ 'ticket %s not found' % ticket
+ )
+
+ def test_saml_bad_ticket_prefix(self):
+ """test validation with a bad ticket prefix. Validation should fail with 'AuthnFailed'"""
+ bad_ticket = "RANDOM-NOT-BEGINING-WITH-ST-OR-ST"
+ client = Client()
+ response = client.post(
+ '/samlValidate?TARGET=%s' % self.service,
+ self.xml_template % {
+ 'ticket': bad_ticket,
+ 'request_id': utils.gen_saml_id(),
+ 'issue_instant': timezone.now().isoformat()
+ },
+ content_type="text/xml; encoding='utf-8'"
+ )
+ self.assert_error(
+ response,
+ "AuthnFailed",
+ 'ticket %s should begin with PT- or ST-' % bad_ticket
+ )
+
+ def test_saml_bad_target(self):
+ """test with a valid(ticket, service), but using a bad target"""
+ bad_target = "https://www.example.org"
+ ticket = get_user_ticket_request(self.service)[1]
+
+ client = Client()
+ response = client.post(
+ '/samlValidate?TARGET=%s' % bad_target,
+ self.xml_template % {
+ 'ticket': ticket.value,
+ 'request_id': utils.gen_saml_id(),
+ 'issue_instant': timezone.now().isoformat()
+ },
+ content_type="text/xml; encoding='utf-8'"
+ )
+ self.assert_error(
+ response,
+ "AuthnFailed",
+ 'TARGET %s do not match ticket service' % bad_target
+ )
+
+ def test_saml_bad_xml(self):
+ """test validation with a bad xml request, validation should fail"""
+ client = Client()
+ response = client.post(
+ '/samlValidate?TARGET=%s' % self.service,
+ "",
+ content_type="text/xml; encoding='utf-8'"
+ )
+ self.assert_error(response, 'VersionMismatch')
diff --git a/cas_server/tests/utils.py b/cas_server/tests/utils.py
new file mode 100644
index 0000000..db49dc9
--- /dev/null
+++ b/cas_server/tests/utils.py
@@ -0,0 +1,86 @@
+"""Some utils functions for tests"""
+from cas_server.default_settings import settings
+
+from django.test import Client
+
+from lxml import etree
+
+from cas_server import models
+from cas_server import utils
+
+
+def copy_form(form):
+ """Copy form value into a dict"""
+ params = {}
+ for field in form:
+ if field.value():
+ params[field.name] = field.value()
+ else:
+ params[field.name] = ""
+ return params
+
+
+def get_login_page_params(client=None):
+ """Return a client and the POST params for the client to login"""
+ if client is None:
+ client = Client()
+ response = client.get('/login')
+ params = copy_form(response.context["form"])
+ return client, params
+
+
+def get_auth_client(**update):
+ """return a authenticated client"""
+ client, params = get_login_page_params()
+ params["username"] = settings.CAS_TEST_USER
+ params["password"] = settings.CAS_TEST_PASSWORD
+ params.update(update)
+
+ client.post('/login', params)
+ return client
+
+
+def get_user_ticket_request(service):
+ """Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
+ client = get_auth_client()
+ response = client.get("/login", {"service": service})
+ ticket_value = response['Location'].split('ticket=')[-1]
+ user = models.User.objects.get(
+ username=settings.CAS_TEST_USER,
+ session_key=client.session.session_key
+ )
+ ticket = models.ServiceTicket.objects.get(value=ticket_value)
+ return (user, ticket)
+
+
+def get_pgt():
+ """return a dict contening a service, user and PGT ticket for this service"""
+ (host, port) = utils.PGTUrlHandler.run()[1:3]
+ service = "http://%s:%s" % (host, port)
+
+ (user, ticket) = get_user_ticket_request(service)
+
+ client = Client()
+ client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
+ params = utils.PGTUrlHandler.PARAMS.copy()
+
+ params["service"] = service
+ params["user"] = user
+
+ return params
+
+
+def get_proxy_ticket(service):
+ params = get_pgt()
+
+ # get a proxy ticket
+ client = Client()
+ response = client.get('/proxy', {'pgt': params['pgtId'], 'targetService': service})
+ root = etree.fromstring(response.content)
+ proxy_ticket = root.xpath(
+ "//cas:proxyTicket",
+ namespaces={'cas': "http://www.yale.edu/tp/cas"}
+ )
+ proxy_ticket = proxy_ticket[0].text
+ ticket = models.ProxyTicket.objects.get(value=proxy_ticket)
+ return ticket
diff --git a/cas_server/views.py b/cas_server/views.py
index bbd25d4..7d42768 100644
--- a/cas_server/views.py
+++ b/cas_server/views.py
@@ -923,11 +923,15 @@ class SamlValidate(View, AttributesMixin):
'username': self.ticket.user.username,
'attributes': attributes
}
- if self.ticket.service_pattern.user_field and \
- self.ticket.user.attributs.get(self.ticket.service_pattern.user_field):
+ if (self.ticket.service_pattern.user_field and
+ self.ticket.user.attributs.get(self.ticket.service_pattern.user_field)):
params['username'] = self.ticket.user.attributs.get(
self.ticket.service_pattern.user_field
)
+ if isinstance(params['username'], list):
+ # the list is not empty because we wont generate a ticket with a user_field
+ # that evaluate to False
+ params['username'] = params['username'][0]
logger.info(
"SamlValidate: ticket %s validated for user %s on service %s." % (
self.ticket.value,