diff --git a/cas_server/auth.py b/cas_server/auth.py
index afcb722..d666ec5 100644
--- a/cas_server/auth.py
+++ b/cas_server/auth.py
@@ -171,7 +171,7 @@ class CASFederateAuth(AuthUser):
def attributs(self):
"""return a dict of user attributes"""
- if not self.user:
+ if not self.user: # pragma: no cover (should not happen)
return {}
else:
return self.user.attributs
diff --git a/cas_server/default_settings.py b/cas_server/default_settings.py
index 07b420e..be3f064 100644
--- a/cas_server/default_settings.py
+++ b/cas_server/default_settings.py
@@ -14,7 +14,6 @@ from django.conf import settings
from django.contrib.staticfiles.templatetags.staticfiles import static
import re
-import six
def setting_default(name, default_value):
@@ -112,13 +111,10 @@ except AttributeError:
key = settings.CAS_FEDERATE_PROVIDERS[key][2].lower()
else:
key = key.lower()
- if isinstance(key, six.string_types) or isinstance(key, six.text_type):
- return tuple(
- int(num) if num else alpha
- for num, alpha in __cas_federate_providers_list_sort.tokenize(key)
- )
- else:
- return key
+ return tuple(
+ int(num) if num else alpha
+ for num, alpha in __cas_federate_providers_list_sort.tokenize(key)
+ )
__cas_federate_providers_list_sort.tokenize = re.compile(r'(\d+)|(\D+)').findall
__CAS_FEDERATE_PROVIDERS_LIST.sort(key=__cas_federate_providers_list_sort)
diff --git a/cas_server/federate.py b/cas_server/federate.py
index 453a778..2f6489a 100644
--- a/cas_server/federate.py
+++ b/cas_server/federate.py
@@ -15,6 +15,7 @@ from .cas import CASClient
from .models import FederatedUser, FederateSLO, User
from importlib import import_module
+from six.moves import urllib
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
@@ -27,7 +28,7 @@ class CASFederateValidateUser(object):
def __init__(self, provider, service_url):
self.provider = provider
- if provider in settings.CAS_FEDERATE_PROVIDERS:
+ if provider in settings.CAS_FEDERATE_PROVIDERS: # pragma: no branch (should always be True)
(server_url, version) = settings.CAS_FEDERATE_PROVIDERS[provider][:2]
self.client = CASClient(
service_url=service_url,
@@ -44,9 +45,12 @@ class CASFederateValidateUser(object):
def verify_ticket(self, ticket):
"""test `password` agains the user"""
- if self.client is None:
+ if self.client is None: # pragma: no cover (should not happen)
+ return False
+ try:
+ username, attributs = self.client.verify_ticket(ticket)[:2]
+ except urllib.error.URLError:
return False
- username, attributs = self.client.verify_ticket(ticket)[:2]
if username is not None:
if attributs is None:
attributs = {}
@@ -83,23 +87,20 @@ class CASFederateValidateUser(object):
def clean_sessions(self, logout_request):
try:
- slos = self.client.get_saml_slos(logout_request)
- except NameError:
+ slos = self.client.get_saml_slos(logout_request) or []
+ except NameError: # pragma: no cover (should not happen)
slos = []
for slo in slos:
- try:
- for federate_slo in FederateSLO.objects.filter(ticket=slo.text):
- session = SessionStore(session_key=federate_slo.session_key)
- session.flush()
- try:
- user = User.objects.get(
- username=federate_slo.username,
- session_key=federate_slo.session_key
- )
- user.logout()
- user.delete()
- except User.DoesNotExist:
- pass
- federate_slo.delete()
- except FederateSLO.DoesNotExist:
- pass
+ for federate_slo in FederateSLO.objects.filter(ticket=slo.text):
+ session = SessionStore(session_key=federate_slo.session_key)
+ session.flush()
+ try:
+ user = User.objects.get(
+ username=federate_slo.username,
+ session_key=federate_slo.session_key
+ )
+ user.logout()
+ user.delete()
+ except User.DoesNotExist: # pragma: no cover (should not happen)
+ pass
+ federate_slo.delete()
diff --git a/cas_server/forms.py b/cas_server/forms.py
index b5cf4d0..dc0e866 100644
--- a/cas_server/forms.py
+++ b/cas_server/forms.py
@@ -31,6 +31,8 @@ class WarnForm(forms.Form):
class FederateSelect(forms.Form):
provider = forms.ChoiceField(
label=_('Identity provider'),
+ # with use a lambda abstraction to delay the access to settings.CAS_FEDERATE_PROVIDERS
+ # this is usefull to use the override_settings decorator in tests
choices=[
(
p,
@@ -88,8 +90,12 @@ class FederateUserCredential(UserCredential):
user = models.FederatedUser.objects.get(username=username, provider=provider)
user.ticket = ""
user.save()
- except models.FederatedUser.DoesNotExist:
- raise
+ # should not happed as is the FederatedUser do not exists, super should
+ # raise before a ValidationError("bad user")
+ except models.FederatedUser.DoesNotExist: # pragma: no cover (should not happend)
+ raise forms.ValidationError(
+ _(u"User not found in the temporary database, please try to reconnect")
+ )
return cleaned_data
diff --git a/cas_server/management/commands/cas_clean_federate.py b/cas_server/management/commands/cas_clean_federate.py
index 04e0608..8d91935 100644
--- a/cas_server/management/commands/cas_clean_federate.py
+++ b/cas_server/management/commands/cas_clean_federate.py
@@ -1,11 +1,7 @@
from django.core.management.base import BaseCommand
from django.utils.translation import ugettext_lazy as _
-from django.utils import timezone
-
-from datetime import timedelta
from ... import models
-from ...default_settings import settings
class Command(BaseCommand):
@@ -13,11 +9,5 @@ class Command(BaseCommand):
help = _(u"Clean old federated users")
def handle(self, *args, **options):
- federated_users = models.FederatedUser.objects.filter(
- last_update__lt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT))
- )
- known_users = {user.username for user in models.User.objects.all()}
- for user in federated_users:
- if not ('%s@%s' % (user.username, user.provider)) in known_users:
- user.delete()
+ models.FederatedUser.clean_old_entries()
models.FederateSLO.clean_deleted_sessions()
diff --git a/cas_server/models.py b/cas_server/models.py
index aea270b..3d1f17f 100644
--- a/cas_server/models.py
+++ b/cas_server/models.py
@@ -46,6 +46,16 @@ class FederatedUser(models.Model):
def __unicode__(self):
return u"%s@%s" % (self.username, self.provider)
+ @classmethod
+ def clean_old_entries(cls):
+ federated_users = cls.objects.filter(
+ last_update__lt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT))
+ )
+ known_users = {user.username for user in User.objects.all()}
+ for user in federated_users:
+ if not ('%s@%s' % (user.username, user.provider)) in known_users:
+ user.delete()
+
class FederateSLO(models.Model):
class Meta:
@@ -54,11 +64,6 @@ class FederateSLO(models.Model):
session_key = models.CharField(max_length=40, blank=True, null=True)
ticket = models.CharField(max_length=255)
- @property
- def provider(self):
- component = self.username.split("@")
- return component[-1]
-
@classmethod
def clean_deleted_sessions(cls):
for federate_slo in cls.objects.all():
@@ -76,6 +81,14 @@ class User(models.Model):
username = models.CharField(max_length=30)
date = models.DateTimeField(auto_now=True)
+ def delete(self, *args, **kwargs):
+ if settings.CAS_FEDERATE:
+ FederateSLO.objects.filter(
+ username=self.username,
+ session_key=self.session_key
+ ).delete()
+ super(User, self).delete(*args, **kwargs)
+
@classmethod
def clean_old_entries(cls):
"""Remove users inactive since more that SESSION_COOKIE_AGE"""
diff --git a/cas_server/tests/mixin.py b/cas_server/tests/mixin.py
index ddbf2d2..4612fd2 100644
--- a/cas_server/tests/mixin.py
+++ b/cas_server/tests/mixin.py
@@ -191,3 +191,50 @@ class UserModels(object):
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
+
+
+class CanLogin(object):
+ """Assertion about login"""
+ def assert_logged(
+ self, client, response, warn=False,
+ code=200, username=settings.CAS_TEST_USER
+ ):
+ """Assertions testing that client is well authenticated"""
+ self.assertEqual(response.status_code, code)
+ # this message is displayed to the user upon successful authentication
+ self.assertIn(
+ (
+ b"You have successfully logged into "
+ b"the Central Authentication Service"
+ ),
+ response.content
+ )
+ # these session variables a set if usccessfully authenticated
+ self.assertEqual(client.session["username"], username)
+ self.assertIs(client.session["warn"], warn)
+ self.assertIs(client.session["authenticated"], True)
+
+ # on successfull authentication, a corresponding user object is created
+ self.assertTrue(
+ models.User.objects.get(
+ username=username,
+ session_key=client.session.session_key
+ )
+ )
+
+ def assert_login_failed(self, client, response, code=200):
+ """Assertions testing a failed login attempt"""
+ self.assertEqual(response.status_code, code)
+ # this message is displayed to the user upon successful authentication, so it should not
+ # appear
+ self.assertFalse(
+ (
+ b"You have successfully logged into "
+ b"the Central Authentication Service"
+ ) in response.content
+ )
+
+ # if authentication has failed, these session variables should not be set
+ self.assertTrue(client.session.get("username") is None)
+ self.assertTrue(client.session.get("warn") is None)
+ self.assertTrue(client.session.get("authenticated") is None)
diff --git a/cas_server/tests/test_federate.py b/cas_server/tests/test_federate.py
new file mode 100644
index 0000000..b4e76b2
--- /dev/null
+++ b/cas_server/tests/test_federate.py
@@ -0,0 +1,344 @@
+# -*- coding: utf-8 -*-
+# This program is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
+# more details.
+#
+# You should have received a copy of the GNU General Public License version 3
+# along with this program; if not, write to the Free Software Foundation, Inc., 51
+# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+#
+# (c) 2016 Valentin Samir
+"""tests for the CAS federate mode"""
+from cas_server import default_settings
+from cas_server.default_settings import settings
+
+import django
+from django.test import TestCase, Client
+from django.test.utils import override_settings
+
+from six.moves import reload_module
+
+from cas_server import utils, forms
+from cas_server.tests.mixin import BaseServicePattern, CanLogin
+from cas_server.tests import utils as tests_utils
+
+PROVIDERS = {
+ "example.com": ("http://127.0.0.1:8080", 1, "Example dot com"),
+ "example.org": ("http://127.0.0.1:8081", 2, "Example dot org"),
+ "example.net": ("http://127.0.0.1:8082", 3, "Example dot net"),
+ "example.test": ("http://127.0.0.1:8083", 'CAS_2_SAML_1_0'),
+}
+
+PROVIDERS_LIST = list(PROVIDERS.keys())
+PROVIDERS_LIST.sort()
+
+
+def getaddrinfo_mock(name, port, *args, **kwargs):
+ return [(2, 1, 6, '', ('127.0.0.1', 80))]
+
+
+@override_settings(
+ CAS_FEDERATE=True,
+ CAS_FEDERATE_PROVIDERS=PROVIDERS,
+ CAS_FEDERATE_PROVIDERS_LIST=PROVIDERS_LIST,
+ CAS_AUTH_CLASS="cas_server.auth.CASFederateAuth",
+ # test with a non ascii username
+ CAS_TEST_USER=u"dédé"
+)
+class FederateAuthLoginLogoutTestCase(TestCase, BaseServicePattern, CanLogin):
+ """tests for the views login logout and federate then the federated mode is enabled"""
+ def setUp(self):
+ """Prepare the test context"""
+ self.setup_service_patterns()
+ reload_module(forms)
+
+ def test_default_settings(self):
+ """default settings should populated some default variable then CAS_FEDERATE is True"""
+ provider_list = settings.CAS_FEDERATE_PROVIDERS_LIST
+ del settings.CAS_FEDERATE_PROVIDERS_LIST
+ del settings.CAS_AUTH_CLASS
+ reload_module(default_settings)
+ self.assertEqual(settings.CAS_FEDERATE_PROVIDERS_LIST, provider_list)
+ self.assertEqual(settings.CAS_AUTH_CLASS, "cas_server.auth.CASFederateAuth")
+
+ def test_login_get_provider(self):
+ """some assertion about the login page in federated mode"""
+ client = Client()
+ response = client.get("/login")
+ self.assertEqual(response.status_code, 200)
+ for key, value in settings.CAS_FEDERATE_PROVIDERS.items():
+ self.assertTrue('' % (
+ key,
+ utils.get_tuple(value, 2, key)
+ ) in response.content.decode("utf-8"))
+ self.assertEqual(response.context['post_url'], '/federate')
+
+ def test_login_post_provider(self, remember=False):
+ """test a successful login wrokflow"""
+ tickets = []
+ # choose the example.com provider
+ for (provider, cas_port) in [
+ ("example.com", 8080), ("example.org", 8081),
+ ("example.net", 8082), ("example.test", 8083)
+ ]:
+ # get a bare client
+ client = Client()
+ # fetch the login page
+ response = client.get("/login")
+ # in federated mode, we shoudl POST do /federate on the login page
+ self.assertEqual(response.context['post_url'], '/federate')
+ # get current form parameter
+ params = tests_utils.copy_form(response.context["form"])
+ params['provider'] = provider
+ if remember:
+ params['remember'] = 'on'
+ # post the choosed provider
+ response = client.post('/federate', params)
+ # we are redirected to the provider CAS client url
+ self.assertEqual(response.status_code, 302)
+ if remember:
+ self.assertEqual(response["Location"], '%s/federate/%s?remember=on' % (
+ 'http://testserver' if django.VERSION < (1, 9) else "",
+ provider
+ ))
+ else:
+ self.assertEqual(response["Location"], '%s/federate/%s' % (
+ 'http://testserver' if django.VERSION < (1, 9) else "",
+ provider
+ ))
+ # let's follow the redirect
+ response = client.get('/federate/%s' % provider)
+ # we are redirected to the provider CAS for authentication
+ self.assertEqual(response.status_code, 302)
+ self.assertEqual(
+ response["Location"],
+ "%s/login?service=http%%3A%%2F%%2Ftestserver%%2Ffederate%%2F%s" % (
+ settings.CAS_FEDERATE_PROVIDERS[provider][0],
+ provider
+ )
+ )
+ # let's generate a ticket
+ ticket = utils.gen_st()
+ # we lauch a dummy CAS server that only validate once for the service
+ # http://testserver/federate/example.com with `ticket`
+ tests_utils.DummyCAS.run(
+ ("http://testserver/federate/%s" % provider).encode("ascii"),
+ ticket.encode("ascii"),
+ settings.CAS_TEST_USER.encode("utf8"),
+ [],
+ cas_port
+ )
+ # we normally provide a good ticket and should be redirected to /login as the ticket
+ # get successfully validated again the dummy CAS
+ response = client.get('/federate/%s' % provider, {'ticket': ticket})
+ self.assertEqual(response.status_code, 302)
+ self.assertEqual(response["Location"], "%s/login" % (
+ 'http://testserver' if django.VERSION < (1, 9) else ""
+ ))
+ # follow the redirect
+ response = client.get("/login")
+ # we should get a page with a from with all widget hidden that auto POST to /login using
+ # javascript. If javascript is disabled, a "connect" button is showed
+ self.assertTrue(response.context['auto_submit'])
+ self.assertEqual(response.context['post_url'], '/login')
+ params = tests_utils.copy_form(response.context["form"])
+ # POST ge prefiled from parameters
+ response = client.post("/login", params)
+ # the user should now being authenticated using username test@`provider`
+ self.assert_logged(
+ client, response, username='%s@%s' % (settings.CAS_TEST_USER, provider)
+ )
+ tickets.append((provider, ticket, client))
+
+ # try to get a ticket
+ response = client.get("/login", {'service': self.service})
+ self.assertEqual(response.status_code, 302)
+ self.assertTrue(response["Location"].startswith("%s?ticket=" % self.service))
+ return tickets
+
+ def test_login_twice(self):
+ """Test that user id db is used for the second login (cf coverage)"""
+ self.test_login_post_provider()
+ self.test_login_post_provider()
+
+ @override_settings(CAS_FEDERATE=False)
+ def test_auth_federate_false(self):
+ """federated view should redirect to /login then CAS_FEDERATE is False"""
+ provider = "example.com"
+ client = Client()
+ response = client.get("/federate/%s" % provider)
+ self.assertEqual(response.status_code, 302)
+ self.assertEqual(response["Location"], "%s/login" % (
+ 'http://testserver' if django.VERSION < (1, 9) else ""
+ ))
+ response = client.post("%s/federate/%s" % (
+ 'http://testserver' if django.VERSION < (1, 9) else "",
+ provider
+ ))
+ self.assertEqual(response.status_code, 302)
+ self.assertEqual(response["Location"], "%s/login" % (
+ 'http://testserver' if django.VERSION < (1, 9) else ""
+ ))
+
+ def test_auth_federate_errors(self):
+ """
+ The federated view should redirect to /login if the provider is unknown or not provided,
+ try to fetch a new ticket if the provided ticket validation fail
+ (network error or bad ticket)
+ """
+ return
+ good_provider = "example.com"
+ bad_provider = "exemple.fr"
+ client = Client()
+ response = client.get("/federate/%s" % bad_provider)
+ self.assertEqual(response.status_code, 302)
+ self.assertEqual(response["Location"], "%s/login" % (
+ 'http://testserver' if django.VERSION < (1, 9) else ""
+ ))
+
+ # test CAS not avaible
+ response = client.get("/federate/%s" % good_provider, {'ticket': utils.gen_st()})
+ self.assertEqual(response.status_code, 302)
+ self.assertEqual(
+ response["Location"],
+ "%s/login?service=http%%3A%%2F%%2Ftestserver%%2Ffederate%%2F%s" % (
+ settings.CAS_FEDERATE_PROVIDERS[good_provider][0],
+ good_provider
+ )
+ )
+
+ # test CAS avaible but bad ticket
+ tests_utils.DummyCAS.run(
+ ("http://testserver/federate/%s" % good_provider).encode("ascii"),
+ utils.gen_st().encode("ascii"),
+ settings.CAS_TEST_USER.encode("utf-8"),
+ [],
+ 8080
+ )
+ response = client.get("/federate/%s" % good_provider, {'ticket': utils.gen_st()})
+ self.assertEqual(response.status_code, 302)
+ self.assertEqual(
+ response["Location"],
+ "%s/login?service=http%%3A%%2F%%2Ftestserver%%2Ffederate%%2F%s" % (
+ settings.CAS_FEDERATE_PROVIDERS[good_provider][0],
+ good_provider
+ )
+ )
+
+ response = client.post("/federate")
+ self.assertEqual(response.status_code, 302)
+ self.assertEqual(response["Location"], "%s/login" % (
+ 'http://testserver' if django.VERSION < (1, 9) else ""
+ ))
+
+ def test_auth_federate_slo(self):
+ """test that SLO receive from backend CAS log out the users"""
+ # get tickets and connected clients
+ tickets = self.test_login_post_provider()
+ for (provider, ticket, client) in tickets:
+ # SLO for an unkown ticket should do nothing
+ response = client.post(
+ "/federate/%s" % provider,
+ {'logoutRequest': tests_utils.logout_request(utils.gen_st())}
+ )
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, b"ok")
+ # Bad SLO format should do nothing
+ response = client.post(
+ "/federate/%s" % provider,
+ {'logoutRequest': ""}
+ )
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, b"ok")
+ # Bad SLO format should do nothing
+ response = client.post(
+ "/federate/%s" % provider,
+ {'logoutRequest': ""}
+ )
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, b"ok")
+ response = client.get("/login")
+ self.assert_logged(
+ client, response, username='%s@%s' % (settings.CAS_TEST_USER, provider)
+ )
+
+ # SLO for a previously logged ticket should log out the user if CAS version is
+ # 3 or 'CAS_2_SAML_1_0'
+ response = client.post(
+ "/federate/%s" % provider,
+ {'logoutRequest': tests_utils.logout_request(ticket)}
+ )
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.content, b"ok")
+
+ response = client.get("/login")
+ if settings.CAS_FEDERATE_PROVIDERS[provider][1] in {3, 'CAS_2_SAML_1_0'}: # support SLO
+ self.assert_login_failed(client, response)
+ else:
+ self.assert_logged(
+ client, response, username='%s@%s' % (settings.CAS_TEST_USER, provider)
+ )
+
+ def test_federate_logout(self):
+ """
+ test the logout function: the user should be log out
+ and redirected to his CAS logout page
+ """
+ # get tickets and connected clients
+ tickets = self.test_login_post_provider()
+ for (provider, _, client) in tickets:
+ response = client.get("/logout")
+ self.assertEqual(response.status_code, 302)
+ self.assertEqual(
+ response["Location"],
+ "%s/logout" % settings.CAS_FEDERATE_PROVIDERS[provider][0]
+ )
+ response = client.get("/login")
+ self.assert_login_failed(client, response)
+
+ def test_remember_provider(self):
+ """
+ If the user check remember, next login should not offer the chose of the backend CAS
+ and use the one store in the cookie
+ """
+ tickets = self.test_login_post_provider(remember=True)
+ for (provider, _, client) in tickets:
+ client.get("/logout")
+ response = client.get("/login")
+ self.assertEqual(response.status_code, 302)
+ self.assertEqual(response["Location"], "%s/federate/%s" % (
+ 'http://testserver' if django.VERSION < (1, 9) else "",
+ provider
+ ))
+
+ def test_login_bad_ticket(self):
+ """
+ Try login with a bad ticket:
+ login should fail and the main login page should be displayed to the user
+ """
+ provider = "example.com"
+ # get a bare client
+ client = Client()
+ session = client.session
+ session["federate_username"] = '%s@%s' % (settings.CAS_TEST_USER, provider)
+ session["federate_ticket"] = utils.gen_st()
+ try:
+ session.save()
+ response = client.get("/login")
+ # we should get a page with a from with all widget hidden that auto POST to /login using
+ # javascript. If javascript is disabled, a "connect" button is showed
+ self.assertTrue(response.context['auto_submit'])
+ self.assertEqual(response.context['post_url'], '/login')
+ params = tests_utils.copy_form(response.context["form"])
+ # POST, as (username, ticket) are not valid, we should get the federate login page
+ response = client.post("/login", params)
+ self.assertEqual(response.status_code, 200)
+ for key, value in settings.CAS_FEDERATE_PROVIDERS.items():
+ self.assertTrue('' % (
+ key,
+ utils.get_tuple(value, 2, key)
+ ) in response.content.decode("utf-8"))
+ self.assertEqual(response.context['post_url'], '/federate')
+ except AttributeError:
+ pass
diff --git a/cas_server/tests/test_models.py b/cas_server/tests/test_models.py
index e75f54f..cdaece8 100644
--- a/cas_server/tests/test_models.py
+++ b/cas_server/tests/test_models.py
@@ -12,20 +12,81 @@
"""Tests module for models"""
from cas_server.default_settings import settings
-from django.test import TestCase
+from django.test import TestCase, Client
from django.test.utils import override_settings
from django.utils import timezone
from datetime import timedelta
from importlib import import_module
-from cas_server import models
+from cas_server import models, utils
from cas_server.tests.utils import get_auth_client, HttpParamsHandler
from cas_server.tests.mixin import UserModels, BaseServicePattern
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
+class FederatedUserTestCase(TestCase, UserModels):
+ """test for the federated user model"""
+ def test_clean_old_entries(self):
+ """tests for clean_old_entries that should delete federated user no longer used"""
+ client = Client()
+ client.get("/login")
+ models.FederatedUser.objects.create(
+ username="test1", provider="example.com", attributs={}, ticket=""
+ )
+ models.FederatedUser.objects.create(
+ username="test2", provider="example.com", attributs={}, ticket=""
+ )
+ models.FederatedUser.objects.all().update(
+ last_update=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT + 10))
+ )
+ models.FederatedUser.objects.create(
+ username="test3", provider="example.com", attributs={}, ticket=""
+ )
+ models.User.objects.create(
+ username="test1@example.com", session_key=client.session.session_key
+ )
+ models.FederatedUser.clean_old_entries()
+ self.assertEqual(len(models.FederatedUser.objects.all()), 2)
+ with self.assertRaises(models.FederatedUser.DoesNotExist):
+ models.FederatedUser.objects.get(username="test2")
+
+
+class FederateSLOTestCase(TestCase, UserModels):
+ """test for the federated SLO model"""
+ def test_clean_deleted_sessions(self):
+ """
+ tests for clean_deleted_sessions that should delete object for which matching session
+ do not exists anymore
+ """
+ client1 = Client()
+ client2 = Client()
+ client1.get("/login")
+ client2.get("/login")
+ session = client2.session
+ session['authenticated'] = True
+ try:
+ session.save()
+ except AttributeError:
+ pass
+ models.FederateSLO.objects.create(
+ username="test1@example.com",
+ session_key=client1.session.session_key,
+ ticket=utils.gen_st()
+ )
+ models.FederateSLO.objects.create(
+ username="test2@example.com",
+ session_key=client2.session.session_key,
+ ticket=utils.gen_st()
+ )
+ self.assertEqual(len(models.FederateSLO.objects.all()), 2)
+ models.FederateSLO.clean_deleted_sessions()
+ self.assertEqual(len(models.FederateSLO.objects.all()), 1)
+ with self.assertRaises(models.FederateSLO.DoesNotExist):
+ models.FederateSLO.objects.get(username="test1@example.com")
+
+
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class UserTestCase(TestCase, UserModels):
"""tests for the user models"""
diff --git a/cas_server/tests/test_utils.py b/cas_server/tests/test_utils.py
index 76fa2cc..411848a 100644
--- a/cas_server/tests/test_utils.py
+++ b/cas_server/tests/test_utils.py
@@ -10,7 +10,7 @@
#
# (c) 2016 Valentin Samir
"""Tests module for utils"""
-from django.test import TestCase
+from django.test import TestCase, RequestFactory
import six
@@ -189,3 +189,22 @@ class UtilsTestCase(TestCase):
self.assertFalse(utils.crypt_salt_is_valid("$$")) # start with $ followed by $
self.assertFalse(utils.crypt_salt_is_valid("$toto")) # start with $ but no secondary $
self.assertFalse(utils.crypt_salt_is_valid("$toto$toto")) # algorithm toto not known
+
+ def test_get_current_url(self):
+ """test the function get_current_url"""
+ factory = RequestFactory()
+ request = factory.get('/truc/muche?test=1')
+ self.assertEqual(utils.get_current_url(request), 'http://testserver/truc/muche?test=1')
+ self.assertEqual(
+ utils.get_current_url(request, ignore_params={'test'}),
+ 'http://testserver/truc/muche'
+ )
+
+ def test_get_tuple(self):
+ """test the function get_tuple"""
+ test_tuple = (1, 2, 3)
+ for index, value in enumerate(test_tuple):
+ self.assertEqual(utils.get_tuple(test_tuple, index), value)
+ self.assertEqual(utils.get_tuple(test_tuple, 3), None)
+ self.assertEqual(utils.get_tuple(test_tuple, 3, 'toto'), 'toto')
+ self.assertEqual(utils.get_tuple(None, 3), None)
diff --git a/cas_server/tests/test_view.py b/cas_server/tests/test_view.py
index 0acd52f..49fa2d2 100644
--- a/cas_server/tests/test_view.py
+++ b/cas_server/tests/test_view.py
@@ -1,4 +1,4 @@
-# ⁻*- coding: utf-8 -*-
+# -*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
@@ -36,57 +36,17 @@ from cas_server.tests.utils import (
HttpParamsHandler,
Http404Handler
)
-from cas_server.tests.mixin import BaseServicePattern, XmlContent
+from cas_server.tests.mixin import BaseServicePattern, XmlContent, CanLogin
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
-class LoginTestCase(TestCase, BaseServicePattern):
+class LoginTestCase(TestCase, BaseServicePattern, CanLogin):
"""Tests for the login view"""
def setUp(self):
"""Prepare the test context:"""
# we prepare a bunch a service url and service patterns for tests
self.setup_service_patterns()
- def assert_logged(self, client, response, warn=False, code=200):
- """Assertions testing that client is well authenticated"""
- self.assertEqual(response.status_code, code)
- # this message is displayed to the user upon successful authentication
- self.assertTrue(
- (
- b"You have successfully logged into "
- b"the Central Authentication Service"
- ) in response.content
- )
- # these session variables a set if usccessfully authenticated
- self.assertTrue(client.session["username"] == settings.CAS_TEST_USER)
- self.assertTrue(client.session["warn"] is warn)
- self.assertTrue(client.session["authenticated"] is True)
-
- # on successfull authentication, a corresponding user object is created
- self.assertTrue(
- models.User.objects.get(
- username=settings.CAS_TEST_USER,
- session_key=client.session.session_key
- )
- )
-
- def assert_login_failed(self, client, response, code=200):
- """Assertions testing a failed login attempt"""
- self.assertEqual(response.status_code, code)
- # this message is displayed to the user upon successful authentication, so it should not
- # appear
- self.assertFalse(
- (
- b"You have successfully logged into "
- b"the Central Authentication Service"
- ) in response.content
- )
-
- # if authentication has failed, these session variables should not be set
- self.assertTrue(client.session.get("username") is None)
- self.assertTrue(client.session.get("warn") is None)
- self.assertTrue(client.session.get("authenticated") is None)
-
def test_login_view_post_goodpass_goodlt(self):
"""Test a successul login"""
# we get a client who fetch a frist time the login page and the login form default
diff --git a/cas_server/tests/utils.py b/cas_server/tests/utils.py
index bd692e9..b8419c6 100644
--- a/cas_server/tests/utils.py
+++ b/cas_server/tests/utils.py
@@ -13,14 +13,33 @@
from cas_server.default_settings import settings
from django.test import Client
+from django.template import loader, Context
+from django.utils import timezone
import cgi
+import six
from threading import Thread
from lxml import etree
from six.moves import BaseHTTPServer
from six.moves.urllib.parse import urlparse, parse_qsl
+from datetime import timedelta
from cas_server import models
+from cas_server import utils
+
+
+def return_unicode(string, charset):
+ if not isinstance(string, six.text_type):
+ return string.decode(charset)
+ else:
+ return string
+
+
+def return_bytes(string, charset):
+ if isinstance(string, six.text_type):
+ return string.encode(charset)
+ else:
+ return string
def copy_form(form):
@@ -149,10 +168,10 @@ class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler):
return
@classmethod
- def run(cls):
+ def run(cls, port=0):
"""Run a BaseHTTPServer using this class as handler"""
server_class = BaseHTTPServer.HTTPServer
- httpd = server_class(("127.0.0.1", 0), cls)
+ httpd = server_class(("127.0.0.1", port), cls)
(host, port) = httpd.socket.getsockname()
def lauch():
@@ -178,3 +197,143 @@ class Http404Handler(HttpParamsHandler):
def do_POST(self):
"""Called on a POST request on the BaseHTTPServer"""
return self.do_GET()
+
+
+class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
+
+ def test_params(self):
+ if (
+ self.server.ticket is not None and
+ self.params.get("service").encode("ascii") == self.server.service and
+ self.params.get("ticket").encode("ascii") == self.server.ticket
+ ):
+ self.server.ticket = None
+ print("good")
+ return True
+ else:
+ print("bad (%r, %r) != (%r, %r)" % (
+ self.params.get("service").encode("ascii"),
+ self.params.get("ticket").encode("ascii"),
+ self.server.service,
+ self.server.ticket
+ ))
+
+ return False
+
+ def send_headers(self, code, content_type):
+ self.send_response(200)
+ self.send_header("Content-type", content_type)
+ self.end_headers()
+
+ def do_GET(self):
+ url = urlparse(self.path)
+ self.params = dict(parse_qsl(url.query))
+ if url.path == "/validate":
+ self.send_headers(200, "text/plain; charset=utf-8")
+ if self.test_params():
+ self.wfile.write(b"yes\n" + self.server.username + b"\n")
+ self.server.ticket = None
+ else:
+ self.wfile.write(b"no\n")
+ elif url.path in {
+ '/serviceValidate', '/serviceValidate',
+ '/p3/serviceValidate', '/p3/proxyValidate'
+ }:
+ self.send_headers(200, "text/xml; charset=utf-8")
+ if self.test_params():
+ t = loader.get_template('cas_server/serviceValidate.xml')
+ c = Context({
+ 'username': self.server.username,
+ 'attributes': self.server.attributes
+ })
+ self.wfile.write(return_bytes(t.render(c), "utf8"))
+ else:
+ t = loader.get_template('cas_server/serviceValidateError.xml')
+ c = Context({
+ 'code': 'BAD_SERVICE_TICKET',
+ 'msg': 'Valids are (%r, %r)' % (self.server.service, self.server.ticket)
+ })
+ self.wfile.write(return_bytes(t.render(c), "utf8"))
+ else:
+ self.return_404()
+
+ def do_POST(self):
+ url = urlparse(self.path)
+ self.params = dict(parse_qsl(url.query))
+ if url.path == "/samlValidate":
+ self.send_headers(200, "text/xml; charset=utf-8")
+ length = int(self.headers.get('content-length'))
+ root = etree.fromstring(self.rfile.read(length))
+ auth_req = root.getchildren()[1].getchildren()[0]
+ ticket = auth_req.getchildren()[0].text.encode("ascii")
+ if (
+ self.server.ticket is not None and
+ self.params.get("TARGET").encode("ascii") == self.server.service and
+ ticket == self.server.ticket
+ ):
+ self.server.ticket = None
+ t = loader.get_template('cas_server/samlValidate.xml')
+ c = Context({
+ 'IssueInstant': timezone.now().isoformat(),
+ 'expireInstant': (timezone.now() + timedelta(seconds=60)).isoformat(),
+ 'Recipient': self.server.service,
+ 'ResponseID': utils.gen_saml_id(),
+ 'username': self.server.username,
+ 'attributes': self.server.attributes,
+ })
+ self.wfile.write(return_bytes(t.render(c), "utf8"))
+ else:
+ t = loader.get_template('cas_server/samlValidateError.xml')
+ c = Context({
+ 'IssueInstant': timezone.now().isoformat(),
+ 'ResponseID': utils.gen_saml_id(),
+ 'code': 'BAD_SERVICE_TICKET',
+ 'msg': 'Valids are (%r, %r)' % (self.server.service, self.server.ticket)
+ })
+ self.wfile.write(return_bytes(t.render(c), "utf8"))
+ else:
+ self.return_404()
+
+ def return_404(self):
+ self.send_response(404)
+ self.send_header(b"Content-type", "text/plain")
+ self.end_headers()
+ self.wfile.write("not found")
+
+ def log_message(self, *args):
+ """silent any log message"""
+ return
+
+ @classmethod
+ def run(cls, service, ticket, username, attributes, port=0):
+ """Run a BaseHTTPServer using this class as handler"""
+ server_class = BaseHTTPServer.HTTPServer
+ httpd = server_class(("127.0.0.1", port), cls)
+ httpd.service = service
+ httpd.ticket = ticket
+ httpd.username = username
+ httpd.attributes = attributes
+ (host, port) = httpd.socket.getsockname()
+
+ def lauch():
+ """routine to lauch in a background thread"""
+ httpd.handle_request()
+ httpd.server_close()
+
+ httpd_thread = Thread(target=lauch)
+ httpd_thread.daemon = True
+ httpd_thread.start()
+ return (httpd, host, port)
+
+
+def logout_request(ticket):
+ return u"""
+
+%(ticket)s
+""" % \
+ {
+ 'id': utils.gen_saml_id(),
+ 'datetime': timezone.now().isoformat(),
+ 'ticket': ticket
+ }
diff --git a/cas_server/views.py b/cas_server/views.py
index 9543c6f..05ce47d 100644
--- a/cas_server/views.py
+++ b/cas_server/views.py
@@ -123,15 +123,18 @@ class LogoutView(View, LogoutMixin):
self.init_get(request)
# if CAS federation mode is enable, bakup the provider before flushing the sessions
if settings.CAS_FEDERATE:
- component = self.request.session.get("username").split('@')
- provider = component[-1]
- auth = CASFederateValidateUser(provider, service_url="")
+ if "username" in self.request.session:
+ component = self.request.session["username"].split('@')
+ provider = component[-1]
+ auth = CASFederateValidateUser(provider, service_url="")
+ else:
+ auth = None
session_nb = self.logout(self.request.GET.get("all"))
# if CAS federation mode is enable, redirect to user CAS logout page
if settings.CAS_FEDERATE:
- params = utils.copy_params(request.GET)
- url = utils.update_url(auth.get_logout_url(), params)
- if url:
+ if auth is not None:
+ params = utils.copy_params(request.GET)
+ url = utils.update_url(auth.get_logout_url(), params)
return HttpResponseRedirect(url)
# if service is set, redirect to service after logout
if self.service:
@@ -195,7 +198,7 @@ class FederateAuth(View):
@staticmethod
def get_cas_client(request, provider):
- if provider in settings.CAS_FEDERATE_PROVIDERS:
+ if provider in settings.CAS_FEDERATE_PROVIDERS: # pragma: no branch (should always be true)
service_url = utils.get_current_url(request, {"ticket", "provider"})
return CASFederateValidateUser(provider, service_url)
@@ -207,14 +210,14 @@ class FederateAuth(View):
auth = self.get_cas_client(request, provider)
try:
auth.clean_sessions(request.POST['logoutRequest'])
- except KeyError:
+ except (KeyError, AttributeError):
pass
return HttpResponse("ok")
# else, a User is trying to log in using an identity provider
else:
# Manually checking for csrf to protect the code below
reason = CsrfViewMiddleware().process_view(request, None, (), {})
- if reason is not None:
+ if reason is not None: # pragma: no cover (csrf checks are disabled during tests)
return reason # Failed the test, stop here.
form = forms.FederateSelect(request.POST)
if form.is_valid():
@@ -252,7 +255,7 @@ class FederateAuth(View):
ticket = request.GET['ticket']
if auth.verify_ticket(ticket):
params = utils.copy_params(request.GET, ignore={"ticket"})
- username = "%s@%s" % (auth.username, auth.provider)
+ username = u"%s@%s" % (auth.username, auth.provider)
request.session["federate_username"] = username
request.session["federate_ticket"] = ticket
auth.register_slo(username, request.session.session_key, ticket)
@@ -281,9 +284,9 @@ class LoginView(View, LogoutMixin):
renewed = False
warned = False
- if settings.CAS_FEDERATE:
- username = None
- ticket = None
+ # used if CAS_FEDERATE is True
+ username = None
+ ticket = None
INVALID_LOGIN_TICKET = 1
USER_LOGIN_OK = 2
@@ -354,7 +357,7 @@ class LoginView(View, LogoutMixin):
elif ret == self.USER_LOGIN_FAILURE: # bad user login
if settings.CAS_FEDERATE:
self.ticket = None
- self.usernalme = None
+ self.username = None
self.init_form()
self.logout()
elif ret == self.USER_ALREADY_LOGGED:
@@ -682,11 +685,14 @@ class Auth(View):
secret = request.POST.get('secret')
if not settings.CAS_AUTH_SHARED_SECRET:
- return HttpResponse("no\nplease set CAS_AUTH_SHARED_SECRET", content_type="text/plain")
+ return HttpResponse(
+ "no\nplease set CAS_AUTH_SHARED_SECRET",
+ content_type="text/plain; charset=utf-8"
+ )
if secret != settings.CAS_AUTH_SHARED_SECRET:
- return HttpResponse("no\n", content_type="text/plain")
+ return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
if not username or not password or not service:
- return HttpResponse("no\n", content_type="text/plain")
+ return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
form = forms.UserCredential(
request.POST,
initial={
@@ -714,11 +720,11 @@ class Auth(View):
service_pattern.check_user(user)
if not request.session.get("authenticated"):
user.delete()
- return HttpResponse("yes\n", content_type="text/plain")
+ return HttpResponse(u"yes\n", content_type="text/plain; charset=utf-8")
except (ServicePattern.DoesNotExist, models.ServicePatternException):
- return HttpResponse("no\n", content_type="text/plain")
+ return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
else:
- return HttpResponse("no\n", content_type="text/plain")
+ return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
class Validate(View):
@@ -758,7 +764,10 @@ class Validate(View):
username = username[0]
else:
username = ticket.user.username
- return HttpResponse("yes\n%s\n" % username, content_type="text/plain")
+ return HttpResponse(
+ u"yes\n%s\n" % username,
+ content_type="text/plain; charset=utf-8"
+ )
except ServiceTicket.DoesNotExist:
logger.warning(
(
@@ -769,10 +778,10 @@ class Validate(View):
service
)
)
- return HttpResponse("no\n", content_type="text/plain")
+ return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
else:
logger.warning("Validate: service or ticket missing")
- return HttpResponse("no\n", content_type="text/plain")
+ return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
class ValidateError(Exception):
@@ -815,8 +824,8 @@ class ValidateService(View, AttributesMixin):
if not self.service or not self.ticket:
logger.warning("ValidateService: missing ticket or service")
return ValidateError(
- 'INVALID_REQUEST',
- "you must specify a service and a ticket"
+ u'INVALID_REQUEST',
+ u"you must specify a service and a ticket"
).render(request)
else:
try:
@@ -886,14 +895,14 @@ class ValidateService(View, AttributesMixin):
for prox in ticket.proxies.all():
proxies.append(prox.url)
else:
- raise ValidateError('INVALID_TICKET', self.ticket)
+ raise ValidateError(u'INVALID_TICKET', self.ticket)
ticket.validate = True
ticket.save()
if ticket.service != self.service:
- raise ValidateError('INVALID_SERVICE', self.service)
+ raise ValidateError(u'INVALID_SERVICE', self.service)
return ticket, proxies
except (ServiceTicket.DoesNotExist, ProxyTicket.DoesNotExist):
- raise ValidateError('INVALID_TICKET', 'ticket not found')
+ raise ValidateError(u'INVALID_TICKET', 'ticket not found')
def process_pgturl(self, params):
"""Handle PGT request"""
@@ -939,18 +948,18 @@ class ValidateService(View, AttributesMixin):
except requests.exceptions.RequestException as error:
error = utils.unpack_nested_exception(error)
raise ValidateError(
- 'INVALID_PROXY_CALLBACK',
- "%s: %s" % (type(error), str(error))
+ u'INVALID_PROXY_CALLBACK',
+ u"%s: %s" % (type(error), str(error))
)
else:
raise ValidateError(
- 'INVALID_PROXY_CALLBACK',
- "callback url not allowed by configuration"
+ u'INVALID_PROXY_CALLBACK',
+ u"callback url not allowed by configuration"
)
except ServicePattern.DoesNotExist:
raise ValidateError(
- 'INVALID_PROXY_CALLBACK',
- 'callback url not allowed by configuration'
+ u'INVALID_PROXY_CALLBACK',
+ u'callback url not allowed by configuration'
)
@@ -971,8 +980,8 @@ class Proxy(View):
return self.process_proxy()
else:
raise ValidateError(
- 'INVALID_REQUEST',
- "you must specify and pgt and targetService"
+ u'INVALID_REQUEST',
+ u"you must specify and pgt and targetService"
)
except ValidateError as error:
logger.warning("Proxy: validation error: %s %s" % (error.code, error.msg))
@@ -985,8 +994,8 @@ class Proxy(View):
pattern = ServicePattern.validate(self.target_service)
if not pattern.proxy:
raise ValidateError(
- 'UNAUTHORIZED_SERVICE',
- 'the service %s do not allow proxy ticket' % self.target_service
+ u'UNAUTHORIZED_SERVICE',
+ u'the service %s do not allow proxy ticket' % self.target_service
)
# is the proxy granting ticket valid
ticket = ProxyGrantingTicket.objects.get(
@@ -1015,13 +1024,13 @@ class Proxy(View):
content_type="text/xml; charset=utf-8"
)
except ProxyGrantingTicket.DoesNotExist:
- raise ValidateError('INVALID_TICKET', 'PGT %s not found' % self.pgt)
+ raise ValidateError(u'INVALID_TICKET', u'PGT %s not found' % self.pgt)
except ServicePattern.DoesNotExist:
- raise ValidateError('UNAUTHORIZED_SERVICE', self.target_service)
+ raise ValidateError(u'UNAUTHORIZED_SERVICE', self.target_service)
except (models.BadUsername, models.BadFilter, models.UserFieldNotDefined):
raise ValidateError(
- 'UNAUTHORIZED_USER',
- 'User %s not allowed on %s' % (ticket.user.username, self.target_service)
+ u'UNAUTHORIZED_USER',
+ u'User %s not allowed on %s' % (ticket.user.username, self.target_service)
)
@@ -1129,18 +1138,18 @@ class SamlValidate(View, AttributesMixin):
)
else:
raise SamlValidateError(
- 'AuthnFailed',
- 'ticket %s should begin with PT- or ST-' % ticket
+ u'AuthnFailed',
+ u'ticket %s should begin with PT- or ST-' % ticket
)
ticket.validate = True
ticket.save()
if ticket.service != self.target:
raise SamlValidateError(
- 'AuthnFailed',
- 'TARGET %s do not match ticket service' % self.target
+ u'AuthnFailed',
+ u'TARGET %s do not match ticket service' % self.target
)
return ticket
except (IndexError, KeyError):
- raise SamlValidateError('VersionMismatch')
+ raise SamlValidateError(u'VersionMismatch')
except (ServiceTicket.DoesNotExist, ProxyTicket.DoesNotExist):
- raise SamlValidateError('AuthnFailed', 'ticket %s not found' % ticket)
+ raise SamlValidateError(u'AuthnFailed', u'ticket %s not found' % ticket)