Models full coverage

This commit is contained in:
Valentin Samir 2016-06-30 13:55:19 +02:00
parent 9814bb0a6f
commit ecf4e66e11
6 changed files with 304 additions and 41 deletions

View File

@ -1,7 +1,9 @@
[run] [run]
branch = True branch = True
source = cas_server source = cas_server
omit = cas_server/migrations* omit =
cas_server/migrations*
cas_server/management/*
[report] [report]
exclude_lines = exclude_lines =

View File

@ -84,7 +84,7 @@ class User(models.Model):
ticket.logout(session, async_list) ticket.logout(session, async_list)
queryset.delete() queryset.delete()
for future in async_list: for future in async_list:
if future: if future: # pragma: no branch (should always be true)
try: try:
future.result() future.result()
except Exception as error: except Exception as error:
@ -112,12 +112,20 @@ class User(models.Model):
(a.name, a.replace if a.replace else a.name) for a in service_pattern.attributs.all() (a.name, a.replace if a.replace else a.name) for a in service_pattern.attributs.all()
) )
replacements = dict( replacements = dict(
(a.name, (a.pattern, a.replace)) for a in service_pattern.replacements.all() (a.attribut, (a.pattern, a.replace)) for a in service_pattern.replacements.all()
) )
service_attributs = {} service_attributs = {}
for (key, value) in self.attributs.items(): for (key, value) in self.attributs.items():
if key in attributs or '*' in attributs: if key in attributs or '*' in attributs:
if key in replacements: if key in replacements:
if isinstance(value, list):
for index, subval in enumerate(value):
value[index] = re.sub(
replacements[key][0],
replacements[key][1],
subval
)
else:
value = re.sub(replacements[key][0], replacements[key][1], value) value = re.sub(replacements[key][0], replacements[key][1], value)
service_attributs[attributs.get(key, key)] = value service_attributs[attributs.get(key, key)] = value
ticket = ticket_class.objects.create( ticket = ticket_class.objects.create(
@ -396,7 +404,6 @@ class Ticket(models.Model):
).delete() ).delete()
# sending SLO to timed-out validated tickets # sending SLO to timed-out validated tickets
if cls.TIMEOUT and cls.TIMEOUT > 0:
async_list = [] async_list = []
session = FuturesSession( session = FuturesSession(
executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS) executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
@ -405,10 +412,10 @@ class Ticket(models.Model):
creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT)) creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT))
) )
for ticket in queryset: for ticket in queryset:
ticket.logout(None, session, async_list) ticket.logout(session, async_list)
queryset.delete() queryset.delete()
for future in async_list: for future in async_list:
if future: if future: # pragma: no branch (should always be true)
try: try:
future.result() future.result()
except Exception as error: except Exception as error:
@ -420,7 +427,7 @@ class Ticket(models.Model):
# On logout invalidate the Ticket # On logout invalidate the Ticket
self.validate = True self.validate = True
self.save() self.save()
if self.validate and self.single_log_out: if self.validate and self.single_log_out: # pragma: no branch (should always be true)
logger.info( logger.info(
"Sending SLO requests to service %s for user %s" % ( "Sending SLO requests to service %s for user %s" % (
self.service, self.service,

View File

@ -1,10 +1,13 @@
"""Some mixin classes for tests""" """Some mixin classes for tests"""
from cas_server.default_settings import settings from cas_server.default_settings import settings
from django.utils import timezone
import re import re
from lxml import etree from lxml import etree
from datetime import timedelta
from cas_server import models from cas_server import models
from cas_server.tests.utils import get_auth_client
class BaseServicePattern(object): class BaseServicePattern(object):
@ -52,6 +55,17 @@ class BaseServicePattern(object):
pattern="^admin$", pattern="^admin$",
service_pattern=self.service_pattern_filter_fail service_pattern=self.service_pattern_filter_fail
) )
self.service_filter_fail_alt = "https://filter_fail_alt.example.com"
self.service_pattern_filter_fail_alt = models.ServicePattern.objects.create(
name="filter_fail_alt",
pattern="^https://filter_fail_alt\.example\.com(/.*)?$",
proxy=proxy,
)
models.FilterAttributValue.objects.create(
attribut="nom",
pattern="^toto$",
service_pattern=self.service_pattern_filter_fail_alt
)
self.service_filter_success = "https://filter_success.example.com" self.service_filter_success = "https://filter_success.example.com"
self.service_pattern_filter_success = models.ServicePattern.objects.create( self.service_pattern_filter_success = models.ServicePattern.objects.create(
name="filter_success", name="filter_success",
@ -143,3 +157,24 @@ class XmlContent(object):
self.assertEqual(attrs1, original) self.assertEqual(attrs1, original)
return root return root
class UserModels(object):
"""Mixin for test on CAS user models"""
def expire_user(self):
"""return an expired user"""
client = get_auth_client()
new_date = timezone.now() - timedelta(seconds=(settings.SESSION_COOKIE_AGE + 600))
models.User.objects.filter(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
).update(date=new_date)
return client
def get_user(self, client):
"""return the user associated with an authenticated client"""
return models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)

View File

@ -0,0 +1,146 @@
"""Tests module for models"""
from cas_server.default_settings import settings
from django.test import TestCase
from django.test.utils import override_settings
from django.utils import timezone
from datetime import timedelta
from importlib import import_module
from cas_server import models
from cas_server import utils
from cas_server.tests.utils import get_auth_client
from cas_server.tests.mixin import UserModels, BaseServicePattern
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class UserTestCase(TestCase, UserModels):
"""tests for the user models"""
def setUp(self):
"""Prepare the test context"""
self.service = 'http://127.0.0.1:45678'
self.service_pattern = models.ServicePattern.objects.create(
name="localhost",
pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
single_log_out=True
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_clean_old_entries(self):
"""test clean_old_entries"""
# get an authenticated client
client = self.expire_user()
# assert the user exists before being cleaned
self.assertEqual(len(models.User.objects.all()), 1)
# assert the last activity date is before the expiry date
self.assertTrue(
self.get_user(client).date < (
timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE)
)
)
# delete old inactive users
models.User.clean_old_entries()
# assert the user has being well delete
self.assertEqual(len(models.User.objects.all()), 0)
def test_clean_deleted_sessions(self):
"""test clean_deleted_sessions"""
# get an authenticated client
client1 = get_auth_client()
client2 = get_auth_client()
# generate a ticket to fire SLO during user cleaning (SLO should fail a nothing listen
# on self.service)
ticket = self.get_user(client1).get_ticket(
models.ServiceTicket,
self.service,
self.service_pattern,
renew=False
)
ticket.validate = True
ticket.save()
# simulated expired session being garbage collected for client1
session = SessionStore(session_key=client1.session.session_key)
session.flush()
# assert the user exists before being cleaned
self.assertTrue(self.get_user(client1))
self.assertTrue(self.get_user(client2))
self.assertEqual(len(models.User.objects.all()), 2)
# session has being remove so the user of client1 is no longer authenticated
self.assertFalse(client1.session.get("authenticated"))
# the user a client2 should still be authenticated
self.assertTrue(client2.session.get("authenticated"))
# the user should be deleted
models.User.clean_deleted_sessions()
# assert the user with expired sessions has being well deleted but the other remain
self.assertEqual(len(models.User.objects.all()), 1)
self.assertFalse(models.ServiceTicket.objects.all())
self.assertTrue(client2.session.get("authenticated"))
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class TicketTestCase(TestCase, UserModels, BaseServicePattern):
"""tests for the tickets models"""
def setUp(self):
"""Prepare the test context"""
self.setup_service_patterns()
self.service = 'http://127.0.0.1:45678'
self.service_pattern = models.ServicePattern.objects.create(
name="localhost",
pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
single_log_out=True
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def get_ticket(
self,
user,
ticket_class,
service,
service_pattern,
renew=False,
validate=False,
validity_expired=False,
timeout_expired=False,
single_log_out=False,
):
"""Return a ticket"""
ticket = user.get_ticket(ticket_class, service, service_pattern, renew)
ticket.validate = validate
ticket.single_log_out = single_log_out
if validity_expired:
ticket.creation = min(
ticket.creation,
(timezone.now() - timedelta(seconds=(ticket_class.VALIDITY + 10)))
)
if timeout_expired:
ticket.creation = min(
ticket.creation,
(timezone.now() - timedelta(seconds=(ticket_class.TIMEOUT + 10)))
)
ticket.save()
return ticket
def test_clean_old_service_ticket(self):
"""test tickets clean_old_entries"""
client = get_auth_client()
user = self.get_user(client)
self.get_ticket(user, models.ServiceTicket, self.service, self.service_pattern)
self.get_ticket(
user, models.ServiceTicket,
self.service, self.service_pattern, validity_expired=True
)
(httpd, host, port) = utils.HttpParamsHandler.run()[0:3]
service = "http://%s:%s" % (host, port)
self.get_ticket(
user, models.ServiceTicket,
service, self.service_pattern, timeout_expired=True,
validate=True, single_log_out=True
)
self.assertEqual(len(models.ServiceTicket.objects.all()), 3)
models.ServiceTicket.clean_old_entries()
params = httpd.PARAMS
self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest'])
self.assertEqual(len(models.ServiceTicket.objects.all()), 1)

View File

@ -240,7 +240,8 @@ class LoginTestCase(TestCase, BaseServicePattern):
"""Test the filtering on user attributes""" """Test the filtering on user attributes"""
client = get_auth_client() client = get_auth_client()
response = client.get("/login", {'service': self.service_filter_fail}) for service in [self.service_filter_fail, self.service_filter_fail_alt]:
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertTrue(b"User charateristics non allowed" in response.content) self.assertTrue(b"User charateristics non allowed" in response.content)
@ -388,6 +389,7 @@ class LoginTestCase(TestCase, BaseServicePattern):
class LogoutTestCase(TestCase): class LogoutTestCase(TestCase):
"""test fot the logout view""" """test fot the logout view"""
def setUp(self): def setUp(self):
"""Prepare the test context"""
self.service = 'http://127.0.0.1:45678' self.service = 'http://127.0.0.1:45678'
self.service_pattern = models.ServicePattern.objects.create( self.service_pattern = models.ServicePattern.objects.create(
name="localhost", name="localhost",
@ -489,14 +491,24 @@ class LogoutTestCase(TestCase):
def test_logout_slo(self): def test_logout_slo(self):
"""test logout from a service with SLO support""" """test logout from a service with SLO support"""
parameters = []
# test normal SLO
(httpd, host, port) = utils.HttpParamsHandler.run()[0:3] (httpd, host, port) = utils.HttpParamsHandler.run()[0:3]
service = "http://%s:%s" % (host, port) service = "http://%s:%s" % (host, port)
(client, ticket) = get_validated_ticket(service)[:2] (client, ticket) = get_validated_ticket(service)[:2]
client.get('/logout') client.get('/logout')
parameters.append((httpd.PARAMS, ticket))
params = httpd.PARAMS # text SLO with a single_log_out_callback
(httpd, host, port) = utils.HttpParamsHandler.run()[0:3]
self.service_pattern.single_log_out_callback = "http://%s:%s" % (host, port)
self.service_pattern.save()
(client, ticket) = get_validated_ticket(self.service)[:2]
client.get('/logout')
parameters.append((httpd.PARAMS, ticket))
for (params, ticket) in parameters:
self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest']) self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest'])
root = etree.fromstring(params[b'logoutRequest'][0]) root = etree.fromstring(params[b'logoutRequest'][0])
@ -831,6 +843,39 @@ class ValidateServiceTestCase(TestCase, XmlContent):
service_pattern=self.service_pattern_one_attribute service_pattern=self.service_pattern_one_attribute
) )
self.service_replace_attribute_list = "https://replace_attribute_list.example.com"
self.service_pattern_replace_attribute_list = models.ServicePattern.objects.create(
name="replace_attribute_list",
pattern="^https://replace_attribute_list\.example\.com(/.*)?$",
)
models.ReplaceAttributValue.objects.create(
attribut="alias",
pattern="^demo",
replace="truc",
service_pattern=self.service_pattern_replace_attribute_list
)
models.ReplaceAttributName.objects.create(
name="alias",
replace="ALIAS",
service_pattern=self.service_pattern_replace_attribute_list
)
self.service_replace_attribute = "https://replace_attribute.example.com"
self.service_pattern_replace_attribute = models.ServicePattern.objects.create(
name="replace_attribute",
pattern="^https://replace_attribute\.example\.com(/.*)?$",
)
models.ReplaceAttributValue.objects.create(
attribut="nom",
pattern="N",
replace="P",
service_pattern=self.service_pattern_replace_attribute
)
models.ReplaceAttributName.objects.create(
name="nom",
replace="NOM",
service_pattern=self.service_pattern_replace_attribute
)
def test_validate_service_view_ok(self): def test_validate_service_view_ok(self):
"""test with a valid (ticket, service), the username and all attributes are transmited""" """test with a valid (ticket, service), the username and all attributes are transmited"""
ticket = get_user_ticket_request(self.service)[1] ticket = get_user_ticket_request(self.service)[1]
@ -857,6 +902,32 @@ class ValidateServiceTestCase(TestCase, XmlContent):
{'nom': settings.CAS_TEST_ATTRIBUTES['nom']} {'nom': settings.CAS_TEST_ATTRIBUTES['nom']}
) )
def test_validate_replace_attributes(self):
"""test with a valid (ticket, service), attributes name and value replacement"""
ticket = get_user_ticket_request(self.service_replace_attribute)[1]
client = Client()
response = client.get(
'/serviceValidate',
{'ticket': ticket.value, 'service': self.service_replace_attribute}
)
self.assert_success(
response,
settings.CAS_TEST_USER,
{'NOM': 'Pymous'}
)
ticket = get_user_ticket_request(self.service_replace_attribute_list)[1]
client = Client()
response = client.get(
'/serviceValidate',
{'ticket': ticket.value, 'service': self.service_replace_attribute_list}
)
self.assert_success(
response,
settings.CAS_TEST_USER,
{'ALIAS': ['truc1', 'truc2']}
)
def test_validate_service_view_badservice(self): def test_validate_service_view_badservice(self):
"""test with a valid ticket but a bad service, the validatin should fail""" """test with a valid ticket but a bad service, the validatin should fail"""
ticket = get_user_ticket_request(self.service)[1] ticket = get_user_ticket_request(self.service)[1]

View File

@ -37,6 +37,8 @@ def get_auth_client(**update):
params.update(update) params.update(update)
client.post('/login', params) client.post('/login', params)
assert client.session.get("authenticated")
return client return client