django-cas-server/cas_server/tests/test_models.py

234 lines
9.6 KiB
Python

# -*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
# more details.
#
# You should have received a copy of the GNU General Public License version 3
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# (c) 2016 Valentin Samir
"""Tests module for models"""
from cas_server.default_settings import settings
import django
from django.test import TestCase, Client
from django.test.utils import override_settings
from django.utils import timezone
from datetime import timedelta
from importlib import import_module
from cas_server import models, utils
from cas_server.tests.utils import get_auth_client, HttpParamsHandler
from cas_server.tests.mixin import UserModels, BaseServicePattern, FederatedIendityProviderModel
from cas_server.tests.test_federate import PROVIDERS
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
class FederatedUserTestCase(TestCase, UserModels, FederatedIendityProviderModel):
"""test for the federated user model"""
def setUp(self):
"""Prepare the test context"""
self.setup_federated_identity_provider(PROVIDERS)
def test_clean_old_entries(self):
"""tests for clean_old_entries that should delete federated user no longer used"""
client = Client()
client.get("/login")
provider = models.FederatedIendityProvider.objects.get(suffix="example.com")
models.FederatedUser.objects.create(
username="test1", provider=provider, attributs={}, ticket=""
)
models.FederatedUser.objects.create(
username="test2", provider=provider, attributs={}, ticket=""
)
models.FederatedUser.objects.all().update(
last_update=(timezone.now() - timedelta(seconds=settings.CAS_TICKET_TIMEOUT + 10))
)
models.FederatedUser.objects.create(
username="test3", provider=provider, attributs={}, ticket=""
)
models.User.objects.create(
username="test1@example.com", session_key=client.session.session_key
)
self.assertEqual(len(models.FederatedUser.objects.all()), 3)
models.FederatedUser.clean_old_entries()
self.assertEqual(len(models.FederatedUser.objects.all()), 2)
with self.assertRaises(models.FederatedUser.DoesNotExist):
models.FederatedUser.objects.get(username="test2")
class FederateSLOTestCase(TestCase, UserModels):
"""test for the federated SLO model"""
def test_clean_deleted_sessions(self):
"""
tests for clean_deleted_sessions that should delete object for which matching session
do not exists anymore
"""
if django.VERSION >= (1, 8):
client1 = Client()
client2 = Client()
client1.get("/login")
client2.get("/login")
session = client2.session
session['authenticated'] = True
session.save()
models.FederateSLO.objects.create(
username="test1@example.com",
session_key=client1.session.session_key,
ticket=utils.gen_st()
)
models.FederateSLO.objects.create(
username="test2@example.com",
session_key=client2.session.session_key,
ticket=utils.gen_st()
)
self.assertEqual(len(models.FederateSLO.objects.all()), 2)
models.FederateSLO.clean_deleted_sessions()
self.assertEqual(len(models.FederateSLO.objects.all()), 1)
with self.assertRaises(models.FederateSLO.DoesNotExist):
models.FederateSLO.objects.get(username="test1@example.com")
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class UserTestCase(TestCase, UserModels):
"""tests for the user models"""
def setUp(self):
"""Prepare the test context"""
self.service = 'http://127.0.0.1:45678'
self.service_pattern = models.ServicePattern.objects.create(
name="localhost",
pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
single_log_out=True
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_clean_old_entries(self):
"""test clean_old_entries"""
# get an authenticated client
client = self.expire_user()
# assert the user exists before being cleaned
self.assertEqual(len(models.User.objects.all()), 1)
# assert the last activity date is before the expiry date
self.assertTrue(
self.get_user(client).date < (
timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE)
)
)
# delete old inactive users
models.User.clean_old_entries()
# assert the user has being well delete
self.assertEqual(len(models.User.objects.all()), 0)
def test_clean_deleted_sessions(self):
"""test clean_deleted_sessions"""
# get an authenticated client
client1 = get_auth_client()
client2 = get_auth_client()
# generate a ticket to fire SLO during user cleaning (SLO should fail a nothing listen
# on self.service)
ticket = self.get_user(client1).get_ticket(
models.ServiceTicket,
self.service,
self.service_pattern,
renew=False
)
ticket.validate = True
ticket.save()
# simulated expired session being garbage collected for client1
session = SessionStore(session_key=client1.session.session_key)
session.flush()
# assert the user exists before being cleaned
self.assertTrue(self.get_user(client1))
self.assertTrue(self.get_user(client2))
self.assertEqual(len(models.User.objects.all()), 2)
# session has being remove so the user of client1 is no longer authenticated
self.assertFalse(client1.session.get("authenticated"))
# the user a client2 should still be authenticated
self.assertTrue(client2.session.get("authenticated"))
# the user should be deleted
models.User.clean_deleted_sessions()
# assert the user with expired sessions has being well deleted but the other remain
self.assertEqual(len(models.User.objects.all()), 1)
self.assertFalse(models.ServiceTicket.objects.all())
self.assertTrue(client2.session.get("authenticated"))
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class TicketTestCase(TestCase, UserModels, BaseServicePattern):
"""tests for the tickets models"""
def setUp(self):
"""Prepare the test context"""
self.setup_service_patterns()
self.service = 'http://127.0.0.1:45678'
self.service_pattern = models.ServicePattern.objects.create(
name="localhost",
pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
single_log_out=True
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
@staticmethod
def get_ticket(
user,
ticket_class,
service,
service_pattern,
renew=False,
validate=False,
validity_expired=False,
timeout_expired=False,
single_log_out=False,
):
"""Return a ticket"""
ticket = user.get_ticket(ticket_class, service, service_pattern, renew)
ticket.validate = validate
ticket.single_log_out = single_log_out
if validity_expired:
ticket.creation = min(
ticket.creation,
(timezone.now() - timedelta(seconds=(ticket_class.VALIDITY + 10)))
)
if timeout_expired:
ticket.creation = min(
ticket.creation,
(timezone.now() - timedelta(seconds=(ticket_class.TIMEOUT + 10)))
)
ticket.save()
return ticket
def test_clean_old_service_ticket(self):
"""test tickets clean_old_entries"""
# ge an authenticated client
client = get_auth_client()
# get the user associated to the client
user = self.get_user(client)
# generate a ticket for that client, waiting for validation
self.get_ticket(user, models.ServiceTicket, self.service, self.service_pattern)
# generate another ticket for those validation time has expired
self.get_ticket(
user, models.ServiceTicket,
self.service, self.service_pattern, validity_expired=True
)
(httpd, host, port) = HttpParamsHandler.run()[0:3]
service = "http://%s:%s" % (host, port)
# generate a ticket with SLO having timeout reach
self.get_ticket(
user, models.ServiceTicket,
service, self.service_pattern, timeout_expired=True,
validate=True, single_log_out=True
)
# there should be 3 tickets in the db
self.assertEqual(len(models.ServiceTicket.objects.all()), 3)
# we call the clean_old_entries method that should delete validated non SLO ticket and
# expired non validated ticket and send SLO for SLO expired ticket before deleting then
models.ServiceTicket.clean_old_entries()
params = httpd.PARAMS
# we successfully got a SLO request
self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest'])
# only 1 ticket remain in the db
self.assertEqual(len(models.ServiceTicket.objects.all()), 1)