Models full coverage
This commit is contained in:
parent
9814bb0a6f
commit
ecf4e66e11
|
@ -1,7 +1,9 @@
|
||||||
[run]
|
[run]
|
||||||
branch = True
|
branch = True
|
||||||
source = cas_server
|
source = cas_server
|
||||||
omit = cas_server/migrations*
|
omit =
|
||||||
|
cas_server/migrations*
|
||||||
|
cas_server/management/*
|
||||||
|
|
||||||
[report]
|
[report]
|
||||||
exclude_lines =
|
exclude_lines =
|
||||||
|
|
|
@ -84,7 +84,7 @@ class User(models.Model):
|
||||||
ticket.logout(session, async_list)
|
ticket.logout(session, async_list)
|
||||||
queryset.delete()
|
queryset.delete()
|
||||||
for future in async_list:
|
for future in async_list:
|
||||||
if future:
|
if future: # pragma: no branch (should always be true)
|
||||||
try:
|
try:
|
||||||
future.result()
|
future.result()
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
|
@ -112,12 +112,20 @@ class User(models.Model):
|
||||||
(a.name, a.replace if a.replace else a.name) for a in service_pattern.attributs.all()
|
(a.name, a.replace if a.replace else a.name) for a in service_pattern.attributs.all()
|
||||||
)
|
)
|
||||||
replacements = dict(
|
replacements = dict(
|
||||||
(a.name, (a.pattern, a.replace)) for a in service_pattern.replacements.all()
|
(a.attribut, (a.pattern, a.replace)) for a in service_pattern.replacements.all()
|
||||||
)
|
)
|
||||||
service_attributs = {}
|
service_attributs = {}
|
||||||
for (key, value) in self.attributs.items():
|
for (key, value) in self.attributs.items():
|
||||||
if key in attributs or '*' in attributs:
|
if key in attributs or '*' in attributs:
|
||||||
if key in replacements:
|
if key in replacements:
|
||||||
|
if isinstance(value, list):
|
||||||
|
for index, subval in enumerate(value):
|
||||||
|
value[index] = re.sub(
|
||||||
|
replacements[key][0],
|
||||||
|
replacements[key][1],
|
||||||
|
subval
|
||||||
|
)
|
||||||
|
else:
|
||||||
value = re.sub(replacements[key][0], replacements[key][1], value)
|
value = re.sub(replacements[key][0], replacements[key][1], value)
|
||||||
service_attributs[attributs.get(key, key)] = value
|
service_attributs[attributs.get(key, key)] = value
|
||||||
ticket = ticket_class.objects.create(
|
ticket = ticket_class.objects.create(
|
||||||
|
@ -396,7 +404,6 @@ class Ticket(models.Model):
|
||||||
).delete()
|
).delete()
|
||||||
|
|
||||||
# sending SLO to timed-out validated tickets
|
# sending SLO to timed-out validated tickets
|
||||||
if cls.TIMEOUT and cls.TIMEOUT > 0:
|
|
||||||
async_list = []
|
async_list = []
|
||||||
session = FuturesSession(
|
session = FuturesSession(
|
||||||
executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
|
executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
|
||||||
|
@ -405,10 +412,10 @@ class Ticket(models.Model):
|
||||||
creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT))
|
creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT))
|
||||||
)
|
)
|
||||||
for ticket in queryset:
|
for ticket in queryset:
|
||||||
ticket.logout(None, session, async_list)
|
ticket.logout(session, async_list)
|
||||||
queryset.delete()
|
queryset.delete()
|
||||||
for future in async_list:
|
for future in async_list:
|
||||||
if future:
|
if future: # pragma: no branch (should always be true)
|
||||||
try:
|
try:
|
||||||
future.result()
|
future.result()
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
|
@ -420,7 +427,7 @@ class Ticket(models.Model):
|
||||||
# On logout invalidate the Ticket
|
# On logout invalidate the Ticket
|
||||||
self.validate = True
|
self.validate = True
|
||||||
self.save()
|
self.save()
|
||||||
if self.validate and self.single_log_out:
|
if self.validate and self.single_log_out: # pragma: no branch (should always be true)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Sending SLO requests to service %s for user %s" % (
|
"Sending SLO requests to service %s for user %s" % (
|
||||||
self.service,
|
self.service,
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
"""Some mixin classes for tests"""
|
"""Some mixin classes for tests"""
|
||||||
from cas_server.default_settings import settings
|
from cas_server.default_settings import settings
|
||||||
|
from django.utils import timezone
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from lxml import etree
|
from lxml import etree
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
from cas_server import models
|
from cas_server import models
|
||||||
|
from cas_server.tests.utils import get_auth_client
|
||||||
|
|
||||||
|
|
||||||
class BaseServicePattern(object):
|
class BaseServicePattern(object):
|
||||||
|
@ -52,6 +55,17 @@ class BaseServicePattern(object):
|
||||||
pattern="^admin$",
|
pattern="^admin$",
|
||||||
service_pattern=self.service_pattern_filter_fail
|
service_pattern=self.service_pattern_filter_fail
|
||||||
)
|
)
|
||||||
|
self.service_filter_fail_alt = "https://filter_fail_alt.example.com"
|
||||||
|
self.service_pattern_filter_fail_alt = models.ServicePattern.objects.create(
|
||||||
|
name="filter_fail_alt",
|
||||||
|
pattern="^https://filter_fail_alt\.example\.com(/.*)?$",
|
||||||
|
proxy=proxy,
|
||||||
|
)
|
||||||
|
models.FilterAttributValue.objects.create(
|
||||||
|
attribut="nom",
|
||||||
|
pattern="^toto$",
|
||||||
|
service_pattern=self.service_pattern_filter_fail_alt
|
||||||
|
)
|
||||||
self.service_filter_success = "https://filter_success.example.com"
|
self.service_filter_success = "https://filter_success.example.com"
|
||||||
self.service_pattern_filter_success = models.ServicePattern.objects.create(
|
self.service_pattern_filter_success = models.ServicePattern.objects.create(
|
||||||
name="filter_success",
|
name="filter_success",
|
||||||
|
@ -143,3 +157,24 @@ class XmlContent(object):
|
||||||
self.assertEqual(attrs1, original)
|
self.assertEqual(attrs1, original)
|
||||||
|
|
||||||
return root
|
return root
|
||||||
|
|
||||||
|
|
||||||
|
class UserModels(object):
|
||||||
|
"""Mixin for test on CAS user models"""
|
||||||
|
def expire_user(self):
|
||||||
|
"""return an expired user"""
|
||||||
|
client = get_auth_client()
|
||||||
|
|
||||||
|
new_date = timezone.now() - timedelta(seconds=(settings.SESSION_COOKIE_AGE + 600))
|
||||||
|
models.User.objects.filter(
|
||||||
|
username=settings.CAS_TEST_USER,
|
||||||
|
session_key=client.session.session_key
|
||||||
|
).update(date=new_date)
|
||||||
|
return client
|
||||||
|
|
||||||
|
def get_user(self, client):
|
||||||
|
"""return the user associated with an authenticated client"""
|
||||||
|
return models.User.objects.get(
|
||||||
|
username=settings.CAS_TEST_USER,
|
||||||
|
session_key=client.session.session_key
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,146 @@
|
||||||
|
"""Tests module for models"""
|
||||||
|
from cas_server.default_settings import settings
|
||||||
|
|
||||||
|
from django.test import TestCase
|
||||||
|
from django.test.utils import override_settings
|
||||||
|
from django.utils import timezone
|
||||||
|
|
||||||
|
from datetime import timedelta
|
||||||
|
from importlib import import_module
|
||||||
|
|
||||||
|
from cas_server import models
|
||||||
|
from cas_server import utils
|
||||||
|
from cas_server.tests.utils import get_auth_client
|
||||||
|
from cas_server.tests.mixin import UserModels, BaseServicePattern
|
||||||
|
|
||||||
|
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
|
||||||
|
class UserTestCase(TestCase, UserModels):
|
||||||
|
"""tests for the user models"""
|
||||||
|
def setUp(self):
|
||||||
|
"""Prepare the test context"""
|
||||||
|
self.service = 'http://127.0.0.1:45678'
|
||||||
|
self.service_pattern = models.ServicePattern.objects.create(
|
||||||
|
name="localhost",
|
||||||
|
pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
|
||||||
|
single_log_out=True
|
||||||
|
)
|
||||||
|
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
|
||||||
|
|
||||||
|
def test_clean_old_entries(self):
|
||||||
|
"""test clean_old_entries"""
|
||||||
|
# get an authenticated client
|
||||||
|
client = self.expire_user()
|
||||||
|
# assert the user exists before being cleaned
|
||||||
|
self.assertEqual(len(models.User.objects.all()), 1)
|
||||||
|
# assert the last activity date is before the expiry date
|
||||||
|
self.assertTrue(
|
||||||
|
self.get_user(client).date < (
|
||||||
|
timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# delete old inactive users
|
||||||
|
models.User.clean_old_entries()
|
||||||
|
# assert the user has being well delete
|
||||||
|
self.assertEqual(len(models.User.objects.all()), 0)
|
||||||
|
|
||||||
|
def test_clean_deleted_sessions(self):
|
||||||
|
"""test clean_deleted_sessions"""
|
||||||
|
# get an authenticated client
|
||||||
|
client1 = get_auth_client()
|
||||||
|
client2 = get_auth_client()
|
||||||
|
# generate a ticket to fire SLO during user cleaning (SLO should fail a nothing listen
|
||||||
|
# on self.service)
|
||||||
|
ticket = self.get_user(client1).get_ticket(
|
||||||
|
models.ServiceTicket,
|
||||||
|
self.service,
|
||||||
|
self.service_pattern,
|
||||||
|
renew=False
|
||||||
|
)
|
||||||
|
ticket.validate = True
|
||||||
|
ticket.save()
|
||||||
|
# simulated expired session being garbage collected for client1
|
||||||
|
session = SessionStore(session_key=client1.session.session_key)
|
||||||
|
session.flush()
|
||||||
|
# assert the user exists before being cleaned
|
||||||
|
self.assertTrue(self.get_user(client1))
|
||||||
|
self.assertTrue(self.get_user(client2))
|
||||||
|
self.assertEqual(len(models.User.objects.all()), 2)
|
||||||
|
# session has being remove so the user of client1 is no longer authenticated
|
||||||
|
self.assertFalse(client1.session.get("authenticated"))
|
||||||
|
# the user a client2 should still be authenticated
|
||||||
|
self.assertTrue(client2.session.get("authenticated"))
|
||||||
|
# the user should be deleted
|
||||||
|
models.User.clean_deleted_sessions()
|
||||||
|
# assert the user with expired sessions has being well deleted but the other remain
|
||||||
|
self.assertEqual(len(models.User.objects.all()), 1)
|
||||||
|
self.assertFalse(models.ServiceTicket.objects.all())
|
||||||
|
self.assertTrue(client2.session.get("authenticated"))
|
||||||
|
|
||||||
|
|
||||||
|
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
|
||||||
|
class TicketTestCase(TestCase, UserModels, BaseServicePattern):
|
||||||
|
"""tests for the tickets models"""
|
||||||
|
def setUp(self):
|
||||||
|
"""Prepare the test context"""
|
||||||
|
self.setup_service_patterns()
|
||||||
|
self.service = 'http://127.0.0.1:45678'
|
||||||
|
self.service_pattern = models.ServicePattern.objects.create(
|
||||||
|
name="localhost",
|
||||||
|
pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
|
||||||
|
single_log_out=True
|
||||||
|
)
|
||||||
|
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
|
||||||
|
|
||||||
|
def get_ticket(
|
||||||
|
self,
|
||||||
|
user,
|
||||||
|
ticket_class,
|
||||||
|
service,
|
||||||
|
service_pattern,
|
||||||
|
renew=False,
|
||||||
|
validate=False,
|
||||||
|
validity_expired=False,
|
||||||
|
timeout_expired=False,
|
||||||
|
single_log_out=False,
|
||||||
|
):
|
||||||
|
"""Return a ticket"""
|
||||||
|
ticket = user.get_ticket(ticket_class, service, service_pattern, renew)
|
||||||
|
ticket.validate = validate
|
||||||
|
ticket.single_log_out = single_log_out
|
||||||
|
if validity_expired:
|
||||||
|
ticket.creation = min(
|
||||||
|
ticket.creation,
|
||||||
|
(timezone.now() - timedelta(seconds=(ticket_class.VALIDITY + 10)))
|
||||||
|
)
|
||||||
|
if timeout_expired:
|
||||||
|
ticket.creation = min(
|
||||||
|
ticket.creation,
|
||||||
|
(timezone.now() - timedelta(seconds=(ticket_class.TIMEOUT + 10)))
|
||||||
|
)
|
||||||
|
ticket.save()
|
||||||
|
return ticket
|
||||||
|
|
||||||
|
def test_clean_old_service_ticket(self):
|
||||||
|
"""test tickets clean_old_entries"""
|
||||||
|
client = get_auth_client()
|
||||||
|
user = self.get_user(client)
|
||||||
|
self.get_ticket(user, models.ServiceTicket, self.service, self.service_pattern)
|
||||||
|
self.get_ticket(
|
||||||
|
user, models.ServiceTicket,
|
||||||
|
self.service, self.service_pattern, validity_expired=True
|
||||||
|
)
|
||||||
|
(httpd, host, port) = utils.HttpParamsHandler.run()[0:3]
|
||||||
|
service = "http://%s:%s" % (host, port)
|
||||||
|
self.get_ticket(
|
||||||
|
user, models.ServiceTicket,
|
||||||
|
service, self.service_pattern, timeout_expired=True,
|
||||||
|
validate=True, single_log_out=True
|
||||||
|
)
|
||||||
|
self.assertEqual(len(models.ServiceTicket.objects.all()), 3)
|
||||||
|
models.ServiceTicket.clean_old_entries()
|
||||||
|
params = httpd.PARAMS
|
||||||
|
self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest'])
|
||||||
|
self.assertEqual(len(models.ServiceTicket.objects.all()), 1)
|
|
@ -240,7 +240,8 @@ class LoginTestCase(TestCase, BaseServicePattern):
|
||||||
"""Test the filtering on user attributes"""
|
"""Test the filtering on user attributes"""
|
||||||
client = get_auth_client()
|
client = get_auth_client()
|
||||||
|
|
||||||
response = client.get("/login", {'service': self.service_filter_fail})
|
for service in [self.service_filter_fail, self.service_filter_fail_alt]:
|
||||||
|
response = client.get("/login", {'service': service})
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
self.assertTrue(b"User charateristics non allowed" in response.content)
|
self.assertTrue(b"User charateristics non allowed" in response.content)
|
||||||
|
|
||||||
|
@ -388,6 +389,7 @@ class LoginTestCase(TestCase, BaseServicePattern):
|
||||||
class LogoutTestCase(TestCase):
|
class LogoutTestCase(TestCase):
|
||||||
"""test fot the logout view"""
|
"""test fot the logout view"""
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
"""Prepare the test context"""
|
||||||
self.service = 'http://127.0.0.1:45678'
|
self.service = 'http://127.0.0.1:45678'
|
||||||
self.service_pattern = models.ServicePattern.objects.create(
|
self.service_pattern = models.ServicePattern.objects.create(
|
||||||
name="localhost",
|
name="localhost",
|
||||||
|
@ -489,14 +491,24 @@ class LogoutTestCase(TestCase):
|
||||||
|
|
||||||
def test_logout_slo(self):
|
def test_logout_slo(self):
|
||||||
"""test logout from a service with SLO support"""
|
"""test logout from a service with SLO support"""
|
||||||
|
parameters = []
|
||||||
|
|
||||||
|
# test normal SLO
|
||||||
(httpd, host, port) = utils.HttpParamsHandler.run()[0:3]
|
(httpd, host, port) = utils.HttpParamsHandler.run()[0:3]
|
||||||
service = "http://%s:%s" % (host, port)
|
service = "http://%s:%s" % (host, port)
|
||||||
|
|
||||||
(client, ticket) = get_validated_ticket(service)[:2]
|
(client, ticket) = get_validated_ticket(service)[:2]
|
||||||
|
|
||||||
client.get('/logout')
|
client.get('/logout')
|
||||||
|
parameters.append((httpd.PARAMS, ticket))
|
||||||
|
|
||||||
params = httpd.PARAMS
|
# text SLO with a single_log_out_callback
|
||||||
|
(httpd, host, port) = utils.HttpParamsHandler.run()[0:3]
|
||||||
|
self.service_pattern.single_log_out_callback = "http://%s:%s" % (host, port)
|
||||||
|
self.service_pattern.save()
|
||||||
|
(client, ticket) = get_validated_ticket(self.service)[:2]
|
||||||
|
client.get('/logout')
|
||||||
|
parameters.append((httpd.PARAMS, ticket))
|
||||||
|
|
||||||
|
for (params, ticket) in parameters:
|
||||||
self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest'])
|
self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest'])
|
||||||
|
|
||||||
root = etree.fromstring(params[b'logoutRequest'][0])
|
root = etree.fromstring(params[b'logoutRequest'][0])
|
||||||
|
@ -831,6 +843,39 @@ class ValidateServiceTestCase(TestCase, XmlContent):
|
||||||
service_pattern=self.service_pattern_one_attribute
|
service_pattern=self.service_pattern_one_attribute
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.service_replace_attribute_list = "https://replace_attribute_list.example.com"
|
||||||
|
self.service_pattern_replace_attribute_list = models.ServicePattern.objects.create(
|
||||||
|
name="replace_attribute_list",
|
||||||
|
pattern="^https://replace_attribute_list\.example\.com(/.*)?$",
|
||||||
|
)
|
||||||
|
models.ReplaceAttributValue.objects.create(
|
||||||
|
attribut="alias",
|
||||||
|
pattern="^demo",
|
||||||
|
replace="truc",
|
||||||
|
service_pattern=self.service_pattern_replace_attribute_list
|
||||||
|
)
|
||||||
|
models.ReplaceAttributName.objects.create(
|
||||||
|
name="alias",
|
||||||
|
replace="ALIAS",
|
||||||
|
service_pattern=self.service_pattern_replace_attribute_list
|
||||||
|
)
|
||||||
|
self.service_replace_attribute = "https://replace_attribute.example.com"
|
||||||
|
self.service_pattern_replace_attribute = models.ServicePattern.objects.create(
|
||||||
|
name="replace_attribute",
|
||||||
|
pattern="^https://replace_attribute\.example\.com(/.*)?$",
|
||||||
|
)
|
||||||
|
models.ReplaceAttributValue.objects.create(
|
||||||
|
attribut="nom",
|
||||||
|
pattern="N",
|
||||||
|
replace="P",
|
||||||
|
service_pattern=self.service_pattern_replace_attribute
|
||||||
|
)
|
||||||
|
models.ReplaceAttributName.objects.create(
|
||||||
|
name="nom",
|
||||||
|
replace="NOM",
|
||||||
|
service_pattern=self.service_pattern_replace_attribute
|
||||||
|
)
|
||||||
|
|
||||||
def test_validate_service_view_ok(self):
|
def test_validate_service_view_ok(self):
|
||||||
"""test with a valid (ticket, service), the username and all attributes are transmited"""
|
"""test with a valid (ticket, service), the username and all attributes are transmited"""
|
||||||
ticket = get_user_ticket_request(self.service)[1]
|
ticket = get_user_ticket_request(self.service)[1]
|
||||||
|
@ -857,6 +902,32 @@ class ValidateServiceTestCase(TestCase, XmlContent):
|
||||||
{'nom': settings.CAS_TEST_ATTRIBUTES['nom']}
|
{'nom': settings.CAS_TEST_ATTRIBUTES['nom']}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_validate_replace_attributes(self):
|
||||||
|
"""test with a valid (ticket, service), attributes name and value replacement"""
|
||||||
|
ticket = get_user_ticket_request(self.service_replace_attribute)[1]
|
||||||
|
client = Client()
|
||||||
|
response = client.get(
|
||||||
|
'/serviceValidate',
|
||||||
|
{'ticket': ticket.value, 'service': self.service_replace_attribute}
|
||||||
|
)
|
||||||
|
self.assert_success(
|
||||||
|
response,
|
||||||
|
settings.CAS_TEST_USER,
|
||||||
|
{'NOM': 'Pymous'}
|
||||||
|
)
|
||||||
|
|
||||||
|
ticket = get_user_ticket_request(self.service_replace_attribute_list)[1]
|
||||||
|
client = Client()
|
||||||
|
response = client.get(
|
||||||
|
'/serviceValidate',
|
||||||
|
{'ticket': ticket.value, 'service': self.service_replace_attribute_list}
|
||||||
|
)
|
||||||
|
self.assert_success(
|
||||||
|
response,
|
||||||
|
settings.CAS_TEST_USER,
|
||||||
|
{'ALIAS': ['truc1', 'truc2']}
|
||||||
|
)
|
||||||
|
|
||||||
def test_validate_service_view_badservice(self):
|
def test_validate_service_view_badservice(self):
|
||||||
"""test with a valid ticket but a bad service, the validatin should fail"""
|
"""test with a valid ticket but a bad service, the validatin should fail"""
|
||||||
ticket = get_user_ticket_request(self.service)[1]
|
ticket = get_user_ticket_request(self.service)[1]
|
||||||
|
|
|
@ -37,6 +37,8 @@ def get_auth_client(**update):
|
||||||
params.update(update)
|
params.update(update)
|
||||||
|
|
||||||
client.post('/login', params)
|
client.post('/login', params)
|
||||||
|
assert client.session.get("authenticated")
|
||||||
|
|
||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue