From 50781dba18f1a7b49c1422c9154339ed4d663ce9 Mon Sep 17 00:00:00 2001 From: Valentin Samir Date: Sun, 21 Jun 2015 18:56:16 +0200 Subject: [PATCH] add some tests --- .travis.yml | 2 + cas_server/forms.py | 4 +- cas_server/models.py | 2 +- cas_server/utils.py | 30 +++++++++----- cas_server/views.py | 8 ++-- tests/dummy.py | 76 +++++++++++++++++++++++++++++++++- tests/test_proxy.py | 52 +++++++++++++++++++++++ tests/test_validate_service.py | 12 ++---- tests/test_views_auth.py | 19 ++++----- tests/test_views_login.py | 21 ++++------ tests/test_views_logout.py | 18 ++------ tests/test_views_validate.py | 15 +++---- tox.ini | 14 +++++++ 13 files changed, 195 insertions(+), 78 deletions(-) create mode 100644 tests/test_proxy.py diff --git a/.travis.yml b/.travis.yml index 62e69e3..3d9d10c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,6 +7,8 @@ env: matrix: - TOX_ENV=py27-django17 - TOX_ENV=py27-django18 + - TOX_ENV=py34-django17 + - TOX_ENV=py34-django18 - TOX_ENV=flake8 cache: directories: diff --git a/cas_server/forms.py b/cas_server/forms.py index 4df2e03..58871b0 100644 --- a/cas_server/forms.py +++ b/cas_server/forms.py @@ -14,8 +14,8 @@ from .default_settings import settings from django import forms from django.utils.translation import ugettext_lazy as _ -import utils -import models +import cas_server.utils as utils +import cas_server.models as models class UserCredential(forms.Form): diff --git a/cas_server/models.py b/cas_server/models.py index b2d9813..6a21d12 100644 --- a/cas_server/models.py +++ b/cas_server/models.py @@ -27,7 +27,7 @@ from datetime import timedelta from concurrent.futures import ThreadPoolExecutor from requests_futures.sessions import FuturesSession -import utils +import cas_server.utils as utils SessionStore = import_module(settings.SESSION_ENGINE).SessionStore diff --git a/cas_server/utils.py b/cas_server/utils.py index 0cf8b59..f2b94ad 100644 --- a/cas_server/utils.py +++ b/cas_server/utils.py @@ -16,12 +16,17 @@ from django.utils.importlib import import_module from django.core.urlresolvers import reverse from django.http import HttpResponseRedirect -import urlparse -import urllib import random import string +try: + from urlparse import urlparse, urlunparse, parse_qsl + from urllib import urlencode +except ImportError: + from urllib.parse import urlparse, urlunparse, parse_qsl, urlencode + + def import_attr(path): """transform a python module.attr path to the attr""" if not isinstance(path, str): @@ -33,26 +38,29 @@ def import_attr(path): def redirect_params(url_name, params=None): """Redirect to `url_name` with `params` as querystring""" url = reverse(url_name) - params = urllib.urlencode(params if params else {}) + params = urlencode(params if params else {}) return HttpResponseRedirect(url + "?%s" % params) def update_url(url, params): """update params in the `url` query string""" - if isinstance(url, unicode): + if not isinstance(url, bytes): url = url.encode('utf-8') - for key, value in params.items(): - if isinstance(key, unicode): + for key, value in list(params.items()): + if not isinstance(key, bytes): del params[key] key = key.encode('utf-8') - if isinstance(value, unicode): + if not isinstance(value, bytes): value = value.encode('utf-8') params[key] = value - url_parts = list(urlparse.urlparse(url)) - query = dict(urlparse.parse_qsl(url_parts[4])) + url_parts = list(urlparse(url)) + query = dict(parse_qsl(url_parts[4])) query.update(params) - url_parts[4] = urllib.urlencode(query) - return urlparse.urlunparse(url_parts).decode('utf-8') + url_parts[4] = urlencode(query) + for i in range(len(url_parts)): + if not isinstance(url_parts[i], bytes): + url_parts[i] = url_parts[i].encode('utf-8') + return urlunparse(url_parts).decode('utf-8') def unpack_nested_exception(error): diff --git a/cas_server/views.py b/cas_server/views.py index cc3ee2b..dd1cc4d 100644 --- a/cas_server/views.py +++ b/cas_server/views.py @@ -26,9 +26,9 @@ import requests from lxml import etree from datetime import timedelta -import utils -import forms -import models +import cas_server.utils as utils +import cas_server.forms as forms +import cas_server.models as models from .models import ServiceTicket, ProxyTicket, ProxyGrantingTicket from .models import ServicePattern @@ -633,7 +633,7 @@ class Proxy(View): self.target_service, pattern, renew=False) - pticket.proxies.create(url=ticket.service) + models.Proxy.objects.create(proxy_ticket=pticket, url=ticket.service) return render( self.request, "cas_server/proxy.xml", diff --git a/tests/dummy.py b/tests/dummy.py index 95ee0fe..3534140 100644 --- a/tests/dummy.py +++ b/tests/dummy.py @@ -1,3 +1,4 @@ +import functools from cas_server import models class DummyUserManager(object): @@ -10,6 +11,75 @@ class DummyUserManager(object): else: raise models.User.DoesNotExist() + +def dummy(*args, **kwds): + pass + +def dummy_service_pattern(**kwargs): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwds): + service_validate = models.ServicePattern.validate + models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern(**kwargs)) + ret = func(*args, **kwds) + models.ServicePattern.validate = service_validate + return ret + return wrapper + return decorator + +def dummy_user(username, session_key): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwds): + user_manager = models.User.objects + user_save = models.User.save + user_delete = models.User.delete + models.User.objects = DummyUserManager(username, session_key) + models.User.save = dummy + models.User.delete = dummy + ret = func(*args, **kwds) + models.User.objects = user_manager + models.User.save = user_save + models.User.delete = user_delete + return ret + return wrapper + return decorator + +def dummy_ticket(ticket_class, service, ticket): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwds): + ticket_manager = ticket_class.objects + ticket_save = ticket_class.save + ticket_delete = ticket_class.delete + ticket_class.objects = DummyTicketManager(ticket_class, service, ticket) + ticket_class.save = dummy + ticket_class.delete = dummy + ret = func(*args, **kwds) + ticket_class.objects = ticket_manager + ticket_class.save = ticket_save + ticket_class.delete = ticket_delete + return ret + return wrapper + return decorator + + +def dummy_proxy(func): + @functools.wraps(func) + def wrapper(*args, **kwds): + proxy_manager = models.Proxy.objects + models.Proxy.objects = DummyProxyManager() + ret = func(*args, **kwds) + models.Proxy.objects = proxy_manager + return ret + return wrapper + +class DummyProxyManager(object): + def create(self, **kwargs): + for field in models.Proxy._meta.fields: + field.allow_unsaved_instance_assignment = True + return models.Proxy(**kwargs) + class DummyTicketManager(object): def __init__(self, ticket_class, service, ticket): self.ticket_class = ticket_class @@ -17,7 +87,7 @@ class DummyTicketManager(object): self.ticket = ticket def create(self, **kwargs): - for field in models.ServiceTicket._meta.fields: + for field in self.ticket_class._meta.fields: field.allow_unsaved_instance_assignment = True return self.ticket_class(**kwargs) @@ -25,6 +95,8 @@ class DummyTicketManager(object): return DummyQuerySet() def get(self, **kwargs): + for field in self.ticket_class._meta.fields: + field.allow_unsaved_instance_assignment = True if 'value' in kwargs: if kwargs['value'] != self.ticket: raise self.ticket_class.DoesNotExist() @@ -41,7 +113,7 @@ class DummyTicketManager(object): for field in models.ServiceTicket._meta.fields: field.allow_unsaved_instance_assignment = True - for key in kwargs.keys(): + for key in list(kwargs): if '__' in key: del kwargs[key] kwargs['attributs'] = {'mail': 'test@example.com'} diff --git a/tests/test_proxy.py b/tests/test_proxy.py new file mode 100644 index 0000000..963d834 --- /dev/null +++ b/tests/test_proxy.py @@ -0,0 +1,52 @@ +from __future__ import absolute_import +from tests.init import * + +from django.test import RequestFactory + +import os +import pytest +from lxml import etree +from cas_server.views import ValidateService, Proxy +from cas_server import models + +from tests.dummy import * + +@pytest.mark.django_db +@dummy_ticket(models.ProxyGrantingTicket, '', "PGT-random") +@dummy_service_pattern(proxy=True) +@dummy_user(username="test", session_key="test_session") +@dummy_ticket(models.ProxyTicket, "https://www.example.com", "PT-random") +@dummy_proxy +def test_proxy_ok(): + factory = RequestFactory() + request = factory.get('/proxy?pgt=PGT-random&targetService=https://www.example.com') + + request.session = DummySession() + + proxy = Proxy() + response = proxy.get(request) + + assert response.status_code == 200 + + root = etree.fromstring(response.content) + proxy_tickets = root.xpath("//cas:proxyTicket", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + + assert len(proxy_tickets) == 1 + + factory = RequestFactory() + request = factory.get('/proxyValidate?ticket=PT-random&service=https://www.example.com') + + validate = ValidateService() + validate.allow_proxy_ticket = True + response = validate.get(request) + + assert response.status_code == 200 + + root = etree.fromstring(response.content) + users = root.xpath("//cas:user", namespaces={'cas': "http://www.yale.edu/tp/cas"}) + + assert len(users) == 1 + assert users[0].text == "test" + + + diff --git a/tests/test_validate_service.py b/tests/test_validate_service.py index b89b8fb..940e23b 100644 --- a/tests/test_validate_service.py +++ b/tests/test_validate_service.py @@ -12,15 +12,13 @@ from cas_server import models from .dummy import * @pytest.mark.django_db +@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") def test_validate_service_view_ok(): factory = RequestFactory() request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example.com') request.session = DummySession() - models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random") - models.ServiceTicket.save = lambda x:None - validate = ValidateService() validate.allow_proxy_ticket = False response = validate.get(request) @@ -47,15 +45,13 @@ def test_validate_service_view_ok(): @pytest.mark.django_db +@dummy_ticket(models.ServiceTicket, 'https://www.example2.com', "ST-random") def test_validate_service_view_badservice(): factory = RequestFactory() request = factory.get('/serviceValidate?ticket=ST-random&service=https://www.example1.com') request.session = DummySession() - models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example2.com', "ST-random") - models.ServiceTicket.save = lambda x:None - validate = ValidateService() validate.allow_proxy_ticket = False response = validate.get(request) @@ -70,15 +66,13 @@ def test_validate_service_view_badservice(): assert error[0].attrib['code'] == 'INVALID_SERVICE' @pytest.mark.django_db +@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random2") def test_validate_service_view_badticket(): factory = RequestFactory() request = factory.get('/serviceValidate?ticket=ST-random1&service=https://www.example.com') request.session = DummySession() - models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random2") - models.ServiceTicket.save = lambda x:None - validate = ValidateService() validate.allow_proxy_ticket = False response = validate.get(request) diff --git a/tests/test_views_auth.py b/tests/test_views_auth.py index 133cb85..4b4a9eb 100644 --- a/tests/test_views_auth.py +++ b/tests/test_views_auth.py @@ -14,36 +14,33 @@ from .dummy import * settings.CAS_AUTH_SHARED_SECRET = "test" @pytest.mark.django_db +@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") +@dummy_user(username="test", session_key="test_session") +@dummy_service_pattern() def test_auth_view_goodpass(): factory = RequestFactory() request = factory.post('/auth', {'username':'test', 'password':'test', 'service':'https://www.example.com', 'secret':'test'}) request.session = DummySession() - models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key) - models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random") - models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern()) - auth = Auth() response = auth.post(request) assert response.status_code == 200 - assert response.content == "yes\n" - + assert response.content == b"yes\n" +@dummy_service_pattern() +@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") +@dummy_user(username="test", session_key="test_session") def test_auth_view_badpass(): factory = RequestFactory() request = factory.post('/auth', {'username':'test', 'password':'badpass', 'service':'https://www.example.com', 'secret':'test'}) request.session = DummySession() - models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key) - models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random") - models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern()) - auth = Auth() response = auth.post(request) assert response.status_code == 200 - assert response.content == "no\n" + assert response.content == b"no\n" diff --git a/tests/test_views_login.py b/tests/test_views_login.py index 7dc7203..3b7d580 100644 --- a/tests/test_views_login.py +++ b/tests/test_views_login.py @@ -92,6 +92,7 @@ def test_view_login_get_unauth(): assert response.status_code == 200 @pytest.mark.django_db +@dummy_user(username="test", session_key="test_session") def test_view_login_get_auth(): factory = RequestFactory() request = factory.post('/login') @@ -107,14 +108,15 @@ def test_view_login_get_auth(): assert ret == LoginView.USER_AUTHENTICATED - models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key) - login = LoginView() response = login.get(request) assert response.status_code == 200 @pytest.mark.django_db +@dummy_service_pattern() +@dummy_user(username="test", session_key="test_session") +@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") def test_view_login_get_auth_service(): factory = RequestFactory() request = factory.post('/login?service=https://www.example.com') @@ -130,12 +132,6 @@ def test_view_login_get_auth_service(): assert ret == LoginView.USER_AUTHENTICATED - models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key) - models.User.save = lambda x:None - models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random") - models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern()) - models.ServiceTicket.save = lambda x:None - login = LoginView() response = login.get(request) @@ -143,6 +139,9 @@ def test_view_login_get_auth_service(): assert response['Location'].startswith('https://www.example.com?ticket=ST-') @pytest.mark.django_db +@dummy_service_pattern() +@dummy_user(username="test", session_key="test_session") +@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") def test_view_login_get_auth_service_warn(): factory = RequestFactory() request = factory.post('/login?service=https://www.example.com') @@ -158,12 +157,6 @@ def test_view_login_get_auth_service_warn(): assert ret == LoginView.USER_AUTHENTICATED - models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key) - models.User.save = lambda x:None - models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random") - models.ServicePattern.validate = classmethod(lambda x,y: models.ServicePattern()) - models.ServiceTicket.save = lambda x:None - login = LoginView() response = login.get(request) diff --git a/tests/test_views_logout.py b/tests/test_views_logout.py index 9fa26f4..03410bd 100644 --- a/tests/test_views_logout.py +++ b/tests/test_views_logout.py @@ -13,6 +13,7 @@ from .dummy import * @pytest.mark.django_db +@dummy_user(username="test", session_key="test_session") def test_logout_view(): factory = RequestFactory() request = factory.get('/logout') @@ -23,21 +24,17 @@ def test_logout_view(): request.session["username"] = "test" request.session["warn"] = False - models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key) - dlist = [None] - models.User.delete = lambda x:dlist.pop() - logout = LogoutView() response = logout.get(request) assert response.status_code == 200 - assert dlist == [] assert not request.session.get("authenticated") assert not request.session.get("username") assert not request.session.get("warn") @pytest.mark.django_db +@dummy_user(username="test", session_key="test_session") def test_logout_view_url(): factory = RequestFactory() request = factory.get('/logout?url=https://www.example.com') @@ -48,16 +45,11 @@ def test_logout_view_url(): request.session["username"] = "test" request.session["warn"] = False - models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key) - dlist = [None] - models.User.delete = lambda x:dlist.pop() - logout = LogoutView() response = logout.get(request) assert response.status_code == 302 assert response['Location'] == 'https://www.example.com' - assert dlist == [] assert not request.session.get("authenticated") assert not request.session.get("username") assert not request.session.get("warn") @@ -65,6 +57,7 @@ def test_logout_view_url(): @pytest.mark.django_db +@dummy_user(username="test", session_key="test_session") def test_logout_view_service(): factory = RequestFactory() request = factory.get('/logout?service=https://www.example.com') @@ -75,16 +68,11 @@ def test_logout_view_service(): request.session["username"] = "test" request.session["warn"] = False - models.User.objects = DummyUserManager(username="test", session_key=request.session.session_key) - dlist = [None] - models.User.delete = lambda x:dlist.pop() - logout = LogoutView() response = logout.get(request) assert response.status_code == 302 assert response['Location'] == 'https://www.example.com' - assert dlist == [] assert not request.session.get("authenticated") assert not request.session.get("username") assert not request.session.get("warn") diff --git a/tests/test_views_validate.py b/tests/test_views_validate.py index 5f71056..0aec568 100644 --- a/tests/test_views_validate.py +++ b/tests/test_views_validate.py @@ -12,50 +12,47 @@ from cas_server import models from .dummy import * @pytest.mark.django_db +@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") def test_validate_view_ok(): factory = RequestFactory() request = factory.get('/validate?ticket=ST-random&service=https://www.example.com') request.session = DummySession() - models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random") - validate = Validate() response = validate.get(request) assert response.status_code == 200 - assert response.content == "yes\n" + assert response.content == b"yes\n" @pytest.mark.django_db +@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random") def test_validate_view_badservice(): factory = RequestFactory() request = factory.get('/validate?ticket=ST-random&service=https://www.example2.com') request.session = DummySession() - models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random") - validate = Validate() response = validate.get(request) assert response.status_code == 200 - assert response.content == "no\n" + assert response.content == b"no\n" @pytest.mark.django_db +@dummy_ticket(models.ServiceTicket, 'https://www.example.com', "ST-random1") def test_validate_view_badticket(): factory = RequestFactory() request = factory.get('/validate?ticket=ST-random2&service=https://www.example.com') request.session = DummySession() - models.ServiceTicket.objects = DummyTicketManager(models.ServiceTicket, 'https://www.example.com', "ST-random1") - validate = Validate() response = validate.get(request) assert response.status_code == 200 - assert response.content == "no\n" + assert response.content == b"no\n" diff --git a/tox.ini b/tox.ini index 61adaeb..ea1e77e 100644 --- a/tox.ini +++ b/tox.ini @@ -2,6 +2,8 @@ envlist= py27-django17, py27-django18, + py34-django17, + py34-django18, flake8, [flake8] @@ -27,6 +29,18 @@ deps = Django>=1.8,<1.9 {[base]deps} +[testenv:py34-django17] +basepython=python3.4 +deps = + Django>=1.7,<1.8 + {[base]deps} + +[testenv:py34-django18] +basepython=python3.4 +deps = + Django>=1.8,<1.9 + {[base]deps} + [testenv:flake8] basepython=python deps=flake8