From bab79c4de54686c3b305069f3dd1cf655bf547d6 Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Mon, 27 Jun 2016 23:55:17 +0200 Subject: [PATCH] More unit tests (essentially for the login view) and some docstrings --- .coveragerc | 1 + Makefile | 3 +- README.rst | 3 +- cas_server/default_settings.py | 7 +- cas_server/tests.py | 335 ++++++++++++++++++++++++++++----- cas_server/views.py | 33 +++- settings_tests.py | 2 +- urls_tests.py | 22 +++ 8 files changed, 343 insertions(+), 63 deletions(-) create mode 100644 urls_tests.py diff --git a/.coveragerc b/.coveragerc index f11c9de..b4da6da 100644 --- a/.coveragerc +++ b/.coveragerc @@ -5,3 +5,4 @@ exclude_lines = def __unicode__ raise AssertionError raise NotImplementedError + if six.PY3: diff --git a/Makefile b/Makefile index 9088fba..2273da9 100644 --- a/Makefile +++ b/Makefile @@ -49,8 +49,9 @@ coverage: test_venv test_venv/bin/pip install coverage test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests test_venv/bin/coverage html - test_venv/bin/coverage xml + rm htmlcov/coverage_html.js # I am really pissed off by those keybord shortcuts coverage_codacy: coverage + test_venv/bin/coverage xml test_venv/bin/pip install codacy-coverage test_venv/bin/python-codacy-coverage -r coverage.xml diff --git a/README.rst b/README.rst index 29b8057..bb148a3 100644 --- a/README.rst +++ b/README.rst @@ -219,7 +219,8 @@ Test backend settings. Only usefull if you are using the test authentication bac * ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``. * ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``. * ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is - ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}``. + ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net', + 'alias': ['demo1', 'demo2']}``. Authentication backend diff --git a/cas_server/default_settings.py b/cas_server/default_settings.py index 2824991..00bb6fa 100644 --- a/cas_server/default_settings.py +++ b/cas_server/default_settings.py @@ -78,5 +78,10 @@ setting_default('CAS_TEST_USER', 'test') setting_default('CAS_TEST_PASSWORD', 'test') setting_default( 'CAS_TEST_ATTRIBUTES', - {'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'} + { + 'nom': 'Nymous', + 'prenom': 'Ano', + 'email': 'anonymous@example.net', + 'alias': ['demo1', 'demo2'] + } ) diff --git a/cas_server/tests.py b/cas_server/tests.py index 7d355cb..5f0a29d 100644 --- a/cas_server/tests.py +++ b/cas_server/tests.py @@ -3,36 +3,49 @@ from .default_settings import settings from django.test import TestCase from django.test import Client +import re import six +import random from lxml import etree +from six.moves import range from cas_server import models from cas_server import utils -def get_login_page_params(): - client = Client() - response = client.get('/login') - form = response.context["form"] +def copy_form(form): + """Copy form value into a dict""" params = {} for field in form: if field.value(): params[field.name] = field.value() else: params[field.name] = "" + return params + + +def get_login_page_params(client=None): + """Return a client and the POST params for the client to login""" + if client is None: + client = Client() + response = client.get('/login') + params = copy_form(response.context["form"]) return client, params -def get_auth_client(): +def get_auth_client(**update): + """return a authenticated client""" client, params = get_login_page_params() params["username"] = settings.CAS_TEST_USER params["password"] = settings.CAS_TEST_PASSWORD + params.update(update) client.post('/login', params) return client def get_user_ticket_request(service): + """Make an auth client to request a ticket for `service`, return the tuple (user, ticket)""" client = get_auth_client() response = client.get("/login", {"service": service}) ticket_value = response['Location'].split('ticket=')[-1] @@ -45,6 +58,7 @@ def get_user_ticket_request(service): def get_pgt(): + """return a dict contening a service, user and PGT ticket for this service""" (host, port) = utils.PGTUrlHandler.run()[1:3] service = "http://%s:%s" % (host, port) @@ -110,7 +124,7 @@ class CheckPasswordCase(TestCase): self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8")) self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8")) - def test_hox_sha512(self): + def test_hex_sha512(self): """test the hex_sha512 auth method""" hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest() @@ -123,29 +137,83 @@ class CheckPasswordCase(TestCase): class LoginTestCase(TestCase): - + """Tests for the login view""" def setUp(self): + """ + Prepare the test context: + * set the auth class to 'cas_server.auth.TestAuthUser' + * create a service pattern for https://www.example.com/** + * Set the service pattern to return all user attributes + """ settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + + # For general purpose testing self.service_pattern = models.ServicePattern.objects.create( name="example", pattern="^https://www\.example\.com(/.*)?$", ) models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) - def test_login_view_post_goodpass_goodlt(self): - client, params = get_login_page_params() - params["username"] = settings.CAS_TEST_USER - params["password"] = settings.CAS_TEST_PASSWORD + # For testing the restrict_users attributes + self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create( + name="restrict_user_fail", + pattern="^https://restrict_user_fail\.example\.com(/.*)?$", + restrict_users=True, + ) + self.service_pattern_restrict_user_success = models.ServicePattern.objects.create( + name="restrict_user_success", + pattern="^https://restrict_user_success\.example\.com(/.*)?$", + restrict_users=True, + ) + models.Username.objects.create( + value=settings.CAS_TEST_USER, + service_pattern=self.service_pattern_restrict_user_success + ) - response = client.post('/login', params) + # For testing the user attributes filtering conditions + self.service_pattern_filter_fail = models.ServicePattern.objects.create( + name="filter_fail", + pattern="^https://filter_fail\.example\.com(/.*)?$", + ) + models.FilterAttributValue.objects.create( + attribut="right", + pattern="^admin$", + service_pattern=self.service_pattern_filter_fail + ) + self.service_pattern_filter_success = models.ServicePattern.objects.create( + name="filter_success", + pattern="^https://filter_success\.example\.com(/.*)?$", + ) + models.FilterAttributValue.objects.create( + attribut="email", + pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']), + service_pattern=self.service_pattern_filter_success + ) - self.assertEqual(response.status_code, 200) + # For testing the user_field attributes + self.service_pattern_field_needed_fail = models.ServicePattern.objects.create( + name="field_needed_fail", + pattern="^https://field_needed_fail\.example\.com(/.*)?$", + user_field="uid" + ) + self.service_pattern_field_needed_success = models.ServicePattern.objects.create( + name="field_needed_success", + pattern="^https://field_needed_success\.example\.com(/.*)?$", + user_field="nom" + ) + + def assert_logged(self, client, response, warn=False, code=200): + """Assertions testing that client is well authenticated""" + self.assertEqual(response.status_code, code) self.assertTrue( ( b"You have successfully logged into " b"the Central Authentication Service" ) in response.content ) + self.assertTrue(client.session["username"] == settings.CAS_TEST_USER) + self.assertTrue(client.session["warn"] is warn) + self.assertTrue(client.session["authenticated"] is True) self.assertTrue( models.User.objects.get( @@ -154,7 +222,59 @@ class LoginTestCase(TestCase): ) ) + def assert_login_failed(self, client, response, code=200): + """Assertions testing a failed login attempt""" + self.assertEqual(response.status_code, code) + self.assertFalse( + ( + b"You have successfully logged into " + b"the Central Authentication Service" + ) in response.content + ) + + self.assertTrue(client.session.get("username") is None) + self.assertTrue(client.session.get("warn") is None) + self.assertTrue(client.session.get("authenticated") is None) + + def test_login_view_post_goodpass_goodlt(self): + """Test a successul login""" + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + self.assertTrue(params['lt'] in client.session['lt']) + + response = client.post('/login', params) + self.assert_logged(client, response) + # LoginTicket conssumed + self.assertTrue(params['lt'] not in client.session['lt']) + + def test_login_view_post_goodpass_goodlt_warn(self): + """Test a successul login requesting to be warned before creating services tickets""" + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + params["warn"] = "on" + + response = client.post('/login', params) + self.assert_logged(client, response, warn=True) + + def test_lt_max(self): + """Check we only keep the last 100 Login Ticket for a user""" + client, params = get_login_page_params() + current_lt = params["lt"] + i_in_test = random.randint(0, 100) + i_not_in_test = random.randint(100, 150) + for i in range(150): + if i == i_in_test: + self.assertTrue(current_lt in client.session['lt']) + if i == i_not_in_test: + self.assertTrue(current_lt not in client.session['lt']) + self.assertTrue(len(client.session['lt']) <= 100) + client, params = get_login_page_params(client) + self.assertTrue(len(client.session['lt']) <= 100) + def test_login_view_post_badlt(self): + """Login attempt with a bad LoginTicket""" client, params = get_login_page_params() params["username"] = settings.CAS_TEST_USER params["password"] = settings.CAS_TEST_PASSWORD @@ -162,47 +282,26 @@ class LoginTestCase(TestCase): response = client.post('/login', params) - self.assertEqual(response.status_code, 200) + self.assert_login_failed(client, response) self.assertTrue(b"Invalid login ticket" in response.content) - self.assertFalse( - ( - b"You have successfully logged into " - b"the Central Authentication Service" - ) in response.content - ) def test_login_view_post_badpass_good_lt(self): + """Login attempt with a bad password""" client, params = get_login_page_params() params["username"] = settings.CAS_TEST_USER params["password"] = "test2" response = client.post('/login', params) - self.assertEqual(response.status_code, 200) + self.assert_login_failed(client, response) self.assertTrue( ( b"The credentials you provided cannot be " b"determined to be authentic" ) in response.content ) - self.assertFalse( - ( - b"You have successfully logged into " - b"the Central Authentication Service" - ) in response.content - ) - def test_view_login_get_auth_allowed_service(self): - client = get_auth_client() - response = client.get("/login?service=https://www.example.com") - self.assertEqual(response.status_code, 302) - self.assertTrue(response.has_header('Location')) - self.assertTrue( - response['Location'].startswith( - "https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX - ) - ) - - ticket_value = response['Location'].split('ticket=')[-1] + def assert_ticket_attributes(self, client, ticket_value): + """check the ticket attributes in the db""" user = models.User.objects.get( username=settings.CAS_TEST_USER, session_key=client.session.session_key @@ -214,12 +313,136 @@ class LoginTestCase(TestCase): self.assertEqual(ticket.validate, False) self.assertEqual(ticket.service_pattern, self.service_pattern) + def assert_service_ticket(self, client, response): + """check that a ticket is well emited when requested on a allowed service""" + self.assertEqual(response.status_code, 302) + self.assertTrue(response.has_header('Location')) + self.assertTrue( + response['Location'].startswith( + "https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX + ) + ) + + ticket_value = response['Location'].split('ticket=')[-1] + self.assert_ticket_attributes(client, ticket_value) + + def test_view_login_get_allowed_service(self): + """Request a ticket for an allowed service by an unauthenticated client""" + client = Client() + response = client.get("/login?service=https://www.example.com") + self.assertEqual(response.status_code, 200) + self.assertTrue( + ( + "Authentication required by service " + "example (https://www.example.com)" + ) in response.content + ) + + def test_view_login_get_denied_service(self): + """Request a ticket for an denied service by an unauthenticated client""" + client = Client() + response = client.get("/login?service=https://www.example.net") + self.assertEqual(response.status_code, 200) + self.assertTrue("Service https://www.example.net non allowed" in response.content) + + def test_view_login_get_auth_allowed_service(self): + """Request a ticket for an allowed service by an authenticated client""" + # client is already authenticated + client = get_auth_client() + response = client.get("/login?service=https://www.example.com") + self.assert_service_ticket(client, response) + + def test_view_login_get_auth_allowed_service_warn(self): + """Request a ticket for an allowed service by an authenticated client""" + # client is already authenticated + client = get_auth_client(warn="on") + response = client.get("/login?service=https://www.example.com") + self.assertEqual(response.status_code, 200) + self.assertTrue( + ( + "Authentication has been required by service " + "example (https://www.example.com)" + ) in response.content + ) + + params = copy_form(response.context["form"]) + response = client.post("/login", params) + self.assert_service_ticket(client, response) + def test_view_login_get_auth_denied_service(self): + """Request a ticket for a not allowed service by an authenticated client""" client = get_auth_client() response = client.get("/login?service=https://www.example.org") self.assertEqual(response.status_code, 200) self.assertTrue(b"Service https://www.example.org non allowed" in response.content) + def test_user_logged_not_in_db(self): + """If the user is logged but has been delete from the database, it should be logged out""" + client = get_auth_client() + models.User.objects.get( + username=settings.CAS_TEST_USER, + session_key=client.session.session_key + ).delete() + response = client.get("/login") + + self.assert_login_failed(client, response, code=302) + self.assertEqual(response["Location"], "/login?") + + def test_service_restrict_user(self): + """Testing the restric user capability fro a service""" + service = "https://restrict_user_fail.example.com" + client = get_auth_client() + response = client.get("/login", {'service': service}) + self.assertEqual(response.status_code, 200) + self.assertTrue("Username non allowed" in response.content) + + service = "https://restrict_user_success.example.com" + response = client.get("/login", {'service': service}) + self.assertEqual(response.status_code, 302) + self.assertTrue(response["Location"].startswith("%s?ticket=" % service)) + + def test_service_filter(self): + """Test the filtering on user attributes""" + service = "https://filter_fail.example.com" + client = get_auth_client() + response = client.get("/login", {'service': service}) + self.assertEqual(response.status_code, 200) + self.assertTrue("User charateristics non allowed" in response.content) + + service = "https://filter_success.example.com" + response = client.get("/login", {'service': service}) + self.assertEqual(response.status_code, 302) + self.assertTrue(response["Location"].startswith("%s?ticket=" % service)) + + def test_service_user_field(self): + """Test using a user attribute as username: case on if the attribute exists or not""" + service = "https://field_needed_fail.example.com" + client = get_auth_client() + response = client.get("/login", {'service': service}) + self.assertEqual(response.status_code, 200) + self.assertTrue("The attribut uid is needed to use that service" in response.content) + + service = "https://field_needed_success.example.com" + response = client.get("/login", {'service': service}) + self.assertEqual(response.status_code, 302) + self.assertTrue(response["Location"].startswith("%s?ticket=" % service)) + + def test_gateway(self): + """test gateway parameter""" + + # First with an authenticated client that fail to get a ticket for a service + service = "https://restrict_user_fail.example.com" + client = get_auth_client() + response = client.get("/login", {'service': service, 'gateway': 'on'}) + self.assertEqual(response.status_code, 302) + self.assertEqual(response["Location"], service) + + # second for an user not yet authenticated on a valid service + client = Client() + response = client.get('/login', {'service': service, 'gateway': 'on'}) + self.assertEqual(response.status_code, 302) + self.assertEqual(response["Location"], service) + class LogoutTestCase(TestCase): @@ -454,17 +677,24 @@ class ValidateServiceTestCase(TestCase): namespaces={'cas': "http://www.yale.edu/tp/cas"} ) self.assertEqual(len(attributes), 1) - attrs1 = {} + attrs1 = set() for attr in attributes[0]: - attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text + attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text)) attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) self.assertEqual(len(attributes), len(attrs1)) - attrs2 = {} + attrs2 = set() for attr in attributes: - attrs2[attr.attrib['name']] = attr.attrib['value'] + attrs2.add((attr.attrib['name'], attr.attrib['value'])) + original = set() + for key, value in settings.CAS_TEST_ATTRIBUTES.items(): + if isinstance(value, list): + for v in value: + original.add((key, v)) + else: + original.add((key, value)) self.assertEqual(attrs1, attrs2) - self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES) + self.assertEqual(attrs1, original) def test_validate_service_view_badservice(self): ticket = get_user_ticket_request(self.service)[1] @@ -623,17 +853,24 @@ class ProxyTestCase(TestCase): namespaces={'cas': "http://www.yale.edu/tp/cas"} ) self.assertEqual(len(attributes), 1) - attrs1 = {} + attrs1 = set() for attr in attributes[0]: - attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text + attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text)) attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) self.assertEqual(len(attributes), len(attrs1)) - attrs2 = {} + attrs2 = set() for attr in attributes: - attrs2[attr.attrib['name']] = attr.attrib['value'] + attrs2.add((attr.attrib['name'], attr.attrib['value'])) + original = set() + for key, value in settings.CAS_TEST_ATTRIBUTES.items(): + if isinstance(value, list): + for v in value: + original.add((key, v)) + else: + original.add((key, value)) self.assertEqual(attrs1, attrs2) - self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES) + self.assertEqual(attrs1, original) def test_validate_proxy_bad(self): params = get_pgt() diff --git a/cas_server/views.py b/cas_server/views.py index 2b33a6c..a48dd7e 100644 --- a/cas_server/views.py +++ b/cas_server/views.py @@ -105,6 +105,7 @@ class LogoutView(View, LogoutMixin): service = None def init_get(self, request): + """Initialize GET received parameters""" self.request = request self.service = request.GET.get('service') self.url = request.GET.get('url') @@ -196,6 +197,7 @@ class LoginView(View, LogoutMixin): USER_NOT_AUTHENTICATED = 6 def init_post(self, request): + """Initialize POST received parameters""" self.request = request self.service = request.POST.get('service') self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False") @@ -205,15 +207,19 @@ class LoginView(View, LogoutMixin): if request.POST.get('warned') and request.POST['warned'] != "False": self.warned = True - def check_lt(self): - # save LT for later check - lt_valid = self.request.session.get('lt', []) - lt_send = self.request.POST.get('lt') - # generate a new LT (by posting the LT has been consumed) + def gen_lt(self): + """Generate a new LoginTicket and add it to the list of valid LT for the user""" self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()] if len(self.request.session['lt']) > 100: self.request.session['lt'] = self.request.session['lt'][-100:] + def check_lt(self): + """Check is the POSTed LoginTicket is valid, if yes invalide it""" + # save LT for later check + lt_valid = self.request.session.get('lt', []) + lt_send = self.request.POST.get('lt') + # generate a new LT (by posting the LT has been consumed) + self.gen_lt() # check if send LT is valid if lt_valid is None or lt_send not in lt_valid: return False @@ -238,7 +244,7 @@ class LoginView(View, LogoutMixin): username=self.request.session['username'], session_key=self.request.session.session_key ) - self.user.save() + self.user.save() # pragma: no cover (should not happend) except models.User.DoesNotExist: self.user = models.User.objects.create( username=self.request.session['username'], @@ -250,10 +256,15 @@ class LoginView(View, LogoutMixin): elif ret == self.USER_ALREADY_LOGGED: pass else: - raise EnvironmentError("invalid output for LoginView.process_post") + raise EnvironmentError("invalid output for LoginView.process_post") # pragma: no cover return self.common() def process_post(self): + """ + Analyse the POST request: + * check that the LoginTicket is valid + * check that the user sumited credentials are valid + """ if not self.check_lt(): values = self.request.POST.copy() # if not set a new LT and fail @@ -280,6 +291,7 @@ class LoginView(View, LogoutMixin): return self.USER_ALREADY_LOGGED def init_get(self, request): + """Initialize GET received parameters""" self.request = request self.service = request.GET.get('service') self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False") @@ -294,15 +306,16 @@ class LoginView(View, LogoutMixin): return self.common() def process_get(self): - # generate a new LT if none is present - self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()] - + """Analyse the GET request""" + # generate a new LT + self.gen_lt() if not self.request.session.get("authenticated") or self.renew: self.init_form() return self.USER_NOT_AUTHENTICATED return self.USER_AUTHENTICATED def init_form(self, values=None): + """Initialization of the good form depending of POST and GET parameters""" self.form = forms.UserCredential( values, initial={ diff --git a/settings_tests.py b/settings_tests.py index 4588c2c..e1c0558 100644 --- a/settings_tests.py +++ b/settings_tests.py @@ -52,7 +52,7 @@ MIDDLEWARE_CLASSES = [ 'django.middleware.locale.LocaleMiddleware', ] -ROOT_URLCONF = 'cas_server.urls' +ROOT_URLCONF = 'urls_tests' # Database # https://docs.djangoproject.com/en/1.9/ref/settings/#databases diff --git a/urls_tests.py b/urls_tests.py new file mode 100644 index 0000000..a9ed25c --- /dev/null +++ b/urls_tests.py @@ -0,0 +1,22 @@ +"""cas URL Configuration + +The `urlpatterns` list routes URLs to views. For more information please see: + https://docs.djangoproject.com/en/1.9/topics/http/urls/ +Examples: +Function views + 1. Add an import: from my_app import views + 2. Add a URL to urlpatterns: url(r'^$', views.home, name='home') +Class-based views + 1. Add an import: from other_app.views import Home + 2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home') +Including another URLconf + 1. Import the include() function: from django.conf.urls import url, include, include + 2. Add a URL to urlpatterns: url(r'^blog/', include('blog.urls')) +""" +from django.conf.urls import url, include +from django.contrib import admin + +urlpatterns = [ + url(r'^admin/', admin.site.urls), + url(r'^', include('cas_server.urls', namespace='cas_server')), +]