From a0ab47a4aed341fdfec8a819ed2a270e60ee1fbc Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Fri, 24 Jun 2016 21:05:43 +0200 Subject: [PATCH 01/27] Allow pgtUrl to be localhost without https --- cas_server/views.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/cas_server/views.py b/cas_server/views.py index e431499..149632b 100644 --- a/cas_server/views.py +++ b/cas_server/views.py @@ -23,6 +23,7 @@ from django.views.decorators.csrf import csrf_exempt from django.views.generic import View +import re import logging import pprint import requests @@ -666,7 +667,10 @@ class ValidateService(View, AttributesMixin): params['username'] = self.ticket.user.attributs.get( self.ticket.service_pattern.user_field ) - if self.pgt_url and self.pgt_url.startswith("https://"): + if self.pgt_url and ( + self.pgt_url.startswith("https://") or + re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url) + ): return self.process_pgturl(params) else: logger.info( From 5cb25de99f7a2673706f9464c6a35e1d178686fe Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Fri, 24 Jun 2016 21:06:36 +0200 Subject: [PATCH 02/27] Put test username, password, attributes in settings --- cas_server/auth.py | 4 ++-- cas_server/default_settings.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/cas_server/auth.py b/cas_server/auth.py index 7ccacae..c9e8e34 100644 --- a/cas_server/auth.py +++ b/cas_server/auth.py @@ -57,11 +57,11 @@ class TestAuthUser(AuthUser): def test_password(self, password): """test `password` agains the user""" - return self.username == "test" and password == "test" + return self.username == settings.CAS_TEST_USER and password == settings.CAS_TEST_PASSWORD def attributs(self): """return a dict of user attributes""" - return {'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'} + return settings.CAS_TEST_ATTRIBUTES class MysqlAuthUser(AuthUser): diff --git a/cas_server/default_settings.py b/cas_server/default_settings.py index 9ad6f53..2c421d7 100644 --- a/cas_server/default_settings.py +++ b/cas_server/default_settings.py @@ -73,3 +73,7 @@ setting_default('CAS_SQL_DBCHARSET', 'utf8') setting_default('CAS_SQL_USER_QUERY', 'SELECT user AS usersame, pass AS ' 'password, users.* FROM users WHERE user = %s') setting_default('CAS_SQL_PASSWORD_CHECK', 'crypt') # crypt or plain + +setting_default('CAS_TEST_USER', 'test') +setting_default('CAS_TEST_PASSWORD', 'test') +setting_default('CAS_TEST_ATTRIBUTES', {'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}) From 12201665deb95342f834db094fe7c77be4eb4676 Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Fri, 24 Jun 2016 21:07:19 +0200 Subject: [PATCH 03/27] Add some dango unit tests --- cas_server/tests.py | 478 ++++++++++++++++++++++++++++++++++++++++++++ cas_server/utils.py | 31 +++ 2 files changed, 509 insertions(+) create mode 100644 cas_server/tests.py diff --git a/cas_server/tests.py b/cas_server/tests.py new file mode 100644 index 0000000..75683e6 --- /dev/null +++ b/cas_server/tests.py @@ -0,0 +1,478 @@ +from .default_settings import settings + +from django.test import TestCase +from django.test import Client + +from lxml import etree +import BaseHTTPServer + +import models +import utils + +def get_login_page_params(): + client = Client() + response = client.get('/login') + form = response.context["form"] + params = {} + for field in form: + if field.value(): + params[field.name] = field.value() + else: + params[field.name] = "" + return client, params + +def get_auth_client(): + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + + response = client.post('/login', params) + return client + +def get_user_ticket_request(service): + client = get_auth_client() + response = client.get("/login", {"service": service}) + ticket_value = response['Location'].split('ticket=')[-1] + user = models.User.objects.get(username=settings.CAS_TEST_USER, session_key=client.session.session_key) + ticket = models.ServiceTicket.objects.get(value=ticket_value) + return (user, ticket) + + + +def get_pgt(): + (httpd_thread, host, port) = utils.PGTUrlHandler.run() + service = "http://%s:%s" % (host, port) + + (user, ticket) = get_user_ticket_request(service) + + client = Client() + response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service}) + params = utils.PGTUrlHandler.PARAMS.copy() + + params["service"] = service + params["user"] = user + + return params + +class LoginTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + self.service_pattern = models.ServicePattern.objects.create( + name="example", + pattern="^https://www\.example\.com(/.*)?$", + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + def test_login_view_post_goodpass_goodlt(self): + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + + response = client.post('/login', params) + + self.assertEqual(response.status_code, 200) + self.assertTrue("You have successfully logged into the Central Authentication Service" in response.content) + + self.assertTrue(models.User.objects.get(username=settings.CAS_TEST_USER, session_key=client.session.session_key)) + + + def test_login_view_post_badlt(self): + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + params["lt"] = 'LT-random' + + response = client.post('/login', params) + + self.assertEqual(response.status_code, 200) + self.assertTrue("Invalid login ticket" in response.content) + self.assertFalse("You have successfully logged into the Central Authentication Service" in response.content) + + + def test_login_view_post_badpass_good_lt(self): + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = "test2" + response = client.post('/login', params) + + self.assertEqual(response.status_code, 200) + self.assertTrue(" The credentials you provided cannot be determined to be authentic" in response.content) + self.assertFalse("You have successfully logged into the Central Authentication Service" in response.content) + + + def test_view_login_get_auth_allowed_service(self): + client = get_auth_client() + response = client.get("/login?service=https://www.example.com") + self.assertEqual(response.status_code, 302) + self.assertTrue(response.has_header('Location')) + self.assertTrue(response['Location'].startswith("https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX)) + + ticket_value = response['Location'].split('ticket=')[-1] + user = models.User.objects.get(username=settings.CAS_TEST_USER, session_key=client.session.session_key) + self.assertTrue(user) + ticket = models.ServiceTicket.objects.get(value=ticket_value) + self.assertEqual(ticket.user, user) + self.assertEqual(ticket.attributs, settings.CAS_TEST_ATTRIBUTES) + self.assertEqual(ticket.validate, False) + self.assertEqual(ticket.service_pattern, self.service_pattern) + + def test_view_login_get_auth_denied_service(self): + client = get_auth_client() + response = client.get("/login?service=https://www.example.org") + self.assertEqual(response.status_code, 200) + self.assertTrue("Service https://www.example.org non allowed" in response.content) + + +class LogoutTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + + def test_logout_view(self): + client = get_auth_client() + + response = client.get("/login") + self.assertEqual(response.status_code, 200) + self.assertTrue("You have successfully logged into the Central Authentication Service" in response.content) + + response = client.get("/logout") + self.assertEqual(response.status_code, 200) + self.assertTrue("You have successfully logged out from the Central Authentication Service" in response.content) + + response = client.get("/login") + self.assertEqual(response.status_code, 200) + self.assertFalse("You have successfully logged into the Central Authentication Service" in response.content) + + def test_logout_view_url(self): + client = get_auth_client() + + response = client.get('/logout?url=https://www.example.com') + self.assertEqual(response.status_code, 302) + self.assertTrue(response.has_header("Location")) + self.assertEqual(response["Location"], "https://www.example.com") + + response = client.get("/login") + self.assertEqual(response.status_code, 200) + self.assertFalse("You have successfully logged into the Central Authentication Service" in response.content) + + def test_logout_view_service(self): + client = get_auth_client() + + response = client.get('/logout?service=https://www.example.com') + self.assertEqual(response.status_code, 302) + self.assertTrue(response.has_header("Location")) + self.assertEqual(response["Location"], "https://www.example.com") + + response = client.get("/login") + self.assertEqual(response.status_code, 200) + self.assertFalse("You have successfully logged into the Central Authentication Service" in response.content) + + + open("/tmp/test.html", "w").write(response.content) + + + +class AuthTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + self.service = 'https://www.example.com' + models.ServicePattern.objects.create( + name="example", + pattern="^https://www\.example\.com(/.*)?$" + ) + + def test_auth_view_goodpass(self): + settings.CAS_AUTH_SHARED_SECRET = 'test' + client = Client() + response = client.post('/auth', {'username':settings.CAS_TEST_USER, 'password':settings.CAS_TEST_PASSWORD, 'service':self.service, 'secret':'test'}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, 'yes\n') + + def test_auth_view_badpass(self): + settings.CAS_AUTH_SHARED_SECRET = 'test' + client = Client() + response = client.post('/auth', {'username':settings.CAS_TEST_USER, 'password':'badpass', 'service':self.service, 'secret':'test'}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, 'no\n') + + def test_auth_view_badservice(self): + settings.CAS_AUTH_SHARED_SECRET = 'test' + client = Client() + response = client.post('/auth', {'username':settings.CAS_TEST_USER, 'password':settings.CAS_TEST_PASSWORD, 'service':'https://www.example.org', 'secret':'test'}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, 'no\n') + + def test_auth_view_badsecret(self): + settings.CAS_AUTH_SHARED_SECRET = 'test' + client = Client() + response = client.post('/auth', {'username':settings.CAS_TEST_USER, 'password':settings.CAS_TEST_PASSWORD, 'service':self.service, 'secret':'badsecret'}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, 'no\n') + + def test_auth_view_badsettings(self): + settings.CAS_AUTH_SHARED_SECRET = None + client = Client() + response = client.post('/auth', {'username':settings.CAS_TEST_USER, 'password':settings.CAS_TEST_PASSWORD, 'service':self.service, 'secret':'test'}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, "no\nplease set CAS_AUTH_SHARED_SECRET") + + +class ValidateTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + self.service = 'https://www.example.com' + self.service_pattern = models.ServicePattern.objects.create( + name="example", + pattern="^https://www\.example\.com(/.*)?$" + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + def test_validate_view_ok(self): + (user, ticket) = get_user_ticket_request(self.service) + + client = Client() + response = client.get('/validate', {'ticket': ticket.value, 'service': self.service}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, 'yes\ntest\n') + + def test_validate_view_badservice(self): + (user, ticket) = get_user_ticket_request(self.service) + + client = Client() + response = client.get('/validate', {'ticket': ticket.value, 'service': "https://www.example.org"}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, 'no\n') + + def test_validate_view_badticket(self): + (user, ticket) = get_user_ticket_request(self.service) + + client = Client() + response = client.get('/validate', {'ticket': "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX, 'service': self.service}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, 'no\n') + +class ValidateServiceTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + self.service = 'http://127.0.0.1:45678' + self.service_pattern = models.ServicePattern.objects.create( + name="localhost", + pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$", + proxy_callback=True + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + def test_validate_service_view_ok(self): + (user, ticket) = get_user_ticket_request(self.service) + + client = Client() + response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + sucess = root.xpath("//cas:authenticationSuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertTrue(sucess) + + users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(users), 1) + self.assertEqual(users[0].text, settings.CAS_TEST_USER) + + attributes = root.xpath("//cas:attributes", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(attributes), 1) + attrs1 = {} + for attr in attributes[0]: + attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]]=attr.text + + attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(attributes), len(attrs1)) + attrs2 = {} + for attr in attributes: + attrs2[attr.attrib['name']] = attr.attrib['value'] + self.assertEqual(attrs1, attrs2) + self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES) + + def test_validate_service_view_badservice(self): + (user, ticket) = get_user_ticket_request(self.service) + + client = Client() + bad_service = "https://www.example.org" + response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': bad_service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "INVALID_SERVICE") + self.assertEqual(error[0].text, bad_service) + + def test_validate_service_view_badticket_goodprefix(self): + (user, ticket) = get_user_ticket_request(self.service) + + client = Client() + bad_ticket = "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX + response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "INVALID_TICKET") + self.assertEqual(error[0].text, 'ticket not found') + + def test_validate_service_view_badticket_badprefix(self): + (user, ticket) = get_user_ticket_request(self.service) + + client = Client() + bad_ticket = "RANDOM" + response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "INVALID_TICKET") + self.assertEqual(error[0].text, bad_ticket) + + def test_validate_service_view_ok_pgturl(self): + (httpd_thread, host, port) = utils.PGTUrlHandler.run() + service = "http://%s:%s" % (host, port) + + (user, ticket) = get_user_ticket_request(service) + + client = Client() + response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service}) + pgt_params = utils.PGTUrlHandler.PARAMS.copy() + self.assertEqual(response.status_code, 200) + + + root = etree.fromstring(response.content) + pgtiou = root.xpath("//cas:proxyGrantingTicket", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(pgtiou), 1) + self.assertEqual(pgt_params["pgtIou"], pgtiou[0].text) + self.assertTrue("pgtId" in pgt_params) + + def test_validate_service_pgturl_bad_proxy_callback(self): + self.service_pattern.proxy_callback = False + self.service_pattern.save() + (user, ticket) = get_user_ticket_request(self.service) + + client = Client() + response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service, 'pgtUrl': self.service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "INVALID_PROXY_CALLBACK") + self.assertEqual(error[0].text, "callback url not allowed by configuration") + +class ProxyTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + self.service = 'http://127.0.0.1' + self.service_pattern = models.ServicePattern.objects.create( + name="localhost", + pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$", + proxy=True, + proxy_callback=True + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + + def test_validate_proxy_ok(self): + params = get_pgt() + + # get a proxy ticket + client1 = Client() + response = client1.get('/proxy', {'pgt': params['pgtId'], 'targetService': self.service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + sucess = root.xpath("//cas:proxySuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertTrue(sucess) + + proxy_ticket = root.xpath("//cas:proxyTicket", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(proxy_ticket), 1) + proxy_ticket = proxy_ticket[0].text + + + # validate the proxy ticket + client2 = Client() + response = client2.get('/proxyValidate', {'ticket': proxy_ticket, 'service': self.service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + sucess = root.xpath("//cas:authenticationSuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertTrue(sucess) + + # check that the proxy is send to the end service + proxies = root.xpath("//cas:proxies", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(proxies), 1) + proxy = proxies[0].xpath("//cas:proxy", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(proxy), 1) + self.assertEqual(proxy[0].text, params["service"]) + + # same tests than those for serviceValidate + users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(users), 1) + self.assertEqual(users[0].text, settings.CAS_TEST_USER) + + attributes = root.xpath("//cas:attributes", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(attributes), 1) + attrs1 = {} + for attr in attributes[0]: + attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]]=attr.text + + attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(attributes), len(attrs1)) + attrs2 = {} + for attr in attributes: + attrs2[attr.attrib['name']] = attr.attrib['value'] + self.assertEqual(attrs1, attrs2) + self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES) + + + def test_validate_proxy_bad(self): + params = get_pgt() + + # bad PGT + client1 = Client() + response = client1.get('/proxy', {'pgt': "%s-RANDOM" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX, 'targetService': params['service']}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "INVALID_TICKET") + self.assertEqual(error[0].text, "PGT %s-RANDOM not found" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX) + + # bad targetService + client2 = Client() + response = client2.get('/proxy', {'pgt': params['pgtId'], 'targetService': "https://www.example.org"}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE") + self.assertEqual(error[0].text, "https://www.example.org") + + + # service do not allow proxy ticket + self.service_pattern.proxy = False + self.service_pattern.save() + + client3 = Client() + response = client3.get('/proxy', {'pgt': params['pgtId'], 'targetService': params['service']}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE") + self.assertEqual(error[0].text, 'the service %s do not allow proxy ticket' % params['service']) diff --git a/cas_server/utils.py b/cas_server/utils.py index fdb8f46..69e5623 100644 --- a/cas_server/utils.py +++ b/cas_server/utils.py @@ -19,6 +19,8 @@ from django.contrib import messages import random import string import json +import BaseHTTPServer +from threading import Thread from importlib import import_module try: @@ -144,3 +146,32 @@ def gen_pgtiou(): def gen_saml_id(): """Generate an saml id""" return _gen_ticket('_') + + +class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler): + PARAMS={} + def do_GET(s): + s.send_response(200) + s.send_header("Content-type", "text/plain") + s.end_headers() + s.wfile.write("ok") + url = urlparse(s.path) + params = dict(parse_qsl(url.query)) + PGTUrlHandler.PARAMS.update(params) + s.wfile.write("%s" % params) + def log_message(self, format, *args): + return + + @staticmethod + def run(): + server_class = BaseHTTPServer.HTTPServer + httpd = server_class(("127.0.0.1", 0), PGTUrlHandler) + (host, port) = httpd.socket.getsockname() + def lauch(): + httpd.handle_request() + #httpd.serve_forever() + httpd.server_close() + httpd_thread = Thread(target=lauch) + httpd_thread.daemon = True + httpd_thread.start() + return (httpd_thread, host, port) From 0776e371e8df90f19338650c6e5277bae8af0da7 Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Fri, 24 Jun 2016 21:23:33 +0200 Subject: [PATCH 04/27] style --- cas_server/default_settings.py | 5 +- cas_server/tests.py | 277 ++++++++++++++++++++++++++------- cas_server/utils.py | 9 +- 3 files changed, 229 insertions(+), 62 deletions(-) diff --git a/cas_server/default_settings.py b/cas_server/default_settings.py index 2c421d7..2824991 100644 --- a/cas_server/default_settings.py +++ b/cas_server/default_settings.py @@ -76,4 +76,7 @@ setting_default('CAS_SQL_PASSWORD_CHECK', 'crypt') # crypt or plain setting_default('CAS_TEST_USER', 'test') setting_default('CAS_TEST_PASSWORD', 'test') -setting_default('CAS_TEST_ATTRIBUTES', {'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}) +setting_default( + 'CAS_TEST_ATTRIBUTES', + {'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'} +) diff --git a/cas_server/tests.py b/cas_server/tests.py index 75683e6..b989ee6 100644 --- a/cas_server/tests.py +++ b/cas_server/tests.py @@ -4,11 +4,11 @@ from django.test import TestCase from django.test import Client from lxml import etree -import BaseHTTPServer import models import utils + def get_login_page_params(): client = Client() response = client.get('/login') @@ -21,24 +21,28 @@ def get_login_page_params(): params[field.name] = "" return client, params + def get_auth_client(): client, params = get_login_page_params() params["username"] = settings.CAS_TEST_USER params["password"] = settings.CAS_TEST_PASSWORD - response = client.post('/login', params) + client.post('/login', params) return client + def get_user_ticket_request(service): client = get_auth_client() response = client.get("/login", {"service": service}) ticket_value = response['Location'].split('ticket=')[-1] - user = models.User.objects.get(username=settings.CAS_TEST_USER, session_key=client.session.session_key) + user = models.User.objects.get( + username=settings.CAS_TEST_USER, + session_key=client.session.session_key + ) ticket = models.ServiceTicket.objects.get(value=ticket_value) return (user, ticket) - def get_pgt(): (httpd_thread, host, port) = utils.PGTUrlHandler.run() service = "http://%s:%s" % (host, port) @@ -46,7 +50,7 @@ def get_pgt(): (user, ticket) = get_user_ticket_request(service) client = Client() - response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service}) + client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service}) params = utils.PGTUrlHandler.PARAMS.copy() params["service"] = service @@ -54,6 +58,7 @@ def get_pgt(): return params + class LoginTestCase(TestCase): def setUp(self): @@ -72,10 +77,19 @@ class LoginTestCase(TestCase): response = client.post('/login', params) self.assertEqual(response.status_code, 200) - self.assertTrue("You have successfully logged into the Central Authentication Service" in response.content) - - self.assertTrue(models.User.objects.get(username=settings.CAS_TEST_USER, session_key=client.session.session_key)) + self.assertTrue( + ( + "You have successfully logged into " + "the Central Authentication Service" + ) in response.content + ) + self.assertTrue( + models.User.objects.get( + username=settings.CAS_TEST_USER, + session_key=client.session.session_key + ) + ) def test_login_view_post_badlt(self): client, params = get_login_page_params() @@ -87,8 +101,12 @@ class LoginTestCase(TestCase): self.assertEqual(response.status_code, 200) self.assertTrue("Invalid login ticket" in response.content) - self.assertFalse("You have successfully logged into the Central Authentication Service" in response.content) - + self.assertFalse( + ( + "You have successfully logged into " + "the Central Authentication Service" + ) in response.content + ) def test_login_view_post_badpass_good_lt(self): client, params = get_login_page_params() @@ -97,19 +115,35 @@ class LoginTestCase(TestCase): response = client.post('/login', params) self.assertEqual(response.status_code, 200) - self.assertTrue(" The credentials you provided cannot be determined to be authentic" in response.content) - self.assertFalse("You have successfully logged into the Central Authentication Service" in response.content) - + self.assertTrue( + ( + "The credentials you provided cannot be " + "determined to be authentic" + ) in response.content + ) + self.assertFalse( + ( + "You have successfully logged into " + "the Central Authentication Service" + ) in response.content + ) def test_view_login_get_auth_allowed_service(self): client = get_auth_client() response = client.get("/login?service=https://www.example.com") self.assertEqual(response.status_code, 302) self.assertTrue(response.has_header('Location')) - self.assertTrue(response['Location'].startswith("https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX)) + self.assertTrue( + response['Location'].startswith( + "https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX + ) + ) ticket_value = response['Location'].split('ticket=')[-1] - user = models.User.objects.get(username=settings.CAS_TEST_USER, session_key=client.session.session_key) + user = models.User.objects.get( + username=settings.CAS_TEST_USER, + session_key=client.session.session_key + ) self.assertTrue(user) ticket = models.ServiceTicket.objects.get(value=ticket_value) self.assertEqual(ticket.user, user) @@ -134,15 +168,30 @@ class LogoutTestCase(TestCase): response = client.get("/login") self.assertEqual(response.status_code, 200) - self.assertTrue("You have successfully logged into the Central Authentication Service" in response.content) + self.assertTrue( + ( + "You have successfully logged into " + "the Central Authentication Service" + ) in response.content + ) response = client.get("/logout") self.assertEqual(response.status_code, 200) - self.assertTrue("You have successfully logged out from the Central Authentication Service" in response.content) + self.assertTrue( + ( + "You have successfully logged out from " + "the Central Authentication Service" + ) in response.content + ) response = client.get("/login") self.assertEqual(response.status_code, 200) - self.assertFalse("You have successfully logged into the Central Authentication Service" in response.content) + self.assertFalse( + ( + "You have successfully logged into " + "the Central Authentication Service" + ) in response.content + ) def test_logout_view_url(self): client = get_auth_client() @@ -154,7 +203,12 @@ class LogoutTestCase(TestCase): response = client.get("/login") self.assertEqual(response.status_code, 200) - self.assertFalse("You have successfully logged into the Central Authentication Service" in response.content) + self.assertFalse( + ( + "You have successfully logged into " + "the Central Authentication Service" + ) in response.content + ) def test_logout_view_service(self): client = get_auth_client() @@ -166,11 +220,12 @@ class LogoutTestCase(TestCase): response = client.get("/login") self.assertEqual(response.status_code, 200) - self.assertFalse("You have successfully logged into the Central Authentication Service" in response.content) - - - open("/tmp/test.html", "w").write(response.content) - + self.assertFalse( + ( + "You have successfully logged into " + "the Central Authentication Service" + ) in response.content + ) class AuthTestCase(TestCase): @@ -186,35 +241,75 @@ class AuthTestCase(TestCase): def test_auth_view_goodpass(self): settings.CAS_AUTH_SHARED_SECRET = 'test' client = Client() - response = client.post('/auth', {'username':settings.CAS_TEST_USER, 'password':settings.CAS_TEST_PASSWORD, 'service':self.service, 'secret':'test'}) + response = client.post( + '/auth', + { + 'username': settings.CAS_TEST_USER, + 'password': settings.CAS_TEST_PASSWORD, + 'service': self.service, + 'secret': 'test' + } + ) self.assertEqual(response.status_code, 200) self.assertEqual(response.content, 'yes\n') def test_auth_view_badpass(self): settings.CAS_AUTH_SHARED_SECRET = 'test' client = Client() - response = client.post('/auth', {'username':settings.CAS_TEST_USER, 'password':'badpass', 'service':self.service, 'secret':'test'}) + response = client.post( + '/auth', + { + 'username': settings.CAS_TEST_USER, + 'password': 'badpass', + 'service': self.service, + 'secret': 'test' + } + ) self.assertEqual(response.status_code, 200) self.assertEqual(response.content, 'no\n') def test_auth_view_badservice(self): settings.CAS_AUTH_SHARED_SECRET = 'test' client = Client() - response = client.post('/auth', {'username':settings.CAS_TEST_USER, 'password':settings.CAS_TEST_PASSWORD, 'service':'https://www.example.org', 'secret':'test'}) + response = client.post( + '/auth', + { + 'username': settings.CAS_TEST_USER, + 'password': settings.CAS_TEST_PASSWORD, + 'service': 'https://www.example.org', + 'secret': 'test' + } + ) self.assertEqual(response.status_code, 200) self.assertEqual(response.content, 'no\n') def test_auth_view_badsecret(self): settings.CAS_AUTH_SHARED_SECRET = 'test' client = Client() - response = client.post('/auth', {'username':settings.CAS_TEST_USER, 'password':settings.CAS_TEST_PASSWORD, 'service':self.service, 'secret':'badsecret'}) + response = client.post( + '/auth', + { + 'username': settings.CAS_TEST_USER, + 'password': settings.CAS_TEST_PASSWORD, + 'service': self.service, + 'secret': 'badsecret' + } + ) self.assertEqual(response.status_code, 200) self.assertEqual(response.content, 'no\n') def test_auth_view_badsettings(self): settings.CAS_AUTH_SHARED_SECRET = None client = Client() - response = client.post('/auth', {'username':settings.CAS_TEST_USER, 'password':settings.CAS_TEST_PASSWORD, 'service':self.service, 'secret':'test'}) + response = client.post( + '/auth', + { + 'username': settings.CAS_TEST_USER, + 'password': settings.CAS_TEST_PASSWORD, + 'service': self.service, + 'secret': 'test' + } + ) self.assertEqual(response.status_code, 200) self.assertEqual(response.content, "no\nplease set CAS_AUTH_SHARED_SECRET") @@ -242,7 +337,10 @@ class ValidateTestCase(TestCase): (user, ticket) = get_user_ticket_request(self.service) client = Client() - response = client.get('/validate', {'ticket': ticket.value, 'service': "https://www.example.org"}) + response = client.get( + '/validate', + {'ticket': ticket.value, 'service': "https://www.example.org"} + ) self.assertEqual(response.status_code, 200) self.assertEqual(response.content, 'no\n') @@ -250,10 +348,14 @@ class ValidateTestCase(TestCase): (user, ticket) = get_user_ticket_request(self.service) client = Client() - response = client.get('/validate', {'ticket': "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX, 'service': self.service}) + response = client.get( + '/validate', + {'ticket': "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX, 'service': self.service} + ) self.assertEqual(response.status_code, 200) self.assertEqual(response.content, 'no\n') + class ValidateServiceTestCase(TestCase): def setUp(self): @@ -274,18 +376,24 @@ class ValidateServiceTestCase(TestCase): self.assertEqual(response.status_code, 200) root = etree.fromstring(response.content) - sucess = root.xpath("//cas:authenticationSuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + sucess = root.xpath( + "//cas:authenticationSuccess", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) self.assertTrue(sucess) users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"}) self.assertEqual(len(users), 1) self.assertEqual(users[0].text, settings.CAS_TEST_USER) - attributes = root.xpath("//cas:attributes", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + attributes = root.xpath( + "//cas:attributes", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) self.assertEqual(len(attributes), 1) attrs1 = {} for attr in attributes[0]: - attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]]=attr.text + attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) self.assertEqual(len(attributes), len(attrs1)) @@ -304,7 +412,10 @@ class ValidateServiceTestCase(TestCase): self.assertEqual(response.status_code, 200) root = etree.fromstring(response.content) - error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) self.assertEqual(len(error), 1) self.assertEqual(error[0].attrib['code'], "INVALID_SERVICE") self.assertEqual(error[0].text, bad_service) @@ -318,7 +429,10 @@ class ValidateServiceTestCase(TestCase): self.assertEqual(response.status_code, 200) root = etree.fromstring(response.content) - error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) self.assertEqual(len(error), 1) self.assertEqual(error[0].attrib['code'], "INVALID_TICKET") self.assertEqual(error[0].text, 'ticket not found') @@ -332,7 +446,10 @@ class ValidateServiceTestCase(TestCase): self.assertEqual(response.status_code, 200) root = etree.fromstring(response.content) - error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) self.assertEqual(len(error), 1) self.assertEqual(error[0].attrib['code'], "INVALID_TICKET") self.assertEqual(error[0].text, bad_ticket) @@ -344,13 +461,18 @@ class ValidateServiceTestCase(TestCase): (user, ticket) = get_user_ticket_request(service) client = Client() - response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service}) + response = client.get( + '/serviceValidate', + {'ticket': ticket.value, 'service': service, 'pgtUrl': service} + ) pgt_params = utils.PGTUrlHandler.PARAMS.copy() self.assertEqual(response.status_code, 200) - root = etree.fromstring(response.content) - pgtiou = root.xpath("//cas:proxyGrantingTicket", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + pgtiou = root.xpath( + "//cas:proxyGrantingTicket", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) self.assertEqual(len(pgtiou), 1) self.assertEqual(pgt_params["pgtIou"], pgtiou[0].text) self.assertTrue("pgtId" in pgt_params) @@ -361,15 +483,22 @@ class ValidateServiceTestCase(TestCase): (user, ticket) = get_user_ticket_request(self.service) client = Client() - response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service, 'pgtUrl': self.service}) + response = client.get( + '/serviceValidate', + {'ticket': ticket.value, 'service': self.service, 'pgtUrl': self.service} + ) self.assertEqual(response.status_code, 200) root = etree.fromstring(response.content) - error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) self.assertEqual(len(error), 1) self.assertEqual(error[0].attrib['code'], "INVALID_PROXY_CALLBACK") self.assertEqual(error[0].text, "callback url not allowed by configuration") + class ProxyTestCase(TestCase): def setUp(self): @@ -383,7 +512,6 @@ class ProxyTestCase(TestCase): ) models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) - def test_validate_proxy_ok(self): params = get_pgt() @@ -396,18 +524,23 @@ class ProxyTestCase(TestCase): sucess = root.xpath("//cas:proxySuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"}) self.assertTrue(sucess) - proxy_ticket = root.xpath("//cas:proxyTicket", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + proxy_ticket = root.xpath( + "//cas:proxyTicket", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) self.assertEqual(len(proxy_ticket), 1) proxy_ticket = proxy_ticket[0].text - # validate the proxy ticket client2 = Client() response = client2.get('/proxyValidate', {'ticket': proxy_ticket, 'service': self.service}) self.assertEqual(response.status_code, 200) root = etree.fromstring(response.content) - sucess = root.xpath("//cas:authenticationSuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + sucess = root.xpath( + "//cas:authenticationSuccess", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) self.assertTrue(sucess) # check that the proxy is send to the end service @@ -422,11 +555,14 @@ class ProxyTestCase(TestCase): self.assertEqual(len(users), 1) self.assertEqual(users[0].text, settings.CAS_TEST_USER) - attributes = root.xpath("//cas:attributes", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + attributes = root.xpath( + "//cas:attributes", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) self.assertEqual(len(attributes), 1) attrs1 = {} for attr in attributes[0]: - attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]]=attr.text + attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) self.assertEqual(len(attributes), len(attrs1)) @@ -436,43 +572,68 @@ class ProxyTestCase(TestCase): self.assertEqual(attrs1, attrs2) self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES) - def test_validate_proxy_bad(self): params = get_pgt() # bad PGT client1 = Client() - response = client1.get('/proxy', {'pgt': "%s-RANDOM" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX, 'targetService': params['service']}) + response = client1.get( + '/proxy', + { + 'pgt': "%s-RANDOM" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX, + 'targetService': params['service'] + } + ) self.assertEqual(response.status_code, 200) root = etree.fromstring(response.content) - error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) self.assertEqual(len(error), 1) self.assertEqual(error[0].attrib['code'], "INVALID_TICKET") - self.assertEqual(error[0].text, "PGT %s-RANDOM not found" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX) + self.assertEqual( + error[0].text, + "PGT %s-RANDOM not found" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX + ) # bad targetService client2 = Client() - response = client2.get('/proxy', {'pgt': params['pgtId'], 'targetService': "https://www.example.org"}) + response = client2.get( + '/proxy', + {'pgt': params['pgtId'], 'targetService': "https://www.example.org"} + ) self.assertEqual(response.status_code, 200) root = etree.fromstring(response.content) - error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) self.assertEqual(len(error), 1) self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE") self.assertEqual(error[0].text, "https://www.example.org") - # service do not allow proxy ticket self.service_pattern.proxy = False self.service_pattern.save() client3 = Client() - response = client3.get('/proxy', {'pgt': params['pgtId'], 'targetService': params['service']}) + response = client3.get( + '/proxy', + {'pgt': params['pgtId'], 'targetService': params['service']} + ) self.assertEqual(response.status_code, 200) root = etree.fromstring(response.content) - error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) self.assertEqual(len(error), 1) self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE") - self.assertEqual(error[0].text, 'the service %s do not allow proxy ticket' % params['service']) + self.assertEqual( + error[0].text, + 'the service %s do not allow proxy ticket' % params['service'] + ) diff --git a/cas_server/utils.py b/cas_server/utils.py index 69e5623..4db8f9e 100644 --- a/cas_server/utils.py +++ b/cas_server/utils.py @@ -149,7 +149,8 @@ def gen_saml_id(): class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler): - PARAMS={} + PARAMS = {} + def do_GET(s): s.send_response(200) s.send_header("Content-type", "text/plain") @@ -159,6 +160,7 @@ class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler): params = dict(parse_qsl(url.query)) PGTUrlHandler.PARAMS.update(params) s.wfile.write("%s" % params) + def log_message(self, format, *args): return @@ -166,11 +168,12 @@ class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler): def run(): server_class = BaseHTTPServer.HTTPServer httpd = server_class(("127.0.0.1", 0), PGTUrlHandler) - (host, port) = httpd.socket.getsockname() + (host, port) = httpd.socket.getsockname() + def lauch(): httpd.handle_request() - #httpd.serve_forever() httpd.server_close() + httpd_thread = Thread(target=lauch) httpd_thread.daemon = True httpd_thread.start() From 4bb886f08366b18585c20cbda1612d0e0d335d1f Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Fri, 24 Jun 2016 23:37:24 +0200 Subject: [PATCH 05/27] python3 compatibility --- cas_server/tests.py | 60 ++++++++++++++++++++++----------------------- cas_server/urls.py | 2 +- cas_server/utils.py | 14 +++-------- 3 files changed, 35 insertions(+), 41 deletions(-) diff --git a/cas_server/tests.py b/cas_server/tests.py index b989ee6..6db4518 100644 --- a/cas_server/tests.py +++ b/cas_server/tests.py @@ -5,8 +5,8 @@ from django.test import Client from lxml import etree -import models -import utils +from cas_server import models +from cas_server import utils def get_login_page_params(): @@ -79,8 +79,8 @@ class LoginTestCase(TestCase): self.assertEqual(response.status_code, 200) self.assertTrue( ( - "You have successfully logged into " - "the Central Authentication Service" + b"You have successfully logged into " + b"the Central Authentication Service" ) in response.content ) @@ -100,11 +100,11 @@ class LoginTestCase(TestCase): response = client.post('/login', params) self.assertEqual(response.status_code, 200) - self.assertTrue("Invalid login ticket" in response.content) + self.assertTrue(b"Invalid login ticket" in response.content) self.assertFalse( ( - "You have successfully logged into " - "the Central Authentication Service" + b"You have successfully logged into " + b"the Central Authentication Service" ) in response.content ) @@ -117,14 +117,14 @@ class LoginTestCase(TestCase): self.assertEqual(response.status_code, 200) self.assertTrue( ( - "The credentials you provided cannot be " - "determined to be authentic" + b"The credentials you provided cannot be " + b"determined to be authentic" ) in response.content ) self.assertFalse( ( - "You have successfully logged into " - "the Central Authentication Service" + b"You have successfully logged into " + b"the Central Authentication Service" ) in response.content ) @@ -155,7 +155,7 @@ class LoginTestCase(TestCase): client = get_auth_client() response = client.get("/login?service=https://www.example.org") self.assertEqual(response.status_code, 200) - self.assertTrue("Service https://www.example.org non allowed" in response.content) + self.assertTrue(b"Service https://www.example.org non allowed" in response.content) class LogoutTestCase(TestCase): @@ -170,8 +170,8 @@ class LogoutTestCase(TestCase): self.assertEqual(response.status_code, 200) self.assertTrue( ( - "You have successfully logged into " - "the Central Authentication Service" + b"You have successfully logged into " + b"the Central Authentication Service" ) in response.content ) @@ -179,8 +179,8 @@ class LogoutTestCase(TestCase): self.assertEqual(response.status_code, 200) self.assertTrue( ( - "You have successfully logged out from " - "the Central Authentication Service" + b"You have successfully logged out from " + b"the Central Authentication Service" ) in response.content ) @@ -188,8 +188,8 @@ class LogoutTestCase(TestCase): self.assertEqual(response.status_code, 200) self.assertFalse( ( - "You have successfully logged into " - "the Central Authentication Service" + b"You have successfully logged into " + b"the Central Authentication Service" ) in response.content ) @@ -205,8 +205,8 @@ class LogoutTestCase(TestCase): self.assertEqual(response.status_code, 200) self.assertFalse( ( - "You have successfully logged into " - "the Central Authentication Service" + b"You have successfully logged into " + b"the Central Authentication Service" ) in response.content ) @@ -222,8 +222,8 @@ class LogoutTestCase(TestCase): self.assertEqual(response.status_code, 200) self.assertFalse( ( - "You have successfully logged into " - "the Central Authentication Service" + b"You have successfully logged into " + b"the Central Authentication Service" ) in response.content ) @@ -251,7 +251,7 @@ class AuthTestCase(TestCase): } ) self.assertEqual(response.status_code, 200) - self.assertEqual(response.content, 'yes\n') + self.assertEqual(response.content, b'yes\n') def test_auth_view_badpass(self): settings.CAS_AUTH_SHARED_SECRET = 'test' @@ -266,7 +266,7 @@ class AuthTestCase(TestCase): } ) self.assertEqual(response.status_code, 200) - self.assertEqual(response.content, 'no\n') + self.assertEqual(response.content, b'no\n') def test_auth_view_badservice(self): settings.CAS_AUTH_SHARED_SECRET = 'test' @@ -281,7 +281,7 @@ class AuthTestCase(TestCase): } ) self.assertEqual(response.status_code, 200) - self.assertEqual(response.content, 'no\n') + self.assertEqual(response.content, b'no\n') def test_auth_view_badsecret(self): settings.CAS_AUTH_SHARED_SECRET = 'test' @@ -296,7 +296,7 @@ class AuthTestCase(TestCase): } ) self.assertEqual(response.status_code, 200) - self.assertEqual(response.content, 'no\n') + self.assertEqual(response.content, b'no\n') def test_auth_view_badsettings(self): settings.CAS_AUTH_SHARED_SECRET = None @@ -311,7 +311,7 @@ class AuthTestCase(TestCase): } ) self.assertEqual(response.status_code, 200) - self.assertEqual(response.content, "no\nplease set CAS_AUTH_SHARED_SECRET") + self.assertEqual(response.content, b"no\nplease set CAS_AUTH_SHARED_SECRET") class ValidateTestCase(TestCase): @@ -331,7 +331,7 @@ class ValidateTestCase(TestCase): client = Client() response = client.get('/validate', {'ticket': ticket.value, 'service': self.service}) self.assertEqual(response.status_code, 200) - self.assertEqual(response.content, 'yes\ntest\n') + self.assertEqual(response.content, b'yes\ntest\n') def test_validate_view_badservice(self): (user, ticket) = get_user_ticket_request(self.service) @@ -342,7 +342,7 @@ class ValidateTestCase(TestCase): {'ticket': ticket.value, 'service': "https://www.example.org"} ) self.assertEqual(response.status_code, 200) - self.assertEqual(response.content, 'no\n') + self.assertEqual(response.content, b'no\n') def test_validate_view_badticket(self): (user, ticket) = get_user_ticket_request(self.service) @@ -353,7 +353,7 @@ class ValidateTestCase(TestCase): {'ticket': "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX, 'service': self.service} ) self.assertEqual(response.status_code, 200) - self.assertEqual(response.content, 'no\n') + self.assertEqual(response.content, b'no\n') class ValidateServiceTestCase(TestCase): diff --git a/cas_server/urls.py b/cas_server/urls.py index b2ed38b..982ef9d 100644 --- a/cas_server/urls.py +++ b/cas_server/urls.py @@ -14,7 +14,7 @@ from django.conf.urls import patterns, url from django.views.generic import RedirectView from django.views.decorators.debug import sensitive_post_parameters, sensitive_variables -import views +from cas_server import views urlpatterns = patterns( '', diff --git a/cas_server/utils.py b/cas_server/utils.py index 4db8f9e..2c2b77f 100644 --- a/cas_server/utils.py +++ b/cas_server/utils.py @@ -19,15 +19,10 @@ from django.contrib import messages import random import string import json -import BaseHTTPServer from threading import Thread from importlib import import_module - -try: - from urlparse import urlparse, urlunparse, parse_qsl - from urllib import urlencode -except ImportError: - from urllib.parse import urlparse, urlunparse, parse_qsl, urlencode +from six.moves import BaseHTTPServer +from six.moves.urllib.parse import urlparse, urlunparse, parse_qsl, urlencode def context(params): @@ -153,13 +148,12 @@ class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler): def do_GET(s): s.send_response(200) - s.send_header("Content-type", "text/plain") + s.send_header(b"Content-type", "text/plain") s.end_headers() - s.wfile.write("ok") + s.wfile.write(b"ok") url = urlparse(s.path) params = dict(parse_qsl(url.query)) PGTUrlHandler.PARAMS.update(params) - s.wfile.write("%s" % params) def log_message(self, format, *args): return From 64b90c50778caf073f9ec2ba1fcb3c5a332c8c4e Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Sat, 25 Jun 2016 11:09:46 +0200 Subject: [PATCH 06/27] Use django integrated unit tests --- requirements-dev.txt | 1 + run_tests | 22 +++++ settings_tests.py | 83 +++++++++++++++++ tests/__init__.py | 0 tests/dummy.py | 136 --------------------------- tests/init.py | 32 ------- tests/test_proxy.py | 52 ----------- tests/test_validate_service.py | 87 ------------------ tests/test_views_auth.py | 46 ---------- tests/test_views_login.py | 163 --------------------------------- tests/test_views_logout.py | 80 ---------------- tests/test_views_validate.py | 58 ------------ tox.ini | 2 +- 13 files changed, 107 insertions(+), 655 deletions(-) create mode 100755 run_tests create mode 100644 settings_tests.py delete mode 100644 tests/__init__.py delete mode 100644 tests/dummy.py delete mode 100644 tests/init.py delete mode 100644 tests/test_proxy.py delete mode 100644 tests/test_validate_service.py delete mode 100644 tests/test_views_auth.py delete mode 100644 tests/test_views_login.py delete mode 100644 tests/test_views_logout.py delete mode 100644 tests/test_views_validate.py diff --git a/requirements-dev.txt b/requirements-dev.txt index 9998ce7..e6ef993 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,3 +7,4 @@ django-picklefield>=0.3.1 requests_futures>=0.9.5 django-bootstrap3>=5.4 lxml>=3.4 +six>=1 diff --git a/run_tests b/run_tests new file mode 100755 index 0000000..4ea21ee --- /dev/null +++ b/run_tests @@ -0,0 +1,22 @@ +#!/usr/bin/env python +import os, sys +import django +from django.conf import settings + +import settings_tests + +settings.configure(**settings_tests.__dict__) +django.setup() + +try: + # Django <= 1.8 + from django.test.simple import DjangoTestSuiteRunner + test_runner = DjangoTestSuiteRunner(verbosity=1) +except ImportError: + # Django >= 1.8 + from django.test.runner import DiscoverRunner + test_runner = DiscoverRunner(verbosity=1) + +failures = test_runner.run_tests(['cas_server']) +if failures: + sys.exit(failures) diff --git a/settings_tests.py b/settings_tests.py new file mode 100644 index 0000000..4588c2c --- /dev/null +++ b/settings_tests.py @@ -0,0 +1,83 @@ +""" +Django test settings for cas_server application. + +Generated by 'django-admin startproject' using Django 1.9.7. + +For more information on this file, see +https://docs.djangoproject.com/en/1.9/topics/settings/ + +For the full list of settings and their values, see +https://docs.djangoproject.com/en/1.9/ref/settings/ +""" + +import os + +# Build paths inside the project like this: os.path.join(BASE_DIR, ...) +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +# Quick-start development settings - unsuitable for production +# See https://docs.djangoproject.com/en/1.9/howto/deployment/checklist/ + +# SECURITY WARNING: keep the secret key used in production secret! +SECRET_KEY = 'changeme' + +# SECURITY WARNING: don't run with debug turned on in production! +DEBUG = True + +ALLOWED_HOSTS = [] + + +# Application definition + +INSTALLED_APPS = [ + 'django.contrib.admin', + 'django.contrib.auth', + 'django.contrib.contenttypes', + 'django.contrib.sessions', + 'django.contrib.messages', + 'django.contrib.staticfiles', + 'bootstrap3', + 'cas_server', +] + +MIDDLEWARE_CLASSES = [ + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.middleware.common.CommonMiddleware', + 'django.middleware.csrf.CsrfViewMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'django.contrib.auth.middleware.SessionAuthenticationMiddleware', + 'django.contrib.messages.middleware.MessageMiddleware', + 'django.middleware.clickjacking.XFrameOptionsMiddleware', + 'django.middleware.locale.LocaleMiddleware', +] + +ROOT_URLCONF = 'cas_server.urls' + +# Database +# https://docs.djangoproject.com/en/1.9/ref/settings/#databases + +DATABASES = { + 'default': { + 'ENGINE': 'django.db.backends.sqlite3', + } +} + +# Internationalization +# https://docs.djangoproject.com/en/1.9/topics/i18n/ + +LANGUAGE_CODE = 'en-us' + +TIME_ZONE = 'UTC' + +USE_I18N = True + +USE_L10N = True + +USE_TZ = True + + +# Static files (CSS, JavaScript, Images) +# https://docs.djangoproject.com/en/1.9/howto/static-files/ + +STATIC_URL = '/static/' diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/dummy.py b/tests/dummy.py deleted file mode 100644 index 8266d7b..0000000 --- a/tests/dummy.py +++ /dev/null @@ -1,136 +0,0 @@ -import functools -from cas_server import models - -class DummyUserManager(object): - def __init__(self, username, session_key): - self.username = username - self.session_key = session_key - def get(self, username=None, session_key=None): - if username == self.username and session_key == self.session_key: - return models.User(username=username, session_key=session_key) - else: - raise models.User.DoesNotExist() - - -def dummy(*args, **kwds): - pass - -def dummy_service_pattern(**kwargs): - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwds): - service_validate = models.ServicePattern.validate - models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern(**kwargs)) - ret = func(*args, **kwds) - models.ServicePattern.validate = service_validate - return ret - return wrapper - return decorator - -def dummy_user(username, session_key): - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwds): - user_manager = models.User.objects - user_save = models.User.save - user_delete = models.User.delete - models.User.objects = DummyUserManager(username, session_key) - models.User.save = dummy - models.User.delete = dummy - ret = func(*args, **kwds) - models.User.objects = user_manager - models.User.save = user_save - models.User.delete = user_delete - return ret - return wrapper - return decorator - -def dummy_ticket(ticket_class, service, ticket): - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwds): - ticket_manager = ticket_class.objects - ticket_save = ticket_class.save - ticket_delete = ticket_class.delete - ticket_class.objects = DummyTicketManager(ticket_class, service, ticket) - ticket_class.save = dummy - ticket_class.delete = dummy - ret = func(*args, **kwds) - ticket_class.objects = ticket_manager - ticket_class.save = ticket_save - ticket_class.delete = ticket_delete - return ret - return wrapper - return decorator - - -def dummy_proxy(func): - @functools.wraps(func) - def wrapper(*args, **kwds): - proxy_manager = models.Proxy.objects - models.Proxy.objects = DummyProxyManager() - ret = func(*args, **kwds) - models.Proxy.objects = proxy_manager - return ret - return wrapper - -class DummyProxyManager(object): - def create(self, **kwargs): - for field in models.Proxy._meta.fields: - field.allow_unsaved_instance_assignment = True - return models.Proxy(**kwargs) - -class DummyTicketManager(object): - def __init__(self, ticket_class, service, ticket): - self.ticket_class = ticket_class - self.service = service - self.ticket = ticket - - def create(self, **kwargs): - for field in self.ticket_class._meta.fields: - field.allow_unsaved_instance_assignment = True - return self.ticket_class(**kwargs) - - def filter(self, *args, **kwargs): - return DummyQuerySet() - - def get(self, **kwargs): - for field in self.ticket_class._meta.fields: - field.allow_unsaved_instance_assignment = True - if 'value' in kwargs: - if kwargs['value'] != self.ticket: - raise self.ticket_class.DoesNotExist() - else: - kwargs['value'] = self.ticket - - if 'service' in kwargs: - if kwargs['service'] != self.service: - raise self.ticket_class.DoesNotExist() - else: - kwargs['service'] = self.service - if not 'user' in kwargs: - kwargs['user'] = models.User(username="test") - - for field in models.ServiceTicket._meta.fields: - field.allow_unsaved_instance_assignment = True - for key in list(kwargs): - if '__' in key: - del kwargs[key] - kwargs['attributs'] = {'mail': 'test@example.com'} - kwargs['service_pattern'] = models.ServicePattern() - return self.ticket_class(**kwargs) - - - -class DummySession(dict): - session_key = "test_session" - - def set_expiry(self, int): - pass - - def flush(self): - self.clear() - - -class DummyQuerySet(set): - pass diff --git a/tests/init.py b/tests/init.py deleted file mode 100644 index f6ede9e..0000000 --- a/tests/init.py +++ /dev/null @@ -1,32 +0,0 @@ -import django -from django.conf import settings -from django.contrib import messages - -settings.configure() -settings.STATIC_URL = "/static/" -settings.DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': '/dev/null', - } -} -settings.INSTALLED_APPS = ( - 'django.contrib.admin', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.sessions', - 'django.contrib.messages', - 'django.contrib.staticfiles', - 'bootstrap3', - 'cas_server', -) - -settings.ROOT_URLCONF = "/" -settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' - -try: - django.setup() -except AttributeError: - pass -messages.add_message = lambda x,y,z:None - diff --git a/tests/test_proxy.py b/tests/test_proxy.py deleted file mode 100644 index 963d834..0000000 --- a/tests/test_proxy.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import absolute_import -from tests.init import * - -from django.test import RequestFactory - -import os -import pytest -from lxml import etree -from cas_server.views import ValidateService, Proxy -from cas_server import models - -from tests.dummy import * - -@pytest.mark.django_db -@dummy_ticket(models.ProxyGrantingTicket, '', "PGT-random") -@dummy_service_pattern(proxy=True) -@dummy_user(username="test", session_key="test_session") -@dummy_ticket(models.ProxyTicket, "https://www.example.com", "PT-random") -@dummy_proxy -def test_proxy_ok(): - factory = RequestFactory() - request = factory.get('/proxy?pgt=PGT-random&targetService=https://www.example.com') - - request.session = DummySession() - - proxy = Proxy() - response = proxy.get(request) - - assert response.status_code == 200 - - root = etree.fromstring(response.content) - proxy_tickets = root.xpath("//cas:proxyTicket", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - - assert len(proxy_tickets) == 1 - - factory = RequestFactory() - request = factory.get('/proxyValidate?ticket=PT-random&service=https://www.example.com') - - validate = ValidateService() - validate.allow_proxy_ticket = True - response = validate.get(request) - - assert response.status_code == 200 - - root = etree.fromstring(response.content) - users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - - assert len(users) == 1 - assert users[0].text == "test" - - - diff --git a/tests/test_validate_service.py b/tests/test_validate_service.py deleted file mode 100644 index 940e23b..0000000 --- a/tests/test_validate_service.py +++ /dev/null @@ -1,87 +0,0 @@ -from __future__ import absolute_import -from .init import * - -from django.test import RequestFactory - -import os -import pytest -from lxml import etree -from cas_server.views import ValidateService -from cas_server import models - -from .dummy import * - -@pytest.mark.django_db -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") -def test_validate_service_view_ok(): - factory = RequestFactory() - request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example.com') - - request.session = DummySession() - - validate = ValidateService() - validate.allow_proxy_ticket = False - response = validate.get(request) - - assert response.status_code == 200 - - root = etree.fromstring(response.content) - users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - - assert len(users) == 1 - assert users[0].text == "test" - - attributes = root.xpath("//cas:attributes", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - - assert len(attributes) == 1 - - attrs = {} - for attr in attributes[0]: - attrs[attr.tag[len("http://www.yale.edu/tp/cas")+2:]]=attr.text - - assert 'mail' in attrs - assert attrs['mail'] == 'test@example.com' - - - -@pytest.mark.django_db -@dummy_ticket(models.ServiceTicket, 'https://www.example2.com', "ST-random") -def test_validate_service_view_badservice(): - factory = RequestFactory() - request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example1.com') - - request.session = DummySession() - - validate = ValidateService() - validate.allow_proxy_ticket = False - response = validate.get(request) - - assert response.status_code == 200 - - root = etree.fromstring(response.content) - - error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - - assert len(error) == 1 - assert error[0].attrib['code'] == 'INVALID_SERVICE' - -@pytest.mark.django_db -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random2") -def test_validate_service_view_badticket(): - factory = RequestFactory() - request = factory.get('/serviceValidate?ticket=ST-random1&service=https://www.example.com') - - request.session = DummySession() - - validate = ValidateService() - validate.allow_proxy_ticket = False - response = validate.get(request) - - assert response.status_code == 200 - - root = etree.fromstring(response.content) - - error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - - assert len(error) == 1 - assert error[0].attrib['code'] == 'INVALID_TICKET' diff --git a/tests/test_views_auth.py b/tests/test_views_auth.py deleted file mode 100644 index 4b4a9eb..0000000 --- a/tests/test_views_auth.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import absolute_import -from .init import * - -from django.test import RequestFactory - -import os -import pytest - -from cas_server.views import Auth -from cas_server import models - -from .dummy import * - -settings.CAS_AUTH_SHARED_SECRET = "test" - -@pytest.mark.django_db -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") -@dummy_user(username="test", session_key="test_session") -@dummy_service_pattern() -def test_auth_view_goodpass(): - factory = RequestFactory() - request = factory.post('/auth', {'username':'test', 'password':'test', 'service':'https://www.example.com', 'secret':'test'}) - - request.session = DummySession() - - auth = Auth() - response = auth.post(request) - - assert response.status_code == 200 - assert response.content == b"yes\n" - -@dummy_service_pattern() -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") -@dummy_user(username="test", session_key="test_session") -def test_auth_view_badpass(): - factory = RequestFactory() - request = factory.post('/auth', {'username':'test', 'password':'badpass', 'service':'https://www.example.com', 'secret':'test'}) - - request.session = DummySession() - - auth = Auth() - response = auth.post(request) - - assert response.status_code == 200 - assert response.content == b"no\n" - diff --git a/tests/test_views_login.py b/tests/test_views_login.py deleted file mode 100644 index 6aabe80..0000000 --- a/tests/test_views_login.py +++ /dev/null @@ -1,163 +0,0 @@ -from __future__ import absolute_import -from .init import * - -from django.test import RequestFactory - -import os -import pytest - -from cas_server.views import LoginView -from cas_server import models - -from .dummy import * - - - -def test_login_view_post_goodpass_goodlt(): - factory = RequestFactory() - request = factory.post('/login', {'username':'test', 'password':'test', 'lt':'LT-random'}) - request.session = DummySession() - - request.session['lt'] = ['LT-random'] - - request.session["username"] = os.urandom(20) - request.session["warn"] = os.urandom(20) - - login = LoginView() - login.init_post(request) - - ret = login.process_post(pytest=True) - - assert ret == LoginView.USER_LOGIN_OK - assert request.session.get("authenticated") == True - assert request.session.get("username") == "test" - assert request.session.get("warn") == False - -def test_login_view_post_badlt(): - factory = RequestFactory() - request = factory.post('/login', {'username':'test', 'password':'test', 'lt':'LT-random1'}) - request.session = DummySession() - - request.session['lt'] = ['LT-random2'] - - authenticated = os.urandom(20) - username = os.urandom(20) - warn = os.urandom(20) - - request.session["authenticated"] = authenticated - request.session["username"] = username - request.session["warn"] = warn - - login = LoginView() - login.init_post(request) - - ret = login.process_post(pytest=True) - - assert ret == LoginView.INVALID_LOGIN_TICKET - assert request.session.get("authenticated") == authenticated - assert request.session.get("username") == username - assert request.session.get("warn") == warn - -def test_login_view_post_badpass_good_lt(): - factory = RequestFactory() - request = factory.post('/login', {'username':'test', 'password':'badpassword', 'lt':'LT-random'}) - request.session = DummySession() - - request.session['lt'] = ['LT-random'] - - login = LoginView() - login.init_post(request) - ret = login.process_post() - - assert ret == LoginView.USER_LOGIN_FAILURE - assert not request.session.get("authenticated") - assert not request.session.get("username") - assert not request.session.get("warn") - - -def test_view_login_get_unauth(): - factory = RequestFactory() - request = factory.post('/login') - request.session = DummySession() - - login = LoginView() - login.init_get(request) - ret = login.process_get() - - assert ret == LoginView.USER_NOT_AUTHENTICATED - - login = LoginView() - response = login.get(request) - - assert response.status_code == 200 - -@pytest.mark.django_db -@dummy_user(username="test", session_key="test_session") -def test_view_login_get_auth(): - factory = RequestFactory() - request = factory.post('/login') - request.session = DummySession() - - request.session["authenticated"] = True - request.session["username"] = "test" - request.session["warn"] = False - - login = LoginView() - login.init_get(request) - ret = login.process_get() - - assert ret == LoginView.USER_AUTHENTICATED - - login = LoginView() - response = login.get(request) - - assert response.status_code == 200 - -@pytest.mark.django_db -@dummy_service_pattern() -@dummy_user(username="test", session_key="test_session") -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") -def test_view_login_get_auth_service(): - factory = RequestFactory() - request = factory.post('/login?service=https://www.example.com') - request.session = DummySession() - - request.session["authenticated"] = True - request.session["username"] = "test" - request.session["warn"] = False - - login = LoginView() - login.init_get(request) - ret = login.process_get() - - assert ret == LoginView.USER_AUTHENTICATED - - login = LoginView() - response = login.get(request) - - assert response.status_code == 302 - assert response['Location'].startswith('https://www.example.com?ticket=ST-') - -@pytest.mark.django_db -@dummy_service_pattern() -@dummy_user(username="test", session_key="test_session") -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") -def test_view_login_get_auth_service_warn(): - factory = RequestFactory() - request = factory.post('/login?service=https://www.example.com') - request.session = DummySession() - - request.session["authenticated"] = True - request.session["username"] = "test" - request.session["warn"] = True - - login = LoginView() - login.init_get(request) - ret = login.process_get() - - assert ret == LoginView.USER_AUTHENTICATED - - login = LoginView() - response = login.get(request) - - assert response.status_code == 200 diff --git a/tests/test_views_logout.py b/tests/test_views_logout.py deleted file mode 100644 index 03410bd..0000000 --- a/tests/test_views_logout.py +++ /dev/null @@ -1,80 +0,0 @@ -from __future__ import absolute_import -from .init import * - -from django.test import RequestFactory - -import os -import pytest - -from cas_server.views import LogoutView -from cas_server import models - -from .dummy import * - - -@pytest.mark.django_db -@dummy_user(username="test", session_key="test_session") -def test_logout_view(): - factory = RequestFactory() - request = factory.get('/logout') - - request.session = DummySession() - - request.session["authenticated"] = True - request.session["username"] = "test" - request.session["warn"] = False - - logout = LogoutView() - response = logout.get(request) - - assert response.status_code == 200 - assert not request.session.get("authenticated") - assert not request.session.get("username") - assert not request.session.get("warn") - - -@pytest.mark.django_db -@dummy_user(username="test", session_key="test_session") -def test_logout_view_url(): - factory = RequestFactory() - request = factory.get('/logout?url=https://www.example.com') - - request.session = DummySession() - - request.session["authenticated"] = True - request.session["username"] = "test" - request.session["warn"] = False - - logout = LogoutView() - response = logout.get(request) - - assert response.status_code == 302 - assert response['Location'] == 'https://www.example.com' - assert not request.session.get("authenticated") - assert not request.session.get("username") - assert not request.session.get("warn") - - - -@pytest.mark.django_db -@dummy_user(username="test", session_key="test_session") -def test_logout_view_service(): - factory = RequestFactory() - request = factory.get('/logout?service=https://www.example.com') - - request.session = DummySession() - - request.session["authenticated"] = True - request.session["username"] = "test" - request.session["warn"] = False - - logout = LogoutView() - response = logout.get(request) - - assert response.status_code == 302 - assert response['Location'] == 'https://www.example.com' - assert not request.session.get("authenticated") - assert not request.session.get("username") - assert not request.session.get("warn") - - diff --git a/tests/test_views_validate.py b/tests/test_views_validate.py deleted file mode 100644 index 201387f..0000000 --- a/tests/test_views_validate.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import absolute_import -from .init import * - -from django.test import RequestFactory - -import os -import pytest - -from cas_server.views import Validate -from cas_server import models - -from .dummy import * - -@pytest.mark.django_db -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") -def test_validate_view_ok(): - factory = RequestFactory() - request = factory.get('/validate?ticket=ST-random&service=https://www.example.com') - - request.session = DummySession() - - validate = Validate() - response = validate.get(request) - - assert response.status_code == 200 - assert response.content == b"yes\ntest\n" - - - -@pytest.mark.django_db -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") -def test_validate_view_badservice(): - factory = RequestFactory() - request = factory.get('/validate?ticket=ST-random&service=https://www.example2.com') - - request.session = DummySession() - - validate = Validate() - response = validate.get(request) - - assert response.status_code == 200 - assert response.content == b"no\n" - - - -@pytest.mark.django_db -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random1") -def test_validate_view_badticket(): - factory = RequestFactory() - request = factory.get('/validate?ticket=ST-random2&service=https://www.example.com') - - request.session = DummySession() - - validate = Validate() - response = validate.get(request) - - assert response.status_code == 200 - assert response.content == b"no\n" diff --git a/tox.ini b/tox.ini index 997620a..0b65c56 100644 --- a/tox.ini +++ b/tox.ini @@ -17,7 +17,7 @@ deps = -r{toxinidir}/requirements-dev.txt [testenv] -commands=py.test --tb native {posargs:tests} +commands=python run_tests {posargs:tests} [testenv:py27-django17] basepython=python2.7 From 269cfb463b1fdb0ac11556d041eedbe79267db64 Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Sat, 25 Jun 2016 11:26:56 +0200 Subject: [PATCH 07/27] Add coverage to Makefile --- .gitignore | 4 ++++ Makefile | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/.gitignore b/.gitignore index 0b5a2a6..42d76c1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ *.pyc *.egg-info +*~ +*.swp build/ bootstrap3 @@ -10,3 +12,5 @@ manage.py .tox test_venv +.coverage +htmlcov/ diff --git a/Makefile b/Makefile index e5b19f1..df1a6a6 100644 --- a/Makefile +++ b/Makefile @@ -44,3 +44,8 @@ test_project: test_venv test_venv/cas/manage.py run_test_server: test_project test_venv/bin/python test_venv/cas/manage.py runserver + +coverage: test_venv + test_venv/bin/pip install coverage + test_venv/bin/coverage run --source='cas_server' run_tests + test_venv/bin/coverage html From 560b5b7a21c3d46c59cf1f6c3088ec6bdcb1cb87 Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Sat, 25 Jun 2016 11:37:44 +0200 Subject: [PATCH 08/27] Add six to requirements --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8d64df0..97d4f1c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,4 @@ requests_futures>=0.9.5 django-picklefield>=0.3.1 django-bootstrap3>=5.4 lxml>=3.4 - +six>=1 From a6c77b54d8bd6ec156f9ec45b6921846b2aaf016 Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Sun, 26 Jun 2016 10:28:13 +0200 Subject: [PATCH 09/27] Update some README links --- README.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index 78eaae1..f154049 100644 --- a/README.rst +++ b/README.rst @@ -11,7 +11,7 @@ CAS Server :target: https://www.gnu.org/licenses/gpl-3.0.html CAS Server is a Django application implementing the `CAS Protocol 3.0 Specification -`_. +`_. By defaut, the authentication process use django internal users but you can easily use any sources (see auth classes in the auth.py file) @@ -70,7 +70,7 @@ Quick start 4. You should add some management commands to a crontab: ``clearsessions``, ``cas_clean_tickets`` and ``cas_clean_sessions``. - * ``clearsessions``: please see `Clearing the session store `_. + * ``clearsessions``: please see `Clearing the session store `_. * ``cas_clean_tickets``: old tickets and timed-out tickets do not get purge from the database automatically. They are just marked as invalid. ``cas_clean_tickets`` is a clean-up management command for this purpose. It send SingleLogOut request @@ -204,7 +204,7 @@ Logs ---- ``django-cas-server`` logs most of its actions. To enable login, you must set the ``LOGGING`` -(https://docs.djangoproject.com/en/dev/topics/logging) variable is ``settings.py``. +(https://docs.djangoproject.com/en/stable/topics/logging) variable is ``settings.py``. Users successful actions (login, logout) are logged with the level ``INFO``, failures are logged with the level ``WARNING`` and user attributes transmitted to a service are logged with the level ``DEBUG``. From 23bbd8080a57c28ae9e2ceb6c5c3b4fec092a45f Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Sun, 26 Jun 2016 10:41:15 +0200 Subject: [PATCH 10/27] Codacy badges and coverage --- .gitignore | 1 + Makefile | 5 +++++ README.rst | 6 ++++++ 3 files changed, 12 insertions(+) diff --git a/.gitignore b/.gitignore index 42d76c1..3b1bcb6 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ cas/ dist/ db.sqlite3 manage.py +coverage.xml .tox test_venv diff --git a/Makefile b/Makefile index df1a6a6..6ada128 100644 --- a/Makefile +++ b/Makefile @@ -49,3 +49,8 @@ coverage: test_venv test_venv/bin/pip install coverage test_venv/bin/coverage run --source='cas_server' run_tests test_venv/bin/coverage html + test_venv/bin/coverage xml + +coverage_codacy: coverage + test_venv/bin/pip install codacy-coverage + test_venv/bin/python-codacy-coverage -r coverage.xml diff --git a/README.rst b/README.rst index f154049..85b2dc4 100644 --- a/README.rst +++ b/README.rst @@ -10,6 +10,12 @@ CAS Server .. image:: https://img.shields.io/pypi/l/django-cas-server.svg :target: https://www.gnu.org/licenses/gpl-3.0.html +.. image:: https://api.codacy.com/project/badge/Grade/255c21623d6946ef8802fa7995b61366 + :target: https://www.codacy.com/app/valentin-samir/django-cas-server + +.. image:: https://api.codacy.com/project/badge/Coverage/255c21623d6946ef8802fa7995b61366 + :target: https://www.codacy.com/app/valentin-samir/django-cas-server + CAS Server is a Django application implementing the `CAS Protocol 3.0 Specification `_. From 173f4d8a82c1e54b95e3fad1f1a97e76f4123386 Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Sun, 26 Jun 2016 10:55:08 +0200 Subject: [PATCH 11/27] Omit django migrations in coverage --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 6ada128..9088fba 100644 --- a/Makefile +++ b/Makefile @@ -47,7 +47,7 @@ run_test_server: test_project coverage: test_venv test_venv/bin/pip install coverage - test_venv/bin/coverage run --source='cas_server' run_tests + test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests test_venv/bin/coverage html test_venv/bin/coverage xml From 03cbab37f428e727e9c2ba8b46aac0f14931822e Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Sun, 26 Jun 2016 11:01:37 +0200 Subject: [PATCH 12/27] Javascript style --- cas_server/static/cas_server/cas.js | 41 +++++++++++++---------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/cas_server/static/cas_server/cas.js b/cas_server/static/cas_server/cas.js index 4c42dde..06e1a5d 100644 --- a/cas_server/static/cas_server/cas.js +++ b/cas_server/static/cas_server/cas.js @@ -1,55 +1,52 @@ function cas_login(cas_server_login, service, login_service, callback){ - url = cas_server_login + '?service=' + encodeURIComponent(service); + var url = cas_server_login + "?service=" + encodeURIComponent(service); $.ajax({ - type: 'GET', - url:url, - beforeSend: function (request) { + type: "GET", + url, + beforeSend(request) { request.setRequestHeader("X-AJAX", "1"); }, xhrFields: { withCredentials: true }, - success: function(data, textStatus, request){ - if(data.status == 'success'){ + success(data, textStatus, request){ + if(data.status === "success"){ $.ajax({ - type: 'GET', + type: "GET", url: data.url, xhrFields: { withCredentials: true }, success: callback, - error: function (request, textStatus, errorThrown) {}, + error(request, textStatus, errorThrown) {}, }); } else { - if(data.detail == "login required"){ - window.location.href = cas_server_login + '?service=' + encodeURIComponent(login_service); + if(data.detail === "login required"){ + window.location.href = cas_server_login + "?service=" + encodeURIComponent(login_service); } else { - alert('error: ' + data.messages[1].message); + alert("error: " + data.messages[1].message); } } }, - error: function (request, textStatus, errorThrown) {}, + error(request, textStatus, errorThrown) {}, }); } function cas_logout(cas_server_logout){ $.ajax({ - type: 'GET', - url:cas_server_logout, - beforeSend: function (request) { + type: "GET", + url: cas_server_logout, + beforeSend(request) { request.setRequestHeader("X-AJAX", "1"); }, xhrFields: { withCredentials: true }, - error: function (request, textStatus, errorThrown) {}, - success: function(data, textStatus, request){ - if(data.status == 'error'){ - alert('error: ' + data.messages[1].message); + error(request, textStatus, errorThrown) {}, + success(data, textStatus, request){ + if(data.status === "error"){ + alert("error: " + data.messages[1].message); } }, }); } - - - From bf7da7e805c102885fc34b830b0b3255e6649c23 Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Sun, 26 Jun 2016 11:02:57 +0200 Subject: [PATCH 13/27] More descriptive name for default_app_config --- cas_server/__init__.py | 2 +- cas_server/apps.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cas_server/__init__.py b/cas_server/__init__.py index 1bb1fa4..f830740 100644 --- a/cas_server/__init__.py +++ b/cas_server/__init__.py @@ -9,4 +9,4 @@ # # (c) 2015 Valentin Samir -default_app_config = 'cas_server.apps.AppConfig' +default_app_config = 'cas_server.apps.CasAppConfig' diff --git a/cas_server/apps.py b/cas_server/apps.py index bb93d57..c34b6eb 100644 --- a/cas_server/apps.py +++ b/cas_server/apps.py @@ -2,6 +2,6 @@ from django.utils.translation import ugettext_lazy as _ from django.apps import AppConfig -class AppConfig(AppConfig): +class CasAppConfig(AppConfig): name = 'cas_server' verbose_name = _('Central Authentication Service') From 3e80a018dd0bbe100c74928da5bf24fa68231c5d Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Sun, 26 Jun 2016 11:04:05 +0200 Subject: [PATCH 14/27] Css style --- cas_server/static/cas_server/login.css | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cas_server/static/cas_server/login.css b/cas_server/static/cas_server/login.css index b29433d..6d3524b 100644 --- a/cas_server/static/cas_server/login.css +++ b/cas_server/static/cas_server/login.css @@ -43,14 +43,14 @@ body { @media screen and (max-width: 680px) { #app-name { - margin: 0px; + margin: 0; } #app-name img { display: block; margin: auto; } body { - padding-top: 0px; - padding-bottom: 0px; + padding-top: 0; + padding-bottom: 0; } } From ac5f3590636f1590614d162033b02205cf441573 Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Sun, 26 Jun 2016 11:16:41 +0200 Subject: [PATCH 15/27] style --- cas_server/auth.py | 4 ++-- cas_server/tests.py | 22 +++++++++++----------- cas_server/utils.py | 8 ++++---- cas_server/views.py | 18 ++++++------------ 4 files changed, 23 insertions(+), 29 deletions(-) diff --git a/cas_server/auth.py b/cas_server/auth.py index c9e8e34..c2a4b19 100644 --- a/cas_server/auth.py +++ b/cas_server/auth.py @@ -26,11 +26,11 @@ class AuthUser(object): def test_password(self, password): """test `password` agains the user""" - raise NotImplemented() + raise NotImplementedError() def attributs(self): """return a dict of user attributes""" - raise NotImplemented() + raise NotImplementedError() class DummyAuthUser(AuthUser): diff --git a/cas_server/tests.py b/cas_server/tests.py index 6db4518..222596e 100644 --- a/cas_server/tests.py +++ b/cas_server/tests.py @@ -44,7 +44,7 @@ def get_user_ticket_request(service): def get_pgt(): - (httpd_thread, host, port) = utils.PGTUrlHandler.run() + (host, port) = utils.PGTUrlHandler.run()[1:3] service = "http://%s:%s" % (host, port) (user, ticket) = get_user_ticket_request(service) @@ -326,7 +326,7 @@ class ValidateTestCase(TestCase): models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) def test_validate_view_ok(self): - (user, ticket) = get_user_ticket_request(self.service) + ticket = get_user_ticket_request(self.service)[1] client = Client() response = client.get('/validate', {'ticket': ticket.value, 'service': self.service}) @@ -334,7 +334,7 @@ class ValidateTestCase(TestCase): self.assertEqual(response.content, b'yes\ntest\n') def test_validate_view_badservice(self): - (user, ticket) = get_user_ticket_request(self.service) + ticket = get_user_ticket_request(self.service)[1] client = Client() response = client.get( @@ -345,7 +345,7 @@ class ValidateTestCase(TestCase): self.assertEqual(response.content, b'no\n') def test_validate_view_badticket(self): - (user, ticket) = get_user_ticket_request(self.service) + get_user_ticket_request(self.service) client = Client() response = client.get( @@ -369,7 +369,7 @@ class ValidateServiceTestCase(TestCase): models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) def test_validate_service_view_ok(self): - (user, ticket) = get_user_ticket_request(self.service) + ticket = get_user_ticket_request(self.service)[1] client = Client() response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service}) @@ -404,7 +404,7 @@ class ValidateServiceTestCase(TestCase): self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES) def test_validate_service_view_badservice(self): - (user, ticket) = get_user_ticket_request(self.service) + ticket = get_user_ticket_request(self.service)[1] client = Client() bad_service = "https://www.example.org" @@ -421,7 +421,7 @@ class ValidateServiceTestCase(TestCase): self.assertEqual(error[0].text, bad_service) def test_validate_service_view_badticket_goodprefix(self): - (user, ticket) = get_user_ticket_request(self.service) + get_user_ticket_request(self.service) client = Client() bad_ticket = "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX @@ -438,7 +438,7 @@ class ValidateServiceTestCase(TestCase): self.assertEqual(error[0].text, 'ticket not found') def test_validate_service_view_badticket_badprefix(self): - (user, ticket) = get_user_ticket_request(self.service) + get_user_ticket_request(self.service) client = Client() bad_ticket = "RANDOM" @@ -455,10 +455,10 @@ class ValidateServiceTestCase(TestCase): self.assertEqual(error[0].text, bad_ticket) def test_validate_service_view_ok_pgturl(self): - (httpd_thread, host, port) = utils.PGTUrlHandler.run() + (host, port) = utils.PGTUrlHandler.run()[1:3] service = "http://%s:%s" % (host, port) - (user, ticket) = get_user_ticket_request(service) + ticket = get_user_ticket_request(service)[1] client = Client() response = client.get( @@ -480,7 +480,7 @@ class ValidateServiceTestCase(TestCase): def test_validate_service_pgturl_bad_proxy_callback(self): self.service_pattern.proxy_callback = False self.service_pattern.save() - (user, ticket) = get_user_ticket_request(self.service) + ticket = get_user_ticket_request(self.service)[1] client = Client() response = client.get( diff --git a/cas_server/utils.py b/cas_server/utils.py index 2c2b77f..8a2a040 100644 --- a/cas_server/utils.py +++ b/cas_server/utils.py @@ -80,9 +80,9 @@ def update_url(url, params): query = dict(parse_qsl(url_parts[4])) query.update(params) url_parts[4] = urlencode(query) - for i in range(len(url_parts)): - if not isinstance(url_parts[i], bytes): - url_parts[i] = url_parts[i].encode('utf-8') + for i, url_part in enumerate(url_parts): + if not isinstance(url_part, bytes): + url_parts[i] = url_part.encode('utf-8') return urlunparse(url_parts).decode('utf-8') @@ -155,7 +155,7 @@ class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler): params = dict(parse_qsl(url.query)) PGTUrlHandler.PARAMS.update(params) - def log_message(self, format, *args): + def log_message(self, template, *args): return @staticmethod diff --git a/cas_server/views.py b/cas_server/views.py index 149632b..37fe179 100644 --- a/cas_server/views.py +++ b/cas_server/views.py @@ -63,12 +63,12 @@ class AttributesMixin(object): class LogoutMixin(object): """destroy CAS session utils""" - def logout(self, all=False): + def logout(self, all_session=False): """effectively destroy CAS session""" session_nb = 0 username = self.request.session.get("username") if username: - if all: + if all_session: logger.info("Logging out user %s from all of they sessions." % username) else: logger.info("Logging out user %s." % username) @@ -86,8 +86,8 @@ class LogoutMixin(object): # if user not found in database, flush the session anyway self.request.session.flush() - # If all is set logout user from alternative sessions - if all: + # If all_session is set logout user from alternative sessions + if all_session: for user in models.User.objects.filter(username=username): session = SessionStore(session_key=user.session_key) session.flush() @@ -198,10 +198,7 @@ class LoginView(View, LogoutMixin): def init_post(self, request): self.request = request self.service = request.POST.get('service') - if request.POST.get('renew') and request.POST['renew'] != "False": - self.renew = True - else: - self.renew = False + self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False") self.gateway = request.POST.get('gateway') self.method = request.POST.get('method') self.ajax = 'HTTP_X_AJAX' in request.META @@ -285,10 +282,7 @@ class LoginView(View, LogoutMixin): def init_get(self, request): self.request = request self.service = request.GET.get('service') - if request.GET.get('renew') and request.GET['renew'] != "False": - self.renew = True - else: - self.renew = False + self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False") self.gateway = request.GET.get('gateway') self.method = request.GET.get('method') self.ajax = 'HTTP_X_AJAX' in request.META From 86b9d72d4ce6faf843216ce39777a57321db07ae Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Sun, 26 Jun 2016 12:54:18 +0200 Subject: [PATCH 16/27] Update README.rst --- README.rst | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/README.rst b/README.rst index 85b2dc4..9eeb2be 100644 --- a/README.rst +++ b/README.rst @@ -43,6 +43,15 @@ Features Quick start ----------- +0. If you want to make a virtualenv for ``django-cas-server``, you will need the following + dependencies on a bare debian like system:: + + virtualenv build-essential python-dev libxml2-dev libxslt1-dev zlib1g-dev + + If you want to use python3 instead of python2, replace ``python-dev`` with ``python3-dev``. + + If you intend to run the tox tests you will also need ``python3.4-dev`` depending of the current + version of python3 on your system. 1. Add "cas_server" to your INSTALLED_APPS setting like this:: @@ -128,14 +137,14 @@ Template settings: Authentication settings: -* ``CAS_AUTH_CLASS``: A dotted path to a class implementing ``cas_server.auth.AuthUser``. - The default is ``"cas_server.auth.DjangoAuthUser"`` +* ``CAS_AUTH_CLASS``: A dotted path to a class or a class implementing + ``cas_server.auth.AuthUser``. The default is ``"cas_server.auth.DjangoAuthUser"`` * ``SESSION_COOKIE_AGE``: This is a django settings. Here, it control the delay in seconds after which inactive users are logged out. The default is ``1209600`` (2 weeks). You probably should reduce it to something like ``86400`` seconds (1 day). -* ``CAS_PROXY_CA_CERTIFICATE_PATH``: Path to certificates authority file. Usually on linux +* ``CAS_PROXY_CA_CERTIFICATE_PATH``: Path to certificate authorities file. Usually on linux the local CAs are in ``/etc/ssl/certs/ca-certificates.crt``. The default is ``True`` which tell requests to use its internal certificat authorities. Settings it to ``False`` should disable all x509 certificates validation and MUST not be done in production. @@ -152,7 +161,7 @@ Tickets validity settings: application. The default is ``60``. * ``CAS_PGT_VALIDITY``: Number of seconds the proxy granting tickets are valid. The default is ``3600`` (1 hour). -* ``CAS_TICKET_TIMEOUT``: Number of seconds a ticket is kept is the database before sending +* ``CAS_TICKET_TIMEOUT``: Number of seconds a ticket is kept in the database before sending Single Log Out request and being cleared. The default is ``86400`` (24 hours). Tickets miscellaneous settings: @@ -174,12 +183,12 @@ Tickets miscellaneous settings: * ``CAS_SERVICE_TICKET_PREFIX``: Prefix of service tickets. The default is ``"ST"``. The CAS specification mandate that service tickets MUST begin with the characters ST so you should not change this. -* ``CAS_PROXY_TICKET_PREFIX``: Prefix of proxy ticket. The default is ``"ST"``. +* ``CAS_PROXY_TICKET_PREFIX``: Prefix of proxy ticket. The default is ``"PT"``. * ``CAS_PROXY_GRANTING_TICKET_PREFIX``: Prefix of proxy granting ticket. The default is ``"PGT"``. * ``CAS_PROXY_GRANTING_TICKET_IOU_PREFIX``: Prefix of proxy granting ticket IOU. The default is ``"PGTIOU"``. -Mysql backend settings. Only usefull is you use the mysql authentication backend: +Mysql backend settings. Only usefull if you are using the mysql authentication backend: * ``CAS_SQL_HOST``: Host for the SQL server. The default is ``"localhost"``. * ``CAS_SQL_USERNAME``: Username for connecting to the SQL server. @@ -193,14 +202,23 @@ Mysql backend settings. Only usefull is you use the mysql authentication backend * ``CAS_SQL_PASSWORD_CHECK``: The method used to check the user password. Must be ``"crypt"`` or ``"plain``". The default is ``"crypt"``. + +Test backend settings. Only usefull if you are using the test authentication backend: + +* ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``. +* ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``. +* ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is + ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}``. + + Authentication backend ---------------------- ``django-cas-server`` comes with some authentication backends: * dummy backend ``cas_server.auth.DummyAuthUser``: all authentication attempt fails. -* test backend ``cas_server.auth.TestAuthUser``: username is ``test`` and password is ``test`` - the returned attributes for the user are: ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}`` +* test backend ``cas_server.auth.TestAuthUser``: username, password and returned attributes + for the user are defined by the ``CAS_TEST_*`` settings. * django backend ``cas_server.auth.DjangoAuthUser``: Users are authenticated agains django users system. This is the default backend. The returned attributes are the fields available on the user model. * mysql backend ``cas_server.auth.MysqlAuthUser``: see the 'Mysql backend settings' section. @@ -210,7 +228,7 @@ Logs ---- ``django-cas-server`` logs most of its actions. To enable login, you must set the ``LOGGING`` -(https://docs.djangoproject.com/en/stable/topics/logging) variable is ``settings.py``. +(https://docs.djangoproject.com/en/stable/topics/logging) variable in ``settings.py``. Users successful actions (login, logout) are logged with the level ``INFO``, failures are logged with the level ``WARNING`` and user attributes transmitted to a service are logged with the level ``DEBUG``. From 8303f816df1a70e9b8ba6509e1a31666f5d14dee Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Sun, 26 Jun 2016 15:34:26 +0200 Subject: [PATCH 17/27] Exclude non test auth from coverage --- cas_server/auth.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/cas_server/auth.py b/cas_server/auth.py index c2a4b19..4d26f09 100644 --- a/cas_server/auth.py +++ b/cas_server/auth.py @@ -12,7 +12,7 @@ """Some authentication classes for the CAS""" from django.conf import settings from django.contrib.auth import get_user_model -try: +try: # pragma: no cover import MySQLdb import MySQLdb.cursors import crypt @@ -33,7 +33,7 @@ class AuthUser(object): raise NotImplementedError() -class DummyAuthUser(AuthUser): +class DummyAuthUser(AuthUser): # pragma: no cover """A Dummy authentication class""" def __init__(self, username): @@ -64,7 +64,7 @@ class TestAuthUser(AuthUser): return settings.CAS_TEST_ATTRIBUTES -class MysqlAuthUser(AuthUser): +class MysqlAuthUser(AuthUser): # pragma: no cover """A mysql auth class: authentication user agains a mysql database""" user = None @@ -89,9 +89,7 @@ class MysqlAuthUser(AuthUser): def test_password(self, password): """test `password` agains the user""" - if not self.user: - return False - else: + if self.user: if settings.CAS_SQL_PASSWORD_CHECK == "plain": return password == self.user["password"] elif settings.CAS_SQL_PASSWORD_CHECK == "crypt": @@ -103,16 +101,18 @@ class MysqlAuthUser(AuthUser): password, self.user["password"][:2] ) == self.user["password"] + else: + return False def attributs(self): """return a dict of user attributes""" - if not self.user: - return {} - else: + if self.user: return self.user + else: + return {} -class DjangoAuthUser(AuthUser): +class DjangoAuthUser(AuthUser): # pragma: no cover """A django auth class: authenticate user agains django internal users""" user = None @@ -126,17 +126,17 @@ class DjangoAuthUser(AuthUser): def test_password(self, password): """test `password` agains the user""" - if not self.user: - return False - else: + if self.user: return self.user.check_password(password) + else: + return False def attributs(self): """return a dict of user attributes""" - if not self.user: - return {} - else: + if self.user: attr = {} for field in self.user._meta.fields: attr[field.attname] = getattr(self.user, field.attname) return attr + else: + return {} From 164e2f5c28e919d8c1d1763ce100b6d63b5b1064 Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Sun, 26 Jun 2016 16:02:25 +0200 Subject: [PATCH 18/27] style --- cas_server/tests.py | 30 +++++++++++++++--------------- cas_server/utils.py | 16 ++++++++-------- cas_server/views.py | 20 ++++++++++---------- 3 files changed, 33 insertions(+), 33 deletions(-) diff --git a/cas_server/tests.py b/cas_server/tests.py index 222596e..710b890 100644 --- a/cas_server/tests.py +++ b/cas_server/tests.py @@ -10,25 +10,25 @@ from cas_server import utils def get_login_page_params(): - client = Client() - response = client.get('/login') - form = response.context["form"] - params = {} - for field in form: - if field.value(): - params[field.name] = field.value() - else: - params[field.name] = "" - return client, params + client = Client() + response = client.get('/login') + form = response.context["form"] + params = {} + for field in form: + if field.value(): + params[field.name] = field.value() + else: + params[field.name] = "" + return client, params def get_auth_client(): - client, params = get_login_page_params() - params["username"] = settings.CAS_TEST_USER - params["password"] = settings.CAS_TEST_PASSWORD + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD - client.post('/login', params) - return client + client.post('/login', params) + return client def get_user_ticket_request(service): diff --git a/cas_server/utils.py b/cas_server/utils.py index 8a2a040..bd7e273 100644 --- a/cas_server/utils.py +++ b/cas_server/utils.py @@ -30,7 +30,7 @@ def context(params): return params -def JsonResponse(request, data): +def json_response(request, data): data["messages"] = [] for msg in messages.get_messages(request): data["messages"].append({'message': msg.message, 'level': msg.level_tag}) @@ -146,16 +146,16 @@ def gen_saml_id(): class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler): PARAMS = {} - def do_GET(s): - s.send_response(200) - s.send_header(b"Content-type", "text/plain") - s.end_headers() - s.wfile.write(b"ok") - url = urlparse(s.path) + def do_GET(self): + self.send_response(200) + self.send_header(b"Content-type", "text/plain") + self.end_headers() + self.wfile.write(b"ok") + url = urlparse(self.path) params = dict(parse_qsl(url.query)) PGTUrlHandler.PARAMS.update(params) - def log_message(self, template, *args): + def log_message(self, *args): return @staticmethod diff --git a/cas_server/views.py b/cas_server/views.py index 37fe179..2b33a6c 100644 --- a/cas_server/views.py +++ b/cas_server/views.py @@ -35,7 +35,7 @@ import cas_server.utils as utils import cas_server.forms as forms import cas_server.models as models -from .utils import JsonResponse +from .utils import json_response from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket from .models import ServicePattern @@ -154,13 +154,13 @@ class LogoutView(View, LogoutMixin): 'url': url, 'session_nb': session_nb } - return JsonResponse(request, data) + return json_response(request, data) else: return redirect("cas_server:login") else: if self.ajax: data = {'status': 'success', 'detail': 'logout', 'session_nb': session_nb} - return JsonResponse(request, data) + return json_response(request, data) else: return render( request, @@ -253,7 +253,7 @@ class LoginView(View, LogoutMixin): raise EnvironmentError("invalid output for LoginView.process_post") return self.common() - def process_post(self, pytest=False): + def process_post(self): if not self.check_lt(): values = self.request.POST.copy() # if not set a new LT and fail @@ -330,7 +330,7 @@ class LoginView(View, LogoutMixin): ) if self.ajax: data = {"status": "error", "detail": "confirmation needed"} - return JsonResponse(self.request, data) + return json_response(self.request, data) else: warn_form = forms.WarnForm(initial={ 'service': self.service, @@ -357,7 +357,7 @@ class LoginView(View, LogoutMixin): return HttpResponseRedirect(redirect_url) else: data = {"status": "success", "detail": "auth", "url": redirect_url} - return JsonResponse(self.request, data) + return json_response(self.request, data) except ServicePattern.DoesNotExist: error = 1 messages.add_message( @@ -401,7 +401,7 @@ class LoginView(View, LogoutMixin): ) else: data = {"status": "error", "detail": "auth", "code": error} - return JsonResponse(self.request, data) + return json_response(self.request, data) def authenticated(self): """Processing authenticated users""" @@ -423,7 +423,7 @@ class LoginView(View, LogoutMixin): "detail": "login required", "url": utils.reverse_params("cas_server:login", params=self.request.GET) } - return JsonResponse(self.request, data) + return json_response(self.request, data) else: return utils.redirect_params("cas_server:login", params=self.request.GET) @@ -433,7 +433,7 @@ class LoginView(View, LogoutMixin): else: if self.ajax: data = {"status": "success", "detail": "logged"} - return JsonResponse(self.request, data) + return json_response(self.request, data) else: return render( self.request, @@ -476,7 +476,7 @@ class LoginView(View, LogoutMixin): "detail": "login required", "url": utils.reverse_params("cas_server:login", params=self.request.GET) } - return JsonResponse(self.request, data) + return json_response(self.request, data) else: return render( self.request, From 02a566c129da7fd4aa51c2c83e5be8af980171bd Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Sun, 26 Jun 2016 16:13:09 +0200 Subject: [PATCH 19/27] Use constant only caps for constants --- cas_server/admin.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/cas_server/admin.py b/cas_server/admin.py index bfa5a73..a6a9be4 100644 --- a/cas_server/admin.py +++ b/cas_server/admin.py @@ -14,9 +14,9 @@ from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket, User, Servi from .models import Username, ReplaceAttributName, ReplaceAttributValue, FilterAttributValue from .forms import TicketForm -tickets_readonly_fields = ('validate', 'service', 'service_pattern', +TICKETS_READONLY_FIELDS = ('validate', 'service', 'service_pattern', 'creation', 'renew', 'single_log_out', 'value') -tickets_fields = ('validate', 'service', 'service_pattern', +TICKETS_FIELDS = ('validate', 'service', 'service_pattern', 'creation', 'renew', 'single_log_out') @@ -25,8 +25,8 @@ class ServiceTicketInline(admin.TabularInline): model = ServiceTicket extra = 0 form = TicketForm - readonly_fields = tickets_readonly_fields - fields = tickets_fields + readonly_fields = TICKETS_READONLY_FIELDS + fields = TICKETS_FIELDS class ProxyTicketInline(admin.TabularInline): @@ -34,8 +34,8 @@ class ProxyTicketInline(admin.TabularInline): model = ProxyTicket extra = 0 form = TicketForm - readonly_fields = tickets_readonly_fields - fields = tickets_fields + readonly_fields = TICKETS_READONLY_FIELDS + fields = TICKETS_FIELDS class ProxyGrantingInline(admin.TabularInline): @@ -43,8 +43,8 @@ class ProxyGrantingInline(admin.TabularInline): model = ProxyGrantingTicket extra = 0 form = TicketForm - readonly_fields = tickets_readonly_fields - fields = tickets_fields[1:] + readonly_fields = TICKETS_READONLY_FIELDS + fields = TICKETS_FIELDS[1:] class UserAdmin(admin.ModelAdmin): From b36a9a15235a3f4ce9087d447822b1d9697a1dbc Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Sun, 26 Jun 2016 17:20:26 +0200 Subject: [PATCH 20/27] Add a .coveragerc file --- .coveragerc | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 .coveragerc diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..f11c9de --- /dev/null +++ b/.coveragerc @@ -0,0 +1,7 @@ +[report] +exclude_lines = + pragma: no cover + def __repr__ + def __unicode__ + raise AssertionError + raise NotImplementedError From ac206d56d6712f63120a6f6b9e1d9332f3b6d6ab Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Sun, 26 Jun 2016 20:29:47 +0200 Subject: [PATCH 21/27] Add some password check methods to the MySQL auth backend --- README.rst | 14 +++- cas_server/auth.py | 18 ++--- cas_server/utils.py | 156 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 175 insertions(+), 13 deletions(-) diff --git a/README.rst b/README.rst index 9eeb2be..070a437 100644 --- a/README.rst +++ b/README.rst @@ -199,8 +199,18 @@ Mysql backend settings. Only usefull if you are using the mysql authentication b The username must be in field ``username``, the password in ``password``, additional fields are used as the user attributes. The default is ``"SELECT user AS usersame, pass AS password, users.* FROM users WHERE user = %s"`` -* ``CAS_SQL_PASSWORD_CHECK``: The method used to check the user password. Must be - ``"crypt"`` or ``"plain``". The default is ``"crypt"``. +* ``CAS_SQL_PASSWORD_CHECK``: The method used to check the user password. Must be one of the following: + * ``"crypt"`` (see ``), the password in the database + should begin this $ + * ``"ldap"`` (see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html) + the password in the database must begin with one of {MD5}, {SMD5}, {SHA}, {SSHA}, {SHA256}, + {SSHA256}, {SHA384}, {SSHA384}, {SHA512}, {SSHA512}, {CRYPT}. + * ``"hex_HASH_NAME"`` with ``HASH_NAME`` in md5, sha1, sha224, sha256, sha384, sha512. + The hashed password in the database is compare to the hexadecimal digest of the clear + password hashed with the corresponding algorithm. + * ``"plain"``, the password in the database must be in clear. + + The default is ``"crypt"``. Test backend settings. Only usefull if you are using the test authentication backend: diff --git a/cas_server/auth.py b/cas_server/auth.py index 4d26f09..0c147c2 100644 --- a/cas_server/auth.py +++ b/cas_server/auth.py @@ -16,6 +16,7 @@ try: # pragma: no cover import MySQLdb import MySQLdb.cursors import crypt + from utils import check_password except ImportError: MySQLdb = None @@ -90,17 +91,12 @@ class MysqlAuthUser(AuthUser): # pragma: no cover def test_password(self, password): """test `password` agains the user""" if self.user: - if settings.CAS_SQL_PASSWORD_CHECK == "plain": - return password == self.user["password"] - elif settings.CAS_SQL_PASSWORD_CHECK == "crypt": - if self.user["password"].startswith('$'): - salt = '$'.join(self.user["password"].split('$', 3)[:-1]) - return crypt.crypt(password, salt) == self.user["password"] - else: - return crypt.crypt( - password, - self.user["password"][:2] - ) == self.user["password"] + check_password( + settings.CAS_SQL_PASSWORD_CHECK, + password, + self.user["password"], + settings.CAS_SQL_DBCHARSET + ) else: return False diff --git a/cas_server/utils.py b/cas_server/utils.py index bd7e273..340a898 100644 --- a/cas_server/utils.py +++ b/cas_server/utils.py @@ -19,6 +19,10 @@ from django.contrib import messages import random import string import json +import hashlib +import crypt +import base64 +import six from threading import Thread from importlib import import_module from six.moves import BaseHTTPServer @@ -172,3 +176,155 @@ class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler): httpd_thread.daemon = True httpd_thread.start() return (httpd_thread, host, port) + +class LdapHashUserPassword(object): + """Please see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html""" + + schemes_salt = {b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}", b"{CRYPT}"} + schemes_nosalt = {b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"} + + _schemes_to_hash = { + b"{SMD5}": hashlib.md5, + b"{MD5}": hashlib.md5, + b"{SSHA}": hashlib.sha1, + b"{SHA}": hashlib.sha1, + b"{SSHA256}": hashlib.sha256, + b"{SHA256}": hashlib.sha256, + b"{SSHA384}": hashlib.sha384, + b"{SHA384}": hashlib.sha384, + b"{SSHA512}": hashlib.sha512, + b"{SHA512}": hashlib.sha512 + } + + _schemes_to_len = { + b"{SMD5}": 16, + b"{SSHA}": 20, + b"{SSHA256}": 32, + b"{SSHA384}": 48, + b"{SSHA512}": 64, + } + + + + class BadScheme(ValueError): + pass + + class BadHash(ValueError): + pass + + class BadSalt(ValueError): + pass + + @classmethod + def _raise_bad_scheme(cls, scheme, valid, msg): + valid_schemes = [s for s in valid] + valid_schemes.sort() + raise cls.BadScheme(msg % (scheme, ", ".join(valid_schemes))) + + @classmethod + def _test_scheme(cls, scheme): + if scheme not in cls.schemes_salt and scheme not in cls.schemes_nosalt: + cls._raise_bad_scheme( + scheme, + cls.schemes_salt | cls.schemes_nosalt, + "The scheme %r is not valid. Valide schemes are %s." + ) + + @classmethod + def _test_scheme_salt(cls, scheme): + if scheme not in cls.schemes_salt: + cls._raise_bad_scheme( + scheme, + cls.schemes_salt, + "The scheme %r is only valid without a salt. Valide schemes with salt are %s." + ) + + @classmethod + def _test_scheme_nosalt(cls, scheme): + if scheme not in cls.schemes_nosalt: + cls._raise_bad_scheme( + scheme, + cls.schemes_nosalt, + "The scheme %r is only valid with a salt. Valide schemes without salt are %s." + ) + + @classmethod + def hash(cls, scheme, password, salt=None, charset="utf8"): + scheme = scheme.upper() + cls._test_scheme(scheme) + if salt is None or salt == b"": + salt = b"" + cls._test_scheme_nosalt(scheme) + elif salt is not None: + cls._test_scheme_salt(scheme) + try: + return scheme + base64.b64encode(cls._schemes_to_hash[scheme](password + salt).digest() + salt) + except KeyError: + if six.PY3: + password = password.decode(charset) + salt = salt.decode(charset) + hashed_password = crypt.crypt(password, salt) + if hashed_password is None: + raise cls.BadSalt("System crypt implementation do not support the salt %r" % salt) + if six.PY3: + hashed_password = hashed_password.encode(charset) + return scheme + hashed_password + + @classmethod + def get_scheme(cls, hashed_passord): + if not hashed_passord[0] == b'{' or not b'}' in hashed_passord: + raise cls.BadHash("%r should start with the scheme enclosed with { }" % hashed_passord) + scheme = hashed_passord.split(b'}', 1)[0] + scheme = scheme.upper() + b"}" + return scheme + + + @classmethod + def get_salt(cls, hashed_passord): + scheme = cls.get_scheme(hashed_passord) + cls._test_scheme(scheme) + if scheme in cls.schemes_nosalt: + return b"" + elif scheme == b'{CRYPT}': + return b'$'.join(hashed_passord.split(b'$', 3)[:-1]) + else: + hashed_passord = base64.b64decode(hashed_passord[len(scheme):]) + if len(hashed_passord) < cls._schemes_to_len[scheme]: + raise cls.BadHash("Hash too short for the scheme %s" % scheme) + return hashed_passord[cls._schemes_to_len[scheme]:] + + + +def check_password(method, password, hashed_password, charset): + if not isinstance(password, six.binary_type): + password = password.encode(charset) + if not isinstance(hashed_password, six.binary_type): + hashed_password = hashed_password.encode(charset) + if method == "plain": + return password == hashed_password + elif method == "crypt": + if hashed_password.startswith(b'$'): + salt = b'$'.join(hashed_password.split(b'$', 3)[:-1]) + elif hashed_password.startswith(b'_'): + salt = hashed_password[:9] + else: + salt = hashed_password[:2] + if six.PY3: + password = password.decode(charset) + salt = salt.decode(charset) + hashed_password = hashed_password.decode(charset) + crypted_password = crypt.crypt(password, salt) + if crypted_password is None: + raise ValueError("System crypt implementation do not support the salt %r" % salt) + return crypted_password == hashed_password + elif method == "ldap": + scheme = LdapHashUserPassword.get_scheme(hashed_password) + salt = LdapHashUserPassword.get_salt(hashed_password) + return LdapHashUserPassword.hash(scheme, password, salt, charset=charset) == hashed_password + elif ( + method.startswith("hex_") and + method[4:] in {"md5", "sha1", "sha224", "sha256", "sha384", "sha512"} + ): + return getattr(hashlib, method[4:])(password).hexdigest() == hashed_password.lower() + else: + raise ValueError("Unknown password method check %r" % method) From 6faeaad57e835c73a2f4e023f2f60ecd479b77cb Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Sun, 26 Jun 2016 20:34:26 +0200 Subject: [PATCH 22/27] Typo in README.rst --- README.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 070a437..29b8057 100644 --- a/README.rst +++ b/README.rst @@ -200,7 +200,8 @@ Mysql backend settings. Only usefull if you are using the mysql authentication b additional fields are used as the user attributes. The default is ``"SELECT user AS usersame, pass AS password, users.* FROM users WHERE user = %s"`` * ``CAS_SQL_PASSWORD_CHECK``: The method used to check the user password. Must be one of the following: - * ``"crypt"`` (see ``), the password in the database + + * ``"crypt"`` (see ), the password in the database should begin this $ * ``"ldap"`` (see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html) the password in the database must begin with one of {MD5}, {SMD5}, {SHA}, {SSHA}, {SHA256}, From 2fac47f0b178a7faf2eb039c9c5aecb45eaafb2f Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Sun, 26 Jun 2016 21:44:41 +0200 Subject: [PATCH 23/27] Add unit test for the utils function check_password --- cas_server/auth.py | 1 - cas_server/tests.py | 55 +++++++++++++++++++++++++++++++++++++++++++++ cas_server/utils.py | 20 +++++++++-------- 3 files changed, 66 insertions(+), 10 deletions(-) diff --git a/cas_server/auth.py b/cas_server/auth.py index 0c147c2..7051828 100644 --- a/cas_server/auth.py +++ b/cas_server/auth.py @@ -15,7 +15,6 @@ from django.contrib.auth import get_user_model try: # pragma: no cover import MySQLdb import MySQLdb.cursors - import crypt from utils import check_password except ImportError: MySQLdb = None diff --git a/cas_server/tests.py b/cas_server/tests.py index 710b890..3f53a04 100644 --- a/cas_server/tests.py +++ b/cas_server/tests.py @@ -3,6 +3,7 @@ from .default_settings import settings from django.test import TestCase from django.test import Client +import six from lxml import etree from cas_server import models @@ -59,6 +60,60 @@ def get_pgt(): return params +class CheckPasswordCase(TestCase): + + def setUp(self): + self.password1 = utils.gen_saml_id() + self.password2 = utils.gen_saml_id() + if not isinstance(self.password1, bytes): + self.password1 = self.password1.encode("utf8") + self.password2 = self.password2.encode("utf8") + + def test_setup(self): + self.assertIsInstance(self.password1, bytes) + self.assertIsInstance(self.password2, bytes) + + def test_plain(self): + self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8")) + self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8")) + + def test_crypt(self): + if six.PY3: + hashed_password1 = utils.crypt.crypt( + self.password1.decode("utf8"), + "$6$UVVAQvrMyXMF3FF3" + ).encode("utf8") + else: + hashed_password1 = utils.crypt.crypt(self.password1, "$6$UVVAQvrMyXMF3FF3") + + self.assertTrue(utils.check_password("crypt", self.password1, hashed_password1, "utf8")) + self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8")) + + def test_ldap_ssha(self): + salt = b"UVVAQvrMyXMF3FF3" + hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8") + + self.assertIsInstance(hashed_password1, bytes) + self.assertTrue(utils.check_password("ldap", self.password1, hashed_password1, "utf8")) + self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8")) + + def test_hex_md5(self): + hashed_password1 = utils.hashlib.md5(self.password1).hexdigest() + + self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8")) + self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8")) + + def test_hox_sha512(self): + hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest() + + self.assertTrue( + utils.check_password("hex_sha512", self.password1, hashed_password1, "utf8") + ) + self.assertFalse( + utils.check_password("hex_sha512", self.password2, hashed_password1, "utf8") + ) + + class LoginTestCase(TestCase): def setUp(self): diff --git a/cas_server/utils.py b/cas_server/utils.py index 340a898..68325eb 100644 --- a/cas_server/utils.py +++ b/cas_server/utils.py @@ -177,6 +177,7 @@ class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler): httpd_thread.start() return (httpd_thread, host, port) + class LdapHashUserPassword(object): """Please see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html""" @@ -204,8 +205,6 @@ class LdapHashUserPassword(object): b"{SSHA512}": 64, } - - class BadScheme(ValueError): pass @@ -217,9 +216,9 @@ class LdapHashUserPassword(object): @classmethod def _raise_bad_scheme(cls, scheme, valid, msg): - valid_schemes = [s for s in valid] + valid_schemes = [s.decode() for s in valid] valid_schemes.sort() - raise cls.BadScheme(msg % (scheme, ", ".join(valid_schemes))) + raise cls.BadScheme(msg % (scheme, u", ".join(valid_schemes))) @classmethod def _test_scheme(cls, scheme): @@ -258,7 +257,9 @@ class LdapHashUserPassword(object): elif salt is not None: cls._test_scheme_salt(scheme) try: - return scheme + base64.b64encode(cls._schemes_to_hash[scheme](password + salt).digest() + salt) + return scheme + base64.b64encode( + cls._schemes_to_hash[scheme](password + salt).digest() + salt + ) except KeyError: if six.PY3: password = password.decode(charset) @@ -272,13 +273,12 @@ class LdapHashUserPassword(object): @classmethod def get_scheme(cls, hashed_passord): - if not hashed_passord[0] == b'{' or not b'}' in hashed_passord: + if not hashed_passord[0] == b'{'[0] or b'}' not in hashed_passord: raise cls.BadHash("%r should start with the scheme enclosed with { }" % hashed_passord) scheme = hashed_passord.split(b'}', 1)[0] scheme = scheme.upper() + b"}" return scheme - @classmethod def get_salt(cls, hashed_passord): scheme = cls.get_scheme(hashed_passord) @@ -294,7 +294,6 @@ class LdapHashUserPassword(object): return hashed_passord[cls._schemes_to_len[scheme]:] - def check_password(method, password, hashed_password, charset): if not isinstance(password, six.binary_type): password = password.encode(charset) @@ -325,6 +324,9 @@ def check_password(method, password, hashed_password, charset): method.startswith("hex_") and method[4:] in {"md5", "sha1", "sha224", "sha256", "sha384", "sha512"} ): - return getattr(hashlib, method[4:])(password).hexdigest() == hashed_password.lower() + return getattr( + hashlib, + method[4:] + )(password).hexdigest().encode("ascii") == hashed_password.lower() else: raise ValueError("Unknown password method check %r" % method) From 93c2dae96b6658ac7e7780b83f31a5fb8d0d264e Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Sun, 26 Jun 2016 22:07:38 +0200 Subject: [PATCH 24/27] Add docstrings --- cas_server/tests.py | 8 ++++++++ cas_server/utils.py | 20 ++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/cas_server/tests.py b/cas_server/tests.py index 3f53a04..7d355cb 100644 --- a/cas_server/tests.py +++ b/cas_server/tests.py @@ -61,8 +61,10 @@ def get_pgt(): class CheckPasswordCase(TestCase): + """Tests for the utils function `utils.check_password`""" def setUp(self): + """Generate random bytes string that will be used ass passwords""" self.password1 = utils.gen_saml_id() self.password2 = utils.gen_saml_id() if not isinstance(self.password1, bytes): @@ -70,14 +72,17 @@ class CheckPasswordCase(TestCase): self.password2 = self.password2.encode("utf8") def test_setup(self): + """check that generated password are bytes""" self.assertIsInstance(self.password1, bytes) self.assertIsInstance(self.password2, bytes) def test_plain(self): + """test the plain auth method""" self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8")) self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8")) def test_crypt(self): + """test the crypt auth method""" if six.PY3: hashed_password1 = utils.crypt.crypt( self.password1.decode("utf8"), @@ -90,6 +95,7 @@ class CheckPasswordCase(TestCase): self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8")) def test_ldap_ssha(self): + """test the ldap auth method with a {SSHA} scheme""" salt = b"UVVAQvrMyXMF3FF3" hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8") @@ -98,12 +104,14 @@ class CheckPasswordCase(TestCase): self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8")) def test_hex_md5(self): + """test the hex_md5 auth method""" hashed_password1 = utils.hashlib.md5(self.password1).hexdigest() self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8")) self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8")) def test_hox_sha512(self): + """test the hex_sha512 auth method""" hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest() self.assertTrue( diff --git a/cas_server/utils.py b/cas_server/utils.py index 68325eb..c8b345b 100644 --- a/cas_server/utils.py +++ b/cas_server/utils.py @@ -206,22 +206,30 @@ class LdapHashUserPassword(object): } class BadScheme(ValueError): + """Error raised then the hash scheme is not in schemes_salt + schemes_nosalt""" pass class BadHash(ValueError): + """Error raised then the hash is too short""" pass class BadSalt(ValueError): + """Error raised then with the scheme {CRYPT} the salt is invalid""" pass @classmethod def _raise_bad_scheme(cls, scheme, valid, msg): + """ + Raise BadScheme error for `scheme`, possible valid scheme are + in `valid`, the error message is `msg` + """ valid_schemes = [s.decode() for s in valid] valid_schemes.sort() raise cls.BadScheme(msg % (scheme, u", ".join(valid_schemes))) @classmethod def _test_scheme(cls, scheme): + """Test if a scheme is valide or raise BadScheme""" if scheme not in cls.schemes_salt and scheme not in cls.schemes_nosalt: cls._raise_bad_scheme( scheme, @@ -231,6 +239,7 @@ class LdapHashUserPassword(object): @classmethod def _test_scheme_salt(cls, scheme): + """Test if the scheme need a salt or raise BadScheme""" if scheme not in cls.schemes_salt: cls._raise_bad_scheme( scheme, @@ -240,6 +249,7 @@ class LdapHashUserPassword(object): @classmethod def _test_scheme_nosalt(cls, scheme): + """Test if the scheme need no salt or raise BadScheme""" if scheme not in cls.schemes_nosalt: cls._raise_bad_scheme( scheme, @@ -249,6 +259,10 @@ class LdapHashUserPassword(object): @classmethod def hash(cls, scheme, password, salt=None, charset="utf8"): + """ + Hash `password` with `scheme` using `salt`. + This three variable beeing encoded in `charset`. + """ scheme = scheme.upper() cls._test_scheme(scheme) if salt is None or salt == b"": @@ -273,6 +287,7 @@ class LdapHashUserPassword(object): @classmethod def get_scheme(cls, hashed_passord): + """Return the scheme of `hashed_passord` or raise BadHash""" if not hashed_passord[0] == b'{'[0] or b'}' not in hashed_passord: raise cls.BadHash("%r should start with the scheme enclosed with { }" % hashed_passord) scheme = hashed_passord.split(b'}', 1)[0] @@ -281,6 +296,7 @@ class LdapHashUserPassword(object): @classmethod def get_salt(cls, hashed_passord): + """Return the salt of `hashed_passord` possibly empty""" scheme = cls.get_scheme(hashed_passord) cls._test_scheme(scheme) if scheme in cls.schemes_nosalt: @@ -295,6 +311,10 @@ class LdapHashUserPassword(object): def check_password(method, password, hashed_password, charset): + """ + Check that `password` match `hashed_password` using `method`, + assuming the encoding is `charset`. + """ if not isinstance(password, six.binary_type): password = password.encode(charset) if not isinstance(hashed_password, six.binary_type): From 7db31578643b43479009d88993224eb7a410fb6d Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Mon, 27 Jun 2016 14:01:39 +0200 Subject: [PATCH 25/27] Forgotten return --- cas_server/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cas_server/auth.py b/cas_server/auth.py index 7051828..f84fb11 100644 --- a/cas_server/auth.py +++ b/cas_server/auth.py @@ -90,7 +90,7 @@ class MysqlAuthUser(AuthUser): # pragma: no cover def test_password(self, password): """test `password` agains the user""" if self.user: - check_password( + return check_password( settings.CAS_SQL_PASSWORD_CHECK, password, self.user["password"], From bab79c4de54686c3b305069f3dd1cf655bf547d6 Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Mon, 27 Jun 2016 23:55:17 +0200 Subject: [PATCH 26/27] More unit tests (essentially for the login view) and some docstrings --- .coveragerc | 1 + Makefile | 3 +- README.rst | 3 +- cas_server/default_settings.py | 7 +- cas_server/tests.py | 335 ++++++++++++++++++++++++++++----- cas_server/views.py | 33 +++- settings_tests.py | 2 +- urls_tests.py | 22 +++ 8 files changed, 343 insertions(+), 63 deletions(-) create mode 100644 urls_tests.py diff --git a/.coveragerc b/.coveragerc index f11c9de..b4da6da 100644 --- a/.coveragerc +++ b/.coveragerc @@ -5,3 +5,4 @@ exclude_lines = def __unicode__ raise AssertionError raise NotImplementedError + if six.PY3: diff --git a/Makefile b/Makefile index 9088fba..2273da9 100644 --- a/Makefile +++ b/Makefile @@ -49,8 +49,9 @@ coverage: test_venv test_venv/bin/pip install coverage test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests test_venv/bin/coverage html - test_venv/bin/coverage xml + rm htmlcov/coverage_html.js # I am really pissed off by those keybord shortcuts coverage_codacy: coverage + test_venv/bin/coverage xml test_venv/bin/pip install codacy-coverage test_venv/bin/python-codacy-coverage -r coverage.xml diff --git a/README.rst b/README.rst index 29b8057..bb148a3 100644 --- a/README.rst +++ b/README.rst @@ -219,7 +219,8 @@ Test backend settings. Only usefull if you are using the test authentication bac * ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``. * ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``. * ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is - ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}``. + ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net', + 'alias': ['demo1', 'demo2']}``. Authentication backend diff --git a/cas_server/default_settings.py b/cas_server/default_settings.py index 2824991..00bb6fa 100644 --- a/cas_server/default_settings.py +++ b/cas_server/default_settings.py @@ -78,5 +78,10 @@ setting_default('CAS_TEST_USER', 'test') setting_default('CAS_TEST_PASSWORD', 'test') setting_default( 'CAS_TEST_ATTRIBUTES', - {'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'} + { + 'nom': 'Nymous', + 'prenom': 'Ano', + 'email': 'anonymous@example.net', + 'alias': ['demo1', 'demo2'] + } ) diff --git a/cas_server/tests.py b/cas_server/tests.py index 7d355cb..5f0a29d 100644 --- a/cas_server/tests.py +++ b/cas_server/tests.py @@ -3,36 +3,49 @@ from .default_settings import settings from django.test import TestCase from django.test import Client +import re import six +import random from lxml import etree +from six.moves import range from cas_server import models from cas_server import utils -def get_login_page_params(): - client = Client() - response = client.get('/login') - form = response.context["form"] +def copy_form(form): + """Copy form value into a dict""" params = {} for field in form: if field.value(): params[field.name] = field.value() else: params[field.name] = "" + return params + + +def get_login_page_params(client=None): + """Return a client and the POST params for the client to login""" + if client is None: + client = Client() + response = client.get('/login') + params = copy_form(response.context["form"]) return client, params -def get_auth_client(): +def get_auth_client(**update): + """return a authenticated client""" client, params = get_login_page_params() params["username"] = settings.CAS_TEST_USER params["password"] = settings.CAS_TEST_PASSWORD + params.update(update) client.post('/login', params) return client def get_user_ticket_request(service): + """Make an auth client to request a ticket for `service`, return the tuple (user, ticket)""" client = get_auth_client() response = client.get("/login", {"service": service}) ticket_value = response['Location'].split('ticket=')[-1] @@ -45,6 +58,7 @@ def get_user_ticket_request(service): def get_pgt(): + """return a dict contening a service, user and PGT ticket for this service""" (host, port) = utils.PGTUrlHandler.run()[1:3] service = "http://%s:%s" % (host, port) @@ -110,7 +124,7 @@ class CheckPasswordCase(TestCase): self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8")) self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8")) - def test_hox_sha512(self): + def test_hex_sha512(self): """test the hex_sha512 auth method""" hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest() @@ -123,29 +137,83 @@ class CheckPasswordCase(TestCase): class LoginTestCase(TestCase): - + """Tests for the login view""" def setUp(self): + """ + Prepare the test context: + * set the auth class to 'cas_server.auth.TestAuthUser' + * create a service pattern for https://www.example.com/** + * Set the service pattern to return all user attributes + """ settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + + # For general purpose testing self.service_pattern = models.ServicePattern.objects.create( name="example", pattern="^https://www\.example\.com(/.*)?$", ) models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) - def test_login_view_post_goodpass_goodlt(self): - client, params = get_login_page_params() - params["username"] = settings.CAS_TEST_USER - params["password"] = settings.CAS_TEST_PASSWORD + # For testing the restrict_users attributes + self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create( + name="restrict_user_fail", + pattern="^https://restrict_user_fail\.example\.com(/.*)?$", + restrict_users=True, + ) + self.service_pattern_restrict_user_success = models.ServicePattern.objects.create( + name="restrict_user_success", + pattern="^https://restrict_user_success\.example\.com(/.*)?$", + restrict_users=True, + ) + models.Username.objects.create( + value=settings.CAS_TEST_USER, + service_pattern=self.service_pattern_restrict_user_success + ) - response = client.post('/login', params) + # For testing the user attributes filtering conditions + self.service_pattern_filter_fail = models.ServicePattern.objects.create( + name="filter_fail", + pattern="^https://filter_fail\.example\.com(/.*)?$", + ) + models.FilterAttributValue.objects.create( + attribut="right", + pattern="^admin$", + service_pattern=self.service_pattern_filter_fail + ) + self.service_pattern_filter_success = models.ServicePattern.objects.create( + name="filter_success", + pattern="^https://filter_success\.example\.com(/.*)?$", + ) + models.FilterAttributValue.objects.create( + attribut="email", + pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']), + service_pattern=self.service_pattern_filter_success + ) - self.assertEqual(response.status_code, 200) + # For testing the user_field attributes + self.service_pattern_field_needed_fail = models.ServicePattern.objects.create( + name="field_needed_fail", + pattern="^https://field_needed_fail\.example\.com(/.*)?$", + user_field="uid" + ) + self.service_pattern_field_needed_success = models.ServicePattern.objects.create( + name="field_needed_success", + pattern="^https://field_needed_success\.example\.com(/.*)?$", + user_field="nom" + ) + + def assert_logged(self, client, response, warn=False, code=200): + """Assertions testing that client is well authenticated""" + self.assertEqual(response.status_code, code) self.assertTrue( ( b"You have successfully logged into " b"the Central Authentication Service" ) in response.content ) + self.assertTrue(client.session["username"] == settings.CAS_TEST_USER) + self.assertTrue(client.session["warn"] is warn) + self.assertTrue(client.session["authenticated"] is True) self.assertTrue( models.User.objects.get( @@ -154,7 +222,59 @@ class LoginTestCase(TestCase): ) ) + def assert_login_failed(self, client, response, code=200): + """Assertions testing a failed login attempt""" + self.assertEqual(response.status_code, code) + self.assertFalse( + ( + b"You have successfully logged into " + b"the Central Authentication Service" + ) in response.content + ) + + self.assertTrue(client.session.get("username") is None) + self.assertTrue(client.session.get("warn") is None) + self.assertTrue(client.session.get("authenticated") is None) + + def test_login_view_post_goodpass_goodlt(self): + """Test a successul login""" + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + self.assertTrue(params['lt'] in client.session['lt']) + + response = client.post('/login', params) + self.assert_logged(client, response) + # LoginTicket conssumed + self.assertTrue(params['lt'] not in client.session['lt']) + + def test_login_view_post_goodpass_goodlt_warn(self): + """Test a successul login requesting to be warned before creating services tickets""" + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + params["warn"] = "on" + + response = client.post('/login', params) + self.assert_logged(client, response, warn=True) + + def test_lt_max(self): + """Check we only keep the last 100 Login Ticket for a user""" + client, params = get_login_page_params() + current_lt = params["lt"] + i_in_test = random.randint(0, 100) + i_not_in_test = random.randint(100, 150) + for i in range(150): + if i == i_in_test: + self.assertTrue(current_lt in client.session['lt']) + if i == i_not_in_test: + self.assertTrue(current_lt not in client.session['lt']) + self.assertTrue(len(client.session['lt']) <= 100) + client, params = get_login_page_params(client) + self.assertTrue(len(client.session['lt']) <= 100) + def test_login_view_post_badlt(self): + """Login attempt with a bad LoginTicket""" client, params = get_login_page_params() params["username"] = settings.CAS_TEST_USER params["password"] = settings.CAS_TEST_PASSWORD @@ -162,47 +282,26 @@ class LoginTestCase(TestCase): response = client.post('/login', params) - self.assertEqual(response.status_code, 200) + self.assert_login_failed(client, response) self.assertTrue(b"Invalid login ticket" in response.content) - self.assertFalse( - ( - b"You have successfully logged into " - b"the Central Authentication Service" - ) in response.content - ) def test_login_view_post_badpass_good_lt(self): + """Login attempt with a bad password""" client, params = get_login_page_params() params["username"] = settings.CAS_TEST_USER params["password"] = "test2" response = client.post('/login', params) - self.assertEqual(response.status_code, 200) + self.assert_login_failed(client, response) self.assertTrue( ( b"The credentials you provided cannot be " b"determined to be authentic" ) in response.content ) - self.assertFalse( - ( - b"You have successfully logged into " - b"the Central Authentication Service" - ) in response.content - ) - def test_view_login_get_auth_allowed_service(self): - client = get_auth_client() - response = client.get("/login?service=https://www.example.com") - self.assertEqual(response.status_code, 302) - self.assertTrue(response.has_header('Location')) - self.assertTrue( - response['Location'].startswith( - "https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX - ) - ) - - ticket_value = response['Location'].split('ticket=')[-1] + def assert_ticket_attributes(self, client, ticket_value): + """check the ticket attributes in the db""" user = models.User.objects.get( username=settings.CAS_TEST_USER, session_key=client.session.session_key @@ -214,12 +313,136 @@ class LoginTestCase(TestCase): self.assertEqual(ticket.validate, False) self.assertEqual(ticket.service_pattern, self.service_pattern) + def assert_service_ticket(self, client, response): + """check that a ticket is well emited when requested on a allowed service""" + self.assertEqual(response.status_code, 302) + self.assertTrue(response.has_header('Location')) + self.assertTrue( + response['Location'].startswith( + "https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX + ) + ) + + ticket_value = response['Location'].split('ticket=')[-1] + self.assert_ticket_attributes(client, ticket_value) + + def test_view_login_get_allowed_service(self): + """Request a ticket for an allowed service by an unauthenticated client""" + client = Client() + response = client.get("/login?service=https://www.example.com") + self.assertEqual(response.status_code, 200) + self.assertTrue( + ( + "Authentication required by service " + "example (https://www.example.com)" + ) in response.content + ) + + def test_view_login_get_denied_service(self): + """Request a ticket for an denied service by an unauthenticated client""" + client = Client() + response = client.get("/login?service=https://www.example.net") + self.assertEqual(response.status_code, 200) + self.assertTrue("Service https://www.example.net non allowed" in response.content) + + def test_view_login_get_auth_allowed_service(self): + """Request a ticket for an allowed service by an authenticated client""" + # client is already authenticated + client = get_auth_client() + response = client.get("/login?service=https://www.example.com") + self.assert_service_ticket(client, response) + + def test_view_login_get_auth_allowed_service_warn(self): + """Request a ticket for an allowed service by an authenticated client""" + # client is already authenticated + client = get_auth_client(warn="on") + response = client.get("/login?service=https://www.example.com") + self.assertEqual(response.status_code, 200) + self.assertTrue( + ( + "Authentication has been required by service " + "example (https://www.example.com)" + ) in response.content + ) + + params = copy_form(response.context["form"]) + response = client.post("/login", params) + self.assert_service_ticket(client, response) + def test_view_login_get_auth_denied_service(self): + """Request a ticket for a not allowed service by an authenticated client""" client = get_auth_client() response = client.get("/login?service=https://www.example.org") self.assertEqual(response.status_code, 200) self.assertTrue(b"Service https://www.example.org non allowed" in response.content) + def test_user_logged_not_in_db(self): + """If the user is logged but has been delete from the database, it should be logged out""" + client = get_auth_client() + models.User.objects.get( + username=settings.CAS_TEST_USER, + session_key=client.session.session_key + ).delete() + response = client.get("/login") + + self.assert_login_failed(client, response, code=302) + self.assertEqual(response["Location"], "/login?") + + def test_service_restrict_user(self): + """Testing the restric user capability fro a service""" + service = "https://restrict_user_fail.example.com" + client = get_auth_client() + response = client.get("/login", {'service': service}) + self.assertEqual(response.status_code, 200) + self.assertTrue("Username non allowed" in response.content) + + service = "https://restrict_user_success.example.com" + response = client.get("/login", {'service': service}) + self.assertEqual(response.status_code, 302) + self.assertTrue(response["Location"].startswith("%s?ticket=" % service)) + + def test_service_filter(self): + """Test the filtering on user attributes""" + service = "https://filter_fail.example.com" + client = get_auth_client() + response = client.get("/login", {'service': service}) + self.assertEqual(response.status_code, 200) + self.assertTrue("User charateristics non allowed" in response.content) + + service = "https://filter_success.example.com" + response = client.get("/login", {'service': service}) + self.assertEqual(response.status_code, 302) + self.assertTrue(response["Location"].startswith("%s?ticket=" % service)) + + def test_service_user_field(self): + """Test using a user attribute as username: case on if the attribute exists or not""" + service = "https://field_needed_fail.example.com" + client = get_auth_client() + response = client.get("/login", {'service': service}) + self.assertEqual(response.status_code, 200) + self.assertTrue("The attribut uid is needed to use that service" in response.content) + + service = "https://field_needed_success.example.com" + response = client.get("/login", {'service': service}) + self.assertEqual(response.status_code, 302) + self.assertTrue(response["Location"].startswith("%s?ticket=" % service)) + + def test_gateway(self): + """test gateway parameter""" + + # First with an authenticated client that fail to get a ticket for a service + service = "https://restrict_user_fail.example.com" + client = get_auth_client() + response = client.get("/login", {'service': service, 'gateway': 'on'}) + self.assertEqual(response.status_code, 302) + self.assertEqual(response["Location"], service) + + # second for an user not yet authenticated on a valid service + client = Client() + response = client.get('/login', {'service': service, 'gateway': 'on'}) + self.assertEqual(response.status_code, 302) + self.assertEqual(response["Location"], service) + class LogoutTestCase(TestCase): @@ -454,17 +677,24 @@ class ValidateServiceTestCase(TestCase): namespaces={'cas': "http://www.yale.edu/tp/cas"} ) self.assertEqual(len(attributes), 1) - attrs1 = {} + attrs1 = set() for attr in attributes[0]: - attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text + attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text)) attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) self.assertEqual(len(attributes), len(attrs1)) - attrs2 = {} + attrs2 = set() for attr in attributes: - attrs2[attr.attrib['name']] = attr.attrib['value'] + attrs2.add((attr.attrib['name'], attr.attrib['value'])) + original = set() + for key, value in settings.CAS_TEST_ATTRIBUTES.items(): + if isinstance(value, list): + for v in value: + original.add((key, v)) + else: + original.add((key, value)) self.assertEqual(attrs1, attrs2) - self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES) + self.assertEqual(attrs1, original) def test_validate_service_view_badservice(self): ticket = get_user_ticket_request(self.service)[1] @@ -623,17 +853,24 @@ class ProxyTestCase(TestCase): namespaces={'cas': "http://www.yale.edu/tp/cas"} ) self.assertEqual(len(attributes), 1) - attrs1 = {} + attrs1 = set() for attr in attributes[0]: - attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text + attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text)) attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) self.assertEqual(len(attributes), len(attrs1)) - attrs2 = {} + attrs2 = set() for attr in attributes: - attrs2[attr.attrib['name']] = attr.attrib['value'] + attrs2.add((attr.attrib['name'], attr.attrib['value'])) + original = set() + for key, value in settings.CAS_TEST_ATTRIBUTES.items(): + if isinstance(value, list): + for v in value: + original.add((key, v)) + else: + original.add((key, value)) self.assertEqual(attrs1, attrs2) - self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES) + self.assertEqual(attrs1, original) def test_validate_proxy_bad(self): params = get_pgt() diff --git a/cas_server/views.py b/cas_server/views.py index 2b33a6c..a48dd7e 100644 --- a/cas_server/views.py +++ b/cas_server/views.py @@ -105,6 +105,7 @@ class LogoutView(View, LogoutMixin): service = None def init_get(self, request): + """Initialize GET received parameters""" self.request = request self.service = request.GET.get('service') self.url = request.GET.get('url') @@ -196,6 +197,7 @@ class LoginView(View, LogoutMixin): USER_NOT_AUTHENTICATED = 6 def init_post(self, request): + """Initialize POST received parameters""" self.request = request self.service = request.POST.get('service') self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False") @@ -205,15 +207,19 @@ class LoginView(View, LogoutMixin): if request.POST.get('warned') and request.POST['warned'] != "False": self.warned = True - def check_lt(self): - # save LT for later check - lt_valid = self.request.session.get('lt', []) - lt_send = self.request.POST.get('lt') - # generate a new LT (by posting the LT has been consumed) + def gen_lt(self): + """Generate a new LoginTicket and add it to the list of valid LT for the user""" self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()] if len(self.request.session['lt']) > 100: self.request.session['lt'] = self.request.session['lt'][-100:] + def check_lt(self): + """Check is the POSTed LoginTicket is valid, if yes invalide it""" + # save LT for later check + lt_valid = self.request.session.get('lt', []) + lt_send = self.request.POST.get('lt') + # generate a new LT (by posting the LT has been consumed) + self.gen_lt() # check if send LT is valid if lt_valid is None or lt_send not in lt_valid: return False @@ -238,7 +244,7 @@ class LoginView(View, LogoutMixin): username=self.request.session['username'], session_key=self.request.session.session_key ) - self.user.save() + self.user.save() # pragma: no cover (should not happend) except models.User.DoesNotExist: self.user = models.User.objects.create( username=self.request.session['username'], @@ -250,10 +256,15 @@ class LoginView(View, LogoutMixin): elif ret == self.USER_ALREADY_LOGGED: pass else: - raise EnvironmentError("invalid output for LoginView.process_post") + raise EnvironmentError("invalid output for LoginView.process_post") # pragma: no cover return self.common() def process_post(self): + """ + Analyse the POST request: + * check that the LoginTicket is valid + * check that the user sumited credentials are valid + """ if not self.check_lt(): values = self.request.POST.copy() # if not set a new LT and fail @@ -280,6 +291,7 @@ class LoginView(View, LogoutMixin): return self.USER_ALREADY_LOGGED def init_get(self, request): + """Initialize GET received parameters""" self.request = request self.service = request.GET.get('service') self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False") @@ -294,15 +306,16 @@ class LoginView(View, LogoutMixin): return self.common() def process_get(self): - # generate a new LT if none is present - self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()] - + """Analyse the GET request""" + # generate a new LT + self.gen_lt() if not self.request.session.get("authenticated") or self.renew: self.init_form() return self.USER_NOT_AUTHENTICATED return self.USER_AUTHENTICATED def init_form(self, values=None): + """Initialization of the good form depending of POST and GET parameters""" self.form = forms.UserCredential( values, initial={ diff --git a/settings_tests.py b/settings_tests.py index 4588c2c..e1c0558 100644 --- a/settings_tests.py +++ b/settings_tests.py @@ -52,7 +52,7 @@ MIDDLEWARE_CLASSES = [ 'django.middleware.locale.LocaleMiddleware', ] -ROOT_URLCONF = 'cas_server.urls' +ROOT_URLCONF = 'urls_tests' # Database # https://docs.djangoproject.com/en/1.9/ref/settings/#databases diff --git a/urls_tests.py b/urls_tests.py new file mode 100644 index 0000000..a9ed25c --- /dev/null +++ b/urls_tests.py @@ -0,0 +1,22 @@ +"""cas URL Configuration + +The `urlpatterns` list routes URLs to views. For more information please see: + https://docs.djangoproject.com/en/1.9/topics/http/urls/ +Examples: +Function views + 1. Add an import: from my_app import views + 2. Add a URL to urlpatterns: url(r'^$', views.home, name='home') +Class-based views + 1. Add an import: from other_app.views import Home + 2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home') +Including another URLconf + 1. Import the include() function: from django.conf.urls import url, include, include + 2. Add a URL to urlpatterns: url(r'^blog/', include('blog.urls')) +""" +from django.conf.urls import url, include +from django.contrib import admin + +urlpatterns = [ + url(r'^admin/', admin.site.urls), + url(r'^', include('cas_server.urls', namespace='cas_server')), +] From fc57288c3047a0e98caf96069949df692d9dece8 Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Tue, 28 Jun 2016 00:10:36 +0200 Subject: [PATCH 27/27] Fix some python3 compat and change in test client behaviour in django 1.9 --- cas_server/tests.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/cas_server/tests.py b/cas_server/tests.py index 5f0a29d..916a6d4 100644 --- a/cas_server/tests.py +++ b/cas_server/tests.py @@ -1,5 +1,6 @@ from .default_settings import settings +import django from django.test import TestCase from django.test import Client @@ -333,8 +334,8 @@ class LoginTestCase(TestCase): self.assertEqual(response.status_code, 200) self.assertTrue( ( - "Authentication required by service " - "example (https://www.example.com)" + b"Authentication required by service " + b"example (https://www.example.com)" ) in response.content ) @@ -343,7 +344,7 @@ class LoginTestCase(TestCase): client = Client() response = client.get("/login?service=https://www.example.net") self.assertEqual(response.status_code, 200) - self.assertTrue("Service https://www.example.net non allowed" in response.content) + self.assertTrue(b"Service https://www.example.net non allowed" in response.content) def test_view_login_get_auth_allowed_service(self): """Request a ticket for an allowed service by an authenticated client""" @@ -360,8 +361,8 @@ class LoginTestCase(TestCase): self.assertEqual(response.status_code, 200) self.assertTrue( ( - "Authentication has been required by service " - "example (https://www.example.com)" + b"Authentication has been required by service " + b"example (https://www.example.com)" ) in response.content ) @@ -386,7 +387,10 @@ class LoginTestCase(TestCase): response = client.get("/login") self.assert_login_failed(client, response, code=302) - self.assertEqual(response["Location"], "/login?") + if django.VERSION < (1, 9): + self.assertEqual(response["Location"], "http://testserver/login") + else: + self.assertEqual(response["Location"], "/login?") def test_service_restrict_user(self): """Testing the restric user capability fro a service""" @@ -394,7 +398,7 @@ class LoginTestCase(TestCase): client = get_auth_client() response = client.get("/login", {'service': service}) self.assertEqual(response.status_code, 200) - self.assertTrue("Username non allowed" in response.content) + self.assertTrue(b"Username non allowed" in response.content) service = "https://restrict_user_success.example.com" response = client.get("/login", {'service': service}) @@ -407,7 +411,7 @@ class LoginTestCase(TestCase): client = get_auth_client() response = client.get("/login", {'service': service}) self.assertEqual(response.status_code, 200) - self.assertTrue("User charateristics non allowed" in response.content) + self.assertTrue(b"User charateristics non allowed" in response.content) service = "https://filter_success.example.com" response = client.get("/login", {'service': service}) @@ -420,7 +424,7 @@ class LoginTestCase(TestCase): client = get_auth_client() response = client.get("/login", {'service': service}) self.assertEqual(response.status_code, 200) - self.assertTrue("The attribut uid is needed to use that service" in response.content) + self.assertTrue(b"The attribut uid is needed to use that service" in response.content) service = "https://field_needed_success.example.com" response = client.get("/login", {'service': service}) @@ -689,8 +693,8 @@ class ValidateServiceTestCase(TestCase): original = set() for key, value in settings.CAS_TEST_ATTRIBUTES.items(): if isinstance(value, list): - for v in value: - original.add((key, v)) + for sub_value in value: + original.add((key, sub_value)) else: original.add((key, value)) self.assertEqual(attrs1, attrs2) @@ -865,8 +869,8 @@ class ProxyTestCase(TestCase): original = set() for key, value in settings.CAS_TEST_ATTRIBUTES.items(): if isinstance(value, list): - for v in value: - original.add((key, v)) + for sub_value in value: + original.add((key, sub_value)) else: original.add((key, value)) self.assertEqual(attrs1, attrs2)