Add some docstrings and comments

This commit is contained in:
Valentin Samir 2016-06-29 00:25:09 +02:00
parent 3e53429feb
commit 6972ad7536
7 changed files with 21 additions and 5 deletions

View File

@ -8,5 +8,5 @@
# Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. # Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
# #
# (c) 2015 Valentin Samir # (c) 2015 Valentin Samir
"""A django CAS server application"""
default_app_config = 'cas_server.apps.CasAppConfig' default_app_config = 'cas_server.apps.CasAppConfig'

View File

@ -1,7 +1,9 @@
"""django config module"""
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.apps import AppConfig from django.apps import AppConfig
class CasAppConfig(AppConfig): class CasAppConfig(AppConfig):
"""django CAS application config class"""
name = 'cas_server' name = 'cas_server'
verbose_name = _('Central Authentication Service') verbose_name = _('Central Authentication Service')

View File

@ -21,6 +21,7 @@ except ImportError:
class AuthUser(object): class AuthUser(object):
"""Authentication base class"""
def __init__(self, username): def __init__(self, username):
self.username = username self.username = username

View File

@ -19,6 +19,7 @@ import cas_server.models as models
class WarnForm(forms.Form): class WarnForm(forms.Form):
"""Form used on warn page before emiting a ticket"""
service = forms.CharField(widget=forms.HiddenInput(), required=False) service = forms.CharField(widget=forms.HiddenInput(), required=False)
renew = forms.BooleanField(widget=forms.HiddenInput(), required=False) renew = forms.BooleanField(widget=forms.HiddenInput(), required=False)
gateway = forms.CharField(widget=forms.HiddenInput(), required=False) gateway = forms.CharField(widget=forms.HiddenInput(), required=False)

View File

@ -47,6 +47,7 @@ class User(models.Model):
@classmethod @classmethod
def clean_old_entries(cls): def clean_old_entries(cls):
"""Remove users inactive since more that SESSION_COOKIE_AGE"""
users = cls.objects.filter( users = cls.objects.filter(
date__lt=(timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE)) date__lt=(timezone.now() - timedelta(seconds=settings.SESSION_COOKIE_AGE))
) )
@ -56,6 +57,7 @@ class User(models.Model):
@classmethod @classmethod
def clean_deleted_sessions(cls): def clean_deleted_sessions(cls):
"""Remove user where the session do not exists anymore"""
for user in cls.objects.all(): for user in cls.objects.all():
if not SessionStore(session_key=user.session_key).get('authenticated'): if not SessionStore(session_key=user.session_key).get('authenticated'):
user.logout() user.logout()
@ -141,6 +143,7 @@ class User(models.Model):
class ServicePatternException(Exception): class ServicePatternException(Exception):
"""Base exception of exceptions raised in the ServicePattern model"""
pass pass

View File

@ -1,3 +1,4 @@
"""Tests module"""
from .default_settings import settings from .default_settings import settings
import django import django
@ -83,7 +84,7 @@ class CheckPasswordCase(TestCase):
"""Generate random bytes string that will be used ass passwords""" """Generate random bytes string that will be used ass passwords"""
self.password1 = utils.gen_saml_id() self.password1 = utils.gen_saml_id()
self.password2 = utils.gen_saml_id() self.password2 = utils.gen_saml_id()
if not isinstance(self.password1, bytes): if not isinstance(self.password1, bytes): # pragma: no cover executed only in python3
self.password1 = self.password1.encode("utf8") self.password1 = self.password1.encode("utf8")
self.password2 = self.password2.encode("utf8") self.password2 = self.password2.encode("utf8")
@ -403,7 +404,7 @@ class LoginTestCase(TestCase, BaseServicePattern):
response = client.get("/login") response = client.get("/login")
self.assert_login_failed(client, response, code=302) self.assert_login_failed(client, response, code=302)
if django.VERSION < (1, 9): if django.VERSION < (1, 9): # pragma: no cover coverage is computed with dango 1.9
self.assertEqual(response["Location"], "http://testserver/login") self.assertEqual(response["Location"], "http://testserver/login")
else: else:
self.assertEqual(response["Location"], "/login?") self.assertEqual(response["Location"], "/login?")
@ -572,7 +573,7 @@ class LoginTestCase(TestCase, BaseServicePattern):
@override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser') @override_settings(CAS_AUTH_CLASS='cas_server.auth.TestAuthUser')
class LogoutTestCase(TestCase): class LogoutTestCase(TestCase):
"""test fot the logout view"""
def test_logout(self): def test_logout(self):
"""logout is idempotent""" """logout is idempotent"""
client = Client() client = Client()
@ -693,7 +694,7 @@ class LogoutTestCase(TestCase):
response = client.get('/logout') response = client.get('/logout')
self.assertEqual(response.status_code, 302) self.assertEqual(response.status_code, 302)
if django.VERSION < (1, 9): if django.VERSION < (1, 9): # pragma: no cover coverage is computed with dango 1.9
self.assertEqual(response["Location"], "http://testserver/login") self.assertEqual(response["Location"], "http://testserver/login")
else: else:
self.assertEqual(response["Location"], "/login") self.assertEqual(response["Location"], "/login")

View File

@ -30,11 +30,13 @@ from six.moves.urllib.parse import urlparse, urlunparse, parse_qsl, urlencode
def context(params): def context(params):
"""Function that add somes variable to the context before template rendering"""
params["settings"] = settings params["settings"] = settings
return params return params
def json_response(request, data): def json_response(request, data):
"""Wrapper dumping `data` to a json and sending it to the user with an HttpResponse"""
data["messages"] = [] data["messages"] = []
for msg in messages.get_messages(request): for msg in messages.get_messages(request):
data["messages"].append({'message': msg.message, 'level': msg.level_tag}) data["messages"].append({'message': msg.message, 'level': msg.level_tag})
@ -64,6 +66,7 @@ def redirect_params(url_name, params=None):
def reverse_params(url_name, params=None, **kwargs): def reverse_params(url_name, params=None, **kwargs):
"""compule the reverse url or `url_name` and add GET parameters from `params` to it"""
url = reverse(url_name, **kwargs) url = reverse(url_name, **kwargs)
params = urlencode(params if params else {}) params = urlencode(params if params else {})
return url + "?%s" % params return url + "?%s" % params
@ -152,6 +155,7 @@ class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler):
PARAMS = {} PARAMS = {}
def do_GET(self): def do_GET(self):
"""Called on a GET request on the BaseHTTPServer"""
self.send_response(200) self.send_response(200)
self.send_header(b"Content-type", "text/plain") self.send_header(b"Content-type", "text/plain")
self.end_headers() self.end_headers()
@ -161,15 +165,18 @@ class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler):
PGTUrlHandler.PARAMS.update(params) PGTUrlHandler.PARAMS.update(params)
def log_message(self, *args): def log_message(self, *args):
"""silent any log message"""
return return
@classmethod @classmethod
def run(cls): def run(cls):
"""Run a BaseHTTPServer using this class as handler"""
server_class = BaseHTTPServer.HTTPServer server_class = BaseHTTPServer.HTTPServer
httpd = server_class(("127.0.0.1", 0), cls) httpd = server_class(("127.0.0.1", 0), cls)
(host, port) = httpd.socket.getsockname() (host, port) = httpd.socket.getsockname()
def lauch(): def lauch():
"""routine to lauch in a background thread"""
httpd.handle_request() httpd.handle_request()
httpd.server_close() httpd.server_close()
@ -182,6 +189,7 @@ class PGTUrlHandler(BaseHTTPServer.BaseHTTPRequestHandler):
class PGTUrlHandler404(PGTUrlHandler): class PGTUrlHandler404(PGTUrlHandler):
"""A simple http server that always return 404 not found. Used in unit tests""" """A simple http server that always return 404 not found. Used in unit tests"""
def do_GET(self): def do_GET(self):
"""Called on a GET request on the BaseHTTPServer"""
self.send_response(404) self.send_response(404)
self.send_header(b"Content-type", "text/plain") self.send_header(b"Content-type", "text/plain")
self.end_headers() self.end_headers()