More unit tests (essentially for the login view) and some docstrings

This commit is contained in:
Valentin Samir 2016-06-27 23:55:17 +02:00
parent 7db3157864
commit bab79c4de5
8 changed files with 343 additions and 63 deletions

View File

@ -5,3 +5,4 @@ exclude_lines =
def __unicode__ def __unicode__
raise AssertionError raise AssertionError
raise NotImplementedError raise NotImplementedError
if six.PY3:

View File

@ -49,8 +49,9 @@ coverage: test_venv
test_venv/bin/pip install coverage test_venv/bin/pip install coverage
test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests test_venv/bin/coverage run --source='cas_server' --omit='cas_server/migrations*' run_tests
test_venv/bin/coverage html test_venv/bin/coverage html
test_venv/bin/coverage xml rm htmlcov/coverage_html.js # I am really pissed off by those keybord shortcuts
coverage_codacy: coverage coverage_codacy: coverage
test_venv/bin/coverage xml
test_venv/bin/pip install codacy-coverage test_venv/bin/pip install codacy-coverage
test_venv/bin/python-codacy-coverage -r coverage.xml test_venv/bin/python-codacy-coverage -r coverage.xml

View File

@ -219,7 +219,8 @@ Test backend settings. Only usefull if you are using the test authentication bac
* ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``. * ``CAS_TEST_USER``: Username of the test user. The default is ``"test"``.
* ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``. * ``CAS_TEST_PASSWORD``: Password of the test user. The default is ``"test"``.
* ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is * ``CAS_TEST_ATTRIBUTES``: Attributes of the test user. The default is
``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'}``. ``{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net',
'alias': ['demo1', 'demo2']}``.
Authentication backend Authentication backend

View File

@ -78,5 +78,10 @@ setting_default('CAS_TEST_USER', 'test')
setting_default('CAS_TEST_PASSWORD', 'test') setting_default('CAS_TEST_PASSWORD', 'test')
setting_default( setting_default(
'CAS_TEST_ATTRIBUTES', 'CAS_TEST_ATTRIBUTES',
{'nom': 'Nymous', 'prenom': 'Ano', 'email': 'anonymous@example.net'} {
'nom': 'Nymous',
'prenom': 'Ano',
'email': 'anonymous@example.net',
'alias': ['demo1', 'demo2']
}
) )

View File

@ -3,36 +3,49 @@ from .default_settings import settings
from django.test import TestCase from django.test import TestCase
from django.test import Client from django.test import Client
import re
import six import six
import random
from lxml import etree from lxml import etree
from six.moves import range
from cas_server import models from cas_server import models
from cas_server import utils from cas_server import utils
def get_login_page_params(): def copy_form(form):
client = Client() """Copy form value into a dict"""
response = client.get('/login')
form = response.context["form"]
params = {} params = {}
for field in form: for field in form:
if field.value(): if field.value():
params[field.name] = field.value() params[field.name] = field.value()
else: else:
params[field.name] = "" params[field.name] = ""
return params
def get_login_page_params(client=None):
"""Return a client and the POST params for the client to login"""
if client is None:
client = Client()
response = client.get('/login')
params = copy_form(response.context["form"])
return client, params return client, params
def get_auth_client(): def get_auth_client(**update):
"""return a authenticated client"""
client, params = get_login_page_params() client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD params["password"] = settings.CAS_TEST_PASSWORD
params.update(update)
client.post('/login', params) client.post('/login', params)
return client return client
def get_user_ticket_request(service): def get_user_ticket_request(service):
"""Make an auth client to request a ticket for `service`, return the tuple (user, ticket)"""
client = get_auth_client() client = get_auth_client()
response = client.get("/login", {"service": service}) response = client.get("/login", {"service": service})
ticket_value = response['Location'].split('ticket=')[-1] ticket_value = response['Location'].split('ticket=')[-1]
@ -45,6 +58,7 @@ def get_user_ticket_request(service):
def get_pgt(): def get_pgt():
"""return a dict contening a service, user and PGT ticket for this service"""
(host, port) = utils.PGTUrlHandler.run()[1:3] (host, port) = utils.PGTUrlHandler.run()[1:3]
service = "http://%s:%s" % (host, port) service = "http://%s:%s" % (host, port)
@ -110,7 +124,7 @@ class CheckPasswordCase(TestCase):
self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8")) self.assertTrue(utils.check_password("hex_md5", self.password1, hashed_password1, "utf8"))
self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8")) self.assertFalse(utils.check_password("hex_md5", self.password2, hashed_password1, "utf8"))
def test_hox_sha512(self): def test_hex_sha512(self):
"""test the hex_sha512 auth method""" """test the hex_sha512 auth method"""
hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest() hashed_password1 = utils.hashlib.sha512(self.password1).hexdigest()
@ -123,29 +137,83 @@ class CheckPasswordCase(TestCase):
class LoginTestCase(TestCase): class LoginTestCase(TestCase):
"""Tests for the login view"""
def setUp(self): def setUp(self):
"""
Prepare the test context:
* set the auth class to 'cas_server.auth.TestAuthUser'
* create a service pattern for https://www.example.com/**
* Set the service pattern to return all user attributes
"""
settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser' settings.CAS_AUTH_CLASS = 'cas_server.auth.TestAuthUser'
# For general purpose testing
self.service_pattern = models.ServicePattern.objects.create( self.service_pattern = models.ServicePattern.objects.create(
name="example", name="example",
pattern="^https://www\.example\.com(/.*)?$", pattern="^https://www\.example\.com(/.*)?$",
) )
models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern) models.ReplaceAttributName.objects.create(name="*", service_pattern=self.service_pattern)
def test_login_view_post_goodpass_goodlt(self): # For testing the restrict_users attributes
client, params = get_login_page_params() self.service_pattern_restrict_user_fail = models.ServicePattern.objects.create(
params["username"] = settings.CAS_TEST_USER name="restrict_user_fail",
params["password"] = settings.CAS_TEST_PASSWORD pattern="^https://restrict_user_fail\.example\.com(/.*)?$",
restrict_users=True,
)
self.service_pattern_restrict_user_success = models.ServicePattern.objects.create(
name="restrict_user_success",
pattern="^https://restrict_user_success\.example\.com(/.*)?$",
restrict_users=True,
)
models.Username.objects.create(
value=settings.CAS_TEST_USER,
service_pattern=self.service_pattern_restrict_user_success
)
response = client.post('/login', params) # For testing the user attributes filtering conditions
self.service_pattern_filter_fail = models.ServicePattern.objects.create(
name="filter_fail",
pattern="^https://filter_fail\.example\.com(/.*)?$",
)
models.FilterAttributValue.objects.create(
attribut="right",
pattern="^admin$",
service_pattern=self.service_pattern_filter_fail
)
self.service_pattern_filter_success = models.ServicePattern.objects.create(
name="filter_success",
pattern="^https://filter_success\.example\.com(/.*)?$",
)
models.FilterAttributValue.objects.create(
attribut="email",
pattern="^%s$" % re.escape(settings.CAS_TEST_ATTRIBUTES['email']),
service_pattern=self.service_pattern_filter_success
)
self.assertEqual(response.status_code, 200) # For testing the user_field attributes
self.service_pattern_field_needed_fail = models.ServicePattern.objects.create(
name="field_needed_fail",
pattern="^https://field_needed_fail\.example\.com(/.*)?$",
user_field="uid"
)
self.service_pattern_field_needed_success = models.ServicePattern.objects.create(
name="field_needed_success",
pattern="^https://field_needed_success\.example\.com(/.*)?$",
user_field="nom"
)
def assert_logged(self, client, response, warn=False, code=200):
"""Assertions testing that client is well authenticated"""
self.assertEqual(response.status_code, code)
self.assertTrue( self.assertTrue(
( (
b"You have successfully logged into " b"You have successfully logged into "
b"the Central Authentication Service" b"the Central Authentication Service"
) in response.content ) in response.content
) )
self.assertTrue(client.session["username"] == settings.CAS_TEST_USER)
self.assertTrue(client.session["warn"] is warn)
self.assertTrue(client.session["authenticated"] is True)
self.assertTrue( self.assertTrue(
models.User.objects.get( models.User.objects.get(
@ -154,7 +222,59 @@ class LoginTestCase(TestCase):
) )
) )
def assert_login_failed(self, client, response, code=200):
"""Assertions testing a failed login attempt"""
self.assertEqual(response.status_code, code)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
self.assertTrue(client.session.get("username") is None)
self.assertTrue(client.session.get("warn") is None)
self.assertTrue(client.session.get("authenticated") is None)
def test_login_view_post_goodpass_goodlt(self):
"""Test a successul login"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
self.assertTrue(params['lt'] in client.session['lt'])
response = client.post('/login', params)
self.assert_logged(client, response)
# LoginTicket conssumed
self.assertTrue(params['lt'] not in client.session['lt'])
def test_login_view_post_goodpass_goodlt_warn(self):
"""Test a successul login requesting to be warned before creating services tickets"""
client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD
params["warn"] = "on"
response = client.post('/login', params)
self.assert_logged(client, response, warn=True)
def test_lt_max(self):
"""Check we only keep the last 100 Login Ticket for a user"""
client, params = get_login_page_params()
current_lt = params["lt"]
i_in_test = random.randint(0, 100)
i_not_in_test = random.randint(100, 150)
for i in range(150):
if i == i_in_test:
self.assertTrue(current_lt in client.session['lt'])
if i == i_not_in_test:
self.assertTrue(current_lt not in client.session['lt'])
self.assertTrue(len(client.session['lt']) <= 100)
client, params = get_login_page_params(client)
self.assertTrue(len(client.session['lt']) <= 100)
def test_login_view_post_badlt(self): def test_login_view_post_badlt(self):
"""Login attempt with a bad LoginTicket"""
client, params = get_login_page_params() client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER params["username"] = settings.CAS_TEST_USER
params["password"] = settings.CAS_TEST_PASSWORD params["password"] = settings.CAS_TEST_PASSWORD
@ -162,47 +282,26 @@ class LoginTestCase(TestCase):
response = client.post('/login', params) response = client.post('/login', params)
self.assertEqual(response.status_code, 200) self.assert_login_failed(client, response)
self.assertTrue(b"Invalid login ticket" in response.content) self.assertTrue(b"Invalid login ticket" in response.content)
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
def test_login_view_post_badpass_good_lt(self): def test_login_view_post_badpass_good_lt(self):
"""Login attempt with a bad password"""
client, params = get_login_page_params() client, params = get_login_page_params()
params["username"] = settings.CAS_TEST_USER params["username"] = settings.CAS_TEST_USER
params["password"] = "test2" params["password"] = "test2"
response = client.post('/login', params) response = client.post('/login', params)
self.assertEqual(response.status_code, 200) self.assert_login_failed(client, response)
self.assertTrue( self.assertTrue(
( (
b"The credentials you provided cannot be " b"The credentials you provided cannot be "
b"determined to be authentic" b"determined to be authentic"
) in response.content ) in response.content
) )
self.assertFalse(
(
b"You have successfully logged into "
b"the Central Authentication Service"
) in response.content
)
def test_view_login_get_auth_allowed_service(self): def assert_ticket_attributes(self, client, ticket_value):
client = get_auth_client() """check the ticket attributes in the db"""
response = client.get("/login?service=https://www.example.com")
self.assertEqual(response.status_code, 302)
self.assertTrue(response.has_header('Location'))
self.assertTrue(
response['Location'].startswith(
"https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX
)
)
ticket_value = response['Location'].split('ticket=')[-1]
user = models.User.objects.get( user = models.User.objects.get(
username=settings.CAS_TEST_USER, username=settings.CAS_TEST_USER,
session_key=client.session.session_key session_key=client.session.session_key
@ -214,12 +313,136 @@ class LoginTestCase(TestCase):
self.assertEqual(ticket.validate, False) self.assertEqual(ticket.validate, False)
self.assertEqual(ticket.service_pattern, self.service_pattern) self.assertEqual(ticket.service_pattern, self.service_pattern)
def assert_service_ticket(self, client, response):
"""check that a ticket is well emited when requested on a allowed service"""
self.assertEqual(response.status_code, 302)
self.assertTrue(response.has_header('Location'))
self.assertTrue(
response['Location'].startswith(
"https://www.example.com?ticket=%s-" % settings.CAS_SERVICE_TICKET_PREFIX
)
)
ticket_value = response['Location'].split('ticket=')[-1]
self.assert_ticket_attributes(client, ticket_value)
def test_view_login_get_allowed_service(self):
"""Request a ticket for an allowed service by an unauthenticated client"""
client = Client()
response = client.get("/login?service=https://www.example.com")
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
"Authentication required by service "
"example (https://www.example.com)"
) in response.content
)
def test_view_login_get_denied_service(self):
"""Request a ticket for an denied service by an unauthenticated client"""
client = Client()
response = client.get("/login?service=https://www.example.net")
self.assertEqual(response.status_code, 200)
self.assertTrue("Service https://www.example.net non allowed" in response.content)
def test_view_login_get_auth_allowed_service(self):
"""Request a ticket for an allowed service by an authenticated client"""
# client is already authenticated
client = get_auth_client()
response = client.get("/login?service=https://www.example.com")
self.assert_service_ticket(client, response)
def test_view_login_get_auth_allowed_service_warn(self):
"""Request a ticket for an allowed service by an authenticated client"""
# client is already authenticated
client = get_auth_client(warn="on")
response = client.get("/login?service=https://www.example.com")
self.assertEqual(response.status_code, 200)
self.assertTrue(
(
"Authentication has been required by service "
"example (https://www.example.com)"
) in response.content
)
params = copy_form(response.context["form"])
response = client.post("/login", params)
self.assert_service_ticket(client, response)
def test_view_login_get_auth_denied_service(self): def test_view_login_get_auth_denied_service(self):
"""Request a ticket for a not allowed service by an authenticated client"""
client = get_auth_client() client = get_auth_client()
response = client.get("/login?service=https://www.example.org") response = client.get("/login?service=https://www.example.org")
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
self.assertTrue(b"Service https://www.example.org non allowed" in response.content) self.assertTrue(b"Service https://www.example.org non allowed" in response.content)
def test_user_logged_not_in_db(self):
"""If the user is logged but has been delete from the database, it should be logged out"""
client = get_auth_client()
models.User.objects.get(
username=settings.CAS_TEST_USER,
session_key=client.session.session_key
).delete()
response = client.get("/login")
self.assert_login_failed(client, response, code=302)
self.assertEqual(response["Location"], "/login?")
def test_service_restrict_user(self):
"""Testing the restric user capability fro a service"""
service = "https://restrict_user_fail.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 200)
self.assertTrue("Username non allowed" in response.content)
service = "https://restrict_user_success.example.com"
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 302)
self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
def test_service_filter(self):
"""Test the filtering on user attributes"""
service = "https://filter_fail.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 200)
self.assertTrue("User charateristics non allowed" in response.content)
service = "https://filter_success.example.com"
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 302)
self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
def test_service_user_field(self):
"""Test using a user attribute as username: case on if the attribute exists or not"""
service = "https://field_needed_fail.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 200)
self.assertTrue("The attribut uid is needed to use that service" in response.content)
service = "https://field_needed_success.example.com"
response = client.get("/login", {'service': service})
self.assertEqual(response.status_code, 302)
self.assertTrue(response["Location"].startswith("%s?ticket=" % service))
def test_gateway(self):
"""test gateway parameter"""
# First with an authenticated client that fail to get a ticket for a service
service = "https://restrict_user_fail.example.com"
client = get_auth_client()
response = client.get("/login", {'service': service, 'gateway': 'on'})
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], service)
# second for an user not yet authenticated on a valid service
client = Client()
response = client.get('/login', {'service': service, 'gateway': 'on'})
self.assertEqual(response.status_code, 302)
self.assertEqual(response["Location"], service)
class LogoutTestCase(TestCase): class LogoutTestCase(TestCase):
@ -454,17 +677,24 @@ class ValidateServiceTestCase(TestCase):
namespaces={'cas': "http://www.yale.edu/tp/cas"} namespaces={'cas': "http://www.yale.edu/tp/cas"}
) )
self.assertEqual(len(attributes), 1) self.assertEqual(len(attributes), 1)
attrs1 = {} attrs1 = set()
for attr in attributes[0]: for attr in attributes[0]:
attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(attributes), len(attrs1)) self.assertEqual(len(attributes), len(attrs1))
attrs2 = {} attrs2 = set()
for attr in attributes: for attr in attributes:
attrs2[attr.attrib['name']] = attr.attrib['value'] attrs2.add((attr.attrib['name'], attr.attrib['value']))
original = set()
for key, value in settings.CAS_TEST_ATTRIBUTES.items():
if isinstance(value, list):
for v in value:
original.add((key, v))
else:
original.add((key, value))
self.assertEqual(attrs1, attrs2) self.assertEqual(attrs1, attrs2)
self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES) self.assertEqual(attrs1, original)
def test_validate_service_view_badservice(self): def test_validate_service_view_badservice(self):
ticket = get_user_ticket_request(self.service)[1] ticket = get_user_ticket_request(self.service)[1]
@ -623,17 +853,24 @@ class ProxyTestCase(TestCase):
namespaces={'cas': "http://www.yale.edu/tp/cas"} namespaces={'cas': "http://www.yale.edu/tp/cas"}
) )
self.assertEqual(len(attributes), 1) self.assertEqual(len(attributes), 1)
attrs1 = {} attrs1 = set()
for attr in attributes[0]: for attr in attributes[0]:
attrs1[attr.tag[len("http://www.yale.edu/tp/cas")+2:]] = attr.text attrs1.add((attr.tag[len("http://www.yale.edu/tp/cas")+2:], attr.text))
attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"}) attributes = root.xpath("//cas:attribute", namespaces={'cas': "http://www.yale.edu/tp/cas"})
self.assertEqual(len(attributes), len(attrs1)) self.assertEqual(len(attributes), len(attrs1))
attrs2 = {} attrs2 = set()
for attr in attributes: for attr in attributes:
attrs2[attr.attrib['name']] = attr.attrib['value'] attrs2.add((attr.attrib['name'], attr.attrib['value']))
original = set()
for key, value in settings.CAS_TEST_ATTRIBUTES.items():
if isinstance(value, list):
for v in value:
original.add((key, v))
else:
original.add((key, value))
self.assertEqual(attrs1, attrs2) self.assertEqual(attrs1, attrs2)
self.assertEqual(attrs1, settings.CAS_TEST_ATTRIBUTES) self.assertEqual(attrs1, original)
def test_validate_proxy_bad(self): def test_validate_proxy_bad(self):
params = get_pgt() params = get_pgt()

View File

@ -105,6 +105,7 @@ class LogoutView(View, LogoutMixin):
service = None service = None
def init_get(self, request): def init_get(self, request):
"""Initialize GET received parameters"""
self.request = request self.request = request
self.service = request.GET.get('service') self.service = request.GET.get('service')
self.url = request.GET.get('url') self.url = request.GET.get('url')
@ -196,6 +197,7 @@ class LoginView(View, LogoutMixin):
USER_NOT_AUTHENTICATED = 6 USER_NOT_AUTHENTICATED = 6
def init_post(self, request): def init_post(self, request):
"""Initialize POST received parameters"""
self.request = request self.request = request
self.service = request.POST.get('service') self.service = request.POST.get('service')
self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False") self.renew = bool(request.POST.get('renew') and request.POST['renew'] != "False")
@ -205,15 +207,19 @@ class LoginView(View, LogoutMixin):
if request.POST.get('warned') and request.POST['warned'] != "False": if request.POST.get('warned') and request.POST['warned'] != "False":
self.warned = True self.warned = True
def check_lt(self): def gen_lt(self):
# save LT for later check """Generate a new LoginTicket and add it to the list of valid LT for the user"""
lt_valid = self.request.session.get('lt', [])
lt_send = self.request.POST.get('lt')
# generate a new LT (by posting the LT has been consumed)
self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()] self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()]
if len(self.request.session['lt']) > 100: if len(self.request.session['lt']) > 100:
self.request.session['lt'] = self.request.session['lt'][-100:] self.request.session['lt'] = self.request.session['lt'][-100:]
def check_lt(self):
"""Check is the POSTed LoginTicket is valid, if yes invalide it"""
# save LT for later check
lt_valid = self.request.session.get('lt', [])
lt_send = self.request.POST.get('lt')
# generate a new LT (by posting the LT has been consumed)
self.gen_lt()
# check if send LT is valid # check if send LT is valid
if lt_valid is None or lt_send not in lt_valid: if lt_valid is None or lt_send not in lt_valid:
return False return False
@ -238,7 +244,7 @@ class LoginView(View, LogoutMixin):
username=self.request.session['username'], username=self.request.session['username'],
session_key=self.request.session.session_key session_key=self.request.session.session_key
) )
self.user.save() self.user.save() # pragma: no cover (should not happend)
except models.User.DoesNotExist: except models.User.DoesNotExist:
self.user = models.User.objects.create( self.user = models.User.objects.create(
username=self.request.session['username'], username=self.request.session['username'],
@ -250,10 +256,15 @@ class LoginView(View, LogoutMixin):
elif ret == self.USER_ALREADY_LOGGED: elif ret == self.USER_ALREADY_LOGGED:
pass pass
else: else:
raise EnvironmentError("invalid output for LoginView.process_post") raise EnvironmentError("invalid output for LoginView.process_post") # pragma: no cover
return self.common() return self.common()
def process_post(self): def process_post(self):
"""
Analyse the POST request:
* check that the LoginTicket is valid
* check that the user sumited credentials are valid
"""
if not self.check_lt(): if not self.check_lt():
values = self.request.POST.copy() values = self.request.POST.copy()
# if not set a new LT and fail # if not set a new LT and fail
@ -280,6 +291,7 @@ class LoginView(View, LogoutMixin):
return self.USER_ALREADY_LOGGED return self.USER_ALREADY_LOGGED
def init_get(self, request): def init_get(self, request):
"""Initialize GET received parameters"""
self.request = request self.request = request
self.service = request.GET.get('service') self.service = request.GET.get('service')
self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False") self.renew = bool(request.GET.get('renew') and request.GET['renew'] != "False")
@ -294,15 +306,16 @@ class LoginView(View, LogoutMixin):
return self.common() return self.common()
def process_get(self): def process_get(self):
# generate a new LT if none is present """Analyse the GET request"""
self.request.session['lt'] = self.request.session.get('lt', []) + [utils.gen_lt()] # generate a new LT
self.gen_lt()
if not self.request.session.get("authenticated") or self.renew: if not self.request.session.get("authenticated") or self.renew:
self.init_form() self.init_form()
return self.USER_NOT_AUTHENTICATED return self.USER_NOT_AUTHENTICATED
return self.USER_AUTHENTICATED return self.USER_AUTHENTICATED
def init_form(self, values=None): def init_form(self, values=None):
"""Initialization of the good form depending of POST and GET parameters"""
self.form = forms.UserCredential( self.form = forms.UserCredential(
values, values,
initial={ initial={

View File

@ -52,7 +52,7 @@ MIDDLEWARE_CLASSES = [
'django.middleware.locale.LocaleMiddleware', 'django.middleware.locale.LocaleMiddleware',
] ]
ROOT_URLCONF = 'cas_server.urls' ROOT_URLCONF = 'urls_tests'
# Database # Database
# https://docs.djangoproject.com/en/1.9/ref/settings/#databases # https://docs.djangoproject.com/en/1.9/ref/settings/#databases

22
urls_tests.py Normal file
View File

@ -0,0 +1,22 @@
"""cas URL Configuration
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/1.9/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views
2. Add a URL to urlpatterns: url(r'^$', views.home, name='home')
Class-based views
1. Add an import: from other_app.views import Home
2. Add a URL to urlpatterns: url(r'^$', Home.as_view(), name='home')
Including another URLconf
1. Import the include() function: from django.conf.urls import url, include, include
2. Add a URL to urlpatterns: url(r'^blog/', include('blog.urls'))
"""
from django.conf.urls import url, include
from django.contrib import admin
urlpatterns = [
url(r'^admin/', admin.site.urls),
url(r'^', include('cas_server.urls', namespace='cas_server')),
]