Add unit tests for when CAS_FEDERATE is True

Also fix some unicode related bugs
This commit is contained in:
Valentin Samir 2016-07-03 13:51:00 +02:00
parent fcd906ca78
commit 90daf3d2a0
13 changed files with 749 additions and 144 deletions

View File

@ -171,7 +171,7 @@ class CASFederateAuth(AuthUser):
def attributs(self): def attributs(self):
"""return a dict of user attributes""" """return a dict of user attributes"""
if not self.user: if not self.user: # pragma: no cover (should not happen)
return {} return {}
else: else:
return self.user.attributs return self.user.attributs

View File

@ -14,7 +14,6 @@ from django.conf import settings
from django.contrib.staticfiles.templatetags.staticfiles import static from django.contrib.staticfiles.templatetags.staticfiles import static
import re import re
import six
def setting_default(name, default_value): def setting_default(name, default_value):
@ -112,13 +111,10 @@ except AttributeError:
key = settings.CAS_FEDERATE_PROVIDERS[key][2].lower() key = settings.CAS_FEDERATE_PROVIDERS[key][2].lower()
else: else:
key = key.lower() key = key.lower()
if isinstance(key, six.string_types) or isinstance(key, six.text_type):
return tuple( return tuple(
int(num) if num else alpha int(num) if num else alpha
for num, alpha in __cas_federate_providers_list_sort.tokenize(key) for num, alpha in __cas_federate_providers_list_sort.tokenize(key)
) )
else:
return key
__cas_federate_providers_list_sort.tokenize = re.compile(r'(\d+)|(\D+)').findall __cas_federate_providers_list_sort.tokenize = re.compile(r'(\d+)|(\D+)').findall
__CAS_FEDERATE_PROVIDERS_LIST.sort(key=__cas_federate_providers_list_sort) __CAS_FEDERATE_PROVIDERS_LIST.sort(key=__cas_federate_providers_list_sort)

View File

@ -15,6 +15,7 @@ from .cas import CASClient
from .models import FederatedUser, FederateSLO, User from .models import FederatedUser, FederateSLO, User
from importlib import import_module from importlib import import_module
from six.moves import urllib
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
@ -27,7 +28,7 @@ class CASFederateValidateUser(object):
def __init__(self, provider, service_url): def __init__(self, provider, service_url):
self.provider = provider self.provider = provider
if provider in settings.CAS_FEDERATE_PROVIDERS: if provider in settings.CAS_FEDERATE_PROVIDERS: # pragma: no branch (should always be True)
(server_url, version) = settings.CAS_FEDERATE_PROVIDERS[provider][:2] (server_url, version) = settings.CAS_FEDERATE_PROVIDERS[provider][:2]
self.client = CASClient( self.client = CASClient(
service_url=service_url, service_url=service_url,
@ -44,9 +45,12 @@ class CASFederateValidateUser(object):
def verify_ticket(self, ticket): def verify_ticket(self, ticket):
"""test `password` agains the user""" """test `password` agains the user"""
if self.client is None: if self.client is None: # pragma: no cover (should not happen)
return False return False
try:
username, attributs = self.client.verify_ticket(ticket)[:2] username, attributs = self.client.verify_ticket(ticket)[:2]
except urllib.error.URLError:
return False
if username is not None: if username is not None:
if attributs is None: if attributs is None:
attributs = {} attributs = {}
@ -83,11 +87,10 @@ class CASFederateValidateUser(object):
def clean_sessions(self, logout_request): def clean_sessions(self, logout_request):
try: try:
slos = self.client.get_saml_slos(logout_request) slos = self.client.get_saml_slos(logout_request) or []
except NameError: except NameError: # pragma: no cover (should not happen)
slos = [] slos = []
for slo in slos: for slo in slos:
try:
for federate_slo in FederateSLO.objects.filter(ticket=slo.text): for federate_slo in FederateSLO.objects.filter(ticket=slo.text):
session = SessionStore(session_key=federate_slo.session_key) session = SessionStore(session_key=federate_slo.session_key)
session.flush() session.flush()
@ -98,8 +101,6 @@ class CASFederateValidateUser(object):
) )
user.logout() user.logout()
user.delete() user.delete()
except User.DoesNotExist: except User.DoesNotExist: # pragma: no cover (should not happen)
pass pass
federate_slo.delete() federate_slo.delete()
except FederateSLO.DoesNotExist:
pass

View File

@ -31,6 +31,8 @@ class WarnForm(forms.Form):
class FederateSelect(forms.Form): class FederateSelect(forms.Form):
provider = forms.ChoiceField( provider = forms.ChoiceField(
label=_('Identity provider'), label=_('Identity provider'),
# with use a lambda abstraction to delay the access to settings.CAS_FEDERATE_PROVIDERS
# this is usefull to use the override_settings decorator in tests
choices=[ choices=[
( (
p, p,
@ -88,8 +90,12 @@ class FederateUserCredential(UserCredential):
user = models.FederatedUser.objects.get(username=username, provider=provider) user = models.FederatedUser.objects.get(username=username, provider=provider)
user.ticket = "" user.ticket = ""
user.save() user.save()
except models.FederatedUser.DoesNotExist: # should not happed as is the FederatedUser do not exists, super should
raise # raise before a ValidationError("bad user")
except models.FederatedUser.DoesNotExist: # pragma: no cover (should not happend)
raise forms.ValidationError(
_(u"User not found in the temporary database, please try to reconnect")
)
return cleaned_data return cleaned_data

View File

@ -1,11 +1,7 @@
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.utils import timezone
from datetime import timedelta
from ... import models from ... import models
from ...default_settings import settings
class Command(BaseCommand): class Command(BaseCommand):
@ -13,11 +9,5 @@ class Command(BaseCommand):
help = _(u"Clean old federated users") help = _(u"Clean old federated users")
def handle(self, *args, **options): def handle(self, *args, **options):
federated_users = models.FederatedUser.objects.filter( models.FederatedUser.clean_old_entries()
last_update__lt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT))
)
known_users = {user.username for user in models.User.objects.all()}
for user in federated_users:
if not ('%s@%s' % (user.username, user.provider)) in known_users:
user.delete()
models.FederateSLO.clean_deleted_sessions() models.FederateSLO.clean_deleted_sessions()

View File

@ -46,6 +46,16 @@ class FederatedUser(models.Model):
def __unicode__(self): def __unicode__(self):
return u"%s@%s" % (self.username, self.provider) return u"%s@%s" % (self.username, self.provider)
@classmethod
def clean_old_entries(cls):
federated_users = cls.objects.filter(
last_update__lt=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT))
)
known_users = {user.username for user in User.objects.all()}
for user in federated_users:
if not ('%s@%s' % (user.username, user.provider)) in known_users:
user.delete()
class FederateSLO(models.Model): class FederateSLO(models.Model):
class Meta: class Meta:
@ -54,11 +64,6 @@ class FederateSLO(models.Model):
session_key = models.CharField(max_length=40, blank=True, null=True) session_key = models.CharField(max_length=40, blank=True, null=True)
ticket = models.CharField(max_length=255) ticket = models.CharField(max_length=255)
@property
def provider(self):
component = self.username.split("@")
return component[-1]
@classmethod @classmethod
def clean_deleted_sessions(cls): def clean_deleted_sessions(cls):
for federate_slo in cls.objects.all(): for federate_slo in cls.objects.all():
@ -76,6 +81,14 @@ class User(models.Model):
username = models.CharField(max_length=30) username = models.CharField(max_length=30)
date = models.DateTimeField(auto_now=True) date = models.DateTimeField(auto_now=True)
def delete(self, *args, **kwargs):
if settings.CAS_FEDERATE:
FederateSLO.objects.filter(
username=self.username,
session_key=self.session_key
).delete()
super(User, self).delete(*args, **kwargs)
@classmethod @classmethod
def clean_old_entries(cls): def clean_old_entries(cls):
"""Remove users inactive since more that SESSION_COOKIE_AGE""" """Remove users inactive since more that SESSION_COOKIE_AGE"""

View File

@ -191,3 +191,50 @@ class UserModels(object):
username=settings.CAS_TEST_USER, username=settings.CAS_TEST_USER,
session_key=client.session.session_key session_key=client.session.session_key
) )
class CanLogin(object):
"""Assertion about login"""
def assert_logged(
self, client, response, warn=False,
code=200, username=settings.CAS_TEST_USER
):
"""Assertions testing that client is well authenticated"""
self.assertEqual(response.status_code, code)
# this message is displayed to the user upon successful authentication
self.assertIn(
(
b"You have successfully logged into "
b"the Central Authentication Service"
),
response.content
)
# these session variables a set if usccessfully authenticated
self.assertEqual(client.session["username"], username)
self.assertIs(client.session["warn"], warn)
self.assertIs(client.session["authenticated"], True)
# on successfull authentication, a corresponding user object is created
self.assertTrue(
models.User.objects.get(
username=username,
session_key=client.session.session_key
)
)
def assert_login_failed(self, client, response, code=200):
"""Assertions testing a failed login attempt"""
self.assertEqual(response.status_code, code)
# this message is displayed to the user upon successful authentication, so it should not
# appear
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
# if authentication has failed, these session variables should not be set
self.assertTrue(client.session.get("username") is None)
self.assertTrue(client.session.get("warn") is None)
self.assertTrue(client.session.get("authenticated") is None)

View File

@ -0,0 +1,344 @@
# -*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
# more details.
#
# You should have received a copy of the GNU General Public License version 3
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# (c) 2016 Valentin Samir
"""tests for the CAS federate mode"""
from cas_server import default_settings
from cas_server.default_settings import settings
import django
from django.test import TestCase, Client
from django.test.utils import override_settings
from six.moves import reload_module
from cas_server import utils, forms
from cas_server.tests.mixin import BaseServicePattern, CanLogin
from cas_server.tests import utils as tests_utils
PROVIDERS = {
"example.com": ("http://127.0.0.1:8080", 1, "Example dot com"),
"example.org": ("http://127.0.0.1:8081", 2, "Example dot org"),
"example.net": ("http://127.0.0.1:8082", 3, "Example dot net"),
"example.test": ("http://127.0.0.1:8083", 'CAS_2_SAML_1_0'),
}
PROVIDERS_LIST = list(PROVIDERS.keys())
PROVIDERS_LIST.sort()
def getaddrinfo_mock(name, port, *args, **kwargs):
return [(2, 1, 6, '', ('127.0.0.1', 80))]
@override_settings(
CAS_FEDERATE=True,
CAS_FEDERATE_PROVIDERS=PROVIDERS,
CAS_FEDERATE_PROVIDERS_LIST=PROVIDERS_LIST,
CAS_AUTH_CLASS="cas_server.auth.CASFederateAuth",
# test with a non ascii username
CAS_TEST_USER=u"dédé"
)
class FederateAuthLoginLogoutTestCase(TestCase, BaseServicePattern, CanLogin):
"""tests for the views login logout and federate then the federated mode is enabled"""
def setUp(self):
"""Prepare the test context"""
self.setup_service_patterns()
reload_module(forms)
def test_default_settings(self):
"""default settings should populated some default variable then CAS_FEDERATE is True"""
provider_list = settings.CAS_FEDERATE_PROVIDERS_LIST
del settings.CAS_FEDERATE_PROVIDERS_LIST
del settings.CAS_AUTH_CLASS
reload_module(default_settings)
self.assertEqual(settings.CAS_FEDERATE_PROVIDERS_LIST, provider_list)
self.assertEqual(settings.CAS_AUTH_CLASS, "cas_server.auth.CASFederateAuth")
def test_login_get_provider(self):
"""some assertion about the login page in federated mode"""
client = Client()
response = client.get("/login")
self.assertEqual(response.status_code, 200)
for key, value in settings.CAS_FEDERATE_PROVIDERS.items():
self.assertTrue('<option value="%s">%s</option>' % (
key,
utils.get_tuple(value, 2, key)
) in response.content.decode("utf-8"))
self.assertEqual(response.context['post_url'], '/federate')
def test_login_post_provider(self, remember=False):
"""test a successful login wrokflow"""
tickets = []
# choose the example.com provider
for (provider, cas_port) in [
("example.com", 8080), ("example.org", 8081),
("example.net", 8082), ("example.test", 8083)
]:
# get a bare client
client = Client()
# fetch the login page
response = client.get("/login")
# in federated mode, we shoudl POST do /federate on the login page
self.assertEqual(response.context['post_url'], '/federate')
# get current form parameter
params = tests_utils.copy_form(response.context["form"])
params['provider'] = provider
if remember:
params['remember'] = 'on'
# post the choosed provider
response = client.post('/federate', params)
# we are redirected to the provider CAS client url
self.assertEqual(response.status_code, 302)
if remember:
self.assertEqual(response["Location"], '%s/federate/%s?remember=on' % (
'http://testserver' if django.VERSION < (1, 9) else "",
provider
))
else:
self.assertEqual(response["Location"], '%s/federate/%s' % (
'http://testserver' if django.VERSION < (1, 9) else "",
provider
))
# let's follow the redirect
response = client.get('/federate/%s' % provider)
# we are redirected to the provider CAS for authentication
self.assertEqual(response.status_code, 302)
self.assertEqual(
response["Location"],
"%s/login?service=http%%3A%%2F%%2Ftestserver%%2Ffederate%%2F%s" % (
settings.CAS_FEDERATE_PROVIDERS[provider][0],
provider
)
)
# let's generate a ticket
ticket = utils.gen_st()
# we lauch a dummy CAS server that only validate once for the service
# http://testserver/federate/example.com with `ticket`
tests_utils.DummyCAS.run(
("http://testserver/federate/%s" % provider).encode("ascii"),
ticket.encode("ascii"),
settings.CAS_TEST_USER.encode("utf8"),
[],
cas_port
)
# we normally provide a good ticket and should be redirected to /login as the ticket
# get successfully validated again the dummy CAS
response = client.get('/federate/%s' % provider, {'ticket': ticket})
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], "%s/login" % (
'http://testserver' if django.VERSION < (1, 9) else ""
))
# follow the redirect
response = client.get("/login")
# we should get a page with a from with all widget hidden that auto POST to /login using
# javascript. If javascript is disabled, a "connect" button is showed
self.assertTrue(response.context['auto_submit'])
self.assertEqual(response.context['post_url'], '/login')
params = tests_utils.copy_form(response.context["form"])
# POST ge prefiled from parameters
response = client.post("/login", params)
# the user should now being authenticated using username test@`provider`
self.assert_logged(
client, response, username='%s@%s' % (settings.CAS_TEST_USER, provider)
)
tickets.append((provider, ticket, client))
# try to get a ticket
response = client.get("/login", {'service': self.service})
self.assertEqual(response.status_code, 302)
self.assertTrue(response["Location"].startswith("%s?ticket=" % self.service))
return tickets
def test_login_twice(self):
"""Test that user id db is used for the second login (cf coverage)"""
self.test_login_post_provider()
self.test_login_post_provider()
@override_settings(CAS_FEDERATE=False)
def test_auth_federate_false(self):
"""federated view should redirect to /login then CAS_FEDERATE is False"""
provider = "example.com"
client = Client()
response = client.get("/federate/%s" % provider)
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], "%s/login" % (
'http://testserver' if django.VERSION < (1, 9) else ""
))
response = client.post("%s/federate/%s" % (
'http://testserver' if django.VERSION < (1, 9) else "",
provider
))
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], "%s/login" % (
'http://testserver' if django.VERSION < (1, 9) else ""
))
def test_auth_federate_errors(self):
"""
The federated view should redirect to /login if the provider is unknown or not provided,
try to fetch a new ticket if the provided ticket validation fail
(network error or bad ticket)
"""
return
good_provider = "example.com"
bad_provider = "exemple.fr"
client = Client()
response = client.get("/federate/%s" % bad_provider)
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], "%s/login" % (
'http://testserver' if django.VERSION < (1, 9) else ""
))
# test CAS not avaible
response = client.get("/federate/%s" % good_provider, {'ticket': utils.gen_st()})
self.assertEqual(response.status_code, 302)
self.assertEqual(
response["Location"],
"%s/login?service=http%%3A%%2F%%2Ftestserver%%2Ffederate%%2F%s" % (
settings.CAS_FEDERATE_PROVIDERS[good_provider][0],
good_provider
)
)
# test CAS avaible but bad ticket
tests_utils.DummyCAS.run(
("http://testserver/federate/%s" % good_provider).encode("ascii"),
utils.gen_st().encode("ascii"),
settings.CAS_TEST_USER.encode("utf-8"),
[],
8080
)
response = client.get("/federate/%s" % good_provider, {'ticket': utils.gen_st()})
self.assertEqual(response.status_code, 302)
self.assertEqual(
response["Location"],
"%s/login?service=http%%3A%%2F%%2Ftestserver%%2Ffederate%%2F%s" % (
settings.CAS_FEDERATE_PROVIDERS[good_provider][0],
good_provider
)
)
response = client.post("/federate")
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], "%s/login" % (
'http://testserver' if django.VERSION < (1, 9) else ""
))
def test_auth_federate_slo(self):
"""test that SLO receive from backend CAS log out the users"""
# get tickets and connected clients
tickets = self.test_login_post_provider()
for (provider, ticket, client) in tickets:
# SLO for an unkown ticket should do nothing
response = client.post(
"/federate/%s" % provider,
{'logoutRequest': tests_utils.logout_request(utils.gen_st())}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"ok")
# Bad SLO format should do nothing
response = client.post(
"/federate/%s" % provider,
{'logoutRequest': ""}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"ok")
# Bad SLO format should do nothing
response = client.post(
"/federate/%s" % provider,
{'logoutRequest': "<root></root>"}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"ok")
response = client.get("/login")
self.assert_logged(
client, response, username='%s@%s' % (settings.CAS_TEST_USER, provider)
)
# SLO for a previously logged ticket should log out the user if CAS version is
# 3 or 'CAS_2_SAML_1_0'
response = client.post(
"/federate/%s" % provider,
{'logoutRequest': tests_utils.logout_request(ticket)}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"ok")
response = client.get("/login")
if settings.CAS_FEDERATE_PROVIDERS[provider][1] in {3, 'CAS_2_SAML_1_0'}: # support SLO
self.assert_login_failed(client, response)
else:
self.assert_logged(
client, response, username='%s@%s' % (settings.CAS_TEST_USER, provider)
)
def test_federate_logout(self):
"""
test the logout function: the user should be log out
and redirected to his CAS logout page
"""
# get tickets and connected clients
tickets = self.test_login_post_provider()
for (provider, _, client) in tickets:
response = client.get("/logout")
self.assertEqual(response.status_code, 302)
self.assertEqual(
response["Location"],
"%s/logout" % settings.CAS_FEDERATE_PROVIDERS[provider][0]
)
response = client.get("/login")
self.assert_login_failed(client, response)
def test_remember_provider(self):
"""
If the user check remember, next login should not offer the chose of the backend CAS
and use the one store in the cookie
"""
tickets = self.test_login_post_provider(remember=True)
for (provider, _, client) in tickets:
client.get("/logout")
response = client.get("/login")
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], "%s/federate/%s" % (
'http://testserver' if django.VERSION < (1, 9) else "",
provider
))
def test_login_bad_ticket(self):
"""
Try login with a bad ticket:
login should fail and the main login page should be displayed to the user
"""
provider = "example.com"
# get a bare client
client = Client()
session = client.session
session["federate_username"] = '%s@%s' % (settings.CAS_TEST_USER, provider)
session["federate_ticket"] = utils.gen_st()
try:
session.save()
response = client.get("/login")
# we should get a page with a from with all widget hidden that auto POST to /login using
# javascript. If javascript is disabled, a "connect" button is showed
self.assertTrue(response.context['auto_submit'])
self.assertEqual(response.context['post_url'], '/login')
params = tests_utils.copy_form(response.context["form"])
# POST, as (username, ticket) are not valid, we should get the federate login page
response = client.post("/login", params)
self.assertEqual(response.status_code, 200)
for key, value in settings.CAS_FEDERATE_PROVIDERS.items():
self.assertTrue('<option value="%s">%s</option>' % (
key,
utils.get_tuple(value, 2, key)
) in response.content.decode("utf-8"))
self.assertEqual(response.context['post_url'], '/federate')
except AttributeError:
pass

View File

@ -12,20 +12,81 @@
"""Tests module for models""" """Tests module for models"""
from cas_server.default_settings import settings from cas_server.default_settings import settings
from django.test import TestCase from django.test import TestCase, Client
from django.test.utils import override_settings from django.test.utils import override_settings
from django.utils import timezone from django.utils import timezone
from datetime import timedelta from datetime import timedelta
from importlib import import_module from importlib import import_module
from cas_server import models from cas_server import models, utils
from cas_server.tests.utils import get_auth_client, HttpParamsHandler from cas_server.tests.utils import get_auth_client, HttpParamsHandler
from cas_server.tests.mixin import UserModels, BaseServicePattern from cas_server.tests.mixin import UserModels, BaseServicePattern
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
class FederatedUserTestCase(TestCase, UserModels):
"""test for the federated user model"""
def test_clean_old_entries(self):
"""tests for clean_old_entries that should delete federated user no longer used"""
client = Client()
client.get("/login")
models.FederatedUser.objects.create(
username="test1", provider="example.com", attributs={}, ticket=""
)
models.FederatedUser.objects.create(
username="test2", provider="example.com", attributs={}, ticket=""
)
models.FederatedUser.objects.all().update(
last_update=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT + 10))
)
models.FederatedUser.objects.create(
username="test3", provider="example.com", attributs={}, ticket=""
)
models.User.objects.create(
username="test1@example.com", session_key=client.session.session_key
)
models.FederatedUser.clean_old_entries()
self.assertEqual(len(models.FederatedUser.objects.all()), 2)
with self.assertRaises(models.FederatedUser.DoesNotExist):
models.FederatedUser.objects.get(username="test2")
class FederateSLOTestCase(TestCase, UserModels):
"""test for the federated SLO model"""
def test_clean_deleted_sessions(self):
"""
tests for clean_deleted_sessions that should delete object for which matching session
do not exists anymore
"""
client1 = Client()
client2 = Client()
client1.get("/login")
client2.get("/login")
session = client2.session
session['authenticated'] = True
try:
session.save()
except AttributeError:
pass
models.FederateSLO.objects.create(
username="test1@example.com",
session_key=client1.session.session_key,
ticket=utils.gen_st()
)
models.FederateSLO.objects.create(
username="test2@example.com",
session_key=client2.session.session_key,
ticket=utils.gen_st()
)
self.assertEqual(len(models.FederateSLO.objects.all()), 2)
models.FederateSLO.clean_deleted_sessions()
self.assertEqual(len(models.FederateSLO.objects.all()), 1)
with self.assertRaises(models.FederateSLO.DoesNotExist):
models.FederateSLO.objects.get(username="test1@example.com")
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser') @override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class UserTestCase(TestCase, UserModels): class UserTestCase(TestCase, UserModels):
"""tests for the user models""" """tests for the user models"""

View File

@ -10,7 +10,7 @@
# #
# (c) 2016 Valentin Samir # (c) 2016 Valentin Samir
"""Tests module for utils""" """Tests module for utils"""
from django.test import TestCase from django.test import TestCase, RequestFactory
import six import six
@ -189,3 +189,22 @@ class UtilsTestCase(TestCase):
self.assertFalse(utils.crypt_salt_is_valid("$$")) # start with $ followed by $ self.assertFalse(utils.crypt_salt_is_valid("$$")) # start with $ followed by $
self.assertFalse(utils.crypt_salt_is_valid("$toto")) # start with $ but no secondary $ self.assertFalse(utils.crypt_salt_is_valid("$toto")) # start with $ but no secondary $
self.assertFalse(utils.crypt_salt_is_valid("$toto$toto")) # algorithm toto not known self.assertFalse(utils.crypt_salt_is_valid("$toto$toto")) # algorithm toto not known
def test_get_current_url(self):
"""test the function get_current_url"""
factory = RequestFactory()
request = factory.get('/truc/muche?test=1')
self.assertEqual(utils.get_current_url(request), 'http://testserver/truc/muche?test=1')
self.assertEqual(
utils.get_current_url(request, ignore_params={'test'}),
'http://testserver/truc/muche'
)
def test_get_tuple(self):
"""test the function get_tuple"""
test_tuple = (1, 2, 3)
for index, value in enumerate(test_tuple):
self.assertEqual(utils.get_tuple(test_tuple, index), value)
self.assertEqual(utils.get_tuple(test_tuple, 3), None)
self.assertEqual(utils.get_tuple(test_tuple, 3, 'toto'), 'toto')
self.assertEqual(utils.get_tuple(None, 3), None)

View File

@ -1,4 +1,4 @@
# *- coding: utf-8 -*- # -*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT # This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for # FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
@ -36,57 +36,17 @@ from cas_server.tests.utils import (
HttpParamsHandler, HttpParamsHandler,
Http404Handler Http404Handler
) )
from cas_server.tests.mixin import BaseServicePattern, XmlContent from cas_server.tests.mixin import BaseServicePattern, XmlContent, CanLogin
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser') @override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class LoginTestCase(TestCase, BaseServicePattern): class LoginTestCase(TestCase, BaseServicePattern, CanLogin):
"""Tests for the login view""" """Tests for the login view"""
def setUp(self): def setUp(self):
"""Prepare the test context:""" """Prepare the test context:"""
# we prepare a bunch a service url and service patterns for tests # we prepare a bunch a service url and service patterns for tests
self.setup_service_patterns() self.setup_service_patterns()
def assert_logged(self, client, response, warn=False, code=200):
"""Assertions testing that client is well authenticated"""
self.assertEqual(response.status_code, code)
# this message is displayed to the user upon successful authentication
self.assertTrue(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
# these session variables a set if usccessfully authenticated
self.assertTrue(client.session["username"] == settings.CAS_TEST_USER)
self.assertTrue(client.session["warn"] is warn)
self.assertTrue(client.session["authenticated"] is True)
# on successfull authentication, a corresponding user object is created
self.assertTrue(
models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
)
def assert_login_failed(self, client, response, code=200):
"""Assertions testing a failed login attempt"""
self.assertEqual(response.status_code, code)
# this message is displayed to the user upon successful authentication, so it should not
# appear
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
# if authentication has failed, these session variables should not be set
self.assertTrue(client.session.get("username") is None)
self.assertTrue(client.session.get("warn") is None)
self.assertTrue(client.session.get("authenticated") is None)
def test_login_view_post_goodpass_goodlt(self): def test_login_view_post_goodpass_goodlt(self):
"""Test a successul login""" """Test a successul login"""
# we get a client who fetch a frist time the login page and the login form default # we get a client who fetch a frist time the login page and the login form default

View File

@ -13,14 +13,33 @@
from cas_server.default_settings import settings from cas_server.default_settings import settings
from django.test import Client from django.test import Client
from django.template import loader, Context
from django.utils import timezone
import cgi import cgi
import six
from threading import Thread from threading import Thread
from lxml import etree from lxml import etree
from six.moves import BaseHTTPServer from six.moves import BaseHTTPServer
from six.moves.urllib.parse import urlparse, parse_qsl from six.moves.urllib.parse import urlparse, parse_qsl
from datetime import timedelta
from cas_server import models from cas_server import models
from cas_server import utils
def return_unicode(string, charset):
if not isinstance(string, six.text_type):
return string.decode(charset)
else:
return string
def return_bytes(string, charset):
if isinstance(string, six.text_type):
return string.encode(charset)
else:
return string
def copy_form(form): def copy_form(form):
@ -149,10 +168,10 @@ class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler):
return return
@classmethod @classmethod
def run(cls): def run(cls, port=0):
"""Run a BaseHTTPServer using this class as handler""" """Run a BaseHTTPServer using this class as handler"""
server_class = BaseHTTPServer.HTTPServer server_class = BaseHTTPServer.HTTPServer
httpd = server_class(("127.0.0.1", 0), cls) httpd = server_class(("127.0.0.1", port), cls)
(host, port) = httpd.socket.getsockname() (host, port) = httpd.socket.getsockname()
def lauch(): def lauch():
@ -178,3 +197,143 @@ class Http404Handler(HttpParamsHandler):
def do_POST(self): def do_POST(self):
"""Called on a POST request on the BaseHTTPServer""" """Called on a POST request on the BaseHTTPServer"""
return self.do_GET() return self.do_GET()
class DummyCAS(BaseHTTPServer.BaseHTTPRequestHandler):
def test_params(self):
if (
self.server.ticket is not None and
self.params.get("service").encode("ascii") == self.server.service and
self.params.get("ticket").encode("ascii") == self.server.ticket
):
self.server.ticket = None
print("good")
return True
else:
print("bad (%r, %r) != (%r, %r)" % (
self.params.get("service").encode("ascii"),
self.params.get("ticket").encode("ascii"),
self.server.service,
self.server.ticket
))
return False
def send_headers(self, code, content_type):
self.send_response(200)
self.send_header("Content-type", content_type)
self.end_headers()
def do_GET(self):
url = urlparse(self.path)
self.params = dict(parse_qsl(url.query))
if url.path == "/validate":
self.send_headers(200, "text/plain; charset=utf-8")
if self.test_params():
self.wfile.write(b"yes\n" + self.server.username + b"\n")
self.server.ticket = None
else:
self.wfile.write(b"no\n")
elif url.path in {
'/serviceValidate', '/serviceValidate',
'/p3/serviceValidate', '/p3/proxyValidate'
}:
self.send_headers(200, "text/xml; charset=utf-8")
if self.test_params():
t = loader.get_template('cas_server/serviceValidate.xml')
c = Context({
'username': self.server.username,
'attributes': self.server.attributes
})
self.wfile.write(return_bytes(t.render(c), "utf8"))
else:
t = loader.get_template('cas_server/serviceValidateError.xml')
c = Context({
'code': 'BAD_SERVICE_TICKET',
'msg': 'Valids are (%r, %r)' % (self.server.service, self.server.ticket)
})
self.wfile.write(return_bytes(t.render(c), "utf8"))
else:
self.return_404()
def do_POST(self):
url = urlparse(self.path)
self.params = dict(parse_qsl(url.query))
if url.path == "/samlValidate":
self.send_headers(200, "text/xml; charset=utf-8")
length = int(self.headers.get('content-length'))
root = etree.fromstring(self.rfile.read(length))
auth_req = root.getchildren()[1].getchildren()[0]
ticket = auth_req.getchildren()[0].text.encode("ascii")
if (
self.server.ticket is not None and
self.params.get("TARGET").encode("ascii") == self.server.service and
ticket == self.server.ticket
):
self.server.ticket = None
t = loader.get_template('cas_server/samlValidate.xml')
c = Context({
'IssueInstant': timezone.now().isoformat(),
'expireInstant': (timezone.now() + timedelta(seconds=60)).isoformat(),
'Recipient': self.server.service,
'ResponseID': utils.gen_saml_id(),
'username': self.server.username,
'attributes': self.server.attributes,
})
self.wfile.write(return_bytes(t.render(c), "utf8"))
else:
t = loader.get_template('cas_server/samlValidateError.xml')
c = Context({
'IssueInstant': timezone.now().isoformat(),
'ResponseID': utils.gen_saml_id(),
'code': 'BAD_SERVICE_TICKET',
'msg': 'Valids are (%r, %r)' % (self.server.service, self.server.ticket)
})
self.wfile.write(return_bytes(t.render(c), "utf8"))
else:
self.return_404()
def return_404(self):
self.send_response(404)
self.send_header(b"Content-type", "text/plain")
self.end_headers()
self.wfile.write("not found")
def log_message(self, *args):
"""silent any log message"""
return
@classmethod
def run(cls, service, ticket, username, attributes, port=0):
"""Run a BaseHTTPServer using this class as handler"""
server_class = BaseHTTPServer.HTTPServer
httpd = server_class(("127.0.0.1", port), cls)
httpd.service = service
httpd.ticket = ticket
httpd.username = username
httpd.attributes = attributes
(host, port) = httpd.socket.getsockname()
def lauch():
"""routine to lauch in a background thread"""
httpd.handle_request()
httpd.server_close()
httpd_thread = Thread(target=lauch)
httpd_thread.daemon = True
httpd_thread.start()
return (httpd, host, port)
def logout_request(ticket):
return u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
<samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
</samlp:LogoutRequest>""" % \
{
'id': utils.gen_saml_id(),
'datetime': timezone.now().isoformat(),
'ticket': ticket
}

View File

@ -123,15 +123,18 @@ class LogoutView(View, LogoutMixin):
self.init_get(request) self.init_get(request)
# if CAS federation mode is enable, bakup the provider before flushing the sessions # if CAS federation mode is enable, bakup the provider before flushing the sessions
if settings.CAS_FEDERATE: if settings.CAS_FEDERATE:
component = self.request.session.get("username").split('@') if "username" in self.request.session:
component = self.request.session["username"].split('@')
provider = component[-1] provider = component[-1]
auth = CASFederateValidateUser(provider, service_url="") auth = CASFederateValidateUser(provider, service_url="")
else:
auth = None
session_nb = self.logout(self.request.GET.get("all")) session_nb = self.logout(self.request.GET.get("all"))
# if CAS federation mode is enable, redirect to user CAS logout page # if CAS federation mode is enable, redirect to user CAS logout page
if settings.CAS_FEDERATE: if settings.CAS_FEDERATE:
if auth is not None:
params = utils.copy_params(request.GET) params = utils.copy_params(request.GET)
url = utils.update_url(auth.get_logout_url(), params) url = utils.update_url(auth.get_logout_url(), params)
if url:
return HttpResponseRedirect(url) return HttpResponseRedirect(url)
# if service is set, redirect to service after logout # if service is set, redirect to service after logout
if self.service: if self.service:
@ -195,7 +198,7 @@ class FederateAuth(View):
@staticmethod @staticmethod
def get_cas_client(request, provider): def get_cas_client(request, provider):
if provider in settings.CAS_FEDERATE_PROVIDERS: if provider in settings.CAS_FEDERATE_PROVIDERS: # pragma: no branch (should always be true)
service_url = utils.get_current_url(request, {"ticket", "provider"}) service_url = utils.get_current_url(request, {"ticket", "provider"})
return CASFederateValidateUser(provider, service_url) return CASFederateValidateUser(provider, service_url)
@ -207,14 +210,14 @@ class FederateAuth(View):
auth = self.get_cas_client(request, provider) auth = self.get_cas_client(request, provider)
try: try:
auth.clean_sessions(request.POST['logoutRequest']) auth.clean_sessions(request.POST['logoutRequest'])
except KeyError: except (KeyError, AttributeError):
pass pass
return HttpResponse("ok") return HttpResponse("ok")
# else, a User is trying to log in using an identity provider # else, a User is trying to log in using an identity provider
else: else:
# Manually checking for csrf to protect the code below # Manually checking for csrf to protect the code below
reason = CsrfViewMiddleware().process_view(request, None, (), {}) reason = CsrfViewMiddleware().process_view(request, None, (), {})
if reason is not None: if reason is not None: # pragma: no cover (csrf checks are disabled during tests)
return reason # Failed the test, stop here. return reason # Failed the test, stop here.
form = forms.FederateSelect(request.POST) form = forms.FederateSelect(request.POST)
if form.is_valid(): if form.is_valid():
@ -252,7 +255,7 @@ class FederateAuth(View):
ticket = request.GET['ticket'] ticket = request.GET['ticket']
if auth.verify_ticket(ticket): if auth.verify_ticket(ticket):
params = utils.copy_params(request.GET, ignore={"ticket"}) params = utils.copy_params(request.GET, ignore={"ticket"})
username = "%s@%s" % (auth.username, auth.provider) username = u"%s@%s" % (auth.username, auth.provider)
request.session["federate_username"] = username request.session["federate_username"] = username
request.session["federate_ticket"] = ticket request.session["federate_ticket"] = ticket
auth.register_slo(username, request.session.session_key, ticket) auth.register_slo(username, request.session.session_key, ticket)
@ -281,7 +284,7 @@ class LoginView(View, LogoutMixin):
renewed = False renewed = False
warned = False warned = False
if settings.CAS_FEDERATE: # used if CAS_FEDERATE is True
username = None username = None
ticket = None ticket = None
@ -354,7 +357,7 @@ class LoginView(View, LogoutMixin):
elif ret == self.USER_LOGIN_FAILURE: # bad user login elif ret == self.USER_LOGIN_FAILURE: # bad user login
if settings.CAS_FEDERATE: if settings.CAS_FEDERATE:
self.ticket = None self.ticket = None
self.usernalme = None self.username = None
self.init_form() self.init_form()
self.logout() self.logout()
elif ret == self.USER_ALREADY_LOGGED: elif ret == self.USER_ALREADY_LOGGED:
@ -682,11 +685,14 @@ class Auth(View):
secret = request.POST.get('secret') secret = request.POST.get('secret')
if not settings.CAS_AUTH_SHARED_SECRET: if not settings.CAS_AUTH_SHARED_SECRET:
return HttpResponse("no\nplease set CAS_AUTH_SHARED_SECRET", content_type="text/plain") return HttpResponse(
"no\nplease set CAS_AUTH_SHARED_SECRET",
content_type="text/plain; charset=utf-8"
)
if secret != settings.CAS_AUTH_SHARED_SECRET: if secret != settings.CAS_AUTH_SHARED_SECRET:
return HttpResponse("no\n", content_type="text/plain") return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
if not username or not password or not service: if not username or not password or not service:
return HttpResponse("no\n", content_type="text/plain") return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
form = forms.UserCredential( form = forms.UserCredential(
request.POST, request.POST,
initial={ initial={
@ -714,11 +720,11 @@ class Auth(View):
service_pattern.check_user(user) service_pattern.check_user(user)
if not request.session.get("authenticated"): if not request.session.get("authenticated"):
user.delete() user.delete()
return HttpResponse("yes\n", content_type="text/plain") return HttpResponse(u"yes\n", content_type="text/plain; charset=utf-8")
except (ServicePattern.DoesNotExist, models.ServicePatternException): except (ServicePattern.DoesNotExist, models.ServicePatternException):
return HttpResponse("no\n", content_type="text/plain") return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
else: else:
return HttpResponse("no\n", content_type="text/plain") return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
class Validate(View): class Validate(View):
@ -758,7 +764,10 @@ class Validate(View):
username = username[0] username = username[0]
else: else:
username = ticket.user.username username = ticket.user.username
return HttpResponse("yes\n%s\n" % username, content_type="text/plain") return HttpResponse(
u"yes\n%s\n" % username,
content_type="text/plain; charset=utf-8"
)
except ServiceTicket.DoesNotExist: except ServiceTicket.DoesNotExist:
logger.warning( logger.warning(
( (
@ -769,10 +778,10 @@ class Validate(View):
service service
) )
) )
return HttpResponse("no\n", content_type="text/plain") return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
else: else:
logger.warning("Validate: service or ticket missing") logger.warning("Validate: service or ticket missing")
return HttpResponse("no\n", content_type="text/plain") return HttpResponse(u"no\n", content_type="text/plain; charset=utf-8")
class ValidateError(Exception): class ValidateError(Exception):
@ -815,8 +824,8 @@ class ValidateService(View, AttributesMixin):
if not self.service or not self.ticket: if not self.service or not self.ticket:
logger.warning("ValidateService: missing ticket or service") logger.warning("ValidateService: missing ticket or service")
return ValidateError( return ValidateError(
'INVALID_REQUEST', u'INVALID_REQUEST',
"you must specify a service and a ticket" u"you must specify a service and a ticket"
).render(request) ).render(request)
else: else:
try: try:
@ -886,14 +895,14 @@ class ValidateService(View, AttributesMixin):
for prox in ticket.proxies.all(): for prox in ticket.proxies.all():
proxies.append(prox.url) proxies.append(prox.url)
else: else:
raise ValidateError('INVALID_TICKET', self.ticket) raise ValidateError(u'INVALID_TICKET', self.ticket)
ticket.validate = True ticket.validate = True
ticket.save() ticket.save()
if ticket.service != self.service: if ticket.service != self.service:
raise ValidateError('INVALID_SERVICE', self.service) raise ValidateError(u'INVALID_SERVICE', self.service)
return ticket, proxies return ticket, proxies
except (ServiceTicket.DoesNotExist, ProxyTicket.DoesNotExist): except (ServiceTicket.DoesNotExist, ProxyTicket.DoesNotExist):
raise ValidateError('INVALID_TICKET', 'ticket not found') raise ValidateError(u'INVALID_TICKET', 'ticket not found')
def process_pgturl(self, params): def process_pgturl(self, params):
"""Handle PGT request""" """Handle PGT request"""
@ -939,18 +948,18 @@ class ValidateService(View, AttributesMixin):
except requests.exceptions.RequestException as error: except requests.exceptions.RequestException as error:
error = utils.unpack_nested_exception(error) error = utils.unpack_nested_exception(error)
raise ValidateError( raise ValidateError(
'INVALID_PROXY_CALLBACK', u'INVALID_PROXY_CALLBACK',
"%s: %s" % (type(error), str(error)) u"%s: %s" % (type(error), str(error))
) )
else: else:
raise ValidateError( raise ValidateError(
'INVALID_PROXY_CALLBACK', u'INVALID_PROXY_CALLBACK',
"callback url not allowed by configuration" u"callback url not allowed by configuration"
) )
except ServicePattern.DoesNotExist: except ServicePattern.DoesNotExist:
raise ValidateError( raise ValidateError(
'INVALID_PROXY_CALLBACK', u'INVALID_PROXY_CALLBACK',
'callback url not allowed by configuration' u'callback url not allowed by configuration'
) )
@ -971,8 +980,8 @@ class Proxy(View):
return self.process_proxy() return self.process_proxy()
else: else:
raise ValidateError( raise ValidateError(
'INVALID_REQUEST', u'INVALID_REQUEST',
"you must specify and pgt and targetService" u"you must specify and pgt and targetService"
) )
except ValidateError as error: except ValidateError as error:
logger.warning("Proxy: validation error: %s %s" % (error.code, error.msg)) logger.warning("Proxy: validation error: %s %s" % (error.code, error.msg))
@ -985,8 +994,8 @@ class Proxy(View):
pattern = ServicePattern.validate(self.target_service) pattern = ServicePattern.validate(self.target_service)
if not pattern.proxy: if not pattern.proxy:
raise ValidateError( raise ValidateError(
'UNAUTHORIZED_SERVICE', u'UNAUTHORIZED_SERVICE',
'the service %s do not allow proxy ticket' % self.target_service u'the service %s do not allow proxy ticket' % self.target_service
) )
# is the proxy granting ticket valid # is the proxy granting ticket valid
ticket = ProxyGrantingTicket.objects.get( ticket = ProxyGrantingTicket.objects.get(
@ -1015,13 +1024,13 @@ class Proxy(View):
content_type="text/xml; charset=utf-8" content_type="text/xml; charset=utf-8"
) )
except ProxyGrantingTicket.DoesNotExist: except ProxyGrantingTicket.DoesNotExist:
raise ValidateError('INVALID_TICKET', 'PGT %s not found' % self.pgt) raise ValidateError(u'INVALID_TICKET', u'PGT %s not found' % self.pgt)
except ServicePattern.DoesNotExist: except ServicePattern.DoesNotExist:
raise ValidateError('UNAUTHORIZED_SERVICE', self.target_service) raise ValidateError(u'UNAUTHORIZED_SERVICE', self.target_service)
except (models.BadUsername, models.BadFilter, models.UserFieldNotDefined): except (models.BadUsername, models.BadFilter, models.UserFieldNotDefined):
raise ValidateError( raise ValidateError(
'UNAUTHORIZED_USER', u'UNAUTHORIZED_USER',
'User %s not allowed on %s' % (ticket.user.username, self.target_service) u'User %s not allowed on %s' % (ticket.user.username, self.target_service)
) )
@ -1129,18 +1138,18 @@ class SamlValidate(View, AttributesMixin):
) )
else: else:
raise SamlValidateError( raise SamlValidateError(
'AuthnFailed', u'AuthnFailed',
'ticket %s should begin with PT- or ST-' % ticket u'ticket %s should begin with PT- or ST-' % ticket
) )
ticket.validate = True ticket.validate = True
ticket.save() ticket.save()
if ticket.service != self.target: if ticket.service != self.target:
raise SamlValidateError( raise SamlValidateError(
'AuthnFailed', u'AuthnFailed',
'TARGET %s do not match ticket service' % self.target u'TARGET %s do not match ticket service' % self.target
) )
return ticket return ticket
except (IndexError, KeyError): except (IndexError, KeyError):
raise SamlValidateError('VersionMismatch') raise SamlValidateError(u'VersionMismatch')
except (ServiceTicket.DoesNotExist, ProxyTicket.DoesNotExist): except (ServiceTicket.DoesNotExist, ProxyTicket.DoesNotExist):
raise SamlValidateError('AuthnFailed', 'ticket %s not found' % ticket) raise SamlValidateError(u'AuthnFailed', u'ticket %s not found' % ticket)