django-cas-server/cas_server/tests/test_view.py

1352 lines
52 KiB
Python

"""Tests module"""
from cas_server.default_settings import settings
import django
from django.test import TestCase, Client
from django.test.utils import override_settings
import re
import six
import random
import json
from lxml import etree
from six.moves import range
from cas_server import models
from cas_server import utils
def copy_form(form):
"""Copy form value into a dict"""
params = {}
for field in form:
if field.value():
params[field.name] = field.value()
else:
params[field.name] = ""
return params
def get_login_page_params(client=None):
"""Return a client and the POST params for the client to login"""
if client is None:
client = Client()
response = client.get('/login')
params = copy_form(response.context["form"])
return client, params
def get_auth_client(**update):
"""return a authenticated client"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
params.update(update)
client.post('/login', params)
return client
def get_user_ticket_request(service):
"""Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
client = get_auth_client()
response = client.get("/login", {"service": service})
ticket_value = response['Location'].split('ticket=')[-1]
user = models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
ticket = models.ServiceTicket.objects.get(value=ticket_value)
return (user, ticket)
def get_pgt():
"""return a dict contening a service, user and PGT ticket for this service"""
(host, port) = utils.PGTUrlHandler.run()[1:3]
service = "http://%s:%s" % (host, port)
(user, ticket) = get_user_ticket_request(service)
client = Client()
client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
params = utils.PGTUrlHandler.PARAMS.copy()
params["service"] = service
params["user"] = user
return params
class CheckPasswordCase(TestCase):
"""Tests for the utils function `utils.check_password`"""
def setUp(self):
"""Generate random bytes string that will be used ass passwords"""
self.password1 = utils.gen_saml_id()
self.password2 = utils.gen_saml_id()
if not isinstance(self.password1, bytes): # pragma: no cover executed only in python3
self.password1 = self.password1.encode("utf8")
self.password2 = self.password2.encode("utf8")
def test_setup(self):
"""check that generated password are bytes"""
self.assertIsInstance(self.password1, bytes)
self.assertIsInstance(self.password2, bytes)
def test_plain(self):
"""test the plain auth method"""
self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
def test_crypt(self):
"""test the crypt auth method"""
if six.PY3:
hashed_password1 = utils.crypt.crypt(
self.password1.decode("utf8"),
"$6$UVVAQvrMyXMF3FF3"
).encode("utf8")
else:
hashed_password1 = utils.crypt.crypt(self.password1, "$6$UVVAQvrMyXMF3FF3")
self.assertTrue(utils.check_password("crypt", self.password1, hashed_password1, "utf8"))
self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8"))
def test_ldap_ssha(self):
"""test the ldap auth method with a {SSHA} scheme"""
salt = b"UVVAQvrMyXMF3FF3"
hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8")
self.assertIsInstance(hashed_password1, bytes)
self.assertTrue(utils.check_password("ldap", self.password1, hashed_password1, "utf8"))
self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8"))
def test_hex_md5(self):
"""test the hex_md5 auth method"""
hashed_password1 = utils.hashlib.md5(self.password1).hexdigest()
self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8"))
self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8"))
def test_hex_sha512(self):
"""test the hex_sha512 auth method"""
hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest()
self.assertTrue(
utils.check_password("hex_sha512", self.password1, hashed_password1, "utf8")
)
self.assertFalse(
utils.check_password("hex_sha512", self.password2, hashed_password1, "utf8")
)
class BaseServicePattern(object):
"""Mixing for setting up service pattern for testing"""
def setup_service_patterns(self, proxy=False):
"""setting up service pattern"""
# For general purpose testing
self.service = "https://www.example.com"
self.service_pattern = models.ServicePattern.objects.create(
name="example",
pattern="^https://www\.example\.com(/.*)?$",
proxy=proxy,
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
# For testing the restrict_users attributes
self.service_restrict_user_fail = "https://restrict_user_fail.example.com"
self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create(
name="restrict_user_fail",
pattern="^https://restrict_user_fail\.example\.com(/.*)?$",
restrict_users=True,
proxy=proxy,
)
self.service_restrict_user_success = "https://restrict_user_success.example.com"
self.service_pattern_restrict_user_success = models.ServicePattern.objects.create(
name="restrict_user_success",
pattern="^https://restrict_user_success\.example\.com(/.*)?$",
restrict_users=True,
proxy=proxy,
)
models.Username.objects.create(
value=settings.CAS_TEST_USER,
service_pattern=self.service_pattern_restrict_user_success
)
# For testing the user attributes filtering conditions
self.service_filter_fail = "https://filter_fail.example.com"
self.service_pattern_filter_fail = models.ServicePattern.objects.create(
name="filter_fail",
pattern="^https://filter_fail\.example\.com(/.*)?$",
proxy=proxy,
)
models.FilterAttributValue.objects.create(
attribut="right",
pattern="^admin$",
service_pattern=self.service_pattern_filter_fail
)
self.service_filter_success = "https://filter_success.example.com"
self.service_pattern_filter_success = models.ServicePattern.objects.create(
name="filter_success",
pattern="^https://filter_success\.example\.com(/.*)?$",
proxy=proxy,
)
models.FilterAttributValue.objects.create(
attribut="email",
pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']),
service_pattern=self.service_pattern_filter_success
)
# For testing the user_field attributes
self.service_field_needed_fail = "https://field_needed_fail.example.com"
self.service_pattern_field_needed_fail = models.ServicePattern.objects.create(
name="field_needed_fail",
pattern="^https://field_needed_fail\.example\.com(/.*)?$",
user_field="uid",
proxy=proxy,
)
self.service_field_needed_success = "https://field_needed_success.example.com"
self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
name="field_needed_success",
pattern="^https://field_needed_success\.example\.com(/.*)?$",
user_field="nom",
proxy=proxy,
)
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class LoginTestCase(TestCase, BaseServicePattern):
"""Tests for the login view"""
def setUp(self):
"""Prepare the test context:"""
self.setup_service_patterns()
def assert_logged(self, client, response, warn=False, code=200):
"""Assertions testing that client is well authenticated"""
self.assertEqual(response.status_code, code)
self.assertTrue(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
self.assertTrue(client.session["username"] == settings.CAS_TEST_USER)
self.assertTrue(client.session["warn"] is warn)
self.assertTrue(client.session["authenticated"] is True)
self.assertTrue(
models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
)
def assert_login_failed(self, client, response, code=200):
"""Assertions testing a failed login attempt"""
self.assertEqual(response.status_code, code)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
self.assertTrue(client.session.get("username") is None)
self.assertTrue(client.session.get("warn") is None)
self.assertTrue(client.session.get("authenticated") is None)
def test_login_view_post_goodpass_goodlt(self):
"""Test a successul login"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
self.assertTrue(params['lt'] in client.session['lt'])
response = client.post('/login', params)
self.assert_logged(client, response)
# LoginTicket conssumed
self.assertTrue(params['lt'] not in client.session['lt'])
def test_login_view_post_goodpass_goodlt_warn(self):
"""Test a successul login requesting to be warned before creating services tickets"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
params["warn"] = "on"
response = client.post('/login', params)
self.assert_logged(client, response, warn=True)
def test_lt_max(self):
"""Check we only keep the last 100 Login Ticket for a user"""
client, params = get_login_page_params()
current_lt = params["lt"]
i_in_test = random.randint(0, 99)
i_not_in_test = random.randint(101, 150)
for i in range(150):
if i == i_in_test:
self.assertTrue(current_lt in client.session['lt'])
if i == i_not_in_test:
self.assertTrue(current_lt not in client.session['lt'])
self.assertTrue(len(client.session['lt']) <= 100)
client, params = get_login_page_params(client)
self.assertTrue(len(client.session['lt']) <= 100)
def test_login_view_post_badlt(self):
"""Login attempt with a bad LoginTicket"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
params["lt"] = 'LT-random'
response = client.post('/login', params)
self.assert_login_failed(client, response)
self.assertTrue(b"Invalid login ticket" in response.content)
def test_login_view_post_badpass_good_lt(self):
"""Login attempt with a bad password"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = "test2"
response = client.post('/login', params)
self.assert_login_failed(client, response)
self.assertTrue(
(
b"The credentials you provided cannot be "
b"determined to be authentic"
) in response.content
)
def assert_ticket_attributes(self, client, ticket_value):
"""check the ticket attributes in the db"""
user = models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
self.assertTrue(user)
ticket = models.ServiceTicket.objects.get(value=ticket_value)
self.assertEqual(ticket.user, user)
self.assertEqual(ticket.attributs, settings.CAS_TEST_ATTRIBUTES)
self.assertEqual(ticket.validate, False)
self.assertEqual(ticket.service_pattern, self.service_pattern)
def assert_service_ticket(self, client, response):
"""check that a ticket is well emited when requested on a allowed service"""
self.assertEqual(response.status_code, 302)
self.assertTrue(response.has_header('Location'))
self.assertTrue(
response['Location'].startswith(
"https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX
)
)
ticket_value = response['Location'].split('ticket=')[-1]
self.assert_ticket_attributes(client, ticket_value)
def test_view_login_get_allowed_service(self):
"""Request a ticket for an allowed service by an unauthenticated client"""
client = Client()
response = client.get("/login?service=https://www.example.com")
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"Authentication required by service "
b"example (https://www.example.com)"
) in response.content
)
def test_view_login_get_denied_service(self):
"""Request a ticket for an denied service by an unauthenticated client"""
client = Client()
response = client.get("/login?service=https://www.example.net")
self.assertEqual(response.status_code, 200)
self.assertTrue(b"Service https://www.example.net non allowed" in response.content)
def test_view_login_get_auth_allowed_service(self):
"""Request a ticket for an allowed service by an authenticated client"""
# client is already authenticated
client = get_auth_client()
response = client.get("/login?service=https://www.example.com")
self.assert_service_ticket(client, response)
def test_view_login_get_auth_allowed_service_warn(self):
"""Request a ticket for an allowed service by an authenticated client"""
# client is already authenticated
client = get_auth_client(warn="on")
response = client.get("/login?service=https://www.example.com")
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"Authentication has been required by service "
b"example (https://www.example.com)"
) in response.content
)
params = copy_form(response.context["form"])
response = client.post("/login", params)
self.assert_service_ticket(client, response)
def test_view_login_get_auth_denied_service(self):
"""Request a ticket for a not allowed service by an authenticated client"""
client = get_auth_client()
response = client.get("/login?service=https://www.example.org")
self.assertEqual(response.status_code, 200)
self.assertTrue(b"Service https://www.example.org non allowed" in response.content)
def test_user_logged_not_in_db(self):
"""If the user is logged but has been delete from the database, it should be logged out"""
client = get_auth_client()
models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
).delete()
response = client.get("/login")
self.assert_login_failed(client, response, code=302)
if django.VERSION < (1, 9): # pragma: no cover coverage is computed with dango 1.9
self.assertEqual(response["Location"], "http://testserver/login")
else:
self.assertEqual(response["Location"], "/login?")
def test_service_restrict_user(self):
"""Testing the restric user capability fro a service"""
client = get_auth_client()
response = client.get("/login", {'service': self.service_restrict_user_fail})
self.assertEqual(response.status_code, 200)
self.assertTrue(b"Username non allowed" in response.content)
response = client.get("/login", {'service': self.service_restrict_user_success})
self.assertEqual(response.status_code, 302)
self.assertTrue(
response["Location"].startswith("%s?ticket=" % self.service_restrict_user_success)
)
def test_service_filter(self):
"""Test the filtering on user attributes"""
client = get_auth_client()
response = client.get("/login", {'service': self.service_filter_fail})
self.assertEqual(response.status_code, 200)
self.assertTrue(b"User charateristics non allowed" in response.content)
response = client.get("/login", {'service': self.service_filter_success})
self.assertEqual(response.status_code, 302)
self.assertTrue(response["Location"].startswith("%s?ticket=" % self.service_filter_success))
def test_service_user_field(self):
"""Test using a user attribute as username: case on if the attribute exists or not"""
client = get_auth_client()
response = client.get("/login", {'service': self.service_field_needed_fail})
self.assertEqual(response.status_code, 200)
self.assertTrue(b"The attribut uid is needed to use that service" in response.content)
response = client.get("/login", {'service': self.service_field_needed_success})
self.assertEqual(response.status_code, 302)
self.assertTrue(
response["Location"].startswith("%s?ticket=" % self.service_field_needed_success)
)
@override_settings(CAS_TEST_ATTRIBUTES={'nom': []})
def test_service_user_field_evaluate_to_false(self):
"""
Test using a user attribute as username:
case the attribute exists but evaluate to False
"""
client = get_auth_client()
response = client.get("/login", {"service": self.service_field_needed_success})
self.assertEqual(response.status_code, 200)
self.assertTrue(b"The attribut nom is needed to use that service" in response.content)
def test_gateway(self):
"""test gateway parameter"""
# First with an authenticated client that fail to get a ticket for a service
service = "https://restrict_user_fail.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service, 'gateway': 'on'})
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], service)
# second for an user not yet authenticated on a valid service
client = Client()
response = client.get('/login', {'service': service, 'gateway': 'on'})
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], service)
def test_renew(self):
"""test the authentication renewal request from a service"""
service = "https://www.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service, 'renew': 'on'})
# we are ask to reauthenticate
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"Authentication renewal required by "
b"service example (https://www.example.com)"
) in response.content
)
params = copy_form(response.context["form"])
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
self.assertEqual(params["renew"], True)
response = client.post("/login", params)
self.assertEqual(response.status_code, 302)
ticket_value = response['Location'].split('ticket=')[-1]
ticket = models.ServiceTicket.objects.get(value=ticket_value)
# the created ticket is marked has being gottent after a renew
self.assertEqual(ticket.renew, True)
def test_ajax_login_required(self):
"""test ajax, login required"""
client = Client()
response = client.get("/login", HTTP_X_AJAX='on')
self.assertEqual(response.status_code, 200)
data = json.loads(response.content.decode("utf8"))
self.assertEqual(data["status"], "error")
self.assertEqual(data["detail"], "login required")
self.assertEqual(data["url"], "/login?")
def test_ajax_logged_user_deleted(self):
"""test ajax user logged deleted: login required"""
client = get_auth_client()
user = models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
)
user.delete()
response = client.get("/login", HTTP_X_AJAX='on')
self.assertEqual(response.status_code, 200)
data = json.loads(response.content.decode("utf8"))
self.assertEqual(data["status"], "error")
self.assertEqual(data["detail"], "login required")
self.assertEqual(data["url"], "/login?")
def test_ajax_logged(self):
"""test ajax user is successfully logged"""
client = get_auth_client()
response = client.get("/login", HTTP_X_AJAX='on')
self.assertEqual(response.status_code, 200)
data = json.loads(response.content.decode("utf8"))
self.assertEqual(data["status"], "success")
self.assertEqual(data["detail"], "logged")
def test_ajax_get_ticket_success(self):
"""test ajax retrieve a ticket for an allowed service"""
service = "https://www.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service}, HTTP_X_AJAX='on')
self.assertEqual(response.status_code, 200)
data = json.loads(response.content.decode("utf8"))
self.assertEqual(data["status"], "success")
self.assertEqual(data["detail"], "auth")
self.assertTrue(data["url"].startswith('%s?ticket=' % service))
def test_ajax_get_ticket_fail(self):
"""test ajax retrieve a ticket for a denied service"""
service = "https://www.example.org"
client = get_auth_client()
response = client.get("/login", {'service': service}, HTTP_X_AJAX='on')
self.assertEqual(response.status_code, 200)
data = json.loads(response.content.decode("utf8"))
self.assertEqual(data["status"], "error")
self.assertEqual(data["detail"], "auth")
self.assertEqual(data["messages"][0]["level"], "error")
self.assertEqual(
data["messages"][0]["message"],
"Service https://www.example.org non allowed."
)
def test_ajax_get_ticket_warn(self):
"""test get a ticket but user asked to be warned"""
service = "https://www.example.com"
client = get_auth_client(warn="on")
response = client.get("/login", {'service': service}, HTTP_X_AJAX='on')
self.assertEqual(response.status_code, 200)
data = json.loads(response.content.decode("utf8"))
self.assertEqual(data["status"], "error")
self.assertEqual(data["detail"], "confirmation needed")
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class LogoutTestCase(TestCase):
"""test fot the logout view"""
def test_logout(self):
"""logout is idempotent"""
client = Client()
client.get("/logout")
self.assertFalse(client.session.get("username"))
self.assertFalse(client.session.get("authenticated"))
def test_logout_view(self):
"""test simple logout"""
client = get_auth_client()
client2 = get_auth_client()
response = client.get("/login")
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
self.assertTrue(client.session["username"] == settings.CAS_TEST_USER)
self.assertTrue(client.session["authenticated"] is True)
response = client.get("/logout")
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
b"You have successfully logged out from "
b"the Central Authentication Service"
) in response.content
)
self.assertFalse(client.session.get("username"))
self.assertFalse(client.session.get("authenticated"))
# client2 is still logged
self.assertTrue(client2.session["username"] == settings.CAS_TEST_USER)
self.assertTrue(client2.session["authenticated"] is True)
response = client.get("/login")
self.assertEqual(response.status_code, 200)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
def test_logout_from_all_session(self):
"""test logout from all my session"""
client = get_auth_client()
client2 = get_auth_client()
client.get("/logout?all=1")
# both client are logged out
self.assertFalse(client.session.get("username"))
self.assertFalse(client.session.get("authenticated"))
self.assertFalse(client2.session.get("username"))
self.assertFalse(client2.session.get("authenticated"))
def assert_redirect_to_service(self, client, response):
"""assert logout redirect to parameter"""
self.assertEqual(response.status_code, 302)
self.assertTrue(response.has_header("Location"))
self.assertEqual(response["Location"], "https://www.example.com")
response = client.get("/login")
self.assertEqual(response.status_code, 200)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
def test_logout_view_url(self):
"""test logout redirect to url parameter"""
client = get_auth_client()
response = client.get('/logout?url=https://www.example.com')
self.assert_redirect_to_service(client, response)
def test_logout_view_service(self):
"""test logout redirect to service parameter"""
client = get_auth_client()
response = client.get('/logout?service=https://www.example.com')
self.assert_redirect_to_service(client, response)
def test_ajax_logout(self):
"""test ajax logout"""
client = get_auth_client()
response = client.get('/logout', HTTP_X_AJAX='on')
self.assertEqual(response.status_code, 200)
data = json.loads(response.content.decode("utf8"))
self.assertEqual(data["status"], "success")
self.assertEqual(data["detail"], "logout")
self.assertEqual(data['session_nb'], 1)
def test_ajax_logout_all_session(self):
"""test ajax logout from a random number a sessions"""
nb_client = random.randint(2, 10)
clients = [get_auth_client() for i in range(nb_client)]
response = clients[0].get('/logout?all=1', HTTP_X_AJAX='on')
self.assertEqual(response.status_code, 200)
data = json.loads(response.content.decode("utf8"))
self.assertEqual(data["status"], "success")
self.assertEqual(data["detail"], "logout")
self.assertEqual(data['session_nb'], nb_client)
@override_settings(CAS_REDIRECT_TO_LOGIN_AFTER_LOGOUT=True)
def test_redirect_after_logout(self):
"""Test redirect to login after logout parameter"""
client = get_auth_client()
response = client.get('/logout')
self.assertEqual(response.status_code, 302)
if django.VERSION < (1, 9): # pragma: no cover coverage is computed with dango 1.9
self.assertEqual(response["Location"], "http://testserver/login")
else:
self.assertEqual(response["Location"], "/login")
self.assertFalse(client.session.get("username"))
self.assertFalse(client.session.get("authenticated"))
@override_settings(CAS_REDIRECT_TO_LOGIN_AFTER_LOGOUT=True)
def test_redirect_after_logout_to_service(self):
"""test prevalence of redirect url/service parameter over redirect to login after logout"""
client = get_auth_client()
response = client.get('/logout?url=https://www.example.com')
self.assert_redirect_to_service(client, response)
response = client.get('/logout?service=https://www.example.com')
self.assert_redirect_to_service(client, response)
@override_settings(CAS_REDIRECT_TO_LOGIN_AFTER_LOGOUT=True)
def test_ajax_redirect_after_logout(self):
"""Test ajax redirect to login after logout parameter"""
client = get_auth_client()
response = client.get('/logout', HTTP_X_AJAX='on')
self.assertEqual(response.status_code, 200)
data = json.loads(response.content.decode("utf8"))
self.assertEqual(data["status"], "success")
self.assertEqual(data["detail"], "logout")
self.assertEqual(data['session_nb'], 1)
self.assertEqual(data['url'], '/login')
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class AuthTestCase(TestCase):
"""
Test for the auth view, used for external services
to validate (user, pass, service) tuples.
"""
def setUp(self):
"""preparing test context"""
self.service = 'https://www.example.com'
models.ServicePattern.objects.create(
name="example",
pattern="^https://www\.example\.com(/.*)?$"
)
@override_settings(CAS_AUTH_SHARED_SECRET='test')
def test_auth_view_goodpass(self):
"""successful request are awsered by yes"""
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': settings.CAS_TEST_PASSWORD,
'service': self.service,
'secret': 'test'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'yes\n')
@override_settings(CAS_AUTH_SHARED_SECRET='test')
def test_auth_view_badpass(self):
""" bag user password => no"""
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': 'badpass',
'service': self.service,
'secret': 'test'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
@override_settings(CAS_AUTH_SHARED_SECRET='test')
def test_auth_view_badservice(self):
"""bad service => no"""
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': settings.CAS_TEST_PASSWORD,
'service': 'https://www.example.org',
'secret': 'test'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
@override_settings(CAS_AUTH_SHARED_SECRET='test')
def test_auth_view_badsecret(self):
"""bad api key => no"""
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': settings.CAS_TEST_PASSWORD,
'service': self.service,
'secret': 'badsecret'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
def test_auth_view_badsettings(self):
"""api not set => error"""
client = Client()
response = client.post(
'/auth',
{
'username': settings.CAS_TEST_USER,
'password': settings.CAS_TEST_PASSWORD,
'service': self.service,
'secret': 'test'
}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b"no\nplease set CAS_AUTH_SHARED_SECRET")
@override_settings(CAS_AUTH_SHARED_SECRET='test')
def test_auth_view_missing_parameter(self):
"""missing parameter in request => no"""
client = Client()
params = {
'username': settings.CAS_TEST_USER,
'password': settings.CAS_TEST_PASSWORD,
'service': self.service,
'secret': 'test'
}
for key in ['username', 'password', 'service']:
send_params = params.copy()
del send_params[key]
response = client.post('/auth', send_params)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class ValidateTestCase(TestCase):
"""tests for the validate view"""
def setUp(self):
"""preparing test context"""
self.service = 'https://www.example.com'
self.service_pattern = models.ServicePattern.objects.create(
name="example",
pattern="^https://www\.example\.com(/.*)?$"
)
self.service_user_field = "https://user_field.example.com"
self.service_pattern_user_field = models.ServicePattern.objects.create(
name="user field",
pattern="^https://user_field\.example\.com(/.*)?$",
user_field="alias"
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_validate_view_ok(self):
"""test for a valid (ticket, service)"""
ticket = get_user_ticket_request(self.service)[1]
client = Client()
response = client.get('/validate', {'ticket': ticket.value, 'service': self.service})
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'yes\ntest\n')
def test_validate_view_badservice(self):
"""test for a valid ticket but bad service"""
ticket = get_user_ticket_request(self.service)[1]
client = Client()
response = client.get(
'/validate',
{'ticket': ticket.value, 'service': "https://www.example.org"}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
def test_validate_view_badticket(self):
"""test for a bad ticket but valid service"""
get_user_ticket_request(self.service)
client = Client()
response = client.get(
'/validate',
{'ticket': "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX, 'service': self.service}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
def test_validate_user_field_ok(self):
"""
test with a good user_field. A bad user_field (that evaluate to False)
wont happed cause it is filtered in the login view
"""
ticket = get_user_ticket_request(self.service_user_field)[1]
client = Client()
response = client.get(
'/validate',
{'ticket': ticket.value, 'service': self.service_user_field}
)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'yes\ndemo1\n')
def test_validate_missing_parameter(self):
"""test with a missing GET parameter among [service, ticket]"""
ticket = get_user_ticket_request(self.service)[1]
client = Client()
params = {'ticket': ticket.value, 'service': self.service}
for key in ['ticket', 'service']:
send_params = params.copy()
del send_params[key]
response = client.get('/validate', send_params)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.content, b'no\n')
class XmlContent(object):
"""Mixin for test on CAS XML responses"""
def assert_error(self, response, code, text=None):
"""Assert a validation error"""
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
error = root.xpath(
"//cas:authenticationFailure",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(error), 1)
self.assertEqual(error[0].attrib['code'], code)
if text is not None:
self.assertEqual(error[0].text, text)
def assert_success(self, response, username, original_attributes):
"""assert a ticket validation success"""
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
sucess = root.xpath(
"//cas:authenticationSuccess",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertTrue(sucess)
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(users), 1)
self.assertEqual(users[0].text, username)
attributes = root.xpath(
"//cas:attributes",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(attributes), 1)
attrs1 = set()
for attr in attributes[0]:
attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(attributes), len(attrs1))
attrs2 = set()
for attr in attributes:
attrs2.add((attr.attrib['name'], attr.attrib['value']))
original = set()
for key, value in original_attributes.items():
if isinstance(value, list):
for sub_value in value:
original.add((key, sub_value))
else:
original.add((key, value))
self.assertEqual(attrs1, attrs2)
self.assertEqual(attrs1, original)
return root
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class ValidateServiceTestCase(TestCase, XmlContent):
"""tests for the serviceValidate view"""
def setUp(self):
"""preparing test context"""
self.service = 'http://127.0.0.1:45678'
self.service_pattern = models.ServicePattern.objects.create(
name="localhost",
pattern="^https?://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
proxy_callback=True
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
self.service_user_field = "https://user_field.example.com"
self.service_pattern_user_field = models.ServicePattern.objects.create(
name="user field",
pattern="^https://user_field\.example\.com(/.*)?$",
user_field="alias"
)
self.service_one_attribute = "https://one_attribute.example.com"
self.service_pattern_one_attribute = models.ServicePattern.objects.create(
name="one_attribute",
pattern="^https://one_attribute\.example\.com(/.*)?$"
)
models.ReplaceAttributName.objects.create(
name="nom",
service_pattern=self.service_pattern_one_attribute
)
def test_validate_service_view_ok(self):
"""test with a valid (ticket, service), the username and all attributes are transmited"""
ticket = get_user_ticket_request(self.service)[1]
client = Client()
response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service})
self.assert_success(response, settings.CAS_TEST_USER, settings.CAS_TEST_ATTRIBUTES)
def test_validate_service_view_ok_one_attribute(self):
"""
test with a valid (ticket, service), the username and
the 'nom' only attribute are transmited
"""
ticket = get_user_ticket_request(self.service_one_attribute)[1]
client = Client()
response = client.get(
'/serviceValidate',
{'ticket': ticket.value, 'service': self.service_one_attribute}
)
self.assert_success(
response,
settings.CAS_TEST_USER,
{'nom': settings.CAS_TEST_ATTRIBUTES['nom']}
)
def test_validate_service_view_badservice(self):
"""test with a valid ticket but a bad service, the validatin should fail"""
ticket = get_user_ticket_request(self.service)[1]
client = Client()
bad_service = "https://www.example.org"
response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': bad_service})
self.assert_error(
response,
"INVALID_SERVICE",
bad_service
)
def test_validate_service_view_badticket_goodprefix(self):
"""
test with a good service bud a bad ticket begining with ST-,
the validation should fail with the error (INVALID_TICKET, ticket not found)
"""
get_user_ticket_request(self.service)
client = Client()
bad_ticket = "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX
response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
self.assert_error(
response,
"INVALID_TICKET",
'ticket not found'
)
def test_validate_service_view_badticket_badprefix(self):
"""
test with a good service bud a bad ticket not begining with ST-,
the validation should fail with the error (INVALID_TICKET, `the ticket`)
"""
get_user_ticket_request(self.service)
client = Client()
bad_ticket = "RANDOM"
response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
self.assert_error(
response,
"INVALID_TICKET",
bad_ticket
)
def test_validate_service_view_ok_pgturl(self):
"""test the retrieval of a ProxyGrantingTicket"""
(host, port) = utils.PGTUrlHandler.run()[1:3]
service = "http://%s:%s" % (host, port)
ticket = get_user_ticket_request(service)[1]
client = Client()
response = client.get(
'/serviceValidate',
{'ticket': ticket.value, 'service': service, 'pgtUrl': service}
)
pgt_params = utils.PGTUrlHandler.PARAMS.copy()
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
pgtiou = root.xpath(
"//cas:proxyGrantingTicket",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(pgtiou), 1)
self.assertEqual(pgt_params["pgtIou"], pgtiou[0].text)
self.assertTrue("pgtId" in pgt_params)
def test_validate_service_pgturl_sslerror(self):
"""test the retrieval of a ProxyGrantingTicket with a SSL error on the pgtUrl"""
(host, port) = utils.PGTUrlHandler.run()[1:3]
service = "https://%s:%s" % (host, port)
ticket = get_user_ticket_request(service)[1]
client = Client()
response = client.get(
'/serviceValidate',
{'ticket': ticket.value, 'service': service, 'pgtUrl': service}
)
self.assert_error(
response,
"INVALID_PROXY_CALLBACK",
)
def test_validate_service_pgturl_404(self):
"""
test the retrieval on a ProxyGrantingTicket then to pgtUrl return a http error.
PGT creation should be aborted but the ticket still be valid
"""
(host, port) = utils.PGTUrlHandler404.run()[1:3]
service = "http://%s:%s" % (host, port)
ticket = get_user_ticket_request(service)[1]
client = Client()
response = client.get(
'/serviceValidate',
{'ticket': ticket.value, 'service': service, 'pgtUrl': service}
)
root = self.assert_success(response, settings.CAS_TEST_USER, settings.CAS_TEST_ATTRIBUTES)
pgtiou = root.xpath(
"//cas:proxyGrantingTicket",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertFalse(pgtiou)
def test_validate_service_pgturl_bad_proxy_callback(self):
"""test the retrieval of a ProxyGrantingTicket, not allowed pgtUrl should be denied"""
self.service_pattern.proxy_callback = False
self.service_pattern.save()
ticket = get_user_ticket_request(self.service)[1]
client = Client()
response = client.get(
'/serviceValidate',
{'ticket': ticket.value, 'service': self.service, 'pgtUrl': self.service}
)
self.assert_error(
response,
"INVALID_PROXY_CALLBACK",
"callback url not allowed by configuration"
)
self.service_pattern.proxy_callback = True
ticket = get_user_ticket_request(self.service)[1]
client = Client()
response = client.get(
'/serviceValidate',
{'ticket': ticket.value, 'service': self.service, 'pgtUrl': "https://www.example.org"}
)
self.assert_error(
response,
"INVALID_PROXY_CALLBACK",
"callback url not allowed by configuration"
)
def test_validate_user_field_ok(self):
"""
test with a good user_field. A bad user_field (that evaluate to False)
wont happed cause it is filtered in the login view
"""
ticket = get_user_ticket_request(self.service_user_field)[1]
client = Client()
response = client.get(
'/serviceValidate',
{'ticket': ticket.value, 'service': self.service_user_field}
)
self.assert_success(
response,
settings.CAS_TEST_ATTRIBUTES["alias"][0],
{}
)
def test_validate_missing_parameter(self):
"""test with a missing GET parameter among [service, ticket]"""
ticket = get_user_ticket_request(self.service)[1]
client = Client()
params = {'ticket': ticket.value, 'service': self.service}
for key in ['ticket', 'service']:
send_params = params.copy()
del send_params[key]
response = client.get('/serviceValidate', send_params)
self.assert_error(
response,
"INVALID_REQUEST",
"you must specify a service and a ticket"
)
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class ProxyTestCase(TestCase, BaseServicePattern, XmlContent):
"""tests for the proxy view"""
def setUp(self):
"""preparing test context"""
self.setup_service_patterns(proxy=True)
self.service = 'http://127.0.0.1'
self.service_pattern = models.ServicePattern.objects.create(
name="localhost",
pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
proxy=True,
proxy_callback=True
)
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_validate_proxy_ok(self):
"""
Get a PGT, get a proxy ticket, validate it. Validation should succeed and
show the proxy service URL.
"""
params = get_pgt()
# get a proxy ticket
client1 = Client()
response = client1.get('/proxy', {'pgt': params['pgtId'], 'targetService': self.service})
self.assertEqual(response.status_code, 200)
root = etree.fromstring(response.content)
sucess = root.xpath("//cas:proxySuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertTrue(sucess)
proxy_ticket = root.xpath(
"//cas:proxyTicket",
namespaces={'cas': "http://www.yale.edu/tp/cas"}
)
self.assertEqual(len(proxy_ticket), 1)
proxy_ticket = proxy_ticket[0].text
# validate the proxy ticket
client2 = Client()
response = client2.get('/proxyValidate', {'ticket': proxy_ticket, 'service': self.service})
root = self.assert_success(
response,
settings.CAS_TEST_USER,
settings.CAS_TEST_ATTRIBUTES
)
# check that the proxy is send to the end service
proxies = root.xpath("//cas:proxies", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(proxies), 1)
proxy = proxies[0].xpath("//cas:proxy", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(proxy), 1)
self.assertEqual(proxy[0].text, params["service"])
def test_validate_proxy_bad_pgt(self):
"""Try to get a ProxyTicket with a bad PGT. The PT generation should fail"""
params = get_pgt()
client = Client()
response = client.get(
'/proxy',
{
'pgt': "%s-RANDOM" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX,
'targetService': params['service']
}
)
self.assert_error(
response,
"INVALID_TICKET",
"PGT %s-RANDOM not found" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX
)
def test_validate_proxy_bad_service(self):
"""
Try to get a ProxyTicket for a denied service and
a service that do not allow PT. The PT generation should fail.
"""
params = get_pgt()
client1 = Client()
response = client1.get(
'/proxy',
{'pgt': params['pgtId'], 'targetService': "https://www.example.org"}
)
self.assert_error(
response,
"UNAUTHORIZED_SERVICE",
"https://www.example.org"
)
# service do not allow proxy ticket
self.service_pattern.proxy = False
self.service_pattern.save()
client2 = Client()
response = client2.get(
'/proxy',
{'pgt': params['pgtId'], 'targetService': params['service']}
)
self.assert_error(
response,
"UNAUTHORIZED_SERVICE",
'the service %s do not allow proxy ticket' % params['service']
)
self.service_pattern.proxy = True
self.service_pattern.save()
def test_proxy_unauthorized_user(self):
"""
Try to get a PT for services that do not allow the current user:
* first with a service that restrict allower username
* second with a service requiring somes conditions on the user attributes
* third with a service using a particular user attribute as username
All this tests should fail
"""
params = get_pgt()
for service in [
self.service_restrict_user_fail,
self.service_filter_fail,
self.service_field_needed_fail
]:
client = Client()
response = client.get(
'/proxy',
{'pgt': params['pgtId'], 'targetService': service}
)
self.assert_error(
response,
"UNAUTHORIZED_USER",
'User %s not allowed on %s' % (settings.CAS_TEST_USER, service)
)
def test_proxy_missing_parameter(self):
"""Try to get a PGT with some missing GET parameters. The PT should not be emited"""
params = get_pgt()
base_params = {'pgt': params['pgtId'], 'targetService': "https://www.example.org"}
for key in ["pgt", 'targetService']:
send_params = base_params.copy()
del send_params[key]
client = Client()
response = client.get("/proxy", send_params)
self.assert_error(
response,
"INVALID_REQUEST",
'you must specify and pgt and targetService'
)