@@ -1,3 +1,11 @@
 | 
				
			|||||||
 | 
					[run]
 | 
				
			||||||
 | 
					branch = True
 | 
				
			||||||
 | 
					source = cas_server
 | 
				
			||||||
 | 
					omit =
 | 
				
			||||||
 | 
					    cas_server/migrations*
 | 
				
			||||||
 | 
					    cas_server/management/*
 | 
				
			||||||
 | 
					    cas_server/tests/*
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[report]
 | 
					[report]
 | 
				
			||||||
exclude_lines =
 | 
					exclude_lines =
 | 
				
			||||||
    pragma: no cover
 | 
					    pragma: no cover
 | 
				
			||||||
@@ -5,3 +13,4 @@ exclude_lines =
 | 
				
			|||||||
    def __unicode__
 | 
					    def __unicode__
 | 
				
			||||||
    raise AssertionError
 | 
					    raise AssertionError
 | 
				
			||||||
    raise NotImplementedError
 | 
					    raise NotImplementedError
 | 
				
			||||||
 | 
					    if six.PY3:
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,19 +2,19 @@ language: python
 | 
				
			|||||||
python:
 | 
					python:
 | 
				
			||||||
  - "2.7"
 | 
					  - "2.7"
 | 
				
			||||||
env:
 | 
					env:
 | 
				
			||||||
  global:
 | 
					 | 
				
			||||||
    - PIP_DOWNLOAD_CACHE=$HOME/.pip_cache
 | 
					 | 
				
			||||||
  matrix:
 | 
					  matrix:
 | 
				
			||||||
 | 
					    - TOX_ENV=coverage
 | 
				
			||||||
 | 
					    - TOX_ENV=flake8
 | 
				
			||||||
    - TOX_ENV=py27-django17
 | 
					    - TOX_ENV=py27-django17
 | 
				
			||||||
    - TOX_ENV=py27-django18
 | 
					    - TOX_ENV=py27-django18
 | 
				
			||||||
    - TOX_ENV=py27-django19
 | 
					    - TOX_ENV=py27-django19
 | 
				
			||||||
    - TOX_ENV=py34-django17
 | 
					    - TOX_ENV=py34-django17
 | 
				
			||||||
    - TOX_ENV=py34-django18
 | 
					    - TOX_ENV=py34-django18
 | 
				
			||||||
    - TOX_ENV=py34-django19
 | 
					    - TOX_ENV=py34-django19
 | 
				
			||||||
    - TOX_ENV=flake8
 | 
					 | 
				
			||||||
cache:
 | 
					cache:
 | 
				
			||||||
  directories:
 | 
					  directories:
 | 
				
			||||||
    - $HOME/.pip-cache/
 | 
					    - $HOME/.cache/pip/
 | 
				
			||||||
 | 
					    - $HOME/build/nitmir/django-cas-server/.tox/
 | 
				
			||||||
install:
 | 
					install:
 | 
				
			||||||
  - "travis_retry pip install setuptools --upgrade"
 | 
					  - "travis_retry pip install setuptools --upgrade"
 | 
				
			||||||
  - "pip install tox"
 | 
					  - "pip install tox"
 | 
				
			||||||
@@ -22,4 +22,3 @@ script:
 | 
				
			|||||||
  - tox -e $TOX_ENV
 | 
					  - tox -e $TOX_ENV
 | 
				
			||||||
after_script:
 | 
					after_script:
 | 
				
			||||||
  - cat .tox/$TOX_ENV/log/*.log
 | 
					  - cat .tox/$TOX_ENV/log/*.log
 | 
				
			||||||
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										45
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										45
									
								
								Makefile
									
									
									
									
									
								
							@@ -1,11 +1,15 @@
 | 
				
			|||||||
.PHONY: clean build install dist test_venv test_project
 | 
					.PHONY: build dist
 | 
				
			||||||
VERSION=`python setup.py -V`
 | 
					VERSION=`python setup.py -V`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
build:
 | 
					build:
 | 
				
			||||||
	python setup.py build
 | 
						python setup.py build
 | 
				
			||||||
 | 
					
 | 
				
			||||||
install:
 | 
					install: dist
 | 
				
			||||||
	python setup.py install
 | 
						pip -V
 | 
				
			||||||
 | 
						pip install --no-deps --upgrade --force-reinstall --find-links ./dist/django-cas-server-${VERSION}.tar.gz django-cas-server
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					uninstall:
 | 
				
			||||||
 | 
						pip uninstall django-cas-server || true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
clean_pyc:
 | 
					clean_pyc:
 | 
				
			||||||
	find ./ -name '*.pyc' -delete
 | 
						find ./ -name '*.pyc' -delete
 | 
				
			||||||
@@ -16,18 +20,23 @@ clean_tox:
 | 
				
			|||||||
	rm -rf .tox
 | 
						rm -rf .tox
 | 
				
			||||||
clean_test_venv:
 | 
					clean_test_venv:
 | 
				
			||||||
	rm -rf test_venv
 | 
						rm -rf test_venv
 | 
				
			||||||
clean: clean_pyc clean_build
 | 
					clean_coverage:
 | 
				
			||||||
clean_all: clean_pyc clean_build clean_tox clean_test_venv
 | 
						rm -rf coverage.xml .coverage htmlcov
 | 
				
			||||||
 | 
					clean_tild_backup:
 | 
				
			||||||
 | 
						find ./ -name '*~' -delete
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					clean: clean_pyc clean_build clean_coverage clean_tild_backup
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					clean_all: clean clean_tox clean_test_venv
 | 
				
			||||||
 | 
					
 | 
				
			||||||
dist:
 | 
					dist:
 | 
				
			||||||
	python setup.py sdist
 | 
						python setup.py sdist
 | 
				
			||||||
 | 
					
 | 
				
			||||||
test_venv:
 | 
					test_venv/bin/python:
 | 
				
			||||||
	mkdir -p test_venv
 | 
					 | 
				
			||||||
	virtualenv test_venv
 | 
						virtualenv test_venv
 | 
				
			||||||
	test_venv/bin/pip install -U --requirement requirements.txt
 | 
						test_venv/bin/pip install -U --requirement requirements-dev.txt Django
 | 
				
			||||||
 | 
					
 | 
				
			||||||
test_venv/cas/manage.py:
 | 
					test_venv/cas/manage.py: test_venv
 | 
				
			||||||
	mkdir -p test_venv/cas
 | 
						mkdir -p test_venv/cas
 | 
				
			||||||
	test_venv/bin/django-admin startproject cas test_venv/cas
 | 
						test_venv/bin/django-admin startproject cas test_venv/cas
 | 
				
			||||||
	ln -s ../../cas_server test_venv/cas/cas_server
 | 
						ln -s ../../cas_server test_venv/cas/cas_server
 | 
				
			||||||
@@ -38,19 +47,15 @@ test_venv/cas/manage.py:
 | 
				
			|||||||
	test_venv/bin/python test_venv/cas/manage.py migrate
 | 
						test_venv/bin/python test_venv/cas/manage.py migrate
 | 
				
			||||||
	test_venv/bin/python test_venv/cas/manage.py createsuperuser
 | 
						test_venv/bin/python test_venv/cas/manage.py createsuperuser
 | 
				
			||||||
 | 
					
 | 
				
			||||||
test_project: test_venv test_venv/cas/manage.py
 | 
					test_venv: test_venv/bin/python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					test_project: test_venv/cas/manage.py
 | 
				
			||||||
	@echo "##############################################################"
 | 
						@echo "##############################################################"
 | 
				
			||||||
	@echo "A test django project was created in $(realpath test_venv/cas)"
 | 
						@echo "A test django project was created in $(realpath test_venv/cas)"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
run_test_server: test_project
 | 
					run_server: test_project
 | 
				
			||||||
	test_venv/bin/python test_venv/cas/manage.py runserver
 | 
						test_venv/bin/python test_venv/cas/manage.py runserver
 | 
				
			||||||
 | 
					
 | 
				
			||||||
coverage: test_venv
 | 
					run_tests: test_venv
 | 
				
			||||||
	test_venv/bin/pip install coverage
 | 
						test_venv/bin/py.test --cov=cas_server --cov-report html
 | 
				
			||||||
	test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests
 | 
						rm htmlcov/coverage_html.js  # I am really pissed off by those keybord shortcuts
 | 
				
			||||||
	test_venv/bin/coverage html
 | 
					 | 
				
			||||||
	test_venv/bin/coverage xml
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
coverage_codacy: coverage
 | 
					 | 
				
			||||||
	test_venv/bin/pip install codacy-coverage
 | 
					 | 
				
			||||||
	test_venv/bin/python-codacy-coverage -r coverage.xml
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -219,7 +219,8 @@ Test backend settings. Only usefull if you are using the test authentication bac
 | 
				
			|||||||
* ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``.
 | 
					* ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``.
 | 
				
			||||||
* ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``.
 | 
					* ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``.
 | 
				
			||||||
* ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is
 | 
					* ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is
 | 
				
			||||||
  ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}``.
 | 
					  ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net',
 | 
				
			||||||
 | 
					  'alias': ['demo1', 'demo2']}``.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Authentication backend
 | 
					Authentication backend
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,6 +7,6 @@
 | 
				
			|||||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
					# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
				
			||||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
					# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# (c) 2015 Valentin Samir
 | 
					# (c) 2015-2016 Valentin Samir
 | 
				
			||||||
 | 
					"""A django CAS server application"""
 | 
				
			||||||
default_app_config = 'cas_server.apps.CasAppConfig'
 | 
					default_app_config = 'cas_server.apps.CasAppConfig'
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,7 +7,7 @@
 | 
				
			|||||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
					# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
				
			||||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
					# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# (c) 2015 Valentin Samir
 | 
					# (c) 2015-2016 Valentin Samir
 | 
				
			||||||
"""module for the admin interface of the app"""
 | 
					"""module for the admin interface of the app"""
 | 
				
			||||||
from django.contrib import admin
 | 
					from django.contrib import admin
 | 
				
			||||||
from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket, User, ServicePattern
 | 
					from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket, User, ServicePattern
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,7 +1,19 @@
 | 
				
			|||||||
 | 
					# This program is distributed in the hope that it will be useful, but WITHOUT
 | 
				
			||||||
 | 
					# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 | 
				
			||||||
 | 
					# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
 | 
				
			||||||
 | 
					# more details.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# You should have received a copy of the GNU General Public License version 3
 | 
				
			||||||
 | 
					# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
				
			||||||
 | 
					# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# (c) 2015-2016 Valentin Samir
 | 
				
			||||||
 | 
					"""django config module"""
 | 
				
			||||||
from django.utils.translation import ugettext_lazy as _
 | 
					from django.utils.translation import ugettext_lazy as _
 | 
				
			||||||
from django.apps import AppConfig
 | 
					from django.apps import AppConfig
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class CasAppConfig(AppConfig):
 | 
					class CasAppConfig(AppConfig):
 | 
				
			||||||
 | 
					    """django CAS application config class"""
 | 
				
			||||||
    name = 'cas_server'
 | 
					    name = 'cas_server'
 | 
				
			||||||
    verbose_name = _('Central Authentication Service')
 | 
					    verbose_name = _('Central Authentication Service')
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,7 +8,7 @@
 | 
				
			|||||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
					# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
				
			||||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
					# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# (c) 2015 Valentin Samir
 | 
					# (c) 2015-2016 Valentin Samir
 | 
				
			||||||
"""Some authentication classes for the CAS"""
 | 
					"""Some authentication classes for the CAS"""
 | 
				
			||||||
from django.conf import settings
 | 
					from django.conf import settings
 | 
				
			||||||
from django.contrib.auth import get_user_model
 | 
					from django.contrib.auth import get_user_model
 | 
				
			||||||
@@ -21,6 +21,7 @@ except ImportError:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AuthUser(object):
 | 
					class AuthUser(object):
 | 
				
			||||||
 | 
					    """Authentication base class"""
 | 
				
			||||||
    def __init__(self, username):
 | 
					    def __init__(self, username):
 | 
				
			||||||
        self.username = username
 | 
					        self.username = username
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -78,5 +78,12 @@ setting_default('CAS_TEST_USER', 'test')
 | 
				
			|||||||
setting_default('CAS_TEST_PASSWORD', 'test')
 | 
					setting_default('CAS_TEST_PASSWORD', 'test')
 | 
				
			||||||
setting_default(
 | 
					setting_default(
 | 
				
			||||||
    'CAS_TEST_ATTRIBUTES',
 | 
					    'CAS_TEST_ATTRIBUTES',
 | 
				
			||||||
    {'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}
 | 
					    {
 | 
				
			||||||
 | 
					        'nom': 'Nymous',
 | 
				
			||||||
 | 
					        'prenom': 'Ano',
 | 
				
			||||||
 | 
					        'email': 'anonymous@example.net',
 | 
				
			||||||
 | 
					        'alias': ['demo1', 'demo2']
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					setting_default('CAS_ENABLE_AJAX_AUTH', False)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,7 +7,7 @@
 | 
				
			|||||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
					# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
				
			||||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
					# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# (c) 2015 Valentin Samir
 | 
					# (c) 2015-2016 Valentin Samir
 | 
				
			||||||
"""forms for the app"""
 | 
					"""forms for the app"""
 | 
				
			||||||
from .default_settings import settings
 | 
					from .default_settings import settings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -19,6 +19,7 @@ import cas_server.models as models
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class WarnForm(forms.Form):
 | 
					class WarnForm(forms.Form):
 | 
				
			||||||
 | 
					    """Form used on warn page before emiting a ticket"""
 | 
				
			||||||
    service = forms.CharField(widget=forms.HiddenInput(), required=False)
 | 
					    service = forms.CharField(widget=forms.HiddenInput(), required=False)
 | 
				
			||||||
    renew = forms.BooleanField(widget=forms.HiddenInput(), required=False)
 | 
					    renew = forms.BooleanField(widget=forms.HiddenInput(), required=False)
 | 
				
			||||||
    gateway = forms.CharField(widget=forms.HiddenInput(), required=False)
 | 
					    gateway = forms.CharField(widget=forms.HiddenInput(), required=False)
 | 
				
			||||||
@@ -35,6 +36,7 @@ class UserCredential(forms.Form):
 | 
				
			|||||||
    lt = forms.CharField(widget=forms.HiddenInput(), required=False)
 | 
					    lt = forms.CharField(widget=forms.HiddenInput(), required=False)
 | 
				
			||||||
    method = forms.CharField(widget=forms.HiddenInput(), required=False)
 | 
					    method = forms.CharField(widget=forms.HiddenInput(), required=False)
 | 
				
			||||||
    warn = forms.BooleanField(label=_('warn'), required=False)
 | 
					    warn = forms.BooleanField(label=_('warn'), required=False)
 | 
				
			||||||
 | 
					    renew = forms.BooleanField(widget=forms.HiddenInput(), required=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, *args, **kwargs):
 | 
					    def __init__(self, *args, **kwargs):
 | 
				
			||||||
        super(UserCredential, self).__init__(*args, **kwargs)
 | 
					        super(UserCredential, self).__init__(*args, **kwargs)
 | 
				
			||||||
@@ -46,6 +48,7 @@ class UserCredential(forms.Form):
 | 
				
			|||||||
            cleaned_data["username"] = auth.username
 | 
					            cleaned_data["username"] = auth.username
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            raise forms.ValidationError(_(u"Bad user"))
 | 
					            raise forms.ValidationError(_(u"Bad user"))
 | 
				
			||||||
 | 
					        return cleaned_data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TicketForm(forms.ModelForm):
 | 
					class TicketForm(forms.ModelForm):
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,3 +1,4 @@
 | 
				
			|||||||
 | 
					"""Clean deleted sessions management command"""
 | 
				
			||||||
from django.core.management.base import BaseCommand
 | 
					from django.core.management.base import BaseCommand
 | 
				
			||||||
from django.utils.translation import ugettext_lazy as _
 | 
					from django.utils.translation import ugettext_lazy as _
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -5,6 +6,7 @@ from ... import models
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Command(BaseCommand):
 | 
					class Command(BaseCommand):
 | 
				
			||||||
 | 
					    """Clean deleted sessions"""
 | 
				
			||||||
    args = ''
 | 
					    args = ''
 | 
				
			||||||
    help = _(u"Clean deleted sessions")
 | 
					    help = _(u"Clean deleted sessions")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,3 +1,4 @@
 | 
				
			|||||||
 | 
					"""Clean old trickets management command"""
 | 
				
			||||||
from django.core.management.base import BaseCommand
 | 
					from django.core.management.base import BaseCommand
 | 
				
			||||||
from django.utils.translation import ugettext_lazy as _
 | 
					from django.utils.translation import ugettext_lazy as _
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -5,6 +6,7 @@ from ... import models
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Command(BaseCommand):
 | 
					class Command(BaseCommand):
 | 
				
			||||||
 | 
					    """Clean old trickets"""
 | 
				
			||||||
    args = ''
 | 
					    args = ''
 | 
				
			||||||
    help = _(u"Clean old trickets")
 | 
					    help = _(u"Clean old trickets")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,7 +8,7 @@
 | 
				
			|||||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
					# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
				
			||||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
					# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# (c) 2015 Valentin Samir
 | 
					# (c) 2015-2016 Valentin Samir
 | 
				
			||||||
"""models for the app"""
 | 
					"""models for the app"""
 | 
				
			||||||
from .default_settings import settings
 | 
					from .default_settings import settings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -20,7 +20,6 @@ from django.utils import timezone
 | 
				
			|||||||
from picklefield.fields import PickledObjectField
 | 
					from picklefield.fields import PickledObjectField
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import re
 | 
					import re
 | 
				
			||||||
import os
 | 
					 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
from importlib import import_module
 | 
					from importlib import import_module
 | 
				
			||||||
@@ -47,6 +46,7 @@ class User(models.Model):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def clean_old_entries(cls):
 | 
					    def clean_old_entries(cls):
 | 
				
			||||||
 | 
					        """Remove users inactive since more that SESSION_COOKIE_AGE"""
 | 
				
			||||||
        users = cls.objects.filter(
 | 
					        users = cls.objects.filter(
 | 
				
			||||||
            date__lt=(timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE))
 | 
					            date__lt=(timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE))
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
@@ -56,6 +56,7 @@ class User(models.Model):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def clean_deleted_sessions(cls):
 | 
					    def clean_deleted_sessions(cls):
 | 
				
			||||||
 | 
					        """Remove user where the session do not exists anymore"""
 | 
				
			||||||
        for user in cls.objects.all():
 | 
					        for user in cls.objects.all():
 | 
				
			||||||
            if not SessionStore(session_key=user.session_key).get('authenticated'):
 | 
					            if not SessionStore(session_key=user.session_key).get('authenticated'):
 | 
				
			||||||
                user.logout()
 | 
					                user.logout()
 | 
				
			||||||
@@ -80,10 +81,10 @@ class User(models.Model):
 | 
				
			|||||||
        for ticket_class in ticket_classes:
 | 
					        for ticket_class in ticket_classes:
 | 
				
			||||||
            queryset = ticket_class.objects.filter(user=self)
 | 
					            queryset = ticket_class.objects.filter(user=self)
 | 
				
			||||||
            for ticket in queryset:
 | 
					            for ticket in queryset:
 | 
				
			||||||
                ticket.logout(request, session, async_list)
 | 
					                ticket.logout(session, async_list)
 | 
				
			||||||
            queryset.delete()
 | 
					            queryset.delete()
 | 
				
			||||||
        for future in async_list:
 | 
					        for future in async_list:
 | 
				
			||||||
            if future:
 | 
					            if future:  # pragma: no branch (should always be true)
 | 
				
			||||||
                try:
 | 
					                try:
 | 
				
			||||||
                    future.result()
 | 
					                    future.result()
 | 
				
			||||||
                except Exception as error:
 | 
					                except Exception as error:
 | 
				
			||||||
@@ -111,13 +112,21 @@ class User(models.Model):
 | 
				
			|||||||
            (a.name, a.replace if a.replace else a.name) for a in service_pattern.attributs.all()
 | 
					            (a.name, a.replace if a.replace else a.name) for a in service_pattern.attributs.all()
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        replacements = dict(
 | 
					        replacements = dict(
 | 
				
			||||||
            (a.name, (a.pattern, a.replace)) for a in service_pattern.replacements.all()
 | 
					            (a.attribut, (a.pattern, a.replace)) for a in service_pattern.replacements.all()
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        service_attributs = {}
 | 
					        service_attributs = {}
 | 
				
			||||||
        for (key, value) in self.attributs.items():
 | 
					        for (key, value) in self.attributs.items():
 | 
				
			||||||
            if key in attributs or '*' in attributs:
 | 
					            if key in attributs or '*' in attributs:
 | 
				
			||||||
                if key in replacements:
 | 
					                if key in replacements:
 | 
				
			||||||
                    value = re.sub(replacements[key][0], replacements[key][1], value)
 | 
					                    if isinstance(value, list):
 | 
				
			||||||
 | 
					                        for index, subval in enumerate(value):
 | 
				
			||||||
 | 
					                            value[index] = re.sub(
 | 
				
			||||||
 | 
					                                replacements[key][0],
 | 
				
			||||||
 | 
					                                replacements[key][1],
 | 
				
			||||||
 | 
					                                subval
 | 
				
			||||||
 | 
					                            )
 | 
				
			||||||
 | 
					                    else:
 | 
				
			||||||
 | 
					                        value = re.sub(replacements[key][0], replacements[key][1], value)
 | 
				
			||||||
                service_attributs[attributs.get(key, key)] = value
 | 
					                service_attributs[attributs.get(key, key)] = value
 | 
				
			||||||
        ticket = ticket_class.objects.create(
 | 
					        ticket = ticket_class.objects.create(
 | 
				
			||||||
            user=self,
 | 
					            user=self,
 | 
				
			||||||
@@ -141,6 +150,7 @@ class User(models.Model):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ServicePatternException(Exception):
 | 
					class ServicePatternException(Exception):
 | 
				
			||||||
 | 
					    """Base exception of exceptions raised in the ServicePattern model"""
 | 
				
			||||||
    pass
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -394,77 +404,57 @@ class Ticket(models.Model):
 | 
				
			|||||||
        ).delete()
 | 
					        ).delete()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # sending SLO to timed-out validated tickets
 | 
					        # sending SLO to timed-out validated tickets
 | 
				
			||||||
        if cls.TIMEOUT and cls.TIMEOUT > 0:
 | 
					        async_list = []
 | 
				
			||||||
            async_list = []
 | 
					        session = FuturesSession(
 | 
				
			||||||
            session = FuturesSession(
 | 
					            executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
 | 
				
			||||||
                executor=ThreadPoolExecutor(max_workers=settings.CAS_SLO_MAX_PARALLEL_REQUESTS)
 | 
					        )
 | 
				
			||||||
            )
 | 
					        queryset = cls.objects.filter(
 | 
				
			||||||
            queryset = cls.objects.filter(
 | 
					            creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT))
 | 
				
			||||||
                creation__lt=(timezone.now() - timedelta(seconds=cls.TIMEOUT))
 | 
					        )
 | 
				
			||||||
            )
 | 
					        for ticket in queryset:
 | 
				
			||||||
            for ticket in queryset:
 | 
					            ticket.logout(session, async_list)
 | 
				
			||||||
                ticket.logout(None, session, async_list)
 | 
					        queryset.delete()
 | 
				
			||||||
            queryset.delete()
 | 
					        for future in async_list:
 | 
				
			||||||
            for future in async_list:
 | 
					            if future:  # pragma: no branch (should always be true)
 | 
				
			||||||
                if future:
 | 
					                try:
 | 
				
			||||||
                    try:
 | 
					                    future.result()
 | 
				
			||||||
                        future.result()
 | 
					                except Exception as error:
 | 
				
			||||||
                    except Exception as error:
 | 
					                    logger.warning("Error durring SLO %s" % error)
 | 
				
			||||||
                        logger.warning("Error durring SLO %s" % error)
 | 
					                    sys.stderr.write("%r\n" % error)
 | 
				
			||||||
                        sys.stderr.write("%r\n" % error)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def logout(self, request, session, async_list=None):
 | 
					    def logout(self, session, async_list=None):
 | 
				
			||||||
        """Send a SLO request to the ticket service"""
 | 
					        """Send a SLO request to the ticket service"""
 | 
				
			||||||
        # On logout invalidate the Ticket
 | 
					        # On logout invalidate the Ticket
 | 
				
			||||||
        self.validate = True
 | 
					        self.validate = True
 | 
				
			||||||
        self.save()
 | 
					        self.save()
 | 
				
			||||||
        if self.validate and self.single_log_out:
 | 
					        if self.validate and self.single_log_out:  # pragma: no branch (should always be true)
 | 
				
			||||||
            logger.info(
 | 
					            logger.info(
 | 
				
			||||||
                "Sending SLO requests to service %s for user %s" % (
 | 
					                "Sending SLO requests to service %s for user %s" % (
 | 
				
			||||||
                    self.service,
 | 
					                    self.service,
 | 
				
			||||||
                    self.user.username
 | 
					                    self.user.username
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            try:
 | 
					            xml = u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
 | 
				
			||||||
                xml = u"""<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
 | 
					 ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
 | 
				
			||||||
     ID="%(id)s" Version="2.0" IssueInstant="%(datetime)s">
 | 
					<saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
 | 
				
			||||||
    <saml:NameID xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"></saml:NameID>
 | 
					<samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
 | 
				
			||||||
    <samlp:SessionIndex>%(ticket)s</samlp:SessionIndex>
 | 
					</samlp:LogoutRequest>""" % \
 | 
				
			||||||
  </samlp:LogoutRequest>""" % \
 | 
					                {
 | 
				
			||||||
                    {
 | 
					                    'id': utils.gen_saml_id(),
 | 
				
			||||||
                        'id': os.urandom(20).encode("hex"),
 | 
					                    'datetime': timezone.now().isoformat(),
 | 
				
			||||||
                        'datetime': timezone.now().isoformat(),
 | 
					                    'ticket':  self.value
 | 
				
			||||||
                        'ticket':  self.value
 | 
					                }
 | 
				
			||||||
                    }
 | 
					            if self.service_pattern.single_log_out_callback:
 | 
				
			||||||
                if self.service_pattern.single_log_out_callback:
 | 
					                url = self.service_pattern.single_log_out_callback
 | 
				
			||||||
                    url = self.service_pattern.single_log_out_callback
 | 
					            else:
 | 
				
			||||||
                else:
 | 
					                url = self.service
 | 
				
			||||||
                    url = self.service
 | 
					            async_list.append(
 | 
				
			||||||
                async_list.append(
 | 
					                session.post(
 | 
				
			||||||
                    session.post(
 | 
					                    url.encode('utf-8'),
 | 
				
			||||||
                        url.encode('utf-8'),
 | 
					                    data={'logoutRequest': xml.encode('utf-8')},
 | 
				
			||||||
                        data={'logoutRequest': xml.encode('utf-8')},
 | 
					                    timeout=settings.CAS_SLO_TIMEOUT
 | 
				
			||||||
                        timeout=settings.CAS_SLO_TIMEOUT
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
            except Exception as error:
 | 
					            )
 | 
				
			||||||
                error = utils.unpack_nested_exception(error)
 | 
					 | 
				
			||||||
                logger.warning(
 | 
					 | 
				
			||||||
                    "Error durring SLO for user %s on service %s: %s" % (
 | 
					 | 
				
			||||||
                        self.user.username,
 | 
					 | 
				
			||||||
                        self.service,
 | 
					 | 
				
			||||||
                        error
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
                if request is not None:
 | 
					 | 
				
			||||||
                    messages.add_message(
 | 
					 | 
				
			||||||
                        request,
 | 
					 | 
				
			||||||
                        messages.WARNING,
 | 
					 | 
				
			||||||
                        _(u'Error during service logout %(service)s:\n%(error)s') %
 | 
					 | 
				
			||||||
                        {'service':  self.service, 'error': error}
 | 
					 | 
				
			||||||
                    )
 | 
					 | 
				
			||||||
                else:
 | 
					 | 
				
			||||||
                    sys.stderr.write("%r\n" % error)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ServiceTicket(Ticket):
 | 
					class ServiceTicket(Ticket):
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,702 +0,0 @@
 | 
				
			|||||||
from .default_settings import settings
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from django.test import TestCase
 | 
					 | 
				
			||||||
from django.test import Client
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import six
 | 
					 | 
				
			||||||
from lxml import etree
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from cas_server import models
 | 
					 | 
				
			||||||
from cas_server import utils
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_login_page_params():
 | 
					 | 
				
			||||||
    client = Client()
 | 
					 | 
				
			||||||
    response = client.get('/login')
 | 
					 | 
				
			||||||
    form = response.context["form"]
 | 
					 | 
				
			||||||
    params = {}
 | 
					 | 
				
			||||||
    for field in form:
 | 
					 | 
				
			||||||
        if field.value():
 | 
					 | 
				
			||||||
            params[field.name] = field.value()
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            params[field.name] = ""
 | 
					 | 
				
			||||||
    return client, params
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_auth_client():
 | 
					 | 
				
			||||||
    client, params = get_login_page_params()
 | 
					 | 
				
			||||||
    params["username"] = settings.CAS_TEST_USER
 | 
					 | 
				
			||||||
    params["password"] = settings.CAS_TEST_PASSWORD
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    client.post('/login', params)
 | 
					 | 
				
			||||||
    return client
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_user_ticket_request(service):
 | 
					 | 
				
			||||||
    client = get_auth_client()
 | 
					 | 
				
			||||||
    response = client.get("/login", {"service": service})
 | 
					 | 
				
			||||||
    ticket_value = response['Location'].split('ticket=')[-1]
 | 
					 | 
				
			||||||
    user = models.User.objects.get(
 | 
					 | 
				
			||||||
        username=settings.CAS_TEST_USER,
 | 
					 | 
				
			||||||
        session_key=client.session.session_key
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    ticket = models.ServiceTicket.objects.get(value=ticket_value)
 | 
					 | 
				
			||||||
    return (user, ticket)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_pgt():
 | 
					 | 
				
			||||||
    (host, port) = utils.PGTUrlHandler.run()[1:3]
 | 
					 | 
				
			||||||
    service = "http://%s:%s" % (host, port)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    (user, ticket) = get_user_ticket_request(service)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    client = Client()
 | 
					 | 
				
			||||||
    client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
 | 
					 | 
				
			||||||
    params = utils.PGTUrlHandler.PARAMS.copy()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    params["service"] = service
 | 
					 | 
				
			||||||
    params["user"] = user
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return params
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class CheckPasswordCase(TestCase):
 | 
					 | 
				
			||||||
    """Tests for the utils function `utils.check_password`"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def setUp(self):
 | 
					 | 
				
			||||||
        """Generate random bytes string that will be used ass passwords"""
 | 
					 | 
				
			||||||
        self.password1 = utils.gen_saml_id()
 | 
					 | 
				
			||||||
        self.password2 = utils.gen_saml_id()
 | 
					 | 
				
			||||||
        if not isinstance(self.password1, bytes):
 | 
					 | 
				
			||||||
            self.password1 = self.password1.encode("utf8")
 | 
					 | 
				
			||||||
            self.password2 = self.password2.encode("utf8")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_setup(self):
 | 
					 | 
				
			||||||
        """check that generated password are bytes"""
 | 
					 | 
				
			||||||
        self.assertIsInstance(self.password1, bytes)
 | 
					 | 
				
			||||||
        self.assertIsInstance(self.password2, bytes)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_plain(self):
 | 
					 | 
				
			||||||
        """test the plain auth method"""
 | 
					 | 
				
			||||||
        self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
 | 
					 | 
				
			||||||
        self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_crypt(self):
 | 
					 | 
				
			||||||
        """test the crypt auth method"""
 | 
					 | 
				
			||||||
        if six.PY3:
 | 
					 | 
				
			||||||
            hashed_password1 = utils.crypt.crypt(
 | 
					 | 
				
			||||||
                self.password1.decode("utf8"),
 | 
					 | 
				
			||||||
                "$6$UVVAQvrMyXMF3FF3"
 | 
					 | 
				
			||||||
            ).encode("utf8")
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            hashed_password1 = utils.crypt.crypt(self.password1, "$6$UVVAQvrMyXMF3FF3")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.assertTrue(utils.check_password("crypt", self.password1, hashed_password1, "utf8"))
 | 
					 | 
				
			||||||
        self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8"))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_ldap_ssha(self):
 | 
					 | 
				
			||||||
        """test the ldap auth method with a {SSHA} scheme"""
 | 
					 | 
				
			||||||
        salt = b"UVVAQvrMyXMF3FF3"
 | 
					 | 
				
			||||||
        hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.assertIsInstance(hashed_password1, bytes)
 | 
					 | 
				
			||||||
        self.assertTrue(utils.check_password("ldap", self.password1, hashed_password1, "utf8"))
 | 
					 | 
				
			||||||
        self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8"))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_hex_md5(self):
 | 
					 | 
				
			||||||
        """test the hex_md5 auth method"""
 | 
					 | 
				
			||||||
        hashed_password1 = utils.hashlib.md5(self.password1).hexdigest()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8"))
 | 
					 | 
				
			||||||
        self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8"))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_hox_sha512(self):
 | 
					 | 
				
			||||||
        """test the hex_sha512 auth method"""
 | 
					 | 
				
			||||||
        hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.assertTrue(
 | 
					 | 
				
			||||||
            utils.check_password("hex_sha512", self.password1, hashed_password1, "utf8")
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertFalse(
 | 
					 | 
				
			||||||
            utils.check_password("hex_sha512", self.password2, hashed_password1, "utf8")
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class LoginTestCase(TestCase):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def setUp(self):
 | 
					 | 
				
			||||||
        settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
 | 
					 | 
				
			||||||
        self.service_pattern = models.ServicePattern.objects.create(
 | 
					 | 
				
			||||||
            name="example",
 | 
					 | 
				
			||||||
            pattern="^https://www\.example\.com(/.*)?$",
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_login_view_post_goodpass_goodlt(self):
 | 
					 | 
				
			||||||
        client, params = get_login_page_params()
 | 
					 | 
				
			||||||
        params["username"] = settings.CAS_TEST_USER
 | 
					 | 
				
			||||||
        params["password"] = settings.CAS_TEST_PASSWORD
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        response = client.post('/login', params)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
        self.assertTrue(
 | 
					 | 
				
			||||||
            (
 | 
					 | 
				
			||||||
                b"You have successfully logged into "
 | 
					 | 
				
			||||||
                b"the Central Authentication Service"
 | 
					 | 
				
			||||||
            ) in response.content
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.assertTrue(
 | 
					 | 
				
			||||||
            models.User.objects.get(
 | 
					 | 
				
			||||||
                username=settings.CAS_TEST_USER,
 | 
					 | 
				
			||||||
                session_key=client.session.session_key
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_login_view_post_badlt(self):
 | 
					 | 
				
			||||||
        client, params = get_login_page_params()
 | 
					 | 
				
			||||||
        params["username"] = settings.CAS_TEST_USER
 | 
					 | 
				
			||||||
        params["password"] = settings.CAS_TEST_PASSWORD
 | 
					 | 
				
			||||||
        params["lt"] = 'LT-random'
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        response = client.post('/login', params)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
        self.assertTrue(b"Invalid login ticket" in response.content)
 | 
					 | 
				
			||||||
        self.assertFalse(
 | 
					 | 
				
			||||||
            (
 | 
					 | 
				
			||||||
                b"You have successfully logged into "
 | 
					 | 
				
			||||||
                b"the Central Authentication Service"
 | 
					 | 
				
			||||||
            ) in response.content
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_login_view_post_badpass_good_lt(self):
 | 
					 | 
				
			||||||
        client, params = get_login_page_params()
 | 
					 | 
				
			||||||
        params["username"] = settings.CAS_TEST_USER
 | 
					 | 
				
			||||||
        params["password"] = "test2"
 | 
					 | 
				
			||||||
        response = client.post('/login', params)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
        self.assertTrue(
 | 
					 | 
				
			||||||
            (
 | 
					 | 
				
			||||||
                b"The credentials you provided cannot be "
 | 
					 | 
				
			||||||
                b"determined to be authentic"
 | 
					 | 
				
			||||||
            ) in response.content
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertFalse(
 | 
					 | 
				
			||||||
            (
 | 
					 | 
				
			||||||
                b"You have successfully logged into "
 | 
					 | 
				
			||||||
                b"the Central Authentication Service"
 | 
					 | 
				
			||||||
            ) in response.content
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_view_login_get_auth_allowed_service(self):
 | 
					 | 
				
			||||||
        client = get_auth_client()
 | 
					 | 
				
			||||||
        response = client.get("/login?service=https://www.example.com")
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 302)
 | 
					 | 
				
			||||||
        self.assertTrue(response.has_header('Location'))
 | 
					 | 
				
			||||||
        self.assertTrue(
 | 
					 | 
				
			||||||
            response['Location'].startswith(
 | 
					 | 
				
			||||||
                "https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        ticket_value = response['Location'].split('ticket=')[-1]
 | 
					 | 
				
			||||||
        user = models.User.objects.get(
 | 
					 | 
				
			||||||
            username=settings.CAS_TEST_USER,
 | 
					 | 
				
			||||||
            session_key=client.session.session_key
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertTrue(user)
 | 
					 | 
				
			||||||
        ticket = models.ServiceTicket.objects.get(value=ticket_value)
 | 
					 | 
				
			||||||
        self.assertEqual(ticket.user, user)
 | 
					 | 
				
			||||||
        self.assertEqual(ticket.attributs, settings.CAS_TEST_ATTRIBUTES)
 | 
					 | 
				
			||||||
        self.assertEqual(ticket.validate, False)
 | 
					 | 
				
			||||||
        self.assertEqual(ticket.service_pattern, self.service_pattern)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_view_login_get_auth_denied_service(self):
 | 
					 | 
				
			||||||
        client = get_auth_client()
 | 
					 | 
				
			||||||
        response = client.get("/login?service=https://www.example.org")
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
        self.assertTrue(b"Service https://www.example.org non allowed" in response.content)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class LogoutTestCase(TestCase):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def setUp(self):
 | 
					 | 
				
			||||||
        settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_logout_view(self):
 | 
					 | 
				
			||||||
        client = get_auth_client()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        response = client.get("/login")
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
        self.assertTrue(
 | 
					 | 
				
			||||||
            (
 | 
					 | 
				
			||||||
                b"You have successfully logged into "
 | 
					 | 
				
			||||||
                b"the Central Authentication Service"
 | 
					 | 
				
			||||||
            ) in response.content
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        response = client.get("/logout")
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
        self.assertTrue(
 | 
					 | 
				
			||||||
            (
 | 
					 | 
				
			||||||
                b"You have successfully logged out from "
 | 
					 | 
				
			||||||
                b"the Central Authentication Service"
 | 
					 | 
				
			||||||
            ) in response.content
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        response = client.get("/login")
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
        self.assertFalse(
 | 
					 | 
				
			||||||
            (
 | 
					 | 
				
			||||||
                b"You have successfully logged into "
 | 
					 | 
				
			||||||
                b"the Central Authentication Service"
 | 
					 | 
				
			||||||
            ) in response.content
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_logout_view_url(self):
 | 
					 | 
				
			||||||
        client = get_auth_client()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        response = client.get('/logout?url=https://www.example.com')
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 302)
 | 
					 | 
				
			||||||
        self.assertTrue(response.has_header("Location"))
 | 
					 | 
				
			||||||
        self.assertEqual(response["Location"], "https://www.example.com")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        response = client.get("/login")
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
        self.assertFalse(
 | 
					 | 
				
			||||||
            (
 | 
					 | 
				
			||||||
                b"You have successfully logged into "
 | 
					 | 
				
			||||||
                b"the Central Authentication Service"
 | 
					 | 
				
			||||||
            ) in response.content
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_logout_view_service(self):
 | 
					 | 
				
			||||||
        client = get_auth_client()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        response = client.get('/logout?service=https://www.example.com')
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 302)
 | 
					 | 
				
			||||||
        self.assertTrue(response.has_header("Location"))
 | 
					 | 
				
			||||||
        self.assertEqual(response["Location"], "https://www.example.com")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        response = client.get("/login")
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
        self.assertFalse(
 | 
					 | 
				
			||||||
            (
 | 
					 | 
				
			||||||
                b"You have successfully logged into "
 | 
					 | 
				
			||||||
                b"the Central Authentication Service"
 | 
					 | 
				
			||||||
            ) in response.content
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class AuthTestCase(TestCase):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def setUp(self):
 | 
					 | 
				
			||||||
        settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
 | 
					 | 
				
			||||||
        self.service = 'https://www.example.com'
 | 
					 | 
				
			||||||
        models.ServicePattern.objects.create(
 | 
					 | 
				
			||||||
            name="example",
 | 
					 | 
				
			||||||
            pattern="^https://www\.example\.com(/.*)?$"
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_auth_view_goodpass(self):
 | 
					 | 
				
			||||||
        settings.CAS_AUTH_SHARED_SECRET = 'test'
 | 
					 | 
				
			||||||
        client = Client()
 | 
					 | 
				
			||||||
        response = client.post(
 | 
					 | 
				
			||||||
            '/auth',
 | 
					 | 
				
			||||||
            {
 | 
					 | 
				
			||||||
                'username': settings.CAS_TEST_USER,
 | 
					 | 
				
			||||||
                'password': settings.CAS_TEST_PASSWORD,
 | 
					 | 
				
			||||||
                'service': self.service,
 | 
					 | 
				
			||||||
                'secret': 'test'
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
        self.assertEqual(response.content, b'yes\n')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_auth_view_badpass(self):
 | 
					 | 
				
			||||||
        settings.CAS_AUTH_SHARED_SECRET = 'test'
 | 
					 | 
				
			||||||
        client = Client()
 | 
					 | 
				
			||||||
        response = client.post(
 | 
					 | 
				
			||||||
            '/auth',
 | 
					 | 
				
			||||||
            {
 | 
					 | 
				
			||||||
                'username': settings.CAS_TEST_USER,
 | 
					 | 
				
			||||||
                'password': 'badpass',
 | 
					 | 
				
			||||||
                'service': self.service,
 | 
					 | 
				
			||||||
                'secret': 'test'
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
        self.assertEqual(response.content, b'no\n')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_auth_view_badservice(self):
 | 
					 | 
				
			||||||
        settings.CAS_AUTH_SHARED_SECRET = 'test'
 | 
					 | 
				
			||||||
        client = Client()
 | 
					 | 
				
			||||||
        response = client.post(
 | 
					 | 
				
			||||||
            '/auth',
 | 
					 | 
				
			||||||
            {
 | 
					 | 
				
			||||||
                'username': settings.CAS_TEST_USER,
 | 
					 | 
				
			||||||
                'password': settings.CAS_TEST_PASSWORD,
 | 
					 | 
				
			||||||
                'service': 'https://www.example.org',
 | 
					 | 
				
			||||||
                'secret': 'test'
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
        self.assertEqual(response.content, b'no\n')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_auth_view_badsecret(self):
 | 
					 | 
				
			||||||
        settings.CAS_AUTH_SHARED_SECRET = 'test'
 | 
					 | 
				
			||||||
        client = Client()
 | 
					 | 
				
			||||||
        response = client.post(
 | 
					 | 
				
			||||||
            '/auth',
 | 
					 | 
				
			||||||
            {
 | 
					 | 
				
			||||||
                'username': settings.CAS_TEST_USER,
 | 
					 | 
				
			||||||
                'password': settings.CAS_TEST_PASSWORD,
 | 
					 | 
				
			||||||
                'service': self.service,
 | 
					 | 
				
			||||||
                'secret': 'badsecret'
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
        self.assertEqual(response.content, b'no\n')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_auth_view_badsettings(self):
 | 
					 | 
				
			||||||
        settings.CAS_AUTH_SHARED_SECRET = None
 | 
					 | 
				
			||||||
        client = Client()
 | 
					 | 
				
			||||||
        response = client.post(
 | 
					 | 
				
			||||||
            '/auth',
 | 
					 | 
				
			||||||
            {
 | 
					 | 
				
			||||||
                'username': settings.CAS_TEST_USER,
 | 
					 | 
				
			||||||
                'password': settings.CAS_TEST_PASSWORD,
 | 
					 | 
				
			||||||
                'service': self.service,
 | 
					 | 
				
			||||||
                'secret': 'test'
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
        self.assertEqual(response.content, b"no\nplease set CAS_AUTH_SHARED_SECRET")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ValidateTestCase(TestCase):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def setUp(self):
 | 
					 | 
				
			||||||
        settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
 | 
					 | 
				
			||||||
        self.service = 'https://www.example.com'
 | 
					 | 
				
			||||||
        self.service_pattern = models.ServicePattern.objects.create(
 | 
					 | 
				
			||||||
            name="example",
 | 
					 | 
				
			||||||
            pattern="^https://www\.example\.com(/.*)?$"
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_validate_view_ok(self):
 | 
					 | 
				
			||||||
        ticket = get_user_ticket_request(self.service)[1]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        client = Client()
 | 
					 | 
				
			||||||
        response = client.get('/validate', {'ticket': ticket.value, 'service': self.service})
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
        self.assertEqual(response.content, b'yes\ntest\n')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_validate_view_badservice(self):
 | 
					 | 
				
			||||||
        ticket = get_user_ticket_request(self.service)[1]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        client = Client()
 | 
					 | 
				
			||||||
        response = client.get(
 | 
					 | 
				
			||||||
            '/validate',
 | 
					 | 
				
			||||||
            {'ticket': ticket.value, 'service': "https://www.example.org"}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
        self.assertEqual(response.content, b'no\n')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_validate_view_badticket(self):
 | 
					 | 
				
			||||||
        get_user_ticket_request(self.service)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        client = Client()
 | 
					 | 
				
			||||||
        response = client.get(
 | 
					 | 
				
			||||||
            '/validate',
 | 
					 | 
				
			||||||
            {'ticket': "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX, 'service': self.service}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
        self.assertEqual(response.content, b'no\n')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ValidateServiceTestCase(TestCase):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def setUp(self):
 | 
					 | 
				
			||||||
        settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
 | 
					 | 
				
			||||||
        self.service = 'http://127.0.0.1:45678'
 | 
					 | 
				
			||||||
        self.service_pattern = models.ServicePattern.objects.create(
 | 
					 | 
				
			||||||
            name="localhost",
 | 
					 | 
				
			||||||
            pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
 | 
					 | 
				
			||||||
            proxy_callback=True
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_validate_service_view_ok(self):
 | 
					 | 
				
			||||||
        ticket = get_user_ticket_request(self.service)[1]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        client = Client()
 | 
					 | 
				
			||||||
        response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service})
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        root = etree.fromstring(response.content)
 | 
					 | 
				
			||||||
        sucess = root.xpath(
 | 
					 | 
				
			||||||
            "//cas:authenticationSuccess",
 | 
					 | 
				
			||||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertTrue(sucess)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
					 | 
				
			||||||
        self.assertEqual(len(users), 1)
 | 
					 | 
				
			||||||
        self.assertEqual(users[0].text, settings.CAS_TEST_USER)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        attributes = root.xpath(
 | 
					 | 
				
			||||||
            "//cas:attributes",
 | 
					 | 
				
			||||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(len(attributes), 1)
 | 
					 | 
				
			||||||
        attrs1 = {}
 | 
					 | 
				
			||||||
        for attr in attributes[0]:
 | 
					 | 
				
			||||||
            attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
					 | 
				
			||||||
        self.assertEqual(len(attributes), len(attrs1))
 | 
					 | 
				
			||||||
        attrs2 = {}
 | 
					 | 
				
			||||||
        for attr in attributes:
 | 
					 | 
				
			||||||
            attrs2[attr.attrib['name']] = attr.attrib['value']
 | 
					 | 
				
			||||||
        self.assertEqual(attrs1, attrs2)
 | 
					 | 
				
			||||||
        self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_validate_service_view_badservice(self):
 | 
					 | 
				
			||||||
        ticket = get_user_ticket_request(self.service)[1]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        client = Client()
 | 
					 | 
				
			||||||
        bad_service = "https://www.example.org"
 | 
					 | 
				
			||||||
        response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': bad_service})
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        root = etree.fromstring(response.content)
 | 
					 | 
				
			||||||
        error = root.xpath(
 | 
					 | 
				
			||||||
            "//cas:authenticationFailure",
 | 
					 | 
				
			||||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(len(error), 1)
 | 
					 | 
				
			||||||
        self.assertEqual(error[0].attrib['code'], "INVALID_SERVICE")
 | 
					 | 
				
			||||||
        self.assertEqual(error[0].text, bad_service)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_validate_service_view_badticket_goodprefix(self):
 | 
					 | 
				
			||||||
        get_user_ticket_request(self.service)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        client = Client()
 | 
					 | 
				
			||||||
        bad_ticket = "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX
 | 
					 | 
				
			||||||
        response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        root = etree.fromstring(response.content)
 | 
					 | 
				
			||||||
        error = root.xpath(
 | 
					 | 
				
			||||||
            "//cas:authenticationFailure",
 | 
					 | 
				
			||||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(len(error), 1)
 | 
					 | 
				
			||||||
        self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
 | 
					 | 
				
			||||||
        self.assertEqual(error[0].text, 'ticket not found')
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_validate_service_view_badticket_badprefix(self):
 | 
					 | 
				
			||||||
        get_user_ticket_request(self.service)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        client = Client()
 | 
					 | 
				
			||||||
        bad_ticket = "RANDOM"
 | 
					 | 
				
			||||||
        response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        root = etree.fromstring(response.content)
 | 
					 | 
				
			||||||
        error = root.xpath(
 | 
					 | 
				
			||||||
            "//cas:authenticationFailure",
 | 
					 | 
				
			||||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(len(error), 1)
 | 
					 | 
				
			||||||
        self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
 | 
					 | 
				
			||||||
        self.assertEqual(error[0].text, bad_ticket)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_validate_service_view_ok_pgturl(self):
 | 
					 | 
				
			||||||
        (host, port) = utils.PGTUrlHandler.run()[1:3]
 | 
					 | 
				
			||||||
        service = "http://%s:%s" % (host, port)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        ticket = get_user_ticket_request(service)[1]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        client = Client()
 | 
					 | 
				
			||||||
        response = client.get(
 | 
					 | 
				
			||||||
            '/serviceValidate',
 | 
					 | 
				
			||||||
            {'ticket': ticket.value, 'service': service, 'pgtUrl': service}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        pgt_params = utils.PGTUrlHandler.PARAMS.copy()
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        root = etree.fromstring(response.content)
 | 
					 | 
				
			||||||
        pgtiou = root.xpath(
 | 
					 | 
				
			||||||
            "//cas:proxyGrantingTicket",
 | 
					 | 
				
			||||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(len(pgtiou), 1)
 | 
					 | 
				
			||||||
        self.assertEqual(pgt_params["pgtIou"], pgtiou[0].text)
 | 
					 | 
				
			||||||
        self.assertTrue("pgtId" in pgt_params)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_validate_service_pgturl_bad_proxy_callback(self):
 | 
					 | 
				
			||||||
        self.service_pattern.proxy_callback = False
 | 
					 | 
				
			||||||
        self.service_pattern.save()
 | 
					 | 
				
			||||||
        ticket = get_user_ticket_request(self.service)[1]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        client = Client()
 | 
					 | 
				
			||||||
        response = client.get(
 | 
					 | 
				
			||||||
            '/serviceValidate',
 | 
					 | 
				
			||||||
            {'ticket': ticket.value, 'service': self.service, 'pgtUrl': self.service}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        root = etree.fromstring(response.content)
 | 
					 | 
				
			||||||
        error = root.xpath(
 | 
					 | 
				
			||||||
            "//cas:authenticationFailure",
 | 
					 | 
				
			||||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(len(error), 1)
 | 
					 | 
				
			||||||
        self.assertEqual(error[0].attrib['code'], "INVALID_PROXY_CALLBACK")
 | 
					 | 
				
			||||||
        self.assertEqual(error[0].text, "callback url not allowed by configuration")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ProxyTestCase(TestCase):
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def setUp(self):
 | 
					 | 
				
			||||||
        settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
 | 
					 | 
				
			||||||
        self.service = 'http://127.0.0.1'
 | 
					 | 
				
			||||||
        self.service_pattern = models.ServicePattern.objects.create(
 | 
					 | 
				
			||||||
            name="localhost",
 | 
					 | 
				
			||||||
            pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
 | 
					 | 
				
			||||||
            proxy=True,
 | 
					 | 
				
			||||||
            proxy_callback=True
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_validate_proxy_ok(self):
 | 
					 | 
				
			||||||
        params = get_pgt()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # get a proxy ticket
 | 
					 | 
				
			||||||
        client1 = Client()
 | 
					 | 
				
			||||||
        response = client1.get('/proxy', {'pgt': params['pgtId'], 'targetService': self.service})
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        root = etree.fromstring(response.content)
 | 
					 | 
				
			||||||
        sucess = root.xpath("//cas:proxySuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
					 | 
				
			||||||
        self.assertTrue(sucess)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        proxy_ticket = root.xpath(
 | 
					 | 
				
			||||||
            "//cas:proxyTicket",
 | 
					 | 
				
			||||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(len(proxy_ticket), 1)
 | 
					 | 
				
			||||||
        proxy_ticket = proxy_ticket[0].text
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # validate the proxy ticket
 | 
					 | 
				
			||||||
        client2 = Client()
 | 
					 | 
				
			||||||
        response = client2.get('/proxyValidate', {'ticket': proxy_ticket, 'service': self.service})
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        root = etree.fromstring(response.content)
 | 
					 | 
				
			||||||
        sucess = root.xpath(
 | 
					 | 
				
			||||||
            "//cas:authenticationSuccess",
 | 
					 | 
				
			||||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertTrue(sucess)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # check that the proxy is send to the end service
 | 
					 | 
				
			||||||
        proxies = root.xpath("//cas:proxies", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
					 | 
				
			||||||
        self.assertEqual(len(proxies), 1)
 | 
					 | 
				
			||||||
        proxy = proxies[0].xpath("//cas:proxy", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
					 | 
				
			||||||
        self.assertEqual(len(proxy), 1)
 | 
					 | 
				
			||||||
        self.assertEqual(proxy[0].text, params["service"])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # same tests than those for serviceValidate
 | 
					 | 
				
			||||||
        users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
					 | 
				
			||||||
        self.assertEqual(len(users), 1)
 | 
					 | 
				
			||||||
        self.assertEqual(users[0].text, settings.CAS_TEST_USER)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        attributes = root.xpath(
 | 
					 | 
				
			||||||
            "//cas:attributes",
 | 
					 | 
				
			||||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(len(attributes), 1)
 | 
					 | 
				
			||||||
        attrs1 = {}
 | 
					 | 
				
			||||||
        for attr in attributes[0]:
 | 
					 | 
				
			||||||
            attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
					 | 
				
			||||||
        self.assertEqual(len(attributes), len(attrs1))
 | 
					 | 
				
			||||||
        attrs2 = {}
 | 
					 | 
				
			||||||
        for attr in attributes:
 | 
					 | 
				
			||||||
            attrs2[attr.attrib['name']] = attr.attrib['value']
 | 
					 | 
				
			||||||
        self.assertEqual(attrs1, attrs2)
 | 
					 | 
				
			||||||
        self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def test_validate_proxy_bad(self):
 | 
					 | 
				
			||||||
        params = get_pgt()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # bad PGT
 | 
					 | 
				
			||||||
        client1 = Client()
 | 
					 | 
				
			||||||
        response = client1.get(
 | 
					 | 
				
			||||||
            '/proxy',
 | 
					 | 
				
			||||||
            {
 | 
					 | 
				
			||||||
                'pgt': "%s-RANDOM" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX,
 | 
					 | 
				
			||||||
                'targetService': params['service']
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        root = etree.fromstring(response.content)
 | 
					 | 
				
			||||||
        error = root.xpath(
 | 
					 | 
				
			||||||
            "//cas:authenticationFailure",
 | 
					 | 
				
			||||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(len(error), 1)
 | 
					 | 
				
			||||||
        self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
 | 
					 | 
				
			||||||
        self.assertEqual(
 | 
					 | 
				
			||||||
            error[0].text,
 | 
					 | 
				
			||||||
            "PGT %s-RANDOM not found" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # bad targetService
 | 
					 | 
				
			||||||
        client2 = Client()
 | 
					 | 
				
			||||||
        response = client2.get(
 | 
					 | 
				
			||||||
            '/proxy',
 | 
					 | 
				
			||||||
            {'pgt': params['pgtId'], 'targetService': "https://www.example.org"}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        root = etree.fromstring(response.content)
 | 
					 | 
				
			||||||
        error = root.xpath(
 | 
					 | 
				
			||||||
            "//cas:authenticationFailure",
 | 
					 | 
				
			||||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(len(error), 1)
 | 
					 | 
				
			||||||
        self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE")
 | 
					 | 
				
			||||||
        self.assertEqual(error[0].text, "https://www.example.org")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # service do not allow proxy ticket
 | 
					 | 
				
			||||||
        self.service_pattern.proxy = False
 | 
					 | 
				
			||||||
        self.service_pattern.save()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        client3 = Client()
 | 
					 | 
				
			||||||
        response = client3.get(
 | 
					 | 
				
			||||||
            '/proxy',
 | 
					 | 
				
			||||||
            {'pgt': params['pgtId'], 'targetService': params['service']}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        root = etree.fromstring(response.content)
 | 
					 | 
				
			||||||
        error = root.xpath(
 | 
					 | 
				
			||||||
            "//cas:authenticationFailure",
 | 
					 | 
				
			||||||
            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        self.assertEqual(len(error), 1)
 | 
					 | 
				
			||||||
        self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE")
 | 
					 | 
				
			||||||
        self.assertEqual(
 | 
					 | 
				
			||||||
            error[0].text,
 | 
					 | 
				
			||||||
            'the service %s do not allow proxy ticket' % params['service']
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
							
								
								
									
										0
									
								
								cas_server/tests/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								cas_server/tests/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										193
									
								
								cas_server/tests/mixin.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										193
									
								
								cas_server/tests/mixin.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,193 @@
 | 
				
			|||||||
 | 
					# ⁻*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					# This program is distributed in the hope that it will be useful, but WITHOUT
 | 
				
			||||||
 | 
					# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 | 
				
			||||||
 | 
					# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
 | 
				
			||||||
 | 
					# more details.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# You should have received a copy of the GNU General Public License version 3
 | 
				
			||||||
 | 
					# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
				
			||||||
 | 
					# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# (c) 2016 Valentin Samir
 | 
				
			||||||
 | 
					"""Some mixin classes for tests"""
 | 
				
			||||||
 | 
					from cas_server.default_settings import settings
 | 
				
			||||||
 | 
					from django.utils import timezone
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import re
 | 
				
			||||||
 | 
					from lxml import etree
 | 
				
			||||||
 | 
					from datetime import timedelta
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from cas_server import models
 | 
				
			||||||
 | 
					from cas_server.tests.utils import get_auth_client
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BaseServicePattern(object):
 | 
				
			||||||
 | 
					    """Mixing for setting up service pattern for testing"""
 | 
				
			||||||
 | 
					    def setup_service_patterns(self, proxy=False):
 | 
				
			||||||
 | 
					        """setting up service pattern"""
 | 
				
			||||||
 | 
					        # For general purpose testing
 | 
				
			||||||
 | 
					        self.service = "https://www.example.com"
 | 
				
			||||||
 | 
					        self.service_pattern = models.ServicePattern.objects.create(
 | 
				
			||||||
 | 
					            name="example",
 | 
				
			||||||
 | 
					            pattern="^https://www\.example\.com(/.*)?$",
 | 
				
			||||||
 | 
					            proxy=proxy,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # For testing the restrict_users attributes
 | 
				
			||||||
 | 
					        self.service_restrict_user_fail = "https://restrict_user_fail.example.com"
 | 
				
			||||||
 | 
					        self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create(
 | 
				
			||||||
 | 
					            name="restrict_user_fail",
 | 
				
			||||||
 | 
					            pattern="^https://restrict_user_fail\.example\.com(/.*)?$",
 | 
				
			||||||
 | 
					            restrict_users=True,
 | 
				
			||||||
 | 
					            proxy=proxy,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.service_restrict_user_success = "https://restrict_user_success.example.com"
 | 
				
			||||||
 | 
					        self.service_pattern_restrict_user_success = models.ServicePattern.objects.create(
 | 
				
			||||||
 | 
					            name="restrict_user_success",
 | 
				
			||||||
 | 
					            pattern="^https://restrict_user_success\.example\.com(/.*)?$",
 | 
				
			||||||
 | 
					            restrict_users=True,
 | 
				
			||||||
 | 
					            proxy=proxy,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        models.Username.objects.create(
 | 
				
			||||||
 | 
					            value=settings.CAS_TEST_USER,
 | 
				
			||||||
 | 
					            service_pattern=self.service_pattern_restrict_user_success
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # For testing the user attributes filtering conditions
 | 
				
			||||||
 | 
					        self.service_filter_fail = "https://filter_fail.example.com"
 | 
				
			||||||
 | 
					        self.service_pattern_filter_fail = models.ServicePattern.objects.create(
 | 
				
			||||||
 | 
					            name="filter_fail",
 | 
				
			||||||
 | 
					            pattern="^https://filter_fail\.example\.com(/.*)?$",
 | 
				
			||||||
 | 
					            proxy=proxy,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        models.FilterAttributValue.objects.create(
 | 
				
			||||||
 | 
					            attribut="right",
 | 
				
			||||||
 | 
					            pattern="^admin$",
 | 
				
			||||||
 | 
					            service_pattern=self.service_pattern_filter_fail
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.service_filter_fail_alt = "https://filter_fail_alt.example.com"
 | 
				
			||||||
 | 
					        self.service_pattern_filter_fail_alt = models.ServicePattern.objects.create(
 | 
				
			||||||
 | 
					            name="filter_fail_alt",
 | 
				
			||||||
 | 
					            pattern="^https://filter_fail_alt\.example\.com(/.*)?$",
 | 
				
			||||||
 | 
					            proxy=proxy,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        models.FilterAttributValue.objects.create(
 | 
				
			||||||
 | 
					            attribut="nom",
 | 
				
			||||||
 | 
					            pattern="^toto$",
 | 
				
			||||||
 | 
					            service_pattern=self.service_pattern_filter_fail_alt
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.service_filter_success = "https://filter_success.example.com"
 | 
				
			||||||
 | 
					        self.service_pattern_filter_success = models.ServicePattern.objects.create(
 | 
				
			||||||
 | 
					            name="filter_success",
 | 
				
			||||||
 | 
					            pattern="^https://filter_success\.example\.com(/.*)?$",
 | 
				
			||||||
 | 
					            proxy=proxy,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        models.FilterAttributValue.objects.create(
 | 
				
			||||||
 | 
					            attribut="email",
 | 
				
			||||||
 | 
					            pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']),
 | 
				
			||||||
 | 
					            service_pattern=self.service_pattern_filter_success
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # For testing the user_field attributes
 | 
				
			||||||
 | 
					        self.service_field_needed_fail = "https://field_needed_fail.example.com"
 | 
				
			||||||
 | 
					        self.service_pattern_field_needed_fail = models.ServicePattern.objects.create(
 | 
				
			||||||
 | 
					            name="field_needed_fail",
 | 
				
			||||||
 | 
					            pattern="^https://field_needed_fail\.example\.com(/.*)?$",
 | 
				
			||||||
 | 
					            user_field="uid",
 | 
				
			||||||
 | 
					            proxy=proxy,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.service_field_needed_success = "https://field_needed_success.example.com"
 | 
				
			||||||
 | 
					        self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
 | 
				
			||||||
 | 
					            name="field_needed_success",
 | 
				
			||||||
 | 
					            pattern="^https://field_needed_success\.example\.com(/.*)?$",
 | 
				
			||||||
 | 
					            user_field="alias",
 | 
				
			||||||
 | 
					            proxy=proxy,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.service_field_needed_success_alt = "https://field_needed_success_alt.example.com"
 | 
				
			||||||
 | 
					        self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
 | 
				
			||||||
 | 
					            name="field_needed_success_alt",
 | 
				
			||||||
 | 
					            pattern="^https://field_needed_success_alt\.example\.com(/.*)?$",
 | 
				
			||||||
 | 
					            user_field="nom",
 | 
				
			||||||
 | 
					            proxy=proxy,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class XmlContent(object):
 | 
				
			||||||
 | 
					    """Mixin for test on CAS XML responses"""
 | 
				
			||||||
 | 
					    def assert_error(self, response, code, text=None):
 | 
				
			||||||
 | 
					        """Assert a validation error"""
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, 200)
 | 
				
			||||||
 | 
					        root = etree.fromstring(response.content)
 | 
				
			||||||
 | 
					        error = root.xpath(
 | 
				
			||||||
 | 
					            "//cas:authenticationFailure",
 | 
				
			||||||
 | 
					            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(len(error), 1)
 | 
				
			||||||
 | 
					        self.assertEqual(error[0].attrib['code'], code)
 | 
				
			||||||
 | 
					        if text is not None:
 | 
				
			||||||
 | 
					            self.assertEqual(error[0].text, text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def assert_success(self, response, username, original_attributes):
 | 
				
			||||||
 | 
					        """assert a ticket validation success"""
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, 200)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        root = etree.fromstring(response.content)
 | 
				
			||||||
 | 
					        sucess = root.xpath(
 | 
				
			||||||
 | 
					            "//cas:authenticationSuccess",
 | 
				
			||||||
 | 
					            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertTrue(sucess)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
				
			||||||
 | 
					        self.assertEqual(len(users), 1)
 | 
				
			||||||
 | 
					        self.assertEqual(users[0].text, username)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        attributes = root.xpath(
 | 
				
			||||||
 | 
					            "//cas:attributes",
 | 
				
			||||||
 | 
					            namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(len(attributes), 1)
 | 
				
			||||||
 | 
					        attrs1 = set()
 | 
				
			||||||
 | 
					        for attr in attributes[0]:
 | 
				
			||||||
 | 
					            attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
 | 
				
			||||||
 | 
					        self.assertEqual(len(attributes), len(attrs1))
 | 
				
			||||||
 | 
					        attrs2 = set()
 | 
				
			||||||
 | 
					        for attr in attributes:
 | 
				
			||||||
 | 
					            attrs2.add((attr.attrib['name'], attr.attrib['value']))
 | 
				
			||||||
 | 
					        original = set()
 | 
				
			||||||
 | 
					        for key, value in original_attributes.items():
 | 
				
			||||||
 | 
					            if isinstance(value, list):
 | 
				
			||||||
 | 
					                for sub_value in value:
 | 
				
			||||||
 | 
					                    original.add((key, sub_value))
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                original.add((key, value))
 | 
				
			||||||
 | 
					        self.assertEqual(attrs1, attrs2)
 | 
				
			||||||
 | 
					        self.assertEqual(attrs1, original)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return root
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class UserModels(object):
 | 
				
			||||||
 | 
					    """Mixin for test on CAS user models"""
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def expire_user():
 | 
				
			||||||
 | 
					        """return an expired user"""
 | 
				
			||||||
 | 
					        client = get_auth_client()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        new_date = timezone.now() - timedelta(seconds=(settings.SESSION_COOKIE_AGE + 600))
 | 
				
			||||||
 | 
					        models.User.objects.filter(
 | 
				
			||||||
 | 
					            username=settings.CAS_TEST_USER,
 | 
				
			||||||
 | 
					            session_key=client.session.session_key
 | 
				
			||||||
 | 
					        ).update(date=new_date)
 | 
				
			||||||
 | 
					        return client
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def get_user(client):
 | 
				
			||||||
 | 
					        """return the user associated with an authenticated client"""
 | 
				
			||||||
 | 
					        return models.User.objects.get(
 | 
				
			||||||
 | 
					            username=settings.CAS_TEST_USER,
 | 
				
			||||||
 | 
					            session_key=client.session.session_key
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
@@ -52,7 +52,7 @@ MIDDLEWARE_CLASSES = [
 | 
				
			|||||||
    'django.middleware.locale.LocaleMiddleware',
 | 
					    'django.middleware.locale.LocaleMiddleware',
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ROOT_URLCONF = 'cas_server.urls'
 | 
					ROOT_URLCONF = 'cas_server.tests.urls'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Database
 | 
					# Database
 | 
				
			||||||
# https://docs.djangoproject.com/en/1.9/ref/settings/#databases
 | 
					# https://docs.djangoproject.com/en/1.9/ref/settings/#databases
 | 
				
			||||||
@@ -60,6 +60,7 @@ ROOT_URLCONF = 'cas_server.urls'
 | 
				
			|||||||
DATABASES = {
 | 
					DATABASES = {
 | 
				
			||||||
    'default': {
 | 
					    'default': {
 | 
				
			||||||
        'ENGINE': 'django.db.backends.sqlite3',
 | 
					        'ENGINE': 'django.db.backends.sqlite3',
 | 
				
			||||||
 | 
					        'NAME': ':memory:',
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							
								
								
									
										166
									
								
								cas_server/tests/test_models.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										166
									
								
								cas_server/tests/test_models.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,166 @@
 | 
				
			|||||||
 | 
					# ⁻*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					# This program is distributed in the hope that it will be useful, but WITHOUT
 | 
				
			||||||
 | 
					# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 | 
				
			||||||
 | 
					# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
 | 
				
			||||||
 | 
					# more details.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# You should have received a copy of the GNU General Public License version 3
 | 
				
			||||||
 | 
					# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
				
			||||||
 | 
					# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# (c) 2016 Valentin Samir
 | 
				
			||||||
 | 
					"""Tests module for models"""
 | 
				
			||||||
 | 
					from cas_server.default_settings import settings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.test import TestCase
 | 
				
			||||||
 | 
					from django.test.utils import override_settings
 | 
				
			||||||
 | 
					from django.utils import timezone
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from datetime import timedelta
 | 
				
			||||||
 | 
					from importlib import import_module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from cas_server import models
 | 
				
			||||||
 | 
					from cas_server.tests.utils import get_auth_client, HttpParamsHandler
 | 
				
			||||||
 | 
					from cas_server.tests.mixin import UserModels, BaseServicePattern
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
 | 
				
			||||||
 | 
					class UserTestCase(TestCase, UserModels):
 | 
				
			||||||
 | 
					    """tests for the user models"""
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        """Prepare the test context"""
 | 
				
			||||||
 | 
					        self.service = 'http://127.0.0.1:45678'
 | 
				
			||||||
 | 
					        self.service_pattern = models.ServicePattern.objects.create(
 | 
				
			||||||
 | 
					            name="localhost",
 | 
				
			||||||
 | 
					            pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
 | 
				
			||||||
 | 
					            single_log_out=True
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_clean_old_entries(self):
 | 
				
			||||||
 | 
					        """test clean_old_entries"""
 | 
				
			||||||
 | 
					        # get an authenticated client
 | 
				
			||||||
 | 
					        client = self.expire_user()
 | 
				
			||||||
 | 
					        # assert the user exists before being cleaned
 | 
				
			||||||
 | 
					        self.assertEqual(len(models.User.objects.all()), 1)
 | 
				
			||||||
 | 
					        # assert the last activity date is before the expiry date
 | 
				
			||||||
 | 
					        self.assertTrue(
 | 
				
			||||||
 | 
					            self.get_user(client).date < (
 | 
				
			||||||
 | 
					                timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        # delete old inactive users
 | 
				
			||||||
 | 
					        models.User.clean_old_entries()
 | 
				
			||||||
 | 
					        # assert the user has being well delete
 | 
				
			||||||
 | 
					        self.assertEqual(len(models.User.objects.all()), 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_clean_deleted_sessions(self):
 | 
				
			||||||
 | 
					        """test clean_deleted_sessions"""
 | 
				
			||||||
 | 
					        # get an authenticated client
 | 
				
			||||||
 | 
					        client1 = get_auth_client()
 | 
				
			||||||
 | 
					        client2 = get_auth_client()
 | 
				
			||||||
 | 
					        # generate a ticket to fire SLO during user cleaning (SLO should fail a nothing listen
 | 
				
			||||||
 | 
					        # on self.service)
 | 
				
			||||||
 | 
					        ticket = self.get_user(client1).get_ticket(
 | 
				
			||||||
 | 
					            models.ServiceTicket,
 | 
				
			||||||
 | 
					            self.service,
 | 
				
			||||||
 | 
					            self.service_pattern,
 | 
				
			||||||
 | 
					            renew=False
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        ticket.validate = True
 | 
				
			||||||
 | 
					        ticket.save()
 | 
				
			||||||
 | 
					        # simulated expired session being garbage collected for client1
 | 
				
			||||||
 | 
					        session = SessionStore(session_key=client1.session.session_key)
 | 
				
			||||||
 | 
					        session.flush()
 | 
				
			||||||
 | 
					        # assert the user exists before being cleaned
 | 
				
			||||||
 | 
					        self.assertTrue(self.get_user(client1))
 | 
				
			||||||
 | 
					        self.assertTrue(self.get_user(client2))
 | 
				
			||||||
 | 
					        self.assertEqual(len(models.User.objects.all()), 2)
 | 
				
			||||||
 | 
					        # session has being remove so the user of client1 is no longer authenticated
 | 
				
			||||||
 | 
					        self.assertFalse(client1.session.get("authenticated"))
 | 
				
			||||||
 | 
					        # the user a client2 should still be authenticated
 | 
				
			||||||
 | 
					        self.assertTrue(client2.session.get("authenticated"))
 | 
				
			||||||
 | 
					        # the user should be deleted
 | 
				
			||||||
 | 
					        models.User.clean_deleted_sessions()
 | 
				
			||||||
 | 
					        # assert the user with expired sessions has being well deleted but the other remain
 | 
				
			||||||
 | 
					        self.assertEqual(len(models.User.objects.all()), 1)
 | 
				
			||||||
 | 
					        self.assertFalse(models.ServiceTicket.objects.all())
 | 
				
			||||||
 | 
					        self.assertTrue(client2.session.get("authenticated"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
 | 
				
			||||||
 | 
					class TicketTestCase(TestCase, UserModels, BaseServicePattern):
 | 
				
			||||||
 | 
					    """tests for the tickets models"""
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        """Prepare the test context"""
 | 
				
			||||||
 | 
					        self.setup_service_patterns()
 | 
				
			||||||
 | 
					        self.service = 'http://127.0.0.1:45678'
 | 
				
			||||||
 | 
					        self.service_pattern = models.ServicePattern.objects.create(
 | 
				
			||||||
 | 
					            name="localhost",
 | 
				
			||||||
 | 
					            pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
 | 
				
			||||||
 | 
					            single_log_out=True
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def get_ticket(
 | 
				
			||||||
 | 
					        user,
 | 
				
			||||||
 | 
					        ticket_class,
 | 
				
			||||||
 | 
					        service,
 | 
				
			||||||
 | 
					        service_pattern,
 | 
				
			||||||
 | 
					        renew=False,
 | 
				
			||||||
 | 
					        validate=False,
 | 
				
			||||||
 | 
					        validity_expired=False,
 | 
				
			||||||
 | 
					        timeout_expired=False,
 | 
				
			||||||
 | 
					        single_log_out=False,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        """Return a ticket"""
 | 
				
			||||||
 | 
					        ticket = user.get_ticket(ticket_class, service, service_pattern, renew)
 | 
				
			||||||
 | 
					        ticket.validate = validate
 | 
				
			||||||
 | 
					        ticket.single_log_out = single_log_out
 | 
				
			||||||
 | 
					        if validity_expired:
 | 
				
			||||||
 | 
					            ticket.creation = min(
 | 
				
			||||||
 | 
					                ticket.creation,
 | 
				
			||||||
 | 
					                (timezone.now() - timedelta(seconds=(ticket_class.VALIDITY + 10)))
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        if timeout_expired:
 | 
				
			||||||
 | 
					            ticket.creation = min(
 | 
				
			||||||
 | 
					                ticket.creation,
 | 
				
			||||||
 | 
					                (timezone.now() - timedelta(seconds=(ticket_class.TIMEOUT + 10)))
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        ticket.save()
 | 
				
			||||||
 | 
					        return ticket
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_clean_old_service_ticket(self):
 | 
				
			||||||
 | 
					        """test tickets clean_old_entries"""
 | 
				
			||||||
 | 
					        # ge an authenticated client
 | 
				
			||||||
 | 
					        client = get_auth_client()
 | 
				
			||||||
 | 
					        # get the user associated to the client
 | 
				
			||||||
 | 
					        user = self.get_user(client)
 | 
				
			||||||
 | 
					        # generate a ticket for that client, waiting for validation
 | 
				
			||||||
 | 
					        self.get_ticket(user, models.ServiceTicket, self.service, self.service_pattern)
 | 
				
			||||||
 | 
					        # generate another ticket for those validation time has expired
 | 
				
			||||||
 | 
					        self.get_ticket(
 | 
				
			||||||
 | 
					            user, models.ServiceTicket,
 | 
				
			||||||
 | 
					            self.service, self.service_pattern, validity_expired=True
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        (httpd, host, port) = HttpParamsHandler.run()[0:3]
 | 
				
			||||||
 | 
					        service = "http://%s:%s" % (host, port)
 | 
				
			||||||
 | 
					        # generate a ticket with SLO having timeout reach
 | 
				
			||||||
 | 
					        self.get_ticket(
 | 
				
			||||||
 | 
					            user, models.ServiceTicket,
 | 
				
			||||||
 | 
					            service, self.service_pattern, timeout_expired=True,
 | 
				
			||||||
 | 
					            validate=True, single_log_out=True
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        # there should be 3 tickets in the db
 | 
				
			||||||
 | 
					        self.assertEqual(len(models.ServiceTicket.objects.all()), 3)
 | 
				
			||||||
 | 
					        # we call the clean_old_entries method that should delete validated non SLO ticket and
 | 
				
			||||||
 | 
					        # expired non validated ticket and send SLO for SLO expired ticket before deleting then
 | 
				
			||||||
 | 
					        models.ServiceTicket.clean_old_entries()
 | 
				
			||||||
 | 
					        params = httpd.PARAMS
 | 
				
			||||||
 | 
					        # we successfully got a SLO request
 | 
				
			||||||
 | 
					        self.assertTrue(b'logoutRequest' in params and params[b'logoutRequest'])
 | 
				
			||||||
 | 
					        # only 1 ticket remain in the db
 | 
				
			||||||
 | 
					        self.assertEqual(len(models.ServiceTicket.objects.all()), 1)
 | 
				
			||||||
							
								
								
									
										191
									
								
								cas_server/tests/test_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										191
									
								
								cas_server/tests/test_utils.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,191 @@
 | 
				
			|||||||
 | 
					# ⁻*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					# This program is distributed in the hope that it will be useful, but WITHOUT
 | 
				
			||||||
 | 
					# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 | 
				
			||||||
 | 
					# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
 | 
				
			||||||
 | 
					# more details.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# You should have received a copy of the GNU General Public License version 3
 | 
				
			||||||
 | 
					# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
				
			||||||
 | 
					# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# (c) 2016 Valentin Samir
 | 
				
			||||||
 | 
					"""Tests module for utils"""
 | 
				
			||||||
 | 
					from django.test import TestCase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import six
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from cas_server import utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CheckPasswordCase(TestCase):
 | 
				
			||||||
 | 
					    """Tests for the utils function `utils.check_password`"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        """Generate random bytes string that will be used ass passwords"""
 | 
				
			||||||
 | 
					        self.password1 = utils.gen_saml_id()
 | 
				
			||||||
 | 
					        self.password2 = utils.gen_saml_id()
 | 
				
			||||||
 | 
					        if not isinstance(self.password1, bytes):  # pragma: no cover executed only in python3
 | 
				
			||||||
 | 
					            self.password1 = self.password1.encode("utf8")
 | 
				
			||||||
 | 
					            self.password2 = self.password2.encode("utf8")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_setup(self):
 | 
				
			||||||
 | 
					        """check that generated password are bytes"""
 | 
				
			||||||
 | 
					        self.assertIsInstance(self.password1, bytes)
 | 
				
			||||||
 | 
					        self.assertIsInstance(self.password2, bytes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_plain(self):
 | 
				
			||||||
 | 
					        """test the plain auth method"""
 | 
				
			||||||
 | 
					        self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
 | 
				
			||||||
 | 
					        self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_plain_unicode(self):
 | 
				
			||||||
 | 
					        """test the plain auth method with unicode input"""
 | 
				
			||||||
 | 
					        self.assertTrue(
 | 
				
			||||||
 | 
					            utils.check_password(
 | 
				
			||||||
 | 
					                "plain",
 | 
				
			||||||
 | 
					                self.password1.decode("utf8"),
 | 
				
			||||||
 | 
					                self.password1.decode("utf8"),
 | 
				
			||||||
 | 
					                "utf8"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertFalse(
 | 
				
			||||||
 | 
					            utils.check_password(
 | 
				
			||||||
 | 
					                "plain",
 | 
				
			||||||
 | 
					                self.password1.decode("utf8"),
 | 
				
			||||||
 | 
					                self.password2.decode("utf8"),
 | 
				
			||||||
 | 
					                "utf8"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_crypt(self):
 | 
				
			||||||
 | 
					        """test the crypt auth method"""
 | 
				
			||||||
 | 
					        salts = ["$6$UVVAQvrMyXMF3FF3", "aa"]
 | 
				
			||||||
 | 
					        hashed_password1 = []
 | 
				
			||||||
 | 
					        for salt in salts:
 | 
				
			||||||
 | 
					            if six.PY3:
 | 
				
			||||||
 | 
					                hashed_password1.append(
 | 
				
			||||||
 | 
					                    utils.crypt.crypt(
 | 
				
			||||||
 | 
					                        self.password1.decode("utf8"),
 | 
				
			||||||
 | 
					                        salt
 | 
				
			||||||
 | 
					                    ).encode("utf8")
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                hashed_password1.append(utils.crypt.crypt(self.password1, salt))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for hp1 in hashed_password1:
 | 
				
			||||||
 | 
					            self.assertTrue(utils.check_password("crypt", self.password1, hp1, "utf8"))
 | 
				
			||||||
 | 
					            self.assertFalse(utils.check_password("crypt", self.password2, hp1, "utf8"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with self.assertRaises(ValueError):
 | 
				
			||||||
 | 
					            utils.check_password("crypt", self.password1, b"$truc$s$dsdsd", "utf8")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_ldap_password_valid(self):
 | 
				
			||||||
 | 
					        """test the ldap auth method with all the schemes"""
 | 
				
			||||||
 | 
					        salt = b"UVVAQvrMyXMF3FF3"
 | 
				
			||||||
 | 
					        schemes_salt = [b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}"]
 | 
				
			||||||
 | 
					        schemes_nosalt = [b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"]
 | 
				
			||||||
 | 
					        hashed_password1 = []
 | 
				
			||||||
 | 
					        for scheme in schemes_salt:
 | 
				
			||||||
 | 
					            hashed_password1.append(
 | 
				
			||||||
 | 
					                utils.LdapHashUserPassword.hash(scheme, self.password1, salt, charset="utf8")
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        for scheme in schemes_nosalt:
 | 
				
			||||||
 | 
					            hashed_password1.append(
 | 
				
			||||||
 | 
					                utils.LdapHashUserPassword.hash(scheme, self.password1, charset="utf8")
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        hashed_password1.append(
 | 
				
			||||||
 | 
					            utils.LdapHashUserPassword.hash(
 | 
				
			||||||
 | 
					                b"{CRYPT}",
 | 
				
			||||||
 | 
					                self.password1,
 | 
				
			||||||
 | 
					                b"$6$UVVAQvrMyXMF3FF3",
 | 
				
			||||||
 | 
					                charset="utf8"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        for hp1 in hashed_password1:
 | 
				
			||||||
 | 
					            self.assertIsInstance(hp1, bytes)
 | 
				
			||||||
 | 
					            self.assertTrue(utils.check_password("ldap", self.password1, hp1, "utf8"))
 | 
				
			||||||
 | 
					            self.assertFalse(utils.check_password("ldap", self.password2, hp1, "utf8"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_ldap_password_fail(self):
 | 
				
			||||||
 | 
					        """test the ldap auth method with malformed hash or bad schemes"""
 | 
				
			||||||
 | 
					        salt = b"UVVAQvrMyXMF3FF3"
 | 
				
			||||||
 | 
					        schemes_salt = [b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}"]
 | 
				
			||||||
 | 
					        schemes_nosalt = [b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # first try to hash with bad parameters
 | 
				
			||||||
 | 
					        with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
 | 
				
			||||||
 | 
					            utils.LdapHashUserPassword.hash(b"TOTO", self.password1)
 | 
				
			||||||
 | 
					        for scheme in schemes_nosalt:
 | 
				
			||||||
 | 
					            with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
 | 
				
			||||||
 | 
					                utils.LdapHashUserPassword.hash(scheme, self.password1, salt)
 | 
				
			||||||
 | 
					        for scheme in schemes_salt:
 | 
				
			||||||
 | 
					            with self.assertRaises(utils.LdapHashUserPassword.BadScheme):
 | 
				
			||||||
 | 
					                utils.LdapHashUserPassword.hash(scheme, self.password1)
 | 
				
			||||||
 | 
					        with self.assertRaises(utils.LdapHashUserPassword.BadSalt):
 | 
				
			||||||
 | 
					            utils.LdapHashUserPassword.hash(b'{CRYPT}', self.password1, b"$truc$toto")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # then try to check hash with bad hashes
 | 
				
			||||||
 | 
					        with self.assertRaises(utils.LdapHashUserPassword.BadHash):
 | 
				
			||||||
 | 
					            utils.check_password("ldap", self.password1, b"TOTOssdsdsd", "utf8")
 | 
				
			||||||
 | 
					        for scheme in schemes_salt:
 | 
				
			||||||
 | 
					            with self.assertRaises(utils.LdapHashUserPassword.BadHash):
 | 
				
			||||||
 | 
					                utils.check_password("ldap", self.password1, scheme + b"dG90b3E8ZHNkcw==", "utf8")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_hex(self):
 | 
				
			||||||
 | 
					        """test all the hex_HASH method: the hashed password is a simple hash of the password"""
 | 
				
			||||||
 | 
					        hashes = ["md5", "sha1", "sha224", "sha256", "sha384", "sha512"]
 | 
				
			||||||
 | 
					        hashed_password1 = []
 | 
				
			||||||
 | 
					        for hash in hashes:
 | 
				
			||||||
 | 
					            hashed_password1.append(
 | 
				
			||||||
 | 
					                ("hex_%s" % hash, getattr(utils.hashlib, hash)(self.password1).hexdigest())
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        for (method, hp1) in hashed_password1:
 | 
				
			||||||
 | 
					            self.assertTrue(utils.check_password(method, self.password1, hp1, "utf8"))
 | 
				
			||||||
 | 
					            self.assertFalse(utils.check_password(method, self.password2, hp1, "utf8"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_bad_method(self):
 | 
				
			||||||
 | 
					        """try to check password with a bad method, should raise a ValueError"""
 | 
				
			||||||
 | 
					        with self.assertRaises(ValueError):
 | 
				
			||||||
 | 
					            utils.check_password("test", self.password1, b"$truc$s$dsdsd", "utf8")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class UtilsTestCase(TestCase):
 | 
				
			||||||
 | 
					    """tests for some little utils functions"""
 | 
				
			||||||
 | 
					    def test_import_attr(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					            test the import_attr function. Feeded with a dotted path string, it should
 | 
				
			||||||
 | 
					            import the dotted module and return that last componend of the dotted path
 | 
				
			||||||
 | 
					            (function, class or variable)
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        with self.assertRaises(ImportError):
 | 
				
			||||||
 | 
					            utils.import_attr('toto.titi.tutu')
 | 
				
			||||||
 | 
					        with self.assertRaises(AttributeError):
 | 
				
			||||||
 | 
					            utils.import_attr('cas_server.utils.toto')
 | 
				
			||||||
 | 
					        with self.assertRaises(ValueError):
 | 
				
			||||||
 | 
					            utils.import_attr('toto')
 | 
				
			||||||
 | 
					        self.assertEqual(
 | 
				
			||||||
 | 
					            utils.import_attr('cas_server.default_app_config'),
 | 
				
			||||||
 | 
					            'cas_server.apps.CasAppConfig'
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        self.assertEqual(utils.import_attr(utils), utils)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_update_url(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					            test the update_url function. Given an url with possible GET parameter and a dict
 | 
				
			||||||
 | 
					            the function build a url with GET parameters updated by the dictionnary
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        url1 = utils.update_url(u"https://www.example.com?toto=1", {u"tata": u"2"})
 | 
				
			||||||
 | 
					        url2 = utils.update_url(b"https://www.example.com?toto=1", {b"tata": b"2"})
 | 
				
			||||||
 | 
					        self.assertEqual(url1, u"https://www.example.com?tata=2&toto=1")
 | 
				
			||||||
 | 
					        self.assertEqual(url2, u"https://www.example.com?tata=2&toto=1")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        url3 = utils.update_url(u"https://www.example.com?toto=1", {u"toto": u"2"})
 | 
				
			||||||
 | 
					        self.assertEqual(url3, u"https://www.example.com?toto=2")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_crypt_salt_is_valid(self):
 | 
				
			||||||
 | 
					        """test the function crypt_salt_is_valid who test if a crypt salt is valid"""
 | 
				
			||||||
 | 
					        self.assertFalse(utils.crypt_salt_is_valid(""))  # len 0
 | 
				
			||||||
 | 
					        self.assertFalse(utils.crypt_salt_is_valid("a"))  # len 1
 | 
				
			||||||
 | 
					        self.assertFalse(utils.crypt_salt_is_valid("$$"))  # start with $ followed by $
 | 
				
			||||||
 | 
					        self.assertFalse(utils.crypt_salt_is_valid("$toto"))  # start with $ but no secondary $
 | 
				
			||||||
 | 
					        self.assertFalse(utils.crypt_salt_is_valid("$toto$toto"))  # algorithm toto not known
 | 
				
			||||||
							
								
								
									
										1813
									
								
								cas_server/tests/test_view.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1813
									
								
								cas_server/tests/test_view.py
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										22
									
								
								cas_server/tests/urls.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								cas_server/tests/urls.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,22 @@
 | 
				
			|||||||
 | 
					"""cas URL Configuration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The `urlpatterns` list routes URLs to views. For more information please see:
 | 
				
			||||||
 | 
					    https://docs.djangoproject.com/en/1.9/topics/http/urls/
 | 
				
			||||||
 | 
					Examples:
 | 
				
			||||||
 | 
					Function views
 | 
				
			||||||
 | 
					    1. Add an import:  from my_app import views
 | 
				
			||||||
 | 
					    2. Add a URL to urlpatterns:  url(r'^$', views.home, name='home')
 | 
				
			||||||
 | 
					Class-based views
 | 
				
			||||||
 | 
					    1. Add an import:  from other_app.views import Home
 | 
				
			||||||
 | 
					    2. Add a URL to urlpatterns:  url(r'^$', Home.as_view(), name='home')
 | 
				
			||||||
 | 
					Including another URLconf
 | 
				
			||||||
 | 
					    1. Import the include() function: from django.conf.urls import url, include, include
 | 
				
			||||||
 | 
					    2. Add a URL to urlpatterns:  url(r'^blog/', include('blog.urls'))
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					from django.conf.urls import url, include
 | 
				
			||||||
 | 
					from django.contrib import admin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					urlpatterns = [
 | 
				
			||||||
 | 
					    url(r'^admin/', admin.site.urls),
 | 
				
			||||||
 | 
					    url(r'^', include('cas_server.urls', namespace='cas_server')),
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
							
								
								
									
										180
									
								
								cas_server/tests/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										180
									
								
								cas_server/tests/utils.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,180 @@
 | 
				
			|||||||
 | 
					# ⁻*- coding: utf-8 -*-
 | 
				
			||||||
 | 
					# This program is distributed in the hope that it will be useful, but WITHOUT
 | 
				
			||||||
 | 
					# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 | 
				
			||||||
 | 
					# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
 | 
				
			||||||
 | 
					# more details.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# You should have received a copy of the GNU General Public License version 3
 | 
				
			||||||
 | 
					# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
				
			||||||
 | 
					# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# (c) 2016 Valentin Samir
 | 
				
			||||||
 | 
					"""Some utils functions for tests"""
 | 
				
			||||||
 | 
					from cas_server.default_settings import settings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from django.test import Client
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import cgi
 | 
				
			||||||
 | 
					from threading import Thread
 | 
				
			||||||
 | 
					from lxml import etree
 | 
				
			||||||
 | 
					from six.moves import BaseHTTPServer
 | 
				
			||||||
 | 
					from six.moves.urllib.parse import urlparse, parse_qsl
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from cas_server import models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def copy_form(form):
 | 
				
			||||||
 | 
					    """Copy form value into a dict"""
 | 
				
			||||||
 | 
					    params = {}
 | 
				
			||||||
 | 
					    for field in form:
 | 
				
			||||||
 | 
					        if field.value():
 | 
				
			||||||
 | 
					            params[field.name] = field.value()
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            params[field.name] = ""
 | 
				
			||||||
 | 
					    return params
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_login_page_params(client=None):
 | 
				
			||||||
 | 
					    """Return a client and the POST params for the client to login"""
 | 
				
			||||||
 | 
					    if client is None:
 | 
				
			||||||
 | 
					        client = Client()
 | 
				
			||||||
 | 
					    response = client.get('/login')
 | 
				
			||||||
 | 
					    params = copy_form(response.context["form"])
 | 
				
			||||||
 | 
					    return client, params
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_auth_client(**update):
 | 
				
			||||||
 | 
					    """return a authenticated client"""
 | 
				
			||||||
 | 
					    client, params = get_login_page_params()
 | 
				
			||||||
 | 
					    params["username"] = settings.CAS_TEST_USER
 | 
				
			||||||
 | 
					    params["password"] = settings.CAS_TEST_PASSWORD
 | 
				
			||||||
 | 
					    params.update(update)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    client.post('/login', params)
 | 
				
			||||||
 | 
					    assert client.session.get("authenticated")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return client
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_user_ticket_request(service):
 | 
				
			||||||
 | 
					    """Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
 | 
				
			||||||
 | 
					    client = get_auth_client()
 | 
				
			||||||
 | 
					    response = client.get("/login", {"service": service})
 | 
				
			||||||
 | 
					    ticket_value = response['Location'].split('ticket=')[-1]
 | 
				
			||||||
 | 
					    user = models.User.objects.get(
 | 
				
			||||||
 | 
					        username=settings.CAS_TEST_USER,
 | 
				
			||||||
 | 
					        session_key=client.session.session_key
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    ticket = models.ServiceTicket.objects.get(value=ticket_value)
 | 
				
			||||||
 | 
					    return (user, ticket, client)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_validated_ticket(service):
 | 
				
			||||||
 | 
					    """Return a tick that has being already validated. Used to test SLO"""
 | 
				
			||||||
 | 
					    (ticket, auth_client) = get_user_ticket_request(service)[1:3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    client = Client()
 | 
				
			||||||
 | 
					    response = client.get('/validate', {'ticket': ticket.value, 'service': service})
 | 
				
			||||||
 | 
					    assert (response.status_code == 200)
 | 
				
			||||||
 | 
					    assert (response.content == b'yes\ntest\n')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ticket = models.ServiceTicket.objects.get(value=ticket.value)
 | 
				
			||||||
 | 
					    return (auth_client, ticket)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_pgt():
 | 
				
			||||||
 | 
					    """return a dict contening a service, user and PGT ticket for this service"""
 | 
				
			||||||
 | 
					    (httpd, host, port) = HttpParamsHandler.run()[0:3]
 | 
				
			||||||
 | 
					    service = "http://%s:%s" % (host, port)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    (user, ticket) = get_user_ticket_request(service)[:2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    client = Client()
 | 
				
			||||||
 | 
					    client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
 | 
				
			||||||
 | 
					    params = httpd.PARAMS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    params["service"] = service
 | 
				
			||||||
 | 
					    params["user"] = user
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return params
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_proxy_ticket(service):
 | 
				
			||||||
 | 
					    """Return a ProxyTicket waiting for validation"""
 | 
				
			||||||
 | 
					    params = get_pgt()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # get a proxy ticket
 | 
				
			||||||
 | 
					    client = Client()
 | 
				
			||||||
 | 
					    response = client.get('/proxy', {'pgt': params['pgtId'], 'targetService': service})
 | 
				
			||||||
 | 
					    root = etree.fromstring(response.content)
 | 
				
			||||||
 | 
					    proxy_ticket = root.xpath(
 | 
				
			||||||
 | 
					        "//cas:proxyTicket",
 | 
				
			||||||
 | 
					        namespaces={'cas': "http://www.yale.edu/tp/cas"}
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    proxy_ticket = proxy_ticket[0].text
 | 
				
			||||||
 | 
					    ticket = models.ProxyTicket.objects.get(value=proxy_ticket)
 | 
				
			||||||
 | 
					    return ticket
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class HttpParamsHandler(BaseHTTPServer.BaseHTTPRequestHandler):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					        A simple http server that return 200 on GET or POST
 | 
				
			||||||
 | 
					        and store GET or POST parameters. Used in unit tests
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def do_GET(self):
 | 
				
			||||||
 | 
					        """Called on a GET request on the BaseHTTPServer"""
 | 
				
			||||||
 | 
					        self.send_response(200)
 | 
				
			||||||
 | 
					        self.send_header(b"Content-type", "text/plain")
 | 
				
			||||||
 | 
					        self.end_headers()
 | 
				
			||||||
 | 
					        self.wfile.write(b"ok")
 | 
				
			||||||
 | 
					        url = urlparse(self.path)
 | 
				
			||||||
 | 
					        params = dict(parse_qsl(url.query))
 | 
				
			||||||
 | 
					        self.server.PARAMS = params
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def do_POST(self):
 | 
				
			||||||
 | 
					        """Called on a POST request on the BaseHTTPServer"""
 | 
				
			||||||
 | 
					        ctype, pdict = cgi.parse_header(self.headers.get('content-type'))
 | 
				
			||||||
 | 
					        if ctype == 'multipart/form-data':
 | 
				
			||||||
 | 
					            postvars = cgi.parse_multipart(self.rfile, pdict)
 | 
				
			||||||
 | 
					        elif ctype == 'application/x-www-form-urlencoded':
 | 
				
			||||||
 | 
					            length = int(self.headers.get('content-length'))
 | 
				
			||||||
 | 
					            postvars = cgi.parse_qs(self.rfile.read(length), keep_blank_values=1)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            postvars = {}
 | 
				
			||||||
 | 
					        self.server.PARAMS = postvars
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def log_message(self, *args):
 | 
				
			||||||
 | 
					        """silent any log message"""
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @classmethod
 | 
				
			||||||
 | 
					    def run(cls):
 | 
				
			||||||
 | 
					        """Run a BaseHTTPServer using this class as handler"""
 | 
				
			||||||
 | 
					        server_class = BaseHTTPServer.HTTPServer
 | 
				
			||||||
 | 
					        httpd = server_class(("127.0.0.1", 0), cls)
 | 
				
			||||||
 | 
					        (host, port) = httpd.socket.getsockname()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def lauch():
 | 
				
			||||||
 | 
					            """routine to lauch in a background thread"""
 | 
				
			||||||
 | 
					            httpd.handle_request()
 | 
				
			||||||
 | 
					            httpd.server_close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        httpd_thread = Thread(target=lauch)
 | 
				
			||||||
 | 
					        httpd_thread.daemon = True
 | 
				
			||||||
 | 
					        httpd_thread.start()
 | 
				
			||||||
 | 
					        return (httpd, host, port)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Http404Handler(HttpParamsHandler):
 | 
				
			||||||
 | 
					    """A simple http server that always return 404 not found. Used in unit tests"""
 | 
				
			||||||
 | 
					    def do_GET(self):
 | 
				
			||||||
 | 
					        """Called on a GET request on the BaseHTTPServer"""
 | 
				
			||||||
 | 
					        self.send_response(404)
 | 
				
			||||||
 | 
					        self.send_header(b"Content-type", "text/plain")
 | 
				
			||||||
 | 
					        self.end_headers()
 | 
				
			||||||
 | 
					        self.wfile.write(b"error 404 not found")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def do_POST(self):
 | 
				
			||||||
 | 
					        """Called on a POST request on the BaseHTTPServer"""
 | 
				
			||||||
 | 
					        return self.do_GET()
 | 
				
			||||||
@@ -8,7 +8,7 @@
 | 
				
			|||||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
					# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
				
			||||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
					# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# (c) 2015 Valentin Samir
 | 
					# (c) 2015-2016 Valentin Samir
 | 
				
			||||||
"""urls for the app"""
 | 
					"""urls for the app"""
 | 
				
			||||||
from django.conf.urls import patterns, url
 | 
					from django.conf.urls import patterns, url
 | 
				
			||||||
from django.views.generic import RedirectView
 | 
					from django.views.generic import RedirectView
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
# ⁻*- coding: utf-8 -*-
 | 
					# -*- coding: utf-8 -*-
 | 
				
			||||||
# This program is distributed in the hope that it will be useful, but WITHOUT
 | 
					# This program is distributed in the hope that it will be useful, but WITHOUT
 | 
				
			||||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 | 
					# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 | 
				
			||||||
# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
 | 
					# FOR A PARTICULAR PURPOSE. See the GNU General Public License version 3 for
 | 
				
			||||||
@@ -8,7 +8,7 @@
 | 
				
			|||||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
					# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
				
			||||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
					# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# (c) 2015 Valentin Samir
 | 
					# (c) 2015-2016 Valentin Samir
 | 
				
			||||||
"""Some util function for the app"""
 | 
					"""Some util function for the app"""
 | 
				
			||||||
from .default_settings import settings
 | 
					from .default_settings import settings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -23,18 +23,19 @@ import hashlib
 | 
				
			|||||||
import crypt
 | 
					import crypt
 | 
				
			||||||
import base64
 | 
					import base64
 | 
				
			||||||
import six
 | 
					import six
 | 
				
			||||||
from threading import Thread
 | 
					
 | 
				
			||||||
from importlib import import_module
 | 
					from importlib import import_module
 | 
				
			||||||
from six.moves import BaseHTTPServer
 | 
					 | 
				
			||||||
from six.moves.urllib.parse import urlparse, urlunparse, parse_qsl, urlencode
 | 
					from six.moves.urllib.parse import urlparse, urlunparse, parse_qsl, urlencode
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def context(params):
 | 
					def context(params):
 | 
				
			||||||
 | 
					    """Function that add somes variable to the context before template rendering"""
 | 
				
			||||||
    params["settings"] = settings
 | 
					    params["settings"] = settings
 | 
				
			||||||
    return params
 | 
					    return params
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def json_response(request, data):
 | 
					def json_response(request, data):
 | 
				
			||||||
 | 
					    """Wrapper dumping `data` to a json and sending it to the user with an HttpResponse"""
 | 
				
			||||||
    data["messages"] = []
 | 
					    data["messages"] = []
 | 
				
			||||||
    for msg in messages.get_messages(request):
 | 
					    for msg in messages.get_messages(request):
 | 
				
			||||||
        data["messages"].append({'message': msg.message, 'level': msg.level_tag})
 | 
					        data["messages"].append({'message': msg.message, 'level': msg.level_tag})
 | 
				
			||||||
@@ -64,6 +65,7 @@ def redirect_params(url_name, params=None):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def reverse_params(url_name, params=None, **kwargs):
 | 
					def reverse_params(url_name, params=None, **kwargs):
 | 
				
			||||||
 | 
					    """compule the reverse url or `url_name` and add GET parameters from `params` to it"""
 | 
				
			||||||
    url = reverse(url_name, **kwargs)
 | 
					    url = reverse(url_name, **kwargs)
 | 
				
			||||||
    params = urlencode(params if params else {})
 | 
					    params = urlencode(params if params else {})
 | 
				
			||||||
    return url + "?%s" % params
 | 
					    return url + "?%s" % params
 | 
				
			||||||
@@ -83,10 +85,13 @@ def update_url(url, params):
 | 
				
			|||||||
    url_parts = list(urlparse(url))
 | 
					    url_parts = list(urlparse(url))
 | 
				
			||||||
    query = dict(parse_qsl(url_parts[4]))
 | 
					    query = dict(parse_qsl(url_parts[4]))
 | 
				
			||||||
    query.update(params)
 | 
					    query.update(params)
 | 
				
			||||||
    url_parts[4] = urlencode(query)
 | 
					    # make the params order deterministic
 | 
				
			||||||
    for i, url_part in enumerate(url_parts):
 | 
					    query = list(query.items())
 | 
				
			||||||
        if not isinstance(url_part, bytes):
 | 
					    query.sort()
 | 
				
			||||||
            url_parts[i] = url_part.encode('utf-8')
 | 
					    url_query = urlencode(query)
 | 
				
			||||||
 | 
					    if not isinstance(url_query, bytes):  # pragma: no cover in python3 urlencode return an unicode
 | 
				
			||||||
 | 
					        url_query = url_query.encode("utf-8")
 | 
				
			||||||
 | 
					    url_parts[4] = url_query
 | 
				
			||||||
    return urlunparse(url_parts).decode('utf-8')
 | 
					    return urlunparse(url_parts).decode('utf-8')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -147,35 +152,25 @@ def gen_saml_id():
 | 
				
			|||||||
    return _gen_ticket('_')
 | 
					    return _gen_ticket('_')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler):
 | 
					def crypt_salt_is_valid(salt):
 | 
				
			||||||
    PARAMS = {}
 | 
					    """Return True is salt is valid has a crypt salt, False otherwise"""
 | 
				
			||||||
 | 
					    if len(salt) < 2:
 | 
				
			||||||
    def do_GET(self):
 | 
					        return False
 | 
				
			||||||
        self.send_response(200)
 | 
					    else:
 | 
				
			||||||
        self.send_header(b"Content-type", "text/plain")
 | 
					        if salt[0] == '$':
 | 
				
			||||||
        self.end_headers()
 | 
					            if salt[1] == '$':
 | 
				
			||||||
        self.wfile.write(b"ok")
 | 
					                return False
 | 
				
			||||||
        url = urlparse(self.path)
 | 
					            else:
 | 
				
			||||||
        params = dict(parse_qsl(url.query))
 | 
					                if '$' not in salt[1:]:
 | 
				
			||||||
        PGTUrlHandler.PARAMS.update(params)
 | 
					                    return False
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
    def log_message(self, *args):
 | 
					                    hashed = crypt.crypt("", salt)
 | 
				
			||||||
        return
 | 
					                    if not hashed or '$' not in hashed[1:]:
 | 
				
			||||||
 | 
					                        return False
 | 
				
			||||||
    @staticmethod
 | 
					                    else:
 | 
				
			||||||
    def run():
 | 
					                        return True
 | 
				
			||||||
        server_class = BaseHTTPServer.HTTPServer
 | 
					        else:
 | 
				
			||||||
        httpd = server_class(("127.0.0.1", 0), PGTUrlHandler)
 | 
					            return True
 | 
				
			||||||
        (host, port) = httpd.socket.getsockname()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        def lauch():
 | 
					 | 
				
			||||||
            httpd.handle_request()
 | 
					 | 
				
			||||||
            httpd.server_close()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        httpd_thread = Thread(target=lauch)
 | 
					 | 
				
			||||||
        httpd_thread.daemon = True
 | 
					 | 
				
			||||||
        httpd_thread.start()
 | 
					 | 
				
			||||||
        return (httpd_thread, host, port)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LdapHashUserPassword(object):
 | 
					class LdapHashUserPassword(object):
 | 
				
			||||||
@@ -268,7 +263,7 @@ class LdapHashUserPassword(object):
 | 
				
			|||||||
        if salt is None or salt == b"":
 | 
					        if salt is None or salt == b"":
 | 
				
			||||||
            salt = b""
 | 
					            salt = b""
 | 
				
			||||||
            cls._test_scheme_nosalt(scheme)
 | 
					            cls._test_scheme_nosalt(scheme)
 | 
				
			||||||
        elif salt is not None:
 | 
					        else:
 | 
				
			||||||
            cls._test_scheme_salt(scheme)
 | 
					            cls._test_scheme_salt(scheme)
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            return scheme + base64.b64encode(
 | 
					            return scheme + base64.b64encode(
 | 
				
			||||||
@@ -278,9 +273,9 @@ class LdapHashUserPassword(object):
 | 
				
			|||||||
            if six.PY3:
 | 
					            if six.PY3:
 | 
				
			||||||
                password = password.decode(charset)
 | 
					                password = password.decode(charset)
 | 
				
			||||||
                salt = salt.decode(charset)
 | 
					                salt = salt.decode(charset)
 | 
				
			||||||
            hashed_password = crypt.crypt(password, salt)
 | 
					            if not crypt_salt_is_valid(salt):
 | 
				
			||||||
            if hashed_password is None:
 | 
					 | 
				
			||||||
                raise cls.BadSalt("System crypt implementation do not support the salt %r" % salt)
 | 
					                raise cls.BadSalt("System crypt implementation do not support the salt %r" % salt)
 | 
				
			||||||
 | 
					            hashed_password = crypt.crypt(password, salt)
 | 
				
			||||||
            if six.PY3:
 | 
					            if six.PY3:
 | 
				
			||||||
                hashed_password = hashed_password.encode(charset)
 | 
					                hashed_password = hashed_password.encode(charset)
 | 
				
			||||||
            return scheme + hashed_password
 | 
					            return scheme + hashed_password
 | 
				
			||||||
@@ -302,7 +297,7 @@ class LdapHashUserPassword(object):
 | 
				
			|||||||
        if scheme in cls.schemes_nosalt:
 | 
					        if scheme in cls.schemes_nosalt:
 | 
				
			||||||
            return b""
 | 
					            return b""
 | 
				
			||||||
        elif scheme == b'{CRYPT}':
 | 
					        elif scheme == b'{CRYPT}':
 | 
				
			||||||
            return b'$'.join(hashed_passord.split(b'$', 3)[:-1])
 | 
					            return b'$'.join(hashed_passord.split(b'$', 3)[:-1])[len(scheme):]
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            hashed_passord = base64.b64decode(hashed_passord[len(scheme):])
 | 
					            hashed_passord = base64.b64decode(hashed_passord[len(scheme):])
 | 
				
			||||||
            if len(hashed_passord) < cls._schemes_to_len[scheme]:
 | 
					            if len(hashed_passord) < cls._schemes_to_len[scheme]:
 | 
				
			||||||
@@ -324,7 +319,7 @@ def check_password(method, password, hashed_password, charset):
 | 
				
			|||||||
    elif method == "crypt":
 | 
					    elif method == "crypt":
 | 
				
			||||||
        if hashed_password.startswith(b'$'):
 | 
					        if hashed_password.startswith(b'$'):
 | 
				
			||||||
            salt = b'$'.join(hashed_password.split(b'$', 3)[:-1])
 | 
					            salt = b'$'.join(hashed_password.split(b'$', 3)[:-1])
 | 
				
			||||||
        elif hashed_password.startswith(b'_'):
 | 
					        elif hashed_password.startswith(b'_'):  # pragma: no cover old BSD format not supported
 | 
				
			||||||
            salt = hashed_password[:9]
 | 
					            salt = hashed_password[:9]
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            salt = hashed_password[:2]
 | 
					            salt = hashed_password[:2]
 | 
				
			||||||
@@ -332,9 +327,9 @@ def check_password(method, password, hashed_password, charset):
 | 
				
			|||||||
            password = password.decode(charset)
 | 
					            password = password.decode(charset)
 | 
				
			||||||
            salt = salt.decode(charset)
 | 
					            salt = salt.decode(charset)
 | 
				
			||||||
            hashed_password = hashed_password.decode(charset)
 | 
					            hashed_password = hashed_password.decode(charset)
 | 
				
			||||||
        crypted_password = crypt.crypt(password, salt)
 | 
					        if not crypt_salt_is_valid(salt):
 | 
				
			||||||
        if crypted_password is None:
 | 
					 | 
				
			||||||
            raise ValueError("System crypt implementation do not support the salt %r" % salt)
 | 
					            raise ValueError("System crypt implementation do not support the salt %r" % salt)
 | 
				
			||||||
 | 
					        crypted_password = crypt.crypt(password, salt)
 | 
				
			||||||
        return crypted_password == hashed_password
 | 
					        return crypted_password == hashed_password
 | 
				
			||||||
    elif method == "ldap":
 | 
					    elif method == "ldap":
 | 
				
			||||||
        scheme = LdapHashUserPassword.get_scheme(hashed_password)
 | 
					        scheme = LdapHashUserPassword.get_scheme(hashed_password)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,7 +8,7 @@
 | 
				
			|||||||
# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
					# along with this program; if not, write to the Free Software Foundation, Inc., 51
 | 
				
			||||||
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
					# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
# (c) 2015 Valentin Samir
 | 
					# (c) 2015-2016 Valentin Samir
 | 
				
			||||||
"""views for the app"""
 | 
					"""views for the app"""
 | 
				
			||||||
from .default_settings import settings
 | 
					from .default_settings import settings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -105,10 +105,11 @@ class LogoutView(View, LogoutMixin):
 | 
				
			|||||||
    service = None
 | 
					    service = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def init_get(self, request):
 | 
					    def init_get(self, request):
 | 
				
			||||||
 | 
					        """Initialize GET received parameters"""
 | 
				
			||||||
        self.request = request
 | 
					        self.request = request
 | 
				
			||||||
        self.service = request.GET.get('service')
 | 
					        self.service = request.GET.get('service')
 | 
				
			||||||
        self.url = request.GET.get('url')
 | 
					        self.url = request.GET.get('url')
 | 
				
			||||||
        self.ajax = 'HTTP_X_AJAX' in request.META
 | 
					        self.ajax = settings.CAS_ENABLE_AJAX_AUTH and 'HTTP_X_AJAX' in request.META
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get(self, request, *args, **kwargs):
 | 
					    def get(self, request, *args, **kwargs):
 | 
				
			||||||
        """methode called on GET request on this view"""
 | 
					        """methode called on GET request on this view"""
 | 
				
			||||||
@@ -196,24 +197,30 @@ class LoginView(View, LogoutMixin):
 | 
				
			|||||||
    USER_NOT_AUTHENTICATED = 6
 | 
					    USER_NOT_AUTHENTICATED = 6
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def init_post(self, request):
 | 
					    def init_post(self, request):
 | 
				
			||||||
 | 
					        """Initialize POST received parameters"""
 | 
				
			||||||
        self.request = request
 | 
					        self.request = request
 | 
				
			||||||
        self.service = request.POST.get('service')
 | 
					        self.service = request.POST.get('service')
 | 
				
			||||||
        self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False")
 | 
					        self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False")
 | 
				
			||||||
        self.gateway = request.POST.get('gateway')
 | 
					        self.gateway = request.POST.get('gateway')
 | 
				
			||||||
        self.method = request.POST.get('method')
 | 
					        self.method = request.POST.get('method')
 | 
				
			||||||
        self.ajax = 'HTTP_X_AJAX' in request.META
 | 
					        self.ajax = settings.CAS_ENABLE_AJAX_AUTH and 'HTTP_X_AJAX' in request.META
 | 
				
			||||||
        if request.POST.get('warned') and request.POST['warned'] != "False":
 | 
					        if request.POST.get('warned') and request.POST['warned'] != "False":
 | 
				
			||||||
            self.warned = True
 | 
					            self.warned = True
 | 
				
			||||||
 | 
					        self.warn = request.POST.get('warn')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def check_lt(self):
 | 
					    def gen_lt(self):
 | 
				
			||||||
        # save LT for later check
 | 
					        """Generate a new LoginTicket and add it to the list of valid LT for the user"""
 | 
				
			||||||
        lt_valid = self.request.session.get('lt', [])
 | 
					 | 
				
			||||||
        lt_send = self.request.POST.get('lt')
 | 
					 | 
				
			||||||
        # generate a new LT (by posting the LT has been consumed)
 | 
					 | 
				
			||||||
        self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
 | 
					        self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
 | 
				
			||||||
        if len(self.request.session['lt']) > 100:
 | 
					        if len(self.request.session['lt']) > 100:
 | 
				
			||||||
            self.request.session['lt'] = self.request.session['lt'][-100:]
 | 
					            self.request.session['lt'] = self.request.session['lt'][-100:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def check_lt(self):
 | 
				
			||||||
 | 
					        """Check is the POSTed LoginTicket is valid, if yes invalide it"""
 | 
				
			||||||
 | 
					        # save LT for later check
 | 
				
			||||||
 | 
					        lt_valid = self.request.session.get('lt', [])
 | 
				
			||||||
 | 
					        lt_send = self.request.POST.get('lt')
 | 
				
			||||||
 | 
					        # generate a new LT (by posting the LT has been consumed)
 | 
				
			||||||
 | 
					        self.gen_lt()
 | 
				
			||||||
        # check if send LT is valid
 | 
					        # check if send LT is valid
 | 
				
			||||||
        if lt_valid is None or lt_send not in lt_valid:
 | 
					        if lt_valid is None or lt_send not in lt_valid:
 | 
				
			||||||
            return False
 | 
					            return False
 | 
				
			||||||
@@ -238,7 +245,7 @@ class LoginView(View, LogoutMixin):
 | 
				
			|||||||
                    username=self.request.session['username'],
 | 
					                    username=self.request.session['username'],
 | 
				
			||||||
                    session_key=self.request.session.session_key
 | 
					                    session_key=self.request.session.session_key
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
                self.user.save()
 | 
					                self.user.save()  # pragma: no cover (should not happend)
 | 
				
			||||||
            except models.User.DoesNotExist:
 | 
					            except models.User.DoesNotExist:
 | 
				
			||||||
                self.user = models.User.objects.create(
 | 
					                self.user = models.User.objects.create(
 | 
				
			||||||
                    username=self.request.session['username'],
 | 
					                    username=self.request.session['username'],
 | 
				
			||||||
@@ -250,10 +257,15 @@ class LoginView(View, LogoutMixin):
 | 
				
			|||||||
        elif ret == self.USER_ALREADY_LOGGED:
 | 
					        elif ret == self.USER_ALREADY_LOGGED:
 | 
				
			||||||
            pass
 | 
					            pass
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            raise EnvironmentError("invalid output for LoginView.process_post")
 | 
					            raise EnvironmentError("invalid output for LoginView.process_post")  # pragma: no cover
 | 
				
			||||||
        return self.common()
 | 
					        return self.common()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def process_post(self):
 | 
					    def process_post(self):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					            Analyse the POST request:
 | 
				
			||||||
 | 
					                * check that the LoginTicket is valid
 | 
				
			||||||
 | 
					                * check that the user sumited credentials are valid
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        if not self.check_lt():
 | 
					        if not self.check_lt():
 | 
				
			||||||
            values = self.request.POST.copy()
 | 
					            values = self.request.POST.copy()
 | 
				
			||||||
            # if not set a new LT and fail
 | 
					            # if not set a new LT and fail
 | 
				
			||||||
@@ -280,12 +292,14 @@ class LoginView(View, LogoutMixin):
 | 
				
			|||||||
            return self.USER_ALREADY_LOGGED
 | 
					            return self.USER_ALREADY_LOGGED
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def init_get(self, request):
 | 
					    def init_get(self, request):
 | 
				
			||||||
 | 
					        """Initialize GET received parameters"""
 | 
				
			||||||
        self.request = request
 | 
					        self.request = request
 | 
				
			||||||
        self.service = request.GET.get('service')
 | 
					        self.service = request.GET.get('service')
 | 
				
			||||||
        self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False")
 | 
					        self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False")
 | 
				
			||||||
        self.gateway = request.GET.get('gateway')
 | 
					        self.gateway = request.GET.get('gateway')
 | 
				
			||||||
        self.method = request.GET.get('method')
 | 
					        self.method = request.GET.get('method')
 | 
				
			||||||
        self.ajax = 'HTTP_X_AJAX' in request.META
 | 
					        self.ajax = settings.CAS_ENABLE_AJAX_AUTH and 'HTTP_X_AJAX' in request.META
 | 
				
			||||||
 | 
					        self.warn = request.GET.get('warn')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get(self, request, *args, **kwargs):
 | 
					    def get(self, request, *args, **kwargs):
 | 
				
			||||||
        """methode called on GET request on this view"""
 | 
					        """methode called on GET request on this view"""
 | 
				
			||||||
@@ -294,22 +308,24 @@ class LoginView(View, LogoutMixin):
 | 
				
			|||||||
        return self.common()
 | 
					        return self.common()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def process_get(self):
 | 
					    def process_get(self):
 | 
				
			||||||
        # generate a new LT if none is present
 | 
					        """Analyse the GET request"""
 | 
				
			||||||
        self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
 | 
					        # generate a new LT
 | 
				
			||||||
 | 
					        self.gen_lt()
 | 
				
			||||||
        if not self.request.session.get("authenticated") or self.renew:
 | 
					        if not self.request.session.get("authenticated") or self.renew:
 | 
				
			||||||
            self.init_form()
 | 
					            self.init_form()
 | 
				
			||||||
            return self.USER_NOT_AUTHENTICATED
 | 
					            return self.USER_NOT_AUTHENTICATED
 | 
				
			||||||
        return self.USER_AUTHENTICATED
 | 
					        return self.USER_AUTHENTICATED
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def init_form(self, values=None):
 | 
					    def init_form(self, values=None):
 | 
				
			||||||
 | 
					        """Initialization of the good form depending of POST and GET parameters"""
 | 
				
			||||||
        self.form = forms.UserCredential(
 | 
					        self.form = forms.UserCredential(
 | 
				
			||||||
            values,
 | 
					            values,
 | 
				
			||||||
            initial={
 | 
					            initial={
 | 
				
			||||||
                'service': self.service,
 | 
					                'service': self.service,
 | 
				
			||||||
                'method': self.method,
 | 
					                'method': self.method,
 | 
				
			||||||
                'warn': self.request.session.get("warn"),
 | 
					                'warn': self.request.session.get("warn"),
 | 
				
			||||||
                'lt': self.request.session['lt'][-1]
 | 
					                'lt': self.request.session['lt'][-1],
 | 
				
			||||||
 | 
					                'renew': self.renew
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -351,7 +367,7 @@ class LoginView(View, LogoutMixin):
 | 
				
			|||||||
                redirect_url = self.user.get_service_url(
 | 
					                redirect_url = self.user.get_service_url(
 | 
				
			||||||
                    self.service,
 | 
					                    self.service,
 | 
				
			||||||
                    service_pattern,
 | 
					                    service_pattern,
 | 
				
			||||||
                    renew=self.renew
 | 
					                    renew=self.renewed
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
                if not self.ajax:
 | 
					                if not self.ajax:
 | 
				
			||||||
                    return HttpResponseRedirect(redirect_url)
 | 
					                    return HttpResponseRedirect(redirect_url)
 | 
				
			||||||
@@ -580,12 +596,9 @@ class Validate(View):
 | 
				
			|||||||
                        ticket.service_pattern.user_field
 | 
					                        ticket.service_pattern.user_field
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
                    if isinstance(username, list):
 | 
					                    if isinstance(username, list):
 | 
				
			||||||
                        try:
 | 
					                        # the list is not empty because we wont generate a ticket with a user_field
 | 
				
			||||||
                            username = username[0]
 | 
					                        # that evaluate to False
 | 
				
			||||||
                        except IndexError:
 | 
					                        username = username[0]
 | 
				
			||||||
                            username = None
 | 
					 | 
				
			||||||
                    if not username:
 | 
					 | 
				
			||||||
                        username = ""
 | 
					 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    username = ticket.user.username
 | 
					                    username = ticket.user.username
 | 
				
			||||||
                return HttpResponse("yes\n%s\n" % username, content_type="text/plain")
 | 
					                return HttpResponse("yes\n%s\n" % username, content_type="text/plain")
 | 
				
			||||||
@@ -661,6 +674,10 @@ class ValidateService(View, AttributesMixin):
 | 
				
			|||||||
                    params['username'] = self.ticket.user.attributs.get(
 | 
					                    params['username'] = self.ticket.user.attributs.get(
 | 
				
			||||||
                        self.ticket.service_pattern.user_field
 | 
					                        self.ticket.service_pattern.user_field
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
 | 
					                    if isinstance(params['username'], list):
 | 
				
			||||||
 | 
					                        # the list is not empty because we wont generate a ticket with a user_field
 | 
				
			||||||
 | 
					                        # that evaluate to False
 | 
				
			||||||
 | 
					                        params['username'] = params['username'][0]
 | 
				
			||||||
                if self.pgt_url and (
 | 
					                if self.pgt_url and (
 | 
				
			||||||
                    self.pgt_url.startswith("https://") or
 | 
					                    self.pgt_url.startswith("https://") or
 | 
				
			||||||
                    re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url)
 | 
					                    re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url)
 | 
				
			||||||
@@ -762,9 +779,12 @@ class ValidateService(View, AttributesMixin):
 | 
				
			|||||||
                        params,
 | 
					                        params,
 | 
				
			||||||
                        content_type="text/xml; charset=utf-8"
 | 
					                        content_type="text/xml; charset=utf-8"
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
                except requests.exceptions.SSLError as error:
 | 
					                except requests.exceptions.RequestException as error:
 | 
				
			||||||
                    error = utils.unpack_nested_exception(error)
 | 
					                    error = utils.unpack_nested_exception(error)
 | 
				
			||||||
                    raise ValidateError('INVALID_PROXY_CALLBACK', str(error))
 | 
					                    raise ValidateError(
 | 
				
			||||||
 | 
					                        'INVALID_PROXY_CALLBACK',
 | 
				
			||||||
 | 
					                        "%s: %s" % (type(error), str(error))
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                raise ValidateError(
 | 
					                raise ValidateError(
 | 
				
			||||||
                    'INVALID_PROXY_CALLBACK',
 | 
					                    'INVALID_PROXY_CALLBACK',
 | 
				
			||||||
@@ -844,7 +864,7 @@ class Proxy(View):
 | 
				
			|||||||
        except (models.BadUsername, models.BadFilter, models.UserFieldNotDefined):
 | 
					        except (models.BadUsername, models.BadFilter, models.UserFieldNotDefined):
 | 
				
			||||||
            raise ValidateError(
 | 
					            raise ValidateError(
 | 
				
			||||||
                'UNAUTHORIZED_USER',
 | 
					                'UNAUTHORIZED_USER',
 | 
				
			||||||
                '%s not allowed on %s' % (ticket.user, self.target_service)
 | 
					                'User %s not allowed on %s' % (ticket.user.username, self.target_service)
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -903,11 +923,15 @@ class SamlValidate(View, AttributesMixin):
 | 
				
			|||||||
                'username': self.ticket.user.username,
 | 
					                'username': self.ticket.user.username,
 | 
				
			||||||
                'attributes': attributes
 | 
					                'attributes': attributes
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            if self.ticket.service_pattern.user_field and \
 | 
					            if (self.ticket.service_pattern.user_field and
 | 
				
			||||||
                    self.ticket.user.attributs.get(self.ticket.service_pattern.user_field):
 | 
					                    self.ticket.user.attributs.get(self.ticket.service_pattern.user_field)):
 | 
				
			||||||
                params['username'] = self.ticket.user.attributs.get(
 | 
					                params['username'] = self.ticket.user.attributs.get(
 | 
				
			||||||
                    self.ticket.service_pattern.user_field
 | 
					                    self.ticket.service_pattern.user_field
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
 | 
					                if isinstance(params['username'], list):
 | 
				
			||||||
 | 
					                    # the list is not empty because we wont generate a ticket with a user_field
 | 
				
			||||||
 | 
					                    # that evaluate to False
 | 
				
			||||||
 | 
					                    params['username'] = params['username'][0]
 | 
				
			||||||
            logger.info(
 | 
					            logger.info(
 | 
				
			||||||
                "SamlValidate: ticket %s validated for user %s on service %s." % (
 | 
					                "SamlValidate: ticket %s validated for user %s on service %s." % (
 | 
				
			||||||
                    self.ticket.value,
 | 
					                    self.ticket.value,
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										5
									
								
								pytest.ini
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								pytest.ini
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,5 @@
 | 
				
			|||||||
 | 
					[pytest]
 | 
				
			||||||
 | 
					testpaths = cas_server/tests/
 | 
				
			||||||
 | 
					DJANGO_SETTINGS_MODULE = cas_server.tests.settings
 | 
				
			||||||
 | 
					norecursedirs = .* build dist docs
 | 
				
			||||||
 | 
					python_paths = .
 | 
				
			||||||
@@ -1,10 +1,12 @@
 | 
				
			|||||||
tox==1.8.1
 | 
					setuptools>=5.5
 | 
				
			||||||
pytest==2.6.4
 | 
					tox>=1.8.1
 | 
				
			||||||
pytest-django==2.7.0
 | 
					pytest>=2.6.4
 | 
				
			||||||
pytest-pythonpath==0.3
 | 
					pytest-django>=2.8.0
 | 
				
			||||||
 | 
					pytest-pythonpath>=0.3
 | 
				
			||||||
 | 
					pytest-cov>=2.2.1
 | 
				
			||||||
requests>=2.4
 | 
					requests>=2.4
 | 
				
			||||||
django-picklefield>=0.3.1
 | 
					 | 
				
			||||||
requests_futures>=0.9.5
 | 
					requests_futures>=0.9.5
 | 
				
			||||||
 | 
					django-picklefield>=0.3.1
 | 
				
			||||||
django-bootstrap3>=5.4
 | 
					django-bootstrap3>=5.4
 | 
				
			||||||
lxml>=3.4
 | 
					lxml>=3.4
 | 
				
			||||||
six>=1
 | 
					six>=1
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										22
									
								
								run_tests
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								run_tests
									
									
									
									
									
								
							@@ -1,22 +0,0 @@
 | 
				
			|||||||
#!/usr/bin/env python
 | 
					 | 
				
			||||||
import os, sys
 | 
					 | 
				
			||||||
import django
 | 
					 | 
				
			||||||
from django.conf import settings
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import settings_tests
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
settings.configure(**settings_tests.__dict__)
 | 
					 | 
				
			||||||
django.setup()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
try:
 | 
					 | 
				
			||||||
    # Django <= 1.8
 | 
					 | 
				
			||||||
    from django.test.simple import DjangoTestSuiteRunner
 | 
					 | 
				
			||||||
    test_runner = DjangoTestSuiteRunner(verbosity=1)
 | 
					 | 
				
			||||||
except ImportError:
 | 
					 | 
				
			||||||
    # Django >= 1.8
 | 
					 | 
				
			||||||
    from django.test.runner import DiscoverRunner
 | 
					 | 
				
			||||||
    test_runner = DiscoverRunner(verbosity=1)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
failures = test_runner.run_tests(['cas_server'])
 | 
					 | 
				
			||||||
if failures:
 | 
					 | 
				
			||||||
    sys.exit(failures)
 | 
					 | 
				
			||||||
							
								
								
									
										3
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								setup.py
									
									
									
									
									
								
							@@ -34,7 +34,8 @@ setup(
 | 
				
			|||||||
    version='0.4.4',
 | 
					    version='0.4.4',
 | 
				
			||||||
    packages=[
 | 
					    packages=[
 | 
				
			||||||
        'cas_server', 'cas_server.migrations',
 | 
					        'cas_server', 'cas_server.migrations',
 | 
				
			||||||
        'cas_server.management', 'cas_server.management.commands'
 | 
					        'cas_server.management', 'cas_server.management.commands',
 | 
				
			||||||
 | 
					        'cas_server.tests'
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
    include_package_data=True,
 | 
					    include_package_data=True,
 | 
				
			||||||
    license='GPLv3',
 | 
					    license='GPLv3',
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										14
									
								
								tox.ini
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								tox.ini
									
									
									
									
									
								
							@@ -1,12 +1,12 @@
 | 
				
			|||||||
[tox]
 | 
					[tox]
 | 
				
			||||||
envlist=
 | 
					envlist=
 | 
				
			||||||
 | 
					    flake8,
 | 
				
			||||||
    py27-django17,
 | 
					    py27-django17,
 | 
				
			||||||
    py27-django18,
 | 
					    py27-django18,
 | 
				
			||||||
    py27-django19,
 | 
					    py27-django19,
 | 
				
			||||||
    py34-django17,
 | 
					    py34-django17,
 | 
				
			||||||
    py34-django18,
 | 
					    py34-django18,
 | 
				
			||||||
    py34-django19,
 | 
					    py34-django19,
 | 
				
			||||||
    flake8,
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
[flake8]
 | 
					[flake8]
 | 
				
			||||||
max-line-length=100
 | 
					max-line-length=100
 | 
				
			||||||
@@ -17,7 +17,7 @@ deps =
 | 
				
			|||||||
    -r{toxinidir}/requirements-dev.txt
 | 
					    -r{toxinidir}/requirements-dev.txt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[testenv]
 | 
					[testenv]
 | 
				
			||||||
commands=python run_tests {posargs:tests}
 | 
					commands=py.test {posargs:cas_server/tests/}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[testenv:py27-django17]
 | 
					[testenv:py27-django17]
 | 
				
			||||||
basepython=python2.7
 | 
					basepython=python2.7
 | 
				
			||||||
@@ -60,3 +60,13 @@ basepython=python
 | 
				
			|||||||
deps=flake8
 | 
					deps=flake8
 | 
				
			||||||
commands=flake8 {toxinidir}/cas_server
 | 
					commands=flake8 {toxinidir}/cas_server
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[testenv:coverage]
 | 
				
			||||||
 | 
					basepython=python
 | 
				
			||||||
 | 
					passenv=CODACY_PROJECT_TOKEN
 | 
				
			||||||
 | 
					deps=
 | 
				
			||||||
 | 
					    -r{toxinidir}/requirements-dev.txt
 | 
				
			||||||
 | 
					    codacy-coverage
 | 
				
			||||||
 | 
					commands=
 | 
				
			||||||
 | 
					    py.test --cov=cas_server --cov-report xml
 | 
				
			||||||
 | 
					    python-codacy-coverage -r {toxinidir}/coverage.xml
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user