Merge pull request #3 from nitmir/tests_django
Convert tests to django builting tests, add coverage, add codacy
This commit is contained in:
commit
aff76e70ec
5
.gitignore
vendored
5
.gitignore
vendored
@ -1,5 +1,7 @@
|
||||
*.pyc
|
||||
*.egg-info
|
||||
*~
|
||||
*.swp
|
||||
|
||||
build/
|
||||
bootstrap3
|
||||
@ -7,6 +9,9 @@ cas/
|
||||
dist/
|
||||
db.sqlite3
|
||||
manage.py
|
||||
coverage.xml
|
||||
|
||||
.tox
|
||||
test_venv
|
||||
.coverage
|
||||
htmlcov/
|
||||
|
10
Makefile
10
Makefile
@ -44,3 +44,13 @@ test_project: test_venv test_venv/cas/manage.py
|
||||
|
||||
run_test_server: test_project
|
||||
test_venv/bin/python test_venv/cas/manage.py runserver
|
||||
|
||||
coverage: test_venv
|
||||
test_venv/bin/pip install coverage
|
||||
test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests
|
||||
test_venv/bin/coverage html
|
||||
test_venv/bin/coverage xml
|
||||
|
||||
coverage_codacy: coverage
|
||||
test_venv/bin/pip install codacy-coverage
|
||||
test_venv/bin/python-codacy-coverage -r coverage.xml
|
||||
|
12
README.rst
12
README.rst
@ -10,8 +10,14 @@ CAS Server
|
||||
.. image:: https://img.shields.io/pypi/l/django-cas-server.svg
|
||||
:target: https://www.gnu.org/licenses/gpl-3.0.html
|
||||
|
||||
.. image:: https://api.codacy.com/project/badge/Grade/255c21623d6946ef8802fa7995b61366
|
||||
:target: https://www.codacy.com/app/valentin-samir/django-cas-server
|
||||
|
||||
.. image:: https://api.codacy.com/project/badge/Coverage/255c21623d6946ef8802fa7995b61366
|
||||
:target: https://www.codacy.com/app/valentin-samir/django-cas-server
|
||||
|
||||
CAS Server is a Django application implementing the `CAS Protocol 3.0 Specification
|
||||
<https://jasig.github.io/cas/development/protocol/CAS-Protocol-Specification.html>`_.
|
||||
<https://apereo.github.io/cas/4.2.x/protocol/CAS-Protocol-Specification.html>`_.
|
||||
|
||||
By defaut, the authentication process use django internal users but you can easily
|
||||
use any sources (see auth classes in the auth.py file)
|
||||
@ -70,7 +76,7 @@ Quick start
|
||||
4. You should add some management commands to a crontab: ``clearsessions``,
|
||||
``cas_clean_tickets`` and ``cas_clean_sessions``.
|
||||
|
||||
* ``clearsessions``: please see `Clearing the session store <https://docs.djangoproject.com/en/1.9/topics/http/sessions/#clearing-the-session-store>`_.
|
||||
* ``clearsessions``: please see `Clearing the session store <https://docs.djangoproject.com/en/stable/topics/http/sessions/#clearing-the-session-store>`_.
|
||||
* ``cas_clean_tickets``: old tickets and timed-out tickets do not get purge from
|
||||
the database automatically. They are just marked as invalid. ``cas_clean_tickets``
|
||||
is a clean-up management command for this purpose. It send SingleLogOut request
|
||||
@ -204,7 +210,7 @@ Logs
|
||||
----
|
||||
|
||||
``django-cas-server`` logs most of its actions. To enable login, you must set the ``LOGGING``
|
||||
(https://docs.djangoproject.com/en/dev/topics/logging) variable is ``settings.py``.
|
||||
(https://docs.djangoproject.com/en/stable/topics/logging) variable is ``settings.py``.
|
||||
|
||||
Users successful actions (login, logout) are logged with the level ``INFO``, failures are logged
|
||||
with the level ``WARNING`` and user attributes transmitted to a service are logged with the level ``DEBUG``.
|
||||
|
@ -9,4 +9,4 @@
|
||||
#
|
||||
# (c) 2015 Valentin Samir
|
||||
|
||||
default_app_config = 'cas_server.apps.AppConfig'
|
||||
default_app_config = 'cas_server.apps.CasAppConfig'
|
||||
|
@ -2,6 +2,6 @@ from django.utils.translation import ugettext_lazy as _
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class AppConfig(AppConfig):
|
||||
class CasAppConfig(AppConfig):
|
||||
name = 'cas_server'
|
||||
verbose_name = _('Central Authentication Service')
|
||||
|
@ -26,11 +26,11 @@ class AuthUser(object):
|
||||
|
||||
def test_password(self, password):
|
||||
"""test `password` agains the user"""
|
||||
raise NotImplemented()
|
||||
raise NotImplementedError()
|
||||
|
||||
def attributs(self):
|
||||
"""return a dict of user attributes"""
|
||||
raise NotImplemented()
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DummyAuthUser(AuthUser):
|
||||
@ -57,11 +57,11 @@ class TestAuthUser(AuthUser):
|
||||
|
||||
def test_password(self, password):
|
||||
"""test `password` agains the user"""
|
||||
return self.username == "test" and password == "test"
|
||||
return self.username == settings.CAS_TEST_USER and password == settings.CAS_TEST_PASSWORD
|
||||
|
||||
def attributs(self):
|
||||
"""return a dict of user attributes"""
|
||||
return {'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}
|
||||
return settings.CAS_TEST_ATTRIBUTES
|
||||
|
||||
|
||||
class MysqlAuthUser(AuthUser):
|
||||
|
@ -73,3 +73,10 @@ setting_default('CAS_SQL_DBCHARSET', 'utf8')
|
||||
setting_default('CAS_SQL_USER_QUERY', 'SELECT user AS usersame, pass AS '
|
||||
'password, users.* FROM users WHERE user = %s')
|
||||
setting_default('CAS_SQL_PASSWORD_CHECK', 'crypt') # crypt or plain
|
||||
|
||||
setting_default('CAS_TEST_USER', 'test')
|
||||
setting_default('CAS_TEST_PASSWORD', 'test')
|
||||
setting_default(
|
||||
'CAS_TEST_ATTRIBUTES',
|
||||
{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}
|
||||
)
|
||||
|
@ -1,55 +1,52 @@
|
||||
function cas_login(cas_server_login, service, login_service, callback){
|
||||
url = cas_server_login + '?service=' + encodeURIComponent(service);
|
||||
var url = cas_server_login + "?service=" + encodeURIComponent(service);
|
||||
$.ajax({
|
||||
type: 'GET',
|
||||
url:url,
|
||||
beforeSend: function (request) {
|
||||
type: "GET",
|
||||
url,
|
||||
beforeSend(request) {
|
||||
request.setRequestHeader("X-AJAX", "1");
|
||||
},
|
||||
xhrFields: {
|
||||
withCredentials: true
|
||||
},
|
||||
success: function(data, textStatus, request){
|
||||
if(data.status == 'success'){
|
||||
success(data, textStatus, request){
|
||||
if(data.status === "success"){
|
||||
$.ajax({
|
||||
type: 'GET',
|
||||
type: "GET",
|
||||
url: data.url,
|
||||
xhrFields: {
|
||||
withCredentials: true
|
||||
},
|
||||
success: callback,
|
||||
error: function (request, textStatus, errorThrown) {},
|
||||
error(request, textStatus, errorThrown) {},
|
||||
});
|
||||
} else {
|
||||
if(data.detail == "login required"){
|
||||
window.location.href = cas_server_login + '?service=' + encodeURIComponent(login_service);
|
||||
if(data.detail === "login required"){
|
||||
window.location.href = cas_server_login + "?service=" + encodeURIComponent(login_service);
|
||||
} else {
|
||||
alert('error: ' + data.messages[1].message);
|
||||
alert("error: " + data.messages[1].message);
|
||||
}
|
||||
}
|
||||
},
|
||||
error: function (request, textStatus, errorThrown) {},
|
||||
error(request, textStatus, errorThrown) {},
|
||||
});
|
||||
}
|
||||
|
||||
function cas_logout(cas_server_logout){
|
||||
$.ajax({
|
||||
type: 'GET',
|
||||
type: "GET",
|
||||
url: cas_server_logout,
|
||||
beforeSend: function (request) {
|
||||
beforeSend(request) {
|
||||
request.setRequestHeader("X-AJAX", "1");
|
||||
},
|
||||
xhrFields: {
|
||||
withCredentials: true
|
||||
},
|
||||
error: function (request, textStatus, errorThrown) {},
|
||||
success: function(data, textStatus, request){
|
||||
if(data.status == 'error'){
|
||||
alert('error: ' + data.messages[1].message);
|
||||
error(request, textStatus, errorThrown) {},
|
||||
success(data, textStatus, request){
|
||||
if(data.status === "error"){
|
||||
alert("error: " + data.messages[1].message);
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
@ -43,14 +43,14 @@ body {
|
||||
|
||||
@media screen and (max-width: 680px) {
|
||||
#app-name {
|
||||
margin: 0px;
|
||||
margin: 0;
|
||||
}
|
||||
#app-name img {
|
||||
display: block;
|
||||
margin: auto;
|
||||
}
|
||||
body {
|
||||
padding-top: 0px;
|
||||
padding-bottom: 0px;
|
||||
padding-top: 0;
|
||||
padding-bottom: 0;
|
||||
}
|
||||
}
|
||||
|
639
cas_server/tests.py
Normal file
639
cas_server/tests.py
Normal file
@ -0,0 +1,639 @@
|
||||
from .default_settings import settings
|
||||
|
||||
from django.test import TestCase
|
||||
from django.test import Client
|
||||
|
||||
from lxml import etree
|
||||
|
||||
from cas_server import models
|
||||
from cas_server import utils
|
||||
|
||||
|
||||
def get_login_page_params():
|
||||
client = Client()
|
||||
response = client.get('/login')
|
||||
form = response.context["form"]
|
||||
params = {}
|
||||
for field in form:
|
||||
if field.value():
|
||||
params[field.name] = field.value()
|
||||
else:
|
||||
params[field.name] = ""
|
||||
return client, params
|
||||
|
||||
|
||||
def get_auth_client():
|
||||
client, params = get_login_page_params()
|
||||
params["username"] = settings.CAS_TEST_USER
|
||||
params["password"] = settings.CAS_TEST_PASSWORD
|
||||
|
||||
client.post('/login', params)
|
||||
return client
|
||||
|
||||
|
||||
def get_user_ticket_request(service):
|
||||
client = get_auth_client()
|
||||
response = client.get("/login", {"service": service})
|
||||
ticket_value = response['Location'].split('ticket=')[-1]
|
||||
user = models.User.objects.get(
|
||||
username=settings.CAS_TEST_USER,
|
||||
session_key=client.session.session_key
|
||||
)
|
||||
ticket = models.ServiceTicket.objects.get(value=ticket_value)
|
||||
return (user, ticket)
|
||||
|
||||
|
||||
def get_pgt():
|
||||
(host, port) = utils.PGTUrlHandler.run()[1:3]
|
||||
service = "http://%s:%s" % (host, port)
|
||||
|
||||
(user, ticket) = get_user_ticket_request(service)
|
||||
|
||||
client = Client()
|
||||
client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
|
||||
params = utils.PGTUrlHandler.PARAMS.copy()
|
||||
|
||||
params["service"] = service
|
||||
params["user"] = user
|
||||
|
||||
return params
|
||||
|
||||
|
||||
class LoginTestCase(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
|
||||
self.service_pattern = models.ServicePattern.objects.create(
|
||||
name="example",
|
||||
pattern="^https://www\.example\.com(/.*)?$",
|
||||
)
|
||||
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
|
||||
|
||||
def test_login_view_post_goodpass_goodlt(self):
|
||||
client, params = get_login_page_params()
|
||||
params["username"] = settings.CAS_TEST_USER
|
||||
params["password"] = settings.CAS_TEST_PASSWORD
|
||||
|
||||
response = client.post('/login', params)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertTrue(
|
||||
(
|
||||
b"You have successfully logged into "
|
||||
b"the Central Authentication Service"
|
||||
) in response.content
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
models.User.objects.get(
|
||||
username=settings.CAS_TEST_USER,
|
||||
session_key=client.session.session_key
|
||||
)
|
||||
)
|
||||
|
||||
def test_login_view_post_badlt(self):
|
||||
client, params = get_login_page_params()
|
||||
params["username"] = settings.CAS_TEST_USER
|
||||
params["password"] = settings.CAS_TEST_PASSWORD
|
||||
params["lt"] = 'LT-random'
|
||||
|
||||
response = client.post('/login', params)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertTrue(b"Invalid login ticket" in response.content)
|
||||
self.assertFalse(
|
||||
(
|
||||
b"You have successfully logged into "
|
||||
b"the Central Authentication Service"
|
||||
) in response.content
|
||||
)
|
||||
|
||||
def test_login_view_post_badpass_good_lt(self):
|
||||
client, params = get_login_page_params()
|
||||
params["username"] = settings.CAS_TEST_USER
|
||||
params["password"] = "test2"
|
||||
response = client.post('/login', params)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertTrue(
|
||||
(
|
||||
b"The credentials you provided cannot be "
|
||||
b"determined to be authentic"
|
||||
) in response.content
|
||||
)
|
||||
self.assertFalse(
|
||||
(
|
||||
b"You have successfully logged into "
|
||||
b"the Central Authentication Service"
|
||||
) in response.content
|
||||
)
|
||||
|
||||
def test_view_login_get_auth_allowed_service(self):
|
||||
client = get_auth_client()
|
||||
response = client.get("/login?service=https://www.example.com")
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertTrue(response.has_header('Location'))
|
||||
self.assertTrue(
|
||||
response['Location'].startswith(
|
||||
"https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX
|
||||
)
|
||||
)
|
||||
|
||||
ticket_value = response['Location'].split('ticket=')[-1]
|
||||
user = models.User.objects.get(
|
||||
username=settings.CAS_TEST_USER,
|
||||
session_key=client.session.session_key
|
||||
)
|
||||
self.assertTrue(user)
|
||||
ticket = models.ServiceTicket.objects.get(value=ticket_value)
|
||||
self.assertEqual(ticket.user, user)
|
||||
self.assertEqual(ticket.attributs, settings.CAS_TEST_ATTRIBUTES)
|
||||
self.assertEqual(ticket.validate, False)
|
||||
self.assertEqual(ticket.service_pattern, self.service_pattern)
|
||||
|
||||
def test_view_login_get_auth_denied_service(self):
|
||||
client = get_auth_client()
|
||||
response = client.get("/login?service=https://www.example.org")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertTrue(b"Service https://www.example.org non allowed" in response.content)
|
||||
|
||||
|
||||
class LogoutTestCase(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
|
||||
|
||||
def test_logout_view(self):
|
||||
client = get_auth_client()
|
||||
|
||||
response = client.get("/login")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertTrue(
|
||||
(
|
||||
b"You have successfully logged into "
|
||||
b"the Central Authentication Service"
|
||||
) in response.content
|
||||
)
|
||||
|
||||
response = client.get("/logout")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertTrue(
|
||||
(
|
||||
b"You have successfully logged out from "
|
||||
b"the Central Authentication Service"
|
||||
) in response.content
|
||||
)
|
||||
|
||||
response = client.get("/login")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertFalse(
|
||||
(
|
||||
b"You have successfully logged into "
|
||||
b"the Central Authentication Service"
|
||||
) in response.content
|
||||
)
|
||||
|
||||
def test_logout_view_url(self):
|
||||
client = get_auth_client()
|
||||
|
||||
response = client.get('/logout?url=https://www.example.com')
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertTrue(response.has_header("Location"))
|
||||
self.assertEqual(response["Location"], "https://www.example.com")
|
||||
|
||||
response = client.get("/login")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertFalse(
|
||||
(
|
||||
b"You have successfully logged into "
|
||||
b"the Central Authentication Service"
|
||||
) in response.content
|
||||
)
|
||||
|
||||
def test_logout_view_service(self):
|
||||
client = get_auth_client()
|
||||
|
||||
response = client.get('/logout?service=https://www.example.com')
|
||||
self.assertEqual(response.status_code, 302)
|
||||
self.assertTrue(response.has_header("Location"))
|
||||
self.assertEqual(response["Location"], "https://www.example.com")
|
||||
|
||||
response = client.get("/login")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertFalse(
|
||||
(
|
||||
b"You have successfully logged into "
|
||||
b"the Central Authentication Service"
|
||||
) in response.content
|
||||
)
|
||||
|
||||
|
||||
class AuthTestCase(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
|
||||
self.service = 'https://www.example.com'
|
||||
models.ServicePattern.objects.create(
|
||||
name="example",
|
||||
pattern="^https://www\.example\.com(/.*)?$"
|
||||
)
|
||||
|
||||
def test_auth_view_goodpass(self):
|
||||
settings.CAS_AUTH_SHARED_SECRET = 'test'
|
||||
client = Client()
|
||||
response = client.post(
|
||||
'/auth',
|
||||
{
|
||||
'username': settings.CAS_TEST_USER,
|
||||
'password': settings.CAS_TEST_PASSWORD,
|
||||
'service': self.service,
|
||||
'secret': 'test'
|
||||
}
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.content, b'yes\n')
|
||||
|
||||
def test_auth_view_badpass(self):
|
||||
settings.CAS_AUTH_SHARED_SECRET = 'test'
|
||||
client = Client()
|
||||
response = client.post(
|
||||
'/auth',
|
||||
{
|
||||
'username': settings.CAS_TEST_USER,
|
||||
'password': 'badpass',
|
||||
'service': self.service,
|
||||
'secret': 'test'
|
||||
}
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.content, b'no\n')
|
||||
|
||||
def test_auth_view_badservice(self):
|
||||
settings.CAS_AUTH_SHARED_SECRET = 'test'
|
||||
client = Client()
|
||||
response = client.post(
|
||||
'/auth',
|
||||
{
|
||||
'username': settings.CAS_TEST_USER,
|
||||
'password': settings.CAS_TEST_PASSWORD,
|
||||
'service': 'https://www.example.org',
|
||||
'secret': 'test'
|
||||
}
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.content, b'no\n')
|
||||
|
||||
def test_auth_view_badsecret(self):
|
||||
settings.CAS_AUTH_SHARED_SECRET = 'test'
|
||||
client = Client()
|
||||
response = client.post(
|
||||
'/auth',
|
||||
{
|
||||
'username': settings.CAS_TEST_USER,
|
||||
'password': settings.CAS_TEST_PASSWORD,
|
||||
'service': self.service,
|
||||
'secret': 'badsecret'
|
||||
}
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.content, b'no\n')
|
||||
|
||||
def test_auth_view_badsettings(self):
|
||||
settings.CAS_AUTH_SHARED_SECRET = None
|
||||
client = Client()
|
||||
response = client.post(
|
||||
'/auth',
|
||||
{
|
||||
'username': settings.CAS_TEST_USER,
|
||||
'password': settings.CAS_TEST_PASSWORD,
|
||||
'service': self.service,
|
||||
'secret': 'test'
|
||||
}
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.content, b"no\nplease set CAS_AUTH_SHARED_SECRET")
|
||||
|
||||
|
||||
class ValidateTestCase(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
|
||||
self.service = 'https://www.example.com'
|
||||
self.service_pattern = models.ServicePattern.objects.create(
|
||||
name="example",
|
||||
pattern="^https://www\.example\.com(/.*)?$"
|
||||
)
|
||||
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
|
||||
|
||||
def test_validate_view_ok(self):
|
||||
ticket = get_user_ticket_request(self.service)[1]
|
||||
|
||||
client = Client()
|
||||
response = client.get('/validate', {'ticket': ticket.value, 'service': self.service})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.content, b'yes\ntest\n')
|
||||
|
||||
def test_validate_view_badservice(self):
|
||||
ticket = get_user_ticket_request(self.service)[1]
|
||||
|
||||
client = Client()
|
||||
response = client.get(
|
||||
'/validate',
|
||||
{'ticket': ticket.value, 'service': "https://www.example.org"}
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.content, b'no\n')
|
||||
|
||||
def test_validate_view_badticket(self):
|
||||
get_user_ticket_request(self.service)
|
||||
|
||||
client = Client()
|
||||
response = client.get(
|
||||
'/validate',
|
||||
{'ticket': "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX, 'service': self.service}
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.content, b'no\n')
|
||||
|
||||
|
||||
class ValidateServiceTestCase(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
|
||||
self.service = 'http://127.0.0.1:45678'
|
||||
self.service_pattern = models.ServicePattern.objects.create(
|
||||
name="localhost",
|
||||
pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
|
||||
proxy_callback=True
|
||||
)
|
||||
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
|
||||
|
||||
def test_validate_service_view_ok(self):
|
||||
ticket = get_user_ticket_request(self.service)[1]
|
||||
|
||||
client = Client()
|
||||
response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
root = etree.fromstring(response.content)
|
||||
sucess = root.xpath(
|
||||
"//cas:authenticationSuccess",
|
||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||
)
|
||||
self.assertTrue(sucess)
|
||||
|
||||
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||
self.assertEqual(len(users), 1)
|
||||
self.assertEqual(users[0].text, settings.CAS_TEST_USER)
|
||||
|
||||
attributes = root.xpath(
|
||||
"//cas:attributes",
|
||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||
)
|
||||
self.assertEqual(len(attributes), 1)
|
||||
attrs1 = {}
|
||||
for attr in attributes[0]:
|
||||
attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text
|
||||
|
||||
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||
self.assertEqual(len(attributes), len(attrs1))
|
||||
attrs2 = {}
|
||||
for attr in attributes:
|
||||
attrs2[attr.attrib['name']] = attr.attrib['value']
|
||||
self.assertEqual(attrs1, attrs2)
|
||||
self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES)
|
||||
|
||||
def test_validate_service_view_badservice(self):
|
||||
ticket = get_user_ticket_request(self.service)[1]
|
||||
|
||||
client = Client()
|
||||
bad_service = "https://www.example.org"
|
||||
response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': bad_service})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
root = etree.fromstring(response.content)
|
||||
error = root.xpath(
|
||||
"//cas:authenticationFailure",
|
||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||
)
|
||||
self.assertEqual(len(error), 1)
|
||||
self.assertEqual(error[0].attrib['code'], "INVALID_SERVICE")
|
||||
self.assertEqual(error[0].text, bad_service)
|
||||
|
||||
def test_validate_service_view_badticket_goodprefix(self):
|
||||
get_user_ticket_request(self.service)
|
||||
|
||||
client = Client()
|
||||
bad_ticket = "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX
|
||||
response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
root = etree.fromstring(response.content)
|
||||
error = root.xpath(
|
||||
"//cas:authenticationFailure",
|
||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||
)
|
||||
self.assertEqual(len(error), 1)
|
||||
self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
|
||||
self.assertEqual(error[0].text, 'ticket not found')
|
||||
|
||||
def test_validate_service_view_badticket_badprefix(self):
|
||||
get_user_ticket_request(self.service)
|
||||
|
||||
client = Client()
|
||||
bad_ticket = "RANDOM"
|
||||
response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
root = etree.fromstring(response.content)
|
||||
error = root.xpath(
|
||||
"//cas:authenticationFailure",
|
||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||
)
|
||||
self.assertEqual(len(error), 1)
|
||||
self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
|
||||
self.assertEqual(error[0].text, bad_ticket)
|
||||
|
||||
def test_validate_service_view_ok_pgturl(self):
|
||||
(host, port) = utils.PGTUrlHandler.run()[1:3]
|
||||
service = "http://%s:%s" % (host, port)
|
||||
|
||||
ticket = get_user_ticket_request(service)[1]
|
||||
|
||||
client = Client()
|
||||
response = client.get(
|
||||
'/serviceValidate',
|
||||
{'ticket': ticket.value, 'service': service, 'pgtUrl': service}
|
||||
)
|
||||
pgt_params = utils.PGTUrlHandler.PARAMS.copy()
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
root = etree.fromstring(response.content)
|
||||
pgtiou = root.xpath(
|
||||
"//cas:proxyGrantingTicket",
|
||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||
)
|
||||
self.assertEqual(len(pgtiou), 1)
|
||||
self.assertEqual(pgt_params["pgtIou"], pgtiou[0].text)
|
||||
self.assertTrue("pgtId" in pgt_params)
|
||||
|
||||
def test_validate_service_pgturl_bad_proxy_callback(self):
|
||||
self.service_pattern.proxy_callback = False
|
||||
self.service_pattern.save()
|
||||
ticket = get_user_ticket_request(self.service)[1]
|
||||
|
||||
client = Client()
|
||||
response = client.get(
|
||||
'/serviceValidate',
|
||||
{'ticket': ticket.value, 'service': self.service, 'pgtUrl': self.service}
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
root = etree.fromstring(response.content)
|
||||
error = root.xpath(
|
||||
"//cas:authenticationFailure",
|
||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||
)
|
||||
self.assertEqual(len(error), 1)
|
||||
self.assertEqual(error[0].attrib['code'], "INVALID_PROXY_CALLBACK")
|
||||
self.assertEqual(error[0].text, "callback url not allowed by configuration")
|
||||
|
||||
|
||||
class ProxyTestCase(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
|
||||
self.service = 'http://127.0.0.1'
|
||||
self.service_pattern = models.ServicePattern.objects.create(
|
||||
name="localhost",
|
||||
pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
|
||||
proxy=True,
|
||||
proxy_callback=True
|
||||
)
|
||||
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
|
||||
|
||||
def test_validate_proxy_ok(self):
|
||||
params = get_pgt()
|
||||
|
||||
# get a proxy ticket
|
||||
client1 = Client()
|
||||
response = client1.get('/proxy', {'pgt': params['pgtId'], 'targetService': self.service})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
root = etree.fromstring(response.content)
|
||||
sucess = root.xpath("//cas:proxySuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||
self.assertTrue(sucess)
|
||||
|
||||
proxy_ticket = root.xpath(
|
||||
"//cas:proxyTicket",
|
||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||
)
|
||||
self.assertEqual(len(proxy_ticket), 1)
|
||||
proxy_ticket = proxy_ticket[0].text
|
||||
|
||||
# validate the proxy ticket
|
||||
client2 = Client()
|
||||
response = client2.get('/proxyValidate', {'ticket': proxy_ticket, 'service': self.service})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
root = etree.fromstring(response.content)
|
||||
sucess = root.xpath(
|
||||
"//cas:authenticationSuccess",
|
||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||
)
|
||||
self.assertTrue(sucess)
|
||||
|
||||
# check that the proxy is send to the end service
|
||||
proxies = root.xpath("//cas:proxies", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||
self.assertEqual(len(proxies), 1)
|
||||
proxy = proxies[0].xpath("//cas:proxy", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||
self.assertEqual(len(proxy), 1)
|
||||
self.assertEqual(proxy[0].text, params["service"])
|
||||
|
||||
# same tests than those for serviceValidate
|
||||
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||
self.assertEqual(len(users), 1)
|
||||
self.assertEqual(users[0].text, settings.CAS_TEST_USER)
|
||||
|
||||
attributes = root.xpath(
|
||||
"//cas:attributes",
|
||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||
)
|
||||
self.assertEqual(len(attributes), 1)
|
||||
attrs1 = {}
|
||||
for attr in attributes[0]:
|
||||
attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text
|
||||
|
||||
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||
self.assertEqual(len(attributes), len(attrs1))
|
||||
attrs2 = {}
|
||||
for attr in attributes:
|
||||
attrs2[attr.attrib['name']] = attr.attrib['value']
|
||||
self.assertEqual(attrs1, attrs2)
|
||||
self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES)
|
||||
|
||||
def test_validate_proxy_bad(self):
|
||||
params = get_pgt()
|
||||
|
||||
# bad PGT
|
||||
client1 = Client()
|
||||
response = client1.get(
|
||||
'/proxy',
|
||||
{
|
||||
'pgt': "%s-RANDOM" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX,
|
||||
'targetService': params['service']
|
||||
}
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
root = etree.fromstring(response.content)
|
||||
error = root.xpath(
|
||||
"//cas:authenticationFailure",
|
||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||
)
|
||||
self.assertEqual(len(error), 1)
|
||||
self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
|
||||
self.assertEqual(
|
||||
error[0].text,
|
||||
"PGT %s-RANDOM not found" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX
|
||||
)
|
||||
|
||||
# bad targetService
|
||||
client2 = Client()
|
||||
response = client2.get(
|
||||
'/proxy',
|
||||
{'pgt': params['pgtId'], 'targetService': "https://www.example.org"}
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
root = etree.fromstring(response.content)
|
||||
error = root.xpath(
|
||||
"//cas:authenticationFailure",
|
||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||
)
|
||||
self.assertEqual(len(error), 1)
|
||||
self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE")
|
||||
self.assertEqual(error[0].text, "https://www.example.org")
|
||||
|
||||
# service do not allow proxy ticket
|
||||
self.service_pattern.proxy = False
|
||||
self.service_pattern.save()
|
||||
|
||||
client3 = Client()
|
||||
response = client3.get(
|
||||
'/proxy',
|
||||
{'pgt': params['pgtId'], 'targetService': params['service']}
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
root = etree.fromstring(response.content)
|
||||
error = root.xpath(
|
||||
"//cas:authenticationFailure",
|
||||
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
||||
)
|
||||
self.assertEqual(len(error), 1)
|
||||
self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE")
|
||||
self.assertEqual(
|
||||
error[0].text,
|
||||
'the service %s do not allow proxy ticket' % params['service']
|
||||
)
|
@ -14,7 +14,7 @@ from django.conf.urls import patterns, url
|
||||
from django.views.generic import RedirectView
|
||||
from django.views.decorators.debug import sensitive_post_parameters, sensitive_variables
|
||||
|
||||
import views
|
||||
from cas_server import views
|
||||
|
||||
urlpatterns = patterns(
|
||||
'',
|
||||
|
@ -19,13 +19,10 @@ from django.contrib import messages
|
||||
import random
|
||||
import string
|
||||
import json
|
||||
from threading import Thread
|
||||
from importlib import import_module
|
||||
|
||||
try:
|
||||
from urlparse import urlparse, urlunparse, parse_qsl
|
||||
from urllib import urlencode
|
||||
except ImportError:
|
||||
from urllib.parse import urlparse, urlunparse, parse_qsl, urlencode
|
||||
from six.moves import BaseHTTPServer
|
||||
from six.moves.urllib.parse import urlparse, urlunparse, parse_qsl, urlencode
|
||||
|
||||
|
||||
def context(params):
|
||||
@ -83,9 +80,9 @@ def update_url(url, params):
|
||||
query = dict(parse_qsl(url_parts[4]))
|
||||
query.update(params)
|
||||
url_parts[4] = urlencode(query)
|
||||
for i in range(len(url_parts)):
|
||||
if not isinstance(url_parts[i], bytes):
|
||||
url_parts[i] = url_parts[i].encode('utf-8')
|
||||
for i, url_part in enumerate(url_parts):
|
||||
if not isinstance(url_part, bytes):
|
||||
url_parts[i] = url_part.encode('utf-8')
|
||||
return urlunparse(url_parts).decode('utf-8')
|
||||
|
||||
|
||||
@ -144,3 +141,34 @@ def gen_pgtiou():
|
||||
def gen_saml_id():
|
||||
"""Generate an saml id"""
|
||||
return _gen_ticket('_')
|
||||
|
||||
|
||||
class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler):
|
||||
PARAMS = {}
|
||||
|
||||
def do_GET(s):
|
||||
s.send_response(200)
|
||||
s.send_header(b"Content-type", "text/plain")
|
||||
s.end_headers()
|
||||
s.wfile.write(b"ok")
|
||||
url = urlparse(s.path)
|
||||
params = dict(parse_qsl(url.query))
|
||||
PGTUrlHandler.PARAMS.update(params)
|
||||
|
||||
def log_message(self, template, *args):
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def run():
|
||||
server_class = BaseHTTPServer.HTTPServer
|
||||
httpd = server_class(("127.0.0.1", 0), PGTUrlHandler)
|
||||
(host, port) = httpd.socket.getsockname()
|
||||
|
||||
def lauch():
|
||||
httpd.handle_request()
|
||||
httpd.server_close()
|
||||
|
||||
httpd_thread = Thread(target=lauch)
|
||||
httpd_thread.daemon = True
|
||||
httpd_thread.start()
|
||||
return (httpd_thread, host, port)
|
||||
|
@ -23,6 +23,7 @@ from django.views.decorators.csrf import csrf_exempt
|
||||
|
||||
from django.views.generic import View
|
||||
|
||||
import re
|
||||
import logging
|
||||
import pprint
|
||||
import requests
|
||||
@ -62,12 +63,12 @@ class AttributesMixin(object):
|
||||
|
||||
class LogoutMixin(object):
|
||||
"""destroy CAS session utils"""
|
||||
def logout(self, all=False):
|
||||
def logout(self, all_session=False):
|
||||
"""effectively destroy CAS session"""
|
||||
session_nb = 0
|
||||
username = self.request.session.get("username")
|
||||
if username:
|
||||
if all:
|
||||
if all_session:
|
||||
logger.info("Logging out user %s from all of they sessions." % username)
|
||||
else:
|
||||
logger.info("Logging out user %s." % username)
|
||||
@ -85,8 +86,8 @@ class LogoutMixin(object):
|
||||
# if user not found in database, flush the session anyway
|
||||
self.request.session.flush()
|
||||
|
||||
# If all is set logout user from alternative sessions
|
||||
if all:
|
||||
# If all_session is set logout user from alternative sessions
|
||||
if all_session:
|
||||
for user in models.User.objects.filter(username=username):
|
||||
session = SessionStore(session_key=user.session_key)
|
||||
session.flush()
|
||||
@ -197,10 +198,7 @@ class LoginView(View, LogoutMixin):
|
||||
def init_post(self, request):
|
||||
self.request = request
|
||||
self.service = request.POST.get('service')
|
||||
if request.POST.get('renew') and request.POST['renew'] != "False":
|
||||
self.renew = True
|
||||
else:
|
||||
self.renew = False
|
||||
self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False")
|
||||
self.gateway = request.POST.get('gateway')
|
||||
self.method = request.POST.get('method')
|
||||
self.ajax = 'HTTP_X_AJAX' in request.META
|
||||
@ -284,10 +282,7 @@ class LoginView(View, LogoutMixin):
|
||||
def init_get(self, request):
|
||||
self.request = request
|
||||
self.service = request.GET.get('service')
|
||||
if request.GET.get('renew') and request.GET['renew'] != "False":
|
||||
self.renew = True
|
||||
else:
|
||||
self.renew = False
|
||||
self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False")
|
||||
self.gateway = request.GET.get('gateway')
|
||||
self.method = request.GET.get('method')
|
||||
self.ajax = 'HTTP_X_AJAX' in request.META
|
||||
@ -666,7 +661,10 @@ class ValidateService(View, AttributesMixin):
|
||||
params['username'] = self.ticket.user.attributs.get(
|
||||
self.ticket.service_pattern.user_field
|
||||
)
|
||||
if self.pgt_url and self.pgt_url.startswith("https://"):
|
||||
if self.pgt_url and (
|
||||
self.pgt_url.startswith("https://") or
|
||||
re.match("^http://(127\.0\.0\.1|localhost)(:[0-9]+)?(/.*)?$", self.pgt_url)
|
||||
):
|
||||
return self.process_pgturl(params)
|
||||
else:
|
||||
logger.info(
|
||||
|
@ -7,3 +7,4 @@ django-picklefield>=0.3.1
|
||||
requests_futures>=0.9.5
|
||||
django-bootstrap3>=5.4
|
||||
lxml>=3.4
|
||||
six>=1
|
||||
|
@ -5,4 +5,4 @@ requests_futures>=0.9.5
|
||||
django-picklefield>=0.3.1
|
||||
django-bootstrap3>=5.4
|
||||
lxml>=3.4
|
||||
|
||||
six>=1
|
||||
|
22
run_tests
Executable file
22
run_tests
Executable file
@ -0,0 +1,22 @@
|
||||
#!/usr/bin/env python
|
||||
import os, sys
|
||||
import django
|
||||
from django.conf import settings
|
||||
|
||||
import settings_tests
|
||||
|
||||
settings.configure(**settings_tests.__dict__)
|
||||
django.setup()
|
||||
|
||||
try:
|
||||
# Django <= 1.8
|
||||
from django.test.simple import DjangoTestSuiteRunner
|
||||
test_runner = DjangoTestSuiteRunner(verbosity=1)
|
||||
except ImportError:
|
||||
# Django >= 1.8
|
||||
from django.test.runner import DiscoverRunner
|
||||
test_runner = DiscoverRunner(verbosity=1)
|
||||
|
||||
failures = test_runner.run_tests(['cas_server'])
|
||||
if failures:
|
||||
sys.exit(failures)
|
83
settings_tests.py
Normal file
83
settings_tests.py
Normal file
@ -0,0 +1,83 @@
|
||||
"""
|
||||
Django test settings for cas_server application.
|
||||
|
||||
Generated by 'django-admin startproject' using Django 1.9.7.
|
||||
|
||||
For more information on this file, see
|
||||
https://docs.djangoproject.com/en/1.9/topics/settings/
|
||||
|
||||
For the full list of settings and their values, see
|
||||
https://docs.djangoproject.com/en/1.9/ref/settings/
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
# Build paths inside the project like this: os.path.join(BASE_DIR, ...)
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
# Quick-start development settings - unsuitable for production
|
||||
# See https://docs.djangoproject.com/en/1.9/howto/deployment/checklist/
|
||||
|
||||
# SECURITY WARNING: keep the secret key used in production secret!
|
||||
SECRET_KEY = 'changeme'
|
||||
|
||||
# SECURITY WARNING: don't run with debug turned on in production!
|
||||
DEBUG = True
|
||||
|
||||
ALLOWED_HOSTS = []
|
||||
|
||||
|
||||
# Application definition
|
||||
|
||||
INSTALLED_APPS = [
|
||||
'django.contrib.admin',
|
||||
'django.contrib.auth',
|
||||
'django.contrib.contenttypes',
|
||||
'django.contrib.sessions',
|
||||
'django.contrib.messages',
|
||||
'django.contrib.staticfiles',
|
||||
'bootstrap3',
|
||||
'cas_server',
|
||||
]
|
||||
|
||||
MIDDLEWARE_CLASSES = [
|
||||
'django.contrib.sessions.middleware.SessionMiddleware',
|
||||
'django.middleware.common.CommonMiddleware',
|
||||
'django.middleware.csrf.CsrfViewMiddleware',
|
||||
'django.contrib.auth.middleware.AuthenticationMiddleware',
|
||||
'django.contrib.auth.middleware.SessionAuthenticationMiddleware',
|
||||
'django.contrib.messages.middleware.MessageMiddleware',
|
||||
'django.middleware.clickjacking.XFrameOptionsMiddleware',
|
||||
'django.middleware.locale.LocaleMiddleware',
|
||||
]
|
||||
|
||||
ROOT_URLCONF = 'cas_server.urls'
|
||||
|
||||
# Database
|
||||
# https://docs.djangoproject.com/en/1.9/ref/settings/#databases
|
||||
|
||||
DATABASES = {
|
||||
'default': {
|
||||
'ENGINE': 'django.db.backends.sqlite3',
|
||||
}
|
||||
}
|
||||
|
||||
# Internationalization
|
||||
# https://docs.djangoproject.com/en/1.9/topics/i18n/
|
||||
|
||||
LANGUAGE_CODE = 'en-us'
|
||||
|
||||
TIME_ZONE = 'UTC'
|
||||
|
||||
USE_I18N = True
|
||||
|
||||
USE_L10N = True
|
||||
|
||||
USE_TZ = True
|
||||
|
||||
|
||||
# Static files (CSS, JavaScript, Images)
|
||||
# https://docs.djangoproject.com/en/1.9/howto/static-files/
|
||||
|
||||
STATIC_URL = '/static/'
|
136
tests/dummy.py
136
tests/dummy.py
@ -1,136 +0,0 @@
|
||||
import functools
|
||||
from cas_server import models
|
||||
|
||||
class DummyUserManager(object):
|
||||
def __init__(self, username, session_key):
|
||||
self.username = username
|
||||
self.session_key = session_key
|
||||
def get(self, username=None, session_key=None):
|
||||
if username == self.username and session_key == self.session_key:
|
||||
return models.User(username=username, session_key=session_key)
|
||||
else:
|
||||
raise models.User.DoesNotExist()
|
||||
|
||||
|
||||
def dummy(*args, **kwds):
|
||||
pass
|
||||
|
||||
def dummy_service_pattern(**kwargs):
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwds):
|
||||
service_validate = models.ServicePattern.validate
|
||||
models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern(**kwargs))
|
||||
ret = func(*args, **kwds)
|
||||
models.ServicePattern.validate = service_validate
|
||||
return ret
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
def dummy_user(username, session_key):
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwds):
|
||||
user_manager = models.User.objects
|
||||
user_save = models.User.save
|
||||
user_delete = models.User.delete
|
||||
models.User.objects = DummyUserManager(username, session_key)
|
||||
models.User.save = dummy
|
||||
models.User.delete = dummy
|
||||
ret = func(*args, **kwds)
|
||||
models.User.objects = user_manager
|
||||
models.User.save = user_save
|
||||
models.User.delete = user_delete
|
||||
return ret
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
def dummy_ticket(ticket_class, service, ticket):
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwds):
|
||||
ticket_manager = ticket_class.objects
|
||||
ticket_save = ticket_class.save
|
||||
ticket_delete = ticket_class.delete
|
||||
ticket_class.objects = DummyTicketManager(ticket_class, service, ticket)
|
||||
ticket_class.save = dummy
|
||||
ticket_class.delete = dummy
|
||||
ret = func(*args, **kwds)
|
||||
ticket_class.objects = ticket_manager
|
||||
ticket_class.save = ticket_save
|
||||
ticket_class.delete = ticket_delete
|
||||
return ret
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def dummy_proxy(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwds):
|
||||
proxy_manager = models.Proxy.objects
|
||||
models.Proxy.objects = DummyProxyManager()
|
||||
ret = func(*args, **kwds)
|
||||
models.Proxy.objects = proxy_manager
|
||||
return ret
|
||||
return wrapper
|
||||
|
||||
class DummyProxyManager(object):
|
||||
def create(self, **kwargs):
|
||||
for field in models.Proxy._meta.fields:
|
||||
field.allow_unsaved_instance_assignment = True
|
||||
return models.Proxy(**kwargs)
|
||||
|
||||
class DummyTicketManager(object):
|
||||
def __init__(self, ticket_class, service, ticket):
|
||||
self.ticket_class = ticket_class
|
||||
self.service = service
|
||||
self.ticket = ticket
|
||||
|
||||
def create(self, **kwargs):
|
||||
for field in self.ticket_class._meta.fields:
|
||||
field.allow_unsaved_instance_assignment = True
|
||||
return self.ticket_class(**kwargs)
|
||||
|
||||
def filter(self, *args, **kwargs):
|
||||
return DummyQuerySet()
|
||||
|
||||
def get(self, **kwargs):
|
||||
for field in self.ticket_class._meta.fields:
|
||||
field.allow_unsaved_instance_assignment = True
|
||||
if 'value' in kwargs:
|
||||
if kwargs['value'] != self.ticket:
|
||||
raise self.ticket_class.DoesNotExist()
|
||||
else:
|
||||
kwargs['value'] = self.ticket
|
||||
|
||||
if 'service' in kwargs:
|
||||
if kwargs['service'] != self.service:
|
||||
raise self.ticket_class.DoesNotExist()
|
||||
else:
|
||||
kwargs['service'] = self.service
|
||||
if not 'user' in kwargs:
|
||||
kwargs['user'] = models.User(username="test")
|
||||
|
||||
for field in models.ServiceTicket._meta.fields:
|
||||
field.allow_unsaved_instance_assignment = True
|
||||
for key in list(kwargs):
|
||||
if '__' in key:
|
||||
del kwargs[key]
|
||||
kwargs['attributs'] = {'mail': 'test@example.com'}
|
||||
kwargs['service_pattern'] = models.ServicePattern()
|
||||
return self.ticket_class(**kwargs)
|
||||
|
||||
|
||||
|
||||
class DummySession(dict):
|
||||
session_key = "test_session"
|
||||
|
||||
def set_expiry(self, int):
|
||||
pass
|
||||
|
||||
def flush(self):
|
||||
self.clear()
|
||||
|
||||
|
||||
class DummyQuerySet(set):
|
||||
pass
|
@ -1,32 +0,0 @@
|
||||
import django
|
||||
from django.conf import settings
|
||||
from django.contrib import messages
|
||||
|
||||
settings.configure()
|
||||
settings.STATIC_URL = "/static/"
|
||||
settings.DATABASES = {
|
||||
'default': {
|
||||
'ENGINE': 'django.db.backends.sqlite3',
|
||||
'NAME': '/dev/null',
|
||||
}
|
||||
}
|
||||
settings.INSTALLED_APPS = (
|
||||
'django.contrib.admin',
|
||||
'django.contrib.auth',
|
||||
'django.contrib.contenttypes',
|
||||
'django.contrib.sessions',
|
||||
'django.contrib.messages',
|
||||
'django.contrib.staticfiles',
|
||||
'bootstrap3',
|
||||
'cas_server',
|
||||
)
|
||||
|
||||
settings.ROOT_URLCONF = "/"
|
||||
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
|
||||
|
||||
try:
|
||||
django.setup()
|
||||
except AttributeError:
|
||||
pass
|
||||
messages.add_message = lambda x,y,z:None
|
||||
|
@ -1,52 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
from tests.init import *
|
||||
|
||||
from django.test import RequestFactory
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from lxml import etree
|
||||
from cas_server.views import ValidateService, Proxy
|
||||
from cas_server import models
|
||||
|
||||
from tests.dummy import *
|
||||
|
||||
@pytest.mark.django_db
|
||||
@dummy_ticket(models.ProxyGrantingTicket, '', "PGT-random")
|
||||
@dummy_service_pattern(proxy=True)
|
||||
@dummy_user(username="test", session_key="test_session")
|
||||
@dummy_ticket(models.ProxyTicket, "https://www.example.com", "PT-random")
|
||||
@dummy_proxy
|
||||
def test_proxy_ok():
|
||||
factory = RequestFactory()
|
||||
request = factory.get('/proxy?pgt=PGT-random&targetService=https://www.example.com')
|
||||
|
||||
request.session = DummySession()
|
||||
|
||||
proxy = Proxy()
|
||||
response = proxy.get(request)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
root = etree.fromstring(response.content)
|
||||
proxy_tickets = root.xpath("//cas:proxyTicket", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||
|
||||
assert len(proxy_tickets) == 1
|
||||
|
||||
factory = RequestFactory()
|
||||
request = factory.get('/proxyValidate?ticket=PT-random&service=https://www.example.com')
|
||||
|
||||
validate = ValidateService()
|
||||
validate.allow_proxy_ticket = True
|
||||
response = validate.get(request)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
root = etree.fromstring(response.content)
|
||||
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||
|
||||
assert len(users) == 1
|
||||
assert users[0].text == "test"
|
||||
|
||||
|
||||
|
@ -1,87 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
from .init import *
|
||||
|
||||
from django.test import RequestFactory
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from lxml import etree
|
||||
from cas_server.views import ValidateService
|
||||
from cas_server import models
|
||||
|
||||
from .dummy import *
|
||||
|
||||
@pytest.mark.django_db
|
||||
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
||||
def test_validate_service_view_ok():
|
||||
factory = RequestFactory()
|
||||
request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example.com')
|
||||
|
||||
request.session = DummySession()
|
||||
|
||||
validate = ValidateService()
|
||||
validate.allow_proxy_ticket = False
|
||||
response = validate.get(request)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
root = etree.fromstring(response.content)
|
||||
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||
|
||||
assert len(users) == 1
|
||||
assert users[0].text == "test"
|
||||
|
||||
attributes = root.xpath("//cas:attributes", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||
|
||||
assert len(attributes) == 1
|
||||
|
||||
attrs = {}
|
||||
for attr in attributes[0]:
|
||||
attrs[attr.tag[len("http://www.yale.edu/tp/cas")+2:]]=attr.text
|
||||
|
||||
assert 'mail' in attrs
|
||||
assert attrs['mail'] == 'test@example.com'
|
||||
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@dummy_ticket(models.ServiceTicket, 'https://www.example2.com', "ST-random")
|
||||
def test_validate_service_view_badservice():
|
||||
factory = RequestFactory()
|
||||
request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example1.com')
|
||||
|
||||
request.session = DummySession()
|
||||
|
||||
validate = ValidateService()
|
||||
validate.allow_proxy_ticket = False
|
||||
response = validate.get(request)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
root = etree.fromstring(response.content)
|
||||
|
||||
error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||
|
||||
assert len(error) == 1
|
||||
assert error[0].attrib['code'] == 'INVALID_SERVICE'
|
||||
|
||||
@pytest.mark.django_db
|
||||
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random2")
|
||||
def test_validate_service_view_badticket():
|
||||
factory = RequestFactory()
|
||||
request = factory.get('/serviceValidate?ticket=ST-random1&service=https://www.example.com')
|
||||
|
||||
request.session = DummySession()
|
||||
|
||||
validate = ValidateService()
|
||||
validate.allow_proxy_ticket = False
|
||||
response = validate.get(request)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
root = etree.fromstring(response.content)
|
||||
|
||||
error = root.xpath("//cas:authenticationFailure", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
||||
|
||||
assert len(error) == 1
|
||||
assert error[0].attrib['code'] == 'INVALID_TICKET'
|
@ -1,46 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
from .init import *
|
||||
|
||||
from django.test import RequestFactory
|
||||
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from cas_server.views import Auth
|
||||
from cas_server import models
|
||||
|
||||
from .dummy import *
|
||||
|
||||
settings.CAS_AUTH_SHARED_SECRET = "test"
|
||||
|
||||
@pytest.mark.django_db
|
||||
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
||||
@dummy_user(username="test", session_key="test_session")
|
||||
@dummy_service_pattern()
|
||||
def test_auth_view_goodpass():
|
||||
factory = RequestFactory()
|
||||
request = factory.post('/auth', {'username':'test', 'password':'test', 'service':'https://www.example.com', 'secret':'test'})
|
||||
|
||||
request.session = DummySession()
|
||||
|
||||
auth = Auth()
|
||||
response = auth.post(request)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"yes\n"
|
||||
|
||||
@dummy_service_pattern()
|
||||
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
||||
@dummy_user(username="test", session_key="test_session")
|
||||
def test_auth_view_badpass():
|
||||
factory = RequestFactory()
|
||||
request = factory.post('/auth', {'username':'test', 'password':'badpass', 'service':'https://www.example.com', 'secret':'test'})
|
||||
|
||||
request.session = DummySession()
|
||||
|
||||
auth = Auth()
|
||||
response = auth.post(request)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"no\n"
|
||||
|
@ -1,163 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
from .init import *
|
||||
|
||||
from django.test import RequestFactory
|
||||
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from cas_server.views import LoginView
|
||||
from cas_server import models
|
||||
|
||||
from .dummy import *
|
||||
|
||||
|
||||
|
||||
def test_login_view_post_goodpass_goodlt():
|
||||
factory = RequestFactory()
|
||||
request = factory.post('/login', {'username':'test', 'password':'test', 'lt':'LT-random'})
|
||||
request.session = DummySession()
|
||||
|
||||
request.session['lt'] = ['LT-random']
|
||||
|
||||
request.session["username"] = os.urandom(20)
|
||||
request.session["warn"] = os.urandom(20)
|
||||
|
||||
login = LoginView()
|
||||
login.init_post(request)
|
||||
|
||||
ret = login.process_post(pytest=True)
|
||||
|
||||
assert ret == LoginView.USER_LOGIN_OK
|
||||
assert request.session.get("authenticated") == True
|
||||
assert request.session.get("username") == "test"
|
||||
assert request.session.get("warn") == False
|
||||
|
||||
def test_login_view_post_badlt():
|
||||
factory = RequestFactory()
|
||||
request = factory.post('/login', {'username':'test', 'password':'test', 'lt':'LT-random1'})
|
||||
request.session = DummySession()
|
||||
|
||||
request.session['lt'] = ['LT-random2']
|
||||
|
||||
authenticated = os.urandom(20)
|
||||
username = os.urandom(20)
|
||||
warn = os.urandom(20)
|
||||
|
||||
request.session["authenticated"] = authenticated
|
||||
request.session["username"] = username
|
||||
request.session["warn"] = warn
|
||||
|
||||
login = LoginView()
|
||||
login.init_post(request)
|
||||
|
||||
ret = login.process_post(pytest=True)
|
||||
|
||||
assert ret == LoginView.INVALID_LOGIN_TICKET
|
||||
assert request.session.get("authenticated") == authenticated
|
||||
assert request.session.get("username") == username
|
||||
assert request.session.get("warn") == warn
|
||||
|
||||
def test_login_view_post_badpass_good_lt():
|
||||
factory = RequestFactory()
|
||||
request = factory.post('/login', {'username':'test', 'password':'badpassword', 'lt':'LT-random'})
|
||||
request.session = DummySession()
|
||||
|
||||
request.session['lt'] = ['LT-random']
|
||||
|
||||
login = LoginView()
|
||||
login.init_post(request)
|
||||
ret = login.process_post()
|
||||
|
||||
assert ret == LoginView.USER_LOGIN_FAILURE
|
||||
assert not request.session.get("authenticated")
|
||||
assert not request.session.get("username")
|
||||
assert not request.session.get("warn")
|
||||
|
||||
|
||||
def test_view_login_get_unauth():
|
||||
factory = RequestFactory()
|
||||
request = factory.post('/login')
|
||||
request.session = DummySession()
|
||||
|
||||
login = LoginView()
|
||||
login.init_get(request)
|
||||
ret = login.process_get()
|
||||
|
||||
assert ret == LoginView.USER_NOT_AUTHENTICATED
|
||||
|
||||
login = LoginView()
|
||||
response = login.get(request)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.django_db
|
||||
@dummy_user(username="test", session_key="test_session")
|
||||
def test_view_login_get_auth():
|
||||
factory = RequestFactory()
|
||||
request = factory.post('/login')
|
||||
request.session = DummySession()
|
||||
|
||||
request.session["authenticated"] = True
|
||||
request.session["username"] = "test"
|
||||
request.session["warn"] = False
|
||||
|
||||
login = LoginView()
|
||||
login.init_get(request)
|
||||
ret = login.process_get()
|
||||
|
||||
assert ret == LoginView.USER_AUTHENTICATED
|
||||
|
||||
login = LoginView()
|
||||
response = login.get(request)
|
||||
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.django_db
|
||||
@dummy_service_pattern()
|
||||
@dummy_user(username="test", session_key="test_session")
|
||||
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
||||
def test_view_login_get_auth_service():
|
||||
factory = RequestFactory()
|
||||
request = factory.post('/login?service=https://www.example.com')
|
||||
request.session = DummySession()
|
||||
|
||||
request.session["authenticated"] = True
|
||||
request.session["username"] = "test"
|
||||
request.session["warn"] = False
|
||||
|
||||
login = LoginView()
|
||||
login.init_get(request)
|
||||
ret = login.process_get()
|
||||
|
||||
assert ret == LoginView.USER_AUTHENTICATED
|
||||
|
||||
login = LoginView()
|
||||
response = login.get(request)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert response['Location'].startswith('https://www.example.com?ticket=ST-')
|
||||
|
||||
@pytest.mark.django_db
|
||||
@dummy_service_pattern()
|
||||
@dummy_user(username="test", session_key="test_session")
|
||||
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
||||
def test_view_login_get_auth_service_warn():
|
||||
factory = RequestFactory()
|
||||
request = factory.post('/login?service=https://www.example.com')
|
||||
request.session = DummySession()
|
||||
|
||||
request.session["authenticated"] = True
|
||||
request.session["username"] = "test"
|
||||
request.session["warn"] = True
|
||||
|
||||
login = LoginView()
|
||||
login.init_get(request)
|
||||
ret = login.process_get()
|
||||
|
||||
assert ret == LoginView.USER_AUTHENTICATED
|
||||
|
||||
login = LoginView()
|
||||
response = login.get(request)
|
||||
|
||||
assert response.status_code == 200
|
@ -1,80 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
from .init import *
|
||||
|
||||
from django.test import RequestFactory
|
||||
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from cas_server.views import LogoutView
|
||||
from cas_server import models
|
||||
|
||||
from .dummy import *
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@dummy_user(username="test", session_key="test_session")
|
||||
def test_logout_view():
|
||||
factory = RequestFactory()
|
||||
request = factory.get('/logout')
|
||||
|
||||
request.session = DummySession()
|
||||
|
||||
request.session["authenticated"] = True
|
||||
request.session["username"] = "test"
|
||||
request.session["warn"] = False
|
||||
|
||||
logout = LogoutView()
|
||||
response = logout.get(request)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert not request.session.get("authenticated")
|
||||
assert not request.session.get("username")
|
||||
assert not request.session.get("warn")
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@dummy_user(username="test", session_key="test_session")
|
||||
def test_logout_view_url():
|
||||
factory = RequestFactory()
|
||||
request = factory.get('/logout?url=https://www.example.com')
|
||||
|
||||
request.session = DummySession()
|
||||
|
||||
request.session["authenticated"] = True
|
||||
request.session["username"] = "test"
|
||||
request.session["warn"] = False
|
||||
|
||||
logout = LogoutView()
|
||||
response = logout.get(request)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert response['Location'] == 'https://www.example.com'
|
||||
assert not request.session.get("authenticated")
|
||||
assert not request.session.get("username")
|
||||
assert not request.session.get("warn")
|
||||
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@dummy_user(username="test", session_key="test_session")
|
||||
def test_logout_view_service():
|
||||
factory = RequestFactory()
|
||||
request = factory.get('/logout?service=https://www.example.com')
|
||||
|
||||
request.session = DummySession()
|
||||
|
||||
request.session["authenticated"] = True
|
||||
request.session["username"] = "test"
|
||||
request.session["warn"] = False
|
||||
|
||||
logout = LogoutView()
|
||||
response = logout.get(request)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert response['Location'] == 'https://www.example.com'
|
||||
assert not request.session.get("authenticated")
|
||||
assert not request.session.get("username")
|
||||
assert not request.session.get("warn")
|
||||
|
||||
|
@ -1,58 +0,0 @@
|
||||
from __future__ import absolute_import
|
||||
from .init import *
|
||||
|
||||
from django.test import RequestFactory
|
||||
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from cas_server.views import Validate
|
||||
from cas_server import models
|
||||
|
||||
from .dummy import *
|
||||
|
||||
@pytest.mark.django_db
|
||||
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
||||
def test_validate_view_ok():
|
||||
factory = RequestFactory()
|
||||
request = factory.get('/validate?ticket=ST-random&service=https://www.example.com')
|
||||
|
||||
request.session = DummySession()
|
||||
|
||||
validate = Validate()
|
||||
response = validate.get(request)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"yes\ntest\n"
|
||||
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random")
|
||||
def test_validate_view_badservice():
|
||||
factory = RequestFactory()
|
||||
request = factory.get('/validate?ticket=ST-random&service=https://www.example2.com')
|
||||
|
||||
request.session = DummySession()
|
||||
|
||||
validate = Validate()
|
||||
response = validate.get(request)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"no\n"
|
||||
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random1")
|
||||
def test_validate_view_badticket():
|
||||
factory = RequestFactory()
|
||||
request = factory.get('/validate?ticket=ST-random2&service=https://www.example.com')
|
||||
|
||||
request.session = DummySession()
|
||||
|
||||
validate = Validate()
|
||||
response = validate.get(request)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.content == b"no\n"
|
Loading…
Reference in New Issue
Block a user