diff --git a/cas_server/tests.py b/cas_server/tests.py new file mode 100644 index 0000000..75683e6 --- /dev/null +++ b/cas_server/tests.py @@ -0,0 +1,478 @@ +from .default_settings import settings + +from django.test import TestCase +from django.test import Client + +from lxml import etree +import BaseHTTPServer + +import models +import utils + +def get_login_page_params(): + client = Client() + response = client.get('/login') + form = response.context["form"] + params = {} + for field in form: + if field.value(): + params[field.name] = field.value() + else: + params[field.name] = "" + return client, params + +def get_auth_client(): + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + + response = client.post('/login', params) + return client + +def get_user_ticket_request(service): + client = get_auth_client() + response = client.get("/login", {"service": service}) + ticket_value = response['Location'].split('ticket=')[-1] + user = models.User.objects.get(username=settings.CAS_TEST_USER, session_key=client.session.session_key) + ticket = models.ServiceTicket.objects.get(value=ticket_value) + return (user, ticket) + + + +def get_pgt(): + (httpd_thread, host, port) = utils.PGTUrlHandler.run() + service = "http://%s:%s" % (host, port) + + (user, ticket) = get_user_ticket_request(service) + + client = Client() + response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service}) + params = utils.PGTUrlHandler.PARAMS.copy() + + params["service"] = service + params["user"] = user + + return params + +class LoginTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + self.service_pattern = models.ServicePattern.objects.create( + name="example", + pattern="^https://www\.example\.com(/.*)?$", + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + def test_login_view_post_goodpass_goodlt(self): + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + + response = client.post('/login', params) + + self.assertEqual(response.status_code, 200) + self.assertTrue("You have successfully logged into the Central Authentication Service" in response.content) + + self.assertTrue(models.User.objects.get(username=settings.CAS_TEST_USER, session_key=client.session.session_key)) + + + def test_login_view_post_badlt(self): + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + params["lt"] = 'LT-random' + + response = client.post('/login', params) + + self.assertEqual(response.status_code, 200) + self.assertTrue("Invalid login ticket" in response.content) + self.assertFalse("You have successfully logged into the Central Authentication Service" in response.content) + + + def test_login_view_post_badpass_good_lt(self): + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = "test2" + response = client.post('/login', params) + + self.assertEqual(response.status_code, 200) + self.assertTrue(" The credentials you provided cannot be determined to be authentic" in response.content) + self.assertFalse("You have successfully logged into the Central Authentication Service" in response.content) + + + def test_view_login_get_auth_allowed_service(self): + client = get_auth_client() + response = client.get("/login?service=https://www.example.com") + self.assertEqual(response.status_code, 302) + self.assertTrue(response.has_header('Location')) + self.assertTrue(response['Location'].startswith("https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX)) + + ticket_value = response['Location'].split('ticket=')[-1] + user = models.User.objects.get(username=settings.CAS_TEST_USER, session_key=client.session.session_key) + self.assertTrue(user) + ticket = models.ServiceTicket.objects.get(value=ticket_value) + self.assertEqual(ticket.user, user) + self.assertEqual(ticket.attributs, settings.CAS_TEST_ATTRIBUTES) + self.assertEqual(ticket.validate, False) + self.assertEqual(ticket.service_pattern, self.service_pattern) + + def test_view_login_get_auth_denied_service(self): + client = get_auth_client() + response = client.get("/login?service=https://www.example.org") + self.assertEqual(response.status_code, 200) + self.assertTrue("Service https://www.example.org non allowed" in response.content) + + +class LogoutTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + + def test_logout_view(self): + client = get_auth_client() + + response = client.get("/login") + self.assertEqual(response.status_code, 200) + self.assertTrue("You have successfully logged into the Central Authentication Service" in response.content) + + response = client.get("/logout") + self.assertEqual(response.status_code, 200) + self.assertTrue("You have successfully logged out from the Central Authentication Service" in response.content) + + response = client.get("/login") + self.assertEqual(response.status_code, 200) + self.assertFalse("You have successfully logged into the Central Authentication Service" in response.content) + + def test_logout_view_url(self): + client = get_auth_client() + + response = client.get('/logout?url=https://www.example.com') + self.assertEqual(response.status_code, 302) + self.assertTrue(response.has_header("Location")) + self.assertEqual(response["Location"], "https://www.example.com") + + response = client.get("/login") + self.assertEqual(response.status_code, 200) + self.assertFalse("You have successfully logged into the Central Authentication Service" in response.content) + + def test_logout_view_service(self): + client = get_auth_client() + + response = client.get('/logout?service=https://www.example.com') + self.assertEqual(response.status_code, 302) + self.assertTrue(response.has_header("Location")) + self.assertEqual(response["Location"], "https://www.example.com") + + response = client.get("/login") + self.assertEqual(response.status_code, 200) + self.assertFalse("You have successfully logged into the Central Authentication Service" in response.content) + + + open("/tmp/test.html", "w").write(response.content) + + + +class AuthTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + self.service = 'https://www.example.com' + models.ServicePattern.objects.create( + name="example", + pattern="^https://www\.example\.com(/.*)?$" + ) + + def test_auth_view_goodpass(self): + settings.CAS_AUTH_SHARED_SECRET = 'test' + client = Client() + response = client.post('/auth', {'username':settings.CAS_TEST_USER, 'password':settings.CAS_TEST_PASSWORD, 'service':self.service, 'secret':'test'}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, 'yes\n') + + def test_auth_view_badpass(self): + settings.CAS_AUTH_SHARED_SECRET = 'test' + client = Client() + response = client.post('/auth', {'username':settings.CAS_TEST_USER, 'password':'badpass', 'service':self.service, 'secret':'test'}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, 'no\n') + + def test_auth_view_badservice(self): + settings.CAS_AUTH_SHARED_SECRET = 'test' + client = Client() + response = client.post('/auth', {'username':settings.CAS_TEST_USER, 'password':settings.CAS_TEST_PASSWORD, 'service':'https://www.example.org', 'secret':'test'}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, 'no\n') + + def test_auth_view_badsecret(self): + settings.CAS_AUTH_SHARED_SECRET = 'test' + client = Client() + response = client.post('/auth', {'username':settings.CAS_TEST_USER, 'password':settings.CAS_TEST_PASSWORD, 'service':self.service, 'secret':'badsecret'}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, 'no\n') + + def test_auth_view_badsettings(self): + settings.CAS_AUTH_SHARED_SECRET = None + client = Client() + response = client.post('/auth', {'username':settings.CAS_TEST_USER, 'password':settings.CAS_TEST_PASSWORD, 'service':self.service, 'secret':'test'}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, "no\nplease set CAS_AUTH_SHARED_SECRET") + + +class ValidateTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + self.service = 'https://www.example.com' + self.service_pattern = models.ServicePattern.objects.create( + name="example", + pattern="^https://www\.example\.com(/.*)?$" + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + def test_validate_view_ok(self): + (user, ticket) = get_user_ticket_request(self.service) + + client = Client() + response = client.get('/validate', {'ticket': ticket.value, 'service': self.service}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, 'yes\ntest\n') + + def test_validate_view_badservice(self): + (user, ticket) = get_user_ticket_request(self.service) + + client = Client() + response = client.get('/validate', {'ticket': ticket.value, 'service': "https://www.example.org"}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, 'no\n') + + def test_validate_view_badticket(self): + (user, ticket) = get_user_ticket_request(self.service) + + client = Client() + response = client.get('/validate', {'ticket': "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX, 'service': self.service}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, 'no\n') + +class ValidateServiceTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + self.service = 'http://127.0.0.1:45678' + self.service_pattern = models.ServicePattern.objects.create( + name="localhost", + pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$", + proxy_callback=True + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + def test_validate_service_view_ok(self): + (user, ticket) = get_user_ticket_request(self.service) + + client = Client() + response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + sucess = root.xpath("//cas:authenticationSuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertTrue(sucess) + + users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(users), 1) + self.assertEqual(users[0].text, settings.CAS_TEST_USER) + + attributes = root.xpath("//cas:attributes", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(attributes), 1) + attrs1 = {} + for attr in attributes[0]: + attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]]=attr.text + + attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(attributes), len(attrs1)) + attrs2 = {} + for attr in attributes: + attrs2[attr.attrib['name']] = attr.attrib['value'] + self.assertEqual(attrs1, attrs2) + self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES) + + def test_validate_service_view_badservice(self): + (user, ticket) = get_user_ticket_request(self.service) + + client = Client() + bad_service = "https://www.example.org" + response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': bad_service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "INVALID_SERVICE") + self.assertEqual(error[0].text, bad_service) + + def test_validate_service_view_badticket_goodprefix(self): + (user, ticket) = get_user_ticket_request(self.service) + + client = Client() + bad_ticket = "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX + response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "INVALID_TICKET") + self.assertEqual(error[0].text, 'ticket not found') + + def test_validate_service_view_badticket_badprefix(self): + (user, ticket) = get_user_ticket_request(self.service) + + client = Client() + bad_ticket = "RANDOM" + response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "INVALID_TICKET") + self.assertEqual(error[0].text, bad_ticket) + + def test_validate_service_view_ok_pgturl(self): + (httpd_thread, host, port) = utils.PGTUrlHandler.run() + service = "http://%s:%s" % (host, port) + + (user, ticket) = get_user_ticket_request(service) + + client = Client() + response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service}) + pgt_params = utils.PGTUrlHandler.PARAMS.copy() + self.assertEqual(response.status_code, 200) + + + root = etree.fromstring(response.content) + pgtiou = root.xpath("//cas:proxyGrantingTicket", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(pgtiou), 1) + self.assertEqual(pgt_params["pgtIou"], pgtiou[0].text) + self.assertTrue("pgtId" in pgt_params) + + def test_validate_service_pgturl_bad_proxy_callback(self): + self.service_pattern.proxy_callback = False + self.service_pattern.save() + (user, ticket) = get_user_ticket_request(self.service) + + client = Client() + response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service, 'pgtUrl': self.service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "INVALID_PROXY_CALLBACK") + self.assertEqual(error[0].text, "callback url not allowed by configuration") + +class ProxyTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + self.service = 'http://127.0.0.1' + self.service_pattern = models.ServicePattern.objects.create( + name="localhost", + pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$", + proxy=True, + proxy_callback=True + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + + def test_validate_proxy_ok(self): + params = get_pgt() + + # get a proxy ticket + client1 = Client() + response = client1.get('/proxy', {'pgt': params['pgtId'], 'targetService': self.service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + sucess = root.xpath("//cas:proxySuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertTrue(sucess) + + proxy_ticket = root.xpath("//cas:proxyTicket", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(proxy_ticket), 1) + proxy_ticket = proxy_ticket[0].text + + + # validate the proxy ticket + client2 = Client() + response = client2.get('/proxyValidate', {'ticket': proxy_ticket, 'service': self.service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + sucess = root.xpath("//cas:authenticationSuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertTrue(sucess) + + # check that the proxy is send to the end service + proxies = root.xpath("//cas:proxies", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(proxies), 1) + proxy = proxies[0].xpath("//cas:proxy", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(proxy), 1) + self.assertEqual(proxy[0].text, params["service"]) + + # same tests than those for serviceValidate + users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(users), 1) + self.assertEqual(users[0].text, settings.CAS_TEST_USER) + + attributes = root.xpath("//cas:attributes", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(attributes), 1) + attrs1 = {} + for attr in attributes[0]: + attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]]=attr.text + + attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(attributes), len(attrs1)) + attrs2 = {} + for attr in attributes: + attrs2[attr.attrib['name']] = attr.attrib['value'] + self.assertEqual(attrs1, attrs2) + self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES) + + + def test_validate_proxy_bad(self): + params = get_pgt() + + # bad PGT + client1 = Client() + response = client1.get('/proxy', {'pgt': "%s-RANDOM" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX, 'targetService': params['service']}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "INVALID_TICKET") + self.assertEqual(error[0].text, "PGT %s-RANDOM not found" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX) + + # bad targetService + client2 = Client() + response = client2.get('/proxy', {'pgt': params['pgtId'], 'targetService': "https://www.example.org"}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE") + self.assertEqual(error[0].text, "https://www.example.org") + + + # service do not allow proxy ticket + self.service_pattern.proxy = False + self.service_pattern.save() + + client3 = Client() + response = client3.get('/proxy', {'pgt': params['pgtId'], 'targetService': params['service']}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE") + self.assertEqual(error[0].text, 'the service %s do not allow proxy ticket' % params['service']) diff --git a/cas_server/utils.py b/cas_server/utils.py index fdb8f46..69e5623 100644 --- a/cas_server/utils.py +++ b/cas_server/utils.py @@ -19,6 +19,8 @@ from django.contrib import messages import random import string import json +import BaseHTTPServer +from threading import Thread from importlib import import_module try: @@ -144,3 +146,32 @@ def gen_pgtiou(): def gen_saml_id(): """Generate an saml id""" return _gen_ticket('_') + + +class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler): + PARAMS={} + def do_GET(s): + s.send_response(200) + s.send_header("Content-type", "text/plain") + s.end_headers() + s.wfile.write("ok") + url = urlparse(s.path) + params = dict(parse_qsl(url.query)) + PGTUrlHandler.PARAMS.update(params) + s.wfile.write("%s" % params) + def log_message(self, format, *args): + return + + @staticmethod + def run(): + server_class = BaseHTTPServer.HTTPServer + httpd = server_class(("127.0.0.1", 0), PGTUrlHandler) + (host, port) = httpd.socket.getsockname() + def lauch(): + httpd.handle_request() + #httpd.serve_forever() + httpd.server_close() + httpd_thread = Thread(target=lauch) + httpd_thread.daemon = True + httpd_thread.start() + return (httpd_thread, host, port)