Merge pull request #5 from nitmir/dev

Unit tests full coverage
This commit is contained in:
Valentin Samir 2016-07-01 01:26:41 +02:00 committed by GitHub
commit eb64412612
30 changed files with 2812 additions and 902 deletions

View File

@ -1,3 +1,11 @@
[run]
branch = True
source = cas_server
omit =
cas_server/migrations*
cas_server/management/*
cas_server/tests/*
[report] [report]
exclude_lines = exclude_lines =
pragma: no cover pragma: no cover
@ -5,3 +13,4 @@ exclude_lines =
def __unicode__ def __unicode__
raise AssertionError raise AssertionError
raise NotImplementedError raise NotImplementedError
if six.PY3:

View File

@ -2,19 +2,19 @@ language: python
python: python:
- "2.7" - "2.7"
env: env:
global:
- PIP_DOWNLOAD_CACHE=$HOME/.pip_cache
matrix: matrix:
- TOX_ENV=coverage
- TOX_ENV=flake8
- TOX_ENV=py27-django17 - TOX_ENV=py27-django17
- TOX_ENV=py27-django18 - TOX_ENV=py27-django18
- TOX_ENV=py27-django19 - TOX_ENV=py27-django19
- TOX_ENV=py34-django17 - TOX_ENV=py34-django17
- TOX_ENV=py34-django18 - TOX_ENV=py34-django18
- TOX_ENV=py34-django19 - TOX_ENV=py34-django19
- TOX_ENV=flake8
cache: cache:
directories: directories:
- $HOME/.pip-cache/ - $HOME/.cache/pip/
- $HOME/build/nitmir/django-cas-server/.tox/
install: install:
- "travis_retry pip install setuptools --upgrade" - "travis_retry pip install setuptools --upgrade"
- "pip install tox" - "pip install tox"
@ -22,4 +22,3 @@ script:
- tox -e $TOX_ENV - tox -e $TOX_ENV
after_script: after_script:
- cat .tox/$TOX_ENV/log/*.log - cat .tox/$TOX_ENV/log/*.log

View File

@ -1,11 +1,15 @@
.PHONY: clean build install dist test_venv test_project .PHONY: build dist
VERSION=`python setup.py -V` VERSION=`python setup.py -V`
build: build:
python setup.py build python setup.py build
install: install: dist
python setup.py install pip -V
pip install --no-deps --upgrade --force-reinstall --find-links ./dist/django-cas-server-${VERSION}.tar.gz django-cas-server
uninstall:
pip uninstall django-cas-server || true
clean_pyc: clean_pyc:
find ./ -name '*.pyc' -delete find ./ -name '*.pyc' -delete
@ -16,18 +20,23 @@ clean_tox:
rm -rf .tox rm -rf .tox
clean_test_venv: clean_test_venv:
rm -rf test_venv rm -rf test_venv
clean: clean_pyc clean_build clean_coverage:
clean_all: clean_pyc clean_build clean_tox clean_test_venv rm -rf coverage.xml .coverage htmlcov
clean_tild_backup:
find ./ -name '*~' -delete
clean: clean_pyc clean_build clean_coverage clean_tild_backup
clean_all: clean clean_tox clean_test_venv
dist: dist:
python setup.py sdist python setup.py sdist
test_venv: test_venv/bin/python:
mkdir -p test_venv
virtualenv test_venv virtualenv test_venv
test_venv/bin/pip install -U --requirement requirements.txt test_venv/bin/pip install -U --requirement requirements-dev.txt Django
test_venv/cas/manage.py: test_venv/cas/manage.py: test_venv
mkdir -p test_venv/cas mkdir -p test_venv/cas
test_venv/bin/django-admin startproject cas test_venv/cas test_venv/bin/django-admin startproject cas test_venv/cas
ln -s ../../cas_server test_venv/cas/cas_server ln -s ../../cas_server test_venv/cas/cas_server
@ -38,19 +47,15 @@ test_venv/cas/manage.py:
test_venv/bin/python test_venv/cas/manage.py migrate test_venv/bin/python test_venv/cas/manage.py migrate
test_venv/bin/python test_venv/cas/manage.py createsuperuser test_venv/bin/python test_venv/cas/manage.py createsuperuser
test_project: test_venv test_venv/cas/manage.py test_venv: test_venv/bin/python
test_project: test_venv/cas/manage.py
@echo "##############################################################" @echo "##############################################################"
@echo "A test django project was created in $(realpath test_venv/cas)" @echo "A test django project was created in $(realpath test_venv/cas)"
run_test_server: test_project run_server: test_project
test_venv/bin/python test_venv/cas/manage.py runserver test_venv/bin/python test_venv/cas/manage.py runserver
coverage: test_venv run_tests: test_venv
test_venv/bin/pip install coverage test_venv/bin/py.test --cov=cas_server --cov-report html
test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests rm htmlcov/coverage_html.js # I am really pissed off by those keybord shortcuts
test_venv/bin/coverage html
test_venv/bin/coverage xml
coverage_codacy: coverage
test_venv/bin/pip install codacy-coverage
test_venv/bin/python-codacy-coverage -r coverage.xml

View File

@ -219,7 +219,8 @@ Test backend settings. Only usefull if you are using the test authentication bac
* ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``. * ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``.
* ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``. * ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``.
* ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is * ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is
``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}``. ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net',
'alias': ['demo1', 'demo2']}``.
Authentication backend Authentication backend

View File

@ -7,6 +7,6 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51 # along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
# #
# (c) 2015 Valentin Samir # (c) 2015-2016 Valentin Samir
"""A django CAS server application"""
default_app_config = 'cas_server.apps.CasAppConfig' default_app_config = 'cas_server.apps.CasAppConfig'

View File

@ -7,7 +7,7 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51 # along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
# #
# (c) 2015 Valentin Samir # (c) 2015-2016 Valentin Samir
"""module for the admin interface of the app""" """module for the admin interface of the app"""
from django.contrib import admin from django.contrib import admin
from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket, User, ServicePattern from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket, User, ServicePattern

View File

@ -1,7 +1,19 @@
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
# more details.
#
# You should have received a copy of the GNU General Public License version 3
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# (c) 2015-2016 Valentin Samir
"""django config module"""
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.apps import AppConfig from django.apps import AppConfig
class CasAppConfig(AppConfig): class CasAppConfig(AppConfig):
"""django CAS application config class"""
name = 'cas_server' name = 'cas_server'
verbose_name = _('Central Authentication Service') verbose_name = _('Central Authentication Service')

View File

@ -8,7 +8,7 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51 # along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
# #
# (c) 2015 Valentin Samir # (c) 2015-2016 Valentin Samir
"""Some authentication classes for the CAS""" """Some authentication classes for the CAS"""
from django.conf import settings from django.conf import settings
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
@ -21,6 +21,7 @@ except ImportError:
class AuthUser(object): class AuthUser(object):
"""Authentication base class"""
def __init__(self, username): def __init__(self, username):
self.username = username self.username = username

View File

@ -78,5 +78,12 @@ setting_default('CAS_TEST_USER', 'test')
setting_default('CAS_TEST_PASSWORD', 'test') setting_default('CAS_TEST_PASSWORD', 'test')
setting_default( setting_default(
'CAS_TEST_ATTRIBUTES', 'CAS_TEST_ATTRIBUTES',
{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'} {
'nom': 'Nymous',
'prenom': 'Ano',
'email': 'anonymous@example.net',
'alias': ['demo1', 'demo2']
}
) )
setting_default('CAS_ENABLE_AJAX_AUTH', False)

View File

@ -7,7 +7,7 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51 # along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
# #
# (c) 2015 Valentin Samir # (c) 2015-2016 Valentin Samir
"""forms for the app""" """forms for the app"""
from .default_settings import settings from .default_settings import settings
@ -19,6 +19,7 @@ import cas_server.models as models
class WarnForm(forms.Form): class WarnForm(forms.Form):
"""Form used on warn page before emiting a ticket"""
service = forms.CharField(widget=forms.HiddenInput(), required=False) service = forms.CharField(widget=forms.HiddenInput(), required=False)
renew = forms.BooleanField(widget=forms.HiddenInput(), required=False) renew = forms.BooleanField(widget=forms.HiddenInput(), required=False)
gateway = forms.CharField(widget=forms.HiddenInput(), required=False) gateway = forms.CharField(widget=forms.HiddenInput(), required=False)
@ -35,6 +36,7 @@ class UserCredential(forms.Form):
lt = forms.CharField(widget=forms.HiddenInput(), required=False) lt = forms.CharField(widget=forms.HiddenInput(), required=False)
method = forms.CharField(widget=forms.HiddenInput(), required=False) method = forms.CharField(widget=forms.HiddenInput(), required=False)
warn = forms.BooleanField(label=_('warn'), required=False) warn = forms.BooleanField(label=_('warn'), required=False)
renew = forms.BooleanField(widget=forms.HiddenInput(), required=False)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(UserCredential, self).__init__(*args, **kwargs) super(UserCredential, self).__init__(*args, **kwargs)
@ -46,6 +48,7 @@ class UserCredential(forms.Form):
cleaned_data["username"] = auth.username cleaned_data["username"] = auth.username
else: else:
raise forms.ValidationError(_(u"Bad user")) raise forms.ValidationError(_(u"Bad user"))
return cleaned_data
class TicketForm(forms.ModelForm): class TicketForm(forms.ModelForm):

View File

@ -1,3 +1,4 @@
"""Clean deleted sessions management command"""
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
@ -5,6 +6,7 @@ from ... import models
class Command(BaseCommand): class Command(BaseCommand):
"""Clean deleted sessions"""
args = '' args = ''
help = _(u"Clean deleted sessions") help = _(u"Clean deleted sessions")

View File

@ -1,3 +1,4 @@
"""Clean old trickets management command"""
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
@ -5,6 +6,7 @@ from ... import models
class Command(BaseCommand): class Command(BaseCommand):
"""Clean old trickets"""
args = '' args = ''
help = _(u"Clean old trickets") help = _(u"Clean old trickets")

View File

@ -8,7 +8,7 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51 # along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
# #
# (c) 2015 Valentin Samir # (c) 2015-2016 Valentin Samir
"""models for the app""" """models for the app"""
from .default_settings import settings from .default_settings import settings
@ -20,7 +20,6 @@ from django.utils import timezone
from picklefield.fields import PickledObjectField from picklefield.fields import PickledObjectField
import re import re
import os
import sys import sys
import logging import logging
from importlib import import_module from importlib import import_module
@ -47,6 +46,7 @@ class User(models.Model):
@classmethod @classmethod
def clean_old_entries(cls): def clean_old_entries(cls):
"""Remove users inactive since more that SESSION_COOKIE_AGE"""
users = cls.objects.filter( users = cls.objects.filter(
date__lt=(timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE)) date__lt=(timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE))
) )
@ -56,6 +56,7 @@ class User(models.Model):
@classmethod @classmethod
def clean_deleted_sessions(cls): def clean_deleted_sessions(cls):
"""Remove user where the session do not exists anymore"""
for user in cls.objects.all(): for user in cls.objects.all():
if not SessionStore(session_key=user.session_key).get('authenticated'): if not SessionStore(session_key=user.session_key).get('authenticated'):
user.logout() user.logout()
@ -80,10 +81,10 @@ class User(models.Model):
for ticket_class in ticket_classes: for ticket_class in ticket_classes:
queryset = ticket_class.objects.filter(user=self) queryset = ticket_class.objects.filter(user=self)
for ticket in queryset: for ticket in queryset:
ticket.logout(request, session, async_list) ticket.logout(session, async_list)
queryset.delete() queryset.delete()
for future in async_list: for future in async_list:
if future: if future: # pragma: no branch (should always be true)
try: try:
future.result() future.result()
except Exception as error: except Exception as error:
@ -111,13 +112,21 @@ class User(models.Model):
(a.name, a.replace if a.replace else a.name) for a in service_pattern.attributs.all() (a.name, a.replace if a.replace else a.name) for a in service_pattern.attributs.all()
) )
replacements = dict( replacements = dict(
(a.name, (a.pattern, a.replace)) for a in service_pattern.replacements.all() (a.attribut, (a.pattern, a.replace)) for a in service_pattern.replacements.all()
) )
service_attributs = {} service_attributs = {}
for (key, value) in self.attributs.items(): for (key, value) in self.attributs.items():
if key in attributs or '*' in attributs: if key in attributs or '*' in attributs:
if key in replacements: if key in replacements:
value = re.sub(replacements[key][0], replacements[key][1], value) if isinstance(value, list):
for index, subval in enumerate(value):
value[index] = re.sub(
replacements[key][0],
replacements[key][1],
subval
)
else:
value = re.sub(replacements[key][0], replacements[key][1], value)
service_attributs[attributs.get(key, key)] = value service_attributs[attributs.get(key, key)] = value
ticket = ticket_class.objects.create( ticket = ticket_class.objects.create(
user=self, user=self,
@ -141,6 +150,7 @@ class User(models.Model):
class ServicePatternException(Exception): class ServicePatternException(Exception):
"""Base exception of exceptions raised in the ServicePattern model"""
pass pass
@ -394,77 +404,57 @@ class Ticket(models.Model):
).delete() ).delete()
# sending SLO to timed-out validated tickets # sending SLO to timed-out validated tickets
if cls.TIMEOUT and cls.TIMEOUT > 0: async_list = []
async_list = [] session = FuturesSession(
session = FuturesSession( executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS) )
) queryset = cls.objects.filter(
queryset = cls.objects.filter( creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT))
creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT)) )
) for ticket in queryset:
for ticket in queryset: ticket.logout(session, async_list)
ticket.logout(None, session, async_list) queryset.delete()
queryset.delete() for future in async_list:
for future in async_list: if future: # pragma: no branch (should always be true)
if future: try:
try: future.result()
future.result() except Exception as error:
except Exception as error: logger.warning("Error durring SLO %s" % error)
logger.warning("Error durring SLO %s" % error) sys.stderr.write("%r\n" % error)
sys.stderr.write("%r\n" % error)
def logout(self, request, session, async_list=None): def logout(self, session, async_list=None):
"""Send a SLO request to the ticket service""" """Send a SLO request to the ticket service"""
# On logout invalidate the Ticket # On logout invalidate the Ticket
self.validate = True self.validate = True
self.save() self.save()
if self.validate and self.single_log_out: if self.validate and self.single_log_out: # pragma: no branch (should always be true)
logger.info( logger.info(
"Sending SLO requests to service %s for user %s" % ( "Sending SLO requests to service %s for user %s" % (
self.service, self.service,
self.user.username self.user.username
) )
) )
try: xml = u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
xml = u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol" ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s"> <saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID> <samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
<samlp:SessionIndex>%(ticket)s</samlp:SessionIndex> </samlp:LogoutRequest>""" % \
</samlp:LogoutRequest>""" % \ {
{ 'id': utils.gen_saml_id(),
'id': os.urandom(20).encode("hex"), 'datetime': timezone.now().isoformat(),
'datetime': timezone.now().isoformat(), 'ticket': self.value
'ticket': self.value }
} if self.service_pattern.single_log_out_callback:
if self.service_pattern.single_log_out_callback: url = self.service_pattern.single_log_out_callback
url = self.service_pattern.single_log_out_callback else:
else: url = self.service
url = self.service async_list.append(
async_list.append( session.post(
session.post( url.encode('utf-8'),
url.encode('utf-8'), data={'logoutRequest': xml.encode('utf-8')},
data={'logoutRequest': xml.encode('utf-8')}, timeout=settings.CAS_SLO_TIMEOUT
timeout=settings.CAS_SLO_TIMEOUT
)
) )
except Exception as error: )
error = utils.unpack_nested_exception(error)
logger.warning(
"Error durring SLO for user %s on service %s: %s" % (
self.user.username,
self.service,
error
)
)
if request is not None:
messages.add_message(
request,
messages.WARNING,
_(u'Error during service logout %(service)s:\n%(error)s') %
{'service': self.service, 'error': error}
)
else:
sys.stderr.write("%r\n" % error)
class ServiceTicket(Ticket): class ServiceTicket(Ticket):

View File

@ -1,702 +0,0 @@
from .default_settings import settings
from django.test import TestCase
from django.test import Client
import six
from lxml import etree
from cas_server import models
from cas_server import utils
def get_login_page_params():
client = Client()
response = client.get('/login')
form = response.context["form"]
params = {}
for field in form:
if field.value():
params[field.name] = field.value()
else:
params[field.name] = ""
return client, params
def get_auth_client():
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
client.post('/login', params)
return client
def get_user_ticket_request(service):
client = get_auth_client()
response = client.get("/login", {"service": service})
ticket_value = response['Location'].split('ticket=')[-1]
user = models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
ticket = models.ServiceTicket.objects.get(value=ticket_value)
return (user, ticket)
def get_pgt():
(host, port) = utils.PGTUrlHandler.run()[1:3]
service = "http://%s:%s" % (host, port)
(user, ticket) = get_user_ticket_request(service)
client = Client()
client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
params = utils.PGTUrlHandler.PARAMS.copy()
params["service"] = service
params["user"] = user
return params
class CheckPasswordCase(TestCase):
"""Tests for the utils function `utils.check_password`"""
def setUp(self):
"""Generate random bytes string that will be used ass passwords"""
self.password1 = utils.gen_saml_id()
self.password2 = utils.gen_saml_id()
if not isinstance(self.password1, bytes):
self.password1 = self.password1.encode("utf8")
self.password2 = self.password2.encode("utf8")
def test_setup(self):
"""check that generated password are bytes"""
self.assertIsInstance(self.password1, bytes)
self.assertIsInstance(self.password2, bytes)
def test_plain(self):
"""test the plain auth method"""
self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
def test_crypt(self):
"""test the crypt auth method"""
if six.PY3:
hashed_password1 = utils.crypt.crypt(
self.password1.decode("utf8"),
"$6$UVVAQvrMyXMF3FF3"
).encode("utf8")
else:
hashed_password1 = utils.crypt.crypt(self.password1, "$6$UVVAQvrMyXMF3FF3")
self.assertTrue(utils.check_password("crypt", self.password1, hashed_password1, "utf8"))
self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8"))
def test_ldap_ssha(self):
"""test the ldap auth method with a {SSHA} scheme"""
salt = b"UVVAQvrMyXMF3FF3"
hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8")
self.assertIsInstance(hashed_password1, bytes)
self.assertTrue(utils.check_password("ldap", self.password1, hashed_password1, "utf8"))
self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8"))
def test_hex_md5(self):
"""test the hex_md5 auth method"""
hashed_password1 = utils.hashlib.md5(self.password1).hexdigest()
self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8"))
self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8"))
def test_hox_sha512(self):
"""test the hex_sha512 auth method"""
hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest()
self.assertTrue(
utils.check_password("hex_sha512", self.password1, hashed_password1, "utf8")
)
self.assertFalse(
utils.check_password("hex_sha512", self.password2, hashed_password1, "utf8")
)
class LoginTestCase(TestCase):
def setUp(self):
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
self.service_pattern = models.ServicePattern.objects.create(
name="example",
pattern="^https://www\.example\.com(/.*)?$",
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_login_view_post_goodpass_goodlt(self):
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
response = client.post('/login', params)
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
self.assertTrue(
models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
)
def test_login_view_post_badlt(self):
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
params["lt"] = 'LT-random'
response = client.post('/login', params)
self.assertEqual(response.status_code, 200)
self.assertTrue(b"Invalid login ticket" in response.content)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
def test_login_view_post_badpass_good_lt(self):
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = "test2"
response = client.post('/login', params)
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"The credentials you provided cannot be "
b"determined to be authentic"
) in response.content
)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
def test_view_login_get_auth_allowed_service(self):
client = get_auth_client()
response = client.get("/login?service=https://www.example.com")
self.assertEqual(response.status_code, 302)
self.assertTrue(response.has_header('Location'))
self.assertTrue(
response['Location'].startswith(
"https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX
)
)
ticket_value = response['Location'].split('ticket=')[-1]
user = models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
self.assertTrue(user)
ticket = models.ServiceTicket.objects.get(value=ticket_value)
self.assertEqual(ticket.user, user)
self.assertEqual(ticket.attributs, settings.CAS_TEST_ATTRIBUTES)
self.assertEqual(ticket.validate, False)
self.assertEqual(ticket.service_pattern, self.service_pattern)
def test_view_login_get_auth_denied_service(self):
client = get_auth_client()
response = client.get("/login?service=https://www.example.org")
self.assertEqual(response.status_code, 200)
self.assertTrue(b"Service https://www.example.org non allowed" in response.content)
class LogoutTestCase(TestCase):
def setUp(self):
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
def test_logout_view(self):
client = get_auth_client()
response = client.get("/login")
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
response = client.get("/logout")
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"You have successfully logged out from "
b"the Central Authentication Service"
) in response.content
)
response = client.get("/login")
self.assertEqual(response.status_code, 200)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
def test_logout_view_url(self):
client = get_auth_client()
response = client.get('/logout?url=https://www.example.com')
self.assertEqual(response.status_code, 302)
self.assertTrue(response.has_header("Location"))
self.assertEqual(response["Location"], "https://www.example.com")
response = client.get("/login")
self.assertEqual(response.status_code, 200)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
def test_logout_view_service(self):
client = get_auth_client()
response = client.get('/logout?service=https://www.example.com')
self.assertEqual(response.status_code, 302)
self.assertTrue(response.has_header("Location"))
self.assertEqual(response["Location"], "https://www.example.com")
response = client.get("/login")
self.assertEqual(response.status_code, 200)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
class AuthTestCase(TestCase):
def setUp(self):
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
self.service = 'https://www.example.com'
models.ServicePattern.objects.create(
name="example",
pattern="^https://www\.example\.com(/.*)?$"
)
def test_auth_view_goodpass(self):
settings.CAS_AUTH_SHARED_SECRET = 'test'
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': settings.CAS_TEST_PASSWORD,
'service': self.service,
'secret': 'test'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'yes\n')
def test_auth_view_badpass(self):
settings.CAS_AUTH_SHARED_SECRET = 'test'
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': 'badpass',
'service': self.service,
'secret': 'test'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
def test_auth_view_badservice(self):
settings.CAS_AUTH_SHARED_SECRET = 'test'
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': settings.CAS_TEST_PASSWORD,
'service': 'https://www.example.org',
'secret': 'test'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
def test_auth_view_badsecret(self):
settings.CAS_AUTH_SHARED_SECRET = 'test'
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': settings.CAS_TEST_PASSWORD,
'service': self.service,
'secret': 'badsecret'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
def test_auth_view_badsettings(self):
settings.CAS_AUTH_SHARED_SECRET = None
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': settings.CAS_TEST_PASSWORD,
'service': self.service,
'secret': 'test'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"no\nplease set CAS_AUTH_SHARED_SECRET")
class ValidateTestCase(TestCase):
def setUp(self):
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
self.service = 'https://www.example.com'
self.service_pattern = models.ServicePattern.objects.create(
name="example",
pattern="^https://www\.example\.com(/.*)?$"
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_validate_view_ok(self):
ticket = get_user_ticket_request(self.service)[1]
client = Client()
response = client.get('/validate', {'ticket': ticket.value, 'service': self.service})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'yes\ntest\n')
def test_validate_view_badservice(self):
ticket = get_user_ticket_request(self.service)[1]
client = Client()
response = client.get(
'/validate',
{'ticket': ticket.value, 'service': "https://www.example.org"}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
def test_validate_view_badticket(self):
get_user_ticket_request(self.service)
client = Client()
response = client.get(
'/validate',
{'ticket': "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX, 'service': self.service}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
class ValidateServiceTestCase(TestCase):
def setUp(self):
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
self.service = 'http://127.0.0.1:45678'
self.service_pattern = models.ServicePattern.objects.create(
name="localhost",
pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
proxy_callback=True
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_validate_service_view_ok(self):
ticket = get_user_ticket_request(self.service)[1]
client = Client()
response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
sucess = root.xpath(
"//cas:authenticationSuccess",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertTrue(sucess)
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(users), 1)
self.assertEqual(users[0].text, settings.CAS_TEST_USER)
attributes = root.xpath(
"//cas:attributes",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(attributes), 1)
attrs1 = {}
for attr in attributes[0]:
attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(attributes), len(attrs1))
attrs2 = {}
for attr in attributes:
attrs2[attr.attrib['name']] = attr.attrib['value']
self.assertEqual(attrs1, attrs2)
self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES)
def test_validate_service_view_badservice(self):
ticket = get_user_ticket_request(self.service)[1]
client = Client()
bad_service = "https://www.example.org"
response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': bad_service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_SERVICE")
self.assertEqual(error[0].text, bad_service)
def test_validate_service_view_badticket_goodprefix(self):
get_user_ticket_request(self.service)
client = Client()
bad_ticket = "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX
response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
self.assertEqual(error[0].text, 'ticket not found')
def test_validate_service_view_badticket_badprefix(self):
get_user_ticket_request(self.service)
client = Client()
bad_ticket = "RANDOM"
response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
self.assertEqual(error[0].text, bad_ticket)
def test_validate_service_view_ok_pgturl(self):
(host, port) = utils.PGTUrlHandler.run()[1:3]
service = "http://%s:%s" % (host, port)
ticket = get_user_ticket_request(service)[1]
client = Client()
response = client.get(
'/serviceValidate',
{'ticket': ticket.value, 'service': service, 'pgtUrl': service}
)
pgt_params = utils.PGTUrlHandler.PARAMS.copy()
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
pgtiou = root.xpath(
"//cas:proxyGrantingTicket",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(pgtiou), 1)
self.assertEqual(pgt_params["pgtIou"], pgtiou[0].text)
self.assertTrue("pgtId" in pgt_params)
def test_validate_service_pgturl_bad_proxy_callback(self):
self.service_pattern.proxy_callback = False
self.service_pattern.save()
ticket = get_user_ticket_request(self.service)[1]
client = Client()
response = client.get(
'/serviceValidate',
{'ticket': ticket.value, 'service': self.service, 'pgtUrl': self.service}
)
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_PROXY_CALLBACK")
self.assertEqual(error[0].text, "callback url not allowed by configuration")
class ProxyTestCase(TestCase):
def setUp(self):
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
self.service = 'http://127.0.0.1'
self.service_pattern = models.ServicePattern.objects.create(
name="localhost",
pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
proxy=True,
proxy_callback=True
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_validate_proxy_ok(self):
params = get_pgt()
# get a proxy ticket
client1 = Client()
response = client1.get('/proxy', {'pgt': params['pgtId'], 'targetService': self.service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
sucess = root.xpath("//cas:proxySuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertTrue(sucess)
proxy_ticket = root.xpath(
"//cas:proxyTicket",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(proxy_ticket), 1)
proxy_ticket = proxy_ticket[0].text
# validate the proxy ticket
client2 = Client()
response = client2.get('/proxyValidate', {'ticket': proxy_ticket, 'service': self.service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
sucess = root.xpath(
"//cas:authenticationSuccess",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertTrue(sucess)
# check that the proxy is send to the end service
proxies = root.xpath("//cas:proxies", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(proxies), 1)
proxy = proxies[0].xpath("//cas:proxy", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(proxy), 1)
self.assertEqual(proxy[0].text, params["service"])
# same tests than those for serviceValidate
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(users), 1)
self.assertEqual(users[0].text, settings.CAS_TEST_USER)
attributes = root.xpath(
"//cas:attributes",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(attributes), 1)
attrs1 = {}
for attr in attributes[0]:
attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(attributes), len(attrs1))
attrs2 = {}
for attr in attributes:
attrs2[attr.attrib['name']] = attr.attrib['value']
self.assertEqual(attrs1, attrs2)
self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES)
def test_validate_proxy_bad(self):
params = get_pgt()
# bad PGT
client1 = Client()
response = client1.get(
'/proxy',
{
'pgt': "%s-RANDOM" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX,
'targetService': params['service']
}
)
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
self.assertEqual(
error[0].text,
"PGT %s-RANDOM not found" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX
)
# bad targetService
client2 = Client()
response = client2.get(
'/proxy',
{'pgt': params['pgtId'], 'targetService': "https://www.example.org"}
)
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE")
self.assertEqual(error[0].text, "https://www.example.org")
# service do not allow proxy ticket
self.service_pattern.proxy = False
self.service_pattern.save()
client3 = Client()
response = client3.get(
'/proxy',
{'pgt': params['pgtId'], 'targetService': params['service']}
)
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE")
self.assertEqual(
error[0].text,
'the service %s do not allow proxy ticket' % params['service']
)

View File

193
cas_server/tests/mixin.py Normal file
View File

@ -0,0 +1,193 @@
# ⁻*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
# more details.
#
# You should have received a copy of the GNU General Public License version 3
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# (c) 2016 Valentin Samir
"""Some mixin classes for tests"""
from cas_server.default_settings import settings
from django.utils import timezone
import re
from lxml import etree
from datetime import timedelta
from cas_server import models
from cas_server.tests.utils import get_auth_client
class BaseServicePattern(object):
"""Mixing for setting up service pattern for testing"""
def setup_service_patterns(self, proxy=False):
"""setting up service pattern"""
# For general purpose testing
self.service = "https://www.example.com"
self.service_pattern = models.ServicePattern.objects.create(
name="example",
pattern="^https://www\.example\.com(/.*)?$",
proxy=proxy,
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
# For testing the restrict_users attributes
self.service_restrict_user_fail = "https://restrict_user_fail.example.com"
self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create(
name="restrict_user_fail",
pattern="^https://restrict_user_fail\.example\.com(/.*)?$",
restrict_users=True,
proxy=proxy,
)
self.service_restrict_user_success = "https://restrict_user_success.example.com"
self.service_pattern_restrict_user_success = models.ServicePattern.objects.create(
name="restrict_user_success",
pattern="^https://restrict_user_success\.example\.com(/.*)?$",
restrict_users=True,
proxy=proxy,
)
models.Username.objects.create(
value=settings.CAS_TEST_USER,
service_pattern=self.service_pattern_restrict_user_success
)
# For testing the user attributes filtering conditions
self.service_filter_fail = "https://filter_fail.example.com"
self.service_pattern_filter_fail = models.ServicePattern.objects.create(
name="filter_fail",
pattern="^https://filter_fail\.example\.com(/.*)?$",
proxy=proxy,
)
models.FilterAttributValue.objects.create(
attribut="right",
pattern="^admin$",
service_pattern=self.service_pattern_filter_fail
)
self.service_filter_fail_alt = "https://filter_fail_alt.example.com"
self.service_pattern_filter_fail_alt = models.ServicePattern.objects.create(
name="filter_fail_alt",
pattern="^https://filter_fail_alt\.example\.com(/.*)?$",
proxy=proxy,
)
models.FilterAttributValue.objects.create(
attribut="nom",
pattern="^toto$",
service_pattern=self.service_pattern_filter_fail_alt
)
self.service_filter_success = "https://filter_success.example.com"
self.service_pattern_filter_success = models.ServicePattern.objects.create(
name="filter_success",
pattern="^https://filter_success\.example\.com(/.*)?$",
proxy=proxy,
)
models.FilterAttributValue.objects.create(
attribut="email",
pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']),
service_pattern=self.service_pattern_filter_success
)
# For testing the user_field attributes
self.service_field_needed_fail = "https://field_needed_fail.example.com"
self.service_pattern_field_needed_fail = models.ServicePattern.objects.create(
name="field_needed_fail",
pattern="^https://field_needed_fail\.example\.com(/.*)?$",
user_field="uid",
proxy=proxy,
)
self.service_field_needed_success = "https://field_needed_success.example.com"
self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
name="field_needed_success",
pattern="^https://field_needed_success\.example\.com(/.*)?$",
user_field="alias",
proxy=proxy,
)
self.service_field_needed_success_alt = "https://field_needed_success_alt.example.com"
self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
name="field_needed_success_alt",
pattern="^https://field_needed_success_alt\.example\.com(/.*)?$",
user_field="nom",
proxy=proxy,
)
class XmlContent(object):
"""Mixin for test on CAS XML responses"""
def assert_error(self, response, code, text=None):
"""Assert a validation error"""
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], code)
if text is not None:
self.assertEqual(error[0].text, text)
def assert_success(self, response, username, original_attributes):
"""assert a ticket validation success"""
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
sucess = root.xpath(
"//cas:authenticationSuccess",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertTrue(sucess)
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(users), 1)
self.assertEqual(users[0].text, username)
attributes = root.xpath(
"//cas:attributes",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(attributes), 1)
attrs1 = set()
for attr in attributes[0]:
attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(attributes), len(attrs1))
attrs2 = set()
for attr in attributes:
attrs2.add((attr.attrib['name'], attr.attrib['value']))
original = set()
for key, value in original_attributes.items():
if isinstance(value, list):
for sub_value in value:
original.add((key, sub_value))
else:
original.add((key, value))
self.assertEqual(attrs1, attrs2)
self.assertEqual(attrs1, original)
return root
class UserModels(object):
"""Mixin for test on CAS user models"""
@staticmethod
def expire_user():
"""return an expired user"""
client = get_auth_client()
new_date = timezone.now() - timedelta(seconds=(settings.SESSION_COOKIE_AGE + 600))
models.User.objects.filter(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
).update(date=new_date)
return client
@staticmethod
def get_user(client):
"""return the user associated with an authenticated client"""
return models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)

View File

@ -52,7 +52,7 @@ MIDDLEWARE_CLASSES = [
'django.middleware.locale.LocaleMiddleware', 'django.middleware.locale.LocaleMiddleware',
] ]
ROOT_URLCONF = 'cas_server.urls' ROOT_URLCONF = 'cas_server.tests.urls'
# Database # Database
# https://docs.djangoproject.com/en/1.9/ref/settings/#databases # https://docs.djangoproject.com/en/1.9/ref/settings/#databases
@ -60,6 +60,7 @@ ROOT_URLCONF = 'cas_server.urls'
DATABASES = { DATABASES = {
'default': { 'default': {
'ENGINE': 'django.db.backends.sqlite3', 'ENGINE': 'django.db.backends.sqlite3',
'NAME': ':memory:',
} }
} }

View File

@ -0,0 +1,166 @@
# ⁻*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
# more details.
#
# You should have received a copy of the GNU General Public License version 3
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# (c) 2016 Valentin Samir
"""Tests module for models"""
from cas_server.default_settings import settings
from django.test import TestCase
from django.test.utils import override_settings
from django.utils import timezone
from datetime import timedelta
from importlib import import_module
from cas_server import models
from cas_server.tests.utils import get_auth_client, HttpParamsHandler
from cas_server.tests.mixin import UserModels, BaseServicePattern
SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class UserTestCase(TestCase, UserModels):
"""tests for the user models"""
def setUp(self):
"""Prepare the test context"""
self.service = 'http://127.0.0.1:45678'
self.service_pattern = models.ServicePattern.objects.create(
name="localhost",
pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
single_log_out=True
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_clean_old_entries(self):
"""test clean_old_entries"""
# get an authenticated client
client = self.expire_user()
# assert the user exists before being cleaned
self.assertEqual(len(models.User.objects.all()), 1)
# assert the last activity date is before the expiry date
self.assertTrue(
self.get_user(client).date < (
timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE)
)
)
# delete old inactive users
models.User.clean_old_entries()
# assert the user has being well delete
self.assertEqual(len(models.User.objects.all()), 0)
def test_clean_deleted_sessions(self):
"""test clean_deleted_sessions"""
# get an authenticated client
client1 = get_auth_client()
client2 = get_auth_client()
# generate a ticket to fire SLO during user cleaning (SLO should fail a nothing listen
# on self.service)
ticket = self.get_user(client1).get_ticket(
models.ServiceTicket,
self.service,
self.service_pattern,
renew=False
)
ticket.validate = True
ticket.save()
# simulated expired session being garbage collected for client1
session = SessionStore(session_key=client1.session.session_key)
session.flush()
# assert the user exists before being cleaned
self.assertTrue(self.get_user(client1))
self.assertTrue(self.get_user(client2))
self.assertEqual(len(models.User.objects.all()), 2)
# session has being remove so the user of client1 is no longer authenticated
self.assertFalse(client1.session.get("authenticated"))
# the user a client2 should still be authenticated
self.assertTrue(client2.session.get("authenticated"))
# the user should be deleted
models.User.clean_deleted_sessions()
# assert the user with expired sessions has being well deleted but the other remain
self.assertEqual(len(models.User.objects.all()), 1)
self.assertFalse(models.ServiceTicket.objects.all())
self.assertTrue(client2.session.get("authenticated"))
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class TicketTestCase(TestCase, UserModels, BaseServicePattern):
"""tests for the tickets models"""
def setUp(self):
"""Prepare the test context"""
self.setup_service_patterns()
self.service = 'http://127.0.0.1:45678'
self.service_pattern = models.ServicePattern.objects.create(
name="localhost",
pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
single_log_out=True
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
@staticmethod
def get_ticket(
user,
ticket_class,
service,
service_pattern,
renew=False,
validate=False,
validity_expired=False,
timeout_expired=False,
single_log_out=False,
):
"""Return a ticket"""
ticket = user.get_ticket(ticket_class, service, service_pattern, renew)
ticket.validate = validate
ticket.single_log_out = single_log_out
if validity_expired:
ticket.creation = min(
ticket.creation,
(timezone.now() - timedelta(seconds=(ticket_class.VALIDITY + 10)))
)
if timeout_expired:
ticket.creation = min(
ticket.creation,
(timezone.now() - timedelta(seconds=(ticket_class.TIMEOUT + 10)))
)
ticket.save()
return ticket
def test_clean_old_service_ticket(self):
"""test tickets clean_old_entries"""
# ge an authenticated client
client = get_auth_client()
# get the user associated to the client
user = self.get_user(client)
# generate a ticket for that client, waiting for validation
self.get_ticket(user, models.ServiceTicket, self.service, self.service_pattern)
# generate another ticket for those validation time has expired
self.get_ticket(
user, models.ServiceTicket,
self.service, self.service_pattern, validity_expired=True
)
(httpd, host, port) = HttpParamsHandler.run()[0:3]
service = "http://%s:%s" % (host, port)
# generate a ticket with SLO having timeout reach
self.get_ticket(
user, models.ServiceTicket,
service, self.service_pattern, timeout_expired=True,
validate=True, single_log_out=True
)
# there should be 3 tickets in the db
self.assertEqual(len(models.ServiceTicket.objects.all()), 3)
# we call the clean_old_entries method that should delete validated non SLO ticket and
# expired non validated ticket and send SLO for SLO expired ticket before deleting then
models.ServiceTicket.clean_old_entries()
params = httpd.PARAMS
# we successfully got a SLO request
self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest'])
# only 1 ticket remain in the db
self.assertEqual(len(models.ServiceTicket.objects.all()), 1)

View File

@ -0,0 +1,191 @@
# ⁻*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
# more details.
#
# You should have received a copy of the GNU General Public License version 3
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# (c) 2016 Valentin Samir
"""Tests module for utils"""
from django.test import TestCase
import six
from cas_server import utils
class CheckPasswordCase(TestCase):
"""Tests for the utils function `utils.check_password`"""
def setUp(self):
"""Generate random bytes string that will be used ass passwords"""
self.password1 = utils.gen_saml_id()
self.password2 = utils.gen_saml_id()
if not isinstance(self.password1, bytes): # pragma: no cover executed only in python3
self.password1 = self.password1.encode("utf8")
self.password2 = self.password2.encode("utf8")
def test_setup(self):
"""check that generated password are bytes"""
self.assertIsInstance(self.password1, bytes)
self.assertIsInstance(self.password2, bytes)
def test_plain(self):
"""test the plain auth method"""
self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
def test_plain_unicode(self):
"""test the plain auth method with unicode input"""
self.assertTrue(
utils.check_password(
"plain",
self.password1.decode("utf8"),
self.password1.decode("utf8"),
"utf8"
)
)
self.assertFalse(
utils.check_password(
"plain",
self.password1.decode("utf8"),
self.password2.decode("utf8"),
"utf8"
)
)
def test_crypt(self):
"""test the crypt auth method"""
salts = ["$6$UVVAQvrMyXMF3FF3", "aa"]
hashed_password1 = []
for salt in salts:
if six.PY3:
hashed_password1.append(
utils.crypt.crypt(
self.password1.decode("utf8"),
salt
).encode("utf8")
)
else:
hashed_password1.append(utils.crypt.crypt(self.password1, salt))
for hp1 in hashed_password1:
self.assertTrue(utils.check_password("crypt", self.password1, hp1, "utf8"))
self.assertFalse(utils.check_password("crypt", self.password2, hp1, "utf8"))
with self.assertRaises(ValueError):
utils.check_password("crypt", self.password1, b"$truc$s$dsdsd", "utf8")
def test_ldap_password_valid(self):
"""test the ldap auth method with all the schemes"""
salt = b"UVVAQvrMyXMF3FF3"
schemes_salt = [b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}"]
schemes_nosalt = [b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"]
hashed_password1 = []
for scheme in schemes_salt:
hashed_password1.append(
utils.LdapHashUserPassword.hash(scheme, self.password1, salt, charset="utf8")
)
for scheme in schemes_nosalt:
hashed_password1.append(
utils.LdapHashUserPassword.hash(scheme, self.password1, charset="utf8")
)
hashed_password1.append(
utils.LdapHashUserPassword.hash(
b"{CRYPT}",
self.password1,
b"$6$UVVAQvrMyXMF3FF3",
charset="utf8"
)
)
for hp1 in hashed_password1:
self.assertIsInstance(hp1, bytes)
self.assertTrue(utils.check_password("ldap", self.password1, hp1, "utf8"))
self.assertFalse(utils.check_password("ldap", self.password2, hp1, "utf8"))
def test_ldap_password_fail(self):
"""test the ldap auth method with malformed hash or bad schemes"""
salt = b"UVVAQvrMyXMF3FF3"
schemes_salt = [b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}"]
schemes_nosalt = [b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"]
# first try to hash with bad parameters
with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
utils.LdapHashUserPassword.hash(b"TOTO", self.password1)
for scheme in schemes_nosalt:
with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
utils.LdapHashUserPassword.hash(scheme, self.password1, salt)
for scheme in schemes_salt:
with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
utils.LdapHashUserPassword.hash(scheme, self.password1)
with self.assertRaises(utils.LdapHashUserPassword.BadSalt):
utils.LdapHashUserPassword.hash(b'{CRYPT}', self.password1, b"$truc$toto")
# then try to check hash with bad hashes
with self.assertRaises(utils.LdapHashUserPassword.BadHash):
utils.check_password("ldap", self.password1, b"TOTOssdsdsd", "utf8")
for scheme in schemes_salt:
with self.assertRaises(utils.LdapHashUserPassword.BadHash):
utils.check_password("ldap", self.password1, scheme + b"dG90b3E8ZHNkcw==", "utf8")
def test_hex(self):
"""test all the hex_HASH method: the hashed password is a simple hash of the password"""
hashes = ["md5", "sha1", "sha224", "sha256", "sha384", "sha512"]
hashed_password1 = []
for hash in hashes:
hashed_password1.append(
("hex_%s" % hash, getattr(utils.hashlib, hash)(self.password1).hexdigest())
)
for (method, hp1) in hashed_password1:
self.assertTrue(utils.check_password(method, self.password1, hp1, "utf8"))
self.assertFalse(utils.check_password(method, self.password2, hp1, "utf8"))
def test_bad_method(self):
"""try to check password with a bad method, should raise a ValueError"""
with self.assertRaises(ValueError):
utils.check_password("test", self.password1, b"$truc$s$dsdsd", "utf8")
class UtilsTestCase(TestCase):
"""tests for some little utils functions"""
def test_import_attr(self):
"""
test the import_attr function. Feeded with a dotted path string, it should
import the dotted module and return that last componend of the dotted path
(function, class or variable)
"""
with self.assertRaises(ImportError):
utils.import_attr('toto.titi.tutu')
with self.assertRaises(AttributeError):
utils.import_attr('cas_server.utils.toto')
with self.assertRaises(ValueError):
utils.import_attr('toto')
self.assertEqual(
utils.import_attr('cas_server.default_app_config'),
'cas_server.apps.CasAppConfig'
)
self.assertEqual(utils.import_attr(utils), utils)
def test_update_url(self):
"""
test the update_url function. Given an url with possible GET parameter and a dict
the function build a url with GET parameters updated by the dictionnary
"""
url1 = utils.update_url(u"https://www.example.com?toto=1", {u"tata": u"2"})
url2 = utils.update_url(b"https://www.example.com?toto=1", {b"tata": b"2"})
self.assertEqual(url1, u"https://www.example.com?tata=2&toto=1")
self.assertEqual(url2, u"https://www.example.com?tata=2&toto=1")
url3 = utils.update_url(u"https://www.example.com?toto=1", {u"toto": u"2"})
self.assertEqual(url3, u"https://www.example.com?toto=2")
def test_crypt_salt_is_valid(self):
"""test the function crypt_salt_is_valid who test if a crypt salt is valid"""
self.assertFalse(utils.crypt_salt_is_valid("")) # len 0
self.assertFalse(utils.crypt_salt_is_valid("a")) # len 1
self.assertFalse(utils.crypt_salt_is_valid("$$")) # start with $ followed by $
self.assertFalse(utils.crypt_salt_is_valid("$toto")) # start with $ but no secondary $
self.assertFalse(utils.crypt_salt_is_valid("$toto$toto")) # algorithm toto not known

File diff suppressed because it is too large Load Diff

22
cas_server/tests/urls.py Normal file
View File

@ -0,0 +1,22 @@
"""cas URL Configuration
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/1.9/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views
2. Add a URL to urlpatterns: url(r'^$', views.home, name='home')
Class-based views
1. Add an import: from other_app.views import Home
2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home')
Including another URLconf
1. Import the include() function: from django.conf.urls import url, include, include
2. Add a URL to urlpatterns: url(r'^blog/', include('blog.urls'))
"""
from django.conf.urls import url, include
from django.contrib import admin
urlpatterns = [
url(r'^admin/', admin.site.urls),
url(r'^', include('cas_server.urls', namespace='cas_server')),
]

180
cas_server/tests/utils.py Normal file
View File

@ -0,0 +1,180 @@
# ⁻*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
# more details.
#
# You should have received a copy of the GNU General Public License version 3
# along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
# (c) 2016 Valentin Samir
"""Some utils functions for tests"""
from cas_server.default_settings import settings
from django.test import Client
import cgi
from threading import Thread
from lxml import etree
from six.moves import BaseHTTPServer
from six.moves.urllib.parse import urlparse, parse_qsl
from cas_server import models
def copy_form(form):
"""Copy form value into a dict"""
params = {}
for field in form:
if field.value():
params[field.name] = field.value()
else:
params[field.name] = ""
return params
def get_login_page_params(client=None):
"""Return a client and the POST params for the client to login"""
if client is None:
client = Client()
response = client.get('/login')
params = copy_form(response.context["form"])
return client, params
def get_auth_client(**update):
"""return a authenticated client"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
params.update(update)
client.post('/login', params)
assert client.session.get("authenticated")
return client
def get_user_ticket_request(service):
"""Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
client = get_auth_client()
response = client.get("/login", {"service": service})
ticket_value = response['Location'].split('ticket=')[-1]
user = models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
ticket = models.ServiceTicket.objects.get(value=ticket_value)
return (user, ticket, client)
def get_validated_ticket(service):
"""Return a tick that has being already validated. Used to test SLO"""
(ticket, auth_client) = get_user_ticket_request(service)[1:3]
client = Client()
response = client.get('/validate', {'ticket': ticket.value, 'service': service})
assert (response.status_code == 200)
assert (response.content == b'yes\ntest\n')
ticket = models.ServiceTicket.objects.get(value=ticket.value)
return (auth_client, ticket)
def get_pgt():
"""return a dict contening a service, user and PGT ticket for this service"""
(httpd, host, port) = HttpParamsHandler.run()[0:3]
service = "http://%s:%s" % (host, port)
(user, ticket) = get_user_ticket_request(service)[:2]
client = Client()
client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
params = httpd.PARAMS
params["service"] = service
params["user"] = user
return params
def get_proxy_ticket(service):
"""Return a ProxyTicket waiting for validation"""
params = get_pgt()
# get a proxy ticket
client = Client()
response = client.get('/proxy', {'pgt': params['pgtId'], 'targetService': service})
root = etree.fromstring(response.content)
proxy_ticket = root.xpath(
"//cas:proxyTicket",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
proxy_ticket = proxy_ticket[0].text
ticket = models.ProxyTicket.objects.get(value=proxy_ticket)
return ticket
class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler):
"""
A simple http server that return 200 on GET or POST
and store GET or POST parameters. Used in unit tests
"""
def do_GET(self):
"""Called on a GET request on the BaseHTTPServer"""
self.send_response(200)
self.send_header(b"Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"ok")
url = urlparse(self.path)
params = dict(parse_qsl(url.query))
self.server.PARAMS = params
def do_POST(self):
"""Called on a POST request on the BaseHTTPServer"""
ctype, pdict = cgi.parse_header(self.headers.get('content-type'))
if ctype == 'multipart/form-data':
postvars = cgi.parse_multipart(self.rfile, pdict)
elif ctype == 'application/x-www-form-urlencoded':
length = int(self.headers.get('content-length'))
postvars = cgi.parse_qs(self.rfile.read(length), keep_blank_values=1)
else:
postvars = {}
self.server.PARAMS = postvars
def log_message(self, *args):
"""silent any log message"""
return
@classmethod
def run(cls):
"""Run a BaseHTTPServer using this class as handler"""
server_class = BaseHTTPServer.HTTPServer
httpd = server_class(("127.0.0.1", 0), cls)
(host, port) = httpd.socket.getsockname()
def lauch():
"""routine to lauch in a background thread"""
httpd.handle_request()
httpd.server_close()
httpd_thread = Thread(target=lauch)
httpd_thread.daemon = True
httpd_thread.start()
return (httpd, host, port)
class Http404Handler(HttpParamsHandler):
"""A simple http server that always return 404 not found. Used in unit tests"""
def do_GET(self):
"""Called on a GET request on the BaseHTTPServer"""
self.send_response(404)
self.send_header(b"Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"error 404 not found")
def do_POST(self):
"""Called on a POST request on the BaseHTTPServer"""
return self.do_GET()

View File

@ -8,7 +8,7 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51 # along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
# #
# (c) 2015 Valentin Samir # (c) 2015-2016 Valentin Samir
"""urls for the app""" """urls for the app"""
from django.conf.urls import patterns, url from django.conf.urls import patterns, url
from django.views.generic import RedirectView from django.views.generic import RedirectView

View File

@ -1,4 +1,4 @@
# *- coding: utf-8 -*- # -*- coding: utf-8 -*-
# This program is distributed in the hope that it will be useful, but WITHOUT # This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for # FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
@ -8,7 +8,7 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51 # along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
# #
# (c) 2015 Valentin Samir # (c) 2015-2016 Valentin Samir
"""Some util function for the app""" """Some util function for the app"""
from .default_settings import settings from .default_settings import settings
@ -23,18 +23,19 @@ import hashlib
import crypt import crypt
import base64 import base64
import six import six
from threading import Thread
from importlib import import_module from importlib import import_module
from six.moves import BaseHTTPServer
from six.moves.urllib.parse import urlparse, urlunparse, parse_qsl, urlencode from six.moves.urllib.parse import urlparse, urlunparse, parse_qsl, urlencode
def context(params): def context(params):
"""Function that add somes variable to the context before template rendering"""
params["settings"] = settings params["settings"] = settings
return params return params
def json_response(request, data): def json_response(request, data):
"""Wrapper dumping `data` to a json and sending it to the user with an HttpResponse"""
data["messages"] = [] data["messages"] = []
for msg in messages.get_messages(request): for msg in messages.get_messages(request):
data["messages"].append({'message': msg.message, 'level': msg.level_tag}) data["messages"].append({'message': msg.message, 'level': msg.level_tag})
@ -64,6 +65,7 @@ def redirect_params(url_name, params=None):
def reverse_params(url_name, params=None, **kwargs): def reverse_params(url_name, params=None, **kwargs):
"""compule the reverse url or `url_name` and add GET parameters from `params` to it"""
url = reverse(url_name, **kwargs) url = reverse(url_name, **kwargs)
params = urlencode(params if params else {}) params = urlencode(params if params else {})
return url + "?%s" % params return url + "?%s" % params
@ -83,10 +85,13 @@ def update_url(url, params):
url_parts = list(urlparse(url)) url_parts = list(urlparse(url))
query = dict(parse_qsl(url_parts[4])) query = dict(parse_qsl(url_parts[4]))
query.update(params) query.update(params)
url_parts[4] = urlencode(query) # make the params order deterministic
for i, url_part in enumerate(url_parts): query = list(query.items())
if not isinstance(url_part, bytes): query.sort()
url_parts[i] = url_part.encode('utf-8') url_query = urlencode(query)
if not isinstance(url_query, bytes): # pragma: no cover in python3 urlencode return an unicode
url_query = url_query.encode("utf-8")
url_parts[4] = url_query
return urlunparse(url_parts).decode('utf-8') return urlunparse(url_parts).decode('utf-8')
@ -147,35 +152,25 @@ def gen_saml_id():
return _gen_ticket('_') return _gen_ticket('_')
class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler): def crypt_salt_is_valid(salt):
PARAMS = {} """Return True is salt is valid has a crypt salt, False otherwise"""
if len(salt) < 2:
def do_GET(self): return False
self.send_response(200) else:
self.send_header(b"Content-type", "text/plain") if salt[0] == '$':
self.end_headers() if salt[1] == '$':
self.wfile.write(b"ok") return False
url = urlparse(self.path) else:
params = dict(parse_qsl(url.query)) if '$' not in salt[1:]:
PGTUrlHandler.PARAMS.update(params) return False
else:
def log_message(self, *args): hashed = crypt.crypt("", salt)
return if not hashed or '$' not in hashed[1:]:
return False
@staticmethod else:
def run(): return True
server_class = BaseHTTPServer.HTTPServer else:
httpd = server_class(("127.0.0.1", 0), PGTUrlHandler) return True
(host, port) = httpd.socket.getsockname()
def lauch():
httpd.handle_request()
httpd.server_close()
httpd_thread = Thread(target=lauch)
httpd_thread.daemon = True
httpd_thread.start()
return (httpd_thread, host, port)
class LdapHashUserPassword(object): class LdapHashUserPassword(object):
@ -268,7 +263,7 @@ class LdapHashUserPassword(object):
if salt is None or salt == b"": if salt is None or salt == b"":
salt = b"" salt = b""
cls._test_scheme_nosalt(scheme) cls._test_scheme_nosalt(scheme)
elif salt is not None: else:
cls._test_scheme_salt(scheme) cls._test_scheme_salt(scheme)
try: try:
return scheme + base64.b64encode( return scheme + base64.b64encode(
@ -278,9 +273,9 @@ class LdapHashUserPassword(object):
if six.PY3: if six.PY3:
password = password.decode(charset) password = password.decode(charset)
salt = salt.decode(charset) salt = salt.decode(charset)
hashed_password = crypt.crypt(password, salt) if not crypt_salt_is_valid(salt):
if hashed_password is None:
raise cls.BadSalt("System crypt implementation do not support the salt %r" % salt) raise cls.BadSalt("System crypt implementation do not support the salt %r" % salt)
hashed_password = crypt.crypt(password, salt)
if six.PY3: if six.PY3:
hashed_password = hashed_password.encode(charset) hashed_password = hashed_password.encode(charset)
return scheme + hashed_password return scheme + hashed_password
@ -302,7 +297,7 @@ class LdapHashUserPassword(object):
if scheme in cls.schemes_nosalt: if scheme in cls.schemes_nosalt:
return b"" return b""
elif scheme == b'{CRYPT}': elif scheme == b'{CRYPT}':
return b'$'.join(hashed_passord.split(b'$', 3)[:-1]) return b'$'.join(hashed_passord.split(b'$', 3)[:-1])[len(scheme):]
else: else:
hashed_passord = base64.b64decode(hashed_passord[len(scheme):]) hashed_passord = base64.b64decode(hashed_passord[len(scheme):])
if len(hashed_passord) < cls._schemes_to_len[scheme]: if len(hashed_passord) < cls._schemes_to_len[scheme]:
@ -324,7 +319,7 @@ def check_password(method, password, hashed_password, charset):
elif method == "crypt": elif method == "crypt":
if hashed_password.startswith(b'$'): if hashed_password.startswith(b'$'):
salt = b'$'.join(hashed_password.split(b'$', 3)[:-1]) salt = b'$'.join(hashed_password.split(b'$', 3)[:-1])
elif hashed_password.startswith(b'_'): elif hashed_password.startswith(b'_'): # pragma: no cover old BSD format not supported
salt = hashed_password[:9] salt = hashed_password[:9]
else: else:
salt = hashed_password[:2] salt = hashed_password[:2]
@ -332,9 +327,9 @@ def check_password(method, password, hashed_password, charset):
password = password.decode(charset) password = password.decode(charset)
salt = salt.decode(charset) salt = salt.decode(charset)
hashed_password = hashed_password.decode(charset) hashed_password = hashed_password.decode(charset)
crypted_password = crypt.crypt(password, salt) if not crypt_salt_is_valid(salt):
if crypted_password is None:
raise ValueError("System crypt implementation do not support the salt %r" % salt) raise ValueError("System crypt implementation do not support the salt %r" % salt)
crypted_password = crypt.crypt(password, salt)
return crypted_password == hashed_password return crypted_password == hashed_password
elif method == "ldap": elif method == "ldap":
scheme = LdapHashUserPassword.get_scheme(hashed_password) scheme = LdapHashUserPassword.get_scheme(hashed_password)

View File

@ -8,7 +8,7 @@
# along with this program; if not, write to the Free Software Foundation, Inc., 51 # along with this program; if not, write to the Free Software Foundation, Inc., 51
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
# #
# (c) 2015 Valentin Samir # (c) 2015-2016 Valentin Samir
"""views for the app""" """views for the app"""
from .default_settings import settings from .default_settings import settings
@ -105,10 +105,11 @@ class LogoutView(View, LogoutMixin):
service = None service = None
def init_get(self, request): def init_get(self, request):
"""Initialize GET received parameters"""
self.request = request self.request = request
self.service = request.GET.get('service') self.service = request.GET.get('service')
self.url = request.GET.get('url') self.url = request.GET.get('url')
self.ajax = 'HTTP_X_AJAX' in request.META self.ajax = settings.CAS_ENABLE_AJAX_AUTH and 'HTTP_X_AJAX' in request.META
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
"""methode called on GET request on this view""" """methode called on GET request on this view"""
@ -196,24 +197,30 @@ class LoginView(View, LogoutMixin):
USER_NOT_AUTHENTICATED = 6 USER_NOT_AUTHENTICATED = 6
def init_post(self, request): def init_post(self, request):
"""Initialize POST received parameters"""
self.request = request self.request = request
self.service = request.POST.get('service') self.service = request.POST.get('service')
self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False") self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False")
self.gateway = request.POST.get('gateway') self.gateway = request.POST.get('gateway')
self.method = request.POST.get('method') self.method = request.POST.get('method')
self.ajax = 'HTTP_X_AJAX' in request.META self.ajax = settings.CAS_ENABLE_AJAX_AUTH and 'HTTP_X_AJAX' in request.META
if request.POST.get('warned') and request.POST['warned'] != "False": if request.POST.get('warned') and request.POST['warned'] != "False":
self.warned = True self.warned = True
self.warn = request.POST.get('warn')
def check_lt(self): def gen_lt(self):
# save LT for later check """Generate a new LoginTicket and add it to the list of valid LT for the user"""
lt_valid = self.request.session.get('lt', [])
lt_send = self.request.POST.get('lt')
# generate a new LT (by posting the LT has been consumed)
self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()] self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
if len(self.request.session['lt']) > 100: if len(self.request.session['lt']) > 100:
self.request.session['lt'] = self.request.session['lt'][-100:] self.request.session['lt'] = self.request.session['lt'][-100:]
def check_lt(self):
"""Check is the POSTed LoginTicket is valid, if yes invalide it"""
# save LT for later check
lt_valid = self.request.session.get('lt', [])
lt_send = self.request.POST.get('lt')
# generate a new LT (by posting the LT has been consumed)
self.gen_lt()
# check if send LT is valid # check if send LT is valid
if lt_valid is None or lt_send not in lt_valid: if lt_valid is None or lt_send not in lt_valid:
return False return False
@ -238,7 +245,7 @@ class LoginView(View, LogoutMixin):
username=self.request.session['username'], username=self.request.session['username'],
session_key=self.request.session.session_key session_key=self.request.session.session_key
) )
self.user.save() self.user.save() # pragma: no cover (should not happend)
except models.User.DoesNotExist: except models.User.DoesNotExist:
self.user = models.User.objects.create( self.user = models.User.objects.create(
username=self.request.session['username'], username=self.request.session['username'],
@ -250,10 +257,15 @@ class LoginView(View, LogoutMixin):
elif ret == self.USER_ALREADY_LOGGED: elif ret == self.USER_ALREADY_LOGGED:
pass pass
else: else:
raise EnvironmentError("invalid output for LoginView.process_post") raise EnvironmentError("invalid output for LoginView.process_post") # pragma: no cover
return self.common() return self.common()
def process_post(self): def process_post(self):
"""
Analyse the POST request:
* check that the LoginTicket is valid
* check that the user sumited credentials are valid
"""
if not self.check_lt(): if not self.check_lt():
values = self.request.POST.copy() values = self.request.POST.copy()
# if not set a new LT and fail # if not set a new LT and fail
@ -280,12 +292,14 @@ class LoginView(View, LogoutMixin):
return self.USER_ALREADY_LOGGED return self.USER_ALREADY_LOGGED
def init_get(self, request): def init_get(self, request):
"""Initialize GET received parameters"""
self.request = request self.request = request
self.service = request.GET.get('service') self.service = request.GET.get('service')
self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False") self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False")
self.gateway = request.GET.get('gateway') self.gateway = request.GET.get('gateway')
self.method = request.GET.get('method') self.method = request.GET.get('method')
self.ajax = 'HTTP_X_AJAX' in request.META self.ajax = settings.CAS_ENABLE_AJAX_AUTH and 'HTTP_X_AJAX' in request.META
self.warn = request.GET.get('warn')
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
"""methode called on GET request on this view""" """methode called on GET request on this view"""
@ -294,22 +308,24 @@ class LoginView(View, LogoutMixin):
return self.common() return self.common()
def process_get(self): def process_get(self):
# generate a new LT if none is present """Analyse the GET request"""
self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()] # generate a new LT
self.gen_lt()
if not self.request.session.get("authenticated") or self.renew: if not self.request.session.get("authenticated") or self.renew:
self.init_form() self.init_form()
return self.USER_NOT_AUTHENTICATED return self.USER_NOT_AUTHENTICATED
return self.USER_AUTHENTICATED return self.USER_AUTHENTICATED
def init_form(self, values=None): def init_form(self, values=None):
"""Initialization of the good form depending of POST and GET parameters"""
self.form = forms.UserCredential( self.form = forms.UserCredential(
values, values,
initial={ initial={
'service': self.service, 'service': self.service,
'method': self.method, 'method': self.method,
'warn': self.request.session.get("warn"), 'warn': self.request.session.get("warn"),
'lt': self.request.session['lt'][-1] 'lt': self.request.session['lt'][-1],
'renew': self.renew
} }
) )
@ -351,7 +367,7 @@ class LoginView(View, LogoutMixin):
redirect_url = self.user.get_service_url( redirect_url = self.user.get_service_url(
self.service, self.service,
service_pattern, service_pattern,
renew=self.renew renew=self.renewed
) )
if not self.ajax: if not self.ajax:
return HttpResponseRedirect(redirect_url) return HttpResponseRedirect(redirect_url)
@ -580,12 +596,9 @@ class Validate(View):
ticket.service_pattern.user_field ticket.service_pattern.user_field
) )
if isinstance(username, list): if isinstance(username, list):
try: # the list is not empty because we wont generate a ticket with a user_field
username = username[0] # that evaluate to False
except IndexError: username = username[0]
username = None
if not username:
username = ""
else: else:
username = ticket.user.username username = ticket.user.username
return HttpResponse("yes\n%s\n" % username, content_type="text/plain") return HttpResponse("yes\n%s\n" % username, content_type="text/plain")
@ -661,6 +674,10 @@ class ValidateService(View, AttributesMixin):
params['username'] = self.ticket.user.attributs.get( params['username'] = self.ticket.user.attributs.get(
self.ticket.service_pattern.user_field self.ticket.service_pattern.user_field
) )
if isinstance(params['username'], list):
# the list is not empty because we wont generate a ticket with a user_field
# that evaluate to False
params['username'] = params['username'][0]
if self.pgt_url and ( if self.pgt_url and (
self.pgt_url.startswith("https://") or self.pgt_url.startswith("https://") or
re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url) re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url)
@ -762,9 +779,12 @@ class ValidateService(View, AttributesMixin):
params, params,
content_type="text/xml; charset=utf-8" content_type="text/xml; charset=utf-8"
) )
except requests.exceptions.SSLError as error: except requests.exceptions.RequestException as error:
error = utils.unpack_nested_exception(error) error = utils.unpack_nested_exception(error)
raise ValidateError('INVALID_PROXY_CALLBACK', str(error)) raise ValidateError(
'INVALID_PROXY_CALLBACK',
"%s: %s" % (type(error), str(error))
)
else: else:
raise ValidateError( raise ValidateError(
'INVALID_PROXY_CALLBACK', 'INVALID_PROXY_CALLBACK',
@ -844,7 +864,7 @@ class Proxy(View):
except (models.BadUsername, models.BadFilter, models.UserFieldNotDefined): except (models.BadUsername, models.BadFilter, models.UserFieldNotDefined):
raise ValidateError( raise ValidateError(
'UNAUTHORIZED_USER', 'UNAUTHORIZED_USER',
'%s not allowed on %s' % (ticket.user, self.target_service) 'User %s not allowed on %s' % (ticket.user.username, self.target_service)
) )
@ -903,11 +923,15 @@ class SamlValidate(View, AttributesMixin):
'username': self.ticket.user.username, 'username': self.ticket.user.username,
'attributes': attributes 'attributes': attributes
} }
if self.ticket.service_pattern.user_field and \ if (self.ticket.service_pattern.user_field and
self.ticket.user.attributs.get(self.ticket.service_pattern.user_field): self.ticket.user.attributs.get(self.ticket.service_pattern.user_field)):
params['username'] = self.ticket.user.attributs.get( params['username'] = self.ticket.user.attributs.get(
self.ticket.service_pattern.user_field self.ticket.service_pattern.user_field
) )
if isinstance(params['username'], list):
# the list is not empty because we wont generate a ticket with a user_field
# that evaluate to False
params['username'] = params['username'][0]
logger.info( logger.info(
"SamlValidate: ticket %s validated for user %s on service %s." % ( "SamlValidate: ticket %s validated for user %s on service %s." % (
self.ticket.value, self.ticket.value,

5
pytest.ini Normal file
View File

@ -0,0 +1,5 @@
[pytest]
testpaths = cas_server/tests/
DJANGO_SETTINGS_MODULE = cas_server.tests.settings
norecursedirs = .* build dist docs
python_paths = .

View File

@ -1,10 +1,12 @@
tox==1.8.1 setuptools>=5.5
pytest==2.6.4 tox>=1.8.1
pytest-django==2.7.0 pytest>=2.6.4
pytest-pythonpath==0.3 pytest-django>=2.8.0
pytest-pythonpath>=0.3
pytest-cov>=2.2.1
requests>=2.4 requests>=2.4
django-picklefield>=0.3.1
requests_futures>=0.9.5 requests_futures>=0.9.5
django-picklefield>=0.3.1
django-bootstrap3>=5.4 django-bootstrap3>=5.4
lxml>=3.4 lxml>=3.4
six>=1 six>=1

View File

@ -1,22 +0,0 @@
#!/usr/bin/env python
import os, sys
import django
from django.conf import settings
import settings_tests
settings.configure(**settings_tests.__dict__)
django.setup()
try:
# Django <= 1.8
from django.test.simple import DjangoTestSuiteRunner
test_runner = DjangoTestSuiteRunner(verbosity=1)
except ImportError:
# Django >= 1.8
from django.test.runner import DiscoverRunner
test_runner = DiscoverRunner(verbosity=1)
failures = test_runner.run_tests(['cas_server'])
if failures:
sys.exit(failures)

View File

@ -34,7 +34,8 @@ setup(
version='0.4.4', version='0.4.4',
packages=[ packages=[
'cas_server', 'cas_server.migrations', 'cas_server', 'cas_server.migrations',
'cas_server.management', 'cas_server.management.commands' 'cas_server.management', 'cas_server.management.commands',
'cas_server.tests'
], ],
include_package_data=True, include_package_data=True,
license='GPLv3', license='GPLv3',

14
tox.ini
View File

@ -1,12 +1,12 @@
[tox] [tox]
envlist= envlist=
flake8,
py27-django17, py27-django17,
py27-django18, py27-django18,
py27-django19, py27-django19,
py34-django17, py34-django17,
py34-django18, py34-django18,
py34-django19, py34-django19,
flake8,
[flake8] [flake8]
max-line-length=100 max-line-length=100
@ -17,7 +17,7 @@ deps =
-r{toxinidir}/requirements-dev.txt -r{toxinidir}/requirements-dev.txt
[testenv] [testenv]
commands=python run_tests {posargs:tests} commands=py.test {posargs:cas_server/tests/}
[testenv:py27-django17] [testenv:py27-django17]
basepython=python2.7 basepython=python2.7
@ -60,3 +60,13 @@ basepython=python
deps=flake8 deps=flake8
commands=flake8 {toxinidir}/cas_server commands=flake8 {toxinidir}/cas_server
[testenv:coverage]
basepython=python
passenv=CODACY_PROJECT_TOKEN
deps=
-r{toxinidir}/requirements-dev.txt
codacy-coverage
commands=
py.test --cov=cas_server --cov-report xml
python-codacy-coverage -r {toxinidir}/coverage.xml