diff --git a/.gitignore b/.gitignore index 0b5a2a6..3b1bcb6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ *.pyc *.egg-info +*~ +*.swp build/ bootstrap3 @@ -7,6 +9,9 @@ cas/ dist/ db.sqlite3 manage.py +coverage.xml .tox test_venv +.coverage +htmlcov/ diff --git a/Makefile b/Makefile index e5b19f1..9088fba 100644 --- a/Makefile +++ b/Makefile @@ -44,3 +44,13 @@ test_project: test_venv test_venv/cas/manage.py run_test_server: test_project test_venv/bin/python test_venv/cas/manage.py runserver + +coverage: test_venv + test_venv/bin/pip install coverage + test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests + test_venv/bin/coverage html + test_venv/bin/coverage xml + +coverage_codacy: coverage + test_venv/bin/pip install codacy-coverage + test_venv/bin/python-codacy-coverage -r coverage.xml diff --git a/README.rst b/README.rst index 78eaae1..85b2dc4 100644 --- a/README.rst +++ b/README.rst @@ -10,8 +10,14 @@ CAS Server .. image:: https://img.shields.io/pypi/l/django-cas-server.svg :target: https://www.gnu.org/licenses/gpl-3.0.html +.. image:: https://api.codacy.com/project/badge/Grade/255c21623d6946ef8802fa7995b61366 + :target: https://www.codacy.com/app/valentin-samir/django-cas-server + +.. image:: https://api.codacy.com/project/badge/Coverage/255c21623d6946ef8802fa7995b61366 + :target: https://www.codacy.com/app/valentin-samir/django-cas-server + CAS Server is a Django application implementing the `CAS Protocol 3.0 Specification -`_. +`_. By defaut, the authentication process use django internal users but you can easily use any sources (see auth classes in the auth.py file) @@ -70,7 +76,7 @@ Quick start 4. You should add some management commands to a crontab: ``clearsessions``, ``cas_clean_tickets`` and ``cas_clean_sessions``. - * ``clearsessions``: please see `Clearing the session store `_. + * ``clearsessions``: please see `Clearing the session store `_. * ``cas_clean_tickets``: old tickets and timed-out tickets do not get purge from the database automatically. They are just marked as invalid. ``cas_clean_tickets`` is a clean-up management command for this purpose. It send SingleLogOut request @@ -204,7 +210,7 @@ Logs ---- ``django-cas-server`` logs most of its actions. To enable login, you must set the ``LOGGING`` -(https://docs.djangoproject.com/en/dev/topics/logging) variable is ``settings.py``. +(https://docs.djangoproject.com/en/stable/topics/logging) variable is ``settings.py``. Users successful actions (login, logout) are logged with the level ``INFO``, failures are logged with the level ``WARNING`` and user attributes transmitted to a service are logged with the level ``DEBUG``. diff --git a/cas_server/__init__.py b/cas_server/__init__.py index 1bb1fa4..f830740 100644 --- a/cas_server/__init__.py +++ b/cas_server/__init__.py @@ -9,4 +9,4 @@ # # (c) 2015 Valentin Samir -default_app_config = 'cas_server.apps.AppConfig' +default_app_config = 'cas_server.apps.CasAppConfig' diff --git a/cas_server/apps.py b/cas_server/apps.py index bb93d57..c34b6eb 100644 --- a/cas_server/apps.py +++ b/cas_server/apps.py @@ -2,6 +2,6 @@ from django.utils.translation import ugettext_lazy as _ from django.apps import AppConfig -class AppConfig(AppConfig): +class CasAppConfig(AppConfig): name = 'cas_server' verbose_name = _('Central Authentication Service') diff --git a/cas_server/auth.py b/cas_server/auth.py index 7ccacae..c2a4b19 100644 --- a/cas_server/auth.py +++ b/cas_server/auth.py @@ -26,11 +26,11 @@ class AuthUser(object): def test_password(self, password): """test `password` agains the user""" - raise NotImplemented() + raise NotImplementedError() def attributs(self): """return a dict of user attributes""" - raise NotImplemented() + raise NotImplementedError() class DummyAuthUser(AuthUser): @@ -57,11 +57,11 @@ class TestAuthUser(AuthUser): def test_password(self, password): """test `password` agains the user""" - return self.username == "test" and password == "test" + return self.username == settings.CAS_TEST_USER and password == settings.CAS_TEST_PASSWORD def attributs(self): """return a dict of user attributes""" - return {'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'} + return settings.CAS_TEST_ATTRIBUTES class MysqlAuthUser(AuthUser): diff --git a/cas_server/default_settings.py b/cas_server/default_settings.py index 9ad6f53..2824991 100644 --- a/cas_server/default_settings.py +++ b/cas_server/default_settings.py @@ -73,3 +73,10 @@ setting_default('CAS_SQL_DBCHARSET', 'utf8') setting_default('CAS_SQL_USER_QUERY', 'SELECT user AS usersame, pass AS ' 'password, users.* FROM users WHERE user = %s') setting_default('CAS_SQL_PASSWORD_CHECK', 'crypt') # crypt or plain + +setting_default('CAS_TEST_USER', 'test') +setting_default('CAS_TEST_PASSWORD', 'test') +setting_default( + 'CAS_TEST_ATTRIBUTES', + {'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'} +) diff --git a/cas_server/static/cas_server/cas.js b/cas_server/static/cas_server/cas.js index 4c42dde..06e1a5d 100644 --- a/cas_server/static/cas_server/cas.js +++ b/cas_server/static/cas_server/cas.js @@ -1,55 +1,52 @@ function cas_login(cas_server_login, service, login_service, callback){ - url = cas_server_login + '?service=' + encodeURIComponent(service); + var url = cas_server_login + "?service=" + encodeURIComponent(service); $.ajax({ - type: 'GET', - url:url, - beforeSend: function (request) { + type: "GET", + url, + beforeSend(request) { request.setRequestHeader("X-AJAX", "1"); }, xhrFields: { withCredentials: true }, - success: function(data, textStatus, request){ - if(data.status == 'success'){ + success(data, textStatus, request){ + if(data.status === "success"){ $.ajax({ - type: 'GET', + type: "GET", url: data.url, xhrFields: { withCredentials: true }, success: callback, - error: function (request, textStatus, errorThrown) {}, + error(request, textStatus, errorThrown) {}, }); } else { - if(data.detail == "login required"){ - window.location.href = cas_server_login + '?service=' + encodeURIComponent(login_service); + if(data.detail === "login required"){ + window.location.href = cas_server_login + "?service=" + encodeURIComponent(login_service); } else { - alert('error: ' + data.messages[1].message); + alert("error: " + data.messages[1].message); } } }, - error: function (request, textStatus, errorThrown) {}, + error(request, textStatus, errorThrown) {}, }); } function cas_logout(cas_server_logout){ $.ajax({ - type: 'GET', - url:cas_server_logout, - beforeSend: function (request) { + type: "GET", + url: cas_server_logout, + beforeSend(request) { request.setRequestHeader("X-AJAX", "1"); }, xhrFields: { withCredentials: true }, - error: function (request, textStatus, errorThrown) {}, - success: function(data, textStatus, request){ - if(data.status == 'error'){ - alert('error: ' + data.messages[1].message); + error(request, textStatus, errorThrown) {}, + success(data, textStatus, request){ + if(data.status === "error"){ + alert("error: " + data.messages[1].message); } }, }); } - - - diff --git a/cas_server/static/cas_server/login.css b/cas_server/static/cas_server/login.css index b29433d..6d3524b 100644 --- a/cas_server/static/cas_server/login.css +++ b/cas_server/static/cas_server/login.css @@ -43,14 +43,14 @@ body { @media screen and (max-width: 680px) { #app-name { - margin: 0px; + margin: 0; } #app-name img { display: block; margin: auto; } body { - padding-top: 0px; - padding-bottom: 0px; + padding-top: 0; + padding-bottom: 0; } } diff --git a/cas_server/tests.py b/cas_server/tests.py new file mode 100644 index 0000000..222596e --- /dev/null +++ b/cas_server/tests.py @@ -0,0 +1,639 @@ +from .default_settings import settings + +from django.test import TestCase +from django.test import Client + +from lxml import etree + +from cas_server import models +from cas_server import utils + + +def get_login_page_params(): + client = Client() + response = client.get('/login') + form = response.context["form"] + params = {} + for field in form: + if field.value(): + params[field.name] = field.value() + else: + params[field.name] = "" + return client, params + + +def get_auth_client(): + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + + client.post('/login', params) + return client + + +def get_user_ticket_request(service): + client = get_auth_client() + response = client.get("/login", {"service": service}) + ticket_value = response['Location'].split('ticket=')[-1] + user = models.User.objects.get( + username=settings.CAS_TEST_USER, + session_key=client.session.session_key + ) + ticket = models.ServiceTicket.objects.get(value=ticket_value) + return (user, ticket) + + +def get_pgt(): + (host, port) = utils.PGTUrlHandler.run()[1:3] + service = "http://%s:%s" % (host, port) + + (user, ticket) = get_user_ticket_request(service) + + client = Client() + client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service}) + params = utils.PGTUrlHandler.PARAMS.copy() + + params["service"] = service + params["user"] = user + + return params + + +class LoginTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + self.service_pattern = models.ServicePattern.objects.create( + name="example", + pattern="^https://www\.example\.com(/.*)?$", + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + def test_login_view_post_goodpass_goodlt(self): + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + + response = client.post('/login', params) + + self.assertEqual(response.status_code, 200) + self.assertTrue( + ( + b"You have successfully logged into " + b"the Central Authentication Service" + ) in response.content + ) + + self.assertTrue( + models.User.objects.get( + username=settings.CAS_TEST_USER, + session_key=client.session.session_key + ) + ) + + def test_login_view_post_badlt(self): + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = settings.CAS_TEST_PASSWORD + params["lt"] = 'LT-random' + + response = client.post('/login', params) + + self.assertEqual(response.status_code, 200) + self.assertTrue(b"Invalid login ticket" in response.content) + self.assertFalse( + ( + b"You have successfully logged into " + b"the Central Authentication Service" + ) in response.content + ) + + def test_login_view_post_badpass_good_lt(self): + client, params = get_login_page_params() + params["username"] = settings.CAS_TEST_USER + params["password"] = "test2" + response = client.post('/login', params) + + self.assertEqual(response.status_code, 200) + self.assertTrue( + ( + b"The credentials you provided cannot be " + b"determined to be authentic" + ) in response.content + ) + self.assertFalse( + ( + b"You have successfully logged into " + b"the Central Authentication Service" + ) in response.content + ) + + def test_view_login_get_auth_allowed_service(self): + client = get_auth_client() + response = client.get("/login?service=https://www.example.com") + self.assertEqual(response.status_code, 302) + self.assertTrue(response.has_header('Location')) + self.assertTrue( + response['Location'].startswith( + "https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX + ) + ) + + ticket_value = response['Location'].split('ticket=')[-1] + user = models.User.objects.get( + username=settings.CAS_TEST_USER, + session_key=client.session.session_key + ) + self.assertTrue(user) + ticket = models.ServiceTicket.objects.get(value=ticket_value) + self.assertEqual(ticket.user, user) + self.assertEqual(ticket.attributs, settings.CAS_TEST_ATTRIBUTES) + self.assertEqual(ticket.validate, False) + self.assertEqual(ticket.service_pattern, self.service_pattern) + + def test_view_login_get_auth_denied_service(self): + client = get_auth_client() + response = client.get("/login?service=https://www.example.org") + self.assertEqual(response.status_code, 200) + self.assertTrue(b"Service https://www.example.org non allowed" in response.content) + + +class LogoutTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + + def test_logout_view(self): + client = get_auth_client() + + response = client.get("/login") + self.assertEqual(response.status_code, 200) + self.assertTrue( + ( + b"You have successfully logged into " + b"the Central Authentication Service" + ) in response.content + ) + + response = client.get("/logout") + self.assertEqual(response.status_code, 200) + self.assertTrue( + ( + b"You have successfully logged out from " + b"the Central Authentication Service" + ) in response.content + ) + + response = client.get("/login") + self.assertEqual(response.status_code, 200) + self.assertFalse( + ( + b"You have successfully logged into " + b"the Central Authentication Service" + ) in response.content + ) + + def test_logout_view_url(self): + client = get_auth_client() + + response = client.get('/logout?url=https://www.example.com') + self.assertEqual(response.status_code, 302) + self.assertTrue(response.has_header("Location")) + self.assertEqual(response["Location"], "https://www.example.com") + + response = client.get("/login") + self.assertEqual(response.status_code, 200) + self.assertFalse( + ( + b"You have successfully logged into " + b"the Central Authentication Service" + ) in response.content + ) + + def test_logout_view_service(self): + client = get_auth_client() + + response = client.get('/logout?service=https://www.example.com') + self.assertEqual(response.status_code, 302) + self.assertTrue(response.has_header("Location")) + self.assertEqual(response["Location"], "https://www.example.com") + + response = client.get("/login") + self.assertEqual(response.status_code, 200) + self.assertFalse( + ( + b"You have successfully logged into " + b"the Central Authentication Service" + ) in response.content + ) + + +class AuthTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + self.service = 'https://www.example.com' + models.ServicePattern.objects.create( + name="example", + pattern="^https://www\.example\.com(/.*)?$" + ) + + def test_auth_view_goodpass(self): + settings.CAS_AUTH_SHARED_SECRET = 'test' + client = Client() + response = client.post( + '/auth', + { + 'username': settings.CAS_TEST_USER, + 'password': settings.CAS_TEST_PASSWORD, + 'service': self.service, + 'secret': 'test' + } + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'yes\n') + + def test_auth_view_badpass(self): + settings.CAS_AUTH_SHARED_SECRET = 'test' + client = Client() + response = client.post( + '/auth', + { + 'username': settings.CAS_TEST_USER, + 'password': 'badpass', + 'service': self.service, + 'secret': 'test' + } + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'no\n') + + def test_auth_view_badservice(self): + settings.CAS_AUTH_SHARED_SECRET = 'test' + client = Client() + response = client.post( + '/auth', + { + 'username': settings.CAS_TEST_USER, + 'password': settings.CAS_TEST_PASSWORD, + 'service': 'https://www.example.org', + 'secret': 'test' + } + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'no\n') + + def test_auth_view_badsecret(self): + settings.CAS_AUTH_SHARED_SECRET = 'test' + client = Client() + response = client.post( + '/auth', + { + 'username': settings.CAS_TEST_USER, + 'password': settings.CAS_TEST_PASSWORD, + 'service': self.service, + 'secret': 'badsecret' + } + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'no\n') + + def test_auth_view_badsettings(self): + settings.CAS_AUTH_SHARED_SECRET = None + client = Client() + response = client.post( + '/auth', + { + 'username': settings.CAS_TEST_USER, + 'password': settings.CAS_TEST_PASSWORD, + 'service': self.service, + 'secret': 'test' + } + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b"no\nplease set CAS_AUTH_SHARED_SECRET") + + +class ValidateTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + self.service = 'https://www.example.com' + self.service_pattern = models.ServicePattern.objects.create( + name="example", + pattern="^https://www\.example\.com(/.*)?$" + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + def test_validate_view_ok(self): + ticket = get_user_ticket_request(self.service)[1] + + client = Client() + response = client.get('/validate', {'ticket': ticket.value, 'service': self.service}) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'yes\ntest\n') + + def test_validate_view_badservice(self): + ticket = get_user_ticket_request(self.service)[1] + + client = Client() + response = client.get( + '/validate', + {'ticket': ticket.value, 'service': "https://www.example.org"} + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'no\n') + + def test_validate_view_badticket(self): + get_user_ticket_request(self.service) + + client = Client() + response = client.get( + '/validate', + {'ticket': "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX, 'service': self.service} + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.content, b'no\n') + + +class ValidateServiceTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + self.service = 'http://127.0.0.1:45678' + self.service_pattern = models.ServicePattern.objects.create( + name="localhost", + pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$", + proxy_callback=True + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + def test_validate_service_view_ok(self): + ticket = get_user_ticket_request(self.service)[1] + + client = Client() + response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + sucess = root.xpath( + "//cas:authenticationSuccess", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertTrue(sucess) + + users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(users), 1) + self.assertEqual(users[0].text, settings.CAS_TEST_USER) + + attributes = root.xpath( + "//cas:attributes", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(attributes), 1) + attrs1 = {} + for attr in attributes[0]: + attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text + + attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(attributes), len(attrs1)) + attrs2 = {} + for attr in attributes: + attrs2[attr.attrib['name']] = attr.attrib['value'] + self.assertEqual(attrs1, attrs2) + self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES) + + def test_validate_service_view_badservice(self): + ticket = get_user_ticket_request(self.service)[1] + + client = Client() + bad_service = "https://www.example.org" + response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': bad_service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "INVALID_SERVICE") + self.assertEqual(error[0].text, bad_service) + + def test_validate_service_view_badticket_goodprefix(self): + get_user_ticket_request(self.service) + + client = Client() + bad_ticket = "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX + response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "INVALID_TICKET") + self.assertEqual(error[0].text, 'ticket not found') + + def test_validate_service_view_badticket_badprefix(self): + get_user_ticket_request(self.service) + + client = Client() + bad_ticket = "RANDOM" + response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "INVALID_TICKET") + self.assertEqual(error[0].text, bad_ticket) + + def test_validate_service_view_ok_pgturl(self): + (host, port) = utils.PGTUrlHandler.run()[1:3] + service = "http://%s:%s" % (host, port) + + ticket = get_user_ticket_request(service)[1] + + client = Client() + response = client.get( + '/serviceValidate', + {'ticket': ticket.value, 'service': service, 'pgtUrl': service} + ) + pgt_params = utils.PGTUrlHandler.PARAMS.copy() + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + pgtiou = root.xpath( + "//cas:proxyGrantingTicket", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(pgtiou), 1) + self.assertEqual(pgt_params["pgtIou"], pgtiou[0].text) + self.assertTrue("pgtId" in pgt_params) + + def test_validate_service_pgturl_bad_proxy_callback(self): + self.service_pattern.proxy_callback = False + self.service_pattern.save() + ticket = get_user_ticket_request(self.service)[1] + + client = Client() + response = client.get( + '/serviceValidate', + {'ticket': ticket.value, 'service': self.service, 'pgtUrl': self.service} + ) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "INVALID_PROXY_CALLBACK") + self.assertEqual(error[0].text, "callback url not allowed by configuration") + + +class ProxyTestCase(TestCase): + + def setUp(self): + settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' + self.service = 'http://127.0.0.1' + self.service_pattern = models.ServicePattern.objects.create( + name="localhost", + pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$", + proxy=True, + proxy_callback=True + ) + models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) + + def test_validate_proxy_ok(self): + params = get_pgt() + + # get a proxy ticket + client1 = Client() + response = client1.get('/proxy', {'pgt': params['pgtId'], 'targetService': self.service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + sucess = root.xpath("//cas:proxySuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertTrue(sucess) + + proxy_ticket = root.xpath( + "//cas:proxyTicket", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(proxy_ticket), 1) + proxy_ticket = proxy_ticket[0].text + + # validate the proxy ticket + client2 = Client() + response = client2.get('/proxyValidate', {'ticket': proxy_ticket, 'service': self.service}) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + sucess = root.xpath( + "//cas:authenticationSuccess", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertTrue(sucess) + + # check that the proxy is send to the end service + proxies = root.xpath("//cas:proxies", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(proxies), 1) + proxy = proxies[0].xpath("//cas:proxy", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(proxy), 1) + self.assertEqual(proxy[0].text, params["service"]) + + # same tests than those for serviceValidate + users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(users), 1) + self.assertEqual(users[0].text, settings.CAS_TEST_USER) + + attributes = root.xpath( + "//cas:attributes", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(attributes), 1) + attrs1 = {} + for attr in attributes[0]: + attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text + + attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + self.assertEqual(len(attributes), len(attrs1)) + attrs2 = {} + for attr in attributes: + attrs2[attr.attrib['name']] = attr.attrib['value'] + self.assertEqual(attrs1, attrs2) + self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES) + + def test_validate_proxy_bad(self): + params = get_pgt() + + # bad PGT + client1 = Client() + response = client1.get( + '/proxy', + { + 'pgt': "%s-RANDOM" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX, + 'targetService': params['service'] + } + ) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "INVALID_TICKET") + self.assertEqual( + error[0].text, + "PGT %s-RANDOM not found" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX + ) + + # bad targetService + client2 = Client() + response = client2.get( + '/proxy', + {'pgt': params['pgtId'], 'targetService': "https://www.example.org"} + ) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE") + self.assertEqual(error[0].text, "https://www.example.org") + + # service do not allow proxy ticket + self.service_pattern.proxy = False + self.service_pattern.save() + + client3 = Client() + response = client3.get( + '/proxy', + {'pgt': params['pgtId'], 'targetService': params['service']} + ) + self.assertEqual(response.status_code, 200) + + root = etree.fromstring(response.content) + error = root.xpath( + "//cas:authenticationFailure", + namespaces={'cas': "http://www.yale.edu/tp/cas"} + ) + self.assertEqual(len(error), 1) + self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE") + self.assertEqual( + error[0].text, + 'the service %s do not allow proxy ticket' % params['service'] + ) diff --git a/cas_server/urls.py b/cas_server/urls.py index b2ed38b..982ef9d 100644 --- a/cas_server/urls.py +++ b/cas_server/urls.py @@ -14,7 +14,7 @@ from django.conf.urls import patterns, url from django.views.generic import RedirectView from django.views.decorators.debug import sensitive_post_parameters, sensitive_variables -import views +from cas_server import views urlpatterns = patterns( '', diff --git a/cas_server/utils.py b/cas_server/utils.py index fdb8f46..8a2a040 100644 --- a/cas_server/utils.py +++ b/cas_server/utils.py @@ -19,13 +19,10 @@ from django.contrib import messages import random import string import json +from threading import Thread from importlib import import_module - -try: - from urlparse import urlparse, urlunparse, parse_qsl - from urllib import urlencode -except ImportError: - from urllib.parse import urlparse, urlunparse, parse_qsl, urlencode +from six.moves import BaseHTTPServer +from six.moves.urllib.parse import urlparse, urlunparse, parse_qsl, urlencode def context(params): @@ -83,9 +80,9 @@ def update_url(url, params): query = dict(parse_qsl(url_parts[4])) query.update(params) url_parts[4] = urlencode(query) - for i in range(len(url_parts)): - if not isinstance(url_parts[i], bytes): - url_parts[i] = url_parts[i].encode('utf-8') + for i, url_part in enumerate(url_parts): + if not isinstance(url_part, bytes): + url_parts[i] = url_part.encode('utf-8') return urlunparse(url_parts).decode('utf-8') @@ -144,3 +141,34 @@ def gen_pgtiou(): def gen_saml_id(): """Generate an saml id""" return _gen_ticket('_') + + +class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler): + PARAMS = {} + + def do_GET(s): + s.send_response(200) + s.send_header(b"Content-type", "text/plain") + s.end_headers() + s.wfile.write(b"ok") + url = urlparse(s.path) + params = dict(parse_qsl(url.query)) + PGTUrlHandler.PARAMS.update(params) + + def log_message(self, template, *args): + return + + @staticmethod + def run(): + server_class = BaseHTTPServer.HTTPServer + httpd = server_class(("127.0.0.1", 0), PGTUrlHandler) + (host, port) = httpd.socket.getsockname() + + def lauch(): + httpd.handle_request() + httpd.server_close() + + httpd_thread = Thread(target=lauch) + httpd_thread.daemon = True + httpd_thread.start() + return (httpd_thread, host, port) diff --git a/cas_server/views.py b/cas_server/views.py index e431499..37fe179 100644 --- a/cas_server/views.py +++ b/cas_server/views.py @@ -23,6 +23,7 @@ from django.views.decorators.csrf import csrf_exempt from django.views.generic import View +import re import logging import pprint import requests @@ -62,12 +63,12 @@ class AttributesMixin(object): class LogoutMixin(object): """destroy CAS session utils""" - def logout(self, all=False): + def logout(self, all_session=False): """effectively destroy CAS session""" session_nb = 0 username = self.request.session.get("username") if username: - if all: + if all_session: logger.info("Logging out user %s from all of they sessions." % username) else: logger.info("Logging out user %s." % username) @@ -85,8 +86,8 @@ class LogoutMixin(object): # if user not found in database, flush the session anyway self.request.session.flush() - # If all is set logout user from alternative sessions - if all: + # If all_session is set logout user from alternative sessions + if all_session: for user in models.User.objects.filter(username=username): session = SessionStore(session_key=user.session_key) session.flush() @@ -197,10 +198,7 @@ class LoginView(View, LogoutMixin): def init_post(self, request): self.request = request self.service = request.POST.get('service') - if request.POST.get('renew') and request.POST['renew'] != "False": - self.renew = True - else: - self.renew = False + self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False") self.gateway = request.POST.get('gateway') self.method = request.POST.get('method') self.ajax = 'HTTP_X_AJAX' in request.META @@ -284,10 +282,7 @@ class LoginView(View, LogoutMixin): def init_get(self, request): self.request = request self.service = request.GET.get('service') - if request.GET.get('renew') and request.GET['renew'] != "False": - self.renew = True - else: - self.renew = False + self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False") self.gateway = request.GET.get('gateway') self.method = request.GET.get('method') self.ajax = 'HTTP_X_AJAX' in request.META @@ -666,7 +661,10 @@ class ValidateService(View, AttributesMixin): params['username'] = self.ticket.user.attributs.get( self.ticket.service_pattern.user_field ) - if self.pgt_url and self.pgt_url.startswith("https://"): + if self.pgt_url and ( + self.pgt_url.startswith("https://") or + re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url) + ): return self.process_pgturl(params) else: logger.info( diff --git a/requirements-dev.txt b/requirements-dev.txt index 9998ce7..e6ef993 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,3 +7,4 @@ django-picklefield>=0.3.1 requests_futures>=0.9.5 django-bootstrap3>=5.4 lxml>=3.4 +six>=1 diff --git a/requirements.txt b/requirements.txt index 8d64df0..97d4f1c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,4 @@ requests_futures>=0.9.5 django-picklefield>=0.3.1 django-bootstrap3>=5.4 lxml>=3.4 - +six>=1 diff --git a/run_tests b/run_tests new file mode 100755 index 0000000..4ea21ee --- /dev/null +++ b/run_tests @@ -0,0 +1,22 @@ +#!/usr/bin/env python +import os, sys +import django +from django.conf import settings + +import settings_tests + +settings.configure(**settings_tests.__dict__) +django.setup() + +try: + # Django <= 1.8 + from django.test.simple import DjangoTestSuiteRunner + test_runner = DjangoTestSuiteRunner(verbosity=1) +except ImportError: + # Django >= 1.8 + from django.test.runner import DiscoverRunner + test_runner = DiscoverRunner(verbosity=1) + +failures = test_runner.run_tests(['cas_server']) +if failures: + sys.exit(failures) diff --git a/settings_tests.py b/settings_tests.py new file mode 100644 index 0000000..4588c2c --- /dev/null +++ b/settings_tests.py @@ -0,0 +1,83 @@ +""" +Django test settings for cas_server application. + +Generated by 'django-admin startproject' using Django 1.9.7. + +For more information on this file, see +https://docs.djangoproject.com/en/1.9/topics/settings/ + +For the full list of settings and their values, see +https://docs.djangoproject.com/en/1.9/ref/settings/ +""" + +import os + +# Build paths inside the project like this: os.path.join(BASE_DIR, ...) +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +# Quick-start development settings - unsuitable for production +# See https://docs.djangoproject.com/en/1.9/howto/deployment/checklist/ + +# SECURITY WARNING: keep the secret key used in production secret! +SECRET_KEY = 'changeme' + +# SECURITY WARNING: don't run with debug turned on in production! +DEBUG = True + +ALLOWED_HOSTS = [] + + +# Application definition + +INSTALLED_APPS = [ + 'django.contrib.admin', + 'django.contrib.auth', + 'django.contrib.contenttypes', + 'django.contrib.sessions', + 'django.contrib.messages', + 'django.contrib.staticfiles', + 'bootstrap3', + 'cas_server', +] + +MIDDLEWARE_CLASSES = [ + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.middleware.common.CommonMiddleware', + 'django.middleware.csrf.CsrfViewMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'django.contrib.auth.middleware.SessionAuthenticationMiddleware', + 'django.contrib.messages.middleware.MessageMiddleware', + 'django.middleware.clickjacking.XFrameOptionsMiddleware', + 'django.middleware.locale.LocaleMiddleware', +] + +ROOT_URLCONF = 'cas_server.urls' + +# Database +# https://docs.djangoproject.com/en/1.9/ref/settings/#databases + +DATABASES = { + 'default': { + 'ENGINE': 'django.db.backends.sqlite3', + } +} + +# Internationalization +# https://docs.djangoproject.com/en/1.9/topics/i18n/ + +LANGUAGE_CODE = 'en-us' + +TIME_ZONE = 'UTC' + +USE_I18N = True + +USE_L10N = True + +USE_TZ = True + + +# Static files (CSS, JavaScript, Images) +# https://docs.djangoproject.com/en/1.9/howto/static-files/ + +STATIC_URL = '/static/' diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/dummy.py b/tests/dummy.py deleted file mode 100644 index 8266d7b..0000000 --- a/tests/dummy.py +++ /dev/null @@ -1,136 +0,0 @@ -import functools -from cas_server import models - -class DummyUserManager(object): - def __init__(self, username, session_key): - self.username = username - self.session_key = session_key - def get(self, username=None, session_key=None): - if username == self.username and session_key == self.session_key: - return models.User(username=username, session_key=session_key) - else: - raise models.User.DoesNotExist() - - -def dummy(*args, **kwds): - pass - -def dummy_service_pattern(**kwargs): - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwds): - service_validate = models.ServicePattern.validate - models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern(**kwargs)) - ret = func(*args, **kwds) - models.ServicePattern.validate = service_validate - return ret - return wrapper - return decorator - -def dummy_user(username, session_key): - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwds): - user_manager = models.User.objects - user_save = models.User.save - user_delete = models.User.delete - models.User.objects = DummyUserManager(username, session_key) - models.User.save = dummy - models.User.delete = dummy - ret = func(*args, **kwds) - models.User.objects = user_manager - models.User.save = user_save - models.User.delete = user_delete - return ret - return wrapper - return decorator - -def dummy_ticket(ticket_class, service, ticket): - def decorator(func): - @functools.wraps(func) - def wrapper(*args, **kwds): - ticket_manager = ticket_class.objects - ticket_save = ticket_class.save - ticket_delete = ticket_class.delete - ticket_class.objects = DummyTicketManager(ticket_class, service, ticket) - ticket_class.save = dummy - ticket_class.delete = dummy - ret = func(*args, **kwds) - ticket_class.objects = ticket_manager - ticket_class.save = ticket_save - ticket_class.delete = ticket_delete - return ret - return wrapper - return decorator - - -def dummy_proxy(func): - @functools.wraps(func) - def wrapper(*args, **kwds): - proxy_manager = models.Proxy.objects - models.Proxy.objects = DummyProxyManager() - ret = func(*args, **kwds) - models.Proxy.objects = proxy_manager - return ret - return wrapper - -class DummyProxyManager(object): - def create(self, **kwargs): - for field in models.Proxy._meta.fields: - field.allow_unsaved_instance_assignment = True - return models.Proxy(**kwargs) - -class DummyTicketManager(object): - def __init__(self, ticket_class, service, ticket): - self.ticket_class = ticket_class - self.service = service - self.ticket = ticket - - def create(self, **kwargs): - for field in self.ticket_class._meta.fields: - field.allow_unsaved_instance_assignment = True - return self.ticket_class(**kwargs) - - def filter(self, *args, **kwargs): - return DummyQuerySet() - - def get(self, **kwargs): - for field in self.ticket_class._meta.fields: - field.allow_unsaved_instance_assignment = True - if 'value' in kwargs: - if kwargs['value'] != self.ticket: - raise self.ticket_class.DoesNotExist() - else: - kwargs['value'] = self.ticket - - if 'service' in kwargs: - if kwargs['service'] != self.service: - raise self.ticket_class.DoesNotExist() - else: - kwargs['service'] = self.service - if not 'user' in kwargs: - kwargs['user'] = models.User(username="test") - - for field in models.ServiceTicket._meta.fields: - field.allow_unsaved_instance_assignment = True - for key in list(kwargs): - if '__' in key: - del kwargs[key] - kwargs['attributs'] = {'mail': 'test@example.com'} - kwargs['service_pattern'] = models.ServicePattern() - return self.ticket_class(**kwargs) - - - -class DummySession(dict): - session_key = "test_session" - - def set_expiry(self, int): - pass - - def flush(self): - self.clear() - - -class DummyQuerySet(set): - pass diff --git a/tests/init.py b/tests/init.py deleted file mode 100644 index f6ede9e..0000000 --- a/tests/init.py +++ /dev/null @@ -1,32 +0,0 @@ -import django -from django.conf import settings -from django.contrib import messages - -settings.configure() -settings.STATIC_URL = "/static/" -settings.DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': '/dev/null', - } -} -settings.INSTALLED_APPS = ( - 'django.contrib.admin', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.sessions', - 'django.contrib.messages', - 'django.contrib.staticfiles', - 'bootstrap3', - 'cas_server', -) - -settings.ROOT_URLCONF = "/" -settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' - -try: - django.setup() -except AttributeError: - pass -messages.add_message = lambda x,y,z:None - diff --git a/tests/test_proxy.py b/tests/test_proxy.py deleted file mode 100644 index 963d834..0000000 --- a/tests/test_proxy.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import absolute_import -from tests.init import * - -from django.test import RequestFactory - -import os -import pytest -from lxml import etree -from cas_server.views import ValidateService, Proxy -from cas_server import models - -from tests.dummy import * - -@pytest.mark.django_db -@dummy_ticket(models.ProxyGrantingTicket, '', "PGT-random") -@dummy_service_pattern(proxy=True) -@dummy_user(username="test", session_key="test_session") -@dummy_ticket(models.ProxyTicket, "https://www.example.com", "PT-random") -@dummy_proxy -def test_proxy_ok(): - factory = RequestFactory() - request = factory.get('/proxy?pgt=PGT-random&targetService=https://www.example.com') - - request.session = DummySession() - - proxy = Proxy() - response = proxy.get(request) - - assert response.status_code == 200 - - root = etree.fromstring(response.content) - proxy_tickets = root.xpath("//cas:proxyTicket", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - - assert len(proxy_tickets) == 1 - - factory = RequestFactory() - request = factory.get('/proxyValidate?ticket=PT-random&service=https://www.example.com') - - validate = ValidateService() - validate.allow_proxy_ticket = True - response = validate.get(request) - - assert response.status_code == 200 - - root = etree.fromstring(response.content) - users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - - assert len(users) == 1 - assert users[0].text == "test" - - - diff --git a/tests/test_validate_service.py b/tests/test_validate_service.py deleted file mode 100644 index 940e23b..0000000 --- a/tests/test_validate_service.py +++ /dev/null @@ -1,87 +0,0 @@ -from __future__ import absolute_import -from .init import * - -from django.test import RequestFactory - -import os -import pytest -from lxml import etree -from cas_server.views import ValidateService -from cas_server import models - -from .dummy import * - -@pytest.mark.django_db -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") -def test_validate_service_view_ok(): - factory = RequestFactory() - request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example.com') - - request.session = DummySession() - - validate = ValidateService() - validate.allow_proxy_ticket = False - response = validate.get(request) - - assert response.status_code == 200 - - root = etree.fromstring(response.content) - users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - - assert len(users) == 1 - assert users[0].text == "test" - - attributes = root.xpath("//cas:attributes", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - - assert len(attributes) == 1 - - attrs = {} - for attr in attributes[0]: - attrs[attr.tag[len("http://www.yale.edu/tp/cas")+2:]]=attr.text - - assert 'mail' in attrs - assert attrs['mail'] == 'test@example.com' - - - -@pytest.mark.django_db -@dummy_ticket(models.ServiceTicket, 'https://www.example2.com', "ST-random") -def test_validate_service_view_badservice(): - factory = RequestFactory() - request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example1.com') - - request.session = DummySession() - - validate = ValidateService() - validate.allow_proxy_ticket = False - response = validate.get(request) - - assert response.status_code == 200 - - root = etree.fromstring(response.content) - - error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - - assert len(error) == 1 - assert error[0].attrib['code'] == 'INVALID_SERVICE' - -@pytest.mark.django_db -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random2") -def test_validate_service_view_badticket(): - factory = RequestFactory() - request = factory.get('/serviceValidate?ticket=ST-random1&service=https://www.example.com') - - request.session = DummySession() - - validate = ValidateService() - validate.allow_proxy_ticket = False - response = validate.get(request) - - assert response.status_code == 200 - - root = etree.fromstring(response.content) - - error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"}) - - assert len(error) == 1 - assert error[0].attrib['code'] == 'INVALID_TICKET' diff --git a/tests/test_views_auth.py b/tests/test_views_auth.py deleted file mode 100644 index 4b4a9eb..0000000 --- a/tests/test_views_auth.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import absolute_import -from .init import * - -from django.test import RequestFactory - -import os -import pytest - -from cas_server.views import Auth -from cas_server import models - -from .dummy import * - -settings.CAS_AUTH_SHARED_SECRET = "test" - -@pytest.mark.django_db -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") -@dummy_user(username="test", session_key="test_session") -@dummy_service_pattern() -def test_auth_view_goodpass(): - factory = RequestFactory() - request = factory.post('/auth', {'username':'test', 'password':'test', 'service':'https://www.example.com', 'secret':'test'}) - - request.session = DummySession() - - auth = Auth() - response = auth.post(request) - - assert response.status_code == 200 - assert response.content == b"yes\n" - -@dummy_service_pattern() -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") -@dummy_user(username="test", session_key="test_session") -def test_auth_view_badpass(): - factory = RequestFactory() - request = factory.post('/auth', {'username':'test', 'password':'badpass', 'service':'https://www.example.com', 'secret':'test'}) - - request.session = DummySession() - - auth = Auth() - response = auth.post(request) - - assert response.status_code == 200 - assert response.content == b"no\n" - diff --git a/tests/test_views_login.py b/tests/test_views_login.py deleted file mode 100644 index 6aabe80..0000000 --- a/tests/test_views_login.py +++ /dev/null @@ -1,163 +0,0 @@ -from __future__ import absolute_import -from .init import * - -from django.test import RequestFactory - -import os -import pytest - -from cas_server.views import LoginView -from cas_server import models - -from .dummy import * - - - -def test_login_view_post_goodpass_goodlt(): - factory = RequestFactory() - request = factory.post('/login', {'username':'test', 'password':'test', 'lt':'LT-random'}) - request.session = DummySession() - - request.session['lt'] = ['LT-random'] - - request.session["username"] = os.urandom(20) - request.session["warn"] = os.urandom(20) - - login = LoginView() - login.init_post(request) - - ret = login.process_post(pytest=True) - - assert ret == LoginView.USER_LOGIN_OK - assert request.session.get("authenticated") == True - assert request.session.get("username") == "test" - assert request.session.get("warn") == False - -def test_login_view_post_badlt(): - factory = RequestFactory() - request = factory.post('/login', {'username':'test', 'password':'test', 'lt':'LT-random1'}) - request.session = DummySession() - - request.session['lt'] = ['LT-random2'] - - authenticated = os.urandom(20) - username = os.urandom(20) - warn = os.urandom(20) - - request.session["authenticated"] = authenticated - request.session["username"] = username - request.session["warn"] = warn - - login = LoginView() - login.init_post(request) - - ret = login.process_post(pytest=True) - - assert ret == LoginView.INVALID_LOGIN_TICKET - assert request.session.get("authenticated") == authenticated - assert request.session.get("username") == username - assert request.session.get("warn") == warn - -def test_login_view_post_badpass_good_lt(): - factory = RequestFactory() - request = factory.post('/login', {'username':'test', 'password':'badpassword', 'lt':'LT-random'}) - request.session = DummySession() - - request.session['lt'] = ['LT-random'] - - login = LoginView() - login.init_post(request) - ret = login.process_post() - - assert ret == LoginView.USER_LOGIN_FAILURE - assert not request.session.get("authenticated") - assert not request.session.get("username") - assert not request.session.get("warn") - - -def test_view_login_get_unauth(): - factory = RequestFactory() - request = factory.post('/login') - request.session = DummySession() - - login = LoginView() - login.init_get(request) - ret = login.process_get() - - assert ret == LoginView.USER_NOT_AUTHENTICATED - - login = LoginView() - response = login.get(request) - - assert response.status_code == 200 - -@pytest.mark.django_db -@dummy_user(username="test", session_key="test_session") -def test_view_login_get_auth(): - factory = RequestFactory() - request = factory.post('/login') - request.session = DummySession() - - request.session["authenticated"] = True - request.session["username"] = "test" - request.session["warn"] = False - - login = LoginView() - login.init_get(request) - ret = login.process_get() - - assert ret == LoginView.USER_AUTHENTICATED - - login = LoginView() - response = login.get(request) - - assert response.status_code == 200 - -@pytest.mark.django_db -@dummy_service_pattern() -@dummy_user(username="test", session_key="test_session") -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") -def test_view_login_get_auth_service(): - factory = RequestFactory() - request = factory.post('/login?service=https://www.example.com') - request.session = DummySession() - - request.session["authenticated"] = True - request.session["username"] = "test" - request.session["warn"] = False - - login = LoginView() - login.init_get(request) - ret = login.process_get() - - assert ret == LoginView.USER_AUTHENTICATED - - login = LoginView() - response = login.get(request) - - assert response.status_code == 302 - assert response['Location'].startswith('https://www.example.com?ticket=ST-') - -@pytest.mark.django_db -@dummy_service_pattern() -@dummy_user(username="test", session_key="test_session") -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") -def test_view_login_get_auth_service_warn(): - factory = RequestFactory() - request = factory.post('/login?service=https://www.example.com') - request.session = DummySession() - - request.session["authenticated"] = True - request.session["username"] = "test" - request.session["warn"] = True - - login = LoginView() - login.init_get(request) - ret = login.process_get() - - assert ret == LoginView.USER_AUTHENTICATED - - login = LoginView() - response = login.get(request) - - assert response.status_code == 200 diff --git a/tests/test_views_logout.py b/tests/test_views_logout.py deleted file mode 100644 index 03410bd..0000000 --- a/tests/test_views_logout.py +++ /dev/null @@ -1,80 +0,0 @@ -from __future__ import absolute_import -from .init import * - -from django.test import RequestFactory - -import os -import pytest - -from cas_server.views import LogoutView -from cas_server import models - -from .dummy import * - - -@pytest.mark.django_db -@dummy_user(username="test", session_key="test_session") -def test_logout_view(): - factory = RequestFactory() - request = factory.get('/logout') - - request.session = DummySession() - - request.session["authenticated"] = True - request.session["username"] = "test" - request.session["warn"] = False - - logout = LogoutView() - response = logout.get(request) - - assert response.status_code == 200 - assert not request.session.get("authenticated") - assert not request.session.get("username") - assert not request.session.get("warn") - - -@pytest.mark.django_db -@dummy_user(username="test", session_key="test_session") -def test_logout_view_url(): - factory = RequestFactory() - request = factory.get('/logout?url=https://www.example.com') - - request.session = DummySession() - - request.session["authenticated"] = True - request.session["username"] = "test" - request.session["warn"] = False - - logout = LogoutView() - response = logout.get(request) - - assert response.status_code == 302 - assert response['Location'] == 'https://www.example.com' - assert not request.session.get("authenticated") - assert not request.session.get("username") - assert not request.session.get("warn") - - - -@pytest.mark.django_db -@dummy_user(username="test", session_key="test_session") -def test_logout_view_service(): - factory = RequestFactory() - request = factory.get('/logout?service=https://www.example.com') - - request.session = DummySession() - - request.session["authenticated"] = True - request.session["username"] = "test" - request.session["warn"] = False - - logout = LogoutView() - response = logout.get(request) - - assert response.status_code == 302 - assert response['Location'] == 'https://www.example.com' - assert not request.session.get("authenticated") - assert not request.session.get("username") - assert not request.session.get("warn") - - diff --git a/tests/test_views_validate.py b/tests/test_views_validate.py deleted file mode 100644 index 201387f..0000000 --- a/tests/test_views_validate.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import absolute_import -from .init import * - -from django.test import RequestFactory - -import os -import pytest - -from cas_server.views import Validate -from cas_server import models - -from .dummy import * - -@pytest.mark.django_db -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") -def test_validate_view_ok(): - factory = RequestFactory() - request = factory.get('/validate?ticket=ST-random&service=https://www.example.com') - - request.session = DummySession() - - validate = Validate() - response = validate.get(request) - - assert response.status_code == 200 - assert response.content == b"yes\ntest\n" - - - -@pytest.mark.django_db -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") -def test_validate_view_badservice(): - factory = RequestFactory() - request = factory.get('/validate?ticket=ST-random&service=https://www.example2.com') - - request.session = DummySession() - - validate = Validate() - response = validate.get(request) - - assert response.status_code == 200 - assert response.content == b"no\n" - - - -@pytest.mark.django_db -@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random1") -def test_validate_view_badticket(): - factory = RequestFactory() - request = factory.get('/validate?ticket=ST-random2&service=https://www.example.com') - - request.session = DummySession() - - validate = Validate() - response = validate.get(request) - - assert response.status_code == 200 - assert response.content == b"no\n" diff --git a/tox.ini b/tox.ini index 997620a..0b65c56 100644 --- a/tox.ini +++ b/tox.ini @@ -17,7 +17,7 @@ deps = -r{toxinidir}/requirements-dev.txt [testenv] -commands=py.test --tb native {posargs:tests} +commands=python run_tests {posargs:tests} [testenv:py27-django17] basepython=python2.7