1033 lines
39 KiB
Python
1033 lines
39 KiB
Python
from .default_settings import settings
|
|
|
|
import django
|
|
from django.test import TestCase
|
|
from django.test import Client
|
|
|
|
import re
|
|
import six
|
|
import random
|
|
import json
|
|
from lxml import etree
|
|
from six.moves import range
|
|
|
|
from cas_server import models
|
|
from cas_server import utils
|
|
|
|
|
|
def copy_form(form):
|
|
"""Copy form value into a dict"""
|
|
params = {}
|
|
for field in form:
|
|
if field.value():
|
|
params[field.name] = field.value()
|
|
else:
|
|
params[field.name] = ""
|
|
return params
|
|
|
|
|
|
def get_login_page_params(client=None):
|
|
"""Return a client and the POST params for the client to login"""
|
|
if client is None:
|
|
client = Client()
|
|
response = client.get('/login')
|
|
params = copy_form(response.context["form"])
|
|
return client, params
|
|
|
|
|
|
def get_auth_client(**update):
|
|
"""return a authenticated client"""
|
|
client, params = get_login_page_params()
|
|
params["username"] = settings.CAS_TEST_USER
|
|
params["password"] = settings.CAS_TEST_PASSWORD
|
|
params.update(update)
|
|
|
|
client.post('/login', params)
|
|
return client
|
|
|
|
|
|
def get_user_ticket_request(service):
|
|
"""Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
|
|
client = get_auth_client()
|
|
response = client.get("/login", {"service": service})
|
|
ticket_value = response['Location'].split('ticket=')[-1]
|
|
user = models.User.objects.get(
|
|
username=settings.CAS_TEST_USER,
|
|
session_key=client.session.session_key
|
|
)
|
|
ticket = models.ServiceTicket.objects.get(value=ticket_value)
|
|
return (user, ticket)
|
|
|
|
|
|
def get_pgt():
|
|
"""return a dict contening a service, user and PGT ticket for this service"""
|
|
(host, port) = utils.PGTUrlHandler.run()[1:3]
|
|
service = "http://%s:%s" % (host, port)
|
|
|
|
(user, ticket) = get_user_ticket_request(service)
|
|
|
|
client = Client()
|
|
client.get('/serviceValidate', {'ticket': ticket.value, 'service': service, 'pgtUrl': service})
|
|
params = utils.PGTUrlHandler.PARAMS.copy()
|
|
|
|
params["service"] = service
|
|
params["user"] = user
|
|
|
|
return params
|
|
|
|
|
|
class CheckPasswordCase(TestCase):
|
|
"""Tests for the utils function `utils.check_password`"""
|
|
|
|
def setUp(self):
|
|
"""Generate random bytes string that will be used ass passwords"""
|
|
self.password1 = utils.gen_saml_id()
|
|
self.password2 = utils.gen_saml_id()
|
|
if not isinstance(self.password1, bytes):
|
|
self.password1 = self.password1.encode("utf8")
|
|
self.password2 = self.password2.encode("utf8")
|
|
|
|
def test_setup(self):
|
|
"""check that generated password are bytes"""
|
|
self.assertIsInstance(self.password1, bytes)
|
|
self.assertIsInstance(self.password2, bytes)
|
|
|
|
def test_plain(self):
|
|
"""test the plain auth method"""
|
|
self.assertTrue(utils.check_password("plain", self.password1, self.password1, "utf8"))
|
|
self.assertFalse(utils.check_password("plain", self.password1, self.password2, "utf8"))
|
|
|
|
def test_crypt(self):
|
|
"""test the crypt auth method"""
|
|
if six.PY3:
|
|
hashed_password1 = utils.crypt.crypt(
|
|
self.password1.decode("utf8"),
|
|
"$6$UVVAQvrMyXMF3FF3"
|
|
).encode("utf8")
|
|
else:
|
|
hashed_password1 = utils.crypt.crypt(self.password1, "$6$UVVAQvrMyXMF3FF3")
|
|
|
|
self.assertTrue(utils.check_password("crypt", self.password1, hashed_password1, "utf8"))
|
|
self.assertFalse(utils.check_password("crypt", self.password2, hashed_password1, "utf8"))
|
|
|
|
def test_ldap_ssha(self):
|
|
"""test the ldap auth method with a {SSHA} scheme"""
|
|
salt = b"UVVAQvrMyXMF3FF3"
|
|
hashed_password1 = utils.LdapHashUserPassword.hash(b'{SSHA}', self.password1, salt, "utf8")
|
|
|
|
self.assertIsInstance(hashed_password1, bytes)
|
|
self.assertTrue(utils.check_password("ldap", self.password1, hashed_password1, "utf8"))
|
|
self.assertFalse(utils.check_password("ldap", self.password2, hashed_password1, "utf8"))
|
|
|
|
def test_hex_md5(self):
|
|
"""test the hex_md5 auth method"""
|
|
hashed_password1 = utils.hashlib.md5(self.password1).hexdigest()
|
|
|
|
self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8"))
|
|
self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8"))
|
|
|
|
def test_hex_sha512(self):
|
|
"""test the hex_sha512 auth method"""
|
|
hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest()
|
|
|
|
self.assertTrue(
|
|
utils.check_password("hex_sha512", self.password1, hashed_password1, "utf8")
|
|
)
|
|
self.assertFalse(
|
|
utils.check_password("hex_sha512", self.password2, hashed_password1, "utf8")
|
|
)
|
|
|
|
|
|
class LoginTestCase(TestCase):
|
|
"""Tests for the login view"""
|
|
def setUp(self):
|
|
"""
|
|
Prepare the test context:
|
|
* set the auth class to 'cas_server.auth.TestAuthUser'
|
|
* create a service pattern for https://www.example.com/**
|
|
* Set the service pattern to return all user attributes
|
|
"""
|
|
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
|
|
|
|
# For general purpose testing
|
|
self.service_pattern = models.ServicePattern.objects.create(
|
|
name="example",
|
|
pattern="^https://www\.example\.com(/.*)?$",
|
|
)
|
|
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
|
|
|
|
# For testing the restrict_users attributes
|
|
self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create(
|
|
name="restrict_user_fail",
|
|
pattern="^https://restrict_user_fail\.example\.com(/.*)?$",
|
|
restrict_users=True,
|
|
)
|
|
self.service_pattern_restrict_user_success = models.ServicePattern.objects.create(
|
|
name="restrict_user_success",
|
|
pattern="^https://restrict_user_success\.example\.com(/.*)?$",
|
|
restrict_users=True,
|
|
)
|
|
models.Username.objects.create(
|
|
value=settings.CAS_TEST_USER,
|
|
service_pattern=self.service_pattern_restrict_user_success
|
|
)
|
|
|
|
# For testing the user attributes filtering conditions
|
|
self.service_pattern_filter_fail = models.ServicePattern.objects.create(
|
|
name="filter_fail",
|
|
pattern="^https://filter_fail\.example\.com(/.*)?$",
|
|
)
|
|
models.FilterAttributValue.objects.create(
|
|
attribut="right",
|
|
pattern="^admin$",
|
|
service_pattern=self.service_pattern_filter_fail
|
|
)
|
|
self.service_pattern_filter_success = models.ServicePattern.objects.create(
|
|
name="filter_success",
|
|
pattern="^https://filter_success\.example\.com(/.*)?$",
|
|
)
|
|
models.FilterAttributValue.objects.create(
|
|
attribut="email",
|
|
pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']),
|
|
service_pattern=self.service_pattern_filter_success
|
|
)
|
|
|
|
# For testing the user_field attributes
|
|
self.service_pattern_field_needed_fail = models.ServicePattern.objects.create(
|
|
name="field_needed_fail",
|
|
pattern="^https://field_needed_fail\.example\.com(/.*)?$",
|
|
user_field="uid"
|
|
)
|
|
self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
|
|
name="field_needed_success",
|
|
pattern="^https://field_needed_success\.example\.com(/.*)?$",
|
|
user_field="nom"
|
|
)
|
|
|
|
def assert_logged(self, client, response, warn=False, code=200):
|
|
"""Assertions testing that client is well authenticated"""
|
|
self.assertEqual(response.status_code, code)
|
|
self.assertTrue(
|
|
(
|
|
b"You have successfully logged into "
|
|
b"the Central Authentication Service"
|
|
) in response.content
|
|
)
|
|
self.assertTrue(client.session["username"] == settings.CAS_TEST_USER)
|
|
self.assertTrue(client.session["warn"] is warn)
|
|
self.assertTrue(client.session["authenticated"] is True)
|
|
|
|
self.assertTrue(
|
|
models.User.objects.get(
|
|
username=settings.CAS_TEST_USER,
|
|
session_key=client.session.session_key
|
|
)
|
|
)
|
|
|
|
def assert_login_failed(self, client, response, code=200):
|
|
"""Assertions testing a failed login attempt"""
|
|
self.assertEqual(response.status_code, code)
|
|
self.assertFalse(
|
|
(
|
|
b"You have successfully logged into "
|
|
b"the Central Authentication Service"
|
|
) in response.content
|
|
)
|
|
|
|
self.assertTrue(client.session.get("username") is None)
|
|
self.assertTrue(client.session.get("warn") is None)
|
|
self.assertTrue(client.session.get("authenticated") is None)
|
|
|
|
def test_login_view_post_goodpass_goodlt(self):
|
|
"""Test a successul login"""
|
|
client, params = get_login_page_params()
|
|
params["username"] = settings.CAS_TEST_USER
|
|
params["password"] = settings.CAS_TEST_PASSWORD
|
|
self.assertTrue(params['lt'] in client.session['lt'])
|
|
|
|
response = client.post('/login', params)
|
|
self.assert_logged(client, response)
|
|
# LoginTicket conssumed
|
|
self.assertTrue(params['lt'] not in client.session['lt'])
|
|
|
|
def test_login_view_post_goodpass_goodlt_warn(self):
|
|
"""Test a successul login requesting to be warned before creating services tickets"""
|
|
client, params = get_login_page_params()
|
|
params["username"] = settings.CAS_TEST_USER
|
|
params["password"] = settings.CAS_TEST_PASSWORD
|
|
params["warn"] = "on"
|
|
|
|
response = client.post('/login', params)
|
|
self.assert_logged(client, response, warn=True)
|
|
|
|
def test_lt_max(self):
|
|
"""Check we only keep the last 100 Login Ticket for a user"""
|
|
client, params = get_login_page_params()
|
|
current_lt = params["lt"]
|
|
i_in_test = random.randint(0, 100)
|
|
i_not_in_test = random.randint(100, 150)
|
|
for i in range(150):
|
|
if i == i_in_test:
|
|
self.assertTrue(current_lt in client.session['lt'])
|
|
if i == i_not_in_test:
|
|
self.assertTrue(current_lt not in client.session['lt'])
|
|
self.assertTrue(len(client.session['lt']) <= 100)
|
|
client, params = get_login_page_params(client)
|
|
self.assertTrue(len(client.session['lt']) <= 100)
|
|
|
|
def test_login_view_post_badlt(self):
|
|
"""Login attempt with a bad LoginTicket"""
|
|
client, params = get_login_page_params()
|
|
params["username"] = settings.CAS_TEST_USER
|
|
params["password"] = settings.CAS_TEST_PASSWORD
|
|
params["lt"] = 'LT-random'
|
|
|
|
response = client.post('/login', params)
|
|
|
|
self.assert_login_failed(client, response)
|
|
self.assertTrue(b"Invalid login ticket" in response.content)
|
|
|
|
def test_login_view_post_badpass_good_lt(self):
|
|
"""Login attempt with a bad password"""
|
|
client, params = get_login_page_params()
|
|
params["username"] = settings.CAS_TEST_USER
|
|
params["password"] = "test2"
|
|
response = client.post('/login', params)
|
|
|
|
self.assert_login_failed(client, response)
|
|
self.assertTrue(
|
|
(
|
|
b"The credentials you provided cannot be "
|
|
b"determined to be authentic"
|
|
) in response.content
|
|
)
|
|
|
|
def assert_ticket_attributes(self, client, ticket_value):
|
|
"""check the ticket attributes in the db"""
|
|
user = models.User.objects.get(
|
|
username=settings.CAS_TEST_USER,
|
|
session_key=client.session.session_key
|
|
)
|
|
self.assertTrue(user)
|
|
ticket = models.ServiceTicket.objects.get(value=ticket_value)
|
|
self.assertEqual(ticket.user, user)
|
|
self.assertEqual(ticket.attributs, settings.CAS_TEST_ATTRIBUTES)
|
|
self.assertEqual(ticket.validate, False)
|
|
self.assertEqual(ticket.service_pattern, self.service_pattern)
|
|
|
|
def assert_service_ticket(self, client, response):
|
|
"""check that a ticket is well emited when requested on a allowed service"""
|
|
self.assertEqual(response.status_code, 302)
|
|
self.assertTrue(response.has_header('Location'))
|
|
self.assertTrue(
|
|
response['Location'].startswith(
|
|
"https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX
|
|
)
|
|
)
|
|
|
|
ticket_value = response['Location'].split('ticket=')[-1]
|
|
self.assert_ticket_attributes(client, ticket_value)
|
|
|
|
def test_view_login_get_allowed_service(self):
|
|
"""Request a ticket for an allowed service by an unauthenticated client"""
|
|
client = Client()
|
|
response = client.get("/login?service=https://www.example.com")
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertTrue(
|
|
(
|
|
b"Authentication required by service "
|
|
b"example (https://www.example.com)"
|
|
) in response.content
|
|
)
|
|
|
|
def test_view_login_get_denied_service(self):
|
|
"""Request a ticket for an denied service by an unauthenticated client"""
|
|
client = Client()
|
|
response = client.get("/login?service=https://www.example.net")
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertTrue(b"Service https://www.example.net non allowed" in response.content)
|
|
|
|
def test_view_login_get_auth_allowed_service(self):
|
|
"""Request a ticket for an allowed service by an authenticated client"""
|
|
# client is already authenticated
|
|
client = get_auth_client()
|
|
response = client.get("/login?service=https://www.example.com")
|
|
self.assert_service_ticket(client, response)
|
|
|
|
def test_view_login_get_auth_allowed_service_warn(self):
|
|
"""Request a ticket for an allowed service by an authenticated client"""
|
|
# client is already authenticated
|
|
client = get_auth_client(warn="on")
|
|
response = client.get("/login?service=https://www.example.com")
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertTrue(
|
|
(
|
|
b"Authentication has been required by service "
|
|
b"example (https://www.example.com)"
|
|
) in response.content
|
|
)
|
|
|
|
params = copy_form(response.context["form"])
|
|
response = client.post("/login", params)
|
|
self.assert_service_ticket(client, response)
|
|
|
|
def test_view_login_get_auth_denied_service(self):
|
|
"""Request a ticket for a not allowed service by an authenticated client"""
|
|
client = get_auth_client()
|
|
response = client.get("/login?service=https://www.example.org")
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertTrue(b"Service https://www.example.org non allowed" in response.content)
|
|
|
|
def test_user_logged_not_in_db(self):
|
|
"""If the user is logged but has been delete from the database, it should be logged out"""
|
|
client = get_auth_client()
|
|
models.User.objects.get(
|
|
username=settings.CAS_TEST_USER,
|
|
session_key=client.session.session_key
|
|
).delete()
|
|
response = client.get("/login")
|
|
|
|
self.assert_login_failed(client, response, code=302)
|
|
if django.VERSION < (1, 9):
|
|
self.assertEqual(response["Location"], "http://testserver/login")
|
|
else:
|
|
self.assertEqual(response["Location"], "/login?")
|
|
|
|
def test_service_restrict_user(self):
|
|
"""Testing the restric user capability fro a service"""
|
|
service = "https://restrict_user_fail.example.com"
|
|
client = get_auth_client()
|
|
response = client.get("/login", {'service': service})
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertTrue(b"Username non allowed" in response.content)
|
|
|
|
service = "https://restrict_user_success.example.com"
|
|
response = client.get("/login", {'service': service})
|
|
self.assertEqual(response.status_code, 302)
|
|
self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
|
|
|
|
def test_service_filter(self):
|
|
"""Test the filtering on user attributes"""
|
|
service = "https://filter_fail.example.com"
|
|
client = get_auth_client()
|
|
response = client.get("/login", {'service': service})
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertTrue(b"User charateristics non allowed" in response.content)
|
|
|
|
service = "https://filter_success.example.com"
|
|
response = client.get("/login", {'service': service})
|
|
self.assertEqual(response.status_code, 302)
|
|
self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
|
|
|
|
def test_service_user_field(self):
|
|
"""Test using a user attribute as username: case on if the attribute exists or not"""
|
|
service = "https://field_needed_fail.example.com"
|
|
client = get_auth_client()
|
|
response = client.get("/login", {'service': service})
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertTrue(b"The attribut uid is needed to use that service" in response.content)
|
|
|
|
service = "https://field_needed_success.example.com"
|
|
response = client.get("/login", {'service': service})
|
|
self.assertEqual(response.status_code, 302)
|
|
self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
|
|
|
|
def test_gateway(self):
|
|
"""test gateway parameter"""
|
|
|
|
# First with an authenticated client that fail to get a ticket for a service
|
|
service = "https://restrict_user_fail.example.com"
|
|
client = get_auth_client()
|
|
response = client.get("/login", {'service': service, 'gateway': 'on'})
|
|
self.assertEqual(response.status_code, 302)
|
|
self.assertEqual(response["Location"], service)
|
|
|
|
# second for an user not yet authenticated on a valid service
|
|
client = Client()
|
|
response = client.get('/login', {'service': service, 'gateway': 'on'})
|
|
self.assertEqual(response.status_code, 302)
|
|
self.assertEqual(response["Location"], service)
|
|
|
|
def test_renew(self):
|
|
"""test the authentication renewal request from a service"""
|
|
service = "https://www.example.com"
|
|
client = get_auth_client()
|
|
response = client.get("/login", {'service': service, 'renew': 'on'})
|
|
# we are ask to reauthenticate
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertTrue(
|
|
(
|
|
b"Authentication renewal required by "
|
|
b"service example (https://www.example.com)"
|
|
) in response.content
|
|
)
|
|
params = copy_form(response.context["form"])
|
|
params["username"] = settings.CAS_TEST_USER
|
|
params["password"] = settings.CAS_TEST_PASSWORD
|
|
self.assertEqual(params["renew"], True)
|
|
response = client.post("/login", params)
|
|
self.assertEqual(response.status_code, 302)
|
|
ticket_value = response['Location'].split('ticket=')[-1]
|
|
ticket = models.ServiceTicket.objects.get(value=ticket_value)
|
|
# the created ticket is marked has being gottent after a renew
|
|
self.assertEqual(ticket.renew, True)
|
|
|
|
def test_ajax_login_required(self):
|
|
"""test ajax, login required"""
|
|
client = Client()
|
|
response = client.get("/login", HTTP_X_AJAX='on')
|
|
data = json.loads(response.content.decode("utf8"))
|
|
self.assertEqual(data["status"], "error")
|
|
self.assertEqual(data["detail"], "login required")
|
|
self.assertEqual(data["url"], "/login?")
|
|
|
|
def test_ajax_logged_user_deleted(self):
|
|
"""test ajax user logged deleted: login required"""
|
|
client = get_auth_client()
|
|
user = models.User.objects.get(
|
|
username=settings.CAS_TEST_USER,
|
|
session_key=client.session.session_key
|
|
)
|
|
user.delete()
|
|
response = client.get("/login", HTTP_X_AJAX='on')
|
|
data = json.loads(response.content.decode("utf8"))
|
|
self.assertEqual(data["status"], "error")
|
|
self.assertEqual(data["detail"], "login required")
|
|
self.assertEqual(data["url"], "/login?")
|
|
|
|
def test_ajax_logged(self):
|
|
"""test ajax user is successfully logged"""
|
|
client = get_auth_client()
|
|
response = client.get("/login", HTTP_X_AJAX='on')
|
|
data = json.loads(response.content.decode("utf8"))
|
|
self.assertEqual(data["status"], "success")
|
|
self.assertEqual(data["detail"], "logged")
|
|
|
|
def test_ajax_get_ticket_success(self):
|
|
"""test ajax retrieve a ticket for an allowed service"""
|
|
service = "https://www.example.com"
|
|
client = get_auth_client()
|
|
response = client.get("/login", {'service': service}, HTTP_X_AJAX='on')
|
|
data = json.loads(response.content.decode("utf8"))
|
|
self.assertEqual(data["status"], "success")
|
|
self.assertEqual(data["detail"], "auth")
|
|
self.assertTrue(data["url"].startswith('%s?ticket=' % service))
|
|
|
|
def test_ajax_get_ticket_fail(self):
|
|
"""test ajax retrieve a ticket for a denied service"""
|
|
service = "https://www.example.org"
|
|
client = get_auth_client()
|
|
response = client.get("/login", {'service': service}, HTTP_X_AJAX='on')
|
|
data = json.loads(response.content.decode("utf8"))
|
|
self.assertEqual(data["status"], "error")
|
|
self.assertEqual(data["detail"], "auth")
|
|
self.assertEqual(data["messages"][0]["level"], "error")
|
|
self.assertEqual(
|
|
data["messages"][0]["message"],
|
|
"Service https://www.example.org non allowed."
|
|
)
|
|
|
|
def test_ajax_get_ticket_warn(self):
|
|
"""test get a ticket but user asked to be warned"""
|
|
service = "https://www.example.com"
|
|
client = get_auth_client(warn="on")
|
|
response = client.get("/login", {'service': service}, HTTP_X_AJAX='on')
|
|
data = json.loads(response.content.decode("utf8"))
|
|
self.assertEqual(data["status"], "error")
|
|
self.assertEqual(data["detail"], "confirmation needed")
|
|
|
|
|
|
class LogoutTestCase(TestCase):
|
|
|
|
def setUp(self):
|
|
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
|
|
|
|
def test_logout_view(self):
|
|
client = get_auth_client()
|
|
|
|
response = client.get("/login")
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertTrue(
|
|
(
|
|
b"You have successfully logged into "
|
|
b"the Central Authentication Service"
|
|
) in response.content
|
|
)
|
|
|
|
response = client.get("/logout")
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertTrue(
|
|
(
|
|
b"You have successfully logged out from "
|
|
b"the Central Authentication Service"
|
|
) in response.content
|
|
)
|
|
|
|
response = client.get("/login")
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertFalse(
|
|
(
|
|
b"You have successfully logged into "
|
|
b"the Central Authentication Service"
|
|
) in response.content
|
|
)
|
|
|
|
def test_logout_view_url(self):
|
|
client = get_auth_client()
|
|
|
|
response = client.get('/logout?url=https://www.example.com')
|
|
self.assertEqual(response.status_code, 302)
|
|
self.assertTrue(response.has_header("Location"))
|
|
self.assertEqual(response["Location"], "https://www.example.com")
|
|
|
|
response = client.get("/login")
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertFalse(
|
|
(
|
|
b"You have successfully logged into "
|
|
b"the Central Authentication Service"
|
|
) in response.content
|
|
)
|
|
|
|
def test_logout_view_service(self):
|
|
client = get_auth_client()
|
|
|
|
response = client.get('/logout?service=https://www.example.com')
|
|
self.assertEqual(response.status_code, 302)
|
|
self.assertTrue(response.has_header("Location"))
|
|
self.assertEqual(response["Location"], "https://www.example.com")
|
|
|
|
response = client.get("/login")
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertFalse(
|
|
(
|
|
b"You have successfully logged into "
|
|
b"the Central Authentication Service"
|
|
) in response.content
|
|
)
|
|
|
|
|
|
class AuthTestCase(TestCase):
|
|
|
|
def setUp(self):
|
|
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
|
|
self.service = 'https://www.example.com'
|
|
models.ServicePattern.objects.create(
|
|
name="example",
|
|
pattern="^https://www\.example\.com(/.*)?$"
|
|
)
|
|
|
|
def test_auth_view_goodpass(self):
|
|
settings.CAS_AUTH_SHARED_SECRET = 'test'
|
|
client = Client()
|
|
response = client.post(
|
|
'/auth',
|
|
{
|
|
'username': settings.CAS_TEST_USER,
|
|
'password': settings.CAS_TEST_PASSWORD,
|
|
'service': self.service,
|
|
'secret': 'test'
|
|
}
|
|
)
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertEqual(response.content, b'yes\n')
|
|
|
|
def test_auth_view_badpass(self):
|
|
settings.CAS_AUTH_SHARED_SECRET = 'test'
|
|
client = Client()
|
|
response = client.post(
|
|
'/auth',
|
|
{
|
|
'username': settings.CAS_TEST_USER,
|
|
'password': 'badpass',
|
|
'service': self.service,
|
|
'secret': 'test'
|
|
}
|
|
)
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertEqual(response.content, b'no\n')
|
|
|
|
def test_auth_view_badservice(self):
|
|
settings.CAS_AUTH_SHARED_SECRET = 'test'
|
|
client = Client()
|
|
response = client.post(
|
|
'/auth',
|
|
{
|
|
'username': settings.CAS_TEST_USER,
|
|
'password': settings.CAS_TEST_PASSWORD,
|
|
'service': 'https://www.example.org',
|
|
'secret': 'test'
|
|
}
|
|
)
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertEqual(response.content, b'no\n')
|
|
|
|
def test_auth_view_badsecret(self):
|
|
settings.CAS_AUTH_SHARED_SECRET = 'test'
|
|
client = Client()
|
|
response = client.post(
|
|
'/auth',
|
|
{
|
|
'username': settings.CAS_TEST_USER,
|
|
'password': settings.CAS_TEST_PASSWORD,
|
|
'service': self.service,
|
|
'secret': 'badsecret'
|
|
}
|
|
)
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertEqual(response.content, b'no\n')
|
|
|
|
def test_auth_view_badsettings(self):
|
|
settings.CAS_AUTH_SHARED_SECRET = None
|
|
client = Client()
|
|
response = client.post(
|
|
'/auth',
|
|
{
|
|
'username': settings.CAS_TEST_USER,
|
|
'password': settings.CAS_TEST_PASSWORD,
|
|
'service': self.service,
|
|
'secret': 'test'
|
|
}
|
|
)
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertEqual(response.content, b"no\nplease set CAS_AUTH_SHARED_SECRET")
|
|
|
|
|
|
class ValidateTestCase(TestCase):
|
|
|
|
def setUp(self):
|
|
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
|
|
self.service = 'https://www.example.com'
|
|
self.service_pattern = models.ServicePattern.objects.create(
|
|
name="example",
|
|
pattern="^https://www\.example\.com(/.*)?$"
|
|
)
|
|
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
|
|
|
|
def test_validate_view_ok(self):
|
|
ticket = get_user_ticket_request(self.service)[1]
|
|
|
|
client = Client()
|
|
response = client.get('/validate', {'ticket': ticket.value, 'service': self.service})
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertEqual(response.content, b'yes\ntest\n')
|
|
|
|
def test_validate_view_badservice(self):
|
|
ticket = get_user_ticket_request(self.service)[1]
|
|
|
|
client = Client()
|
|
response = client.get(
|
|
'/validate',
|
|
{'ticket': ticket.value, 'service': "https://www.example.org"}
|
|
)
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertEqual(response.content, b'no\n')
|
|
|
|
def test_validate_view_badticket(self):
|
|
get_user_ticket_request(self.service)
|
|
|
|
client = Client()
|
|
response = client.get(
|
|
'/validate',
|
|
{'ticket': "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX, 'service': self.service}
|
|
)
|
|
self.assertEqual(response.status_code, 200)
|
|
self.assertEqual(response.content, b'no\n')
|
|
|
|
|
|
class ValidateServiceTestCase(TestCase):
|
|
|
|
def setUp(self):
|
|
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
|
|
self.service = 'http://127.0.0.1:45678'
|
|
self.service_pattern = models.ServicePattern.objects.create(
|
|
name="localhost",
|
|
pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
|
|
proxy_callback=True
|
|
)
|
|
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
|
|
|
|
def test_validate_service_view_ok(self):
|
|
ticket = get_user_ticket_request(self.service)[1]
|
|
|
|
client = Client()
|
|
response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': self.service})
|
|
self.assertEqual(response.status_code, 200)
|
|
|
|
root = etree.fromstring(response.content)
|
|
sucess = root.xpath(
|
|
"//cas:authenticationSuccess",
|
|
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
)
|
|
self.assertTrue(sucess)
|
|
|
|
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
|
self.assertEqual(len(users), 1)
|
|
self.assertEqual(users[0].text, settings.CAS_TEST_USER)
|
|
|
|
attributes = root.xpath(
|
|
"//cas:attributes",
|
|
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
)
|
|
self.assertEqual(len(attributes), 1)
|
|
attrs1 = set()
|
|
for attr in attributes[0]:
|
|
attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
|
|
|
|
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
|
self.assertEqual(len(attributes), len(attrs1))
|
|
attrs2 = set()
|
|
for attr in attributes:
|
|
attrs2.add((attr.attrib['name'], attr.attrib['value']))
|
|
original = set()
|
|
for key, value in settings.CAS_TEST_ATTRIBUTES.items():
|
|
if isinstance(value, list):
|
|
for sub_value in value:
|
|
original.add((key, sub_value))
|
|
else:
|
|
original.add((key, value))
|
|
self.assertEqual(attrs1, attrs2)
|
|
self.assertEqual(attrs1, original)
|
|
|
|
def test_validate_service_view_badservice(self):
|
|
ticket = get_user_ticket_request(self.service)[1]
|
|
|
|
client = Client()
|
|
bad_service = "https://www.example.org"
|
|
response = client.get('/serviceValidate', {'ticket': ticket.value, 'service': bad_service})
|
|
self.assertEqual(response.status_code, 200)
|
|
|
|
root = etree.fromstring(response.content)
|
|
error = root.xpath(
|
|
"//cas:authenticationFailure",
|
|
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
)
|
|
self.assertEqual(len(error), 1)
|
|
self.assertEqual(error[0].attrib['code'], "INVALID_SERVICE")
|
|
self.assertEqual(error[0].text, bad_service)
|
|
|
|
def test_validate_service_view_badticket_goodprefix(self):
|
|
get_user_ticket_request(self.service)
|
|
|
|
client = Client()
|
|
bad_ticket = "%s-RANDOM" % settings.CAS_SERVICE_TICKET_PREFIX
|
|
response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
|
|
self.assertEqual(response.status_code, 200)
|
|
|
|
root = etree.fromstring(response.content)
|
|
error = root.xpath(
|
|
"//cas:authenticationFailure",
|
|
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
)
|
|
self.assertEqual(len(error), 1)
|
|
self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
|
|
self.assertEqual(error[0].text, 'ticket not found')
|
|
|
|
def test_validate_service_view_badticket_badprefix(self):
|
|
get_user_ticket_request(self.service)
|
|
|
|
client = Client()
|
|
bad_ticket = "RANDOM"
|
|
response = client.get('/serviceValidate', {'ticket': bad_ticket, 'service': self.service})
|
|
self.assertEqual(response.status_code, 200)
|
|
|
|
root = etree.fromstring(response.content)
|
|
error = root.xpath(
|
|
"//cas:authenticationFailure",
|
|
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
)
|
|
self.assertEqual(len(error), 1)
|
|
self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
|
|
self.assertEqual(error[0].text, bad_ticket)
|
|
|
|
def test_validate_service_view_ok_pgturl(self):
|
|
(host, port) = utils.PGTUrlHandler.run()[1:3]
|
|
service = "http://%s:%s" % (host, port)
|
|
|
|
ticket = get_user_ticket_request(service)[1]
|
|
|
|
client = Client()
|
|
response = client.get(
|
|
'/serviceValidate',
|
|
{'ticket': ticket.value, 'service': service, 'pgtUrl': service}
|
|
)
|
|
pgt_params = utils.PGTUrlHandler.PARAMS.copy()
|
|
self.assertEqual(response.status_code, 200)
|
|
|
|
root = etree.fromstring(response.content)
|
|
pgtiou = root.xpath(
|
|
"//cas:proxyGrantingTicket",
|
|
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
)
|
|
self.assertEqual(len(pgtiou), 1)
|
|
self.assertEqual(pgt_params["pgtIou"], pgtiou[0].text)
|
|
self.assertTrue("pgtId" in pgt_params)
|
|
|
|
def test_validate_service_pgturl_bad_proxy_callback(self):
|
|
self.service_pattern.proxy_callback = False
|
|
self.service_pattern.save()
|
|
ticket = get_user_ticket_request(self.service)[1]
|
|
|
|
client = Client()
|
|
response = client.get(
|
|
'/serviceValidate',
|
|
{'ticket': ticket.value, 'service': self.service, 'pgtUrl': self.service}
|
|
)
|
|
self.assertEqual(response.status_code, 200)
|
|
|
|
root = etree.fromstring(response.content)
|
|
error = root.xpath(
|
|
"//cas:authenticationFailure",
|
|
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
)
|
|
self.assertEqual(len(error), 1)
|
|
self.assertEqual(error[0].attrib['code'], "INVALID_PROXY_CALLBACK")
|
|
self.assertEqual(error[0].text, "callback url not allowed by configuration")
|
|
|
|
|
|
class ProxyTestCase(TestCase):
|
|
|
|
def setUp(self):
|
|
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
|
|
self.service = 'http://127.0.0.1'
|
|
self.service_pattern = models.ServicePattern.objects.create(
|
|
name="localhost",
|
|
pattern="^http://127\.0\.0\.1(:[0-9]+)?(/.*)?$",
|
|
proxy=True,
|
|
proxy_callback=True
|
|
)
|
|
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
|
|
|
|
def test_validate_proxy_ok(self):
|
|
params = get_pgt()
|
|
|
|
# get a proxy ticket
|
|
client1 = Client()
|
|
response = client1.get('/proxy', {'pgt': params['pgtId'], 'targetService': self.service})
|
|
self.assertEqual(response.status_code, 200)
|
|
|
|
root = etree.fromstring(response.content)
|
|
sucess = root.xpath("//cas:proxySuccess", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
|
self.assertTrue(sucess)
|
|
|
|
proxy_ticket = root.xpath(
|
|
"//cas:proxyTicket",
|
|
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
)
|
|
self.assertEqual(len(proxy_ticket), 1)
|
|
proxy_ticket = proxy_ticket[0].text
|
|
|
|
# validate the proxy ticket
|
|
client2 = Client()
|
|
response = client2.get('/proxyValidate', {'ticket': proxy_ticket, 'service': self.service})
|
|
self.assertEqual(response.status_code, 200)
|
|
|
|
root = etree.fromstring(response.content)
|
|
sucess = root.xpath(
|
|
"//cas:authenticationSuccess",
|
|
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
)
|
|
self.assertTrue(sucess)
|
|
|
|
# check that the proxy is send to the end service
|
|
proxies = root.xpath("//cas:proxies", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
|
self.assertEqual(len(proxies), 1)
|
|
proxy = proxies[0].xpath("//cas:proxy", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
|
self.assertEqual(len(proxy), 1)
|
|
self.assertEqual(proxy[0].text, params["service"])
|
|
|
|
# same tests than those for serviceValidate
|
|
users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
|
self.assertEqual(len(users), 1)
|
|
self.assertEqual(users[0].text, settings.CAS_TEST_USER)
|
|
|
|
attributes = root.xpath(
|
|
"//cas:attributes",
|
|
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
)
|
|
self.assertEqual(len(attributes), 1)
|
|
attrs1 = set()
|
|
for attr in attributes[0]:
|
|
attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
|
|
|
|
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
|
|
self.assertEqual(len(attributes), len(attrs1))
|
|
attrs2 = set()
|
|
for attr in attributes:
|
|
attrs2.add((attr.attrib['name'], attr.attrib['value']))
|
|
original = set()
|
|
for key, value in settings.CAS_TEST_ATTRIBUTES.items():
|
|
if isinstance(value, list):
|
|
for sub_value in value:
|
|
original.add((key, sub_value))
|
|
else:
|
|
original.add((key, value))
|
|
self.assertEqual(attrs1, attrs2)
|
|
self.assertEqual(attrs1, original)
|
|
|
|
def test_validate_proxy_bad(self):
|
|
params = get_pgt()
|
|
|
|
# bad PGT
|
|
client1 = Client()
|
|
response = client1.get(
|
|
'/proxy',
|
|
{
|
|
'pgt': "%s-RANDOM" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX,
|
|
'targetService': params['service']
|
|
}
|
|
)
|
|
self.assertEqual(response.status_code, 200)
|
|
|
|
root = etree.fromstring(response.content)
|
|
error = root.xpath(
|
|
"//cas:authenticationFailure",
|
|
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
)
|
|
self.assertEqual(len(error), 1)
|
|
self.assertEqual(error[0].attrib['code'], "INVALID_TICKET")
|
|
self.assertEqual(
|
|
error[0].text,
|
|
"PGT %s-RANDOM not found" % settings.CAS_PROXY_GRANTING_TICKET_PREFIX
|
|
)
|
|
|
|
# bad targetService
|
|
client2 = Client()
|
|
response = client2.get(
|
|
'/proxy',
|
|
{'pgt': params['pgtId'], 'targetService': "https://www.example.org"}
|
|
)
|
|
self.assertEqual(response.status_code, 200)
|
|
|
|
root = etree.fromstring(response.content)
|
|
error = root.xpath(
|
|
"//cas:authenticationFailure",
|
|
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
)
|
|
self.assertEqual(len(error), 1)
|
|
self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE")
|
|
self.assertEqual(error[0].text, "https://www.example.org")
|
|
|
|
# service do not allow proxy ticket
|
|
self.service_pattern.proxy = False
|
|
self.service_pattern.save()
|
|
|
|
client3 = Client()
|
|
response = client3.get(
|
|
'/proxy',
|
|
{'pgt': params['pgtId'], 'targetService': params['service']}
|
|
)
|
|
self.assertEqual(response.status_code, 200)
|
|
|
|
root = etree.fromstring(response.content)
|
|
error = root.xpath(
|
|
"//cas:authenticationFailure",
|
|
namespaces={'cas': "http://www.yale.edu/tp/cas"}
|
|
)
|
|
self.assertEqual(len(error), 1)
|
|
self.assertEqual(error[0].attrib['code'], "UNAUTHORIZED_SERVICE")
|
|
self.assertEqual(
|
|
error[0].text,
|
|
'the service %s do not allow proxy ticket' % params['service']
|
|
)
|