Merge branch 'dev' into federate

This commit is contained in:
Valentin Samir 2016-06-28 00:34:31 +02:00
commit 32b5627c38
28 changed files with 1472 additions and 779 deletions

8
.coveragerc Normal file
View File

@ -0,0 +1,8 @@
[report]
exclude_lines =
pragma: no cover
def __repr__
def __unicode__
raise AssertionError
raise NotImplementedError
if six.PY3:

4
.gitignore vendored
View File

@ -1,5 +1,6 @@
*.pyc *.pyc
*.egg-info *.egg-info
*~
*.swp *.swp
build/ build/
@ -8,6 +9,9 @@ cas/
dist/ dist/
db.sqlite3 db.sqlite3
manage.py manage.py
coverage.xml
.tox .tox
test_venv test_venv
.coverage
htmlcov/

View File

@ -44,3 +44,14 @@ test_project: test_venv test_venv/cas/manage.py
run_test_server: test_project run_test_server: test_project
test_venv/bin/python test_venv/cas/manage.py runserver test_venv/bin/python test_venv/cas/manage.py runserver
coverage: test_venv
test_venv/bin/pip install coverage
test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests
test_venv/bin/coverage html
rm htmlcov/coverage_html.js # I am really pissed off by those keybord shortcuts
coverage_codacy: coverage
test_venv/bin/coverage xml
test_venv/bin/pip install codacy-coverage
test_venv/bin/python-codacy-coverage -r coverage.xml

View File

@ -10,8 +10,14 @@ CAS Server
.. image:: https://img.shields.io/pypi/l/django-cas-server.svg .. image:: https://img.shields.io/pypi/l/django-cas-server.svg
:target: https://www.gnu.org/licenses/gpl-3.0.html :target: https://www.gnu.org/licenses/gpl-3.0.html
.. image:: https://api.codacy.com/project/badge/Grade/255c21623d6946ef8802fa7995b61366
:target: https://www.codacy.com/app/valentin-samir/django-cas-server
.. image:: https://api.codacy.com/project/badge/Coverage/255c21623d6946ef8802fa7995b61366
:target: https://www.codacy.com/app/valentin-samir/django-cas-server
CAS Server is a Django application implementing the `CAS Protocol 3.0 Specification CAS Server is a Django application implementing the `CAS Protocol 3.0 Specification
<https://jasig.github.io/cas/development/protocol/CAS-Protocol-Specification.html>`_. <https://apereo.github.io/cas/4.2.x/protocol/CAS-Protocol-Specification.html>`_.
By defaut, the authentication process use django internal users but you can easily By defaut, the authentication process use django internal users but you can easily
use any sources (see auth classes in the auth.py file) use any sources (see auth classes in the auth.py file)
@ -37,6 +43,15 @@ Features
Quick start Quick start
----------- -----------
0. If you want to make a virtualenv for ``django-cas-server``, you will need the following
dependencies on a bare debian like system::
virtualenv build-essential python-dev libxml2-dev libxslt1-dev zlib1g-dev
If you want to use python3 instead of python2, replace ``python-dev`` with ``python3-dev``.
If you intend to run the tox tests you will also need ``python3.4-dev`` depending of the current
version of python3 on your system.
1. Add "cas_server" to your INSTALLED_APPS setting like this:: 1. Add "cas_server" to your INSTALLED_APPS setting like this::
@ -70,7 +85,7 @@ Quick start
4. You should add some management commands to a crontab: ``clearsessions``, 4. You should add some management commands to a crontab: ``clearsessions``,
``cas_clean_tickets`` and ``cas_clean_sessions``. ``cas_clean_tickets`` and ``cas_clean_sessions``.
* ``clearsessions``: please see `Clearing the session store <https://docs.djangoproject.com/en/1.9/topics/http/sessions/#clearing-the-session-store>`_. * ``clearsessions``: please see `Clearing the session store <https://docs.djangoproject.com/en/stable/topics/http/sessions/#clearing-the-session-store>`_.
* ``cas_clean_tickets``: old tickets and timed-out tickets do not get purge from * ``cas_clean_tickets``: old tickets and timed-out tickets do not get purge from
the database automatically. They are just marked as invalid. ``cas_clean_tickets`` the database automatically. They are just marked as invalid. ``cas_clean_tickets``
is a clean-up management command for this purpose. It send SingleLogOut request is a clean-up management command for this purpose. It send SingleLogOut request
@ -122,14 +137,14 @@ Template settings:
Authentication settings: Authentication settings:
* ``CAS_AUTH_CLASS``: A dotted path to a class implementing ``cas_server.auth.AuthUser``. * ``CAS_AUTH_CLASS``: A dotted path to a class or a class implementing
The default is ``"cas_server.auth.DjangoAuthUser"`` ``cas_server.auth.AuthUser``. The default is ``"cas_server.auth.DjangoAuthUser"``
* ``SESSION_COOKIE_AGE``: This is a django settings. Here, it control the delay in seconds after * ``SESSION_COOKIE_AGE``: This is a django settings. Here, it control the delay in seconds after
which inactive users are logged out. The default is ``1209600`` (2 weeks). You probably should which inactive users are logged out. The default is ``1209600`` (2 weeks). You probably should
reduce it to something like ``86400`` seconds (1 day). reduce it to something like ``86400`` seconds (1 day).
* ``CAS_PROXY_CA_CERTIFICATE_PATH``: Path to certificates authority file. Usually on linux * ``CAS_PROXY_CA_CERTIFICATE_PATH``: Path to certificate authorities file. Usually on linux
the local CAs are in ``/etc/ssl/certs/ca-certificates.crt``. The default is ``True`` which the local CAs are in ``/etc/ssl/certs/ca-certificates.crt``. The default is ``True`` which
tell requests to use its internal certificat authorities. Settings it to ``False`` should tell requests to use its internal certificat authorities. Settings it to ``False`` should
disable all x509 certificates validation and MUST not be done in production. disable all x509 certificates validation and MUST not be done in production.
@ -162,7 +177,7 @@ Tickets validity settings:
application. The default is ``60``. application. The default is ``60``.
* ``CAS_PGT_VALIDITY``: Number of seconds the proxy granting tickets are valid. * ``CAS_PGT_VALIDITY``: Number of seconds the proxy granting tickets are valid.
The default is ``3600`` (1 hour). The default is ``3600`` (1 hour).
* ``CAS_TICKET_TIMEOUT``: Number of seconds a ticket is kept is the database before sending * ``CAS_TICKET_TIMEOUT``: Number of seconds a ticket is kept in the database before sending
Single Log Out request and being cleared. The default is ``86400`` (24 hours). Single Log Out request and being cleared. The default is ``86400`` (24 hours).
Tickets miscellaneous settings: Tickets miscellaneous settings:
@ -184,12 +199,12 @@ Tickets miscellaneous settings:
* ``CAS_SERVICE_TICKET_PREFIX``: Prefix of service tickets. The default is ``"ST"``. * ``CAS_SERVICE_TICKET_PREFIX``: Prefix of service tickets. The default is ``"ST"``.
The CAS specification mandate that service tickets MUST begin with the characters ST The CAS specification mandate that service tickets MUST begin with the characters ST
so you should not change this. so you should not change this.
* ``CAS_PROXY_TICKET_PREFIX``: Prefix of proxy ticket. The default is ``"ST"``. * ``CAS_PROXY_TICKET_PREFIX``: Prefix of proxy ticket. The default is ``"PT"``.
* ``CAS_PROXY_GRANTING_TICKET_PREFIX``: Prefix of proxy granting ticket. The default is ``"PGT"``. * ``CAS_PROXY_GRANTING_TICKET_PREFIX``: Prefix of proxy granting ticket. The default is ``"PGT"``.
* ``CAS_PROXY_GRANTING_TICKET_IOU_PREFIX``: Prefix of proxy granting ticket IOU. The default is ``"PGTIOU"``. * ``CAS_PROXY_GRANTING_TICKET_IOU_PREFIX``: Prefix of proxy granting ticket IOU. The default is ``"PGTIOU"``.
Mysql backend settings. Only usefull is you use the mysql authentication backend: Mysql backend settings. Only usefull if you are using the mysql authentication backend:
* ``CAS_SQL_HOST``: Host for the SQL server. The default is ``"localhost"``. * ``CAS_SQL_HOST``: Host for the SQL server. The default is ``"localhost"``.
* ``CAS_SQL_USERNAME``: Username for connecting to the SQL server. * ``CAS_SQL_USERNAME``: Username for connecting to the SQL server.
@ -200,8 +215,29 @@ Mysql backend settings. Only usefull is you use the mysql authentication backend
The username must be in field ``username``, the password in ``password``, The username must be in field ``username``, the password in ``password``,
additional fields are used as the user attributes. additional fields are used as the user attributes.
The default is ``"SELECT user AS usersame, pass AS password, users.* FROM users WHERE user = %s"`` The default is ``"SELECT user AS usersame, pass AS password, users.* FROM users WHERE user = %s"``
* ``CAS_SQL_PASSWORD_CHECK``: The method used to check the user password. Must be * ``CAS_SQL_PASSWORD_CHECK``: The method used to check the user password. Must be one of the following:
``"crypt"`` or ``"plain``". The default is ``"crypt"``.
* ``"crypt"`` (see <https://en.wikipedia.org/wiki/Crypt_(C)>), the password in the database
should begin this $
* ``"ldap"`` (see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html)
the password in the database must begin with one of {MD5}, {SMD5}, {SHA}, {SSHA}, {SHA256},
{SSHA256}, {SHA384}, {SSHA384}, {SHA512}, {SSHA512}, {CRYPT}.
* ``"hex_HASH_NAME"`` with ``HASH_NAME`` in md5, sha1, sha224, sha256, sha384, sha512.
The hashed password in the database is compare to the hexadecimal digest of the clear
password hashed with the corresponding algorithm.
* ``"plain"``, the password in the database must be in clear.
The default is ``"crypt"``.
Test backend settings. Only usefull if you are using the test authentication backend:
* ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``.
* ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``.
* ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is
``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net',
'alias': ['demo1', 'demo2']}``.
Authentication backend Authentication backend
---------------------- ----------------------
@ -209,8 +245,8 @@ Authentication backend
``django-cas-server`` comes with some authentication backends: ``django-cas-server`` comes with some authentication backends:
* dummy backend ``cas_server.auth.DummyAuthUser``: all authentication attempt fails. * dummy backend ``cas_server.auth.DummyAuthUser``: all authentication attempt fails.
* test backend ``cas_server.auth.TestAuthUser``: username is ``test`` and password is ``test`` * test backend ``cas_server.auth.TestAuthUser``: username, password and returned attributes
the returned attributes for the user are: ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}`` for the user are defined by the ``CAS_TEST_*`` settings.
* django backend ``cas_server.auth.DjangoAuthUser``: Users are authenticated agains django users system. * django backend ``cas_server.auth.DjangoAuthUser``: Users are authenticated agains django users system.
This is the default backend. The returned attributes are the fields available on the user model. This is the default backend. The returned attributes are the fields available on the user model.
* mysql backend ``cas_server.auth.MysqlAuthUser``: see the 'Mysql backend settings' section. * mysql backend ``cas_server.auth.MysqlAuthUser``: see the 'Mysql backend settings' section.
@ -222,7 +258,7 @@ Logs
---- ----
``django-cas-server`` logs most of its actions. To enable login, you must set the ``LOGGING`` ``django-cas-server`` logs most of its actions. To enable login, you must set the ``LOGGING``
(https://docs.djangoproject.com/en/dev/topics/logging) variable is ``settings.py``. (https://docs.djangoproject.com/en/stable/topics/logging) variable in ``settings.py``.
Users successful actions (login, logout) are logged with the level ``INFO``, failures are logged Users successful actions (login, logout) are logged with the level ``INFO``, failures are logged
with the level ``WARNING`` and user attributes transmitted to a service are logged with the level ``DEBUG``. with the level ``WARNING`` and user attributes transmitted to a service are logged with the level ``DEBUG``.

View File

@ -9,4 +9,4 @@
# #
# (c) 2015 Valentin Samir # (c) 2015 Valentin Samir
default_app_config = 'cas_server.apps.AppConfig' default_app_config = 'cas_server.apps.CasAppConfig'

View File

@ -14,9 +14,9 @@ from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket, User, Servi
from .models import Username, ReplaceAttributName, ReplaceAttributValue, FilterAttributValue from .models import Username, ReplaceAttributName, ReplaceAttributValue, FilterAttributValue
from .forms import TicketForm from .forms import TicketForm
tickets_readonly_fields = ('validate', 'service', 'service_pattern', TICKETS_READONLY_FIELDS = ('validate', 'service', 'service_pattern',
'creation', 'renew', 'single_log_out', 'value') 'creation', 'renew', 'single_log_out', 'value')
tickets_fields = ('validate', 'service', 'service_pattern', TICKETS_FIELDS = ('validate', 'service', 'service_pattern',
'creation', 'renew', 'single_log_out') 'creation', 'renew', 'single_log_out')
@ -25,8 +25,8 @@ class ServiceTicketInline(admin.TabularInline):
model = ServiceTicket model = ServiceTicket
extra = 0 extra = 0
form = TicketForm form = TicketForm
readonly_fields = tickets_readonly_fields readonly_fields = TICKETS_READONLY_FIELDS
fields = tickets_fields fields = TICKETS_FIELDS
class ProxyTicketInline(admin.TabularInline): class ProxyTicketInline(admin.TabularInline):
@ -34,8 +34,8 @@ class ProxyTicketInline(admin.TabularInline):
model = ProxyTicket model = ProxyTicket
extra = 0 extra = 0
form = TicketForm form = TicketForm
readonly_fields = tickets_readonly_fields readonly_fields = TICKETS_READONLY_FIELDS
fields = tickets_fields fields = TICKETS_FIELDS
class ProxyGrantingInline(admin.TabularInline): class ProxyGrantingInline(admin.TabularInline):
@ -43,8 +43,8 @@ class ProxyGrantingInline(admin.TabularInline):
model = ProxyGrantingTicket model = ProxyGrantingTicket
extra = 0 extra = 0
form = TicketForm form = TicketForm
readonly_fields = tickets_readonly_fields readonly_fields = TICKETS_READONLY_FIELDS
fields = tickets_fields[1:] fields = TICKETS_FIELDS[1:]
class UserAdmin(admin.ModelAdmin): class UserAdmin(admin.ModelAdmin):

View File

@ -2,6 +2,6 @@ from django.utils.translation import ugettext_lazy as _
from django.apps import AppConfig from django.apps import AppConfig
class AppConfig(AppConfig): class CasAppConfig(AppConfig):
name = 'cas_server' name = 'cas_server'
verbose_name = _('Central Authentication Service') verbose_name = _('Central Authentication Service')

View File

@ -15,10 +15,10 @@ from django.contrib.auth import get_user_model
from django.utils import timezone from django.utils import timezone
from datetime import timedelta from datetime import timedelta
try: try: # pragma: no cover
import MySQLdb import MySQLdb
import MySQLdb.cursors import MySQLdb.cursors
import crypt from utils import check_password
except ImportError: except ImportError:
MySQLdb = None MySQLdb = None
@ -31,14 +31,14 @@ class AuthUser(object):
def test_password(self, password): def test_password(self, password):
"""test `password` agains the user""" """test `password` agains the user"""
raise NotImplemented() raise NotImplementedError()
def attributs(self): def attributs(self):
"""return a dict of user attributes""" """return a dict of user attributes"""
raise NotImplemented() raise NotImplementedError()
class DummyAuthUser(AuthUser): class DummyAuthUser(AuthUser): # pragma: no cover
"""A Dummy authentication class""" """A Dummy authentication class"""
def __init__(self, username): def __init__(self, username):
@ -62,14 +62,14 @@ class TestAuthUser(AuthUser):
def test_password(self, password): def test_password(self, password):
"""test `password` agains the user""" """test `password` agains the user"""
return self.username == "test" and password == "test" return self.username == settings.CAS_TEST_USER and password == settings.CAS_TEST_PASSWORD
def attributs(self): def attributs(self):
"""return a dict of user attributes""" """return a dict of user attributes"""
return {'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'} return settings.CAS_TEST_ATTRIBUTES
class MysqlAuthUser(AuthUser): class MysqlAuthUser(AuthUser): # pragma: no cover
"""A mysql auth class: authentication user agains a mysql database""" """A mysql auth class: authentication user agains a mysql database"""
user = None user = None
@ -94,30 +94,25 @@ class MysqlAuthUser(AuthUser):
def test_password(self, password): def test_password(self, password):
"""test `password` agains the user""" """test `password` agains the user"""
if not self.user: if self.user:
return False return check_password(
else: settings.CAS_SQL_PASSWORD_CHECK,
if settings.CAS_SQL_PASSWORD_CHECK == "plain":
return password == self.user["password"]
elif settings.CAS_SQL_PASSWORD_CHECK == "crypt":
if self.user["password"].startswith('$'):
salt = '$'.join(self.user["password"].split('$', 3)[:-1])
return crypt.crypt(password, salt) == self.user["password"]
else:
return crypt.crypt(
password, password,
self.user["password"][:2] self.user["password"],
) == self.user["password"] settings.CAS_SQL_DBCHARSET
)
else:
return False
def attributs(self): def attributs(self):
"""return a dict of user attributes""" """return a dict of user attributes"""
if not self.user: if self.user:
return {}
else:
return self.user return self.user
else:
return {}
class DjangoAuthUser(AuthUser): class DjangoAuthUser(AuthUser): # pragma: no cover
"""A django auth class: authenticate user agains django internal users""" """A django auth class: authenticate user agains django internal users"""
user = None user = None
@ -131,21 +126,20 @@ class DjangoAuthUser(AuthUser):
def test_password(self, password): def test_password(self, password):
"""test `password` agains the user""" """test `password` agains the user"""
if not self.user: if self.user:
return False
else:
return self.user.check_password(password) return self.user.check_password(password)
else:
return False
def attributs(self): def attributs(self):
"""return a dict of user attributes""" """return a dict of user attributes"""
if not self.user: if self.user:
return {}
else:
attr = {} attr = {}
for field in self.user._meta.fields: for field in self.user._meta.fields:
attr[field.attname] = getattr(self.user, field.attname) attr[field.attname] = getattr(self.user, field.attname)
return attr return attr
else:
return {}
class CASFederateAuth(AuthUser): class CASFederateAuth(AuthUser):
user = None user = None

View File

@ -78,6 +78,17 @@ setting_default('CAS_SQL_USER_QUERY', 'SELECT user AS usersame, pass AS '
'password, users.* FROM users WHERE user = %s') 'password, users.* FROM users WHERE user = %s')
setting_default('CAS_SQL_PASSWORD_CHECK', 'crypt') # crypt or plain setting_default('CAS_SQL_PASSWORD_CHECK', 'crypt') # crypt or plain
setting_default('CAS_TEST_USER', 'test')
setting_default('CAS_TEST_PASSWORD', 'test')
setting_default(
'CAS_TEST_ATTRIBUTES',
{
'nom': 'Nymous',
'prenom': 'Ano',
'email': 'anonymous@example.net',
'alias': ['demo1', 'demo2']
}
)
setting_default('CAS_FEDERATE', False) setting_default('CAS_FEDERATE', False)
# A dict of "provider suffix" -> (provider CAS server url, CAS version, verbose name) # A dict of "provider suffix" -> (provider CAS server url, CAS version, verbose name)

View File

@ -1,55 +1,52 @@
function cas_login(cas_server_login, service, login_service, callback){ function cas_login(cas_server_login, service, login_service, callback){
url = cas_server_login + '?service=' + encodeURIComponent(service); var url = cas_server_login + "?service=" + encodeURIComponent(service);
$.ajax({ $.ajax({
type: 'GET', type: "GET",
url:url, url,
beforeSend: function (request) { beforeSend(request) {
request.setRequestHeader("X-AJAX", "1"); request.setRequestHeader("X-AJAX", "1");
}, },
xhrFields: { xhrFields: {
withCredentials: true withCredentials: true
}, },
success: function(data, textStatus, request){ success(data, textStatus, request){
if(data.status == 'success'){ if(data.status === "success"){
$.ajax({ $.ajax({
type: 'GET', type: "GET",
url: data.url, url: data.url,
xhrFields: { xhrFields: {
withCredentials: true withCredentials: true
}, },
success: callback, success: callback,
error: function (request, textStatus, errorThrown) {}, error(request, textStatus, errorThrown) {},
}); });
} else { } else {
if(data.detail == "login required"){ if(data.detail === "login required"){
window.location.href = cas_server_login + '?service=' + encodeURIComponent(login_service); window.location.href = cas_server_login + "?service=" + encodeURIComponent(login_service);
} else { } else {
alert('error: ' + data.messages[1].message); alert("error: " + data.messages[1].message);
} }
} }
}, },
error: function (request, textStatus, errorThrown) {}, error(request, textStatus, errorThrown) {},
}); });
} }
function cas_logout(cas_server_logout){ function cas_logout(cas_server_logout){
$.ajax({ $.ajax({
type: 'GET', type: "GET",
url: cas_server_logout, url: cas_server_logout,
beforeSend: function (request) { beforeSend(request) {
request.setRequestHeader("X-AJAX", "1"); request.setRequestHeader("X-AJAX", "1");
}, },
xhrFields: { xhrFields: {
withCredentials: true withCredentials: true
}, },
error: function (request, textStatus, errorThrown) {}, error(request, textStatus, errorThrown) {},
success: function(data, textStatus, request){ success(data, textStatus, request){
if(data.status == 'error'){ if(data.status === "error"){
alert('error: ' + data.messages[1].message); alert("error: " + data.messages[1].message);
} }
}, },
}); });
} }

View File

@ -43,14 +43,14 @@ body {
@media screen and (max-width: 680px) { @media screen and (max-width: 680px) {
#app-name { #app-name {
margin: 0px; margin: 0;
} }
#app-name img { #app-name img {
display: block; display: block;
margin: auto; margin: auto;
} }
body { body {
padding-top: 0px; padding-top: 0;
padding-bottom: 0px; padding-bottom: 0;
} }
} }

943
cas_server/tests.py Normal file
View File

@ -0,0 +1,943 @@
from .default_settings import settings
import django
from django.test import TestCase
from django.test import Client
import re
import six
import random
from lxml import etree
from six.moves import range
from cas_server import models
from cas_server import utils
def copy_form(form):
"""Copy form value into a dict"""
params = {}
for field in form:
if field.value():
params[field.name] = field.value()
else:
params[field.name] = ""
return params
def get_login_page_params(client=None):
"""Return a client and the POST params for the client to login"""
if client is None:
client = Client()
response = client.get('/login')
params = copy_form(response.context["form"])
return client, params
def get_auth_client(**update):
"""return a authenticated client"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
params.update(update)
client.post('/login', params)
return client
def get_user_ticket_request(service):
"""Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
client = get_auth_client()
response = client.get("/login", {"service": service})
ticket_value = response['Location'].split('ticket=')[-1]
user = models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
ticket = models.ServiceTicket.objects.get(value=ticket_value)
return (user, ticket)
def get_pgt():
"""return a dict contening a service, user and PGT ticket for this service"""
(host, port) = utils.PGTUrlHandler.run()[1:3]
service = "http://%s:%s" % (host, port)
(user, ticket) = get_user_ticket_request(service)
client = Client()
client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
params = utils.PGTUrlHandler.PARAMS.copy()
params["service"] = service
params["user"] = user
return params
class CheckPasswordCase(TestCase):
"""Tests for the utils function `utils.check_password`"""
def setUp(self):
"""Generate random bytes string that will be used ass passwords"""
self.password1 = utils.gen_saml_id()
self.password2 = utils.gen_saml_id()
if not isinstance(self.password1, bytes):
self.password1 = self.password1.encode("utf8")
self.password2 = self.password2.encode("utf8")
def test_setup(self):
"""check that generated password are bytes"""
self.assertIsInstance(self.password1, bytes)
self.assertIsInstance(self.password2, bytes)
def test_plain(self):
"""test the plain auth method"""
self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
def test_crypt(self):
"""test the crypt auth method"""
if six.PY3:
hashed_password1 = utils.crypt.crypt(
self.password1.decode("utf8"),
"$6$UVVAQvrMyXMF3FF3"
).encode("utf8")
else:
hashed_password1 = utils.crypt.crypt(self.password1, "$6$UVVAQvrMyXMF3FF3")
self.assertTrue(utils.check_password("crypt", self.password1, hashed_password1, "utf8"))
self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8"))
def test_ldap_ssha(self):
"""test the ldap auth method with a {SSHA} scheme"""
salt = b"UVVAQvrMyXMF3FF3"
hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8")
self.assertIsInstance(hashed_password1, bytes)
self.assertTrue(utils.check_password("ldap", self.password1, hashed_password1, "utf8"))
self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8"))
def test_hex_md5(self):
"""test the hex_md5 auth method"""
hashed_password1 = utils.hashlib.md5(self.password1).hexdigest()
self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8"))
self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8"))
def test_hex_sha512(self):
"""test the hex_sha512 auth method"""
hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest()
self.assertTrue(
utils.check_password("hex_sha512", self.password1, hashed_password1, "utf8")
)
self.assertFalse(
utils.check_password("hex_sha512", self.password2, hashed_password1, "utf8")
)
class LoginTestCase(TestCase):
"""Tests for the login view"""
def setUp(self):
"""
Prepare the test context:
* set the auth class to 'cas_server.auth.TestAuthUser'
* create a service pattern for https://www.example.com/**
* Set the service pattern to return all user attributes
"""
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
# For general purpose testing
self.service_pattern = models.ServicePattern.objects.create(
name="example",
pattern="^https://www\.example\.com(/.*)?$",
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
# For testing the restrict_users attributes
self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create(
name="restrict_user_fail",
pattern="^https://restrict_user_fail\.example\.com(/.*)?$",
restrict_users=True,
)
self.service_pattern_restrict_user_success = models.ServicePattern.objects.create(
name="restrict_user_success",
pattern="^https://restrict_user_success\.example\.com(/.*)?$",
restrict_users=True,
)
models.Username.objects.create(
value=settings.CAS_TEST_USER,
service_pattern=self.service_pattern_restrict_user_success
)
# For testing the user attributes filtering conditions
self.service_pattern_filter_fail = models.ServicePattern.objects.create(
name="filter_fail",
pattern="^https://filter_fail\.example\.com(/.*)?$",
)
models.FilterAttributValue.objects.create(
attribut="right",
pattern="^admin$",
service_pattern=self.service_pattern_filter_fail
)
self.service_pattern_filter_success = models.ServicePattern.objects.create(
name="filter_success",
pattern="^https://filter_success\.example\.com(/.*)?$",
)
models.FilterAttributValue.objects.create(
attribut="email",
pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']),
service_pattern=self.service_pattern_filter_success
)
# For testing the user_field attributes
self.service_pattern_field_needed_fail = models.ServicePattern.objects.create(
name="field_needed_fail",
pattern="^https://field_needed_fail\.example\.com(/.*)?$",
user_field="uid"
)
self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
name="field_needed_success",
pattern="^https://field_needed_success\.example\.com(/.*)?$",
user_field="nom"
)
def assert_logged(self, client, response, warn=False, code=200):
"""Assertions testing that client is well authenticated"""
self.assertEqual(response.status_code, code)
self.assertTrue(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
self.assertTrue(client.session["username"] == settings.CAS_TEST_USER)
self.assertTrue(client.session["warn"] is warn)
self.assertTrue(client.session["authenticated"] is True)
self.assertTrue(
models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
)
def assert_login_failed(self, client, response, code=200):
"""Assertions testing a failed login attempt"""
self.assertEqual(response.status_code, code)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
self.assertTrue(client.session.get("username") is None)
self.assertTrue(client.session.get("warn") is None)
self.assertTrue(client.session.get("authenticated") is None)
def test_login_view_post_goodpass_goodlt(self):
"""Test a successul login"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
self.assertTrue(params['lt'] in client.session['lt'])
response = client.post('/login', params)
self.assert_logged(client, response)
# LoginTicket conssumed
self.assertTrue(params['lt'] not in client.session['lt'])
def test_login_view_post_goodpass_goodlt_warn(self):
"""Test a successul login requesting to be warned before creating services tickets"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
params["warn"] = "on"
response = client.post('/login', params)
self.assert_logged(client, response, warn=True)
def test_lt_max(self):
"""Check we only keep the last 100 Login Ticket for a user"""
client, params = get_login_page_params()
current_lt = params["lt"]
i_in_test = random.randint(0, 100)
i_not_in_test = random.randint(100, 150)
for i in range(150):
if i == i_in_test:
self.assertTrue(current_lt in client.session['lt'])
if i == i_not_in_test:
self.assertTrue(current_lt not in client.session['lt'])
self.assertTrue(len(client.session['lt']) <= 100)
client, params = get_login_page_params(client)
self.assertTrue(len(client.session['lt']) <= 100)
def test_login_view_post_badlt(self):
"""Login attempt with a bad LoginTicket"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
params["lt"] = 'LT-random'
response = client.post('/login', params)
self.assert_login_failed(client, response)
self.assertTrue(b"Invalid login ticket" in response.content)
def test_login_view_post_badpass_good_lt(self):
"""Login attempt with a bad password"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = "test2"
response = client.post('/login', params)
self.assert_login_failed(client, response)
self.assertTrue(
(
b"The credentials you provided cannot be "
b"determined to be authentic"
) in response.content
)
def assert_ticket_attributes(self, client, ticket_value):
"""check the ticket attributes in the db"""
user = models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
self.assertTrue(user)
ticket = models.ServiceTicket.objects.get(value=ticket_value)
self.assertEqual(ticket.user, user)
self.assertEqual(ticket.attributs, settings.CAS_TEST_ATTRIBUTES)
self.assertEqual(ticket.validate, False)
self.assertEqual(ticket.service_pattern, self.service_pattern)
def assert_service_ticket(self, client, response):
"""check that a ticket is well emited when requested on a allowed service"""
self.assertEqual(response.status_code, 302)
self.assertTrue(response.has_header('Location'))
self.assertTrue(
response['Location'].startswith(
"https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX
)
)
ticket_value = response['Location'].split('ticket=')[-1]
self.assert_ticket_attributes(client, ticket_value)
def test_view_login_get_allowed_service(self):
"""Request a ticket for an allowed service by an unauthenticated client"""
client = Client()
response = client.get("/login?service=https://www.example.com")
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"Authentication required by service "
b"example (https://www.example.com)"
) in response.content
)
def test_view_login_get_denied_service(self):
"""Request a ticket for an denied service by an unauthenticated client"""
client = Client()
response = client.get("/login?service=https://www.example.net")
self.assertEqual(response.status_code, 200)
self.assertTrue(b"Service https://www.example.net non allowed" in response.content)
def test_view_login_get_auth_allowed_service(self):
"""Request a ticket for an allowed service by an authenticated client"""
# client is already authenticated
client = get_auth_client()
response = client.get("/login?service=https://www.example.com")
self.assert_service_ticket(client, response)
def test_view_login_get_auth_allowed_service_warn(self):
"""Request a ticket for an allowed service by an authenticated client"""
# client is already authenticated
client = get_auth_client(warn="on")
response = client.get("/login?service=https://www.example.com")
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"Authentication has been required by service "
b"example (https://www.example.com)"
) in response.content
)
params = copy_form(response.context["form"])
response = client.post("/login", params)
self.assert_service_ticket(client, response)
def test_view_login_get_auth_denied_service(self):
"""Request a ticket for a not allowed service by an authenticated client"""
client = get_auth_client()
response = client.get("/login?service=https://www.example.org")
self.assertEqual(response.status_code, 200)
self.assertTrue(b"Service https://www.example.org non allowed" in response.content)
def test_user_logged_not_in_db(self):
"""If the user is logged but has been delete from the database, it should be logged out"""
client = get_auth_client()
models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
).delete()
response = client.get("/login")
self.assert_login_failed(client, response, code=302)
if django.VERSION < (1, 9):
self.assertEqual(response["Location"], "http://testserver/login")
else:
self.assertEqual(response["Location"], "/login?")
def test_service_restrict_user(self):
"""Testing the restric user capability fro a service"""
service = "https://restrict_user_fail.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 200)
self.assertTrue(b"Username non allowed" in response.content)
service = "https://restrict_user_success.example.com"
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 302)
self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
def test_service_filter(self):
"""Test the filtering on user attributes"""
service = "https://filter_fail.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 200)
self.assertTrue(b"User charateristics non allowed" in response.content)
service = "https://filter_success.example.com"
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 302)
self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
def test_service_user_field(self):
"""Test using a user attribute as username: case on if the attribute exists or not"""
service = "https://field_needed_fail.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 200)
self.assertTrue(b"The attribut uid is needed to use that service" in response.content)
service = "https://field_needed_success.example.com"
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 302)
self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
def test_gateway(self):
"""test gateway parameter"""
# First with an authenticated client that fail to get a ticket for a service
service = "https://restrict_user_fail.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service, 'gateway': 'on'})
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], service)
# second for an user not yet authenticated on a valid service
client = Client()
response = client.get('/login', {'service': service, 'gateway': 'on'})
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], service)
class LogoutTestCase(TestCase):
def setUp(self):
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
def test_logout_view(self):
client = get_auth_client()
response = client.get("/login")
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
response = client.get("/logout")
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"You have successfully logged out from "
b"the Central Authentication Service"
) in response.content
)
response = client.get("/login")
self.assertEqual(response.status_code, 200)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
def test_logout_view_url(self):
client = get_auth_client()
response = client.get('/logout?url=https://www.example.com')
self.assertEqual(response.status_code, 302)
self.assertTrue(response.has_header("Location"))
self.assertEqual(response["Location"], "https://www.example.com")
response = client.get("/login")
self.assertEqual(response.status_code, 200)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
def test_logout_view_service(self):
client = get_auth_client()
response = client.get('/logout?service=https://www.example.com')
self.assertEqual(response.status_code, 302)
self.assertTrue(response.has_header("Location"))
self.assertEqual(response["Location"], "https://www.example.com")
response = client.get("/login")
self.assertEqual(response.status_code, 200)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
class AuthTestCase(TestCase):
def setUp(self):
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
self.service = 'https://www.example.com'
models.ServicePattern.objects.create(
name="example",
pattern="^https://www\.example\.com(/.*)?$"
)
def test_auth_view_goodpass(self):
settings.CAS_AUTH_SHARED_SECRET = 'test'
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': settings.CAS_TEST_PASSWORD,
'service': self.service,
'secret': 'test'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'yes\n')
def test_auth_view_badpass(self):
settings.CAS_AUTH_SHARED_SECRET = 'test'
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': 'badpass',
'service': self.service,
'secret': 'test'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
def test_auth_view_badservice(self):
settings.CAS_AUTH_SHARED_SECRET = 'test'
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': settings.CAS_TEST_PASSWORD,
'service': 'https://www.example.org',
'secret': 'test'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
def test_auth_view_badsecret(self):
settings.CAS_AUTH_SHARED_SECRET = 'test'
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': settings.CAS_TEST_PASSWORD,
'service': self.service,
'secret': 'badsecret'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
def test_auth_view_badsettings(self):
settings.CAS_AUTH_SHARED_SECRET = None
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': settings.CAS_TEST_PASSWORD,
'service': self.service,
'secret': 'test'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"no\nplease set CAS_AUTH_SHARED_SECRET")
class ValidateTestCase(TestCase):
def setUp(self):
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
self.service = 'https://www.example.com'
self.service_pattern = models.ServicePattern.objects.create(
name="example",
pattern="^https://www\.example\.com(/.*)?$"
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_validate_view_ok(self):
ticket = get_user_ticket_request(self.service)[1]
client = Client()
response = client.get('/validate', {'ticket': ticket.value, 'service': self.service})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'yes\ntest\n')
def test_validate_view_badservice(self):
ticket = get_user_ticket_request(self.service)[1]
client = Client()
response = client.get(
'/validate',
{'ticket': ticket.value, 'service': "https://www.example.org"}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
def test_validate_view_badticket(self):
get_user_ticket_request(self.service)
client = Client()
response = client.get(
'/validate',
{'ticket': "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX, 'service': self.service}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
class ValidateServiceTestCase(TestCase):
def setUp(self):
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
self.service = 'http://127.0.0.1:45678'
self.service_pattern = models.ServicePattern.objects.create(
name="localhost",
pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
proxy_callback=True
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_validate_service_view_ok(self):
ticket = get_user_ticket_request(self.service)[1]
client = Client()
response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
sucess = root.xpath(
"//cas:authenticationSuccess",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertTrue(sucess)
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(users), 1)
self.assertEqual(users[0].text, settings.CAS_TEST_USER)
attributes = root.xpath(
"//cas:attributes",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(attributes), 1)
attrs1 = set()
for attr in attributes[0]:
attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(attributes), len(attrs1))
attrs2 = set()
for attr in attributes:
attrs2.add((attr.attrib['name'], attr.attrib['value']))
original = set()
for key, value in settings.CAS_TEST_ATTRIBUTES.items():
if isinstance(value, list):
for sub_value in value:
original.add((key, sub_value))
else:
original.add((key, value))
self.assertEqual(attrs1, attrs2)
self.assertEqual(attrs1, original)
def test_validate_service_view_badservice(self):
ticket = get_user_ticket_request(self.service)[1]
client = Client()
bad_service = "https://www.example.org"
response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': bad_service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_SERVICE")
self.assertEqual(error[0].text, bad_service)
def test_validate_service_view_badticket_goodprefix(self):
get_user_ticket_request(self.service)
client = Client()
bad_ticket = "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX
response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
self.assertEqual(error[0].text, 'ticket not found')
def test_validate_service_view_badticket_badprefix(self):
get_user_ticket_request(self.service)
client = Client()
bad_ticket = "RANDOM"
response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
self.assertEqual(error[0].text, bad_ticket)
def test_validate_service_view_ok_pgturl(self):
(host, port) = utils.PGTUrlHandler.run()[1:3]
service = "http://%s:%s" % (host, port)
ticket = get_user_ticket_request(service)[1]
client = Client()
response = client.get(
'/serviceValidate',
{'ticket': ticket.value, 'service': service, 'pgtUrl': service}
)
pgt_params = utils.PGTUrlHandler.PARAMS.copy()
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
pgtiou = root.xpath(
"//cas:proxyGrantingTicket",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(pgtiou), 1)
self.assertEqual(pgt_params["pgtIou"], pgtiou[0].text)
self.assertTrue("pgtId" in pgt_params)
def test_validate_service_pgturl_bad_proxy_callback(self):
self.service_pattern.proxy_callback = False
self.service_pattern.save()
ticket = get_user_ticket_request(self.service)[1]
client = Client()
response = client.get(
'/serviceValidate',
{'ticket': ticket.value, 'service': self.service, 'pgtUrl': self.service}
)
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_PROXY_CALLBACK")
self.assertEqual(error[0].text, "callback url not allowed by configuration")
class ProxyTestCase(TestCase):
def setUp(self):
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
self.service = 'http://127.0.0.1'
self.service_pattern = models.ServicePattern.objects.create(
name="localhost",
pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
proxy=True,
proxy_callback=True
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_validate_proxy_ok(self):
params = get_pgt()
# get a proxy ticket
client1 = Client()
response = client1.get('/proxy', {'pgt': params['pgtId'], 'targetService': self.service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
sucess = root.xpath("//cas:proxySuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertTrue(sucess)
proxy_ticket = root.xpath(
"//cas:proxyTicket",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(proxy_ticket), 1)
proxy_ticket = proxy_ticket[0].text
# validate the proxy ticket
client2 = Client()
response = client2.get('/proxyValidate', {'ticket': proxy_ticket, 'service': self.service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
sucess = root.xpath(
"//cas:authenticationSuccess",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertTrue(sucess)
# check that the proxy is send to the end service
proxies = root.xpath("//cas:proxies", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(proxies), 1)
proxy = proxies[0].xpath("//cas:proxy", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(proxy), 1)
self.assertEqual(proxy[0].text, params["service"])
# same tests than those for serviceValidate
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(users), 1)
self.assertEqual(users[0].text, settings.CAS_TEST_USER)
attributes = root.xpath(
"//cas:attributes",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(attributes), 1)
attrs1 = set()
for attr in attributes[0]:
attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(attributes), len(attrs1))
attrs2 = set()
for attr in attributes:
attrs2.add((attr.attrib['name'], attr.attrib['value']))
original = set()
for key, value in settings.CAS_TEST_ATTRIBUTES.items():
if isinstance(value, list):
for sub_value in value:
original.add((key, sub_value))
else:
original.add((key, value))
self.assertEqual(attrs1, attrs2)
self.assertEqual(attrs1, original)
def test_validate_proxy_bad(self):
params = get_pgt()
# bad PGT
client1 = Client()
response = client1.get(
'/proxy',
{
'pgt': "%s-RANDOM" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX,
'targetService': params['service']
}
)
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
self.assertEqual(
error[0].text,
"PGT %s-RANDOM not found" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX
)
# bad targetService
client2 = Client()
response = client2.get(
'/proxy',
{'pgt': params['pgtId'], 'targetService': "https://www.example.org"}
)
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE")
self.assertEqual(error[0].text, "https://www.example.org")
# service do not allow proxy ticket
self.service_pattern.proxy = False
self.service_pattern.save()
client3 = Client()
response = client3.get(
'/proxy',
{'pgt': params['pgtId'], 'targetService': params['service']}
)
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE")
self.assertEqual(
error[0].text,
'the service %s do not allow proxy ticket' % params['service']
)

View File

@ -14,7 +14,7 @@ from django.conf.urls import patterns, url
from django.views.generic import RedirectView from django.views.generic import RedirectView
from django.views.decorators.debug import sensitive_post_parameters, sensitive_variables from django.views.decorators.debug import sensitive_post_parameters, sensitive_variables
import views from cas_server import views
urlpatterns = patterns( urlpatterns = patterns(
'', '',

View File

@ -19,14 +19,15 @@ from django.contrib import messages
import random import random
import string import string
import json import json
import hashlib
import crypt
import base64
import six
from threading import Thread
from importlib import import_module from importlib import import_module
from datetime import datetime, timedelta from datetime import datetime, timedelta
from six.moves import BaseHTTPServer
try: from six.moves.urllib.parse import urlparse, urlunparse, parse_qsl, urlencode
from urlparse import urlparse, urlunparse, parse_qsl
from urllib import urlencode
except ImportError:
from urllib.parse import urlparse, urlunparse, parse_qsl, urlencode
def context(params): def context(params):
@ -34,7 +35,7 @@ def context(params):
return params return params
def JsonResponse(request, data): def json_response(request, data):
data["messages"] = [] data["messages"] = []
for msg in messages.get_messages(request): for msg in messages.get_messages(request):
data["messages"].append({'message': msg.message, 'level': msg.level_tag}) data["messages"].append({'message': msg.message, 'level': msg.level_tag})
@ -120,9 +121,9 @@ def update_url(url, params):
query = dict(parse_qsl(url_parts[4])) query = dict(parse_qsl(url_parts[4]))
query.update(params) query.update(params)
url_parts[4] = urlencode(query) url_parts[4] = urlencode(query)
for i in range(len(url_parts)): for i, url_part in enumerate(url_parts):
if not isinstance(url_parts[i], bytes): if not isinstance(url_part, bytes):
url_parts[i] = url_parts[i].encode('utf-8') url_parts[i] = url_part.encode('utf-8')
return urlunparse(url_parts).decode('utf-8') return urlunparse(url_parts).decode('utf-8')
@ -190,3 +191,207 @@ def get_tuple(tuple, index, default=None):
return tuple[index] return tuple[index]
except IndexError: except IndexError:
return default return default
class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler):
PARAMS = {}
def do_GET(self):
self.send_response(200)
self.send_header(b"Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"ok")
url = urlparse(self.path)
params = dict(parse_qsl(url.query))
PGTUrlHandler.PARAMS.update(params)
def log_message(self, *args):
return
@staticmethod
def run():
server_class = BaseHTTPServer.HTTPServer
httpd = server_class(("127.0.0.1", 0), PGTUrlHandler)
(host, port) = httpd.socket.getsockname()
def lauch():
httpd.handle_request()
httpd.server_close()
httpd_thread = Thread(target=lauch)
httpd_thread.daemon = True
httpd_thread.start()
return (httpd_thread, host, port)
class LdapHashUserPassword(object):
"""Please see https://tools.ietf.org/id/draft-stroeder-hashed-userpassword-values-01.html"""
schemes_salt = {b"{SMD5}", b"{SSHA}", b"{SSHA256}", b"{SSHA384}", b"{SSHA512}", b"{CRYPT}"}
schemes_nosalt = {b"{MD5}", b"{SHA}", b"{SHA256}", b"{SHA384}", b"{SHA512}"}
_schemes_to_hash = {
b"{SMD5}": hashlib.md5,
b"{MD5}": hashlib.md5,
b"{SSHA}": hashlib.sha1,
b"{SHA}": hashlib.sha1,
b"{SSHA256}": hashlib.sha256,
b"{SHA256}": hashlib.sha256,
b"{SSHA384}": hashlib.sha384,
b"{SHA384}": hashlib.sha384,
b"{SSHA512}": hashlib.sha512,
b"{SHA512}": hashlib.sha512
}
_schemes_to_len = {
b"{SMD5}": 16,
b"{SSHA}": 20,
b"{SSHA256}": 32,
b"{SSHA384}": 48,
b"{SSHA512}": 64,
}
class BadScheme(ValueError):
"""Error raised then the hash scheme is not in schemes_salt + schemes_nosalt"""
pass
class BadHash(ValueError):
"""Error raised then the hash is too short"""
pass
class BadSalt(ValueError):
"""Error raised then with the scheme {CRYPT} the salt is invalid"""
pass
@classmethod
def _raise_bad_scheme(cls, scheme, valid, msg):
"""
Raise BadScheme error for `scheme`, possible valid scheme are
in `valid`, the error message is `msg`
"""
valid_schemes = [s.decode() for s in valid]
valid_schemes.sort()
raise cls.BadScheme(msg % (scheme, u", ".join(valid_schemes)))
@classmethod
def _test_scheme(cls, scheme):
"""Test if a scheme is valide or raise BadScheme"""
if scheme not in cls.schemes_salt and scheme not in cls.schemes_nosalt:
cls._raise_bad_scheme(
scheme,
cls.schemes_salt | cls.schemes_nosalt,
"The scheme %r is not valid. Valide schemes are %s."
)
@classmethod
def _test_scheme_salt(cls, scheme):
"""Test if the scheme need a salt or raise BadScheme"""
if scheme not in cls.schemes_salt:
cls._raise_bad_scheme(
scheme,
cls.schemes_salt,
"The scheme %r is only valid without a salt. Valide schemes with salt are %s."
)
@classmethod
def _test_scheme_nosalt(cls, scheme):
"""Test if the scheme need no salt or raise BadScheme"""
if scheme not in cls.schemes_nosalt:
cls._raise_bad_scheme(
scheme,
cls.schemes_nosalt,
"The scheme %r is only valid with a salt. Valide schemes without salt are %s."
)
@classmethod
def hash(cls, scheme, password, salt=None, charset="utf8"):
"""
Hash `password` with `scheme` using `salt`.
This three variable beeing encoded in `charset`.
"""
scheme = scheme.upper()
cls._test_scheme(scheme)
if salt is None or salt == b"":
salt = b""
cls._test_scheme_nosalt(scheme)
elif salt is not None:
cls._test_scheme_salt(scheme)
try:
return scheme + base64.b64encode(
cls._schemes_to_hash[scheme](password + salt).digest() + salt
)
except KeyError:
if six.PY3:
password = password.decode(charset)
salt = salt.decode(charset)
hashed_password = crypt.crypt(password, salt)
if hashed_password is None:
raise cls.BadSalt("System crypt implementation do not support the salt %r" % salt)
if six.PY3:
hashed_password = hashed_password.encode(charset)
return scheme + hashed_password
@classmethod
def get_scheme(cls, hashed_passord):
"""Return the scheme of `hashed_passord` or raise BadHash"""
if not hashed_passord[0] == b'{'[0] or b'}' not in hashed_passord:
raise cls.BadHash("%r should start with the scheme enclosed with { }" % hashed_passord)
scheme = hashed_passord.split(b'}', 1)[0]
scheme = scheme.upper() + b"}"
return scheme
@classmethod
def get_salt(cls, hashed_passord):
"""Return the salt of `hashed_passord` possibly empty"""
scheme = cls.get_scheme(hashed_passord)
cls._test_scheme(scheme)
if scheme in cls.schemes_nosalt:
return b""
elif scheme == b'{CRYPT}':
return b'$'.join(hashed_passord.split(b'$', 3)[:-1])
else:
hashed_passord = base64.b64decode(hashed_passord[len(scheme):])
if len(hashed_passord) < cls._schemes_to_len[scheme]:
raise cls.BadHash("Hash too short for the scheme %s" % scheme)
return hashed_passord[cls._schemes_to_len[scheme]:]
def check_password(method, password, hashed_password, charset):
"""
Check that `password` match `hashed_password` using `method`,
assuming the encoding is `charset`.
"""
if not isinstance(password, six.binary_type):
password = password.encode(charset)
if not isinstance(hashed_password, six.binary_type):
hashed_password = hashed_password.encode(charset)
if method == "plain":
return password == hashed_password
elif method == "crypt":
if hashed_password.startswith(b'$'):
salt = b'$'.join(hashed_password.split(b'$', 3)[:-1])
elif hashed_password.startswith(b'_'):
salt = hashed_password[:9]
else:
salt = hashed_password[:2]
if six.PY3:
password = password.decode(charset)
salt = salt.decode(charset)
hashed_password = hashed_password.decode(charset)
crypted_password = crypt.crypt(password, salt)
if crypted_password is None:
raise ValueError("System crypt implementation do not support the salt %r" % salt)
return crypted_password == hashed_password
elif method == "ldap":
scheme = LdapHashUserPassword.get_scheme(hashed_password)
salt = LdapHashUserPassword.get_salt(hashed_password)
return LdapHashUserPassword.hash(scheme, password, salt, charset=charset) == hashed_password
elif (
method.startswith("hex_") and
method[4:] in {"md5", "sha1", "sha224", "sha256", "sha384", "sha512"}
):
return getattr(
hashlib,
method[4:]
)(password).hexdigest().encode("ascii") == hashed_password.lower()
else:
raise ValueError("Unknown password method check %r" % method)

View File

@ -23,6 +23,7 @@ from django.views.decorators.csrf import csrf_exempt
from django.middleware.csrf import CsrfViewMiddleware from django.middleware.csrf import CsrfViewMiddleware
from django.views.generic import View from django.views.generic import View
import re
import logging import logging
import pprint import pprint
import requests import requests
@ -34,7 +35,7 @@ import cas_server.utils as utils
import cas_server.forms as forms import cas_server.forms as forms
import cas_server.models as models import cas_server.models as models
from .utils import JsonResponse from .utils import json_response
from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket
from .models import ServicePattern from .models import ServicePattern
from .federate import CASFederateValidateUser from .federate import CASFederateValidateUser
@ -63,12 +64,12 @@ class AttributesMixin(object):
class LogoutMixin(object): class LogoutMixin(object):
"""destroy CAS session utils""" """destroy CAS session utils"""
def logout(self, all=False): def logout(self, all_session=False):
"""effectively destroy CAS session""" """effectively destroy CAS session"""
session_nb = 0 session_nb = 0
username = self.request.session.get("username") username = self.request.session.get("username")
if username: if username:
if all: if all_session:
logger.info("Logging out user %s from all of they sessions." % username) logger.info("Logging out user %s from all of they sessions." % username)
else: else:
logger.info("Logging out user %s." % username) logger.info("Logging out user %s." % username)
@ -91,8 +92,8 @@ class LogoutMixin(object):
# if user not found in database, flush the session anyway # if user not found in database, flush the session anyway
self.request.session.flush() self.request.session.flush()
# If all is set logout user from alternative sessions # If all_session is set logout user from alternative sessions
if all: if all_session:
for user in models.User.objects.filter(username=username): for user in models.User.objects.filter(username=username):
session = SessionStore(session_key=user.session_key) session = SessionStore(session_key=user.session_key)
session.flush() session.flush()
@ -110,6 +111,7 @@ class LogoutView(View, LogoutMixin):
service = None service = None
def init_get(self, request): def init_get(self, request):
"""Initialize GET received parameters"""
self.request = request self.request = request
self.service = request.GET.get('service') self.service = request.GET.get('service')
self.url = request.GET.get('url') self.url = request.GET.get('url')
@ -170,13 +172,13 @@ class LogoutView(View, LogoutMixin):
'url': url, 'url': url,
'session_nb': session_nb 'session_nb': session_nb
} }
return JsonResponse(request, data) return json_response(request, data)
else: else:
return redirect("cas_server:login") return redirect("cas_server:login")
else: else:
if self.ajax: if self.ajax:
data = {'status': 'success', 'detail': 'logout', 'session_nb': session_nb} data = {'status': 'success', 'detail': 'logout', 'session_nb': session_nb}
return JsonResponse(request, data) return json_response(request, data)
else: else:
return render( return render(
request, request,
@ -290,12 +292,10 @@ class LoginView(View, LogoutMixin):
USER_NOT_AUTHENTICATED = 6 USER_NOT_AUTHENTICATED = 6
def init_post(self, request): def init_post(self, request):
"""Initialize POST received parameters"""
self.request = request self.request = request
self.service = request.POST.get('service') self.service = request.POST.get('service')
if request.POST.get('renew') and request.POST['renew'] != "False": self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False")
self.renew = True
else:
self.renew = False
self.gateway = request.POST.get('gateway') self.gateway = request.POST.get('gateway')
self.method = request.POST.get('method') self.method = request.POST.get('method')
self.ajax = 'HTTP_X_AJAX' in request.META self.ajax = 'HTTP_X_AJAX' in request.META
@ -306,15 +306,19 @@ class LoginView(View, LogoutMixin):
self.username = request.POST.get('username') self.username = request.POST.get('username')
self.ticket = request.POST.get('ticket') self.ticket = request.POST.get('ticket')
def check_lt(self): def gen_lt(self):
# save LT for later check """Generate a new LoginTicket and add it to the list of valid LT for the user"""
lt_valid = self.request.session.get('lt', [])
lt_send = self.request.POST.get('lt')
# generate a new LT (by posting the LT has been consumed)
self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()] self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
if len(self.request.session['lt']) > 100: if len(self.request.session['lt']) > 100:
self.request.session['lt'] = self.request.session['lt'][-100:] self.request.session['lt'] = self.request.session['lt'][-100:]
def check_lt(self):
"""Check is the POSTed LoginTicket is valid, if yes invalide it"""
# save LT for later check
lt_valid = self.request.session.get('lt', [])
lt_send = self.request.POST.get('lt')
# generate a new LT (by posting the LT has been consumed)
self.gen_lt()
# check if send LT is valid # check if send LT is valid
if lt_valid is None or lt_send not in lt_valid: if lt_valid is None or lt_send not in lt_valid:
return False return False
@ -339,7 +343,7 @@ class LoginView(View, LogoutMixin):
username=self.request.session['username'], username=self.request.session['username'],
session_key=self.request.session.session_key session_key=self.request.session.session_key
) )
self.user.save() self.user.save() # pragma: no cover (should not happend)
except models.User.DoesNotExist: except models.User.DoesNotExist:
self.user = models.User.objects.create( self.user = models.User.objects.create(
username=self.request.session['username'], username=self.request.session['username'],
@ -355,10 +359,15 @@ class LoginView(View, LogoutMixin):
elif ret == self.USER_ALREADY_LOGGED: elif ret == self.USER_ALREADY_LOGGED:
pass pass
else: else:
raise EnvironmentError("invalid output for LoginView.process_post") raise EnvironmentError("invalid output for LoginView.process_post") # pragma: no cover
return self.common() return self.common()
def process_post(self, pytest=False): def process_post(self):
"""
Analyse the POST request:
* check that the LoginTicket is valid
* check that the user sumited credentials are valid
"""
if not self.check_lt(): if not self.check_lt():
values = self.request.POST.copy() values = self.request.POST.copy()
# if not set a new LT and fail # if not set a new LT and fail
@ -385,12 +394,10 @@ class LoginView(View, LogoutMixin):
return self.USER_ALREADY_LOGGED return self.USER_ALREADY_LOGGED
def init_get(self, request): def init_get(self, request):
"""Initialize GET received parameters"""
self.request = request self.request = request
self.service = request.GET.get('service') self.service = request.GET.get('service')
if request.GET.get('renew') and request.GET['renew'] != "False": self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False")
self.renew = True
else:
self.renew = False
self.gateway = request.GET.get('gateway') self.gateway = request.GET.get('gateway')
self.method = request.GET.get('method') self.method = request.GET.get('method')
self.ajax = 'HTTP_X_AJAX' in request.META self.ajax = 'HTTP_X_AJAX' in request.META
@ -410,15 +417,16 @@ class LoginView(View, LogoutMixin):
return self.common() return self.common()
def process_get(self): def process_get(self):
# generate a new LT if none is present """Analyse the GET request"""
self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()] # generate a new LT
self.gen_lt()
if not self.request.session.get("authenticated") or self.renew: if not self.request.session.get("authenticated") or self.renew:
self.init_form() self.init_form()
return self.USER_NOT_AUTHENTICATED return self.USER_NOT_AUTHENTICATED
return self.USER_AUTHENTICATED return self.USER_AUTHENTICATED
def init_form(self, values=None): def init_form(self, values=None):
"""Initialization of the good form depending of POST and GET parameters"""
form_initial = { form_initial = {
'service': self.service, 'service': self.service,
'method': self.method, 'method': self.method,
@ -459,7 +467,7 @@ class LoginView(View, LogoutMixin):
) )
if self.ajax: if self.ajax:
data = {"status": "error", "detail": "confirmation needed"} data = {"status": "error", "detail": "confirmation needed"}
return JsonResponse(self.request, data) return json_response(self.request, data)
else: else:
warn_form = forms.WarnForm(initial={ warn_form = forms.WarnForm(initial={
'service': self.service, 'service': self.service,
@ -486,7 +494,7 @@ class LoginView(View, LogoutMixin):
return HttpResponseRedirect(redirect_url) return HttpResponseRedirect(redirect_url)
else: else:
data = {"status": "success", "detail": "auth", "url": redirect_url} data = {"status": "success", "detail": "auth", "url": redirect_url}
return JsonResponse(self.request, data) return json_response(self.request, data)
except ServicePattern.DoesNotExist: except ServicePattern.DoesNotExist:
error = 1 error = 1
messages.add_message( messages.add_message(
@ -530,7 +538,7 @@ class LoginView(View, LogoutMixin):
) )
else: else:
data = {"status": "error", "detail": "auth", "code": error} data = {"status": "error", "detail": "auth", "code": error}
return JsonResponse(self.request, data) return json_response(self.request, data)
def authenticated(self): def authenticated(self):
"""Processing authenticated users""" """Processing authenticated users"""
@ -552,7 +560,7 @@ class LoginView(View, LogoutMixin):
"detail": "login required", "detail": "login required",
"url": utils.reverse_params("cas_server:login", params=self.request.GET) "url": utils.reverse_params("cas_server:login", params=self.request.GET)
} }
return JsonResponse(self.request, data) return json_response(self.request, data)
else: else:
return utils.redirect_params("cas_server:login", params=self.request.GET) return utils.redirect_params("cas_server:login", params=self.request.GET)
@ -562,7 +570,7 @@ class LoginView(View, LogoutMixin):
else: else:
if self.ajax: if self.ajax:
data = {"status": "success", "detail": "logged"} data = {"status": "success", "detail": "logged"}
return JsonResponse(self.request, data) return json_response(self.request, data)
else: else:
return render( return render(
self.request, self.request,
@ -605,7 +613,7 @@ class LoginView(View, LogoutMixin):
"detail": "login required", "detail": "login required",
"url": utils.reverse_params("cas_server:login", params=self.request.GET) "url": utils.reverse_params("cas_server:login", params=self.request.GET)
} }
return JsonResponse(self.request, data) return json_response(self.request, data)
else: else:
if settings.CAS_FEDERATE: if settings.CAS_FEDERATE:
if self.username and self.ticket: if self.username and self.ticket:
@ -824,7 +832,10 @@ class ValidateService(View, AttributesMixin):
params['username'] = self.ticket.user.attributs.get( params['username'] = self.ticket.user.attributs.get(
self.ticket.service_pattern.user_field self.ticket.service_pattern.user_field
) )
if self.pgt_url and self.pgt_url.startswith("https://"): if self.pgt_url and (
self.pgt_url.startswith("https://") or
re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url)
):
return self.process_pgturl(params) return self.process_pgturl(params)
else: else:
logger.info( logger.info(

22
run_tests Executable file
View File

@ -0,0 +1,22 @@
#!/usr/bin/env python
import os, sys
import django
from django.conf import settings
import settings_tests
settings.configure(**settings_tests.__dict__)
django.setup()
try:
# Django <= 1.8
from django.test.simple import DjangoTestSuiteRunner
test_runner = DjangoTestSuiteRunner(verbosity=1)
except ImportError:
# Django >= 1.8
from django.test.runner import DiscoverRunner
test_runner = DiscoverRunner(verbosity=1)
failures = test_runner.run_tests(['cas_server'])
if failures:
sys.exit(failures)

83
settings_tests.py Normal file
View File

@ -0,0 +1,83 @@
"""
Django test settings for cas_server application.
Generated by 'django-admin startproject' using Django 1.9.7.
For more information on this file, see
https://docs.djangoproject.com/en/1.9/topics/settings/
For the full list of settings and their values, see
https://docs.djangoproject.com/en/1.9/ref/settings/
"""
import os
# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# Quick-start development settings - unsuitable for production
# See https://docs.djangoproject.com/en/1.9/howto/deployment/checklist/
# SECURITY WARNING: keep the secret key used in production secret!
SECRET_KEY = 'changeme'
# SECURITY WARNING: don't run with debug turned on in production!
DEBUG = True
ALLOWED_HOSTS = []
# Application definition
INSTALLED_APPS = [
'django.contrib.admin',
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.messages',
'django.contrib.staticfiles',
'bootstrap3',
'cas_server',
]
MIDDLEWARE_CLASSES = [
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.auth.middleware.SessionAuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
'django.middleware.clickjacking.XFrameOptionsMiddleware',
'django.middleware.locale.LocaleMiddleware',
]
ROOT_URLCONF = 'urls_tests'
# Database
# https://docs.djangoproject.com/en/1.9/ref/settings/#databases
DATABASES = {
'default': {
'ENGINE': 'django.db.backends.sqlite3',
}
}
# Internationalization
# https://docs.djangoproject.com/en/1.9/topics/i18n/
LANGUAGE_CODE = 'en-us'
TIME_ZONE = 'UTC'
USE_I18N = True
USE_L10N = True
USE_TZ = True
# Static files (CSS, JavaScript, Images)
# https://docs.djangoproject.com/en/1.9/howto/static-files/
STATIC_URL = '/static/'

View File

View File

@ -1,136 +0,0 @@
import functools
from cas_server import models
class DummyUserManager(object):
def __init__(self, username, session_key):
self.username = username
self.session_key = session_key
def get(self, username=None, session_key=None):
if username == self.username and session_key == self.session_key:
return models.User(username=username, session_key=session_key)
else:
raise models.User.DoesNotExist()
def dummy(*args, **kwds):
pass
def dummy_service_pattern(**kwargs):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwds):
service_validate = models.ServicePattern.validate
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern(**kwargs))
ret = func(*args, **kwds)
models.ServicePattern.validate = service_validate
return ret
return wrapper
return decorator
def dummy_user(username, session_key):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwds):
user_manager = models.User.objects
user_save = models.User.save
user_delete = models.User.delete
models.User.objects = DummyUserManager(username, session_key)
models.User.save = dummy
models.User.delete = dummy
ret = func(*args, **kwds)
models.User.objects = user_manager
models.User.save = user_save
models.User.delete = user_delete
return ret
return wrapper
return decorator
def dummy_ticket(ticket_class, service, ticket):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwds):
ticket_manager = ticket_class.objects
ticket_save = ticket_class.save
ticket_delete = ticket_class.delete
ticket_class.objects = DummyTicketManager(ticket_class, service, ticket)
ticket_class.save = dummy
ticket_class.delete = dummy
ret = func(*args, **kwds)
ticket_class.objects = ticket_manager
ticket_class.save = ticket_save
ticket_class.delete = ticket_delete
return ret
return wrapper
return decorator
def dummy_proxy(func):
@functools.wraps(func)
def wrapper(*args, **kwds):
proxy_manager = models.Proxy.objects
models.Proxy.objects = DummyProxyManager()
ret = func(*args, **kwds)
models.Proxy.objects = proxy_manager
return ret
return wrapper
class DummyProxyManager(object):
def create(self, **kwargs):
for field in models.Proxy._meta.fields:
field.allow_unsaved_instance_assignment = True
return models.Proxy(**kwargs)
class DummyTicketManager(object):
def __init__(self, ticket_class, service, ticket):
self.ticket_class = ticket_class
self.service = service
self.ticket = ticket
def create(self, **kwargs):
for field in self.ticket_class._meta.fields:
field.allow_unsaved_instance_assignment = True
return self.ticket_class(**kwargs)
def filter(self, *args, **kwargs):
return DummyQuerySet()
def get(self, **kwargs):
for field in self.ticket_class._meta.fields:
field.allow_unsaved_instance_assignment = True
if 'value' in kwargs:
if kwargs['value'] != self.ticket:
raise self.ticket_class.DoesNotExist()
else:
kwargs['value'] = self.ticket
if 'service' in kwargs:
if kwargs['service'] != self.service:
raise self.ticket_class.DoesNotExist()
else:
kwargs['service'] = self.service
if not 'user' in kwargs:
kwargs['user'] = models.User(username="test")
for field in models.ServiceTicket._meta.fields:
field.allow_unsaved_instance_assignment = True
for key in list(kwargs):
if '__' in key:
del kwargs[key]
kwargs['attributs'] = {'mail': 'test@example.com'}
kwargs['service_pattern'] = models.ServicePattern()
return self.ticket_class(**kwargs)
class DummySession(dict):
session_key = "test_session"
def set_expiry(self, int):
pass
def flush(self):
self.clear()
class DummyQuerySet(set):
pass

View File

@ -1,32 +0,0 @@
import django
from django.conf import settings
from django.contrib import messages
settings.configure()
settings.STATIC_URL = "/static/"
settings.DATABASES = {
'default': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': '/dev/null',
}
}
settings.INSTALLED_APPS = (
'django.contrib.admin',
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.messages',
'django.contrib.staticfiles',
'bootstrap3',
'cas_server',
)
settings.ROOT_URLCONF = "/"
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
try:
django.setup()
except AttributeError:
pass
messages.add_message = lambda x,y,z:None

View File

@ -1,52 +0,0 @@
from __future__ import absolute_import
from tests.init import *
from django.test import RequestFactory
import os
import pytest
from lxml import etree
from cas_server.views import ValidateService, Proxy
from cas_server import models
from tests.dummy import *
@pytest.mark.django_db
@dummy_ticket(models.ProxyGrantingTicket, '', "PGT-random")
@dummy_service_pattern(proxy=True)
@dummy_user(username="test", session_key="test_session")
@dummy_ticket(models.ProxyTicket, "https://www.example.com", "PT-random")
@dummy_proxy
def test_proxy_ok():
factory = RequestFactory()
request = factory.get('/proxy?pgt=PGT-random&targetService=https://www.example.com')
request.session = DummySession()
proxy = Proxy()
response = proxy.get(request)
assert response.status_code == 200
root = etree.fromstring(response.content)
proxy_tickets = root.xpath("//cas:proxyTicket", namespaces={'cas': "http://www.yale.edu/tp/cas"})
assert len(proxy_tickets) == 1
factory = RequestFactory()
request = factory.get('/proxyValidate?ticket=PT-random&service=https://www.example.com')
validate = ValidateService()
validate.allow_proxy_ticket = True
response = validate.get(request)
assert response.status_code == 200
root = etree.fromstring(response.content)
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
assert len(users) == 1
assert users[0].text == "test"

View File

@ -1,87 +0,0 @@
from __future__ import absolute_import
from .init import *
from django.test import RequestFactory
import os
import pytest
from lxml import etree
from cas_server.views import ValidateService
from cas_server import models
from .dummy import *
@pytest.mark.django_db
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
def test_validate_service_view_ok():
factory = RequestFactory()
request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example.com')
request.session = DummySession()
validate = ValidateService()
validate.allow_proxy_ticket = False
response = validate.get(request)
assert response.status_code == 200
root = etree.fromstring(response.content)
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
assert len(users) == 1
assert users[0].text == "test"
attributes = root.xpath("//cas:attributes", namespaces={'cas': "http://www.yale.edu/tp/cas"})
assert len(attributes) == 1
attrs = {}
for attr in attributes[0]:
attrs[attr.tag[len("http://www.yale.edu/tp/cas")+2:]]=attr.text
assert 'mail' in attrs
assert attrs['mail'] == 'test@example.com'
@pytest.mark.django_db
@dummy_ticket(models.ServiceTicket, 'https://www.example2.com', "ST-random")
def test_validate_service_view_badservice():
factory = RequestFactory()
request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example1.com')
request.session = DummySession()
validate = ValidateService()
validate.allow_proxy_ticket = False
response = validate.get(request)
assert response.status_code == 200
root = etree.fromstring(response.content)
error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"})
assert len(error) == 1
assert error[0].attrib['code'] == 'INVALID_SERVICE'
@pytest.mark.django_db
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random2")
def test_validate_service_view_badticket():
factory = RequestFactory()
request = factory.get('/serviceValidate?ticket=ST-random1&service=https://www.example.com')
request.session = DummySession()
validate = ValidateService()
validate.allow_proxy_ticket = False
response = validate.get(request)
assert response.status_code == 200
root = etree.fromstring(response.content)
error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"})
assert len(error) == 1
assert error[0].attrib['code'] == 'INVALID_TICKET'

View File

@ -1,46 +0,0 @@
from __future__ import absolute_import
from .init import *
from django.test import RequestFactory
import os
import pytest
from cas_server.views import Auth
from cas_server import models
from .dummy import *
settings.CAS_AUTH_SHARED_SECRET = "test"
@pytest.mark.django_db
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
@dummy_user(username="test", session_key="test_session")
@dummy_service_pattern()
def test_auth_view_goodpass():
factory = RequestFactory()
request = factory.post('/auth', {'username':'test', 'password':'test', 'service':'https://www.example.com', 'secret':'test'})
request.session = DummySession()
auth = Auth()
response = auth.post(request)
assert response.status_code == 200
assert response.content == b"yes\n"
@dummy_service_pattern()
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
@dummy_user(username="test", session_key="test_session")
def test_auth_view_badpass():
factory = RequestFactory()
request = factory.post('/auth', {'username':'test', 'password':'badpass', 'service':'https://www.example.com', 'secret':'test'})
request.session = DummySession()
auth = Auth()
response = auth.post(request)
assert response.status_code == 200
assert response.content == b"no\n"

View File

@ -1,163 +0,0 @@
from __future__ import absolute_import
from .init import *
from django.test import RequestFactory
import os
import pytest
from cas_server.views import LoginView
from cas_server import models
from .dummy import *
def test_login_view_post_goodpass_goodlt():
factory = RequestFactory()
request = factory.post('/login', {'username':'test', 'password':'test', 'lt':'LT-random'})
request.session = DummySession()
request.session['lt'] = ['LT-random']
request.session["username"] = os.urandom(20)
request.session["warn"] = os.urandom(20)
login = LoginView()
login.init_post(request)
ret = login.process_post(pytest=True)
assert ret == LoginView.USER_LOGIN_OK
assert request.session.get("authenticated") == True
assert request.session.get("username") == "test"
assert request.session.get("warn") == False
def test_login_view_post_badlt():
factory = RequestFactory()
request = factory.post('/login', {'username':'test', 'password':'test', 'lt':'LT-random1'})
request.session = DummySession()
request.session['lt'] = ['LT-random2']
authenticated = os.urandom(20)
username = os.urandom(20)
warn = os.urandom(20)
request.session["authenticated"] = authenticated
request.session["username"] = username
request.session["warn"] = warn
login = LoginView()
login.init_post(request)
ret = login.process_post(pytest=True)
assert ret == LoginView.INVALID_LOGIN_TICKET
assert request.session.get("authenticated") == authenticated
assert request.session.get("username") == username
assert request.session.get("warn") == warn
def test_login_view_post_badpass_good_lt():
factory = RequestFactory()
request = factory.post('/login', {'username':'test', 'password':'badpassword', 'lt':'LT-random'})
request.session = DummySession()
request.session['lt'] = ['LT-random']
login = LoginView()
login.init_post(request)
ret = login.process_post()
assert ret == LoginView.USER_LOGIN_FAILURE
assert not request.session.get("authenticated")
assert not request.session.get("username")
assert not request.session.get("warn")
def test_view_login_get_unauth():
factory = RequestFactory()
request = factory.post('/login')
request.session = DummySession()
login = LoginView()
login.init_get(request)
ret = login.process_get()
assert ret == LoginView.USER_NOT_AUTHENTICATED
login = LoginView()
response = login.get(request)
assert response.status_code == 200
@pytest.mark.django_db
@dummy_user(username="test", session_key="test_session")
def test_view_login_get_auth():
factory = RequestFactory()
request = factory.post('/login')
request.session = DummySession()
request.session["authenticated"] = True
request.session["username"] = "test"
request.session["warn"] = False
login = LoginView()
login.init_get(request)
ret = login.process_get()
assert ret == LoginView.USER_AUTHENTICATED
login = LoginView()
response = login.get(request)
assert response.status_code == 200
@pytest.mark.django_db
@dummy_service_pattern()
@dummy_user(username="test", session_key="test_session")
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
def test_view_login_get_auth_service():
factory = RequestFactory()
request = factory.post('/login?service=https://www.example.com')
request.session = DummySession()
request.session["authenticated"] = True
request.session["username"] = "test"
request.session["warn"] = False
login = LoginView()
login.init_get(request)
ret = login.process_get()
assert ret == LoginView.USER_AUTHENTICATED
login = LoginView()
response = login.get(request)
assert response.status_code == 302
assert response['Location'].startswith('https://www.example.com?ticket=ST-')
@pytest.mark.django_db
@dummy_service_pattern()
@dummy_user(username="test", session_key="test_session")
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
def test_view_login_get_auth_service_warn():
factory = RequestFactory()
request = factory.post('/login?service=https://www.example.com')
request.session = DummySession()
request.session["authenticated"] = True
request.session["username"] = "test"
request.session["warn"] = True
login = LoginView()
login.init_get(request)
ret = login.process_get()
assert ret == LoginView.USER_AUTHENTICATED
login = LoginView()
response = login.get(request)
assert response.status_code == 200

View File

@ -1,80 +0,0 @@
from __future__ import absolute_import
from .init import *
from django.test import RequestFactory
import os
import pytest
from cas_server.views import LogoutView
from cas_server import models
from .dummy import *
@pytest.mark.django_db
@dummy_user(username="test", session_key="test_session")
def test_logout_view():
factory = RequestFactory()
request = factory.get('/logout')
request.session = DummySession()
request.session["authenticated"] = True
request.session["username"] = "test"
request.session["warn"] = False
logout = LogoutView()
response = logout.get(request)
assert response.status_code == 200
assert not request.session.get("authenticated")
assert not request.session.get("username")
assert not request.session.get("warn")
@pytest.mark.django_db
@dummy_user(username="test", session_key="test_session")
def test_logout_view_url():
factory = RequestFactory()
request = factory.get('/logout?url=https://www.example.com')
request.session = DummySession()
request.session["authenticated"] = True
request.session["username"] = "test"
request.session["warn"] = False
logout = LogoutView()
response = logout.get(request)
assert response.status_code == 302
assert response['Location'] == 'https://www.example.com'
assert not request.session.get("authenticated")
assert not request.session.get("username")
assert not request.session.get("warn")
@pytest.mark.django_db
@dummy_user(username="test", session_key="test_session")
def test_logout_view_service():
factory = RequestFactory()
request = factory.get('/logout?service=https://www.example.com')
request.session = DummySession()
request.session["authenticated"] = True
request.session["username"] = "test"
request.session["warn"] = False
logout = LogoutView()
response = logout.get(request)
assert response.status_code == 302
assert response['Location'] == 'https://www.example.com'
assert not request.session.get("authenticated")
assert not request.session.get("username")
assert not request.session.get("warn")

View File

@ -1,58 +0,0 @@
from __future__ import absolute_import
from .init import *
from django.test import RequestFactory
import os
import pytest
from cas_server.views import Validate
from cas_server import models
from .dummy import *
@pytest.mark.django_db
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
def test_validate_view_ok():
factory = RequestFactory()
request = factory.get('/validate?ticket=ST-random&service=https://www.example.com')
request.session = DummySession()
validate = Validate()
response = validate.get(request)
assert response.status_code == 200
assert response.content == b"yes\ntest\n"
@pytest.mark.django_db
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
def test_validate_view_badservice():
factory = RequestFactory()
request = factory.get('/validate?ticket=ST-random&service=https://www.example2.com')
request.session = DummySession()
validate = Validate()
response = validate.get(request)
assert response.status_code == 200
assert response.content == b"no\n"
@pytest.mark.django_db
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random1")
def test_validate_view_badticket():
factory = RequestFactory()
request = factory.get('/validate?ticket=ST-random2&service=https://www.example.com')
request.session = DummySession()
validate = Validate()
response = validate.get(request)
assert response.status_code == 200
assert response.content == b"no\n"

View File

@ -17,7 +17,7 @@ deps =
-r{toxinidir}/requirements-dev.txt -r{toxinidir}/requirements-dev.txt
[testenv] [testenv]
commands=py.test --tb native {posargs:tests} commands=python run_tests {posargs:tests}
[testenv:py27-django17] [testenv:py27-django17]
basepython=python2.7 basepython=python2.7

22
urls_tests.py Normal file
View File

@ -0,0 +1,22 @@
"""cas URL Configuration
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/1.9/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views
2. Add a URL to urlpatterns: url(r'^$', views.home, name='home')
Class-based views
1. Add an import: from other_app.views import Home
2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home')
Including another URLconf
1. Import the include() function: from django.conf.urls import url, include, include
2. Add a URL to urlpatterns: url(r'^blog/', include('blog.urls'))
"""
from django.conf.urls import url, include
from django.contrib import admin
urlpatterns = [
url(r'^admin/', admin.site.urls),
url(r'^', include('cas_server.urls', namespace='cas_server')),
]